Compare commits

..

7 Commits

12 changed files with 504 additions and 468 deletions

5
.gitignore vendored
View File

@@ -296,5 +296,8 @@ Network Trash Folder
Temporary Items Temporary Items
.apdisk .apdisk
# ---> Custom
results/ results/
*.log *.log
*.txt
weights/

View File

@@ -21,4 +21,10 @@ nohup python fed_run.py > train.log 2>&1 &
- Implement FedNova - Implement FedNova
- Add more YOLO versions (e.g., YOLOv8, YOLOv5, etc.) - Add more YOLO versions (e.g., YOLOv8, YOLOv5, etc.)
- Implement YOLOv8 - Implement YOLOv8
- Implement YOLOv5 - Implement YOLOv5
# references
[PyTorch Federated Learning](https://github.com/rruisong/pytorch_federated_learning)
[YOLOv11-pt](https://github.com/jahongir7174/YOLOv11-pt)

View File

@@ -3,11 +3,11 @@ import torch
from torch import nn from torch import nn
from torch.utils import data from torch.utils import data
from torch.amp.autocast_mode import autocast from torch.amp.autocast_mode import autocast
from tqdm import tqdm
from utils.fed_util import init_model from utils.fed_util import init_model
from utils import util from utils import util
from utils.dataset import Dataset from utils.dataset import Dataset
from typing import cast from typing import cast
from tqdm import tqdm
class FedYoloClient(object): class FedYoloClient(object):
@@ -82,52 +82,48 @@ class FedYoloClient(object):
# load the global model parameters # load the global model parameters
self.model.load_state_dict(Global_model_state_dict, strict=True) self.model.load_state_dict(Global_model_state_dict, strict=True)
def train(self, args): def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
""" """
Train the local model Train the local model.
Args: Returns: (state_dict, n_data, avg_loss_per_image)
:param args: Command line arguments
- local_rank: Local rank for distributed training
- world_size: World size for distributed training
- distributed: Whether to use distributed training
- input_size: Input size for the model
Returns:
:return: Local updated model, number of local data points, training loss
""" """
# ---- Dist init (if any) ----
if args.distributed: if args.distributed:
torch.cuda.set_device(device=args.local_rank) torch.cuda.set_device(device=args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://") torch.distributed.init_process_group(backend="nccl", init_method="env://")
# print(f"Client {self.name} - distributed training on {world_size} GPUs, local rank: {local_rank}")
# self._device = torch.device("cuda", local_rank)
if args.local_rank == 0:
pass
# if not os.path.exists("weights"):
# os.makedirs("weights")
util.setup_seed() util.setup_seed()
util.setup_multi_processes() util.setup_multi_processes()
# model # device = torch.device(f"cuda:{args.local_rank}" if torch.cuda.is_available() else "cpu")
# init model have been done in __init__() # self.model.to(device)
self.model.to(self._device) self.model.cuda()
# show model architecture
# print(self.model)
# Optimizer # ---- Optimizer / WD scaling & LR warmup/schedule ----
accumulate = max(round(64 / (self._batch_size * args.world_size)), 1) # accumulate = effective grad-accumulation steps to emulate global batch 64
self._weight_decay = self._batch_size * args.world_size * accumulate / 64 world_size = getattr(args, "world_size", 1)
accumulate = max(round(64 / (self._batch_size * max(world_size, 1))), 1)
# scale weight_decay like YOLO recipes
scaled_wd = self._weight_decay * self._batch_size * max(world_size, 1) * accumulate / 64
optimizer = torch.optim.SGD( optimizer = torch.optim.SGD(
util.set_params(self.model, self._weight_decay), util.set_params(self.model, scaled_wd),
lr=self._min_lr, lr=self._min_lr,
momentum=self._momentum, momentum=self._momentum,
nesterov=True, nesterov=True,
) )
# EMA # ---- EMA (track the underlying module if DDP) ----
# track_model = self.model.module if is_ddp else self.model
ema = util.EMA(self.model) if args.local_rank == 0 else None ema = util.EMA(self.model) if args.local_rank == 0 else None
data_set = Dataset( print(type(self.train_dataset))
# ---- Data ----
dataset = Dataset(
filenames=self.train_dataset, filenames=self.train_dataset,
input_size=args.input_size, input_size=args.input_size,
params=self.params, params=self.params,
@@ -136,26 +132,28 @@ class FedYoloClient(object):
if args.distributed: if args.distributed:
train_sampler = data.DistributedSampler( train_sampler = data.DistributedSampler(
data_set, num_replicas=args.world_size, rank=args.local_rank, shuffle=True dataset, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
) )
else: else:
train_sampler = None train_sampler = None
loader = data.DataLoader( loader = data.DataLoader(
data_set, dataset,
batch_size=self._batch_size, batch_size=self._batch_size,
shuffle=train_sampler is None, shuffle=(train_sampler is None),
sampler=train_sampler, sampler=train_sampler,
num_workers=self.num_workers, num_workers=self.num_workers,
pin_memory=True, pin_memory=True,
collate_fn=Dataset.collate_fn, collate_fn=Dataset.collate_fn,
drop_last=False,
) )
# Scheduler
num_steps = max(1, len(loader)) num_steps = max(1, len(loader))
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps) scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
# DDP mode
if args.distributed: # ---- SyncBN + DDP (if any) ----
is_ddp = bool(args.distributed)
if is_ddp:
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model) self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
self.model = nn.parallel.DistributedDataParallel( self.model = nn.parallel.DistributedDataParallel(
module=self.model, module=self.model,
@@ -164,102 +162,133 @@ class FedYoloClient(object):
find_unused_parameters=False, find_unused_parameters=False,
) )
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True) # ---- AMP + loss ----
scaler = torch.amp.grad_scaler.GradScaler(enabled=True)
# criterion = util.ComputeLoss(
# self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model,
# self.params,
# )
criterion = util.ComputeLoss(self.model, self.params) criterion = util.ComputeLoss(self.model, self.params)
# log # ---- Training ----
# if args.local_rank == 0:
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
# print("\n" + header)
# p_bar = tqdm(total=args.epochs * num_steps, ncols=120)
# p_bar.set_description(f"{self.name:>10}")
for epoch in range(args.epochs): for epoch in range(args.epochs):
# (self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model).train()
self.model.train() self.model.train()
# when distributed, set epoch for shuffling if is_ddp and train_sampler is not None:
if args.distributed and train_sampler is not None:
train_sampler.set_epoch(epoch) train_sampler.set_epoch(epoch)
if args.epochs - epoch == 10: # disable mosaic in the last 10 epochs (if dataset supports it)
# disable mosaic augmentation in the last 10 epochs if args.epochs - epoch == 10 and hasattr(loader.dataset, "mosaic"):
ds = cast(Dataset, loader.dataset) ds = cast(Dataset, loader.dataset)
ds.mosaic = False ds.mosaic = False
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
avg_box_loss = util.AverageMeter() loss_box_meter = util.AverageMeter()
avg_cls_loss = util.AverageMeter() loss_cls_meter = util.AverageMeter()
avg_dfl_loss = util.AverageMeter() loss_dfl_meter = util.AverageMeter()
# # --- header (once per epoch, YOLO-style) --- for i, (images, targets) in enumerate(loader):
# if args.local_rank == 0: print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}")
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl") step = i + epoch * num_steps
# print("\n" + header)
# p_bar = enumerate(loader) # scheduler per-step (your util.LinearLR expects step)
# if args.local_rank == 0: scheduler.step(step=step, optimizer=optimizer)
# p_bar = tqdm(p_bar, total=num_steps, ncols=120)
for i, (samples, targets) in enumerate(loader): # images = images.to(device, non_blocking=True).float() / 255.0
global_step = i + num_steps * epoch images = images.cuda().float() / 255.0
scheduler.step(step=global_step, optimizer=optimizer) bs = images.size(0)
# total_imgs_seen += bs
samples = samples.cuda(non_blocking=True).float() / 255.0 # targets: keep as your ComputeLoss expects (often CPU lists/tensors).
# Move to GPU here only if your loss requires it.
# Forward with autocast(device_type="cuda", enabled=True):
with autocast("cuda", enabled=True): outputs = self.model(images) # DDP wraps forward
outputs = self.model(samples)
box_loss, cls_loss, dfl_loss = criterion(outputs, targets) box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
# meters (use the *unscaled* values) # total_loss = box_loss + cls_loss + dfl_loss
bs = samples.size(0) # Gradient accumulation: normalize by 'accumulate' so LR stays effective
avg_box_loss.update(box_loss.item(), bs) # total_loss = total_loss / accumulate
avg_cls_loss.update(cls_loss.item(), bs)
avg_dfl_loss.update(dfl_loss.item(), bs)
# scale losses by batch/world if your loss is averaged internally per-sample/device # IMPORTANT: assume criterion returns **average per image** in the batch.
# box_loss = box_loss * self._batch_size * args.world_size # Keep logging on the true (unscaled) values:
# cls_loss = cls_loss * self._batch_size * args.world_size loss_box_meter.update(box_loss.item(), bs)
# dfl_loss = dfl_loss * self._batch_size * args.world_size loss_cls_meter.update(cls_loss.item(), bs)
loss_dfl_meter.update(dfl_loss.item(), bs)
box_loss *= self._batch_size
cls_loss *= self._batch_size
dfl_loss *= self._batch_size
box_loss *= args.world_size
cls_loss *= args.world_size
dfl_loss *= args.world_size
total_loss = box_loss + cls_loss + dfl_loss total_loss = box_loss + cls_loss + dfl_loss
# Backward scaler.scale(total_loss).backward()
amp_scale.scale(total_loss).backward()
# Optimize # optimize
if (i + 1) % accumulate == 0: if step % accumulate == 0:
amp_scale.unscale_(optimizer) # unscale gradients # scaler.unscale_(optimizer)
util.clip_gradients(model=self.model, max_norm=10.0) # clip gradients # util.clip_gradients(self.model)
amp_scale.step(optimizer) scaler.step(optimizer)
amp_scale.update() scaler.update()
optimizer.zero_grad(set_to_none=True) optimizer.zero_grad(set_to_none=True)
# # Step when we have 'accumulate' micro-batches, or at the end
# if ((i + 1) % accumulate == 0) or (i + 1 == len(loader)):
# scaler.unscale_(optimizer)
# util.clip_gradients(
# model=(
# self.model.module
# if isinstance(self.model, nn.parallel.DistributedDataParallel)
# else self.model
# ),
# max_norm=10.0,
# )
# scaler.step(optimizer)
# scaler.update()
# optimizer.zero_grad(set_to_none=True)
if ema: if ema:
ema.update(self.model) # Update EMA from the underlying module
ema.update(
self.model.module
if isinstance(self.model, nn.parallel.DistributedDataParallel)
else self.model
)
# print loss to test
print(
f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}"
)
torch.cuda.synchronize()
# torch.cuda.synchronize() # ---- Final average loss (per image) over the whole epoch span ----
avg_loss_per_image = loss_box_meter.avg + loss_cls_meter.avg + loss_dfl_meter.avg
# tqdm update # ---- Cleanup DDP ----
# if args.local_rank == 0: if is_ddp:
# mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
# desc = ("%10s" * 2 + "%10.4g" * 3) % (
# self.name,
# mem,
# avg_box_loss.avg,
# avg_cls_loss.avg,
# avg_dfl_loss.avg,
# )
# cast(tqdm, p_bar).set_description(desc)
# p_bar.update(1)
# p_bar.close()
# clean
if args.distributed:
torch.distributed.destroy_process_group() torch.distributed.destroy_process_group()
torch.cuda.empty_cache() torch.cuda.empty_cache()
return ( # ---- Choose which weights to return ----
self.model.state_dict() if not ema else ema.ema.state_dict(), # - If EMA exists, return EMA weights (common YOLO eval practice)
self.n_data, # - Be careful with DDP: grab state_dict from the underlying module / EMA model
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg}, if ema:
) # print("Using EMA weights")
return (ema.ema.state_dict(), self.n_data, avg_loss_per_image)
else:
# Safely get the underlying module if wrapped by DDP; getattr returns the module or the original object.
model_obj = getattr(self.model, "module", self.model)
# If it's a proper nn.Module, call state_dict(); if it's already a state dict, use it;
# otherwise try to call state_dict() and finally fall back to wrapping the object.
if isinstance(model_obj, torch.nn.Module):
model_to_return = model_obj.state_dict()
elif isinstance(model_obj, dict):
model_to_return = model_obj
else:
try:
model_to_return = model_obj.state_dict()
except Exception:
# fallback: if model_obj is a tensor or unexpected object, wrap it in a dict
model_to_return = {"state": model_obj}
return model_to_return, self.n_data, avg_loss_per_image

View File

@@ -4,6 +4,7 @@ from torch.utils.data import DataLoader
from utils.fed_util import init_model from utils.fed_util import init_model
from utils.dataset import Dataset from utils.dataset import Dataset
from utils import util from utils import util
from nets import YOLO
class FedYoloServer(object): class FedYoloServer(object):
@@ -21,7 +22,7 @@ class FedYoloServer(object):
self.client_n_data = {} self.client_n_data = {}
self.selected_clients = [] self.selected_clients = []
self._batch_size = params.get("val_batch_size", 4) self._batch_size = params.get("val_batch_size", 200)
self.client_list = client_list self.client_list = client_list
self.valset = None self.valset = None
@@ -40,7 +41,7 @@ class FedYoloServer(object):
self.model = init_model(model_name, self._num_classes) self.model = init_model(model_name, self._num_classes)
self.params = params self.params = params
def load_valset(self, valset): def load_valset(self, valset: Dataset):
"""Server loads the validation dataset.""" """Server loads the validation dataset."""
self.valset = valset self.valset = valset
@@ -48,78 +49,6 @@ class FedYoloServer(object):
"""Return global model weights.""" """Return global model weights."""
return self.model.state_dict() return self.model.state_dict()
@torch.no_grad()
def test(self, args) -> dict:
"""
Test the global model on the server's validation set.
Returns:
dict with keys: mAP, mAP50, precision, recall
"""
if self.valset is None:
return {}
loader = DataLoader(
self.valset,
batch_size=self._batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
collate_fn=Dataset.collate_fn,
)
dev = self._device
# move to device for eval; keep in float32 for stability
self.model.eval().to(dev).float()
iou_v = torch.linspace(0.5, 0.95, 10, device=dev)
n_iou = iou_v.numel()
metrics = []
for samples, targets in loader:
samples = samples.to(dev, non_blocking=True).float() / 255.0
_, _, h, w = samples.shape
scale = torch.tensor((w, h, w, h), device=dev)
outputs = self.model(samples)
outputs = util.non_max_suppression(outputs)
for i, output in enumerate(outputs):
idx = targets["idx"] == i
cls = targets["cls"][idx].to(dev)
box = targets["box"][idx].to(dev)
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=dev)
if output.shape[0] == 0:
if cls.shape[0]:
metrics.append((metric, *torch.zeros((2, 0), device=dev), cls.squeeze(-1)))
continue
if cls.shape[0]:
if cls.dim() == 1:
cls = cls.unsqueeze(1)
box_xy = util.wh2xy(box)
if not isinstance(box_xy, torch.Tensor):
box_xy = torch.tensor(box_xy, device=dev)
target = torch.cat((cls, box_xy * scale), dim=1)
metric = util.compute_metric(output[:, :6], target, iou_v)
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
if not metrics:
# move back to CPU before returning
self.model.to("cpu").float()
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
if len(metrics) and metrics[0].any():
_, _, prec, rec, map50, mean_ap = util.compute_ap(*metrics, names=self.params["names"], plot=False)
else:
prec, rec, map50, mean_ap = 0, 0, 0, 0
# return model to CPU so next agg() stays device-consistent
self.model.to("cpu").float()
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
def select_clients(self, connection_ratio=1.0): def select_clients(self, connection_ratio=1.0):
""" """
Randomly select a fraction of clients. Randomly select a fraction of clients.
@@ -130,80 +59,69 @@ class FedYoloServer(object):
self.n_data = 0 self.n_data = 0
for client_id in self.client_list: for client_id in self.client_list:
# Random selection based on connection ratio # Random selection based on connection ratio
if np.random.rand() <= connection_ratio: s = np.random.binomial(np.ones(1).astype(int), connection_ratio)
if s[0] == 1:
self.selected_clients.append(client_id) self.selected_clients.append(client_id)
self.n_data += self.client_n_data.get(client_id, 0) self.n_data += self.client_n_data[client_id]
@torch.no_grad()
def agg(self): def agg(self):
"""Aggregate client updates (FedAvg) on CPU/FP32, preserving non-float buffers.""" """
Server aggregates the local updates from selected clients using FedAvg.
:return: model_state: aggregated model weights
:return: avg_loss: weighted average training loss across selected clients
:return: n_data: total number of data points across selected clients
"""
if len(self.selected_clients) == 0 or self.n_data == 0: if len(self.selected_clients) == 0 or self.n_data == 0:
return self.model.state_dict(), {}, 0 import warnings
# Ensure global model is on CPU for safe load later warnings.warn("No clients selected or no data available for aggregation.")
self.model.to("cpu") return self.model.state_dict(), 0, 0
global_state = self.model.state_dict() # may hold CPU or CUDA refs; were on CPU now
avg_loss = {} # Initialize a model for aggregation
total_n = float(self.n_data) model = init_model(model_name=self.model_name, num_classes=self._num_classes)
model_state = model.state_dict()
# Prepare accumulators on CPU. For floating tensors, use float32 zeros. avg_loss = 0
# For non-floating tensors (e.g., BN num_batches_tracked int64), well copy from the first client.
new_state = {}
first_client = None
for cid in self.selected_clients:
if cid in self.client_state:
first_client = cid
break
assert first_client is not None, "No client states available to aggregate." # Aggregate the local updated models from selected clients
for i, name in enumerate(self.selected_clients):
for k, v in global_state.items(): if name not in self.client_state:
if v.is_floating_point():
new_state[k] = torch.zeros_like(v.detach().cpu(), dtype=torch.float32)
else:
# For non-float buffers, just copy from the first client (or keep global)
new_state[k] = self.client_state[first_client][k].clone()
# Accumulate floating tensors with weights; keep non-floats as assigned above
for cid in self.selected_clients:
if cid not in self.client_state:
continue continue
weight = self.client_n_data[cid] / total_n for key in self.client_state[name]:
cst = self.client_state[cid] if i == 0:
for k in new_state.keys(): # First client, initialize the model_state
if new_state[k].is_floating_point(): model_state[key] = self.client_state[name][key] * (self.client_n_data[name] / self.n_data)
# cst[k] is CPU; ensure float32 for accumulation else:
new_state[k].add_(cst[k].to(torch.float32), alpha=weight) # math equation: w = sum(n_k / n * w_k)
model_state[key] = model_state[key] + self.client_state[name][key] * (
# weighted average losses self.client_n_data[name] / self.n_data
for lk, lv in self.client_loss[cid].items(): )
avg_loss[lk] = avg_loss.get(lk, 0.0) + float(lv) * weight avg_loss = avg_loss + self.client_loss[name] * (self.client_n_data[name] / self.n_data)
# Load aggregated state back into the global model (model is on CPU)
with torch.no_grad():
self.model.load_state_dict(new_state, strict=True)
self.model.load_state_dict(model_state, strict=True)
self.round += 1 self.round += 1
# Return CPU state_dict (good for broadcasting to clients)
return {k: v.clone() for k, v in self.model.state_dict().items()}, avg_loss, int(self.n_data)
def rec(self, name, state_dict, n_data, loss_dict): n_data = self.n_data
return model_state, avg_loss, n_data
def rec(self, name, state_dict, n_data, loss):
""" """
Receive local update from a client. Receive local update from a client.
- Store all floating tensors as CPU float32 - Store all floating tensors as CPU float32
- Store non-floating tensors (e.g., BN counters) as CPU in original dtype - Store non-floating tensors (e.g., BN counters) as CPU in original dtype
""" """
self.n_data += n_data self.n_data += n_data
safe_state = {}
with torch.no_grad(): self.client_state[name] = {}
for k, v in state_dict.items(): self.client_n_data[name] = {}
t = v.detach().cpu() self.client_loss[name] = {}
if t.is_floating_point():
t = t.to(torch.float32) self.client_state[name].update(state_dict)
safe_state[k] = t
self.client_state[name] = safe_state
self.client_n_data[name] = int(n_data) self.client_n_data[name] = int(n_data)
self.client_loss[name] = {k: float(v) for k, v in loss_dict.items()} self.client_loss[name] = loss
def flush(self): def flush(self):
"""Clear stored client updates.""" """Clear stored client updates."""
@@ -211,3 +129,94 @@ class FedYoloServer(object):
self.client_state.clear() self.client_state.clear()
self.client_n_data.clear() self.client_n_data.clear()
self.client_loss.clear() self.client_loss.clear()
def test(self):
"""Evaluate the global model on the server's validation dataset."""
if self.valset is None:
import warnings
warnings.warn("No validation dataset available for testing.")
return {}
return test(self.valset, self.params, self.model)
@torch.no_grad()
def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]:
"""
Evaluate the model on the validation dataset.
Args:
valset: validation dataset
params: dict of parameters (must include 'names')
model: YOLO model to evaluate
batch_size: batch size for evaluation
Returns:
dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap)
"""
loader = DataLoader(
dataset=valset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
collate_fn=Dataset.collate_fn,
)
model.cuda()
model.half()
model.eval()
# Configure
iou_v = torch.linspace(start=0.5, end=0.95, steps=10).cuda() # iou vector for mAP@0.5:0.95
n_iou = iou_v.numel()
m_pre = 0
m_rec = 0
map50 = 0
mean_ap = 0
metrics = []
for samples, targets in loader:
samples = samples.cuda()
samples = samples.half() # uint8 to fp16/32
samples = samples / 255.0 # 0 - 255 to 0.0 - 1.0
_, _, h, w = samples.shape # batch-size, channels, height, width
scale = torch.tensor((w, h, w, h)).cuda()
# Inference
outputs = model(samples)
# NMS
outputs = util.non_max_suppression(outputs)
# Metrics
for i, output in enumerate(outputs):
idx = targets["idx"]
if idx.dim() > 1:
idx = idx.squeeze(-1)
idx = idx == i
# idx = targets["idx"] == i
cls = targets["cls"][idx]
box = targets["box"][idx]
cls = cls.cuda()
box = box.cuda()
metric = torch.zeros(output.shape[0], n_iou, dtype=torch.bool).cuda()
if output.shape[0] == 0:
if cls.shape[0]:
metrics.append((metric, *torch.zeros((2, 0)).cuda(), cls.squeeze(-1)))
continue
# Evaluate
if cls.shape[0]:
target = torch.cat(tensors=(cls, util.wh2xy(box) * scale), dim=1)
metric = util.compute_metric(output[:, :6], target, iou_v)
# Append
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
# Compute metrics
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
if len(metrics) and metrics[0].any():
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"])
# Print results
# print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
# Return results
model.float() # for training
return mean_ap, map50, m_rec, m_pre

View File

@@ -3,92 +3,16 @@ import os
import json import json
import yaml import yaml
import time import time
import random
from tqdm import tqdm from tqdm import tqdm
import numpy as np
import torch import torch
import matplotlib.pyplot as plt
from utils.dataset import Dataset from utils.fed_util import build_valset_if_available, seed_everything, plot_curves
from fed_algo_cs.client_base import FedYoloClient from fed_algo_cs.client_base import FedYoloClient
from fed_algo_cs.server_base import FedYoloServer from fed_algo_cs.server_base import FedYoloServer
from utils.args import args_parser # args parser from utils.args import args_parser # args parser
from utils.fed_util import divide_trainset # divide_trainset from utils.fed_util import divide_trainset # divide_trainset
def _read_list_file(txt_path: str):
"""Read one path per line; keep as-is (absolute or relative)."""
if not txt_path or not os.path.exists(txt_path):
return []
with open(txt_path, "r", encoding="utf-8") as f:
return [ln.strip() for ln in f if ln.strip()]
def _build_valset_if_available(cfg, params):
"""
Try to build a validation Dataset.
- If cfg['val_txt'] exists, use it.
- Else if <dataset_path>/val.txt exists, use it.
- Else return None (testing will be skipped).
Args:
cfg: config dict
params: params dict for Dataset
Returns:
Dataset or None
"""
input_size = int(cfg.get("input_size", 640))
val_txt = cfg.get("val_txt", "")
if not val_txt:
ds_root = cfg.get("dataset_path", "")
guess = os.path.join(ds_root, "val.txt") if ds_root else ""
val_txt = guess if os.path.exists(guess) else ""
val_files = _read_list_file(val_txt)
if not val_files:
return None
return Dataset(
filenames=val_files,
input_size=input_size,
params=params,
augment=True,
)
def _seed_everything(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
def _plot_curves(save_dir, hist):
"""
Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round.
"""
os.makedirs(save_dir, exist_ok=True)
rounds = np.arange(1, len(hist["mAP"]) + 1)
plt.figure()
if hist["mAP"]:
plt.plot(rounds, hist["mAP"], label="mAP50-95")
if hist["mAP50"]:
plt.plot(rounds, hist["mAP50"], label="mAP50")
if hist["precision"]:
plt.plot(rounds, hist["precision"], label="precision")
if hist["recall"]:
plt.plot(rounds, hist["recall"], label="recall")
if hist["train_loss"]:
plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)")
plt.xlabel("Global Round")
plt.ylabel("Metric")
plt.title("Federated YOLO - Server Metrics")
plt.legend()
out_png = os.path.join(save_dir, "fed_yolo_curves.png")
plt.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"[plot] saved: {out_png}")
def fed_run(): def fed_run():
""" """
Main FL process: Main FL process:
@@ -98,20 +22,22 @@ def fed_run():
- Record & save results, plot curves - Record & save results, plot curves
""" """
args_cli = args_parser() args_cli = args_parser()
# TODO: cfg and params should not be separately defined
with open(args_cli.config, "r", encoding="utf-8") as f: with open(args_cli.config, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) cfg = yaml.safe_load(f)
# --- params / config normalization --- # --- params / config normalization ---
# For convenience we pass the same `params` dict used by Dataset/model/loss. # For convenience we pass the same `params` dict used by Dataset/model/loss.
# Here we re-use the top-level cfg directly as params. # Here we re-use the top-level cfg directly as params.
params = dict(cfg) # params = dict(cfg)
if "names" in cfg and isinstance(cfg["names"], dict): if "names" in cfg and isinstance(cfg["names"], dict):
# Convert {0: 'uav', 1: 'car', ...} to list if you prefer list # Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
# but we can leave dict; your utils appear to accept dict # but we can leave dict; your utils appear to accept dict
pass pass
# seeds # seeds
_seed_everything(int(cfg.get("i_seed", 0))) seed_everything(int(cfg.get("i_seed", 0)))
# --- split clients' train data from a global train list --- # --- split clients' train data from a global train list ---
# Expect either cfg["train_txt"] or <dataset_path>/train.txt # Expect either cfg["train_txt"] or <dataset_path>/train.txt
@@ -144,13 +70,13 @@ def fed_run():
clients = {} clients = {}
for uid in users: for uid in users:
c = FedYoloClient(name=uid, model_name=model_name, params=params) c = FedYoloClient(name=uid, model_name=model_name, params=cfg)
c.load_trainset(user_data[uid]["filename"]) c.load_trainset(user_data[uid]["filename"])
clients[uid] = c clients[uid] = c
# --- build server & optional validation set --- # --- build server & optional validation set ---
server = FedYoloServer(client_list=users, model_name=model_name, params=params) server = FedYoloServer(client_list=users, model_name=model_name, params=cfg)
valset = _build_valset_if_available(cfg, params) valset = build_valset_if_available(cfg, params=cfg, args=args_cli)
# valset is a Dataset class, not data loader # valset is a Dataset class, not data loader
if valset is not None: if valset is not None:
server.load_valset(valset) server.load_valset(valset)
@@ -186,27 +112,25 @@ def fed_run():
t0 = time.time() t0 = time.time()
# Local training (sequential over all users) # Local training (sequential over all users)
for uid in users: for uid in users:
# tqdm desc update
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}")) p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
client = clients[uid] # FedYoloClient instance client = clients[uid] # FedYoloClient instance
client.update(global_state) # load global weights client.update(global_state) # load global weights
state_dict, n_data, loss_dict = client.train(args_cli) # local training state_dict, n_data, train_loss = client.train(args_cli) # local training
server.rec(uid, state_dict, n_data, loss_dict) server.rec(uid, state_dict, n_data, train_loss)
# Select a fraction for aggregation (FedAvg subset if desired) # Select a fraction for aggregation (FedAvg subset if desired)
server.select_clients(connection_ratio=connection_ratio) server.select_clients(connection_ratio=connection_ratio)
# Aggregate # Aggregate
global_state, avg_loss_dict, _ = server.agg() global_state, avg_loss, _ = server.agg()
# Compute a scalar train loss for plotting (sum of components) # Compute a scalar train loss for plotting (sum of components)
scalar_train_loss = float(sum(avg_loss_dict.values())) if avg_loss_dict else 0.0 scalar_train_loss = avg_loss if avg_loss else 0.0
# Test (if valset provided) # Test (if valset provided)
test_metrics = server.test(args_cli) if server.valset is not None else {} mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0)
mAP = float(test_metrics.get("mAP", 0.0))
mAP50 = float(test_metrics.get("mAP50", 0.0))
precision = float(test_metrics.get("precision", 0.0))
recall = float(test_metrics.get("recall", 0.0))
# Flush per-round client caches # Flush per-round client caches
server.flush() server.flush()
@@ -233,22 +157,23 @@ def fed_run():
p_bar.set_postfix(desc) p_bar.set_postfix(desc)
# Save running JSON (resumable logs) # Save running JSON (resumable logs)
save_name = ( save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s"
f"[{cfg.get('fed_algo', 'FedAvg')},{cfg.get('model_name', 'yolo')},"
f"{cfg.get('num_local_epoch', cfg.get('client', {}).get('num_local_epoch', 1))},"
f"{cfg.get('num_local_class', 2)},"
f"{cfg.get('i_seed', 0)}]"
)
out_json = os.path.join(res_root, save_name + ".json") out_json = os.path.join(res_root, save_name + ".json")
with open(out_json, "w", encoding="utf-8") as f: with open(out_json, "w", encoding="utf-8") as f:
json.dump(history, f, indent=2) json.dump(history, f, indent=4)
p_bar.update(1) p_bar.update(1)
p_bar.close() p_bar.close()
# Save final global model weights
if not os.path.exists("./weights"):
os.makedirs("./weights", exist_ok=True)
torch.save(global_state, f"./weights/{save_name}_final.pth")
print(f"[save] final global model weights: ./weights/{save_name}_final.pth")
# --- final plot --- # --- final plot ---
_plot_curves(res_root, history) plot_curves(res_root, history, savename=f"{save_name}_curve.png")
print("[done] training complete.") print("[done] training complete.")

2
fed_run.sh Normal file
View File

@@ -0,0 +1,2 @@
GPUS=$1
python3 -m torch.distributed.run --nproc_per_node=$GPUS fed_run.py ${@:2}

3
nets/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .nn import YOLO, yolo_v11_l, yolo_v11_m, yolo_v11_s, yolo_v11_t, yolo_v11_x, yolo_v11_n
__all__ = ["YOLO", "yolo_v11_l", "yolo_v11_m", "yolo_v11_s", "yolo_v11_t", "yolo_v11_x", "yolo_v11_n"]

77
testcode.py Normal file
View File

@@ -0,0 +1,77 @@
from utils.fed_util import init_model
from fed_algo_cs.server_base import test
import os
import yaml
from utils.args import args_parser # args parser
from fed_algo_cs.client_base import FedYoloClient # FedYoloClient
from fed_algo_cs.server_base import FedYoloServer # FedYoloServer
from utils import Dataset # Dataset
if __name__ == "__main__":
# model structure test
model = init_model("yolo_v11_n", num_classes=1)
with open("model.txt", "w", encoding="utf-8") as f:
print(model, file=f)
# loop over model key and values
with open("model_key_value.txt", "w", encoding="utf-8") as f:
for k, v in model.state_dict().items():
print(f"{k}: {v.shape}", file=f)
# test agg function
from fed_algo_cs.server_base import FedYoloServer
import torch
import yaml
with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
params = dict(cfg)
server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=params)
state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()}
state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()}
state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()}
server.rec("c1", state1, n_data=20, loss=0.1)
server.rec("c2", state2, n_data=30, loss=0.2)
server.rec("c3", state3, n_data=50, loss=0.3)
server.select_clients(connection_ratio=1.0)
model_state, avg_loss, n_data = server.agg()
with open("agg_model.txt", "w", encoding="utf-8") as f:
for k, v in model_state.items():
print(f"{k}: {v.float().mean()}", file=f)
print(f"avg_loss: {avg_loss}, n_data: {n_data}")
# test single client training (should be the same as standalone training)
args = args_parser()
with open(args.config, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
# params = dict(cfg)
client = FedYoloClient(name="c1", params=cfg, model_name="yolo_v11_n")
filenames = []
data_dir = "/home/image1325/ssd1/dataset/COCO128"
with open(f"{data_dir}/train.txt") as f:
for filename in f.readlines():
filename = os.path.basename(filename.rstrip())
filenames.append(f"{data_dir}/images/train2017/" + filename)
client.load_trainset(train_dataset=filenames)
model_state, n_data, avg_loss = client.train(args=args)
model = init_model("yolo_v11_n", num_classes=80)
model.load_state_dict(model_state)
valset = Dataset(
filenames=filenames,
input_size=640,
params=cfg,
augment=False,
)
if valset is not None:
precision, recall, map50, map = test(valset=valset, params=cfg, model=model, batch_size=128)
print(
f"precision: {precision}, recall: {recall}, map50: {map50}, map: {map}, loss: {avg_loss}, n_data: {n_data}"
)
else:
raise ValueError("valset is None, please provide a valid valset in config file.")

3
utils/__init__.py Normal file
View File

@@ -0,0 +1,3 @@
from .dataset import *
from .fed_util import *
from .util import *

View File

@@ -8,16 +8,11 @@ import torch
from PIL import Image from PIL import Image
from torch.utils import data from torch.utils import data
FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "JPEG", "JPG", "PNG", "TIFF" FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp"
class Dataset(data.Dataset): class Dataset(data.Dataset):
params: dict def __init__(self, filenames, input_size, params, augment):
mosaic: bool
augment: bool
input_size: int
def __init__(self, filenames, input_size: int, params: dict, augment: bool):
self.params = params self.params = params
self.mosaic = augment self.mosaic = augment
self.augment = augment self.augment = augment
@@ -48,8 +43,6 @@ class Dataset(data.Dataset):
else: else:
# Load image # Load image
image, shape = self.load_image(index) image, shape = self.load_image(index)
if image is None:
raise ValueError(f"Failed to load image at index {index}: {self.filenames[index]}")
h, w = image.shape[:2] h, w = image.shape[:2]
# Resize # Resize
@@ -57,7 +50,7 @@ class Dataset(data.Dataset):
label = self.labels[index].copy() label = self.labels[index].copy()
if label.size: if label.size:
label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, int(pad[0]), int(pad[1])) label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, pad[0], pad[1])
if self.augment: if self.augment:
image, label = random_perspective(image, label, self.params) image, label = random_perspective(image, label, self.params)
@@ -84,25 +77,24 @@ class Dataset(data.Dataset):
if nl: if nl:
box[:, 0] = 1 - box[:, 0] box[:, 0] = 1 - box[:, 0]
# target_cls = torch.zeros((nl, 1)) # XXX: when nl=0, torch.from_numpy(box) will error
# target_box = torch.zeros((nl, 4))
# if nl:
# target_cls = torch.from_numpy(cls)
# target_box = torch.from_numpy(box)
# fix [cls, box] empty bug. e.g. [0,1] is illegal in DataLoader collate_fn cat operation
if nl: if nl:
target_cls = torch.from_numpy(cls).view(-1, 1).float() # always (N,1) target_cls = torch.from_numpy(cls).view(-1, 1).float() # always (N,1)
target_box = torch.from_numpy(box).reshape(-1, 4).float() # always (N,4) target_box = torch.from_numpy(box).reshape(-1, 4).float() # always (N,4)
else: else:
target_cls = torch.zeros((0, 1), dtype=torch.float32) target_cls = torch.zeros((0, 1), dtype=torch.float32)
target_box = torch.zeros((0, 4), dtype=torch.float32) target_box = torch.zeros((0, 4), dtype=torch.float32)
# target_cls = torch.zeros((nl, 1))
# target_box = torch.zeros((nl, 4))
# if nl:
# target_cls = torch.from_numpy(cls)
# target_box = torch.from_numpy(box)
# Convert HWC to CHW, BGR to RGB # Convert HWC to CHW, BGR to RGB
sample = image.transpose((2, 0, 1))[::-1] sample = image.transpose((2, 0, 1))[::-1]
sample = numpy.ascontiguousarray(sample) sample = numpy.ascontiguousarray(sample)
# init: return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl) # return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl)
return torch.from_numpy(sample), target_cls, target_box, torch.zeros((nl, 1), dtype=torch.long) return torch.from_numpy(sample), target_cls, target_box, torch.zeros((nl, 1), dtype=torch.long)
def __len__(self): def __len__(self):
@@ -111,7 +103,7 @@ class Dataset(data.Dataset):
def load_image(self, i): def load_image(self, i):
image = cv2.imread(self.filenames[i]) image = cv2.imread(self.filenames[i])
if image is None: if image is None:
raise ValueError(f"Image not found or unable to open: {self.filenames[i]}") raise FileNotFoundError(f"Image Not Found {self.filenames[i]}")
h, w = image.shape[:2] h, w = image.shape[:2]
r = self.input_size / max(h, w) r = self.input_size / max(h, w)
if r != 1: if r != 1:
@@ -173,8 +165,8 @@ class Dataset(data.Dataset):
x2b = min(shape[1], x2a - x1a) x2b = min(shape[1], x2a - x1a)
y2b = min(y2a - y1a, shape[0]) y2b = min(y2a - y1a, shape[0])
pad_w = (x1a if x1a is not None else 0) - (x1b if x1b is not None else 0) pad_w = x1a - x1b
pad_h = (y1a if y1a is not None else 0) - (y1b if y1b is not None else 0) pad_h = y1a - y1b
image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b] image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]
# Labels # Labels
@@ -197,14 +189,8 @@ class Dataset(data.Dataset):
def collate_fn(batch): def collate_fn(batch):
samples, cls, box, indices = zip(*batch) samples, cls, box, indices = zip(*batch)
# ensure empty tensor shape is correct cls = torch.cat(cls, dim=0)
cls = [c.view(-1, 1) for c in cls] box = torch.cat(box, dim=0)
box = [b.reshape(-1, 4) for b in box]
indices = [i for i in indices]
cls = torch.cat(cls, dim=0) if cls else torch.zeros((0, 1))
box = torch.cat(box, dim=0) if box else torch.zeros((0, 4))
indices = torch.cat(indices, dim=0) if indices else torch.zeros((0,), dtype=torch.long)
new_indices = list(indices) new_indices = list(indices)
for i in range(len(indices)): for i in range(len(indices)):
@@ -215,7 +201,7 @@ class Dataset(data.Dataset):
return torch.stack(samples, dim=0), targets return torch.stack(samples, dim=0), targets
@staticmethod @staticmethod
def load_label_use_cache(filenames): def load_label(filenames):
path = f"{os.path.dirname(filenames[0])}.cache" path = f"{os.path.dirname(filenames[0])}.cache"
if os.path.exists(path): if os.path.exists(path):
return torch.load(path, weights_only=False) return torch.load(path, weights_only=False)
@@ -228,14 +214,11 @@ class Dataset(data.Dataset):
image.verify() # PIL verify image.verify() # PIL verify
shape = image.size # image size shape = image.size # image size
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels" assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert image.format is not None and image.format.lower() in FORMATS, ( assert image.format.lower() in FORMATS, f"invalid image format {image.format}"
f"invalid image format {image.format}"
)
# verify labels # verify labels
a = f"{os.sep}images{os.sep}" a = f"{os.sep}images{os.sep}"
b = f"{os.sep}labels{os.sep}" b = f"{os.sep}labels{os.sep}"
if os.path.isfile(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"): if os.path.isfile(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"):
with open(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt") as f: with open(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt") as f:
label = [x.split() for x in f.read().strip().splitlines() if len(x)] label = [x.split() for x in f.read().strip().splitlines() if len(x)]
@@ -260,50 +243,6 @@ class Dataset(data.Dataset):
torch.save(x, path) torch.save(x, path)
return x return x
@staticmethod
def load_label(filenames):
x = {}
for filename in filenames:
try:
# verify images
with open(filename, "rb") as f:
image = Image.open(f)
image.verify()
shape = image.size
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
assert image.format is not None and image.format.lower() in FORMATS, (
f"invalid image format {image.format}"
)
# verify labels
a = f"{os.sep}images{os.sep}"
b = f"{os.sep}labels{os.sep}"
label_path = b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"
if os.path.isfile(label_path):
rows = []
with open(label_path) as f:
for line in f:
parts = line.strip().split()
if len(parts) == 5: # YOLO format
rows.append([float(x) for x in parts])
label = numpy.array(rows, dtype=numpy.float32) if rows else numpy.zeros((0, 5), dtype=numpy.float32)
if label.shape[0]:
assert (label >= 0).all()
assert label.shape[1] == 5
assert (label[:, 1:] <= 1.0001).all()
_, i = numpy.unique(label, axis=0, return_index=True)
label = label[i]
else:
label = numpy.zeros((0, 5), dtype=numpy.float32)
except (FileNotFoundError, AssertionError):
label = numpy.zeros((0, 5), dtype=numpy.float32)
x[filename] = label
return x
def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0): def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0):
# Convert nx4 boxes # Convert nx4 boxes
@@ -462,9 +401,7 @@ class Albumentations:
albumentations.ToGray(p=0.01), albumentations.ToGray(p=0.01),
albumentations.MedianBlur(p=0.01), albumentations.MedianBlur(p=0.01),
] ]
self.transform = albumentations.Compose( self.transform = albumentations.Compose(transforms, albumentations.BboxParams("yolo", ["class_labels"]))
transforms, albumentations.BboxParams(format="yolo", label_fields=["class_labels"])
)
except ImportError: # package not installed, skip except ImportError: # package not installed, skip
pass pass

View File

@@ -1,10 +1,15 @@
import os import os
import re import re
import random import random
import matplotlib.pyplot as plt
from utils.dataset import Dataset
import numpy as np
import torch
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Set, Any from typing import Dict, List, Optional, Set, Any
from nets import nn from nets import nn
from nets import YOLO
def _image_to_label_path(img_path: str) -> str: def _image_to_label_path(img_path: str) -> str:
@@ -59,6 +64,14 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]:
return class_ids return class_ids
def _read_list_file(txt_path: str):
"""Read one path per line; keep as-is (absolute or relative)."""
if not txt_path or not os.path.exists(txt_path):
return []
with open(txt_path, "r", encoding="utf-8") as f:
return [ln.strip() for ln in f if ln.strip()]
def divide_trainset( def divide_trainset(
trainset_path: str, trainset_path: str,
num_local_class: int, num_local_class: int,
@@ -230,7 +243,7 @@ def divide_trainset(
return result return result
def init_model(model_name, num_classes): def init_model(model_name, num_classes) -> YOLO:
""" """
Initialize the model for a specific learning task Initialize the model for a specific learning task
Args: Args:
@@ -252,3 +265,74 @@ def init_model(model_name, num_classes):
raise ValueError("Model {} is not supported.".format(model_name)) raise ValueError("Model {} is not supported.".format(model_name))
return model return model
def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]:
"""
Try to build a validation Dataset.
- If cfg['val_txt'] exists, use it.
- Else if <dataset_path>/val.txt exists, use it.
- Else return None (testing will be skipped).
Args:
cfg: config dict
params: params dict for Dataset
Returns:
Dataset or None
"""
input_size = args.input_size if args and hasattr(args, "input_size") else 640
val_txt = cfg.get("val_txt", "")
if not val_txt:
ds_root = cfg.get("dataset_path", "")
guess = os.path.join(ds_root, "val.txt") if ds_root else ""
val_txt = guess if os.path.exists(guess) else ""
val_files = _read_list_file(val_txt)
if not val_files:
import warnings
warnings.warn("No validation dataset found.")
return None
return Dataset(
filenames=val_files,
input_size=input_size,
params=params,
augment=True,
)
def seed_everything(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"):
"""
Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round.
Args:
save_dir: directory to save the plot
hist: history dict with keys "mAP", "mAP50", "precision", "recall", "train_loss"
savename: output filename
"""
os.makedirs(save_dir, exist_ok=True)
rounds = np.arange(1, len(hist["mAP"]) + 1)
plt.figure()
if hist["mAP"]:
plt.plot(rounds, hist["mAP"], label="mAP50-95")
if hist["mAP50"]:
plt.plot(rounds, hist["mAP50"], label="mAP50")
if hist["precision"]:
plt.plot(rounds, hist["precision"], label="precision")
if hist["recall"]:
plt.plot(rounds, hist["recall"], label="recall")
if hist["train_loss"]:
plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)")
plt.xlabel("Global Round")
plt.ylabel("Metric")
plt.title("Federated YOLO - Server Metrics")
plt.legend()
out_png = os.path.join(save_dir, savename)
plt.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"[plot] saved: {out_png}")

View File

@@ -1,7 +1,3 @@
"""
Utility functions for yolo.
"""
import copy import copy
import random import random
from time import time from time import time
@@ -97,7 +93,7 @@ def make_anchors(x, strides, offset=0.5):
_, _, h, w = x[i].shape _, _, h, w = x[i].shape
sx = torch.arange(end=w, device=device, dtype=dtype) + offset # shift x sx = torch.arange(end=w, device=device, dtype=dtype) + offset # shift x
sy = torch.arange(end=h, device=device, dtype=dtype) + offset # shift y sy = torch.arange(end=h, device=device, dtype=dtype) + offset # shift y
sy, sx = torch.meshgrid(sy, sx, indexing="ij") sy, sx = torch.meshgrid(sy, sx)
anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2)) anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2))
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device)) stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
return torch.cat(anchor_tensor), torch.cat(stride_tensor) return torch.cat(anchor_tensor), torch.cat(stride_tensor)
@@ -151,7 +147,7 @@ def non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.65)
box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2) box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2)
if nc > 1: if nc > 1:
i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T
x = torch.cat((box[i], x[i, 4 + j].unsqueeze(1), j[:, None].float()), dim=1) x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float()), 1)
else: # best class only else: # best class only
conf, j = cls.max(1, keepdim=True) conf, j = cls.max(1, keepdim=True)
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold] x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold]
@@ -195,13 +191,7 @@ def plot_pr_curve(px, py, ap, names, save_dir):
else: else:
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision) ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
ax.plot( ax.plot(px, py.mean(1), linewidth=3, color="blue", label="all classes %.3f mAP@0.5" % ap[:, 0].mean())
px,
py.mean(1),
linewidth=3,
color="blue",
label="all classes %.3f mAP@0.5" % ap[:, 0].mean(),
)
ax.set_xlabel("Recall") ax.set_xlabel("Recall")
ax.set_ylabel("Precision") ax.set_ylabel("Precision")
ax.set_xlim(0, 1) ax.set_xlim(0, 1)
@@ -224,13 +214,7 @@ def plot_curve(px, py, names, save_dir, x_label="Confidence", y_label="Metric"):
ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric) ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
y = smooth(py.mean(0), f=0.05) y = smooth(py.mean(0), f=0.05)
ax.plot( ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.3f} at {px[y.argmax()]:.3f}")
px,
y,
linewidth=3,
color="blue",
label=f"all classes {y.max():.3f} at {px[y.argmax()]:.3f}",
)
ax.set_xlabel(x_label) ax.set_xlabel(x_label)
ax.set_ylabel(y_label) ax.set_ylabel(y_label)
ax.set_xlim(0, 1) ax.set_xlim(0, 1)
@@ -296,8 +280,7 @@ def compute_ap(tp, conf, output, target, plot=False, names=(), eps=1e-16):
# Integrate area under curve # Integrate area under curve
x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO) x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO)
# numpy.trapz is deprecated in numpy 2.0.0 or after version, use numpy.trapezoid instead ap[ci, j] = numpy.trapz(numpy.interp(x, m_rec, m_pre), x) # integrate
ap[ci, j] = numpy.trapezoid(numpy.interp(x, m_rec, m_pre), x) # integrate
if plot and j == 0: if plot and j == 0:
py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5 py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5
@@ -443,7 +426,7 @@ class LinearLR:
min_lr = params["min_lr"] min_lr = params["min_lr"]
warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100)) warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100))
decay_steps = max(1, int(args.epochs * num_steps - warmup_steps)) decay_steps = int(args.epochs * num_steps - warmup_steps)
warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps), endpoint=False) warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps), endpoint=False)
decay_lr = numpy.linspace(max_lr, min_lr, decay_steps) decay_lr = numpy.linspace(max_lr, min_lr, decay_steps)
@@ -528,16 +511,8 @@ class Assigner(torch.nn.Module):
mask_in_gts = mask_in_gts.view(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps) mask_in_gts = mask_in_gts.view(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps)
na = pd_bboxes.shape[-2] na = pd_bboxes.shape[-2]
gt_mask = (mask_in_gts * mask_gt).bool() # b, max_num_obj, h*w gt_mask = (mask_in_gts * mask_gt).bool() # b, max_num_obj, h*w
overlaps = torch.zeros( overlaps = torch.zeros([batch_size, num_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
[batch_size, num_max_boxes, na], bbox_scores = torch.zeros([batch_size, num_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
dtype=pd_bboxes.dtype,
device=pd_bboxes.device,
)
bbox_scores = torch.zeros(
[batch_size, num_max_boxes, na],
dtype=pd_scores.dtype,
device=pd_scores.device,
)
ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long) # 2, b, max_num_obj ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long) # 2, b, max_num_obj
ind[0] = torch.arange(end=batch_size).view(-1, 1).expand(-1, num_max_boxes) # b, max_num_obj ind[0] = torch.arange(end=batch_size).view(-1, 1).expand(-1, num_max_boxes) # b, max_num_obj
@@ -587,9 +562,7 @@ class Assigner(torch.nn.Module):
target_labels.clamp_(0) target_labels.clamp_(0)
target_scores = torch.zeros( target_scores = torch.zeros(
(target_labels.shape[0], target_labels.shape[1], self.nc), (target_labels.shape[0], target_labels.shape[1], self.nc), dtype=torch.int64, device=target_labels.device
dtype=torch.int64,
device=target_labels.device,
) )
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1) target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
@@ -672,16 +645,7 @@ class BoxLoss(torch.nn.Module):
super().__init__() super().__init__()
self.dfl_ch = dfl_ch self.dfl_ch = dfl_ch
def forward( def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
self,
pred_dist,
pred_bboxes,
anchor_points,
target_bboxes,
target_scores,
target_scores_sum,
fg_mask,
):
# IoU loss # IoU loss
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1) weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
iou = compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask]) iou = compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
@@ -803,13 +767,7 @@ class ComputeLoss:
if fg_mask.sum(): if fg_mask.sum():
target_bboxes /= stride_tensor target_bboxes /= stride_tensor
loss_box, loss_dfl = self.box_loss( loss_box, loss_dfl = self.box_loss(
pred_distri, pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
pred_bboxes,
anchor_points,
target_bboxes,
target_scores,
target_scores_sum,
fg_mask,
) )
loss_box *= self.params["box"] # box gain loss_box *= self.params["box"] # box gain