diff --git a/federated_learning/yolov8_fed.py b/federated_learning/yolov8_fed.py index dbadbe8..4c66766 100644 --- a/federated_learning/yolov8_fed.py +++ b/federated_learning/yolov8_fed.py @@ -2,6 +2,8 @@ import glob import os from pathlib import Path import json +from pydoc import cli +from threading import local import yaml from ultralytics import YOLO @@ -17,66 +19,83 @@ def federated_avg(global_model, client_weights): if total_samples == 0: raise ValueError("Total number of samples must be positive.") + # DEBUG: global_dict + # print(global_model) + # 获取YOLO底层PyTorch模型参数 global_dict = global_model.model.state_dict() # 提取所有客户端的 state_dict 和对应样本数 state_dicts, sample_counts = zip(*client_weights) - for key in global_dict: - # 对每一层参数取平均 - # if global_dict[key].data.dtype == torch.float32: - # global_dict[key].data = torch.stack( - # [w[key].float() for w in client_weights], 0 - # ).mean(0) + # 克隆参数并脱离计算图 + global_dict_copy = { + k: v.clone().detach().requires_grad_(False) for k, v in global_dict.items() + } - # 加权平均 - if global_dict[key].dtype == torch.float32: # 只聚合浮点型参数 - # 跳过 BatchNorm 层的统计量 - if any( - x in key for x in ["running_mean", "running_var", "num_batches_tracked"] - ): - continue - # 按照样本数加权求和 - weighted_tensors = [ - sd[key].float() * (n / total_samples) - for sd, n in zip(state_dicts, sample_counts) - ] - global_dict[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0) + # 聚合可训练且存在的参数 + for key in global_dict_copy: + # if global_dict_copy[key].dtype != torch.float32: + # continue + # if any( + # x in key for x in ["running_mean", "running_var", "num_batches_tracked"] + # ): + # continue + # 检查所有客户端是否包含当前键 + all_clients_have_key = all(key in sd for sd in state_dicts) + if all_clients_have_key: + # 计算每个客户端的加权张量 + # weighted_tensors = [ + # client_state[key].float() * (sample_count / total_samples) + # for client_state, sample_count in zip(state_dicts, sample_counts) + # ] + weighted_tensors = [] + for client_state, sample_count in zip(state_dicts, sample_counts): + weight = sample_count / total_samples # 计算权重 + weighted_tensor = client_state[key].float() * weight # 加权张量 + weighted_tensors.append(weighted_tensor) + # 聚合加权张量并更新全局参数 + global_dict_copy[key] = torch.stack(weighted_tensors, dim=0).sum(dim=0) - # 解决模型参数不匹配问题 - # try: - # # 加载回YOLO模型 - # global_model.model.load_state_dict(global_dict) - # except RuntimeError as e: - # print('Ignoring "' + str(e) + '"') + # else: + # print(f"错误: 键 {key} 在部分客户端缺失,已保留全局参数") + # 终止训练或记录日志 + # raise KeyError(f"键 {key} 缺失") - # 加载回YOLO模型 - global_model.model.load_state_dict(global_dict) + # 加载回YOLO模型 + global_model.model.load_state_dict(global_dict_copy, strict=True) - # 随机选取一个非统计量层进行对比 - # sample_key = next(k for k in global_dict if 'running_' not in k) - # aggregated_mean = global_dict[sample_key].mean().item() - # client_means = [sd[sample_key].float().mean().item() for sd in state_dicts] - # print(f"layer: '{sample_key}' Mean after aggregation: {aggregated_mean:.6f}") - # print(f"The average value of the layer for each client: {client_means}") + # global_model.model.train() + # with torch.no_grad(): + # global_model.model.load_state_dict(global_dict_copy, strict=True) # 定义多个关键层 MONITOR_KEYS = [ - "model.0.conv.weight", # 输入层卷积 - "model.10.conv.weight", # 中间层卷积 - "model.22.dfl.conv.weight", # 输出层分类头 + "model.0.conv.weight", + "model.1.conv.weight", + "model.3.conv.weight", + "model.5.conv.weight", + "model.7.conv.weight", + "model.9.cv1.conv.weight", + "model.12.cv1.conv.weight", + "model.15.cv1.conv.weight", + "model.18.cv1.conv.weight", + "model.21.cv1.conv.weight", + "model.22.dfl.conv.weight", ] with open("aggregation_check.txt", "a") as f: f.write("\n=== 参数聚合检查 ===\n") for key in MONITOR_KEYS: - if key not in global_dict: - continue + # if key not in global_dict: + # continue + # if not all(key in sd for sd in state_dicts): + # continue + # 计算聚合后均值 aggregated_mean = global_dict[key].mean().item() + # 计算各客户端均值 client_means = [sd[key].float().mean().item() for sd in state_dicts] - with open("aggregation_check.txt", "a") as f: f.write(f"层 '{key}' 聚合后均值: {aggregated_mean:.6f}\n") f.write(f"各客户端该层均值差异: {[f'{cm:.6f}' for cm in client_means]}\n") @@ -87,24 +106,35 @@ def federated_avg(global_model, client_weights): # ------------ 修改训练流程 ------------ def federated_train(num_rounds, clients_data): - # ========== 新增:初始化指标记录 ========== + # ========== 初始化指标记录 ========== metrics = { "round": [], "val_mAP": [], # 每轮验证集mAP - "train_loss": [], # 每轮平均训练损失 + # "train_loss": [], # 每轮平均训练损失 "client_mAPs": [], # 各客户端本地模型在验证集上的mAP "communication_cost": [], # 每轮通信开销(MB) } - # 初始化全局模型 device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - global_model = YOLO("../yolov8n.yaml").to(device) - # 设置类别数 - # global_model.model.nc = 1 + global_model = ( + YOLO("/home/image1325/DATA/Graduation-Project/federated_learning/yolov8n.yaml") + .load("/home/image1325/DATA/Graduation-Project/federated_learning/yolov8n.pt") + .to(device) + ) + global_model.model.model[-1].nc = 1 # 设置检测类别数为1 + # global_model.model.train.ema.enabled = False + + # 克隆全局模型 + local_model = copy.deepcopy(global_model) for _ in range(num_rounds): client_weights = [] - client_losses = [] # 记录各客户端的训练损失 + # 各客户端的训练损失 + # client_losses = [] + + # DEBUG: 检查全局模型参数 + # global_dict = global_model.model.state_dict() + # print(global_dict.keys()) # 每个客户端本地训练 for data_path in clients_data: @@ -118,23 +148,28 @@ def federated_train(num_rounds, clients_data): ) # 从配置文件中获取图像目录 # print(f"Image directory: {img_dir}") - num_samples = (len(glob.glob(os.path.join(img_dir, "*.jpg"))) + num_samples = ( + len(glob.glob(os.path.join(img_dir, "*.jpg"))) + len(glob.glob(os.path.join(img_dir, "*.png"))) + len(glob.glob(os.path.join(img_dir, "*.jpeg"))) ) # print(f"Number of images: {num_samples}") - # 克隆全局模型 - local_model = copy.deepcopy(global_model) + local_model.model.load_state_dict( + global_model.model.state_dict(), strict=True + ) # 本地训练(保持你的原有参数设置) - results = local_model.train( + local_model.train( + name=f"train{_ + 1}", # 当前轮次 data=data_path, - epochs=4, # 每轮本地训练多少个epoch + # model=local_model, + epochs=16, # 每轮本地训练多少个epoch # save_period=16, - imgsz=640, # 图像大小 + imgsz=768, # 图像大小 verbose=False, # 关闭冗余输出 - batch=-1, + batch=-1, # 批大小 + workers=6, # 工作线程数 ) # 记录客户端训练损失 @@ -142,21 +177,32 @@ def federated_train(num_rounds, clients_data): # client_losses.append(client_loss) # 收集模型参数及样本数 - client_weights.append( - (copy.deepcopy(local_model.model.state_dict()), num_samples) - ) + client_weights.append((local_model.model.state_dict(), num_samples)) # 聚合参数更新全局模型 global_model = federated_avg(global_model, client_weights) + # DEBUG: 检查全局模型参数 + # keys = global_model.model.state_dict().keys() + # ========== 评估全局模型 ========== + # 复制全局模型以避免在评估时修改参数 + val_model = copy.deepcopy(global_model) # 评估全局模型在验证集上的性能 - val_results = global_model.val( - data="/mnt/DATA/UAVdataset/data.yaml", # 指定验证集配置文件 - imgsz=640, - batch=-1, - verbose=False, - ) + with torch.no_grad(): + val_results = val_model.val( + data="/mnt/DATA/uav_dataset_old/UAVdataset/fed_data.yaml", # 指定验证集配置文件 + imgsz=768, # 图像大小 + batch=16, # 批大小 + verbose=False, # 关闭冗余输出 + ) + # 丢弃评估模型 + del val_model + + # DEBUG: 检查全局模型参数 + # if keys != global_model.model.state_dict().keys(): + # print("模型参数不一致!") + val_mAP = val_results.box.map # 获取mAP@0.5 # 计算平均训练损失 @@ -174,24 +220,29 @@ def federated_train(num_rounds, clients_data): metrics["communication_cost"].append(model_size) # 打印当前轮次结果 with open("aggregation_check.txt", "a") as f: - f.write(f"\n[Round {_ + 1}/{num_rounds}]") - f.write(f"Validation mAP@0.5: {val_mAP:.4f}") + f.write(f"\n[Round {_ + 1}/{num_rounds}]\n") + f.write(f"Validation mAP@0.5: {val_mAP:.4f}\n") # f.write(f"Average Train Loss: {avg_train_loss:.4f}") - f.write(f"Communication Cost: {model_size:.2f} MB\n") + f.write(f"Communication Cost: {model_size:.2f} MB\n\n") return global_model, metrics -# ------------ 使用示例 ------------ if __name__ == "__main__": # 联邦训练配置 clients_config = [ - "/mnt/DATA/uav_dataset_fed/train1/train1.yaml", # 客户端1数据路径 - "/mnt/DATA/uav_dataset_fed/train2/train2.yaml", # 客户端2数据路径 + "/mnt/DATA/uav_fed/train1/train1.yaml", # 客户端1数据路径 + "/mnt/DATA/uav_fed/train2/train2.yaml", # 客户端2数据路径 ] + # 使用本地数据集进行测试 + # clients_config = [ + # "/home/image1325/DATA/Graduation-Project/dataset/train1/train1.yaml", + # "/home/image1325/DATA/Graduation-Project/dataset/train2/train2.yaml", + # ] + # 运行联邦训练 - final_model, metrics = federated_train(num_rounds=40, clients_data=clients_config) + final_model, metrics = federated_train(num_rounds=10, clients_data=clients_config) # 保存最终模型 final_model.save("yolov8n_federated.pt")