代码拉取完成,页面将自动刷新
# Benchmark script for LightGlue on real images
import argparse
import time
from collections import defaultdict
from pathlib import Path
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch._dynamo
from lightglue import LightGlue, SuperPoint
from lightglue.utils import load_image
torch.set_grad_enabled(False)
def measure(matcher, data, device="cuda", r=100):
timings = np.zeros((r, 1))
if device.type == "cuda":
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
# warmup
for _ in range(10):
_ = matcher(data)
# measurements
with torch.no_grad():
for rep in range(r):
if device.type == "cuda":
starter.record()
_ = matcher(data)
ender.record()
# sync gpu
torch.cuda.synchronize()
curr_time = starter.elapsed_time(ender)
else:
start = time.perf_counter()
_ = matcher(data)
curr_time = (time.perf_counter() - start) * 1e3
timings[rep] = curr_time
mean_syn = np.sum(timings) / r
std_syn = np.std(timings)
return {"mean": mean_syn, "std": std_syn}
def print_as_table(d, title, cnames):
print()
header = f"{title:30} " + " ".join([f"{x:>7}" for x in cnames])
print(header)
print("-" * len(header))
for k, l in d.items():
print(f"{k:30}", " ".join([f"{x:>7.1f}" for x in l]))
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark script for LightGlue")
parser.add_argument(
"--device",
choices=["auto", "cuda", "cpu", "mps"],
default="auto",
help="device to benchmark on",
)
parser.add_argument("--compile", action="store_true", help="Compile LightGlue runs")
parser.add_argument(
"--no_flash", action="store_true", help="disable FlashAttention"
)
parser.add_argument(
"--no_prune_thresholds",
action="store_true",
help="disable pruning thresholds (i.e. always do pruning)",
)
parser.add_argument(
"--add_superglue",
action="store_true",
help="add SuperGlue to the benchmark (requires hloc)",
)
parser.add_argument(
"--measure", default="time", choices=["time", "log-time", "throughput"]
)
parser.add_argument(
"--repeat", "--r", type=int, default=100, help="repetitions of measurements"
)
parser.add_argument(
"--num_keypoints",
nargs="+",
type=int,
default=[256, 512, 1024, 2048, 4096],
help="number of keypoints (list separated by spaces)",
)
parser.add_argument(
"--matmul_precision", default="highest", choices=["highest", "high", "medium"]
)
parser.add_argument(
"--save", default=None, type=str, help="path where figure should be saved"
)
args = parser.parse_intermixed_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if args.device != "auto":
device = torch.device(args.device)
print("Running benchmark on device:", device)
images = Path("assets")
inputs = {
"easy": (
load_image(images / "DSC_0411.JPG"),
load_image(images / "DSC_0410.JPG"),
),
"difficult": (
load_image(images / "sacre_coeur1.jpg"),
load_image(images / "sacre_coeur2.jpg"),
),
}
configs = {
"LightGlue-full": {
"depth_confidence": -1,
"width_confidence": -1,
},
# 'LG-prune': {
# 'width_confidence': -1,
# },
# 'LG-depth': {
# 'depth_confidence': -1,
# },
"LightGlue-adaptive": {},
}
if args.compile:
configs = {**configs, **{k + "-compile": v for k, v in configs.items()}}
sg_configs = {
# 'SuperGlue': {},
"SuperGlue-fast": {"sinkhorn_iterations": 5}
}
torch.set_float32_matmul_precision(args.matmul_precision)
results = {k: defaultdict(list) for k, v in inputs.items()}
extractor = SuperPoint(max_num_keypoints=None, detection_threshold=-1)
extractor = extractor.eval().to(device)
figsize = (len(inputs) * 4.5, 4.5)
fig, axes = plt.subplots(1, len(inputs), sharey=True, figsize=figsize)
axes = axes if len(inputs) > 1 else [axes]
fig.canvas.manager.set_window_title(f"LightGlue benchmark ({device.type})")
for title, ax in zip(inputs.keys(), axes):
ax.set_xscale("log", base=2)
bases = [2**x for x in range(7, 16)]
ax.set_xticks(bases, bases)
ax.grid(which="major")
if args.measure == "log-time":
ax.set_yscale("log")
yticks = [10**x for x in range(6)]
ax.set_yticks(yticks, yticks)
mpos = [10**x * i for x in range(6) for i in range(2, 10)]
mlabel = [
10**x * i if i in [2, 5] else None
for x in range(6)
for i in range(2, 10)
]
ax.set_yticks(mpos, mlabel, minor=True)
ax.grid(which="minor", linewidth=0.2)
ax.set_title(title)
ax.set_xlabel("# keypoints")
if args.measure == "throughput":
ax.set_ylabel("Throughput [pairs/s]")
else:
ax.set_ylabel("Latency [ms]")
for name, conf in configs.items():
print("Run benchmark for:", name)
torch.cuda.empty_cache()
matcher = LightGlue(features="superpoint", flash=not args.no_flash, **conf)
if args.no_prune_thresholds:
matcher.pruning_keypoint_thresholds = {
k: -1 for k in matcher.pruning_keypoint_thresholds
}
matcher = matcher.eval().to(device)
if name.endswith("compile"):
import torch._dynamo
torch._dynamo.reset() # avoid buffer overflow
matcher.compile()
for pair_name, ax in zip(inputs.keys(), axes):
image0, image1 = [x.to(device) for x in inputs[pair_name]]
runtimes = []
for num_kpts in args.num_keypoints:
extractor.conf.max_num_keypoints = num_kpts
feats0 = extractor.extract(image0)
feats1 = extractor.extract(image1)
runtime = measure(
matcher,
{"image0": feats0, "image1": feats1},
device=device,
r=args.repeat,
)["mean"]
results[pair_name][name].append(
1000 / runtime if args.measure == "throughput" else runtime
)
ax.plot(
args.num_keypoints, results[pair_name][name], label=name, marker="o"
)
del matcher, feats0, feats1
if args.add_superglue:
from hloc.matchers.superglue import SuperGlue
for name, conf in sg_configs.items():
print("Run benchmark for:", name)
matcher = SuperGlue(conf)
matcher = matcher.eval().to(device)
for pair_name, ax in zip(inputs.keys(), axes):
image0, image1 = [x.to(device) for x in inputs[pair_name]]
runtimes = []
for num_kpts in args.num_keypoints:
extractor.conf.max_num_keypoints = num_kpts
feats0 = extractor.extract(image0)
feats1 = extractor.extract(image1)
data = {
"image0": image0[None],
"image1": image1[None],
**{k + "0": v for k, v in feats0.items()},
**{k + "1": v for k, v in feats1.items()},
}
data["scores0"] = data["keypoint_scores0"]
data["scores1"] = data["keypoint_scores1"]
data["descriptors0"] = (
data["descriptors0"].transpose(-1, -2).contiguous()
)
data["descriptors1"] = (
data["descriptors1"].transpose(-1, -2).contiguous()
)
runtime = measure(matcher, data, device=device, r=args.repeat)[
"mean"
]
results[pair_name][name].append(
1000 / runtime if args.measure == "throughput" else runtime
)
ax.plot(
args.num_keypoints, results[pair_name][name], label=name, marker="o"
)
del matcher, data, image0, image1, feats0, feats1
for name, runtimes in results.items():
print_as_table(runtimes, name, args.num_keypoints)
axes[0].legend()
fig.tight_layout()
if args.save:
plt.savefig(args.save, dpi=fig.dpi)
plt.show()
此处可能存在不合适展示的内容,页面不予展示。您可通过相关编辑功能自查并修改。
如您确认内容无涉及 不当用语 / 纯广告导流 / 暴力 / 低俗色情 / 侵权 / 盗版 / 虚假 / 无价值内容或违法国家有关法律法规的内容,可点击提交进行申诉,我们将尽快为您处理。