#!/usr/bin/env python3 import os import json import yaml import time import random from tqdm import tqdm import numpy as np import torch import matplotlib.pyplot as plt from utils.dataset import Dataset from fed_algo_cs.client_base import FedYoloClient from fed_algo_cs.server_base import FedYoloServer from utils.args import args_parser # your args parser from utils.fed_util import divide_trainset # divide_trainset is yours def _read_list_file(txt_path: str): """Read one path per line; keep as-is (absolute or relative).""" if not txt_path or not os.path.exists(txt_path): return [] with open(txt_path, "r", encoding="utf-8") as f: return [ln.strip() for ln in f if ln.strip()] def _build_valset_if_available(cfg, params): """ Try to build a validation Dataset. - If cfg['val_txt'] exists, use it. - Else if /val.txt exists, use it. - Else return None (testing will be skipped). Args: cfg: config dict params: params dict for Dataset Returns: Dataset or None """ input_size = int(cfg.get("input_size", 640)) val_txt = cfg.get("val_txt", "") if not val_txt: ds_root = cfg.get("dataset_path", "") guess = os.path.join(ds_root, "val.txt") if ds_root else "" val_txt = guess if os.path.exists(guess) else "" val_files = _read_list_file(val_txt) if not val_files: return None return Dataset( filenames=val_files, input_size=input_size, params=params, augment=True, ) def _seed_everything(seed: int): np.random.seed(seed) torch.manual_seed(seed) random.seed(seed) def _plot_curves(save_dir, hist): """ Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round. """ os.makedirs(save_dir, exist_ok=True) rounds = np.arange(1, len(hist["mAP"]) + 1) plt.figure() if hist["mAP"]: plt.plot(rounds, hist["mAP"], label="mAP50-95") if hist["mAP50"]: plt.plot(rounds, hist["mAP50"], label="mAP50") if hist["precision"]: plt.plot(rounds, hist["precision"], label="precision") if hist["recall"]: plt.plot(rounds, hist["recall"], label="recall") if hist["train_loss"]: plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)") plt.xlabel("Global Round") plt.ylabel("Metric") plt.title("Federated YOLO - Server Metrics") plt.legend() out_png = os.path.join(save_dir, "fed_yolo_curves.png") plt.savefig(out_png, dpi=150, bbox_inches="tight") print(f"[plot] saved: {out_png}") def fed_run(): """ Main FL process: - Initialize clients & server - For each round: sequential local training -> record -> select -> aggregate - Test & flush - Record & save results, plot curves """ args_cli = args_parser() with open(args_cli.config, "r", encoding="utf-8") as f: cfg = yaml.safe_load(f) # --- params / config normalization --- # For convenience we pass the same `params` dict used by Dataset/model/loss. # Here we re-use the top-level cfg directly as params. params = dict(cfg) if "names" in cfg and isinstance(cfg["names"], dict): # Convert {0: 'uav', 1: 'car', ...} to list if you prefer list # but we can leave dict; your utils appear to accept dict pass # seeds _seed_everything(int(cfg.get("i_seed", 0))) # --- split clients' train data from a global train list --- # Expect either cfg["train_txt"] or /train.txt train_txt = cfg.get("train_txt", "") if not train_txt: ds_root = cfg.get("dataset_path", "") guess = os.path.join(ds_root, "train.txt") if ds_root else "" train_txt = guess if not train_txt or not os.path.exists(train_txt): raise FileNotFoundError( f"train.txt not found. Provide --config with 'train_txt' or ensure '{train_txt}' exists." ) split = divide_trainset( trainset_path=train_txt, num_local_class=int(cfg.get("num_local_class", 1)), num_client=int(cfg.get("num_client", 64)), min_data=int(cfg.get("min_data", 100)), max_data=int(cfg.get("max_data", 100)), mode=str(cfg.get("partition_mode", "disjoint")), # "overlap" or "disjoint" seed=int(cfg.get("i_seed", 0)), ) users = split["users"] user_data = split["user_data"] # mapping: id -> {"filename": [...]} # --- build clients --- model_name = cfg.get("model_name", "yolo_v11_n") clients = {} for uid in users: c = FedYoloClient(name=uid, model_name=model_name, params=params) c.load_trainset(user_data[uid]["filename"]) clients[uid] = c # --- build server & optional validation set --- server = FedYoloServer(client_list=users, model_name=model_name, params=params) valset = _build_valset_if_available(cfg, params) # valset is a Dataset class, not data loader if valset is not None: server.load_valset(valset) # --- push initial global weights --- global_state = server.state_dict() # --- args object for client.train() --- # args_train = _make_args_for_client(cfg, args_cli) # --- history recorder --- history = { "mAP": [], "mAP50": [], "precision": [], "recall": [], "train_loss": [], # scalar sum of client-weighted dict losses "round_time_sec": [], } # --- main FL loop --- num_round = int(cfg.get("num_round", 50)) connection_ratio = float(cfg.get("connection_ratio", 1.0)) # e.g., 1.0 = all clients res_root = cfg.get("res_root", "results") os.makedirs(res_root, exist_ok=True) for rnd in tqdm(range(num_round), desc="main federal loop round"): t0 = time.time() # Local training (sequential over all users) for uid in tqdm(users, desc=f"Round {rnd + 1} local training", leave=False): client = clients[uid] # FedYoloClient instance client.update(global_state) # load global weights state_dict, n_data, loss_dict = client.train(args_cli) # local training server.rec(uid, state_dict, n_data, loss_dict) # Select a fraction for aggregation (FedAvg subset if desired) server.select_clients(connection_ratio=connection_ratio) # Aggregate global_state, avg_loss_dict, _ = server.agg() # Compute a scalar train loss for plotting (sum of components) scalar_train_loss = float(sum(avg_loss_dict.values())) if avg_loss_dict else 0.0 # Test (if valset provided) test_metrics = server.test(args_cli) if server.valset is not None else {} mAP = float(test_metrics.get("mAP", 0.0)) mAP50 = float(test_metrics.get("mAP50", 0.0)) precision = float(test_metrics.get("precision", 0.0)) recall = float(test_metrics.get("recall", 0.0)) # Flush per-round client caches server.flush() # Record & log history["mAP"].append(mAP) history["mAP50"].append(mAP50) history["precision"].append(precision) history["recall"].append(recall) history["train_loss"].append(scalar_train_loss) history["round_time_sec"].append(time.time() - t0) print( f"[round {rnd + 1:04d}] " f"loss={scalar_train_loss:.4f} mAP50-95={mAP:.4f} mAP50={mAP50:.4f} " f"P={precision:.4f} R={recall:.4f}" ) # Save running JSON (resumable logs) save_name = ( f"[{cfg.get('fed_algo', 'FedAvg')},{cfg.get('model_name', 'yolo')}," f"{cfg.get('num_local_epoch', cfg.get('client', {}).get('num_local_epoch', 1))}," f"{cfg.get('num_local_class', 2)}," f"{cfg.get('i_seed', 0)}]" ) out_json = os.path.join(res_root, save_name + ".json") with open(out_json, "w", encoding="utf-8") as f: json.dump(history, f, indent=2) # --- final plot --- _plot_curves(res_root, history) print("[done] training complete.") if __name__ == "__main__": fed_run()