#!/usr/bin/env python3 import os import json import yaml import time from tqdm import tqdm import torch from utils.fed_util import build_valset_if_available, seed_everything, plot_curves from fed_algo_cs.client_base import FedYoloClient from fed_algo_cs.server_base import FedYoloServer from utils.args import args_parser # args parser from utils.fed_util import divide_trainset # divide_trainset def fed_run(): """ Main FL process: - Initialize clients & server - For each round: sequential local training -> record -> select -> aggregate - Test & flush - Record & save results, plot curves """ args_cli = args_parser() # TODO: cfg and params should not be separately defined with open(args_cli.config, "r", encoding="utf-8") as f: cfg = yaml.safe_load(f) # --- params / config normalization --- # For convenience we pass the same `params` dict used by Dataset/model/loss. # Here we re-use the top-level cfg directly as params. # params = dict(cfg) if "names" in cfg and isinstance(cfg["names"], dict): # Convert {0: 'uav', 1: 'car', ...} to list if you prefer list # but we can leave dict; your utils appear to accept dict pass # seeds seed_everything(int(cfg.get("i_seed", 0))) # --- split clients' train data from a global train list --- # Expect either cfg["train_txt"] or /train.txt train_txt = cfg.get("train_txt", "") if not train_txt: ds_root = cfg.get("dataset_path", "") guess = os.path.join(ds_root, "train2017.txt") if ds_root else "" train_txt = guess if not train_txt or not os.path.exists(train_txt): raise FileNotFoundError( f"train2017.txt not found. Provide --config with 'train_txt' or ensure '{train_txt}' exists." ) split = divide_trainset( trainset_path=train_txt, num_local_class=int(cfg.get("num_local_class", 1)), num_client=int(cfg.get("num_client", 64)), min_data=int(cfg.get("min_data", 100)), max_data=int(cfg.get("max_data", 100)), mode=str(cfg.get("partition_mode", "overlap")), # "overlap" or "disjoint" seed=int(cfg.get("i_seed", 0)), ) users = split["users"] user_data = split["user_data"] # mapping: id -> {"filename": [...]} # --- build clients --- model_name = cfg.get("model_name", "yolo_v11_n") clients = {} for uid in users: c = FedYoloClient(name=uid, model_name=model_name, params=cfg) c.load_trainset(user_data[uid]["filename"]) clients[uid] = c # --- build server & optional validation set --- server = FedYoloServer(client_list=users, model_name=model_name, params=cfg) valset = build_valset_if_available(cfg, params=cfg, args=args_cli, val_name="val2017") # valset is a Dataset class, not data loader if valset is not None: server.load_valset(valset) # --- push initial global weights --- global_state = server.state_dict() # --- args object for client.train() --- # args_train = _make_args_for_client(cfg, args_cli) # --- history recorder --- history = { "mAP": [], "mAP50": [], "precision": [], "recall": [], "train_loss": [], # scalar sum of client-weighted dict losses "round_time_sec": [], } # --- main FL loop --- num_round = int(cfg.get("num_round", 50)) connection_ratio = float(cfg.get("connection_ratio", 1.0)) # e.g., 1.0 = all clients res_root = cfg.get("res_root", "results") os.makedirs(res_root, exist_ok=True) # tqdm logging header = ("%10s" * 2) % ("Round", "client") tqdm.write("\n" + header) p_bar = tqdm(total=num_round, ncols=160, ascii="->>") for rnd in range(num_round): t0 = time.time() # Local training (sequential over all users) for uid in users: # tqdm desc update p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}")) client = clients[uid] # FedYoloClient instance client.update(global_state) # load global weights state_dict, n_data, train_loss = client.train(args_cli) # local training server.rec(uid, state_dict, n_data, train_loss) # Select a fraction for aggregation (FedAvg subset if desired) server.select_clients(connection_ratio=connection_ratio) # Aggregate global_state, avg_loss, _ = server.agg() # Compute a scalar train loss for plotting (sum of components) scalar_train_loss = avg_loss if avg_loss else 0.0 # Test (if valset provided) mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0) # Flush per-round client caches server.flush() # Record & log history["mAP"].append(mAP) history["mAP50"].append(mAP50) history["precision"].append(precision) history["recall"].append(recall) history["train_loss"].append(scalar_train_loss) history["round_time_sec"].append(time.time() - t0) # Log GPU memory usage # gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G" # tqdm update desc = { "loss": f"{scalar_train_loss:.6g}", "mAP50": f"{mAP50:.6g}", "mAP": f"{mAP:.6g}", "precision": f"{precision:.6g}", "recall": f"{recall:.6g}", # "gpu_mem": gpu_mem, } p_bar.set_postfix(desc) # Save running JSON (resumable logs) save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s" out_json = os.path.join(res_root, save_name + ".json") with open(out_json, "w", encoding="utf-8") as f: json.dump(history, f, indent=4) p_bar.update(1) p_bar.close() # Save final global model weights if not os.path.exists("./weights"): os.makedirs("./weights", exist_ok=True) torch.save(global_state, f"./weights/{save_name}_final.pth") print(f"[save] final global model weights: ./weights/{save_name}_final.pth") # --- final plot --- plot_curves(res_root, history, savename=f"{save_name}_curve.png") print("[done] training complete.") if __name__ == "__main__": fed_run()