Fedavg and YOLOv11 training
This commit is contained in:
126
config/coco_cfg.yaml
Normal file
126
config/coco_cfg.yaml
Normal file
@@ -0,0 +1,126 @@
|
||||
# global system:
|
||||
fed_algo: "FedAvg" # federated learning algorithm
|
||||
model_name: "yolo_v11_n" # yolo_v11_n, yolo_v11_t, yolo_v11_s, yolo_v11_m, yolo_v11_l, yolo_v11_x
|
||||
i_seed: 202509 # initial random seed
|
||||
|
||||
num_client: 64 # total number of clients
|
||||
num_round: 5 # total number of communication rounds
|
||||
num_local_class: 80 # number of classes per client
|
||||
|
||||
res_root: "results" # root directory for results
|
||||
dataset_path: "/home/image1325/ssd1/dataset/COCO128/"
|
||||
# train_txt: "train.txt" # path to training set txt file
|
||||
# val_txt: "val.txt" # path to validation set txt file
|
||||
# test_txt: "test.txt" # path to test set txt file
|
||||
|
||||
local_batch_size: 32 # local training batch size
|
||||
val_batch_size: 4 # validation batch size
|
||||
|
||||
num_workers: 4 # number of data loader workers
|
||||
min_data: 128 # minimum number of images per client
|
||||
max_data: 128 # maximum number of images per client
|
||||
partition_mode: "overlap" # "overlap" or "disjoint"
|
||||
connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients
|
||||
|
||||
# local training:
|
||||
min_lr: 0.000100000000 # initial learning rate
|
||||
max_lr: 0.010000000000 # maximum learning rate
|
||||
momentum: 0.9370000000 # SGD momentum/Adam beta1
|
||||
weight_decay: 0.000500 # optimizer weight decay
|
||||
|
||||
warmup_epochs: 3.00000 # warmup epochs
|
||||
box: 7.500000000000000 # box loss gain
|
||||
cls: 0.500000000000000 # cls loss gain
|
||||
dfl: 1.500000000000000 # dfl loss gain
|
||||
hsv_h: 0.0150000000000 # image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.7000000000000 # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.4000000000000 # image HSV-Value augmentation (fraction)
|
||||
degrees: 0.00000000000 # image rotation (+/- deg)
|
||||
translate: 0.100000000 # image translation (+/- fraction)
|
||||
scale: 0.5000000000000 # image scale (+/- gain)
|
||||
shear: 0.0000000000000 # image shear (+/- deg)
|
||||
flip_ud: 0.00000000000 # image flip up-down (probability)
|
||||
flip_lr: 0.50000000000 # image flip left-right (probability)
|
||||
mosaic: 1.000000000000 # image mosaic (probability)
|
||||
mix_up: 0.000000000000 # image mix-up (probability)
|
||||
names:
|
||||
0: person
|
||||
1: bicycle
|
||||
2: car
|
||||
3: motorcycle
|
||||
4: airplane
|
||||
5: bus
|
||||
6: train
|
||||
7: truck
|
||||
8: boat
|
||||
9: traffic light
|
||||
10: fire hydrant
|
||||
11: stop sign
|
||||
12: parking meter
|
||||
13: bench
|
||||
14: bird
|
||||
15: cat
|
||||
16: dog
|
||||
17: horse
|
||||
18: sheep
|
||||
19: cow
|
||||
20: elephant
|
||||
21: bear
|
||||
22: zebra
|
||||
23: giraffe
|
||||
24: backpack
|
||||
25: umbrella
|
||||
26: handbag
|
||||
27: tie
|
||||
28: suitcase
|
||||
29: frisbee
|
||||
30: skis
|
||||
31: snowboard
|
||||
32: sports ball
|
||||
33: kite
|
||||
34: baseball bat
|
||||
35: baseball glove
|
||||
36: skateboard
|
||||
37: surfboard
|
||||
38: tennis racket
|
||||
39: bottle
|
||||
40: wine glass
|
||||
41: cup
|
||||
42: fork
|
||||
43: knife
|
||||
44: spoon
|
||||
45: bowl
|
||||
46: banana
|
||||
47: apple
|
||||
48: sandwich
|
||||
49: orange
|
||||
50: broccoli
|
||||
51: carrot
|
||||
52: hot dog
|
||||
53: pizza
|
||||
54: donut
|
||||
55: cake
|
||||
56: chair
|
||||
57: couch
|
||||
58: potted plant
|
||||
59: bed
|
||||
60: dining table
|
||||
61: toilet
|
||||
62: tv
|
||||
63: laptop
|
||||
64: mouse
|
||||
65: remote
|
||||
66: keyboard
|
||||
67: cell phone
|
||||
68: microwave
|
||||
69: oven
|
||||
70: toaster
|
||||
71: sink
|
||||
72: refrigerator
|
||||
73: book
|
||||
74: clock
|
||||
75: vase
|
||||
76: scissors
|
||||
77: teddy bear
|
||||
78: hair drier
|
||||
79: toothbrush
|
47
config/uav_cfg.yaml
Normal file
47
config/uav_cfg.yaml
Normal file
@@ -0,0 +1,47 @@
|
||||
# global system:
|
||||
fed_algo: "FedAvg" # federated learning algorithm
|
||||
model_name: "yolo_v11_n" # yolo_v11_n, yolo_v11_t, yolo_v11_s, yolo_v11_m, yolo_v11_l, yolo_v11_x
|
||||
i_seed: 202509 # initial random seed
|
||||
|
||||
num_client: 100 # total number of clients
|
||||
num_round: 500 # total number of communication rounds
|
||||
num_local_class: 1 # number of classes per client
|
||||
|
||||
res_root: "results" # root directory for results
|
||||
dataset_path: "/home/image1325/ssd1/dataset/uav/"
|
||||
# train_txt: "train.txt" # path to training set txt file
|
||||
# val_txt: "val.txt" # path to validation set txt file
|
||||
# test_txt: "test.txt" # path to test set txt file
|
||||
|
||||
local_batch_size: 32 # local training batch size
|
||||
val_batch_size: 16 # validation batch size
|
||||
|
||||
num_workers: 4 # number of data loader workers
|
||||
min_data: 640 # minimum number of images per client
|
||||
max_data: 720 # maximum number of images per client
|
||||
partition_mode: "overlap" # "overlap" or "disjoint"
|
||||
connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients
|
||||
|
||||
# local training:
|
||||
min_lr: 0.000100000000 # initial learning rate
|
||||
max_lr: 0.010000000000 # maximum learning rate
|
||||
momentum: 0.9370000000 # SGD momentum/Adam beta1
|
||||
weight_decay: 0.000500 # optimizer weight decay
|
||||
|
||||
warmup_epochs: 3.00000 # warmup epochs
|
||||
box: 7.500000000000000 # box loss gain
|
||||
cls: 0.500000000000000 # cls loss gain
|
||||
dfl: 1.500000000000000 # dfl loss gain
|
||||
hsv_h: 0.0150000000000 # image HSV-Hue augmentation (fraction)
|
||||
hsv_s: 0.7000000000000 # image HSV-Saturation augmentation (fraction)
|
||||
hsv_v: 0.4000000000000 # image HSV-Value augmentation (fraction)
|
||||
degrees: 0.00000000000 # image rotation (+/- deg)
|
||||
translate: 0.100000000 # image translation (+/- fraction)
|
||||
scale: 0.5000000000000 # image scale (+/- gain)
|
||||
shear: 0.0000000000000 # image shear (+/- deg)
|
||||
flip_ud: 0.00000000000 # image flip up-down (probability)
|
||||
flip_lr: 0.50000000000 # image flip left-right (probability)
|
||||
mosaic: 1.000000000000 # image mosaic (probability)
|
||||
mix_up: 0.000000000000 # image mix-up (probability)
|
||||
names:
|
||||
0: uav
|
233
fed_algo_cs/client_base.py
Normal file
233
fed_algo_cs/client_base.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils import data
|
||||
from torch.amp.autocast_mode import autocast
|
||||
from utils.fed_util import init_model
|
||||
from utils import util
|
||||
from utils.dataset import Dataset
|
||||
from typing import cast
|
||||
|
||||
|
||||
class FedYoloClient(object):
|
||||
def __init__(self, name, model_name, params):
|
||||
"""
|
||||
Initialize the client k for federated learning
|
||||
Args:
|
||||
:param name: Name of the client k
|
||||
:param model_name: Name of the model
|
||||
:param params: config file including the hyperparameters for local training
|
||||
- batch_size: Local training batch size in the client k
|
||||
- num_workers: Number of data loader workers
|
||||
|
||||
- min_lr: Minimum learning rate
|
||||
- max_lr: Maximum learning rate
|
||||
- momentum: Momentum for local training
|
||||
- weight_decay: Weight decay for local training
|
||||
"""
|
||||
self.params = params
|
||||
# initialize the metadata in local client k
|
||||
self.target_ip = "127.0.0.3"
|
||||
self.port = 9999
|
||||
self.name = name
|
||||
|
||||
# initialize the parameters in local client k
|
||||
self._batch_size = self.params["local_batch_size"]
|
||||
self._min_lr = self.params["min_lr"]
|
||||
self._max_lr = self.params["max_lr"]
|
||||
self._momentum = self.params["momentum"]
|
||||
self.num_workers = self.params["num_workers"]
|
||||
|
||||
self.loss_record = []
|
||||
# train set length
|
||||
self.n_data = 0
|
||||
|
||||
# initialize the local training and testing dataset
|
||||
self.train_dataset = None
|
||||
self.val_dataset = None
|
||||
|
||||
# initialize the local model
|
||||
self._num_classes = len(self.params["names"])
|
||||
self._weight_decay = self.params["weight_decay"]
|
||||
|
||||
self.model_name = model_name
|
||||
self.model = init_model(model_name, self._num_classes)
|
||||
|
||||
model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
|
||||
self.parameter_number = sum([np.prod(p.size()) for p in model_parameters])
|
||||
|
||||
# GPU
|
||||
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def load_trainset(self, train_dataset: list[str]):
|
||||
"""
|
||||
Load the local training dataset
|
||||
Args:
|
||||
:param train_dataset: Training dataset
|
||||
"""
|
||||
self.train_dataset = train_dataset
|
||||
self.n_data = len(self.train_dataset)
|
||||
|
||||
def update(self, Global_model_state_dict):
|
||||
"""
|
||||
Update the local model with the global model parameters
|
||||
Args:
|
||||
:param Global_model_state_dict: State dictionary of the global model
|
||||
"""
|
||||
|
||||
if not hasattr(self, "model") or self.model is None:
|
||||
self.model = init_model(self.model_name, self._num_classes)
|
||||
|
||||
# load the global model parameters
|
||||
self.model.load_state_dict(Global_model_state_dict, strict=True)
|
||||
|
||||
def train(self, args):
|
||||
"""
|
||||
Train the local model
|
||||
Args:
|
||||
:param args: Command line arguments
|
||||
- local_rank: Local rank for distributed training
|
||||
- world_size: World size for distributed training
|
||||
- distributed: Whether to use distributed training
|
||||
- input_size: Input size for the model
|
||||
Returns:
|
||||
:return: Local updated model, number of local data points, training loss
|
||||
"""
|
||||
|
||||
if args.distributed:
|
||||
torch.cuda.set_device(device=args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||||
# print(f"Client {self.name} - distributed training on {world_size} GPUs, local rank: {local_rank}")
|
||||
# self._device = torch.device("cuda", local_rank)
|
||||
|
||||
if args.local_rank == 0:
|
||||
pass
|
||||
# if not os.path.exists("weights"):
|
||||
# os.makedirs("weights")
|
||||
|
||||
util.setup_seed()
|
||||
util.setup_multi_processes()
|
||||
|
||||
# model
|
||||
# init model have been done in __init__()
|
||||
self.model.to(self._device)
|
||||
|
||||
# Optimizer
|
||||
accumulate = max(round(64 / (self._batch_size * args.world_size)), 1)
|
||||
self._weight_decay = self._batch_size * args.world_size * accumulate / 64
|
||||
|
||||
optimizer = torch.optim.SGD(
|
||||
util.set_params(self.model, self._weight_decay),
|
||||
lr=self._min_lr,
|
||||
momentum=self._momentum,
|
||||
nesterov=True,
|
||||
)
|
||||
|
||||
# EMA
|
||||
ema = util.EMA(self.model) if args.local_rank == 0 else None
|
||||
|
||||
data_set = Dataset(
|
||||
filenames=self.train_dataset,
|
||||
input_size=args.input_size,
|
||||
params=self.params,
|
||||
augment=True,
|
||||
)
|
||||
|
||||
if args.distributed:
|
||||
train_sampler = data.DistributedSampler(
|
||||
data_set, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
|
||||
)
|
||||
else:
|
||||
train_sampler = None
|
||||
|
||||
loader = data.DataLoader(
|
||||
data_set,
|
||||
batch_size=self._batch_size,
|
||||
shuffle=train_sampler is None,
|
||||
sampler=train_sampler,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=Dataset.collate_fn,
|
||||
)
|
||||
|
||||
# Scheduler
|
||||
num_steps = max(1, len(loader))
|
||||
# print(len(loader))
|
||||
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
|
||||
# DDP mode
|
||||
if args.distributed:
|
||||
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
|
||||
self.model = nn.parallel.DistributedDataParallel(
|
||||
module=self.model,
|
||||
device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=False,
|
||||
)
|
||||
|
||||
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True)
|
||||
criterion = util.ComputeLoss(self.model, self.params)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
self.model.train()
|
||||
# when distributed, set epoch for shuffling
|
||||
if args.distributed and train_sampler is not None:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
if args.epochs - epoch == 10:
|
||||
# disable mosaic augmentation in the last 10 epochs
|
||||
ds = cast(Dataset, loader.dataset)
|
||||
ds.mosaic = False
|
||||
|
||||
avg_box_loss = util.AverageMeter()
|
||||
avg_cls_loss = util.AverageMeter()
|
||||
avg_dfl_loss = util.AverageMeter()
|
||||
|
||||
for i, (samples, targets) in enumerate(loader):
|
||||
global_step = i + num_steps * epoch
|
||||
scheduler.step(step=global_step, optimizer=optimizer)
|
||||
|
||||
samples = samples.cuda(non_blocking=True).float() / 255.0
|
||||
|
||||
# Forward
|
||||
with autocast("cuda", enabled=True):
|
||||
outputs = self.model(samples)
|
||||
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
|
||||
|
||||
# meters (use the *unscaled* values)
|
||||
bs = samples.size(0)
|
||||
avg_box_loss.update(box_loss.item(), bs)
|
||||
avg_cls_loss.update(cls_loss.item(), bs)
|
||||
avg_dfl_loss.update(dfl_loss.item(), bs)
|
||||
|
||||
# scale losses by batch/world if your loss is averaged internally per-sample/device
|
||||
box_loss = box_loss * self._batch_size * args.world_size
|
||||
cls_loss = cls_loss * self._batch_size * args.world_size
|
||||
dfl_loss = dfl_loss * self._batch_size * args.world_size
|
||||
|
||||
total_loss = box_loss + cls_loss + dfl_loss
|
||||
|
||||
# Backward
|
||||
amp_scale.scale(total_loss).backward()
|
||||
|
||||
# Optimize
|
||||
if (i + 1) % accumulate == 0:
|
||||
amp_scale.step(optimizer)
|
||||
amp_scale.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
if ema:
|
||||
ema.update(self.model)
|
||||
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
# clean
|
||||
if args.distributed:
|
||||
torch.distributed.destroy_process_group()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return (
|
||||
self.model.state_dict(),
|
||||
self.n_data,
|
||||
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg},
|
||||
)
|
178
fed_algo_cs/server_base.py
Normal file
178
fed_algo_cs/server_base.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from utils.fed_util import init_model
|
||||
from utils.dataset import Dataset
|
||||
from utils import util
|
||||
|
||||
|
||||
class FedYoloServer(object):
|
||||
def __init__(self, client_list, model_name, params):
|
||||
"""
|
||||
Federated YOLO Server
|
||||
Args:
|
||||
client_list: list of connected clients
|
||||
model_name: YOLO model architecture name
|
||||
params: dict of hyperparameters (must include 'names')
|
||||
"""
|
||||
# Track client updates
|
||||
self.client_state = {}
|
||||
self.client_loss = {}
|
||||
self.client_n_data = {}
|
||||
self.selected_clients = []
|
||||
|
||||
self._batch_size = params.get("val_batch_size", 4)
|
||||
self.client_list = client_list
|
||||
self.valset = None
|
||||
|
||||
# Federated bookkeeping
|
||||
self.round = 0
|
||||
# Total number of classes
|
||||
self.n_data = 0
|
||||
|
||||
# Device
|
||||
gpu = 0
|
||||
self._device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Global model
|
||||
self._num_classes = len(params["names"])
|
||||
self.model_name = model_name
|
||||
self.model = init_model(model_name, self._num_classes)
|
||||
self.params = params
|
||||
|
||||
def load_valset(self, valset):
|
||||
"""Server loads the validation dataset."""
|
||||
self.valset = valset
|
||||
|
||||
def state_dict(self):
|
||||
"""Return global model weights."""
|
||||
return self.model.state_dict()
|
||||
|
||||
@torch.no_grad()
|
||||
def test(self, args):
|
||||
"""
|
||||
Evaluate global model on validation set using YOLO metrics (mAP, precision, recall).
|
||||
Returns:
|
||||
dict with {"mAP": ..., "mAP50": ..., "precision": ..., "recall": ...}
|
||||
"""
|
||||
if self.valset is None:
|
||||
return {}
|
||||
|
||||
loader = DataLoader(
|
||||
self.valset,
|
||||
batch_size=self._batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
collate_fn=Dataset.collate_fn,
|
||||
)
|
||||
|
||||
self.model.to(self._device).eval().half()
|
||||
|
||||
iou_v = torch.linspace(0.5, 0.95, 10).to(self._device) # IoU thresholds
|
||||
n_iou = iou_v.numel()
|
||||
metrics = []
|
||||
|
||||
for samples, targets in loader:
|
||||
samples = samples.to(self._device).half() / 255.0
|
||||
_, _, h, w = samples.shape
|
||||
scale = torch.tensor((w, h, w, h)).to(self._device)
|
||||
|
||||
outputs = self.model(samples)
|
||||
outputs = util.non_max_suppression(outputs)
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
idx = targets["idx"] == i
|
||||
cls = targets["cls"][idx].to(self._device)
|
||||
box = targets["box"][idx].to(self._device)
|
||||
|
||||
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=self._device)
|
||||
|
||||
if output.shape[0] == 0:
|
||||
if cls.shape[0]:
|
||||
metrics.append((metric, *torch.zeros((2, 0), device=self._device), cls.squeeze(-1)))
|
||||
continue
|
||||
|
||||
if cls.shape[0]:
|
||||
cls_tensor = cls if isinstance(cls, torch.Tensor) else torch.tensor(cls, device=self._device)
|
||||
if cls_tensor.dim() == 1:
|
||||
cls_tensor = cls_tensor.unsqueeze(1)
|
||||
box_xy = util.wh2xy(box)
|
||||
if not isinstance(box_xy, torch.Tensor):
|
||||
box_xy = torch.tensor(box_xy, device=self._device)
|
||||
target = torch.cat((cls_tensor, box_xy * scale), dim=1)
|
||||
metric = util.compute_metric(output[:, :6], target, iou_v)
|
||||
|
||||
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
||||
|
||||
# Compute metrics
|
||||
if not metrics:
|
||||
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
|
||||
|
||||
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
|
||||
if len(metrics) and metrics[0].any():
|
||||
_, _, prec, rec, map50, mean_ap = util.compute_ap(*metrics, names=self.params["names"], plot=False)
|
||||
else:
|
||||
prec, rec, map50, mean_ap = 0, 0, 0, 0
|
||||
|
||||
# Back to float32 for further training
|
||||
self.model.float()
|
||||
|
||||
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
|
||||
|
||||
def select_clients(self, connection_ratio=1.0):
|
||||
"""Randomly select a fraction of clients."""
|
||||
self.selected_clients = []
|
||||
self.n_data = 0
|
||||
for client_id in self.client_list:
|
||||
if np.random.rand() <= connection_ratio:
|
||||
self.selected_clients.append(client_id)
|
||||
self.n_data += self.client_n_data.get(client_id, 0)
|
||||
|
||||
def agg(self):
|
||||
"""Aggregate client updates (FedAvg)."""
|
||||
if len(self.selected_clients) == 0 or self.n_data == 0:
|
||||
return self.model.state_dict(), {}, 0
|
||||
|
||||
model = init_model(self.model_name, self._num_classes)
|
||||
model_state = model.state_dict()
|
||||
|
||||
avg_loss = {}
|
||||
for i, name in enumerate(self.selected_clients):
|
||||
if name not in self.client_state:
|
||||
continue
|
||||
weight = self.client_n_data[name] / self.n_data
|
||||
for key in model_state.keys():
|
||||
if i == 0:
|
||||
model_state[key] = self.client_state[name][key] * weight
|
||||
else:
|
||||
model_state[key] += self.client_state[name][key] * weight
|
||||
|
||||
# Weighted average losses
|
||||
for k, v in self.client_loss[name].items():
|
||||
avg_loss[k] = avg_loss.get(k, 0.0) + v * weight
|
||||
|
||||
self.model.load_state_dict(model_state, strict=True)
|
||||
self.round += 1
|
||||
return model_state, avg_loss, self.n_data
|
||||
|
||||
def rec(self, name, state_dict, n_data, loss_dict):
|
||||
"""
|
||||
Receive local update from a client.
|
||||
Args:
|
||||
name: client ID
|
||||
state_dict: state dictionary of the local model
|
||||
n_data: number of data samples used in local training
|
||||
loss_dict: dict of losses from local training
|
||||
"""
|
||||
self.n_data += n_data
|
||||
self.client_state[name] = {k: v.cpu() for k, v in state_dict.items()}
|
||||
self.client_n_data[name] = n_data
|
||||
self.client_loss[name] = loss_dict
|
||||
|
||||
def flush(self):
|
||||
"""Clear stored client updates."""
|
||||
self.n_data = 0
|
||||
self.client_state.clear()
|
||||
self.client_n_data.clear()
|
||||
self.client_loss.clear()
|
239
fed_run.py
Normal file
239
fed_run.py
Normal file
@@ -0,0 +1,239 @@
|
||||
#!/usr/bin/env python3
|
||||
import os
|
||||
import json
|
||||
import yaml
|
||||
import time
|
||||
import random
|
||||
from tqdm import tqdm
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
from utils.dataset import Dataset
|
||||
from fed_algo_cs.client_base import FedYoloClient
|
||||
from fed_algo_cs.server_base import FedYoloServer
|
||||
from utils.args import args_parser # your args parser
|
||||
from utils.fed_util import divide_trainset # divide_trainset is yours
|
||||
|
||||
|
||||
def _read_list_file(txt_path: str):
|
||||
"""Read one path per line; keep as-is (absolute or relative)."""
|
||||
if not txt_path or not os.path.exists(txt_path):
|
||||
return []
|
||||
with open(txt_path, "r", encoding="utf-8") as f:
|
||||
return [ln.strip() for ln in f if ln.strip()]
|
||||
|
||||
|
||||
def _build_valset_if_available(cfg, params):
|
||||
"""
|
||||
Try to build a validation Dataset.
|
||||
- If cfg['val_txt'] exists, use it.
|
||||
- Else if <dataset_path>/val.txt exists, use it.
|
||||
- Else return None (testing will be skipped).
|
||||
Args:
|
||||
cfg: config dict
|
||||
params: params dict for Dataset
|
||||
Returns:
|
||||
Dataset or None
|
||||
"""
|
||||
input_size = int(cfg.get("input_size", 640))
|
||||
val_txt = cfg.get("val_txt", "")
|
||||
if not val_txt:
|
||||
ds_root = cfg.get("dataset_path", "")
|
||||
guess = os.path.join(ds_root, "val.txt") if ds_root else ""
|
||||
val_txt = guess if os.path.exists(guess) else ""
|
||||
|
||||
val_files = _read_list_file(val_txt)
|
||||
if not val_files:
|
||||
return None
|
||||
|
||||
return Dataset(
|
||||
filenames=val_files,
|
||||
input_size=input_size,
|
||||
params=params,
|
||||
augment=True,
|
||||
)
|
||||
|
||||
|
||||
def _seed_everything(seed: int):
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def _plot_curves(save_dir, hist):
|
||||
"""
|
||||
Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round.
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
rounds = np.arange(1, len(hist["mAP"]) + 1)
|
||||
|
||||
plt.figure()
|
||||
if hist["mAP"]:
|
||||
plt.plot(rounds, hist["mAP"], label="mAP50-95")
|
||||
if hist["mAP50"]:
|
||||
plt.plot(rounds, hist["mAP50"], label="mAP50")
|
||||
if hist["precision"]:
|
||||
plt.plot(rounds, hist["precision"], label="precision")
|
||||
if hist["recall"]:
|
||||
plt.plot(rounds, hist["recall"], label="recall")
|
||||
if hist["train_loss"]:
|
||||
plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)")
|
||||
plt.xlabel("Global Round")
|
||||
plt.ylabel("Metric")
|
||||
plt.title("Federated YOLO - Server Metrics")
|
||||
plt.legend()
|
||||
out_png = os.path.join(save_dir, "fed_yolo_curves.png")
|
||||
plt.savefig(out_png, dpi=150, bbox_inches="tight")
|
||||
print(f"[plot] saved: {out_png}")
|
||||
|
||||
|
||||
def fed_run():
|
||||
"""
|
||||
Main FL process:
|
||||
- Initialize clients & server
|
||||
- For each round: sequential local training -> record -> select -> aggregate
|
||||
- Test & flush
|
||||
- Record & save results, plot curves
|
||||
"""
|
||||
args_cli = args_parser()
|
||||
with open(args_cli.config, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
|
||||
# --- params / config normalization ---
|
||||
# For convenience we pass the same `params` dict used by Dataset/model/loss.
|
||||
# Here we re-use the top-level cfg directly as params.
|
||||
params = dict(cfg)
|
||||
if "names" in cfg and isinstance(cfg["names"], dict):
|
||||
# Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
|
||||
# but we can leave dict; your utils appear to accept dict
|
||||
pass
|
||||
|
||||
# seeds
|
||||
_seed_everything(int(cfg.get("i_seed", 0)))
|
||||
|
||||
# --- split clients' train data from a global train list ---
|
||||
# Expect either cfg["train_txt"] or <dataset_path>/train.txt
|
||||
train_txt = cfg.get("train_txt", "")
|
||||
if not train_txt:
|
||||
ds_root = cfg.get("dataset_path", "")
|
||||
guess = os.path.join(ds_root, "train.txt") if ds_root else ""
|
||||
train_txt = guess
|
||||
|
||||
if not train_txt or not os.path.exists(train_txt):
|
||||
raise FileNotFoundError(
|
||||
f"train.txt not found. Provide --config with 'train_txt' or ensure '{train_txt}' exists."
|
||||
)
|
||||
|
||||
split = divide_trainset(
|
||||
trainset_path=train_txt,
|
||||
num_local_class=int(cfg.get("num_local_class", 1)),
|
||||
num_client=int(cfg.get("num_client", 64)),
|
||||
min_data=int(cfg.get("min_data", 100)),
|
||||
max_data=int(cfg.get("max_data", 100)),
|
||||
mode=str(cfg.get("partition_mode", "disjoint")), # "overlap" or "disjoint"
|
||||
seed=int(cfg.get("i_seed", 0)),
|
||||
)
|
||||
|
||||
users = split["users"]
|
||||
user_data = split["user_data"] # mapping: id -> {"filename": [...]}
|
||||
|
||||
# --- build clients ---
|
||||
model_name = cfg.get("model_name", "yolo_v11_n")
|
||||
clients = {}
|
||||
for uid in users:
|
||||
c = FedYoloClient(name=uid, model_name=model_name, params=params)
|
||||
c.load_trainset(user_data[uid]["filename"])
|
||||
clients[uid] = c
|
||||
|
||||
# --- build server & optional validation set ---
|
||||
server = FedYoloServer(client_list=users, model_name=model_name, params=params)
|
||||
valset = _build_valset_if_available(cfg, params)
|
||||
# valset is a Dataset class, not data loader
|
||||
if valset is not None:
|
||||
server.load_valset(valset)
|
||||
|
||||
# --- push initial global weights ---
|
||||
global_state = server.state_dict()
|
||||
|
||||
# --- args object for client.train() ---
|
||||
# args_train = _make_args_for_client(cfg, args_cli)
|
||||
|
||||
# --- history recorder ---
|
||||
history = {
|
||||
"mAP": [],
|
||||
"mAP50": [],
|
||||
"precision": [],
|
||||
"recall": [],
|
||||
"train_loss": [], # scalar sum of client-weighted dict losses
|
||||
"round_time_sec": [],
|
||||
}
|
||||
|
||||
# --- main FL loop ---
|
||||
num_round = int(cfg.get("num_round", 50))
|
||||
connection_ratio = float(cfg.get("connection_ratio", 1.0)) # e.g., 1.0 = all clients
|
||||
res_root = cfg.get("res_root", "results")
|
||||
os.makedirs(res_root, exist_ok=True)
|
||||
|
||||
for rnd in tqdm(range(num_round), desc="main federal loop round"):
|
||||
t0 = time.time()
|
||||
|
||||
# Local training (sequential over all users)
|
||||
for uid in tqdm(users, desc=f"Round {rnd + 1} local training", leave=False):
|
||||
client = clients[uid] # FedYoloClient instance
|
||||
client.update(global_state) # load global weights
|
||||
state_dict, n_data, loss_dict = client.train(args_cli) # local training
|
||||
server.rec(uid, state_dict, n_data, loss_dict)
|
||||
|
||||
# Select a fraction for aggregation (FedAvg subset if desired)
|
||||
server.select_clients(connection_ratio=connection_ratio)
|
||||
|
||||
# Aggregate
|
||||
global_state, avg_loss_dict, _ = server.agg()
|
||||
|
||||
# Compute a scalar train loss for plotting (sum of components)
|
||||
scalar_train_loss = float(sum(avg_loss_dict.values())) if avg_loss_dict else 0.0
|
||||
|
||||
# Test (if valset provided)
|
||||
test_metrics = server.test(args_cli) if server.valset is not None else {}
|
||||
mAP = float(test_metrics.get("mAP", 0.0))
|
||||
mAP50 = float(test_metrics.get("mAP50", 0.0))
|
||||
precision = float(test_metrics.get("precision", 0.0))
|
||||
recall = float(test_metrics.get("recall", 0.0))
|
||||
|
||||
# Flush per-round client caches
|
||||
server.flush()
|
||||
|
||||
# Record & log
|
||||
history["mAP"].append(mAP)
|
||||
history["mAP50"].append(mAP50)
|
||||
history["precision"].append(precision)
|
||||
history["recall"].append(recall)
|
||||
history["train_loss"].append(scalar_train_loss)
|
||||
history["round_time_sec"].append(time.time() - t0)
|
||||
|
||||
print(
|
||||
f"[round {rnd + 1:04d}] "
|
||||
f"loss={scalar_train_loss:.4f} mAP50-95={mAP:.4f} mAP50={mAP50:.4f} "
|
||||
f"P={precision:.4f} R={recall:.4f}"
|
||||
)
|
||||
|
||||
# Save running JSON (resumable logs)
|
||||
save_name = (
|
||||
f"[{cfg.get('fed_algo', 'FedAvg')},{cfg.get('model_name', 'yolo')},"
|
||||
f"{cfg.get('num_local_epoch', cfg.get('client', {}).get('num_local_epoch', 1))},"
|
||||
f"{cfg.get('num_local_class', 2)},"
|
||||
f"{cfg.get('i_seed', 0)}]"
|
||||
)
|
||||
out_json = os.path.join(res_root, save_name + ".json")
|
||||
with open(out_json, "w", encoding="utf-8") as f:
|
||||
json.dump(history, f, indent=2)
|
||||
|
||||
# --- final plot ---
|
||||
_plot_curves(res_root, history)
|
||||
print("[done] training complete.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fed_run()
|
362
nets/nn.py
Normal file
362
nets/nn.py
Normal file
@@ -0,0 +1,362 @@
|
||||
"""
|
||||
This file contains the model definition of YOLOv11
|
||||
"""
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
|
||||
from utils.util import make_anchors
|
||||
|
||||
|
||||
def fuse_conv(conv, norm):
|
||||
fused_conv = (
|
||||
torch.nn.Conv2d(
|
||||
conv.in_channels,
|
||||
conv.out_channels,
|
||||
kernel_size=conv.kernel_size,
|
||||
stride=conv.stride,
|
||||
padding=conv.padding,
|
||||
groups=conv.groups,
|
||||
bias=True,
|
||||
)
|
||||
.requires_grad_(False)
|
||||
.to(conv.weight.device)
|
||||
)
|
||||
|
||||
w_conv = conv.weight.clone().view(conv.out_channels, -1)
|
||||
w_norm = torch.diag(norm.weight.div(torch.sqrt(norm.eps + norm.running_var)))
|
||||
fused_conv.weight.copy_(torch.mm(w_norm, w_conv).view(fused_conv.weight.size()))
|
||||
|
||||
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
|
||||
b_norm = norm.bias - norm.weight.mul(norm.running_mean).div(torch.sqrt(norm.running_var + norm.eps))
|
||||
if fused_conv.bias is not None:
|
||||
fused_conv.bias.copy_(torch.mm(w_norm, b_conv.reshape(-1, 1)).reshape(-1) + b_norm)
|
||||
|
||||
return fused_conv
|
||||
|
||||
|
||||
class Conv(torch.nn.Module):
|
||||
def __init__(self, in_ch, out_ch, activation, k=1, s=1, p=0, g=1):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_ch, out_ch, k, s, p, groups=g, bias=False)
|
||||
self.norm = torch.nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.03)
|
||||
self.relu = activation
|
||||
|
||||
def forward(self, x):
|
||||
return self.relu(self.norm(self.conv(x)))
|
||||
|
||||
def fuse_forward(self, x):
|
||||
return self.relu(self.conv(x))
|
||||
|
||||
|
||||
class Residual(torch.nn.Module):
|
||||
def __init__(self, ch, e=0.5):
|
||||
super().__init__()
|
||||
self.conv1 = Conv(ch, int(ch * e), torch.nn.SiLU(), k=3, p=1)
|
||||
self.conv2 = Conv(int(ch * e), ch, torch.nn.SiLU(), k=3, p=1)
|
||||
|
||||
def forward(self, x):
|
||||
return x + self.conv2(self.conv1(x))
|
||||
|
||||
|
||||
class CSPModule(torch.nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.conv1 = Conv(in_ch, out_ch // 2, torch.nn.SiLU())
|
||||
self.conv2 = Conv(in_ch, out_ch // 2, torch.nn.SiLU())
|
||||
self.conv3 = Conv(2 * (out_ch // 2), out_ch, torch.nn.SiLU())
|
||||
self.res_m = torch.nn.Sequential(Residual(out_ch // 2, e=1.0), Residual(out_ch // 2, e=1.0))
|
||||
|
||||
def forward(self, x):
|
||||
y = self.res_m(self.conv1(x))
|
||||
return self.conv3(torch.cat((y, self.conv2(x)), dim=1))
|
||||
|
||||
|
||||
class CSP(torch.nn.Module):
|
||||
def __init__(self, in_ch, out_ch, n, csp, r):
|
||||
super().__init__()
|
||||
self.conv1 = Conv(in_ch, 2 * (out_ch // r), torch.nn.SiLU())
|
||||
self.conv2 = Conv((2 + n) * (out_ch // r), out_ch, torch.nn.SiLU())
|
||||
|
||||
if not csp:
|
||||
self.res_m = torch.nn.ModuleList(Residual(out_ch // r) for _ in range(n))
|
||||
else:
|
||||
self.res_m = torch.nn.ModuleList(CSPModule(out_ch // r, out_ch // r) for _ in range(n))
|
||||
|
||||
def forward(self, x):
|
||||
y = list(self.conv1(x).chunk(2, 1))
|
||||
y.extend(m(y[-1]) for m in self.res_m)
|
||||
return self.conv2(torch.cat(y, dim=1))
|
||||
|
||||
|
||||
class SPP(torch.nn.Module):
|
||||
def __init__(self, in_ch, out_ch, k=5):
|
||||
super().__init__()
|
||||
self.conv1 = Conv(in_ch, in_ch // 2, torch.nn.SiLU())
|
||||
self.conv2 = Conv(in_ch * 2, out_ch, torch.nn.SiLU())
|
||||
self.res_m = torch.nn.MaxPool2d(k, stride=1, padding=k // 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
y1 = self.res_m(x)
|
||||
y2 = self.res_m(y1)
|
||||
return self.conv2(torch.cat(tensors=[x, y1, y2, self.res_m(y2)], dim=1))
|
||||
|
||||
|
||||
class Attention(torch.nn.Module):
|
||||
def __init__(self, ch, num_head):
|
||||
super().__init__()
|
||||
self.num_head = num_head
|
||||
self.dim_head = ch // num_head
|
||||
self.dim_key = self.dim_head // 2
|
||||
self.scale = self.dim_key**-0.5
|
||||
|
||||
self.qkv = Conv(ch, ch + self.dim_key * num_head * 2, torch.nn.Identity())
|
||||
|
||||
self.conv1 = Conv(ch, ch, torch.nn.Identity(), k=3, p=1, g=ch)
|
||||
self.conv2 = Conv(ch, ch, torch.nn.Identity())
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.shape
|
||||
|
||||
qkv = self.qkv(x)
|
||||
qkv = qkv.view(b, self.num_head, self.dim_key * 2 + self.dim_head, h * w)
|
||||
|
||||
q, k, v = qkv.split([self.dim_key, self.dim_key, self.dim_head], dim=2)
|
||||
|
||||
attn = (q.transpose(-2, -1) @ k) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
x = (v @ attn.transpose(-2, -1)).view(b, c, h, w) + self.conv1(v.reshape(b, c, h, w))
|
||||
return self.conv2(x)
|
||||
|
||||
|
||||
class PSABlock(torch.nn.Module):
|
||||
def __init__(self, ch, num_head):
|
||||
super().__init__()
|
||||
self.conv1 = Attention(ch, num_head)
|
||||
self.conv2 = torch.nn.Sequential(Conv(ch, ch * 2, torch.nn.SiLU()), Conv(ch * 2, ch, torch.nn.Identity()))
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.conv1(x)
|
||||
return x + self.conv2(x)
|
||||
|
||||
|
||||
class PSA(torch.nn.Module):
|
||||
def __init__(self, ch, n):
|
||||
super().__init__()
|
||||
self.conv1 = Conv(ch, 2 * (ch // 2), torch.nn.SiLU())
|
||||
self.conv2 = Conv(2 * (ch // 2), ch, torch.nn.SiLU())
|
||||
self.res_m = torch.nn.Sequential(*(PSABlock(ch // 2, ch // 128) for _ in range(n)))
|
||||
|
||||
def forward(self, x):
|
||||
x, y = self.conv1(x).chunk(2, 1)
|
||||
return self.conv2(torch.cat(tensors=(x, self.res_m(y)), dim=1))
|
||||
|
||||
|
||||
class DarkNet(torch.nn.Module):
|
||||
def __init__(self, width, depth, csp):
|
||||
super().__init__()
|
||||
self.p1 = []
|
||||
self.p2 = []
|
||||
self.p3 = []
|
||||
self.p4 = []
|
||||
self.p5 = []
|
||||
|
||||
# p1/2
|
||||
self.p1.append(Conv(width[0], width[1], torch.nn.SiLU(), k=3, s=2, p=1))
|
||||
# p2/4
|
||||
self.p2.append(Conv(width[1], width[2], torch.nn.SiLU(), k=3, s=2, p=1))
|
||||
self.p2.append(CSP(width[2], width[3], depth[0], csp[0], r=4))
|
||||
# p3/8
|
||||
self.p3.append(Conv(width[3], width[3], torch.nn.SiLU(), k=3, s=2, p=1))
|
||||
self.p3.append(CSP(width[3], width[4], depth[1], csp[0], r=4))
|
||||
# p4/16
|
||||
self.p4.append(Conv(width[4], width[4], torch.nn.SiLU(), k=3, s=2, p=1))
|
||||
self.p4.append(CSP(width[4], width[4], depth[2], csp[1], r=2))
|
||||
# p5/32
|
||||
self.p5.append(Conv(width[4], width[5], torch.nn.SiLU(), k=3, s=2, p=1))
|
||||
self.p5.append(CSP(width[5], width[5], depth[3], csp[1], r=2))
|
||||
self.p5.append(SPP(width[5], width[5]))
|
||||
self.p5.append(PSA(width[5], depth[4]))
|
||||
|
||||
self.p1 = torch.nn.Sequential(*self.p1)
|
||||
self.p2 = torch.nn.Sequential(*self.p2)
|
||||
self.p3 = torch.nn.Sequential(*self.p3)
|
||||
self.p4 = torch.nn.Sequential(*self.p4)
|
||||
self.p5 = torch.nn.Sequential(*self.p5)
|
||||
|
||||
def forward(self, x):
|
||||
p1 = self.p1(x)
|
||||
p2 = self.p2(p1)
|
||||
p3 = self.p3(p2)
|
||||
p4 = self.p4(p3)
|
||||
p5 = self.p5(p4)
|
||||
return p3, p4, p5
|
||||
|
||||
|
||||
class DarkFPN(torch.nn.Module):
|
||||
def __init__(self, width, depth, csp):
|
||||
super().__init__()
|
||||
self.up = torch.nn.Upsample(scale_factor=2)
|
||||
self.h1 = CSP(width[4] + width[5], width[4], depth[5], csp[0], r=2)
|
||||
self.h2 = CSP(width[4] + width[4], width[3], depth[5], csp[0], r=2)
|
||||
self.h3 = Conv(width[3], width[3], torch.nn.SiLU(), k=3, s=2, p=1)
|
||||
self.h4 = CSP(width[3] + width[4], width[4], depth[5], csp[0], r=2)
|
||||
self.h5 = Conv(width[4], width[4], torch.nn.SiLU(), k=3, s=2, p=1)
|
||||
self.h6 = CSP(width[4] + width[5], width[5], depth[5], csp[1], r=2)
|
||||
|
||||
def forward(self, x):
|
||||
p3, p4, p5 = x
|
||||
p4 = self.h1(torch.cat(tensors=[self.up(p5), p4], dim=1))
|
||||
p3 = self.h2(torch.cat(tensors=[self.up(p4), p3], dim=1))
|
||||
p4 = self.h4(torch.cat(tensors=[self.h3(p3), p4], dim=1))
|
||||
p5 = self.h6(torch.cat(tensors=[self.h5(p4), p5], dim=1))
|
||||
return p3, p4, p5
|
||||
|
||||
|
||||
class DFL(torch.nn.Module):
|
||||
# Generalized Focal Loss
|
||||
# https://ieeexplore.ieee.org/document/9792391
|
||||
def __init__(self, ch=16):
|
||||
super().__init__()
|
||||
self.ch = ch
|
||||
self.conv = torch.nn.Conv2d(ch, out_channels=1, kernel_size=1, bias=False).requires_grad_(False)
|
||||
x = torch.arange(ch, dtype=torch.float).view(1, ch, 1, 1)
|
||||
self.conv.weight.data[:] = torch.nn.Parameter(x)
|
||||
|
||||
def forward(self, x):
|
||||
b, c, a = x.shape
|
||||
x = x.view(b, 4, self.ch, a).transpose(2, 1)
|
||||
return self.conv(x.softmax(1)).view(b, 4, a)
|
||||
|
||||
|
||||
class Head(torch.nn.Module):
|
||||
anchors = torch.empty(0)
|
||||
strides = torch.empty(0)
|
||||
|
||||
def __init__(self, nc=80, filters=()):
|
||||
super().__init__()
|
||||
self.ch = 16 # DFL channels
|
||||
self.nc = nc # number of classes
|
||||
self.nl = len(filters) # number of detection layers
|
||||
self.no = nc + self.ch * 4 # number of outputs per anchor
|
||||
self.stride = torch.zeros(self.nl) # strides computed during build
|
||||
|
||||
box = max(64, filters[0] // 4)
|
||||
cls = max(80, filters[0], self.nc)
|
||||
|
||||
self.dfl = DFL(self.ch)
|
||||
self.box = torch.nn.ModuleList(
|
||||
torch.nn.Sequential(
|
||||
Conv(x, box, torch.nn.SiLU(), k=3, p=1),
|
||||
Conv(box, box, torch.nn.SiLU(), k=3, p=1),
|
||||
torch.nn.Conv2d(box, out_channels=4 * self.ch, kernel_size=1),
|
||||
)
|
||||
for x in filters
|
||||
)
|
||||
self.cls = torch.nn.ModuleList(
|
||||
torch.nn.Sequential(
|
||||
Conv(x, x, torch.nn.SiLU(), k=3, p=1, g=x),
|
||||
Conv(x, cls, torch.nn.SiLU()),
|
||||
Conv(cls, cls, torch.nn.SiLU(), k=3, p=1, g=cls),
|
||||
Conv(cls, cls, torch.nn.SiLU()),
|
||||
torch.nn.Conv2d(cls, out_channels=self.nc, kernel_size=1),
|
||||
)
|
||||
for x in filters
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
for i, (box, cls) in enumerate(zip(self.box, self.cls)):
|
||||
x[i] = torch.cat(tensors=(box(x[i]), cls(x[i])), dim=1)
|
||||
if self.training:
|
||||
return x
|
||||
|
||||
self.anchors, self.strides = (i.transpose(0, 1) for i in make_anchors(x, self.stride))
|
||||
x = torch.cat([i.view(x[0].shape[0], self.no, -1) for i in x], dim=2)
|
||||
box, cls = x.split(split_size=(4 * self.ch, self.nc), dim=1)
|
||||
|
||||
a, b = self.dfl(box).chunk(2, 1)
|
||||
a = self.anchors.unsqueeze(0) - a
|
||||
b = self.anchors.unsqueeze(0) + b
|
||||
box = torch.cat(tensors=((a + b) / 2, b - a), dim=1)
|
||||
|
||||
return torch.cat(tensors=(box * self.strides, cls.sigmoid()), dim=1)
|
||||
|
||||
def initialize_biases(self):
|
||||
# Initialize biases
|
||||
# WARNING: requires stride availability
|
||||
for box, cls, s in zip(self.box, self.cls, self.stride):
|
||||
# box
|
||||
box[-1].bias.data[:] = 1.0
|
||||
# cls (.01 objects, 80 classes, 640 image)
|
||||
cls[-1].bias.data[: self.nc] = math.log(5 / self.nc / (640 / s) ** 2)
|
||||
|
||||
|
||||
class YOLO(torch.nn.Module):
|
||||
def __init__(self, width, depth, csp, num_classes):
|
||||
super().__init__()
|
||||
self.net = DarkNet(width, depth, csp)
|
||||
self.fpn = DarkFPN(width, depth, csp)
|
||||
|
||||
img_dummy = torch.zeros(1, width[0], 256, 256)
|
||||
self.head = Head(num_classes, (width[3], width[4], width[5]))
|
||||
self.head.stride = torch.tensor([256 / x.shape[-2] for x in self.forward(img_dummy)])
|
||||
self.stride = self.head.stride
|
||||
self.head.initialize_biases()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.net(x)
|
||||
x = self.fpn(x)
|
||||
return self.head(list(x))
|
||||
|
||||
def fuse(self):
|
||||
for m in self.modules():
|
||||
if type(m) is Conv and hasattr(m, "norm"):
|
||||
m.conv = fuse_conv(m.conv, m.norm)
|
||||
m.forward = m.fuse_forward
|
||||
delattr(m, "norm")
|
||||
return self
|
||||
|
||||
|
||||
def yolo_v11_n(num_classes: int = 80):
|
||||
csp = [False, True]
|
||||
depth = [1, 1, 1, 1, 1, 1]
|
||||
width = [3, 16, 32, 64, 128, 256]
|
||||
return YOLO(width, depth, csp, num_classes)
|
||||
|
||||
|
||||
def yolo_v11_t(num_classes: int = 80):
|
||||
csp = [False, True]
|
||||
depth = [1, 1, 1, 1, 1, 1]
|
||||
width = [3, 24, 48, 96, 192, 384]
|
||||
return YOLO(width, depth, csp, num_classes)
|
||||
|
||||
|
||||
def yolo_v11_s(num_classes: int = 80):
|
||||
csp = [False, True]
|
||||
depth = [1, 1, 1, 1, 1, 1]
|
||||
width = [3, 32, 64, 128, 256, 512]
|
||||
return YOLO(width, depth, csp, num_classes)
|
||||
|
||||
|
||||
def yolo_v11_m(num_classes: int = 80):
|
||||
csp = [True, True]
|
||||
depth = [1, 1, 1, 1, 1, 1]
|
||||
width = [3, 64, 128, 256, 512, 512]
|
||||
return YOLO(width, depth, csp, num_classes)
|
||||
|
||||
|
||||
def yolo_v11_l(num_classes: int = 80):
|
||||
csp = [True, True]
|
||||
depth = [2, 2, 2, 2, 2, 2]
|
||||
width = [3, 64, 128, 256, 512, 512]
|
||||
return YOLO(width, depth, csp, num_classes)
|
||||
|
||||
|
||||
def yolo_v11_x(num_classes: int = 80):
|
||||
csp = [True, True]
|
||||
depth = [2, 2, 2, 2, 2, 2]
|
||||
width = [3, 96, 192, 384, 768, 768]
|
||||
return YOLO(width, depth, csp, num_classes)
|
18
utils/args.py
Normal file
18
utils/args.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import argparse
|
||||
import os
|
||||
|
||||
|
||||
def args_parser():
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument("--epochs", type=int, default=10, help="number of rounds of local training")
|
||||
parser.add_argument("--input_size", type=int, default=640, help="image input size")
|
||||
parser.add_argument("--config", type=str, default="./config/uav_cfg.yaml", help="Path to YAML config")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
args.local_rank = int(os.getenv("LOCAL_RANK", 0))
|
||||
args.world_size = int(os.getenv("WORLD_SIZE", 1))
|
||||
args.distributed = int(os.getenv("WORLD_SIZE", 1)) > 1
|
||||
|
||||
return args
|
478
utils/dataset.py
Normal file
478
utils/dataset.py
Normal file
@@ -0,0 +1,478 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
|
||||
import cv2
|
||||
import numpy
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils import data
|
||||
|
||||
FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "JPEG", "JPG", "PNG", "TIFF"
|
||||
|
||||
|
||||
class Dataset(data.Dataset):
|
||||
params: dict
|
||||
mosaic: bool
|
||||
augment: bool
|
||||
input_size: int
|
||||
|
||||
def __init__(self, filenames, input_size: int, params: dict, augment: bool):
|
||||
self.params = params
|
||||
self.mosaic = augment
|
||||
self.augment = augment
|
||||
self.input_size = input_size
|
||||
|
||||
# Read labels
|
||||
labels = self.load_label(filenames)
|
||||
self.labels = list(labels.values())
|
||||
self.filenames = list(labels.keys()) # update
|
||||
self.n = len(self.filenames) # number of samples
|
||||
self.indices = range(self.n)
|
||||
# Albumentations (optional, only used if package is installed)
|
||||
self.albumentations = Albumentations()
|
||||
|
||||
def __getitem__(self, index):
|
||||
index = self.indices[index]
|
||||
|
||||
if self.mosaic and random.random() < self.params["mosaic"]:
|
||||
# Load MOSAIC
|
||||
image, label = self.load_mosaic(index, self.params)
|
||||
# MixUp augmentation
|
||||
if random.random() < self.params["mix_up"]:
|
||||
index = random.choice(self.indices)
|
||||
mix_image1, mix_label1 = image, label
|
||||
mix_image2, mix_label2 = self.load_mosaic(index, self.params)
|
||||
|
||||
image, label = mix_up(mix_image1, mix_label1, mix_image2, mix_label2)
|
||||
else:
|
||||
# Load image
|
||||
image, shape = self.load_image(index)
|
||||
if image is None:
|
||||
raise ValueError(f"Failed to load image at index {index}: {self.filenames[index]}")
|
||||
h, w = image.shape[:2]
|
||||
|
||||
# Resize
|
||||
image, ratio, pad = resize(image, self.input_size, self.augment)
|
||||
|
||||
label = self.labels[index].copy()
|
||||
if label.size:
|
||||
label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, int(pad[0]), int(pad[1]))
|
||||
if self.augment:
|
||||
image, label = random_perspective(image, label, self.params)
|
||||
|
||||
nl = len(label) # number of labels
|
||||
h, w = image.shape[:2]
|
||||
cls = label[:, 0:1]
|
||||
box = label[:, 1:5]
|
||||
box = xy2wh(box, w, h)
|
||||
|
||||
if self.augment:
|
||||
# Albumentations
|
||||
image, box, cls = self.albumentations(image, box, cls)
|
||||
nl = len(box) # update after albumentations
|
||||
# HSV color-space
|
||||
augment_hsv(image, self.params)
|
||||
# Flip up-down
|
||||
if random.random() < self.params["flip_ud"]:
|
||||
image = numpy.flipud(image)
|
||||
if nl:
|
||||
box[:, 1] = 1 - box[:, 1]
|
||||
# Flip left-right
|
||||
if random.random() < self.params["flip_lr"]:
|
||||
image = numpy.fliplr(image)
|
||||
if nl:
|
||||
box[:, 0] = 1 - box[:, 0]
|
||||
|
||||
# target_cls = torch.zeros((nl, 1))
|
||||
# target_box = torch.zeros((nl, 4))
|
||||
# if nl:
|
||||
# target_cls = torch.from_numpy(cls)
|
||||
# target_box = torch.from_numpy(box)
|
||||
|
||||
# fix [cls, box] empty bug. e.g. [0,1] is illegal in DataLoader collate_fn cat operation
|
||||
if nl:
|
||||
target_cls = torch.from_numpy(cls).view(-1, 1).float() # always (N,1)
|
||||
target_box = torch.from_numpy(box).reshape(-1, 4).float() # always (N,4)
|
||||
else:
|
||||
target_cls = torch.zeros((0, 1), dtype=torch.float32)
|
||||
target_box = torch.zeros((0, 4), dtype=torch.float32)
|
||||
|
||||
# Convert HWC to CHW, BGR to RGB
|
||||
sample = image.transpose((2, 0, 1))[::-1]
|
||||
sample = numpy.ascontiguousarray(sample)
|
||||
|
||||
# init: return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl)
|
||||
return torch.from_numpy(sample), target_cls, target_box, torch.zeros((nl, 1), dtype=torch.long)
|
||||
|
||||
def __len__(self):
|
||||
return len(self.filenames)
|
||||
|
||||
def load_image(self, i):
|
||||
image = cv2.imread(self.filenames[i])
|
||||
if image is None:
|
||||
raise ValueError(f"Image not found or unable to open: {self.filenames[i]}")
|
||||
h, w = image.shape[:2]
|
||||
r = self.input_size / max(h, w)
|
||||
if r != 1:
|
||||
image = cv2.resize(
|
||||
image, dsize=(int(w * r), int(h * r)), interpolation=resample() if self.augment else cv2.INTER_LINEAR
|
||||
)
|
||||
return image, (h, w)
|
||||
|
||||
def load_mosaic(self, index, params):
|
||||
label4 = []
|
||||
border = [-self.input_size // 2, -self.input_size // 2]
|
||||
image4 = numpy.full((self.input_size * 2, self.input_size * 2, 3), 0, dtype=numpy.uint8)
|
||||
y1a, y2a, x1a, x2a, y1b, y2b, x1b, x2b = (None, None, None, None, None, None, None, None)
|
||||
|
||||
xc = int(random.uniform(-border[0], 2 * self.input_size + border[1]))
|
||||
yc = int(random.uniform(-border[0], 2 * self.input_size + border[1]))
|
||||
|
||||
indices = [index] + random.choices(self.indices, k=3)
|
||||
random.shuffle(indices)
|
||||
|
||||
for i, index in enumerate(indices):
|
||||
# Load image
|
||||
image, _ = self.load_image(index)
|
||||
shape = image.shape
|
||||
if i == 0: # top left
|
||||
x1a = max(xc - shape[1], 0)
|
||||
y1a = max(yc - shape[0], 0)
|
||||
x2a = xc
|
||||
y2a = yc
|
||||
x1b = shape[1] - (x2a - x1a)
|
||||
y1b = shape[0] - (y2a - y1a)
|
||||
x2b = shape[1]
|
||||
y2b = shape[0]
|
||||
if i == 1: # top right
|
||||
x1a = xc
|
||||
y1a = max(yc - shape[0], 0)
|
||||
x2a = min(xc + shape[1], self.input_size * 2)
|
||||
y2a = yc
|
||||
x1b = 0
|
||||
y1b = shape[0] - (y2a - y1a)
|
||||
x2b = min(shape[1], x2a - x1a)
|
||||
y2b = shape[0]
|
||||
if i == 2: # bottom left
|
||||
x1a = max(xc - shape[1], 0)
|
||||
y1a = yc
|
||||
x2a = xc
|
||||
y2a = min(self.input_size * 2, yc + shape[0])
|
||||
x1b = shape[1] - (x2a - x1a)
|
||||
y1b = 0
|
||||
x2b = shape[1]
|
||||
y2b = min(y2a - y1a, shape[0])
|
||||
if i == 3: # bottom right
|
||||
x1a = xc
|
||||
y1a = yc
|
||||
x2a = min(xc + shape[1], self.input_size * 2)
|
||||
y2a = min(self.input_size * 2, yc + shape[0])
|
||||
x1b = 0
|
||||
y1b = 0
|
||||
x2b = min(shape[1], x2a - x1a)
|
||||
y2b = min(y2a - y1a, shape[0])
|
||||
|
||||
pad_w = (x1a if x1a is not None else 0) - (x1b if x1b is not None else 0)
|
||||
pad_h = (y1a if y1a is not None else 0) - (y1b if y1b is not None else 0)
|
||||
image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]
|
||||
|
||||
# Labels
|
||||
label = self.labels[index].copy()
|
||||
if len(label):
|
||||
label[:, 1:] = wh2xy(label[:, 1:], shape[1], shape[0], pad_w, pad_h)
|
||||
label4.append(label)
|
||||
|
||||
# Concat/clip labels
|
||||
label4 = numpy.concatenate(label4, 0)
|
||||
for x in label4[:, 1:]:
|
||||
numpy.clip(x, 0, 2 * self.input_size, out=x)
|
||||
|
||||
# Augment
|
||||
image4, label4 = random_perspective(image4, label4, params, border)
|
||||
|
||||
return image4, label4
|
||||
|
||||
@staticmethod
|
||||
def collate_fn(batch):
|
||||
samples, cls, box, indices = zip(*batch)
|
||||
|
||||
# ensure empty tensor shape is correct
|
||||
cls = [c.view(-1, 1) for c in cls]
|
||||
box = [b.reshape(-1, 4) for b in box]
|
||||
indices = [i for i in indices]
|
||||
|
||||
cls = torch.cat(cls, dim=0) if cls else torch.zeros((0, 1))
|
||||
box = torch.cat(box, dim=0) if box else torch.zeros((0, 4))
|
||||
indices = torch.cat(indices, dim=0) if indices else torch.zeros((0,), dtype=torch.long)
|
||||
|
||||
new_indices = list(indices)
|
||||
for i in range(len(indices)):
|
||||
new_indices[i] += i
|
||||
indices = torch.cat(new_indices, dim=0)
|
||||
|
||||
targets = {"cls": cls, "box": box, "idx": indices}
|
||||
return torch.stack(samples, dim=0), targets
|
||||
|
||||
@staticmethod
|
||||
def load_label_use_cache(filenames):
|
||||
path = f"{os.path.dirname(filenames[0])}.cache"
|
||||
if os.path.exists(path):
|
||||
return torch.load(path, weights_only=False)
|
||||
x = {}
|
||||
for filename in filenames:
|
||||
try:
|
||||
# verify images
|
||||
with open(filename, "rb") as f:
|
||||
image = Image.open(f)
|
||||
image.verify() # PIL verify
|
||||
shape = image.size # image size
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert image.format is not None and image.format.lower() in FORMATS, (
|
||||
f"invalid image format {image.format}"
|
||||
)
|
||||
|
||||
# verify labels
|
||||
a = f"{os.sep}images{os.sep}"
|
||||
b = f"{os.sep}labels{os.sep}"
|
||||
|
||||
if os.path.isfile(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"):
|
||||
with open(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt") as f:
|
||||
label = [x.split() for x in f.read().strip().splitlines() if len(x)]
|
||||
label = numpy.array(label, dtype=numpy.float32)
|
||||
nl = len(label)
|
||||
if nl:
|
||||
assert (label >= 0).all()
|
||||
assert label.shape[1] == 5
|
||||
assert (label[:, 1:] <= 1).all()
|
||||
_, i = numpy.unique(label, axis=0, return_index=True)
|
||||
if len(i) < nl: # duplicate row check
|
||||
label = label[i] # remove duplicates
|
||||
else:
|
||||
label = numpy.zeros((0, 5), dtype=numpy.float32)
|
||||
else:
|
||||
label = numpy.zeros((0, 5), dtype=numpy.float32)
|
||||
except FileNotFoundError:
|
||||
label = numpy.zeros((0, 5), dtype=numpy.float32)
|
||||
except AssertionError:
|
||||
continue
|
||||
x[filename] = label
|
||||
torch.save(x, path)
|
||||
return x
|
||||
|
||||
@staticmethod
|
||||
def load_label(filenames):
|
||||
x = {}
|
||||
for filename in filenames:
|
||||
try:
|
||||
# verify images
|
||||
with open(filename, "rb") as f:
|
||||
image = Image.open(f)
|
||||
image.verify()
|
||||
shape = image.size
|
||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||
assert image.format is not None and image.format.lower() in FORMATS, (
|
||||
f"invalid image format {image.format}"
|
||||
)
|
||||
|
||||
# verify labels
|
||||
a = f"{os.sep}images{os.sep}"
|
||||
b = f"{os.sep}labels{os.sep}"
|
||||
label_path = b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"
|
||||
|
||||
if os.path.isfile(label_path):
|
||||
rows = []
|
||||
with open(label_path) as f:
|
||||
for line in f:
|
||||
parts = line.strip().split()
|
||||
if len(parts) == 5: # YOLO format
|
||||
rows.append([float(x) for x in parts])
|
||||
label = numpy.array(rows, dtype=numpy.float32) if rows else numpy.zeros((0, 5), dtype=numpy.float32)
|
||||
|
||||
if label.shape[0]:
|
||||
assert (label >= 0).all()
|
||||
assert label.shape[1] == 5
|
||||
assert (label[:, 1:] <= 1.0001).all()
|
||||
_, i = numpy.unique(label, axis=0, return_index=True)
|
||||
label = label[i]
|
||||
else:
|
||||
label = numpy.zeros((0, 5), dtype=numpy.float32)
|
||||
|
||||
except (FileNotFoundError, AssertionError):
|
||||
label = numpy.zeros((0, 5), dtype=numpy.float32)
|
||||
|
||||
x[filename] = label
|
||||
return x
|
||||
|
||||
|
||||
def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0):
|
||||
# Convert nx4 boxes
|
||||
# from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
|
||||
y = numpy.copy(x)
|
||||
y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + pad_w # top left x
|
||||
y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + pad_h # top left y
|
||||
y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + pad_w # bottom right x
|
||||
y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + pad_h # bottom right y
|
||||
return y
|
||||
|
||||
|
||||
def xy2wh(x, w, h):
|
||||
# warning: inplace clip
|
||||
x[:, [0, 2]] = x[:, [0, 2]].clip(0, w - 1e-3) # x1, x2
|
||||
x[:, [1, 3]] = x[:, [1, 3]].clip(0, h - 1e-3) # y1, y2
|
||||
|
||||
# Convert nx4 boxes
|
||||
# from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
|
||||
y = numpy.copy(x)
|
||||
y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
|
||||
y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
|
||||
y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
|
||||
y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
|
||||
return y
|
||||
|
||||
|
||||
def resample():
|
||||
choices = (cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4)
|
||||
return random.choice(seq=choices)
|
||||
|
||||
|
||||
def augment_hsv(image, params):
|
||||
# HSV color-space augmentation
|
||||
h = params["hsv_h"]
|
||||
s = params["hsv_s"]
|
||||
v = params["hsv_v"]
|
||||
|
||||
r = numpy.random.uniform(-1, 1, 3) * [h, s, v] + 1
|
||||
h, s, v = cv2.split(cv2.cvtColor(image, cv2.COLOR_BGR2HSV))
|
||||
|
||||
x = numpy.arange(0, 256, dtype=r.dtype)
|
||||
lut_h = ((x * r[0]) % 180).astype("uint8")
|
||||
lut_s = numpy.clip(x * r[1], 0, 255).astype("uint8")
|
||||
lut_v = numpy.clip(x * r[2], 0, 255).astype("uint8")
|
||||
|
||||
hsv = cv2.merge((cv2.LUT(h, lut_h), cv2.LUT(s, lut_s), cv2.LUT(v, lut_v)))
|
||||
cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR, dst=image) # no return needed
|
||||
|
||||
|
||||
def resize(image, input_size, augment):
|
||||
# Resize and pad image while meeting stride-multiple constraints
|
||||
shape = image.shape[:2] # current shape [height, width]
|
||||
|
||||
# Scale ratio (new / old)
|
||||
r = min(input_size / shape[0], input_size / shape[1])
|
||||
if not augment: # only scale down, do not scale up (for better val mAP)
|
||||
r = min(r, 1.0)
|
||||
|
||||
# Compute padding
|
||||
pad = int(round(shape[1] * r)), int(round(shape[0] * r))
|
||||
w = (input_size - pad[0]) / 2
|
||||
h = (input_size - pad[1]) / 2
|
||||
|
||||
if shape[::-1] != pad: # resize
|
||||
image = cv2.resize(image, dsize=pad, interpolation=resample() if augment else cv2.INTER_LINEAR)
|
||||
top, bottom = int(round(h - 0.1)), int(round(h + 0.1))
|
||||
left, right = int(round(w - 0.1)), int(round(w + 0.1))
|
||||
image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT) # add border
|
||||
return image, (r, r), (w, h)
|
||||
|
||||
|
||||
def candidates(box1, box2):
|
||||
# box1(4,n), box2(4,n)
|
||||
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
|
||||
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
|
||||
aspect_ratio = numpy.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio
|
||||
return (w2 > 2) & (h2 > 2) & (w2 * h2 / (w1 * h1 + 1e-16) > 0.1) & (aspect_ratio < 100)
|
||||
|
||||
|
||||
def random_perspective(image, label, params, border=(0, 0)):
|
||||
h = image.shape[0] + border[0] * 2
|
||||
w = image.shape[1] + border[1] * 2
|
||||
|
||||
# Center
|
||||
center = numpy.eye(3)
|
||||
center[0, 2] = -image.shape[1] / 2 # x translation (pixels)
|
||||
center[1, 2] = -image.shape[0] / 2 # y translation (pixels)
|
||||
|
||||
# Perspective
|
||||
perspective = numpy.eye(3)
|
||||
|
||||
# Rotation and Scale
|
||||
rotate = numpy.eye(3)
|
||||
a = random.uniform(-params["degrees"], params["degrees"])
|
||||
s = random.uniform(1 - params["scale"], 1 + params["scale"])
|
||||
rotate[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
|
||||
|
||||
# Shear
|
||||
shear = numpy.eye(3)
|
||||
shear[0, 1] = math.tan(random.uniform(-params["shear"], params["shear"]) * math.pi / 180)
|
||||
shear[1, 0] = math.tan(random.uniform(-params["shear"], params["shear"]) * math.pi / 180)
|
||||
|
||||
# Translation
|
||||
translate = numpy.eye(3)
|
||||
translate[0, 2] = random.uniform(0.5 - params["translate"], 0.5 + params["translate"]) * w
|
||||
translate[1, 2] = random.uniform(0.5 - params["translate"], 0.5 + params["translate"]) * h
|
||||
|
||||
# Combined rotation matrix, order of operations (right to left) is IMPORTANT
|
||||
matrix = translate @ shear @ rotate @ perspective @ center
|
||||
if (border[0] != 0) or (border[1] != 0) or (matrix != numpy.eye(3)).any(): # image changed
|
||||
image = cv2.warpAffine(image, matrix[:2], dsize=(w, h), borderValue=(0, 0, 0))
|
||||
|
||||
# Transform label coordinates
|
||||
n = len(label)
|
||||
if n:
|
||||
xy = numpy.ones((n * 4, 3))
|
||||
xy[:, :2] = label[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
|
||||
xy = xy @ matrix.T # transform
|
||||
xy = xy[:, :2].reshape(n, 8) # perspective rescale or affine
|
||||
|
||||
# create new boxes
|
||||
x = xy[:, [0, 2, 4, 6]]
|
||||
y = xy[:, [1, 3, 5, 7]]
|
||||
box = numpy.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
|
||||
|
||||
# clip
|
||||
box[:, [0, 2]] = box[:, [0, 2]].clip(0, w)
|
||||
box[:, [1, 3]] = box[:, [1, 3]].clip(0, h)
|
||||
# filter candidates
|
||||
indices = candidates(box1=label[:, 1:5].T * s, box2=box.T)
|
||||
|
||||
label = label[indices]
|
||||
label[:, 1:5] = box[indices]
|
||||
|
||||
return image, label
|
||||
|
||||
|
||||
def mix_up(image1, label1, image2, label2):
|
||||
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
|
||||
alpha = numpy.random.beta(a=32.0, b=32.0) # mix-up ratio, alpha=beta=32.0
|
||||
image = (image1 * alpha + image2 * (1 - alpha)).astype(numpy.uint8)
|
||||
label = numpy.concatenate((label1, label2), 0)
|
||||
return image, label
|
||||
|
||||
|
||||
class Albumentations:
|
||||
def __init__(self):
|
||||
self.transform = None
|
||||
try:
|
||||
import albumentations
|
||||
|
||||
transforms = [
|
||||
albumentations.Blur(p=0.01),
|
||||
albumentations.CLAHE(p=0.01),
|
||||
albumentations.ToGray(p=0.01),
|
||||
albumentations.MedianBlur(p=0.01),
|
||||
]
|
||||
self.transform = albumentations.Compose(
|
||||
transforms, albumentations.BboxParams(format="yolo", label_fields=["class_labels"])
|
||||
)
|
||||
|
||||
except ImportError: # package not installed, skip
|
||||
pass
|
||||
|
||||
def __call__(self, image, box, cls):
|
||||
if self.transform:
|
||||
x = self.transform(image=image, bboxes=box, class_labels=cls)
|
||||
image = x["image"]
|
||||
box = numpy.array(x["bboxes"])
|
||||
cls = numpy.array(x["class_labels"])
|
||||
return image, box, cls
|
250
utils/fed_util.py
Normal file
250
utils/fed_util.py
Normal file
@@ -0,0 +1,250 @@
|
||||
import os
|
||||
import re
|
||||
import random
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
|
||||
from nets import nn
|
||||
|
||||
|
||||
def _image_to_label_path(img_path: str) -> str:
|
||||
"""
|
||||
Convert an image path like ".../images/train2017/xxx.jpg"
|
||||
to the corresponding label path ".../labels/train2017/xxx.txt".
|
||||
Works for POSIX/Windows separators.
|
||||
"""
|
||||
# swap "/images/" (or "\images\") to "/labels/"
|
||||
label_path = re.sub(r"([/\\])images([/\\])", r"\1labels\2", img_path)
|
||||
# swap extension to .txt
|
||||
root, _ = os.path.splitext(label_path)
|
||||
return root + ".txt"
|
||||
|
||||
|
||||
def _parse_yolo_label_file(label_path: str) -> Set[int]:
|
||||
"""
|
||||
Return a set of class_ids found in a YOLO .txt label file.
|
||||
Empty file -> empty set. Missing file -> empty set.
|
||||
Robust to blank lines / trailing spaces.
|
||||
"""
|
||||
class_ids: Set[int] = set()
|
||||
if not os.path.exists(label_path):
|
||||
return class_ids
|
||||
try:
|
||||
with open(label_path, "r", encoding="utf-8") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if not line:
|
||||
continue
|
||||
# YOLO format: cls cx cy w h
|
||||
parts = line.split()
|
||||
if not parts:
|
||||
continue
|
||||
try:
|
||||
cls = int(parts[0])
|
||||
except ValueError:
|
||||
# handle weird case like '23.0'
|
||||
try:
|
||||
cls = int(float(parts[0]))
|
||||
except ValueError:
|
||||
# skip malformed line
|
||||
continue
|
||||
class_ids.add(cls)
|
||||
except Exception:
|
||||
# If the file can't be read for some reason, treat as no labels
|
||||
return set()
|
||||
return class_ids
|
||||
|
||||
|
||||
def divide_trainset(
|
||||
trainset_path: str,
|
||||
num_local_class: int,
|
||||
num_client: int,
|
||||
min_data: int,
|
||||
max_data: int,
|
||||
mode: str = "overlap", # "overlap" or "disjoint"
|
||||
seed: Optional[int] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Build a federated split from a YOLO dataset list file.
|
||||
|
||||
Args:
|
||||
trainset_path: path to a .txt file containing one image path per line
|
||||
e.g. /COCO/images/train2017/1111.jpg
|
||||
num_local_class: how many distinct classes to sample for each client
|
||||
num_client: number of clients
|
||||
min_data: minimum number of images per client
|
||||
max_data: maximum number of images per client
|
||||
mode: "overlap" -> images may be shared across clients
|
||||
"disjoint" -> each image is used by at most one client
|
||||
seed: optional random seed for reproducibility
|
||||
|
||||
Returns:
|
||||
trainset_divided = {
|
||||
"users": ["c_00001", ...],
|
||||
"user_data": {
|
||||
"c_00001": {"filename": [img_path, ...]},
|
||||
...
|
||||
},
|
||||
"num_samples": [len(list_for_user1), len(list_for_user2), ...]
|
||||
}
|
||||
|
||||
Example:
|
||||
dataset = divide_trainset(
|
||||
trainset_path="/COCO/train2017.txt",
|
||||
num_local_class=3,
|
||||
num_client=5,
|
||||
min_data=10,
|
||||
max_data=20,
|
||||
mode="disjoint", # or "overlap"
|
||||
seed=42
|
||||
)
|
||||
|
||||
print(dataset["users"]) # ['c_00001', ..., 'c_00005']
|
||||
print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
|
||||
print(dataset["user_data"]["c_00001"]["filename"][:3])
|
||||
"""
|
||||
if seed is not None:
|
||||
random.seed(seed)
|
||||
|
||||
# ---- Basic validations (defensive programming) ----
|
||||
if num_client <= 0:
|
||||
raise ValueError("num_client must be > 0")
|
||||
if num_local_class <= 0:
|
||||
raise ValueError("num_local_class must be > 0")
|
||||
if min_data < 0 or max_data < 0:
|
||||
raise ValueError("min_data/max_data must be >= 0")
|
||||
if max_data < min_data:
|
||||
raise ValueError("max_data must be >= min_data")
|
||||
if mode not in {"overlap", "disjoint"}:
|
||||
raise ValueError('mode must be "overlap" or "disjoint"')
|
||||
|
||||
# ---- 1) Read image list ----
|
||||
with open(trainset_path, "r", encoding="utf-8") as f:
|
||||
all_images_raw = [ln.strip() for ln in f if ln.strip()]
|
||||
|
||||
# Normalize and deduplicate image paths (safe)
|
||||
all_images: List[str] = []
|
||||
seen = set()
|
||||
for p in all_images_raw:
|
||||
# keep exact string (don’t join with cwd), just normalize slashes
|
||||
norm = os.path.normpath(p)
|
||||
if norm not in seen:
|
||||
seen.add(norm)
|
||||
all_images.append(norm)
|
||||
|
||||
# ---- 2) Build mappings from labels ----
|
||||
class_to_images: Dict[int, Set[str]] = defaultdict(set)
|
||||
image_to_classes: Dict[str, Set[int]] = {}
|
||||
|
||||
missing_label_files = 0
|
||||
empty_label_files = 0
|
||||
parsed_images = 0
|
||||
|
||||
for img in all_images:
|
||||
lbl = _image_to_label_path(img)
|
||||
if not os.path.exists(lbl):
|
||||
# Missing labels: skip image (no class info)
|
||||
missing_label_files += 1
|
||||
continue
|
||||
|
||||
classes = _parse_yolo_label_file(lbl)
|
||||
if not classes:
|
||||
# No objects in this image -> skip (no class bucket)
|
||||
empty_label_files += 1
|
||||
continue
|
||||
|
||||
image_to_classes[img] = classes
|
||||
for c in classes:
|
||||
class_to_images[c].add(img)
|
||||
parsed_images += 1
|
||||
|
||||
if not class_to_images:
|
||||
# No usable images found
|
||||
return {
|
||||
"users": [f"c_{i + 1:05d}" for i in range(num_client)],
|
||||
"user_data": {f"c_{i + 1:05d}": {"filename": []} for i in range(num_client)},
|
||||
"num_samples": [0 for _ in range(num_client)],
|
||||
}
|
||||
|
||||
all_classes: List[int] = sorted(class_to_images.keys())
|
||||
# Available pool for disjoint mode (only images with labels)
|
||||
available_images: Set[str] = set(image_to_classes.keys())
|
||||
|
||||
# ---- 3) Allocate to clients ----
|
||||
result = {"users": [], "user_data": {}, "num_samples": []}
|
||||
|
||||
for cid in range(num_client):
|
||||
user_id = f"c_{cid + 1:05d}"
|
||||
result["users"].append(user_id)
|
||||
|
||||
# Pick the classes for this client (sample without replacement from global class set)
|
||||
k = min(num_local_class, len(all_classes))
|
||||
chosen_classes = random.sample(all_classes, k) if k > 0 else []
|
||||
|
||||
# Decide how many samples for this client
|
||||
need = min_data if min_data == max_data else random.randint(min_data, max_data)
|
||||
|
||||
# Build the candidate pool for this client
|
||||
if mode == "overlap":
|
||||
pool_set: Set[str] = set()
|
||||
for c in chosen_classes:
|
||||
pool_set.update(class_to_images[c])
|
||||
else: # "disjoint": restrict to currently available images
|
||||
pool_set = set()
|
||||
for c in chosen_classes:
|
||||
# intersect with available images
|
||||
pool_set.update(class_to_images[c] & available_images)
|
||||
|
||||
# Deduplicate and sample
|
||||
pool_list = list(pool_set)
|
||||
if len(pool_list) <= need:
|
||||
chosen_imgs = pool_list[:] # take all (can be fewer than need)
|
||||
else:
|
||||
chosen_imgs = random.sample(pool_list, need)
|
||||
|
||||
# Record for the user
|
||||
result["user_data"][user_id] = {"filename": chosen_imgs}
|
||||
result["num_samples"].append(len(chosen_imgs))
|
||||
|
||||
# If disjoint, remove selected images from availability everywhere
|
||||
if mode == "disjoint" and chosen_imgs:
|
||||
for img in chosen_imgs:
|
||||
if img in available_images:
|
||||
available_images.remove(img)
|
||||
# remove from every class bucket this image belongs to
|
||||
for c in image_to_classes.get(img, []):
|
||||
if img in class_to_images[c]:
|
||||
class_to_images[c].remove(img)
|
||||
# Optional: prune empty classes from all_classes to speed up later loops
|
||||
# (keep list stable; just skip empties naturally)
|
||||
|
||||
# (Optional) You can print some quick diagnostics if helpful:
|
||||
# print(f"[INFO] Parsed images with labels: {parsed_images}")
|
||||
# print(f"[INFO] Missing label files: {missing_label_files}")
|
||||
# print(f"[INFO] Empty label files: {empty_label_files}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def init_model(model_name, num_classes):
|
||||
"""
|
||||
Initialize the model for a specific learning task
|
||||
Args:
|
||||
:param model_name: Name of the model
|
||||
:param num_classes: Number of classes
|
||||
"""
|
||||
model = None
|
||||
if model_name == "yolo_v11_n":
|
||||
model = nn.yolo_v11_n(num_classes=num_classes)
|
||||
elif model_name == "yolo_v11_s":
|
||||
model = nn.yolo_v11_s(num_classes=num_classes)
|
||||
elif model_name == "yolo_v11_m":
|
||||
model = nn.yolo_v11_m(num_classes=num_classes)
|
||||
elif model_name == "yolo_v11_l":
|
||||
model = nn.yolo_v11_l(num_classes=num_classes)
|
||||
elif model_name == "yolo_v11_x":
|
||||
model = nn.yolo_v11_x(num_classes=num_classes)
|
||||
else:
|
||||
raise ValueError("Model {} is not supported.".format(model_name))
|
||||
|
||||
return model
|
818
utils/util.py
Normal file
818
utils/util.py
Normal file
@@ -0,0 +1,818 @@
|
||||
"""
|
||||
Utility functions for yolo.
|
||||
"""
|
||||
|
||||
import copy
|
||||
import random
|
||||
from time import time
|
||||
|
||||
import math
|
||||
import numpy
|
||||
import torch
|
||||
import torchvision
|
||||
from torch.nn.functional import cross_entropy
|
||||
|
||||
|
||||
def setup_seed():
|
||||
"""
|
||||
Setup random seed.
|
||||
"""
|
||||
random.seed(0)
|
||||
numpy.random.seed(0)
|
||||
torch.manual_seed(0)
|
||||
torch.backends.cudnn.benchmark = False
|
||||
torch.backends.cudnn.deterministic = True
|
||||
|
||||
|
||||
def setup_multi_processes():
|
||||
"""
|
||||
Setup multi-processing environment variables.
|
||||
"""
|
||||
import cv2
|
||||
from os import environ
|
||||
from platform import system
|
||||
|
||||
# set multiprocess start method as `fork` to speed up the training
|
||||
if system() != "Windows":
|
||||
torch.multiprocessing.set_start_method("fork", force=True)
|
||||
|
||||
# disable opencv multithreading to avoid system being overloaded
|
||||
cv2.setNumThreads(0)
|
||||
|
||||
# setup OMP threads
|
||||
if "OMP_NUM_THREADS" not in environ:
|
||||
environ["OMP_NUM_THREADS"] = "1"
|
||||
|
||||
# setup MKL threads
|
||||
if "MKL_NUM_THREADS" not in environ:
|
||||
environ["MKL_NUM_THREADS"] = "1"
|
||||
|
||||
|
||||
def export_onnx(args):
|
||||
import onnx # noqa
|
||||
|
||||
inputs = ["images"]
|
||||
outputs = ["outputs"]
|
||||
dynamic = {"outputs": {0: "batch", 1: "anchors"}}
|
||||
|
||||
m = torch.load("./weights/best.pt", weights_only=False)["model"].float()
|
||||
x = torch.zeros((1, 3, args.input_size, args.input_size))
|
||||
|
||||
torch.onnx.export(
|
||||
m.cpu(),
|
||||
(x.cpu(),),
|
||||
f="./weights/best.onnx",
|
||||
verbose=False,
|
||||
opset_version=12,
|
||||
# WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
|
||||
do_constant_folding=True,
|
||||
input_names=inputs,
|
||||
output_names=outputs,
|
||||
dynamic_axes=dynamic or None,
|
||||
)
|
||||
|
||||
# Checks
|
||||
model_onnx = onnx.load("./weights/best.onnx") # load onnx model
|
||||
onnx.checker.check_model(model_onnx) # check onnx model
|
||||
|
||||
onnx.save(model_onnx, "./weights/best.onnx")
|
||||
# Inference example
|
||||
# https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/autobackend.py
|
||||
|
||||
|
||||
def wh2xy(x):
|
||||
y = x.clone() if isinstance(x, torch.Tensor) else numpy.copy(x)
|
||||
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
|
||||
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
|
||||
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
|
||||
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
|
||||
return y
|
||||
|
||||
|
||||
def make_anchors(x, strides, offset=0.5):
|
||||
assert x is not None
|
||||
anchor_tensor, stride_tensor = [], []
|
||||
dtype, device = x[0].dtype, x[0].device
|
||||
for i, stride in enumerate(strides):
|
||||
_, _, h, w = x[i].shape
|
||||
sx = torch.arange(end=w, device=device, dtype=dtype) + offset # shift x
|
||||
sy = torch.arange(end=h, device=device, dtype=dtype) + offset # shift y
|
||||
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
|
||||
anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2))
|
||||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
||||
return torch.cat(anchor_tensor), torch.cat(stride_tensor)
|
||||
|
||||
|
||||
def compute_metric(output, target, iou_v):
|
||||
# intersection(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
|
||||
(a1, a2) = target[:, 1:].unsqueeze(1).chunk(2, 2)
|
||||
(b1, b2) = output[:, :4].unsqueeze(0).chunk(2, 2)
|
||||
intersection = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
|
||||
# IoU = intersection / (area1 + area2 - intersection)
|
||||
iou = intersection / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - intersection + 1e-7)
|
||||
|
||||
correct = numpy.zeros((output.shape[0], iou_v.shape[0]))
|
||||
correct = correct.astype(bool)
|
||||
for i in range(len(iou_v)):
|
||||
# IoU > threshold and classes match
|
||||
x = torch.where((iou >= iou_v[i]) & (target[:, 0:1] == output[:, 5]))
|
||||
if x[0].shape[0]:
|
||||
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou]
|
||||
if x[0].shape[0] > 1:
|
||||
matches = matches[matches[:, 2].argsort()[::-1]]
|
||||
matches = matches[numpy.unique(matches[:, 1], return_index=True)[1]]
|
||||
matches = matches[numpy.unique(matches[:, 0], return_index=True)[1]]
|
||||
correct[matches[:, 1].astype(int), i] = True
|
||||
return torch.tensor(correct, dtype=torch.bool, device=output.device)
|
||||
|
||||
|
||||
def non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.65):
|
||||
max_wh = 7680
|
||||
max_det = 300
|
||||
max_nms = 30000
|
||||
|
||||
bs = outputs.shape[0] # batch size
|
||||
nc = outputs.shape[1] - 4 # number of classes
|
||||
xc = outputs[:, 4 : 4 + nc].amax(1) > confidence_threshold # candidates
|
||||
|
||||
# Settings
|
||||
start = time()
|
||||
limit = 0.5 + 0.05 * bs # seconds to quit after
|
||||
output = [torch.zeros((0, 6), device=outputs.device)] * bs
|
||||
for index, x in enumerate(outputs): # image index, image inference
|
||||
x = x.transpose(0, -1)[xc[index]] # confidence
|
||||
|
||||
# If none remain process next image
|
||||
if not x.shape[0]:
|
||||
continue
|
||||
|
||||
# matrix nx6 (box, confidence, cls)
|
||||
box, cls = x.split((4, nc), 1)
|
||||
box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2)
|
||||
if nc > 1:
|
||||
i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T
|
||||
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float()), dim=1)
|
||||
else: # best class only
|
||||
conf, j = cls.max(1, keepdim=True)
|
||||
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold]
|
||||
|
||||
# Check shape
|
||||
n = x.shape[0] # number of boxes
|
||||
if not n: # no boxes
|
||||
continue
|
||||
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
|
||||
|
||||
# Batched NMS
|
||||
c = x[:, 5:6] * max_wh # classes
|
||||
boxes, scores = x[:, :4] + c, x[:, 4] # boxes, scores
|
||||
indices = torchvision.ops.nms(boxes, scores, iou_threshold) # NMS
|
||||
indices = indices[:max_det] # limit detections
|
||||
|
||||
output[index] = x[indices]
|
||||
if (time() - start) > limit:
|
||||
break # time limit exceeded
|
||||
|
||||
return output
|
||||
|
||||
|
||||
def smooth(y, f=0.1):
|
||||
# Box filter of fraction f
|
||||
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
|
||||
p = numpy.ones(nf // 2) # ones padding
|
||||
yp = numpy.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
|
||||
return numpy.convolve(yp, numpy.ones(nf) / nf, mode="valid") # y-smoothed
|
||||
|
||||
|
||||
def plot_pr_curve(px, py, ap, names, save_dir):
|
||||
from matplotlib import pyplot
|
||||
|
||||
fig, ax = pyplot.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
py = numpy.stack(py, axis=1)
|
||||
|
||||
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
||||
for i, y in enumerate(py.T):
|
||||
ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
|
||||
else:
|
||||
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
|
||||
|
||||
ax.plot(
|
||||
px,
|
||||
py.mean(1),
|
||||
linewidth=3,
|
||||
color="blue",
|
||||
label="all classes %.3f mAP@0.5" % ap[:, 0].mean(),
|
||||
)
|
||||
ax.set_xlabel("Recall")
|
||||
ax.set_ylabel("Precision")
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||
ax.set_title("Precision-Recall Curve")
|
||||
fig.savefig(save_dir, dpi=250)
|
||||
pyplot.close(fig)
|
||||
|
||||
|
||||
def plot_curve(px, py, names, save_dir, x_label="Confidence", y_label="Metric"):
|
||||
from matplotlib import pyplot
|
||||
|
||||
figure, ax = pyplot.subplots(1, 1, figsize=(9, 6), tight_layout=True)
|
||||
|
||||
if 0 < len(names) < 21: # display per-class legend if < 21 classes
|
||||
for i, y in enumerate(py):
|
||||
ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
|
||||
else:
|
||||
ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
|
||||
|
||||
y = smooth(py.mean(0), f=0.05)
|
||||
ax.plot(
|
||||
px,
|
||||
y,
|
||||
linewidth=3,
|
||||
color="blue",
|
||||
label=f"all classes {y.max():.3f} at {px[y.argmax()]:.3f}",
|
||||
)
|
||||
ax.set_xlabel(x_label)
|
||||
ax.set_ylabel(y_label)
|
||||
ax.set_xlim(0, 1)
|
||||
ax.set_ylim(0, 1)
|
||||
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
|
||||
ax.set_title(f"{y_label}-Confidence Curve")
|
||||
figure.savefig(save_dir, dpi=250)
|
||||
pyplot.close(figure)
|
||||
|
||||
|
||||
def compute_ap(tp, conf, output, target, plot=False, names=(), eps=1e-16):
|
||||
"""
|
||||
Compute the average precision, given the recall and precision curves.
|
||||
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
|
||||
# Arguments
|
||||
tp: True positives (nparray, nx1 or nx10).
|
||||
conf: Object-ness value from 0-1 (nparray).
|
||||
output: Predicted object classes (nparray).
|
||||
target: True object classes (nparray).
|
||||
# Returns
|
||||
The average precision
|
||||
"""
|
||||
# Sort by object-ness
|
||||
i = numpy.argsort(-conf)
|
||||
tp, conf, output = tp[i], conf[i], output[i]
|
||||
|
||||
# Find unique classes
|
||||
unique_classes, nt = numpy.unique(target, return_counts=True)
|
||||
nc = unique_classes.shape[0] # number of classes, number of detections
|
||||
|
||||
# Create Precision-Recall curve and compute AP for each class
|
||||
p = numpy.zeros((nc, 1000))
|
||||
r = numpy.zeros((nc, 1000))
|
||||
ap = numpy.zeros((nc, tp.shape[1]))
|
||||
px, py = numpy.linspace(start=0, stop=1, num=1000), [] # for plotting
|
||||
for ci, c in enumerate(unique_classes):
|
||||
i = output == c
|
||||
nl = nt[ci] # number of labels
|
||||
no = i.sum() # number of outputs
|
||||
if no == 0 or nl == 0:
|
||||
continue
|
||||
|
||||
# Accumulate FPs and TPs
|
||||
fpc = (1 - tp[i]).cumsum(0)
|
||||
tpc = tp[i].cumsum(0)
|
||||
|
||||
# Recall
|
||||
recall = tpc / (nl + eps) # recall curve
|
||||
# negative x, xp because xp decreases
|
||||
r[ci] = numpy.interp(-px, -conf[i], recall[:, 0], left=0)
|
||||
|
||||
# Precision
|
||||
precision = tpc / (tpc + fpc) # precision curve
|
||||
p[ci] = numpy.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
|
||||
|
||||
# AP from recall-precision curve
|
||||
for j in range(tp.shape[1]):
|
||||
m_rec = numpy.concatenate(([0.0], recall[:, j], [1.0]))
|
||||
m_pre = numpy.concatenate(([1.0], precision[:, j], [0.0]))
|
||||
|
||||
# Compute the precision envelope
|
||||
m_pre = numpy.flip(numpy.maximum.accumulate(numpy.flip(m_pre)))
|
||||
|
||||
# Integrate area under curve
|
||||
x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO)
|
||||
ap[ci, j] = numpy.trapz(numpy.interp(x, m_rec, m_pre), x) # integrate
|
||||
if plot and j == 0:
|
||||
py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5
|
||||
|
||||
# Compute F1 (harmonic mean of precision and recall)
|
||||
f1 = 2 * p * r / (p + r + eps)
|
||||
if plot:
|
||||
names = dict(enumerate(names)) # to dict
|
||||
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
|
||||
plot_pr_curve(px, py, ap, names, save_dir="./weights/PR_curve.png")
|
||||
plot_curve(px, f1, names, save_dir="./weights/F1_curve.png", y_label="F1")
|
||||
plot_curve(px, p, names, save_dir="./weights/P_curve.png", y_label="Precision")
|
||||
plot_curve(px, r, names, save_dir="./weights/R_curve.png", y_label="Recall")
|
||||
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
|
||||
p, r, f1 = p[:, i], r[:, i], f1[:, i]
|
||||
tp = (r * nt).round() # true positives
|
||||
fp = (tp / (p + eps) - tp).round() # false positives
|
||||
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
|
||||
m_pre, m_rec = p.mean(), r.mean()
|
||||
map50, mean_ap = ap50.mean(), ap.mean()
|
||||
return tp, fp, m_pre, m_rec, map50, mean_ap
|
||||
|
||||
|
||||
def compute_iou(box1, box2, eps=1e-7):
|
||||
# Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
|
||||
|
||||
# Get the coordinates of bounding boxes
|
||||
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
|
||||
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
|
||||
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
|
||||
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
|
||||
|
||||
# Intersection area
|
||||
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * (
|
||||
b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
|
||||
).clamp(0)
|
||||
|
||||
# Union Area
|
||||
union = w1 * h1 + w2 * h2 - inter + eps
|
||||
|
||||
# IoU
|
||||
iou = inter / union
|
||||
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
|
||||
ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
|
||||
c2 = cw**2 + ch**2 + eps # convex diagonal squared
|
||||
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
|
||||
# https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
|
||||
v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
|
||||
with torch.no_grad():
|
||||
alpha = v / (v - iou + (1 + eps))
|
||||
return iou - (rho2 / c2 + v * alpha) # CIoU
|
||||
|
||||
|
||||
def strip_optimizer(filename):
|
||||
x = torch.load(filename, map_location="cpu", weights_only=False)
|
||||
x["model"].half() # to FP16
|
||||
for p in x["model"].parameters():
|
||||
p.requires_grad = False
|
||||
torch.save(x, f=filename)
|
||||
|
||||
|
||||
def clip_gradients(model, max_norm=10.0):
|
||||
parameters = model.parameters()
|
||||
torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm)
|
||||
|
||||
|
||||
def load_weight(model, ckpt):
|
||||
dst = model.state_dict()
|
||||
src = torch.load(ckpt, weights_only=False)["model"].float().cpu()
|
||||
|
||||
ckpt = {}
|
||||
for k, v in src.state_dict().items():
|
||||
if k in dst and v.shape == dst[k].shape:
|
||||
ckpt[k] = v
|
||||
|
||||
model.load_state_dict(state_dict=ckpt, strict=False)
|
||||
return model
|
||||
|
||||
|
||||
def set_params(model, decay):
|
||||
p1 = []
|
||||
p2 = []
|
||||
norm = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k)
|
||||
for m in model.modules():
|
||||
for n, p in m.named_parameters(recurse=0):
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if n == "bias": # bias (no decay)
|
||||
p1.append(p)
|
||||
elif n == "weight" and isinstance(m, norm): # norm-weight (no decay)
|
||||
p1.append(p)
|
||||
else:
|
||||
p2.append(p) # weight (with decay)
|
||||
return [{"params": p1, "weight_decay": 0.00}, {"params": p2, "weight_decay": decay}]
|
||||
|
||||
|
||||
def plot_lr(args, optimizer, scheduler, num_steps):
|
||||
from matplotlib import pyplot
|
||||
|
||||
optimizer = copy.copy(optimizer)
|
||||
scheduler = copy.copy(scheduler)
|
||||
|
||||
y = []
|
||||
for epoch in range(args.epochs):
|
||||
for i in range(num_steps):
|
||||
step = i + num_steps * epoch
|
||||
scheduler.step(step, optimizer)
|
||||
y.append(optimizer.param_groups[0]["lr"])
|
||||
pyplot.plot(y, ".-", label="LR")
|
||||
pyplot.xlabel("step")
|
||||
pyplot.ylabel("LR")
|
||||
pyplot.grid()
|
||||
pyplot.xlim(0, args.epochs * num_steps)
|
||||
pyplot.ylim(0)
|
||||
pyplot.savefig("./weights/lr.png", dpi=200)
|
||||
pyplot.close()
|
||||
|
||||
|
||||
class CosineLR:
|
||||
def __init__(self, args, params, num_steps):
|
||||
max_lr = params["max_lr"]
|
||||
min_lr = params["min_lr"]
|
||||
|
||||
warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100))
|
||||
decay_steps = int(args.epochs * num_steps - warmup_steps)
|
||||
|
||||
warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps))
|
||||
|
||||
decay_lr = []
|
||||
for step in range(1, decay_steps + 1):
|
||||
alpha = math.cos(math.pi * step / decay_steps)
|
||||
decay_lr.append(min_lr + 0.5 * (max_lr - min_lr) * (1 + alpha))
|
||||
|
||||
self.total_lr = numpy.concatenate((warmup_lr, decay_lr))
|
||||
|
||||
def step(self, step, optimizer):
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = self.total_lr[step]
|
||||
|
||||
|
||||
class LinearLR:
|
||||
def __init__(self, args, params, num_steps):
|
||||
max_lr = params["max_lr"]
|
||||
min_lr = params["min_lr"]
|
||||
|
||||
warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100))
|
||||
decay_steps = max(1, int(args.epochs * num_steps - warmup_steps))
|
||||
|
||||
warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps), endpoint=False)
|
||||
decay_lr = numpy.linspace(max_lr, min_lr, decay_steps)
|
||||
|
||||
self.total_lr = numpy.concatenate((warmup_lr, decay_lr))
|
||||
|
||||
def step(self, step, optimizer):
|
||||
for param_group in optimizer.param_groups:
|
||||
param_group["lr"] = self.total_lr[step]
|
||||
|
||||
|
||||
class EMA:
|
||||
"""
|
||||
Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
|
||||
Keeps a moving average of everything in the model state_dict (parameters and buffers)
|
||||
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
|
||||
"""
|
||||
|
||||
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
|
||||
# Create EMA
|
||||
self.ema = copy.deepcopy(model).eval() # FP32 EMA
|
||||
self.updates = updates # number of EMA updates
|
||||
# decay exponential ramp (to help early epochs)
|
||||
self.decay = lambda x: decay * (1 - math.exp(-x / tau))
|
||||
for p in self.ema.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
def update(self, model):
|
||||
if hasattr(model, "module"):
|
||||
model = model.module
|
||||
# Update EMA parameters
|
||||
with torch.no_grad():
|
||||
self.updates += 1
|
||||
d = self.decay(self.updates)
|
||||
|
||||
msd = model.state_dict() # model state_dict
|
||||
for k, v in self.ema.state_dict().items():
|
||||
if v.dtype.is_floating_point:
|
||||
v *= d
|
||||
v += (1 - d) * msd[k].detach()
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
def __init__(self):
|
||||
self.num = 0
|
||||
self.sum = 0
|
||||
self.avg = 0
|
||||
|
||||
def update(self, v, n):
|
||||
if not math.isnan(float(v)):
|
||||
self.num = self.num + n
|
||||
self.sum = self.sum + v * n
|
||||
self.avg = self.sum / self.num
|
||||
|
||||
|
||||
class Assigner(torch.nn.Module):
|
||||
def __init__(self, nc=80, top_k=13, alpha=1.0, beta=6.0, eps=1e-9):
|
||||
super().__init__()
|
||||
self.top_k = top_k
|
||||
self.nc = nc
|
||||
self.alpha = alpha
|
||||
self.beta = beta
|
||||
self.eps = eps
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
|
||||
batch_size = pd_scores.size(0)
|
||||
num_max_boxes = gt_bboxes.size(1)
|
||||
|
||||
if num_max_boxes == 0:
|
||||
device = gt_bboxes.device
|
||||
return (
|
||||
torch.zeros_like(pd_bboxes).to(device),
|
||||
torch.zeros_like(pd_scores).to(device),
|
||||
torch.zeros_like(pd_scores[..., 0]).to(device),
|
||||
)
|
||||
|
||||
num_anchors = anc_points.shape[0]
|
||||
shape = gt_bboxes.shape
|
||||
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)
|
||||
mask_in_gts = torch.cat((anc_points[None] - lt, rb - anc_points[None]), dim=2)
|
||||
mask_in_gts = mask_in_gts.view(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps)
|
||||
na = pd_bboxes.shape[-2]
|
||||
gt_mask = (mask_in_gts * mask_gt).bool() # b, max_num_obj, h*w
|
||||
overlaps = torch.zeros(
|
||||
[batch_size, num_max_boxes, na],
|
||||
dtype=pd_bboxes.dtype,
|
||||
device=pd_bboxes.device,
|
||||
)
|
||||
bbox_scores = torch.zeros(
|
||||
[batch_size, num_max_boxes, na],
|
||||
dtype=pd_scores.dtype,
|
||||
device=pd_scores.device,
|
||||
)
|
||||
|
||||
ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long) # 2, b, max_num_obj
|
||||
ind[0] = torch.arange(end=batch_size).view(-1, 1).expand(-1, num_max_boxes) # b, max_num_obj
|
||||
ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
|
||||
bbox_scores[gt_mask] = pd_scores[ind[0], :, ind[1]][gt_mask] # b, max_num_obj, h*w
|
||||
|
||||
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, num_max_boxes, -1, -1)[gt_mask]
|
||||
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[gt_mask]
|
||||
overlaps[gt_mask] = compute_iou(gt_boxes, pd_boxes).squeeze(-1).clamp_(0)
|
||||
|
||||
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
|
||||
|
||||
top_k_mask = mask_gt.expand(-1, -1, self.top_k).bool()
|
||||
top_k_metrics, top_k_indices = torch.topk(align_metric, self.top_k, dim=-1, largest=True)
|
||||
if top_k_mask is None:
|
||||
top_k_mask = (top_k_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(top_k_indices)
|
||||
top_k_indices.masked_fill_(~top_k_mask, 0)
|
||||
|
||||
mask_top_k = torch.zeros(align_metric.shape, dtype=torch.int8, device=top_k_indices.device)
|
||||
ones = torch.ones_like(top_k_indices[:, :, :1], dtype=torch.int8, device=top_k_indices.device)
|
||||
for k in range(self.top_k):
|
||||
mask_top_k.scatter_add_(-1, top_k_indices[:, :, k : k + 1], ones)
|
||||
mask_top_k.masked_fill_(mask_top_k > 1, 0)
|
||||
mask_top_k = mask_top_k.to(align_metric.dtype)
|
||||
mask_pos = mask_top_k * mask_in_gts * mask_gt
|
||||
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
if fg_mask.max() > 1:
|
||||
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, num_max_boxes, -1)
|
||||
max_overlaps_idx = overlaps.argmax(1)
|
||||
|
||||
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
|
||||
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
|
||||
|
||||
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()
|
||||
fg_mask = mask_pos.sum(-2)
|
||||
target_gt_idx = mask_pos.argmax(-2)
|
||||
|
||||
# Assigned target
|
||||
index = torch.arange(end=batch_size, dtype=torch.int64, device=gt_labels.device)[..., None]
|
||||
target_index = target_gt_idx + index * num_max_boxes
|
||||
target_labels = gt_labels.long().flatten()[target_index]
|
||||
|
||||
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_index]
|
||||
|
||||
# Assigned target scores
|
||||
target_labels.clamp_(0)
|
||||
|
||||
target_scores = torch.zeros(
|
||||
(target_labels.shape[0], target_labels.shape[1], self.nc),
|
||||
dtype=torch.int64,
|
||||
device=target_labels.device,
|
||||
)
|
||||
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
||||
|
||||
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc)
|
||||
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
|
||||
|
||||
# Normalize
|
||||
align_metric *= mask_pos
|
||||
pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)
|
||||
pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)
|
||||
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
|
||||
target_scores = target_scores * norm_align_metric
|
||||
|
||||
return target_bboxes, target_scores, fg_mask.bool()
|
||||
|
||||
|
||||
class QFL(torch.nn.Module):
|
||||
def __init__(self, beta=2.0):
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
bce_loss = self.bce_loss(outputs, targets)
|
||||
return torch.pow(torch.abs(targets - outputs.sigmoid()), self.beta) * bce_loss
|
||||
|
||||
|
||||
class VFL(torch.nn.Module):
|
||||
def __init__(self, alpha=0.75, gamma=2.00, iou_weighted=True):
|
||||
super().__init__()
|
||||
assert alpha >= 0.0
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.iou_weighted = iou_weighted
|
||||
self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
assert outputs.size() == targets.size()
|
||||
targets = targets.type_as(outputs)
|
||||
|
||||
if self.iou_weighted:
|
||||
focal_weight = (
|
||||
targets * (targets > 0.0).float()
|
||||
+ self.alpha * (outputs.sigmoid() - targets).abs().pow(self.gamma) * (targets <= 0.0).float()
|
||||
)
|
||||
|
||||
else:
|
||||
focal_weight = (targets > 0.0).float() + self.alpha * (outputs.sigmoid() - targets).abs().pow(
|
||||
self.gamma
|
||||
) * (targets <= 0.0).float()
|
||||
|
||||
return self.bce_loss(outputs, targets) * focal_weight
|
||||
|
||||
|
||||
class FocalLoss(torch.nn.Module):
|
||||
def __init__(self, alpha=0.25, gamma=1.5):
|
||||
super().__init__()
|
||||
self.alpha = alpha
|
||||
self.gamma = gamma
|
||||
self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
|
||||
|
||||
def forward(self, outputs, targets):
|
||||
loss = self.bce_loss(outputs, targets)
|
||||
|
||||
if self.alpha > 0:
|
||||
alpha_factor = targets * self.alpha + (1 - targets) * (1 - self.alpha)
|
||||
loss *= alpha_factor
|
||||
|
||||
if self.gamma > 0:
|
||||
outputs_sigmoid = outputs.sigmoid()
|
||||
p_t = targets * outputs_sigmoid + (1 - targets) * (1 - outputs_sigmoid)
|
||||
gamma_factor = (1.0 - p_t) ** self.gamma
|
||||
loss *= gamma_factor
|
||||
|
||||
return loss
|
||||
|
||||
|
||||
class BoxLoss(torch.nn.Module):
|
||||
def __init__(self, dfl_ch):
|
||||
super().__init__()
|
||||
self.dfl_ch = dfl_ch
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pred_dist,
|
||||
pred_bboxes,
|
||||
anchor_points,
|
||||
target_bboxes,
|
||||
target_scores,
|
||||
target_scores_sum,
|
||||
fg_mask,
|
||||
):
|
||||
# IoU loss
|
||||
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
|
||||
iou = compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
|
||||
loss_box = ((1.0 - iou) * weight).sum() / target_scores_sum
|
||||
|
||||
# DFL loss
|
||||
a, b = target_bboxes.chunk(2, -1)
|
||||
target = torch.cat((anchor_points - a, b - anchor_points), -1)
|
||||
target = target.clamp(0, self.dfl_ch - 0.01)
|
||||
loss_dfl = self.df_loss(pred_dist[fg_mask].view(-1, self.dfl_ch + 1), target[fg_mask])
|
||||
loss_dfl = (loss_dfl * weight).sum() / target_scores_sum
|
||||
|
||||
return loss_box, loss_dfl
|
||||
|
||||
@staticmethod
|
||||
def df_loss(pred_dist, target):
|
||||
# Distribution Focal Loss (DFL)
|
||||
# https://ieeexplore.ieee.org/document/9792391
|
||||
tl = target.long() # target left
|
||||
tr = tl + 1 # target right
|
||||
wl = tr - target # weight left
|
||||
wr = 1 - wl # weight right
|
||||
left_loss = cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape)
|
||||
right_loss = cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape)
|
||||
return (left_loss * wl + right_loss * wr).mean(-1, keepdim=True)
|
||||
|
||||
|
||||
class ComputeLoss:
|
||||
def __init__(self, model, params):
|
||||
if hasattr(model, "module"):
|
||||
model = model.module
|
||||
|
||||
device = next(model.parameters()).device
|
||||
|
||||
m = model.head # Head() module
|
||||
|
||||
self.params = params
|
||||
self.stride = m.stride
|
||||
self.nc = m.nc
|
||||
self.no = m.no
|
||||
self.reg_max = m.ch
|
||||
self.device = device
|
||||
|
||||
self.box_loss = BoxLoss(m.ch - 1).to(device)
|
||||
self.cls_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
|
||||
self.assigner = Assigner(nc=self.nc, top_k=10, alpha=0.5, beta=6.0)
|
||||
|
||||
self.project = torch.arange(m.ch, dtype=torch.float, device=device)
|
||||
|
||||
def box_decode(self, anchor_points, pred_dist):
|
||||
b, a, c = pred_dist.shape
|
||||
pred_dist = pred_dist.view(b, a, 4, c // 4)
|
||||
pred_dist = pred_dist.softmax(3)
|
||||
pred_dist = pred_dist.matmul(self.project.type(pred_dist.dtype))
|
||||
lt, rb = pred_dist.chunk(2, -1)
|
||||
x1y1 = anchor_points - lt
|
||||
x2y2 = anchor_points + rb
|
||||
return torch.cat(tensors=(x1y1, x2y2), dim=-1)
|
||||
|
||||
def __call__(self, outputs, targets):
|
||||
x = torch.cat([i.view(outputs[0].shape[0], self.no, -1) for i in outputs], dim=2)
|
||||
pred_distri, pred_scores = x.split(split_size=(self.reg_max * 4, self.nc), dim=1)
|
||||
|
||||
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
|
||||
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
|
||||
|
||||
data_type = pred_scores.dtype
|
||||
batch_size = pred_scores.shape[0]
|
||||
input_size = torch.tensor(outputs[0].shape[2:], device=self.device, dtype=data_type) * self.stride[0]
|
||||
anchor_points, stride_tensor = make_anchors(outputs, self.stride, offset=0.5)
|
||||
|
||||
idx = targets["idx"].view(-1, 1)
|
||||
cls = targets["cls"].view(-1, 1)
|
||||
box = targets["box"]
|
||||
|
||||
targets = torch.cat((idx, cls, box), dim=1).to(self.device)
|
||||
if targets.shape[0] == 0:
|
||||
gt = torch.zeros(batch_size, 0, 5, device=self.device)
|
||||
else:
|
||||
i = targets[:, 0]
|
||||
_, counts = i.unique(return_counts=True)
|
||||
counts = counts.to(dtype=torch.int32)
|
||||
gt = torch.zeros(batch_size, counts.max(), 5, device=self.device)
|
||||
for j in range(batch_size):
|
||||
matches = i == j
|
||||
n = matches.sum()
|
||||
if n:
|
||||
gt[j, :n] = targets[matches, 1:]
|
||||
x = gt[..., 1:5].mul_(input_size[[1, 0, 1, 0]])
|
||||
y = torch.empty_like(x)
|
||||
dw = x[..., 2] / 2 # half-width
|
||||
dh = x[..., 3] / 2 # half-height
|
||||
y[..., 0] = x[..., 0] - dw # top left x
|
||||
y[..., 1] = x[..., 1] - dh # top left y
|
||||
y[..., 2] = x[..., 0] + dw # bottom right x
|
||||
y[..., 3] = x[..., 1] + dh # bottom right y
|
||||
gt[..., 1:5] = y
|
||||
gt_labels, gt_bboxes = gt.split((1, 4), 2)
|
||||
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
|
||||
|
||||
pred_bboxes = self.box_decode(anchor_points, pred_distri)
|
||||
assigned_targets = self.assigner(
|
||||
pred_scores.detach().sigmoid(),
|
||||
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
|
||||
anchor_points * stride_tensor,
|
||||
gt_labels,
|
||||
gt_bboxes,
|
||||
mask_gt,
|
||||
)
|
||||
target_bboxes, target_scores, fg_mask = assigned_targets
|
||||
|
||||
target_scores_sum = max(target_scores.sum(), 1)
|
||||
|
||||
loss_cls = self.cls_loss(pred_scores, target_scores.to(data_type)).sum() / target_scores_sum # BCE
|
||||
|
||||
# Box loss
|
||||
loss_box = torch.zeros(1, device=self.device)
|
||||
loss_dfl = torch.zeros(1, device=self.device)
|
||||
if fg_mask.sum():
|
||||
target_bboxes /= stride_tensor
|
||||
loss_box, loss_dfl = self.box_loss(
|
||||
pred_distri,
|
||||
pred_bboxes,
|
||||
anchor_points,
|
||||
target_bboxes,
|
||||
target_scores,
|
||||
target_scores_sum,
|
||||
fg_mask,
|
||||
)
|
||||
|
||||
loss_box *= self.params["box"] # box gain
|
||||
loss_cls *= self.params["cls"] # cls gain
|
||||
loss_dfl *= self.params["dfl"] # dfl gain
|
||||
|
||||
return loss_box, loss_cls, loss_dfl
|
Reference in New Issue
Block a user