Compare commits
3 Commits
b19f11125d
...
964a8024c0
Author | SHA1 | Date | |
---|---|---|---|
964a8024c0 | |||
0b52cfc4f5 | |||
c2e538898c |
@@ -121,40 +121,53 @@ class FedYoloServer(object):
|
|||||||
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
|
return {"mAP": float(mean_ap), "mAP50": float(map50), "precision": float(prec), "recall": float(rec)}
|
||||||
|
|
||||||
def select_clients(self, connection_ratio=1.0):
|
def select_clients(self, connection_ratio=1.0):
|
||||||
"""Randomly select a fraction of clients."""
|
"""
|
||||||
|
Randomly select a fraction of clients.
|
||||||
|
Args:
|
||||||
|
connection_ratio: fraction of clients to select (0 < connection_ratio <= 1)
|
||||||
|
"""
|
||||||
self.selected_clients = []
|
self.selected_clients = []
|
||||||
self.n_data = 0
|
self.n_data = 0
|
||||||
for client_id in self.client_list:
|
for client_id in self.client_list:
|
||||||
|
# Random selection based on connection ratio
|
||||||
if np.random.rand() <= connection_ratio:
|
if np.random.rand() <= connection_ratio:
|
||||||
self.selected_clients.append(client_id)
|
self.selected_clients.append(client_id)
|
||||||
self.n_data += self.client_n_data.get(client_id, 0)
|
self.n_data += self.client_n_data.get(client_id, 0)
|
||||||
|
|
||||||
def agg(self):
|
def agg(self):
|
||||||
"""Aggregate client updates (FedAvg)."""
|
"""
|
||||||
|
Aggregate client updates (FedAvg).
|
||||||
|
Returns:
|
||||||
|
global_state: aggregated model state dictionary
|
||||||
|
avg_loss: dict of averaged losses
|
||||||
|
n_data: total number of data classes samples used in this round
|
||||||
|
"""
|
||||||
if len(self.selected_clients) == 0 or self.n_data == 0:
|
if len(self.selected_clients) == 0 or self.n_data == 0:
|
||||||
return self.model.state_dict(), {}, 0
|
return self.model.state_dict(), {}, 0
|
||||||
|
|
||||||
model = init_model(self.model_name, self._num_classes)
|
# start from current global model
|
||||||
model_state = model.state_dict()
|
global_state = self.model.state_dict()
|
||||||
|
|
||||||
|
# zero buffer for accumulation
|
||||||
|
new_state = {k: torch.zeros_like(v, dtype=torch.float32) for k, v in global_state.items()}
|
||||||
|
|
||||||
avg_loss = {}
|
avg_loss = {}
|
||||||
for i, name in enumerate(self.selected_clients):
|
for name in self.selected_clients:
|
||||||
if name not in self.client_state:
|
if name not in self.client_state:
|
||||||
continue
|
continue
|
||||||
weight = self.client_n_data[name] / self.n_data
|
weight = self.client_n_data[name] / self.n_data
|
||||||
for key in model_state.keys():
|
for k in new_state.keys():
|
||||||
if i == 0:
|
# accumulate in float32 to avoid fp16 issues
|
||||||
model_state[key] = self.client_state[name][key] * weight
|
new_state[k] += self.client_state[name][k].to(torch.float32) * weight
|
||||||
else:
|
|
||||||
model_state[key] += self.client_state[name][key] * weight
|
|
||||||
|
|
||||||
# Weighted average losses
|
# losses
|
||||||
for k, v in self.client_loss[name].items():
|
for k, v in self.client_loss[name].items():
|
||||||
avg_loss[k] = avg_loss.get(k, 0.0) + v * weight
|
avg_loss[k] = avg_loss.get(k, 0.0) + v * weight
|
||||||
|
|
||||||
self.model.load_state_dict(model_state, strict=True)
|
# load aggregated params back into global model
|
||||||
|
self.model.load_state_dict(new_state, strict=True)
|
||||||
self.round += 1
|
self.round += 1
|
||||||
return model_state, avg_loss, self.n_data
|
return self.model.state_dict(), avg_loss, self.n_data
|
||||||
|
|
||||||
def rec(self, name, state_dict, n_data, loss_dict):
|
def rec(self, name, state_dict, n_data, loss_dict):
|
||||||
"""
|
"""
|
||||||
|
10
fed_run.py
10
fed_run.py
@@ -142,7 +142,8 @@ def fed_run():
|
|||||||
# --- build clients ---
|
# --- build clients ---
|
||||||
model_name = cfg.get("model_name", "yolo_v11_n")
|
model_name = cfg.get("model_name", "yolo_v11_n")
|
||||||
clients = {}
|
clients = {}
|
||||||
for uid in users:
|
|
||||||
|
for uid in tqdm(users, desc="Building clients", leave=True, unit="client"):
|
||||||
c = FedYoloClient(name=uid, model_name=model_name, params=params)
|
c = FedYoloClient(name=uid, model_name=model_name, params=params)
|
||||||
c.load_trainset(user_data[uid]["filename"])
|
c.load_trainset(user_data[uid]["filename"])
|
||||||
clients[uid] = c
|
clients[uid] = c
|
||||||
@@ -176,11 +177,11 @@ def fed_run():
|
|||||||
res_root = cfg.get("res_root", "results")
|
res_root = cfg.get("res_root", "results")
|
||||||
os.makedirs(res_root, exist_ok=True)
|
os.makedirs(res_root, exist_ok=True)
|
||||||
|
|
||||||
for rnd in tqdm(range(num_round), desc="main federal loop round"):
|
for rnd in tqdm(range(num_round), desc="main federal loop round:"):
|
||||||
t0 = time.time()
|
t0 = time.time()
|
||||||
|
|
||||||
# Local training (sequential over all users)
|
# Local training (sequential over all users)
|
||||||
for uid in tqdm(users, desc=f"Round {rnd + 1} local training", leave=False):
|
for uid in tqdm(users, desc=f"Round {rnd + 1} local training: ", leave=False):
|
||||||
client = clients[uid] # FedYoloClient instance
|
client = clients[uid] # FedYoloClient instance
|
||||||
client.update(global_state) # load global weights
|
client.update(global_state) # load global weights
|
||||||
state_dict, n_data, loss_dict = client.train(args_cli) # local training
|
state_dict, n_data, loss_dict = client.train(args_cli) # local training
|
||||||
@@ -213,10 +214,11 @@ def fed_run():
|
|||||||
history["train_loss"].append(scalar_train_loss)
|
history["train_loss"].append(scalar_train_loss)
|
||||||
history["round_time_sec"].append(time.time() - t0)
|
history["round_time_sec"].append(time.time() - t0)
|
||||||
|
|
||||||
print(
|
tqdm.write(
|
||||||
f"[round {rnd + 1:04d}] "
|
f"[round {rnd + 1:04d}] "
|
||||||
f"loss={scalar_train_loss:.4f} mAP50-95={mAP:.4f} mAP50={mAP50:.4f} "
|
f"loss={scalar_train_loss:.4f} mAP50-95={mAP:.4f} mAP50={mAP50:.4f} "
|
||||||
f"P={precision:.4f} R={recall:.4f}"
|
f"P={precision:.4f} R={recall:.4f}"
|
||||||
|
f"\n"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Save running JSON (resumable logs)
|
# Save running JSON (resumable logs)
|
||||||
|
@@ -25,6 +25,10 @@ def _parse_yolo_label_file(label_path: str) -> Set[int]:
|
|||||||
Return a set of class_ids found in a YOLO .txt label file.
|
Return a set of class_ids found in a YOLO .txt label file.
|
||||||
Empty file -> empty set. Missing file -> empty set.
|
Empty file -> empty set. Missing file -> empty set.
|
||||||
Robust to blank lines / trailing spaces.
|
Robust to blank lines / trailing spaces.
|
||||||
|
Args:
|
||||||
|
label_path: path to the label file
|
||||||
|
Returns:
|
||||||
|
set of class IDs (integers) found in the file
|
||||||
"""
|
"""
|
||||||
class_ids: Set[int] = set()
|
class_ids: Set[int] = set()
|
||||||
if not os.path.exists(label_path):
|
if not os.path.exists(label_path):
|
||||||
|
Reference in New Issue
Block a user