import glob import os from pathlib import Path import json from pydoc import cli from threading import local import yaml from ultralytics import YOLO import copy import torch # ------------ 新增联邦学习工具函数 ------------ def federated_avg(global_model, client_weights): """联邦平均核心算法""" # 计算总样本数 total_samples = sum(n for _, n in client_weights) if total_samples == 0: raise ValueError("Total number of samples must be positive.") # DEBUG: global_dict # print(global_model) # 获取YOLO底层PyTorch模型参数 global_dict = global_model.model.state_dict() # 提取所有客户端的 state_dict 和对应样本数 state_dicts, sample_counts = zip(*client_weights) # 克隆参数并脱离计算图 global_dict_copy = { k: v.clone().detach().requires_grad_(False) for k, v in global_dict.items() } # 聚合可训练且存在的参数 for key in global_dict_copy: # if global_dict_copy[key].dtype != torch.float32: # continue # if any( # x in key for x in ["running_mean", "running_var", "num_batches_tracked"] # ): # continue # 检查所有客户端是否包含当前键 all_clients_have_key = all(key in sd for sd in state_dicts) if all_clients_have_key: # 计算每个客户端的加权张量 # weighted_tensors = [ # client_state[key].float() * (sample_count / total_samples) # for client_state, sample_count in zip(state_dicts, sample_counts) # ] weighted_tensors = [] for client_state, sample_count in zip(state_dicts, sample_counts): weight = sample_count / total_samples # 计算权重 weighted_tensor = client_state[key].float() * weight # 加权张量 weighted_tensors.append(weighted_tensor) # 聚合加权张量并更新全局参数 global_dict_copy[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0) # else: # print(f"错误: 键 {key} 在部分客户端缺失,已保留全局参数") # 终止训练或记录日志 # raise KeyError(f"键 {key} 缺失") # 加载回YOLO模型 global_model.model.load_state_dict(global_dict_copy, strict=True) # global_model.model.train() # with torch.no_grad(): # global_model.model.load_state_dict(global_dict_copy, strict=True) # 定义多个关键层 MONITOR_KEYS = [ "model.0.conv.weight", "model.1.conv.weight", "model.3.conv.weight", "model.5.conv.weight", "model.7.conv.weight", "model.9.cv1.conv.weight", "model.12.cv1.conv.weight", "model.15.cv1.conv.weight", "model.18.cv1.conv.weight", "model.21.cv1.conv.weight", "model.22.dfl.conv.weight", ] with open("aggregation_check.txt", "a") as f: f.write("\n=== 参数聚合检查 ===\n") for key in MONITOR_KEYS: # if key not in global_dict: # continue # if not all(key in sd for sd in state_dicts): # continue # 计算聚合后均值 aggregated_mean = global_dict[key].mean().item() # 计算各客户端均值 client_means = [sd[key].float().mean().item() for sd in state_dicts] with open("aggregation_check.txt", "a") as f: f.write(f"层 '{key}' 聚合后均值: {aggregated_mean:.6f}\n") f.write(f"各客户端该层均值差异: {[f'{cm:.6f}' for cm in client_means]}\n") f.write(f"客户端最大差异: {max(client_means) - min(client_means):.6f}\n\n") return global_model # ------------ 修改训练流程 ------------ def federated_train(num_rounds, clients_data): # ========== 初始化指标记录 ========== metrics = { "round": [], "val_mAP": [], # 每轮验证集mAP # "train_loss": [], # 每轮平均训练损失 "client_mAPs": [], # 各客户端本地模型在验证集上的mAP "communication_cost": [], # 每轮通信开销(MB) } # 初始化全局模型 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") global_model = ( YOLO("/home/image1325/DATA/Graduation-Project/federated_learning/yolov8n.yaml") .load("/home/image1325/DATA/Graduation-Project/federated_learning/yolov8n.pt") .to(device) ) global_model.model.model[-1].nc = 1 # 设置检测类别数为1 # global_model.model.train.ema.enabled = False # 克隆全局模型 local_model = copy.deepcopy(global_model) for _ in range(num_rounds): client_weights = [] # 各客户端的训练损失 # client_losses = [] # DEBUG: 检查全局模型参数 # global_dict = global_model.model.state_dict() # print(global_dict.keys()) # 每个客户端本地训练 for data_path in clients_data: # 统计本地训练样本数 with open(data_path, "r") as f: config = yaml.safe_load(f) # Resolve img_dir relative to the YAML file's location yaml_dir = os.path.dirname(data_path) img_dir = os.path.join( yaml_dir, config.get("train", data_path) ) # 从配置文件中获取图像目录 # print(f"Image directory: {img_dir}") num_samples = ( len(glob.glob(os.path.join(img_dir, "*.jpg"))) + len(glob.glob(os.path.join(img_dir, "*.png"))) + len(glob.glob(os.path.join(img_dir, "*.jpeg"))) ) # print(f"Number of images: {num_samples}") local_model.model.load_state_dict( global_model.model.state_dict(), strict=True ) # 本地训练(保持你的原有参数设置) local_model.train( name=f"train{_ + 1}", # 当前轮次 data=data_path, # model=local_model, epochs=16, # 每轮本地训练多少个epoch # save_period=16, imgsz=768, # 图像大小 verbose=False, # 关闭冗余输出 batch=-1, # 批大小 workers=6, # 工作线程数 ) # 记录客户端训练损失 # client_loss = results.results_dict['train_loss'] # client_losses.append(client_loss) # 收集模型参数及样本数 client_weights.append((local_model.model.state_dict(), num_samples)) # 聚合参数更新全局模型 global_model = federated_avg(global_model, client_weights) # DEBUG: 检查全局模型参数 # keys = global_model.model.state_dict().keys() # ========== 评估全局模型 ========== # 复制全局模型以避免在评估时修改参数 val_model = copy.deepcopy(global_model) # 评估全局模型在验证集上的性能 with torch.no_grad(): val_results = val_model.val( data="/mnt/DATA/uav_dataset_old/UAVdataset/fed_data.yaml", # 指定验证集配置文件 imgsz=768, # 图像大小 batch=16, # 批大小 verbose=False, # 关闭冗余输出 ) # 丢弃评估模型 del val_model # DEBUG: 检查全局模型参数 # if keys != global_model.model.state_dict().keys(): # print("模型参数不一致!") val_mAP = val_results.box.map # 获取mAP@0.5 # 计算平均训练损失 # avg_train_loss = sum(client_losses) / len(client_losses) # 计算通信开销(假设传输全部模型参数) model_size = sum(p.numel() * 4 for p in global_model.model.parameters()) / ( 1024**2 ) # MB # 记录到指标容器 metrics["round"].append(_ + 1) metrics["val_mAP"].append(val_mAP) # metrics['train_loss'].append(avg_train_loss) metrics["communication_cost"].append(model_size) # 打印当前轮次结果 with open("aggregation_check.txt", "a") as f: f.write(f"\n[Round {_ + 1}/{num_rounds}]\n") f.write(f"Validation mAP@0.5: {val_mAP:.4f}\n") # f.write(f"Average Train Loss: {avg_train_loss:.4f}") f.write(f"Communication Cost: {model_size:.2f} MB\n\n") return global_model, metrics if __name__ == "__main__": # 联邦训练配置 clients_config = [ "/mnt/DATA/uav_fed/train1/train1.yaml", # 客户端1数据路径 "/mnt/DATA/uav_fed/train2/train2.yaml", # 客户端2数据路径 ] # 使用本地数据集进行测试 # clients_config = [ # "/home/image1325/DATA/Graduation-Project/dataset/train1/train1.yaml", # "/home/image1325/DATA/Graduation-Project/dataset/train2/train2.yaml", # ] # 运行联邦训练 final_model, metrics = federated_train(num_rounds=10, clients_data=clients_config) # 保存最终模型 final_model.save("yolov8n_federated.pt") # final_model.export(format="onnx") # 导出为ONNX格式 with open("metrics.json", "w") as f: json.dump(metrics, f, indent=4)