优化FedYoloClient和FedYoloServer类
This commit is contained in:
@@ -4,6 +4,7 @@ from torch.utils.data import DataLoader
|
||||
from utils.fed_util import init_model
|
||||
from utils.dataset import Dataset
|
||||
from utils import util
|
||||
from nets import YOLO
|
||||
|
||||
|
||||
class FedYoloServer(object):
|
||||
@@ -21,7 +22,7 @@ class FedYoloServer(object):
|
||||
self.client_n_data = {}
|
||||
self.selected_clients = []
|
||||
|
||||
self._batch_size = params.get("val_batch_size", 4)
|
||||
self._batch_size = params.get("val_batch_size", 200)
|
||||
self.client_list = client_list
|
||||
self.valset = None
|
||||
|
||||
@@ -40,7 +41,7 @@ class FedYoloServer(object):
|
||||
self.model = init_model(model_name, self._num_classes)
|
||||
self.params = params
|
||||
|
||||
def load_valset(self, valset):
|
||||
def load_valset(self, valset: Dataset):
|
||||
"""Server loads the validation dataset."""
|
||||
self.valset = valset
|
||||
|
||||
@@ -48,78 +49,6 @@ class FedYoloServer(object):
|
||||
"""Return global model weights."""
|
||||
return self.model.state_dict()
|
||||
|
||||
@torch.no_grad()
|
||||
def test(self, args) -> dict:
|
||||
"""
|
||||
Test the global model on the server's validation set.
|
||||
Returns:
|
||||
dict with keys: mAP, mAP50, precision, recall
|
||||
"""
|
||||
if self.valset is None:
|
||||
return {}
|
||||
|
||||
loader = DataLoader(
|
||||
self.valset,
|
||||
batch_size=self._batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
collate_fn=Dataset.collate_fn,
|
||||
)
|
||||
|
||||
dev = self._device
|
||||
# move to device for eval; keep in float32 for stability
|
||||
self.model.eval().to(dev).float()
|
||||
|
||||
iou_v = torch.linspace(0.5, 0.95, 10, device=dev)
|
||||
n_iou = iou_v.numel()
|
||||
metrics = []
|
||||
|
||||
for samples, targets in loader:
|
||||
samples = samples.to(dev, non_blocking=True).float() / 255.0
|
||||
_, _, h, w = samples.shape
|
||||
scale = torch.tensor((w, h, w, h), device=dev)
|
||||
|
||||
outputs = self.model(samples)
|
||||
outputs = util.non_max_suppression(outputs)
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
idx = targets["idx"] == i
|
||||
cls = targets["cls"][idx].to(dev)
|
||||
box = targets["box"][idx].to(dev)
|
||||
|
||||
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=dev)
|
||||
if output.shape[0] == 0:
|
||||
if cls.shape[0]:
|
||||
metrics.append((metric, *torch.zeros((2, 0), device=dev), cls.squeeze(-1)))
|
||||
continue
|
||||
|
||||
if cls.shape[0]:
|
||||
if cls.dim() == 1:
|
||||
cls = cls.unsqueeze(1)
|
||||
box_xy = util.wh2xy(box)
|
||||
if not isinstance(box_xy, torch.Tensor):
|
||||
box_xy = torch.tensor(box_xy, device=dev)
|
||||
target = torch.cat((cls, box_xy * scale), dim=1)
|
||||
metric = util.compute_metric(output[:, :6], target, iou_v)
|
||||
|
||||
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
||||
|
||||
if not metrics:
|
||||
# move back to CPU before returning
|
||||
self.model.to("cpu").float()
|
||||
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
|
||||
|
||||
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
|
||||
if len(metrics) and metrics[0].any():
|
||||
_, _, prec, rec, map50, mean_ap = util.compute_ap(*metrics, names=self.params["names"], plot=False)
|
||||
else:
|
||||
prec, rec, map50, mean_ap = 0, 0, 0, 0
|
||||
|
||||
# return model to CPU so next agg() stays device-consistent
|
||||
self.model.to("cpu").float()
|
||||
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
|
||||
|
||||
def select_clients(self, connection_ratio=1.0):
|
||||
"""
|
||||
Randomly select a fraction of clients.
|
||||
@@ -130,80 +59,69 @@ class FedYoloServer(object):
|
||||
self.n_data = 0
|
||||
for client_id in self.client_list:
|
||||
# Random selection based on connection ratio
|
||||
if np.random.rand() <= connection_ratio:
|
||||
s = np.random.binomial(np.ones(1).astype(int), connection_ratio)
|
||||
if s[0] == 1:
|
||||
self.selected_clients.append(client_id)
|
||||
self.n_data += self.client_n_data.get(client_id, 0)
|
||||
self.n_data += self.client_n_data[client_id]
|
||||
|
||||
@torch.no_grad()
|
||||
def agg(self):
|
||||
"""Aggregate client updates (FedAvg) on CPU/FP32, preserving non-float buffers."""
|
||||
"""
|
||||
Server aggregates the local updates from selected clients using FedAvg.
|
||||
|
||||
:return: model_state: aggregated model weights
|
||||
:return: avg_loss: weighted average training loss across selected clients
|
||||
:return: n_data: total number of data points across selected clients
|
||||
"""
|
||||
if len(self.selected_clients) == 0 or self.n_data == 0:
|
||||
return self.model.state_dict(), {}, 0
|
||||
import warnings
|
||||
|
||||
# Ensure global model is on CPU for safe load later
|
||||
self.model.to("cpu")
|
||||
global_state = self.model.state_dict() # may hold CPU or CUDA refs; we’re on CPU now
|
||||
warnings.warn("No clients selected or no data available for aggregation.")
|
||||
return self.model.state_dict(), 0, 0
|
||||
|
||||
avg_loss = {}
|
||||
total_n = float(self.n_data)
|
||||
# Initialize a model for aggregation
|
||||
model = init_model(model_name=self.model_name, num_classes=self._num_classes)
|
||||
model_state = model.state_dict()
|
||||
|
||||
# Prepare accumulators on CPU. For floating tensors, use float32 zeros.
|
||||
# For non-floating tensors (e.g., BN num_batches_tracked int64), we’ll copy from the first client.
|
||||
new_state = {}
|
||||
first_client = None
|
||||
for cid in self.selected_clients:
|
||||
if cid in self.client_state:
|
||||
first_client = cid
|
||||
break
|
||||
avg_loss = 0
|
||||
|
||||
assert first_client is not None, "No client states available to aggregate."
|
||||
|
||||
for k, v in global_state.items():
|
||||
if v.is_floating_point():
|
||||
new_state[k] = torch.zeros_like(v.detach().cpu(), dtype=torch.float32)
|
||||
else:
|
||||
# For non-float buffers, just copy from the first client (or keep global)
|
||||
new_state[k] = self.client_state[first_client][k].clone()
|
||||
|
||||
# Accumulate floating tensors with weights; keep non-floats as assigned above
|
||||
for cid in self.selected_clients:
|
||||
if cid not in self.client_state:
|
||||
# Aggregate the local updated models from selected clients
|
||||
for i, name in enumerate(self.selected_clients):
|
||||
if name not in self.client_state:
|
||||
continue
|
||||
weight = self.client_n_data[cid] / total_n
|
||||
cst = self.client_state[cid]
|
||||
for k in new_state.keys():
|
||||
if new_state[k].is_floating_point():
|
||||
# cst[k] is CPU; ensure float32 for accumulation
|
||||
new_state[k].add_(cst[k].to(torch.float32), alpha=weight)
|
||||
|
||||
# weighted average losses
|
||||
for lk, lv in self.client_loss[cid].items():
|
||||
avg_loss[lk] = avg_loss.get(lk, 0.0) + float(lv) * weight
|
||||
|
||||
# Load aggregated state back into the global model (model is on CPU)
|
||||
with torch.no_grad():
|
||||
self.model.load_state_dict(new_state, strict=True)
|
||||
for key in self.client_state[name]:
|
||||
if i == 0:
|
||||
# First client, initialize the model_state
|
||||
model_state[key] = self.client_state[name][key] * (self.client_n_data[name] / self.n_data)
|
||||
else:
|
||||
# math equation: w = sum(n_k / n * w_k)
|
||||
model_state[key] = model_state[key] + self.client_state[name][key] * (
|
||||
self.client_n_data[name] / self.n_data
|
||||
)
|
||||
avg_loss = avg_loss + self.client_loss[name] * (self.client_n_data[name] / self.n_data)
|
||||
|
||||
self.model.load_state_dict(model_state, strict=True)
|
||||
self.round += 1
|
||||
# Return CPU state_dict (good for broadcasting to clients)
|
||||
return {k: v.clone() for k, v in self.model.state_dict().items()}, avg_loss, int(self.n_data)
|
||||
|
||||
def rec(self, name, state_dict, n_data, loss_dict):
|
||||
n_data = self.n_data
|
||||
|
||||
return model_state, avg_loss, n_data
|
||||
|
||||
def rec(self, name, state_dict, n_data, loss):
|
||||
"""
|
||||
Receive local update from a client.
|
||||
- Store all floating tensors as CPU float32
|
||||
- Store non-floating tensors (e.g., BN counters) as CPU in original dtype
|
||||
"""
|
||||
self.n_data += n_data
|
||||
safe_state = {}
|
||||
with torch.no_grad():
|
||||
for k, v in state_dict.items():
|
||||
t = v.detach().cpu()
|
||||
if t.is_floating_point():
|
||||
t = t.to(torch.float32)
|
||||
safe_state[k] = t
|
||||
self.client_state[name] = safe_state
|
||||
|
||||
self.client_state[name] = {}
|
||||
self.client_n_data[name] = {}
|
||||
self.client_loss[name] = {}
|
||||
|
||||
self.client_state[name].update(state_dict)
|
||||
self.client_n_data[name] = int(n_data)
|
||||
self.client_loss[name] = {k: float(v) for k, v in loss_dict.items()}
|
||||
self.client_loss[name] = loss
|
||||
|
||||
def flush(self):
|
||||
"""Clear stored client updates."""
|
||||
@@ -211,3 +129,94 @@ class FedYoloServer(object):
|
||||
self.client_state.clear()
|
||||
self.client_n_data.clear()
|
||||
self.client_loss.clear()
|
||||
|
||||
def test(self):
|
||||
"""Evaluate the global model on the server's validation dataset."""
|
||||
if self.valset is None:
|
||||
import warnings
|
||||
|
||||
warnings.warn("No validation dataset available for testing.")
|
||||
return {}
|
||||
return test(self.valset, self.params, self.model)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]:
|
||||
"""
|
||||
Evaluate the model on the validation dataset.
|
||||
Args:
|
||||
valset: validation dataset
|
||||
params: dict of parameters (must include 'names')
|
||||
model: YOLO model to evaluate
|
||||
batch_size: batch size for evaluation
|
||||
Returns:
|
||||
dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap)
|
||||
"""
|
||||
loader = DataLoader(
|
||||
dataset=valset,
|
||||
batch_size=batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
collate_fn=Dataset.collate_fn,
|
||||
)
|
||||
|
||||
model.cuda()
|
||||
model.half()
|
||||
model.eval()
|
||||
|
||||
# Configure
|
||||
iou_v = torch.linspace(start=0.5, end=0.95, steps=10).cuda() # iou vector for mAP@0.5:0.95
|
||||
n_iou = iou_v.numel()
|
||||
|
||||
m_pre = 0
|
||||
m_rec = 0
|
||||
map50 = 0
|
||||
mean_ap = 0
|
||||
metrics = []
|
||||
|
||||
for samples, targets in loader:
|
||||
samples = samples.cuda()
|
||||
samples = samples.half() # uint8 to fp16/32
|
||||
samples = samples / 255.0 # 0 - 255 to 0.0 - 1.0
|
||||
_, _, h, w = samples.shape # batch-size, channels, height, width
|
||||
scale = torch.tensor((w, h, w, h)).cuda()
|
||||
# Inference
|
||||
outputs = model(samples)
|
||||
# NMS
|
||||
outputs = util.non_max_suppression(outputs)
|
||||
# Metrics
|
||||
for i, output in enumerate(outputs):
|
||||
idx = targets["idx"]
|
||||
if idx.dim() > 1:
|
||||
idx = idx.squeeze(-1)
|
||||
idx = idx == i
|
||||
# idx = targets["idx"] == i
|
||||
cls = targets["cls"][idx]
|
||||
box = targets["box"][idx]
|
||||
|
||||
cls = cls.cuda()
|
||||
box = box.cuda()
|
||||
|
||||
metric = torch.zeros(output.shape[0], n_iou, dtype=torch.bool).cuda()
|
||||
|
||||
if output.shape[0] == 0:
|
||||
if cls.shape[0]:
|
||||
metrics.append((metric, *torch.zeros((2, 0)).cuda(), cls.squeeze(-1)))
|
||||
continue
|
||||
# Evaluate
|
||||
if cls.shape[0]:
|
||||
target = torch.cat(tensors=(cls, util.wh2xy(box) * scale), dim=1)
|
||||
metric = util.compute_metric(output[:, :6], target, iou_v)
|
||||
# Append
|
||||
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
||||
|
||||
# Compute metrics
|
||||
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
|
||||
if len(metrics) and metrics[0].any():
|
||||
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"])
|
||||
# Print results
|
||||
# print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
|
||||
# Return results
|
||||
model.float() # for training
|
||||
return mean_ap, map50, m_rec, m_pre
|
||||
|
Reference in New Issue
Block a user