Files
fed-yolo/fed_algo_cs/server_base.py

214 lines
7.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import numpy as np
import torch
from torch.utils.data import DataLoader
from utils.fed_util import init_model
from utils.dataset import Dataset
from utils import util
class FedYoloServer(object):
def __init__(self, client_list, model_name, params):
"""
Federated YOLO Server
Args:
client_list: list of connected clients
model_name: YOLO model architecture name
params: dict of hyperparameters (must include 'names')
"""
# Track client updates
self.client_state = {}
self.client_loss = {}
self.client_n_data = {}
self.selected_clients = []
self._batch_size = params.get("val_batch_size", 4)
self.client_list = client_list
self.valset = None
# Federated bookkeeping
self.round = 0
# Total number of classes
self.n_data = 0
# Device
gpu = 0
self._device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")
# Global model
self._num_classes = len(params["names"])
self.model_name = model_name
self.model = init_model(model_name, self._num_classes)
self.params = params
def load_valset(self, valset):
"""Server loads the validation dataset."""
self.valset = valset
def state_dict(self):
"""Return global model weights."""
return self.model.state_dict()
@torch.no_grad()
def test(self, args) -> dict:
"""
Test the global model on the server's validation set.
Returns:
dict with keys: mAP, mAP50, precision, recall
"""
if self.valset is None:
return {}
loader = DataLoader(
self.valset,
batch_size=self._batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
collate_fn=Dataset.collate_fn,
)
dev = self._device
# move to device for eval; keep in float32 for stability
self.model.eval().to(dev).float()
iou_v = torch.linspace(0.5, 0.95, 10, device=dev)
n_iou = iou_v.numel()
metrics = []
for samples, targets in loader:
samples = samples.to(dev, non_blocking=True).float() / 255.0
_, _, h, w = samples.shape
scale = torch.tensor((w, h, w, h), device=dev)
outputs = self.model(samples)
outputs = util.non_max_suppression(outputs)
for i, output in enumerate(outputs):
idx = targets["idx"] == i
cls = targets["cls"][idx].to(dev)
box = targets["box"][idx].to(dev)
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=dev)
if output.shape[0] == 0:
if cls.shape[0]:
metrics.append((metric, *torch.zeros((2, 0), device=dev), cls.squeeze(-1)))
continue
if cls.shape[0]:
if cls.dim() == 1:
cls = cls.unsqueeze(1)
box_xy = util.wh2xy(box)
if not isinstance(box_xy, torch.Tensor):
box_xy = torch.tensor(box_xy, device=dev)
target = torch.cat((cls, box_xy * scale), dim=1)
metric = util.compute_metric(output[:, :6], target, iou_v)
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
if not metrics:
# move back to CPU before returning
self.model.to("cpu").float()
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
if len(metrics) and metrics[0].any():
_, _, prec, rec, map50, mean_ap = util.compute_ap(*metrics, names=self.params["names"], plot=False)
else:
prec, rec, map50, mean_ap = 0, 0, 0, 0
# return model to CPU so next agg() stays device-consistent
self.model.to("cpu").float()
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
def select_clients(self, connection_ratio=1.0):
"""
Randomly select a fraction of clients.
Args:
connection_ratio: fraction of clients to select (0 < connection_ratio <= 1)
"""
self.selected_clients = []
self.n_data = 0
for client_id in self.client_list:
# Random selection based on connection ratio
if np.random.rand() <= connection_ratio:
self.selected_clients.append(client_id)
self.n_data += self.client_n_data.get(client_id, 0)
def agg(self):
"""Aggregate client updates (FedAvg) on CPU/FP32, preserving non-float buffers."""
if len(self.selected_clients) == 0 or self.n_data == 0:
return self.model.state_dict(), {}, 0
# Ensure global model is on CPU for safe load later
self.model.to("cpu")
global_state = self.model.state_dict() # may hold CPU or CUDA refs; were on CPU now
avg_loss = {}
total_n = float(self.n_data)
# Prepare accumulators on CPU. For floating tensors, use float32 zeros.
# For non-floating tensors (e.g., BN num_batches_tracked int64), well copy from the first client.
new_state = {}
first_client = None
for cid in self.selected_clients:
if cid in self.client_state:
first_client = cid
break
assert first_client is not None, "No client states available to aggregate."
for k, v in global_state.items():
if v.is_floating_point():
new_state[k] = torch.zeros_like(v.detach().cpu(), dtype=torch.float32)
else:
# For non-float buffers, just copy from the first client (or keep global)
new_state[k] = self.client_state[first_client][k].clone()
# Accumulate floating tensors with weights; keep non-floats as assigned above
for cid in self.selected_clients:
if cid not in self.client_state:
continue
weight = self.client_n_data[cid] / total_n
cst = self.client_state[cid]
for k in new_state.keys():
if new_state[k].is_floating_point():
# cst[k] is CPU; ensure float32 for accumulation
new_state[k].add_(cst[k].to(torch.float32), alpha=weight)
# weighted average losses
for lk, lv in self.client_loss[cid].items():
avg_loss[lk] = avg_loss.get(lk, 0.0) + float(lv) * weight
# Load aggregated state back into the global model (model is on CPU)
with torch.no_grad():
self.model.load_state_dict(new_state, strict=True)
self.round += 1
# Return CPU state_dict (good for broadcasting to clients)
return {k: v.clone() for k, v in self.model.state_dict().items()}, avg_loss, int(self.n_data)
def rec(self, name, state_dict, n_data, loss_dict):
"""
Receive local update from a client.
- Store all floating tensors as CPU float32
- Store non-floating tensors (e.g., BN counters) as CPU in original dtype
"""
self.n_data += n_data
safe_state = {}
with torch.no_grad():
for k, v in state_dict.items():
t = v.detach().cpu()
if t.is_floating_point():
t = t.to(torch.float32)
safe_state[k] = t
self.client_state[name] = safe_state
self.client_n_data[name] = int(n_data)
self.client_loss[name] = {k: float(v) for k, v in loss_dict.items()}
def flush(self):
"""Clear stored client updates."""
self.n_data = 0
self.client_state.clear()
self.client_n_data.clear()
self.client_loss.clear()