Compare commits

..

7 Commits

12 changed files with 504 additions and 468 deletions

3
.gitignore vendored
View File

@@ -296,5 +296,8 @@ Network Trash Folder
Temporary Items
.apdisk
# ---> Custom
results/
*.log
*.txt
weights/

View File

@@ -22,3 +22,9 @@ nohup python fed_run.py > train.log 2>&1 &
- Add more YOLO versions (e.g., YOLOv8, YOLOv5, etc.)
- Implement YOLOv8
- Implement YOLOv5
# references
[PyTorch Federated Learning](https://github.com/rruisong/pytorch_federated_learning)
[YOLOv11-pt](https://github.com/jahongir7174/YOLOv11-pt)

View File

@@ -3,11 +3,11 @@ import torch
from torch import nn
from torch.utils import data
from torch.amp.autocast_mode import autocast
from tqdm import tqdm
from utils.fed_util import init_model
from utils import util
from utils.dataset import Dataset
from typing import cast
from tqdm import tqdm
class FedYoloClient(object):
@@ -82,52 +82,48 @@ class FedYoloClient(object):
# load the global model parameters
self.model.load_state_dict(Global_model_state_dict, strict=True)
def train(self, args):
def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
"""
Train the local model
Args:
:param args: Command line arguments
- local_rank: Local rank for distributed training
- world_size: World size for distributed training
- distributed: Whether to use distributed training
- input_size: Input size for the model
Returns:
:return: Local updated model, number of local data points, training loss
Train the local model.
Returns: (state_dict, n_data, avg_loss_per_image)
"""
# ---- Dist init (if any) ----
if args.distributed:
torch.cuda.set_device(device=args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
# print(f"Client {self.name} - distributed training on {world_size} GPUs, local rank: {local_rank}")
# self._device = torch.device("cuda", local_rank)
if args.local_rank == 0:
pass
# if not os.path.exists("weights"):
# os.makedirs("weights")
util.setup_seed()
util.setup_multi_processes()
# model
# init model have been done in __init__()
self.model.to(self._device)
# device = torch.device(f"cuda:{args.local_rank}" if torch.cuda.is_available() else "cpu")
# self.model.to(device)
self.model.cuda()
# show model architecture
# print(self.model)
# Optimizer
accumulate = max(round(64 / (self._batch_size * args.world_size)), 1)
self._weight_decay = self._batch_size * args.world_size * accumulate / 64
# ---- Optimizer / WD scaling & LR warmup/schedule ----
# accumulate = effective grad-accumulation steps to emulate global batch 64
world_size = getattr(args, "world_size", 1)
accumulate = max(round(64 / (self._batch_size * max(world_size, 1))), 1)
# scale weight_decay like YOLO recipes
scaled_wd = self._weight_decay * self._batch_size * max(world_size, 1) * accumulate / 64
optimizer = torch.optim.SGD(
util.set_params(self.model, self._weight_decay),
util.set_params(self.model, scaled_wd),
lr=self._min_lr,
momentum=self._momentum,
nesterov=True,
)
# EMA
# ---- EMA (track the underlying module if DDP) ----
# track_model = self.model.module if is_ddp else self.model
ema = util.EMA(self.model) if args.local_rank == 0 else None
data_set = Dataset(
print(type(self.train_dataset))
# ---- Data ----
dataset = Dataset(
filenames=self.train_dataset,
input_size=args.input_size,
params=self.params,
@@ -136,26 +132,28 @@ class FedYoloClient(object):
if args.distributed:
train_sampler = data.DistributedSampler(
data_set, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
dataset, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
)
else:
train_sampler = None
loader = data.DataLoader(
data_set,
dataset,
batch_size=self._batch_size,
shuffle=train_sampler is None,
shuffle=(train_sampler is None),
sampler=train_sampler,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=Dataset.collate_fn,
drop_last=False,
)
# Scheduler
num_steps = max(1, len(loader))
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
# DDP mode
if args.distributed:
# ---- SyncBN + DDP (if any) ----
is_ddp = bool(args.distributed)
if is_ddp:
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
self.model = nn.parallel.DistributedDataParallel(
module=self.model,
@@ -164,102 +162,133 @@ class FedYoloClient(object):
find_unused_parameters=False,
)
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True)
# ---- AMP + loss ----
scaler = torch.amp.grad_scaler.GradScaler(enabled=True)
# criterion = util.ComputeLoss(
# self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model,
# self.params,
# )
criterion = util.ComputeLoss(self.model, self.params)
# log
# if args.local_rank == 0:
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
# print("\n" + header)
# p_bar = tqdm(total=args.epochs * num_steps, ncols=120)
# p_bar.set_description(f"{self.name:>10}")
# ---- Training ----
for epoch in range(args.epochs):
# (self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model).train()
self.model.train()
# when distributed, set epoch for shuffling
if args.distributed and train_sampler is not None:
if is_ddp and train_sampler is not None:
train_sampler.set_epoch(epoch)
if args.epochs - epoch == 10:
# disable mosaic augmentation in the last 10 epochs
# disable mosaic in the last 10 epochs (if dataset supports it)
if args.epochs - epoch == 10 and hasattr(loader.dataset, "mosaic"):
ds = cast(Dataset, loader.dataset)
ds.mosaic = False
optimizer.zero_grad(set_to_none=True)
avg_box_loss = util.AverageMeter()
avg_cls_loss = util.AverageMeter()
avg_dfl_loss = util.AverageMeter()
loss_box_meter = util.AverageMeter()
loss_cls_meter = util.AverageMeter()
loss_dfl_meter = util.AverageMeter()
# # --- header (once per epoch, YOLO-style) ---
# if args.local_rank == 0:
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
# print("\n" + header)
for i, (images, targets) in enumerate(loader):
print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}")
step = i + epoch * num_steps
# p_bar = enumerate(loader)
# if args.local_rank == 0:
# p_bar = tqdm(p_bar, total=num_steps, ncols=120)
# scheduler per-step (your util.LinearLR expects step)
scheduler.step(step=step, optimizer=optimizer)
for i, (samples, targets) in enumerate(loader):
global_step = i + num_steps * epoch
scheduler.step(step=global_step, optimizer=optimizer)
# images = images.to(device, non_blocking=True).float() / 255.0
images = images.cuda().float() / 255.0
bs = images.size(0)
# total_imgs_seen += bs
samples = samples.cuda(non_blocking=True).float() / 255.0
# targets: keep as your ComputeLoss expects (often CPU lists/tensors).
# Move to GPU here only if your loss requires it.
# Forward
with autocast("cuda", enabled=True):
outputs = self.model(samples)
with autocast(device_type="cuda", enabled=True):
outputs = self.model(images) # DDP wraps forward
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
# meters (use the *unscaled* values)
bs = samples.size(0)
avg_box_loss.update(box_loss.item(), bs)
avg_cls_loss.update(cls_loss.item(), bs)
avg_dfl_loss.update(dfl_loss.item(), bs)
# total_loss = box_loss + cls_loss + dfl_loss
# Gradient accumulation: normalize by 'accumulate' so LR stays effective
# total_loss = total_loss / accumulate
# scale losses by batch/world if your loss is averaged internally per-sample/device
# box_loss = box_loss * self._batch_size * args.world_size
# cls_loss = cls_loss * self._batch_size * args.world_size
# dfl_loss = dfl_loss * self._batch_size * args.world_size
# IMPORTANT: assume criterion returns **average per image** in the batch.
# Keep logging on the true (unscaled) values:
loss_box_meter.update(box_loss.item(), bs)
loss_cls_meter.update(cls_loss.item(), bs)
loss_dfl_meter.update(dfl_loss.item(), bs)
box_loss *= self._batch_size
cls_loss *= self._batch_size
dfl_loss *= self._batch_size
box_loss *= args.world_size
cls_loss *= args.world_size
dfl_loss *= args.world_size
total_loss = box_loss + cls_loss + dfl_loss
# Backward
amp_scale.scale(total_loss).backward()
scaler.scale(total_loss).backward()
# Optimize
if (i + 1) % accumulate == 0:
amp_scale.unscale_(optimizer) # unscale gradients
util.clip_gradients(model=self.model, max_norm=10.0) # clip gradients
amp_scale.step(optimizer)
amp_scale.update()
# optimize
if step % accumulate == 0:
# scaler.unscale_(optimizer)
# util.clip_gradients(self.model)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
if ema:
ema.update(self.model)
# torch.cuda.synchronize()
# tqdm update
# if args.local_rank == 0:
# mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
# desc = ("%10s" * 2 + "%10.4g" * 3) % (
# self.name,
# mem,
# avg_box_loss.avg,
# avg_cls_loss.avg,
# avg_dfl_loss.avg,
# # Step when we have 'accumulate' micro-batches, or at the end
# if ((i + 1) % accumulate == 0) or (i + 1 == len(loader)):
# scaler.unscale_(optimizer)
# util.clip_gradients(
# model=(
# self.model.module
# if isinstance(self.model, nn.parallel.DistributedDataParallel)
# else self.model
# ),
# max_norm=10.0,
# )
# cast(tqdm, p_bar).set_description(desc)
# p_bar.update(1)
# scaler.step(optimizer)
# scaler.update()
# optimizer.zero_grad(set_to_none=True)
# p_bar.close()
if ema:
# Update EMA from the underlying module
ema.update(
self.model.module
if isinstance(self.model, nn.parallel.DistributedDataParallel)
else self.model
)
# print loss to test
print(
f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}"
)
torch.cuda.synchronize()
# clean
if args.distributed:
# ---- Final average loss (per image) over the whole epoch span ----
avg_loss_per_image = loss_box_meter.avg + loss_cls_meter.avg + loss_dfl_meter.avg
# ---- Cleanup DDP ----
if is_ddp:
torch.distributed.destroy_process_group()
torch.cuda.empty_cache()
return (
self.model.state_dict() if not ema else ema.ema.state_dict(),
self.n_data,
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg},
)
# ---- Choose which weights to return ----
# - If EMA exists, return EMA weights (common YOLO eval practice)
# - Be careful with DDP: grab state_dict from the underlying module / EMA model
if ema:
# print("Using EMA weights")
return (ema.ema.state_dict(), self.n_data, avg_loss_per_image)
else:
# Safely get the underlying module if wrapped by DDP; getattr returns the module or the original object.
model_obj = getattr(self.model, "module", self.model)
# If it's a proper nn.Module, call state_dict(); if it's already a state dict, use it;
# otherwise try to call state_dict() and finally fall back to wrapping the object.
if isinstance(model_obj, torch.nn.Module):
model_to_return = model_obj.state_dict()
elif isinstance(model_obj, dict):
model_to_return = model_obj
else:
try:
model_to_return = model_obj.state_dict()
except Exception:
# fallback: if model_obj is a tensor or unexpected object, wrap it in a dict
model_to_return = {"state": model_obj}
return model_to_return, self.n_data, avg_loss_per_image

View File

@@ -4,6 +4,7 @@ from torch.utils.data import DataLoader
from utils.fed_util import init_model
from utils.dataset import Dataset
from utils import util
from nets import YOLO
class FedYoloServer(object):
@@ -21,7 +22,7 @@ class FedYoloServer(object):
self.client_n_data = {}
self.selected_clients = []
self._batch_size = params.get("val_batch_size", 4)
self._batch_size = params.get("val_batch_size", 200)
self.client_list = client_list
self.valset = None
@@ -40,7 +41,7 @@ class FedYoloServer(object):
self.model = init_model(model_name, self._num_classes)
self.params = params
def load_valset(self, valset):
def load_valset(self, valset: Dataset):
"""Server loads the validation dataset."""
self.valset = valset
@@ -48,78 +49,6 @@ class FedYoloServer(object):
"""Return global model weights."""
return self.model.state_dict()
@torch.no_grad()
def test(self, args) -> dict:
"""
Test the global model on the server's validation set.
Returns:
dict with keys: mAP, mAP50, precision, recall
"""
if self.valset is None:
return {}
loader = DataLoader(
self.valset,
batch_size=self._batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
collate_fn=Dataset.collate_fn,
)
dev = self._device
# move to device for eval; keep in float32 for stability
self.model.eval().to(dev).float()
iou_v = torch.linspace(0.5, 0.95, 10, device=dev)
n_iou = iou_v.numel()
metrics = []
for samples, targets in loader:
samples = samples.to(dev, non_blocking=True).float() / 255.0
_, _, h, w = samples.shape
scale = torch.tensor((w, h, w, h), device=dev)
outputs = self.model(samples)
outputs = util.non_max_suppression(outputs)
for i, output in enumerate(outputs):
idx = targets["idx"] == i
cls = targets["cls"][idx].to(dev)
box = targets["box"][idx].to(dev)
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=dev)
if output.shape[0] == 0:
if cls.shape[0]:
metrics.append((metric, *torch.zeros((2, 0), device=dev), cls.squeeze(-1)))
continue
if cls.shape[0]:
if cls.dim() == 1:
cls = cls.unsqueeze(1)
box_xy = util.wh2xy(box)
if not isinstance(box_xy, torch.Tensor):
box_xy = torch.tensor(box_xy, device=dev)
target = torch.cat((cls, box_xy * scale), dim=1)
metric = util.compute_metric(output[:, :6], target, iou_v)
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
if not metrics:
# move back to CPU before returning
self.model.to("cpu").float()
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
if len(metrics) and metrics[0].any():
_, _, prec, rec, map50, mean_ap = util.compute_ap(*metrics, names=self.params["names"], plot=False)
else:
prec, rec, map50, mean_ap = 0, 0, 0, 0
# return model to CPU so next agg() stays device-consistent
self.model.to("cpu").float()
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
def select_clients(self, connection_ratio=1.0):
"""
Randomly select a fraction of clients.
@@ -130,80 +59,69 @@ class FedYoloServer(object):
self.n_data = 0
for client_id in self.client_list:
# Random selection based on connection ratio
if np.random.rand() <= connection_ratio:
s = np.random.binomial(np.ones(1).astype(int), connection_ratio)
if s[0] == 1:
self.selected_clients.append(client_id)
self.n_data += self.client_n_data.get(client_id, 0)
self.n_data += self.client_n_data[client_id]
@torch.no_grad()
def agg(self):
"""Aggregate client updates (FedAvg) on CPU/FP32, preserving non-float buffers."""
"""
Server aggregates the local updates from selected clients using FedAvg.
:return: model_state: aggregated model weights
:return: avg_loss: weighted average training loss across selected clients
:return: n_data: total number of data points across selected clients
"""
if len(self.selected_clients) == 0 or self.n_data == 0:
return self.model.state_dict(), {}, 0
import warnings
# Ensure global model is on CPU for safe load later
self.model.to("cpu")
global_state = self.model.state_dict() # may hold CPU or CUDA refs; were on CPU now
warnings.warn("No clients selected or no data available for aggregation.")
return self.model.state_dict(), 0, 0
avg_loss = {}
total_n = float(self.n_data)
# Initialize a model for aggregation
model = init_model(model_name=self.model_name, num_classes=self._num_classes)
model_state = model.state_dict()
# Prepare accumulators on CPU. For floating tensors, use float32 zeros.
# For non-floating tensors (e.g., BN num_batches_tracked int64), well copy from the first client.
new_state = {}
first_client = None
for cid in self.selected_clients:
if cid in self.client_state:
first_client = cid
break
avg_loss = 0
assert first_client is not None, "No client states available to aggregate."
for k, v in global_state.items():
if v.is_floating_point():
new_state[k] = torch.zeros_like(v.detach().cpu(), dtype=torch.float32)
else:
# For non-float buffers, just copy from the first client (or keep global)
new_state[k] = self.client_state[first_client][k].clone()
# Accumulate floating tensors with weights; keep non-floats as assigned above
for cid in self.selected_clients:
if cid not in self.client_state:
# Aggregate the local updated models from selected clients
for i, name in enumerate(self.selected_clients):
if name not in self.client_state:
continue
weight = self.client_n_data[cid] / total_n
cst = self.client_state[cid]
for k in new_state.keys():
if new_state[k].is_floating_point():
# cst[k] is CPU; ensure float32 for accumulation
new_state[k].add_(cst[k].to(torch.float32), alpha=weight)
# weighted average losses
for lk, lv in self.client_loss[cid].items():
avg_loss[lk] = avg_loss.get(lk, 0.0) + float(lv) * weight
# Load aggregated state back into the global model (model is on CPU)
with torch.no_grad():
self.model.load_state_dict(new_state, strict=True)
for key in self.client_state[name]:
if i == 0:
# First client, initialize the model_state
model_state[key] = self.client_state[name][key] * (self.client_n_data[name] / self.n_data)
else:
# math equation: w = sum(n_k / n * w_k)
model_state[key] = model_state[key] + self.client_state[name][key] * (
self.client_n_data[name] / self.n_data
)
avg_loss = avg_loss + self.client_loss[name] * (self.client_n_data[name] / self.n_data)
self.model.load_state_dict(model_state, strict=True)
self.round += 1
# Return CPU state_dict (good for broadcasting to clients)
return {k: v.clone() for k, v in self.model.state_dict().items()}, avg_loss, int(self.n_data)
def rec(self, name, state_dict, n_data, loss_dict):
n_data = self.n_data
return model_state, avg_loss, n_data
def rec(self, name, state_dict, n_data, loss):
"""
Receive local update from a client.
- Store all floating tensors as CPU float32
- Store non-floating tensors (e.g., BN counters) as CPU in original dtype
"""
self.n_data += n_data
safe_state = {}
with torch.no_grad():
for k, v in state_dict.items():
t = v.detach().cpu()
if t.is_floating_point():
t = t.to(torch.float32)
safe_state[k] = t
self.client_state[name] = safe_state
self.client_state[name] = {}
self.client_n_data[name] = {}
self.client_loss[name] = {}
self.client_state[name].update(state_dict)
self.client_n_data[name] = int(n_data)
self.client_loss[name] = {k: float(v) for k, v in loss_dict.items()}
self.client_loss[name] = loss
def flush(self):
"""Clear stored client updates."""
@@ -211,3 +129,94 @@ class FedYoloServer(object):
self.client_state.clear()
self.client_n_data.clear()
self.client_loss.clear()
def test(self):
"""Evaluate the global model on the server's validation dataset."""
if self.valset is None:
import warnings
warnings.warn("No validation dataset available for testing.")
return {}
return test(self.valset, self.params, self.model)
@torch.no_grad()
def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]:
"""
Evaluate the model on the validation dataset.
Args:
valset: validation dataset
params: dict of parameters (must include 'names')
model: YOLO model to evaluate
batch_size: batch size for evaluation
Returns:
dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap)
"""
loader = DataLoader(
dataset=valset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
collate_fn=Dataset.collate_fn,
)
model.cuda()
model.half()
model.eval()
# Configure
iou_v = torch.linspace(start=0.5, end=0.95, steps=10).cuda() # iou vector for mAP@0.5:0.95
n_iou = iou_v.numel()
m_pre = 0
m_rec = 0
map50 = 0
mean_ap = 0
metrics = []
for samples, targets in loader:
samples = samples.cuda()
samples = samples.half() # uint8 to fp16/32
samples = samples / 255.0 # 0 - 255 to 0.0 - 1.0
_, _, h, w = samples.shape # batch-size, channels, height, width
scale = torch.tensor((w, h, w, h)).cuda()
# Inference
outputs = model(samples)
# NMS
outputs = util.non_max_suppression(outputs)
# Metrics
for i, output in enumerate(outputs):
idx = targets["idx"]
if idx.dim() > 1:
idx = idx.squeeze(-1)
idx = idx == i
# idx = targets["idx"] == i
cls = targets["cls"][idx]
box = targets["box"][idx]
cls = cls.cuda()
box = box.cuda()
metric = torch.zeros(output.shape[0], n_iou, dtype=torch.bool).cuda()
if output.shape[0] == 0:
if cls.shape[0]:
metrics.append((metric, *torch.zeros((2, 0)).cuda(), cls.squeeze(-1)))
continue
# Evaluate
if cls.shape[0]:
target = torch.cat(tensors=(cls, util.wh2xy(box) * scale), dim=1)
metric = util.compute_metric(output[:, :6], target, iou_v)
# Append
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
# Compute metrics
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
if len(metrics) and metrics[0].any():
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"])
# Print results
# print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
# Return results
model.float() # for training
return mean_ap, map50, m_rec, m_pre

View File

@@ -3,92 +3,16 @@ import os
import json
import yaml
import time
import random
from tqdm import tqdm
import numpy as np
import torch
import matplotlib.pyplot as plt
from utils.dataset import Dataset
from utils.fed_util import build_valset_if_available, seed_everything, plot_curves
from fed_algo_cs.client_base import FedYoloClient
from fed_algo_cs.server_base import FedYoloServer
from utils.args import args_parser # args parser
from utils.fed_util import divide_trainset # divide_trainset
def _read_list_file(txt_path: str):
"""Read one path per line; keep as-is (absolute or relative)."""
if not txt_path or not os.path.exists(txt_path):
return []
with open(txt_path, "r", encoding="utf-8") as f:
return [ln.strip() for ln in f if ln.strip()]
def _build_valset_if_available(cfg, params):
"""
Try to build a validation Dataset.
- If cfg['val_txt'] exists, use it.
- Else if <dataset_path>/val.txt exists, use it.
- Else return None (testing will be skipped).
Args:
cfg: config dict
params: params dict for Dataset
Returns:
Dataset or None
"""
input_size = int(cfg.get("input_size", 640))
val_txt = cfg.get("val_txt", "")
if not val_txt:
ds_root = cfg.get("dataset_path", "")
guess = os.path.join(ds_root, "val.txt") if ds_root else ""
val_txt = guess if os.path.exists(guess) else ""
val_files = _read_list_file(val_txt)
if not val_files:
return None
return Dataset(
filenames=val_files,
input_size=input_size,
params=params,
augment=True,
)
def _seed_everything(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
def _plot_curves(save_dir, hist):
"""
Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round.
"""
os.makedirs(save_dir, exist_ok=True)
rounds = np.arange(1, len(hist["mAP"]) + 1)
plt.figure()
if hist["mAP"]:
plt.plot(rounds, hist["mAP"], label="mAP50-95")
if hist["mAP50"]:
plt.plot(rounds, hist["mAP50"], label="mAP50")
if hist["precision"]:
plt.plot(rounds, hist["precision"], label="precision")
if hist["recall"]:
plt.plot(rounds, hist["recall"], label="recall")
if hist["train_loss"]:
plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)")
plt.xlabel("Global Round")
plt.ylabel("Metric")
plt.title("Federated YOLO - Server Metrics")
plt.legend()
out_png = os.path.join(save_dir, "fed_yolo_curves.png")
plt.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"[plot] saved: {out_png}")
def fed_run():
"""
Main FL process:
@@ -98,20 +22,22 @@ def fed_run():
- Record & save results, plot curves
"""
args_cli = args_parser()
# TODO: cfg and params should not be separately defined
with open(args_cli.config, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
# --- params / config normalization ---
# For convenience we pass the same `params` dict used by Dataset/model/loss.
# Here we re-use the top-level cfg directly as params.
params = dict(cfg)
# params = dict(cfg)
if "names" in cfg and isinstance(cfg["names"], dict):
# Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
# but we can leave dict; your utils appear to accept dict
pass
# seeds
_seed_everything(int(cfg.get("i_seed", 0)))
seed_everything(int(cfg.get("i_seed", 0)))
# --- split clients' train data from a global train list ---
# Expect either cfg["train_txt"] or <dataset_path>/train.txt
@@ -144,13 +70,13 @@ def fed_run():
clients = {}
for uid in users:
c = FedYoloClient(name=uid, model_name=model_name, params=params)
c = FedYoloClient(name=uid, model_name=model_name, params=cfg)
c.load_trainset(user_data[uid]["filename"])
clients[uid] = c
# --- build server & optional validation set ---
server = FedYoloServer(client_list=users, model_name=model_name, params=params)
valset = _build_valset_if_available(cfg, params)
server = FedYoloServer(client_list=users, model_name=model_name, params=cfg)
valset = build_valset_if_available(cfg, params=cfg, args=args_cli)
# valset is a Dataset class, not data loader
if valset is not None:
server.load_valset(valset)
@@ -186,27 +112,25 @@ def fed_run():
t0 = time.time()
# Local training (sequential over all users)
for uid in users:
# tqdm desc update
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
client = clients[uid] # FedYoloClient instance
client.update(global_state) # load global weights
state_dict, n_data, loss_dict = client.train(args_cli) # local training
server.rec(uid, state_dict, n_data, loss_dict)
state_dict, n_data, train_loss = client.train(args_cli) # local training
server.rec(uid, state_dict, n_data, train_loss)
# Select a fraction for aggregation (FedAvg subset if desired)
server.select_clients(connection_ratio=connection_ratio)
# Aggregate
global_state, avg_loss_dict, _ = server.agg()
global_state, avg_loss, _ = server.agg()
# Compute a scalar train loss for plotting (sum of components)
scalar_train_loss = float(sum(avg_loss_dict.values())) if avg_loss_dict else 0.0
scalar_train_loss = avg_loss if avg_loss else 0.0
# Test (if valset provided)
test_metrics = server.test(args_cli) if server.valset is not None else {}
mAP = float(test_metrics.get("mAP", 0.0))
mAP50 = float(test_metrics.get("mAP50", 0.0))
precision = float(test_metrics.get("precision", 0.0))
recall = float(test_metrics.get("recall", 0.0))
mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0)
# Flush per-round client caches
server.flush()
@@ -233,22 +157,23 @@ def fed_run():
p_bar.set_postfix(desc)
# Save running JSON (resumable logs)
save_name = (
f"[{cfg.get('fed_algo', 'FedAvg')},{cfg.get('model_name', 'yolo')},"
f"{cfg.get('num_local_epoch', cfg.get('client', {}).get('num_local_epoch', 1))},"
f"{cfg.get('num_local_class', 2)},"
f"{cfg.get('i_seed', 0)}]"
)
save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s"
out_json = os.path.join(res_root, save_name + ".json")
with open(out_json, "w", encoding="utf-8") as f:
json.dump(history, f, indent=2)
json.dump(history, f, indent=4)
p_bar.update(1)
p_bar.close()
# Save final global model weights
if not os.path.exists("./weights"):
os.makedirs("./weights", exist_ok=True)
torch.save(global_state, f"./weights/{save_name}_final.pth")
print(f"[save] final global model weights: ./weights/{save_name}_final.pth")
# --- final plot ---
_plot_curves(res_root, history)
plot_curves(res_root, history, savename=f"{save_name}_curve.png")
print("[done] training complete.")

2
fed_run.sh Normal file
View File

@@ -0,0 +1,2 @@
GPUS=$1
python3 -m torch.distributed.run --nproc_per_node=$GPUS fed_run.py ${@:2}

3
nets/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .nn import YOLO, yolo_v11_l, yolo_v11_m, yolo_v11_s, yolo_v11_t, yolo_v11_x, yolo_v11_n
__all__ = ["YOLO", "yolo_v11_l", "yolo_v11_m", "yolo_v11_s", "yolo_v11_t", "yolo_v11_x", "yolo_v11_n"]

77
testcode.py Normal file
View File

@@ -0,0 +1,77 @@
from utils.fed_util import init_model
from fed_algo_cs.server_base import test
import os
import yaml
from utils.args import args_parser # args parser
from fed_algo_cs.client_base import FedYoloClient # FedYoloClient
from fed_algo_cs.server_base import FedYoloServer # FedYoloServer
from utils import Dataset # Dataset
if __name__ == "__main__":
# model structure test
model = init_model("yolo_v11_n", num_classes=1)
with open("model.txt", "w", encoding="utf-8") as f:
print(model, file=f)
# loop over model key and values
with open("model_key_value.txt", "w", encoding="utf-8") as f:
for k, v in model.state_dict().items():
print(f"{k}: {v.shape}", file=f)
# test agg function
from fed_algo_cs.server_base import FedYoloServer
import torch
import yaml
with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
params = dict(cfg)
server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=params)
state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()}
state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()}
state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()}
server.rec("c1", state1, n_data=20, loss=0.1)
server.rec("c2", state2, n_data=30, loss=0.2)
server.rec("c3", state3, n_data=50, loss=0.3)
server.select_clients(connection_ratio=1.0)
model_state, avg_loss, n_data = server.agg()
with open("agg_model.txt", "w", encoding="utf-8") as f:
for k, v in model_state.items():
print(f"{k}: {v.float().mean()}", file=f)
print(f"avg_loss: {avg_loss}, n_data: {n_data}")
# test single client training (should be the same as standalone training)
args = args_parser()
with open(args.config, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
# params = dict(cfg)
client = FedYoloClient(name="c1", params=cfg, model_name="yolo_v11_n")
filenames = []
data_dir = "/home/image1325/ssd1/dataset/COCO128"
with open(f"{data_dir}/train.txt") as f:
for filename in f.readlines():
filename = os.path.basename(filename.rstrip())
filenames.append(f"{data_dir}/images/train2017/" + filename)
client.load_trainset(train_dataset=filenames)
model_state, n_data, avg_loss = client.train(args=args)
model = init_model("yolo_v11_n", num_classes=80)
model.load_state_dict(model_state)
valset = Dataset(
filenames=filenames,
input_size=640,
params=cfg,
augment=False,
)
if valset is not None:
precision, recall, map50, map = test(valset=valset, params=cfg, model=model, batch_size=128)
print(
f"precision: {precision}, recall: {recall}, map50: {map50}, map: {map}, loss: {avg_loss}, n_data: {n_data}"
)
else:
raise ValueError("valset is None, please provide a valid valset in config file.")

3
utils/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .dataset import *
from .fed_util import *
from .util import *

View File

@@ -8,16 +8,11 @@ import torch
from PIL import Image
from torch.utils import data
FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "JPEG", "JPG", "PNG", "TIFF"
FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp"
class Dataset(data.Dataset):
params: dict
mosaic: bool
augment: bool
input_size: int
def __init__(self, filenames, input_size: int, params: dict, augment: bool):
def __init__(self, filenames, input_size, params, augment):
self.params = params
self.mosaic = augment
self.augment = augment
@@ -48,8 +43,6 @@ class Dataset(data.Dataset):
else:
# Load image
image, shape = self.load_image(index)
if image is None:
raise ValueError(f"Failed to load image at index {index}: {self.filenames[index]}")
h, w = image.shape[:2]
# Resize
@@ -57,7 +50,7 @@ class Dataset(data.Dataset):
label = self.labels[index].copy()
if label.size:
label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, int(pad[0]), int(pad[1]))
label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, pad[0], pad[1])
if self.augment:
image, label = random_perspective(image, label, self.params)
@@ -84,25 +77,24 @@ class Dataset(data.Dataset):
if nl:
box[:, 0] = 1 - box[:, 0]
# target_cls = torch.zeros((nl, 1))
# target_box = torch.zeros((nl, 4))
# if nl:
# target_cls = torch.from_numpy(cls)
# target_box = torch.from_numpy(box)
# fix [cls, box] empty bug. e.g. [0,1] is illegal in DataLoader collate_fn cat operation
# XXX: when nl=0, torch.from_numpy(box) will error
if nl:
target_cls = torch.from_numpy(cls).view(-1, 1).float() # always (N,1)
target_box = torch.from_numpy(box).reshape(-1, 4).float() # always (N,4)
else:
target_cls = torch.zeros((0, 1), dtype=torch.float32)
target_box = torch.zeros((0, 4), dtype=torch.float32)
# target_cls = torch.zeros((nl, 1))
# target_box = torch.zeros((nl, 4))
# if nl:
# target_cls = torch.from_numpy(cls)
# target_box = torch.from_numpy(box)
# Convert HWC to CHW, BGR to RGB
sample = image.transpose((2, 0, 1))[::-1]
sample = numpy.ascontiguousarray(sample)
# init: return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl)
# return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl)
return torch.from_numpy(sample), target_cls, target_box, torch.zeros((nl, 1), dtype=torch.long)
def __len__(self):
@@ -111,7 +103,7 @@ class Dataset(data.Dataset):
def load_image(self, i):
image = cv2.imread(self.filenames[i])
if image is None:
raise ValueError(f"Image not found or unable to open: {self.filenames[i]}")
raise FileNotFoundError(f"Image Not Found {self.filenames[i]}")
h, w = image.shape[:2]
r = self.input_size / max(h, w)
if r != 1:
@@ -173,8 +165,8 @@ class Dataset(data.Dataset):
x2b = min(shape[1], x2a - x1a)
y2b = min(y2a - y1a, shape[0])
pad_w = (x1a if x1a is not None else 0) - (x1b if x1b is not None else 0)
pad_h = (y1a if y1a is not None else 0) - (y1b if y1b is not None else 0)
pad_w = x1a - x1b
pad_h = y1a - y1b
image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]
# Labels
@@ -197,14 +189,8 @@ class Dataset(data.Dataset):
def collate_fn(batch):
samples, cls, box, indices = zip(*batch)
# ensure empty tensor shape is correct
cls = [c.view(-1, 1) for c in cls]
box = [b.reshape(-1, 4) for b in box]
indices = [i for i in indices]
cls = torch.cat(cls, dim=0) if cls else torch.zeros((0, 1))
box = torch.cat(box, dim=0) if box else torch.zeros((0, 4))
indices = torch.cat(indices, dim=0) if indices else torch.zeros((0,), dtype=torch.long)
cls = torch.cat(cls, dim=0)
box = torch.cat(box, dim=0)
new_indices = list(indices)
for i in range(len(indices)):
@@ -215,7 +201,7 @@ class Dataset(data.Dataset):
return torch.stack(samples, dim=0), targets
@staticmethod
def load_label_use_cache(filenames):
def load_label(filenames):
path = f"{os.path.dirname(filenames[0])}.cache"
if os.path.exists(path):
return torch.load(path, weights_only=False)
@@ -228,14 +214,11 @@ class Dataset(data.Dataset):
image.verify() # PIL verify
shape = image.size # image size
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert image.format is not None and image.format.lower() in FORMATS, (
f"invalid image format {image.format}"
)
assert image.format.lower() in FORMATS, f"invalid image format {image.format}"
# verify labels
a = f"{os.sep}images{os.sep}"
b = f"{os.sep}labels{os.sep}"
if os.path.isfile(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"):
with open(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt") as f:
label = [x.split() for x in f.read().strip().splitlines() if len(x)]
@@ -260,50 +243,6 @@ class Dataset(data.Dataset):
torch.save(x, path)
return x
@staticmethod
def load_label(filenames):
x = {}
for filename in filenames:
try:
# verify images
with open(filename, "rb") as f:
image = Image.open(f)
image.verify()
shape = image.size
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert image.format is not None and image.format.lower() in FORMATS, (
f"invalid image format {image.format}"
)
# verify labels
a = f"{os.sep}images{os.sep}"
b = f"{os.sep}labels{os.sep}"
label_path = b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"
if os.path.isfile(label_path):
rows = []
with open(label_path) as f:
for line in f:
parts = line.strip().split()
if len(parts) == 5: # YOLO format
rows.append([float(x) for x in parts])
label = numpy.array(rows, dtype=numpy.float32) if rows else numpy.zeros((0, 5), dtype=numpy.float32)
if label.shape[0]:
assert (label >= 0).all()
assert label.shape[1] == 5
assert (label[:, 1:] <= 1.0001).all()
_, i = numpy.unique(label, axis=0, return_index=True)
label = label[i]
else:
label = numpy.zeros((0, 5), dtype=numpy.float32)
except (FileNotFoundError, AssertionError):
label = numpy.zeros((0, 5), dtype=numpy.float32)
x[filename] = label
return x
def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0):
# Convert nx4 boxes
@@ -462,9 +401,7 @@ class Albumentations:
albumentations.ToGray(p=0.01),
albumentations.MedianBlur(p=0.01),
]
self.transform = albumentations.Compose(
transforms, albumentations.BboxParams(format="yolo", label_fields=["class_labels"])
)
self.transform = albumentations.Compose(transforms, albumentations.BboxParams("yolo", ["class_labels"]))
except ImportError: # package not installed, skip
pass

View File

@@ -1,10 +1,15 @@
import os
import re
import random
import matplotlib.pyplot as plt
from utils.dataset import Dataset
import numpy as np
import torch
from collections import defaultdict
from typing import Dict, List, Optional, Set, Any
from nets import nn
from nets import YOLO
def _image_to_label_path(img_path: str) -> str:
@@ -59,6 +64,14 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]:
return class_ids
def _read_list_file(txt_path: str):
"""Read one path per line; keep as-is (absolute or relative)."""
if not txt_path or not os.path.exists(txt_path):
return []
with open(txt_path, "r", encoding="utf-8") as f:
return [ln.strip() for ln in f if ln.strip()]
def divide_trainset(
trainset_path: str,
num_local_class: int,
@@ -230,7 +243,7 @@ def divide_trainset(
return result
def init_model(model_name, num_classes):
def init_model(model_name, num_classes) -> YOLO:
"""
Initialize the model for a specific learning task
Args:
@@ -252,3 +265,74 @@ def init_model(model_name, num_classes):
raise ValueError("Model {} is not supported.".format(model_name))
return model
def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]:
"""
Try to build a validation Dataset.
- If cfg['val_txt'] exists, use it.
- Else if <dataset_path>/val.txt exists, use it.
- Else return None (testing will be skipped).
Args:
cfg: config dict
params: params dict for Dataset
Returns:
Dataset or None
"""
input_size = args.input_size if args and hasattr(args, "input_size") else 640
val_txt = cfg.get("val_txt", "")
if not val_txt:
ds_root = cfg.get("dataset_path", "")
guess = os.path.join(ds_root, "val.txt") if ds_root else ""
val_txt = guess if os.path.exists(guess) else ""
val_files = _read_list_file(val_txt)
if not val_files:
import warnings
warnings.warn("No validation dataset found.")
return None
return Dataset(
filenames=val_files,
input_size=input_size,
params=params,
augment=True,
)
def seed_everything(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"):
"""
Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round.
Args:
save_dir: directory to save the plot
hist: history dict with keys "mAP", "mAP50", "precision", "recall", "train_loss"
savename: output filename
"""
os.makedirs(save_dir, exist_ok=True)
rounds = np.arange(1, len(hist["mAP"]) + 1)
plt.figure()
if hist["mAP"]:
plt.plot(rounds, hist["mAP"], label="mAP50-95")
if hist["mAP50"]:
plt.plot(rounds, hist["mAP50"], label="mAP50")
if hist["precision"]:
plt.plot(rounds, hist["precision"], label="precision")
if hist["recall"]:
plt.plot(rounds, hist["recall"], label="recall")
if hist["train_loss"]:
plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)")
plt.xlabel("Global Round")
plt.ylabel("Metric")
plt.title("Federated YOLO - Server Metrics")
plt.legend()
out_png = os.path.join(save_dir, savename)
plt.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"[plot] saved: {out_png}")

View File

@@ -1,7 +1,3 @@
"""
Utility functions for yolo.
"""
import copy
import random
from time import time
@@ -97,7 +93,7 @@ def make_anchors(x, strides, offset=0.5):
_, _, h, w = x[i].shape
sx = torch.arange(end=w, device=device, dtype=dtype) + offset # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + offset # shift y
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
sy, sx = torch.meshgrid(sy, sx)
anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2))
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
return torch.cat(anchor_tensor), torch.cat(stride_tensor)
@@ -151,7 +147,7 @@ def non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.65)
box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2)
if nc > 1:
i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 4 + j].unsqueeze(1), j[:, None].float()), dim=1)
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float()), 1)
else: # best class only
conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold]
@@ -195,13 +191,7 @@ def plot_pr_curve(px, py, ap, names, save_dir):
else:
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
ax.plot(
px,
py.mean(1),
linewidth=3,
color="blue",
label="all classes %.3f mAP@0.5" % ap[:, 0].mean(),
)
ax.plot(px, py.mean(1), linewidth=3, color="blue", label="all classes %.3f mAP@0.5" % ap[:, 0].mean())
ax.set_xlabel("Recall")
ax.set_ylabel("Precision")
ax.set_xlim(0, 1)
@@ -224,13 +214,7 @@ def plot_curve(px, py, names, save_dir, x_label="Confidence", y_label="Metric"):
ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
y = smooth(py.mean(0), f=0.05)
ax.plot(
px,
y,
linewidth=3,
color="blue",
label=f"all classes {y.max():.3f} at {px[y.argmax()]:.3f}",
)
ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.3f} at {px[y.argmax()]:.3f}")
ax.set_xlabel(x_label)
ax.set_ylabel(y_label)
ax.set_xlim(0, 1)
@@ -296,8 +280,7 @@ def compute_ap(tp, conf, output, target, plot=False, names=(), eps=1e-16):
# Integrate area under curve
x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO)
# numpy.trapz is deprecated in numpy 2.0.0 or after version, use numpy.trapezoid instead
ap[ci, j] = numpy.trapezoid(numpy.interp(x, m_rec, m_pre), x) # integrate
ap[ci, j] = numpy.trapz(numpy.interp(x, m_rec, m_pre), x) # integrate
if plot and j == 0:
py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5
@@ -443,7 +426,7 @@ class LinearLR:
min_lr = params["min_lr"]
warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100))
decay_steps = max(1, int(args.epochs * num_steps - warmup_steps))
decay_steps = int(args.epochs * num_steps - warmup_steps)
warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps), endpoint=False)
decay_lr = numpy.linspace(max_lr, min_lr, decay_steps)
@@ -528,16 +511,8 @@ class Assigner(torch.nn.Module):
mask_in_gts = mask_in_gts.view(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps)
na = pd_bboxes.shape[-2]
gt_mask = (mask_in_gts * mask_gt).bool() # b, max_num_obj, h*w
overlaps = torch.zeros(
[batch_size, num_max_boxes, na],
dtype=pd_bboxes.dtype,
device=pd_bboxes.device,
)
bbox_scores = torch.zeros(
[batch_size, num_max_boxes, na],
dtype=pd_scores.dtype,
device=pd_scores.device,
)
overlaps = torch.zeros([batch_size, num_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
bbox_scores = torch.zeros([batch_size, num_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long) # 2, b, max_num_obj
ind[0] = torch.arange(end=batch_size).view(-1, 1).expand(-1, num_max_boxes) # b, max_num_obj
@@ -587,9 +562,7 @@ class Assigner(torch.nn.Module):
target_labels.clamp_(0)
target_scores = torch.zeros(
(target_labels.shape[0], target_labels.shape[1], self.nc),
dtype=torch.int64,
device=target_labels.device,
(target_labels.shape[0], target_labels.shape[1], self.nc), dtype=torch.int64, device=target_labels.device
)
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
@@ -672,16 +645,7 @@ class BoxLoss(torch.nn.Module):
super().__init__()
self.dfl_ch = dfl_ch
def forward(
self,
pred_dist,
pred_bboxes,
anchor_points,
target_bboxes,
target_scores,
target_scores_sum,
fg_mask,
):
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
# IoU loss
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
iou = compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
@@ -803,13 +767,7 @@ class ComputeLoss:
if fg_mask.sum():
target_bboxes /= stride_tensor
loss_box, loss_dfl = self.box_loss(
pred_distri,
pred_bboxes,
anchor_points,
target_bboxes,
target_scores,
target_scores_sum,
fg_mask,
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
)
loss_box *= self.params["box"] # box gain