import numpy as np import torch from torch import nn from torch.utils import data from torch.amp.autocast_mode import autocast from utils.fed_util import init_model from utils import util from utils.dataset import Dataset from typing import cast from tqdm import tqdm class FedYoloClient(object): def __init__(self, name, model_name, params): """ Initialize the client k for federated learning Args: :param name: Name of the client k :param model_name: Name of the model :param params: config file including the hyperparameters for local training - batch_size: Local training batch size in the client k - num_workers: Number of data loader workers - min_lr: Minimum learning rate - max_lr: Maximum learning rate - momentum: Momentum for local training - weight_decay: Weight decay for local training """ self.params = params # initialize the metadata in local client k self.target_ip = "127.0.0.3" self.port = 9999 self.name = name # initialize the parameters in local client k self._batch_size = self.params["local_batch_size"] self._min_lr = self.params["min_lr"] self._max_lr = self.params["max_lr"] self._momentum = self.params["momentum"] self.num_workers = self.params["num_workers"] self.loss_record = [] # train set length self.n_data = 0 # initialize the local training and testing dataset self.train_dataset = None self.val_dataset = None # initialize the local model self._num_classes = len(self.params["names"]) self._weight_decay = self.params["weight_decay"] self.model_name = model_name self.model = init_model(model_name, self._num_classes) model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) self.parameter_number = sum([np.prod(p.size()) for p in model_parameters]) # GPU self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") def load_trainset(self, train_dataset: list[str]): """ Load the local training dataset Args: :param train_dataset: Training dataset """ self.train_dataset = train_dataset self.n_data = len(self.train_dataset) def update(self, Global_model_state_dict): """ Update the local model with the global model parameters Args: :param Global_model_state_dict: State dictionary of the global model """ if not hasattr(self, "model") or self.model is None: self.model = init_model(self.model_name, self._num_classes) # load the global model parameters self.model.load_state_dict(Global_model_state_dict, strict=True) def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]: """ Train the local model. Returns: (state_dict, n_data, avg_loss_per_image) """ # ---- Dist init (if any) ---- if args.distributed: torch.cuda.set_device(device=args.local_rank) torch.distributed.init_process_group(backend="nccl", init_method="env://") util.setup_seed() util.setup_multi_processes() # device = torch.device(f"cuda:{args.local_rank}" if torch.cuda.is_available() else "cpu") # self.model.to(device) self.model.cuda() # show model architecture # print(self.model) # ---- Optimizer / WD scaling & LR warmup/schedule ---- # accumulate = effective grad-accumulation steps to emulate global batch 64 world_size = getattr(args, "world_size", 1) accumulate = max(round(64 / (self._batch_size * max(world_size, 1))), 1) # scale weight_decay like YOLO recipes scaled_wd = self._weight_decay * self._batch_size * max(world_size, 1) * accumulate / 64 optimizer = torch.optim.SGD( util.set_params(self.model, scaled_wd), lr=self._min_lr, momentum=self._momentum, nesterov=True, ) # ---- EMA (track the underlying module if DDP) ---- # track_model = self.model.module if is_ddp else self.model ema = util.EMA(self.model) if args.local_rank == 0 else None # print(type(self.train_dataset)) # ---- Data ---- dataset = Dataset( filenames=self.train_dataset, input_size=args.input_size, params=self.params, augment=True, ) if args.distributed: train_sampler = data.DistributedSampler( dataset, num_replicas=args.world_size, rank=args.local_rank, shuffle=True ) else: train_sampler = None loader = data.DataLoader( dataset, batch_size=self._batch_size, shuffle=(train_sampler is None), sampler=train_sampler, num_workers=self.num_workers, pin_memory=True, collate_fn=Dataset.collate_fn, drop_last=False, ) num_steps = max(1, len(loader)) scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps) # ---- SyncBN + DDP (if any) ---- is_ddp = bool(args.distributed) if is_ddp: self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model) self.model = nn.parallel.DistributedDataParallel( module=self.model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=False, ) # ---- AMP + loss ---- scaler = torch.amp.grad_scaler.GradScaler(enabled=True) # criterion = util.ComputeLoss( # self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model, # self.params, # ) criterion = util.ComputeLoss(self.model, self.params) # ---- Training ---- for epoch in range(args.epochs): # (self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model).train() self.model.train() if is_ddp and train_sampler is not None: train_sampler.set_epoch(epoch) # disable mosaic in the last 10 epochs (if dataset supports it) if args.epochs - epoch == 10 and hasattr(loader.dataset, "mosaic"): ds = cast(Dataset, loader.dataset) ds.mosaic = False optimizer.zero_grad(set_to_none=True) loss_box_meter = util.AverageMeter() loss_cls_meter = util.AverageMeter() loss_dfl_meter = util.AverageMeter() for i, (images, targets) in enumerate(loader): # print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}") step = i + epoch * num_steps # scheduler per-step (your util.LinearLR expects step) scheduler.step(step=step, optimizer=optimizer) # images = images.to(device, non_blocking=True).float() / 255.0 images = images.cuda().float() / 255.0 bs = images.size(0) # total_imgs_seen += bs # targets: keep as your ComputeLoss expects (often CPU lists/tensors). # Move to GPU here only if your loss requires it. with autocast(device_type="cuda", enabled=True): outputs = self.model(images) # DDP wraps forward box_loss, cls_loss, dfl_loss = criterion(outputs, targets) # total_loss = box_loss + cls_loss + dfl_loss # Gradient accumulation: normalize by 'accumulate' so LR stays effective # total_loss = total_loss / accumulate # IMPORTANT: assume criterion returns **average per image** in the batch. # Keep logging on the true (unscaled) values: loss_box_meter.update(box_loss.item(), bs) loss_cls_meter.update(cls_loss.item(), bs) loss_dfl_meter.update(dfl_loss.item(), bs) box_loss *= self._batch_size cls_loss *= self._batch_size dfl_loss *= self._batch_size box_loss *= args.world_size cls_loss *= args.world_size dfl_loss *= args.world_size total_loss = box_loss + cls_loss + dfl_loss scaler.scale(total_loss).backward() # optimize if step % accumulate == 0: # scaler.unscale_(optimizer) # util.clip_gradients(self.model) scaler.step(optimizer) scaler.update() optimizer.zero_grad(set_to_none=True) # # Step when we have 'accumulate' micro-batches, or at the end # if ((i + 1) % accumulate == 0) or (i + 1 == len(loader)): # scaler.unscale_(optimizer) # util.clip_gradients( # model=( # self.model.module # if isinstance(self.model, nn.parallel.DistributedDataParallel) # else self.model # ), # max_norm=10.0, # ) # scaler.step(optimizer) # scaler.update() # optimizer.zero_grad(set_to_none=True) if ema: # Update EMA from the underlying module ema.update( self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model ) # print loss to test # print( # f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}" # ) torch.cuda.synchronize() # ---- Final average loss (per image) over the whole epoch span ---- avg_loss_per_image = loss_box_meter.avg + loss_cls_meter.avg + loss_dfl_meter.avg # ---- Cleanup DDP ---- if is_ddp: torch.distributed.destroy_process_group() torch.cuda.empty_cache() # ---- Choose which weights to return ---- # - If EMA exists, return EMA weights (common YOLO eval practice) # - Be careful with DDP: grab state_dict from the underlying module / EMA model if ema: # print("Using EMA weights") return (ema.ema.state_dict(), self.n_data, avg_loss_per_image) else: # Safely get the underlying module if wrapped by DDP; getattr returns the module or the original object. model_obj = getattr(self.model, "module", self.model) # If it's a proper nn.Module, call state_dict(); if it's already a state dict, use it; # otherwise try to call state_dict() and finally fall back to wrapping the object. if isinstance(model_obj, torch.nn.Module): model_to_return = model_obj.state_dict() elif isinstance(model_obj, dict): model_to_return = model_obj else: try: model_to_return = model_obj.state_dict() except Exception: # fallback: if model_obj is a tensor or unexpected object, wrap it in a dict model_to_return = {"state": model_obj} return model_to_return, self.n_data, avg_loss_per_image