Compare commits
4 Commits
c7a3d34d88
...
fdb70869f9
| Author | SHA1 | Date | |
|---|---|---|---|
| fdb70869f9 | |||
| 291b82bec3 | |||
| c7afef2dc2 | |||
| 194ca8ee31 |
@@ -19,9 +19,6 @@ nohup bash fed_run.sh 1 > train.log 2>&1 &
|
|||||||
- Implement FedProx
|
- Implement FedProx
|
||||||
- Implement SCAFFOLD
|
- Implement SCAFFOLD
|
||||||
- Implement FedNova
|
- Implement FedNova
|
||||||
- Add more YOLO versions (e.g., YOLOv8, YOLOv5, etc.)
|
|
||||||
- Implement YOLOv8
|
|
||||||
- Implement YOLOv5
|
|
||||||
|
|
||||||
# references
|
# references
|
||||||
[PyTorch Federated Learning](https://github.com/rruisong/pytorch_federated_learning)
|
[PyTorch Federated Learning](https://github.com/rruisong/pytorch_federated_learning)
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ local_batch_size: 32 # local training batch size
|
|||||||
val_batch_size: 128 # validation batch size
|
val_batch_size: 128 # validation batch size
|
||||||
|
|
||||||
num_workers: 8 # number of data loader workers
|
num_workers: 8 # number of data loader workers
|
||||||
min_data: 1700 # minimum number of images per client
|
min_data: 1800 # minimum number of images per client
|
||||||
max_data: 1800 # maximum number of images per client
|
max_data: 1900 # maximum number of images per client
|
||||||
partition_mode: "overlap" # "overlap" or "disjoint"
|
partition_mode: "overlap" # "overlap" or "disjoint"
|
||||||
connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients
|
connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients
|
||||||
|
|
||||||
|
|||||||
@@ -3,22 +3,22 @@ fed_algo: "FedAvg" # federated learning algorithm
|
|||||||
model_name: "yolo_v11_n" # yolo_v11_n, yolo_v11_t, yolo_v11_s, yolo_v11_m, yolo_v11_l, yolo_v11_x
|
model_name: "yolo_v11_n" # yolo_v11_n, yolo_v11_t, yolo_v11_s, yolo_v11_m, yolo_v11_l, yolo_v11_x
|
||||||
i_seed: 202509 # initial random seed
|
i_seed: 202509 # initial random seed
|
||||||
|
|
||||||
num_client: 100 # total number of clients
|
num_client: 36 # total number of clients
|
||||||
num_round: 500 # total number of communication rounds
|
num_round: 50 # total number of communication rounds
|
||||||
num_local_class: 1 # number of classes per client
|
num_local_class: 1 # number of classes per client
|
||||||
|
|
||||||
res_root: "results" # root directory for results
|
res_root: "results" # root directory for results
|
||||||
dataset_path: "/home/image1325/ssd1/dataset/uav/"
|
dataset_path: "/mnt/DATA/uav/"
|
||||||
# train_txt: "train.txt" # path to training set txt file
|
# train_txt: "train.txt" # path to training set txt file
|
||||||
# val_txt: "val.txt" # path to validation set txt file
|
# val_txt: "val.txt" # path to validation set txt file
|
||||||
# test_txt: "test.txt" # path to test set txt file
|
# test_txt: "test.txt" # path to test set txt file
|
||||||
|
|
||||||
local_batch_size: 32 # local training batch size
|
local_batch_size: 36 # local training batch size
|
||||||
val_batch_size: 16 # validation batch size
|
val_batch_size: 128 # validation batch size
|
||||||
|
|
||||||
num_workers: 4 # number of data loader workers
|
num_workers: 8 # number of data loader workers
|
||||||
min_data: 640 # minimum number of images per client
|
min_data: 385 # minimum number of images per client
|
||||||
max_data: 720 # maximum number of images per client
|
max_data: 400 # maximum number of images per client
|
||||||
partition_mode: "overlap" # "overlap" or "disjoint"
|
partition_mode: "overlap" # "overlap" or "disjoint"
|
||||||
connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients
|
connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients
|
||||||
|
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ class FedYoloClient(object):
|
|||||||
"""
|
"""
|
||||||
Load the local training dataset
|
Load the local training dataset
|
||||||
Args:
|
Args:
|
||||||
:param train_dataset: Training dataset
|
train_dataset: Training dataset
|
||||||
"""
|
"""
|
||||||
self.train_dataset = train_dataset
|
self.train_dataset = train_dataset
|
||||||
self.n_data = len(self.train_dataset)
|
self.n_data = len(self.train_dataset)
|
||||||
@@ -72,8 +72,9 @@ class FedYoloClient(object):
|
|||||||
def update(self, Global_model_state_dict):
|
def update(self, Global_model_state_dict):
|
||||||
"""
|
"""
|
||||||
Update the local model with the global model parameters
|
Update the local model with the global model parameters
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
:param Global_model_state_dict: State dictionary of the global model
|
Global_model_state_dict: State dictionary of the global model
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not hasattr(self, "model") or self.model is None:
|
if not hasattr(self, "model") or self.model is None:
|
||||||
@@ -85,7 +86,15 @@ class FedYoloClient(object):
|
|||||||
def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
|
def train(self, args) -> tuple[dict[str, torch.Tensor], int, float]:
|
||||||
"""
|
"""
|
||||||
Train the local model.
|
Train the local model.
|
||||||
Returns: (state_dict, n_data, avg_loss_per_image)
|
|
||||||
|
Args:
|
||||||
|
args: training arguments including
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(state_dict, n_data, avg_loss_per_image): A tuple including:
|
||||||
|
- state_dict: State dictionary of the trained local model
|
||||||
|
- n_data: Number of training data samples
|
||||||
|
- avg_loss_per_image: Average training loss per image over all epochs
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ---- Dist init (if any) ----
|
# ---- Dist init (if any) ----
|
||||||
|
|||||||
@@ -11,13 +11,13 @@ class FedYoloServer(object):
|
|||||||
def __init__(self, client_list, model_name, params):
|
def __init__(self, client_list, model_name, params):
|
||||||
"""
|
"""
|
||||||
Federated YOLO Server
|
Federated YOLO Server
|
||||||
Args:
|
Attributes:
|
||||||
client_list: list of connected clients
|
client_list: list of connected clients
|
||||||
model_name: YOLO model architecture name
|
model_name: YOLO model architecture name
|
||||||
params: dict of hyperparameters (must include 'names')
|
params: dict of hyperparameters (must include 'names')
|
||||||
"""
|
"""
|
||||||
# Track client updates
|
# Track client updates
|
||||||
self.client_state = {}
|
self.client_state: dict[str, dict[str, torch.Tensor]] = {}
|
||||||
self.client_loss = {}
|
self.client_loss = {}
|
||||||
self.client_n_data = {}
|
self.client_n_data = {}
|
||||||
self.selected_clients = []
|
self.selected_clients = []
|
||||||
@@ -64,14 +64,19 @@ class FedYoloServer(object):
|
|||||||
self.selected_clients.append(client_id)
|
self.selected_clients.append(client_id)
|
||||||
self.n_data += self.client_n_data[client_id]
|
self.n_data += self.client_n_data[client_id]
|
||||||
|
|
||||||
|
# TODO: skip the layer which can not be learnted locally
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def agg(self):
|
def agg(self, skip_bn_layer: bool = False):
|
||||||
"""
|
"""
|
||||||
Server aggregates the local updates from selected clients using FedAvg.
|
Server aggregates the local updates from selected clients using FedAvg.
|
||||||
|
|
||||||
:return: model_state: aggregated model weights
|
Args:
|
||||||
:return: avg_loss: weighted average training loss across selected clients
|
skip_bn_layer: whether to skip batch normalization layers during aggregation
|
||||||
:return: n_data: total number of data points across selected clients
|
|
||||||
|
Returns:
|
||||||
|
:model_state: aggregated model weights
|
||||||
|
:avg_loss: weighted average training loss across selected clients
|
||||||
|
:n_data: total number of data points across selected clients
|
||||||
"""
|
"""
|
||||||
if len(self.selected_clients) == 0 or self.n_data == 0:
|
if len(self.selected_clients) == 0 or self.n_data == 0:
|
||||||
import warnings
|
import warnings
|
||||||
@@ -144,11 +149,13 @@ class FedYoloServer(object):
|
|||||||
def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]:
|
def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[float, float, float, float]:
|
||||||
"""
|
"""
|
||||||
Evaluate the model on the validation dataset.
|
Evaluate the model on the validation dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
valset: validation dataset
|
valset: validation dataset
|
||||||
params: dict of parameters (must include 'names')
|
params: dict of parameters (must include 'names')
|
||||||
model: YOLO model to evaluate
|
model: YOLO model to evaluate
|
||||||
batch_size: batch size for evaluation
|
batch_size: batch size for evaluation
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap)
|
dict with evaluation metrics (tp, fp, m_pre, m_rec, map50, mean_ap)
|
||||||
"""
|
"""
|
||||||
@@ -214,7 +221,9 @@ def test(valset: Dataset, params, model: YOLO, batch_size: int = 200) -> tuple[f
|
|||||||
# Compute metrics
|
# Compute metrics
|
||||||
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
|
metrics = [torch.cat(x, dim=0).cpu().numpy() for x in zip(*metrics)] # to numpy
|
||||||
if len(metrics) and metrics[0].any():
|
if len(metrics) and metrics[0].any():
|
||||||
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(*metrics, plot=False, names=params["names"])
|
tp, fp, m_pre, m_rec, map50, mean_ap = util.compute_ap(
|
||||||
|
*metrics, plot=False, names=params["names"]
|
||||||
|
) # set plot=True to plot metric curve
|
||||||
# Print results
|
# Print results
|
||||||
# print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
|
# print(("%10s" + "%10.3g" * 4) % ("", m_pre, m_rec, map50, mean_ap))
|
||||||
# Return results
|
# Return results
|
||||||
|
|||||||
124
fed_run.py
124
fed_run.py
@@ -5,12 +5,16 @@ import yaml
|
|||||||
import time
|
import time
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
import torch
|
import torch
|
||||||
|
import csv
|
||||||
|
import copy
|
||||||
|
|
||||||
from utils.fed_util import build_valset_if_available, seed_everything, plot_curves
|
from utils.fed_util import build_valset_if_available, seed_everything, plot_curves
|
||||||
from fed_algo_cs.client_base import FedYoloClient
|
from fed_algo_cs.client_base import FedYoloClient
|
||||||
from fed_algo_cs.server_base import FedYoloServer
|
from fed_algo_cs.server_base import FedYoloServer
|
||||||
from utils.args import args_parser # args parser
|
from utils.args import args_parser # args parser
|
||||||
from utils.fed_util import divide_trainset # divide_trainset
|
from utils.fed_util import divide_trainset # divide_trainset
|
||||||
|
from utils import util
|
||||||
|
from utils.fed_util import prepare_result_dir
|
||||||
|
|
||||||
|
|
||||||
def fed_run():
|
def fed_run():
|
||||||
@@ -26,11 +30,6 @@ def fed_run():
|
|||||||
with open(args_cli.config, "r", encoding="utf-8") as f:
|
with open(args_cli.config, "r", encoding="utf-8") as f:
|
||||||
cfg = yaml.safe_load(f)
|
cfg = yaml.safe_load(f)
|
||||||
|
|
||||||
# --- params / config normalization ---
|
|
||||||
# For convenience we pass the same `params` dict used by Dataset/model/loss.
|
|
||||||
# Here we re-use the top-level cfg directly as params.
|
|
||||||
# params = dict(cfg)
|
|
||||||
|
|
||||||
if "names" in cfg and isinstance(cfg["names"], dict):
|
if "names" in cfg and isinstance(cfg["names"], dict):
|
||||||
# Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
|
# Convert {0: 'uav', 1: 'car', ...} to list if you prefer list
|
||||||
# but we can leave dict; your utils appear to accept dict
|
# but we can leave dict; your utils appear to accept dict
|
||||||
@@ -39,6 +38,9 @@ def fed_run():
|
|||||||
# seeds
|
# seeds
|
||||||
seed_everything(int(cfg.get("i_seed", 0)))
|
seed_everything(int(cfg.get("i_seed", 0)))
|
||||||
|
|
||||||
|
# result directory
|
||||||
|
res_root, weights_root = prepare_result_dir(base_root=cfg.get("res_root", "results"))
|
||||||
|
|
||||||
# --- split clients' train data from a global train list ---
|
# --- split clients' train data from a global train list ---
|
||||||
# Expect either cfg["train_txt"] or <dataset_path>/train.txt
|
# Expect either cfg["train_txt"] or <dataset_path>/train.txt
|
||||||
train_txt = cfg.get("train_txt", "")
|
train_txt = cfg.get("train_txt", "")
|
||||||
@@ -67,7 +69,7 @@ def fed_run():
|
|||||||
|
|
||||||
# --- build clients ---
|
# --- build clients ---
|
||||||
model_name = cfg.get("model_name", "yolo_v11_n")
|
model_name = cfg.get("model_name", "yolo_v11_n")
|
||||||
clients = {}
|
clients: dict[str, FedYoloClient] = {}
|
||||||
|
|
||||||
for uid in users:
|
for uid in users:
|
||||||
c = FedYoloClient(name=uid, model_name=model_name, params=cfg)
|
c = FedYoloClient(name=uid, model_name=model_name, params=cfg)
|
||||||
@@ -84,9 +86,6 @@ def fed_run():
|
|||||||
# --- push initial global weights ---
|
# --- push initial global weights ---
|
||||||
global_state = server.state_dict()
|
global_state = server.state_dict()
|
||||||
|
|
||||||
# --- args object for client.train() ---
|
|
||||||
# args_train = _make_args_for_client(cfg, args_cli)
|
|
||||||
|
|
||||||
# --- history recorder ---
|
# --- history recorder ---
|
||||||
history = {
|
history = {
|
||||||
"mAP": [],
|
"mAP": [],
|
||||||
@@ -98,16 +97,16 @@ def fed_run():
|
|||||||
}
|
}
|
||||||
|
|
||||||
# --- main FL loop ---
|
# --- main FL loop ---
|
||||||
|
best = 0.0 # best mAP
|
||||||
num_round = int(cfg.get("num_round", 50))
|
num_round = int(cfg.get("num_round", 50))
|
||||||
connection_ratio = float(cfg.get("connection_ratio", 1.0)) # e.g., 1.0 = all clients
|
connection_ratio = float(cfg.get("connection_ratio", 1.0)) # e.g., 1.0 = all clients
|
||||||
res_root = cfg.get("res_root", "results")
|
|
||||||
os.makedirs(res_root, exist_ok=True)
|
|
||||||
|
|
||||||
# tqdm logging
|
# tqdm logging
|
||||||
header = ("%10s" * 2) % ("Round", "client")
|
header = ("%10s" * 2) % ("Round", "client")
|
||||||
tqdm.write("\n" + header)
|
tqdm.write("\n" + header)
|
||||||
p_bar = tqdm(total=num_round, ncols=160, ascii="->>")
|
p_bar = tqdm(total=num_round, ncols=160, ascii="->>")
|
||||||
|
|
||||||
|
# train loop
|
||||||
for rnd in range(num_round):
|
for rnd in range(num_round):
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
# Local training (sequential over all users)
|
# Local training (sequential over all users)
|
||||||
@@ -115,7 +114,7 @@ def fed_run():
|
|||||||
# tqdm desc update
|
# tqdm desc update
|
||||||
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
|
p_bar.set_description_str(("%10s" * 2) % (f"{rnd + 1}/{num_round}", f"{uid}"))
|
||||||
|
|
||||||
client = clients[uid] # FedYoloClient instance
|
client: FedYoloClient = clients[uid] # FedYoloClient instance
|
||||||
client.update(global_state) # load global weights
|
client.update(global_state) # load global weights
|
||||||
state_dict, n_data, train_loss = client.train(args_cli) # local training
|
state_dict, n_data, train_loss = client.train(args_cli) # local training
|
||||||
server.rec(uid, state_dict, n_data, train_loss)
|
server.rec(uid, state_dict, n_data, train_loss)
|
||||||
@@ -129,51 +128,82 @@ def fed_run():
|
|||||||
# Compute a scalar train loss for plotting (sum of components)
|
# Compute a scalar train loss for plotting (sum of components)
|
||||||
scalar_train_loss = avg_loss if avg_loss else 0.0
|
scalar_train_loss = avg_loss if avg_loss else 0.0
|
||||||
|
|
||||||
# Test (if valset provided)
|
if args_cli.local_rank == 0:
|
||||||
mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0)
|
# Test (if valset provided)
|
||||||
|
mAP, mAP50, recall, precision = server.test() if server.valset is not None else (0.0, 0.0, 0.0, 0.0)
|
||||||
|
|
||||||
# Flush per-round client caches
|
if mAP > best:
|
||||||
server.flush()
|
best = mAP
|
||||||
|
|
||||||
# Record & log
|
# Flush per-round client caches
|
||||||
history["mAP"].append(mAP)
|
server.flush()
|
||||||
history["mAP50"].append(mAP50)
|
|
||||||
history["precision"].append(precision)
|
|
||||||
history["recall"].append(recall)
|
|
||||||
history["train_loss"].append(scalar_train_loss)
|
|
||||||
history["round_time_sec"].append(time.time() - t0)
|
|
||||||
|
|
||||||
# Log GPU memory usage
|
# Record & log
|
||||||
# gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
|
history["mAP"].append(mAP)
|
||||||
# tqdm update
|
history["mAP50"].append(mAP50)
|
||||||
desc = {
|
history["precision"].append(precision)
|
||||||
"loss": f"{scalar_train_loss:.6g}",
|
history["recall"].append(recall)
|
||||||
"mAP50": f"{mAP50:.6g}",
|
history["train_loss"].append(scalar_train_loss)
|
||||||
"mAP": f"{mAP:.6g}",
|
history["round_time_sec"].append(time.time() - t0)
|
||||||
"precision": f"{precision:.6g}",
|
|
||||||
"recall": f"{recall:.6g}",
|
|
||||||
# "gpu_mem": gpu_mem,
|
|
||||||
}
|
|
||||||
p_bar.set_postfix(desc)
|
|
||||||
|
|
||||||
# Save running JSON (resumable logs)
|
# Log GPU memory usage
|
||||||
save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s"
|
# gpu_mem = f"{torch.cuda.memory_reserved() / 1e9:.2f}G" if torch.cuda.is_available() else "0.00G"
|
||||||
out_json = os.path.join(res_root, save_name + ".json")
|
# tqdm update
|
||||||
with open(out_json, "w", encoding="utf-8") as f:
|
desc = {
|
||||||
json.dump(history, f, indent=4)
|
"loss": f"{scalar_train_loss:.6g}",
|
||||||
|
"mAP50": f"{mAP50:.6g}",
|
||||||
|
"mAP": f"{mAP:.6g}",
|
||||||
|
"precision": f"{precision:.6g}",
|
||||||
|
"recall": f"{recall:.6g}",
|
||||||
|
# "gpu_mem": gpu_mem,
|
||||||
|
}
|
||||||
|
p_bar.set_postfix(desc)
|
||||||
|
|
||||||
|
# Save running JSON (resumable logs)
|
||||||
|
# save_name = f"{cfg.get('fed_algo', 'FedAvg')}_{[cfg.get('model_name', 'yolo')]}_{cfg.get('num_client', 0)}c_{cfg.get('num_local_class', 1)}cls_{cfg.get('num_round', 0)}r_{cfg.get('connection_ratio', 1):.2f}cr_{cfg.get('i_seed', 0)}s"
|
||||||
|
|
||||||
|
# out_json = os.path.join(res_root, save_name + ".json")
|
||||||
|
# with open(out_json, "w", encoding="utf-8") as f:
|
||||||
|
# json.dump(history, f, indent=4)
|
||||||
|
|
||||||
|
# Use csv file to save running metrics
|
||||||
|
row = {
|
||||||
|
"round": rnd + 1,
|
||||||
|
"loss": f"{scalar_train_loss:.3f}",
|
||||||
|
"mAP": f"{mAP:.3f}",
|
||||||
|
"mAP50": f"{mAP50:.3f}",
|
||||||
|
"precision": f"{precision:.3f}",
|
||||||
|
"recall": f"{recall:.3f}",
|
||||||
|
"sec": f"{time.time() - t0:.1f}",
|
||||||
|
}
|
||||||
|
|
||||||
|
# log to csv
|
||||||
|
out_csv = os.path.join(res_root, "step.csv")
|
||||||
|
fieldnames = ["round", "loss", "mAP", "mAP50", "precision", "recall", "sec"]
|
||||||
|
mode = "w" if rnd == 0 else "a"
|
||||||
|
with open(file=out_csv, mode=mode, newline="", encoding="utf-8") as f:
|
||||||
|
writer = csv.DictWriter(f, fieldnames=fieldnames)
|
||||||
|
if rnd == 0:
|
||||||
|
writer.writeheader() # write header only once
|
||||||
|
writer.writerow(row)
|
||||||
|
|
||||||
|
# Save final global model weights
|
||||||
|
# FIXME: save model not adaptive YOLOv11-pt specific
|
||||||
|
save_model = {"config": cfg, "model": copy.deepcopy(global_state if global_state else None)}
|
||||||
|
torch.save(save_model, f"{weights_root}/last.pt")
|
||||||
|
if best == mAP:
|
||||||
|
torch.save(save_model, f"{weights_root}/best.pt")
|
||||||
|
del save_model
|
||||||
|
# print(f"[save] final global model weights: {weights_root}/last.pt")
|
||||||
p_bar.update(1)
|
p_bar.update(1)
|
||||||
|
|
||||||
p_bar.close()
|
p_bar.close()
|
||||||
|
|
||||||
# Save final global model weights
|
if args_cli.local_rank == 0:
|
||||||
if not os.path.exists("./weights"):
|
util.strip_optimizer(f"{weights_root}/best.pt")
|
||||||
os.makedirs("./weights", exist_ok=True)
|
util.strip_optimizer(f"{weights_root}/last.pt")
|
||||||
torch.save(global_state, f"./weights/{save_name}_final.pth")
|
|
||||||
print(f"[save] final global model weights: ./weights/{save_name}_final.pth")
|
|
||||||
|
|
||||||
# --- final plot ---
|
# --- final plot ---
|
||||||
plot_curves(res_root, history, savename=f"{save_name}_curve.png")
|
plot_curves(res_root, history, savename="train_curve.png")
|
||||||
print("[done] training complete.")
|
print("[done] training complete.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ def args_parser():
|
|||||||
|
|
||||||
parser.add_argument("--epochs", type=int, default=16, help="number of rounds of local training")
|
parser.add_argument("--epochs", type=int, default=16, help="number of rounds of local training")
|
||||||
parser.add_argument("--input_size", type=int, default=640, help="image input size")
|
parser.add_argument("--input_size", type=int, default=640, help="image input size")
|
||||||
parser.add_argument("--config", type=str, default="./config/coco_cfg.yaml", help="Path to YAML config")
|
parser.add_argument("--config", type=str, default="./config/uav_cfg.yaml", help="Path to YAML config")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import numpy as np
|
|||||||
import torch
|
import torch
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import Dict, List, Optional, Set, Any
|
from typing import Dict, List, Optional, Set, Any
|
||||||
|
import time
|
||||||
|
|
||||||
from nets import nn
|
from nets import nn
|
||||||
from nets import YOLO
|
from nets import YOLO
|
||||||
@@ -30,8 +31,10 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]:
|
|||||||
Return a set of class_ids found in a YOLO .txt label file.
|
Return a set of class_ids found in a YOLO .txt label file.
|
||||||
Empty file -> empty set. Missing file -> empty set.
|
Empty file -> empty set. Missing file -> empty set.
|
||||||
Robust to blank lines / trailing spaces.
|
Robust to blank lines / trailing spaces.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
label_path: path to the label file
|
label_path: path to the label file
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
set of class IDs (integers) found in the file
|
set of class IDs (integers) found in the file
|
||||||
"""
|
"""
|
||||||
@@ -85,7 +88,7 @@ def divide_trainset(
|
|||||||
Build a federated split from a YOLO dataset list file.
|
Build a federated split from a YOLO dataset list file.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
trainset_path: path to a .txt file containing one image path per line
|
trainset_path (str): path to a .txt file containing one image path per line
|
||||||
e.g. /COCO/images/train2017/1111.jpg
|
e.g. /COCO/images/train2017/1111.jpg
|
||||||
num_local_class: how many distinct classes to sample for each client
|
num_local_class: how many distinct classes to sample for each client
|
||||||
num_client: number of clients
|
num_client: number of clients
|
||||||
@@ -95,7 +98,9 @@ def divide_trainset(
|
|||||||
"disjoint" -> each image is used by at most one client
|
"disjoint" -> each image is used by at most one client
|
||||||
seed: optional random seed for reproducibility
|
seed: optional random seed for reproducibility
|
||||||
|
|
||||||
Returns:
|
Returns::
|
||||||
|
|
||||||
|
>>> \\
|
||||||
trainset_divided = {
|
trainset_divided = {
|
||||||
"users": ["c_00001", ...],
|
"users": ["c_00001", ...],
|
||||||
"user_data": {
|
"user_data": {
|
||||||
@@ -105,7 +110,9 @@ def divide_trainset(
|
|||||||
"num_samples": [len(list_for_user1), len(list_for_user2), ...]
|
"num_samples": [len(list_for_user1), len(list_for_user2), ...]
|
||||||
}
|
}
|
||||||
|
|
||||||
Example:
|
Example::
|
||||||
|
|
||||||
|
>>> \\
|
||||||
dataset = divide_trainset(
|
dataset = divide_trainset(
|
||||||
trainset_path="/COCO/train2017.txt",
|
trainset_path="/COCO/train2017.txt",
|
||||||
num_local_class=3,
|
num_local_class=3,
|
||||||
@@ -114,11 +121,11 @@ def divide_trainset(
|
|||||||
max_data=20,
|
max_data=20,
|
||||||
mode="disjoint", # or "overlap"
|
mode="disjoint", # or "overlap"
|
||||||
seed=42
|
seed=42
|
||||||
)
|
)
|
||||||
|
|
||||||
print(dataset["users"]) # ['c_00001', ..., 'c_00005']
|
>>> print(dataset["users"]) # ['c_00001', ..., 'c_00005']
|
||||||
print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
|
>>> print(dataset["num_samples"]) # e.g. [10, 12, 18, 9, 15]
|
||||||
print(dataset["user_data"]["c_00001"]["filename"][:3])
|
>>> print(dataset["user_data"]["c_00001"]["filename"][:3])
|
||||||
"""
|
"""
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
random.seed(seed)
|
random.seed(seed)
|
||||||
@@ -247,8 +254,11 @@ def init_model(model_name, num_classes) -> YOLO:
|
|||||||
"""
|
"""
|
||||||
Initialize the model for a specific learning task
|
Initialize the model for a specific learning task
|
||||||
Args:
|
Args:
|
||||||
:param model_name: Name of the model
|
model_name: Name of the model
|
||||||
:param num_classes: Number of classes
|
num_classes: Number of classes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
model: YOLO model instance
|
||||||
"""
|
"""
|
||||||
model = None
|
model = None
|
||||||
if model_name == "yolo_v11_n":
|
if model_name == "yolo_v11_n":
|
||||||
@@ -273,11 +283,13 @@ def build_valset_if_available(cfg, params, args=None, val_name: str = "val2017")
|
|||||||
- If cfg['val_txt'] exists, use it.
|
- If cfg['val_txt'] exists, use it.
|
||||||
- Else if <dataset_path>/val.txt exists, use it.
|
- Else if <dataset_path>/val.txt exists, use it.
|
||||||
- Else return None (testing will be skipped).
|
- Else return None (testing will be skipped).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
cfg: config dict
|
cfg: config dict
|
||||||
params: params dict for Dataset
|
params: params dict for Dataset
|
||||||
args: optional args object (for input_size)
|
args: optional args object (for input_size)
|
||||||
val_name: name of the validation set folder with no prefix (default: "val2017")
|
val_name: name of the validation set folder with no prefix (default: "val2017")
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset or None
|
Dataset or None
|
||||||
"""
|
"""
|
||||||
@@ -344,3 +356,23 @@ def plot_curves(save_dir, hist, savename="fed_yolo_curves.png"):
|
|||||||
out_png = os.path.join(save_dir, savename)
|
out_png = os.path.join(save_dir, savename)
|
||||||
plt.savefig(out_png, dpi=150, bbox_inches="tight")
|
plt.savefig(out_png, dpi=150, bbox_inches="tight")
|
||||||
print(f"[plot] saved: {out_png}")
|
print(f"[plot] saved: {out_png}")
|
||||||
|
|
||||||
|
|
||||||
|
def prepare_result_dir(base_root: str = "results"):
|
||||||
|
"""
|
||||||
|
Prepare result directories for saving outputs.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
base_root (str): base directory for results.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(res_dir, weights_dir) (str,str): Path to result directory and weights directory.
|
||||||
|
"""
|
||||||
|
os.makedirs(base_root, exist_ok=True)
|
||||||
|
timestamp = time.strftime("%Y%m%d_%H%M%S")
|
||||||
|
res_dir = os.path.join(base_root, f"result_{timestamp}")
|
||||||
|
weights_dir = os.path.join(res_dir, f"weight_{timestamp}")
|
||||||
|
os.makedirs(res_dir, exist_ok=True)
|
||||||
|
os.makedirs(weights_dir, exist_ok=True)
|
||||||
|
print(f"[INFO] Saving results to: {res_dir}")
|
||||||
|
return res_dir, weights_dir
|
||||||
|
|||||||
Reference in New Issue
Block a user