03数据加载-模型等API介绍

This commit is contained in:
luzhisheng 2023-04-21 12:58:49 +08:00
parent dce3f5177e
commit 07e12b7625
4 changed files with 71 additions and 0 deletions

View File

@ -0,0 +1,10 @@
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
mnist_train = MNIST(root="./MNIST_data", train=True, download=True, transform=transforms.PILToTensor())
a = DataLoader(mnist_train, batch_size=1, shuffle=True)
for i in a:
print(i)

View File

@ -0,0 +1,18 @@
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
# 图片标准化处理目的就是提高识别准确度
my_transforms = transforms.Compose(
[
transforms.PILToTensor(),
]
)
mnist_train = MNIST(root="./MNIST_data", train=True, download=True, transform=my_transforms)
a = DataLoader(mnist_train, batch_size=1, shuffle=True)
for img, labels in a:
print(img)
print(labels)
exit()

View File

@ -0,0 +1,22 @@
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import matplotlib.pyplot as plt
# 图片标准化处理目的就是提高识别准确度,比如灰度,旋转,等等
my_transforms = transforms.Compose(
[
transforms.PILToTensor(),
]
)
mnist_train = MNIST(root="./MNIST_data", train=True, download=True, transform=my_transforms)
a = DataLoader(mnist_train, batch_size=2, shuffle=True)
for img, labels in a:
print(labels)
image = make_grid(img).permute(1, 2, 0).numpy()
plt.imshow(image)
plt.show()
exit()

View File

@ -0,0 +1,21 @@
from torch import optim
from torch import nn
# 全连接层
class MnistModel(nn.Module):
def __init__(self):
super(MnistModel, self).__init__()
self.fc2 = nn.Linear(1 * 28 * 28, 100) # 最终为什么是 10因为手写数字识别最终是 10分类的分类任务中有多少就分几类。 0-9
self.relu = nn.ReLU()
def forward(self, image):
image_viwed = image.view(-1, 1 * 28 * 28) # 此处需要拍平
out = self.fc2(image_viwed)
fc1_out = self.relu(out)
return out
model = MnistModel()
optimizer = optim.Adam(model.parameters(), lr=1e-4)
LOST = nn.CTCLoss()