Files
fed-yolo/utils/util.py

820 lines
29 KiB
Python

"""
Utility functions for yolo.
"""
import copy
import random
from time import time
import math
import numpy
import torch
import torchvision
from torch.nn.functional import cross_entropy
def setup_seed():
"""
Setup random seed.
"""
random.seed(0)
numpy.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def setup_multi_processes():
"""
Setup multi-processing environment variables.
"""
import cv2
from os import environ
from platform import system
# set multiprocess start method as `fork` to speed up the training
if system() != "Windows":
torch.multiprocessing.set_start_method("fork", force=True)
# disable opencv multithreading to avoid system being overloaded
cv2.setNumThreads(0)
# setup OMP threads
if "OMP_NUM_THREADS" not in environ:
environ["OMP_NUM_THREADS"] = "1"
# setup MKL threads
if "MKL_NUM_THREADS" not in environ:
environ["MKL_NUM_THREADS"] = "1"
def export_onnx(args):
import onnx # noqa
inputs = ["images"]
outputs = ["outputs"]
dynamic = {"outputs": {0: "batch", 1: "anchors"}}
m = torch.load("./weights/best.pt", weights_only=False)["model"].float()
x = torch.zeros((1, 3, args.input_size, args.input_size))
torch.onnx.export(
m.cpu(),
(x.cpu(),),
f="./weights/best.onnx",
verbose=False,
opset_version=12,
# WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
do_constant_folding=True,
input_names=inputs,
output_names=outputs,
dynamic_axes=dynamic or None,
)
# Checks
model_onnx = onnx.load("./weights/best.onnx") # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
onnx.save(model_onnx, "./weights/best.onnx")
# Inference example
# https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/autobackend.py
def wh2xy(x):
y = x.clone() if isinstance(x, torch.Tensor) else numpy.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
return y
def make_anchors(x, strides, offset=0.5):
assert x is not None
anchor_tensor, stride_tensor = [], []
dtype, device = x[0].dtype, x[0].device
for i, stride in enumerate(strides):
_, _, h, w = x[i].shape
sx = torch.arange(end=w, device=device, dtype=dtype) + offset # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + offset # shift y
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2))
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
return torch.cat(anchor_tensor), torch.cat(stride_tensor)
def compute_metric(output, target, iou_v):
# intersection(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
(a1, a2) = target[:, 1:].unsqueeze(1).chunk(2, 2)
(b1, b2) = output[:, :4].unsqueeze(0).chunk(2, 2)
intersection = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
# IoU = intersection / (area1 + area2 - intersection)
iou = intersection / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - intersection + 1e-7)
correct = numpy.zeros((output.shape[0], iou_v.shape[0]))
correct = correct.astype(bool)
for i in range(len(iou_v)):
# IoU > threshold and classes match
x = torch.where((iou >= iou_v[i]) & (target[:, 0:1] == output[:, 5]))
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou]
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[numpy.unique(matches[:, 1], return_index=True)[1]]
matches = matches[numpy.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=output.device)
def non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.65):
max_wh = 7680
max_det = 300
max_nms = 30000
bs = outputs.shape[0] # batch size
nc = outputs.shape[1] - 4 # number of classes
xc = outputs[:, 4 : 4 + nc].amax(1) > confidence_threshold # candidates
# Settings
start = time()
limit = 0.5 + 0.05 * bs # seconds to quit after
output = [torch.zeros((0, 6), device=outputs.device)] * bs
for index, x in enumerate(outputs): # image index, image inference
x = x.transpose(0, -1)[xc[index]] # confidence
# If none remain process next image
if not x.shape[0]:
continue
# matrix nx6 (box, confidence, cls)
box, cls = x.split((4, nc), 1)
box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2)
if nc > 1:
i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 4 + j].unsqueeze(1), j[:, None].float()), dim=1)
else: # best class only
conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
# Batched NMS
c = x[:, 5:6] * max_wh # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes, scores
indices = torchvision.ops.nms(boxes, scores, iou_threshold) # NMS
indices = indices[:max_det] # limit detections
output[index] = x[indices]
if (time() - start) > limit:
break # time limit exceeded
return output
def smooth(y, f=0.1):
# Box filter of fraction f
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
p = numpy.ones(nf // 2) # ones padding
yp = numpy.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
return numpy.convolve(yp, numpy.ones(nf) / nf, mode="valid") # y-smoothed
def plot_pr_curve(px, py, ap, names, save_dir):
from matplotlib import pyplot
fig, ax = pyplot.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = numpy.stack(py, axis=1)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py.T):
ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
else:
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
ax.plot(
px,
py.mean(1),
linewidth=3,
color="blue",
label="all classes %.3f mAP@0.5" % ap[:, 0].mean(),
)
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title("Precision-Recall Curve")
fig.savefig(save_dir, dpi=250)
pyplot.close(fig)
def plot_curve(px, py, names, save_dir, x_label="Confidence", y_label="Metric"):
from matplotlib import pyplot
figure, ax = pyplot.subplots(1, 1, figsize=(9, 6), tight_layout=True)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py):
ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
else:
ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
y = smooth(py.mean(0), f=0.05)
ax.plot(
px,
y,
linewidth=3,
color="blue",
label=f"all classes {y.max():.3f} at {px[y.argmax()]:.3f}",
)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title(f"{y_label}-Confidence Curve")
figure.savefig(save_dir, dpi=250)
pyplot.close(figure)
def compute_ap(tp, conf, output, target, plot=False, names=(), eps=1e-16):
"""
Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments
tp: True positives (nparray, nx1 or nx10).
conf: Object-ness value from 0-1 (nparray).
output: Predicted object classes (nparray).
target: True object classes (nparray).
# Returns
The average precision
"""
# Sort by object-ness
i = numpy.argsort(-conf)
tp, conf, output = tp[i], conf[i], output[i]
# Find unique classes
unique_classes, nt = numpy.unique(target, return_counts=True)
nc = unique_classes.shape[0] # number of classes, number of detections
# Create Precision-Recall curve and compute AP for each class
p = numpy.zeros((nc, 1000))
r = numpy.zeros((nc, 1000))
ap = numpy.zeros((nc, tp.shape[1]))
px, py = numpy.linspace(start=0, stop=1, num=1000), [] # for plotting
for ci, c in enumerate(unique_classes):
i = output == c
nl = nt[ci] # number of labels
no = i.sum() # number of outputs
if no == 0 or nl == 0:
continue
# Accumulate FPs and TPs
fpc = (1 - tp[i]).cumsum(0)
tpc = tp[i].cumsum(0)
# Recall
recall = tpc / (nl + eps) # recall curve
# negative x, xp because xp decreases
r[ci] = numpy.interp(-px, -conf[i], recall[:, 0], left=0)
# Precision
precision = tpc / (tpc + fpc) # precision curve
p[ci] = numpy.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
# AP from recall-precision curve
for j in range(tp.shape[1]):
m_rec = numpy.concatenate(([0.0], recall[:, j], [1.0]))
m_pre = numpy.concatenate(([1.0], precision[:, j], [0.0]))
# Compute the precision envelope
m_pre = numpy.flip(numpy.maximum.accumulate(numpy.flip(m_pre)))
# Integrate area under curve
x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO)
# numpy.trapz is deprecated in numpy 2.0.0 or after version, use numpy.trapezoid instead
ap[ci, j] = numpy.trapezoid(numpy.interp(x, m_rec, m_pre), x) # integrate
if plot and j == 0:
py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5
# Compute F1 (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + eps)
if plot:
names = dict(enumerate(names)) # to dict
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
plot_pr_curve(px, py, ap, names, save_dir="./weights/PR_curve.png")
plot_curve(px, f1, names, save_dir="./weights/F1_curve.png", y_label="F1")
plot_curve(px, p, names, save_dir="./weights/P_curve.png", y_label="Precision")
plot_curve(px, r, names, save_dir="./weights/R_curve.png", y_label="Recall")
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
p, r, f1 = p[:, i], r[:, i], f1[:, i]
tp = (r * nt).round() # true positives
fp = (tp / (p + eps) - tp).round() # false positives
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
m_pre, m_rec = p.mean(), r.mean()
map50, mean_ap = ap50.mean(), ap.mean()
return tp, fp, m_pre, m_rec, map50, mean_ap
def compute_iou(box1, box2, eps=1e-7):
# Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
# Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
# Intersection area
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * (
b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
).clamp(0)
# Union Area
union = w1 * h1 + w2 * h2 - inter + eps
# IoU
iou = inter / union
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
c2 = cw**2 + ch**2 + eps # convex diagonal squared
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
# https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - (rho2 / c2 + v * alpha) # CIoU
def strip_optimizer(filename):
x = torch.load(filename, map_location="cpu", weights_only=False)
x["model"].half() # to FP16
for p in x["model"].parameters():
p.requires_grad = False
torch.save(x, f=filename)
def clip_gradients(model, max_norm=10.0):
parameters = model.parameters()
torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm)
def load_weight(model, ckpt):
dst = model.state_dict()
src = torch.load(ckpt, weights_only=False)["model"].float().cpu()
ckpt = {}
for k, v in src.state_dict().items():
if k in dst and v.shape == dst[k].shape:
ckpt[k] = v
model.load_state_dict(state_dict=ckpt, strict=False)
return model
def set_params(model, decay):
p1 = []
p2 = []
norm = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k)
for m in model.modules():
for n, p in m.named_parameters(recurse=0):
if not p.requires_grad:
continue
if n == "bias": # bias (no decay)
p1.append(p)
elif n == "weight" and isinstance(m, norm): # norm-weight (no decay)
p1.append(p)
else:
p2.append(p) # weight (with decay)
return [{"params": p1, "weight_decay": 0.00}, {"params": p2, "weight_decay": decay}]
def plot_lr(args, optimizer, scheduler, num_steps):
from matplotlib import pyplot
optimizer = copy.copy(optimizer)
scheduler = copy.copy(scheduler)
y = []
for epoch in range(args.epochs):
for i in range(num_steps):
step = i + num_steps * epoch
scheduler.step(step, optimizer)
y.append(optimizer.param_groups[0]["lr"])
pyplot.plot(y, ".-", label="LR")
pyplot.xlabel("step")
pyplot.ylabel("LR")
pyplot.grid()
pyplot.xlim(0, args.epochs * num_steps)
pyplot.ylim(0)
pyplot.savefig("./weights/lr.png", dpi=200)
pyplot.close()
class CosineLR:
def __init__(self, args, params, num_steps):
max_lr = params["max_lr"]
min_lr = params["min_lr"]
warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100))
decay_steps = int(args.epochs * num_steps - warmup_steps)
warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps))
decay_lr = []
for step in range(1, decay_steps + 1):
alpha = math.cos(math.pi * step / decay_steps)
decay_lr.append(min_lr + 0.5 * (max_lr - min_lr) * (1 + alpha))
self.total_lr = numpy.concatenate((warmup_lr, decay_lr))
def step(self, step, optimizer):
for param_group in optimizer.param_groups:
param_group["lr"] = self.total_lr[step]
class LinearLR:
def __init__(self, args, params, num_steps):
max_lr = params["max_lr"]
min_lr = params["min_lr"]
warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100))
decay_steps = max(1, int(args.epochs * num_steps - warmup_steps))
warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps), endpoint=False)
decay_lr = numpy.linspace(max_lr, min_lr, decay_steps)
self.total_lr = numpy.concatenate((warmup_lr, decay_lr))
def step(self, step, optimizer):
for param_group in optimizer.param_groups:
param_group["lr"] = self.total_lr[step]
class EMA:
"""
Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers)
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
"""
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
# Create EMA
self.ema = copy.deepcopy(model).eval() # FP32 EMA
self.updates = updates # number of EMA updates
# decay exponential ramp (to help early epochs)
self.decay = lambda x: decay * (1 - math.exp(-x / tau))
for p in self.ema.parameters():
p.requires_grad_(False)
def update(self, model):
if hasattr(model, "module"):
model = model.module
# Update EMA parameters
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates)
msd = model.state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()
class AverageMeter:
def __init__(self):
self.num = 0
self.sum = 0
self.avg = 0
def update(self, v, n):
if not math.isnan(float(v)):
self.num = self.num + n
self.sum = self.sum + v * n
self.avg = self.sum / self.num
class Assigner(torch.nn.Module):
def __init__(self, nc=80, top_k=13, alpha=1.0, beta=6.0, eps=1e-9):
super().__init__()
self.top_k = top_k
self.nc = nc
self.alpha = alpha
self.beta = beta
self.eps = eps
@torch.no_grad()
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
batch_size = pd_scores.size(0)
num_max_boxes = gt_bboxes.size(1)
if num_max_boxes == 0:
device = gt_bboxes.device
return (
torch.zeros_like(pd_bboxes).to(device),
torch.zeros_like(pd_scores).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device),
)
num_anchors = anc_points.shape[0]
shape = gt_bboxes.shape
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)
mask_in_gts = torch.cat((anc_points[None] - lt, rb - anc_points[None]), dim=2)
mask_in_gts = mask_in_gts.view(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps)
na = pd_bboxes.shape[-2]
gt_mask = (mask_in_gts * mask_gt).bool() # b, max_num_obj, h*w
overlaps = torch.zeros(
[batch_size, num_max_boxes, na],
dtype=pd_bboxes.dtype,
device=pd_bboxes.device,
)
bbox_scores = torch.zeros(
[batch_size, num_max_boxes, na],
dtype=pd_scores.dtype,
device=pd_scores.device,
)
ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long) # 2, b, max_num_obj
ind[0] = torch.arange(end=batch_size).view(-1, 1).expand(-1, num_max_boxes) # b, max_num_obj
ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
bbox_scores[gt_mask] = pd_scores[ind[0], :, ind[1]][gt_mask] # b, max_num_obj, h*w
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, num_max_boxes, -1, -1)[gt_mask]
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[gt_mask]
overlaps[gt_mask] = compute_iou(gt_boxes, pd_boxes).squeeze(-1).clamp_(0)
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
top_k_mask = mask_gt.expand(-1, -1, self.top_k).bool()
top_k_metrics, top_k_indices = torch.topk(align_metric, self.top_k, dim=-1, largest=True)
if top_k_mask is None:
top_k_mask = (top_k_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(top_k_indices)
top_k_indices.masked_fill_(~top_k_mask, 0)
mask_top_k = torch.zeros(align_metric.shape, dtype=torch.int8, device=top_k_indices.device)
ones = torch.ones_like(top_k_indices[:, :, :1], dtype=torch.int8, device=top_k_indices.device)
for k in range(self.top_k):
mask_top_k.scatter_add_(-1, top_k_indices[:, :, k : k + 1], ones)
mask_top_k.masked_fill_(mask_top_k > 1, 0)
mask_top_k = mask_top_k.to(align_metric.dtype)
mask_pos = mask_top_k * mask_in_gts * mask_gt
fg_mask = mask_pos.sum(-2)
if fg_mask.max() > 1:
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, num_max_boxes, -1)
max_overlaps_idx = overlaps.argmax(1)
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()
fg_mask = mask_pos.sum(-2)
target_gt_idx = mask_pos.argmax(-2)
# Assigned target
index = torch.arange(end=batch_size, dtype=torch.int64, device=gt_labels.device)[..., None]
target_index = target_gt_idx + index * num_max_boxes
target_labels = gt_labels.long().flatten()[target_index]
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_index]
# Assigned target scores
target_labels.clamp_(0)
target_scores = torch.zeros(
(target_labels.shape[0], target_labels.shape[1], self.nc),
dtype=torch.int64,
device=target_labels.device,
)
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc)
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
# Normalize
align_metric *= mask_pos
pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)
pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
target_scores = target_scores * norm_align_metric
return target_bboxes, target_scores, fg_mask.bool()
class QFL(torch.nn.Module):
def __init__(self, beta=2.0):
super().__init__()
self.beta = beta
self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
def forward(self, outputs, targets):
bce_loss = self.bce_loss(outputs, targets)
return torch.pow(torch.abs(targets - outputs.sigmoid()), self.beta) * bce_loss
class VFL(torch.nn.Module):
def __init__(self, alpha=0.75, gamma=2.00, iou_weighted=True):
super().__init__()
assert alpha >= 0.0
self.alpha = alpha
self.gamma = gamma
self.iou_weighted = iou_weighted
self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
def forward(self, outputs, targets):
assert outputs.size() == targets.size()
targets = targets.type_as(outputs)
if self.iou_weighted:
focal_weight = (
targets * (targets > 0.0).float()
+ self.alpha * (outputs.sigmoid() - targets).abs().pow(self.gamma) * (targets <= 0.0).float()
)
else:
focal_weight = (targets > 0.0).float() + self.alpha * (outputs.sigmoid() - targets).abs().pow(
self.gamma
) * (targets <= 0.0).float()
return self.bce_loss(outputs, targets) * focal_weight
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=1.5):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
def forward(self, outputs, targets):
loss = self.bce_loss(outputs, targets)
if self.alpha > 0:
alpha_factor = targets * self.alpha + (1 - targets) * (1 - self.alpha)
loss *= alpha_factor
if self.gamma > 0:
outputs_sigmoid = outputs.sigmoid()
p_t = targets * outputs_sigmoid + (1 - targets) * (1 - outputs_sigmoid)
gamma_factor = (1.0 - p_t) ** self.gamma
loss *= gamma_factor
return loss
class BoxLoss(torch.nn.Module):
def __init__(self, dfl_ch):
super().__init__()
self.dfl_ch = dfl_ch
def forward(
self,
pred_dist,
pred_bboxes,
anchor_points,
target_bboxes,
target_scores,
target_scores_sum,
fg_mask,
):
# IoU loss
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
iou = compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
loss_box = ((1.0 - iou) * weight).sum() / target_scores_sum
# DFL loss
a, b = target_bboxes.chunk(2, -1)
target = torch.cat((anchor_points - a, b - anchor_points), -1)
target = target.clamp(0, self.dfl_ch - 0.01)
loss_dfl = self.df_loss(pred_dist[fg_mask].view(-1, self.dfl_ch + 1), target[fg_mask])
loss_dfl = (loss_dfl * weight).sum() / target_scores_sum
return loss_box, loss_dfl
@staticmethod
def df_loss(pred_dist, target):
# Distribution Focal Loss (DFL)
# https://ieeexplore.ieee.org/document/9792391
tl = target.long() # target left
tr = tl + 1 # target right
wl = tr - target # weight left
wr = 1 - wl # weight right
left_loss = cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape)
right_loss = cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape)
return (left_loss * wl + right_loss * wr).mean(-1, keepdim=True)
class ComputeLoss:
def __init__(self, model, params):
if hasattr(model, "module"):
model = model.module
device = next(model.parameters()).device
m = model.head # Head() module
self.params = params
self.stride = m.stride
self.nc = m.nc
self.no = m.no
self.reg_max = m.ch
self.device = device
self.box_loss = BoxLoss(m.ch - 1).to(device)
self.cls_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
self.assigner = Assigner(nc=self.nc, top_k=10, alpha=0.5, beta=6.0)
self.project = torch.arange(m.ch, dtype=torch.float, device=device)
def box_decode(self, anchor_points, pred_dist):
b, a, c = pred_dist.shape
pred_dist = pred_dist.view(b, a, 4, c // 4)
pred_dist = pred_dist.softmax(3)
pred_dist = pred_dist.matmul(self.project.type(pred_dist.dtype))
lt, rb = pred_dist.chunk(2, -1)
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
return torch.cat(tensors=(x1y1, x2y2), dim=-1)
def __call__(self, outputs, targets):
x = torch.cat([i.view(outputs[0].shape[0], self.no, -1) for i in outputs], dim=2)
pred_distri, pred_scores = x.split(split_size=(self.reg_max * 4, self.nc), dim=1)
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
data_type = pred_scores.dtype
batch_size = pred_scores.shape[0]
input_size = torch.tensor(outputs[0].shape[2:], device=self.device, dtype=data_type) * self.stride[0]
anchor_points, stride_tensor = make_anchors(outputs, self.stride, offset=0.5)
idx = targets["idx"].view(-1, 1)
cls = targets["cls"].view(-1, 1)
box = targets["box"]
targets = torch.cat((idx, cls, box), dim=1).to(self.device)
if targets.shape[0] == 0:
gt = torch.zeros(batch_size, 0, 5, device=self.device)
else:
i = targets[:, 0]
_, counts = i.unique(return_counts=True)
counts = counts.to(dtype=torch.int32)
gt = torch.zeros(batch_size, counts.max(), 5, device=self.device)
for j in range(batch_size):
matches = i == j
n = matches.sum()
if n:
gt[j, :n] = targets[matches, 1:]
x = gt[..., 1:5].mul_(input_size[[1, 0, 1, 0]])
y = torch.empty_like(x)
dw = x[..., 2] / 2 # half-width
dh = x[..., 3] / 2 # half-height
y[..., 0] = x[..., 0] - dw # top left x
y[..., 1] = x[..., 1] - dh # top left y
y[..., 2] = x[..., 0] + dw # bottom right x
y[..., 3] = x[..., 1] + dh # bottom right y
gt[..., 1:5] = y
gt_labels, gt_bboxes = gt.split((1, 4), 2)
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
pred_bboxes = self.box_decode(anchor_points, pred_distri)
assigned_targets = self.assigner(
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_bboxes, target_scores, fg_mask = assigned_targets
target_scores_sum = max(target_scores.sum(), 1)
loss_cls = self.cls_loss(pred_scores, target_scores.to(data_type)).sum() / target_scores_sum # BCE
# Box loss
loss_box = torch.zeros(1, device=self.device)
loss_dfl = torch.zeros(1, device=self.device)
if fg_mask.sum():
target_bboxes /= stride_tensor
loss_box, loss_dfl = self.box_loss(
pred_distri,
pred_bboxes,
anchor_points,
target_bboxes,
target_scores,
target_scores_sum,
fg_mask,
)
loss_box *= self.params["box"] # box gain
loss_cls *= self.params["cls"] # cls gain
loss_dfl *= self.params["dfl"] # dfl gain
return loss_box, loss_cls, loss_dfl