import os import re import random import matplotlib.pyplot as plt from utils.dataset import Dataset import numpy as np import torch from collections import defaultdict from typing import Dict, List, Optional, Set, Any from nets import nn from nets import YOLO def _image_to_label_path(img_path: str) -> str: """ Convert an image path like ".../images/train2017/xxx.jpg" to the corresponding label path ".../labels/train2017/xxx.txt". Works for POSIX/Windows separators. """ # swap "/images/" (or "\images\") to "/labels/" label_path = re.sub(r"([/\\])images([/\\])", r"\1labels\2", img_path) # swap extension to .txt root, _ = os.path.splitext(label_path) return root + ".txt" def _parse_yolo_label_file(label_path: str) -> Set[int]: """ Return a set of class_ids found in a YOLO .txt label file. Empty file -> empty set. Missing file -> empty set. Robust to blank lines / trailing spaces. Args: label_path: path to the label file Returns: set of class IDs (integers) found in the file """ class_ids: Set[int] = set() if not os.path.exists(label_path): return class_ids try: with open(label_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue # YOLO format: cls cx cy w h parts = line.split() if not parts: continue try: cls = int(parts[0]) except ValueError: # handle weird case like '23.0' try: cls = int(float(parts[0])) except ValueError: # skip malformed line continue class_ids.add(cls) except Exception: # If the file can't be read for some reason, treat as no labels return set() return class_ids def _read_list_file(txt_path: str): """Read one path per line; keep as-is (absolute or relative).""" if not txt_path or not os.path.exists(txt_path): return [] with open(txt_path, "r", encoding="utf-8") as f: return [ln.strip() for ln in f if ln.strip()] def divide_trainset( trainset_path: str, num_local_class: int, num_client: int, min_data: int, max_data: int, mode: str = "overlap", # "overlap" or "disjoint" seed: Optional[int] = None, ) -> Dict[str, Any]: """ Build a federated split from a YOLO dataset list file. Args: trainset_path: path to a .txt file containing one image path per line e.g. /COCO/images/train2017/1111.jpg num_local_class: how many distinct classes to sample for each client num_client: number of clients min_data: minimum number of images per client max_data: maximum number of images per client mode: "overlap" -> images may be shared across clients "disjoint" -> each image is used by at most one client seed: optional random seed for reproducibility Returns: trainset_divided = { "users": ["c_00001", ...], "user_data": { "c_00001": {"filename": [img_path, ...]}, ... }, "num_samples": [len(list_for_user1), len(list_for_user2), ...] } Example: dataset = divide_trainset( trainset_path="/COCO/train2017.txt", num_local_class=3, num_client=5, min_data=10, max_data=20, mode="disjoint", # or "overlap" seed=42 ) print(dataset["users"]) # ['c_00001', ..., 'c_00005'] print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15] print(dataset["user_data"]["c_00001"]["filename"][:3]) """ if seed is not None: random.seed(seed) # ---- Basic validations (defensive programming) ---- if num_client <= 0: raise ValueError("num_client must be > 0") if num_local_class <= 0: raise ValueError("num_local_class must be > 0") if min_data < 0 or max_data < 0: raise ValueError("min_data/max_data must be >= 0") if max_data < min_data: raise ValueError("max_data must be >= min_data") if mode not in {"overlap", "disjoint"}: raise ValueError('mode must be "overlap" or "disjoint"') # ---- 1) Read image list ---- with open(trainset_path, "r", encoding="utf-8") as f: all_images_raw = [ln.strip() for ln in f if ln.strip()] # Normalize and deduplicate image paths (safe) all_images: List[str] = [] seen = set() for p in all_images_raw: # keep exact string (don’t join with cwd), just normalize slashes norm = os.path.normpath(p) if norm not in seen: seen.add(norm) all_images.append(norm) # ---- 2) Build mappings from labels ---- class_to_images: Dict[int, Set[str]] = defaultdict(set) image_to_classes: Dict[str, Set[int]] = {} missing_label_files = 0 empty_label_files = 0 parsed_images = 0 for img in all_images: lbl = _image_to_label_path(img) if not os.path.exists(lbl): # Missing labels: skip image (no class info) missing_label_files += 1 continue classes = _parse_yolo_label_file(lbl) if not classes: # No objects in this image -> skip (no class bucket) empty_label_files += 1 continue image_to_classes[img] = classes for c in classes: class_to_images[c].add(img) parsed_images += 1 if not class_to_images: # No usable images found return { "users": [f"c_{i + 1:05d}" for i in range(num_client)], "user_data": {f"c_{i + 1:05d}": {"filename": []} for i in range(num_client)}, "num_samples": [0 for _ in range(num_client)], } all_classes: List[int] = sorted(class_to_images.keys()) # Available pool for disjoint mode (only images with labels) available_images: Set[str] = set(image_to_classes.keys()) # ---- 3) Allocate to clients ---- result = {"users": [], "user_data": {}, "num_samples": []} for cid in range(num_client): user_id = f"c_{cid + 1:05d}" result["users"].append(user_id) # Pick the classes for this client (sample without replacement from global class set) k = min(num_local_class, len(all_classes)) chosen_classes = random.sample(all_classes, k) if k > 0 else [] # Decide how many samples for this client need = min_data if min_data == max_data else random.randint(min_data, max_data) # Build the candidate pool for this client if mode == "overlap": pool_set: Set[str] = set() for c in chosen_classes: pool_set.update(class_to_images[c]) else: # "disjoint": restrict to currently available images pool_set = set() for c in chosen_classes: # intersect with available images pool_set.update(class_to_images[c] & available_images) # Deduplicate and sample pool_list = list(pool_set) if len(pool_list) <= need: chosen_imgs = pool_list[:] # take all (can be fewer than need) else: chosen_imgs = random.sample(pool_list, need) # Record for the user result["user_data"][user_id] = {"filename": chosen_imgs} result["num_samples"].append(len(chosen_imgs)) # If disjoint, remove selected images from availability everywhere if mode == "disjoint" and chosen_imgs: for img in chosen_imgs: if img in available_images: available_images.remove(img) # remove from every class bucket this image belongs to for c in image_to_classes.get(img, []): if img in class_to_images[c]: class_to_images[c].remove(img) # Optional: prune empty classes from all_classes to speed up later loops # (keep list stable; just skip empties naturally) # (Optional) You can print some quick diagnostics if helpful: # print(f"[INFO] Parsed images with labels: {parsed_images}") # print(f"[INFO] Missing label files: {missing_label_files}") # print(f"[INFO] Empty label files: {empty_label_files}") return result def init_model(model_name, num_classes) -> YOLO: """ Initialize the model for a specific learning task Args: :param model_name: Name of the model :param num_classes: Number of classes """ model = None if model_name == "yolo_v11_n": model = nn.yolo_v11_n(num_classes=num_classes) elif model_name == "yolo_v11_s": model = nn.yolo_v11_s(num_classes=num_classes) elif model_name == "yolo_v11_m": model = nn.yolo_v11_m(num_classes=num_classes) elif model_name == "yolo_v11_l": model = nn.yolo_v11_l(num_classes=num_classes) elif model_name == "yolo_v11_x": model = nn.yolo_v11_x(num_classes=num_classes) else: raise ValueError("Model {} is not supported.".format(model_name)) return model def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]: """ Try to build a validation Dataset. - If cfg['val_txt'] exists, use it. - Else if /val.txt exists, use it. - Else return None (testing will be skipped). Args: cfg: config dict params: params dict for Dataset Returns: Dataset or None """ input_size = args.input_size if args and hasattr(args, "input_size") else 640 val_txt = cfg.get("val_txt", "") if not val_txt: ds_root = cfg.get("dataset_path", "") guess = os.path.join(ds_root, "val.txt") if ds_root else "" val_txt = guess if os.path.exists(guess) else "" val_files = _read_list_file(val_txt) if not val_files: import warnings warnings.warn("No validation dataset found.") return None return Dataset( filenames=val_files, input_size=input_size, params=params, augment=True, ) def seed_everything(seed: int): np.random.seed(seed) torch.manual_seed(seed) random.seed(seed) def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"): """ Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round. Args: save_dir: directory to save the plot hist: history dict with keys "mAP", "mAP50", "precision", "recall", "train_loss" savename: output filename """ os.makedirs(save_dir, exist_ok=True) rounds = np.arange(1, len(hist["mAP"]) + 1) plt.figure() if hist["mAP"]: plt.plot(rounds, hist["mAP"], label="mAP50-95") if hist["mAP50"]: plt.plot(rounds, hist["mAP50"], label="mAP50") if hist["precision"]: plt.plot(rounds, hist["precision"], label="precision") if hist["recall"]: plt.plot(rounds, hist["recall"], label="recall") if hist["train_loss"]: plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)") plt.xlabel("Global Round") plt.ylabel("Metric") plt.title("Federated YOLO - Server Metrics") plt.legend() out_png = os.path.join(save_dir, savename) plt.savefig(out_png, dpi=150, bbox_inches="tight") print(f"[plot] saved: {out_png}")