223 lines
7.4 KiB
Python
223 lines
7.4 KiB
Python
import numpy as np
|
|
import torch
|
|
from torch.utils.data import DataLoader
|
|
from utils.fed_util import init_model
|
|
from utils.dataset import Dataset
|
|
from utils import util
|
|
from nets import YOLO
|
|
|
|
|
|
class FedYoloServer(object):
|
|
def __init__(self, client_list, model_name, params):
|
|
"""
|
|
Federated YOLO Server
|
|
Args:
|
|
client_list: list of connected clients
|
|
model_name: YOLO model architecture name
|
|
params: dict of hyperparameters (must include 'names')
|
|
"""
|
|
# Track client updates
|
|
self.client_state = {}
|
|
self.client_loss = {}
|
|
self.client_n_data = {}
|
|
self.selected_clients = []
|
|
|
|
self._batch_size = params.get("val_batch_size", 200)
|
|
self.client_list = client_list
|
|
self.valset = None
|
|
|
|
# Federated bookkeeping
|
|
self.round = 0
|
|
# Total number of classes
|
|
self.n_data = 0
|
|
|
|
# Device
|
|
gpu = 0
|
|
self._device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")
|
|
|
|
# Global model
|
|
self._num_classes = len(params["names"])
|
|
self.model_name = model_name
|
|
self.model = init_model(model_name, self._num_classes)
|
|
self.params = params
|
|
|
|
def load_valset(self, valset: Dataset):
|
|
"""Server loads the validation dataset."""
|
|
self.valset = valset
|
|
|
|
def state_dict(self):
|
|
"""Return global model weights."""
|
|
return self.model.state_dict()
|
|
|
|
def select_clients(self, connection_ratio=1.0):
|
|
"""
|
|
Randomly select a fraction of clients.
|
|
Args:
|
|
connection_ratio: fraction of clients to select (0 < connection_ratio <= 1)
|
|
"""
|
|
self.selected_clients = []
|
|
self.n_data = 0
|
|
for client_id in self.client_list:
|
|
# Random selection based on connection ratio
|
|
s = np.random.binomial(np.ones(1).astype(int), connection_ratio)
|
|
if s[0] == 1:
|
|
self.selected_clients.append(client_id)
|
|
self.n_data += self.client_n_data[client_id]
|
|
|
|
@torch.no_grad()
|
|
def agg(self):
|
|
"""
|
|
Server aggregates the local updates from selected clients using FedAvg.
|
|
|
|
:return: model_state: aggregated model weights
|
|
:return: avg_loss: weighted average training loss across selected clients
|
|
:return: n_data: total number of data points across selected clients
|
|
"""
|
|
if len(self.selected_clients) == 0 or self.n_data == 0:
|
|
import warnings
|
|
|
|
warnings.warn("No clients selected or no data available for aggregation.")
|
|
return self.model.state_dict(), 0, 0
|
|
|
|
# Initialize a model for aggregation
|
|
model = init_model(model_name=self.model_name, num_classes=self._num_classes)
|
|
model_state = model.state_dict()
|
|
|
|
avg_loss = 0
|
|
|
|
# Aggregate the local updated models from selected clients
|
|
for i, name in enumerate(self.selected_clients):
|
|
if name not in self.client_state:
|
|
continue
|
|
for key in self.client_state[name]:
|
|
if i == 0:
|
|
# First client, initialize the model_state
|
|
model_state[key] = self.client_state[name][key] * (self.client_n_data[name] / self.n_data)
|
|
else:
|
|
# math equation: w = sum(n_k / n * w_k)
|
|
model_state[key] = model_state[key] + self.client_state[name][key] * (
|
|
self.client_n_data[name] / self.n_data
|
|
)
|
|
avg_loss = avg_loss + self.client_loss[name] * (self.client_n_data[name] / self.n_data)
|
|
|
|
self.model.load_state_dict(model_state, strict=True)
|
|
self.round += 1
|
|
|
|
n_data = self.n_data
|
|
|
|
return model_state, avg_loss, n_data
|
|
|
|
def rec(self, name, state_dict, n_data, loss):
|
|
"""
|
|
Receive local update from a client.
|
|
- Store all floating tensors as CPU float32
|
|
- Store non-floating tensors (e.g., BN counters) as CPU in original dtype
|
|
"""
|
|
self.n_data += n_data
|
|
|
|
self.client_state[name] = {}
|
|
self.client_n_data[name] = {}
|
|
self.client_loss[name] = {}
|
|
|
|
self.client_state[name].update(state_dict)
|
|
self.client_n_data[name] = int(n_data)
|
|
self.client_loss[name] = loss
|
|
|
|
def flush(self):
|
|
"""Clear stored client updates."""
|
|
self.n_data = 0
|
|
self.client_state.clear()
|
|
self.client_n_data.clear()
|
|
self.client_loss.clear()
|
|
|
|
def test(self):
|
|
"""Evaluate the global model on the server's validation dataset."""
|
|
if self.valset is None:
|
|
import warnings
|
|
|
|
warnings.warn("No validation dataset available for testing.")
|
|
return {}
|
|
return test(self.valset, self.params, self.model)
|
|
|
|
|
|
@torch.no_grad()
|
|
def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]:
|
|
"""
|
|
Evaluate the model on the validation dataset.
|
|
Args:
|
|
valset: validation dataset
|
|
params: dict of parameters (must include 'names')
|
|
model: YOLO model to evaluate
|
|
batch_size: batch size for evaluation
|
|
Returns:
|
|
dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap)
|
|
"""
|
|
loader = DataLoader(
|
|
dataset=valset,
|
|
batch_size=batch_size,
|
|
shuffle=False,
|
|
num_workers=4,
|
|
pin_memory=True,
|
|
collate_fn=Dataset.collate_fn,
|
|
)
|
|
|
|
model.cuda()
|
|
model.half()
|
|
model.eval()
|
|
|
|
# Configure
|
|
iou_v = torch.linspace(start=0.5, end=0.95, steps=10).cuda() # iou vector for mAP@0.5:0.95
|
|
n_iou = iou_v.numel()
|
|
|
|
m_pre = 0
|
|
m_rec = 0
|
|
map50 = 0
|
|
mean_ap = 0
|
|
metrics = []
|
|
|
|
for samples, targets in loader:
|
|
samples = samples.cuda()
|
|
samples = samples.half() # uint8 to fp16/32
|
|
samples = samples / 255.0 # 0 - 255 to 0.0 - 1.0
|
|
_, _, h, w = samples.shape # batch-size, channels, height, width
|
|
scale = torch.tensor((w, h, w, h)).cuda()
|
|
# Inference
|
|
outputs = model(samples)
|
|
# NMS
|
|
outputs = util.non_max_suppression(outputs)
|
|
# Metrics
|
|
for i, output in enumerate(outputs):
|
|
idx = targets["idx"]
|
|
if idx.dim() > 1:
|
|
idx = idx.squeeze(-1)
|
|
idx = idx == i
|
|
# idx = targets["idx"] == i
|
|
cls = targets["cls"][idx]
|
|
box = targets["box"][idx]
|
|
|
|
cls = cls.cuda()
|
|
box = box.cuda()
|
|
|
|
metric = torch.zeros(output.shape[0], n_iou, dtype=torch.bool).cuda()
|
|
|
|
if output.shape[0] == 0:
|
|
if cls.shape[0]:
|
|
metrics.append((metric, *torch.zeros((2, 0)).cuda(), cls.squeeze(-1)))
|
|
continue
|
|
# Evaluate
|
|
if cls.shape[0]:
|
|
target = torch.cat(tensors=(cls, util.wh2xy(box) * scale), dim=1)
|
|
metric = util.compute_metric(output[:, :6], target, iou_v)
|
|
# Append
|
|
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
|
|
|
# Compute metrics
|
|
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
|
|
if len(metrics) and metrics[0].any():
|
|
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"])
|
|
# Print results
|
|
# print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
|
|
# Return results
|
|
model.float() # for training
|
|
return mean_ap, map50, m_rec, m_pre
|