mirror of
https://github.com/luzhisheng/js_reverse.git
synced 2025-04-21 12:15:16 +08:00
04第一次训练与手写数字
This commit is contained in:
parent
6667a6960b
commit
c53395f97a
0
机器学习/04第一次训练与手写数字/__init__.py
Normal file
0
机器学习/04第一次训练与手写数字/__init__.py
Normal file
47
机器学习/04第一次训练与手写数字/推理.py
Normal file
47
机器学习/04第一次训练与手写数字/推理.py
Normal file
@ -0,0 +1,47 @@
|
||||
from torchvision import transforms
|
||||
from torch import nn
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class MnistModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MnistModel, self).__init__()
|
||||
self.fc1 = nn.Linear(1 * 28 * 28, 100) # 最终为什么是 10,因为手写数字识别最终是 10分类的,分类任务中有多少,就分几类。 0-9
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(100, 10)
|
||||
|
||||
def forward(self, image):
|
||||
image_viwed = image.view(-1, 1 * 28 * 28) # 此处需要拍平
|
||||
out_1 = self.fc1(image_viwed)
|
||||
fc1 = self.relu(out_1)
|
||||
out_2 = self.fc2(fc1)
|
||||
return out_2
|
||||
|
||||
|
||||
model = MnistModel()
|
||||
model.load_state_dict(torch.load("./models/model.pkl"))
|
||||
|
||||
my_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.1307,), std=(0.3081,))
|
||||
]
|
||||
)
|
||||
|
||||
image = Image.open('./img/test.jpg')
|
||||
|
||||
my_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.Grayscale(1),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.1307,), std=(0.3081,))
|
||||
]
|
||||
)
|
||||
image = my_transforms(image)
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
output = model(image)
|
||||
result = output.max(dim=1).indices
|
||||
print(result)
|
58
机器学习/04第一次训练与手写数字/测试.py
Normal file
58
机器学习/04第一次训练与手写数字/测试.py
Normal file
@ -0,0 +1,58 @@
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import nn
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
|
||||
class MnistModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MnistModel, self).__init__()
|
||||
self.fc1 = nn.Linear(1 * 28 * 28, 100) # 最终为什么是 10,因为手写数字识别最终是 10分类的,分类任务中有多少,就分几类。 0-9
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(100, 10)
|
||||
|
||||
def forward(self, image):
|
||||
image_viwed = image.view(-1, 1 * 28 * 28) # 此处需要拍平
|
||||
out_1 = self.fc1(image_viwed)
|
||||
fc1 = self.relu(out_1)
|
||||
out_2 = self.fc2(fc1)
|
||||
return out_2
|
||||
|
||||
|
||||
def test_success():
|
||||
# 实例化模型
|
||||
total_loss = []
|
||||
model = MnistModel()
|
||||
if os.path.exists("./models/model.pkl"):
|
||||
model.load_state_dict(torch.load("./models/model.pkl"))
|
||||
loss_function = nn.CrossEntropyLoss()
|
||||
my_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.1307,), std=(0.3081,))
|
||||
]
|
||||
)
|
||||
mnist_train = MNIST(root="../MNIST_data", train=False, download=True, transform=my_transforms)
|
||||
dataloader = DataLoader(mnist_train, batch_size=8, shuffle=True)
|
||||
dataloader = tqdm(dataloader, total=len(dataloader))
|
||||
|
||||
succeed = []
|
||||
|
||||
model.eval()
|
||||
with torch.no_grad():
|
||||
for images, labels in dataloader:
|
||||
# 获取结果
|
||||
output = model(images)
|
||||
result = output.max(dim=1).indices
|
||||
# print(labels)
|
||||
# print(result)
|
||||
succeed.append(result.eq(labels).float().mean().item())
|
||||
# 通过结果计算损失
|
||||
loss = loss_function(output, labels)
|
||||
total_loss.append(loss.item())
|
||||
print(np.mean(total_loss))
|
||||
return np.mean(succeed)
|
72
机器学习/04第一次训练与手写数字/训练.py
Normal file
72
机器学习/04第一次训练与手写数字/训练.py
Normal file
@ -0,0 +1,72 @@
|
||||
from torch import save, load
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from torch import nn
|
||||
from torch import optim
|
||||
from tqdm import tqdm
|
||||
import 测试
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
|
||||
class MnistModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MnistModel, self).__init__()
|
||||
self.fc1 = nn.Linear(1 * 28 * 28, 100) # 最终为什么是 10,因为手写数字识别最终是 10分类的,分类任务中有多少,就分几类。 0-9
|
||||
self.relu = nn.ReLU()
|
||||
self.fc2 = nn.Linear(100, 10)
|
||||
|
||||
def forward(self, image):
|
||||
image_viwed = image.view(-1, 1 * 28 * 28) # 此处需要拍平
|
||||
out_1 = self.fc1(image_viwed)
|
||||
fc1 = self.relu(out_1)
|
||||
out_2 = self.fc2(fc1)
|
||||
return out_2
|
||||
|
||||
|
||||
# 实例化模型
|
||||
model = MnistModel()
|
||||
optimizer = optim.Adam(model.parameters())
|
||||
|
||||
# 加载已经训练好的模型和优化器继续进行训练
|
||||
if os.path.exists('./models/model.pkl'):
|
||||
model.load_state_dict(load("./models/model.pkl"))
|
||||
optimizer.load_state_dict(load("./models/optimizer.pkl"))
|
||||
|
||||
loss_function = nn.CrossEntropyLoss()
|
||||
my_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.1307,), std=(0.3081,))
|
||||
]
|
||||
)
|
||||
mnist_train = MNIST(root="../MNIST_data", train=True, download=True, transform=my_transforms)
|
||||
|
||||
|
||||
def train(epoch):
|
||||
total_loss = []
|
||||
dataloader = DataLoader(mnist_train, batch_size=8, shuffle=True)
|
||||
dataloader = tqdm(dataloader, total=len(dataloader))
|
||||
model.train()
|
||||
for images, labels in dataloader:
|
||||
# 梯度置0
|
||||
optimizer.zero_grad()
|
||||
# 前向传播
|
||||
output = model(images)
|
||||
# 通过结果计算损失
|
||||
loss = loss_function(output, labels)
|
||||
total_loss.append(loss.item())
|
||||
# 反向传播
|
||||
loss.backward()
|
||||
# 优化器更新
|
||||
optimizer.step()
|
||||
|
||||
save(model.state_dict(), './models/model.pkl')
|
||||
save(optimizer.state_dict(), './models/optimizer.pkl')
|
||||
# 打印一下训练成功率
|
||||
print('第{}个epoch,成功率, 损失为{}'.format(epoch, np.mean(total_loss)), 测试.test_success())
|
||||
|
||||
|
||||
for i in range(10):
|
||||
train(i)
|
7
机器学习/04第一次训练与手写数字/进度条.py
Normal file
7
机器学习/04第一次训练与手写数字/进度条.py
Normal file
@ -0,0 +1,7 @@
|
||||
import time
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
c = tqdm(range(0, 1000), total=1000)
|
||||
for i in c:
|
||||
time.sleep(0.1)
|
3
机器学习/README.md
Normal file
3
机器学习/README.md
Normal file
@ -0,0 +1,3 @@
|
||||
pytorch文档
|
||||
|
||||
https://pytorch.org/docs/stable/index.html
|
Loading…
x
Reference in New Issue
Block a user