优化fed_run函数中的进度条显示和训练过程中的日志记录
This commit is contained in:
@@ -3,6 +3,7 @@ import torch
|
||||
from torch import nn
|
||||
from torch.utils import data
|
||||
from torch.amp.autocast_mode import autocast
|
||||
from tqdm import tqdm
|
||||
from utils.fed_util import init_model
|
||||
from utils import util
|
||||
from utils.dataset import Dataset
|
||||
@@ -152,7 +153,6 @@ class FedYoloClient(object):
|
||||
|
||||
# Scheduler
|
||||
num_steps = max(1, len(loader))
|
||||
# print(len(loader))
|
||||
scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps)
|
||||
# DDP mode
|
||||
if args.distributed:
|
||||
@@ -167,7 +167,12 @@ class FedYoloClient(object):
|
||||
amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True)
|
||||
criterion = util.ComputeLoss(self.model, self.params)
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
# log
|
||||
# if args.local_rank == 0:
|
||||
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
|
||||
# print("\n" + header)
|
||||
# p_bar = tqdm(total=args.epochs * num_steps, ncols=120)
|
||||
# p_bar.set_description(f"{self.name:>10}")
|
||||
|
||||
for epoch in range(args.epochs):
|
||||
self.model.train()
|
||||
@@ -180,10 +185,20 @@ class FedYoloClient(object):
|
||||
ds = cast(Dataset, loader.dataset)
|
||||
ds.mosaic = False
|
||||
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
avg_box_loss = util.AverageMeter()
|
||||
avg_cls_loss = util.AverageMeter()
|
||||
avg_dfl_loss = util.AverageMeter()
|
||||
|
||||
# # --- header (once per epoch, YOLO-style) ---
|
||||
# if args.local_rank == 0:
|
||||
# header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl")
|
||||
# print("\n" + header)
|
||||
|
||||
# p_bar = enumerate(loader)
|
||||
# if args.local_rank == 0:
|
||||
# p_bar = tqdm(p_bar, total=num_steps, ncols=120)
|
||||
|
||||
for i, (samples, targets) in enumerate(loader):
|
||||
global_step = i + num_steps * epoch
|
||||
scheduler.step(step=global_step, optimizer=optimizer)
|
||||
@@ -195,24 +210,26 @@ class FedYoloClient(object):
|
||||
outputs = self.model(samples)
|
||||
box_loss, cls_loss, dfl_loss = criterion(outputs, targets)
|
||||
|
||||
# meters (use the *unscaled* values)
|
||||
bs = samples.size(0)
|
||||
avg_box_loss.update(box_loss.item(), bs)
|
||||
avg_cls_loss.update(cls_loss.item(), bs)
|
||||
avg_dfl_loss.update(dfl_loss.item(), bs)
|
||||
# meters (use the *unscaled* values)
|
||||
bs = samples.size(0)
|
||||
avg_box_loss.update(box_loss.item(), bs)
|
||||
avg_cls_loss.update(cls_loss.item(), bs)
|
||||
avg_dfl_loss.update(dfl_loss.item(), bs)
|
||||
|
||||
# scale losses by batch/world if your loss is averaged internally per-sample/device
|
||||
box_loss = box_loss * self._batch_size * args.world_size
|
||||
cls_loss = cls_loss * self._batch_size * args.world_size
|
||||
dfl_loss = dfl_loss * self._batch_size * args.world_size
|
||||
# scale losses by batch/world if your loss is averaged internally per-sample/device
|
||||
# box_loss = box_loss * self._batch_size * args.world_size
|
||||
# cls_loss = cls_loss * self._batch_size * args.world_size
|
||||
# dfl_loss = dfl_loss * self._batch_size * args.world_size
|
||||
|
||||
total_loss = box_loss + cls_loss + dfl_loss
|
||||
total_loss = box_loss + cls_loss + dfl_loss
|
||||
|
||||
# Backward
|
||||
amp_scale.scale(total_loss).backward()
|
||||
|
||||
# Optimize
|
||||
if (i + 1) % accumulate == 0:
|
||||
amp_scale.unscale_(optimizer) # unscale gradients
|
||||
util.clip_gradients(model=self.model, max_norm=10.0) # clip gradients
|
||||
amp_scale.step(optimizer)
|
||||
amp_scale.update()
|
||||
optimizer.zero_grad(set_to_none=True)
|
||||
@@ -221,13 +238,28 @@ class FedYoloClient(object):
|
||||
|
||||
# torch.cuda.synchronize()
|
||||
|
||||
# tqdm update
|
||||
# if args.local_rank == 0:
|
||||
# mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
|
||||
# desc = ("%10s" * 2 + "%10.4g" * 3) % (
|
||||
# self.name,
|
||||
# mem,
|
||||
# avg_box_loss.avg,
|
||||
# avg_cls_loss.avg,
|
||||
# avg_dfl_loss.avg,
|
||||
# )
|
||||
# cast(tqdm, p_bar).set_description(desc)
|
||||
# p_bar.update(1)
|
||||
|
||||
# p_bar.close()
|
||||
|
||||
# clean
|
||||
if args.distributed:
|
||||
torch.distributed.destroy_process_group()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
return (
|
||||
self.model.state_dict(),
|
||||
self.model.state_dict() if not ema else ema.ema.state_dict(),
|
||||
self.n_data,
|
||||
{"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg},
|
||||
)
|
||||
|
Reference in New Issue
Block a user