diff --git a/fed_algo_cs/server_base.py b/fed_algo_cs/server_base.py index 5043675..b59dc43 100644 --- a/fed_algo_cs/server_base.py +++ b/fed_algo_cs/server_base.py @@ -49,11 +49,11 @@ class FedYoloServer(object): return self.model.state_dict() @torch.no_grad() - def test(self, args): + def test(self, args) -> dict: """ - Evaluate global model on validation set using YOLO metrics (mAP, precision, recall). + Test the global model on the server's validation set. Returns: - dict with {"mAP": ..., "mAP50": ..., "precision": ..., "recall": ...} + dict with keys: mAP, mAP50, precision, recall """ if self.valset is None: return {} @@ -67,46 +67,47 @@ class FedYoloServer(object): collate_fn=Dataset.collate_fn, ) - self.model.to(self._device).eval().half() + dev = self._device + # move to device for eval; keep in float32 for stability + self.model.eval().to(dev).float() - iou_v = torch.linspace(0.5, 0.95, 10).to(self._device) # IoU thresholds + iou_v = torch.linspace(0.5, 0.95, 10, device=dev) n_iou = iou_v.numel() metrics = [] for samples, targets in loader: - samples = samples.to(self._device).half() / 255.0 + samples = samples.to(dev, non_blocking=True).float() / 255.0 _, _, h, w = samples.shape - scale = torch.tensor((w, h, w, h)).to(self._device) + scale = torch.tensor((w, h, w, h), device=dev) outputs = self.model(samples) outputs = util.non_max_suppression(outputs) for i, output in enumerate(outputs): idx = targets["idx"] == i - cls = targets["cls"][idx].to(self._device) - box = targets["box"][idx].to(self._device) - - metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=self._device) + cls = targets["cls"][idx].to(dev) + box = targets["box"][idx].to(dev) + metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=dev) if output.shape[0] == 0: if cls.shape[0]: - metrics.append((metric, *torch.zeros((2, 0), device=self._device), cls.squeeze(-1))) + metrics.append((metric, *torch.zeros((2, 0), device=dev), cls.squeeze(-1))) continue if cls.shape[0]: - cls_tensor = cls if isinstance(cls, torch.Tensor) else torch.tensor(cls, device=self._device) - if cls_tensor.dim() == 1: - cls_tensor = cls_tensor.unsqueeze(1) + if cls.dim() == 1: + cls = cls.unsqueeze(1) box_xy = util.wh2xy(box) if not isinstance(box_xy, torch.Tensor): - box_xy = torch.tensor(box_xy, device=self._device) - target = torch.cat((cls_tensor, box_xy * scale), dim=1) + box_xy = torch.tensor(box_xy, device=dev) + target = torch.cat((cls, box_xy * scale), dim=1) metric = util.compute_metric(output[:, :6], target, iou_v) metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1))) - # Compute metrics if not metrics: + # move back to CPU before returning + self.model.to("cpu").float() return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0} metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] @@ -115,9 +116,8 @@ class FedYoloServer(object): else: prec, rec, map50, mean_ap = 0, 0, 0, 0 - # Back to float32 for further training - self.model.float() - + # return model to CPU so next agg() stays device-consistent + self.model.to("cpu").float() return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)} def select_clients(self, connection_ratio=1.0): @@ -135,53 +135,75 @@ class FedYoloServer(object): self.n_data += self.client_n_data.get(client_id, 0) def agg(self): - """ - Aggregate client updates (FedAvg). - Returns: - global_state: aggregated model state dictionary - avg_loss: dict of averaged losses - n_data: total number of data classes samples used in this round - """ + """Aggregate client updates (FedAvg) on CPU/FP32, preserving non-float buffers.""" if len(self.selected_clients) == 0 or self.n_data == 0: return self.model.state_dict(), {}, 0 - # start from current global model - global_state = self.model.state_dict() - - # zero buffer for accumulation - new_state = {k: torch.zeros_like(v, dtype=torch.float32) for k, v in global_state.items()} + # Ensure global model is on CPU for safe load later + self.model.to("cpu") + global_state = self.model.state_dict() # may hold CPU or CUDA refs; we’re on CPU now avg_loss = {} - for name in self.selected_clients: - if name not in self.client_state: + total_n = float(self.n_data) + + # Prepare accumulators on CPU. For floating tensors, use float32 zeros. + # For non-floating tensors (e.g., BN num_batches_tracked int64), we’ll copy from the first client. + new_state = {} + first_client = None + for cid in self.selected_clients: + if cid in self.client_state: + first_client = cid + break + + assert first_client is not None, "No client states available to aggregate." + + for k, v in global_state.items(): + if v.is_floating_point(): + new_state[k] = torch.zeros_like(v.detach().cpu(), dtype=torch.float32) + else: + # For non-float buffers, just copy from the first client (or keep global) + new_state[k] = self.client_state[first_client][k].clone() + + # Accumulate floating tensors with weights; keep non-floats as assigned above + for cid in self.selected_clients: + if cid not in self.client_state: continue - weight = self.client_n_data[name] / self.n_data + weight = self.client_n_data[cid] / total_n + cst = self.client_state[cid] for k in new_state.keys(): - # accumulate in float32 to avoid fp16 issues - new_state[k] += self.client_state[name][k].to(torch.float32) * weight + if new_state[k].is_floating_point(): + # cst[k] is CPU; ensure float32 for accumulation + new_state[k].add_(cst[k].to(torch.float32), alpha=weight) - # losses - for k, v in self.client_loss[name].items(): - avg_loss[k] = avg_loss.get(k, 0.0) + v * weight + # weighted average losses + for lk, lv in self.client_loss[cid].items(): + avg_loss[lk] = avg_loss.get(lk, 0.0) + float(lv) * weight + + # Load aggregated state back into the global model (model is on CPU) + with torch.no_grad(): + self.model.load_state_dict(new_state, strict=True) - # load aggregated params back into global model - self.model.load_state_dict(new_state, strict=True) self.round += 1 - return self.model.state_dict(), avg_loss, self.n_data + # Return CPU state_dict (good for broadcasting to clients) + return {k: v.clone() for k, v in self.model.state_dict().items()}, avg_loss, int(self.n_data) def rec(self, name, state_dict, n_data, loss_dict): """ Receive local update from a client. - Args: - name: client ID - state_dict: state dictionary of the local model - n_data: number of data samples used in local training - loss_dict: dict of losses from local training + - Store all floating tensors as CPU float32 + - Store non-floating tensors (e.g., BN counters) as CPU in original dtype """ self.n_data += n_data - self.client_state[name] = {k: v.cpu() for k, v in state_dict.items()} - self.client_n_data[name] = n_data - self.client_loss[name] = loss_dict + safe_state = {} + with torch.no_grad(): + for k, v in state_dict.items(): + t = v.detach().cpu() + if t.is_floating_point(): + t = t.to(torch.float32) + safe_state[k] = t + self.client_state[name] = safe_state + self.client_n_data[name] = int(n_data) + self.client_loss[name] = {k: float(v) for k, v in loss_dict.items()} def flush(self): """Clear stored client updates."""