优化util.py,简化代码结构,修复非极大值抑制和计算AP函数中的张量维度问题
This commit is contained in:
@@ -1,7 +1,3 @@
|
|||||||
"""
|
|
||||||
Utility functions for yolo.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import random
|
import random
|
||||||
from time import time
|
from time import time
|
||||||
@@ -97,7 +93,7 @@ def make_anchors(x, strides, offset=0.5):
|
|||||||
_, _, h, w = x[i].shape
|
_, _, h, w = x[i].shape
|
||||||
sx = torch.arange(end=w, device=device, dtype=dtype) + offset # shift x
|
sx = torch.arange(end=w, device=device, dtype=dtype) + offset # shift x
|
||||||
sy = torch.arange(end=h, device=device, dtype=dtype) + offset # shift y
|
sy = torch.arange(end=h, device=device, dtype=dtype) + offset # shift y
|
||||||
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
|
sy, sx = torch.meshgrid(sy, sx)
|
||||||
anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2))
|
anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2))
|
||||||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
||||||
return torch.cat(anchor_tensor), torch.cat(stride_tensor)
|
return torch.cat(anchor_tensor), torch.cat(stride_tensor)
|
||||||
@@ -151,7 +147,7 @@ def non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.65)
|
|||||||
box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2)
|
box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2)
|
||||||
if nc > 1:
|
if nc > 1:
|
||||||
i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T
|
i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T
|
||||||
x = torch.cat((box[i], x[i, 4 + j].unsqueeze(1), j[:, None].float()), dim=1)
|
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float()), 1)
|
||||||
else: # best class only
|
else: # best class only
|
||||||
conf, j = cls.max(1, keepdim=True)
|
conf, j = cls.max(1, keepdim=True)
|
||||||
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold]
|
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold]
|
||||||
@@ -195,13 +191,7 @@ def plot_pr_curve(px, py, ap, names, save_dir):
|
|||||||
else:
|
else:
|
||||||
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
|
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
|
||||||
|
|
||||||
ax.plot(
|
ax.plot(px, py.mean(1), linewidth=3, color="blue", label="all classes %.3f mAP@0.5" % ap[:, 0].mean())
|
||||||
px,
|
|
||||||
py.mean(1),
|
|
||||||
linewidth=3,
|
|
||||||
color="blue",
|
|
||||||
label="all classes %.3f mAP@0.5" % ap[:, 0].mean(),
|
|
||||||
)
|
|
||||||
ax.set_xlabel("Recall")
|
ax.set_xlabel("Recall")
|
||||||
ax.set_ylabel("Precision")
|
ax.set_ylabel("Precision")
|
||||||
ax.set_xlim(0, 1)
|
ax.set_xlim(0, 1)
|
||||||
@@ -224,13 +214,7 @@ def plot_curve(px, py, names, save_dir, x_label="Confidence", y_label="Metric"):
|
|||||||
ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
|
ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
|
||||||
|
|
||||||
y = smooth(py.mean(0), f=0.05)
|
y = smooth(py.mean(0), f=0.05)
|
||||||
ax.plot(
|
ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.3f} at {px[y.argmax()]:.3f}")
|
||||||
px,
|
|
||||||
y,
|
|
||||||
linewidth=3,
|
|
||||||
color="blue",
|
|
||||||
label=f"all classes {y.max():.3f} at {px[y.argmax()]:.3f}",
|
|
||||||
)
|
|
||||||
ax.set_xlabel(x_label)
|
ax.set_xlabel(x_label)
|
||||||
ax.set_ylabel(y_label)
|
ax.set_ylabel(y_label)
|
||||||
ax.set_xlim(0, 1)
|
ax.set_xlim(0, 1)
|
||||||
@@ -296,8 +280,7 @@ def compute_ap(tp, conf, output, target, plot=False, names=(), eps=1e-16):
|
|||||||
|
|
||||||
# Integrate area under curve
|
# Integrate area under curve
|
||||||
x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO)
|
x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO)
|
||||||
# numpy.trapz is deprecated in numpy 2.0.0 or after version, use numpy.trapezoid instead
|
ap[ci, j] = numpy.trapz(numpy.interp(x, m_rec, m_pre), x) # integrate
|
||||||
ap[ci, j] = numpy.trapezoid(numpy.interp(x, m_rec, m_pre), x) # integrate
|
|
||||||
if plot and j == 0:
|
if plot and j == 0:
|
||||||
py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5
|
py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5
|
||||||
|
|
||||||
@@ -443,7 +426,7 @@ class LinearLR:
|
|||||||
min_lr = params["min_lr"]
|
min_lr = params["min_lr"]
|
||||||
|
|
||||||
warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100))
|
warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100))
|
||||||
decay_steps = max(1, int(args.epochs * num_steps - warmup_steps))
|
decay_steps = int(args.epochs * num_steps - warmup_steps)
|
||||||
|
|
||||||
warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps), endpoint=False)
|
warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps), endpoint=False)
|
||||||
decay_lr = numpy.linspace(max_lr, min_lr, decay_steps)
|
decay_lr = numpy.linspace(max_lr, min_lr, decay_steps)
|
||||||
@@ -528,16 +511,8 @@ class Assigner(torch.nn.Module):
|
|||||||
mask_in_gts = mask_in_gts.view(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps)
|
mask_in_gts = mask_in_gts.view(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps)
|
||||||
na = pd_bboxes.shape[-2]
|
na = pd_bboxes.shape[-2]
|
||||||
gt_mask = (mask_in_gts * mask_gt).bool() # b, max_num_obj, h*w
|
gt_mask = (mask_in_gts * mask_gt).bool() # b, max_num_obj, h*w
|
||||||
overlaps = torch.zeros(
|
overlaps = torch.zeros([batch_size, num_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
|
||||||
[batch_size, num_max_boxes, na],
|
bbox_scores = torch.zeros([batch_size, num_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
|
||||||
dtype=pd_bboxes.dtype,
|
|
||||||
device=pd_bboxes.device,
|
|
||||||
)
|
|
||||||
bbox_scores = torch.zeros(
|
|
||||||
[batch_size, num_max_boxes, na],
|
|
||||||
dtype=pd_scores.dtype,
|
|
||||||
device=pd_scores.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long) # 2, b, max_num_obj
|
ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long) # 2, b, max_num_obj
|
||||||
ind[0] = torch.arange(end=batch_size).view(-1, 1).expand(-1, num_max_boxes) # b, max_num_obj
|
ind[0] = torch.arange(end=batch_size).view(-1, 1).expand(-1, num_max_boxes) # b, max_num_obj
|
||||||
@@ -587,9 +562,7 @@ class Assigner(torch.nn.Module):
|
|||||||
target_labels.clamp_(0)
|
target_labels.clamp_(0)
|
||||||
|
|
||||||
target_scores = torch.zeros(
|
target_scores = torch.zeros(
|
||||||
(target_labels.shape[0], target_labels.shape[1], self.nc),
|
(target_labels.shape[0], target_labels.shape[1], self.nc), dtype=torch.int64, device=target_labels.device
|
||||||
dtype=torch.int64,
|
|
||||||
device=target_labels.device,
|
|
||||||
)
|
)
|
||||||
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
||||||
|
|
||||||
@@ -672,16 +645,7 @@ class BoxLoss(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dfl_ch = dfl_ch
|
self.dfl_ch = dfl_ch
|
||||||
|
|
||||||
def forward(
|
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||||
self,
|
|
||||||
pred_dist,
|
|
||||||
pred_bboxes,
|
|
||||||
anchor_points,
|
|
||||||
target_bboxes,
|
|
||||||
target_scores,
|
|
||||||
target_scores_sum,
|
|
||||||
fg_mask,
|
|
||||||
):
|
|
||||||
# IoU loss
|
# IoU loss
|
||||||
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
|
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
|
||||||
iou = compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
|
iou = compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
|
||||||
@@ -803,13 +767,7 @@ class ComputeLoss:
|
|||||||
if fg_mask.sum():
|
if fg_mask.sum():
|
||||||
target_bboxes /= stride_tensor
|
target_bboxes /= stride_tensor
|
||||||
loss_box, loss_dfl = self.box_loss(
|
loss_box, loss_dfl = self.box_loss(
|
||||||
pred_distri,
|
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||||
pred_bboxes,
|
|
||||||
anchor_points,
|
|
||||||
target_bboxes,
|
|
||||||
target_scores,
|
|
||||||
target_scores_sum,
|
|
||||||
fg_mask,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_box *= self.params["box"] # box gain
|
loss_box *= self.params["box"] # box gain
|
||||||
|
Reference in New Issue
Block a user