新增testcode.py文件,包含模型结构测试、聚合函数测试及单客户端训练逻辑
This commit is contained in:
77
testcode.py
Normal file
77
testcode.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from utils.fed_util import init_model
|
||||
from fed_algo_cs.server_base import test
|
||||
import os
|
||||
import yaml
|
||||
from utils.args import args_parser # args parser
|
||||
from fed_algo_cs.client_base import FedYoloClient # FedYoloClient
|
||||
from fed_algo_cs.server_base import FedYoloServer # FedYoloServer
|
||||
from utils import Dataset # Dataset
|
||||
|
||||
if __name__ == "__main__":
|
||||
# model structure test
|
||||
model = init_model("yolo_v11_n", num_classes=1)
|
||||
|
||||
with open("model.txt", "w", encoding="utf-8") as f:
|
||||
print(model, file=f)
|
||||
|
||||
# loop over model key and values
|
||||
with open("model_key_value.txt", "w", encoding="utf-8") as f:
|
||||
for k, v in model.state_dict().items():
|
||||
print(f"{k}: {v.shape}", file=f)
|
||||
|
||||
# test agg function
|
||||
from fed_algo_cs.server_base import FedYoloServer
|
||||
import torch
|
||||
import yaml
|
||||
|
||||
with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
params = dict(cfg)
|
||||
|
||||
server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=params)
|
||||
state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()}
|
||||
state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()}
|
||||
state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()}
|
||||
server.rec("c1", state1, n_data=20, loss=0.1)
|
||||
server.rec("c2", state2, n_data=30, loss=0.2)
|
||||
server.rec("c3", state3, n_data=50, loss=0.3)
|
||||
server.select_clients(connection_ratio=1.0)
|
||||
model_state, avg_loss, n_data = server.agg()
|
||||
with open("agg_model.txt", "w", encoding="utf-8") as f:
|
||||
for k, v in model_state.items():
|
||||
print(f"{k}: {v.float().mean()}", file=f)
|
||||
print(f"avg_loss: {avg_loss}, n_data: {n_data}")
|
||||
|
||||
# test single client training (should be the same as standalone training)
|
||||
args = args_parser()
|
||||
with open(args.config, "r", encoding="utf-8") as f:
|
||||
cfg = yaml.safe_load(f)
|
||||
# params = dict(cfg)
|
||||
client = FedYoloClient(name="c1", params=cfg, model_name="yolo_v11_n")
|
||||
|
||||
filenames = []
|
||||
data_dir = "/home/image1325/ssd1/dataset/COCO128"
|
||||
with open(f"{data_dir}/train.txt") as f:
|
||||
for filename in f.readlines():
|
||||
filename = os.path.basename(filename.rstrip())
|
||||
filenames.append(f"{data_dir}/images/train2017/" + filename)
|
||||
|
||||
client.load_trainset(train_dataset=filenames)
|
||||
model_state, n_data, avg_loss = client.train(args=args)
|
||||
|
||||
model = init_model("yolo_v11_n", num_classes=80)
|
||||
model.load_state_dict(model_state)
|
||||
|
||||
valset = Dataset(
|
||||
filenames=filenames,
|
||||
input_size=640,
|
||||
params=cfg,
|
||||
augment=False,
|
||||
)
|
||||
if valset is not None:
|
||||
precision, recall, map50, map = test(valset=valset, params=cfg, model=model, batch_size=128)
|
||||
print(
|
||||
f"precision: {precision}, recall: {recall}, map50: {map50}, map: {map}, loss: {avg_loss}, n_data: {n_data}"
|
||||
)
|
||||
else:
|
||||
raise ValueError("valset is None, please provide a valid valset in config file.")
|
Reference in New Issue
Block a user