From 314f46d54220d7bfa097e09dd61fe8e910c40ac6 Mon Sep 17 00:00:00 2001 From: TY1667 Date: Sun, 19 Oct 2025 21:30:45 +0800 Subject: [PATCH] =?UTF-8?q?=E9=87=8D=E6=9E=84fed=5Frun.py=EF=BC=8C?= =?UTF-8?q?=E7=A7=BB=E9=99=A4=E5=86=97=E4=BD=99=E5=87=BD=E6=95=B0=EF=BC=8C?= =?UTF-8?q?=E4=BC=A0=E5=8F=82BUG=E4=BF=AE=E5=A4=8D=EF=BC=8C=E6=9B=B4?= =?UTF-8?q?=E6=96=B0=E6=A8=A1=E5=9E=8B=E6=9D=83=E9=87=8D=E4=BF=9D=E5=AD=98?= =?UTF-8?q?=E9=80=BB=E8=BE=91=EF=BC=9B=E6=96=B0=E5=A2=9Efed=5Frun.sh?= =?UTF-8?q?=E8=84=9A=E6=9C=AC=E4=BB=A5=E6=94=AF=E6=8C=81=E5=88=86=E5=B8=83?= =?UTF-8?q?=E5=BC=8F=E8=AE=AD=E7=BB=83?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- fed_run.py | 123 +++++++++++------------------------------------------ fed_run.sh | 2 + 2 files changed, 26 insertions(+), 99 deletions(-) create mode 100644 fed_run.sh diff --git a/fed_run.py b/fed_run.py index 103d5df..c853f0a 100644 --- a/fed_run.py +++ b/fed_run.py @@ -3,92 +3,16 @@ import os import json import yaml import time -import random from tqdm import tqdm - -import numpy as np import torch -import matplotlib.pyplot as plt -from utils.dataset import Dataset +from utils.fed_util import build_valset_if_available, seed_everything, plot_curves from fed_algo_cs.client_base import FedYoloClient from fed_algo_cs.server_base import FedYoloServer from utils.args import args_parser # args parser from utils.fed_util import divide_trainset # divide_trainset -def _read_list_file(txt_path: str): - """Read one path per line; keep as-is (absolute or relative).""" - if not txt_path or not os.path.exists(txt_path): - return [] - with open(txt_path, "r", encoding="utf-8") as f: - return [ln.strip() for ln in f if ln.strip()] - - -def _build_valset_if_available(cfg, params): - """ - Try to build a validation Dataset. - - If cfg['val_txt'] exists, use it. - - Else if /val.txt exists, use it. - - Else return None (testing will be skipped). - Args: - cfg: config dict - params: params dict for Dataset - Returns: - Dataset or None - """ - input_size = int(cfg.get("input_size", 640)) - val_txt = cfg.get("val_txt", "") - if not val_txt: - ds_root = cfg.get("dataset_path", "") - guess = os.path.join(ds_root, "val.txt") if ds_root else "" - val_txt = guess if os.path.exists(guess) else "" - - val_files = _read_list_file(val_txt) - if not val_files: - return None - - return Dataset( - filenames=val_files, - input_size=input_size, - params=params, - augment=True, - ) - - -def _seed_everything(seed: int): - np.random.seed(seed) - torch.manual_seed(seed) - random.seed(seed) - - -def _plot_curves(save_dir, hist): - """ - Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round. - """ - os.makedirs(save_dir, exist_ok=True) - rounds = np.arange(1, len(hist["mAP"]) + 1) - - plt.figure() - if hist["mAP"]: - plt.plot(rounds, hist["mAP"], label="mAP50-95") - if hist["mAP50"]: - plt.plot(rounds, hist["mAP50"], label="mAP50") - if hist["precision"]: - plt.plot(rounds, hist["precision"], label="precision") - if hist["recall"]: - plt.plot(rounds, hist["recall"], label="recall") - if hist["train_loss"]: - plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)") - plt.xlabel("Global Round") - plt.ylabel("Metric") - plt.title("Federated YOLO - Server Metrics") - plt.legend() - out_png = os.path.join(save_dir, "fed_yolo_curves.png") - plt.savefig(out_png, dpi=150, bbox_inches="tight") - print(f"[plot] saved: {out_png}") - - def fed_run(): """ Main FL process: @@ -98,20 +22,22 @@ def fed_run(): - Record & save results, plot curves """ args_cli = args_parser() + # TODO: cfg and params should not be separately defined with open(args_cli.config, "r", encoding="utf-8") as f: cfg = yaml.safe_load(f) # --- params / config normalization --- # For convenience we pass the same `params` dict used by Dataset/model/loss. # Here we re-use the top-level cfg directly as params. - params = dict(cfg) + # params = dict(cfg) + if "names" in cfg and isinstance(cfg["names"], dict): # Convert {0: 'uav', 1: 'car', ...} to list if you prefer list # but we can leave dict; your utils appear to accept dict pass # seeds - _seed_everything(int(cfg.get("i_seed", 0))) + seed_everything(int(cfg.get("i_seed", 0))) # --- split clients' train data from a global train list --- # Expect either cfg["train_txt"] or /train.txt @@ -144,13 +70,13 @@ def fed_run(): clients = {} for uid in users: - c = FedYoloClient(name=uid, model_name=model_name, params=params) + c = FedYoloClient(name=uid, model_name=model_name, params=cfg) c.load_trainset(user_data[uid]["filename"]) clients[uid] = c # --- build server & optional validation set --- - server = FedYoloServer(client_list=users, model_name=model_name, params=params) - valset = _build_valset_if_available(cfg, params) + server = FedYoloServer(client_list=users, model_name=model_name, params=cfg) + valset = build_valset_if_available(cfg, params=cfg, args=args_cli) # valset is a Dataset class, not data loader if valset is not None: server.load_valset(valset) @@ -186,27 +112,25 @@ def fed_run(): t0 = time.time() # Local training (sequential over all users) for uid in users: + # tqdm desc update p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}")) + client = clients[uid] # FedYoloClient instance client.update(global_state) # load global weights - state_dict, n_data, loss_dict = client.train(args_cli) # local training - server.rec(uid, state_dict, n_data, loss_dict) + state_dict, n_data, train_loss = client.train(args_cli) # local training + server.rec(uid, state_dict, n_data, train_loss) # Select a fraction for aggregation (FedAvg subset if desired) server.select_clients(connection_ratio=connection_ratio) # Aggregate - global_state, avg_loss_dict, _ = server.agg() + global_state, avg_loss, _ = server.agg() # Compute a scalar train loss for plotting (sum of components) - scalar_train_loss = float(sum(avg_loss_dict.values())) if avg_loss_dict else 0.0 + scalar_train_loss = avg_loss if avg_loss else 0.0 # Test (if valset provided) - test_metrics = server.test(args_cli) if server.valset is not None else {} - mAP = float(test_metrics.get("mAP", 0.0)) - mAP50 = float(test_metrics.get("mAP50", 0.0)) - precision = float(test_metrics.get("precision", 0.0)) - recall = float(test_metrics.get("recall", 0.0)) + mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0) # Flush per-round client caches server.flush() @@ -233,22 +157,23 @@ def fed_run(): p_bar.set_postfix(desc) # Save running JSON (resumable logs) - save_name = ( - f"[{cfg.get('fed_algo', 'FedAvg')},{cfg.get('model_name', 'yolo')}," - f"{cfg.get('num_local_epoch', cfg.get('client', {}).get('num_local_epoch', 1))}," - f"{cfg.get('num_local_class', 2)}," - f"{cfg.get('i_seed', 0)}]" - ) + save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s" out_json = os.path.join(res_root, save_name + ".json") with open(out_json, "w", encoding="utf-8") as f: - json.dump(history, f, indent=2) + json.dump(history, f, indent=4) p_bar.update(1) p_bar.close() + # Save final global model weights + if not os.path.exists("./weights"): + os.makedirs("./weights", exist_ok=True) + torch.save(global_state, f"./weights/{save_name}_final.pth") + print(f"[save] final global model weights: ./weights/{save_name}_final.pth") + # --- final plot --- - _plot_curves(res_root, history) + plot_curves(res_root, history, savename=f"{save_name}_curve.png") print("[done] training complete.") diff --git a/fed_run.sh b/fed_run.sh new file mode 100644 index 0000000..5f17abd --- /dev/null +++ b/fed_run.sh @@ -0,0 +1,2 @@ +GPUS=$1 +python3 -m torch.distributed.run --nproc_per_node=$GPUS fed_run.py ${@:2} \ No newline at end of file