2024-12-24 00:14:35 +08:00

109 lines
3.6 KiB
Python

import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, transforms
from torch.utils.data import DataLoader, Dataset
import numpy as np
from PIL import Image
import os
# 定义 ResNet 模型(以 ResNet18 为例)
class ResNetModel(nn.Module):
def __init__(self, num_classes):
super(ResNetModel, self).__init__()
self.resnet = models.resnet18(pretrained=True)
# 修改最后的全连接层以适应特定的分类任务
self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
def forward(self, x):
return self.resnet(x)
# 自定义数据集类
class WebpageDataset(Dataset):
def __init__(self, image_dir, transform=None):
self.image_dir = image_dir
self.transform = transform
self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
img_name = os.path.join(self.image_dir, self.image_files[idx])
image = Image.open(img_name).convert('RGB')
label = self.get_label_from_filename(self.image_files[idx])
if self.transform:
image = self.transform(image)
return image, label
def get_label_from_filename(self, filename):
# 假设文件名格式为 'class_label.png'
return int(filename.split('_')[0])
# 图像预处理
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
# 定义客户端训练函数
def train_local_model(model, dataloader, criterion, optimizer, epochs=5):
model.train()
for epoch in range(epochs):
for images, labels in dataloader:
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
return model.state_dict()
# 联邦平均算法
def federated_average(models_state_dicts):
avg_state_dict = models_state_dicts[0]
for key in avg_state_dict.keys():
for i in range(1, len(models_state_dicts)):
avg_state_dict[key] += models_state_dicts[i][key]
avg_state_dict[key] = torch.div(avg_state_dict[key], len(models_state_dicts))
return avg_state_dict
# 模拟多个客户端的数据
client_data_dirs = ['client1_data', 'client2_data', 'client3_data'] # 每个客户端的数据目录
num_classes = 10 # 根据实际情况设置
# 初始化全局模型
global_model = ResNetModel(num_classes=num_classes)
# 定义损失函数
criterion = nn.CrossEntropyLoss()
# 联邦学习过程
num_rounds = 10
for round in range(num_rounds):
local_models = []
for client_dir in client_data_dirs:
# 加载客户端数据
dataset = WebpageDataset(image_dir=client_dir, transform=transform)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# 初始化客户端模型
local_model = ResNetModel(num_classes=num_classes)
local_model.load_state_dict(global_model.state_dict())
# 定义优化器
optimizer = optim.SGD(local_model.parameters(), lr=0.01, momentum=0.9)
# 训练本地模型
local_state_dict = train_local_model(local_model, dataloader, criterion, optimizer)
local_models.append(local_state_dict)
# 聚合模型参数
global_state_dict = federated_average(local_models)
global_model.load_state_dict(global_state_dict)
print(f'Round {round+1}/{num_rounds} completed.')
# 保存全局模型
torch.save(global_model.state_dict(), 'federated_resnet_model.pth')