Files
fed-yolo/fed_algo_cs/server_base.py
2025-10-19 21:27:19 +08:00

223 lines
7.4 KiB
Python

import numpy as np
import torch
from torch.utils.data import DataLoader
from utils.fed_util import init_model
from utils.dataset import Dataset
from utils import util
from nets import YOLO
class FedYoloServer(object):
def __init__(self, client_list, model_name, params):
"""
Federated YOLO Server
Args:
client_list: list of connected clients
model_name: YOLO model architecture name
params: dict of hyperparameters (must include 'names')
"""
# Track client updates
self.client_state = {}
self.client_loss = {}
self.client_n_data = {}
self.selected_clients = []
self._batch_size = params.get("val_batch_size", 200)
self.client_list = client_list
self.valset = None
# Federated bookkeeping
self.round = 0
# Total number of classes
self.n_data = 0
# Device
gpu = 0
self._device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")
# Global model
self._num_classes = len(params["names"])
self.model_name = model_name
self.model = init_model(model_name, self._num_classes)
self.params = params
def load_valset(self, valset: Dataset):
"""Server loads the validation dataset."""
self.valset = valset
def state_dict(self):
"""Return global model weights."""
return self.model.state_dict()
def select_clients(self, connection_ratio=1.0):
"""
Randomly select a fraction of clients.
Args:
connection_ratio: fraction of clients to select (0 < connection_ratio <= 1)
"""
self.selected_clients = []
self.n_data = 0
for client_id in self.client_list:
# Random selection based on connection ratio
s = np.random.binomial(np.ones(1).astype(int), connection_ratio)
if s[0] == 1:
self.selected_clients.append(client_id)
self.n_data += self.client_n_data[client_id]
@torch.no_grad()
def agg(self):
"""
Server aggregates the local updates from selected clients using FedAvg.
:return: model_state: aggregated model weights
:return: avg_loss: weighted average training loss across selected clients
:return: n_data: total number of data points across selected clients
"""
if len(self.selected_clients) == 0 or self.n_data == 0:
import warnings
warnings.warn("No clients selected or no data available for aggregation.")
return self.model.state_dict(), 0, 0
# Initialize a model for aggregation
model = init_model(model_name=self.model_name, num_classes=self._num_classes)
model_state = model.state_dict()
avg_loss = 0
# Aggregate the local updated models from selected clients
for i, name in enumerate(self.selected_clients):
if name not in self.client_state:
continue
for key in self.client_state[name]:
if i == 0:
# First client, initialize the model_state
model_state[key] = self.client_state[name][key] * (self.client_n_data[name] / self.n_data)
else:
# math equation: w = sum(n_k / n * w_k)
model_state[key] = model_state[key] + self.client_state[name][key] * (
self.client_n_data[name] / self.n_data
)
avg_loss = avg_loss + self.client_loss[name] * (self.client_n_data[name] / self.n_data)
self.model.load_state_dict(model_state, strict=True)
self.round += 1
n_data = self.n_data
return model_state, avg_loss, n_data
def rec(self, name, state_dict, n_data, loss):
"""
Receive local update from a client.
- Store all floating tensors as CPU float32
- Store non-floating tensors (e.g., BN counters) as CPU in original dtype
"""
self.n_data += n_data
self.client_state[name] = {}
self.client_n_data[name] = {}
self.client_loss[name] = {}
self.client_state[name].update(state_dict)
self.client_n_data[name] = int(n_data)
self.client_loss[name] = loss
def flush(self):
"""Clear stored client updates."""
self.n_data = 0
self.client_state.clear()
self.client_n_data.clear()
self.client_loss.clear()
def test(self):
"""Evaluate the global model on the server's validation dataset."""
if self.valset is None:
import warnings
warnings.warn("No validation dataset available for testing.")
return {}
return test(self.valset, self.params, self.model)
@torch.no_grad()
def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]:
"""
Evaluate the model on the validation dataset.
Args:
valset: validation dataset
params: dict of parameters (must include 'names')
model: YOLO model to evaluate
batch_size: batch size for evaluation
Returns:
dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap)
"""
loader = DataLoader(
dataset=valset,
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
collate_fn=Dataset.collate_fn,
)
model.cuda()
model.half()
model.eval()
# Configure
iou_v = torch.linspace(start=0.5, end=0.95, steps=10).cuda() # iou vector for mAP@0.5:0.95
n_iou = iou_v.numel()
m_pre = 0
m_rec = 0
map50 = 0
mean_ap = 0
metrics = []
for samples, targets in loader:
samples = samples.cuda()
samples = samples.half() # uint8 to fp16/32
samples = samples / 255.0 # 0 - 255 to 0.0 - 1.0
_, _, h, w = samples.shape # batch-size, channels, height, width
scale = torch.tensor((w, h, w, h)).cuda()
# Inference
outputs = model(samples)
# NMS
outputs = util.non_max_suppression(outputs)
# Metrics
for i, output in enumerate(outputs):
idx = targets["idx"]
if idx.dim() > 1:
idx = idx.squeeze(-1)
idx = idx == i
# idx = targets["idx"] == i
cls = targets["cls"][idx]
box = targets["box"][idx]
cls = cls.cuda()
box = box.cuda()
metric = torch.zeros(output.shape[0], n_iou, dtype=torch.bool).cuda()
if output.shape[0] == 0:
if cls.shape[0]:
metrics.append((metric, *torch.zeros((2, 0)).cuda(), cls.squeeze(-1)))
continue
# Evaluate
if cls.shape[0]:
target = torch.cat(tensors=(cls, util.wh2xy(box) * scale), dim=1)
metric = util.compute_metric(output[:, :6], target, iou_v)
# Append
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
# Compute metrics
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
if len(metrics) and metrics[0].any():
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"])
# Print results
# print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
# Return results
model.float() # for training
return mean_ap, map50, m_rec, m_pre