diff --git a/fed_algo_cs/client_base.py b/fed_algo_cs/client_base.py index 433d6b3..d8777ff 100644 --- a/fed_algo_cs/client_base.py +++ b/fed_algo_cs/client_base.py @@ -3,6 +3,7 @@ import torch from torch import nn from torch.utils import data from torch.amp.autocast_mode import autocast +from tqdm import tqdm from utils.fed_util import init_model from utils import util from utils.dataset import Dataset @@ -152,7 +153,6 @@ class FedYoloClient(object): # Scheduler num_steps = max(1, len(loader)) - # print(len(loader)) scheduler = util.LinearLR(args=args, params=self.params, num_steps=num_steps) # DDP mode if args.distributed: @@ -167,7 +167,12 @@ class FedYoloClient(object): amp_scale = torch.amp.grad_scaler.GradScaler(enabled=True) criterion = util.ComputeLoss(self.model, self.params) - optimizer.zero_grad(set_to_none=True) + # log + # if args.local_rank == 0: + # header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl") + # print("\n" + header) + # p_bar = tqdm(total=args.epochs * num_steps, ncols=120) + # p_bar.set_description(f"{self.name:>10}") for epoch in range(args.epochs): self.model.train() @@ -180,10 +185,20 @@ class FedYoloClient(object): ds = cast(Dataset, loader.dataset) ds.mosaic = False + optimizer.zero_grad(set_to_none=True) avg_box_loss = util.AverageMeter() avg_cls_loss = util.AverageMeter() avg_dfl_loss = util.AverageMeter() + # # --- header (once per epoch, YOLO-style) --- + # if args.local_rank == 0: + # header = ("%10s" * 5) % ("client", "memory", "box", "cls", "dfl") + # print("\n" + header) + + # p_bar = enumerate(loader) + # if args.local_rank == 0: + # p_bar = tqdm(p_bar, total=num_steps, ncols=120) + for i, (samples, targets) in enumerate(loader): global_step = i + num_steps * epoch scheduler.step(step=global_step, optimizer=optimizer) @@ -195,24 +210,26 @@ class FedYoloClient(object): outputs = self.model(samples) box_loss, cls_loss, dfl_loss = criterion(outputs, targets) - # meters (use the *unscaled* values) - bs = samples.size(0) - avg_box_loss.update(box_loss.item(), bs) - avg_cls_loss.update(cls_loss.item(), bs) - avg_dfl_loss.update(dfl_loss.item(), bs) + # meters (use the *unscaled* values) + bs = samples.size(0) + avg_box_loss.update(box_loss.item(), bs) + avg_cls_loss.update(cls_loss.item(), bs) + avg_dfl_loss.update(dfl_loss.item(), bs) - # scale losses by batch/world if your loss is averaged internally per-sample/device - box_loss = box_loss * self._batch_size * args.world_size - cls_loss = cls_loss * self._batch_size * args.world_size - dfl_loss = dfl_loss * self._batch_size * args.world_size + # scale losses by batch/world if your loss is averaged internally per-sample/device + # box_loss = box_loss * self._batch_size * args.world_size + # cls_loss = cls_loss * self._batch_size * args.world_size + # dfl_loss = dfl_loss * self._batch_size * args.world_size - total_loss = box_loss + cls_loss + dfl_loss + total_loss = box_loss + cls_loss + dfl_loss # Backward amp_scale.scale(total_loss).backward() # Optimize if (i + 1) % accumulate == 0: + amp_scale.unscale_(optimizer) # unscale gradients + util.clip_gradients(model=self.model, max_norm=10.0) # clip gradients amp_scale.step(optimizer) amp_scale.update() optimizer.zero_grad(set_to_none=True) @@ -221,13 +238,28 @@ class FedYoloClient(object): # torch.cuda.synchronize() + # tqdm update + # if args.local_rank == 0: + # mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G" + # desc = ("%10s" * 2 + "%10.4g" * 3) % ( + # self.name, + # mem, + # avg_box_loss.avg, + # avg_cls_loss.avg, + # avg_dfl_loss.avg, + # ) + # cast(tqdm, p_bar).set_description(desc) + # p_bar.update(1) + + # p_bar.close() + # clean if args.distributed: torch.distributed.destroy_process_group() torch.cuda.empty_cache() return ( - self.model.state_dict(), + self.model.state_dict() if not ema else ema.ema.state_dict(), self.n_data, {"box_loss": avg_box_loss.avg, "cls_loss": avg_cls_loss.avg, "dfl_loss": avg_dfl_loss.avg}, ) diff --git a/fed_run.py b/fed_run.py index 51b35ce..103d5df 100644 --- a/fed_run.py +++ b/fed_run.py @@ -13,8 +13,8 @@ import matplotlib.pyplot as plt from utils.dataset import Dataset from fed_algo_cs.client_base import FedYoloClient from fed_algo_cs.server_base import FedYoloServer -from utils.args import args_parser # your args parser -from utils.fed_util import divide_trainset # divide_trainset is yours +from utils.args import args_parser # args parser +from utils.fed_util import divide_trainset # divide_trainset def _read_list_file(txt_path: str): @@ -132,7 +132,7 @@ def fed_run(): num_client=int(cfg.get("num_client", 64)), min_data=int(cfg.get("min_data", 100)), max_data=int(cfg.get("max_data", 100)), - mode=str(cfg.get("partition_mode", "disjoint")), # "overlap" or "disjoint" + mode=str(cfg.get("partition_mode", "overlap")), # "overlap" or "disjoint" seed=int(cfg.get("i_seed", 0)), ) @@ -143,7 +143,7 @@ def fed_run(): model_name = cfg.get("model_name", "yolo_v11_n") clients = {} - for uid in tqdm(users, desc="Building clients", leave=True, unit="client"): + for uid in users: c = FedYoloClient(name=uid, model_name=model_name, params=params) c.load_trainset(user_data[uid]["filename"]) clients[uid] = c @@ -177,11 +177,16 @@ def fed_run(): res_root = cfg.get("res_root", "results") os.makedirs(res_root, exist_ok=True) - for rnd in tqdm(range(num_round), desc="main federal loop round:"): - t0 = time.time() + # tqdm logging + header = ("%10s" * 2) % ("Round", "client") + tqdm.write("\n" + header) + p_bar = tqdm(total=num_round, ncols=160, ascii="->>") + for rnd in range(num_round): + t0 = time.time() # Local training (sequential over all users) - for uid in tqdm(users, desc=f"Round {rnd + 1} local training: ", leave=False): + for uid in users: + p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}")) client = clients[uid] # FedYoloClient instance client.update(global_state) # load global weights state_dict, n_data, loss_dict = client.train(args_cli) # local training @@ -214,12 +219,18 @@ def fed_run(): history["train_loss"].append(scalar_train_loss) history["round_time_sec"].append(time.time() - t0) - tqdm.write( - f"[round {rnd + 1:04d}] " - f"loss={scalar_train_loss:.4f} mAP50-95={mAP:.4f} mAP50={mAP50:.4f} " - f"P={precision:.4f} R={recall:.4f}" - f"\n" - ) + # Log GPU memory usage + # gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G" + # tqdm update + desc = { + "loss": f"{scalar_train_loss:.6g}", + "mAP50": f"{mAP50:.6g}", + "mAP": f"{mAP:.6g}", + "precision": f"{precision:.6g}", + "recall": f"{recall:.6g}", + # "gpu_mem": gpu_mem, + } + p_bar.set_postfix(desc) # Save running JSON (resumable logs) save_name = ( @@ -232,6 +243,10 @@ def fed_run(): with open(out_json, "w", encoding="utf-8") as f: json.dump(history, f, indent=2) + p_bar.update(1) + + p_bar.close() + # --- final plot --- _plot_curves(res_root, history) print("[done] training complete.")