优化fed_run函数中的进度条显示
This commit is contained in:
10
fed_run.py
10
fed_run.py
@@ -142,7 +142,8 @@ def fed_run():
|
|||||||
# --- build clients ---
|
# --- build clients ---
|
||||||
model_name = cfg.get("model_name", "yolo_v11_n")
|
model_name = cfg.get("model_name", "yolo_v11_n")
|
||||||
clients = {}
|
clients = {}
|
||||||
for uid in users:
|
|
||||||
|
for uid in tqdm(users, desc="Building clients", leave=True, unit="client"):
|
||||||
c = FedYoloClient(name=uid, model_name=model_name, params=params)
|
c = FedYoloClient(name=uid, model_name=model_name, params=params)
|
||||||
c.load_trainset(user_data[uid]["filename"])
|
c.load_trainset(user_data[uid]["filename"])
|
||||||
clients[uid] = c
|
clients[uid] = c
|
||||||
@@ -176,11 +177,11 @@ def fed_run():
|
|||||||
res_root = cfg.get("res_root", "results")
|
res_root = cfg.get("res_root", "results")
|
||||||
os.makedirs(res_root, exist_ok=True)
|
os.makedirs(res_root, exist_ok=True)
|
||||||
|
|
||||||
for rnd in tqdm(range(num_round), desc="main federal loop round"):
|
for rnd in tqdm(range(num_round), desc="main federal loop round:"):
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
# Local training (sequential over all users)
|
# Local training (sequential over all users)
|
||||||
for uid in tqdm(users, desc=f"Round {rnd + 1} local training", leave=False):
|
for uid in tqdm(users, desc=f"Round {rnd + 1} local training: ", leave=False):
|
||||||
client = clients[uid] # FedYoloClient instance
|
client = clients[uid] # FedYoloClient instance
|
||||||
client.update(global_state) # load global weights
|
client.update(global_state) # load global weights
|
||||||
state_dict, n_data, loss_dict = client.train(args_cli) # local training
|
state_dict, n_data, loss_dict = client.train(args_cli) # local training
|
||||||
@@ -213,10 +214,11 @@ def fed_run():
|
|||||||
history["train_loss"].append(scalar_train_loss)
|
history["train_loss"].append(scalar_train_loss)
|
||||||
history["round_time_sec"].append(time.time() - t0)
|
history["round_time_sec"].append(time.time() - t0)
|
||||||
|
|
||||||
print(
|
tqdm.write(
|
||||||
f"[round {rnd + 1:04d}] "
|
f"[round {rnd + 1:04d}] "
|
||||||
f"loss={scalar_train_loss:.4f} mAP50-95={mAP:.4f} mAP50={mAP50:.4f} "
|
f"loss={scalar_train_loss:.4f} mAP50-95={mAP:.4f} mAP50={mAP50:.4f} "
|
||||||
f"P={precision:.4f} R={recall:.4f}"
|
f"P={precision:.4f} R={recall:.4f}"
|
||||||
|
f"\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save running JSON (resumable logs)
|
# Save running JSON (resumable logs)
|
||||||
|
Reference in New Issue
Block a user