Files
fed-yolo/fed_run.py

182 lines
6.4 KiB
Python

#!/usr/bin/env python3
import os
import json
import yaml
import time
from tqdm import tqdm
import torch
from utils.fed_util import build_valset_if_available, seed_everything, plot_curves
from fed_algo_cs.client_base import FedYoloClient
from fed_algo_cs.server_base import FedYoloServer
from utils.args import args_parser # args parser
from utils.fed_util import divide_trainset # divide_trainset
def fed_run():
"""
Main FL process:
- Initialize clients & server
- For each round: sequential local training -> record -> select -> aggregate
- Test & flush
- Record & save results, plot curves
"""
args_cli = args_parser()
# TODO: cfg and params should not be separately defined
with open(args_cli.config, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f)
# --- params / config normalization ---
# For convenience we pass the same `params` dict used by Dataset/model/loss.
# Here we re-use the top-level cfg directly as params.
# params = dict(cfg)
if "names" in cfg and isinstance(cfg["names"], dict):
# Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
# but we can leave dict; your utils appear to accept dict
pass
# seeds
seed_everything(int(cfg.get("i_seed", 0)))
# --- split clients' train data from a global train list ---
# Expect either cfg["train_txt"] or <dataset_path>/train.txt
train_txt = cfg.get("train_txt", "")
if not train_txt:
ds_root = cfg.get("dataset_path", "")
guess = os.path.join(ds_root, "train.txt") if ds_root else ""
train_txt = guess
if not train_txt or not os.path.exists(train_txt):
raise FileNotFoundError(
f"train.txt not found. Provide --config with 'train_txt' or ensure '{train_txt}' exists."
)
split = divide_trainset(
trainset_path=train_txt,
num_local_class=int(cfg.get("num_local_class", 1)),
num_client=int(cfg.get("num_client", 64)),
min_data=int(cfg.get("min_data", 100)),
max_data=int(cfg.get("max_data", 100)),
mode=str(cfg.get("partition_mode", "overlap")), # "overlap" or "disjoint"
seed=int(cfg.get("i_seed", 0)),
)
users = split["users"]
user_data = split["user_data"] # mapping: id -> {"filename": [...]}
# --- build clients ---
model_name = cfg.get("model_name", "yolo_v11_n")
clients = {}
for uid in users:
c = FedYoloClient(name=uid, model_name=model_name, params=cfg)
c.load_trainset(user_data[uid]["filename"])
clients[uid] = c
# --- build server & optional validation set ---
server = FedYoloServer(client_list=users, model_name=model_name, params=cfg)
valset = build_valset_if_available(cfg, params=cfg, args=args_cli)
# valset is a Dataset class, not data loader
if valset is not None:
server.load_valset(valset)
# --- push initial global weights ---
global_state = server.state_dict()
# --- args object for client.train() ---
# args_train = _make_args_for_client(cfg, args_cli)
# --- history recorder ---
history = {
"mAP": [],
"mAP50": [],
"precision": [],
"recall": [],
"train_loss": [], # scalar sum of client-weighted dict losses
"round_time_sec": [],
}
# --- main FL loop ---
num_round = int(cfg.get("num_round", 50))
connection_ratio = float(cfg.get("connection_ratio", 1.0)) # e.g., 1.0 = all clients
res_root = cfg.get("res_root", "results")
os.makedirs(res_root, exist_ok=True)
# tqdm logging
header = ("%10s" * 2) % ("Round", "client")
tqdm.write("\n" + header)
p_bar = tqdm(total=num_round, ncols=160, ascii="->>")
for rnd in range(num_round):
t0 = time.time()
# Local training (sequential over all users)
for uid in users:
# tqdm desc update
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
client = clients[uid] # FedYoloClient instance
client.update(global_state) # load global weights
state_dict, n_data, train_loss = client.train(args_cli) # local training
server.rec(uid, state_dict, n_data, train_loss)
# Select a fraction for aggregation (FedAvg subset if desired)
server.select_clients(connection_ratio=connection_ratio)
# Aggregate
global_state, avg_loss, _ = server.agg()
# Compute a scalar train loss for plotting (sum of components)
scalar_train_loss = avg_loss if avg_loss else 0.0
# Test (if valset provided)
mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0)
# Flush per-round client caches
server.flush()
# Record & log
history["mAP"].append(mAP)
history["mAP50"].append(mAP50)
history["precision"].append(precision)
history["recall"].append(recall)
history["train_loss"].append(scalar_train_loss)
history["round_time_sec"].append(time.time() - t0)
# Log GPU memory usage
# gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
# tqdm update
desc = {
"loss": f"{scalar_train_loss:.6g}",
"mAP50": f"{mAP50:.6g}",
"mAP": f"{mAP:.6g}",
"precision": f"{precision:.6g}",
"recall": f"{recall:.6g}",
# "gpu_mem": gpu_mem,
}
p_bar.set_postfix(desc)
# Save running JSON (resumable logs)
save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s"
out_json = os.path.join(res_root, save_name + ".json")
with open(out_json, "w", encoding="utf-8") as f:
json.dump(history, f, indent=4)
p_bar.update(1)
p_bar.close()
# Save final global model weights
if not os.path.exists("./weights"):
os.makedirs("./weights", exist_ok=True)
torch.save(global_state, f"./weights/{save_name}_final.pth")
print(f"[save] final global model weights: ./weights/{save_name}_final.pth")
# --- final plot ---
plot_curves(res_root, history, savename=f"{save_name}_curve.png")
print("[done] training complete.")
if __name__ == "__main__":
fed_run()