from utils.fed_util import init_model from fed_algo_cs.server_base import test import os import yaml from utils.args import args_parser # args parser from fed_algo_cs.client_base import FedYoloClient # FedYoloClient from fed_algo_cs.server_base import FedYoloServer # FedYoloServer from utils import Dataset # Dataset if __name__ == "__main__": if not os.path.exists("model.txt"): # model structure test model = init_model("yolo_v11_n", num_classes=1) with open("model.txt", "w", encoding="utf-8") as f: print(model, file=f) if not os.path.exists("model_key_value.txt"): # loop over model key and values with open("model_key_value.txt", "w", encoding="utf-8") as f: for k, v in model.state_dict().items(): print(f"{k}: {v.shape}", file=f) # test agg function # from fed_algo_cs.server_base import FedYoloServer # import torch # import yaml # with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f: # cfg = yaml.safe_load(f) # # params = dict(cfg) # server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=cfg) # state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()} # state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()} # state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()} # server.rec("c1", state1, n_data=20, loss=0.1) # server.rec("c2", state2, n_data=30, loss=0.2) # server.rec("c3", state3, n_data=50, loss=0.3) # server.select_clients(connection_ratio=1.0) # model_state, avg_loss, n_data = server.agg() # with open("agg_model.txt", "w", encoding="utf-8") as f: # for k, v in model_state.items(): # print(f"{k}: {v.float().mean()}", file=f) # print(f"avg_loss: {avg_loss}, n_data: {n_data}") # test single client training (should be the same as standalone training) args = args_parser() with open(args.config, "r", encoding="utf-8") as f: cfg = yaml.safe_load(f) # params = dict(cfg) client = FedYoloClient(name="c1", params=cfg, model_name="yolo_v11_n") filenames = [] data_dir = "/mnt/DATA/COCO128" with open(f"{data_dir}/train.txt") as f: for filename in f.readlines(): filename = os.path.basename(filename.rstrip()) filenames.append(f"{data_dir}/images/train2017/" + filename) client.load_trainset(train_dataset=filenames) model_state, n_data, avg_loss = client.train(args=args) model = init_model("yolo_v11_n", num_classes=80) model.load_state_dict(model_state) valset = Dataset( filenames=filenames, input_size=640, params=cfg, augment=False, ) if valset is not None: precision, recall, map50, map = test(valset=valset, params=cfg, model=model, batch_size=128) print( f"precision: {precision}, recall: {recall}, map50: {map50}, map: {map}, loss: {avg_loss}, n_data: {n_data}" ) else: raise ValueError("valset is None, please provide a valid valset in config file.")