diff --git a/nets/__init__.py b/nets/__init__.py new file mode 100644 index 0000000..926ff2a --- /dev/null +++ b/nets/__init__.py @@ -0,0 +1,3 @@ +from .nn import YOLO, yolo_v11_l, yolo_v11_m, yolo_v11_s, yolo_v11_t, yolo_v11_x, yolo_v11_n + +__all__ = ["YOLO", "yolo_v11_l", "yolo_v11_m", "yolo_v11_s", "yolo_v11_t", "yolo_v11_x", "yolo_v11_n"] diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000..07a7db3 --- /dev/null +++ b/utils/__init__.py @@ -0,0 +1,3 @@ +from .dataset import * +from .fed_util import * +from .util import * \ No newline at end of file diff --git a/utils/fed_util.py b/utils/fed_util.py index fea2bc5..d0715b2 100644 --- a/utils/fed_util.py +++ b/utils/fed_util.py @@ -1,10 +1,15 @@ import os import re import random +import matplotlib.pyplot as plt +from utils.dataset import Dataset +import numpy as np +import torch from collections import defaultdict from typing import Dict, List, Optional, Set, Any from nets import nn +from nets import YOLO def _image_to_label_path(img_path: str) -> str: @@ -59,6 +64,14 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]: return class_ids +def _read_list_file(txt_path: str): + """Read one path per line; keep as-is (absolute or relative).""" + if not txt_path or not os.path.exists(txt_path): + return [] + with open(txt_path, "r", encoding="utf-8") as f: + return [ln.strip() for ln in f if ln.strip()] + + def divide_trainset( trainset_path: str, num_local_class: int, @@ -230,7 +243,7 @@ def divide_trainset( return result -def init_model(model_name, num_classes): +def init_model(model_name, num_classes) -> YOLO: """ Initialize the model for a specific learning task Args: @@ -252,3 +265,74 @@ def init_model(model_name, num_classes): raise ValueError("Model {} is not supported.".format(model_name)) return model + + +def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]: + """ + Try to build a validation Dataset. + - If cfg['val_txt'] exists, use it. + - Else if /val.txt exists, use it. + - Else return None (testing will be skipped). + Args: + cfg: config dict + params: params dict for Dataset + Returns: + Dataset or None + """ + input_size = args.input_size if args and hasattr(args, "input_size") else 640 + val_txt = cfg.get("val_txt", "") + if not val_txt: + ds_root = cfg.get("dataset_path", "") + guess = os.path.join(ds_root, "val.txt") if ds_root else "" + val_txt = guess if os.path.exists(guess) else "" + + val_files = _read_list_file(val_txt) + if not val_files: + import warnings + + warnings.warn("No validation dataset found.") + return None + + return Dataset( + filenames=val_files, + input_size=input_size, + params=params, + augment=True, + ) + + +def seed_everything(seed: int): + np.random.seed(seed) + torch.manual_seed(seed) + random.seed(seed) + + +def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"): + """ + Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round. + Args: + save_dir: directory to save the plot + hist: history dict with keys "mAP", "mAP50", "precision", "recall", "train_loss" + savename: output filename + """ + os.makedirs(save_dir, exist_ok=True) + rounds = np.arange(1, len(hist["mAP"]) + 1) + + plt.figure() + if hist["mAP"]: + plt.plot(rounds, hist["mAP"], label="mAP50-95") + if hist["mAP50"]: + plt.plot(rounds, hist["mAP50"], label="mAP50") + if hist["precision"]: + plt.plot(rounds, hist["precision"], label="precision") + if hist["recall"]: + plt.plot(rounds, hist["recall"], label="recall") + if hist["train_loss"]: + plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)") + plt.xlabel("Global Round") + plt.ylabel("Metric") + plt.title("Federated YOLO - Server Metrics") + plt.legend() + out_png = os.path.join(save_dir, savename) + plt.savefig(out_png, dpi=150, bbox_inches="tight") + print(f"[plot] saved: {out_png}")