Compare commits
7 Commits
76d7149512
...
main
Author | SHA1 | Date | |
---|---|---|---|
c7a3d34d88 | |||
b81f33ad28 | |||
f9588b74a8 | |||
2fbb741d3f | |||
1822aca36b | |||
ac4af34802 | |||
52382e460d |
@@ -9,7 +9,7 @@ pip install -r requirements.txt
|
|||||||
|
|
||||||
## how to run
|
## how to run
|
||||||
```bash
|
```bash
|
||||||
nohup python fed_run.py > train.log 2>&1 &
|
nohup bash fed_run.sh 1 > train.log 2>&1 &
|
||||||
```
|
```
|
||||||
|
|
||||||
## results
|
## results
|
||||||
|
@@ -4,21 +4,21 @@ model_name: "yolo_v11_n" # yolo_v11_n, yolo_v11_t, yolo_v11_s, yolo_v11
|
|||||||
i_seed: 202509 # initial random seed
|
i_seed: 202509 # initial random seed
|
||||||
|
|
||||||
num_client: 64 # total number of clients
|
num_client: 64 # total number of clients
|
||||||
num_round: 5 # total number of communication rounds
|
num_round: 50 # total number of communication rounds
|
||||||
num_local_class: 80 # number of classes per client
|
num_local_class: 80 # number of classes per client
|
||||||
|
|
||||||
res_root: "results" # root directory for results
|
res_root: "results" # root directory for results
|
||||||
dataset_path: "/home/image1325/ssd1/dataset/COCO128/"
|
dataset_path: "/mnt/DATA/coco" # root directory for dataset
|
||||||
# train_txt: "train.txt" # path to training set txt file
|
# train_txt: "train.txt" # path to training set txt file
|
||||||
# val_txt: "val.txt" # path to validation set txt file
|
# val_txt: "val.txt" # path to validation set txt file
|
||||||
# test_txt: "test.txt" # path to test set txt file
|
# test_txt: "test.txt" # path to test set txt file
|
||||||
|
|
||||||
local_batch_size: 32 # local training batch size
|
local_batch_size: 32 # local training batch size
|
||||||
val_batch_size: 4 # validation batch size
|
val_batch_size: 128 # validation batch size
|
||||||
|
|
||||||
num_workers: 4 # number of data loader workers
|
num_workers: 8 # number of data loader workers
|
||||||
min_data: 128 # minimum number of images per client
|
min_data: 1700 # minimum number of images per client
|
||||||
max_data: 128 # maximum number of images per client
|
max_data: 1800 # maximum number of images per client
|
||||||
partition_mode: "overlap" # "overlap" or "disjoint"
|
partition_mode: "overlap" # "overlap" or "disjoint"
|
||||||
connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients
|
connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients
|
||||||
|
|
||||||
|
@@ -120,7 +120,7 @@ class FedYoloClient(object):
|
|||||||
# track_model = self.model.module if is_ddp else self.model
|
# track_model = self.model.module if is_ddp else self.model
|
||||||
ema = util.EMA(self.model) if args.local_rank == 0 else None
|
ema = util.EMA(self.model) if args.local_rank == 0 else None
|
||||||
|
|
||||||
print(type(self.train_dataset))
|
# print(type(self.train_dataset))
|
||||||
|
|
||||||
# ---- Data ----
|
# ---- Data ----
|
||||||
dataset = Dataset(
|
dataset = Dataset(
|
||||||
@@ -188,7 +188,7 @@ class FedYoloClient(object):
|
|||||||
loss_dfl_meter = util.AverageMeter()
|
loss_dfl_meter = util.AverageMeter()
|
||||||
|
|
||||||
for i, (images, targets) in enumerate(loader):
|
for i, (images, targets) in enumerate(loader):
|
||||||
print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}")
|
# print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}")
|
||||||
step = i + epoch * num_steps
|
step = i + epoch * num_steps
|
||||||
|
|
||||||
# scheduler per-step (your util.LinearLR expects step)
|
# scheduler per-step (your util.LinearLR expects step)
|
||||||
@@ -257,9 +257,9 @@ class FedYoloClient(object):
|
|||||||
else self.model
|
else self.model
|
||||||
)
|
)
|
||||||
# print loss to test
|
# print loss to test
|
||||||
print(
|
# print(
|
||||||
f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}"
|
# f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}"
|
||||||
)
|
# )
|
||||||
torch.cuda.synchronize()
|
torch.cuda.synchronize()
|
||||||
|
|
||||||
# ---- Final average loss (per image) over the whole epoch span ----
|
# ---- Final average loss (per image) over the whole epoch span ----
|
||||||
|
@@ -44,12 +44,12 @@ def fed_run():
|
|||||||
train_txt = cfg.get("train_txt", "")
|
train_txt = cfg.get("train_txt", "")
|
||||||
if not train_txt:
|
if not train_txt:
|
||||||
ds_root = cfg.get("dataset_path", "")
|
ds_root = cfg.get("dataset_path", "")
|
||||||
guess = os.path.join(ds_root, "train.txt") if ds_root else ""
|
guess = os.path.join(ds_root, "train2017.txt") if ds_root else ""
|
||||||
train_txt = guess
|
train_txt = guess
|
||||||
|
|
||||||
if not train_txt or not os.path.exists(train_txt):
|
if not train_txt or not os.path.exists(train_txt):
|
||||||
raise FileNotFoundError(
|
raise FileNotFoundError(
|
||||||
f"train.txt not found. Provide --config with 'train_txt' or ensure '{train_txt}' exists."
|
f"train2017.txt not found. Provide --config with 'train_txt' or ensure '{train_txt}' exists."
|
||||||
)
|
)
|
||||||
|
|
||||||
split = divide_trainset(
|
split = divide_trainset(
|
||||||
@@ -76,7 +76,7 @@ def fed_run():
|
|||||||
|
|
||||||
# --- build server & optional validation set ---
|
# --- build server & optional validation set ---
|
||||||
server = FedYoloServer(client_list=users, model_name=model_name, params=cfg)
|
server = FedYoloServer(client_list=users, model_name=model_name, params=cfg)
|
||||||
valset = build_valset_if_available(cfg, params=cfg, args=args_cli)
|
valset = build_valset_if_available(cfg, params=cfg, args=args_cli, val_name="val2017")
|
||||||
# valset is a Dataset class, not data loader
|
# valset is a Dataset class, not data loader
|
||||||
if valset is not None:
|
if valset is not None:
|
||||||
server.load_valset(valset)
|
server.load_valset(valset)
|
||||||
|
@@ -1,2 +1,2 @@
|
|||||||
GPUS=$1
|
GPUS=$1
|
||||||
python3 -m torch.distributed.run --nproc_per_node=$GPUS fed_run.py ${@:2}
|
nohup python3 -m torch.distributed.run --nproc_per_node=$GPUS fed_run.py ${@:2} > train.log 2>&1 & disown
|
58
testcode.py
58
testcode.py
@@ -8,39 +8,41 @@ from fed_algo_cs.server_base import FedYoloServer # FedYoloServer
|
|||||||
from utils import Dataset # Dataset
|
from utils import Dataset # Dataset
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
# model structure test
|
if not os.path.exists("model.txt"):
|
||||||
model = init_model("yolo_v11_n", num_classes=1)
|
# model structure test
|
||||||
|
model = init_model("yolo_v11_n", num_classes=1)
|
||||||
|
|
||||||
with open("model.txt", "w", encoding="utf-8") as f:
|
with open("model.txt", "w", encoding="utf-8") as f:
|
||||||
print(model, file=f)
|
print(model, file=f)
|
||||||
|
|
||||||
# loop over model key and values
|
if not os.path.exists("model_key_value.txt"):
|
||||||
with open("model_key_value.txt", "w", encoding="utf-8") as f:
|
# loop over model key and values
|
||||||
for k, v in model.state_dict().items():
|
with open("model_key_value.txt", "w", encoding="utf-8") as f:
|
||||||
print(f"{k}: {v.shape}", file=f)
|
for k, v in model.state_dict().items():
|
||||||
|
print(f"{k}: {v.shape}", file=f)
|
||||||
|
|
||||||
# test agg function
|
# test agg function
|
||||||
from fed_algo_cs.server_base import FedYoloServer
|
# from fed_algo_cs.server_base import FedYoloServer
|
||||||
import torch
|
# import torch
|
||||||
import yaml
|
# import yaml
|
||||||
|
|
||||||
with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f:
|
# with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f:
|
||||||
cfg = yaml.safe_load(f)
|
# cfg = yaml.safe_load(f)
|
||||||
params = dict(cfg)
|
# # params = dict(cfg)
|
||||||
|
|
||||||
server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=params)
|
# server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=cfg)
|
||||||
state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()}
|
# state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()}
|
||||||
state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()}
|
# state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()}
|
||||||
state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()}
|
# state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()}
|
||||||
server.rec("c1", state1, n_data=20, loss=0.1)
|
# server.rec("c1", state1, n_data=20, loss=0.1)
|
||||||
server.rec("c2", state2, n_data=30, loss=0.2)
|
# server.rec("c2", state2, n_data=30, loss=0.2)
|
||||||
server.rec("c3", state3, n_data=50, loss=0.3)
|
# server.rec("c3", state3, n_data=50, loss=0.3)
|
||||||
server.select_clients(connection_ratio=1.0)
|
# server.select_clients(connection_ratio=1.0)
|
||||||
model_state, avg_loss, n_data = server.agg()
|
# model_state, avg_loss, n_data = server.agg()
|
||||||
with open("agg_model.txt", "w", encoding="utf-8") as f:
|
# with open("agg_model.txt", "w", encoding="utf-8") as f:
|
||||||
for k, v in model_state.items():
|
# for k, v in model_state.items():
|
||||||
print(f"{k}: {v.float().mean()}", file=f)
|
# print(f"{k}: {v.float().mean()}", file=f)
|
||||||
print(f"avg_loss: {avg_loss}, n_data: {n_data}")
|
# print(f"avg_loss: {avg_loss}, n_data: {n_data}")
|
||||||
|
|
||||||
# test single client training (should be the same as standalone training)
|
# test single client training (should be the same as standalone training)
|
||||||
args = args_parser()
|
args = args_parser()
|
||||||
@@ -50,7 +52,7 @@ if __name__ == "__main__":
|
|||||||
client = FedYoloClient(name="c1", params=cfg, model_name="yolo_v11_n")
|
client = FedYoloClient(name="c1", params=cfg, model_name="yolo_v11_n")
|
||||||
|
|
||||||
filenames = []
|
filenames = []
|
||||||
data_dir = "/home/image1325/ssd1/dataset/COCO128"
|
data_dir = "/mnt/DATA/COCO128"
|
||||||
with open(f"{data_dir}/train.txt") as f:
|
with open(f"{data_dir}/train.txt") as f:
|
||||||
for filename in f.readlines():
|
for filename in f.readlines():
|
||||||
filename = os.path.basename(filename.rstrip())
|
filename = os.path.basename(filename.rstrip())
|
||||||
|
@@ -5,9 +5,9 @@ import os
|
|||||||
def args_parser():
|
def args_parser():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
|
|
||||||
parser.add_argument("--epochs", type=int, default=10, help="number of rounds of local training")
|
parser.add_argument("--epochs", type=int, default=16, help="number of rounds of local training")
|
||||||
parser.add_argument("--input_size", type=int, default=640, help="image input size")
|
parser.add_argument("--input_size", type=int, default=640, help="image input size")
|
||||||
parser.add_argument("--config", type=str, default="./config/uav_cfg.yaml", help="Path to YAML config")
|
parser.add_argument("--config", type=str, default="./config/coco_cfg.yaml", help="Path to YAML config")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@@ -204,7 +204,10 @@ class Dataset(data.Dataset):
|
|||||||
def load_label(filenames):
|
def load_label(filenames):
|
||||||
path = f"{os.path.dirname(filenames[0])}.cache"
|
path = f"{os.path.dirname(filenames[0])}.cache"
|
||||||
if os.path.exists(path):
|
if os.path.exists(path):
|
||||||
return torch.load(path, weights_only=False)
|
# XXX: temporarily disable cache
|
||||||
|
os.remove(path)
|
||||||
|
pass
|
||||||
|
# return torch.load(path, weights_only=False)
|
||||||
x = {}
|
x = {}
|
||||||
for filename in filenames:
|
for filename in filenames:
|
||||||
try:
|
try:
|
||||||
|
@@ -267,7 +267,7 @@ def init_model(model_name, num_classes) -> YOLO:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]:
|
def build_valset_if_available(cfg, params, args=None, val_name: str = "val2017") -> Optional[Dataset]:
|
||||||
"""
|
"""
|
||||||
Try to build a validation Dataset.
|
Try to build a validation Dataset.
|
||||||
- If cfg['val_txt'] exists, use it.
|
- If cfg['val_txt'] exists, use it.
|
||||||
@@ -276,6 +276,8 @@ def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]:
|
|||||||
Args:
|
Args:
|
||||||
cfg: config dict
|
cfg: config dict
|
||||||
params: params dict for Dataset
|
params: params dict for Dataset
|
||||||
|
args: optional args object (for input_size)
|
||||||
|
val_name: name of the validation set folder with no prefix (default: "val2017")
|
||||||
Returns:
|
Returns:
|
||||||
Dataset or None
|
Dataset or None
|
||||||
"""
|
"""
|
||||||
@@ -283,18 +285,24 @@ def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]:
|
|||||||
val_txt = cfg.get("val_txt", "")
|
val_txt = cfg.get("val_txt", "")
|
||||||
if not val_txt:
|
if not val_txt:
|
||||||
ds_root = cfg.get("dataset_path", "")
|
ds_root = cfg.get("dataset_path", "")
|
||||||
guess = os.path.join(ds_root, "val.txt") if ds_root else ""
|
guess = os.path.join(ds_root, f"{val_name}.txt") if ds_root else ""
|
||||||
val_txt = guess if os.path.exists(guess) else ""
|
val_txt = guess if os.path.exists(guess) else ""
|
||||||
|
|
||||||
val_files = _read_list_file(val_txt)
|
# val_files = _read_list_file(val_txt)
|
||||||
if not val_files:
|
|
||||||
|
filenames = []
|
||||||
|
with open(val_txt, "r", encoding="utf-8") as f:
|
||||||
|
for filename in f.readlines():
|
||||||
|
filename = os.path.basename(filename.rstrip())
|
||||||
|
filenames.append(f"{ds_root}/images/{val_name}/" + filename)
|
||||||
|
if not filenames:
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
warnings.warn("No validation dataset found.")
|
warnings.warn("No validation dataset found.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return Dataset(
|
return Dataset(
|
||||||
filenames=val_files,
|
filenames=filenames,
|
||||||
input_size=input_size,
|
input_size=input_size,
|
||||||
params=params,
|
params=params,
|
||||||
augment=True,
|
augment=True,
|
||||||
|
Reference in New Issue
Block a user