From 1ae76d0aedd4db1baedfad32386dc1e6f5f010ad Mon Sep 17 00:00:00 2001 From: TY1667 Date: Thu, 2 Oct 2025 16:26:27 +0800 Subject: [PATCH] Fedavg and YOLOv11 training --- config/coco_cfg.yaml | 126 ++++++ config/uav_cfg.yaml | 47 +++ fed_algo_cs/client_base.py | 233 +++++++++++ fed_algo_cs/server_base.py | 178 ++++++++ fed_run.py | 239 +++++++++++ nets/nn.py | 362 ++++++++++++++++ utils/args.py | 18 + utils/dataset.py | 478 ++++++++++++++++++++++ utils/fed_util.py | 250 ++++++++++++ utils/util.py | 818 +++++++++++++++++++++++++++++++++++++ 10 files changed, 2749 insertions(+) create mode 100644 config/coco_cfg.yaml create mode 100644 config/uav_cfg.yaml create mode 100644 fed_algo_cs/client_base.py create mode 100644 fed_algo_cs/server_base.py create mode 100644 fed_run.py create mode 100644 nets/nn.py create mode 100644 utils/args.py create mode 100644 utils/dataset.py create mode 100644 utils/fed_util.py create mode 100644 utils/util.py diff --git a/config/coco_cfg.yaml b/config/coco_cfg.yaml new file mode 100644 index 0000000..34175e9 --- /dev/null +++ b/config/coco_cfg.yaml @@ -0,0 +1,126 @@ +# global system: +fed_algo: "FedAvg" # federated learning algorithm +model_name: "yolo_v11_n" # yolo_v11_n, yolo_v11_t, yolo_v11_s, yolo_v11_m, yolo_v11_l, yolo_v11_x +i_seed: 202509 # initial random seed + +num_client: 64 # total number of clients +num_round: 5 # total number of communication rounds +num_local_class: 80 # number of classes per client + +res_root: "results" # root directory for results +dataset_path: "/home/image1325/ssd1/dataset/COCO128/" +# train_txt: "train.txt" # path to training set txt file +# val_txt: "val.txt" # path to validation set txt file +# test_txt: "test.txt" # path to test set txt file + +local_batch_size: 32 # local training batch size +val_batch_size: 4 # validation batch size + +num_workers: 4 # number of data loader workers +min_data: 128 # minimum number of images per client +max_data: 128 # maximum number of images per client +partition_mode: "overlap" # "overlap" or "disjoint" +connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients + +# local training: +min_lr: 0.000100000000 # initial learning rate +max_lr: 0.010000000000 # maximum learning rate +momentum: 0.9370000000 # SGD momentum/Adam beta1 +weight_decay: 0.000500 # optimizer weight decay + +warmup_epochs: 3.00000 # warmup epochs +box: 7.500000000000000 # box loss gain +cls: 0.500000000000000 # cls loss gain +dfl: 1.500000000000000 # dfl loss gain +hsv_h: 0.0150000000000 # image HSV-Hue augmentation (fraction) +hsv_s: 0.7000000000000 # image HSV-Saturation augmentation (fraction) +hsv_v: 0.4000000000000 # image HSV-Value augmentation (fraction) +degrees: 0.00000000000 # image rotation (+/- deg) +translate: 0.100000000 # image translation (+/- fraction) +scale: 0.5000000000000 # image scale (+/- gain) +shear: 0.0000000000000 # image shear (+/- deg) +flip_ud: 0.00000000000 # image flip up-down (probability) +flip_lr: 0.50000000000 # image flip left-right (probability) +mosaic: 1.000000000000 # image mosaic (probability) +mix_up: 0.000000000000 # image mix-up (probability) +names: + 0: person + 1: bicycle + 2: car + 3: motorcycle + 4: airplane + 5: bus + 6: train + 7: truck + 8: boat + 9: traffic light + 10: fire hydrant + 11: stop sign + 12: parking meter + 13: bench + 14: bird + 15: cat + 16: dog + 17: horse + 18: sheep + 19: cow + 20: elephant + 21: bear + 22: zebra + 23: giraffe + 24: backpack + 25: umbrella + 26: handbag + 27: tie + 28: suitcase + 29: frisbee + 30: skis + 31: snowboard + 32: sports ball + 33: kite + 34: baseball bat + 35: baseball glove + 36: skateboard + 37: surfboard + 38: tennis racket + 39: bottle + 40: wine glass + 41: cup + 42: fork + 43: knife + 44: spoon + 45: bowl + 46: banana + 47: apple + 48: sandwich + 49: orange + 50: broccoli + 51: carrot + 52: hot dog + 53: pizza + 54: donut + 55: cake + 56: chair + 57: couch + 58: potted plant + 59: bed + 60: dining table + 61: toilet + 62: tv + 63: laptop + 64: mouse + 65: remote + 66: keyboard + 67: cell phone + 68: microwave + 69: oven + 70: toaster + 71: sink + 72: refrigerator + 73: book + 74: clock + 75: vase + 76: scissors + 77: teddy bear + 78: hair drier + 79: toothbrush diff --git a/config/uav_cfg.yaml b/config/uav_cfg.yaml new file mode 100644 index 0000000..74e8f77 --- /dev/null +++ b/config/uav_cfg.yaml @@ -0,0 +1,47 @@ +# global system: +fed_algo: "FedAvg" # federated learning algorithm +model_name: "yolo_v11_n" # yolo_v11_n, yolo_v11_t, yolo_v11_s, yolo_v11_m, yolo_v11_l, yolo_v11_x +i_seed: 202509 # initial random seed + +num_client: 100 # total number of clients +num_round: 500 # total number of communication rounds +num_local_class: 1 # number of classes per client + +res_root: "results" # root directory for results +dataset_path: "/home/image1325/ssd1/dataset/uav/" +# train_txt: "train.txt" # path to training set txt file +# val_txt: "val.txt" # path to validation set txt file +# test_txt: "test.txt" # path to test set txt file + +local_batch_size: 32 # local training batch size +val_batch_size: 16 # validation batch size + +num_workers: 4 # number of data loader workers +min_data: 640 # minimum number of images per client +max_data: 720 # maximum number of images per client +partition_mode: "overlap" # "overlap" or "disjoint" +connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients + +# local training: +min_lr: 0.000100000000 # initial learning rate +max_lr: 0.010000000000 # maximum learning rate +momentum: 0.9370000000 # SGD momentum/Adam beta1 +weight_decay: 0.000500 # optimizer weight decay + +warmup_epochs: 3.00000 # warmup epochs +box: 7.500000000000000 # box loss gain +cls: 0.500000000000000 # cls loss gain +dfl: 1.500000000000000 # dfl loss gain +hsv_h: 0.0150000000000 # image HSV-Hue augmentation (fraction) +hsv_s: 0.7000000000000 # image HSV-Saturation augmentation (fraction) +hsv_v: 0.4000000000000 # image HSV-Value augmentation (fraction) +degrees: 0.00000000000 # image rotation (+/- deg) +translate: 0.100000000 # image translation (+/- fraction) +scale: 0.5000000000000 # image scale (+/- gain) +shear: 0.0000000000000 # image shear (+/- deg) +flip_ud: 0.00000000000 # image flip up-down (probability) +flip_lr: 0.50000000000 # image flip left-right (probability) +mosaic: 1.000000000000 # image mosaic (probability) +mix_up: 0.000000000000 # image mix-up (probability) +names: + 0: uav diff --git a/fed_algo_cs/client_base.py b/fed_algo_cs/client_base.py new file mode 100644 index 0000000..433d6b3 --- /dev/null +++ b/fed_algo_cs/client_base.py @@ -0,0 +1,233 @@ +import numpy as np +import torch +from torch import nn +from torch.utils import data +from torch.amp.autocast_mode import autocast +from utils.fed_util import init_model +from utils import util +from utils.dataset import Dataset +from typing import cast + + +class FedYoloClient(object): + def __init__(self, name, model_name, params): + """ + Initialize the client k for federated learning + Args: + :param name: Name of the client k + :param model_name: Name of the model + :param params: config file including the hyperparameters for local training + - batch_size: Local training batch size in the client k + - num_workers: Number of data loader workers + + - min_lr: Minimum learning rate + - max_lr: Maximum learning rate + - momentum: Momentum for local training + - weight_decay: Weight decay for local training + """ + self.params = params + # initialize the metadata in local client k + self.target_ip = "127.0.0.3" + self.port = 9999 + self.name = name + + # initialize the parameters in local client k + self._batch_size = self.params["local_batch_size"] + self._min_lr = self.params["min_lr"] + self._max_lr = self.params["max_lr"] + self._momentum = self.params["momentum"] + self.num_workers = self.params["num_workers"] + + self.loss_record = [] + # train set length + self.n_data = 0 + + # initialize the local training and testing dataset + self.train_dataset = None + self.val_dataset = None + + # initialize the local model + self._num_classes = len(self.params["names"]) + self._weight_decay = self.params["weight_decay"] + + self.model_name = model_name + self.model = init_model(model_name, self._num_classes) + + model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) + self.parameter_number = sum([np.prod(p.size()) for p in model_parameters]) + + # GPU + self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + def load_trainset(self, train_dataset: list[str]): + """ + Load the local training dataset + Args: + :param train_dataset: Training dataset + """ + self.train_dataset = train_dataset + self.n_data = len(self.train_dataset) + + def update(self, Global_model_state_dict): + """ + Update the local model with the global model parameters + Args: + :param Global_model_state_dict: State dictionary of the global model + """ + + if not hasattr(self, "model") or self.model is None: + self.model = init_model(self.model_name, self._num_classes) + + # load the global model parameters + self.model.load_state_dict(Global_model_state_dict, strict=True) + + def train(self, args): + """ + Train the local model + Args: + :param args: Command line arguments + - local_rank: Local rank for distributed training + - world_size: World size for distributed training + - distributed: Whether to use distributed training + - input_size: Input size for the model + Returns: + :return: Local updated model, number of local data points, training loss + """ + + if args.distributed: + torch.cuda.set_device(device=args.local_rank) + torch.distributed.init_process_group(backend="nccl", init_method="env://") + # print(f"Client {self.name} - distributed training on {world_size} GPUs, local rank: {local_rank}") + # self._device = torch.device("cuda", local_rank) + + if args.local_rank == 0: + pass + # if not os.path.exists("weights"): + # os.makedirs("weights") + + util.setup_seed() + util.setup_multi_processes() + + # model + # init model have been done in __init__() + self.model.to(self._device) + + # Optimizer + accumulate = max(round(64 / (self._batch_size * args.world_size)), 1) + self._weight_decay = self._batch_size * args.world_size * accumulate / 64 + + optimizer = torch.optim.SGD( + util.set_params(self.model, self._weight_decay), + lr=self._min_lr, + momentum=self._momentum, + nesterov=True, + ) + + # EMA + ema = util.EMA(self.model) if args.local_rank == 0 else None + + data_set = Dataset( + filenames=self.train_dataset, + input_size=args.input_size, + params=self.params, + augment=True, + ) + + if args.distributed: + train_sampler = data.DistributedSampler( + data_set, num_replicas=args.world_size, rank=args.local_rank, shuffle=True + ) + else: + train_sampler = None + + loader = data.DataLoader( + data_set, + batch_size=self._batch_size, + shuffle=train_sampler is None, + sampler=train_sampler, + num_workers=self.num_workers, + pin_memory=True, + collate_fn=Dataset.collate_fn, + ) + + # Scheduler + num_steps = max(1, len(loader)) + # print(len(loader)) + scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps) + # DDP mode + if args.distributed: + self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model) + self.model = nn.parallel.DistributedDataParallel( + module=self.model, + device_ids=[args.local_rank], + output_device=args.local_rank, + find_unused_parameters=False, + ) + + amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True) + criterion = util.ComputeLoss(self.model, self.params) + + optimizer.zero_grad(set_to_none=True) + + for epoch in range(args.epochs): + self.model.train() + # when distributed, set epoch for shuffling + if args.distributed and train_sampler is not None: + train_sampler.set_epoch(epoch) + + if args.epochs - epoch == 10: + # disable mosaic augmentation in the last 10 epochs + ds = cast(Dataset, loader.dataset) + ds.mosaic = False + + avg_box_loss = util.AverageMeter() + avg_cls_loss = util.AverageMeter() + avg_dfl_loss = util.AverageMeter() + + for i, (samples, targets) in enumerate(loader): + global_step = i + num_steps * epoch + scheduler.step(step=global_step, optimizer=optimizer) + + samples = samples.cuda(non_blocking=True).float() / 255.0 + + # Forward + with autocast("cuda", enabled=True): + outputs = self.model(samples) + box_loss, cls_loss, dfl_loss = criterion(outputs, targets) + + # meters (use the *unscaled* values) + bs = samples.size(0) + avg_box_loss.update(box_loss.item(), bs) + avg_cls_loss.update(cls_loss.item(), bs) + avg_dfl_loss.update(dfl_loss.item(), bs) + + # scale losses by batch/world if your loss is averaged internally per-sample/device + box_loss = box_loss * self._batch_size * args.world_size + cls_loss = cls_loss * self._batch_size * args.world_size + dfl_loss = dfl_loss * self._batch_size * args.world_size + + total_loss = box_loss + cls_loss + dfl_loss + + # Backward + amp_scale.scale(total_loss).backward() + + # Optimize + if (i + 1) % accumulate == 0: + amp_scale.step(optimizer) + amp_scale.update() + optimizer.zero_grad(set_to_none=True) + if ema: + ema.update(self.model) + + # torch.cuda.synchronize() + + # clean + if args.distributed: + torch.distributed.destroy_process_group() + torch.cuda.empty_cache() + + return ( + self.model.state_dict(), + self.n_data, + {"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg}, + ) diff --git a/fed_algo_cs/server_base.py b/fed_algo_cs/server_base.py new file mode 100644 index 0000000..5dd82b2 --- /dev/null +++ b/fed_algo_cs/server_base.py @@ -0,0 +1,178 @@ +import numpy as np +import torch +from torch.utils.data import DataLoader +from utils.fed_util import init_model +from utils.dataset import Dataset +from utils import util + + +class FedYoloServer(object): + def __init__(self, client_list, model_name, params): + """ + Federated YOLO Server + Args: + client_list: list of connected clients + model_name: YOLO model architecture name + params: dict of hyperparameters (must include 'names') + """ + # Track client updates + self.client_state = {} + self.client_loss = {} + self.client_n_data = {} + self.selected_clients = [] + + self._batch_size = params.get("val_batch_size", 4) + self.client_list = client_list + self.valset = None + + # Federated bookkeeping + self.round = 0 + # Total number of classes + self.n_data = 0 + + # Device + gpu = 0 + self._device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu") + + # Global model + self._num_classes = len(params["names"]) + self.model_name = model_name + self.model = init_model(model_name, self._num_classes) + self.params = params + + def load_valset(self, valset): + """Server loads the validation dataset.""" + self.valset = valset + + def state_dict(self): + """Return global model weights.""" + return self.model.state_dict() + + @torch.no_grad() + def test(self, args): + """ + Evaluate global model on validation set using YOLO metrics (mAP, precision, recall). + Returns: + dict with {"mAP": ..., "mAP50": ..., "precision": ..., "recall": ...} + """ + if self.valset is None: + return {} + + loader = DataLoader( + self.valset, + batch_size=self._batch_size, + shuffle=False, + num_workers=4, + pin_memory=True, + collate_fn=Dataset.collate_fn, + ) + + self.model.to(self._device).eval().half() + + iou_v = torch.linspace(0.5, 0.95, 10).to(self._device) # IoU thresholds + n_iou = iou_v.numel() + metrics = [] + + for samples, targets in loader: + samples = samples.to(self._device).half() / 255.0 + _, _, h, w = samples.shape + scale = torch.tensor((w, h, w, h)).to(self._device) + + outputs = self.model(samples) + outputs = util.non_max_suppression(outputs) + + for i, output in enumerate(outputs): + idx = targets["idx"] == i + cls = targets["cls"][idx].to(self._device) + box = targets["box"][idx].to(self._device) + + metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=self._device) + + if output.shape[0] == 0: + if cls.shape[0]: + metrics.append((metric, *torch.zeros((2, 0), device=self._device), cls.squeeze(-1))) + continue + + if cls.shape[0]: + cls_tensor = cls if isinstance(cls, torch.Tensor) else torch.tensor(cls, device=self._device) + if cls_tensor.dim() == 1: + cls_tensor = cls_tensor.unsqueeze(1) + box_xy = util.wh2xy(box) + if not isinstance(box_xy, torch.Tensor): + box_xy = torch.tensor(box_xy, device=self._device) + target = torch.cat((cls_tensor, box_xy * scale), dim=1) + metric = util.compute_metric(output[:, :6], target, iou_v) + + metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1))) + + # Compute metrics + if not metrics: + return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0} + + metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] + if len(metrics) and metrics[0].any(): + _, _, prec, rec, map50, mean_ap = util.compute_ap(*metrics, names=self.params["names"], plot=False) + else: + prec, rec, map50, mean_ap = 0, 0, 0, 0 + + # Back to float32 for further training + self.model.float() + + return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)} + + def select_clients(self, connection_ratio=1.0): + """Randomly select a fraction of clients.""" + self.selected_clients = [] + self.n_data = 0 + for client_id in self.client_list: + if np.random.rand() <= connection_ratio: + self.selected_clients.append(client_id) + self.n_data += self.client_n_data.get(client_id, 0) + + def agg(self): + """Aggregate client updates (FedAvg).""" + if len(self.selected_clients) == 0 or self.n_data == 0: + return self.model.state_dict(), {}, 0 + + model = init_model(self.model_name, self._num_classes) + model_state = model.state_dict() + + avg_loss = {} + for i, name in enumerate(self.selected_clients): + if name not in self.client_state: + continue + weight = self.client_n_data[name] / self.n_data + for key in model_state.keys(): + if i == 0: + model_state[key] = self.client_state[name][key] * weight + else: + model_state[key] += self.client_state[name][key] * weight + + # Weighted average losses + for k, v in self.client_loss[name].items(): + avg_loss[k] = avg_loss.get(k, 0.0) + v * weight + + self.model.load_state_dict(model_state, strict=True) + self.round += 1 + return model_state, avg_loss, self.n_data + + def rec(self, name, state_dict, n_data, loss_dict): + """ + Receive local update from a client. + Args: + name: client ID + state_dict: state dictionary of the local model + n_data: number of data samples used in local training + loss_dict: dict of losses from local training + """ + self.n_data += n_data + self.client_state[name] = {k: v.cpu() for k, v in state_dict.items()} + self.client_n_data[name] = n_data + self.client_loss[name] = loss_dict + + def flush(self): + """Clear stored client updates.""" + self.n_data = 0 + self.client_state.clear() + self.client_n_data.clear() + self.client_loss.clear() diff --git a/fed_run.py b/fed_run.py new file mode 100644 index 0000000..b41c4de --- /dev/null +++ b/fed_run.py @@ -0,0 +1,239 @@ +#!/usr/bin/env python3 +import os +import json +import yaml +import time +import random +from tqdm import tqdm + +import numpy as np +import torch +import matplotlib.pyplot as plt + +from utils.dataset import Dataset +from fed_algo_cs.client_base import FedYoloClient +from fed_algo_cs.server_base import FedYoloServer +from utils.args import args_parser # your args parser +from utils.fed_util import divide_trainset # divide_trainset is yours + + +def _read_list_file(txt_path: str): + """Read one path per line; keep as-is (absolute or relative).""" + if not txt_path or not os.path.exists(txt_path): + return [] + with open(txt_path, "r", encoding="utf-8") as f: + return [ln.strip() for ln in f if ln.strip()] + + +def _build_valset_if_available(cfg, params): + """ + Try to build a validation Dataset. + - If cfg['val_txt'] exists, use it. + - Else if /val.txt exists, use it. + - Else return None (testing will be skipped). + Args: + cfg: config dict + params: params dict for Dataset + Returns: + Dataset or None + """ + input_size = int(cfg.get("input_size", 640)) + val_txt = cfg.get("val_txt", "") + if not val_txt: + ds_root = cfg.get("dataset_path", "") + guess = os.path.join(ds_root, "val.txt") if ds_root else "" + val_txt = guess if os.path.exists(guess) else "" + + val_files = _read_list_file(val_txt) + if not val_files: + return None + + return Dataset( + filenames=val_files, + input_size=input_size, + params=params, + augment=True, + ) + + +def _seed_everything(seed: int): + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) + + +def _plot_curves(save_dir, hist): + """ + Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round. + """ + os.makedirs(save_dir, exist_ok=True) + rounds = np.arange(1, len(hist["mAP"]) + 1) + + plt.figure() + if hist["mAP"]: + plt.plot(rounds, hist["mAP"], label="mAP50-95") + if hist["mAP50"]: + plt.plot(rounds, hist["mAP50"], label="mAP50") + if hist["precision"]: + plt.plot(rounds, hist["precision"], label="precision") + if hist["recall"]: + plt.plot(rounds, hist["recall"], label="recall") + if hist["train_loss"]: + plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)") + plt.xlabel("Global Round") + plt.ylabel("Metric") + plt.title("Federated YOLO - Server Metrics") + plt.legend() + out_png = os.path.join(save_dir, "fed_yolo_curves.png") + plt.savefig(out_png, dpi=150, bbox_inches="tight") + print(f"[plot] saved: {out_png}") + + +def fed_run(): + """ + Main FL process: + - Initialize clients & server + - For each round: sequential local training -> record -> select -> aggregate + - Test & flush + - Record & save results, plot curves + """ + args_cli = args_parser() + with open(args_cli.config, "r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + + # --- params / config normalization --- + # For convenience we pass the same `params` dict used by Dataset/model/loss. + # Here we re-use the top-level cfg directly as params. + params = dict(cfg) + if "names" in cfg and isinstance(cfg["names"], dict): + # Convert {0: 'uav', 1: 'car', ...} to list if you prefer list + # but we can leave dict; your utils appear to accept dict + pass + + # seeds + _seed_everything(int(cfg.get("i_seed", 0))) + + # --- split clients' train data from a global train list --- + # Expect either cfg["train_txt"] or /train.txt + train_txt = cfg.get("train_txt", "") + if not train_txt: + ds_root = cfg.get("dataset_path", "") + guess = os.path.join(ds_root, "train.txt") if ds_root else "" + train_txt = guess + + if not train_txt or not os.path.exists(train_txt): + raise FileNotFoundError( + f"train.txt not found. Provide --config with 'train_txt' or ensure '{train_txt}' exists." + ) + + split = divide_trainset( + trainset_path=train_txt, + num_local_class=int(cfg.get("num_local_class", 1)), + num_client=int(cfg.get("num_client", 64)), + min_data=int(cfg.get("min_data", 100)), + max_data=int(cfg.get("max_data", 100)), + mode=str(cfg.get("partition_mode", "disjoint")), # "overlap" or "disjoint" + seed=int(cfg.get("i_seed", 0)), + ) + + users = split["users"] + user_data = split["user_data"] # mapping: id -> {"filename": [...]} + + # --- build clients --- + model_name = cfg.get("model_name", "yolo_v11_n") + clients = {} + for uid in users: + c = FedYoloClient(name=uid, model_name=model_name, params=params) + c.load_trainset(user_data[uid]["filename"]) + clients[uid] = c + + # --- build server & optional validation set --- + server = FedYoloServer(client_list=users, model_name=model_name, params=params) + valset = _build_valset_if_available(cfg, params) + # valset is a Dataset class, not data loader + if valset is not None: + server.load_valset(valset) + + # --- push initial global weights --- + global_state = server.state_dict() + + # --- args object for client.train() --- + # args_train = _make_args_for_client(cfg, args_cli) + + # --- history recorder --- + history = { + "mAP": [], + "mAP50": [], + "precision": [], + "recall": [], + "train_loss": [], # scalar sum of client-weighted dict losses + "round_time_sec": [], + } + + # --- main FL loop --- + num_round = int(cfg.get("num_round", 50)) + connection_ratio = float(cfg.get("connection_ratio", 1.0)) # e.g., 1.0 = all clients + res_root = cfg.get("res_root", "results") + os.makedirs(res_root, exist_ok=True) + + for rnd in tqdm(range(num_round), desc="main federal loop round"): + t0 = time.time() + + # Local training (sequential over all users) + for uid in tqdm(users, desc=f"Round {rnd + 1} local training", leave=False): + client = clients[uid] # FedYoloClient instance + client.update(global_state) # load global weights + state_dict, n_data, loss_dict = client.train(args_cli) # local training + server.rec(uid, state_dict, n_data, loss_dict) + + # Select a fraction for aggregation (FedAvg subset if desired) + server.select_clients(connection_ratio=connection_ratio) + + # Aggregate + global_state, avg_loss_dict, _ = server.agg() + + # Compute a scalar train loss for plotting (sum of components) + scalar_train_loss = float(sum(avg_loss_dict.values())) if avg_loss_dict else 0.0 + + # Test (if valset provided) + test_metrics = server.test(args_cli) if server.valset is not None else {} + mAP = float(test_metrics.get("mAP", 0.0)) + mAP50 = float(test_metrics.get("mAP50", 0.0)) + precision = float(test_metrics.get("precision", 0.0)) + recall = float(test_metrics.get("recall", 0.0)) + + # Flush per-round client caches + server.flush() + + # Record & log + history["mAP"].append(mAP) + history["mAP50"].append(mAP50) + history["precision"].append(precision) + history["recall"].append(recall) + history["train_loss"].append(scalar_train_loss) + history["round_time_sec"].append(time.time() - t0) + + print( + f"[round {rnd + 1:04d}] " + f"loss={scalar_train_loss:.4f} mAP50-95={mAP:.4f} mAP50={mAP50:.4f} " + f"P={precision:.4f} R={recall:.4f}" + ) + + # Save running JSON (resumable logs) + save_name = ( + f"[{cfg.get('fed_algo', 'FedAvg')},{cfg.get('model_name', 'yolo')}," + f"{cfg.get('num_local_epoch', cfg.get('client', {}).get('num_local_epoch', 1))}," + f"{cfg.get('num_local_class', 2)}," + f"{cfg.get('i_seed', 0)}]" + ) + out_json = os.path.join(res_root, save_name + ".json") + with open(out_json, "w", encoding="utf-8") as f: + json.dump(history, f, indent=2) + + # --- final plot --- + _plot_curves(res_root, history) + print("[done] training complete.") + + +if __name__ == "__main__": + fed_run() diff --git a/nets/nn.py b/nets/nn.py new file mode 100644 index 0000000..caafb81 --- /dev/null +++ b/nets/nn.py @@ -0,0 +1,362 @@ +""" +This file contains the model definition of YOLOv11 +""" + +import math + +import torch + +from utils.util import make_anchors + + +def fuse_conv(conv, norm): + fused_conv = ( + torch.nn.Conv2d( + conv.in_channels, + conv.out_channels, + kernel_size=conv.kernel_size, + stride=conv.stride, + padding=conv.padding, + groups=conv.groups, + bias=True, + ) + .requires_grad_(False) + .to(conv.weight.device) + ) + + w_conv = conv.weight.clone().view(conv.out_channels, -1) + w_norm = torch.diag(norm.weight.div(torch.sqrt(norm.eps + norm.running_var))) + fused_conv.weight.copy_(torch.mm(w_norm, w_conv).view(fused_conv.weight.size())) + + b_conv = torch.zeros(conv.weight.size(0), device=conv.weight.device) if conv.bias is None else conv.bias + b_norm = norm.bias - norm.weight.mul(norm.running_mean).div(torch.sqrt(norm.running_var + norm.eps)) + if fused_conv.bias is not None: + fused_conv.bias.copy_(torch.mm(w_norm, b_conv.reshape(-1, 1)).reshape(-1) + b_norm) + + return fused_conv + + +class Conv(torch.nn.Module): + def __init__(self, in_ch, out_ch, activation, k=1, s=1, p=0, g=1): + super().__init__() + self.conv = torch.nn.Conv2d(in_ch, out_ch, k, s, p, groups=g, bias=False) + self.norm = torch.nn.BatchNorm2d(out_ch, eps=0.001, momentum=0.03) + self.relu = activation + + def forward(self, x): + return self.relu(self.norm(self.conv(x))) + + def fuse_forward(self, x): + return self.relu(self.conv(x)) + + +class Residual(torch.nn.Module): + def __init__(self, ch, e=0.5): + super().__init__() + self.conv1 = Conv(ch, int(ch * e), torch.nn.SiLU(), k=3, p=1) + self.conv2 = Conv(int(ch * e), ch, torch.nn.SiLU(), k=3, p=1) + + def forward(self, x): + return x + self.conv2(self.conv1(x)) + + +class CSPModule(torch.nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.conv1 = Conv(in_ch, out_ch // 2, torch.nn.SiLU()) + self.conv2 = Conv(in_ch, out_ch // 2, torch.nn.SiLU()) + self.conv3 = Conv(2 * (out_ch // 2), out_ch, torch.nn.SiLU()) + self.res_m = torch.nn.Sequential(Residual(out_ch // 2, e=1.0), Residual(out_ch // 2, e=1.0)) + + def forward(self, x): + y = self.res_m(self.conv1(x)) + return self.conv3(torch.cat((y, self.conv2(x)), dim=1)) + + +class CSP(torch.nn.Module): + def __init__(self, in_ch, out_ch, n, csp, r): + super().__init__() + self.conv1 = Conv(in_ch, 2 * (out_ch // r), torch.nn.SiLU()) + self.conv2 = Conv((2 + n) * (out_ch // r), out_ch, torch.nn.SiLU()) + + if not csp: + self.res_m = torch.nn.ModuleList(Residual(out_ch // r) for _ in range(n)) + else: + self.res_m = torch.nn.ModuleList(CSPModule(out_ch // r, out_ch // r) for _ in range(n)) + + def forward(self, x): + y = list(self.conv1(x).chunk(2, 1)) + y.extend(m(y[-1]) for m in self.res_m) + return self.conv2(torch.cat(y, dim=1)) + + +class SPP(torch.nn.Module): + def __init__(self, in_ch, out_ch, k=5): + super().__init__() + self.conv1 = Conv(in_ch, in_ch // 2, torch.nn.SiLU()) + self.conv2 = Conv(in_ch * 2, out_ch, torch.nn.SiLU()) + self.res_m = torch.nn.MaxPool2d(k, stride=1, padding=k // 2) + + def forward(self, x): + x = self.conv1(x) + y1 = self.res_m(x) + y2 = self.res_m(y1) + return self.conv2(torch.cat(tensors=[x, y1, y2, self.res_m(y2)], dim=1)) + + +class Attention(torch.nn.Module): + def __init__(self, ch, num_head): + super().__init__() + self.num_head = num_head + self.dim_head = ch // num_head + self.dim_key = self.dim_head // 2 + self.scale = self.dim_key**-0.5 + + self.qkv = Conv(ch, ch + self.dim_key * num_head * 2, torch.nn.Identity()) + + self.conv1 = Conv(ch, ch, torch.nn.Identity(), k=3, p=1, g=ch) + self.conv2 = Conv(ch, ch, torch.nn.Identity()) + + def forward(self, x): + b, c, h, w = x.shape + + qkv = self.qkv(x) + qkv = qkv.view(b, self.num_head, self.dim_key * 2 + self.dim_head, h * w) + + q, k, v = qkv.split([self.dim_key, self.dim_key, self.dim_head], dim=2) + + attn = (q.transpose(-2, -1) @ k) * self.scale + attn = attn.softmax(dim=-1) + + x = (v @ attn.transpose(-2, -1)).view(b, c, h, w) + self.conv1(v.reshape(b, c, h, w)) + return self.conv2(x) + + +class PSABlock(torch.nn.Module): + def __init__(self, ch, num_head): + super().__init__() + self.conv1 = Attention(ch, num_head) + self.conv2 = torch.nn.Sequential(Conv(ch, ch * 2, torch.nn.SiLU()), Conv(ch * 2, ch, torch.nn.Identity())) + + def forward(self, x): + x = x + self.conv1(x) + return x + self.conv2(x) + + +class PSA(torch.nn.Module): + def __init__(self, ch, n): + super().__init__() + self.conv1 = Conv(ch, 2 * (ch // 2), torch.nn.SiLU()) + self.conv2 = Conv(2 * (ch // 2), ch, torch.nn.SiLU()) + self.res_m = torch.nn.Sequential(*(PSABlock(ch // 2, ch // 128) for _ in range(n))) + + def forward(self, x): + x, y = self.conv1(x).chunk(2, 1) + return self.conv2(torch.cat(tensors=(x, self.res_m(y)), dim=1)) + + +class DarkNet(torch.nn.Module): + def __init__(self, width, depth, csp): + super().__init__() + self.p1 = [] + self.p2 = [] + self.p3 = [] + self.p4 = [] + self.p5 = [] + + # p1/2 + self.p1.append(Conv(width[0], width[1], torch.nn.SiLU(), k=3, s=2, p=1)) + # p2/4 + self.p2.append(Conv(width[1], width[2], torch.nn.SiLU(), k=3, s=2, p=1)) + self.p2.append(CSP(width[2], width[3], depth[0], csp[0], r=4)) + # p3/8 + self.p3.append(Conv(width[3], width[3], torch.nn.SiLU(), k=3, s=2, p=1)) + self.p3.append(CSP(width[3], width[4], depth[1], csp[0], r=4)) + # p4/16 + self.p4.append(Conv(width[4], width[4], torch.nn.SiLU(), k=3, s=2, p=1)) + self.p4.append(CSP(width[4], width[4], depth[2], csp[1], r=2)) + # p5/32 + self.p5.append(Conv(width[4], width[5], torch.nn.SiLU(), k=3, s=2, p=1)) + self.p5.append(CSP(width[5], width[5], depth[3], csp[1], r=2)) + self.p5.append(SPP(width[5], width[5])) + self.p5.append(PSA(width[5], depth[4])) + + self.p1 = torch.nn.Sequential(*self.p1) + self.p2 = torch.nn.Sequential(*self.p2) + self.p3 = torch.nn.Sequential(*self.p3) + self.p4 = torch.nn.Sequential(*self.p4) + self.p5 = torch.nn.Sequential(*self.p5) + + def forward(self, x): + p1 = self.p1(x) + p2 = self.p2(p1) + p3 = self.p3(p2) + p4 = self.p4(p3) + p5 = self.p5(p4) + return p3, p4, p5 + + +class DarkFPN(torch.nn.Module): + def __init__(self, width, depth, csp): + super().__init__() + self.up = torch.nn.Upsample(scale_factor=2) + self.h1 = CSP(width[4] + width[5], width[4], depth[5], csp[0], r=2) + self.h2 = CSP(width[4] + width[4], width[3], depth[5], csp[0], r=2) + self.h3 = Conv(width[3], width[3], torch.nn.SiLU(), k=3, s=2, p=1) + self.h4 = CSP(width[3] + width[4], width[4], depth[5], csp[0], r=2) + self.h5 = Conv(width[4], width[4], torch.nn.SiLU(), k=3, s=2, p=1) + self.h6 = CSP(width[4] + width[5], width[5], depth[5], csp[1], r=2) + + def forward(self, x): + p3, p4, p5 = x + p4 = self.h1(torch.cat(tensors=[self.up(p5), p4], dim=1)) + p3 = self.h2(torch.cat(tensors=[self.up(p4), p3], dim=1)) + p4 = self.h4(torch.cat(tensors=[self.h3(p3), p4], dim=1)) + p5 = self.h6(torch.cat(tensors=[self.h5(p4), p5], dim=1)) + return p3, p4, p5 + + +class DFL(torch.nn.Module): + # Generalized Focal Loss + # https://ieeexplore.ieee.org/document/9792391 + def __init__(self, ch=16): + super().__init__() + self.ch = ch + self.conv = torch.nn.Conv2d(ch, out_channels=1, kernel_size=1, bias=False).requires_grad_(False) + x = torch.arange(ch, dtype=torch.float).view(1, ch, 1, 1) + self.conv.weight.data[:] = torch.nn.Parameter(x) + + def forward(self, x): + b, c, a = x.shape + x = x.view(b, 4, self.ch, a).transpose(2, 1) + return self.conv(x.softmax(1)).view(b, 4, a) + + +class Head(torch.nn.Module): + anchors = torch.empty(0) + strides = torch.empty(0) + + def __init__(self, nc=80, filters=()): + super().__init__() + self.ch = 16 # DFL channels + self.nc = nc # number of classes + self.nl = len(filters) # number of detection layers + self.no = nc + self.ch * 4 # number of outputs per anchor + self.stride = torch.zeros(self.nl) # strides computed during build + + box = max(64, filters[0] // 4) + cls = max(80, filters[0], self.nc) + + self.dfl = DFL(self.ch) + self.box = torch.nn.ModuleList( + torch.nn.Sequential( + Conv(x, box, torch.nn.SiLU(), k=3, p=1), + Conv(box, box, torch.nn.SiLU(), k=3, p=1), + torch.nn.Conv2d(box, out_channels=4 * self.ch, kernel_size=1), + ) + for x in filters + ) + self.cls = torch.nn.ModuleList( + torch.nn.Sequential( + Conv(x, x, torch.nn.SiLU(), k=3, p=1, g=x), + Conv(x, cls, torch.nn.SiLU()), + Conv(cls, cls, torch.nn.SiLU(), k=3, p=1, g=cls), + Conv(cls, cls, torch.nn.SiLU()), + torch.nn.Conv2d(cls, out_channels=self.nc, kernel_size=1), + ) + for x in filters + ) + + def forward(self, x): + for i, (box, cls) in enumerate(zip(self.box, self.cls)): + x[i] = torch.cat(tensors=(box(x[i]), cls(x[i])), dim=1) + if self.training: + return x + + self.anchors, self.strides = (i.transpose(0, 1) for i in make_anchors(x, self.stride)) + x = torch.cat([i.view(x[0].shape[0], self.no, -1) for i in x], dim=2) + box, cls = x.split(split_size=(4 * self.ch, self.nc), dim=1) + + a, b = self.dfl(box).chunk(2, 1) + a = self.anchors.unsqueeze(0) - a + b = self.anchors.unsqueeze(0) + b + box = torch.cat(tensors=((a + b) / 2, b - a), dim=1) + + return torch.cat(tensors=(box * self.strides, cls.sigmoid()), dim=1) + + def initialize_biases(self): + # Initialize biases + # WARNING: requires stride availability + for box, cls, s in zip(self.box, self.cls, self.stride): + # box + box[-1].bias.data[:] = 1.0 + # cls (.01 objects, 80 classes, 640 image) + cls[-1].bias.data[: self.nc] = math.log(5 / self.nc / (640 / s) ** 2) + + +class YOLO(torch.nn.Module): + def __init__(self, width, depth, csp, num_classes): + super().__init__() + self.net = DarkNet(width, depth, csp) + self.fpn = DarkFPN(width, depth, csp) + + img_dummy = torch.zeros(1, width[0], 256, 256) + self.head = Head(num_classes, (width[3], width[4], width[5])) + self.head.stride = torch.tensor([256 / x.shape[-2] for x in self.forward(img_dummy)]) + self.stride = self.head.stride + self.head.initialize_biases() + + def forward(self, x): + x = self.net(x) + x = self.fpn(x) + return self.head(list(x)) + + def fuse(self): + for m in self.modules(): + if type(m) is Conv and hasattr(m, "norm"): + m.conv = fuse_conv(m.conv, m.norm) + m.forward = m.fuse_forward + delattr(m, "norm") + return self + + +def yolo_v11_n(num_classes: int = 80): + csp = [False, True] + depth = [1, 1, 1, 1, 1, 1] + width = [3, 16, 32, 64, 128, 256] + return YOLO(width, depth, csp, num_classes) + + +def yolo_v11_t(num_classes: int = 80): + csp = [False, True] + depth = [1, 1, 1, 1, 1, 1] + width = [3, 24, 48, 96, 192, 384] + return YOLO(width, depth, csp, num_classes) + + +def yolo_v11_s(num_classes: int = 80): + csp = [False, True] + depth = [1, 1, 1, 1, 1, 1] + width = [3, 32, 64, 128, 256, 512] + return YOLO(width, depth, csp, num_classes) + + +def yolo_v11_m(num_classes: int = 80): + csp = [True, True] + depth = [1, 1, 1, 1, 1, 1] + width = [3, 64, 128, 256, 512, 512] + return YOLO(width, depth, csp, num_classes) + + +def yolo_v11_l(num_classes: int = 80): + csp = [True, True] + depth = [2, 2, 2, 2, 2, 2] + width = [3, 64, 128, 256, 512, 512] + return YOLO(width, depth, csp, num_classes) + + +def yolo_v11_x(num_classes: int = 80): + csp = [True, True] + depth = [2, 2, 2, 2, 2, 2] + width = [3, 96, 192, 384, 768, 768] + return YOLO(width, depth, csp, num_classes) diff --git a/utils/args.py b/utils/args.py new file mode 100644 index 0000000..b965f7f --- /dev/null +++ b/utils/args.py @@ -0,0 +1,18 @@ +import argparse +import os + + +def args_parser(): + parser = argparse.ArgumentParser() + + parser.add_argument("--epochs", type=int, default=10, help="number of rounds of local training") + parser.add_argument("--input_size", type=int, default=640, help="image input size") + parser.add_argument("--config", type=str, default="./config/uav_cfg.yaml", help="Path to YAML config") + + args = parser.parse_args() + + args.local_rank = int(os.getenv("LOCAL_RANK", 0)) + args.world_size = int(os.getenv("WORLD_SIZE", 1)) + args.distributed = int(os.getenv("WORLD_SIZE", 1)) > 1 + + return args diff --git a/utils/dataset.py b/utils/dataset.py new file mode 100644 index 0000000..e3d71f6 --- /dev/null +++ b/utils/dataset.py @@ -0,0 +1,478 @@ +import math +import os +import random + +import cv2 +import numpy +import torch +from PIL import Image +from torch.utils import data + +FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "JPEG", "JPG", "PNG", "TIFF" + + +class Dataset(data.Dataset): + params: dict + mosaic: bool + augment: bool + input_size: int + + def __init__(self, filenames, input_size: int, params: dict, augment: bool): + self.params = params + self.mosaic = augment + self.augment = augment + self.input_size = input_size + + # Read labels + labels = self.load_label(filenames) + self.labels = list(labels.values()) + self.filenames = list(labels.keys()) # update + self.n = len(self.filenames) # number of samples + self.indices = range(self.n) + # Albumentations (optional, only used if package is installed) + self.albumentations = Albumentations() + + def __getitem__(self, index): + index = self.indices[index] + + if self.mosaic and random.random() < self.params["mosaic"]: + # Load MOSAIC + image, label = self.load_mosaic(index, self.params) + # MixUp augmentation + if random.random() < self.params["mix_up"]: + index = random.choice(self.indices) + mix_image1, mix_label1 = image, label + mix_image2, mix_label2 = self.load_mosaic(index, self.params) + + image, label = mix_up(mix_image1, mix_label1, mix_image2, mix_label2) + else: + # Load image + image, shape = self.load_image(index) + if image is None: + raise ValueError(f"Failed to load image at index {index}: {self.filenames[index]}") + h, w = image.shape[:2] + + # Resize + image, ratio, pad = resize(image, self.input_size, self.augment) + + label = self.labels[index].copy() + if label.size: + label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, int(pad[0]), int(pad[1])) + if self.augment: + image, label = random_perspective(image, label, self.params) + + nl = len(label) # number of labels + h, w = image.shape[:2] + cls = label[:, 0:1] + box = label[:, 1:5] + box = xy2wh(box, w, h) + + if self.augment: + # Albumentations + image, box, cls = self.albumentations(image, box, cls) + nl = len(box) # update after albumentations + # HSV color-space + augment_hsv(image, self.params) + # Flip up-down + if random.random() < self.params["flip_ud"]: + image = numpy.flipud(image) + if nl: + box[:, 1] = 1 - box[:, 1] + # Flip left-right + if random.random() < self.params["flip_lr"]: + image = numpy.fliplr(image) + if nl: + box[:, 0] = 1 - box[:, 0] + + # target_cls = torch.zeros((nl, 1)) + # target_box = torch.zeros((nl, 4)) + # if nl: + # target_cls = torch.from_numpy(cls) + # target_box = torch.from_numpy(box) + + # fix [cls, box] empty bug. e.g. [0,1] is illegal in DataLoader collate_fn cat operation + if nl: + target_cls = torch.from_numpy(cls).view(-1, 1).float() # always (N,1) + target_box = torch.from_numpy(box).reshape(-1, 4).float() # always (N,4) + else: + target_cls = torch.zeros((0, 1), dtype=torch.float32) + target_box = torch.zeros((0, 4), dtype=torch.float32) + + # Convert HWC to CHW, BGR to RGB + sample = image.transpose((2, 0, 1))[::-1] + sample = numpy.ascontiguousarray(sample) + + # init: return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl) + return torch.from_numpy(sample), target_cls, target_box, torch.zeros((nl, 1), dtype=torch.long) + + def __len__(self): + return len(self.filenames) + + def load_image(self, i): + image = cv2.imread(self.filenames[i]) + if image is None: + raise ValueError(f"Image not found or unable to open: {self.filenames[i]}") + h, w = image.shape[:2] + r = self.input_size / max(h, w) + if r != 1: + image = cv2.resize( + image, dsize=(int(w * r), int(h * r)), interpolation=resample() if self.augment else cv2.INTER_LINEAR + ) + return image, (h, w) + + def load_mosaic(self, index, params): + label4 = [] + border = [-self.input_size // 2, -self.input_size // 2] + image4 = numpy.full((self.input_size * 2, self.input_size * 2, 3), 0, dtype=numpy.uint8) + y1a, y2a, x1a, x2a, y1b, y2b, x1b, x2b = (None, None, None, None, None, None, None, None) + + xc = int(random.uniform(-border[0], 2 * self.input_size + border[1])) + yc = int(random.uniform(-border[0], 2 * self.input_size + border[1])) + + indices = [index] + random.choices(self.indices, k=3) + random.shuffle(indices) + + for i, index in enumerate(indices): + # Load image + image, _ = self.load_image(index) + shape = image.shape + if i == 0: # top left + x1a = max(xc - shape[1], 0) + y1a = max(yc - shape[0], 0) + x2a = xc + y2a = yc + x1b = shape[1] - (x2a - x1a) + y1b = shape[0] - (y2a - y1a) + x2b = shape[1] + y2b = shape[0] + if i == 1: # top right + x1a = xc + y1a = max(yc - shape[0], 0) + x2a = min(xc + shape[1], self.input_size * 2) + y2a = yc + x1b = 0 + y1b = shape[0] - (y2a - y1a) + x2b = min(shape[1], x2a - x1a) + y2b = shape[0] + if i == 2: # bottom left + x1a = max(xc - shape[1], 0) + y1a = yc + x2a = xc + y2a = min(self.input_size * 2, yc + shape[0]) + x1b = shape[1] - (x2a - x1a) + y1b = 0 + x2b = shape[1] + y2b = min(y2a - y1a, shape[0]) + if i == 3: # bottom right + x1a = xc + y1a = yc + x2a = min(xc + shape[1], self.input_size * 2) + y2a = min(self.input_size * 2, yc + shape[0]) + x1b = 0 + y1b = 0 + x2b = min(shape[1], x2a - x1a) + y2b = min(y2a - y1a, shape[0]) + + pad_w = (x1a if x1a is not None else 0) - (x1b if x1b is not None else 0) + pad_h = (y1a if y1a is not None else 0) - (y1b if y1b is not None else 0) + image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b] + + # Labels + label = self.labels[index].copy() + if len(label): + label[:, 1:] = wh2xy(label[:, 1:], shape[1], shape[0], pad_w, pad_h) + label4.append(label) + + # Concat/clip labels + label4 = numpy.concatenate(label4, 0) + for x in label4[:, 1:]: + numpy.clip(x, 0, 2 * self.input_size, out=x) + + # Augment + image4, label4 = random_perspective(image4, label4, params, border) + + return image4, label4 + + @staticmethod + def collate_fn(batch): + samples, cls, box, indices = zip(*batch) + + # ensure empty tensor shape is correct + cls = [c.view(-1, 1) for c in cls] + box = [b.reshape(-1, 4) for b in box] + indices = [i for i in indices] + + cls = torch.cat(cls, dim=0) if cls else torch.zeros((0, 1)) + box = torch.cat(box, dim=0) if box else torch.zeros((0, 4)) + indices = torch.cat(indices, dim=0) if indices else torch.zeros((0,), dtype=torch.long) + + new_indices = list(indices) + for i in range(len(indices)): + new_indices[i] += i + indices = torch.cat(new_indices, dim=0) + + targets = {"cls": cls, "box": box, "idx": indices} + return torch.stack(samples, dim=0), targets + + @staticmethod + def load_label_use_cache(filenames): + path = f"{os.path.dirname(filenames[0])}.cache" + if os.path.exists(path): + return torch.load(path, weights_only=False) + x = {} + for filename in filenames: + try: + # verify images + with open(filename, "rb") as f: + image = Image.open(f) + image.verify() # PIL verify + shape = image.size # image size + assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" + assert image.format is not None and image.format.lower() in FORMATS, ( + f"invalid image format {image.format}" + ) + + # verify labels + a = f"{os.sep}images{os.sep}" + b = f"{os.sep}labels{os.sep}" + + if os.path.isfile(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"): + with open(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt") as f: + label = [x.split() for x in f.read().strip().splitlines() if len(x)] + label = numpy.array(label, dtype=numpy.float32) + nl = len(label) + if nl: + assert (label >= 0).all() + assert label.shape[1] == 5 + assert (label[:, 1:] <= 1).all() + _, i = numpy.unique(label, axis=0, return_index=True) + if len(i) < nl: # duplicate row check + label = label[i] # remove duplicates + else: + label = numpy.zeros((0, 5), dtype=numpy.float32) + else: + label = numpy.zeros((0, 5), dtype=numpy.float32) + except FileNotFoundError: + label = numpy.zeros((0, 5), dtype=numpy.float32) + except AssertionError: + continue + x[filename] = label + torch.save(x, path) + return x + + @staticmethod + def load_label(filenames): + x = {} + for filename in filenames: + try: + # verify images + with open(filename, "rb") as f: + image = Image.open(f) + image.verify() + shape = image.size + assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" + assert image.format is not None and image.format.lower() in FORMATS, ( + f"invalid image format {image.format}" + ) + + # verify labels + a = f"{os.sep}images{os.sep}" + b = f"{os.sep}labels{os.sep}" + label_path = b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt" + + if os.path.isfile(label_path): + rows = [] + with open(label_path) as f: + for line in f: + parts = line.strip().split() + if len(parts) == 5: # YOLO format + rows.append([float(x) for x in parts]) + label = numpy.array(rows, dtype=numpy.float32) if rows else numpy.zeros((0, 5), dtype=numpy.float32) + + if label.shape[0]: + assert (label >= 0).all() + assert label.shape[1] == 5 + assert (label[:, 1:] <= 1.0001).all() + _, i = numpy.unique(label, axis=0, return_index=True) + label = label[i] + else: + label = numpy.zeros((0, 5), dtype=numpy.float32) + + except (FileNotFoundError, AssertionError): + label = numpy.zeros((0, 5), dtype=numpy.float32) + + x[filename] = label + return x + + +def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0): + # Convert nx4 boxes + # from [x, y, w, h] normalized to [x1, y1, x2, y2] where xy1=top-left, xy2=bottom-right + y = numpy.copy(x) + y[:, 0] = w * (x[:, 0] - x[:, 2] / 2) + pad_w # top left x + y[:, 1] = h * (x[:, 1] - x[:, 3] / 2) + pad_h # top left y + y[:, 2] = w * (x[:, 0] + x[:, 2] / 2) + pad_w # bottom right x + y[:, 3] = h * (x[:, 1] + x[:, 3] / 2) + pad_h # bottom right y + return y + + +def xy2wh(x, w, h): + # warning: inplace clip + x[:, [0, 2]] = x[:, [0, 2]].clip(0, w - 1e-3) # x1, x2 + x[:, [1, 3]] = x[:, [1, 3]].clip(0, h - 1e-3) # y1, y2 + + # Convert nx4 boxes + # from [x1, y1, x2, y2] to [x, y, w, h] normalized where xy1=top-left, xy2=bottom-right + y = numpy.copy(x) + y[:, 0] = ((x[:, 0] + x[:, 2]) / 2) / w # x center + y[:, 1] = ((x[:, 1] + x[:, 3]) / 2) / h # y center + y[:, 2] = (x[:, 2] - x[:, 0]) / w # width + y[:, 3] = (x[:, 3] - x[:, 1]) / h # height + return y + + +def resample(): + choices = (cv2.INTER_AREA, cv2.INTER_CUBIC, cv2.INTER_LINEAR, cv2.INTER_NEAREST, cv2.INTER_LANCZOS4) + return random.choice(seq=choices) + + +def augment_hsv(image, params): + # HSV color-space augmentation + h = params["hsv_h"] + s = params["hsv_s"] + v = params["hsv_v"] + + r = numpy.random.uniform(-1, 1, 3) * [h, s, v] + 1 + h, s, v = cv2.split(cv2.cvtColor(image, cv2.COLOR_BGR2HSV)) + + x = numpy.arange(0, 256, dtype=r.dtype) + lut_h = ((x * r[0]) % 180).astype("uint8") + lut_s = numpy.clip(x * r[1], 0, 255).astype("uint8") + lut_v = numpy.clip(x * r[2], 0, 255).astype("uint8") + + hsv = cv2.merge((cv2.LUT(h, lut_h), cv2.LUT(s, lut_s), cv2.LUT(v, lut_v))) + cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR, dst=image) # no return needed + + +def resize(image, input_size, augment): + # Resize and pad image while meeting stride-multiple constraints + shape = image.shape[:2] # current shape [height, width] + + # Scale ratio (new / old) + r = min(input_size / shape[0], input_size / shape[1]) + if not augment: # only scale down, do not scale up (for better val mAP) + r = min(r, 1.0) + + # Compute padding + pad = int(round(shape[1] * r)), int(round(shape[0] * r)) + w = (input_size - pad[0]) / 2 + h = (input_size - pad[1]) / 2 + + if shape[::-1] != pad: # resize + image = cv2.resize(image, dsize=pad, interpolation=resample() if augment else cv2.INTER_LINEAR) + top, bottom = int(round(h - 0.1)), int(round(h + 0.1)) + left, right = int(round(w - 0.1)), int(round(w + 0.1)) + image = cv2.copyMakeBorder(image, top, bottom, left, right, cv2.BORDER_CONSTANT) # add border + return image, (r, r), (w, h) + + +def candidates(box1, box2): + # box1(4,n), box2(4,n) + w1, h1 = box1[2] - box1[0], box1[3] - box1[1] + w2, h2 = box2[2] - box2[0], box2[3] - box2[1] + aspect_ratio = numpy.maximum(w2 / (h2 + 1e-16), h2 / (w2 + 1e-16)) # aspect ratio + return (w2 > 2) & (h2 > 2) & (w2 * h2 / (w1 * h1 + 1e-16) > 0.1) & (aspect_ratio < 100) + + +def random_perspective(image, label, params, border=(0, 0)): + h = image.shape[0] + border[0] * 2 + w = image.shape[1] + border[1] * 2 + + # Center + center = numpy.eye(3) + center[0, 2] = -image.shape[1] / 2 # x translation (pixels) + center[1, 2] = -image.shape[0] / 2 # y translation (pixels) + + # Perspective + perspective = numpy.eye(3) + + # Rotation and Scale + rotate = numpy.eye(3) + a = random.uniform(-params["degrees"], params["degrees"]) + s = random.uniform(1 - params["scale"], 1 + params["scale"]) + rotate[:2] = cv2.getRotationMatrix2D(angle=a, center=(0, 0), scale=s) + + # Shear + shear = numpy.eye(3) + shear[0, 1] = math.tan(random.uniform(-params["shear"], params["shear"]) * math.pi / 180) + shear[1, 0] = math.tan(random.uniform(-params["shear"], params["shear"]) * math.pi / 180) + + # Translation + translate = numpy.eye(3) + translate[0, 2] = random.uniform(0.5 - params["translate"], 0.5 + params["translate"]) * w + translate[1, 2] = random.uniform(0.5 - params["translate"], 0.5 + params["translate"]) * h + + # Combined rotation matrix, order of operations (right to left) is IMPORTANT + matrix = translate @ shear @ rotate @ perspective @ center + if (border[0] != 0) or (border[1] != 0) or (matrix != numpy.eye(3)).any(): # image changed + image = cv2.warpAffine(image, matrix[:2], dsize=(w, h), borderValue=(0, 0, 0)) + + # Transform label coordinates + n = len(label) + if n: + xy = numpy.ones((n * 4, 3)) + xy[:, :2] = label[:, [1, 2, 3, 4, 1, 4, 3, 2]].reshape(n * 4, 2) # x1y1, x2y2, x1y2, x2y1 + xy = xy @ matrix.T # transform + xy = xy[:, :2].reshape(n, 8) # perspective rescale or affine + + # create new boxes + x = xy[:, [0, 2, 4, 6]] + y = xy[:, [1, 3, 5, 7]] + box = numpy.concatenate((x.min(1), y.min(1), x.max(1), y.max(1))).reshape(4, n).T + + # clip + box[:, [0, 2]] = box[:, [0, 2]].clip(0, w) + box[:, [1, 3]] = box[:, [1, 3]].clip(0, h) + # filter candidates + indices = candidates(box1=label[:, 1:5].T * s, box2=box.T) + + label = label[indices] + label[:, 1:5] = box[indices] + + return image, label + + +def mix_up(image1, label1, image2, label2): + # Applies MixUp augmentation https://arxiv.org/pdf/1710.09412.pdf + alpha = numpy.random.beta(a=32.0, b=32.0) # mix-up ratio, alpha=beta=32.0 + image = (image1 * alpha + image2 * (1 - alpha)).astype(numpy.uint8) + label = numpy.concatenate((label1, label2), 0) + return image, label + + +class Albumentations: + def __init__(self): + self.transform = None + try: + import albumentations + + transforms = [ + albumentations.Blur(p=0.01), + albumentations.CLAHE(p=0.01), + albumentations.ToGray(p=0.01), + albumentations.MedianBlur(p=0.01), + ] + self.transform = albumentations.Compose( + transforms, albumentations.BboxParams(format="yolo", label_fields=["class_labels"]) + ) + + except ImportError: # package not installed, skip + pass + + def __call__(self, image, box, cls): + if self.transform: + x = self.transform(image=image, bboxes=box, class_labels=cls) + image = x["image"] + box = numpy.array(x["bboxes"]) + cls = numpy.array(x["class_labels"]) + return image, box, cls diff --git a/utils/fed_util.py b/utils/fed_util.py new file mode 100644 index 0000000..e462407 --- /dev/null +++ b/utils/fed_util.py @@ -0,0 +1,250 @@ +import os +import re +import random +from collections import defaultdict +from typing import Dict, List, Optional, Set, Any + +from nets import nn + + +def _image_to_label_path(img_path: str) -> str: + """ + Convert an image path like ".../images/train2017/xxx.jpg" + to the corresponding label path ".../labels/train2017/xxx.txt". + Works for POSIX/Windows separators. + """ + # swap "/images/" (or "\images\") to "/labels/" + label_path = re.sub(r"([/\\])images([/\\])", r"\1labels\2", img_path) + # swap extension to .txt + root, _ = os.path.splitext(label_path) + return root + ".txt" + + +def _parse_yolo_label_file(label_path: str) -> Set[int]: + """ + Return a set of class_ids found in a YOLO .txt label file. + Empty file -> empty set. Missing file -> empty set. + Robust to blank lines / trailing spaces. + """ + class_ids: Set[int] = set() + if not os.path.exists(label_path): + return class_ids + try: + with open(label_path, "r", encoding="utf-8") as f: + for line in f: + line = line.strip() + if not line: + continue + # YOLO format: cls cx cy w h + parts = line.split() + if not parts: + continue + try: + cls = int(parts[0]) + except ValueError: + # handle weird case like '23.0' + try: + cls = int(float(parts[0])) + except ValueError: + # skip malformed line + continue + class_ids.add(cls) + except Exception: + # If the file can't be read for some reason, treat as no labels + return set() + return class_ids + + +def divide_trainset( + trainset_path: str, + num_local_class: int, + num_client: int, + min_data: int, + max_data: int, + mode: str = "overlap", # "overlap" or "disjoint" + seed: Optional[int] = None, +) -> Dict[str, Any]: + """ + Build a federated split from a YOLO dataset list file. + + Args: + trainset_path: path to a .txt file containing one image path per line + e.g. /COCO/images/train2017/1111.jpg + num_local_class: how many distinct classes to sample for each client + num_client: number of clients + min_data: minimum number of images per client + max_data: maximum number of images per client + mode: "overlap" -> images may be shared across clients + "disjoint" -> each image is used by at most one client + seed: optional random seed for reproducibility + + Returns: + trainset_divided = { + "users": ["c_00001", ...], + "user_data": { + "c_00001": {"filename": [img_path, ...]}, + ... + }, + "num_samples": [len(list_for_user1), len(list_for_user2), ...] + } + + Example: + dataset = divide_trainset( + trainset_path="/COCO/train2017.txt", + num_local_class=3, + num_client=5, + min_data=10, + max_data=20, + mode="disjoint", # or "overlap" + seed=42 + ) + + print(dataset["users"]) # ['c_00001', ..., 'c_00005'] + print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15] + print(dataset["user_data"]["c_00001"]["filename"][:3]) + """ + if seed is not None: + random.seed(seed) + + # ---- Basic validations (defensive programming) ---- + if num_client <= 0: + raise ValueError("num_client must be > 0") + if num_local_class <= 0: + raise ValueError("num_local_class must be > 0") + if min_data < 0 or max_data < 0: + raise ValueError("min_data/max_data must be >= 0") + if max_data < min_data: + raise ValueError("max_data must be >= min_data") + if mode not in {"overlap", "disjoint"}: + raise ValueError('mode must be "overlap" or "disjoint"') + + # ---- 1) Read image list ---- + with open(trainset_path, "r", encoding="utf-8") as f: + all_images_raw = [ln.strip() for ln in f if ln.strip()] + + # Normalize and deduplicate image paths (safe) + all_images: List[str] = [] + seen = set() + for p in all_images_raw: + # keep exact string (don’t join with cwd), just normalize slashes + norm = os.path.normpath(p) + if norm not in seen: + seen.add(norm) + all_images.append(norm) + + # ---- 2) Build mappings from labels ---- + class_to_images: Dict[int, Set[str]] = defaultdict(set) + image_to_classes: Dict[str, Set[int]] = {} + + missing_label_files = 0 + empty_label_files = 0 + parsed_images = 0 + + for img in all_images: + lbl = _image_to_label_path(img) + if not os.path.exists(lbl): + # Missing labels: skip image (no class info) + missing_label_files += 1 + continue + + classes = _parse_yolo_label_file(lbl) + if not classes: + # No objects in this image -> skip (no class bucket) + empty_label_files += 1 + continue + + image_to_classes[img] = classes + for c in classes: + class_to_images[c].add(img) + parsed_images += 1 + + if not class_to_images: + # No usable images found + return { + "users": [f"c_{i + 1:05d}" for i in range(num_client)], + "user_data": {f"c_{i + 1:05d}": {"filename": []} for i in range(num_client)}, + "num_samples": [0 for _ in range(num_client)], + } + + all_classes: List[int] = sorted(class_to_images.keys()) + # Available pool for disjoint mode (only images with labels) + available_images: Set[str] = set(image_to_classes.keys()) + + # ---- 3) Allocate to clients ---- + result = {"users": [], "user_data": {}, "num_samples": []} + + for cid in range(num_client): + user_id = f"c_{cid + 1:05d}" + result["users"].append(user_id) + + # Pick the classes for this client (sample without replacement from global class set) + k = min(num_local_class, len(all_classes)) + chosen_classes = random.sample(all_classes, k) if k > 0 else [] + + # Decide how many samples for this client + need = min_data if min_data == max_data else random.randint(min_data, max_data) + + # Build the candidate pool for this client + if mode == "overlap": + pool_set: Set[str] = set() + for c in chosen_classes: + pool_set.update(class_to_images[c]) + else: # "disjoint": restrict to currently available images + pool_set = set() + for c in chosen_classes: + # intersect with available images + pool_set.update(class_to_images[c] & available_images) + + # Deduplicate and sample + pool_list = list(pool_set) + if len(pool_list) <= need: + chosen_imgs = pool_list[:] # take all (can be fewer than need) + else: + chosen_imgs = random.sample(pool_list, need) + + # Record for the user + result["user_data"][user_id] = {"filename": chosen_imgs} + result["num_samples"].append(len(chosen_imgs)) + + # If disjoint, remove selected images from availability everywhere + if mode == "disjoint" and chosen_imgs: + for img in chosen_imgs: + if img in available_images: + available_images.remove(img) + # remove from every class bucket this image belongs to + for c in image_to_classes.get(img, []): + if img in class_to_images[c]: + class_to_images[c].remove(img) + # Optional: prune empty classes from all_classes to speed up later loops + # (keep list stable; just skip empties naturally) + + # (Optional) You can print some quick diagnostics if helpful: + # print(f"[INFO] Parsed images with labels: {parsed_images}") + # print(f"[INFO] Missing label files: {missing_label_files}") + # print(f"[INFO] Empty label files: {empty_label_files}") + + return result + + +def init_model(model_name, num_classes): + """ + Initialize the model for a specific learning task + Args: + :param model_name: Name of the model + :param num_classes: Number of classes + """ + model = None + if model_name == "yolo_v11_n": + model = nn.yolo_v11_n(num_classes=num_classes) + elif model_name == "yolo_v11_s": + model = nn.yolo_v11_s(num_classes=num_classes) + elif model_name == "yolo_v11_m": + model = nn.yolo_v11_m(num_classes=num_classes) + elif model_name == "yolo_v11_l": + model = nn.yolo_v11_l(num_classes=num_classes) + elif model_name == "yolo_v11_x": + model = nn.yolo_v11_x(num_classes=num_classes) + else: + raise ValueError("Model {} is not supported.".format(model_name)) + + return model diff --git a/utils/util.py b/utils/util.py new file mode 100644 index 0000000..08ea61b --- /dev/null +++ b/utils/util.py @@ -0,0 +1,818 @@ +""" +Utility functions for yolo. +""" + +import copy +import random +from time import time + +import math +import numpy +import torch +import torchvision +from torch.nn.functional import cross_entropy + + +def setup_seed(): + """ + Setup random seed. + """ + random.seed(0) + numpy.random.seed(0) + torch.manual_seed(0) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + +def setup_multi_processes(): + """ + Setup multi-processing environment variables. + """ + import cv2 + from os import environ + from platform import system + + # set multiprocess start method as `fork` to speed up the training + if system() != "Windows": + torch.multiprocessing.set_start_method("fork", force=True) + + # disable opencv multithreading to avoid system being overloaded + cv2.setNumThreads(0) + + # setup OMP threads + if "OMP_NUM_THREADS" not in environ: + environ["OMP_NUM_THREADS"] = "1" + + # setup MKL threads + if "MKL_NUM_THREADS" not in environ: + environ["MKL_NUM_THREADS"] = "1" + + +def export_onnx(args): + import onnx # noqa + + inputs = ["images"] + outputs = ["outputs"] + dynamic = {"outputs": {0: "batch", 1: "anchors"}} + + m = torch.load("./weights/best.pt", weights_only=False)["model"].float() + x = torch.zeros((1, 3, args.input_size, args.input_size)) + + torch.onnx.export( + m.cpu(), + (x.cpu(),), + f="./weights/best.onnx", + verbose=False, + opset_version=12, + # WARNING: DNN inference with torch>=1.12 may require do_constant_folding=False + do_constant_folding=True, + input_names=inputs, + output_names=outputs, + dynamic_axes=dynamic or None, + ) + + # Checks + model_onnx = onnx.load("./weights/best.onnx") # load onnx model + onnx.checker.check_model(model_onnx) # check onnx model + + onnx.save(model_onnx, "./weights/best.onnx") + # Inference example + # https://github.com/ultralytics/ultralytics/blob/main/ultralytics/nn/autobackend.py + + +def wh2xy(x): + y = x.clone() if isinstance(x, torch.Tensor) else numpy.copy(x) + y[:, 0] = x[:, 0] - x[:, 2] / 2 # top left x + y[:, 1] = x[:, 1] - x[:, 3] / 2 # top left y + y[:, 2] = x[:, 0] + x[:, 2] / 2 # bottom right x + y[:, 3] = x[:, 1] + x[:, 3] / 2 # bottom right y + return y + + +def make_anchors(x, strides, offset=0.5): + assert x is not None + anchor_tensor, stride_tensor = [], [] + dtype, device = x[0].dtype, x[0].device + for i, stride in enumerate(strides): + _, _, h, w = x[i].shape + sx = torch.arange(end=w, device=device, dtype=dtype) + offset # shift x + sy = torch.arange(end=h, device=device, dtype=dtype) + offset # shift y + sy, sx = torch.meshgrid(sy, sx, indexing="ij") + anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2)) + stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) + return torch.cat(anchor_tensor), torch.cat(stride_tensor) + + +def compute_metric(output, target, iou_v): + # intersection(N,M) = (rb(N,M,2) - lt(N,M,2)).clamp(0).prod(2) + (a1, a2) = target[:, 1:].unsqueeze(1).chunk(2, 2) + (b1, b2) = output[:, :4].unsqueeze(0).chunk(2, 2) + intersection = (torch.min(a2, b2) - torch.max(a1, b1)).clamp(0).prod(2) + # IoU = intersection / (area1 + area2 - intersection) + iou = intersection / ((a2 - a1).prod(2) + (b2 - b1).prod(2) - intersection + 1e-7) + + correct = numpy.zeros((output.shape[0], iou_v.shape[0])) + correct = correct.astype(bool) + for i in range(len(iou_v)): + # IoU > threshold and classes match + x = torch.where((iou >= iou_v[i]) & (target[:, 0:1] == output[:, 5])) + if x[0].shape[0]: + matches = torch.cat((torch.stack(x, 1), iou[x[0], x[1]][:, None]), 1).cpu().numpy() # [label, detect, iou] + if x[0].shape[0] > 1: + matches = matches[matches[:, 2].argsort()[::-1]] + matches = matches[numpy.unique(matches[:, 1], return_index=True)[1]] + matches = matches[numpy.unique(matches[:, 0], return_index=True)[1]] + correct[matches[:, 1].astype(int), i] = True + return torch.tensor(correct, dtype=torch.bool, device=output.device) + + +def non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.65): + max_wh = 7680 + max_det = 300 + max_nms = 30000 + + bs = outputs.shape[0] # batch size + nc = outputs.shape[1] - 4 # number of classes + xc = outputs[:, 4 : 4 + nc].amax(1) > confidence_threshold # candidates + + # Settings + start = time() + limit = 0.5 + 0.05 * bs # seconds to quit after + output = [torch.zeros((0, 6), device=outputs.device)] * bs + for index, x in enumerate(outputs): # image index, image inference + x = x.transpose(0, -1)[xc[index]] # confidence + + # If none remain process next image + if not x.shape[0]: + continue + + # matrix nx6 (box, confidence, cls) + box, cls = x.split((4, nc), 1) + box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2) + if nc > 1: + i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T + x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float()), dim=1) + else: # best class only + conf, j = cls.max(1, keepdim=True) + x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold] + + # Check shape + n = x.shape[0] # number of boxes + if not n: # no boxes + continue + x = x[x[:, 4].argsort(descending=True)[:max_nms]] # sort by confidence and remove excess boxes + + # Batched NMS + c = x[:, 5:6] * max_wh # classes + boxes, scores = x[:, :4] + c, x[:, 4] # boxes, scores + indices = torchvision.ops.nms(boxes, scores, iou_threshold) # NMS + indices = indices[:max_det] # limit detections + + output[index] = x[indices] + if (time() - start) > limit: + break # time limit exceeded + + return output + + +def smooth(y, f=0.1): + # Box filter of fraction f + nf = round(len(y) * f * 2) // 2 + 1 # number of filter elements (must be odd) + p = numpy.ones(nf // 2) # ones padding + yp = numpy.concatenate((p * y[0], y, p * y[-1]), 0) # y padded + return numpy.convolve(yp, numpy.ones(nf) / nf, mode="valid") # y-smoothed + + +def plot_pr_curve(px, py, ap, names, save_dir): + from matplotlib import pyplot + + fig, ax = pyplot.subplots(1, 1, figsize=(9, 6), tight_layout=True) + py = numpy.stack(py, axis=1) + + if 0 < len(names) < 21: # display per-class legend if < 21 classes + for i, y in enumerate(py.T): + ax.plot(px, y, linewidth=1, label=f"{names[i]} {ap[i, 0]:.3f}") # plot(recall, precision) + else: + ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision) + + ax.plot( + px, + py.mean(1), + linewidth=3, + color="blue", + label="all classes %.3f mAP@0.5" % ap[:, 0].mean(), + ) + ax.set_xlabel("Recall") + ax.set_ylabel("Precision") + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + ax.set_title("Precision-Recall Curve") + fig.savefig(save_dir, dpi=250) + pyplot.close(fig) + + +def plot_curve(px, py, names, save_dir, x_label="Confidence", y_label="Metric"): + from matplotlib import pyplot + + figure, ax = pyplot.subplots(1, 1, figsize=(9, 6), tight_layout=True) + + if 0 < len(names) < 21: # display per-class legend if < 21 classes + for i, y in enumerate(py): + ax.plot(px, y, linewidth=1, label=f"{names[i]}") # plot(confidence, metric) + else: + ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric) + + y = smooth(py.mean(0), f=0.05) + ax.plot( + px, + y, + linewidth=3, + color="blue", + label=f"all classes {y.max():.3f} at {px[y.argmax()]:.3f}", + ) + ax.set_xlabel(x_label) + ax.set_ylabel(y_label) + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.legend(bbox_to_anchor=(1.04, 1), loc="upper left") + ax.set_title(f"{y_label}-Confidence Curve") + figure.savefig(save_dir, dpi=250) + pyplot.close(figure) + + +def compute_ap(tp, conf, output, target, plot=False, names=(), eps=1e-16): + """ + Compute the average precision, given the recall and precision curves. + Source: https://github.com/rafaelpadilla/Object-Detection-Metrics. + # Arguments + tp: True positives (nparray, nx1 or nx10). + conf: Object-ness value from 0-1 (nparray). + output: Predicted object classes (nparray). + target: True object classes (nparray). + # Returns + The average precision + """ + # Sort by object-ness + i = numpy.argsort(-conf) + tp, conf, output = tp[i], conf[i], output[i] + + # Find unique classes + unique_classes, nt = numpy.unique(target, return_counts=True) + nc = unique_classes.shape[0] # number of classes, number of detections + + # Create Precision-Recall curve and compute AP for each class + p = numpy.zeros((nc, 1000)) + r = numpy.zeros((nc, 1000)) + ap = numpy.zeros((nc, tp.shape[1])) + px, py = numpy.linspace(start=0, stop=1, num=1000), [] # for plotting + for ci, c in enumerate(unique_classes): + i = output == c + nl = nt[ci] # number of labels + no = i.sum() # number of outputs + if no == 0 or nl == 0: + continue + + # Accumulate FPs and TPs + fpc = (1 - tp[i]).cumsum(0) + tpc = tp[i].cumsum(0) + + # Recall + recall = tpc / (nl + eps) # recall curve + # negative x, xp because xp decreases + r[ci] = numpy.interp(-px, -conf[i], recall[:, 0], left=0) + + # Precision + precision = tpc / (tpc + fpc) # precision curve + p[ci] = numpy.interp(-px, -conf[i], precision[:, 0], left=1) # p at pr_score + + # AP from recall-precision curve + for j in range(tp.shape[1]): + m_rec = numpy.concatenate(([0.0], recall[:, j], [1.0])) + m_pre = numpy.concatenate(([1.0], precision[:, j], [0.0])) + + # Compute the precision envelope + m_pre = numpy.flip(numpy.maximum.accumulate(numpy.flip(m_pre))) + + # Integrate area under curve + x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO) + ap[ci, j] = numpy.trapz(numpy.interp(x, m_rec, m_pre), x) # integrate + if plot and j == 0: + py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5 + + # Compute F1 (harmonic mean of precision and recall) + f1 = 2 * p * r / (p + r + eps) + if plot: + names = dict(enumerate(names)) # to dict + names = [v for k, v in names.items() if k in unique_classes] # list: only classes that have data + plot_pr_curve(px, py, ap, names, save_dir="./weights/PR_curve.png") + plot_curve(px, f1, names, save_dir="./weights/F1_curve.png", y_label="F1") + plot_curve(px, p, names, save_dir="./weights/P_curve.png", y_label="Precision") + plot_curve(px, r, names, save_dir="./weights/R_curve.png", y_label="Recall") + i = smooth(f1.mean(0), 0.1).argmax() # max F1 index + p, r, f1 = p[:, i], r[:, i], f1[:, i] + tp = (r * nt).round() # true positives + fp = (tp / (p + eps) - tp).round() # false positives + ap50, ap = ap[:, 0], ap.mean(1) # AP@0.5, AP@0.5:0.95 + m_pre, m_rec = p.mean(), r.mean() + map50, mean_ap = ap50.mean(), ap.mean() + return tp, fp, m_pre, m_rec, map50, mean_ap + + +def compute_iou(box1, box2, eps=1e-7): + # Returns Intersection over Union (IoU) of box1(1,4) to box2(n,4) + + # Get the coordinates of bounding boxes + b1_x1, b1_y1, b1_x2, b1_y2 = box1.chunk(4, -1) + b2_x1, b2_y1, b2_x2, b2_y2 = box2.chunk(4, -1) + w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps + w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps + + # Intersection area + inter = (b1_x2.minimum(b2_x2) - b1_x1.maximum(b2_x1)).clamp(0) * ( + b1_y2.minimum(b2_y2) - b1_y1.maximum(b2_y1) + ).clamp(0) + + # Union Area + union = w1 * h1 + w2 * h2 - inter + eps + + # IoU + iou = inter / union + cw = b1_x2.maximum(b2_x2) - b1_x1.minimum(b2_x1) # convex (smallest enclosing box) width + ch = b1_y2.maximum(b2_y2) - b1_y1.minimum(b2_y1) # convex height + c2 = cw**2 + ch**2 + eps # convex diagonal squared + rho2 = ((b2_x1 + b2_x2 - b1_x1 - b1_x2) ** 2 + (b2_y1 + b2_y2 - b1_y1 - b1_y2) ** 2) / 4 # center dist ** 2 + # https://github.com/Zzh-tju/DIoU-SSD-pytorch/blob/master/utils/box/box_utils.py#L47 + v = (4 / math.pi**2) * (torch.atan(w2 / h2) - torch.atan(w1 / h1)).pow(2) + with torch.no_grad(): + alpha = v / (v - iou + (1 + eps)) + return iou - (rho2 / c2 + v * alpha) # CIoU + + +def strip_optimizer(filename): + x = torch.load(filename, map_location="cpu", weights_only=False) + x["model"].half() # to FP16 + for p in x["model"].parameters(): + p.requires_grad = False + torch.save(x, f=filename) + + +def clip_gradients(model, max_norm=10.0): + parameters = model.parameters() + torch.nn.utils.clip_grad_norm_(parameters, max_norm=max_norm) + + +def load_weight(model, ckpt): + dst = model.state_dict() + src = torch.load(ckpt, weights_only=False)["model"].float().cpu() + + ckpt = {} + for k, v in src.state_dict().items(): + if k in dst and v.shape == dst[k].shape: + ckpt[k] = v + + model.load_state_dict(state_dict=ckpt, strict=False) + return model + + +def set_params(model, decay): + p1 = [] + p2 = [] + norm = tuple(v for k, v in torch.nn.__dict__.items() if "Norm" in k) + for m in model.modules(): + for n, p in m.named_parameters(recurse=0): + if not p.requires_grad: + continue + if n == "bias": # bias (no decay) + p1.append(p) + elif n == "weight" and isinstance(m, norm): # norm-weight (no decay) + p1.append(p) + else: + p2.append(p) # weight (with decay) + return [{"params": p1, "weight_decay": 0.00}, {"params": p2, "weight_decay": decay}] + + +def plot_lr(args, optimizer, scheduler, num_steps): + from matplotlib import pyplot + + optimizer = copy.copy(optimizer) + scheduler = copy.copy(scheduler) + + y = [] + for epoch in range(args.epochs): + for i in range(num_steps): + step = i + num_steps * epoch + scheduler.step(step, optimizer) + y.append(optimizer.param_groups[0]["lr"]) + pyplot.plot(y, ".-", label="LR") + pyplot.xlabel("step") + pyplot.ylabel("LR") + pyplot.grid() + pyplot.xlim(0, args.epochs * num_steps) + pyplot.ylim(0) + pyplot.savefig("./weights/lr.png", dpi=200) + pyplot.close() + + +class CosineLR: + def __init__(self, args, params, num_steps): + max_lr = params["max_lr"] + min_lr = params["min_lr"] + + warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100)) + decay_steps = int(args.epochs * num_steps - warmup_steps) + + warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps)) + + decay_lr = [] + for step in range(1, decay_steps + 1): + alpha = math.cos(math.pi * step / decay_steps) + decay_lr.append(min_lr + 0.5 * (max_lr - min_lr) * (1 + alpha)) + + self.total_lr = numpy.concatenate((warmup_lr, decay_lr)) + + def step(self, step, optimizer): + for param_group in optimizer.param_groups: + param_group["lr"] = self.total_lr[step] + + +class LinearLR: + def __init__(self, args, params, num_steps): + max_lr = params["max_lr"] + min_lr = params["min_lr"] + + warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100)) + decay_steps = max(1, int(args.epochs * num_steps - warmup_steps)) + + warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps), endpoint=False) + decay_lr = numpy.linspace(max_lr, min_lr, decay_steps) + + self.total_lr = numpy.concatenate((warmup_lr, decay_lr)) + + def step(self, step, optimizer): + for param_group in optimizer.param_groups: + param_group["lr"] = self.total_lr[step] + + +class EMA: + """ + Updated Exponential Moving Average (EMA) from https://github.com/rwightman/pytorch-image-models + Keeps a moving average of everything in the model state_dict (parameters and buffers) + For EMA details see https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage + """ + + def __init__(self, model, decay=0.9999, tau=2000, updates=0): + # Create EMA + self.ema = copy.deepcopy(model).eval() # FP32 EMA + self.updates = updates # number of EMA updates + # decay exponential ramp (to help early epochs) + self.decay = lambda x: decay * (1 - math.exp(-x / tau)) + for p in self.ema.parameters(): + p.requires_grad_(False) + + def update(self, model): + if hasattr(model, "module"): + model = model.module + # Update EMA parameters + with torch.no_grad(): + self.updates += 1 + d = self.decay(self.updates) + + msd = model.state_dict() # model state_dict + for k, v in self.ema.state_dict().items(): + if v.dtype.is_floating_point: + v *= d + v += (1 - d) * msd[k].detach() + + +class AverageMeter: + def __init__(self): + self.num = 0 + self.sum = 0 + self.avg = 0 + + def update(self, v, n): + if not math.isnan(float(v)): + self.num = self.num + n + self.sum = self.sum + v * n + self.avg = self.sum / self.num + + +class Assigner(torch.nn.Module): + def __init__(self, nc=80, top_k=13, alpha=1.0, beta=6.0, eps=1e-9): + super().__init__() + self.top_k = top_k + self.nc = nc + self.alpha = alpha + self.beta = beta + self.eps = eps + + @torch.no_grad() + def forward(self, pd_scores, pd_bboxes, anc_points, gt_labels, gt_bboxes, mask_gt): + batch_size = pd_scores.size(0) + num_max_boxes = gt_bboxes.size(1) + + if num_max_boxes == 0: + device = gt_bboxes.device + return ( + torch.zeros_like(pd_bboxes).to(device), + torch.zeros_like(pd_scores).to(device), + torch.zeros_like(pd_scores[..., 0]).to(device), + ) + + num_anchors = anc_points.shape[0] + shape = gt_bboxes.shape + lt, rb = gt_bboxes.view(-1, 1, 4).chunk(2, 2) + mask_in_gts = torch.cat((anc_points[None] - lt, rb - anc_points[None]), dim=2) + mask_in_gts = mask_in_gts.view(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps) + na = pd_bboxes.shape[-2] + gt_mask = (mask_in_gts * mask_gt).bool() # b, max_num_obj, h*w + overlaps = torch.zeros( + [batch_size, num_max_boxes, na], + dtype=pd_bboxes.dtype, + device=pd_bboxes.device, + ) + bbox_scores = torch.zeros( + [batch_size, num_max_boxes, na], + dtype=pd_scores.dtype, + device=pd_scores.device, + ) + + ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long) # 2, b, max_num_obj + ind[0] = torch.arange(end=batch_size).view(-1, 1).expand(-1, num_max_boxes) # b, max_num_obj + ind[1] = gt_labels.squeeze(-1) # b, max_num_obj + bbox_scores[gt_mask] = pd_scores[ind[0], :, ind[1]][gt_mask] # b, max_num_obj, h*w + + pd_boxes = pd_bboxes.unsqueeze(1).expand(-1, num_max_boxes, -1, -1)[gt_mask] + gt_boxes = gt_bboxes.unsqueeze(2).expand(-1, -1, na, -1)[gt_mask] + overlaps[gt_mask] = compute_iou(gt_boxes, pd_boxes).squeeze(-1).clamp_(0) + + align_metric = bbox_scores.pow(self.alpha) * overlaps.pow(self.beta) + + top_k_mask = mask_gt.expand(-1, -1, self.top_k).bool() + top_k_metrics, top_k_indices = torch.topk(align_metric, self.top_k, dim=-1, largest=True) + if top_k_mask is None: + top_k_mask = (top_k_metrics.max(-1, keepdim=True)[0] > self.eps).expand_as(top_k_indices) + top_k_indices.masked_fill_(~top_k_mask, 0) + + mask_top_k = torch.zeros(align_metric.shape, dtype=torch.int8, device=top_k_indices.device) + ones = torch.ones_like(top_k_indices[:, :, :1], dtype=torch.int8, device=top_k_indices.device) + for k in range(self.top_k): + mask_top_k.scatter_add_(-1, top_k_indices[:, :, k : k + 1], ones) + mask_top_k.masked_fill_(mask_top_k > 1, 0) + mask_top_k = mask_top_k.to(align_metric.dtype) + mask_pos = mask_top_k * mask_in_gts * mask_gt + + fg_mask = mask_pos.sum(-2) + if fg_mask.max() > 1: + mask_multi_gts = (fg_mask.unsqueeze(1) > 1).expand(-1, num_max_boxes, -1) + max_overlaps_idx = overlaps.argmax(1) + + is_max_overlaps = torch.zeros(mask_pos.shape, dtype=mask_pos.dtype, device=mask_pos.device) + is_max_overlaps.scatter_(1, max_overlaps_idx.unsqueeze(1), 1) + + mask_pos = torch.where(mask_multi_gts, is_max_overlaps, mask_pos).float() + fg_mask = mask_pos.sum(-2) + target_gt_idx = mask_pos.argmax(-2) + + # Assigned target + index = torch.arange(end=batch_size, dtype=torch.int64, device=gt_labels.device)[..., None] + target_index = target_gt_idx + index * num_max_boxes + target_labels = gt_labels.long().flatten()[target_index] + + target_bboxes = gt_bboxes.view(-1, gt_bboxes.shape[-1])[target_index] + + # Assigned target scores + target_labels.clamp_(0) + + target_scores = torch.zeros( + (target_labels.shape[0], target_labels.shape[1], self.nc), + dtype=torch.int64, + device=target_labels.device, + ) + target_scores.scatter_(2, target_labels.unsqueeze(-1), 1) + + fg_scores_mask = fg_mask[:, :, None].repeat(1, 1, self.nc) + target_scores = torch.where(fg_scores_mask > 0, target_scores, 0) + + # Normalize + align_metric *= mask_pos + pos_align_metrics = align_metric.amax(dim=-1, keepdim=True) + pos_overlaps = (overlaps * mask_pos).amax(dim=-1, keepdim=True) + norm_align_metric = (align_metric * pos_overlaps / (pos_align_metrics + self.eps)).amax(-2).unsqueeze(-1) + target_scores = target_scores * norm_align_metric + + return target_bboxes, target_scores, fg_mask.bool() + + +class QFL(torch.nn.Module): + def __init__(self, beta=2.0): + super().__init__() + self.beta = beta + self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none") + + def forward(self, outputs, targets): + bce_loss = self.bce_loss(outputs, targets) + return torch.pow(torch.abs(targets - outputs.sigmoid()), self.beta) * bce_loss + + +class VFL(torch.nn.Module): + def __init__(self, alpha=0.75, gamma=2.00, iou_weighted=True): + super().__init__() + assert alpha >= 0.0 + self.alpha = alpha + self.gamma = gamma + self.iou_weighted = iou_weighted + self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none") + + def forward(self, outputs, targets): + assert outputs.size() == targets.size() + targets = targets.type_as(outputs) + + if self.iou_weighted: + focal_weight = ( + targets * (targets > 0.0).float() + + self.alpha * (outputs.sigmoid() - targets).abs().pow(self.gamma) * (targets <= 0.0).float() + ) + + else: + focal_weight = (targets > 0.0).float() + self.alpha * (outputs.sigmoid() - targets).abs().pow( + self.gamma + ) * (targets <= 0.0).float() + + return self.bce_loss(outputs, targets) * focal_weight + + +class FocalLoss(torch.nn.Module): + def __init__(self, alpha=0.25, gamma=1.5): + super().__init__() + self.alpha = alpha + self.gamma = gamma + self.bce_loss = torch.nn.BCEWithLogitsLoss(reduction="none") + + def forward(self, outputs, targets): + loss = self.bce_loss(outputs, targets) + + if self.alpha > 0: + alpha_factor = targets * self.alpha + (1 - targets) * (1 - self.alpha) + loss *= alpha_factor + + if self.gamma > 0: + outputs_sigmoid = outputs.sigmoid() + p_t = targets * outputs_sigmoid + (1 - targets) * (1 - outputs_sigmoid) + gamma_factor = (1.0 - p_t) ** self.gamma + loss *= gamma_factor + + return loss + + +class BoxLoss(torch.nn.Module): + def __init__(self, dfl_ch): + super().__init__() + self.dfl_ch = dfl_ch + + def forward( + self, + pred_dist, + pred_bboxes, + anchor_points, + target_bboxes, + target_scores, + target_scores_sum, + fg_mask, + ): + # IoU loss + weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1) + iou = compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask]) + loss_box = ((1.0 - iou) * weight).sum() / target_scores_sum + + # DFL loss + a, b = target_bboxes.chunk(2, -1) + target = torch.cat((anchor_points - a, b - anchor_points), -1) + target = target.clamp(0, self.dfl_ch - 0.01) + loss_dfl = self.df_loss(pred_dist[fg_mask].view(-1, self.dfl_ch + 1), target[fg_mask]) + loss_dfl = (loss_dfl * weight).sum() / target_scores_sum + + return loss_box, loss_dfl + + @staticmethod + def df_loss(pred_dist, target): + # Distribution Focal Loss (DFL) + # https://ieeexplore.ieee.org/document/9792391 + tl = target.long() # target left + tr = tl + 1 # target right + wl = tr - target # weight left + wr = 1 - wl # weight right + left_loss = cross_entropy(pred_dist, tl.view(-1), reduction="none").view(tl.shape) + right_loss = cross_entropy(pred_dist, tr.view(-1), reduction="none").view(tl.shape) + return (left_loss * wl + right_loss * wr).mean(-1, keepdim=True) + + +class ComputeLoss: + def __init__(self, model, params): + if hasattr(model, "module"): + model = model.module + + device = next(model.parameters()).device + + m = model.head # Head() module + + self.params = params + self.stride = m.stride + self.nc = m.nc + self.no = m.no + self.reg_max = m.ch + self.device = device + + self.box_loss = BoxLoss(m.ch - 1).to(device) + self.cls_loss = torch.nn.BCEWithLogitsLoss(reduction="none") + self.assigner = Assigner(nc=self.nc, top_k=10, alpha=0.5, beta=6.0) + + self.project = torch.arange(m.ch, dtype=torch.float, device=device) + + def box_decode(self, anchor_points, pred_dist): + b, a, c = pred_dist.shape + pred_dist = pred_dist.view(b, a, 4, c // 4) + pred_dist = pred_dist.softmax(3) + pred_dist = pred_dist.matmul(self.project.type(pred_dist.dtype)) + lt, rb = pred_dist.chunk(2, -1) + x1y1 = anchor_points - lt + x2y2 = anchor_points + rb + return torch.cat(tensors=(x1y1, x2y2), dim=-1) + + def __call__(self, outputs, targets): + x = torch.cat([i.view(outputs[0].shape[0], self.no, -1) for i in outputs], dim=2) + pred_distri, pred_scores = x.split(split_size=(self.reg_max * 4, self.nc), dim=1) + + pred_scores = pred_scores.permute(0, 2, 1).contiguous() + pred_distri = pred_distri.permute(0, 2, 1).contiguous() + + data_type = pred_scores.dtype + batch_size = pred_scores.shape[0] + input_size = torch.tensor(outputs[0].shape[2:], device=self.device, dtype=data_type) * self.stride[0] + anchor_points, stride_tensor = make_anchors(outputs, self.stride, offset=0.5) + + idx = targets["idx"].view(-1, 1) + cls = targets["cls"].view(-1, 1) + box = targets["box"] + + targets = torch.cat((idx, cls, box), dim=1).to(self.device) + if targets.shape[0] == 0: + gt = torch.zeros(batch_size, 0, 5, device=self.device) + else: + i = targets[:, 0] + _, counts = i.unique(return_counts=True) + counts = counts.to(dtype=torch.int32) + gt = torch.zeros(batch_size, counts.max(), 5, device=self.device) + for j in range(batch_size): + matches = i == j + n = matches.sum() + if n: + gt[j, :n] = targets[matches, 1:] + x = gt[..., 1:5].mul_(input_size[[1, 0, 1, 0]]) + y = torch.empty_like(x) + dw = x[..., 2] / 2 # half-width + dh = x[..., 3] / 2 # half-height + y[..., 0] = x[..., 0] - dw # top left x + y[..., 1] = x[..., 1] - dh # top left y + y[..., 2] = x[..., 0] + dw # bottom right x + y[..., 3] = x[..., 1] + dh # bottom right y + gt[..., 1:5] = y + gt_labels, gt_bboxes = gt.split((1, 4), 2) + mask_gt = gt_bboxes.sum(2, keepdim=True).gt_(0) + + pred_bboxes = self.box_decode(anchor_points, pred_distri) + assigned_targets = self.assigner( + pred_scores.detach().sigmoid(), + (pred_bboxes.detach() * stride_tensor).type(gt_bboxes.dtype), + anchor_points * stride_tensor, + gt_labels, + gt_bboxes, + mask_gt, + ) + target_bboxes, target_scores, fg_mask = assigned_targets + + target_scores_sum = max(target_scores.sum(), 1) + + loss_cls = self.cls_loss(pred_scores, target_scores.to(data_type)).sum() / target_scores_sum # BCE + + # Box loss + loss_box = torch.zeros(1, device=self.device) + loss_dfl = torch.zeros(1, device=self.device) + if fg_mask.sum(): + target_bboxes /= stride_tensor + loss_box, loss_dfl = self.box_loss( + pred_distri, + pred_bboxes, + anchor_points, + target_bboxes, + target_scores, + target_scores_sum, + fg_mask, + ) + + loss_box *= self.params["box"] # box gain + loss_cls *= self.params["cls"] # cls gain + loss_dfl *= self.params["dfl"] # dfl gain + + return loss_box, loss_cls, loss_dfl