优化test和agg方法,增强模型评估和聚合逻辑的稳定性
This commit is contained in:
@@ -49,11 +49,11 @@ class FedYoloServer(object):
|
|||||||
return self.model.state_dict()
|
return self.model.state_dict()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def test(self, args):
|
def test(self, args) -> dict:
|
||||||
"""
|
"""
|
||||||
Evaluate global model on validation set using YOLO metrics (mAP, precision, recall).
|
Test the global model on the server's validation set.
|
||||||
Returns:
|
Returns:
|
||||||
dict with {"mAP": ..., "mAP50": ..., "precision": ..., "recall": ...}
|
dict with keys: mAP, mAP50, precision, recall
|
||||||
"""
|
"""
|
||||||
if self.valset is None:
|
if self.valset is None:
|
||||||
return {}
|
return {}
|
||||||
@@ -67,46 +67,47 @@ class FedYoloServer(object):
|
|||||||
collate_fn=Dataset.collate_fn,
|
collate_fn=Dataset.collate_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model.to(self._device).eval().half()
|
dev = self._device
|
||||||
|
# move to device for eval; keep in float32 for stability
|
||||||
|
self.model.eval().to(dev).float()
|
||||||
|
|
||||||
iou_v = torch.linspace(0.5, 0.95, 10).to(self._device) # IoU thresholds
|
iou_v = torch.linspace(0.5, 0.95, 10, device=dev)
|
||||||
n_iou = iou_v.numel()
|
n_iou = iou_v.numel()
|
||||||
metrics = []
|
metrics = []
|
||||||
|
|
||||||
for samples, targets in loader:
|
for samples, targets in loader:
|
||||||
samples = samples.to(self._device).half() / 255.0
|
samples = samples.to(dev, non_blocking=True).float() / 255.0
|
||||||
_, _, h, w = samples.shape
|
_, _, h, w = samples.shape
|
||||||
scale = torch.tensor((w, h, w, h)).to(self._device)
|
scale = torch.tensor((w, h, w, h), device=dev)
|
||||||
|
|
||||||
outputs = self.model(samples)
|
outputs = self.model(samples)
|
||||||
outputs = util.non_max_suppression(outputs)
|
outputs = util.non_max_suppression(outputs)
|
||||||
|
|
||||||
for i, output in enumerate(outputs):
|
for i, output in enumerate(outputs):
|
||||||
idx = targets["idx"] == i
|
idx = targets["idx"] == i
|
||||||
cls = targets["cls"][idx].to(self._device)
|
cls = targets["cls"][idx].to(dev)
|
||||||
box = targets["box"][idx].to(self._device)
|
box = targets["box"][idx].to(dev)
|
||||||
|
|
||||||
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=self._device)
|
|
||||||
|
|
||||||
|
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=dev)
|
||||||
if output.shape[0] == 0:
|
if output.shape[0] == 0:
|
||||||
if cls.shape[0]:
|
if cls.shape[0]:
|
||||||
metrics.append((metric, *torch.zeros((2, 0), device=self._device), cls.squeeze(-1)))
|
metrics.append((metric, *torch.zeros((2, 0), device=dev), cls.squeeze(-1)))
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if cls.shape[0]:
|
if cls.shape[0]:
|
||||||
cls_tensor = cls if isinstance(cls, torch.Tensor) else torch.tensor(cls, device=self._device)
|
if cls.dim() == 1:
|
||||||
if cls_tensor.dim() == 1:
|
cls = cls.unsqueeze(1)
|
||||||
cls_tensor = cls_tensor.unsqueeze(1)
|
|
||||||
box_xy = util.wh2xy(box)
|
box_xy = util.wh2xy(box)
|
||||||
if not isinstance(box_xy, torch.Tensor):
|
if not isinstance(box_xy, torch.Tensor):
|
||||||
box_xy = torch.tensor(box_xy, device=self._device)
|
box_xy = torch.tensor(box_xy, device=dev)
|
||||||
target = torch.cat((cls_tensor, box_xy * scale), dim=1)
|
target = torch.cat((cls, box_xy * scale), dim=1)
|
||||||
metric = util.compute_metric(output[:, :6], target, iou_v)
|
metric = util.compute_metric(output[:, :6], target, iou_v)
|
||||||
|
|
||||||
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
||||||
|
|
||||||
# Compute metrics
|
|
||||||
if not metrics:
|
if not metrics:
|
||||||
|
# move back to CPU before returning
|
||||||
|
self.model.to("cpu").float()
|
||||||
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
|
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
|
||||||
|
|
||||||
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
|
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
|
||||||
@@ -115,9 +116,8 @@ class FedYoloServer(object):
|
|||||||
else:
|
else:
|
||||||
prec, rec, map50, mean_ap = 0, 0, 0, 0
|
prec, rec, map50, mean_ap = 0, 0, 0, 0
|
||||||
|
|
||||||
# Back to float32 for further training
|
# return model to CPU so next agg() stays device-consistent
|
||||||
self.model.float()
|
self.model.to("cpu").float()
|
||||||
|
|
||||||
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
|
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
|
||||||
|
|
||||||
def select_clients(self, connection_ratio=1.0):
|
def select_clients(self, connection_ratio=1.0):
|
||||||
@@ -135,53 +135,75 @@ class FedYoloServer(object):
|
|||||||
self.n_data += self.client_n_data.get(client_id, 0)
|
self.n_data += self.client_n_data.get(client_id, 0)
|
||||||
|
|
||||||
def agg(self):
|
def agg(self):
|
||||||
"""
|
"""Aggregate client updates (FedAvg) on CPU/FP32, preserving non-float buffers."""
|
||||||
Aggregate client updates (FedAvg).
|
|
||||||
Returns:
|
|
||||||
global_state: aggregated model state dictionary
|
|
||||||
avg_loss: dict of averaged losses
|
|
||||||
n_data: total number of data classes samples used in this round
|
|
||||||
"""
|
|
||||||
if len(self.selected_clients) == 0 or self.n_data == 0:
|
if len(self.selected_clients) == 0 or self.n_data == 0:
|
||||||
return self.model.state_dict(), {}, 0
|
return self.model.state_dict(), {}, 0
|
||||||
|
|
||||||
# start from current global model
|
# Ensure global model is on CPU for safe load later
|
||||||
global_state = self.model.state_dict()
|
self.model.to("cpu")
|
||||||
|
global_state = self.model.state_dict() # may hold CPU or CUDA refs; we’re on CPU now
|
||||||
# zero buffer for accumulation
|
|
||||||
new_state = {k: torch.zeros_like(v, dtype=torch.float32) for k, v in global_state.items()}
|
|
||||||
|
|
||||||
avg_loss = {}
|
avg_loss = {}
|
||||||
for name in self.selected_clients:
|
total_n = float(self.n_data)
|
||||||
if name not in self.client_state:
|
|
||||||
|
# Prepare accumulators on CPU. For floating tensors, use float32 zeros.
|
||||||
|
# For non-floating tensors (e.g., BN num_batches_tracked int64), we’ll copy from the first client.
|
||||||
|
new_state = {}
|
||||||
|
first_client = None
|
||||||
|
for cid in self.selected_clients:
|
||||||
|
if cid in self.client_state:
|
||||||
|
first_client = cid
|
||||||
|
break
|
||||||
|
|
||||||
|
assert first_client is not None, "No client states available to aggregate."
|
||||||
|
|
||||||
|
for k, v in global_state.items():
|
||||||
|
if v.is_floating_point():
|
||||||
|
new_state[k] = torch.zeros_like(v.detach().cpu(), dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
# For non-float buffers, just copy from the first client (or keep global)
|
||||||
|
new_state[k] = self.client_state[first_client][k].clone()
|
||||||
|
|
||||||
|
# Accumulate floating tensors with weights; keep non-floats as assigned above
|
||||||
|
for cid in self.selected_clients:
|
||||||
|
if cid not in self.client_state:
|
||||||
continue
|
continue
|
||||||
weight = self.client_n_data[name] / self.n_data
|
weight = self.client_n_data[cid] / total_n
|
||||||
|
cst = self.client_state[cid]
|
||||||
for k in new_state.keys():
|
for k in new_state.keys():
|
||||||
# accumulate in float32 to avoid fp16 issues
|
if new_state[k].is_floating_point():
|
||||||
new_state[k] += self.client_state[name][k].to(torch.float32) * weight
|
# cst[k] is CPU; ensure float32 for accumulation
|
||||||
|
new_state[k].add_(cst[k].to(torch.float32), alpha=weight)
|
||||||
|
|
||||||
# losses
|
# weighted average losses
|
||||||
for k, v in self.client_loss[name].items():
|
for lk, lv in self.client_loss[cid].items():
|
||||||
avg_loss[k] = avg_loss.get(k, 0.0) + v * weight
|
avg_loss[lk] = avg_loss.get(lk, 0.0) + float(lv) * weight
|
||||||
|
|
||||||
# load aggregated params back into global model
|
# Load aggregated state back into the global model (model is on CPU)
|
||||||
|
with torch.no_grad():
|
||||||
self.model.load_state_dict(new_state, strict=True)
|
self.model.load_state_dict(new_state, strict=True)
|
||||||
|
|
||||||
self.round += 1
|
self.round += 1
|
||||||
return self.model.state_dict(), avg_loss, self.n_data
|
# Return CPU state_dict (good for broadcasting to clients)
|
||||||
|
return {k: v.clone() for k, v in self.model.state_dict().items()}, avg_loss, int(self.n_data)
|
||||||
|
|
||||||
def rec(self, name, state_dict, n_data, loss_dict):
|
def rec(self, name, state_dict, n_data, loss_dict):
|
||||||
"""
|
"""
|
||||||
Receive local update from a client.
|
Receive local update from a client.
|
||||||
Args:
|
- Store all floating tensors as CPU float32
|
||||||
name: client ID
|
- Store non-floating tensors (e.g., BN counters) as CPU in original dtype
|
||||||
state_dict: state dictionary of the local model
|
|
||||||
n_data: number of data samples used in local training
|
|
||||||
loss_dict: dict of losses from local training
|
|
||||||
"""
|
"""
|
||||||
self.n_data += n_data
|
self.n_data += n_data
|
||||||
self.client_state[name] = {k: v.cpu() for k, v in state_dict.items()}
|
safe_state = {}
|
||||||
self.client_n_data[name] = n_data
|
with torch.no_grad():
|
||||||
self.client_loss[name] = loss_dict
|
for k, v in state_dict.items():
|
||||||
|
t = v.detach().cpu()
|
||||||
|
if t.is_floating_point():
|
||||||
|
t = t.to(torch.float32)
|
||||||
|
safe_state[k] = t
|
||||||
|
self.client_state[name] = safe_state
|
||||||
|
self.client_n_data[name] = int(n_data)
|
||||||
|
self.client_loss[name] = {k: float(v) for k, v in loss_dict.items()}
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
"""Clear stored client updates."""
|
"""Clear stored client updates."""
|
||||||
|
Reference in New Issue
Block a user