修复非极大值抑制函数中的张量维度问题,并更新compute_ap函数以使用numpy.trapezoid替代已弃用的numpy.trapz

This commit is contained in:
2025-10-03 20:23:59 +08:00
parent 33586e0c0c
commit 86c7579b42

View File

@@ -151,7 +151,7 @@ def non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.65)
box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2) box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2)
if nc > 1: if nc > 1:
i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float()), dim=1) x = torch.cat((box[i], x[i, 4 + j].unsqueeze(1), j[:, None].float()), dim=1)
else: # best class only else: # best class only
conf, j = cls.max(1, keepdim=True) conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold] x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold]
@@ -296,7 +296,8 @@ def compute_ap(tp, conf, output, target, plot=False, names=(), eps=1e-16):
# Integrate area under curve # Integrate area under curve
x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO) x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO)
ap[ci, j] = numpy.trapz(numpy.interp(x, m_rec, m_pre), x) # integrate # numpy.trapz is deprecated in numpy 2.0.0 or after version, use numpy.trapezoid instead
ap[ci, j] = numpy.trapezoid(numpy.interp(x, m_rec, m_pre), x) # integrate
if plot and j == 0: if plot and j == 0:
py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5 py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5