Graduation-Project/federated_learning/yolov8_fed.py

202 lines
7.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import glob
import os
from pathlib import Path
import json
import yaml
from ultralytics import YOLO
import copy
import torch
# ------------ 新增联邦学习工具函数 ------------
def federated_avg(global_model, client_weights):
"""联邦平均核心算法"""
# 计算总样本数
total_samples = sum(n for _, n in client_weights)
if total_samples == 0:
raise ValueError("Total number of samples must be positive.")
# 获取YOLO底层PyTorch模型参数
global_dict = global_model.model.state_dict()
# 提取所有客户端的 state_dict 和对应样本数
state_dicts, sample_counts = zip(*client_weights)
for key in global_dict:
# 对每一层参数取平均
# if global_dict[key].data.dtype == torch.float32:
# global_dict[key].data = torch.stack(
# [w[key].float() for w in client_weights], 0
# ).mean(0)
# 加权平均
if global_dict[key].dtype == torch.float32: # 只聚合浮点型参数
# 跳过 BatchNorm 层的统计量
if any(
x in key for x in ["running_mean", "running_var", "num_batches_tracked"]
):
continue
# 按照样本数加权求和
weighted_tensors = [
sd[key].float() * (n / total_samples)
for sd, n in zip(state_dicts, sample_counts)
]
global_dict[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0)
# 解决模型参数不匹配问题
# try:
# # 加载回YOLO模型
# global_model.model.load_state_dict(global_dict)
# except RuntimeError as e:
# print('Ignoring "' + str(e) + '"')
# 加载回YOLO模型
global_model.model.load_state_dict(global_dict)
# 随机选取一个非统计量层进行对比
# sample_key = next(k for k in global_dict if 'running_' not in k)
# aggregated_mean = global_dict[sample_key].mean().item()
# client_means = [sd[sample_key].float().mean().item() for sd in state_dicts]
# print(f"layer: '{sample_key}' Mean after aggregation: {aggregated_mean:.6f}")
# print(f"The average value of the layer for each client: {client_means}")
# 定义多个关键层
MONITOR_KEYS = [
"model.0.conv.weight", # 输入层卷积
"model.10.conv.weight", # 中间层卷积
"model.22.dfl.conv.weight", # 输出层分类头
]
with open("aggregation_check.txt", "a") as f:
f.write("\n=== 参数聚合检查 ===\n")
for key in MONITOR_KEYS:
if key not in global_dict:
continue
# 计算聚合后均值
aggregated_mean = global_dict[key].mean().item()
# 计算各客户端均值
client_means = [sd[key].float().mean().item() for sd in state_dicts]
with open("aggregation_check.txt", "a") as f:
f.write(f"'{key}' 聚合后均值: {aggregated_mean:.6f}\n")
f.write(f"各客户端该层均值差异: {[f'{cm:.6f}' for cm in client_means]}\n")
f.write(f"客户端最大差异: {max(client_means) - min(client_means):.6f}\n\n")
return global_model
# ------------ 修改训练流程 ------------
def federated_train(num_rounds, clients_data):
# ========== 新增:初始化指标记录 ==========
metrics = {
"round": [],
"val_mAP": [], # 每轮验证集mAP
"train_loss": [], # 每轮平均训练损失
"client_mAPs": [], # 各客户端本地模型在验证集上的mAP
"communication_cost": [], # 每轮通信开销MB
}
# 初始化全局模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_model = YOLO("../yolov8n.yaml").to(device)
# 设置类别数
# global_model.model.nc = 1
for _ in range(num_rounds):
client_weights = []
client_losses = [] # 记录各客户端的训练损失
# 每个客户端本地训练
for data_path in clients_data:
# 统计本地训练样本数
with open(data_path, "r") as f:
config = yaml.safe_load(f)
# Resolve img_dir relative to the YAML file's location
yaml_dir = os.path.dirname(data_path)
img_dir = os.path.join(
yaml_dir, config.get("train", data_path)
) # 从配置文件中获取图像目录
# print(f"Image directory: {img_dir}")
num_samples = (len(glob.glob(os.path.join(img_dir, "*.jpg")))
+ len(glob.glob(os.path.join(img_dir, "*.png")))
+ len(glob.glob(os.path.join(img_dir, "*.jpeg")))
)
# print(f"Number of images: {num_samples}")
# 克隆全局模型
local_model = copy.deepcopy(global_model)
# 本地训练(保持你的原有参数设置)
results = local_model.train(
data=data_path,
epochs=4, # 每轮本地训练多少个epoch
# save_period=16,
imgsz=640, # 图像大小
verbose=False, # 关闭冗余输出
batch=-1,
)
# 记录客户端训练损失
# client_loss = results.results_dict['train_loss']
# client_losses.append(client_loss)
# 收集模型参数及样本数
client_weights.append(
(copy.deepcopy(local_model.model.state_dict()), num_samples)
)
# 聚合参数更新全局模型
global_model = federated_avg(global_model, client_weights)
# ========== 评估全局模型 ==========
# 评估全局模型在验证集上的性能
val_results = global_model.val(
data="/mnt/DATA/UAVdataset/data.yaml", # 指定验证集配置文件
imgsz=640,
batch=-1,
verbose=False,
)
val_mAP = val_results.box.map # 获取mAP@0.5
# 计算平均训练损失
# avg_train_loss = sum(client_losses) / len(client_losses)
# 计算通信开销(假设传输全部模型参数)
model_size = sum(p.numel() * 4 for p in global_model.model.parameters()) / (
1024**2
) # MB
# 记录到指标容器
metrics["round"].append(_ + 1)
metrics["val_mAP"].append(val_mAP)
# metrics['train_loss'].append(avg_train_loss)
metrics["communication_cost"].append(model_size)
# 打印当前轮次结果
with open("aggregation_check.txt", "a") as f:
f.write(f"\n[Round {_ + 1}/{num_rounds}]")
f.write(f"Validation mAP@0.5: {val_mAP:.4f}")
# f.write(f"Average Train Loss: {avg_train_loss:.4f}")
f.write(f"Communication Cost: {model_size:.2f} MB\n")
return global_model, metrics
# ------------ 使用示例 ------------
if __name__ == "__main__":
# 联邦训练配置
clients_config = [
"/mnt/DATA/uav_dataset_fed/train1/train1.yaml", # 客户端1数据路径
"/mnt/DATA/uav_dataset_fed/train2/train2.yaml", # 客户端2数据路径
]
# 运行联邦训练
final_model, metrics = federated_train(num_rounds=40, clients_data=clients_config)
# 保存最终模型
final_model.save("yolov8n_federated.pt")
# final_model.export(format="onnx") # 导出为ONNX格式
with open("metrics.json", "w") as f:
json.dump(metrics, f, indent=4)