优化FedYoloClient和FedYoloServer类
This commit is contained in:
@@ -3,11 +3,11 @@ import torch
|
||||
from torch import nn
|
||||
from torch.utils import data
|
||||
from torch.amp.autocast_mode import autocast
|
||||
from tqdm import tqdm
|
||||
from utils.fed_util import init_model
|
||||
from utils import util
|
||||
from utils.dataset import Dataset
|
||||
from typing import cast
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
class FedYoloClient(object):
|
||||
@@ -82,52 +82,48 @@ class FedYoloClient(object):
|
||||
# load the global model parameters
|
||||
self.model.load_state_dict(Global_model_state_dict, strict=True)
|
||||
|
||||
def train(self, args):
|
||||
def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
|
||||
"""
|
||||
Train the local model
|
||||
Args:
|
||||
:param args: Command line arguments
|
||||
- local_rank: Local rank for distributed training
|
||||
- world_size: World size for distributed training
|
||||
- distributed: Whether to use distributed training
|
||||
- input_size: Input size for the model
|
||||
Returns:
|
||||
:return: Local updated model, number of local data points, training loss
|
||||
Train the local model.
|
||||
Returns: (state_dict, n_data, avg_loss_per_image)
|
||||
"""
|
||||
|
||||
# ---- Dist init (if any) ----
|
||||
if args.distributed:
|
||||
torch.cuda.set_device(device=args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||||
# print(f"Client {self.name} - distributed training on {world_size} GPUs, local rank: {local_rank}")
|
||||
# self._device = torch.device("cuda", local_rank)
|
||||
|
||||
if args.local_rank == 0:
|
||||
pass
|
||||
# if not os.path.exists("weights"):
|
||||
# os.makedirs("weights")
|
||||
|
||||
util.setup_seed()
|
||||
util.setup_multi_processes()
|
||||
|
||||
# model
|
||||
# init model have been done in __init__()
|
||||
self.model.to(self._device)
|
||||
# device = torch.device(f"cuda:{args.local_rank}" if torch.cuda.is_available() else "cpu")
|
||||
# self.model.to(device)
|
||||
self.model.cuda()
|
||||
# show model architecture
|
||||
# print(self.model)
|
||||
|
||||
# Optimizer
|
||||
accumulate = max(round(64 / (self._batch_size * args.world_size)), 1)
|
||||
self._weight_decay = self._batch_size * args.world_size * accumulate / 64
|
||||
# ---- Optimizer / WD scaling & LR warmup/schedule ----
|
||||
# accumulate = effective grad-accumulation steps to emulate global batch 64
|
||||
world_size = getattr(args, "world_size", 1)
|
||||
accumulate = max(round(64 / (self._batch_size * max(world_size, 1))), 1)
|
||||
|
||||
# scale weight_decay like YOLO recipes
|
||||
scaled_wd = self._weight_decay * self._batch_size * max(world_size, 1) * accumulate / 64
|
||||
optimizer = torch.optim.SGD(
|
||||
util.set_params(self.model, self._weight_decay),
|
||||
util.set_params(self.model, scaled_wd),
|
||||
lr=self._min_lr,
|
||||
momentum=self._momentum,
|
||||
nesterov=True,
|
||||
)
|
||||
|
||||
# EMA
|
||||
# ---- EMA (track the underlying module if DDP) ----
|
||||
# track_model = self.model.module if is_ddp else self.model
|
||||
ema = util.EMA(self.model) if args.local_rank == 0 else None
|
||||
|
||||
data_set = Dataset(
|
||||
print(type(self.train_dataset))
|
||||
|
||||
# ---- Data ----
|
||||
dataset = Dataset(
|
||||
filenames=self.train_dataset,
|
||||
input_size=args.input_size,
|
||||
params=self.params,
|
||||
@@ -136,26 +132,28 @@ class FedYoloClient(object):
|
||||
|
||||
if args.distributed:
|
||||
train_sampler = data.DistributedSampler(
|
||||
data_set, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
|
||||
dataset, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
|
||||
)
|
||||
else:
|
||||
train_sampler = None
|
||||
|
||||
loader = data.DataLoader(
|
||||
data_set,
|
||||
dataset,
|
||||
batch_size=self._batch_size,
|
||||
shuffle=train_sampler is None,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=Dataset.collate_fn,
|
||||
drop_last=False,
|
||||
)
|
||||
|
||||
# Scheduler
|
||||
num_steps = max(1, len(loader))
|
||||
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
|
||||
# DDP mode
|
||||
if args.distributed:
|
||||
|
||||
# ---- SyncBN + DDP (if any) ----
|
||||
is_ddp = bool(args.distributed)
|
||||
if is_ddp:
|
||||
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
|
||||
self.model = nn.parallel.DistributedDataParallel(
|
||||
module=self.model,
|
||||
@@ -164,102 +162,133 @@ class FedYoloClient(object):
|
||||
find_unused_parameters=False,
|
||||
)
|
||||
|
||||
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True)
|
||||
# ---- AMP + loss ----
|
||||
scaler = torch.amp.grad_scaler.GradScaler(enabled=True)
|
||||
# criterion = util.ComputeLoss(
|
||||
# self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model,
|
||||
# self.params,
|
||||
# )
|
||||
criterion = util.ComputeLoss(self.model, self.params)
|
||||
|
||||
# log
|
||||
# if args.local_rank == 0:
|
||||
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
|
||||
# print("\n" + header)
|
||||
# p_bar = tqdm(total=args.epochs * num_steps, ncols=120)
|
||||
# p_bar.set_description(f"{self.name:>10}")
|
||||
|
||||
# ---- Training ----
|
||||
for epoch in range(args.epochs):
|
||||
# (self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model).train()
|
||||
self.model.train()
|
||||
# when distributed, set epoch for shuffling
|
||||
if args.distributed and train_sampler is not None:
|
||||
if is_ddp and train_sampler is not None:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
if args.epochs - epoch == 10:
|
||||
# disable mosaic augmentation in the last 10 epochs
|
||||
# disable mosaic in the last 10 epochs (if dataset supports it)
|
||||
if args.epochs - epoch == 10 and hasattr(loader.dataset, "mosaic"):
|
||||
ds = cast(Dataset, loader.dataset)
|
||||
ds.mosaic = False
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
avg_box_loss = util.AverageMeter()
|
||||
avg_cls_loss = util.AverageMeter()
|
||||
avg_dfl_loss = util.AverageMeter()
|
||||
loss_box_meter = util.AverageMeter()
|
||||
loss_cls_meter = util.AverageMeter()
|
||||
loss_dfl_meter = util.AverageMeter()
|
||||
|
||||
# # --- header (once per epoch, YOLO-style) ---
|
||||
# if args.local_rank == 0:
|
||||
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
|
||||
# print("\n" + header)
|
||||
for i, (images, targets) in enumerate(loader):
|
||||
print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}")
|
||||
step = i + epoch * num_steps
|
||||
|
||||
# p_bar = enumerate(loader)
|
||||
# if args.local_rank == 0:
|
||||
# p_bar = tqdm(p_bar, total=num_steps, ncols=120)
|
||||
# scheduler per-step (your util.LinearLR expects step)
|
||||
scheduler.step(step=step, optimizer=optimizer)
|
||||
|
||||
for i, (samples, targets) in enumerate(loader):
|
||||
global_step = i + num_steps * epoch
|
||||
scheduler.step(step=global_step, optimizer=optimizer)
|
||||
# images = images.to(device, non_blocking=True).float() / 255.0
|
||||
images = images.cuda().float() / 255.0
|
||||
bs = images.size(0)
|
||||
# total_imgs_seen += bs
|
||||
|
||||
samples = samples.cuda(non_blocking=True).float() / 255.0
|
||||
# targets: keep as your ComputeLoss expects (often CPU lists/tensors).
|
||||
# Move to GPU here only if your loss requires it.
|
||||
|
||||
# Forward
|
||||
with autocast("cuda", enabled=True):
|
||||
outputs = self.model(samples)
|
||||
with autocast(device_type="cuda", enabled=True):
|
||||
outputs = self.model(images) # DDP wraps forward
|
||||
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
|
||||
|
||||
# meters (use the *unscaled* values)
|
||||
bs = samples.size(0)
|
||||
avg_box_loss.update(box_loss.item(), bs)
|
||||
avg_cls_loss.update(cls_loss.item(), bs)
|
||||
avg_dfl_loss.update(dfl_loss.item(), bs)
|
||||
# total_loss = box_loss + cls_loss + dfl_loss
|
||||
# Gradient accumulation: normalize by 'accumulate' so LR stays effective
|
||||
# total_loss = total_loss / accumulate
|
||||
|
||||
# scale losses by batch/world if your loss is averaged internally per-sample/device
|
||||
# box_loss = box_loss * self._batch_size * args.world_size
|
||||
# cls_loss = cls_loss * self._batch_size * args.world_size
|
||||
# dfl_loss = dfl_loss * self._batch_size * args.world_size
|
||||
# IMPORTANT: assume criterion returns **average per image** in the batch.
|
||||
# Keep logging on the true (unscaled) values:
|
||||
loss_box_meter.update(box_loss.item(), bs)
|
||||
loss_cls_meter.update(cls_loss.item(), bs)
|
||||
loss_dfl_meter.update(dfl_loss.item(), bs)
|
||||
|
||||
box_loss *= self._batch_size
|
||||
cls_loss *= self._batch_size
|
||||
dfl_loss *= self._batch_size
|
||||
box_loss *= args.world_size
|
||||
cls_loss *= args.world_size
|
||||
dfl_loss *= args.world_size
|
||||
total_loss = box_loss + cls_loss + dfl_loss
|
||||
|
||||
# Backward
|
||||
amp_scale.scale(total_loss).backward()
|
||||
scaler.scale(total_loss).backward()
|
||||
|
||||
# Optimize
|
||||
if (i + 1) % accumulate == 0:
|
||||
amp_scale.unscale_(optimizer) # unscale gradients
|
||||
util.clip_gradients(model=self.model, max_norm=10.0) # clip gradients
|
||||
amp_scale.step(optimizer)
|
||||
amp_scale.update()
|
||||
# optimize
|
||||
if step % accumulate == 0:
|
||||
# scaler.unscale_(optimizer)
|
||||
# util.clip_gradients(self.model)
|
||||
scaler.step(optimizer)
|
||||
scaler.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
# # Step when we have 'accumulate' micro-batches, or at the end
|
||||
# if ((i + 1) % accumulate == 0) or (i + 1 == len(loader)):
|
||||
# scaler.unscale_(optimizer)
|
||||
# util.clip_gradients(
|
||||
# model=(
|
||||
# self.model.module
|
||||
# if isinstance(self.model, nn.parallel.DistributedDataParallel)
|
||||
# else self.model
|
||||
# ),
|
||||
# max_norm=10.0,
|
||||
# )
|
||||
# scaler.step(optimizer)
|
||||
# scaler.update()
|
||||
# optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
if ema:
|
||||
ema.update(self.model)
|
||||
# Update EMA from the underlying module
|
||||
ema.update(
|
||||
self.model.module
|
||||
if isinstance(self.model, nn.parallel.DistributedDataParallel)
|
||||
else self.model
|
||||
)
|
||||
# print loss to test
|
||||
print(
|
||||
f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}"
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
# torch.cuda.synchronize()
|
||||
# ---- Final average loss (per image) over the whole epoch span ----
|
||||
avg_loss_per_image = loss_box_meter.avg + loss_cls_meter.avg + loss_dfl_meter.avg
|
||||
|
||||
# tqdm update
|
||||
# if args.local_rank == 0:
|
||||
# mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
|
||||
# desc = ("%10s" * 2 + "%10.4g" * 3) % (
|
||||
# self.name,
|
||||
# mem,
|
||||
# avg_box_loss.avg,
|
||||
# avg_cls_loss.avg,
|
||||
# avg_dfl_loss.avg,
|
||||
# )
|
||||
# cast(tqdm, p_bar).set_description(desc)
|
||||
# p_bar.update(1)
|
||||
|
||||
# p_bar.close()
|
||||
|
||||
# clean
|
||||
if args.distributed:
|
||||
# ---- Cleanup DDP ----
|
||||
if is_ddp:
|
||||
torch.distributed.destroy_process_group()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return (
|
||||
self.model.state_dict() if not ema else ema.ema.state_dict(),
|
||||
self.n_data,
|
||||
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg},
|
||||
)
|
||||
# ---- Choose which weights to return ----
|
||||
# - If EMA exists, return EMA weights (common YOLO eval practice)
|
||||
# - Be careful with DDP: grab state_dict from the underlying module / EMA model
|
||||
if ema:
|
||||
# print("Using EMA weights")
|
||||
return (ema.ema.state_dict(), self.n_data, avg_loss_per_image)
|
||||
else:
|
||||
# Safely get the underlying module if wrapped by DDP; getattr returns the module or the original object.
|
||||
model_obj = getattr(self.model, "module", self.model)
|
||||
# If it's a proper nn.Module, call state_dict(); if it's already a state dict, use it;
|
||||
# otherwise try to call state_dict() and finally fall back to wrapping the object.
|
||||
if isinstance(model_obj, torch.nn.Module):
|
||||
model_to_return = model_obj.state_dict()
|
||||
elif isinstance(model_obj, dict):
|
||||
model_to_return = model_obj
|
||||
else:
|
||||
try:
|
||||
model_to_return = model_obj.state_dict()
|
||||
except Exception:
|
||||
# fallback: if model_obj is a tensor or unexpected object, wrap it in a dict
|
||||
model_to_return = {"state": model_obj}
|
||||
return model_to_return, self.n_data, avg_loss_per_image
|
||||
|
Reference in New Issue
Block a user