Fedavg and YOLOv11 training
This commit is contained in:
233
fed_algo_cs/client_base.py
Normal file
233
fed_algo_cs/client_base.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.utils import data
|
||||
from torch.amp.autocast_mode import autocast
|
||||
from utils.fed_util import init_model
|
||||
from utils import util
|
||||
from utils.dataset import Dataset
|
||||
from typing import cast
|
||||
|
||||
|
||||
class FedYoloClient(object):
|
||||
def __init__(self, name, model_name, params):
|
||||
"""
|
||||
Initialize the client k for federated learning
|
||||
Args:
|
||||
:param name: Name of the client k
|
||||
:param model_name: Name of the model
|
||||
:param params: config file including the hyperparameters for local training
|
||||
- batch_size: Local training batch size in the client k
|
||||
- num_workers: Number of data loader workers
|
||||
|
||||
- min_lr: Minimum learning rate
|
||||
- max_lr: Maximum learning rate
|
||||
- momentum: Momentum for local training
|
||||
- weight_decay: Weight decay for local training
|
||||
"""
|
||||
self.params = params
|
||||
# initialize the metadata in local client k
|
||||
self.target_ip = "127.0.0.3"
|
||||
self.port = 9999
|
||||
self.name = name
|
||||
|
||||
# initialize the parameters in local client k
|
||||
self._batch_size = self.params["local_batch_size"]
|
||||
self._min_lr = self.params["min_lr"]
|
||||
self._max_lr = self.params["max_lr"]
|
||||
self._momentum = self.params["momentum"]
|
||||
self.num_workers = self.params["num_workers"]
|
||||
|
||||
self.loss_record = []
|
||||
# train set length
|
||||
self.n_data = 0
|
||||
|
||||
# initialize the local training and testing dataset
|
||||
self.train_dataset = None
|
||||
self.val_dataset = None
|
||||
|
||||
# initialize the local model
|
||||
self._num_classes = len(self.params["names"])
|
||||
self._weight_decay = self.params["weight_decay"]
|
||||
|
||||
self.model_name = model_name
|
||||
self.model = init_model(model_name, self._num_classes)
|
||||
|
||||
model_parameters = filter(lambda p: p.requires_grad, self.model.parameters())
|
||||
self.parameter_number = sum([np.prod(p.size()) for p in model_parameters])
|
||||
|
||||
# GPU
|
||||
self._device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def load_trainset(self, train_dataset: list[str]):
|
||||
"""
|
||||
Load the local training dataset
|
||||
Args:
|
||||
:param train_dataset: Training dataset
|
||||
"""
|
||||
self.train_dataset = train_dataset
|
||||
self.n_data = len(self.train_dataset)
|
||||
|
||||
def update(self, Global_model_state_dict):
|
||||
"""
|
||||
Update the local model with the global model parameters
|
||||
Args:
|
||||
:param Global_model_state_dict: State dictionary of the global model
|
||||
"""
|
||||
|
||||
if not hasattr(self, "model") or self.model is None:
|
||||
self.model = init_model(self.model_name, self._num_classes)
|
||||
|
||||
# load the global model parameters
|
||||
self.model.load_state_dict(Global_model_state_dict, strict=True)
|
||||
|
||||
def train(self, args):
|
||||
"""
|
||||
Train the local model
|
||||
Args:
|
||||
:param args: Command line arguments
|
||||
- local_rank: Local rank for distributed training
|
||||
- world_size: World size for distributed training
|
||||
- distributed: Whether to use distributed training
|
||||
- input_size: Input size for the model
|
||||
Returns:
|
||||
:return: Local updated model, number of local data points, training loss
|
||||
"""
|
||||
|
||||
if args.distributed:
|
||||
torch.cuda.set_device(device=args.local_rank)
|
||||
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||||
# print(f"Client {self.name} - distributed training on {world_size} GPUs, local rank: {local_rank}")
|
||||
# self._device = torch.device("cuda", local_rank)
|
||||
|
||||
if args.local_rank == 0:
|
||||
pass
|
||||
# if not os.path.exists("weights"):
|
||||
# os.makedirs("weights")
|
||||
|
||||
util.setup_seed()
|
||||
util.setup_multi_processes()
|
||||
|
||||
# model
|
||||
# init model have been done in __init__()
|
||||
self.model.to(self._device)
|
||||
|
||||
# Optimizer
|
||||
accumulate = max(round(64 / (self._batch_size * args.world_size)), 1)
|
||||
self._weight_decay = self._batch_size * args.world_size * accumulate / 64
|
||||
|
||||
optimizer = torch.optim.SGD(
|
||||
util.set_params(self.model, self._weight_decay),
|
||||
lr=self._min_lr,
|
||||
momentum=self._momentum,
|
||||
nesterov=True,
|
||||
)
|
||||
|
||||
# EMA
|
||||
ema = util.EMA(self.model) if args.local_rank == 0 else None
|
||||
|
||||
data_set = Dataset(
|
||||
filenames=self.train_dataset,
|
||||
input_size=args.input_size,
|
||||
params=self.params,
|
||||
augment=True,
|
||||
)
|
||||
|
||||
if args.distributed:
|
||||
train_sampler = data.DistributedSampler(
|
||||
data_set, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
|
||||
)
|
||||
else:
|
||||
train_sampler = None
|
||||
|
||||
loader = data.DataLoader(
|
||||
data_set,
|
||||
batch_size=self._batch_size,
|
||||
shuffle=train_sampler is None,
|
||||
sampler=train_sampler,
|
||||
num_workers=self.num_workers,
|
||||
pin_memory=True,
|
||||
collate_fn=Dataset.collate_fn,
|
||||
)
|
||||
|
||||
# Scheduler
|
||||
num_steps = max(1, len(loader))
|
||||
# print(len(loader))
|
||||
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
|
||||
# DDP mode
|
||||
if args.distributed:
|
||||
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
|
||||
self.model = nn.parallel.DistributedDataParallel(
|
||||
module=self.model,
|
||||
device_ids=[args.local_rank],
|
||||
output_device=args.local_rank,
|
||||
find_unused_parameters=False,
|
||||
)
|
||||
|
||||
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True)
|
||||
criterion = util.ComputeLoss(self.model, self.params)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
self.model.train()
|
||||
# when distributed, set epoch for shuffling
|
||||
if args.distributed and train_sampler is not None:
|
||||
train_sampler.set_epoch(epoch)
|
||||
|
||||
if args.epochs - epoch == 10:
|
||||
# disable mosaic augmentation in the last 10 epochs
|
||||
ds = cast(Dataset, loader.dataset)
|
||||
ds.mosaic = False
|
||||
|
||||
avg_box_loss = util.AverageMeter()
|
||||
avg_cls_loss = util.AverageMeter()
|
||||
avg_dfl_loss = util.AverageMeter()
|
||||
|
||||
for i, (samples, targets) in enumerate(loader):
|
||||
global_step = i + num_steps * epoch
|
||||
scheduler.step(step=global_step, optimizer=optimizer)
|
||||
|
||||
samples = samples.cuda(non_blocking=True).float() / 255.0
|
||||
|
||||
# Forward
|
||||
with autocast("cuda", enabled=True):
|
||||
outputs = self.model(samples)
|
||||
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
|
||||
|
||||
# meters (use the *unscaled* values)
|
||||
bs = samples.size(0)
|
||||
avg_box_loss.update(box_loss.item(), bs)
|
||||
avg_cls_loss.update(cls_loss.item(), bs)
|
||||
avg_dfl_loss.update(dfl_loss.item(), bs)
|
||||
|
||||
# scale losses by batch/world if your loss is averaged internally per-sample/device
|
||||
box_loss = box_loss * self._batch_size * args.world_size
|
||||
cls_loss = cls_loss * self._batch_size * args.world_size
|
||||
dfl_loss = dfl_loss * self._batch_size * args.world_size
|
||||
|
||||
total_loss = box_loss + cls_loss + dfl_loss
|
||||
|
||||
# Backward
|
||||
amp_scale.scale(total_loss).backward()
|
||||
|
||||
# Optimize
|
||||
if (i + 1) % accumulate == 0:
|
||||
amp_scale.step(optimizer)
|
||||
amp_scale.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
if ema:
|
||||
ema.update(self.model)
|
||||
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
# clean
|
||||
if args.distributed:
|
||||
torch.distributed.destroy_process_group()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return (
|
||||
self.model.state_dict(),
|
||||
self.n_data,
|
||||
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg},
|
||||
)
|
Reference in New Issue
Block a user