From f9588b74a82225509a18d1800277842de93daf26 Mon Sep 17 00:00:00 2001 From: Yunhao Meng Date: Thu, 23 Oct 2025 13:07:34 +0800 Subject: [PATCH] =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=BB=A3=E7=A0=81=E9=80=BB?= =?UTF-8?q?=E8=BE=91=E4=BC=98=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- testcode.py | 58 +++++++++++++++++++++++++++-------------------------- 1 file changed, 30 insertions(+), 28 deletions(-) diff --git a/testcode.py b/testcode.py index 8c0c67d..c9a44d8 100644 --- a/testcode.py +++ b/testcode.py @@ -8,39 +8,41 @@ from fed_algo_cs.server_base import FedYoloServer # FedYoloServer from utils import Dataset # Dataset if __name__ == "__main__": - # model structure test - model = init_model("yolo_v11_n", num_classes=1) + if not os.path.exists("model.txt"): + # model structure test + model = init_model("yolo_v11_n", num_classes=1) - with open("model.txt", "w", encoding="utf-8") as f: - print(model, file=f) + with open("model.txt", "w", encoding="utf-8") as f: + print(model, file=f) - # loop over model key and values - with open("model_key_value.txt", "w", encoding="utf-8") as f: - for k, v in model.state_dict().items(): - print(f"{k}: {v.shape}", file=f) + if not os.path.exists("model_key_value.txt"): + # loop over model key and values + with open("model_key_value.txt", "w", encoding="utf-8") as f: + for k, v in model.state_dict().items(): + print(f"{k}: {v.shape}", file=f) # test agg function - from fed_algo_cs.server_base import FedYoloServer - import torch - import yaml + # from fed_algo_cs.server_base import FedYoloServer + # import torch + # import yaml - with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f: - cfg = yaml.safe_load(f) - params = dict(cfg) + # with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f: + # cfg = yaml.safe_load(f) + # # params = dict(cfg) - server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=params) - state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()} - state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()} - state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()} - server.rec("c1", state1, n_data=20, loss=0.1) - server.rec("c2", state2, n_data=30, loss=0.2) - server.rec("c3", state3, n_data=50, loss=0.3) - server.select_clients(connection_ratio=1.0) - model_state, avg_loss, n_data = server.agg() - with open("agg_model.txt", "w", encoding="utf-8") as f: - for k, v in model_state.items(): - print(f"{k}: {v.float().mean()}", file=f) - print(f"avg_loss: {avg_loss}, n_data: {n_data}") + # server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=cfg) + # state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()} + # state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()} + # state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()} + # server.rec("c1", state1, n_data=20, loss=0.1) + # server.rec("c2", state2, n_data=30, loss=0.2) + # server.rec("c3", state3, n_data=50, loss=0.3) + # server.select_clients(connection_ratio=1.0) + # model_state, avg_loss, n_data = server.agg() + # with open("agg_model.txt", "w", encoding="utf-8") as f: + # for k, v in model_state.items(): + # print(f"{k}: {v.float().mean()}", file=f) + # print(f"avg_loss: {avg_loss}, n_data: {n_data}") # test single client training (should be the same as standalone training) args = args_parser() @@ -50,7 +52,7 @@ if __name__ == "__main__": client = FedYoloClient(name="c1", params=cfg, model_name="yolo_v11_n") filenames = [] - data_dir = "/home/image1325/ssd1/dataset/COCO128" + data_dir = "/mnt/DATA/COCO128" with open(f"{data_dir}/train.txt") as f: for filename in f.readlines(): filename = os.path.basename(filename.rstrip())