mirror of
https://github.com/luzhisheng/js_reverse.git
synced 2025-04-19 09:24:46 +08:00
59 lines
1.9 KiB
Python
59 lines
1.9 KiB
Python
from torchvision.datasets import MNIST
|
||
from torchvision import transforms
|
||
from torch.utils.data import DataLoader
|
||
from torch import nn
|
||
from tqdm import tqdm
|
||
import torch
|
||
import numpy as np
|
||
import os
|
||
|
||
|
||
class MnistModel(nn.Module):
|
||
def __init__(self):
|
||
super(MnistModel, self).__init__()
|
||
self.fc1 = nn.Linear(1 * 28 * 28, 100) # 最终为什么是 10,因为手写数字识别最终是 10分类的,分类任务中有多少,就分几类。 0-9
|
||
self.relu = nn.ReLU()
|
||
self.fc2 = nn.Linear(100, 10)
|
||
|
||
def forward(self, image):
|
||
image_viwed = image.view(-1, 1 * 28 * 28) # 此处需要拍平
|
||
out_1 = self.fc1(image_viwed)
|
||
fc1 = self.relu(out_1)
|
||
out_2 = self.fc2(fc1)
|
||
return out_2
|
||
|
||
|
||
def test_success():
|
||
# 实例化模型
|
||
total_loss = []
|
||
model = MnistModel()
|
||
if os.path.exists("./models/model.pkl"):
|
||
model.load_state_dict(torch.load("./models/model.pkl"))
|
||
loss_function = nn.CrossEntropyLoss()
|
||
my_transforms = transforms.Compose(
|
||
[
|
||
transforms.ToTensor(),
|
||
transforms.Normalize(mean=(0.1307,), std=(0.3081,))
|
||
]
|
||
)
|
||
mnist_train = MNIST(root="../MNIST_data", train=False, download=True, transform=my_transforms)
|
||
dataloader = DataLoader(mnist_train, batch_size=8, shuffle=True)
|
||
dataloader = tqdm(dataloader, total=len(dataloader))
|
||
|
||
succeed = []
|
||
|
||
model.eval()
|
||
with torch.no_grad():
|
||
for images, labels in dataloader:
|
||
# 获取结果
|
||
output = model(images)
|
||
result = output.max(dim=1).indices
|
||
# print(labels)
|
||
# print(result)
|
||
succeed.append(result.eq(labels).float().mean().item())
|
||
# 通过结果计算损失
|
||
loss = loss_function(output, labels)
|
||
total_loss.append(loss.item())
|
||
print(np.mean(total_loss))
|
||
return np.mean(succeed)
|