diff --git a/机器学习/03数据加载-模型等API介绍/图片数据处理.py b/机器学习/03数据加载-模型等API介绍/图片数据处理.py new file mode 100644 index 0000000..299e07d --- /dev/null +++ b/机器学习/03数据加载-模型等API介绍/图片数据处理.py @@ -0,0 +1,10 @@ +from torchvision.datasets import MNIST +from torchvision import transforms +from torch.utils.data import DataLoader + + +mnist_train = MNIST(root="./MNIST_data", train=True, download=True, transform=transforms.PILToTensor()) + +a = DataLoader(mnist_train, batch_size=1, shuffle=True) +for i in a: + print(i) diff --git a/机器学习/03数据加载-模型等API介绍/图片数据标准化处理.py b/机器学习/03数据加载-模型等API介绍/图片数据标准化处理.py new file mode 100644 index 0000000..0cddb88 --- /dev/null +++ b/机器学习/03数据加载-模型等API介绍/图片数据标准化处理.py @@ -0,0 +1,18 @@ +from torchvision.datasets import MNIST +from torchvision import transforms +from torch.utils.data import DataLoader + +# 图片标准化处理目的就是提高识别准确度 +my_transforms = transforms.Compose( + [ + transforms.PILToTensor(), + ] +) + +mnist_train = MNIST(root="./MNIST_data", train=True, download=True, transform=my_transforms) + +a = DataLoader(mnist_train, batch_size=1, shuffle=True) +for img, labels in a: + print(img) + print(labels) + exit() diff --git a/机器学习/03数据加载-模型等API介绍/图片数据融合处理.py b/机器学习/03数据加载-模型等API介绍/图片数据融合处理.py new file mode 100644 index 0000000..f946271 --- /dev/null +++ b/机器学习/03数据加载-模型等API介绍/图片数据融合处理.py @@ -0,0 +1,22 @@ +from torchvision.datasets import MNIST +from torchvision import transforms +from torch.utils.data import DataLoader +from torchvision.utils import make_grid +import matplotlib.pyplot as plt + +# 图片标准化处理目的就是提高识别准确度,比如灰度,旋转,等等 +my_transforms = transforms.Compose( + [ + transforms.PILToTensor(), + ] +) + +mnist_train = MNIST(root="./MNIST_data", train=True, download=True, transform=my_transforms) + +a = DataLoader(mnist_train, batch_size=2, shuffle=True) +for img, labels in a: + print(labels) + image = make_grid(img).permute(1, 2, 0).numpy() + plt.imshow(image) + plt.show() + exit() diff --git a/机器学习/03数据加载-模型等API介绍/模型和优化器.py b/机器学习/03数据加载-模型等API介绍/模型和优化器.py new file mode 100644 index 0000000..a5afdea --- /dev/null +++ b/机器学习/03数据加载-模型等API介绍/模型和优化器.py @@ -0,0 +1,21 @@ +from torch import optim +from torch import nn + + +# 全连接层 +class MnistModel(nn.Module): + def __init__(self): + super(MnistModel, self).__init__() + self.fc2 = nn.Linear(1 * 28 * 28, 100) # 最终为什么是 10,因为手写数字识别最终是 10分类的,分类任务中有多少,就分几类。 0-9 + self.relu = nn.ReLU() + + def forward(self, image): + image_viwed = image.view(-1, 1 * 28 * 28) # 此处需要拍平 + out = self.fc2(image_viwed) + fc1_out = self.relu(out) + return out + + +model = MnistModel() +optimizer = optim.Adam(model.parameters(), lr=1e-4) +LOST = nn.CTCLoss()