Fedavg and YOLOv11 training
This commit is contained in:
233
fed_algo_cs/client_base.py
Normal file
233
fed_algo_cs/client_base.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils import data
|
||||
from torch.amp.autocast_mode import autocast
|
||||
from utils.fed_util import init_model
|
||||
from utils import util
|
||||
from utils.dataset import Dataset
|
||||
from typing import cast
|
||||
|
||||
|
||||
class FedYoloClient(object):
|
||||
def __init__(self, name, model_name, params):
|
||||
"""
|
||||
Initialize the client k for federated learning
|
||||
Args:
|
||||
:param name: Name of the client k
|
||||
:param model_name: Name of the model
|
||||
:param params: config file including the hyperparameters for local training
|
||||
- batch_size: Local training batch size in the client k
|
||||
- num_workers: Number of data loader workers
|
||||
|
||||
- min_lr: Minimum learning rate
|
||||
- max_lr: Maximum learning rate
|
||||
- momentum: Momentum for local training
|
||||
- weight_decay: Weight decay for local training
|
||||
"""
|
||||
self.params = params
|
||||
# initialize the metadata in local client k
|
||||
self.target_ip = "127.0.0.3"
|
||||
self.port = 9999
|
||||
self.name = name
|
||||
|
||||
# initialize the parameters in local client k
|
||||
self._batch_size = self.params["local_batch_size"]
|
||||
self._min_lr = self.params["min_lr"]
|
||||
self._max_lr = self.params["max_lr"]
|
||||
self._momentum = self.params["momentum"]
|
||||
self.num_workers = self.params["num_workers"]
|
||||
|
||||
self.loss_record = []
|
||||
# train set length
|
||||
self.n_data = 0
|
||||
|
||||
# initialize the local training and testing dataset
|
||||
self.train_dataset = None
|
||||
self.val_dataset = None
|
||||
|
||||
# initialize the local model
|
||||
self._num_classes = len(self.params["names"])
|
||||
self._weight_decay = self.params["weight_decay"]
|
||||
|
||||
self.model_name = model_name
|
||||
self.model = init_model(model_name, self._num_classes)
|
||||
|
||||
model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
|
||||
self.parameter_number = sum([np.prod(p.size()) for p in model_parameters])
|
||||
|
||||
# GPU
|
||||
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def load_trainset(self, train_dataset: list[str]):
|
||||
"""
|
||||
Load the local training dataset
|
||||
Args:
|
||||
:param train_dataset: Training dataset
|
||||
"""
|
||||
self.train_dataset = train_dataset
|
||||
self.n_data = len(self.train_dataset)
|
||||
|
||||
def update(self, Global_model_state_dict):
|
||||
"""
|
||||
Update the local model with the global model parameters
|
||||
Args:
|
||||
:param Global_model_state_dict: State dictionary of the global model
|
||||
"""
|
||||
|
||||
if not hasattr(self, "model") or self.model is None:
|
||||
self.model = init_model(self.model_name, self._num_classes)
|
||||
|
||||
# load the global model parameters
|
||||
self.model.load_state_dict(Global_model_state_dict, strict=True)
|
||||
|
||||
def train(self, args):
|
||||
"""
|
||||
Train the local model
|
||||
Args:
|
||||
:param args: Command line arguments
|
||||
- local_rank: Local rank for distributed training
|
||||
- world_size: World size for distributed training
|
||||
- distributed: Whether to use distributed training
|
||||
- input_size: Input size for the model
|
||||
Returns:
|
||||
:return: Local updated model, number of local data points, training loss
|
||||
"""
|
||||
|
||||
if args.distributed:
|
||||
torch.cuda.set_device(device=args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||||
# print(f"Client {self.name} - distributed training on {world_size} GPUs, local rank: {local_rank}")
|
||||
# self._device = torch.device("cuda", local_rank)
|
||||
|
||||
if args.local_rank == 0:
|
||||
pass
|
||||
# if not os.path.exists("weights"):
|
||||
# os.makedirs("weights")
|
||||
|
||||
util.setup_seed()
|
||||
util.setup_multi_processes()
|
||||
|
||||
# model
|
||||
# init model have been done in __init__()
|
||||
self.model.to(self._device)
|
||||
|
||||
# Optimizer
|
||||
accumulate = max(round(64 / (self._batch_size * args.world_size)), 1)
|
||||
self._weight_decay = self._batch_size * args.world_size * accumulate / 64
|
||||
|
||||
optimizer = torch.optim.SGD(
|
||||
util.set_params(self.model, self._weight_decay),
|
||||
lr=self._min_lr,
|
||||
momentum=self._momentum,
|
||||
nesterov=True,
|
||||
)
|
||||
|
||||
# EMA
|
||||
ema = util.EMA(self.model) if args.local_rank == 0 else None
|
||||
|
||||
data_set = Dataset(
|
||||
filenames=self.train_dataset,
|
||||
input_size=args.input_size,
|
||||
params=self.params,
|
||||
augment=True,
|
||||
)
|
||||
|
||||
if args.distributed:
|
||||
train_sampler = data.DistributedSampler(
|
||||
data_set, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
|
||||
)
|
||||
else:
|
||||
train_sampler = None
|
||||
|
||||
loader = data.DataLoader(
|
||||
data_set,
|
||||
batch_size=self._batch_size,
|
||||
shuffle=train_sampler is None,
|
||||
sampler=train_sampler,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=Dataset.collate_fn,
|
||||
)
|
||||
|
||||
# Scheduler
|
||||
num_steps = max(1, len(loader))
|
||||
# print(len(loader))
|
||||
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
|
||||
# DDP mode
|
||||
if args.distributed:
|
||||
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
|
||||
self.model = nn.parallel.DistributedDataParallel(
|
||||
module=self.model,
|
||||
device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=False,
|
||||
)
|
||||
|
||||
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True)
|
||||
criterion = util.ComputeLoss(self.model, self.params)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
self.model.train()
|
||||
# when distributed, set epoch for shuffling
|
||||
if args.distributed and train_sampler is not None:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
if args.epochs - epoch == 10:
|
||||
# disable mosaic augmentation in the last 10 epochs
|
||||
ds = cast(Dataset, loader.dataset)
|
||||
ds.mosaic = False
|
||||
|
||||
avg_box_loss = util.AverageMeter()
|
||||
avg_cls_loss = util.AverageMeter()
|
||||
avg_dfl_loss = util.AverageMeter()
|
||||
|
||||
for i, (samples, targets) in enumerate(loader):
|
||||
global_step = i + num_steps * epoch
|
||||
scheduler.step(step=global_step, optimizer=optimizer)
|
||||
|
||||
samples = samples.cuda(non_blocking=True).float() / 255.0
|
||||
|
||||
# Forward
|
||||
with autocast("cuda", enabled=True):
|
||||
outputs = self.model(samples)
|
||||
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
|
||||
|
||||
# meters (use the *unscaled* values)
|
||||
bs = samples.size(0)
|
||||
avg_box_loss.update(box_loss.item(), bs)
|
||||
avg_cls_loss.update(cls_loss.item(), bs)
|
||||
avg_dfl_loss.update(dfl_loss.item(), bs)
|
||||
|
||||
# scale losses by batch/world if your loss is averaged internally per-sample/device
|
||||
box_loss = box_loss * self._batch_size * args.world_size
|
||||
cls_loss = cls_loss * self._batch_size * args.world_size
|
||||
dfl_loss = dfl_loss * self._batch_size * args.world_size
|
||||
|
||||
total_loss = box_loss + cls_loss + dfl_loss
|
||||
|
||||
# Backward
|
||||
amp_scale.scale(total_loss).backward()
|
||||
|
||||
# Optimize
|
||||
if (i + 1) % accumulate == 0:
|
||||
amp_scale.step(optimizer)
|
||||
amp_scale.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
if ema:
|
||||
ema.update(self.model)
|
||||
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
# clean
|
||||
if args.distributed:
|
||||
torch.distributed.destroy_process_group()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return (
|
||||
self.model.state_dict(),
|
||||
self.n_data,
|
||||
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg},
|
||||
)
|
178
fed_algo_cs/server_base.py
Normal file
178
fed_algo_cs/server_base.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import DataLoader
|
||||
from utils.fed_util import init_model
|
||||
from utils.dataset import Dataset
|
||||
from utils import util
|
||||
|
||||
|
||||
class FedYoloServer(object):
|
||||
def __init__(self, client_list, model_name, params):
|
||||
"""
|
||||
Federated YOLO Server
|
||||
Args:
|
||||
client_list: list of connected clients
|
||||
model_name: YOLO model architecture name
|
||||
params: dict of hyperparameters (must include 'names')
|
||||
"""
|
||||
# Track client updates
|
||||
self.client_state = {}
|
||||
self.client_loss = {}
|
||||
self.client_n_data = {}
|
||||
self.selected_clients = []
|
||||
|
||||
self._batch_size = params.get("val_batch_size", 4)
|
||||
self.client_list = client_list
|
||||
self.valset = None
|
||||
|
||||
# Federated bookkeeping
|
||||
self.round = 0
|
||||
# Total number of classes
|
||||
self.n_data = 0
|
||||
|
||||
# Device
|
||||
gpu = 0
|
||||
self._device = torch.device(f"cuda:{gpu}" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Global model
|
||||
self._num_classes = len(params["names"])
|
||||
self.model_name = model_name
|
||||
self.model = init_model(model_name, self._num_classes)
|
||||
self.params = params
|
||||
|
||||
def load_valset(self, valset):
|
||||
"""Server loads the validation dataset."""
|
||||
self.valset = valset
|
||||
|
||||
def state_dict(self):
|
||||
"""Return global model weights."""
|
||||
return self.model.state_dict()
|
||||
|
||||
@torch.no_grad()
|
||||
def test(self, args):
|
||||
"""
|
||||
Evaluate global model on validation set using YOLO metrics (mAP, precision, recall).
|
||||
Returns:
|
||||
dict with {"mAP": ..., "mAP50": ..., "precision": ..., "recall": ...}
|
||||
"""
|
||||
if self.valset is None:
|
||||
return {}
|
||||
|
||||
loader = DataLoader(
|
||||
self.valset,
|
||||
batch_size=self._batch_size,
|
||||
shuffle=False,
|
||||
num_workers=4,
|
||||
pin_memory=True,
|
||||
collate_fn=Dataset.collate_fn,
|
||||
)
|
||||
|
||||
self.model.to(self._device).eval().half()
|
||||
|
||||
iou_v = torch.linspace(0.5, 0.95, 10).to(self._device) # IoU thresholds
|
||||
n_iou = iou_v.numel()
|
||||
metrics = []
|
||||
|
||||
for samples, targets in loader:
|
||||
samples = samples.to(self._device).half() / 255.0
|
||||
_, _, h, w = samples.shape
|
||||
scale = torch.tensor((w, h, w, h)).to(self._device)
|
||||
|
||||
outputs = self.model(samples)
|
||||
outputs = util.non_max_suppression(outputs)
|
||||
|
||||
for i, output in enumerate(outputs):
|
||||
idx = targets["idx"] == i
|
||||
cls = targets["cls"][idx].to(self._device)
|
||||
box = targets["box"][idx].to(self._device)
|
||||
|
||||
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=self._device)
|
||||
|
||||
if output.shape[0] == 0:
|
||||
if cls.shape[0]:
|
||||
metrics.append((metric, *torch.zeros((2, 0), device=self._device), cls.squeeze(-1)))
|
||||
continue
|
||||
|
||||
if cls.shape[0]:
|
||||
cls_tensor = cls if isinstance(cls, torch.Tensor) else torch.tensor(cls, device=self._device)
|
||||
if cls_tensor.dim() == 1:
|
||||
cls_tensor = cls_tensor.unsqueeze(1)
|
||||
box_xy = util.wh2xy(box)
|
||||
if not isinstance(box_xy, torch.Tensor):
|
||||
box_xy = torch.tensor(box_xy, device=self._device)
|
||||
target = torch.cat((cls_tensor, box_xy * scale), dim=1)
|
||||
metric = util.compute_metric(output[:, :6], target, iou_v)
|
||||
|
||||
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
||||
|
||||
# Compute metrics
|
||||
if not metrics:
|
||||
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
|
||||
|
||||
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
|
||||
if len(metrics) and metrics[0].any():
|
||||
_, _, prec, rec, map50, mean_ap = util.compute_ap(*metrics, names=self.params["names"], plot=False)
|
||||
else:
|
||||
prec, rec, map50, mean_ap = 0, 0, 0, 0
|
||||
|
||||
# Back to float32 for further training
|
||||
self.model.float()
|
||||
|
||||
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
|
||||
|
||||
def select_clients(self, connection_ratio=1.0):
|
||||
"""Randomly select a fraction of clients."""
|
||||
self.selected_clients = []
|
||||
self.n_data = 0
|
||||
for client_id in self.client_list:
|
||||
if np.random.rand() <= connection_ratio:
|
||||
self.selected_clients.append(client_id)
|
||||
self.n_data += self.client_n_data.get(client_id, 0)
|
||||
|
||||
def agg(self):
|
||||
"""Aggregate client updates (FedAvg)."""
|
||||
if len(self.selected_clients) == 0 or self.n_data == 0:
|
||||
return self.model.state_dict(), {}, 0
|
||||
|
||||
model = init_model(self.model_name, self._num_classes)
|
||||
model_state = model.state_dict()
|
||||
|
||||
avg_loss = {}
|
||||
for i, name in enumerate(self.selected_clients):
|
||||
if name not in self.client_state:
|
||||
continue
|
||||
weight = self.client_n_data[name] / self.n_data
|
||||
for key in model_state.keys():
|
||||
if i == 0:
|
||||
model_state[key] = self.client_state[name][key] * weight
|
||||
else:
|
||||
model_state[key] += self.client_state[name][key] * weight
|
||||
|
||||
# Weighted average losses
|
||||
for k, v in self.client_loss[name].items():
|
||||
avg_loss[k] = avg_loss.get(k, 0.0) + v * weight
|
||||
|
||||
self.model.load_state_dict(model_state, strict=True)
|
||||
self.round += 1
|
||||
return model_state, avg_loss, self.n_data
|
||||
|
||||
def rec(self, name, state_dict, n_data, loss_dict):
|
||||
"""
|
||||
Receive local update from a client.
|
||||
Args:
|
||||
name: client ID
|
||||
state_dict: state dictionary of the local model
|
||||
n_data: number of data samples used in local training
|
||||
loss_dict: dict of losses from local training
|
||||
"""
|
||||
self.n_data += n_data
|
||||
self.client_state[name] = {k: v.cpu() for k, v in state_dict.items()}
|
||||
self.client_n_data[name] = n_data
|
||||
self.client_loss[name] = loss_dict
|
||||
|
||||
def flush(self):
|
||||
"""Clear stored client updates."""
|
||||
self.n_data = 0
|
||||
self.client_state.clear()
|
||||
self.client_n_data.clear()
|
||||
self.client_loss.clear()
|
Reference in New Issue
Block a user