联邦学习示例项目:更改结构
This commit is contained in:
65
fed_example/utils/model_utils.py
Normal file
65
fed_example/utils/model_utils.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torchvision import models
|
||||
|
||||
from Deeplab.deeplab import DeepLab_F
|
||||
from Deeplab.resnet_psa import BasicBlockWithPSA
|
||||
from Deeplab.resnet_psa_v2 import ResNet
|
||||
from model_base.efNet_base_model import DeepLab
|
||||
from model_base.efficientnet import EfficientNet
|
||||
from model_base.resnet_more import CustomResNet
|
||||
from model_base.xcption import Xception
|
||||
|
||||
|
||||
def get_model(name, number_class, device, backbone):
|
||||
"""
|
||||
根据指定的模型名称加载模型,并根据任务类别数调整最后的分类层。
|
||||
|
||||
Args:
|
||||
name (str): 模型名称 ('Vgg', 'ResNet', 'EfficientNet', 'Xception')。
|
||||
number_class (int): 分类类别数。
|
||||
device (torch.device): 设备 ('cuda' or 'cpu')。
|
||||
resnet_type (str): ResNet类型 ('resnet18', 'resnet34', 'resnet50', 'resnet101', etc.)。
|
||||
|
||||
Returns:
|
||||
nn.Module: 经过修改的模型。
|
||||
"""
|
||||
if name == 'Vgg':
|
||||
model = models.vgg16_bn(pretrained=True).to(device)
|
||||
model.classifier[6] = nn.Linear(model.classifier[6].in_features, number_class)
|
||||
elif name == 'ResNet18':
|
||||
model = CustomResNet(resnet_type='resnet18', num_classes=number_class, pretrained=True).to(device)
|
||||
elif name == 'ResNet34':
|
||||
model = CustomResNet(resnet_type='resnet34', num_classes=number_class, pretrained=True).to(device)
|
||||
elif name == 'ResNet50':
|
||||
model = CustomResNet(resnet_type='resnet50', num_classes=number_class, pretrained=True).to(device)
|
||||
elif name == 'ResNet101':
|
||||
model = CustomResNet(resnet_type='resnet101', num_classes=number_class, pretrained=True).to(device)
|
||||
elif name == 'ResNet152':
|
||||
model = CustomResNet(resnet_type='resnet152', num_classes=number_class, pretrained=True).to(device)
|
||||
elif name == 'EfficientNet':
|
||||
# 使用自定义的 DeepLab 类加载 EfficientNet
|
||||
model = DeepLab(backbone='efficientnet', num_classes=number_class).to(device)
|
||||
elif name == 'Xception':
|
||||
model = Xception(
|
||||
in_planes=3,
|
||||
num_classes=number_class,
|
||||
pretrained=True,
|
||||
pretrained_path="/home/terminator/1325/yhs/fedLeaning/pre_model/xception-43020ad28.pth"
|
||||
).to(device)
|
||||
elif name == 'DeepLab':
|
||||
# 使用自定义的 DeepLab 类加载 EfficientNet
|
||||
model = DeepLab_F(num_classes=1, backbone=backbone).to(device)
|
||||
elif name == 'resnet18_psa':
|
||||
model = ResNet(BasicBlockWithPSA, [2, 2, 2, 2], number_class)
|
||||
else:
|
||||
raise ValueError(f"Model {name} is not supported.")
|
||||
return model
|
||||
|
||||
def get_federated_model(device):
|
||||
"""初始化客户端模型和全局模型"""
|
||||
client_models = [
|
||||
get_model("resnet18_psa", 1, device, "*") for _ in range(3)
|
||||
]
|
||||
global_model = get_model("resnet18_psa", 1, device, "*")
|
||||
return client_models, global_model
|
Reference in New Issue
Block a user