测试代码逻辑优化
This commit is contained in:
42
testcode.py
42
testcode.py
@@ -8,39 +8,41 @@ from fed_algo_cs.server_base import FedYoloServer # FedYoloServer
|
|||||||
from utils import Dataset # Dataset
|
from utils import Dataset # Dataset
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
if not os.path.exists("model.txt"):
|
||||||
# model structure test
|
# model structure test
|
||||||
model = init_model("yolo_v11_n", num_classes=1)
|
model = init_model("yolo_v11_n", num_classes=1)
|
||||||
|
|
||||||
with open("model.txt", "w", encoding="utf-8") as f:
|
with open("model.txt", "w", encoding="utf-8") as f:
|
||||||
print(model, file=f)
|
print(model, file=f)
|
||||||
|
|
||||||
|
if not os.path.exists("model_key_value.txt"):
|
||||||
# loop over model key and values
|
# loop over model key and values
|
||||||
with open("model_key_value.txt", "w", encoding="utf-8") as f:
|
with open("model_key_value.txt", "w", encoding="utf-8") as f:
|
||||||
for k, v in model.state_dict().items():
|
for k, v in model.state_dict().items():
|
||||||
print(f"{k}: {v.shape}", file=f)
|
print(f"{k}: {v.shape}", file=f)
|
||||||
|
|
||||||
# test agg function
|
# test agg function
|
||||||
from fed_algo_cs.server_base import FedYoloServer
|
# from fed_algo_cs.server_base import FedYoloServer
|
||||||
import torch
|
# import torch
|
||||||
import yaml
|
# import yaml
|
||||||
|
|
||||||
with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f:
|
# with open("./config/coco128_cfg.yaml", "r", encoding="utf-8") as f:
|
||||||
cfg = yaml.safe_load(f)
|
# cfg = yaml.safe_load(f)
|
||||||
params = dict(cfg)
|
# # params = dict(cfg)
|
||||||
|
|
||||||
server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=params)
|
# server = FedYoloServer(client_list=["c1", "c2", "c3"], model_name="yolo_v11_n", params=cfg)
|
||||||
state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()}
|
# state1 = {k: torch.ones_like(v) for k, v in server.model.state_dict().items()}
|
||||||
state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()}
|
# state2 = {k: torch.ones_like(v) * 2 for k, v in server.model.state_dict().items()}
|
||||||
state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()}
|
# state3 = {k: torch.ones_like(v) * 3 for k, v in server.model.state_dict().items()}
|
||||||
server.rec("c1", state1, n_data=20, loss=0.1)
|
# server.rec("c1", state1, n_data=20, loss=0.1)
|
||||||
server.rec("c2", state2, n_data=30, loss=0.2)
|
# server.rec("c2", state2, n_data=30, loss=0.2)
|
||||||
server.rec("c3", state3, n_data=50, loss=0.3)
|
# server.rec("c3", state3, n_data=50, loss=0.3)
|
||||||
server.select_clients(connection_ratio=1.0)
|
# server.select_clients(connection_ratio=1.0)
|
||||||
model_state, avg_loss, n_data = server.agg()
|
# model_state, avg_loss, n_data = server.agg()
|
||||||
with open("agg_model.txt", "w", encoding="utf-8") as f:
|
# with open("agg_model.txt", "w", encoding="utf-8") as f:
|
||||||
for k, v in model_state.items():
|
# for k, v in model_state.items():
|
||||||
print(f"{k}: {v.float().mean()}", file=f)
|
# print(f"{k}: {v.float().mean()}", file=f)
|
||||||
print(f"avg_loss: {avg_loss}, n_data: {n_data}")
|
# print(f"avg_loss: {avg_loss}, n_data: {n_data}")
|
||||||
|
|
||||||
# test single client training (should be the same as standalone training)
|
# test single client training (should be the same as standalone training)
|
||||||
args = args_parser()
|
args = args_parser()
|
||||||
@@ -50,7 +52,7 @@ if __name__ == "__main__":
|
|||||||
client = FedYoloClient(name="c1", params=cfg, model_name="yolo_v11_n")
|
client = FedYoloClient(name="c1", params=cfg, model_name="yolo_v11_n")
|
||||||
|
|
||||||
filenames = []
|
filenames = []
|
||||||
data_dir = "/home/image1325/ssd1/dataset/COCO128"
|
data_dir = "/mnt/DATA/COCO128"
|
||||||
with open(f"{data_dir}/train.txt") as f:
|
with open(f"{data_dir}/train.txt") as f:
|
||||||
for filename in f.readlines():
|
for filename in f.readlines():
|
||||||
filename = os.path.basename(filename.rstrip())
|
filename = os.path.basename(filename.rstrip())
|
||||||
|
Reference in New Issue
Block a user