From 40de29591b605951e8d67a632914e67f867fae8a Mon Sep 17 00:00:00 2001 From: TY1667 Date: Sun, 19 Oct 2025 21:27:19 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BC=98=E5=8C=96FedYoloClient=E5=92=8CFedYolo?= =?UTF-8?q?Server=E7=B1=BB?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fed_algo_cs/client_base.py | 235 ++++++++++++++++++-------------- fed_algo_cs/server_base.py | 267 +++++++++++++++++++------------------ 2 files changed, 270 insertions(+), 232 deletions(-) diff --git a/fed_algo_cs/client_base.py b/fed_algo_cs/client_base.py index d8777ff..7d86735 100644 --- a/fed_algo_cs/client_base.py +++ b/fed_algo_cs/client_base.py @@ -3,11 +3,11 @@ import torch from torch import nn from torch.utils import data from torch.amp.autocast_mode import autocast -from tqdm import tqdm from utils.fed_util import init_model from utils import util from utils.dataset import Dataset from typing import cast +from tqdm import tqdm class FedYoloClient(object): @@ -82,52 +82,48 @@ class FedYoloClient(object): # load the global model parameters self.model.load_state_dict(Global_model_state_dict, strict=True) - def train(self, args): + def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]: """ - Train the local model - Args: - :param args: Command line arguments - - local_rank: Local rank for distributed training - - world_size: World size for distributed training - - distributed: Whether to use distributed training - - input_size: Input size for the model - Returns: - :return: Local updated model, number of local data points, training loss + Train the local model. + Returns: (state_dict, n_data, avg_loss_per_image) """ + # ---- Dist init (if any) ---- if args.distributed: torch.cuda.set_device(device=args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") - # print(f"Client {self.name} - distributed training on {world_size} GPUs, local rank: {local_rank}") - # self._device = torch.device("cuda", local_rank) - - if args.local_rank == 0: - pass - # if not os.path.exists("weights"): - # os.makedirs("weights") util.setup_seed() util.setup_multi_processes() - # model - # init model have been done in __init__() - self.model.to(self._device) + # device = torch.device(f"cuda:{args.local_rank}" if torch.cuda.is_available() else "cpu") + # self.model.to(device) + self.model.cuda() + # show model architecture + # print(self.model) - # Optimizer - accumulate = max(round(64 / (self._batch_size * args.world_size)), 1) - self._weight_decay = self._batch_size * args.world_size * accumulate / 64 + # ---- Optimizer / WD scaling & LR warmup/schedule ---- + # accumulate = effective grad-accumulation steps to emulate global batch 64 + world_size = getattr(args, "world_size", 1) + accumulate = max(round(64 / (self._batch_size * max(world_size, 1))), 1) + # scale weight_decay like YOLO recipes + scaled_wd = self._weight_decay * self._batch_size * max(world_size, 1) * accumulate / 64 optimizer = torch.optim.SGD( - util.set_params(self.model, self._weight_decay), + util.set_params(self.model, scaled_wd), lr=self._min_lr, momentum=self._momentum, nesterov=True, ) - # EMA + # ---- EMA (track the underlying module if DDP) ---- + # track_model = self.model.module if is_ddp else self.model ema = util.EMA(self.model) if args.local_rank == 0 else None - data_set = Dataset( + print(type(self.train_dataset)) + + # ---- Data ---- + dataset = Dataset( filenames=self.train_dataset, input_size=args.input_size, params=self.params, @@ -136,26 +132,28 @@ class FedYoloClient(object): if args.distributed: train_sampler = data.DistributedSampler( - data_set, num_replicas=args.world_size, rank=args.local_rank, shuffle=True + dataset, num_replicas=args.world_size, rank=args.local_rank, shuffle=True ) else: train_sampler = None loader = data.DataLoader( - data_set, + dataset, batch_size=self._batch_size, - shuffle=train_sampler is None, + shuffle=(train_sampler is None), sampler=train_sampler, num_workers=self.num_workers, pin_memory=True, collate_fn=Dataset.collate_fn, + drop_last=False, ) - # Scheduler num_steps = max(1, len(loader)) scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps) - # DDP mode - if args.distributed: + + # ---- SyncBN + DDP (if any) ---- + is_ddp = bool(args.distributed) + if is_ddp: self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model) self.model = nn.parallel.DistributedDataParallel( module=self.model, @@ -164,102 +162,133 @@ class FedYoloClient(object): find_unused_parameters=False, ) - amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True) + # ---- AMP + loss ---- + scaler = torch.amp.grad_scaler.GradScaler(enabled=True) + # criterion = util.ComputeLoss( + # self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model, + # self.params, + # ) criterion = util.ComputeLoss(self.model, self.params) - # log - # if args.local_rank == 0: - # header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl") - # print("\n" + header) - # p_bar = tqdm(total=args.epochs * num_steps, ncols=120) - # p_bar.set_description(f"{self.name:>10}") - + # ---- Training ---- for epoch in range(args.epochs): + # (self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model).train() self.model.train() - # when distributed, set epoch for shuffling - if args.distributed and train_sampler is not None: + if is_ddp and train_sampler is not None: train_sampler.set_epoch(epoch) - if args.epochs - epoch == 10: - # disable mosaic augmentation in the last 10 epochs + # disable mosaic in the last 10 epochs (if dataset supports it) + if args.epochs - epoch == 10 and hasattr(loader.dataset, "mosaic"): ds = cast(Dataset, loader.dataset) ds.mosaic = False optimizer.zero_grad(set_to_none=True) - avg_box_loss = util.AverageMeter() - avg_cls_loss = util.AverageMeter() - avg_dfl_loss = util.AverageMeter() + loss_box_meter = util.AverageMeter() + loss_cls_meter = util.AverageMeter() + loss_dfl_meter = util.AverageMeter() - # # --- header (once per epoch, YOLO-style) --- - # if args.local_rank == 0: - # header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl") - # print("\n" + header) + for i, (images, targets) in enumerate(loader): + print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}") + step = i + epoch * num_steps - # p_bar = enumerate(loader) - # if args.local_rank == 0: - # p_bar = tqdm(p_bar, total=num_steps, ncols=120) + # scheduler per-step (your util.LinearLR expects step) + scheduler.step(step=step, optimizer=optimizer) - for i, (samples, targets) in enumerate(loader): - global_step = i + num_steps * epoch - scheduler.step(step=global_step, optimizer=optimizer) + # images = images.to(device, non_blocking=True).float() / 255.0 + images = images.cuda().float() / 255.0 + bs = images.size(0) + # total_imgs_seen += bs - samples = samples.cuda(non_blocking=True).float() / 255.0 + # targets: keep as your ComputeLoss expects (often CPU lists/tensors). + # Move to GPU here only if your loss requires it. - # Forward - with autocast("cuda", enabled=True): - outputs = self.model(samples) + with autocast(device_type="cuda", enabled=True): + outputs = self.model(images) # DDP wraps forward box_loss, cls_loss, dfl_loss = criterion(outputs, targets) - # meters (use the *unscaled* values) - bs = samples.size(0) - avg_box_loss.update(box_loss.item(), bs) - avg_cls_loss.update(cls_loss.item(), bs) - avg_dfl_loss.update(dfl_loss.item(), bs) + # total_loss = box_loss + cls_loss + dfl_loss + # Gradient accumulation: normalize by 'accumulate' so LR stays effective + # total_loss = total_loss / accumulate - # scale losses by batch/world if your loss is averaged internally per-sample/device - # box_loss = box_loss * self._batch_size * args.world_size - # cls_loss = cls_loss * self._batch_size * args.world_size - # dfl_loss = dfl_loss * self._batch_size * args.world_size + # IMPORTANT: assume criterion returns **average per image** in the batch. + # Keep logging on the true (unscaled) values: + loss_box_meter.update(box_loss.item(), bs) + loss_cls_meter.update(cls_loss.item(), bs) + loss_dfl_meter.update(dfl_loss.item(), bs) + box_loss *= self._batch_size + cls_loss *= self._batch_size + dfl_loss *= self._batch_size + box_loss *= args.world_size + cls_loss *= args.world_size + dfl_loss *= args.world_size total_loss = box_loss + cls_loss + dfl_loss - # Backward - amp_scale.scale(total_loss).backward() + scaler.scale(total_loss).backward() - # Optimize - if (i + 1) % accumulate == 0: - amp_scale.unscale_(optimizer) # unscale gradients - util.clip_gradients(model=self.model, max_norm=10.0) # clip gradients - amp_scale.step(optimizer) - amp_scale.update() + # optimize + if step % accumulate == 0: + # scaler.unscale_(optimizer) + # util.clip_gradients(self.model) + scaler.step(optimizer) + scaler.update() optimizer.zero_grad(set_to_none=True) + + # # Step when we have 'accumulate' micro-batches, or at the end + # if ((i + 1) % accumulate == 0) or (i + 1 == len(loader)): + # scaler.unscale_(optimizer) + # util.clip_gradients( + # model=( + # self.model.module + # if isinstance(self.model, nn.parallel.DistributedDataParallel) + # else self.model + # ), + # max_norm=10.0, + # ) + # scaler.step(optimizer) + # scaler.update() + # optimizer.zero_grad(set_to_none=True) + if ema: - ema.update(self.model) + # Update EMA from the underlying module + ema.update( + self.model.module + if isinstance(self.model, nn.parallel.DistributedDataParallel) + else self.model + ) + # print loss to test + print( + f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}" + ) + torch.cuda.synchronize() - # torch.cuda.synchronize() + # ---- Final average loss (per image) over the whole epoch span ---- + avg_loss_per_image = loss_box_meter.avg + loss_cls_meter.avg + loss_dfl_meter.avg - # tqdm update - # if args.local_rank == 0: - # mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G" - # desc = ("%10s" * 2 + "%10.4g" * 3) % ( - # self.name, - # mem, - # avg_box_loss.avg, - # avg_cls_loss.avg, - # avg_dfl_loss.avg, - # ) - # cast(tqdm, p_bar).set_description(desc) - # p_bar.update(1) - - # p_bar.close() - - # clean - if args.distributed: + # ---- Cleanup DDP ---- + if is_ddp: torch.distributed.destroy_process_group() torch.cuda.empty_cache() - return ( - self.model.state_dict() if not ema else ema.ema.state_dict(), - self.n_data, - {"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg}, - ) + # ---- Choose which weights to return ---- + # - If EMA exists, return EMA weights (common YOLO eval practice) + # - Be careful with DDP: grab state_dict from the underlying module / EMA model + if ema: + # print("Using EMA weights") + return (ema.ema.state_dict(), self.n_data, avg_loss_per_image) + else: + # Safely get the underlying module if wrapped by DDP; getattr returns the module or the original object. + model_obj = getattr(self.model, "module", self.model) + # If it's a proper nn.Module, call state_dict(); if it's already a state dict, use it; + # otherwise try to call state_dict() and finally fall back to wrapping the object. + if isinstance(model_obj, torch.nn.Module): + model_to_return = model_obj.state_dict() + elif isinstance(model_obj, dict): + model_to_return = model_obj + else: + try: + model_to_return = model_obj.state_dict() + except Exception: + # fallback: if model_obj is a tensor or unexpected object, wrap it in a dict + model_to_return = {"state": model_obj} + return model_to_return, self.n_data, avg_loss_per_image diff --git a/fed_algo_cs/server_base.py b/fed_algo_cs/server_base.py index b59dc43..ad213a1 100644 --- a/fed_algo_cs/server_base.py +++ b/fed_algo_cs/server_base.py @@ -4,6 +4,7 @@ from torch.utils.data import DataLoader from utils.fed_util import init_model from utils.dataset import Dataset from utils import util +from nets import YOLO class FedYoloServer(object): @@ -21,7 +22,7 @@ class FedYoloServer(object): self.client_n_data = {} self.selected_clients = [] - self._batch_size = params.get("val_batch_size", 4) + self._batch_size = params.get("val_batch_size", 200) self.client_list = client_list self.valset = None @@ -40,7 +41,7 @@ class FedYoloServer(object): self.model = init_model(model_name, self._num_classes) self.params = params - def load_valset(self, valset): + def load_valset(self, valset: Dataset): """Server loads the validation dataset.""" self.valset = valset @@ -48,78 +49,6 @@ class FedYoloServer(object): """Return global model weights.""" return self.model.state_dict() - @torch.no_grad() - def test(self, args) -> dict: - """ - Test the global model on the server's validation set. - Returns: - dict with keys: mAP, mAP50, precision, recall - """ - if self.valset is None: - return {} - - loader = DataLoader( - self.valset, - batch_size=self._batch_size, - shuffle=False, - num_workers=4, - pin_memory=True, - collate_fn=Dataset.collate_fn, - ) - - dev = self._device - # move to device for eval; keep in float32 for stability - self.model.eval().to(dev).float() - - iou_v = torch.linspace(0.5, 0.95, 10, device=dev) - n_iou = iou_v.numel() - metrics = [] - - for samples, targets in loader: - samples = samples.to(dev, non_blocking=True).float() / 255.0 - _, _, h, w = samples.shape - scale = torch.tensor((w, h, w, h), device=dev) - - outputs = self.model(samples) - outputs = util.non_max_suppression(outputs) - - for i, output in enumerate(outputs): - idx = targets["idx"] == i - cls = targets["cls"][idx].to(dev) - box = targets["box"][idx].to(dev) - - metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=dev) - if output.shape[0] == 0: - if cls.shape[0]: - metrics.append((metric, *torch.zeros((2, 0), device=dev), cls.squeeze(-1))) - continue - - if cls.shape[0]: - if cls.dim() == 1: - cls = cls.unsqueeze(1) - box_xy = util.wh2xy(box) - if not isinstance(box_xy, torch.Tensor): - box_xy = torch.tensor(box_xy, device=dev) - target = torch.cat((cls, box_xy * scale), dim=1) - metric = util.compute_metric(output[:, :6], target, iou_v) - - metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1))) - - if not metrics: - # move back to CPU before returning - self.model.to("cpu").float() - return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0} - - metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] - if len(metrics) and metrics[0].any(): - _, _, prec, rec, map50, mean_ap = util.compute_ap(*metrics, names=self.params["names"], plot=False) - else: - prec, rec, map50, mean_ap = 0, 0, 0, 0 - - # return model to CPU so next agg() stays device-consistent - self.model.to("cpu").float() - return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)} - def select_clients(self, connection_ratio=1.0): """ Randomly select a fraction of clients. @@ -130,80 +59,69 @@ class FedYoloServer(object): self.n_data = 0 for client_id in self.client_list: # Random selection based on connection ratio - if np.random.rand() <= connection_ratio: + s = np.random.binomial(np.ones(1).astype(int), connection_ratio) + if s[0] == 1: self.selected_clients.append(client_id) - self.n_data += self.client_n_data.get(client_id, 0) + self.n_data += self.client_n_data[client_id] + @torch.no_grad() def agg(self): - """Aggregate client updates (FedAvg) on CPU/FP32, preserving non-float buffers.""" + """ + Server aggregates the local updates from selected clients using FedAvg. + + :return: model_state: aggregated model weights + :return: avg_loss: weighted average training loss across selected clients + :return: n_data: total number of data points across selected clients + """ if len(self.selected_clients) == 0 or self.n_data == 0: - return self.model.state_dict(), {}, 0 + import warnings - # Ensure global model is on CPU for safe load later - self.model.to("cpu") - global_state = self.model.state_dict() # may hold CPU or CUDA refs; we’re on CPU now + warnings.warn("No clients selected or no data available for aggregation.") + return self.model.state_dict(), 0, 0 - avg_loss = {} - total_n = float(self.n_data) + # Initialize a model for aggregation + model = init_model(model_name=self.model_name, num_classes=self._num_classes) + model_state = model.state_dict() - # Prepare accumulators on CPU. For floating tensors, use float32 zeros. - # For non-floating tensors (e.g., BN num_batches_tracked int64), we’ll copy from the first client. - new_state = {} - first_client = None - for cid in self.selected_clients: - if cid in self.client_state: - first_client = cid - break + avg_loss = 0 - assert first_client is not None, "No client states available to aggregate." - - for k, v in global_state.items(): - if v.is_floating_point(): - new_state[k] = torch.zeros_like(v.detach().cpu(), dtype=torch.float32) - else: - # For non-float buffers, just copy from the first client (or keep global) - new_state[k] = self.client_state[first_client][k].clone() - - # Accumulate floating tensors with weights; keep non-floats as assigned above - for cid in self.selected_clients: - if cid not in self.client_state: + # Aggregate the local updated models from selected clients + for i, name in enumerate(self.selected_clients): + if name not in self.client_state: continue - weight = self.client_n_data[cid] / total_n - cst = self.client_state[cid] - for k in new_state.keys(): - if new_state[k].is_floating_point(): - # cst[k] is CPU; ensure float32 for accumulation - new_state[k].add_(cst[k].to(torch.float32), alpha=weight) - - # weighted average losses - for lk, lv in self.client_loss[cid].items(): - avg_loss[lk] = avg_loss.get(lk, 0.0) + float(lv) * weight - - # Load aggregated state back into the global model (model is on CPU) - with torch.no_grad(): - self.model.load_state_dict(new_state, strict=True) + for key in self.client_state[name]: + if i == 0: + # First client, initialize the model_state + model_state[key] = self.client_state[name][key] * (self.client_n_data[name] / self.n_data) + else: + # math equation: w = sum(n_k / n * w_k) + model_state[key] = model_state[key] + self.client_state[name][key] * ( + self.client_n_data[name] / self.n_data + ) + avg_loss = avg_loss + self.client_loss[name] * (self.client_n_data[name] / self.n_data) + self.model.load_state_dict(model_state, strict=True) self.round += 1 - # Return CPU state_dict (good for broadcasting to clients) - return {k: v.clone() for k, v in self.model.state_dict().items()}, avg_loss, int(self.n_data) - def rec(self, name, state_dict, n_data, loss_dict): + n_data = self.n_data + + return model_state, avg_loss, n_data + + def rec(self, name, state_dict, n_data, loss): """ Receive local update from a client. - Store all floating tensors as CPU float32 - Store non-floating tensors (e.g., BN counters) as CPU in original dtype """ self.n_data += n_data - safe_state = {} - with torch.no_grad(): - for k, v in state_dict.items(): - t = v.detach().cpu() - if t.is_floating_point(): - t = t.to(torch.float32) - safe_state[k] = t - self.client_state[name] = safe_state + + self.client_state[name] = {} + self.client_n_data[name] = {} + self.client_loss[name] = {} + + self.client_state[name].update(state_dict) self.client_n_data[name] = int(n_data) - self.client_loss[name] = {k: float(v) for k, v in loss_dict.items()} + self.client_loss[name] = loss def flush(self): """Clear stored client updates.""" @@ -211,3 +129,94 @@ class FedYoloServer(object): self.client_state.clear() self.client_n_data.clear() self.client_loss.clear() + + def test(self): + """Evaluate the global model on the server's validation dataset.""" + if self.valset is None: + import warnings + + warnings.warn("No validation dataset available for testing.") + return {} + return test(self.valset, self.params, self.model) + + +@torch.no_grad() +def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]: + """ + Evaluate the model on the validation dataset. + Args: + valset: validation dataset + params: dict of parameters (must include 'names') + model: YOLO model to evaluate + batch_size: batch size for evaluation + Returns: + dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap) + """ + loader = DataLoader( + dataset=valset, + batch_size=batch_size, + shuffle=False, + num_workers=4, + pin_memory=True, + collate_fn=Dataset.collate_fn, + ) + + model.cuda() + model.half() + model.eval() + + # Configure + iou_v = torch.linspace(start=0.5, end=0.95, steps=10).cuda() # iou vector for mAP@0.5:0.95 + n_iou = iou_v.numel() + + m_pre = 0 + m_rec = 0 + map50 = 0 + mean_ap = 0 + metrics = [] + + for samples, targets in loader: + samples = samples.cuda() + samples = samples.half() # uint8 to fp16/32 + samples = samples / 255.0 # 0 - 255 to 0.0 - 1.0 + _, _, h, w = samples.shape # batch-size, channels, height, width + scale = torch.tensor((w, h, w, h)).cuda() + # Inference + outputs = model(samples) + # NMS + outputs = util.non_max_suppression(outputs) + # Metrics + for i, output in enumerate(outputs): + idx = targets["idx"] + if idx.dim() > 1: + idx = idx.squeeze(-1) + idx = idx == i + # idx = targets["idx"] == i + cls = targets["cls"][idx] + box = targets["box"][idx] + + cls = cls.cuda() + box = box.cuda() + + metric = torch.zeros(output.shape[0], n_iou, dtype=torch.bool).cuda() + + if output.shape[0] == 0: + if cls.shape[0]: + metrics.append((metric, *torch.zeros((2, 0)).cuda(), cls.squeeze(-1))) + continue + # Evaluate + if cls.shape[0]: + target = torch.cat(tensors=(cls, util.wh2xy(box) * scale), dim=1) + metric = util.compute_metric(output[:, :6], target, iou_v) + # Append + metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1))) + + # Compute metrics + metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy + if len(metrics) and metrics[0].any(): + tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"]) + # Print results + # print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap)) + # Return results + model.float() # for training + return mean_ap, map50, m_rec, m_pre