Fedavg and YOLOv11 training

This commit is contained in:
TY1667
2025-10-02 16:26:27 +08:00
parent a60e002733
commit 1ae76d0aed
10 changed files with 2749 additions and 0 deletions

126
config/coco_cfg.yaml Normal file
View File

@@ -0,0 +1,126 @@
# global system:
fed_algo: "FedAvg" # federated learning algorithm
model_name: "yolo_v11_n" # yolo_v11_n, yolo_v11_t, yolo_v11_s, yolo_v11_m, yolo_v11_l, yolo_v11_x
i_seed: 202509 # initial random seed
num_client: 64 # total number of clients
num_round: 5 # total number of communication rounds
num_local_class: 80 # number of classes per client
res_root: "results" # root directory for results
dataset_path: "/home/image1325/ssd1/dataset/COCO128/"
# train_txt: "train.txt" # path to training set txt file
# val_txt: "val.txt" # path to validation set txt file
# test_txt: "test.txt" # path to test set txt file
local_batch_size: 32 # local training batch size
val_batch_size: 4 # validation batch size
num_workers: 4 # number of data loader workers
min_data: 128 # minimum number of images per client
max_data: 128 # maximum number of images per client
partition_mode: "overlap" # "overlap" or "disjoint"
connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients
# local training:
min_lr: 0.000100000000 # initial learning rate
max_lr: 0.010000000000 # maximum learning rate
momentum: 0.9370000000 # SGD momentum/Adam beta1
weight_decay: 0.000500 # optimizer weight decay
warmup_epochs: 3.00000 # warmup epochs
box: 7.500000000000000 # box loss gain
cls: 0.500000000000000 # cls loss gain
dfl: 1.500000000000000 # dfl loss gain
hsv_h: 0.0150000000000 # image HSV-Hue augmentation (fraction)
hsv_s: 0.7000000000000 # image HSV-Saturation augmentation (fraction)
hsv_v: 0.4000000000000 # image HSV-Value augmentation (fraction)
degrees: 0.00000000000 # image rotation (+/- deg)
translate: 0.100000000 # image translation (+/- fraction)
scale: 0.5000000000000 # image scale (+/- gain)
shear: 0.0000000000000 # image shear (+/- deg)
flip_ud: 0.00000000000 # image flip up-down (probability)
flip_lr: 0.50000000000 # image flip left-right (probability)
mosaic: 1.000000000000 # image mosaic (probability)
mix_up: 0.000000000000 # image mix-up (probability)
names:
0: person
1: bicycle
2: car
3: motorcycle
4: airplane
5: bus
6: train
7: truck
8: boat
9: traffic light
10: fire hydrant
11: stop sign
12: parking meter
13: bench
14: bird
15: cat
16: dog
17: horse
18: sheep
19: cow
20: elephant
21: bear
22: zebra
23: giraffe
24: backpack
25: umbrella
26: handbag
27: tie
28: suitcase
29: frisbee
30: skis
31: snowboard
32: sports ball
33: kite
34: baseball bat
35: baseball glove
36: skateboard
37: surfboard
38: tennis racket
39: bottle
40: wine glass
41: cup
42: fork
43: knife
44: spoon
45: bowl
46: banana
47: apple
48: sandwich
49: orange
50: broccoli
51: carrot
52: hot dog
53: pizza
54: donut
55: cake
56: chair
57: couch
58: potted plant
59: bed
60: dining table
61: toilet
62: tv
63: laptop
64: mouse
65: remote
66: keyboard
67: cell phone
68: microwave
69: oven
70: toaster
71: sink
72: refrigerator
73: book
74: clock
75: vase
76: scissors
77: teddy bear
78: hair drier
79: toothbrush

47
config/uav_cfg.yaml Normal file
View File

@@ -0,0 +1,47 @@
# global system:
fed_algo: "FedAvg" # federated learning algorithm
model_name: "yolo_v11_n" # yolo_v11_n, yolo_v11_t, yolo_v11_s, yolo_v11_m, yolo_v11_l, yolo_v11_x
i_seed: 202509 # initial random seed
num_client: 100 # total number of clients
num_round: 500 # total number of communication rounds
num_local_class: 1 # number of classes per client
res_root: "results" # root directory for results
dataset_path: "/home/image1325/ssd1/dataset/uav/"
# train_txt: "train.txt" # path to training set txt file
# val_txt: "val.txt" # path to validation set txt file
# test_txt: "test.txt" # path to test set txt file
local_batch_size: 32 # local training batch size
val_batch_size: 16 # validation batch size
num_workers: 4 # number of data loader workers
min_data: 640 # minimum number of images per client
max_data: 720 # maximum number of images per client
partition_mode: "overlap" # "overlap" or "disjoint"
connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients
# local training:
min_lr: 0.000100000000 # initial learning rate
max_lr: 0.010000000000 # maximum learning rate
momentum: 0.9370000000 # SGD momentum/Adam beta1
weight_decay: 0.000500 # optimizer weight decay
warmup_epochs: 3.00000 # warmup epochs
box: 7.500000000000000 # box loss gain
cls: 0.500000000000000 # cls loss gain
dfl: 1.500000000000000 # dfl loss gain
hsv_h: 0.0150000000000 # image HSV-Hue augmentation (fraction)
hsv_s: 0.7000000000000 # image HSV-Saturation augmentation (fraction)
hsv_v: 0.4000000000000 # image HSV-Value augmentation (fraction)
degrees: 0.00000000000 # image rotation (+/- deg)
translate: 0.100000000 # image translation (+/- fraction)
scale: 0.5000000000000 # image scale (+/- gain)
shear: 0.0000000000000 # image shear (+/- deg)
flip_ud: 0.00000000000 # image flip up-down (probability)
flip_lr: 0.50000000000 # image flip left-right (probability)
mosaic: 1.000000000000 # image mosaic (probability)
mix_up: 0.000000000000 # image mix-up (probability)
names:
0: uav

233
fed_algo_cs/client_base.py Normal file
View File

@@ -0,0 +1,233 @@
import numpy as np
import torch
from torch import nn
from torch.utils import data
from torch.amp.autocast_mode import autocast
from utils.fed_util import init_model
from utils import util
from utils.dataset import Dataset
from typing import cast
class FedYoloClient(object):
def __init__(self, name, model_name, params):
"""
Initialize the client k for federated learning
Args:
:param name: Name of the client k
:param model_name: Name of the model
:param params: config file including the hyperparameters for local training
- batch_size: Local training batch size in the client k
- num_workers: Number of data loader workers
- min_lr: Minimum learning rate
- max_lr: Maximum learning rate
- momentum: Momentum for local training
- weight_decay: Weight decay for local training
"""
self.params = params
# initialize the metadata in local client k
self.target_ip = "127.0.0.3"
self.port = 9999
self.name = name
# initialize the parameters in local client k
self._batch_size = self.params["local_batch_size"]
self._min_lr = self.params["min_lr"]
self._max_lr = self.params["max_lr"]
self._momentum = self.params["momentum"]
self.num_workers = self.params["num_workers"]
self.loss_record = []
# train set length
self.n_data = 0
# initialize the local training and testing dataset
self.train_dataset = None
self.val_dataset = None
# initialize the local model
self._num_classes = len(self.params["names"])
self._weight_decay = self.params["weight_decay"]
self.model_name = model_name
self.model = init_model(model_name, self._num_classes)
model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
self.parameter_number = sum([np.prod(p.size()) for p in model_parameters])
# GPU
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def load_trainset(self, train_dataset: list[str]):
"""
Load the local training dataset
Args:
:param train_dataset: Training dataset
"""
self.train_dataset = train_dataset
self.n_data = len(self.train_dataset)
def update(self, Global_model_state_dict):
"""
Update the local model with the global model parameters
Args:
:param Global_model_state_dict: State dictionary of the global model
"""
if not hasattr(self, "model") or self.model is None:
self.model = init_model(self.model_name, self._num_classes)
# load the global model parameters
self.model.load_state_dict(Global_model_state_dict, strict=True)
def train(self, args):
"""
Train the local model
Args:
:param args: Command line arguments
- local_rank: Local rank for distributed training
- world_size: World size for distributed training
- distributed: Whether to use distributed training
- input_size: Input size for the model
Returns:
:return: Local updated model, number of local data points, training loss
"""
if args.distributed:
torch.cuda.set_device(device=args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
# print(f"Client {self.name} - distributed training on {world_size} GPUs, local rank: {local_rank}")
# self._device = torch.device("cuda", local_rank)
if args.local_rank == 0:
pass
# if not os.path.exists("weights"):
# os.makedirs("weights")
util.setup_seed()
util.setup_multi_processes()
# model
# init model have been done in __init__()
self.model.to(self._device)
# Optimizer
accumulate = max(round(64 / (self._batch_size * args.world_size)), 1)
self._weight_decay = self._batch_size * args.world_size * accumulate / 64
optimizer = torch.optim.SGD(
util.set_params(self.model, self._weight_decay),
lr=self._min_lr,
momentum=self._momentum,
nesterov=True,
)
# EMA
ema = util.EMA(self.model) if args.local_rank == 0 else None
data_set = Dataset(
filenames=self.train_dataset,
input_size=args.input_size,
params=self.params,
augment=True,
)
if args.distributed:
train_sampler = data.DistributedSampler(
data_set, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
)
else:
train_sampler = None
loader = data.DataLoader(
data_set,
batch_size=self._batch_size,
shuffle=train_sampler is None,
sampler=train_sampler,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=Dataset.collate_fn,
)
# Scheduler
num_steps = max(1, len(loader))
# print(len(loader))
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
# DDP mode
if args.distributed:
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
self.model = nn.parallel.DistributedDataParallel(
module=self.model,
device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=False,
)
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True)
criterion = util.ComputeLoss(self.model, self.params)
optimizer.zero_grad(set_to_none=True)
for epoch in range(args.epochs):
self.model.train()
# when distributed, set epoch for shuffling
if args.distributed and train_sampler is not None:
train_sampler.set_epoch(epoch)
if args.epochs - epoch == 10:
# disable mosaic augmentation in the last 10 epochs
ds = cast(Dataset, loader.dataset)
ds.mosaic = False
avg_box_loss = util.AverageMeter()
avg_cls_loss = util.AverageMeter()
avg_dfl_loss = util.AverageMeter()
for i, (samples, targets) in enumerate(loader):
global_step = i + num_steps * epoch
scheduler.step(step=global_step, optimizer=optimizer)
samples = samples.cuda(non_blocking=True).float() / 255.0
# Forward
with autocast("cuda", enabled=True):
outputs = self.model(samples)
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
# meters (use the *unscaled* values)
bs = samples.size(0)
avg_box_loss.update(box_loss.item(), bs)
avg_cls_loss.update(cls_loss.item(), bs)
avg_dfl_loss.update(dfl_loss.item(), bs)
# scale losses by batch/world if your loss is averaged internally per-sample/device
box_loss = box_loss * self._batch_size * args.world_size
cls_loss = cls_loss * self._batch_size * args.world_size
dfl_loss = dfl_loss * self._batch_size * args.world_size
total_loss = box_loss + cls_loss + dfl_loss
# Backward
amp_scale.scale(total_loss).backward()
# Optimize
if (i + 1) % accumulate == 0:
amp_scale.step(optimizer)
amp_scale.update()
optimizer.zero_grad(set_to_none=True)
if ema:
ema.update(self.model)
# torch.cuda.synchronize()
# clean
if args.distributed:
torch.distributed.destroy_process_group()
torch.cuda.empty_cache()
return (
self.model.state_dict(),
self.n_data,
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg},
)

178
fed_algo_cs/server_base.py Normal file
View File

@@ -0,0 +1,178 @@
import numpy as np
import torch
from torch.utils.data import DataLoader
from utils.fed_util import init_model
from utils.dataset import Dataset
from utils import util
class FedYoloServer(object):
def __init__(self, client_list, model_name, params):
"""
Federated YOLO Server
Args:
client_list: list of connected clients
model_name: YOLO model architecture name
params: dict of hyperparameters (must include 'names')
"""
# Track client updates
self.client_state = {}
self.client_loss = {}
self.client_n_data = {}
self.selected_clients = []
self._batch_size = params.get("val_batch_size", 4)
self.client_list = client_list
self.valset = None
# Federated bookkeeping
self.round = 0
# Total number of classes
self.n_data = 0
# Device
gpu = 0
self._device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")
# Global model
self._num_classes = len(params["names"])
self.model_name = model_name
self.model = init_model(model_name, self._num_classes)
self.params = params
def load_valset(self, valset):
"""Server loads the validation dataset."""
self.valset = valset
def state_dict(self):
"""Return global model weights."""
return self.model.state_dict()
@torch.no_grad()
def test(self, args):
"""
Evaluate global model on validation set using YOLO metrics (mAP, precision, recall).
Returns:
dict with {"mAP": ..., "mAP50": ..., "precision": ..., "recall": ...}
"""
if self.valset is None:
return {}
loader = DataLoader(
self.valset,
batch_size=self._batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
collate_fn=Dataset.collate_fn,
)
self.model.to(self._device).eval().half()
iou_v = torch.linspace(0.5, 0.95, 10).to(self._device) # IoU thresholds
n_iou = iou_v.numel()
metrics = []
for samples, targets in loader:
samples = samples.to(self._device).half() / 255.0
_, _, h, w = samples.shape
scale = torch.tensor((w, h, w, h)).to(self._device)
outputs = self.model(samples)
outputs = util.non_max_suppression(outputs)
for i, output in enumerate(outputs):
idx = targets["idx"] == i
cls = targets["cls"][idx].to(self._device)
box = targets["box"][idx].to(self._device)
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=self._device)
if output.shape[0] == 0:
if cls.shape[0]:
metrics.append((metric, *torch.zeros((2, 0), device=self._device), cls.squeeze(-1)))
continue
if cls.shape[0]:
cls_tensor = cls if isinstance(cls, torch.Tensor) else torch.tensor(cls, device=self._device)
if cls_tensor.dim() == 1:
cls_tensor = cls_tensor.unsqueeze(1)
box_xy = util.wh2xy(box)
if not isinstance(box_xy, torch.Tensor):
box_xy = torch.tensor(box_xy, device=self._device)
target = torch.cat((cls_tensor, box_xy * scale), dim=1)
metric = util.compute_metric(output[:, :6], target, iou_v)
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
# Compute metrics
if not metrics:
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
if len(metrics) and metrics[0].any():
_, _, prec, rec, map50, mean_ap = util.compute_ap(*metrics, names=self.params["names"], plot=False)
else:
prec, rec, map50, mean_ap = 0, 0, 0, 0
# Back to float32 for further training
self.model.float()
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
def select_clients(self, connection_ratio=1.0):
"""Randomly select a fraction of clients."""
self.selected_clients = []
self.n_data = 0
for client_id in self.client_list:
if np.random.rand() <= connection_ratio:
self.selected_clients.append(client_id)
self.n_data += self.client_n_data.get(client_id, 0)
def agg(self):
"""Aggregate client updates (FedAvg)."""
if len(self.selected_clients) == 0 or self.n_data == 0:
return self.model.state_dict(), {}, 0
model = init_model(self.model_name, self._num_classes)
model_state = model.state_dict()
avg_loss = {}
for i, name in enumerate(self.selected_clients):
if name not in self.client_state:
continue
weight = self.client_n_data[name] / self.n_data
for key in model_state.keys():
if i == 0:
model_state[key] = self.client_state[name][key] * weight
else:
model_state[key] += self.client_state[name][key] * weight
# Weighted average losses
for k, v in self.client_loss[name].items():
avg_loss[k] = avg_loss.get(k, 0.0) + v * weight
self.model.load_state_dict(model_state, strict=True)
self.round += 1
return model_state, avg_loss, self.n_data
def rec(self, name, state_dict, n_data, loss_dict):
"""
Receive local update from a client.
Args:
name: client ID
state_dict: state dictionary of the local model
n_data: number of data samples used in local training
loss_dict: dict of losses from local training
"""
self.n_data += n_data
self.client_state[name] = {k: v.cpu() for k, v in state_dict.items()}
self.client_n_data[name] = n_data
self.client_loss[name] = loss_dict
def flush(self):
"""Clear stored client updates."""
self.n_data = 0
self.client_state.clear()
self.client_n_data.clear()
self.client_loss.clear()

239
fed_run.py Normal file
View File

@@ -0,0 +1,239 @@
#!/usr/bin/env python3
import os
import json
import yaml
import time
import random
from tqdm import tqdm
import numpy as np
import torch
import matplotlib.pyplot as plt
from utils.dataset import Dataset
from fed_algo_cs.client_base import FedYoloClient
from fed_algo_cs.server_base import FedYoloServer
from utils.args import args_parser # your args parser
from utils.fed_util import divide_trainset # divide_trainset is yours
def _read_list_file(txt_path: str):
"""Read one path per line; keep as-is (absolute or relative)."""
if not txt_path or not os.path.exists(txt_path):
return []
with open(txt_path, "r", encoding="utf-8") as f:
return [ln.strip() for ln in f if ln.strip()]
def _build_valset_if_available(cfg, params):
"""
Try to build a validation Dataset.
- If cfg['val_txt'] exists, use it.
- Else if <dataset_path>/val.txt exists, use it.
- Else return None (testing will be skipped).
Args:
cfg: config dict
params: params dict for Dataset
Returns:
Dataset or None
"""
input_size = int(cfg.get("input_size", 640))
val_txt = cfg.get("val_txt", "")
if not val_txt:
ds_root = cfg.get("dataset_path", "")
guess = os.path.join(ds_root, "val.txt") if ds_root else ""
val_txt = guess if os.path.exists(guess) else ""
val_files = _read_list_file(val_txt)
if not val_files:
return None
return Dataset(
filenames=val_files,
input_size=input_size,
params=params,
augment=True,
)
def _seed_everything(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
def _plot_curves(save_dir, hist):
"""
Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round.
"""
os.makedirs(save_dir, exist_ok=True)
rounds = np.arange(1, len(hist["mAP"]) + 1)
plt.figure()
if hist["mAP"]:
plt.plot(rounds, hist["mAP"], label="mAP50-95")
if hist["mAP50"]:
plt.plot(rounds, hist["mAP50"], label="mAP50")
if hist["precision"]:
plt.plot(rounds, hist["precision"], label="precision")
if hist["recall"]:
plt.plot(rounds, hist["recall"], label="recall")
if hist["train_loss"]:
plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)")
plt.xlabel("Global Round")
plt.ylabel("Metric")
plt.title("Federated YOLO - Server Metrics")
plt.legend()
out_png = os.path.join(save_dir, "fed_yolo_curves.png")
plt.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"[plot] saved: {out_png}")
def fed_run():
"""
Main FL process:
- Initialize clients & server
- For each round: sequential local training -> record -> select -> aggregate
- Test & flush
- Record & save results, plot curves
"""
args_cli = args_parser()
with open(args_cli.config, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
# --- params / config normalization ---
# For convenience we pass the same `params` dict used by Dataset/model/loss.
# Here we re-use the top-level cfg directly as params.
params = dict(cfg)
if "names" in cfg and isinstance(cfg["names"], dict):
# Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
# but we can leave dict; your utils appear to accept dict
pass
# seeds
_seed_everything(int(cfg.get("i_seed", 0)))
# --- split clients' train data from a global train list ---
# Expect either cfg["train_txt"] or <dataset_path>/train.txt
train_txt = cfg.get("train_txt", "")
if not train_txt:
ds_root = cfg.get("dataset_path", "")
guess = os.path.join(ds_root, "train.txt") if ds_root else ""
train_txt = guess
if not train_txt or not os.path.exists(train_txt):
raise FileNotFoundError(
f"train.txt not found. Provide --config with 'train_txt' or ensure '{train_txt}' exists."
)
split = divide_trainset(
trainset_path=train_txt,
num_local_class=int(cfg.get("num_local_class", 1)),
num_client=int(cfg.get("num_client", 64)),
min_data=int(cfg.get("min_data", 100)),
max_data=int(cfg.get("max_data", 100)),
mode=str(cfg.get("partition_mode", "disjoint")), # "overlap" or "disjoint"
seed=int(cfg.get("i_seed", 0)),
)
users = split["users"]
user_data = split["user_data"] # mapping: id -> {"filename": [...]}
# --- build clients ---
model_name = cfg.get("model_name", "yolo_v11_n")
clients = {}
for uid in users:
c = FedYoloClient(name=uid, model_name=model_name, params=params)
c.load_trainset(user_data[uid]["filename"])
clients[uid] = c
# --- build server & optional validation set ---
server = FedYoloServer(client_list=users, model_name=model_name, params=params)
valset = _build_valset_if_available(cfg, params)
# valset is a Dataset class, not data loader
if valset is not None:
server.load_valset(valset)
# --- push initial global weights ---
global_state = server.state_dict()
# --- args object for client.train() ---
# args_train = _make_args_for_client(cfg, args_cli)
# --- history recorder ---
history = {
"mAP": [],
"mAP50": [],
"precision": [],
"recall": [],
"train_loss": [], # scalar sum of client-weighted dict losses
"round_time_sec": [],
}
# --- main FL loop ---
num_round = int(cfg.get("num_round", 50))
connection_ratio = float(cfg.get("connection_ratio", 1.0)) # e.g., 1.0 = all clients
res_root = cfg.get("res_root", "results")
os.makedirs(res_root, exist_ok=True)
for rnd in tqdm(range(num_round), desc="main federal loop round"):
t0 = time.time()
# Local training (sequential over all users)
for uid in tqdm(users, desc=f"Round {rnd + 1} local training", leave=False):
client = clients[uid] # FedYoloClient instance
client.update(global_state) # load global weights
state_dict, n_data, loss_dict = client.train(args_cli) # local training
server.rec(uid, state_dict, n_data, loss_dict)
# Select a fraction for aggregation (FedAvg subset if desired)
server.select_clients(connection_ratio=connection_ratio)
# Aggregate
global_state, avg_loss_dict, _ = server.agg()
# Compute a scalar train loss for plotting (sum of components)
scalar_train_loss = float(sum(avg_loss_dict.values())) if avg_loss_dict else 0.0
# Test (if valset provided)
test_metrics = server.test(args_cli) if server.valset is not None else {}
mAP = float(test_metrics.get("mAP", 0.0))
mAP50 = float(test_metrics.get("mAP50", 0.0))
precision = float(test_metrics.get("precision", 0.0))
recall = float(test_metrics.get("recall", 0.0))
# Flush per-round client caches
server.flush()
# Record & log
history["mAP"].append(mAP)
history["mAP50"].append(mAP50)
history["precision"].append(precision)
history["recall"].append(recall)
history["train_loss"].append(scalar_train_loss)
history["round_time_sec"].append(time.time() - t0)
print(
f"[round {rnd + 1:04d}] "
f"loss={scalar_train_loss:.4f} mAP50-95={mAP:.4f} mAP50={mAP50:.4f} "
f"P={precision:.4f} R={recall:.4f}"
)
# Save running JSON (resumable logs)
save_name = (
f"[{cfg.get('fed_algo', 'FedAvg')},{cfg.get('model_name', 'yolo')},"
f"{cfg.get('num_local_epoch', cfg.get('client', {}).get('num_local_epoch', 1))},"
f"{cfg.get('num_local_class', 2)},"
f"{cfg.get('i_seed', 0)}]"
)
out_json = os.path.join(res_root, save_name + ".json")
with open(out_json, "w", encoding="utf-8") as f:
json.dump(history, f, indent=2)
# --- final plot ---
_plot_curves(res_root, history)
print("[done] training complete.")
if __name__ == "__main__":
fed_run()

362
nets/nn.py Normal file
View File

@@ -0,0 +1,362 @@
"""
This file contains the model definition of YOLOv11
"""
import math
import torch
from utils.util import make_anchors
def fuse_conv(conv, norm):
fused_conv = (
torch.nn.Conv2d(
conv.in_channels,
conv.out_channels,
kernel_size=conv.kernel_size,
stride=conv.stride,
padding=conv.padding,
groups=conv.groups,
bias=True,
)
.requires_grad_(False)
.to(conv.weight.device)
)
w_conv = conv.weight.clone().view(conv.out_channels, -1)
w_norm = torch.diag(norm.weight.div(torch.sqrt(norm.eps + norm.running_var)))
fused_conv.weight.copy_(torch.mm(w_norm, w_conv).view(fused_conv.weight.size()))
b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias
b_norm = norm.bias - norm.weight.mul(norm.running_mean).div(torch.sqrt(norm.running_var + norm.eps))
if fused_conv.bias is not None:
fused_conv.bias.copy_(torch.mm(w_norm, b_conv.reshape(-1, 1)).reshape(-1) + b_norm)
return fused_conv
class Conv(torch.nn.Module):
def __init__(self, in_ch, out_ch, activation, k=1, s=1, p=0, g=1):
super().__init__()
self.conv = torch.nn.Conv2d(in_ch, out_ch, k, s, p, groups=g, bias=False)
self.norm = torch.nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.03)
self.relu = activation
def forward(self, x):
return self.relu(self.norm(self.conv(x)))
def fuse_forward(self, x):
return self.relu(self.conv(x))
class Residual(torch.nn.Module):
def __init__(self, ch, e=0.5):
super().__init__()
self.conv1 = Conv(ch, int(ch * e), torch.nn.SiLU(), k=3, p=1)
self.conv2 = Conv(int(ch * e), ch, torch.nn.SiLU(), k=3, p=1)
def forward(self, x):
return x + self.conv2(self.conv1(x))
class CSPModule(torch.nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.conv1 = Conv(in_ch, out_ch // 2, torch.nn.SiLU())
self.conv2 = Conv(in_ch, out_ch // 2, torch.nn.SiLU())
self.conv3 = Conv(2 * (out_ch // 2), out_ch, torch.nn.SiLU())
self.res_m = torch.nn.Sequential(Residual(out_ch // 2, e=1.0), Residual(out_ch // 2, e=1.0))
def forward(self, x):
y = self.res_m(self.conv1(x))
return self.conv3(torch.cat((y, self.conv2(x)), dim=1))
class CSP(torch.nn.Module):
def __init__(self, in_ch, out_ch, n, csp, r):
super().__init__()
self.conv1 = Conv(in_ch, 2 * (out_ch // r), torch.nn.SiLU())
self.conv2 = Conv((2 + n) * (out_ch // r), out_ch, torch.nn.SiLU())
if not csp:
self.res_m = torch.nn.ModuleList(Residual(out_ch // r) for _ in range(n))
else:
self.res_m = torch.nn.ModuleList(CSPModule(out_ch // r, out_ch // r) for _ in range(n))
def forward(self, x):
y = list(self.conv1(x).chunk(2, 1))
y.extend(m(y[-1]) for m in self.res_m)
return self.conv2(torch.cat(y, dim=1))
class SPP(torch.nn.Module):
def __init__(self, in_ch, out_ch, k=5):
super().__init__()
self.conv1 = Conv(in_ch, in_ch // 2, torch.nn.SiLU())
self.conv2 = Conv(in_ch * 2, out_ch, torch.nn.SiLU())
self.res_m = torch.nn.MaxPool2d(k, stride=1, padding=k // 2)
def forward(self, x):
x = self.conv1(x)
y1 = self.res_m(x)
y2 = self.res_m(y1)
return self.conv2(torch.cat(tensors=[x, y1, y2, self.res_m(y2)], dim=1))
class Attention(torch.nn.Module):
def __init__(self, ch, num_head):
super().__init__()
self.num_head = num_head
self.dim_head = ch // num_head
self.dim_key = self.dim_head // 2
self.scale = self.dim_key**-0.5
self.qkv = Conv(ch, ch + self.dim_key * num_head * 2, torch.nn.Identity())
self.conv1 = Conv(ch, ch, torch.nn.Identity(), k=3, p=1, g=ch)
self.conv2 = Conv(ch, ch, torch.nn.Identity())
def forward(self, x):
b, c, h, w = x.shape
qkv = self.qkv(x)
qkv = qkv.view(b, self.num_head, self.dim_key * 2 + self.dim_head, h * w)
q, k, v = qkv.split([self.dim_key, self.dim_key, self.dim_head], dim=2)
attn = (q.transpose(-2, -1) @ k) * self.scale
attn = attn.softmax(dim=-1)
x = (v @ attn.transpose(-2, -1)).view(b, c, h, w) + self.conv1(v.reshape(b, c, h, w))
return self.conv2(x)
class PSABlock(torch.nn.Module):
def __init__(self, ch, num_head):
super().__init__()
self.conv1 = Attention(ch, num_head)
self.conv2 = torch.nn.Sequential(Conv(ch, ch * 2, torch.nn.SiLU()), Conv(ch * 2, ch, torch.nn.Identity()))
def forward(self, x):
x = x + self.conv1(x)
return x + self.conv2(x)
class PSA(torch.nn.Module):
def __init__(self, ch, n):
super().__init__()
self.conv1 = Conv(ch, 2 * (ch // 2), torch.nn.SiLU())
self.conv2 = Conv(2 * (ch // 2), ch, torch.nn.SiLU())
self.res_m = torch.nn.Sequential(*(PSABlock(ch // 2, ch // 128) for _ in range(n)))
def forward(self, x):
x, y = self.conv1(x).chunk(2, 1)
return self.conv2(torch.cat(tensors=(x, self.res_m(y)), dim=1))
class DarkNet(torch.nn.Module):
def __init__(self, width, depth, csp):
super().__init__()
self.p1 = []
self.p2 = []
self.p3 = []
self.p4 = []
self.p5 = []
# p1/2
self.p1.append(Conv(width[0], width[1], torch.nn.SiLU(), k=3, s=2, p=1))
# p2/4
self.p2.append(Conv(width[1], width[2], torch.nn.SiLU(), k=3, s=2, p=1))
self.p2.append(CSP(width[2], width[3], depth[0], csp[0], r=4))
# p3/8
self.p3.append(Conv(width[3], width[3], torch.nn.SiLU(), k=3, s=2, p=1))
self.p3.append(CSP(width[3], width[4], depth[1], csp[0], r=4))
# p4/16
self.p4.append(Conv(width[4], width[4], torch.nn.SiLU(), k=3, s=2, p=1))
self.p4.append(CSP(width[4], width[4], depth[2], csp[1], r=2))
# p5/32
self.p5.append(Conv(width[4], width[5], torch.nn.SiLU(), k=3, s=2, p=1))
self.p5.append(CSP(width[5], width[5], depth[3], csp[1], r=2))
self.p5.append(SPP(width[5], width[5]))
self.p5.append(PSA(width[5], depth[4]))
self.p1 = torch.nn.Sequential(*self.p1)
self.p2 = torch.nn.Sequential(*self.p2)
self.p3 = torch.nn.Sequential(*self.p3)
self.p4 = torch.nn.Sequential(*self.p4)
self.p5 = torch.nn.Sequential(*self.p5)
def forward(self, x):
p1 = self.p1(x)
p2 = self.p2(p1)
p3 = self.p3(p2)
p4 = self.p4(p3)
p5 = self.p5(p4)
return p3, p4, p5
class DarkFPN(torch.nn.Module):
def __init__(self, width, depth, csp):
super().__init__()
self.up = torch.nn.Upsample(scale_factor=2)
self.h1 = CSP(width[4] + width[5], width[4], depth[5], csp[0], r=2)
self.h2 = CSP(width[4] + width[4], width[3], depth[5], csp[0], r=2)
self.h3 = Conv(width[3], width[3], torch.nn.SiLU(), k=3, s=2, p=1)
self.h4 = CSP(width[3] + width[4], width[4], depth[5], csp[0], r=2)
self.h5 = Conv(width[4], width[4], torch.nn.SiLU(), k=3, s=2, p=1)
self.h6 = CSP(width[4] + width[5], width[5], depth[5], csp[1], r=2)
def forward(self, x):
p3, p4, p5 = x
p4 = self.h1(torch.cat(tensors=[self.up(p5), p4], dim=1))
p3 = self.h2(torch.cat(tensors=[self.up(p4), p3], dim=1))
p4 = self.h4(torch.cat(tensors=[self.h3(p3), p4], dim=1))
p5 = self.h6(torch.cat(tensors=[self.h5(p4), p5], dim=1))
return p3, p4, p5
class DFL(torch.nn.Module):
# Generalized Focal Loss
# https://ieeexplore.ieee.org/document/9792391
def __init__(self, ch=16):
super().__init__()
self.ch = ch
self.conv = torch.nn.Conv2d(ch, out_channels=1, kernel_size=1, bias=False).requires_grad_(False)
x = torch.arange(ch, dtype=torch.float).view(1, ch, 1, 1)
self.conv.weight.data[:] = torch.nn.Parameter(x)
def forward(self, x):
b, c, a = x.shape
x = x.view(b, 4, self.ch, a).transpose(2, 1)
return self.conv(x.softmax(1)).view(b, 4, a)
class Head(torch.nn.Module):
anchors = torch.empty(0)
strides = torch.empty(0)
def __init__(self, nc=80, filters=()):
super().__init__()
self.ch = 16 # DFL channels
self.nc = nc # number of classes
self.nl = len(filters) # number of detection layers
self.no = nc + self.ch * 4 # number of outputs per anchor
self.stride = torch.zeros(self.nl) # strides computed during build
box = max(64, filters[0] // 4)
cls = max(80, filters[0], self.nc)
self.dfl = DFL(self.ch)
self.box = torch.nn.ModuleList(
torch.nn.Sequential(
Conv(x, box, torch.nn.SiLU(), k=3, p=1),
Conv(box, box, torch.nn.SiLU(), k=3, p=1),
torch.nn.Conv2d(box, out_channels=4 * self.ch, kernel_size=1),
)
for x in filters
)
self.cls = torch.nn.ModuleList(
torch.nn.Sequential(
Conv(x, x, torch.nn.SiLU(), k=3, p=1, g=x),
Conv(x, cls, torch.nn.SiLU()),
Conv(cls, cls, torch.nn.SiLU(), k=3, p=1, g=cls),
Conv(cls, cls, torch.nn.SiLU()),
torch.nn.Conv2d(cls, out_channels=self.nc, kernel_size=1),
)
for x in filters
)
def forward(self, x):
for i, (box, cls) in enumerate(zip(self.box, self.cls)):
x[i] = torch.cat(tensors=(box(x[i]), cls(x[i])), dim=1)
if self.training:
return x
self.anchors, self.strides = (i.transpose(0, 1) for i in make_anchors(x, self.stride))
x = torch.cat([i.view(x[0].shape[0], self.no, -1) for i in x], dim=2)
box, cls = x.split(split_size=(4 * self.ch, self.nc), dim=1)
a, b = self.dfl(box).chunk(2, 1)
a = self.anchors.unsqueeze(0) - a
b = self.anchors.unsqueeze(0) + b
box = torch.cat(tensors=((a + b) / 2, b - a), dim=1)
return torch.cat(tensors=(box * self.strides, cls.sigmoid()), dim=1)
def initialize_biases(self):
# Initialize biases
# WARNING: requires stride availability
for box, cls, s in zip(self.box, self.cls, self.stride):
# box
box[-1].bias.data[:] = 1.0
# cls (.01 objects, 80 classes, 640 image)
cls[-1].bias.data[: self.nc] = math.log(5 / self.nc / (640 / s) ** 2)
class YOLO(torch.nn.Module):
def __init__(self, width, depth, csp, num_classes):
super().__init__()
self.net = DarkNet(width, depth, csp)
self.fpn = DarkFPN(width, depth, csp)
img_dummy = torch.zeros(1, width[0], 256, 256)
self.head = Head(num_classes, (width[3], width[4], width[5]))
self.head.stride = torch.tensor([256 / x.shape[-2] for x in self.forward(img_dummy)])
self.stride = self.head.stride
self.head.initialize_biases()
def forward(self, x):
x = self.net(x)
x = self.fpn(x)
return self.head(list(x))
def fuse(self):
for m in self.modules():
if type(m) is Conv and hasattr(m, "norm"):
m.conv = fuse_conv(m.conv, m.norm)
m.forward = m.fuse_forward
delattr(m, "norm")
return self
def yolo_v11_n(num_classes: int = 80):
csp = [False, True]
depth = [1, 1, 1, 1, 1, 1]
width = [3, 16, 32, 64, 128, 256]
return YOLO(width, depth, csp, num_classes)
def yolo_v11_t(num_classes: int = 80):
csp = [False, True]
depth = [1, 1, 1, 1, 1, 1]
width = [3, 24, 48, 96, 192, 384]
return YOLO(width, depth, csp, num_classes)
def yolo_v11_s(num_classes: int = 80):
csp = [False, True]
depth = [1, 1, 1, 1, 1, 1]
width = [3, 32, 64, 128, 256, 512]
return YOLO(width, depth, csp, num_classes)
def yolo_v11_m(num_classes: int = 80):
csp = [True, True]
depth = [1, 1, 1, 1, 1, 1]
width = [3, 64, 128, 256, 512, 512]
return YOLO(width, depth, csp, num_classes)
def yolo_v11_l(num_classes: int = 80):
csp = [True, True]
depth = [2, 2, 2, 2, 2, 2]
width = [3, 64, 128, 256, 512, 512]
return YOLO(width, depth, csp, num_classes)
def yolo_v11_x(num_classes: int = 80):
csp = [True, True]
depth = [2, 2, 2, 2, 2, 2]
width = [3, 96, 192, 384, 768, 768]
return YOLO(width, depth, csp, num_classes)

18
utils/args.py Normal file
View File

@@ -0,0 +1,18 @@
import argparse
import os
def args_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=10, help="number of rounds of local training")
parser.add_argument("--input_size", type=int, default=640, help="image input size")
parser.add_argument("--config", type=str, default="./config/uav_cfg.yaml", help="Path to YAML config")
args = parser.parse_args()
args.local_rank = int(os.getenv("LOCAL_RANK", 0))
args.world_size = int(os.getenv("WORLD_SIZE", 1))
args.distributed = int(os.getenv("WORLD_SIZE", 1)) > 1
return args

478
utils/dataset.py Normal file
View File

@@ -0,0 +1,478 @@
import math
import os
import random
import cv2
import numpy
import torch
from PIL import Image
from torch.utils import data
FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "JPEG", "JPG", "PNG", "TIFF"
class Dataset(data.Dataset):
params: dict
mosaic: bool
augment: bool
input_size: int
def __init__(self, filenames, input_size: int, params: dict, augment: bool):
self.params = params
self.mosaic = augment
self.augment = augment
self.input_size = input_size
# Read labels
labels = self.load_label(filenames)
self.labels = list(labels.values())
self.filenames = list(labels.keys()) # update
self.n = len(self.filenames) # number of samples
self.indices = range(self.n)
# Albumentations (optional, only used if package is installed)
self.albumentations = Albumentations()
def __getitem__(self, index):
index = self.indices[index]
if self.mosaic and random.random() < self.params["mosaic"]:
# Load MOSAIC
image, label = self.load_mosaic(index, self.params)
# MixUp augmentation
if random.random() < self.params["mix_up"]:
index = random.choice(self.indices)
mix_image1, mix_label1 = image, label
mix_image2, mix_label2 = self.load_mosaic(index, self.params)
image, label = mix_up(mix_image1, mix_label1, mix_image2, mix_label2)
else:
# Load image
image, shape = self.load_image(index)
if image is None:
raise ValueError(f"Failed to load image at index {index}: {self.filenames[index]}")
h, w = image.shape[:2]
# Resize
image, ratio, pad = resize(image, self.input_size, self.augment)
label = self.labels[index].copy()
if label.size:
label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, int(pad[0]), int(pad[1]))
if self.augment:
image, label = random_perspective(image, label, self.params)
nl = len(label) # number of labels
h, w = image.shape[:2]
cls = label[:, 0:1]
box = label[:, 1:5]
box = xy2wh(box, w, h)
if self.augment:
# Albumentations
image, box, cls = self.albumentations(image, box, cls)
nl = len(box) # update after albumentations
# HSV color-space
augment_hsv(image, self.params)
# Flip up-down
if random.random() < self.params["flip_ud"]:
image = numpy.flipud(image)
if nl:
box[:, 1] = 1 - box[:, 1]
# Flip left-right
if random.random() < self.params["flip_lr"]:
image = numpy.fliplr(image)
if nl:
box[:, 0] = 1 - box[:, 0]
# target_cls = torch.zeros((nl, 1))
# target_box = torch.zeros((nl, 4))
# if nl:
# target_cls = torch.from_numpy(cls)
# target_box = torch.from_numpy(box)
# fix [cls, box] empty bug. e.g. [0,1] is illegal in DataLoader collate_fn cat operation
if nl:
target_cls = torch.from_numpy(cls).view(-1, 1).float() # always (N,1)
target_box = torch.from_numpy(box).reshape(-1, 4).float() # always (N,4)
else:
target_cls = torch.zeros((0, 1), dtype=torch.float32)
target_box = torch.zeros((0, 4), dtype=torch.float32)
# Convert HWC to CHW, BGR to RGB
sample = image.transpose((2, 0, 1))[::-1]
sample = numpy.ascontiguousarray(sample)
# init: return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl)
return torch.from_numpy(sample), target_cls, target_box, torch.zeros((nl, 1), dtype=torch.long)
def __len__(self):
return len(self.filenames)
def load_image(self, i):
image = cv2.imread(self.filenames[i])
if image is None:
raise ValueError(f"Image not found or unable to open: {self.filenames[i]}")
h, w = image.shape[:2]
r = self.input_size / max(h, w)
if r != 1:
image = cv2.resize(
image, dsize=(int(w * r), int(h * r)), interpolation=resample() if self.augment else cv2.INTER_LINEAR
)
return image, (h, w)
def load_mosaic(self, index, params):
label4 = []
border = [-self.input_size // 2, -self.input_size // 2]
image4 = numpy.full((self.input_size * 2, self.input_size * 2, 3), 0, dtype=numpy.uint8)
y1a, y2a, x1a, x2a, y1b, y2b, x1b, x2b = (None, None, None, None, None, None, None, None)
xc = int(random.uniform(-border[0], 2 * self.input_size + border[1]))
yc = int(random.uniform(-border[0], 2 * self.input_size + border[1]))
indices = [index] + random.choices(self.indices, k=3)
random.shuffle(indices)
for i, index in enumerate(indices):
# Load image
image, _ = self.load_image(index)
shape = image.shape
if i == 0: # top left
x1a = max(xc - shape[1], 0)
y1a = max(yc - shape[0], 0)
x2a = xc
y2a = yc
x1b = shape[1] - (x2a - x1a)
y1b = shape[0] - (y2a - y1a)
x2b = shape[1]
y2b = shape[0]
if i == 1: # top right
x1a = xc
y1a = max(yc - shape[0], 0)
x2a = min(xc + shape[1], self.input_size * 2)
y2a = yc
x1b = 0
y1b = shape[0] - (y2a - y1a)
x2b = min(shape[1], x2a - x1a)
y2b = shape[0]
if i == 2: # bottom left
x1a = max(xc - shape[1], 0)
y1a = yc
x2a = xc
y2a = min(self.input_size * 2, yc + shape[0])
x1b = shape[1] - (x2a - x1a)
y1b = 0
x2b = shape[1]
y2b = min(y2a - y1a, shape[0])
if i == 3: # bottom right
x1a = xc
y1a = yc
x2a = min(xc + shape[1], self.input_size * 2)
y2a = min(self.input_size * 2, yc + shape[0])
x1b = 0
y1b = 0
x2b = min(shape[1], x2a - x1a)
y2b = min(y2a - y1a, shape[0])
pad_w = (x1a if x1a is not None else 0) - (x1b if x1b is not None else 0)
pad_h = (y1a if y1a is not None else 0) - (y1b if y1b is not None else 0)
image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]
# Labels
label = self.labels[index].copy()
if len(label):
label[:, 1:] = wh2xy(label[:, 1:], shape[1], shape[0], pad_w, pad_h)
label4.append(label)
# Concat/clip labels
label4 = numpy.concatenate(label4, 0)
for x in label4[:, 1:]:
numpy.clip(x, 0, 2 * self.input_size, out=x)
# Augment
image4, label4 = random_perspective(image4, label4, params, border)
return image4, label4
@staticmethod
def collate_fn(batch):
samples, cls, box, indices = zip(*batch)
# ensure empty tensor shape is correct
cls = [c.view(-1, 1) for c in cls]
box = [b.reshape(-1, 4) for b in box]
indices = [i for i in indices]
cls = torch.cat(cls, dim=0) if cls else torch.zeros((0, 1))
box = torch.cat(box, dim=0) if box else torch.zeros((0, 4))
indices = torch.cat(indices, dim=0) if indices else torch.zeros((0,), dtype=torch.long)
new_indices = list(indices)
for i in range(len(indices)):
new_indices[i] += i
indices = torch.cat(new_indices, dim=0)
targets = {"cls": cls, "box": box, "idx": indices}
return torch.stack(samples, dim=0), targets
@staticmethod
def load_label_use_cache(filenames):
path = f"{os.path.dirname(filenames[0])}.cache"
if os.path.exists(path):
return torch.load(path, weights_only=False)
x = {}
for filename in filenames:
try:
# verify images
with open(filename, "rb") as f:
image = Image.open(f)
image.verify() # PIL verify
shape = image.size # image size
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert image.format is not None and image.format.lower() in FORMATS, (
f"invalid image format {image.format}"
)
# verify labels
a = f"{os.sep}images{os.sep}"
b = f"{os.sep}labels{os.sep}"
if os.path.isfile(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"):
with open(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt") as f:
label = [x.split() for x in f.read().strip().splitlines() if len(x)]
label = numpy.array(label, dtype=numpy.float32)
nl = len(label)
if nl:
assert (label >= 0).all()
assert label.shape[1] == 5
assert (label[:, 1:] <= 1).all()
_, i = numpy.unique(label, axis=0, return_index=True)
if len(i) < nl: # duplicate row check
label = label[i] # remove duplicates
else:
label = numpy.zeros((0, 5), dtype=numpy.float32)
else:
label = numpy.zeros((0, 5), dtype=numpy.float32)
except FileNotFoundError:
label = numpy.zeros((0, 5), dtype=numpy.float32)
except AssertionError:
continue
x[filename] = label
torch.save(x, path)
return x
@staticmethod
def load_label(filenames):
x = {}
for filename in filenames:
try:
# verify images
with open(filename, "rb") as f:
image = Image.open(f)
image.verify()
shape = image.size
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert image.format is not None and image.format.lower() in FORMATS, (
f"invalid image format {image.format}"
)
# verify labels
a = f"{os.sep}images{os.sep}"
b = f"{os.sep}labels{os.sep}"
label_path = b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"
if os.path.isfile(label_path):
rows = []
with open(label_path) as f:
for line in f:
parts = line.strip().split()
if len(parts) == 5: # YOLO format
rows.append([float(x) for x in parts])
label = numpy.array(rows, dtype=numpy.float32) if rows else numpy.zeros((0, 5), dtype=numpy.float32)
if label.shape[0]:
assert (label >= 0).all()
assert label.shape[1] == 5
assert (label[:, 1:] <= 1.0001).all()
_, i = numpy.unique(label, axis=0, return_index=True)
label = label[i]
else:
label = numpy.zeros((0, 5), dtype=numpy.float32)
except (FileNotFoundError, AssertionError):
label = numpy.zeros((0, 5), dtype=numpy.float32)
x[filename] = label
return x
def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0):
# Convert nx4 boxes
# from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right
y = numpy.copy(x)
y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + pad_w # top left x
y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + pad_h # top left y
y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + pad_w # bottom right x
y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + pad_h # bottom right y
return y
def xy2wh(x, w, h):
# warning: inplace clip
x[:, [0, 2]] = x[:, [0, 2]].clip(0, w - 1e-3) # x1, x2
x[:, [1, 3]] = x[:, [1, 3]].clip(0, h - 1e-3) # y1, y2
# Convert nx4 boxes
# from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right
y = numpy.copy(x)
y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center
y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center
y[:, 2] = (x[:, 2] - x[:, 0]) / w # width
y[:, 3] = (x[:, 3] - x[:, 1]) / h # height
return y
def resample():
choices = (cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4)
return random.choice(seq=choices)
def augment_hsv(image, params):
# HSV color-space augmentation
h = params["hsv_h"]
s = params["hsv_s"]
v = params["hsv_v"]
r = numpy.random.uniform(-1, 1, 3) * [h, s, v] + 1
h, s, v = cv2.split(cv2.cvtColor(image, cv2.COLOR_BGR2HSV))
x = numpy.arange(0, 256, dtype=r.dtype)
lut_h = ((x * r[0]) % 180).astype("uint8")
lut_s = numpy.clip(x * r[1], 0, 255).astype("uint8")
lut_v = numpy.clip(x * r[2], 0, 255).astype("uint8")
hsv = cv2.merge((cv2.LUT(h, lut_h), cv2.LUT(s, lut_s), cv2.LUT(v, lut_v)))
cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR, dst=image) # no return needed
def resize(image, input_size, augment):
# Resize and pad image while meeting stride-multiple constraints
shape = image.shape[:2] # current shape [height, width]
# Scale ratio (new / old)
r = min(input_size / shape[0], input_size / shape[1])
if not augment: # only scale down, do not scale up (for better val mAP)
r = min(r, 1.0)
# Compute padding
pad = int(round(shape[1] * r)), int(round(shape[0] * r))
w = (input_size - pad[0]) / 2
h = (input_size - pad[1]) / 2
if shape[::-1] != pad: # resize
image = cv2.resize(image, dsize=pad, interpolation=resample() if augment else cv2.INTER_LINEAR)
top, bottom = int(round(h - 0.1)), int(round(h + 0.1))
left, right = int(round(w - 0.1)), int(round(w + 0.1))
image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT) # add border
return image, (r, r), (w, h)
def candidates(box1, box2):
# box1(4,n), box2(4,n)
w1, h1 = box1[2] - box1[0], box1[3] - box1[1]
w2, h2 = box2[2] - box2[0], box2[3] - box2[1]
aspect_ratio = numpy.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio
return (w2 > 2) & (h2 > 2) & (w2 * h2 / (w1 * h1 + 1e-16) > 0.1) & (aspect_ratio < 100)
def random_perspective(image, label, params, border=(0, 0)):
h = image.shape[0] + border[0] * 2
w = image.shape[1] + border[1] * 2
# Center
center = numpy.eye(3)
center[0, 2] = -image.shape[1] / 2 # x translation (pixels)
center[1, 2] = -image.shape[0] / 2 # y translation (pixels)
# Perspective
perspective = numpy.eye(3)
# Rotation and Scale
rotate = numpy.eye(3)
a = random.uniform(-params["degrees"], params["degrees"])
s = random.uniform(1 - params["scale"], 1 + params["scale"])
rotate[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s)
# Shear
shear = numpy.eye(3)
shear[0, 1] = math.tan(random.uniform(-params["shear"], params["shear"]) * math.pi / 180)
shear[1, 0] = math.tan(random.uniform(-params["shear"], params["shear"]) * math.pi / 180)
# Translation
translate = numpy.eye(3)
translate[0, 2] = random.uniform(0.5 - params["translate"], 0.5 + params["translate"]) * w
translate[1, 2] = random.uniform(0.5 - params["translate"], 0.5 + params["translate"]) * h
# Combined rotation matrix, order of operations (right to left) is IMPORTANT
matrix = translate @ shear @ rotate @ perspective @ center
if (border[0] != 0) or (border[1] != 0) or (matrix != numpy.eye(3)).any(): # image changed
image = cv2.warpAffine(image, matrix[:2], dsize=(w, h), borderValue=(0, 0, 0))
# Transform label coordinates
n = len(label)
if n:
xy = numpy.ones((n * 4, 3))
xy[:, :2] = label[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1
xy = xy @ matrix.T # transform
xy = xy[:, :2].reshape(n, 8) # perspective rescale or affine
# create new boxes
x = xy[:, [0, 2, 4, 6]]
y = xy[:, [1, 3, 5, 7]]
box = numpy.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T
# clip
box[:, [0, 2]] = box[:, [0, 2]].clip(0, w)
box[:, [1, 3]] = box[:, [1, 3]].clip(0, h)
# filter candidates
indices = candidates(box1=label[:, 1:5].T * s, box2=box.T)
label = label[indices]
label[:, 1:5] = box[indices]
return image, label
def mix_up(image1, label1, image2, label2):
# Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf
alpha = numpy.random.beta(a=32.0, b=32.0) # mix-up ratio, alpha=beta=32.0
image = (image1 * alpha + image2 * (1 - alpha)).astype(numpy.uint8)
label = numpy.concatenate((label1, label2), 0)
return image, label
class Albumentations:
def __init__(self):
self.transform = None
try:
import albumentations
transforms = [
albumentations.Blur(p=0.01),
albumentations.CLAHE(p=0.01),
albumentations.ToGray(p=0.01),
albumentations.MedianBlur(p=0.01),
]
self.transform = albumentations.Compose(
transforms, albumentations.BboxParams(format="yolo", label_fields=["class_labels"])
)
except ImportError: # package not installed, skip
pass
def __call__(self, image, box, cls):
if self.transform:
x = self.transform(image=image, bboxes=box, class_labels=cls)
image = x["image"]
box = numpy.array(x["bboxes"])
cls = numpy.array(x["class_labels"])
return image, box, cls

250
utils/fed_util.py Normal file
View File

@@ -0,0 +1,250 @@
import os
import re
import random
from collections import defaultdict
from typing import Dict, List, Optional, Set, Any
from nets import nn
def _image_to_label_path(img_path: str) -> str:
"""
Convert an image path like ".../images/train2017/xxx.jpg"
to the corresponding label path ".../labels/train2017/xxx.txt".
Works for POSIX/Windows separators.
"""
# swap "/images/" (or "\images\") to "/labels/"
label_path = re.sub(r"([/\\])images([/\\])", r"\1labels\2", img_path)
# swap extension to .txt
root, _ = os.path.splitext(label_path)
return root + ".txt"
def _parse_yolo_label_file(label_path: str) -> Set[int]:
"""
Return a set of class_ids found in a YOLO .txt label file.
Empty file -> empty set. Missing file -> empty set.
Robust to blank lines / trailing spaces.
"""
class_ids: Set[int] = set()
if not os.path.exists(label_path):
return class_ids
try:
with open(label_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
# YOLO format: cls cx cy w h
parts = line.split()
if not parts:
continue
try:
cls = int(parts[0])
except ValueError:
# handle weird case like '23.0'
try:
cls = int(float(parts[0]))
except ValueError:
# skip malformed line
continue
class_ids.add(cls)
except Exception:
# If the file can't be read for some reason, treat as no labels
return set()
return class_ids
def divide_trainset(
trainset_path: str,
num_local_class: int,
num_client: int,
min_data: int,
max_data: int,
mode: str = "overlap", # "overlap" or "disjoint"
seed: Optional[int] = None,
) -> Dict[str, Any]:
"""
Build a federated split from a YOLO dataset list file.
Args:
trainset_path: path to a .txt file containing one image path per line
e.g. /COCO/images/train2017/1111.jpg
num_local_class: how many distinct classes to sample for each client
num_client: number of clients
min_data: minimum number of images per client
max_data: maximum number of images per client
mode: "overlap" -> images may be shared across clients
"disjoint" -> each image is used by at most one client
seed: optional random seed for reproducibility
Returns:
trainset_divided = {
"users": ["c_00001", ...],
"user_data": {
"c_00001": {"filename": [img_path, ...]},
...
},
"num_samples": [len(list_for_user1), len(list_for_user2), ...]
}
Example:
dataset = divide_trainset(
trainset_path="/COCO/train2017.txt",
num_local_class=3,
num_client=5,
min_data=10,
max_data=20,
mode="disjoint", # or "overlap"
seed=42
)
print(dataset["users"]) # ['c_00001', ..., 'c_00005']
print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
print(dataset["user_data"]["c_00001"]["filename"][:3])
"""
if seed is not None:
random.seed(seed)
# ---- Basic validations (defensive programming) ----
if num_client <= 0:
raise ValueError("num_client must be > 0")
if num_local_class <= 0:
raise ValueError("num_local_class must be > 0")
if min_data < 0 or max_data < 0:
raise ValueError("min_data/max_data must be >= 0")
if max_data < min_data:
raise ValueError("max_data must be >= min_data")
if mode not in {"overlap", "disjoint"}:
raise ValueError('mode must be "overlap" or "disjoint"')
# ---- 1) Read image list ----
with open(trainset_path, "r", encoding="utf-8") as f:
all_images_raw = [ln.strip() for ln in f if ln.strip()]
# Normalize and deduplicate image paths (safe)
all_images: List[str] = []
seen = set()
for p in all_images_raw:
# keep exact string (dont join with cwd), just normalize slashes
norm = os.path.normpath(p)
if norm not in seen:
seen.add(norm)
all_images.append(norm)
# ---- 2) Build mappings from labels ----
class_to_images: Dict[int, Set[str]] = defaultdict(set)
image_to_classes: Dict[str, Set[int]] = {}
missing_label_files = 0
empty_label_files = 0
parsed_images = 0
for img in all_images:
lbl = _image_to_label_path(img)
if not os.path.exists(lbl):
# Missing labels: skip image (no class info)
missing_label_files += 1
continue
classes = _parse_yolo_label_file(lbl)
if not classes:
# No objects in this image -> skip (no class bucket)
empty_label_files += 1
continue
image_to_classes[img] = classes
for c in classes:
class_to_images[c].add(img)
parsed_images += 1
if not class_to_images:
# No usable images found
return {
"users": [f"c_{i + 1:05d}" for i in range(num_client)],
"user_data": {f"c_{i + 1:05d}": {"filename": []} for i in range(num_client)},
"num_samples": [0 for _ in range(num_client)],
}
all_classes: List[int] = sorted(class_to_images.keys())
# Available pool for disjoint mode (only images with labels)
available_images: Set[str] = set(image_to_classes.keys())
# ---- 3) Allocate to clients ----
result = {"users": [], "user_data": {}, "num_samples": []}
for cid in range(num_client):
user_id = f"c_{cid + 1:05d}"
result["users"].append(user_id)
# Pick the classes for this client (sample without replacement from global class set)
k = min(num_local_class, len(all_classes))
chosen_classes = random.sample(all_classes, k) if k > 0 else []
# Decide how many samples for this client
need = min_data if min_data == max_data else random.randint(min_data, max_data)
# Build the candidate pool for this client
if mode == "overlap":
pool_set: Set[str] = set()
for c in chosen_classes:
pool_set.update(class_to_images[c])
else: # "disjoint": restrict to currently available images
pool_set = set()
for c in chosen_classes:
# intersect with available images
pool_set.update(class_to_images[c] & available_images)
# Deduplicate and sample
pool_list = list(pool_set)
if len(pool_list) <= need:
chosen_imgs = pool_list[:] # take all (can be fewer than need)
else:
chosen_imgs = random.sample(pool_list, need)
# Record for the user
result["user_data"][user_id] = {"filename": chosen_imgs}
result["num_samples"].append(len(chosen_imgs))
# If disjoint, remove selected images from availability everywhere
if mode == "disjoint" and chosen_imgs:
for img in chosen_imgs:
if img in available_images:
available_images.remove(img)
# remove from every class bucket this image belongs to
for c in image_to_classes.get(img, []):
if img in class_to_images[c]:
class_to_images[c].remove(img)
# Optional: prune empty classes from all_classes to speed up later loops
# (keep list stable; just skip empties naturally)
# (Optional) You can print some quick diagnostics if helpful:
# print(f"[INFO] Parsed images with labels: {parsed_images}")
# print(f"[INFO] Missing label files: {missing_label_files}")
# print(f"[INFO] Empty label files: {empty_label_files}")
return result
def init_model(model_name, num_classes):
"""
Initialize the model for a specific learning task
Args:
:param model_name: Name of the model
:param num_classes: Number of classes
"""
model = None
if model_name == "yolo_v11_n":
model = nn.yolo_v11_n(num_classes=num_classes)
elif model_name == "yolo_v11_s":
model = nn.yolo_v11_s(num_classes=num_classes)
elif model_name == "yolo_v11_m":
model = nn.yolo_v11_m(num_classes=num_classes)
elif model_name == "yolo_v11_l":
model = nn.yolo_v11_l(num_classes=num_classes)
elif model_name == "yolo_v11_x":
model = nn.yolo_v11_x(num_classes=num_classes)
else:
raise ValueError("Model {} is not supported.".format(model_name))
return model

818
utils/util.py Normal file
View File

@@ -0,0 +1,818 @@
"""
Utility functions for yolo.
"""
import copy
import random
from time import time
import math
import numpy
import torch
import torchvision
from torch.nn.functional import cross_entropy
def setup_seed():
"""
Setup random seed.
"""
random.seed(0)
numpy.random.seed(0)
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
def setup_multi_processes():
"""
Setup multi-processing environment variables.
"""
import cv2
from os import environ
from platform import system
# set multiprocess start method as `fork` to speed up the training
if system() != "Windows":
torch.multiprocessing.set_start_method("fork", force=True)
# disable opencv multithreading to avoid system being overloaded
cv2.setNumThreads(0)
# setup OMP threads
if "OMP_NUM_THREADS" not in environ:
environ["OMP_NUM_THREADS"] = "1"
# setup MKL threads
if "MKL_NUM_THREADS" not in environ:
environ["MKL_NUM_THREADS"] = "1"
def export_onnx(args):
import onnx # noqa
inputs = ["images"]
outputs = ["outputs"]
dynamic = {"outputs": {0: "batch", 1: "anchors"}}
m = torch.load("./weights/best.pt", weights_only=False)["model"].float()
x = torch.zeros((1, 3, args.input_size, args.input_size))
torch.onnx.export(
m.cpu(),
(x.cpu(),),
f="./weights/best.onnx",
verbose=False,
opset_version=12,
# WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False
do_constant_folding=True,
input_names=inputs,
output_names=outputs,
dynamic_axes=dynamic or None,
)
# Checks
model_onnx = onnx.load("./weights/best.onnx") # load onnx model
onnx.checker.check_model(model_onnx) # check onnx model
onnx.save(model_onnx, "./weights/best.onnx")
# Inference example
# https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/autobackend.py
def wh2xy(x):
y = x.clone() if isinstance(x, torch.Tensor) else numpy.copy(x)
y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x
y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y
y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x
y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y
return y
def make_anchors(x, strides, offset=0.5):
assert x is not None
anchor_tensor, stride_tensor = [], []
dtype, device = x[0].dtype, x[0].device
for i, stride in enumerate(strides):
_, _, h, w = x[i].shape
sx = torch.arange(end=w, device=device, dtype=dtype) + offset # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + offset # shift y
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2))
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
return torch.cat(anchor_tensor), torch.cat(stride_tensor)
def compute_metric(output, target, iou_v):
# intersection(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2)
(a1, a2) = target[:, 1:].unsqueeze(1).chunk(2, 2)
(b1, b2) = output[:, :4].unsqueeze(0).chunk(2, 2)
intersection = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2)
# IoU = intersection / (area1 + area2 - intersection)
iou = intersection / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - intersection + 1e-7)
correct = numpy.zeros((output.shape[0], iou_v.shape[0]))
correct = correct.astype(bool)
for i in range(len(iou_v)):
# IoU > threshold and classes match
x = torch.where((iou >= iou_v[i]) & (target[:, 0:1] == output[:, 5]))
if x[0].shape[0]:
matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou]
if x[0].shape[0] > 1:
matches = matches[matches[:, 2].argsort()[::-1]]
matches = matches[numpy.unique(matches[:, 1], return_index=True)[1]]
matches = matches[numpy.unique(matches[:, 0], return_index=True)[1]]
correct[matches[:, 1].astype(int), i] = True
return torch.tensor(correct, dtype=torch.bool, device=output.device)
def non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.65):
max_wh = 7680
max_det = 300
max_nms = 30000
bs = outputs.shape[0] # batch size
nc = outputs.shape[1] - 4 # number of classes
xc = outputs[:, 4 : 4 + nc].amax(1) > confidence_threshold # candidates
# Settings
start = time()
limit = 0.5 + 0.05 * bs # seconds to quit after
output = [torch.zeros((0, 6), device=outputs.device)] * bs
for index, x in enumerate(outputs): # image index, image inference
x = x.transpose(0, -1)[xc[index]] # confidence
# If none remain process next image
if not x.shape[0]:
continue
# matrix nx6 (box, confidence, cls)
box, cls = x.split((4, nc), 1)
box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2)
if nc > 1:
i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float()), dim=1)
else: # best class only
conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold]
# Check shape
n = x.shape[0] # number of boxes
if not n: # no boxes
continue
x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes
# Batched NMS
c = x[:, 5:6] * max_wh # classes
boxes, scores = x[:, :4] + c, x[:, 4] # boxes, scores
indices = torchvision.ops.nms(boxes, scores, iou_threshold) # NMS
indices = indices[:max_det] # limit detections
output[index] = x[indices]
if (time() - start) > limit:
break # time limit exceeded
return output
def smooth(y, f=0.1):
# Box filter of fraction f
nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd)
p = numpy.ones(nf // 2) # ones padding
yp = numpy.concatenate((p * y[0], y, p * y[-1]), 0) # y padded
return numpy.convolve(yp, numpy.ones(nf) / nf, mode="valid") # y-smoothed
def plot_pr_curve(px, py, ap, names, save_dir):
from matplotlib import pyplot
fig, ax = pyplot.subplots(1, 1, figsize=(9, 6), tight_layout=True)
py = numpy.stack(py, axis=1)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py.T):
ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision)
else:
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
ax.plot(
px,
py.mean(1),
linewidth=3,
color="blue",
label="all classes %.3f mAP@0.5" % ap[:, 0].mean(),
)
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title("Precision-Recall Curve")
fig.savefig(save_dir, dpi=250)
pyplot.close(fig)
def plot_curve(px, py, names, save_dir, x_label="Confidence", y_label="Metric"):
from matplotlib import pyplot
figure, ax = pyplot.subplots(1, 1, figsize=(9, 6), tight_layout=True)
if 0 < len(names) < 21: # display per-class legend if < 21 classes
for i, y in enumerate(py):
ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric)
else:
ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
y = smooth(py.mean(0), f=0.05)
ax.plot(
px,
y,
linewidth=3,
color="blue",
label=f"all classes {y.max():.3f} at {px[y.argmax()]:.3f}",
)
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
ax.set_title(f"{y_label}-Confidence Curve")
figure.savefig(save_dir, dpi=250)
pyplot.close(figure)
def compute_ap(tp, conf, output, target, plot=False, names=(), eps=1e-16):
"""
Compute the average precision, given the recall and precision curves.
Source: https://github.com/rafaelpadilla/Object-Detection-Metrics.
# Arguments
tp: True positives (nparray, nx1 or nx10).
conf: Object-ness value from 0-1 (nparray).
output: Predicted object classes (nparray).
target: True object classes (nparray).
# Returns
The average precision
"""
# Sort by object-ness
i = numpy.argsort(-conf)
tp, conf, output = tp[i], conf[i], output[i]
# Find unique classes
unique_classes, nt = numpy.unique(target, return_counts=True)
nc = unique_classes.shape[0] # number of classes, number of detections
# Create Precision-Recall curve and compute AP for each class
p = numpy.zeros((nc, 1000))
r = numpy.zeros((nc, 1000))
ap = numpy.zeros((nc, tp.shape[1]))
px, py = numpy.linspace(start=0, stop=1, num=1000), [] # for plotting
for ci, c in enumerate(unique_classes):
i = output == c
nl = nt[ci] # number of labels
no = i.sum() # number of outputs
if no == 0 or nl == 0:
continue
# Accumulate FPs and TPs
fpc = (1 - tp[i]).cumsum(0)
tpc = tp[i].cumsum(0)
# Recall
recall = tpc / (nl + eps) # recall curve
# negative x, xp because xp decreases
r[ci] = numpy.interp(-px, -conf[i], recall[:, 0], left=0)
# Precision
precision = tpc / (tpc + fpc) # precision curve
p[ci] = numpy.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score
# AP from recall-precision curve
for j in range(tp.shape[1]):
m_rec = numpy.concatenate(([0.0], recall[:, j], [1.0]))
m_pre = numpy.concatenate(([1.0], precision[:, j], [0.0]))
# Compute the precision envelope
m_pre = numpy.flip(numpy.maximum.accumulate(numpy.flip(m_pre)))
# Integrate area under curve
x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO)
ap[ci, j] = numpy.trapz(numpy.interp(x, m_rec, m_pre), x) # integrate
if plot and j == 0:
py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5
# Compute F1 (harmonic mean of precision and recall)
f1 = 2 * p * r / (p + r + eps)
if plot:
names = dict(enumerate(names)) # to dict
names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data
plot_pr_curve(px, py, ap, names, save_dir="./weights/PR_curve.png")
plot_curve(px, f1, names, save_dir="./weights/F1_curve.png", y_label="F1")
plot_curve(px, p, names, save_dir="./weights/P_curve.png", y_label="Precision")
plot_curve(px, r, names, save_dir="./weights/R_curve.png", y_label="Recall")
i = smooth(f1.mean(0), 0.1).argmax() # max F1 index
p, r, f1 = p[:, i], r[:, i], f1[:, i]
tp = (r * nt).round() # true positives
fp = (tp / (p + eps) - tp).round() # false positives
ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95
m_pre, m_rec = p.mean(), r.mean()
map50, mean_ap = ap50.mean(), ap.mean()
return tp, fp, m_pre, m_rec, map50, mean_ap
def compute_iou(box1, box2, eps=1e-7):
# Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4)
# Get the coordinates of bounding boxes
b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1)
b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1)
w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
# Intersection area
inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * (
b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1)
).clamp(0)
# Union Area
union = w1 * h1 + w2 * h2 - inter + eps
# IoU
iou = inter / union
cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width
ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height
c2 = cw**2 + ch**2 + eps # convex diagonal squared
rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2
# https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47
v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2)
with torch.no_grad():
alpha = v / (v - iou + (1 + eps))
return iou - (rho2 / c2 + v * alpha) # CIoU
def strip_optimizer(filename):
x = torch.load(filename, map_location="cpu", weights_only=False)
x["model"].half() # to FP16
for p in x["model"].parameters():
p.requires_grad = False
torch.save(x, f=filename)
def clip_gradients(model, max_norm=10.0):
parameters = model.parameters()
torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm)
def load_weight(model, ckpt):
dst = model.state_dict()
src = torch.load(ckpt, weights_only=False)["model"].float().cpu()
ckpt = {}
for k, v in src.state_dict().items():
if k in dst and v.shape == dst[k].shape:
ckpt[k] = v
model.load_state_dict(state_dict=ckpt, strict=False)
return model
def set_params(model, decay):
p1 = []
p2 = []
norm = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k)
for m in model.modules():
for n, p in m.named_parameters(recurse=0):
if not p.requires_grad:
continue
if n == "bias": # bias (no decay)
p1.append(p)
elif n == "weight" and isinstance(m, norm): # norm-weight (no decay)
p1.append(p)
else:
p2.append(p) # weight (with decay)
return [{"params": p1, "weight_decay": 0.00}, {"params": p2, "weight_decay": decay}]
def plot_lr(args, optimizer, scheduler, num_steps):
from matplotlib import pyplot
optimizer = copy.copy(optimizer)
scheduler = copy.copy(scheduler)
y = []
for epoch in range(args.epochs):
for i in range(num_steps):
step = i + num_steps * epoch
scheduler.step(step, optimizer)
y.append(optimizer.param_groups[0]["lr"])
pyplot.plot(y, ".-", label="LR")
pyplot.xlabel("step")
pyplot.ylabel("LR")
pyplot.grid()
pyplot.xlim(0, args.epochs * num_steps)
pyplot.ylim(0)
pyplot.savefig("./weights/lr.png", dpi=200)
pyplot.close()
class CosineLR:
def __init__(self, args, params, num_steps):
max_lr = params["max_lr"]
min_lr = params["min_lr"]
warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100))
decay_steps = int(args.epochs * num_steps - warmup_steps)
warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps))
decay_lr = []
for step in range(1, decay_steps + 1):
alpha = math.cos(math.pi * step / decay_steps)
decay_lr.append(min_lr + 0.5 * (max_lr - min_lr) * (1 + alpha))
self.total_lr = numpy.concatenate((warmup_lr, decay_lr))
def step(self, step, optimizer):
for param_group in optimizer.param_groups:
param_group["lr"] = self.total_lr[step]
class LinearLR:
def __init__(self, args, params, num_steps):
max_lr = params["max_lr"]
min_lr = params["min_lr"]
warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100))
decay_steps = max(1, int(args.epochs * num_steps - warmup_steps))
warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps), endpoint=False)
decay_lr = numpy.linspace(max_lr, min_lr, decay_steps)
self.total_lr = numpy.concatenate((warmup_lr, decay_lr))
def step(self, step, optimizer):
for param_group in optimizer.param_groups:
param_group["lr"] = self.total_lr[step]
class EMA:
"""
Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models
Keeps a moving average of everything in the model state_dict (parameters and buffers)
For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage
"""
def __init__(self, model, decay=0.9999, tau=2000, updates=0):
# Create EMA
self.ema = copy.deepcopy(model).eval() # FP32 EMA
self.updates = updates # number of EMA updates
# decay exponential ramp (to help early epochs)
self.decay = lambda x: decay * (1 - math.exp(-x / tau))
for p in self.ema.parameters():
p.requires_grad_(False)
def update(self, model):
if hasattr(model, "module"):
model = model.module
# Update EMA parameters
with torch.no_grad():
self.updates += 1
d = self.decay(self.updates)
msd = model.state_dict() # model state_dict
for k, v in self.ema.state_dict().items():
if v.dtype.is_floating_point:
v *= d
v += (1 - d) * msd[k].detach()
class AverageMeter:
def __init__(self):
self.num = 0
self.sum = 0
self.avg = 0
def update(self, v, n):
if not math.isnan(float(v)):
self.num = self.num + n
self.sum = self.sum + v * n
self.avg = self.sum / self.num
class Assigner(torch.nn.Module):
def __init__(self, nc=80, top_k=13, alpha=1.0, beta=6.0, eps=1e-9):
super().__init__()
self.top_k = top_k
self.nc = nc
self.alpha = alpha
self.beta = beta
self.eps = eps
@torch.no_grad()
def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt):
batch_size = pd_scores.size(0)
num_max_boxes = gt_bboxes.size(1)
if num_max_boxes == 0:
device = gt_bboxes.device
return (
torch.zeros_like(pd_bboxes).to(device),
torch.zeros_like(pd_scores).to(device),
torch.zeros_like(pd_scores[..., 0]).to(device),
)
num_anchors = anc_points.shape[0]
shape = gt_bboxes.shape
lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2)
mask_in_gts = torch.cat((anc_points[None] - lt, rb - anc_points[None]), dim=2)
mask_in_gts = mask_in_gts.view(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps)
na = pd_bboxes.shape[-2]
gt_mask = (mask_in_gts * mask_gt).bool() # b, max_num_obj, h*w
overlaps = torch.zeros(
[batch_size, num_max_boxes, na],
dtype=pd_bboxes.dtype,
device=pd_bboxes.device,
)
bbox_scores = torch.zeros(
[batch_size, num_max_boxes, na],
dtype=pd_scores.dtype,
device=pd_scores.device,
)
ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long) # 2, b, max_num_obj
ind[0] = torch.arange(end=batch_size).view(-1, 1).expand(-1, num_max_boxes) # b, max_num_obj
ind[1] = gt_labels.squeeze(-1) # b, max_num_obj
bbox_scores[gt_mask] = pd_scores[ind[0], :, ind[1]][gt_mask] # b, max_num_obj, h*w
pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, num_max_boxes, -1, -1)[gt_mask]
gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[gt_mask]
overlaps[gt_mask] = compute_iou(gt_boxes, pd_boxes).squeeze(-1).clamp_(0)
align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta)
top_k_mask = mask_gt.expand(-1, -1, self.top_k).bool()
top_k_metrics, top_k_indices = torch.topk(align_metric, self.top_k, dim=-1, largest=True)
if top_k_mask is None:
top_k_mask = (top_k_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(top_k_indices)
top_k_indices.masked_fill_(~top_k_mask, 0)
mask_top_k = torch.zeros(align_metric.shape, dtype=torch.int8, device=top_k_indices.device)
ones = torch.ones_like(top_k_indices[:, :, :1], dtype=torch.int8, device=top_k_indices.device)
for k in range(self.top_k):
mask_top_k.scatter_add_(-1, top_k_indices[:, :, k : k + 1], ones)
mask_top_k.masked_fill_(mask_top_k > 1, 0)
mask_top_k = mask_top_k.to(align_metric.dtype)
mask_pos = mask_top_k * mask_in_gts * mask_gt
fg_mask = mask_pos.sum(-2)
if fg_mask.max() > 1:
mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, num_max_boxes, -1)
max_overlaps_idx = overlaps.argmax(1)
is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device)
is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1)
mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float()
fg_mask = mask_pos.sum(-2)
target_gt_idx = mask_pos.argmax(-2)
# Assigned target
index = torch.arange(end=batch_size, dtype=torch.int64, device=gt_labels.device)[..., None]
target_index = target_gt_idx + index * num_max_boxes
target_labels = gt_labels.long().flatten()[target_index]
target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_index]
# Assigned target scores
target_labels.clamp_(0)
target_scores = torch.zeros(
(target_labels.shape[0], target_labels.shape[1], self.nc),
dtype=torch.int64,
device=target_labels.device,
)
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc)
target_scores = torch.where(fg_scores_mask > 0, target_scores, 0)
# Normalize
align_metric *= mask_pos
pos_align_metrics = align_metric.amax(dim=-1, keepdim=True)
pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True)
norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1)
target_scores = target_scores * norm_align_metric
return target_bboxes, target_scores, fg_mask.bool()
class QFL(torch.nn.Module):
def __init__(self, beta=2.0):
super().__init__()
self.beta = beta
self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
def forward(self, outputs, targets):
bce_loss = self.bce_loss(outputs, targets)
return torch.pow(torch.abs(targets - outputs.sigmoid()), self.beta) * bce_loss
class VFL(torch.nn.Module):
def __init__(self, alpha=0.75, gamma=2.00, iou_weighted=True):
super().__init__()
assert alpha >= 0.0
self.alpha = alpha
self.gamma = gamma
self.iou_weighted = iou_weighted
self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
def forward(self, outputs, targets):
assert outputs.size() == targets.size()
targets = targets.type_as(outputs)
if self.iou_weighted:
focal_weight = (
targets * (targets > 0.0).float()
+ self.alpha * (outputs.sigmoid() - targets).abs().pow(self.gamma) * (targets <= 0.0).float()
)
else:
focal_weight = (targets > 0.0).float() + self.alpha * (outputs.sigmoid() - targets).abs().pow(
self.gamma
) * (targets <= 0.0).float()
return self.bce_loss(outputs, targets) * focal_weight
class FocalLoss(torch.nn.Module):
def __init__(self, alpha=0.25, gamma=1.5):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
def forward(self, outputs, targets):
loss = self.bce_loss(outputs, targets)
if self.alpha > 0:
alpha_factor = targets * self.alpha + (1 - targets) * (1 - self.alpha)
loss *= alpha_factor
if self.gamma > 0:
outputs_sigmoid = outputs.sigmoid()
p_t = targets * outputs_sigmoid + (1 - targets) * (1 - outputs_sigmoid)
gamma_factor = (1.0 - p_t) ** self.gamma
loss *= gamma_factor
return loss
class BoxLoss(torch.nn.Module):
def __init__(self, dfl_ch):
super().__init__()
self.dfl_ch = dfl_ch
def forward(
self,
pred_dist,
pred_bboxes,
anchor_points,
target_bboxes,
target_scores,
target_scores_sum,
fg_mask,
):
# IoU loss
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
iou = compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
loss_box = ((1.0 - iou) * weight).sum() / target_scores_sum
# DFL loss
a, b = target_bboxes.chunk(2, -1)
target = torch.cat((anchor_points - a, b - anchor_points), -1)
target = target.clamp(0, self.dfl_ch - 0.01)
loss_dfl = self.df_loss(pred_dist[fg_mask].view(-1, self.dfl_ch + 1), target[fg_mask])
loss_dfl = (loss_dfl * weight).sum() / target_scores_sum
return loss_box, loss_dfl
@staticmethod
def df_loss(pred_dist, target):
# Distribution Focal Loss (DFL)
# https://ieeexplore.ieee.org/document/9792391
tl = target.long() # target left
tr = tl + 1 # target right
wl = tr - target # weight left
wr = 1 - wl # weight right
left_loss = cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape)
right_loss = cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape)
return (left_loss * wl + right_loss * wr).mean(-1, keepdim=True)
class ComputeLoss:
def __init__(self, model, params):
if hasattr(model, "module"):
model = model.module
device = next(model.parameters()).device
m = model.head # Head() module
self.params = params
self.stride = m.stride
self.nc = m.nc
self.no = m.no
self.reg_max = m.ch
self.device = device
self.box_loss = BoxLoss(m.ch - 1).to(device)
self.cls_loss = torch.nn.BCEWithLogitsLoss(reduction="none")
self.assigner = Assigner(nc=self.nc, top_k=10, alpha=0.5, beta=6.0)
self.project = torch.arange(m.ch, dtype=torch.float, device=device)
def box_decode(self, anchor_points, pred_dist):
b, a, c = pred_dist.shape
pred_dist = pred_dist.view(b, a, 4, c // 4)
pred_dist = pred_dist.softmax(3)
pred_dist = pred_dist.matmul(self.project.type(pred_dist.dtype))
lt, rb = pred_dist.chunk(2, -1)
x1y1 = anchor_points - lt
x2y2 = anchor_points + rb
return torch.cat(tensors=(x1y1, x2y2), dim=-1)
def __call__(self, outputs, targets):
x = torch.cat([i.view(outputs[0].shape[0], self.no, -1) for i in outputs], dim=2)
pred_distri, pred_scores = x.split(split_size=(self.reg_max * 4, self.nc), dim=1)
pred_scores = pred_scores.permute(0, 2, 1).contiguous()
pred_distri = pred_distri.permute(0, 2, 1).contiguous()
data_type = pred_scores.dtype
batch_size = pred_scores.shape[0]
input_size = torch.tensor(outputs[0].shape[2:], device=self.device, dtype=data_type) * self.stride[0]
anchor_points, stride_tensor = make_anchors(outputs, self.stride, offset=0.5)
idx = targets["idx"].view(-1, 1)
cls = targets["cls"].view(-1, 1)
box = targets["box"]
targets = torch.cat((idx, cls, box), dim=1).to(self.device)
if targets.shape[0] == 0:
gt = torch.zeros(batch_size, 0, 5, device=self.device)
else:
i = targets[:, 0]
_, counts = i.unique(return_counts=True)
counts = counts.to(dtype=torch.int32)
gt = torch.zeros(batch_size, counts.max(), 5, device=self.device)
for j in range(batch_size):
matches = i == j
n = matches.sum()
if n:
gt[j, :n] = targets[matches, 1:]
x = gt[..., 1:5].mul_(input_size[[1, 0, 1, 0]])
y = torch.empty_like(x)
dw = x[..., 2] / 2 # half-width
dh = x[..., 3] / 2 # half-height
y[..., 0] = x[..., 0] - dw # top left x
y[..., 1] = x[..., 1] - dh # top left y
y[..., 2] = x[..., 0] + dw # bottom right x
y[..., 3] = x[..., 1] + dh # bottom right y
gt[..., 1:5] = y
gt_labels, gt_bboxes = gt.split((1, 4), 2)
mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0)
pred_bboxes = self.box_decode(anchor_points, pred_distri)
assigned_targets = self.assigner(
pred_scores.detach().sigmoid(),
(pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype),
anchor_points * stride_tensor,
gt_labels,
gt_bboxes,
mask_gt,
)
target_bboxes, target_scores, fg_mask = assigned_targets
target_scores_sum = max(target_scores.sum(), 1)
loss_cls = self.cls_loss(pred_scores, target_scores.to(data_type)).sum() / target_scores_sum # BCE
# Box loss
loss_box = torch.zeros(1, device=self.device)
loss_dfl = torch.zeros(1, device=self.device)
if fg_mask.sum():
target_bboxes /= stride_tensor
loss_box, loss_dfl = self.box_loss(
pred_distri,
pred_bboxes,
anchor_points,
target_bboxes,
target_scores,
target_scores_sum,
fg_mask,
)
loss_box *= self.params["box"] # box gain
loss_cls *= self.params["cls"] # cls gain
loss_dfl *= self.params["dfl"] # dfl gain
return loss_box, loss_cls, loss_dfl