mirror of
https://github.com/luzhisheng/js_reverse.git
synced 2025-04-20 21:55:07 +08:00
03数据加载-模型等API介绍
This commit is contained in:
parent
dce3f5177e
commit
07e12b7625
10
机器学习/03数据加载-模型等API介绍/图片数据处理.py
Normal file
10
机器学习/03数据加载-模型等API介绍/图片数据处理.py
Normal file
@ -0,0 +1,10 @@
|
|||||||
|
from torchvision.datasets import MNIST
|
||||||
|
from torchvision import transforms
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
|
||||||
|
mnist_train = MNIST(root="./MNIST_data", train=True, download=True, transform=transforms.PILToTensor())
|
||||||
|
|
||||||
|
a = DataLoader(mnist_train, batch_size=1, shuffle=True)
|
||||||
|
for i in a:
|
||||||
|
print(i)
|
18
机器学习/03数据加载-模型等API介绍/图片数据标准化处理.py
Normal file
18
机器学习/03数据加载-模型等API介绍/图片数据标准化处理.py
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
from torchvision.datasets import MNIST
|
||||||
|
from torchvision import transforms
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
# 图片标准化处理目的就是提高识别准确度
|
||||||
|
my_transforms = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.PILToTensor(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
mnist_train = MNIST(root="./MNIST_data", train=True, download=True, transform=my_transforms)
|
||||||
|
|
||||||
|
a = DataLoader(mnist_train, batch_size=1, shuffle=True)
|
||||||
|
for img, labels in a:
|
||||||
|
print(img)
|
||||||
|
print(labels)
|
||||||
|
exit()
|
22
机器学习/03数据加载-模型等API介绍/图片数据融合处理.py
Normal file
22
机器学习/03数据加载-模型等API介绍/图片数据融合处理.py
Normal file
@ -0,0 +1,22 @@
|
|||||||
|
from torchvision.datasets import MNIST
|
||||||
|
from torchvision import transforms
|
||||||
|
from torch.utils.data import DataLoader
|
||||||
|
from torchvision.utils import make_grid
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
# 图片标准化处理目的就是提高识别准确度,比如灰度,旋转,等等
|
||||||
|
my_transforms = transforms.Compose(
|
||||||
|
[
|
||||||
|
transforms.PILToTensor(),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
mnist_train = MNIST(root="./MNIST_data", train=True, download=True, transform=my_transforms)
|
||||||
|
|
||||||
|
a = DataLoader(mnist_train, batch_size=2, shuffle=True)
|
||||||
|
for img, labels in a:
|
||||||
|
print(labels)
|
||||||
|
image = make_grid(img).permute(1, 2, 0).numpy()
|
||||||
|
plt.imshow(image)
|
||||||
|
plt.show()
|
||||||
|
exit()
|
21
机器学习/03数据加载-模型等API介绍/模型和优化器.py
Normal file
21
机器学习/03数据加载-模型等API介绍/模型和优化器.py
Normal file
@ -0,0 +1,21 @@
|
|||||||
|
from torch import optim
|
||||||
|
from torch import nn
|
||||||
|
|
||||||
|
|
||||||
|
# 全连接层
|
||||||
|
class MnistModel(nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super(MnistModel, self).__init__()
|
||||||
|
self.fc2 = nn.Linear(1 * 28 * 28, 100) # 最终为什么是 10,因为手写数字识别最终是 10分类的,分类任务中有多少,就分几类。 0-9
|
||||||
|
self.relu = nn.ReLU()
|
||||||
|
|
||||||
|
def forward(self, image):
|
||||||
|
image_viwed = image.view(-1, 1 * 28 * 28) # 此处需要拍平
|
||||||
|
out = self.fc2(image_viwed)
|
||||||
|
fc1_out = self.relu(out)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
model = MnistModel()
|
||||||
|
optimizer = optim.Adam(model.parameters(), lr=1e-4)
|
||||||
|
LOST = nn.CTCLoss()
|
Loading…
x
Reference in New Issue
Block a user