新增__init__.py文件,fed_util.py 结构优化

This commit is contained in:
TY1667
2025-10-19 21:29:58 +08:00
parent 3f4dd07572
commit 0343a0fd30
3 changed files with 91 additions and 1 deletions

View File

@@ -1,10 +1,15 @@
import os
import re
import random
import matplotlib.pyplot as plt
from utils.dataset import Dataset
import numpy as np
import torch
from collections import defaultdict
from typing import Dict, List, Optional, Set, Any
from nets import nn
from nets import YOLO
def _image_to_label_path(img_path: str) -> str:
@@ -59,6 +64,14 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]:
return class_ids
def _read_list_file(txt_path: str):
"""Read one path per line; keep as-is (absolute or relative)."""
if not txt_path or not os.path.exists(txt_path):
return []
with open(txt_path, "r", encoding="utf-8") as f:
return [ln.strip() for ln in f if ln.strip()]
def divide_trainset(
trainset_path: str,
num_local_class: int,
@@ -230,7 +243,7 @@ def divide_trainset(
return result
def init_model(model_name, num_classes):
def init_model(model_name, num_classes) -> YOLO:
"""
Initialize the model for a specific learning task
Args:
@@ -252,3 +265,74 @@ def init_model(model_name, num_classes):
raise ValueError("Model {} is not supported.".format(model_name))
return model
def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]:
"""
Try to build a validation Dataset.
- If cfg['val_txt'] exists, use it.
- Else if <dataset_path>/val.txt exists, use it.
- Else return None (testing will be skipped).
Args:
cfg: config dict
params: params dict for Dataset
Returns:
Dataset or None
"""
input_size = args.input_size if args and hasattr(args, "input_size") else 640
val_txt = cfg.get("val_txt", "")
if not val_txt:
ds_root = cfg.get("dataset_path", "")
guess = os.path.join(ds_root, "val.txt") if ds_root else ""
val_txt = guess if os.path.exists(guess) else ""
val_files = _read_list_file(val_txt)
if not val_files:
import warnings
warnings.warn("No validation dataset found.")
return None
return Dataset(
filenames=val_files,
input_size=input_size,
params=params,
augment=True,
)
def seed_everything(seed: int):
np.random.seed(seed)
torch.manual_seed(seed)
random.seed(seed)
def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"):
"""
Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round.
Args:
save_dir: directory to save the plot
hist: history dict with keys "mAP", "mAP50", "precision", "recall", "train_loss"
savename: output filename
"""
os.makedirs(save_dir, exist_ok=True)
rounds = np.arange(1, len(hist["mAP"]) + 1)
plt.figure()
if hist["mAP"]:
plt.plot(rounds, hist["mAP"], label="mAP50-95")
if hist["mAP50"]:
plt.plot(rounds, hist["mAP50"], label="mAP50")
if hist["precision"]:
plt.plot(rounds, hist["precision"], label="precision")
if hist["recall"]:
plt.plot(rounds, hist["recall"], label="recall")
if hist["train_loss"]:
plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)")
plt.xlabel("Global Round")
plt.ylabel("Metric")
plt.title("Federated YOLO - Server Metrics")
plt.legend()
out_png = os.path.join(save_dir, savename)
plt.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"[plot] saved: {out_png}")