From f23a22632fd7de02155b4a063bd9559d85f9be3d Mon Sep 17 00:00:00 2001 From: TY1667 Date: Sun, 19 Oct 2025 21:31:08 +0800 Subject: [PATCH] =?UTF-8?q?=E6=96=B0=E5=A2=9Etestcode.py=E6=96=87=E4=BB=B6?= =?UTF-8?q?=EF=BC=8C=E5=8C=85=E5=90=AB=E6=A8=A1=E5=9E=8B=E7=BB=93=E6=9E=84?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E3=80=81=E8=81=9A=E5=90=88=E5=87=BD=E6=95=B0?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E5=8F=8A=E5=8D=95=E5=AE=A2=E6=88=B7=E7=AB=AF?= =?UTF-8?q?=E8=AE=AD=E7=BB=83=E9=80=BB=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- testcode.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 testcode.py diff --git a/testcode.py b/testcode.py new file mode 100644 index 0000000..8c0c67d --- /dev/null +++ b/testcode.py @@ -0,0 +1,77 @@ +from utils.fed_util import init_model +from fed_algo_cs.server_base import test +import os +import yaml +from utils.args import args_parser # args parser +from fed_algo_cs.client_base import FedYoloClient # FedYoloClient +from fed_algo_cs.server_base import FedYoloServer # FedYoloServer +from utils import Dataset # Dataset + +if __name__ == "__main__": + # model structure test + model = init_model("yolo_v11_n", num_classes=1) + + with open("model.txt", "w", encoding="utf-8") as f: + print(model, file=f) + + # loop over model key and values + with open("model_key_value.txt", "w", encoding="utf-8") as f: + for k, v in model.state_dict().items(): + print(f"{k}: {v.shape}", file=f) + + # test agg function + from fed_algo_cs.server_base import FedYoloServer + import torch + import yaml + + with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + params = dict(cfg) + + server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=params) + state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()} + state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()} + state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()} + server.rec("c1", state1, n_data=20, loss=0.1) + server.rec("c2", state2, n_data=30, loss=0.2) + server.rec("c3", state3, n_data=50, loss=0.3) + server.select_clients(connection_ratio=1.0) + model_state, avg_loss, n_data = server.agg() + with open("agg_model.txt", "w", encoding="utf-8") as f: + for k, v in model_state.items(): + print(f"{k}: {v.float().mean()}", file=f) + print(f"avg_loss: {avg_loss}, n_data: {n_data}") + + # test single client training (should be the same as standalone training) + args = args_parser() + with open(args.config, "r", encoding="utf-8") as f: + cfg = yaml.safe_load(f) + # params = dict(cfg) + client = FedYoloClient(name="c1", params=cfg, model_name="yolo_v11_n") + + filenames = [] + data_dir = "/home/image1325/ssd1/dataset/COCO128" + with open(f"{data_dir}/train.txt") as f: + for filename in f.readlines(): + filename = os.path.basename(filename.rstrip()) + filenames.append(f"{data_dir}/images/train2017/" + filename) + + client.load_trainset(train_dataset=filenames) + model_state, n_data, avg_loss = client.train(args=args) + + model = init_model("yolo_v11_n", num_classes=80) + model.load_state_dict(model_state) + + valset = Dataset( + filenames=filenames, + input_size=640, + params=cfg, + augment=False, + ) + if valset is not None: + precision, recall, map50, map = test(valset=valset, params=cfg, model=model, batch_size=128) + print( + f"precision: {precision}, recall: {recall}, map50: {map50}, map: {map}, loss: {avg_loss}, n_data: {n_data}" + ) + else: + raise ValueError("valset is None, please provide a valid valset in config file.")