Graduation-Project/federated_learning/yolov8_fed.py

253 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import glob
import os
from pathlib import Path
import json
from pydoc import cli
from threading import local
import yaml
from ultralytics import YOLO
import copy
import torch
# ------------ 新增联邦学习工具函数 ------------
def federated_avg(global_model, client_weights):
"""联邦平均核心算法"""
# 计算总样本数
total_samples = sum(n for _, n in client_weights)
if total_samples == 0:
raise ValueError("Total number of samples must be positive.")
# DEBUG: global_dict
# print(global_model)
# 获取YOLO底层PyTorch模型参数
global_dict = global_model.model.state_dict()
# 提取所有客户端的 state_dict 和对应样本数
state_dicts, sample_counts = zip(*client_weights)
# 克隆参数并脱离计算图
global_dict_copy = {
k: v.clone().detach().requires_grad_(False) for k, v in global_dict.items()
}
# 聚合可训练且存在的参数
for key in global_dict_copy:
# if global_dict_copy[key].dtype != torch.float32:
# continue
# if any(
# x in key for x in ["running_mean", "running_var", "num_batches_tracked"]
# ):
# continue
# 检查所有客户端是否包含当前键
all_clients_have_key = all(key in sd for sd in state_dicts)
if all_clients_have_key:
# 计算每个客户端的加权张量
# weighted_tensors = [
# client_state[key].float() * (sample_count / total_samples)
# for client_state, sample_count in zip(state_dicts, sample_counts)
# ]
weighted_tensors = []
for client_state, sample_count in zip(state_dicts, sample_counts):
weight = sample_count / total_samples # 计算权重
weighted_tensor = client_state[key].float() * weight # 加权张量
weighted_tensors.append(weighted_tensor)
# 聚合加权张量并更新全局参数
global_dict_copy[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0)
# else:
# print(f"错误: 键 {key} 在部分客户端缺失,已保留全局参数")
# 终止训练或记录日志
# raise KeyError(f"键 {key} 缺失")
# 加载回YOLO模型
global_model.model.load_state_dict(global_dict_copy, strict=True)
# global_model.model.train()
# with torch.no_grad():
# global_model.model.load_state_dict(global_dict_copy, strict=True)
# 定义多个关键层
MONITOR_KEYS = [
"model.0.conv.weight",
"model.1.conv.weight",
"model.3.conv.weight",
"model.5.conv.weight",
"model.7.conv.weight",
"model.9.cv1.conv.weight",
"model.12.cv1.conv.weight",
"model.15.cv1.conv.weight",
"model.18.cv1.conv.weight",
"model.21.cv1.conv.weight",
"model.22.dfl.conv.weight",
]
with open("aggregation_check.txt", "a") as f:
f.write("\n=== 参数聚合检查 ===\n")
for key in MONITOR_KEYS:
# if key not in global_dict:
# continue
# if not all(key in sd for sd in state_dicts):
# continue
# 计算聚合后均值
aggregated_mean = global_dict[key].mean().item()
# 计算各客户端均值
client_means = [sd[key].float().mean().item() for sd in state_dicts]
with open("aggregation_check.txt", "a") as f:
f.write(f"'{key}' 聚合后均值: {aggregated_mean:.6f}\n")
f.write(f"各客户端该层均值差异: {[f'{cm:.6f}' for cm in client_means]}\n")
f.write(f"客户端最大差异: {max(client_means) - min(client_means):.6f}\n\n")
return global_model
# ------------ 修改训练流程 ------------
def federated_train(num_rounds, clients_data):
# ========== 初始化指标记录 ==========
metrics = {
"round": [],
"val_mAP": [], # 每轮验证集mAP
# "train_loss": [], # 每轮平均训练损失
"client_mAPs": [], # 各客户端本地模型在验证集上的mAP
"communication_cost": [], # 每轮通信开销MB
}
# 初始化全局模型
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
global_model = (
YOLO("/home/image1325/DATA/Graduation-Project/federated_learning/yolov8n.yaml")
.load("/home/image1325/DATA/Graduation-Project/federated_learning/yolov8n.pt")
.to(device)
)
global_model.model.model[-1].nc = 1 # 设置检测类别数为1
# global_model.model.train.ema.enabled = False
# 克隆全局模型
local_model = copy.deepcopy(global_model)
for _ in range(num_rounds):
client_weights = []
# 各客户端的训练损失
# client_losses = []
# DEBUG: 检查全局模型参数
# global_dict = global_model.model.state_dict()
# print(global_dict.keys())
# 每个客户端本地训练
for data_path in clients_data:
# 统计本地训练样本数
with open(data_path, "r") as f:
config = yaml.safe_load(f)
# Resolve img_dir relative to the YAML file's location
yaml_dir = os.path.dirname(data_path)
img_dir = os.path.join(
yaml_dir, config.get("train", data_path)
) # 从配置文件中获取图像目录
# print(f"Image directory: {img_dir}")
num_samples = (
len(glob.glob(os.path.join(img_dir, "*.jpg")))
+ len(glob.glob(os.path.join(img_dir, "*.png")))
+ len(glob.glob(os.path.join(img_dir, "*.jpeg")))
)
# print(f"Number of images: {num_samples}")
local_model.model.load_state_dict(
global_model.model.state_dict(), strict=True
)
# 本地训练(保持你的原有参数设置)
local_model.train(
name=f"train{_ + 1}", # 当前轮次
data=data_path,
# model=local_model,
epochs=16, # 每轮本地训练多少个epoch
# save_period=16,
imgsz=768, # 图像大小
verbose=False, # 关闭冗余输出
batch=-1, # 批大小
workers=6, # 工作线程数
)
# 记录客户端训练损失
# client_loss = results.results_dict['train_loss']
# client_losses.append(client_loss)
# 收集模型参数及样本数
client_weights.append((local_model.model.state_dict(), num_samples))
# 聚合参数更新全局模型
global_model = federated_avg(global_model, client_weights)
# DEBUG: 检查全局模型参数
# keys = global_model.model.state_dict().keys()
# ========== 评估全局模型 ==========
# 复制全局模型以避免在评估时修改参数
val_model = copy.deepcopy(global_model)
# 评估全局模型在验证集上的性能
with torch.no_grad():
val_results = val_model.val(
data="/mnt/DATA/uav_dataset_old/UAVdataset/fed_data.yaml", # 指定验证集配置文件
imgsz=768, # 图像大小
batch=16, # 批大小
verbose=False, # 关闭冗余输出
)
# 丢弃评估模型
del val_model
# DEBUG: 检查全局模型参数
# if keys != global_model.model.state_dict().keys():
# print("模型参数不一致!")
val_mAP = val_results.box.map # 获取mAP@0.5
# 计算平均训练损失
# avg_train_loss = sum(client_losses) / len(client_losses)
# 计算通信开销(假设传输全部模型参数)
model_size = sum(p.numel() * 4 for p in global_model.model.parameters()) / (
1024**2
) # MB
# 记录到指标容器
metrics["round"].append(_ + 1)
metrics["val_mAP"].append(val_mAP)
# metrics['train_loss'].append(avg_train_loss)
metrics["communication_cost"].append(model_size)
# 打印当前轮次结果
with open("aggregation_check.txt", "a") as f:
f.write(f"\n[Round {_ + 1}/{num_rounds}]\n")
f.write(f"Validation mAP@0.5: {val_mAP:.4f}\n")
# f.write(f"Average Train Loss: {avg_train_loss:.4f}")
f.write(f"Communication Cost: {model_size:.2f} MB\n\n")
return global_model, metrics
if __name__ == "__main__":
# 联邦训练配置
clients_config = [
"/mnt/DATA/uav_fed/train1/train1.yaml", # 客户端1数据路径
"/mnt/DATA/uav_fed/train2/train2.yaml", # 客户端2数据路径
]
# 使用本地数据集进行测试
# clients_config = [
# "/home/image1325/DATA/Graduation-Project/dataset/train1/train1.yaml",
# "/home/image1325/DATA/Graduation-Project/dataset/train2/train2.yaml",
# ]
# 运行联邦训练
final_model, metrics = federated_train(num_rounds=10, clients_data=clients_config)
# 保存最终模型
final_model.save("yolov8n_federated.pt")
# final_model.export(format="onnx") # 导出为ONNX格式
with open("metrics.json", "w") as f:
json.dump(metrics, f, indent=4)