重构fed_run.py,移除冗余函数,传参BUG修复,更新模型权重保存逻辑;新增fed_run.sh脚本以支持分布式训练
This commit is contained in:
123
fed_run.py
123
fed_run.py
@@ -3,92 +3,16 @@ import os
|
|||||||
import json
|
import json
|
||||||
import yaml
|
import yaml
|
||||||
import time
|
import time
|
||||||
import random
|
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import matplotlib.pyplot as plt
|
|
||||||
|
|
||||||
from utils.dataset import Dataset
|
from utils.fed_util import build_valset_if_available, seed_everything, plot_curves
|
||||||
from fed_algo_cs.client_base import FedYoloClient
|
from fed_algo_cs.client_base import FedYoloClient
|
||||||
from fed_algo_cs.server_base import FedYoloServer
|
from fed_algo_cs.server_base import FedYoloServer
|
||||||
from utils.args import args_parser # args parser
|
from utils.args import args_parser # args parser
|
||||||
from utils.fed_util import divide_trainset # divide_trainset
|
from utils.fed_util import divide_trainset # divide_trainset
|
||||||
|
|
||||||
|
|
||||||
def _read_list_file(txt_path: str):
|
|
||||||
"""Read one path per line; keep as-is (absolute or relative)."""
|
|
||||||
if not txt_path or not os.path.exists(txt_path):
|
|
||||||
return []
|
|
||||||
with open(txt_path, "r", encoding="utf-8") as f:
|
|
||||||
return [ln.strip() for ln in f if ln.strip()]
|
|
||||||
|
|
||||||
|
|
||||||
def _build_valset_if_available(cfg, params):
|
|
||||||
"""
|
|
||||||
Try to build a validation Dataset.
|
|
||||||
- If cfg['val_txt'] exists, use it.
|
|
||||||
- Else if <dataset_path>/val.txt exists, use it.
|
|
||||||
- Else return None (testing will be skipped).
|
|
||||||
Args:
|
|
||||||
cfg: config dict
|
|
||||||
params: params dict for Dataset
|
|
||||||
Returns:
|
|
||||||
Dataset or None
|
|
||||||
"""
|
|
||||||
input_size = int(cfg.get("input_size", 640))
|
|
||||||
val_txt = cfg.get("val_txt", "")
|
|
||||||
if not val_txt:
|
|
||||||
ds_root = cfg.get("dataset_path", "")
|
|
||||||
guess = os.path.join(ds_root, "val.txt") if ds_root else ""
|
|
||||||
val_txt = guess if os.path.exists(guess) else ""
|
|
||||||
|
|
||||||
val_files = _read_list_file(val_txt)
|
|
||||||
if not val_files:
|
|
||||||
return None
|
|
||||||
|
|
||||||
return Dataset(
|
|
||||||
filenames=val_files,
|
|
||||||
input_size=input_size,
|
|
||||||
params=params,
|
|
||||||
augment=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _seed_everything(seed: int):
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
random.seed(seed)
|
|
||||||
|
|
||||||
|
|
||||||
def _plot_curves(save_dir, hist):
|
|
||||||
"""
|
|
||||||
Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round.
|
|
||||||
"""
|
|
||||||
os.makedirs(save_dir, exist_ok=True)
|
|
||||||
rounds = np.arange(1, len(hist["mAP"]) + 1)
|
|
||||||
|
|
||||||
plt.figure()
|
|
||||||
if hist["mAP"]:
|
|
||||||
plt.plot(rounds, hist["mAP"], label="mAP50-95")
|
|
||||||
if hist["mAP50"]:
|
|
||||||
plt.plot(rounds, hist["mAP50"], label="mAP50")
|
|
||||||
if hist["precision"]:
|
|
||||||
plt.plot(rounds, hist["precision"], label="precision")
|
|
||||||
if hist["recall"]:
|
|
||||||
plt.plot(rounds, hist["recall"], label="recall")
|
|
||||||
if hist["train_loss"]:
|
|
||||||
plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)")
|
|
||||||
plt.xlabel("Global Round")
|
|
||||||
plt.ylabel("Metric")
|
|
||||||
plt.title("Federated YOLO - Server Metrics")
|
|
||||||
plt.legend()
|
|
||||||
out_png = os.path.join(save_dir, "fed_yolo_curves.png")
|
|
||||||
plt.savefig(out_png, dpi=150, bbox_inches="tight")
|
|
||||||
print(f"[plot] saved: {out_png}")
|
|
||||||
|
|
||||||
|
|
||||||
def fed_run():
|
def fed_run():
|
||||||
"""
|
"""
|
||||||
Main FL process:
|
Main FL process:
|
||||||
@@ -98,20 +22,22 @@ def fed_run():
|
|||||||
- Record & save results, plot curves
|
- Record & save results, plot curves
|
||||||
"""
|
"""
|
||||||
args_cli = args_parser()
|
args_cli = args_parser()
|
||||||
|
# TODO: cfg and params should not be separately defined
|
||||||
with open(args_cli.config, "r", encoding="utf-8") as f:
|
with open(args_cli.config, "r", encoding="utf-8") as f:
|
||||||
cfg = yaml.safe_load(f)
|
cfg = yaml.safe_load(f)
|
||||||
|
|
||||||
# --- params / config normalization ---
|
# --- params / config normalization ---
|
||||||
# For convenience we pass the same `params` dict used by Dataset/model/loss.
|
# For convenience we pass the same `params` dict used by Dataset/model/loss.
|
||||||
# Here we re-use the top-level cfg directly as params.
|
# Here we re-use the top-level cfg directly as params.
|
||||||
params = dict(cfg)
|
# params = dict(cfg)
|
||||||
|
|
||||||
if "names" in cfg and isinstance(cfg["names"], dict):
|
if "names" in cfg and isinstance(cfg["names"], dict):
|
||||||
# Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
|
# Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
|
||||||
# but we can leave dict; your utils appear to accept dict
|
# but we can leave dict; your utils appear to accept dict
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# seeds
|
# seeds
|
||||||
_seed_everything(int(cfg.get("i_seed", 0)))
|
seed_everything(int(cfg.get("i_seed", 0)))
|
||||||
|
|
||||||
# --- split clients' train data from a global train list ---
|
# --- split clients' train data from a global train list ---
|
||||||
# Expect either cfg["train_txt"] or <dataset_path>/train.txt
|
# Expect either cfg["train_txt"] or <dataset_path>/train.txt
|
||||||
@@ -144,13 +70,13 @@ def fed_run():
|
|||||||
clients = {}
|
clients = {}
|
||||||
|
|
||||||
for uid in users:
|
for uid in users:
|
||||||
c = FedYoloClient(name=uid, model_name=model_name, params=params)
|
c = FedYoloClient(name=uid, model_name=model_name, params=cfg)
|
||||||
c.load_trainset(user_data[uid]["filename"])
|
c.load_trainset(user_data[uid]["filename"])
|
||||||
clients[uid] = c
|
clients[uid] = c
|
||||||
|
|
||||||
# --- build server & optional validation set ---
|
# --- build server & optional validation set ---
|
||||||
server = FedYoloServer(client_list=users, model_name=model_name, params=params)
|
server = FedYoloServer(client_list=users, model_name=model_name, params=cfg)
|
||||||
valset = _build_valset_if_available(cfg, params)
|
valset = build_valset_if_available(cfg, params=cfg, args=args_cli)
|
||||||
# valset is a Dataset class, not data loader
|
# valset is a Dataset class, not data loader
|
||||||
if valset is not None:
|
if valset is not None:
|
||||||
server.load_valset(valset)
|
server.load_valset(valset)
|
||||||
@@ -186,27 +112,25 @@ def fed_run():
|
|||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
# Local training (sequential over all users)
|
# Local training (sequential over all users)
|
||||||
for uid in users:
|
for uid in users:
|
||||||
|
# tqdm desc update
|
||||||
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
|
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
|
||||||
|
|
||||||
client = clients[uid] # FedYoloClient instance
|
client = clients[uid] # FedYoloClient instance
|
||||||
client.update(global_state) # load global weights
|
client.update(global_state) # load global weights
|
||||||
state_dict, n_data, loss_dict = client.train(args_cli) # local training
|
state_dict, n_data, train_loss = client.train(args_cli) # local training
|
||||||
server.rec(uid, state_dict, n_data, loss_dict)
|
server.rec(uid, state_dict, n_data, train_loss)
|
||||||
|
|
||||||
# Select a fraction for aggregation (FedAvg subset if desired)
|
# Select a fraction for aggregation (FedAvg subset if desired)
|
||||||
server.select_clients(connection_ratio=connection_ratio)
|
server.select_clients(connection_ratio=connection_ratio)
|
||||||
|
|
||||||
# Aggregate
|
# Aggregate
|
||||||
global_state, avg_loss_dict, _ = server.agg()
|
global_state, avg_loss, _ = server.agg()
|
||||||
|
|
||||||
# Compute a scalar train loss for plotting (sum of components)
|
# Compute a scalar train loss for plotting (sum of components)
|
||||||
scalar_train_loss = float(sum(avg_loss_dict.values())) if avg_loss_dict else 0.0
|
scalar_train_loss = avg_loss if avg_loss else 0.0
|
||||||
|
|
||||||
# Test (if valset provided)
|
# Test (if valset provided)
|
||||||
test_metrics = server.test(args_cli) if server.valset is not None else {}
|
mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0)
|
||||||
mAP = float(test_metrics.get("mAP", 0.0))
|
|
||||||
mAP50 = float(test_metrics.get("mAP50", 0.0))
|
|
||||||
precision = float(test_metrics.get("precision", 0.0))
|
|
||||||
recall = float(test_metrics.get("recall", 0.0))
|
|
||||||
|
|
||||||
# Flush per-round client caches
|
# Flush per-round client caches
|
||||||
server.flush()
|
server.flush()
|
||||||
@@ -233,22 +157,23 @@ def fed_run():
|
|||||||
p_bar.set_postfix(desc)
|
p_bar.set_postfix(desc)
|
||||||
|
|
||||||
# Save running JSON (resumable logs)
|
# Save running JSON (resumable logs)
|
||||||
save_name = (
|
save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s"
|
||||||
f"[{cfg.get('fed_algo', 'FedAvg')},{cfg.get('model_name', 'yolo')},"
|
|
||||||
f"{cfg.get('num_local_epoch', cfg.get('client', {}).get('num_local_epoch', 1))},"
|
|
||||||
f"{cfg.get('num_local_class', 2)},"
|
|
||||||
f"{cfg.get('i_seed', 0)}]"
|
|
||||||
)
|
|
||||||
out_json = os.path.join(res_root, save_name + ".json")
|
out_json = os.path.join(res_root, save_name + ".json")
|
||||||
with open(out_json, "w", encoding="utf-8") as f:
|
with open(out_json, "w", encoding="utf-8") as f:
|
||||||
json.dump(history, f, indent=2)
|
json.dump(history, f, indent=4)
|
||||||
|
|
||||||
p_bar.update(1)
|
p_bar.update(1)
|
||||||
|
|
||||||
p_bar.close()
|
p_bar.close()
|
||||||
|
|
||||||
|
# Save final global model weights
|
||||||
|
if not os.path.exists("./weights"):
|
||||||
|
os.makedirs("./weights", exist_ok=True)
|
||||||
|
torch.save(global_state, f"./weights/{save_name}_final.pth")
|
||||||
|
print(f"[save] final global model weights: ./weights/{save_name}_final.pth")
|
||||||
|
|
||||||
# --- final plot ---
|
# --- final plot ---
|
||||||
_plot_curves(res_root, history)
|
plot_curves(res_root, history, savename=f"{save_name}_curve.png")
|
||||||
print("[done] training complete.")
|
print("[done] training complete.")
|
||||||
|
|
||||||
|
|
||||||
|
2
fed_run.sh
Normal file
2
fed_run.sh
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
GPUS=$1
|
||||||
|
python3 -m torch.distributed.run --nproc_per_node=$GPUS fed_run.py ${@:2}
|
Reference in New Issue
Block a user