mirror of
https://github.com/NaiboWang/EasySpider.git
synced 2025-04-12 11:37:11 +08:00
109 lines
3.6 KiB
Python
109 lines
3.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torchvision import models, transforms
|
|
from torch.utils.data import DataLoader, Dataset
|
|
import numpy as np
|
|
from PIL import Image
|
|
import os
|
|
|
|
# 定义 ResNet 模型(以 ResNet18 为例)
|
|
class ResNetModel(nn.Module):
|
|
def __init__(self, num_classes):
|
|
super(ResNetModel, self).__init__()
|
|
self.resnet = models.resnet18(pretrained=True)
|
|
# 修改最后的全连接层以适应特定的分类任务
|
|
self.resnet.fc = nn.Linear(self.resnet.fc.in_features, num_classes)
|
|
|
|
def forward(self, x):
|
|
return self.resnet(x)
|
|
|
|
# 自定义数据集类
|
|
class WebpageDataset(Dataset):
|
|
def __init__(self, image_dir, transform=None):
|
|
self.image_dir = image_dir
|
|
self.transform = transform
|
|
self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.png')]
|
|
|
|
def __len__(self):
|
|
return len(self.image_files)
|
|
|
|
def __getitem__(self, idx):
|
|
img_name = os.path.join(self.image_dir, self.image_files[idx])
|
|
image = Image.open(img_name).convert('RGB')
|
|
label = self.get_label_from_filename(self.image_files[idx])
|
|
if self.transform:
|
|
image = self.transform(image)
|
|
return image, label
|
|
|
|
def get_label_from_filename(self, filename):
|
|
# 假设文件名格式为 'class_label.png'
|
|
return int(filename.split('_')[0])
|
|
|
|
# 图像预处理
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor(),
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
|
])
|
|
|
|
# 定义客户端训练函数
|
|
def train_local_model(model, dataloader, criterion, optimizer, epochs=5):
|
|
model.train()
|
|
for epoch in range(epochs):
|
|
for images, labels in dataloader:
|
|
outputs = model(images)
|
|
loss = criterion(outputs, labels)
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
return model.state_dict()
|
|
|
|
# 联邦平均算法
|
|
def federated_average(models_state_dicts):
|
|
avg_state_dict = models_state_dicts[0]
|
|
for key in avg_state_dict.keys():
|
|
for i in range(1, len(models_state_dicts)):
|
|
avg_state_dict[key] += models_state_dicts[i][key]
|
|
avg_state_dict[key] = torch.div(avg_state_dict[key], len(models_state_dicts))
|
|
return avg_state_dict
|
|
|
|
# 模拟多个客户端的数据
|
|
client_data_dirs = ['client1_data', 'client2_data', 'client3_data'] # 每个客户端的数据目录
|
|
num_classes = 10 # 根据实际情况设置
|
|
|
|
# 初始化全局模型
|
|
global_model = ResNetModel(num_classes=num_classes)
|
|
|
|
# 定义损失函数
|
|
criterion = nn.CrossEntropyLoss()
|
|
|
|
# 联邦学习过程
|
|
num_rounds = 10
|
|
for round in range(num_rounds):
|
|
local_models = []
|
|
for client_dir in client_data_dirs:
|
|
# 加载客户端数据
|
|
dataset = WebpageDataset(image_dir=client_dir, transform=transform)
|
|
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
|
|
|
|
# 初始化客户端模型
|
|
local_model = ResNetModel(num_classes=num_classes)
|
|
local_model.load_state_dict(global_model.state_dict())
|
|
|
|
# 定义优化器
|
|
optimizer = optim.SGD(local_model.parameters(), lr=0.01, momentum=0.9)
|
|
|
|
# 训练本地模型
|
|
local_state_dict = train_local_model(local_model, dataloader, criterion, optimizer)
|
|
local_models.append(local_state_dict)
|
|
|
|
# 聚合模型参数
|
|
global_state_dict = federated_average(local_models)
|
|
global_model.load_state_dict(global_state_dict)
|
|
|
|
print(f'Round {round+1}/{num_rounds} completed.')
|
|
|
|
# 保存全局模型
|
|
torch.save(global_model.state_dict(), 'federated_resnet_model.pth')
|