Compare commits

...

4 Commits

8 changed files with 157 additions and 80 deletions

View File

@@ -19,9 +19,6 @@ nohup bash fed_run.sh 1 > train.log 2>&1 &
- Implement FedProx - Implement FedProx
- Implement SCAFFOLD - Implement SCAFFOLD
- Implement FedNova - Implement FedNova
- Add more YOLO versions (e.g., YOLOv8, YOLOv5, etc.)
- Implement YOLOv8
- Implement YOLOv5
# references # references
[PyTorch Federated Learning](https://github.com/rruisong/pytorch_federated_learning) [PyTorch Federated Learning](https://github.com/rruisong/pytorch_federated_learning)

View File

@@ -17,8 +17,8 @@ local_batch_size: 32 # local training batch size
val_batch_size: 128 # validation batch size val_batch_size: 128 # validation batch size
num_workers: 8 # number of data loader workers num_workers: 8 # number of data loader workers
min_data: 1700 # minimum number of images per client min_data: 1800 # minimum number of images per client
max_data: 1800 # maximum number of images per client max_data: 1900 # maximum number of images per client
partition_mode: "overlap" # "overlap" or "disjoint" partition_mode: "overlap" # "overlap" or "disjoint"
connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients

View File

@@ -3,22 +3,22 @@ fed_algo: "FedAvg" # federated learning algorithm
model_name: "yolo_v11_n" # yolo_v11_n, yolo_v11_t, yolo_v11_s, yolo_v11_m, yolo_v11_l, yolo_v11_x model_name: "yolo_v11_n" # yolo_v11_n, yolo_v11_t, yolo_v11_s, yolo_v11_m, yolo_v11_l, yolo_v11_x
i_seed: 202509 # initial random seed i_seed: 202509 # initial random seed
num_client: 100 # total number of clients num_client: 36 # total number of clients
num_round: 500 # total number of communication rounds num_round: 50 # total number of communication rounds
num_local_class: 1 # number of classes per client num_local_class: 1 # number of classes per client
res_root: "results" # root directory for results res_root: "results" # root directory for results
dataset_path: "/home/image1325/ssd1/dataset/uav/" dataset_path: "/mnt/DATA/uav/"
# train_txt: "train.txt" # path to training set txt file # train_txt: "train.txt" # path to training set txt file
# val_txt: "val.txt" # path to validation set txt file # val_txt: "val.txt" # path to validation set txt file
# test_txt: "test.txt" # path to test set txt file # test_txt: "test.txt" # path to test set txt file
local_batch_size: 32 # local training batch size local_batch_size: 36 # local training batch size
val_batch_size: 16 # validation batch size val_batch_size: 128 # validation batch size
num_workers: 4 # number of data loader workers num_workers: 8 # number of data loader workers
min_data: 640 # minimum number of images per client min_data: 385 # minimum number of images per client
max_data: 720 # maximum number of images per client max_data: 400 # maximum number of images per client
partition_mode: "overlap" # "overlap" or "disjoint" partition_mode: "overlap" # "overlap" or "disjoint"
connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients

View File

@@ -64,7 +64,7 @@ class FedYoloClient(object):
""" """
Load the local training dataset Load the local training dataset
Args: Args:
:param train_dataset: Training dataset train_dataset: Training dataset
""" """
self.train_dataset = train_dataset self.train_dataset = train_dataset
self.n_data = len(self.train_dataset) self.n_data = len(self.train_dataset)
@@ -72,8 +72,9 @@ class FedYoloClient(object):
def update(self, Global_model_state_dict): def update(self, Global_model_state_dict):
""" """
Update the local model with the global model parameters Update the local model with the global model parameters
Args: Args:
:param Global_model_state_dict: State dictionary of the global model Global_model_state_dict: State dictionary of the global model
""" """
if not hasattr(self, "model") or self.model is None: if not hasattr(self, "model") or self.model is None:
@@ -85,7 +86,15 @@ class FedYoloClient(object):
def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]: def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
""" """
Train the local model. Train the local model.
Returns: (state_dict, n_data, avg_loss_per_image)
Args:
args: training arguments including
Returns:
(state_dict, n_data, avg_loss_per_image): A tuple including:
- state_dict: State dictionary of the trained local model
- n_data: Number of training data samples
- avg_loss_per_image: Average training loss per image over all epochs
""" """
# ---- Dist init (if any) ---- # ---- Dist init (if any) ----

View File

@@ -11,13 +11,13 @@ class FedYoloServer(object):
def __init__(self, client_list, model_name, params): def __init__(self, client_list, model_name, params):
""" """
Federated YOLO Server Federated YOLO Server
Args: Attributes:
client_list: list of connected clients client_list: list of connected clients
model_name: YOLO model architecture name model_name: YOLO model architecture name
params: dict of hyperparameters (must include 'names') params: dict of hyperparameters (must include 'names')
""" """
# Track client updates # Track client updates
self.client_state = {} self.client_state: dict[str, dict[str, torch.Tensor]] = {}
self.client_loss = {} self.client_loss = {}
self.client_n_data = {} self.client_n_data = {}
self.selected_clients = [] self.selected_clients = []
@@ -64,14 +64,19 @@ class FedYoloServer(object):
self.selected_clients.append(client_id) self.selected_clients.append(client_id)
self.n_data += self.client_n_data[client_id] self.n_data += self.client_n_data[client_id]
# TODO: skip the layer which can not be learnted locally
@torch.no_grad() @torch.no_grad()
def agg(self): def agg(self, skip_bn_layer: bool = False):
""" """
Server aggregates the local updates from selected clients using FedAvg. Server aggregates the local updates from selected clients using FedAvg.
:return: model_state: aggregated model weights Args:
:return: avg_loss: weighted average training loss across selected clients skip_bn_layer: whether to skip batch normalization layers during aggregation
:return: n_data: total number of data points across selected clients
Returns:
:model_state: aggregated model weights
:avg_loss: weighted average training loss across selected clients
:n_data: total number of data points across selected clients
""" """
if len(self.selected_clients) == 0 or self.n_data == 0: if len(self.selected_clients) == 0 or self.n_data == 0:
import warnings import warnings
@@ -144,11 +149,13 @@ class FedYoloServer(object):
def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]: def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]:
""" """
Evaluate the model on the validation dataset. Evaluate the model on the validation dataset.
Args: Args:
valset: validation dataset valset: validation dataset
params: dict of parameters (must include 'names') params: dict of parameters (must include 'names')
model: YOLO model to evaluate model: YOLO model to evaluate
batch_size: batch size for evaluation batch_size: batch size for evaluation
Returns: Returns:
dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap) dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap)
""" """
@@ -214,7 +221,9 @@ def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[f
# Compute metrics # Compute metrics
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
if len(metrics) and metrics[0].any(): if len(metrics) and metrics[0].any():
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"]) tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(
*metrics, plot=False, names=params["names"]
) # set plot=True to plot metric curve
# Print results # Print results
# print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap)) # print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
# Return results # Return results

View File

@@ -5,12 +5,16 @@ import yaml
import time import time
from tqdm import tqdm from tqdm import tqdm
import torch import torch
import csv
import copy
from utils.fed_util import build_valset_if_available, seed_everything, plot_curves from utils.fed_util import build_valset_if_available, seed_everything, plot_curves
from fed_algo_cs.client_base import FedYoloClient from fed_algo_cs.client_base import FedYoloClient
from fed_algo_cs.server_base import FedYoloServer from fed_algo_cs.server_base import FedYoloServer
from utils.args import args_parser # args parser from utils.args import args_parser # args parser
from utils.fed_util import divide_trainset # divide_trainset from utils.fed_util import divide_trainset # divide_trainset
from utils import util
from utils.fed_util import prepare_result_dir
def fed_run(): def fed_run():
@@ -26,11 +30,6 @@ def fed_run():
with open(args_cli.config, "r", encoding="utf-8") as f: with open(args_cli.config, "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) cfg = yaml.safe_load(f)
# --- params / config normalization ---
# For convenience we pass the same `params` dict used by Dataset/model/loss.
# Here we re-use the top-level cfg directly as params.
# params = dict(cfg)
if "names" in cfg and isinstance(cfg["names"], dict): if "names" in cfg and isinstance(cfg["names"], dict):
# Convert {0: 'uav', 1: 'car', ...} to list if you prefer list # Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
# but we can leave dict; your utils appear to accept dict # but we can leave dict; your utils appear to accept dict
@@ -39,6 +38,9 @@ def fed_run():
# seeds # seeds
seed_everything(int(cfg.get("i_seed", 0))) seed_everything(int(cfg.get("i_seed", 0)))
# result directory
res_root, weights_root = prepare_result_dir(base_root=cfg.get("res_root", "results"))
# --- split clients' train data from a global train list --- # --- split clients' train data from a global train list ---
# Expect either cfg["train_txt"] or <dataset_path>/train.txt # Expect either cfg["train_txt"] or <dataset_path>/train.txt
train_txt = cfg.get("train_txt", "") train_txt = cfg.get("train_txt", "")
@@ -67,7 +69,7 @@ def fed_run():
# --- build clients --- # --- build clients ---
model_name = cfg.get("model_name", "yolo_v11_n") model_name = cfg.get("model_name", "yolo_v11_n")
clients = {} clients: dict[str, FedYoloClient] = {}
for uid in users: for uid in users:
c = FedYoloClient(name=uid, model_name=model_name, params=cfg) c = FedYoloClient(name=uid, model_name=model_name, params=cfg)
@@ -84,9 +86,6 @@ def fed_run():
# --- push initial global weights --- # --- push initial global weights ---
global_state = server.state_dict() global_state = server.state_dict()
# --- args object for client.train() ---
# args_train = _make_args_for_client(cfg, args_cli)
# --- history recorder --- # --- history recorder ---
history = { history = {
"mAP": [], "mAP": [],
@@ -98,16 +97,16 @@ def fed_run():
} }
# --- main FL loop --- # --- main FL loop ---
best = 0.0 # best mAP
num_round = int(cfg.get("num_round", 50)) num_round = int(cfg.get("num_round", 50))
connection_ratio = float(cfg.get("connection_ratio", 1.0)) # e.g., 1.0 = all clients connection_ratio = float(cfg.get("connection_ratio", 1.0)) # e.g., 1.0 = all clients
res_root = cfg.get("res_root", "results")
os.makedirs(res_root, exist_ok=True)
# tqdm logging # tqdm logging
header = ("%10s" * 2) % ("Round", "client") header = ("%10s" * 2) % ("Round", "client")
tqdm.write("\n" + header) tqdm.write("\n" + header)
p_bar = tqdm(total=num_round, ncols=160, ascii="->>") p_bar = tqdm(total=num_round, ncols=160, ascii="->>")
# train loop
for rnd in range(num_round): for rnd in range(num_round):
t0 = time.time() t0 = time.time()
# Local training (sequential over all users) # Local training (sequential over all users)
@@ -115,7 +114,7 @@ def fed_run():
# tqdm desc update # tqdm desc update
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}")) p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
client = clients[uid] # FedYoloClient instance client: FedYoloClient = clients[uid] # FedYoloClient instance
client.update(global_state) # load global weights client.update(global_state) # load global weights
state_dict, n_data, train_loss = client.train(args_cli) # local training state_dict, n_data, train_loss = client.train(args_cli) # local training
server.rec(uid, state_dict, n_data, train_loss) server.rec(uid, state_dict, n_data, train_loss)
@@ -129,51 +128,82 @@ def fed_run():
# Compute a scalar train loss for plotting (sum of components) # Compute a scalar train loss for plotting (sum of components)
scalar_train_loss = avg_loss if avg_loss else 0.0 scalar_train_loss = avg_loss if avg_loss else 0.0
# Test (if valset provided) if args_cli.local_rank == 0:
mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0) # Test (if valset provided)
mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0)
# Flush per-round client caches if mAP > best:
server.flush() best = mAP
# Record & log # Flush per-round client caches
history["mAP"].append(mAP) server.flush()
history["mAP50"].append(mAP50)
history["precision"].append(precision)
history["recall"].append(recall)
history["train_loss"].append(scalar_train_loss)
history["round_time_sec"].append(time.time() - t0)
# Log GPU memory usage # Record & log
# gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G" history["mAP"].append(mAP)
# tqdm update history["mAP50"].append(mAP50)
desc = { history["precision"].append(precision)
"loss": f"{scalar_train_loss:.6g}", history["recall"].append(recall)
"mAP50": f"{mAP50:.6g}", history["train_loss"].append(scalar_train_loss)
"mAP": f"{mAP:.6g}", history["round_time_sec"].append(time.time() - t0)
"precision": f"{precision:.6g}",
"recall": f"{recall:.6g}",
# "gpu_mem": gpu_mem,
}
p_bar.set_postfix(desc)
# Save running JSON (resumable logs) # Log GPU memory usage
save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s" # gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
out_json = os.path.join(res_root, save_name + ".json") # tqdm update
with open(out_json, "w", encoding="utf-8") as f: desc = {
json.dump(history, f, indent=4) "loss": f"{scalar_train_loss:.6g}",
"mAP50": f"{mAP50:.6g}",
"mAP": f"{mAP:.6g}",
"precision": f"{precision:.6g}",
"recall": f"{recall:.6g}",
# "gpu_mem": gpu_mem,
}
p_bar.set_postfix(desc)
# Save running JSON (resumable logs)
# save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s"
# out_json = os.path.join(res_root, save_name + ".json")
# with open(out_json, "w", encoding="utf-8") as f:
# json.dump(history, f, indent=4)
# Use csv file to save running metrics
row = {
"round": rnd + 1,
"loss": f"{scalar_train_loss:.3f}",
"mAP": f"{mAP:.3f}",
"mAP50": f"{mAP50:.3f}",
"precision": f"{precision:.3f}",
"recall": f"{recall:.3f}",
"sec": f"{time.time() - t0:.1f}",
}
# log to csv
out_csv = os.path.join(res_root, "step.csv")
fieldnames = ["round", "loss", "mAP", "mAP50", "precision", "recall", "sec"]
mode = "w" if rnd == 0 else "a"
with open(file=out_csv, mode=mode, newline="", encoding="utf-8") as f:
writer = csv.DictWriter(f, fieldnames=fieldnames)
if rnd == 0:
writer.writeheader() # write header only once
writer.writerow(row)
# Save final global model weights
# FIXME: save model not adaptive YOLOv11-pt specific
save_model = {"config": cfg, "model": copy.deepcopy(global_state if global_state else None)}
torch.save(save_model, f"{weights_root}/last.pt")
if best == mAP:
torch.save(save_model, f"{weights_root}/best.pt")
del save_model
# print(f"[save] final global model weights: {weights_root}/last.pt")
p_bar.update(1) p_bar.update(1)
p_bar.close() p_bar.close()
# Save final global model weights if args_cli.local_rank == 0:
if not os.path.exists("./weights"): util.strip_optimizer(f"{weights_root}/best.pt")
os.makedirs("./weights", exist_ok=True) util.strip_optimizer(f"{weights_root}/last.pt")
torch.save(global_state, f"./weights/{save_name}_final.pth")
print(f"[save] final global model weights: ./weights/{save_name}_final.pth")
# --- final plot --- # --- final plot ---
plot_curves(res_root, history, savename=f"{save_name}_curve.png") plot_curves(res_root, history, savename="train_curve.png")
print("[done] training complete.") print("[done] training complete.")

View File

@@ -7,7 +7,7 @@ def args_parser():
parser.add_argument("--epochs", type=int, default=16, help="number of rounds of local training") parser.add_argument("--epochs", type=int, default=16, help="number of rounds of local training")
parser.add_argument("--input_size", type=int, default=640, help="image input size") parser.add_argument("--input_size", type=int, default=640, help="image input size")
parser.add_argument("--config", type=str, default="./config/coco_cfg.yaml", help="Path to YAML config") parser.add_argument("--config", type=str, default="./config/uav_cfg.yaml", help="Path to YAML config")
args = parser.parse_args() args = parser.parse_args()

View File

@@ -7,6 +7,7 @@ import numpy as np
import torch import torch
from collections import defaultdict from collections import defaultdict
from typing import Dict, List, Optional, Set, Any from typing import Dict, List, Optional, Set, Any
import time
from nets import nn from nets import nn
from nets import YOLO from nets import YOLO
@@ -30,8 +31,10 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]:
Return a set of class_ids found in a YOLO .txt label file. Return a set of class_ids found in a YOLO .txt label file.
Empty file -> empty set. Missing file -> empty set. Empty file -> empty set. Missing file -> empty set.
Robust to blank lines / trailing spaces. Robust to blank lines / trailing spaces.
Args: Args:
label_path: path to the label file label_path: path to the label file
Returns: Returns:
set of class IDs (integers) found in the file set of class IDs (integers) found in the file
""" """
@@ -85,7 +88,7 @@ def divide_trainset(
Build a federated split from a YOLO dataset list file. Build a federated split from a YOLO dataset list file.
Args: Args:
trainset_path: path to a .txt file containing one image path per line trainset_path (str): path to a .txt file containing one image path per line
e.g. /COCO/images/train2017/1111.jpg e.g. /COCO/images/train2017/1111.jpg
num_local_class: how many distinct classes to sample for each client num_local_class: how many distinct classes to sample for each client
num_client: number of clients num_client: number of clients
@@ -95,7 +98,9 @@ def divide_trainset(
"disjoint" -> each image is used by at most one client "disjoint" -> each image is used by at most one client
seed: optional random seed for reproducibility seed: optional random seed for reproducibility
Returns: Returns::
>>> \\
trainset_divided = { trainset_divided = {
"users": ["c_00001", ...], "users": ["c_00001", ...],
"user_data": { "user_data": {
@@ -105,7 +110,9 @@ def divide_trainset(
"num_samples": [len(list_for_user1), len(list_for_user2), ...] "num_samples": [len(list_for_user1), len(list_for_user2), ...]
} }
Example: Example::
>>> \\
dataset = divide_trainset( dataset = divide_trainset(
trainset_path="/COCO/train2017.txt", trainset_path="/COCO/train2017.txt",
num_local_class=3, num_local_class=3,
@@ -114,11 +121,11 @@ def divide_trainset(
max_data=20, max_data=20,
mode="disjoint", # or "overlap" mode="disjoint", # or "overlap"
seed=42 seed=42
) )
print(dataset["users"]) # ['c_00001', ..., 'c_00005'] >>> print(dataset["users"]) # ['c_00001', ..., 'c_00005']
print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15] >>> print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
print(dataset["user_data"]["c_00001"]["filename"][:3]) >>> print(dataset["user_data"]["c_00001"]["filename"][:3])
""" """
if seed is not None: if seed is not None:
random.seed(seed) random.seed(seed)
@@ -247,8 +254,11 @@ def init_model(model_name, num_classes) -> YOLO:
""" """
Initialize the model for a specific learning task Initialize the model for a specific learning task
Args: Args:
:param model_name: Name of the model model_name: Name of the model
:param num_classes: Number of classes num_classes: Number of classes
Returns:
model: YOLO model instance
""" """
model = None model = None
if model_name == "yolo_v11_n": if model_name == "yolo_v11_n":
@@ -273,11 +283,13 @@ def build_valset_if_available(cfg, params, args=None, val_name: str = "val2017")
- If cfg['val_txt'] exists, use it. - If cfg['val_txt'] exists, use it.
- Else if <dataset_path>/val.txt exists, use it. - Else if <dataset_path>/val.txt exists, use it.
- Else return None (testing will be skipped). - Else return None (testing will be skipped).
Args: Args:
cfg: config dict cfg: config dict
params: params dict for Dataset params: params dict for Dataset
args: optional args object (for input_size) args: optional args object (for input_size)
val_name: name of the validation set folder with no prefix (default: "val2017") val_name: name of the validation set folder with no prefix (default: "val2017")
Returns: Returns:
Dataset or None Dataset or None
""" """
@@ -344,3 +356,23 @@ def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"):
out_png = os.path.join(save_dir, savename) out_png = os.path.join(save_dir, savename)
plt.savefig(out_png, dpi=150, bbox_inches="tight") plt.savefig(out_png, dpi=150, bbox_inches="tight")
print(f"[plot] saved: {out_png}") print(f"[plot] saved: {out_png}")
def prepare_result_dir(base_root: str = "results"):
"""
Prepare result directories for saving outputs.
Args:
base_root (str): base directory for results.
Returns:
(res_dir, weights_dir) (str,str): Path to result directory and weights directory.
"""
os.makedirs(base_root, exist_ok=True)
timestamp = time.strftime("%Y%m%d_%H%M%S")
res_dir = os.path.join(base_root, f"result_{timestamp}")
weights_dir = os.path.join(res_dir, f"weight_{timestamp}")
os.makedirs(res_dir, exist_ok=True)
os.makedirs(weights_dir, exist_ok=True)
print(f"[INFO] Saving results to: {res_dir}")
return res_dir, weights_dir