基于ModelArts的手写数字识别丨【我与ModelArts的故事】
模型训练与保存
首先在目录下建立 PyTorchModel.py
文件(这个文件名可以自己命名,主要是用来构建网络和模型训练)
下面是 PyTorchModel.py
源码
# 导入相关依赖
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
# 选择使用的硬件
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
# 超参数
num_epochs = 5
num_classes = 10
batch_size = 100
learning_rate = 0.001
# 获取训练集
train_dataset = torchvision.datasets.MNIST(root='./MNIST_data',
train=True,
transform=transforms.ToTensor(),
download=True)
# 获取测试集
test_dataset = torchvision.datasets.MNIST(root='./MNIST_data',
train=False,
transform=transforms.ToTensor(),
download=True)
# 加载训练集
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
# 加载测试集
test_loader = torch.utils.data.DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=False)
# 创建一个模型类
class ConvNet(nn.Module):
def __init__(self, num_classes=10):
super(ConvNet, self).__init__()
self.layer1 = nn.Sequential(
nn.Conv2d(1, 6, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(6),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.layer2 = nn.Sequential(
nn.Conv2d(6, 16, kernel_size=5, stride=1, padding=2),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2)
)
self.fc = nn.Linear(7 * 7 * 16, num_classes)
def forward(self, x):
out = self.layer1(x)
out = self.layer2(out)
out = out.reshape(out.size(0), -1)
out = self.fc(out)
return out
def model_train():
model = ConvNet(num_classes).to(device) # ConvNet类实例化成一个对象
criterion = nn.CrossEntropyLoss() # 交叉熵损失函数
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) # Adam优化器
total_step = len(train_loader) # 训练总步长
# 开始训练
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
loss = criterion(outputs, labels)
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (i + 1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'
.format(epoch + 1, num_epochs, i + 1, total_step, loss.item()))
# 模型测试
model.eval()
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
images = images.to(device)
labels = labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy of tne model on the 10000 test iamges: {} %'.format(100 * correct / total))
# 保存模型
torch.save(model.state_dict(), 'model.ckpt')
if __name__ == "__main__":
model_train()
PyTorchModel.py
文件位置及代码展示。
首先是模型训练,创建一个 .ipynb
文件,在该文件下运行 PyTorchModel.py
进行模型训练。
代码如下:
%run PyTorchModel.py
运行结果
运行完成后会在当前目录下生成一个模型文件:model.ckpt
本地手写数字识别
在上一步骤中,完成模型训练后就要开始就行本地手写数字预测。
主要有五个步骤,导入相关依赖、硬件选择、本地模型加载、本地图片处理和图像预测。
导入相关依赖
import torch
import PyTorchModel
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageOps
注意:在导入相关依赖时,PyTorchModel
是我们自己创建的 .py
文件,注意命名。
硬件选择
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
本地模型加载
model = PyTorchModel.ConvNet().to(device)
model.load_state_dict(torch.load('model.ckpt'))
本地图片处理
img = Image.open('./num-draft/' + image_name) # 导入本地图片
img_gray = np.array(ImageOps.grayscale(img)) # 图片灰度化
img_inv = (255 - img_gray) / 255.0 # 转成白底黑字,以适应MNIST数据集中的数据(黑底白字)。再进行归一化处理
image = np.float32(img_inv.reshape((1, 28 * 28))) # 转换成浮点型
image_array_2_tensor = torch.from_numpy(image.reshape(1, 1, 28, 28)).to(device) # array转换成tensor
图像预测
# 预测
predict = model(image_array_2_tensor)
prediction = torch.max(predict.data, 1)
print('预测的图片是: ', image_name, 'AI判断的数字是{}'.format(prediction[1])) # 打印预测结果
预测结果
从预测的结果来看,模型的效果还是不错的(每次训练的模型效果都不一样,需要调调参数,使得模型的效果更好)。
总结
使用PyTorch实现基于CNN的手写数字识别,主要由网络构建、模型训练和手写数字预测三个部分组成。网络构件中主要使用了卷积神经网络,之后进行模型训练,交叉熵损失函数进行评价,最后保存训练好的模型。在预测过程中,关键点是对模型的加载和本地图像的处理。
其中,在导入相关依赖时,需要使用单独的一个cell,也就是说在单独的一个cell里面导入相关依赖,不要和其他代码合并到一起,否则会报错,导致无法导入相关模块。
在使用ModelArts进行训练和预测的过程中还是比较流畅的,一是硬件GPU的支持,二是开发环境的简洁清爽。这也是喜欢使用ModelArts的原因,在使用的过程中,也发现ModelArts在悄悄地不断升级,界面更加美观,对开发者更加友好。
参考文献
——END——
我正在参加【有奖征文第21期】说说你和ModelArts的故事,输出优质产品体验文章,赢开发者大礼包!
https://bbs.huaweicloud.cn/blogs/395149
- 点赞
- 收藏
- 关注作者
评论(0)