Compare commits

..

7 Commits

9 changed files with 66 additions and 53 deletions

View File

@@ -9,7 +9,7 @@ pip install -r requirements.txt
## how to run ## how to run
```bash ```bash
nohup python fed_run.py > train.log 2>&1 & nohup bash fed_run.sh 1 > train.log 2>&1 &
``` ```
## results ## results

View File

@@ -4,21 +4,21 @@ model_name: "yolo_v11_n" # yolo_v11_n, yolo_v11_t, yolo_v11_s, yolo_v11
i_seed: 202509 # initial random seed i_seed: 202509 # initial random seed
num_client: 64 # total number of clients num_client: 64 # total number of clients
num_round: 5 # total number of communication rounds num_round: 50 # total number of communication rounds
num_local_class: 80 # number of classes per client num_local_class: 80 # number of classes per client
res_root: "results" # root directory for results res_root: "results" # root directory for results
dataset_path: "/home/image1325/ssd1/dataset/COCO128/" dataset_path: "/mnt/DATA/coco" # root directory for dataset
# train_txt: "train.txt" # path to training set txt file # train_txt: "train.txt" # path to training set txt file
# val_txt: "val.txt" # path to validation set txt file # val_txt: "val.txt" # path to validation set txt file
# test_txt: "test.txt" # path to test set txt file # test_txt: "test.txt" # path to test set txt file
local_batch_size: 32 # local training batch size local_batch_size: 32 # local training batch size
val_batch_size: 4 # validation batch size val_batch_size: 128 # validation batch size
num_workers: 4 # number of data loader workers num_workers: 8 # number of data loader workers
min_data: 128 # minimum number of images per client min_data: 1700 # minimum number of images per client
max_data: 128 # maximum number of images per client max_data: 1800 # maximum number of images per client
partition_mode: "overlap" # "overlap" or "disjoint" partition_mode: "overlap" # "overlap" or "disjoint"
connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients connection_ratio: 1 # connection ratio, e.g., 1.0 means all clients

View File

@@ -120,7 +120,7 @@ class FedYoloClient(object):
# track_model = self.model.module if is_ddp else self.model # track_model = self.model.module if is_ddp else self.model
ema = util.EMA(self.model) if args.local_rank == 0 else None ema = util.EMA(self.model) if args.local_rank == 0 else None
print(type(self.train_dataset)) # print(type(self.train_dataset))
# ---- Data ---- # ---- Data ----
dataset = Dataset( dataset = Dataset(
@@ -188,7 +188,7 @@ class FedYoloClient(object):
loss_dfl_meter = util.AverageMeter() loss_dfl_meter = util.AverageMeter()
for i, (images, targets) in enumerate(loader): for i, (images, targets) in enumerate(loader):
print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}") # print(f"Client {self.name} - Epoch {epoch + 1}/{args.epochs} - Step {i + 1}/{num_steps}")
step = i + epoch * num_steps step = i + epoch * num_steps
# scheduler per-step (your util.LinearLR expects step) # scheduler per-step (your util.LinearLR expects step)
@@ -257,9 +257,9 @@ class FedYoloClient(object):
else self.model else self.model
) )
# print loss to test # print loss to test
print( # print(
f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}" # f"loss: {total_loss.item() * accumulate:.4f}, box: {box_loss.item():.4f}, cls: {cls_loss.item():.4f}, dfl: {dfl_loss.item():.4f}"
) # )
torch.cuda.synchronize() torch.cuda.synchronize()
# ---- Final average loss (per image) over the whole epoch span ---- # ---- Final average loss (per image) over the whole epoch span ----

View File

@@ -44,12 +44,12 @@ def fed_run():
train_txt = cfg.get("train_txt", "") train_txt = cfg.get("train_txt", "")
if not train_txt: if not train_txt:
ds_root = cfg.get("dataset_path", "") ds_root = cfg.get("dataset_path", "")
guess = os.path.join(ds_root, "train.txt") if ds_root else "" guess = os.path.join(ds_root, "train2017.txt") if ds_root else ""
train_txt = guess train_txt = guess
if not train_txt or not os.path.exists(train_txt): if not train_txt or not os.path.exists(train_txt):
raise FileNotFoundError( raise FileNotFoundError(
f"train.txt not found. Provide --config with 'train_txt' or ensure '{train_txt}' exists." f"train2017.txt not found. Provide --config with 'train_txt' or ensure '{train_txt}' exists."
) )
split = divide_trainset( split = divide_trainset(
@@ -76,7 +76,7 @@ def fed_run():
# --- build server & optional validation set --- # --- build server & optional validation set ---
server = FedYoloServer(client_list=users, model_name=model_name, params=cfg) server = FedYoloServer(client_list=users, model_name=model_name, params=cfg)
valset = build_valset_if_available(cfg, params=cfg, args=args_cli) valset = build_valset_if_available(cfg, params=cfg, args=args_cli, val_name="val2017")
# valset is a Dataset class, not data loader # valset is a Dataset class, not data loader
if valset is not None: if valset is not None:
server.load_valset(valset) server.load_valset(valset)

View File

@@ -1,2 +1,2 @@
GPUS=$1 GPUS=$1
python3 -m torch.distributed.run --nproc_per_node=$GPUS fed_run.py ${@:2} nohup python3 -m torch.distributed.run --nproc_per_node=$GPUS fed_run.py ${@:2} > train.log 2>&1 & disown

View File

@@ -8,39 +8,41 @@ from fed_algo_cs.server_base import FedYoloServer # FedYoloServer
from utils import Dataset # Dataset from utils import Dataset # Dataset
if __name__ == "__main__": if __name__ == "__main__":
if not os.path.exists("model.txt"):
# model structure test # model structure test
model = init_model("yolo_v11_n", num_classes=1) model = init_model("yolo_v11_n", num_classes=1)
with open("model.txt", "w", encoding="utf-8") as f: with open("model.txt", "w", encoding="utf-8") as f:
print(model, file=f) print(model, file=f)
if not os.path.exists("model_key_value.txt"):
# loop over model key and values # loop over model key and values
with open("model_key_value.txt", "w", encoding="utf-8") as f: with open("model_key_value.txt", "w", encoding="utf-8") as f:
for k, v in model.state_dict().items(): for k, v in model.state_dict().items():
print(f"{k}: {v.shape}", file=f) print(f"{k}: {v.shape}", file=f)
# test agg function # test agg function
from fed_algo_cs.server_base import FedYoloServer # from fed_algo_cs.server_base import FedYoloServer
import torch # import torch
import yaml # import yaml
with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f: # with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f:
cfg = yaml.safe_load(f) # cfg = yaml.safe_load(f)
params = dict(cfg) # # params = dict(cfg)
server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=params) # server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=cfg)
state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()} # state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()}
state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()} # state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()}
state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()} # state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()}
server.rec("c1", state1, n_data=20, loss=0.1) # server.rec("c1", state1, n_data=20, loss=0.1)
server.rec("c2", state2, n_data=30, loss=0.2) # server.rec("c2", state2, n_data=30, loss=0.2)
server.rec("c3", state3, n_data=50, loss=0.3) # server.rec("c3", state3, n_data=50, loss=0.3)
server.select_clients(connection_ratio=1.0) # server.select_clients(connection_ratio=1.0)
model_state, avg_loss, n_data = server.agg() # model_state, avg_loss, n_data = server.agg()
with open("agg_model.txt", "w", encoding="utf-8") as f: # with open("agg_model.txt", "w", encoding="utf-8") as f:
for k, v in model_state.items(): # for k, v in model_state.items():
print(f"{k}: {v.float().mean()}", file=f) # print(f"{k}: {v.float().mean()}", file=f)
print(f"avg_loss: {avg_loss}, n_data: {n_data}") # print(f"avg_loss: {avg_loss}, n_data: {n_data}")
# test single client training (should be the same as standalone training) # test single client training (should be the same as standalone training)
args = args_parser() args = args_parser()
@@ -50,7 +52,7 @@ if __name__ == "__main__":
client = FedYoloClient(name="c1", params=cfg, model_name="yolo_v11_n") client = FedYoloClient(name="c1", params=cfg, model_name="yolo_v11_n")
filenames = [] filenames = []
data_dir = "/home/image1325/ssd1/dataset/COCO128" data_dir = "/mnt/DATA/COCO128"
with open(f"{data_dir}/train.txt") as f: with open(f"{data_dir}/train.txt") as f:
for filename in f.readlines(): for filename in f.readlines():
filename = os.path.basename(filename.rstrip()) filename = os.path.basename(filename.rstrip())

View File

@@ -5,9 +5,9 @@ import os
def args_parser(): def args_parser():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=10, help="number of rounds of local training") parser.add_argument("--epochs", type=int, default=16, help="number of rounds of local training")
parser.add_argument("--input_size", type=int, default=640, help="image input size") parser.add_argument("--input_size", type=int, default=640, help="image input size")
parser.add_argument("--config", type=str, default="./config/uav_cfg.yaml", help="Path to YAML config") parser.add_argument("--config", type=str, default="./config/coco_cfg.yaml", help="Path to YAML config")
args = parser.parse_args() args = parser.parse_args()

View File

@@ -204,7 +204,10 @@ class Dataset(data.Dataset):
def load_label(filenames): def load_label(filenames):
path = f"{os.path.dirname(filenames[0])}.cache" path = f"{os.path.dirname(filenames[0])}.cache"
if os.path.exists(path): if os.path.exists(path):
return torch.load(path, weights_only=False) # XXX: temporarily disable cache
os.remove(path)
pass
# return torch.load(path, weights_only=False)
x = {} x = {}
for filename in filenames: for filename in filenames:
try: try:

View File

@@ -267,7 +267,7 @@ def init_model(model_name, num_classes) -> YOLO:
return model return model
def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]: def build_valset_if_available(cfg, params, args=None, val_name: str = "val2017") -> Optional[Dataset]:
""" """
Try to build a validation Dataset. Try to build a validation Dataset.
- If cfg['val_txt'] exists, use it. - If cfg['val_txt'] exists, use it.
@@ -276,6 +276,8 @@ def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]:
Args: Args:
cfg: config dict cfg: config dict
params: params dict for Dataset params: params dict for Dataset
args: optional args object (for input_size)
val_name: name of the validation set folder with no prefix (default: "val2017")
Returns: Returns:
Dataset or None Dataset or None
""" """
@@ -283,18 +285,24 @@ def build_valset_if_available(cfg, params, args=None) -> Optional[Dataset]:
val_txt = cfg.get("val_txt", "") val_txt = cfg.get("val_txt", "")
if not val_txt: if not val_txt:
ds_root = cfg.get("dataset_path", "") ds_root = cfg.get("dataset_path", "")
guess = os.path.join(ds_root, "val.txt") if ds_root else "" guess = os.path.join(ds_root, f"{val_name}.txt") if ds_root else ""
val_txt = guess if os.path.exists(guess) else "" val_txt = guess if os.path.exists(guess) else ""
val_files = _read_list_file(val_txt) # val_files = _read_list_file(val_txt)
if not val_files:
filenames = []
with open(val_txt, "r", encoding="utf-8") as f:
for filename in f.readlines():
filename = os.path.basename(filename.rstrip())
filenames.append(f"{ds_root}/images/{val_name}/" + filename)
if not filenames:
import warnings import warnings
warnings.warn("No validation dataset found.") warnings.warn("No validation dataset found.")
return None return None
return Dataset( return Dataset(
filenames=val_files, filenames=filenames,
input_size=input_size, input_size=input_size,
params=params, params=params,
augment=True, augment=True,