2023-07-12 23:26:55 +08:00

59 lines
1.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torch import nn
from tqdm import tqdm
import torch
import numpy as np
import os
class MnistModel(nn.Module):
def __init__(self):
super(MnistModel, self).__init__()
self.fc1 = nn.Linear(1 * 28 * 28, 100) # 最终为什么是 10因为手写数字识别最终是 10分类的分类任务中有多少就分几类。 0-9
self.relu = nn.ReLU()
self.fc2 = nn.Linear(100, 10)
def forward(self, image):
image_viwed = image.view(-1, 1 * 28 * 28) # 此处需要拍平
out_1 = self.fc1(image_viwed)
fc1 = self.relu(out_1)
out_2 = self.fc2(fc1)
return out_2
def test_success():
# 实例化模型
total_loss = []
model = MnistModel()
if os.path.exists("./models/model.pkl"):
model.load_state_dict(torch.load("./models/model.pkl"))
loss_function = nn.CrossEntropyLoss()
my_transforms = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(mean=(0.1307,), std=(0.3081,))
]
)
mnist_train = MNIST(root="../MNIST_data", train=False, download=True, transform=my_transforms)
dataloader = DataLoader(mnist_train, batch_size=8, shuffle=True)
dataloader = tqdm(dataloader, total=len(dataloader))
succeed = []
model.eval()
with torch.no_grad():
for images, labels in dataloader:
# 获取结果
output = model(images)
result = output.max(dim=1).indices
# print(labels)
# print(result)
succeed.append(result.eq(labels).float().mean().item())
# 通过结果计算损失
loss = loss_function(output, labels)
total_loss.append(loss.item())
print(np.mean(total_loss))
return np.mean(succeed)