2023-05-26 13:34:44 +08:00

22 lines
643 B
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from torch import optim
from torch import nn
# 全连接层
class MnistModel(nn.Module):
def __init__(self):
super(MnistModel, self).__init__()
self.fc2 = nn.Linear(1 * 28 * 28, 100) # 最终为什么是 10因为手写数字识别最终是 10分类的分类任务中有多少就分几类。 0-9
self.relu = nn.ReLU()
def forward(self, image):
image_viwed = image.view(-1, 1 * 28 * 28) # 此处需要拍平
out = self.fc2(image_viwed)
fc1_out = self.relu(out)
return out
model = MnistModel()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
LOST = nn.CTCLoss()