Files
fed-yolo/fed_algo_cs/client_base.py
2025-10-19 21:27:19 +08:00

295 lines
12 KiB
Python

import numpy as np
import torch
from torch import nn
from torch.utils import data
from torch.amp.autocast_mode import autocast
from utils.fed_util import init_model
from utils import util
from utils.dataset import Dataset
from typing import cast
from tqdm import tqdm
class FedYoloClient(object):
def __init__(self, name, model_name, params):
"""
Initialize the client k for federated learning
Args:
:param name: Name of the client k
:param model_name: Name of the model
:param params: config file including the hyperparameters for local training
- batch_size: Local training batch size in the client k
- num_workers: Number of data loader workers
- min_lr: Minimum learning rate
- max_lr: Maximum learning rate
- momentum: Momentum for local training
- weight_decay: Weight decay for local training
"""
self.params = params
# initialize the metadata in local client k
self.target_ip = "127.0.0.3"
self.port = 9999
self.name = name
# initialize the parameters in local client k
self._batch_size = self.params["local_batch_size"]
self._min_lr = self.params["min_lr"]
self._max_lr = self.params["max_lr"]
self._momentum = self.params["momentum"]
self.num_workers = self.params["num_workers"]
self.loss_record = []
# train set length
self.n_data = 0
# initialize the local training and testing dataset
self.train_dataset = None
self.val_dataset = None
# initialize the local model
self._num_classes = len(self.params["names"])
self._weight_decay = self.params["weight_decay"]
self.model_name = model_name
self.model = init_model(model_name, self._num_classes)
model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
self.parameter_number = sum([np.prod(p.size()) for p in model_parameters])
# GPU
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
def load_trainset(self, train_dataset: list[str]):
"""
Load the local training dataset
Args:
:param train_dataset: Training dataset
"""
self.train_dataset = train_dataset
self.n_data = len(self.train_dataset)
def update(self, Global_model_state_dict):
"""
Update the local model with the global model parameters
Args:
:param Global_model_state_dict: State dictionary of the global model
"""
if not hasattr(self, "model") or self.model is None:
self.model = init_model(self.model_name, self._num_classes)
# load the global model parameters
self.model.load_state_dict(Global_model_state_dict, strict=True)
def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
"""
Train the local model.
Returns: (state_dict, n_data, avg_loss_per_image)
"""
# ---- Dist init (if any) ----
if args.distributed:
torch.cuda.set_device(device=args.local_rank)
torch.distributed.init_process_group(backend="nccl", init_method="env://")
util.setup_seed()
util.setup_multi_processes()
# device = torch.device(f"cuda:{args.local_rank}" if torch.cuda.is_available() else "cpu")
# self.model.to(device)
self.model.cuda()
# show model architecture
# print(self.model)
# ---- Optimizer / WD scaling & LR warmup/schedule ----
# accumulate = effective grad-accumulation steps to emulate global batch 64
world_size = getattr(args, "world_size", 1)
accumulate = max(round(64 / (self._batch_size * max(world_size, 1))), 1)
# scale weight_decay like YOLO recipes
scaled_wd = self._weight_decay * self._batch_size * max(world_size, 1) * accumulate / 64
optimizer = torch.optim.SGD(
util.set_params(self.model, scaled_wd),
lr=self._min_lr,
momentum=self._momentum,
nesterov=True,
)
# ---- EMA (track the underlying module if DDP) ----
# track_model = self.model.module if is_ddp else self.model
ema = util.EMA(self.model) if args.local_rank == 0 else None
print(type(self.train_dataset))
# ---- Data ----
dataset = Dataset(
filenames=self.train_dataset,
input_size=args.input_size,
params=self.params,
augment=True,
)
if args.distributed:
train_sampler = data.DistributedSampler(
dataset, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
)
else:
train_sampler = None
loader = data.DataLoader(
dataset,
batch_size=self._batch_size,
shuffle=(train_sampler is None),
sampler=train_sampler,
num_workers=self.num_workers,
pin_memory=True,
collate_fn=Dataset.collate_fn,
drop_last=False,
)
num_steps = max(1, len(loader))
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
# ---- SyncBN + DDP (if any) ----
is_ddp = bool(args.distributed)
if is_ddp:
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
self.model = nn.parallel.DistributedDataParallel(
module=self.model,
device_ids=[args.local_rank],
output_device=args.local_rank,
find_unused_parameters=False,
)
# ---- AMP + loss ----
scaler = torch.amp.grad_scaler.GradScaler(enabled=True)
# criterion = util.ComputeLoss(
# self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model,
# self.params,
# )
criterion = util.ComputeLoss(self.model, self.params)
# ---- Training ----
for epoch in range(args.epochs):
# (self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model).train()
self.model.train()
if is_ddp and train_sampler is not None:
train_sampler.set_epoch(epoch)
# disable mosaic in the last 10 epochs (if dataset supports it)
if args.epochs - epoch == 10 and hasattr(loader.dataset, "mosaic"):
ds = cast(Dataset, loader.dataset)
ds.mosaic = False
optimizer.zero_grad(set_to_none=True)
loss_box_meter = util.AverageMeter()
loss_cls_meter = util.AverageMeter()
loss_dfl_meter = util.AverageMeter()
for i, (images, targets) in enumerate(loader):
print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}")
step = i + epoch * num_steps
# scheduler per-step (your util.LinearLR expects step)
scheduler.step(step=step, optimizer=optimizer)
# images = images.to(device, non_blocking=True).float() / 255.0
images = images.cuda().float() / 255.0
bs = images.size(0)
# total_imgs_seen += bs
# targets: keep as your ComputeLoss expects (often CPU lists/tensors).
# Move to GPU here only if your loss requires it.
with autocast(device_type="cuda", enabled=True):
outputs = self.model(images) # DDP wraps forward
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
# total_loss = box_loss + cls_loss + dfl_loss
# Gradient accumulation: normalize by 'accumulate' so LR stays effective
# total_loss = total_loss / accumulate
# IMPORTANT: assume criterion returns **average per image** in the batch.
# Keep logging on the true (unscaled) values:
loss_box_meter.update(box_loss.item(), bs)
loss_cls_meter.update(cls_loss.item(), bs)
loss_dfl_meter.update(dfl_loss.item(), bs)
box_loss *= self._batch_size
cls_loss *= self._batch_size
dfl_loss *= self._batch_size
box_loss *= args.world_size
cls_loss *= args.world_size
dfl_loss *= args.world_size
total_loss = box_loss + cls_loss + dfl_loss
scaler.scale(total_loss).backward()
# optimize
if step % accumulate == 0:
# scaler.unscale_(optimizer)
# util.clip_gradients(self.model)
scaler.step(optimizer)
scaler.update()
optimizer.zero_grad(set_to_none=True)
# # Step when we have 'accumulate' micro-batches, or at the end
# if ((i + 1) % accumulate == 0) or (i + 1 == len(loader)):
# scaler.unscale_(optimizer)
# util.clip_gradients(
# model=(
# self.model.module
# if isinstance(self.model, nn.parallel.DistributedDataParallel)
# else self.model
# ),
# max_norm=10.0,
# )
# scaler.step(optimizer)
# scaler.update()
# optimizer.zero_grad(set_to_none=True)
if ema:
# Update EMA from the underlying module
ema.update(
self.model.module
if isinstance(self.model, nn.parallel.DistributedDataParallel)
else self.model
)
# print loss to test
print(
f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}"
)
torch.cuda.synchronize()
# ---- Final average loss (per image) over the whole epoch span ----
avg_loss_per_image = loss_box_meter.avg + loss_cls_meter.avg + loss_dfl_meter.avg
# ---- Cleanup DDP ----
if is_ddp:
torch.distributed.destroy_process_group()
torch.cuda.empty_cache()
# ---- Choose which weights to return ----
# - If EMA exists, return EMA weights (common YOLO eval practice)
# - Be careful with DDP: grab state_dict from the underlying module / EMA model
if ema:
# print("Using EMA weights")
return (ema.ema.state_dict(), self.n_data, avg_loss_per_image)
else:
# Safely get the underlying module if wrapped by DDP; getattr returns the module or the original object.
model_obj = getattr(self.model, "module", self.model)
# If it's a proper nn.Module, call state_dict(); if it's already a state dict, use it;
# otherwise try to call state_dict() and finally fall back to wrapping the object.
if isinstance(model_obj, torch.nn.Module):
model_to_return = model_obj.state_dict()
elif isinstance(model_obj, dict):
model_to_return = model_obj
else:
try:
model_to_return = model_obj.state_dict()
except Exception:
# fallback: if model_obj is a tensor or unexpected object, wrap it in a dict
model_to_return = {"state": model_obj}
return model_to_return, self.n_data, avg_loss_per_image