重构fed_run.py,移除冗余函数,传参BUG修复,更新模型权重保存逻辑;新增fed_run.sh脚本以支持分布式训练

This commit is contained in:
TY1667
2025-10-19 21:30:45 +08:00
parent 0343a0fd30
commit 314f46d542
2 changed files with 26 additions and 99 deletions

View File

@@ -3,92 +3,16 @@ import os
import json import json
import yaml import yaml
import time import time
import random
from tqdm import tqdm from tqdm import tqdm
import numpy as np
import torch import torch
import matplotlib.pyplot as plt
from utils.dataset import Dataset from utils.fed_util import build_valset_if_available, seed_everything, plot_curves
from fed_algo_cs.client_base import FedYoloClient from fed_algo_cs.client_base import FedYoloClient
from fed_algo_cs.server_base import FedYoloServer from fed_algo_cs.server_base import FedYoloServer
from utils.args import args_parser # args parser from utils.args import args_parser # args parser
from utils.fed_util import divide_trainset # divide_trainset from utils.fed_util import divide_trainset # divide_trainset
def _read_list_file(txt_path: str):
"""Read one path per line; keep as-is (absolute or relative)."""
if not txt_path or not os.path.exists(txt_path):
return []
with open(txt_path, "r", encoding="utf-8") as f:
return [ln.strip() for ln in f if ln.strip()]
def _build_valset_if_available(cfg, params):
"""
Try to build a validation Dataset.
- If cfg['val_txt'] exists, use it.
- Else if <dataset_path>/val.txt exists, use it.
- Else return None (testing will be skipped).
Args:
cfg: config dict
params: params dict for Dataset
Returns:
Dataset or None
"""
input_size = int(cfg.get("input_size", 640))
val_txt = cfg.get("val_txt", "")
if not val_txt:
ds_root = cfg.get("dataset_path", "")
guess = os.path.join(ds_root, "val.txt") if ds_root else ""
val_txt = guess if os.path.exists(guess) else ""
val_files = _read_list_file(val_txt)
if not val_files:
return None
return Dataset(
filenames=val_files,
input_size=input_size,
params=params,
augment=True,
)
def _seed_everything(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
def _plot_curves(save_dir, hist):
"""
Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round.
"""
os.makedirs(save_dir, exist_ok=True)
rounds = np.arange(1, len(hist["mAP"]) + 1)
plt.figure()
if hist["mAP"]:
plt.plot(rounds, hist["mAP"], label="mAP50-95")
if hist["mAP50"]:
plt.plot(rounds, hist["mAP50"], label="mAP50")
if hist["precision"]:
plt.plot(rounds, hist["precision"], label="precision")
if hist["recall"]:
plt.plot(rounds, hist["recall"], label="recall")
if hist["train_loss"]:
plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)")
plt.xlabel("Global Round")
plt.ylabel("Metric")
plt.title("Federated YOLO - Server Metrics")
plt.legend()
out_png = os.path.join(save_dir, "fed_yolo_curves.png")
plt.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"[plot] saved: {out_png}")
def fed_run(): def fed_run():
""" """
Main FL process: Main FL process:
@@ -98,20 +22,22 @@ def fed_run():
- Record & save results, plot curves - Record & save results, plot curves
""" """
args_cli = args_parser() args_cli = args_parser()
# TODO: cfg and params should not be separately defined
with open(args_cli.config, "r", encoding="utf-8") as f: with open(args_cli.config, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) cfg = yaml.safe_load(f)
# --- params / config normalization --- # --- params / config normalization ---
# For convenience we pass the same `params` dict used by Dataset/model/loss. # For convenience we pass the same `params` dict used by Dataset/model/loss.
# Here we re-use the top-level cfg directly as params. # Here we re-use the top-level cfg directly as params.
params = dict(cfg) # params = dict(cfg)
if "names" in cfg and isinstance(cfg["names"], dict): if "names" in cfg and isinstance(cfg["names"], dict):
# Convert {0: 'uav', 1: 'car', ...} to list if you prefer list # Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
# but we can leave dict; your utils appear to accept dict # but we can leave dict; your utils appear to accept dict
pass pass
# seeds # seeds
_seed_everything(int(cfg.get("i_seed", 0))) seed_everything(int(cfg.get("i_seed", 0)))
# --- split clients' train data from a global train list --- # --- split clients' train data from a global train list ---
# Expect either cfg["train_txt"] or <dataset_path>/train.txt # Expect either cfg["train_txt"] or <dataset_path>/train.txt
@@ -144,13 +70,13 @@ def fed_run():
clients = {} clients = {}
for uid in users: for uid in users:
c = FedYoloClient(name=uid, model_name=model_name, params=params) c = FedYoloClient(name=uid, model_name=model_name, params=cfg)
c.load_trainset(user_data[uid]["filename"]) c.load_trainset(user_data[uid]["filename"])
clients[uid] = c clients[uid] = c
# --- build server & optional validation set --- # --- build server & optional validation set ---
server = FedYoloServer(client_list=users, model_name=model_name, params=params) server = FedYoloServer(client_list=users, model_name=model_name, params=cfg)
valset = _build_valset_if_available(cfg, params) valset = build_valset_if_available(cfg, params=cfg, args=args_cli)
# valset is a Dataset class, not data loader # valset is a Dataset class, not data loader
if valset is not None: if valset is not None:
server.load_valset(valset) server.load_valset(valset)
@@ -186,27 +112,25 @@ def fed_run():
t0 = time.time() t0 = time.time()
# Local training (sequential over all users) # Local training (sequential over all users)
for uid in users: for uid in users:
# tqdm desc update
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}")) p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
client = clients[uid] # FedYoloClient instance client = clients[uid] # FedYoloClient instance
client.update(global_state) # load global weights client.update(global_state) # load global weights
state_dict, n_data, loss_dict = client.train(args_cli) # local training state_dict, n_data, train_loss = client.train(args_cli) # local training
server.rec(uid, state_dict, n_data, loss_dict) server.rec(uid, state_dict, n_data, train_loss)
# Select a fraction for aggregation (FedAvg subset if desired) # Select a fraction for aggregation (FedAvg subset if desired)
server.select_clients(connection_ratio=connection_ratio) server.select_clients(connection_ratio=connection_ratio)
# Aggregate # Aggregate
global_state, avg_loss_dict, _ = server.agg() global_state, avg_loss, _ = server.agg()
# Compute a scalar train loss for plotting (sum of components) # Compute a scalar train loss for plotting (sum of components)
scalar_train_loss = float(sum(avg_loss_dict.values())) if avg_loss_dict else 0.0 scalar_train_loss = avg_loss if avg_loss else 0.0
# Test (if valset provided) # Test (if valset provided)
test_metrics = server.test(args_cli) if server.valset is not None else {} mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0)
mAP = float(test_metrics.get("mAP", 0.0))
mAP50 = float(test_metrics.get("mAP50", 0.0))
precision = float(test_metrics.get("precision", 0.0))
recall = float(test_metrics.get("recall", 0.0))
# Flush per-round client caches # Flush per-round client caches
server.flush() server.flush()
@@ -233,22 +157,23 @@ def fed_run():
p_bar.set_postfix(desc) p_bar.set_postfix(desc)
# Save running JSON (resumable logs) # Save running JSON (resumable logs)
save_name = ( save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s"
f"[{cfg.get('fed_algo', 'FedAvg')},{cfg.get('model_name', 'yolo')},"
f"{cfg.get('num_local_epoch', cfg.get('client', {}).get('num_local_epoch', 1))},"
f"{cfg.get('num_local_class', 2)},"
f"{cfg.get('i_seed', 0)}]"
)
out_json = os.path.join(res_root, save_name + ".json") out_json = os.path.join(res_root, save_name + ".json")
with open(out_json, "w", encoding="utf-8") as f: with open(out_json, "w", encoding="utf-8") as f:
json.dump(history, f, indent=2) json.dump(history, f, indent=4)
p_bar.update(1) p_bar.update(1)
p_bar.close() p_bar.close()
# Save final global model weights
if not os.path.exists("./weights"):
os.makedirs("./weights", exist_ok=True)
torch.save(global_state, f"./weights/{save_name}_final.pth")
print(f"[save] final global model weights: ./weights/{save_name}_final.pth")
# --- final plot --- # --- final plot ---
_plot_curves(res_root, history) plot_curves(res_root, history, savename=f"{save_name}_curve.png")
print("[done] training complete.") print("[done] training complete.")

2
fed_run.sh Normal file
View File

@@ -0,0 +1,2 @@
GPUS=$1
python3 -m torch.distributed.run --nproc_per_node=$GPUS fed_run.py ${@:2}