优化FedYoloClient和FedYoloServer类
This commit is contained in:
@@ -3,11 +3,11 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils import data
|
from torch.utils import data
|
||||||
from torch.amp.autocast_mode import autocast
|
from torch.amp.autocast_mode import autocast
|
||||||
from tqdm import tqdm
|
|
||||||
from utils.fed_util import init_model
|
from utils.fed_util import init_model
|
||||||
from utils import util
|
from utils import util
|
||||||
from utils.dataset import Dataset
|
from utils.dataset import Dataset
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
class FedYoloClient(object):
|
class FedYoloClient(object):
|
||||||
@@ -82,52 +82,48 @@ class FedYoloClient(object):
|
|||||||
# load the global model parameters
|
# load the global model parameters
|
||||||
self.model.load_state_dict(Global_model_state_dict, strict=True)
|
self.model.load_state_dict(Global_model_state_dict, strict=True)
|
||||||
|
|
||||||
def train(self, args):
|
def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
|
||||||
"""
|
"""
|
||||||
Train the local model
|
Train the local model.
|
||||||
Args:
|
Returns: (state_dict, n_data, avg_loss_per_image)
|
||||||
:param args: Command line arguments
|
|
||||||
- local_rank: Local rank for distributed training
|
|
||||||
- world_size: World size for distributed training
|
|
||||||
- distributed: Whether to use distributed training
|
|
||||||
- input_size: Input size for the model
|
|
||||||
Returns:
|
|
||||||
:return: Local updated model, number of local data points, training loss
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# ---- Dist init (if any) ----
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
torch.cuda.set_device(device=args.local_rank)
|
torch.cuda.set_device(device=args.local_rank)
|
||||||
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||||||
# print(f"Client {self.name} - distributed training on {world_size} GPUs, local rank: {local_rank}")
|
|
||||||
# self._device = torch.device("cuda", local_rank)
|
|
||||||
|
|
||||||
if args.local_rank == 0:
|
|
||||||
pass
|
|
||||||
# if not os.path.exists("weights"):
|
|
||||||
# os.makedirs("weights")
|
|
||||||
|
|
||||||
util.setup_seed()
|
util.setup_seed()
|
||||||
util.setup_multi_processes()
|
util.setup_multi_processes()
|
||||||
|
|
||||||
# model
|
# device = torch.device(f"cuda:{args.local_rank}" if torch.cuda.is_available() else "cpu")
|
||||||
# init model have been done in __init__()
|
# self.model.to(device)
|
||||||
self.model.to(self._device)
|
self.model.cuda()
|
||||||
|
# show model architecture
|
||||||
|
# print(self.model)
|
||||||
|
|
||||||
# Optimizer
|
# ---- Optimizer / WD scaling & LR warmup/schedule ----
|
||||||
accumulate = max(round(64 / (self._batch_size * args.world_size)), 1)
|
# accumulate = effective grad-accumulation steps to emulate global batch 64
|
||||||
self._weight_decay = self._batch_size * args.world_size * accumulate / 64
|
world_size = getattr(args, "world_size", 1)
|
||||||
|
accumulate = max(round(64 / (self._batch_size * max(world_size, 1))), 1)
|
||||||
|
|
||||||
|
# scale weight_decay like YOLO recipes
|
||||||
|
scaled_wd = self._weight_decay * self._batch_size * max(world_size, 1) * accumulate / 64
|
||||||
optimizer = torch.optim.SGD(
|
optimizer = torch.optim.SGD(
|
||||||
util.set_params(self.model, self._weight_decay),
|
util.set_params(self.model, scaled_wd),
|
||||||
lr=self._min_lr,
|
lr=self._min_lr,
|
||||||
momentum=self._momentum,
|
momentum=self._momentum,
|
||||||
nesterov=True,
|
nesterov=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# EMA
|
# ---- EMA (track the underlying module if DDP) ----
|
||||||
|
# track_model = self.model.module if is_ddp else self.model
|
||||||
ema = util.EMA(self.model) if args.local_rank == 0 else None
|
ema = util.EMA(self.model) if args.local_rank == 0 else None
|
||||||
|
|
||||||
data_set = Dataset(
|
print(type(self.train_dataset))
|
||||||
|
|
||||||
|
# ---- Data ----
|
||||||
|
dataset = Dataset(
|
||||||
filenames=self.train_dataset,
|
filenames=self.train_dataset,
|
||||||
input_size=args.input_size,
|
input_size=args.input_size,
|
||||||
params=self.params,
|
params=self.params,
|
||||||
@@ -136,26 +132,28 @@ class FedYoloClient(object):
|
|||||||
|
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
train_sampler = data.DistributedSampler(
|
train_sampler = data.DistributedSampler(
|
||||||
data_set, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
|
dataset, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train_sampler = None
|
train_sampler = None
|
||||||
|
|
||||||
loader = data.DataLoader(
|
loader = data.DataLoader(
|
||||||
data_set,
|
dataset,
|
||||||
batch_size=self._batch_size,
|
batch_size=self._batch_size,
|
||||||
shuffle=train_sampler is None,
|
shuffle=(train_sampler is None),
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
collate_fn=Dataset.collate_fn,
|
collate_fn=Dataset.collate_fn,
|
||||||
|
drop_last=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Scheduler
|
|
||||||
num_steps = max(1, len(loader))
|
num_steps = max(1, len(loader))
|
||||||
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
|
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
|
||||||
# DDP mode
|
|
||||||
if args.distributed:
|
# ---- SyncBN + DDP (if any) ----
|
||||||
|
is_ddp = bool(args.distributed)
|
||||||
|
if is_ddp:
|
||||||
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
|
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
|
||||||
self.model = nn.parallel.DistributedDataParallel(
|
self.model = nn.parallel.DistributedDataParallel(
|
||||||
module=self.model,
|
module=self.model,
|
||||||
@@ -164,102 +162,133 @@ class FedYoloClient(object):
|
|||||||
find_unused_parameters=False,
|
find_unused_parameters=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True)
|
# ---- AMP + loss ----
|
||||||
|
scaler = torch.amp.grad_scaler.GradScaler(enabled=True)
|
||||||
|
# criterion = util.ComputeLoss(
|
||||||
|
# self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model,
|
||||||
|
# self.params,
|
||||||
|
# )
|
||||||
criterion = util.ComputeLoss(self.model, self.params)
|
criterion = util.ComputeLoss(self.model, self.params)
|
||||||
|
|
||||||
# log
|
# ---- Training ----
|
||||||
# if args.local_rank == 0:
|
|
||||||
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
|
|
||||||
# print("\n" + header)
|
|
||||||
# p_bar = tqdm(total=args.epochs * num_steps, ncols=120)
|
|
||||||
# p_bar.set_description(f"{self.name:>10}")
|
|
||||||
|
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
|
# (self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model).train()
|
||||||
self.model.train()
|
self.model.train()
|
||||||
# when distributed, set epoch for shuffling
|
if is_ddp and train_sampler is not None:
|
||||||
if args.distributed and train_sampler is not None:
|
|
||||||
train_sampler.set_epoch(epoch)
|
train_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
if args.epochs - epoch == 10:
|
# disable mosaic in the last 10 epochs (if dataset supports it)
|
||||||
# disable mosaic augmentation in the last 10 epochs
|
if args.epochs - epoch == 10 and hasattr(loader.dataset, "mosaic"):
|
||||||
ds = cast(Dataset, loader.dataset)
|
ds = cast(Dataset, loader.dataset)
|
||||||
ds.mosaic = False
|
ds.mosaic = False
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
avg_box_loss = util.AverageMeter()
|
loss_box_meter = util.AverageMeter()
|
||||||
avg_cls_loss = util.AverageMeter()
|
loss_cls_meter = util.AverageMeter()
|
||||||
avg_dfl_loss = util.AverageMeter()
|
loss_dfl_meter = util.AverageMeter()
|
||||||
|
|
||||||
# # --- header (once per epoch, YOLO-style) ---
|
for i, (images, targets) in enumerate(loader):
|
||||||
# if args.local_rank == 0:
|
print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}")
|
||||||
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
|
step = i + epoch * num_steps
|
||||||
# print("\n" + header)
|
|
||||||
|
|
||||||
# p_bar = enumerate(loader)
|
# scheduler per-step (your util.LinearLR expects step)
|
||||||
# if args.local_rank == 0:
|
scheduler.step(step=step, optimizer=optimizer)
|
||||||
# p_bar = tqdm(p_bar, total=num_steps, ncols=120)
|
|
||||||
|
|
||||||
for i, (samples, targets) in enumerate(loader):
|
# images = images.to(device, non_blocking=True).float() / 255.0
|
||||||
global_step = i + num_steps * epoch
|
images = images.cuda().float() / 255.0
|
||||||
scheduler.step(step=global_step, optimizer=optimizer)
|
bs = images.size(0)
|
||||||
|
# total_imgs_seen += bs
|
||||||
|
|
||||||
samples = samples.cuda(non_blocking=True).float() / 255.0
|
# targets: keep as your ComputeLoss expects (often CPU lists/tensors).
|
||||||
|
# Move to GPU here only if your loss requires it.
|
||||||
|
|
||||||
# Forward
|
with autocast(device_type="cuda", enabled=True):
|
||||||
with autocast("cuda", enabled=True):
|
outputs = self.model(images) # DDP wraps forward
|
||||||
outputs = self.model(samples)
|
|
||||||
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
|
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
|
||||||
|
|
||||||
# meters (use the *unscaled* values)
|
# total_loss = box_loss + cls_loss + dfl_loss
|
||||||
bs = samples.size(0)
|
# Gradient accumulation: normalize by 'accumulate' so LR stays effective
|
||||||
avg_box_loss.update(box_loss.item(), bs)
|
# total_loss = total_loss / accumulate
|
||||||
avg_cls_loss.update(cls_loss.item(), bs)
|
|
||||||
avg_dfl_loss.update(dfl_loss.item(), bs)
|
|
||||||
|
|
||||||
# scale losses by batch/world if your loss is averaged internally per-sample/device
|
# IMPORTANT: assume criterion returns **average per image** in the batch.
|
||||||
# box_loss = box_loss * self._batch_size * args.world_size
|
# Keep logging on the true (unscaled) values:
|
||||||
# cls_loss = cls_loss * self._batch_size * args.world_size
|
loss_box_meter.update(box_loss.item(), bs)
|
||||||
# dfl_loss = dfl_loss * self._batch_size * args.world_size
|
loss_cls_meter.update(cls_loss.item(), bs)
|
||||||
|
loss_dfl_meter.update(dfl_loss.item(), bs)
|
||||||
|
|
||||||
|
box_loss *= self._batch_size
|
||||||
|
cls_loss *= self._batch_size
|
||||||
|
dfl_loss *= self._batch_size
|
||||||
|
box_loss *= args.world_size
|
||||||
|
cls_loss *= args.world_size
|
||||||
|
dfl_loss *= args.world_size
|
||||||
total_loss = box_loss + cls_loss + dfl_loss
|
total_loss = box_loss + cls_loss + dfl_loss
|
||||||
|
|
||||||
# Backward
|
scaler.scale(total_loss).backward()
|
||||||
amp_scale.scale(total_loss).backward()
|
|
||||||
|
|
||||||
# Optimize
|
# optimize
|
||||||
if (i + 1) % accumulate == 0:
|
if step % accumulate == 0:
|
||||||
amp_scale.unscale_(optimizer) # unscale gradients
|
# scaler.unscale_(optimizer)
|
||||||
util.clip_gradients(model=self.model, max_norm=10.0) # clip gradients
|
# util.clip_gradients(self.model)
|
||||||
amp_scale.step(optimizer)
|
scaler.step(optimizer)
|
||||||
amp_scale.update()
|
scaler.update()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
if ema:
|
|
||||||
ema.update(self.model)
|
|
||||||
|
|
||||||
# torch.cuda.synchronize()
|
# # Step when we have 'accumulate' micro-batches, or at the end
|
||||||
|
# if ((i + 1) % accumulate == 0) or (i + 1 == len(loader)):
|
||||||
# tqdm update
|
# scaler.unscale_(optimizer)
|
||||||
# if args.local_rank == 0:
|
# util.clip_gradients(
|
||||||
# mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
|
# model=(
|
||||||
# desc = ("%10s" * 2 + "%10.4g" * 3) % (
|
# self.model.module
|
||||||
# self.name,
|
# if isinstance(self.model, nn.parallel.DistributedDataParallel)
|
||||||
# mem,
|
# else self.model
|
||||||
# avg_box_loss.avg,
|
# ),
|
||||||
# avg_cls_loss.avg,
|
# max_norm=10.0,
|
||||||
# avg_dfl_loss.avg,
|
|
||||||
# )
|
# )
|
||||||
# cast(tqdm, p_bar).set_description(desc)
|
# scaler.step(optimizer)
|
||||||
# p_bar.update(1)
|
# scaler.update()
|
||||||
|
# optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# p_bar.close()
|
if ema:
|
||||||
|
# Update EMA from the underlying module
|
||||||
|
ema.update(
|
||||||
|
self.model.module
|
||||||
|
if isinstance(self.model, nn.parallel.DistributedDataParallel)
|
||||||
|
else self.model
|
||||||
|
)
|
||||||
|
# print loss to test
|
||||||
|
print(
|
||||||
|
f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}"
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
# clean
|
# ---- Final average loss (per image) over the whole epoch span ----
|
||||||
if args.distributed:
|
avg_loss_per_image = loss_box_meter.avg + loss_cls_meter.avg + loss_dfl_meter.avg
|
||||||
|
|
||||||
|
# ---- Cleanup DDP ----
|
||||||
|
if is_ddp:
|
||||||
torch.distributed.destroy_process_group()
|
torch.distributed.destroy_process_group()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return (
|
# ---- Choose which weights to return ----
|
||||||
self.model.state_dict() if not ema else ema.ema.state_dict(),
|
# - If EMA exists, return EMA weights (common YOLO eval practice)
|
||||||
self.n_data,
|
# - Be careful with DDP: grab state_dict from the underlying module / EMA model
|
||||||
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg},
|
if ema:
|
||||||
)
|
# print("Using EMA weights")
|
||||||
|
return (ema.ema.state_dict(), self.n_data, avg_loss_per_image)
|
||||||
|
else:
|
||||||
|
# Safely get the underlying module if wrapped by DDP; getattr returns the module or the original object.
|
||||||
|
model_obj = getattr(self.model, "module", self.model)
|
||||||
|
# If it's a proper nn.Module, call state_dict(); if it's already a state dict, use it;
|
||||||
|
# otherwise try to call state_dict() and finally fall back to wrapping the object.
|
||||||
|
if isinstance(model_obj, torch.nn.Module):
|
||||||
|
model_to_return = model_obj.state_dict()
|
||||||
|
elif isinstance(model_obj, dict):
|
||||||
|
model_to_return = model_obj
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
model_to_return = model_obj.state_dict()
|
||||||
|
except Exception:
|
||||||
|
# fallback: if model_obj is a tensor or unexpected object, wrap it in a dict
|
||||||
|
model_to_return = {"state": model_obj}
|
||||||
|
return model_to_return, self.n_data, avg_loss_per_image
|
||||||
|
@@ -4,6 +4,7 @@ from torch.utils.data import DataLoader
|
|||||||
from utils.fed_util import init_model
|
from utils.fed_util import init_model
|
||||||
from utils.dataset import Dataset
|
from utils.dataset import Dataset
|
||||||
from utils import util
|
from utils import util
|
||||||
|
from nets import YOLO
|
||||||
|
|
||||||
|
|
||||||
class FedYoloServer(object):
|
class FedYoloServer(object):
|
||||||
@@ -21,7 +22,7 @@ class FedYoloServer(object):
|
|||||||
self.client_n_data = {}
|
self.client_n_data = {}
|
||||||
self.selected_clients = []
|
self.selected_clients = []
|
||||||
|
|
||||||
self._batch_size = params.get("val_batch_size", 4)
|
self._batch_size = params.get("val_batch_size", 200)
|
||||||
self.client_list = client_list
|
self.client_list = client_list
|
||||||
self.valset = None
|
self.valset = None
|
||||||
|
|
||||||
@@ -40,7 +41,7 @@ class FedYoloServer(object):
|
|||||||
self.model = init_model(model_name, self._num_classes)
|
self.model = init_model(model_name, self._num_classes)
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|
||||||
def load_valset(self, valset):
|
def load_valset(self, valset: Dataset):
|
||||||
"""Server loads the validation dataset."""
|
"""Server loads the validation dataset."""
|
||||||
self.valset = valset
|
self.valset = valset
|
||||||
|
|
||||||
@@ -48,78 +49,6 @@ class FedYoloServer(object):
|
|||||||
"""Return global model weights."""
|
"""Return global model weights."""
|
||||||
return self.model.state_dict()
|
return self.model.state_dict()
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def test(self, args) -> dict:
|
|
||||||
"""
|
|
||||||
Test the global model on the server's validation set.
|
|
||||||
Returns:
|
|
||||||
dict with keys: mAP, mAP50, precision, recall
|
|
||||||
"""
|
|
||||||
if self.valset is None:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
loader = DataLoader(
|
|
||||||
self.valset,
|
|
||||||
batch_size=self._batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=4,
|
|
||||||
pin_memory=True,
|
|
||||||
collate_fn=Dataset.collate_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
dev = self._device
|
|
||||||
# move to device for eval; keep in float32 for stability
|
|
||||||
self.model.eval().to(dev).float()
|
|
||||||
|
|
||||||
iou_v = torch.linspace(0.5, 0.95, 10, device=dev)
|
|
||||||
n_iou = iou_v.numel()
|
|
||||||
metrics = []
|
|
||||||
|
|
||||||
for samples, targets in loader:
|
|
||||||
samples = samples.to(dev, non_blocking=True).float() / 255.0
|
|
||||||
_, _, h, w = samples.shape
|
|
||||||
scale = torch.tensor((w, h, w, h), device=dev)
|
|
||||||
|
|
||||||
outputs = self.model(samples)
|
|
||||||
outputs = util.non_max_suppression(outputs)
|
|
||||||
|
|
||||||
for i, output in enumerate(outputs):
|
|
||||||
idx = targets["idx"] == i
|
|
||||||
cls = targets["cls"][idx].to(dev)
|
|
||||||
box = targets["box"][idx].to(dev)
|
|
||||||
|
|
||||||
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=dev)
|
|
||||||
if output.shape[0] == 0:
|
|
||||||
if cls.shape[0]:
|
|
||||||
metrics.append((metric, *torch.zeros((2, 0), device=dev), cls.squeeze(-1)))
|
|
||||||
continue
|
|
||||||
|
|
||||||
if cls.shape[0]:
|
|
||||||
if cls.dim() == 1:
|
|
||||||
cls = cls.unsqueeze(1)
|
|
||||||
box_xy = util.wh2xy(box)
|
|
||||||
if not isinstance(box_xy, torch.Tensor):
|
|
||||||
box_xy = torch.tensor(box_xy, device=dev)
|
|
||||||
target = torch.cat((cls, box_xy * scale), dim=1)
|
|
||||||
metric = util.compute_metric(output[:, :6], target, iou_v)
|
|
||||||
|
|
||||||
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
|
||||||
|
|
||||||
if not metrics:
|
|
||||||
# move back to CPU before returning
|
|
||||||
self.model.to("cpu").float()
|
|
||||||
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
|
|
||||||
|
|
||||||
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
|
|
||||||
if len(metrics) and metrics[0].any():
|
|
||||||
_, _, prec, rec, map50, mean_ap = util.compute_ap(*metrics, names=self.params["names"], plot=False)
|
|
||||||
else:
|
|
||||||
prec, rec, map50, mean_ap = 0, 0, 0, 0
|
|
||||||
|
|
||||||
# return model to CPU so next agg() stays device-consistent
|
|
||||||
self.model.to("cpu").float()
|
|
||||||
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
|
|
||||||
|
|
||||||
def select_clients(self, connection_ratio=1.0):
|
def select_clients(self, connection_ratio=1.0):
|
||||||
"""
|
"""
|
||||||
Randomly select a fraction of clients.
|
Randomly select a fraction of clients.
|
||||||
@@ -130,80 +59,69 @@ class FedYoloServer(object):
|
|||||||
self.n_data = 0
|
self.n_data = 0
|
||||||
for client_id in self.client_list:
|
for client_id in self.client_list:
|
||||||
# Random selection based on connection ratio
|
# Random selection based on connection ratio
|
||||||
if np.random.rand() <= connection_ratio:
|
s = np.random.binomial(np.ones(1).astype(int), connection_ratio)
|
||||||
|
if s[0] == 1:
|
||||||
self.selected_clients.append(client_id)
|
self.selected_clients.append(client_id)
|
||||||
self.n_data += self.client_n_data.get(client_id, 0)
|
self.n_data += self.client_n_data[client_id]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def agg(self):
|
def agg(self):
|
||||||
"""Aggregate client updates (FedAvg) on CPU/FP32, preserving non-float buffers."""
|
"""
|
||||||
|
Server aggregates the local updates from selected clients using FedAvg.
|
||||||
|
|
||||||
|
:return: model_state: aggregated model weights
|
||||||
|
:return: avg_loss: weighted average training loss across selected clients
|
||||||
|
:return: n_data: total number of data points across selected clients
|
||||||
|
"""
|
||||||
if len(self.selected_clients) == 0 or self.n_data == 0:
|
if len(self.selected_clients) == 0 or self.n_data == 0:
|
||||||
return self.model.state_dict(), {}, 0
|
import warnings
|
||||||
|
|
||||||
# Ensure global model is on CPU for safe load later
|
warnings.warn("No clients selected or no data available for aggregation.")
|
||||||
self.model.to("cpu")
|
return self.model.state_dict(), 0, 0
|
||||||
global_state = self.model.state_dict() # may hold CPU or CUDA refs; we’re on CPU now
|
|
||||||
|
|
||||||
avg_loss = {}
|
# Initialize a model for aggregation
|
||||||
total_n = float(self.n_data)
|
model = init_model(model_name=self.model_name, num_classes=self._num_classes)
|
||||||
|
model_state = model.state_dict()
|
||||||
|
|
||||||
# Prepare accumulators on CPU. For floating tensors, use float32 zeros.
|
avg_loss = 0
|
||||||
# For non-floating tensors (e.g., BN num_batches_tracked int64), we’ll copy from the first client.
|
|
||||||
new_state = {}
|
|
||||||
first_client = None
|
|
||||||
for cid in self.selected_clients:
|
|
||||||
if cid in self.client_state:
|
|
||||||
first_client = cid
|
|
||||||
break
|
|
||||||
|
|
||||||
assert first_client is not None, "No client states available to aggregate."
|
# Aggregate the local updated models from selected clients
|
||||||
|
for i, name in enumerate(self.selected_clients):
|
||||||
for k, v in global_state.items():
|
if name not in self.client_state:
|
||||||
if v.is_floating_point():
|
|
||||||
new_state[k] = torch.zeros_like(v.detach().cpu(), dtype=torch.float32)
|
|
||||||
else:
|
|
||||||
# For non-float buffers, just copy from the first client (or keep global)
|
|
||||||
new_state[k] = self.client_state[first_client][k].clone()
|
|
||||||
|
|
||||||
# Accumulate floating tensors with weights; keep non-floats as assigned above
|
|
||||||
for cid in self.selected_clients:
|
|
||||||
if cid not in self.client_state:
|
|
||||||
continue
|
continue
|
||||||
weight = self.client_n_data[cid] / total_n
|
for key in self.client_state[name]:
|
||||||
cst = self.client_state[cid]
|
if i == 0:
|
||||||
for k in new_state.keys():
|
# First client, initialize the model_state
|
||||||
if new_state[k].is_floating_point():
|
model_state[key] = self.client_state[name][key] * (self.client_n_data[name] / self.n_data)
|
||||||
# cst[k] is CPU; ensure float32 for accumulation
|
else:
|
||||||
new_state[k].add_(cst[k].to(torch.float32), alpha=weight)
|
# math equation: w = sum(n_k / n * w_k)
|
||||||
|
model_state[key] = model_state[key] + self.client_state[name][key] * (
|
||||||
# weighted average losses
|
self.client_n_data[name] / self.n_data
|
||||||
for lk, lv in self.client_loss[cid].items():
|
)
|
||||||
avg_loss[lk] = avg_loss.get(lk, 0.0) + float(lv) * weight
|
avg_loss = avg_loss + self.client_loss[name] * (self.client_n_data[name] / self.n_data)
|
||||||
|
|
||||||
# Load aggregated state back into the global model (model is on CPU)
|
|
||||||
with torch.no_grad():
|
|
||||||
self.model.load_state_dict(new_state, strict=True)
|
|
||||||
|
|
||||||
|
self.model.load_state_dict(model_state, strict=True)
|
||||||
self.round += 1
|
self.round += 1
|
||||||
# Return CPU state_dict (good for broadcasting to clients)
|
|
||||||
return {k: v.clone() for k, v in self.model.state_dict().items()}, avg_loss, int(self.n_data)
|
|
||||||
|
|
||||||
def rec(self, name, state_dict, n_data, loss_dict):
|
n_data = self.n_data
|
||||||
|
|
||||||
|
return model_state, avg_loss, n_data
|
||||||
|
|
||||||
|
def rec(self, name, state_dict, n_data, loss):
|
||||||
"""
|
"""
|
||||||
Receive local update from a client.
|
Receive local update from a client.
|
||||||
- Store all floating tensors as CPU float32
|
- Store all floating tensors as CPU float32
|
||||||
- Store non-floating tensors (e.g., BN counters) as CPU in original dtype
|
- Store non-floating tensors (e.g., BN counters) as CPU in original dtype
|
||||||
"""
|
"""
|
||||||
self.n_data += n_data
|
self.n_data += n_data
|
||||||
safe_state = {}
|
|
||||||
with torch.no_grad():
|
self.client_state[name] = {}
|
||||||
for k, v in state_dict.items():
|
self.client_n_data[name] = {}
|
||||||
t = v.detach().cpu()
|
self.client_loss[name] = {}
|
||||||
if t.is_floating_point():
|
|
||||||
t = t.to(torch.float32)
|
self.client_state[name].update(state_dict)
|
||||||
safe_state[k] = t
|
|
||||||
self.client_state[name] = safe_state
|
|
||||||
self.client_n_data[name] = int(n_data)
|
self.client_n_data[name] = int(n_data)
|
||||||
self.client_loss[name] = {k: float(v) for k, v in loss_dict.items()}
|
self.client_loss[name] = loss
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
"""Clear stored client updates."""
|
"""Clear stored client updates."""
|
||||||
@@ -211,3 +129,94 @@ class FedYoloServer(object):
|
|||||||
self.client_state.clear()
|
self.client_state.clear()
|
||||||
self.client_n_data.clear()
|
self.client_n_data.clear()
|
||||||
self.client_loss.clear()
|
self.client_loss.clear()
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
"""Evaluate the global model on the server's validation dataset."""
|
||||||
|
if self.valset is None:
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn("No validation dataset available for testing.")
|
||||||
|
return {}
|
||||||
|
return test(self.valset, self.params, self.model)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]:
|
||||||
|
"""
|
||||||
|
Evaluate the model on the validation dataset.
|
||||||
|
Args:
|
||||||
|
valset: validation dataset
|
||||||
|
params: dict of parameters (must include 'names')
|
||||||
|
model: YOLO model to evaluate
|
||||||
|
batch_size: batch size for evaluation
|
||||||
|
Returns:
|
||||||
|
dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap)
|
||||||
|
"""
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset=valset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=4,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=Dataset.collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.cuda()
|
||||||
|
model.half()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Configure
|
||||||
|
iou_v = torch.linspace(start=0.5, end=0.95, steps=10).cuda() # iou vector for mAP@0.5:0.95
|
||||||
|
n_iou = iou_v.numel()
|
||||||
|
|
||||||
|
m_pre = 0
|
||||||
|
m_rec = 0
|
||||||
|
map50 = 0
|
||||||
|
mean_ap = 0
|
||||||
|
metrics = []
|
||||||
|
|
||||||
|
for samples, targets in loader:
|
||||||
|
samples = samples.cuda()
|
||||||
|
samples = samples.half() # uint8 to fp16/32
|
||||||
|
samples = samples / 255.0 # 0 - 255 to 0.0 - 1.0
|
||||||
|
_, _, h, w = samples.shape # batch-size, channels, height, width
|
||||||
|
scale = torch.tensor((w, h, w, h)).cuda()
|
||||||
|
# Inference
|
||||||
|
outputs = model(samples)
|
||||||
|
# NMS
|
||||||
|
outputs = util.non_max_suppression(outputs)
|
||||||
|
# Metrics
|
||||||
|
for i, output in enumerate(outputs):
|
||||||
|
idx = targets["idx"]
|
||||||
|
if idx.dim() > 1:
|
||||||
|
idx = idx.squeeze(-1)
|
||||||
|
idx = idx == i
|
||||||
|
# idx = targets["idx"] == i
|
||||||
|
cls = targets["cls"][idx]
|
||||||
|
box = targets["box"][idx]
|
||||||
|
|
||||||
|
cls = cls.cuda()
|
||||||
|
box = box.cuda()
|
||||||
|
|
||||||
|
metric = torch.zeros(output.shape[0], n_iou, dtype=torch.bool).cuda()
|
||||||
|
|
||||||
|
if output.shape[0] == 0:
|
||||||
|
if cls.shape[0]:
|
||||||
|
metrics.append((metric, *torch.zeros((2, 0)).cuda(), cls.squeeze(-1)))
|
||||||
|
continue
|
||||||
|
# Evaluate
|
||||||
|
if cls.shape[0]:
|
||||||
|
target = torch.cat(tensors=(cls, util.wh2xy(box) * scale), dim=1)
|
||||||
|
metric = util.compute_metric(output[:, :6], target, iou_v)
|
||||||
|
# Append
|
||||||
|
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
||||||
|
|
||||||
|
# Compute metrics
|
||||||
|
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
|
||||||
|
if len(metrics) and metrics[0].any():
|
||||||
|
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"])
|
||||||
|
# Print results
|
||||||
|
# print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
|
||||||
|
# Return results
|
||||||
|
model.float() # for training
|
||||||
|
return mean_ap, map50, m_rec, m_pre
|
||||||
|
Reference in New Issue
Block a user