增加联邦学习评价指标。bugfix: 修复训练模型参数聚合问题

This commit is contained in:
Yunhao Meng 2025-05-10 17:22:56 +08:00
parent 98321aa7d5
commit 76240a12e6

View File

@ -2,6 +2,8 @@ import glob
import os import os
from pathlib import Path from pathlib import Path
import json import json
from pydoc import cli
from threading import local
import yaml import yaml
from ultralytics import YOLO from ultralytics import YOLO
@ -17,66 +19,83 @@ def federated_avg(global_model, client_weights):
if total_samples == 0: if total_samples == 0:
raise ValueError("Total number of samples must be positive.") raise ValueError("Total number of samples must be positive.")
# DEBUG: global_dict
# print(global_model)
# 获取YOLO底层PyTorch模型参数 # 获取YOLO底层PyTorch模型参数
global_dict = global_model.model.state_dict() global_dict = global_model.model.state_dict()
# 提取所有客户端的 state_dict 和对应样本数 # 提取所有客户端的 state_dict 和对应样本数
state_dicts, sample_counts = zip(*client_weights) state_dicts, sample_counts = zip(*client_weights)
for key in global_dict: # 克隆参数并脱离计算图
# 对每一层参数取平均 global_dict_copy = {
# if global_dict[key].data.dtype == torch.float32: k: v.clone().detach().requires_grad_(False) for k, v in global_dict.items()
# global_dict[key].data = torch.stack( }
# [w[key].float() for w in client_weights], 0
# ).mean(0)
# 加权平均 # 聚合可训练且存在的参数
if global_dict[key].dtype == torch.float32: # 只聚合浮点型参数 for key in global_dict_copy:
# 跳过 BatchNorm 层的统计量 # if global_dict_copy[key].dtype != torch.float32:
if any( # continue
x in key for x in ["running_mean", "running_var", "num_batches_tracked"] # if any(
): # x in key for x in ["running_mean", "running_var", "num_batches_tracked"]
continue # ):
# 按照样本数加权求和 # continue
weighted_tensors = [ # 检查所有客户端是否包含当前键
sd[key].float() * (n / total_samples) all_clients_have_key = all(key in sd for sd in state_dicts)
for sd, n in zip(state_dicts, sample_counts) if all_clients_have_key:
] # 计算每个客户端的加权张量
global_dict[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0) # weighted_tensors = [
# client_state[key].float() * (sample_count / total_samples)
# for client_state, sample_count in zip(state_dicts, sample_counts)
# ]
weighted_tensors = []
for client_state, sample_count in zip(state_dicts, sample_counts):
weight = sample_count / total_samples # 计算权重
weighted_tensor = client_state[key].float() * weight # 加权张量
weighted_tensors.append(weighted_tensor)
# 聚合加权张量并更新全局参数
global_dict_copy[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0)
# 解决模型参数不匹配问题 # else:
# try: # print(f"错误: 键 {key} 在部分客户端缺失,已保留全局参数")
# # 加载回YOLO模型 # 终止训练或记录日志
# global_model.model.load_state_dict(global_dict) # raise KeyError(f"键 {key} 缺失")
# except RuntimeError as e:
# print('Ignoring "' + str(e) + '"')
# 加载回YOLO模型 # 加载回YOLO模型
global_model.model.load_state_dict(global_dict) global_model.model.load_state_dict(global_dict_copy, strict=True)
# 随机选取一个非统计量层进行对比 # global_model.model.train()
# sample_key = next(k for k in global_dict if 'running_' not in k) # with torch.no_grad():
# aggregated_mean = global_dict[sample_key].mean().item() # global_model.model.load_state_dict(global_dict_copy, strict=True)
# client_means = [sd[sample_key].float().mean().item() for sd in state_dicts]
# print(f"layer: '{sample_key}' Mean after aggregation: {aggregated_mean:.6f}")
# print(f"The average value of the layer for each client: {client_means}")
# 定义多个关键层 # 定义多个关键层
MONITOR_KEYS = [ MONITOR_KEYS = [
"model.0.conv.weight", # 输入层卷积 "model.0.conv.weight",
"model.10.conv.weight", # 中间层卷积 "model.1.conv.weight",
"model.22.dfl.conv.weight", # 输出层分类头 "model.3.conv.weight",
"model.5.conv.weight",
"model.7.conv.weight",
"model.9.cv1.conv.weight",
"model.12.cv1.conv.weight",
"model.15.cv1.conv.weight",
"model.18.cv1.conv.weight",
"model.21.cv1.conv.weight",
"model.22.dfl.conv.weight",
] ]
with open("aggregation_check.txt", "a") as f: with open("aggregation_check.txt", "a") as f:
f.write("\n=== 参数聚合检查 ===\n") f.write("\n=== 参数聚合检查 ===\n")
for key in MONITOR_KEYS: for key in MONITOR_KEYS:
if key not in global_dict: # if key not in global_dict:
continue # continue
# if not all(key in sd for sd in state_dicts):
# continue
# 计算聚合后均值 # 计算聚合后均值
aggregated_mean = global_dict[key].mean().item() aggregated_mean = global_dict[key].mean().item()
# 计算各客户端均值 # 计算各客户端均值
client_means = [sd[key].float().mean().item() for sd in state_dicts] client_means = [sd[key].float().mean().item() for sd in state_dicts]
with open("aggregation_check.txt", "a") as f: with open("aggregation_check.txt", "a") as f:
f.write(f"'{key}' 聚合后均值: {aggregated_mean:.6f}\n") f.write(f"'{key}' 聚合后均值: {aggregated_mean:.6f}\n")
f.write(f"各客户端该层均值差异: {[f'{cm:.6f}' for cm in client_means]}\n") f.write(f"各客户端该层均值差异: {[f'{cm:.6f}' for cm in client_means]}\n")
@ -87,24 +106,35 @@ def federated_avg(global_model, client_weights):
# ------------ 修改训练流程 ------------ # ------------ 修改训练流程 ------------
def federated_train(num_rounds, clients_data): def federated_train(num_rounds, clients_data):
# ========== 新增:初始化指标记录 ========== # ========== 初始化指标记录 ==========
metrics = { metrics = {
"round": [], "round": [],
"val_mAP": [], # 每轮验证集mAP "val_mAP": [], # 每轮验证集mAP
"train_loss": [], # 每轮平均训练损失 # "train_loss": [], # 每轮平均训练损失
"client_mAPs": [], # 各客户端本地模型在验证集上的mAP "client_mAPs": [], # 各客户端本地模型在验证集上的mAP
"communication_cost": [], # 每轮通信开销MB "communication_cost": [], # 每轮通信开销MB
} }
# 初始化全局模型 # 初始化全局模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_model = YOLO("../yolov8n.yaml").to(device) global_model = (
# 设置类别数 YOLO("/home/image1325/DATA/Graduation-Project/federated_learning/yolov8n.yaml")
# global_model.model.nc = 1 .load("/home/image1325/DATA/Graduation-Project/federated_learning/yolov8n.pt")
.to(device)
)
global_model.model.model[-1].nc = 1 # 设置检测类别数为1
# global_model.model.train.ema.enabled = False
# 克隆全局模型
local_model = copy.deepcopy(global_model)
for _ in range(num_rounds): for _ in range(num_rounds):
client_weights = [] client_weights = []
client_losses = [] # 记录各客户端的训练损失 # 各客户端的训练损失
# client_losses = []
# DEBUG: 检查全局模型参数
# global_dict = global_model.model.state_dict()
# print(global_dict.keys())
# 每个客户端本地训练 # 每个客户端本地训练
for data_path in clients_data: for data_path in clients_data:
@ -118,23 +148,28 @@ def federated_train(num_rounds, clients_data):
) # 从配置文件中获取图像目录 ) # 从配置文件中获取图像目录
# print(f"Image directory: {img_dir}") # print(f"Image directory: {img_dir}")
num_samples = (len(glob.glob(os.path.join(img_dir, "*.jpg"))) num_samples = (
len(glob.glob(os.path.join(img_dir, "*.jpg")))
+ len(glob.glob(os.path.join(img_dir, "*.png"))) + len(glob.glob(os.path.join(img_dir, "*.png")))
+ len(glob.glob(os.path.join(img_dir, "*.jpeg"))) + len(glob.glob(os.path.join(img_dir, "*.jpeg")))
) )
# print(f"Number of images: {num_samples}") # print(f"Number of images: {num_samples}")
# 克隆全局模型 local_model.model.load_state_dict(
local_model = copy.deepcopy(global_model) global_model.model.state_dict(), strict=True
)
# 本地训练(保持你的原有参数设置) # 本地训练(保持你的原有参数设置)
results = local_model.train( local_model.train(
name=f"train{_ + 1}", # 当前轮次
data=data_path, data=data_path,
epochs=4, # 每轮本地训练多少个epoch # model=local_model,
epochs=16, # 每轮本地训练多少个epoch
# save_period=16, # save_period=16,
imgsz=640, # 图像大小 imgsz=768, # 图像大小
verbose=False, # 关闭冗余输出 verbose=False, # 关闭冗余输出
batch=-1, batch=-1, # 批大小
workers=6, # 工作线程数
) )
# 记录客户端训练损失 # 记录客户端训练损失
@ -142,21 +177,32 @@ def federated_train(num_rounds, clients_data):
# client_losses.append(client_loss) # client_losses.append(client_loss)
# 收集模型参数及样本数 # 收集模型参数及样本数
client_weights.append( client_weights.append((local_model.model.state_dict(), num_samples))
(copy.deepcopy(local_model.model.state_dict()), num_samples)
)
# 聚合参数更新全局模型 # 聚合参数更新全局模型
global_model = federated_avg(global_model, client_weights) global_model = federated_avg(global_model, client_weights)
# DEBUG: 检查全局模型参数
# keys = global_model.model.state_dict().keys()
# ========== 评估全局模型 ========== # ========== 评估全局模型 ==========
# 复制全局模型以避免在评估时修改参数
val_model = copy.deepcopy(global_model)
# 评估全局模型在验证集上的性能 # 评估全局模型在验证集上的性能
val_results = global_model.val( with torch.no_grad():
data="/mnt/DATA/UAVdataset/data.yaml", # 指定验证集配置文件 val_results = val_model.val(
imgsz=640, data="/mnt/DATA/uav_dataset_old/UAVdataset/fed_data.yaml", # 指定验证集配置文件
batch=-1, imgsz=768, # 图像大小
verbose=False, batch=16, # 批大小
verbose=False, # 关闭冗余输出
) )
# 丢弃评估模型
del val_model
# DEBUG: 检查全局模型参数
# if keys != global_model.model.state_dict().keys():
# print("模型参数不一致!")
val_mAP = val_results.box.map # 获取mAP@0.5 val_mAP = val_results.box.map # 获取mAP@0.5
# 计算平均训练损失 # 计算平均训练损失
@ -174,24 +220,29 @@ def federated_train(num_rounds, clients_data):
metrics["communication_cost"].append(model_size) metrics["communication_cost"].append(model_size)
# 打印当前轮次结果 # 打印当前轮次结果
with open("aggregation_check.txt", "a") as f: with open("aggregation_check.txt", "a") as f:
f.write(f"\n[Round {_ + 1}/{num_rounds}]") f.write(f"\n[Round {_ + 1}/{num_rounds}]\n")
f.write(f"Validation mAP@0.5: {val_mAP:.4f}") f.write(f"Validation mAP@0.5: {val_mAP:.4f}\n")
# f.write(f"Average Train Loss: {avg_train_loss:.4f}") # f.write(f"Average Train Loss: {avg_train_loss:.4f}")
f.write(f"Communication Cost: {model_size:.2f} MB\n") f.write(f"Communication Cost: {model_size:.2f} MB\n\n")
return global_model, metrics return global_model, metrics
# ------------ 使用示例 ------------
if __name__ == "__main__": if __name__ == "__main__":
# 联邦训练配置 # 联邦训练配置
clients_config = [ clients_config = [
"/mnt/DATA/uav_dataset_fed/train1/train1.yaml", # 客户端1数据路径 "/mnt/DATA/uav_fed/train1/train1.yaml", # 客户端1数据路径
"/mnt/DATA/uav_dataset_fed/train2/train2.yaml", # 客户端2数据路径 "/mnt/DATA/uav_fed/train2/train2.yaml", # 客户端2数据路径
] ]
# 使用本地数据集进行测试
# clients_config = [
# "/home/image1325/DATA/Graduation-Project/dataset/train1/train1.yaml",
# "/home/image1325/DATA/Graduation-Project/dataset/train2/train2.yaml",
# ]
# 运行联邦训练 # 运行联邦训练
final_model, metrics = federated_train(num_rounds=40, clients_data=clients_config) final_model, metrics = federated_train(num_rounds=10, clients_data=clients_config)
# 保存最终模型 # 保存最终模型
final_model.save("yolov8n_federated.pt") final_model.save("yolov8n_federated.pt")