mirror of
https://github.com/luzhisheng/js_reverse.git
synced 2025-04-20 21:55:07 +08:00
03数据加载-模型等API介绍
This commit is contained in:
parent
dce3f5177e
commit
07e12b7625
10
机器学习/03数据加载-模型等API介绍/图片数据处理.py
Normal file
10
机器学习/03数据加载-模型等API介绍/图片数据处理.py
Normal file
@ -0,0 +1,10 @@
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
mnist_train = MNIST(root="./MNIST_data", train=True, download=True, transform=transforms.PILToTensor())
|
||||
|
||||
a = DataLoader(mnist_train, batch_size=1, shuffle=True)
|
||||
for i in a:
|
||||
print(i)
|
18
机器学习/03数据加载-模型等API介绍/图片数据标准化处理.py
Normal file
18
机器学习/03数据加载-模型等API介绍/图片数据标准化处理.py
Normal file
@ -0,0 +1,18 @@
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
# 图片标准化处理目的就是提高识别准确度
|
||||
my_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.PILToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
mnist_train = MNIST(root="./MNIST_data", train=True, download=True, transform=my_transforms)
|
||||
|
||||
a = DataLoader(mnist_train, batch_size=1, shuffle=True)
|
||||
for img, labels in a:
|
||||
print(img)
|
||||
print(labels)
|
||||
exit()
|
22
机器学习/03数据加载-模型等API介绍/图片数据融合处理.py
Normal file
22
机器学习/03数据加载-模型等API介绍/图片数据融合处理.py
Normal file
@ -0,0 +1,22 @@
|
||||
from torchvision.datasets import MNIST
|
||||
from torchvision import transforms
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.utils import make_grid
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# 图片标准化处理目的就是提高识别准确度,比如灰度,旋转,等等
|
||||
my_transforms = transforms.Compose(
|
||||
[
|
||||
transforms.PILToTensor(),
|
||||
]
|
||||
)
|
||||
|
||||
mnist_train = MNIST(root="./MNIST_data", train=True, download=True, transform=my_transforms)
|
||||
|
||||
a = DataLoader(mnist_train, batch_size=2, shuffle=True)
|
||||
for img, labels in a:
|
||||
print(labels)
|
||||
image = make_grid(img).permute(1, 2, 0).numpy()
|
||||
plt.imshow(image)
|
||||
plt.show()
|
||||
exit()
|
21
机器学习/03数据加载-模型等API介绍/模型和优化器.py
Normal file
21
机器学习/03数据加载-模型等API介绍/模型和优化器.py
Normal file
@ -0,0 +1,21 @@
|
||||
from torch import optim
|
||||
from torch import nn
|
||||
|
||||
|
||||
# 全连接层
|
||||
class MnistModel(nn.Module):
|
||||
def __init__(self):
|
||||
super(MnistModel, self).__init__()
|
||||
self.fc2 = nn.Linear(1 * 28 * 28, 100) # 最终为什么是 10,因为手写数字识别最终是 10分类的,分类任务中有多少,就分几类。 0-9
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, image):
|
||||
image_viwed = image.view(-1, 1 * 28 * 28) # 此处需要拍平
|
||||
out = self.fc2(image_viwed)
|
||||
fc1_out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
model = MnistModel()
|
||||
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
||||
LOST = nn.CTCLoss()
|
Loading…
x
Reference in New Issue
Block a user