新增__init__.py文件,fed_util.py 结构优化
This commit is contained in:
3
utils/__init__.py
Normal file
3
utils/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .dataset import *
|
||||
from .fed_util import *
|
||||
from .util import *
|
@@ -1,10 +1,15 @@
|
||||
import os
|
||||
import re
|
||||
import random
|
||||
import matplotlib.pyplot as plt
|
||||
from utils.dataset import Dataset
|
||||
import numpy as np
|
||||
import torch
|
||||
from collections import defaultdict
|
||||
from typing import Dict, List, Optional, Set, Any
|
||||
|
||||
from nets import nn
|
||||
from nets import YOLO
|
||||
|
||||
|
||||
def _image_to_label_path(img_path: str) -> str:
|
||||
@@ -59,6 +64,14 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]:
|
||||
return class_ids
|
||||
|
||||
|
||||
def _read_list_file(txt_path: str):
|
||||
"""Read one path per line; keep as-is (absolute or relative)."""
|
||||
if not txt_path or not os.path.exists(txt_path):
|
||||
return []
|
||||
with open(txt_path, "r", encoding="utf-8") as f:
|
||||
return [ln.strip() for ln in f if ln.strip()]
|
||||
|
||||
|
||||
def divide_trainset(
|
||||
trainset_path: str,
|
||||
num_local_class: int,
|
||||
@@ -230,7 +243,7 @@ def divide_trainset(
|
||||
return result
|
||||
|
||||
|
||||
def init_model(model_name, num_classes):
|
||||
def init_model(model_name, num_classes) -> YOLO:
|
||||
"""
|
||||
Initialize the model for a specific learning task
|
||||
Args:
|
||||
@@ -252,3 +265,74 @@ def init_model(model_name, num_classes):
|
||||
raise ValueError("Model {} is not supported.".format(model_name))
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]:
|
||||
"""
|
||||
Try to build a validation Dataset.
|
||||
- If cfg['val_txt'] exists, use it.
|
||||
- Else if <dataset_path>/val.txt exists, use it.
|
||||
- Else return None (testing will be skipped).
|
||||
Args:
|
||||
cfg: config dict
|
||||
params: params dict for Dataset
|
||||
Returns:
|
||||
Dataset or None
|
||||
"""
|
||||
input_size = args.input_size if args and hasattr(args, "input_size") else 640
|
||||
val_txt = cfg.get("val_txt", "")
|
||||
if not val_txt:
|
||||
ds_root = cfg.get("dataset_path", "")
|
||||
guess = os.path.join(ds_root, "val.txt") if ds_root else ""
|
||||
val_txt = guess if os.path.exists(guess) else ""
|
||||
|
||||
val_files = _read_list_file(val_txt)
|
||||
if not val_files:
|
||||
import warnings
|
||||
|
||||
warnings.warn("No validation dataset found.")
|
||||
return None
|
||||
|
||||
return Dataset(
|
||||
filenames=val_files,
|
||||
input_size=input_size,
|
||||
params=params,
|
||||
augment=True,
|
||||
)
|
||||
|
||||
|
||||
def seed_everything(seed: int):
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"):
|
||||
"""
|
||||
Plot mAP50-95, mAP50, precision, recall, and (optional) summed train loss per round.
|
||||
Args:
|
||||
save_dir: directory to save the plot
|
||||
hist: history dict with keys "mAP", "mAP50", "precision", "recall", "train_loss"
|
||||
savename: output filename
|
||||
"""
|
||||
os.makedirs(save_dir, exist_ok=True)
|
||||
rounds = np.arange(1, len(hist["mAP"]) + 1)
|
||||
|
||||
plt.figure()
|
||||
if hist["mAP"]:
|
||||
plt.plot(rounds, hist["mAP"], label="mAP50-95")
|
||||
if hist["mAP50"]:
|
||||
plt.plot(rounds, hist["mAP50"], label="mAP50")
|
||||
if hist["precision"]:
|
||||
plt.plot(rounds, hist["precision"], label="precision")
|
||||
if hist["recall"]:
|
||||
plt.plot(rounds, hist["recall"], label="recall")
|
||||
if hist["train_loss"]:
|
||||
plt.plot(rounds, hist["train_loss"], label="train_loss (sum of components)")
|
||||
plt.xlabel("Global Round")
|
||||
plt.ylabel("Metric")
|
||||
plt.title("Federated YOLO - Server Metrics")
|
||||
plt.legend()
|
||||
out_png = os.path.join(save_dir, savename)
|
||||
plt.savefig(out_png, dpi=150, bbox_inches="tight")
|
||||
print(f"[plot] saved: {out_png}")
|
||||
|
Reference in New Issue
Block a user