import numpy as np import torch from torch.utils.data import DataLoader from utils.fed_util import init_model from utils.dataset import Dataset from utils import util from nets import YOLO class FedYoloServer(object): def __init__(self, client_list, model_name, params): """ Federated YOLO Server Args: client_list: list of connected clients model_name: YOLO model architecture name params: dict of hyperparameters (must include 'names') """ # Track client updates self.client_state = {} self.client_loss = {} self.client_n_data = {} self.selected_clients = [] self._batch_size = params.get("val_batch_size", 200) self.client_list = client_list self.valset = None # Federated bookkeeping self.round = 0 # Total number of classes self.n_data = 0 # Device gpu = 0 self._device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu") # Global model self._num_classes = len(params["names"]) self.model_name = model_name self.model = init_model(model_name, self._num_classes) self.params = params def load_valset(self, valset: Dataset): """Server loads the validation dataset.""" self.valset = valset def state_dict(self): """Return global model weights.""" return self.model.state_dict() def select_clients(self, connection_ratio=1.0): """ Randomly select a fraction of clients. Args: connection_ratio: fraction of clients to select (0 < connection_ratio <= 1) """ self.selected_clients = [] self.n_data = 0 for client_id in self.client_list: # Random selection based on connection ratio s = np.random.binomial(np.ones(1).astype(int), connection_ratio) if s[0] == 1: self.selected_clients.append(client_id) self.n_data += self.client_n_data[client_id] @torch.no_grad() def agg(self): """ Server aggregates the local updates from selected clients using FedAvg. :return: model_state: aggregated model weights :return: avg_loss: weighted average training loss across selected clients :return: n_data: total number of data points across selected clients """ if len(self.selected_clients) == 0 or self.n_data == 0: import warnings warnings.warn("No clients selected or no data available for aggregation.") return self.model.state_dict(), 0, 0 # Initialize a model for aggregation model = init_model(model_name=self.model_name, num_classes=self._num_classes) model_state = model.state_dict() avg_loss = 0 # Aggregate the local updated models from selected clients for i, name in enumerate(self.selected_clients): if name not in self.client_state: continue for key in self.client_state[name]: if i == 0: # First client, initialize the model_state model_state[key] = self.client_state[name][key] * (self.client_n_data[name] / self.n_data) else: # math equation: w = sum(n_k / n * w_k) model_state[key] = model_state[key] + self.client_state[name][key] * ( self.client_n_data[name] / self.n_data ) avg_loss = avg_loss + self.client_loss[name] * (self.client_n_data[name] / self.n_data) self.model.load_state_dict(model_state, strict=True) self.round += 1 n_data = self.n_data return model_state, avg_loss, n_data def rec(self, name, state_dict, n_data, loss): """ Receive local update from a client. - Store all floating tensors as CPU float32 - Store non-floating tensors (e.g., BN counters) as CPU in original dtype """ self.n_data += n_data self.client_state[name] = {} self.client_n_data[name] = {} self.client_loss[name] = {} self.client_state[name].update(state_dict) self.client_n_data[name] = int(n_data) self.client_loss[name] = loss def flush(self): """Clear stored client updates.""" self.n_data = 0 self.client_state.clear() self.client_n_data.clear() self.client_loss.clear() def test(self): """Evaluate the global model on the server's validation dataset.""" if self.valset is None: import warnings warnings.warn("No validation dataset available for testing.") return {} return test(self.valset, self.params, self.model) @torch.no_grad() def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]: """ Evaluate the model on the validation dataset. Args: valset: validation dataset params: dict of parameters (must include 'names') model: YOLO model to evaluate batch_size: batch size for evaluation Returns: dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap) """ loader = DataLoader( dataset=valset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True, collate_fn=Dataset.collate_fn, ) model.cuda() model.half() model.eval() # Configure iou_v = torch.linspace(start=0.5, end=0.95, steps=10).cuda() # iou vector for mAP@0.5:0.95 n_iou = iou_v.numel() m_pre = 0 m_rec = 0 map50 = 0 mean_ap = 0 metrics = [] for samples, targets in loader: samples = samples.cuda() samples = samples.half() # uint8 to fp16/32 samples = samples / 255.0 # 0 - 255 to 0.0 - 1.0 _, _, h, w = samples.shape # batch-size, channels, height, width scale = torch.tensor((w, h, w, h)).cuda() # Inference outputs = model(samples) # NMS outputs = util.non_max_suppression(outputs) # Metrics for i, output in enumerate(outputs): idx = targets["idx"] if idx.dim() > 1: idx = idx.squeeze(-1) idx = idx == i # idx = targets["idx"] == i cls = targets["cls"][idx] box = targets["box"][idx] cls = cls.cuda() box = box.cuda() metric = torch.zeros(output.shape[0], n_iou, dtype=torch.bool).cuda() if output.shape[0] == 0: if cls.shape[0]: metrics.append((metric, *torch.zeros((2, 0)).cuda(), cls.squeeze(-1))) continue # Evaluate if cls.shape[0]: target = torch.cat(tensors=(cls, util.wh2xy(box) * scale), dim=1) metric = util.compute_metric(output[:, :6], target, iou_v) # Append metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1))) # Compute metrics metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy if len(metrics) and metrics[0].any(): tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"]) # Print results # print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap)) # Return results model.float() # for training return mean_ap, map50, m_rec, m_pre