Compare commits
7 Commits
101ffa51eb
...
76d7149512
Author | SHA1 | Date | |
---|---|---|---|
![]() |
76d7149512 | ||
![]() |
f23a22632f | ||
![]() |
314f46d542 | ||
![]() |
0343a0fd30 | ||
![]() |
3f4dd07572 | ||
![]() |
300ce2e93f | ||
![]() |
40de29591b |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -296,5 +296,8 @@ Network Trash Folder
|
|||||||
Temporary Items
|
Temporary Items
|
||||||
.apdisk
|
.apdisk
|
||||||
|
|
||||||
|
# ---> Custom
|
||||||
results/
|
results/
|
||||||
*.log
|
*.log
|
||||||
|
*.txt
|
||||||
|
weights/
|
||||||
|
@@ -22,3 +22,9 @@ nohup python fed_run.py > train.log 2>&1 &
|
|||||||
- Add more YOLO versions (e.g., YOLOv8, YOLOv5, etc.)
|
- Add more YOLO versions (e.g., YOLOv8, YOLOv5, etc.)
|
||||||
- Implement YOLOv8
|
- Implement YOLOv8
|
||||||
- Implement YOLOv5
|
- Implement YOLOv5
|
||||||
|
|
||||||
|
# references
|
||||||
|
[PyTorch Federated Learning](https://github.com/rruisong/pytorch_federated_learning)
|
||||||
|
|
||||||
|
|
||||||
|
[YOLOv11-pt](https://github.com/jahongir7174/YOLOv11-pt)
|
@@ -3,11 +3,11 @@ import torch
|
|||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.utils import data
|
from torch.utils import data
|
||||||
from torch.amp.autocast_mode import autocast
|
from torch.amp.autocast_mode import autocast
|
||||||
from tqdm import tqdm
|
|
||||||
from utils.fed_util import init_model
|
from utils.fed_util import init_model
|
||||||
from utils import util
|
from utils import util
|
||||||
from utils.dataset import Dataset
|
from utils.dataset import Dataset
|
||||||
from typing import cast
|
from typing import cast
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
|
||||||
class FedYoloClient(object):
|
class FedYoloClient(object):
|
||||||
@@ -82,52 +82,48 @@ class FedYoloClient(object):
|
|||||||
# load the global model parameters
|
# load the global model parameters
|
||||||
self.model.load_state_dict(Global_model_state_dict, strict=True)
|
self.model.load_state_dict(Global_model_state_dict, strict=True)
|
||||||
|
|
||||||
def train(self, args):
|
def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
|
||||||
"""
|
"""
|
||||||
Train the local model
|
Train the local model.
|
||||||
Args:
|
Returns: (state_dict, n_data, avg_loss_per_image)
|
||||||
:param args: Command line arguments
|
|
||||||
- local_rank: Local rank for distributed training
|
|
||||||
- world_size: World size for distributed training
|
|
||||||
- distributed: Whether to use distributed training
|
|
||||||
- input_size: Input size for the model
|
|
||||||
Returns:
|
|
||||||
:return: Local updated model, number of local data points, training loss
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
# ---- Dist init (if any) ----
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
torch.cuda.set_device(device=args.local_rank)
|
torch.cuda.set_device(device=args.local_rank)
|
||||||
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
torch.distributed.init_process_group(backend="nccl", init_method="env://")
|
||||||
# print(f"Client {self.name} - distributed training on {world_size} GPUs, local rank: {local_rank}")
|
|
||||||
# self._device = torch.device("cuda", local_rank)
|
|
||||||
|
|
||||||
if args.local_rank == 0:
|
|
||||||
pass
|
|
||||||
# if not os.path.exists("weights"):
|
|
||||||
# os.makedirs("weights")
|
|
||||||
|
|
||||||
util.setup_seed()
|
util.setup_seed()
|
||||||
util.setup_multi_processes()
|
util.setup_multi_processes()
|
||||||
|
|
||||||
# model
|
# device = torch.device(f"cuda:{args.local_rank}" if torch.cuda.is_available() else "cpu")
|
||||||
# init model have been done in __init__()
|
# self.model.to(device)
|
||||||
self.model.to(self._device)
|
self.model.cuda()
|
||||||
|
# show model architecture
|
||||||
|
# print(self.model)
|
||||||
|
|
||||||
# Optimizer
|
# ---- Optimizer / WD scaling & LR warmup/schedule ----
|
||||||
accumulate = max(round(64 / (self._batch_size * args.world_size)), 1)
|
# accumulate = effective grad-accumulation steps to emulate global batch 64
|
||||||
self._weight_decay = self._batch_size * args.world_size * accumulate / 64
|
world_size = getattr(args, "world_size", 1)
|
||||||
|
accumulate = max(round(64 / (self._batch_size * max(world_size, 1))), 1)
|
||||||
|
|
||||||
|
# scale weight_decay like YOLO recipes
|
||||||
|
scaled_wd = self._weight_decay * self._batch_size * max(world_size, 1) * accumulate / 64
|
||||||
optimizer = torch.optim.SGD(
|
optimizer = torch.optim.SGD(
|
||||||
util.set_params(self.model, self._weight_decay),
|
util.set_params(self.model, scaled_wd),
|
||||||
lr=self._min_lr,
|
lr=self._min_lr,
|
||||||
momentum=self._momentum,
|
momentum=self._momentum,
|
||||||
nesterov=True,
|
nesterov=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
# EMA
|
# ---- EMA (track the underlying module if DDP) ----
|
||||||
|
# track_model = self.model.module if is_ddp else self.model
|
||||||
ema = util.EMA(self.model) if args.local_rank == 0 else None
|
ema = util.EMA(self.model) if args.local_rank == 0 else None
|
||||||
|
|
||||||
data_set = Dataset(
|
print(type(self.train_dataset))
|
||||||
|
|
||||||
|
# ---- Data ----
|
||||||
|
dataset = Dataset(
|
||||||
filenames=self.train_dataset,
|
filenames=self.train_dataset,
|
||||||
input_size=args.input_size,
|
input_size=args.input_size,
|
||||||
params=self.params,
|
params=self.params,
|
||||||
@@ -136,26 +132,28 @@ class FedYoloClient(object):
|
|||||||
|
|
||||||
if args.distributed:
|
if args.distributed:
|
||||||
train_sampler = data.DistributedSampler(
|
train_sampler = data.DistributedSampler(
|
||||||
data_set, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
|
dataset, num_replicas=args.world_size, rank=args.local_rank, shuffle=True
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train_sampler = None
|
train_sampler = None
|
||||||
|
|
||||||
loader = data.DataLoader(
|
loader = data.DataLoader(
|
||||||
data_set,
|
dataset,
|
||||||
batch_size=self._batch_size,
|
batch_size=self._batch_size,
|
||||||
shuffle=train_sampler is None,
|
shuffle=(train_sampler is None),
|
||||||
sampler=train_sampler,
|
sampler=train_sampler,
|
||||||
num_workers=self.num_workers,
|
num_workers=self.num_workers,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
collate_fn=Dataset.collate_fn,
|
collate_fn=Dataset.collate_fn,
|
||||||
|
drop_last=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Scheduler
|
|
||||||
num_steps = max(1, len(loader))
|
num_steps = max(1, len(loader))
|
||||||
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
|
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
|
||||||
# DDP mode
|
|
||||||
if args.distributed:
|
# ---- SyncBN + DDP (if any) ----
|
||||||
|
is_ddp = bool(args.distributed)
|
||||||
|
if is_ddp:
|
||||||
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
|
self.model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self.model)
|
||||||
self.model = nn.parallel.DistributedDataParallel(
|
self.model = nn.parallel.DistributedDataParallel(
|
||||||
module=self.model,
|
module=self.model,
|
||||||
@@ -164,102 +162,133 @@ class FedYoloClient(object):
|
|||||||
find_unused_parameters=False,
|
find_unused_parameters=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True)
|
# ---- AMP + loss ----
|
||||||
|
scaler = torch.amp.grad_scaler.GradScaler(enabled=True)
|
||||||
|
# criterion = util.ComputeLoss(
|
||||||
|
# self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model,
|
||||||
|
# self.params,
|
||||||
|
# )
|
||||||
criterion = util.ComputeLoss(self.model, self.params)
|
criterion = util.ComputeLoss(self.model, self.params)
|
||||||
|
|
||||||
# log
|
# ---- Training ----
|
||||||
# if args.local_rank == 0:
|
|
||||||
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
|
|
||||||
# print("\n" + header)
|
|
||||||
# p_bar = tqdm(total=args.epochs * num_steps, ncols=120)
|
|
||||||
# p_bar.set_description(f"{self.name:>10}")
|
|
||||||
|
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
|
# (self.model.module if isinstance(self.model, nn.parallel.DistributedDataParallel) else self.model).train()
|
||||||
self.model.train()
|
self.model.train()
|
||||||
# when distributed, set epoch for shuffling
|
if is_ddp and train_sampler is not None:
|
||||||
if args.distributed and train_sampler is not None:
|
|
||||||
train_sampler.set_epoch(epoch)
|
train_sampler.set_epoch(epoch)
|
||||||
|
|
||||||
if args.epochs - epoch == 10:
|
# disable mosaic in the last 10 epochs (if dataset supports it)
|
||||||
# disable mosaic augmentation in the last 10 epochs
|
if args.epochs - epoch == 10 and hasattr(loader.dataset, "mosaic"):
|
||||||
ds = cast(Dataset, loader.dataset)
|
ds = cast(Dataset, loader.dataset)
|
||||||
ds.mosaic = False
|
ds.mosaic = False
|
||||||
|
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
avg_box_loss = util.AverageMeter()
|
loss_box_meter = util.AverageMeter()
|
||||||
avg_cls_loss = util.AverageMeter()
|
loss_cls_meter = util.AverageMeter()
|
||||||
avg_dfl_loss = util.AverageMeter()
|
loss_dfl_meter = util.AverageMeter()
|
||||||
|
|
||||||
# # --- header (once per epoch, YOLO-style) ---
|
for i, (images, targets) in enumerate(loader):
|
||||||
# if args.local_rank == 0:
|
print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}")
|
||||||
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
|
step = i + epoch * num_steps
|
||||||
# print("\n" + header)
|
|
||||||
|
|
||||||
# p_bar = enumerate(loader)
|
# scheduler per-step (your util.LinearLR expects step)
|
||||||
# if args.local_rank == 0:
|
scheduler.step(step=step, optimizer=optimizer)
|
||||||
# p_bar = tqdm(p_bar, total=num_steps, ncols=120)
|
|
||||||
|
|
||||||
for i, (samples, targets) in enumerate(loader):
|
# images = images.to(device, non_blocking=True).float() / 255.0
|
||||||
global_step = i + num_steps * epoch
|
images = images.cuda().float() / 255.0
|
||||||
scheduler.step(step=global_step, optimizer=optimizer)
|
bs = images.size(0)
|
||||||
|
# total_imgs_seen += bs
|
||||||
|
|
||||||
samples = samples.cuda(non_blocking=True).float() / 255.0
|
# targets: keep as your ComputeLoss expects (often CPU lists/tensors).
|
||||||
|
# Move to GPU here only if your loss requires it.
|
||||||
|
|
||||||
# Forward
|
with autocast(device_type="cuda", enabled=True):
|
||||||
with autocast("cuda", enabled=True):
|
outputs = self.model(images) # DDP wraps forward
|
||||||
outputs = self.model(samples)
|
|
||||||
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
|
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
|
||||||
|
|
||||||
# meters (use the *unscaled* values)
|
# total_loss = box_loss + cls_loss + dfl_loss
|
||||||
bs = samples.size(0)
|
# Gradient accumulation: normalize by 'accumulate' so LR stays effective
|
||||||
avg_box_loss.update(box_loss.item(), bs)
|
# total_loss = total_loss / accumulate
|
||||||
avg_cls_loss.update(cls_loss.item(), bs)
|
|
||||||
avg_dfl_loss.update(dfl_loss.item(), bs)
|
|
||||||
|
|
||||||
# scale losses by batch/world if your loss is averaged internally per-sample/device
|
# IMPORTANT: assume criterion returns **average per image** in the batch.
|
||||||
# box_loss = box_loss * self._batch_size * args.world_size
|
# Keep logging on the true (unscaled) values:
|
||||||
# cls_loss = cls_loss * self._batch_size * args.world_size
|
loss_box_meter.update(box_loss.item(), bs)
|
||||||
# dfl_loss = dfl_loss * self._batch_size * args.world_size
|
loss_cls_meter.update(cls_loss.item(), bs)
|
||||||
|
loss_dfl_meter.update(dfl_loss.item(), bs)
|
||||||
|
|
||||||
|
box_loss *= self._batch_size
|
||||||
|
cls_loss *= self._batch_size
|
||||||
|
dfl_loss *= self._batch_size
|
||||||
|
box_loss *= args.world_size
|
||||||
|
cls_loss *= args.world_size
|
||||||
|
dfl_loss *= args.world_size
|
||||||
total_loss = box_loss + cls_loss + dfl_loss
|
total_loss = box_loss + cls_loss + dfl_loss
|
||||||
|
|
||||||
# Backward
|
scaler.scale(total_loss).backward()
|
||||||
amp_scale.scale(total_loss).backward()
|
|
||||||
|
|
||||||
# Optimize
|
# optimize
|
||||||
if (i + 1) % accumulate == 0:
|
if step % accumulate == 0:
|
||||||
amp_scale.unscale_(optimizer) # unscale gradients
|
# scaler.unscale_(optimizer)
|
||||||
util.clip_gradients(model=self.model, max_norm=10.0) # clip gradients
|
# util.clip_gradients(self.model)
|
||||||
amp_scale.step(optimizer)
|
scaler.step(optimizer)
|
||||||
amp_scale.update()
|
scaler.update()
|
||||||
optimizer.zero_grad(set_to_none=True)
|
optimizer.zero_grad(set_to_none=True)
|
||||||
if ema:
|
|
||||||
ema.update(self.model)
|
|
||||||
|
|
||||||
# torch.cuda.synchronize()
|
# # Step when we have 'accumulate' micro-batches, or at the end
|
||||||
|
# if ((i + 1) % accumulate == 0) or (i + 1 == len(loader)):
|
||||||
# tqdm update
|
# scaler.unscale_(optimizer)
|
||||||
# if args.local_rank == 0:
|
# util.clip_gradients(
|
||||||
# mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
|
# model=(
|
||||||
# desc = ("%10s" * 2 + "%10.4g" * 3) % (
|
# self.model.module
|
||||||
# self.name,
|
# if isinstance(self.model, nn.parallel.DistributedDataParallel)
|
||||||
# mem,
|
# else self.model
|
||||||
# avg_box_loss.avg,
|
# ),
|
||||||
# avg_cls_loss.avg,
|
# max_norm=10.0,
|
||||||
# avg_dfl_loss.avg,
|
|
||||||
# )
|
# )
|
||||||
# cast(tqdm, p_bar).set_description(desc)
|
# scaler.step(optimizer)
|
||||||
# p_bar.update(1)
|
# scaler.update()
|
||||||
|
# optimizer.zero_grad(set_to_none=True)
|
||||||
|
|
||||||
# p_bar.close()
|
if ema:
|
||||||
|
# Update EMA from the underlying module
|
||||||
|
ema.update(
|
||||||
|
self.model.module
|
||||||
|
if isinstance(self.model, nn.parallel.DistributedDataParallel)
|
||||||
|
else self.model
|
||||||
|
)
|
||||||
|
# print loss to test
|
||||||
|
print(
|
||||||
|
f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}"
|
||||||
|
)
|
||||||
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
# clean
|
# ---- Final average loss (per image) over the whole epoch span ----
|
||||||
if args.distributed:
|
avg_loss_per_image = loss_box_meter.avg + loss_cls_meter.avg + loss_dfl_meter.avg
|
||||||
|
|
||||||
|
# ---- Cleanup DDP ----
|
||||||
|
if is_ddp:
|
||||||
torch.distributed.destroy_process_group()
|
torch.distributed.destroy_process_group()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
|
|
||||||
return (
|
# ---- Choose which weights to return ----
|
||||||
self.model.state_dict() if not ema else ema.ema.state_dict(),
|
# - If EMA exists, return EMA weights (common YOLO eval practice)
|
||||||
self.n_data,
|
# - Be careful with DDP: grab state_dict from the underlying module / EMA model
|
||||||
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg},
|
if ema:
|
||||||
)
|
# print("Using EMA weights")
|
||||||
|
return (ema.ema.state_dict(), self.n_data, avg_loss_per_image)
|
||||||
|
else:
|
||||||
|
# Safely get the underlying module if wrapped by DDP; getattr returns the module or the original object.
|
||||||
|
model_obj = getattr(self.model, "module", self.model)
|
||||||
|
# If it's a proper nn.Module, call state_dict(); if it's already a state dict, use it;
|
||||||
|
# otherwise try to call state_dict() and finally fall back to wrapping the object.
|
||||||
|
if isinstance(model_obj, torch.nn.Module):
|
||||||
|
model_to_return = model_obj.state_dict()
|
||||||
|
elif isinstance(model_obj, dict):
|
||||||
|
model_to_return = model_obj
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
model_to_return = model_obj.state_dict()
|
||||||
|
except Exception:
|
||||||
|
# fallback: if model_obj is a tensor or unexpected object, wrap it in a dict
|
||||||
|
model_to_return = {"state": model_obj}
|
||||||
|
return model_to_return, self.n_data, avg_loss_per_image
|
||||||
|
@@ -4,6 +4,7 @@ from torch.utils.data import DataLoader
|
|||||||
from utils.fed_util import init_model
|
from utils.fed_util import init_model
|
||||||
from utils.dataset import Dataset
|
from utils.dataset import Dataset
|
||||||
from utils import util
|
from utils import util
|
||||||
|
from nets import YOLO
|
||||||
|
|
||||||
|
|
||||||
class FedYoloServer(object):
|
class FedYoloServer(object):
|
||||||
@@ -21,7 +22,7 @@ class FedYoloServer(object):
|
|||||||
self.client_n_data = {}
|
self.client_n_data = {}
|
||||||
self.selected_clients = []
|
self.selected_clients = []
|
||||||
|
|
||||||
self._batch_size = params.get("val_batch_size", 4)
|
self._batch_size = params.get("val_batch_size", 200)
|
||||||
self.client_list = client_list
|
self.client_list = client_list
|
||||||
self.valset = None
|
self.valset = None
|
||||||
|
|
||||||
@@ -40,7 +41,7 @@ class FedYoloServer(object):
|
|||||||
self.model = init_model(model_name, self._num_classes)
|
self.model = init_model(model_name, self._num_classes)
|
||||||
self.params = params
|
self.params = params
|
||||||
|
|
||||||
def load_valset(self, valset):
|
def load_valset(self, valset: Dataset):
|
||||||
"""Server loads the validation dataset."""
|
"""Server loads the validation dataset."""
|
||||||
self.valset = valset
|
self.valset = valset
|
||||||
|
|
||||||
@@ -48,78 +49,6 @@ class FedYoloServer(object):
|
|||||||
"""Return global model weights."""
|
"""Return global model weights."""
|
||||||
return self.model.state_dict()
|
return self.model.state_dict()
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def test(self, args) -> dict:
|
|
||||||
"""
|
|
||||||
Test the global model on the server's validation set.
|
|
||||||
Returns:
|
|
||||||
dict with keys: mAP, mAP50, precision, recall
|
|
||||||
"""
|
|
||||||
if self.valset is None:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
loader = DataLoader(
|
|
||||||
self.valset,
|
|
||||||
batch_size=self._batch_size,
|
|
||||||
shuffle=False,
|
|
||||||
num_workers=4,
|
|
||||||
pin_memory=True,
|
|
||||||
collate_fn=Dataset.collate_fn,
|
|
||||||
)
|
|
||||||
|
|
||||||
dev = self._device
|
|
||||||
# move to device for eval; keep in float32 for stability
|
|
||||||
self.model.eval().to(dev).float()
|
|
||||||
|
|
||||||
iou_v = torch.linspace(0.5, 0.95, 10, device=dev)
|
|
||||||
n_iou = iou_v.numel()
|
|
||||||
metrics = []
|
|
||||||
|
|
||||||
for samples, targets in loader:
|
|
||||||
samples = samples.to(dev, non_blocking=True).float() / 255.0
|
|
||||||
_, _, h, w = samples.shape
|
|
||||||
scale = torch.tensor((w, h, w, h), device=dev)
|
|
||||||
|
|
||||||
outputs = self.model(samples)
|
|
||||||
outputs = util.non_max_suppression(outputs)
|
|
||||||
|
|
||||||
for i, output in enumerate(outputs):
|
|
||||||
idx = targets["idx"] == i
|
|
||||||
cls = targets["cls"][idx].to(dev)
|
|
||||||
box = targets["box"][idx].to(dev)
|
|
||||||
|
|
||||||
metric = torch.zeros((output.shape[0], n_iou), dtype=torch.bool, device=dev)
|
|
||||||
if output.shape[0] == 0:
|
|
||||||
if cls.shape[0]:
|
|
||||||
metrics.append((metric, *torch.zeros((2, 0), device=dev), cls.squeeze(-1)))
|
|
||||||
continue
|
|
||||||
|
|
||||||
if cls.shape[0]:
|
|
||||||
if cls.dim() == 1:
|
|
||||||
cls = cls.unsqueeze(1)
|
|
||||||
box_xy = util.wh2xy(box)
|
|
||||||
if not isinstance(box_xy, torch.Tensor):
|
|
||||||
box_xy = torch.tensor(box_xy, device=dev)
|
|
||||||
target = torch.cat((cls, box_xy * scale), dim=1)
|
|
||||||
metric = util.compute_metric(output[:, :6], target, iou_v)
|
|
||||||
|
|
||||||
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
|
||||||
|
|
||||||
if not metrics:
|
|
||||||
# move back to CPU before returning
|
|
||||||
self.model.to("cpu").float()
|
|
||||||
return {"mAP": 0, "mAP50": 0, "precision": 0, "recall": 0}
|
|
||||||
|
|
||||||
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)]
|
|
||||||
if len(metrics) and metrics[0].any():
|
|
||||||
_, _, prec, rec, map50, mean_ap = util.compute_ap(*metrics, names=self.params["names"], plot=False)
|
|
||||||
else:
|
|
||||||
prec, rec, map50, mean_ap = 0, 0, 0, 0
|
|
||||||
|
|
||||||
# return model to CPU so next agg() stays device-consistent
|
|
||||||
self.model.to("cpu").float()
|
|
||||||
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
|
|
||||||
|
|
||||||
def select_clients(self, connection_ratio=1.0):
|
def select_clients(self, connection_ratio=1.0):
|
||||||
"""
|
"""
|
||||||
Randomly select a fraction of clients.
|
Randomly select a fraction of clients.
|
||||||
@@ -130,80 +59,69 @@ class FedYoloServer(object):
|
|||||||
self.n_data = 0
|
self.n_data = 0
|
||||||
for client_id in self.client_list:
|
for client_id in self.client_list:
|
||||||
# Random selection based on connection ratio
|
# Random selection based on connection ratio
|
||||||
if np.random.rand() <= connection_ratio:
|
s = np.random.binomial(np.ones(1).astype(int), connection_ratio)
|
||||||
|
if s[0] == 1:
|
||||||
self.selected_clients.append(client_id)
|
self.selected_clients.append(client_id)
|
||||||
self.n_data += self.client_n_data.get(client_id, 0)
|
self.n_data += self.client_n_data[client_id]
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
def agg(self):
|
def agg(self):
|
||||||
"""Aggregate client updates (FedAvg) on CPU/FP32, preserving non-float buffers."""
|
"""
|
||||||
|
Server aggregates the local updates from selected clients using FedAvg.
|
||||||
|
|
||||||
|
:return: model_state: aggregated model weights
|
||||||
|
:return: avg_loss: weighted average training loss across selected clients
|
||||||
|
:return: n_data: total number of data points across selected clients
|
||||||
|
"""
|
||||||
if len(self.selected_clients) == 0 or self.n_data == 0:
|
if len(self.selected_clients) == 0 or self.n_data == 0:
|
||||||
return self.model.state_dict(), {}, 0
|
import warnings
|
||||||
|
|
||||||
# Ensure global model is on CPU for safe load later
|
warnings.warn("No clients selected or no data available for aggregation.")
|
||||||
self.model.to("cpu")
|
return self.model.state_dict(), 0, 0
|
||||||
global_state = self.model.state_dict() # may hold CPU or CUDA refs; we’re on CPU now
|
|
||||||
|
|
||||||
avg_loss = {}
|
# Initialize a model for aggregation
|
||||||
total_n = float(self.n_data)
|
model = init_model(model_name=self.model_name, num_classes=self._num_classes)
|
||||||
|
model_state = model.state_dict()
|
||||||
|
|
||||||
# Prepare accumulators on CPU. For floating tensors, use float32 zeros.
|
avg_loss = 0
|
||||||
# For non-floating tensors (e.g., BN num_batches_tracked int64), we’ll copy from the first client.
|
|
||||||
new_state = {}
|
|
||||||
first_client = None
|
|
||||||
for cid in self.selected_clients:
|
|
||||||
if cid in self.client_state:
|
|
||||||
first_client = cid
|
|
||||||
break
|
|
||||||
|
|
||||||
assert first_client is not None, "No client states available to aggregate."
|
# Aggregate the local updated models from selected clients
|
||||||
|
for i, name in enumerate(self.selected_clients):
|
||||||
for k, v in global_state.items():
|
if name not in self.client_state:
|
||||||
if v.is_floating_point():
|
|
||||||
new_state[k] = torch.zeros_like(v.detach().cpu(), dtype=torch.float32)
|
|
||||||
else:
|
|
||||||
# For non-float buffers, just copy from the first client (or keep global)
|
|
||||||
new_state[k] = self.client_state[first_client][k].clone()
|
|
||||||
|
|
||||||
# Accumulate floating tensors with weights; keep non-floats as assigned above
|
|
||||||
for cid in self.selected_clients:
|
|
||||||
if cid not in self.client_state:
|
|
||||||
continue
|
continue
|
||||||
weight = self.client_n_data[cid] / total_n
|
for key in self.client_state[name]:
|
||||||
cst = self.client_state[cid]
|
if i == 0:
|
||||||
for k in new_state.keys():
|
# First client, initialize the model_state
|
||||||
if new_state[k].is_floating_point():
|
model_state[key] = self.client_state[name][key] * (self.client_n_data[name] / self.n_data)
|
||||||
# cst[k] is CPU; ensure float32 for accumulation
|
else:
|
||||||
new_state[k].add_(cst[k].to(torch.float32), alpha=weight)
|
# math equation: w = sum(n_k / n * w_k)
|
||||||
|
model_state[key] = model_state[key] + self.client_state[name][key] * (
|
||||||
# weighted average losses
|
self.client_n_data[name] / self.n_data
|
||||||
for lk, lv in self.client_loss[cid].items():
|
)
|
||||||
avg_loss[lk] = avg_loss.get(lk, 0.0) + float(lv) * weight
|
avg_loss = avg_loss + self.client_loss[name] * (self.client_n_data[name] / self.n_data)
|
||||||
|
|
||||||
# Load aggregated state back into the global model (model is on CPU)
|
|
||||||
with torch.no_grad():
|
|
||||||
self.model.load_state_dict(new_state, strict=True)
|
|
||||||
|
|
||||||
|
self.model.load_state_dict(model_state, strict=True)
|
||||||
self.round += 1
|
self.round += 1
|
||||||
# Return CPU state_dict (good for broadcasting to clients)
|
|
||||||
return {k: v.clone() for k, v in self.model.state_dict().items()}, avg_loss, int(self.n_data)
|
|
||||||
|
|
||||||
def rec(self, name, state_dict, n_data, loss_dict):
|
n_data = self.n_data
|
||||||
|
|
||||||
|
return model_state, avg_loss, n_data
|
||||||
|
|
||||||
|
def rec(self, name, state_dict, n_data, loss):
|
||||||
"""
|
"""
|
||||||
Receive local update from a client.
|
Receive local update from a client.
|
||||||
- Store all floating tensors as CPU float32
|
- Store all floating tensors as CPU float32
|
||||||
- Store non-floating tensors (e.g., BN counters) as CPU in original dtype
|
- Store non-floating tensors (e.g., BN counters) as CPU in original dtype
|
||||||
"""
|
"""
|
||||||
self.n_data += n_data
|
self.n_data += n_data
|
||||||
safe_state = {}
|
|
||||||
with torch.no_grad():
|
self.client_state[name] = {}
|
||||||
for k, v in state_dict.items():
|
self.client_n_data[name] = {}
|
||||||
t = v.detach().cpu()
|
self.client_loss[name] = {}
|
||||||
if t.is_floating_point():
|
|
||||||
t = t.to(torch.float32)
|
self.client_state[name].update(state_dict)
|
||||||
safe_state[k] = t
|
|
||||||
self.client_state[name] = safe_state
|
|
||||||
self.client_n_data[name] = int(n_data)
|
self.client_n_data[name] = int(n_data)
|
||||||
self.client_loss[name] = {k: float(v) for k, v in loss_dict.items()}
|
self.client_loss[name] = loss
|
||||||
|
|
||||||
def flush(self):
|
def flush(self):
|
||||||
"""Clear stored client updates."""
|
"""Clear stored client updates."""
|
||||||
@@ -211,3 +129,94 @@ class FedYoloServer(object):
|
|||||||
self.client_state.clear()
|
self.client_state.clear()
|
||||||
self.client_n_data.clear()
|
self.client_n_data.clear()
|
||||||
self.client_loss.clear()
|
self.client_loss.clear()
|
||||||
|
|
||||||
|
def test(self):
|
||||||
|
"""Evaluate the global model on the server's validation dataset."""
|
||||||
|
if self.valset is None:
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn("No validation dataset available for testing.")
|
||||||
|
return {}
|
||||||
|
return test(self.valset, self.params, self.model)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.no_grad()
|
||||||
|
def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]:
|
||||||
|
"""
|
||||||
|
Evaluate the model on the validation dataset.
|
||||||
|
Args:
|
||||||
|
valset: validation dataset
|
||||||
|
params: dict of parameters (must include 'names')
|
||||||
|
model: YOLO model to evaluate
|
||||||
|
batch_size: batch size for evaluation
|
||||||
|
Returns:
|
||||||
|
dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap)
|
||||||
|
"""
|
||||||
|
loader = DataLoader(
|
||||||
|
dataset=valset,
|
||||||
|
batch_size=batch_size,
|
||||||
|
shuffle=False,
|
||||||
|
num_workers=4,
|
||||||
|
pin_memory=True,
|
||||||
|
collate_fn=Dataset.collate_fn,
|
||||||
|
)
|
||||||
|
|
||||||
|
model.cuda()
|
||||||
|
model.half()
|
||||||
|
model.eval()
|
||||||
|
|
||||||
|
# Configure
|
||||||
|
iou_v = torch.linspace(start=0.5, end=0.95, steps=10).cuda() # iou vector for mAP@0.5:0.95
|
||||||
|
n_iou = iou_v.numel()
|
||||||
|
|
||||||
|
m_pre = 0
|
||||||
|
m_rec = 0
|
||||||
|
map50 = 0
|
||||||
|
mean_ap = 0
|
||||||
|
metrics = []
|
||||||
|
|
||||||
|
for samples, targets in loader:
|
||||||
|
samples = samples.cuda()
|
||||||
|
samples = samples.half() # uint8 to fp16/32
|
||||||
|
samples = samples / 255.0 # 0 - 255 to 0.0 - 1.0
|
||||||
|
_, _, h, w = samples.shape # batch-size, channels, height, width
|
||||||
|
scale = torch.tensor((w, h, w, h)).cuda()
|
||||||
|
# Inference
|
||||||
|
outputs = model(samples)
|
||||||
|
# NMS
|
||||||
|
outputs = util.non_max_suppression(outputs)
|
||||||
|
# Metrics
|
||||||
|
for i, output in enumerate(outputs):
|
||||||
|
idx = targets["idx"]
|
||||||
|
if idx.dim() > 1:
|
||||||
|
idx = idx.squeeze(-1)
|
||||||
|
idx = idx == i
|
||||||
|
# idx = targets["idx"] == i
|
||||||
|
cls = targets["cls"][idx]
|
||||||
|
box = targets["box"][idx]
|
||||||
|
|
||||||
|
cls = cls.cuda()
|
||||||
|
box = box.cuda()
|
||||||
|
|
||||||
|
metric = torch.zeros(output.shape[0], n_iou, dtype=torch.bool).cuda()
|
||||||
|
|
||||||
|
if output.shape[0] == 0:
|
||||||
|
if cls.shape[0]:
|
||||||
|
metrics.append((metric, *torch.zeros((2, 0)).cuda(), cls.squeeze(-1)))
|
||||||
|
continue
|
||||||
|
# Evaluate
|
||||||
|
if cls.shape[0]:
|
||||||
|
target = torch.cat(tensors=(cls, util.wh2xy(box) * scale), dim=1)
|
||||||
|
metric = util.compute_metric(output[:, :6], target, iou_v)
|
||||||
|
# Append
|
||||||
|
metrics.append((metric, output[:, 4], output[:, 5], cls.squeeze(-1)))
|
||||||
|
|
||||||
|
# Compute metrics
|
||||||
|
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
|
||||||
|
if len(metrics) and metrics[0].any():
|
||||||
|
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"])
|
||||||
|
# Print results
|
||||||
|
# print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
|
||||||
|
# Return results
|
||||||
|
model.float() # for training
|
||||||
|
return mean_ap, map50, m_rec, m_pre
|
||||||
|
123
fed_run.py
123
fed_run.py
@@ -3,92 +3,16 @@ import os
|
|||||||
import json
|
import json
|
||||||
import yaml
|
import yaml
|
||||||
import time
|
import time
|
||||||
import random
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
from utils.dataset import Dataset
|
from utils.fed_util import build_valset_if_available, seed_everything, plot_curves
|
||||||
from fed_algo_cs.client_base import FedYoloClient
|
from fed_algo_cs.client_base import FedYoloClient
|
||||||
from fed_algo_cs.server_base import FedYoloServer
|
from fed_algo_cs.server_base import FedYoloServer
|
||||||
from utils.args import args_parser # args parser
|
from utils.args import args_parser # args parser
|
||||||
from utils.fed_util import divide_trainset # divide_trainset
|
from utils.fed_util import divide_trainset # divide_trainset
|
||||||
|
|
||||||
|
|
||||||
def _read_list_file(txt_path: str):
|
|
||||||
"""Read one path per line; keep as-is (absolute or relative)."""
|
|
||||||
if not txt_path or not os.path.exists(txt_path):
|
|
||||||
return []
|
|
||||||
with open(txt_path, "r", encoding="utf-8") as f:
|
|
||||||
return [ln.strip() for ln in f if ln.strip()]
|
|
||||||
|
|
||||||
|
|
||||||
def _build_valset_if_available(cfg, params):
|
|
||||||
"""
|
|
||||||
Try to build a validation Dataset.
|
|
||||||
- If cfg['val_txt'] exists, use it.
|
|
||||||
- Else if <dataset_path>/val.txt exists, use it.
|
|
||||||
- Else return None (testing will be skipped).
|
|
||||||
Args:
|
|
||||||
cfg: config dict
|
|
||||||
params: params dict for Dataset
|
|
||||||
Returns:
|
|
||||||
Dataset or None
|
|
||||||
"""
|
|
||||||
input_size = int(cfg.get("input_size", 640))
|
|
||||||
val_txt = cfg.get("val_txt", "")
|
|
||||||
if not val_txt:
|
|
||||||
ds_root = cfg.get("dataset_path", "")
|
|
||||||
guess = os.path.join(ds_root, "val.txt") if ds_root else ""
|
|
||||||
val_txt = guess if os.path.exists(guess) else ""
|
|
||||||
|
|
||||||
val_files = _read_list_file(val_txt)
|
|
||||||
if not val_files:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return Dataset(
|
|
||||||
filenames=val_files,
|
|
||||||
input_size=input_size,
|
|
||||||
params=params,
|
|
||||||
augment=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _seed_everything(seed: int):
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
random.seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def _plot_curves(save_dir, hist):
|
|
||||||
"""
|
|
||||||
Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round.
|
|
||||||
"""
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
rounds = np.arange(1, len(hist["mAP"]) + 1)
|
|
||||||
|
|
||||||
plt.figure()
|
|
||||||
if hist["mAP"]:
|
|
||||||
plt.plot(rounds, hist["mAP"], label="mAP50-95")
|
|
||||||
if hist["mAP50"]:
|
|
||||||
plt.plot(rounds, hist["mAP50"], label="mAP50")
|
|
||||||
if hist["precision"]:
|
|
||||||
plt.plot(rounds, hist["precision"], label="precision")
|
|
||||||
if hist["recall"]:
|
|
||||||
plt.plot(rounds, hist["recall"], label="recall")
|
|
||||||
if hist["train_loss"]:
|
|
||||||
plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)")
|
|
||||||
plt.xlabel("Global Round")
|
|
||||||
plt.ylabel("Metric")
|
|
||||||
plt.title("Federated YOLO - Server Metrics")
|
|
||||||
plt.legend()
|
|
||||||
out_png = os.path.join(save_dir, "fed_yolo_curves.png")
|
|
||||||
plt.savefig(out_png, dpi=150, bbox_inches="tight")
|
|
||||||
print(f"[plot] saved: {out_png}")
|
|
||||||
|
|
||||||
|
|
||||||
def fed_run():
|
def fed_run():
|
||||||
"""
|
"""
|
||||||
Main FL process:
|
Main FL process:
|
||||||
@@ -98,20 +22,22 @@ def fed_run():
|
|||||||
- Record & save results, plot curves
|
- Record & save results, plot curves
|
||||||
"""
|
"""
|
||||||
args_cli = args_parser()
|
args_cli = args_parser()
|
||||||
|
# TODO: cfg and params should not be separately defined
|
||||||
with open(args_cli.config, "r", encoding="utf-8") as f:
|
with open(args_cli.config, "r", encoding="utf-8") as f:
|
||||||
cfg = yaml.safe_load(f)
|
cfg = yaml.safe_load(f)
|
||||||
|
|
||||||
# --- params / config normalization ---
|
# --- params / config normalization ---
|
||||||
# For convenience we pass the same `params` dict used by Dataset/model/loss.
|
# For convenience we pass the same `params` dict used by Dataset/model/loss.
|
||||||
# Here we re-use the top-level cfg directly as params.
|
# Here we re-use the top-level cfg directly as params.
|
||||||
params = dict(cfg)
|
# params = dict(cfg)
|
||||||
|
|
||||||
if "names" in cfg and isinstance(cfg["names"], dict):
|
if "names" in cfg and isinstance(cfg["names"], dict):
|
||||||
# Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
|
# Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
|
||||||
# but we can leave dict; your utils appear to accept dict
|
# but we can leave dict; your utils appear to accept dict
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# seeds
|
# seeds
|
||||||
_seed_everything(int(cfg.get("i_seed", 0)))
|
seed_everything(int(cfg.get("i_seed", 0)))
|
||||||
|
|
||||||
# --- split clients' train data from a global train list ---
|
# --- split clients' train data from a global train list ---
|
||||||
# Expect either cfg["train_txt"] or <dataset_path>/train.txt
|
# Expect either cfg["train_txt"] or <dataset_path>/train.txt
|
||||||
@@ -144,13 +70,13 @@ def fed_run():
|
|||||||
clients = {}
|
clients = {}
|
||||||
|
|
||||||
for uid in users:
|
for uid in users:
|
||||||
c = FedYoloClient(name=uid, model_name=model_name, params=params)
|
c = FedYoloClient(name=uid, model_name=model_name, params=cfg)
|
||||||
c.load_trainset(user_data[uid]["filename"])
|
c.load_trainset(user_data[uid]["filename"])
|
||||||
clients[uid] = c
|
clients[uid] = c
|
||||||
|
|
||||||
# --- build server & optional validation set ---
|
# --- build server & optional validation set ---
|
||||||
server = FedYoloServer(client_list=users, model_name=model_name, params=params)
|
server = FedYoloServer(client_list=users, model_name=model_name, params=cfg)
|
||||||
valset = _build_valset_if_available(cfg, params)
|
valset = build_valset_if_available(cfg, params=cfg, args=args_cli)
|
||||||
# valset is a Dataset class, not data loader
|
# valset is a Dataset class, not data loader
|
||||||
if valset is not None:
|
if valset is not None:
|
||||||
server.load_valset(valset)
|
server.load_valset(valset)
|
||||||
@@ -186,27 +112,25 @@ def fed_run():
|
|||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
# Local training (sequential over all users)
|
# Local training (sequential over all users)
|
||||||
for uid in users:
|
for uid in users:
|
||||||
|
# tqdm desc update
|
||||||
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
|
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
|
||||||
|
|
||||||
client = clients[uid] # FedYoloClient instance
|
client = clients[uid] # FedYoloClient instance
|
||||||
client.update(global_state) # load global weights
|
client.update(global_state) # load global weights
|
||||||
state_dict, n_data, loss_dict = client.train(args_cli) # local training
|
state_dict, n_data, train_loss = client.train(args_cli) # local training
|
||||||
server.rec(uid, state_dict, n_data, loss_dict)
|
server.rec(uid, state_dict, n_data, train_loss)
|
||||||
|
|
||||||
# Select a fraction for aggregation (FedAvg subset if desired)
|
# Select a fraction for aggregation (FedAvg subset if desired)
|
||||||
server.select_clients(connection_ratio=connection_ratio)
|
server.select_clients(connection_ratio=connection_ratio)
|
||||||
|
|
||||||
# Aggregate
|
# Aggregate
|
||||||
global_state, avg_loss_dict, _ = server.agg()
|
global_state, avg_loss, _ = server.agg()
|
||||||
|
|
||||||
# Compute a scalar train loss for plotting (sum of components)
|
# Compute a scalar train loss for plotting (sum of components)
|
||||||
scalar_train_loss = float(sum(avg_loss_dict.values())) if avg_loss_dict else 0.0
|
scalar_train_loss = avg_loss if avg_loss else 0.0
|
||||||
|
|
||||||
# Test (if valset provided)
|
# Test (if valset provided)
|
||||||
test_metrics = server.test(args_cli) if server.valset is not None else {}
|
mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0)
|
||||||
mAP = float(test_metrics.get("mAP", 0.0))
|
|
||||||
mAP50 = float(test_metrics.get("mAP50", 0.0))
|
|
||||||
precision = float(test_metrics.get("precision", 0.0))
|
|
||||||
recall = float(test_metrics.get("recall", 0.0))
|
|
||||||
|
|
||||||
# Flush per-round client caches
|
# Flush per-round client caches
|
||||||
server.flush()
|
server.flush()
|
||||||
@@ -233,22 +157,23 @@ def fed_run():
|
|||||||
p_bar.set_postfix(desc)
|
p_bar.set_postfix(desc)
|
||||||
|
|
||||||
# Save running JSON (resumable logs)
|
# Save running JSON (resumable logs)
|
||||||
save_name = (
|
save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s"
|
||||||
f"[{cfg.get('fed_algo', 'FedAvg')},{cfg.get('model_name', 'yolo')},"
|
|
||||||
f"{cfg.get('num_local_epoch', cfg.get('client', {}).get('num_local_epoch', 1))},"
|
|
||||||
f"{cfg.get('num_local_class', 2)},"
|
|
||||||
f"{cfg.get('i_seed', 0)}]"
|
|
||||||
)
|
|
||||||
out_json = os.path.join(res_root, save_name + ".json")
|
out_json = os.path.join(res_root, save_name + ".json")
|
||||||
with open(out_json, "w", encoding="utf-8") as f:
|
with open(out_json, "w", encoding="utf-8") as f:
|
||||||
json.dump(history, f, indent=2)
|
json.dump(history, f, indent=4)
|
||||||
|
|
||||||
p_bar.update(1)
|
p_bar.update(1)
|
||||||
|
|
||||||
p_bar.close()
|
p_bar.close()
|
||||||
|
|
||||||
|
# Save final global model weights
|
||||||
|
if not os.path.exists("./weights"):
|
||||||
|
os.makedirs("./weights", exist_ok=True)
|
||||||
|
torch.save(global_state, f"./weights/{save_name}_final.pth")
|
||||||
|
print(f"[save] final global model weights: ./weights/{save_name}_final.pth")
|
||||||
|
|
||||||
# --- final plot ---
|
# --- final plot ---
|
||||||
_plot_curves(res_root, history)
|
plot_curves(res_root, history, savename=f"{save_name}_curve.png")
|
||||||
print("[done] training complete.")
|
print("[done] training complete.")
|
||||||
|
|
||||||
|
|
||||||
|
2
fed_run.sh
Normal file
2
fed_run.sh
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
GPUS=$1
|
||||||
|
python3 -m torch.distributed.run --nproc_per_node=$GPUS fed_run.py ${@:2}
|
3
nets/__init__.py
Normal file
3
nets/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .nn import YOLO, yolo_v11_l, yolo_v11_m, yolo_v11_s, yolo_v11_t, yolo_v11_x, yolo_v11_n
|
||||||
|
|
||||||
|
__all__ = ["YOLO", "yolo_v11_l", "yolo_v11_m", "yolo_v11_s", "yolo_v11_t", "yolo_v11_x", "yolo_v11_n"]
|
77
testcode.py
Normal file
77
testcode.py
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
from utils.fed_util import init_model
|
||||||
|
from fed_algo_cs.server_base import test
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
from utils.args import args_parser # args parser
|
||||||
|
from fed_algo_cs.client_base import FedYoloClient # FedYoloClient
|
||||||
|
from fed_algo_cs.server_base import FedYoloServer # FedYoloServer
|
||||||
|
from utils import Dataset # Dataset
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# model structure test
|
||||||
|
model = init_model("yolo_v11_n", num_classes=1)
|
||||||
|
|
||||||
|
with open("model.txt", "w", encoding="utf-8") as f:
|
||||||
|
print(model, file=f)
|
||||||
|
|
||||||
|
# loop over model key and values
|
||||||
|
with open("model_key_value.txt", "w", encoding="utf-8") as f:
|
||||||
|
for k, v in model.state_dict().items():
|
||||||
|
print(f"{k}: {v.shape}", file=f)
|
||||||
|
|
||||||
|
# test agg function
|
||||||
|
from fed_algo_cs.server_base import FedYoloServer
|
||||||
|
import torch
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f:
|
||||||
|
cfg = yaml.safe_load(f)
|
||||||
|
params = dict(cfg)
|
||||||
|
|
||||||
|
server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=params)
|
||||||
|
state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()}
|
||||||
|
state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()}
|
||||||
|
state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()}
|
||||||
|
server.rec("c1", state1, n_data=20, loss=0.1)
|
||||||
|
server.rec("c2", state2, n_data=30, loss=0.2)
|
||||||
|
server.rec("c3", state3, n_data=50, loss=0.3)
|
||||||
|
server.select_clients(connection_ratio=1.0)
|
||||||
|
model_state, avg_loss, n_data = server.agg()
|
||||||
|
with open("agg_model.txt", "w", encoding="utf-8") as f:
|
||||||
|
for k, v in model_state.items():
|
||||||
|
print(f"{k}: {v.float().mean()}", file=f)
|
||||||
|
print(f"avg_loss: {avg_loss}, n_data: {n_data}")
|
||||||
|
|
||||||
|
# test single client training (should be the same as standalone training)
|
||||||
|
args = args_parser()
|
||||||
|
with open(args.config, "r", encoding="utf-8") as f:
|
||||||
|
cfg = yaml.safe_load(f)
|
||||||
|
# params = dict(cfg)
|
||||||
|
client = FedYoloClient(name="c1", params=cfg, model_name="yolo_v11_n")
|
||||||
|
|
||||||
|
filenames = []
|
||||||
|
data_dir = "/home/image1325/ssd1/dataset/COCO128"
|
||||||
|
with open(f"{data_dir}/train.txt") as f:
|
||||||
|
for filename in f.readlines():
|
||||||
|
filename = os.path.basename(filename.rstrip())
|
||||||
|
filenames.append(f"{data_dir}/images/train2017/" + filename)
|
||||||
|
|
||||||
|
client.load_trainset(train_dataset=filenames)
|
||||||
|
model_state, n_data, avg_loss = client.train(args=args)
|
||||||
|
|
||||||
|
model = init_model("yolo_v11_n", num_classes=80)
|
||||||
|
model.load_state_dict(model_state)
|
||||||
|
|
||||||
|
valset = Dataset(
|
||||||
|
filenames=filenames,
|
||||||
|
input_size=640,
|
||||||
|
params=cfg,
|
||||||
|
augment=False,
|
||||||
|
)
|
||||||
|
if valset is not None:
|
||||||
|
precision, recall, map50, map = test(valset=valset, params=cfg, model=model, batch_size=128)
|
||||||
|
print(
|
||||||
|
f"precision: {precision}, recall: {recall}, map50: {map50}, map: {map}, loss: {avg_loss}, n_data: {n_data}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError("valset is None, please provide a valid valset in config file.")
|
3
utils/__init__.py
Normal file
3
utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
from .dataset import *
|
||||||
|
from .fed_util import *
|
||||||
|
from .util import *
|
@@ -8,16 +8,11 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torch.utils import data
|
from torch.utils import data
|
||||||
|
|
||||||
FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp", "JPEG", "JPG", "PNG", "TIFF"
|
FORMATS = "bmp", "dng", "jpeg", "jpg", "mpo", "png", "tif", "tiff", "webp"
|
||||||
|
|
||||||
|
|
||||||
class Dataset(data.Dataset):
|
class Dataset(data.Dataset):
|
||||||
params: dict
|
def __init__(self, filenames, input_size, params, augment):
|
||||||
mosaic: bool
|
|
||||||
augment: bool
|
|
||||||
input_size: int
|
|
||||||
|
|
||||||
def __init__(self, filenames, input_size: int, params: dict, augment: bool):
|
|
||||||
self.params = params
|
self.params = params
|
||||||
self.mosaic = augment
|
self.mosaic = augment
|
||||||
self.augment = augment
|
self.augment = augment
|
||||||
@@ -48,8 +43,6 @@ class Dataset(data.Dataset):
|
|||||||
else:
|
else:
|
||||||
# Load image
|
# Load image
|
||||||
image, shape = self.load_image(index)
|
image, shape = self.load_image(index)
|
||||||
if image is None:
|
|
||||||
raise ValueError(f"Failed to load image at index {index}: {self.filenames[index]}")
|
|
||||||
h, w = image.shape[:2]
|
h, w = image.shape[:2]
|
||||||
|
|
||||||
# Resize
|
# Resize
|
||||||
@@ -57,7 +50,7 @@ class Dataset(data.Dataset):
|
|||||||
|
|
||||||
label = self.labels[index].copy()
|
label = self.labels[index].copy()
|
||||||
if label.size:
|
if label.size:
|
||||||
label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, int(pad[0]), int(pad[1]))
|
label[:, 1:] = wh2xy(label[:, 1:], ratio[0] * w, ratio[1] * h, pad[0], pad[1])
|
||||||
if self.augment:
|
if self.augment:
|
||||||
image, label = random_perspective(image, label, self.params)
|
image, label = random_perspective(image, label, self.params)
|
||||||
|
|
||||||
@@ -84,25 +77,24 @@ class Dataset(data.Dataset):
|
|||||||
if nl:
|
if nl:
|
||||||
box[:, 0] = 1 - box[:, 0]
|
box[:, 0] = 1 - box[:, 0]
|
||||||
|
|
||||||
# target_cls = torch.zeros((nl, 1))
|
# XXX: when nl=0, torch.from_numpy(box) will error
|
||||||
# target_box = torch.zeros((nl, 4))
|
|
||||||
# if nl:
|
|
||||||
# target_cls = torch.from_numpy(cls)
|
|
||||||
# target_box = torch.from_numpy(box)
|
|
||||||
|
|
||||||
# fix [cls, box] empty bug. e.g. [0,1] is illegal in DataLoader collate_fn cat operation
|
|
||||||
if nl:
|
if nl:
|
||||||
target_cls = torch.from_numpy(cls).view(-1, 1).float() # always (N,1)
|
target_cls = torch.from_numpy(cls).view(-1, 1).float() # always (N,1)
|
||||||
target_box = torch.from_numpy(box).reshape(-1, 4).float() # always (N,4)
|
target_box = torch.from_numpy(box).reshape(-1, 4).float() # always (N,4)
|
||||||
else:
|
else:
|
||||||
target_cls = torch.zeros((0, 1), dtype=torch.float32)
|
target_cls = torch.zeros((0, 1), dtype=torch.float32)
|
||||||
target_box = torch.zeros((0, 4), dtype=torch.float32)
|
target_box = torch.zeros((0, 4), dtype=torch.float32)
|
||||||
|
# target_cls = torch.zeros((nl, 1))
|
||||||
|
# target_box = torch.zeros((nl, 4))
|
||||||
|
# if nl:
|
||||||
|
# target_cls = torch.from_numpy(cls)
|
||||||
|
# target_box = torch.from_numpy(box)
|
||||||
|
|
||||||
# Convert HWC to CHW, BGR to RGB
|
# Convert HWC to CHW, BGR to RGB
|
||||||
sample = image.transpose((2, 0, 1))[::-1]
|
sample = image.transpose((2, 0, 1))[::-1]
|
||||||
sample = numpy.ascontiguousarray(sample)
|
sample = numpy.ascontiguousarray(sample)
|
||||||
|
|
||||||
# init: return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl)
|
# return torch.from_numpy(sample), target_cls, target_box, torch.zeros(nl)
|
||||||
return torch.from_numpy(sample), target_cls, target_box, torch.zeros((nl, 1), dtype=torch.long)
|
return torch.from_numpy(sample), target_cls, target_box, torch.zeros((nl, 1), dtype=torch.long)
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@@ -111,7 +103,7 @@ class Dataset(data.Dataset):
|
|||||||
def load_image(self, i):
|
def load_image(self, i):
|
||||||
image = cv2.imread(self.filenames[i])
|
image = cv2.imread(self.filenames[i])
|
||||||
if image is None:
|
if image is None:
|
||||||
raise ValueError(f"Image not found or unable to open: {self.filenames[i]}")
|
raise FileNotFoundError(f"Image Not Found {self.filenames[i]}")
|
||||||
h, w = image.shape[:2]
|
h, w = image.shape[:2]
|
||||||
r = self.input_size / max(h, w)
|
r = self.input_size / max(h, w)
|
||||||
if r != 1:
|
if r != 1:
|
||||||
@@ -173,8 +165,8 @@ class Dataset(data.Dataset):
|
|||||||
x2b = min(shape[1], x2a - x1a)
|
x2b = min(shape[1], x2a - x1a)
|
||||||
y2b = min(y2a - y1a, shape[0])
|
y2b = min(y2a - y1a, shape[0])
|
||||||
|
|
||||||
pad_w = (x1a if x1a is not None else 0) - (x1b if x1b is not None else 0)
|
pad_w = x1a - x1b
|
||||||
pad_h = (y1a if y1a is not None else 0) - (y1b if y1b is not None else 0)
|
pad_h = y1a - y1b
|
||||||
image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]
|
image4[y1a:y2a, x1a:x2a] = image[y1b:y2b, x1b:x2b]
|
||||||
|
|
||||||
# Labels
|
# Labels
|
||||||
@@ -197,14 +189,8 @@ class Dataset(data.Dataset):
|
|||||||
def collate_fn(batch):
|
def collate_fn(batch):
|
||||||
samples, cls, box, indices = zip(*batch)
|
samples, cls, box, indices = zip(*batch)
|
||||||
|
|
||||||
# ensure empty tensor shape is correct
|
cls = torch.cat(cls, dim=0)
|
||||||
cls = [c.view(-1, 1) for c in cls]
|
box = torch.cat(box, dim=0)
|
||||||
box = [b.reshape(-1, 4) for b in box]
|
|
||||||
indices = [i for i in indices]
|
|
||||||
|
|
||||||
cls = torch.cat(cls, dim=0) if cls else torch.zeros((0, 1))
|
|
||||||
box = torch.cat(box, dim=0) if box else torch.zeros((0, 4))
|
|
||||||
indices = torch.cat(indices, dim=0) if indices else torch.zeros((0,), dtype=torch.long)
|
|
||||||
|
|
||||||
new_indices = list(indices)
|
new_indices = list(indices)
|
||||||
for i in range(len(indices)):
|
for i in range(len(indices)):
|
||||||
@@ -215,7 +201,7 @@ class Dataset(data.Dataset):
|
|||||||
return torch.stack(samples, dim=0), targets
|
return torch.stack(samples, dim=0), targets
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def load_label_use_cache(filenames):
|
def load_label(filenames):
|
||||||
path = f"{os.path.dirname(filenames[0])}.cache"
|
path = f"{os.path.dirname(filenames[0])}.cache"
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
return torch.load(path, weights_only=False)
|
return torch.load(path, weights_only=False)
|
||||||
@@ -228,14 +214,11 @@ class Dataset(data.Dataset):
|
|||||||
image.verify() # PIL verify
|
image.verify() # PIL verify
|
||||||
shape = image.size # image size
|
shape = image.size # image size
|
||||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
||||||
assert image.format is not None and image.format.lower() in FORMATS, (
|
assert image.format.lower() in FORMATS, f"invalid image format {image.format}"
|
||||||
f"invalid image format {image.format}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# verify labels
|
# verify labels
|
||||||
a = f"{os.sep}images{os.sep}"
|
a = f"{os.sep}images{os.sep}"
|
||||||
b = f"{os.sep}labels{os.sep}"
|
b = f"{os.sep}labels{os.sep}"
|
||||||
|
|
||||||
if os.path.isfile(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"):
|
if os.path.isfile(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"):
|
||||||
with open(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt") as f:
|
with open(b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt") as f:
|
||||||
label = [x.split() for x in f.read().strip().splitlines() if len(x)]
|
label = [x.split() for x in f.read().strip().splitlines() if len(x)]
|
||||||
@@ -260,50 +243,6 @@ class Dataset(data.Dataset):
|
|||||||
torch.save(x, path)
|
torch.save(x, path)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def load_label(filenames):
|
|
||||||
x = {}
|
|
||||||
for filename in filenames:
|
|
||||||
try:
|
|
||||||
# verify images
|
|
||||||
with open(filename, "rb") as f:
|
|
||||||
image = Image.open(f)
|
|
||||||
image.verify()
|
|
||||||
shape = image.size
|
|
||||||
assert (shape[0] > 9) & (shape[1] > 9), f"image size {shape} <10 pixels"
|
|
||||||
assert image.format is not None and image.format.lower() in FORMATS, (
|
|
||||||
f"invalid image format {image.format}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# verify labels
|
|
||||||
a = f"{os.sep}images{os.sep}"
|
|
||||||
b = f"{os.sep}labels{os.sep}"
|
|
||||||
label_path = b.join(filename.rsplit(a, 1)).rsplit(".", 1)[0] + ".txt"
|
|
||||||
|
|
||||||
if os.path.isfile(label_path):
|
|
||||||
rows = []
|
|
||||||
with open(label_path) as f:
|
|
||||||
for line in f:
|
|
||||||
parts = line.strip().split()
|
|
||||||
if len(parts) == 5: # YOLO format
|
|
||||||
rows.append([float(x) for x in parts])
|
|
||||||
label = numpy.array(rows, dtype=numpy.float32) if rows else numpy.zeros((0, 5), dtype=numpy.float32)
|
|
||||||
|
|
||||||
if label.shape[0]:
|
|
||||||
assert (label >= 0).all()
|
|
||||||
assert label.shape[1] == 5
|
|
||||||
assert (label[:, 1:] <= 1.0001).all()
|
|
||||||
_, i = numpy.unique(label, axis=0, return_index=True)
|
|
||||||
label = label[i]
|
|
||||||
else:
|
|
||||||
label = numpy.zeros((0, 5), dtype=numpy.float32)
|
|
||||||
|
|
||||||
except (FileNotFoundError, AssertionError):
|
|
||||||
label = numpy.zeros((0, 5), dtype=numpy.float32)
|
|
||||||
|
|
||||||
x[filename] = label
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0):
|
def wh2xy(x, w=640, h=640, pad_w=0, pad_h=0):
|
||||||
# Convert nx4 boxes
|
# Convert nx4 boxes
|
||||||
@@ -462,9 +401,7 @@ class Albumentations:
|
|||||||
albumentations.ToGray(p=0.01),
|
albumentations.ToGray(p=0.01),
|
||||||
albumentations.MedianBlur(p=0.01),
|
albumentations.MedianBlur(p=0.01),
|
||||||
]
|
]
|
||||||
self.transform = albumentations.Compose(
|
self.transform = albumentations.Compose(transforms, albumentations.BboxParams("yolo", ["class_labels"]))
|
||||||
transforms, albumentations.BboxParams(format="yolo", label_fields=["class_labels"])
|
|
||||||
)
|
|
||||||
|
|
||||||
except ImportError: # package not installed, skip
|
except ImportError: # package not installed, skip
|
||||||
pass
|
pass
|
||||||
|
@@ -1,10 +1,15 @@
|
|||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import random
|
import random
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
from utils.dataset import Dataset
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Optional, Set, Any
|
from typing import Dict, List, Optional, Set, Any
|
||||||
|
|
||||||
from nets import nn
|
from nets import nn
|
||||||
|
from nets import YOLO
|
||||||
|
|
||||||
|
|
||||||
def _image_to_label_path(img_path: str) -> str:
|
def _image_to_label_path(img_path: str) -> str:
|
||||||
@@ -59,6 +64,14 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]:
|
|||||||
return class_ids
|
return class_ids
|
||||||
|
|
||||||
|
|
||||||
|
def _read_list_file(txt_path: str):
|
||||||
|
"""Read one path per line; keep as-is (absolute or relative)."""
|
||||||
|
if not txt_path or not os.path.exists(txt_path):
|
||||||
|
return []
|
||||||
|
with open(txt_path, "r", encoding="utf-8") as f:
|
||||||
|
return [ln.strip() for ln in f if ln.strip()]
|
||||||
|
|
||||||
|
|
||||||
def divide_trainset(
|
def divide_trainset(
|
||||||
trainset_path: str,
|
trainset_path: str,
|
||||||
num_local_class: int,
|
num_local_class: int,
|
||||||
@@ -230,7 +243,7 @@ def divide_trainset(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
def init_model(model_name, num_classes):
|
def init_model(model_name, num_classes) -> YOLO:
|
||||||
"""
|
"""
|
||||||
Initialize the model for a specific learning task
|
Initialize the model for a specific learning task
|
||||||
Args:
|
Args:
|
||||||
@@ -252,3 +265,74 @@ def init_model(model_name, num_classes):
|
|||||||
raise ValueError("Model {} is not supported.".format(model_name))
|
raise ValueError("Model {} is not supported.".format(model_name))
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]:
|
||||||
|
"""
|
||||||
|
Try to build a validation Dataset.
|
||||||
|
- If cfg['val_txt'] exists, use it.
|
||||||
|
- Else if <dataset_path>/val.txt exists, use it.
|
||||||
|
- Else return None (testing will be skipped).
|
||||||
|
Args:
|
||||||
|
cfg: config dict
|
||||||
|
params: params dict for Dataset
|
||||||
|
Returns:
|
||||||
|
Dataset or None
|
||||||
|
"""
|
||||||
|
input_size = args.input_size if args and hasattr(args, "input_size") else 640
|
||||||
|
val_txt = cfg.get("val_txt", "")
|
||||||
|
if not val_txt:
|
||||||
|
ds_root = cfg.get("dataset_path", "")
|
||||||
|
guess = os.path.join(ds_root, "val.txt") if ds_root else ""
|
||||||
|
val_txt = guess if os.path.exists(guess) else ""
|
||||||
|
|
||||||
|
val_files = _read_list_file(val_txt)
|
||||||
|
if not val_files:
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
warnings.warn("No validation dataset found.")
|
||||||
|
return None
|
||||||
|
|
||||||
|
return Dataset(
|
||||||
|
filenames=val_files,
|
||||||
|
input_size=input_size,
|
||||||
|
params=params,
|
||||||
|
augment=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def seed_everything(seed: int):
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
random.seed(seed)
|
||||||
|
|
||||||
|
|
||||||
|
def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"):
|
||||||
|
"""
|
||||||
|
Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round.
|
||||||
|
Args:
|
||||||
|
save_dir: directory to save the plot
|
||||||
|
hist: history dict with keys "mAP", "mAP50", "precision", "recall", "train_loss"
|
||||||
|
savename: output filename
|
||||||
|
"""
|
||||||
|
os.makedirs(save_dir, exist_ok=True)
|
||||||
|
rounds = np.arange(1, len(hist["mAP"]) + 1)
|
||||||
|
|
||||||
|
plt.figure()
|
||||||
|
if hist["mAP"]:
|
||||||
|
plt.plot(rounds, hist["mAP"], label="mAP50-95")
|
||||||
|
if hist["mAP50"]:
|
||||||
|
plt.plot(rounds, hist["mAP50"], label="mAP50")
|
||||||
|
if hist["precision"]:
|
||||||
|
plt.plot(rounds, hist["precision"], label="precision")
|
||||||
|
if hist["recall"]:
|
||||||
|
plt.plot(rounds, hist["recall"], label="recall")
|
||||||
|
if hist["train_loss"]:
|
||||||
|
plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)")
|
||||||
|
plt.xlabel("Global Round")
|
||||||
|
plt.ylabel("Metric")
|
||||||
|
plt.title("Federated YOLO - Server Metrics")
|
||||||
|
plt.legend()
|
||||||
|
out_png = os.path.join(save_dir, savename)
|
||||||
|
plt.savefig(out_png, dpi=150, bbox_inches="tight")
|
||||||
|
print(f"[plot] saved: {out_png}")
|
||||||
|
@@ -1,7 +1,3 @@
|
|||||||
"""
|
|
||||||
Utility functions for yolo.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import random
|
import random
|
||||||
from time import time
|
from time import time
|
||||||
@@ -97,7 +93,7 @@ def make_anchors(x, strides, offset=0.5):
|
|||||||
_, _, h, w = x[i].shape
|
_, _, h, w = x[i].shape
|
||||||
sx = torch.arange(end=w, device=device, dtype=dtype) + offset # shift x
|
sx = torch.arange(end=w, device=device, dtype=dtype) + offset # shift x
|
||||||
sy = torch.arange(end=h, device=device, dtype=dtype) + offset # shift y
|
sy = torch.arange(end=h, device=device, dtype=dtype) + offset # shift y
|
||||||
sy, sx = torch.meshgrid(sy, sx, indexing="ij")
|
sy, sx = torch.meshgrid(sy, sx)
|
||||||
anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2))
|
anchor_tensor.append(torch.stack((sx, sy), -1).view(-1, 2))
|
||||||
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
stride_tensor.append(torch.full((h * w, 1), stride, dtype=dtype, device=device))
|
||||||
return torch.cat(anchor_tensor), torch.cat(stride_tensor)
|
return torch.cat(anchor_tensor), torch.cat(stride_tensor)
|
||||||
@@ -151,7 +147,7 @@ def non_max_suppression(outputs, confidence_threshold=0.001, iou_threshold=0.65)
|
|||||||
box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2)
|
box = wh2xy(box) # (cx, cy, w, h) to (x1, y1, x2, y2)
|
||||||
if nc > 1:
|
if nc > 1:
|
||||||
i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T
|
i, j = (cls > confidence_threshold).nonzero(as_tuple=False).T
|
||||||
x = torch.cat((box[i], x[i, 4 + j].unsqueeze(1), j[:, None].float()), dim=1)
|
x = torch.cat((box[i], x[i, 4 + j, None], j[:, None].float()), 1)
|
||||||
else: # best class only
|
else: # best class only
|
||||||
conf, j = cls.max(1, keepdim=True)
|
conf, j = cls.max(1, keepdim=True)
|
||||||
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold]
|
x = torch.cat((box, conf, j.float()), 1)[conf.view(-1) > confidence_threshold]
|
||||||
@@ -195,13 +191,7 @@ def plot_pr_curve(px, py, ap, names, save_dir):
|
|||||||
else:
|
else:
|
||||||
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
|
ax.plot(px, py, linewidth=1, color="grey") # plot(recall, precision)
|
||||||
|
|
||||||
ax.plot(
|
ax.plot(px, py.mean(1), linewidth=3, color="blue", label="all classes %.3f mAP@0.5" % ap[:, 0].mean())
|
||||||
px,
|
|
||||||
py.mean(1),
|
|
||||||
linewidth=3,
|
|
||||||
color="blue",
|
|
||||||
label="all classes %.3f mAP@0.5" % ap[:, 0].mean(),
|
|
||||||
)
|
|
||||||
ax.set_xlabel("Recall")
|
ax.set_xlabel("Recall")
|
||||||
ax.set_ylabel("Precision")
|
ax.set_ylabel("Precision")
|
||||||
ax.set_xlim(0, 1)
|
ax.set_xlim(0, 1)
|
||||||
@@ -224,13 +214,7 @@ def plot_curve(px, py, names, save_dir, x_label="Confidence", y_label="Metric"):
|
|||||||
ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
|
ax.plot(px, py.T, linewidth=1, color="grey") # plot(confidence, metric)
|
||||||
|
|
||||||
y = smooth(py.mean(0), f=0.05)
|
y = smooth(py.mean(0), f=0.05)
|
||||||
ax.plot(
|
ax.plot(px, y, linewidth=3, color="blue", label=f"all classes {y.max():.3f} at {px[y.argmax()]:.3f}")
|
||||||
px,
|
|
||||||
y,
|
|
||||||
linewidth=3,
|
|
||||||
color="blue",
|
|
||||||
label=f"all classes {y.max():.3f} at {px[y.argmax()]:.3f}",
|
|
||||||
)
|
|
||||||
ax.set_xlabel(x_label)
|
ax.set_xlabel(x_label)
|
||||||
ax.set_ylabel(y_label)
|
ax.set_ylabel(y_label)
|
||||||
ax.set_xlim(0, 1)
|
ax.set_xlim(0, 1)
|
||||||
@@ -296,8 +280,7 @@ def compute_ap(tp, conf, output, target, plot=False, names=(), eps=1e-16):
|
|||||||
|
|
||||||
# Integrate area under curve
|
# Integrate area under curve
|
||||||
x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO)
|
x = numpy.linspace(start=0, stop=1, num=101) # 101-point interp (COCO)
|
||||||
# numpy.trapz is deprecated in numpy 2.0.0 or after version, use numpy.trapezoid instead
|
ap[ci, j] = numpy.trapz(numpy.interp(x, m_rec, m_pre), x) # integrate
|
||||||
ap[ci, j] = numpy.trapezoid(numpy.interp(x, m_rec, m_pre), x) # integrate
|
|
||||||
if plot and j == 0:
|
if plot and j == 0:
|
||||||
py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5
|
py.append(numpy.interp(px, m_rec, m_pre)) # precision at mAP@0.5
|
||||||
|
|
||||||
@@ -443,7 +426,7 @@ class LinearLR:
|
|||||||
min_lr = params["min_lr"]
|
min_lr = params["min_lr"]
|
||||||
|
|
||||||
warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100))
|
warmup_steps = int(max(params["warmup_epochs"] * num_steps, 100))
|
||||||
decay_steps = max(1, int(args.epochs * num_steps - warmup_steps))
|
decay_steps = int(args.epochs * num_steps - warmup_steps)
|
||||||
|
|
||||||
warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps), endpoint=False)
|
warmup_lr = numpy.linspace(min_lr, max_lr, int(warmup_steps), endpoint=False)
|
||||||
decay_lr = numpy.linspace(max_lr, min_lr, decay_steps)
|
decay_lr = numpy.linspace(max_lr, min_lr, decay_steps)
|
||||||
@@ -528,16 +511,8 @@ class Assigner(torch.nn.Module):
|
|||||||
mask_in_gts = mask_in_gts.view(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps)
|
mask_in_gts = mask_in_gts.view(shape[0], shape[1], num_anchors, -1).amin(3).gt_(self.eps)
|
||||||
na = pd_bboxes.shape[-2]
|
na = pd_bboxes.shape[-2]
|
||||||
gt_mask = (mask_in_gts * mask_gt).bool() # b, max_num_obj, h*w
|
gt_mask = (mask_in_gts * mask_gt).bool() # b, max_num_obj, h*w
|
||||||
overlaps = torch.zeros(
|
overlaps = torch.zeros([batch_size, num_max_boxes, na], dtype=pd_bboxes.dtype, device=pd_bboxes.device)
|
||||||
[batch_size, num_max_boxes, na],
|
bbox_scores = torch.zeros([batch_size, num_max_boxes, na], dtype=pd_scores.dtype, device=pd_scores.device)
|
||||||
dtype=pd_bboxes.dtype,
|
|
||||||
device=pd_bboxes.device,
|
|
||||||
)
|
|
||||||
bbox_scores = torch.zeros(
|
|
||||||
[batch_size, num_max_boxes, na],
|
|
||||||
dtype=pd_scores.dtype,
|
|
||||||
device=pd_scores.device,
|
|
||||||
)
|
|
||||||
|
|
||||||
ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long) # 2, b, max_num_obj
|
ind = torch.zeros([2, batch_size, num_max_boxes], dtype=torch.long) # 2, b, max_num_obj
|
||||||
ind[0] = torch.arange(end=batch_size).view(-1, 1).expand(-1, num_max_boxes) # b, max_num_obj
|
ind[0] = torch.arange(end=batch_size).view(-1, 1).expand(-1, num_max_boxes) # b, max_num_obj
|
||||||
@@ -587,9 +562,7 @@ class Assigner(torch.nn.Module):
|
|||||||
target_labels.clamp_(0)
|
target_labels.clamp_(0)
|
||||||
|
|
||||||
target_scores = torch.zeros(
|
target_scores = torch.zeros(
|
||||||
(target_labels.shape[0], target_labels.shape[1], self.nc),
|
(target_labels.shape[0], target_labels.shape[1], self.nc), dtype=torch.int64, device=target_labels.device
|
||||||
dtype=torch.int64,
|
|
||||||
device=target_labels.device,
|
|
||||||
)
|
)
|
||||||
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
target_scores.scatter_(2, target_labels.unsqueeze(-1), 1)
|
||||||
|
|
||||||
@@ -672,16 +645,7 @@ class BoxLoss(torch.nn.Module):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.dfl_ch = dfl_ch
|
self.dfl_ch = dfl_ch
|
||||||
|
|
||||||
def forward(
|
def forward(self, pred_dist, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask):
|
||||||
self,
|
|
||||||
pred_dist,
|
|
||||||
pred_bboxes,
|
|
||||||
anchor_points,
|
|
||||||
target_bboxes,
|
|
||||||
target_scores,
|
|
||||||
target_scores_sum,
|
|
||||||
fg_mask,
|
|
||||||
):
|
|
||||||
# IoU loss
|
# IoU loss
|
||||||
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
|
weight = torch.masked_select(target_scores.sum(-1), fg_mask).unsqueeze(-1)
|
||||||
iou = compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
|
iou = compute_iou(pred_bboxes[fg_mask], target_bboxes[fg_mask])
|
||||||
@@ -803,13 +767,7 @@ class ComputeLoss:
|
|||||||
if fg_mask.sum():
|
if fg_mask.sum():
|
||||||
target_bboxes /= stride_tensor
|
target_bboxes /= stride_tensor
|
||||||
loss_box, loss_dfl = self.box_loss(
|
loss_box, loss_dfl = self.box_loss(
|
||||||
pred_distri,
|
pred_distri, pred_bboxes, anchor_points, target_bboxes, target_scores, target_scores_sum, fg_mask
|
||||||
pred_bboxes,
|
|
||||||
anchor_points,
|
|
||||||
target_bboxes,
|
|
||||||
target_scores,
|
|
||||||
target_scores_sum,
|
|
||||||
fg_mask,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_box *= self.params["box"] # box gain
|
loss_box *= self.params["box"] # box gain
|
||||||
|
Reference in New Issue
Block a user