基于深度学习的手写数字识别
数据准备:MNIST手写数字数据集,可以到官网下载,也可从其他地方下载。
任务:使用深度学习对 MNIST 手写数字数据集完成手写数字(0~9)的分类, 首先使用LeNet网络模型在MNIST数据集上得到一个Baseline的准确度, 在此基础上尝试调整模型结构将准确度提升到99.5%以上(四舍五入不算,如类似99.45%不算)。
不限框架,Pytorch
、 Tensorflow
、 Keras
或者 飞桨
等均可。
参考:https://zhuanlan.zhihu.com/p/544161254
摘要
本次实验使用了 Pytorch
作为深度学习框架,并使用 matplotlib
作为可视化工具,实现了一个卷积神经网络(Convolutional Neural Network,CNN)对手写数字(MNIST数据集)进行分类的过程。
概述
MNIST是一个手写数字的图像数据集,其中包含了60000张训练图像和10000张测试图像。这些图像是由真实的手写数字扫描而来,图像分辨率为28x28像素,每个像素的灰度值介于0和255之间。MNIST数据集是机器学习领域的经典数据集之一,它被 广泛用于图像分类和模式识别等任务的研究和评估。
PyTorch是一个基于Python的开源深度学习框架,它提供了一系列工具和接口,可以方便地构建、训练和部署深度学习模型。PyTorch的优势在于它具有动态图的特性,这使得模型的构建和调试变得更加灵活和直观。此外,PyTorch还提供了许多预训练的模型和数据集,可以帮助用户更快地实现自己的深度学习应用。
在使用PyTorch进行深度学习任务时,用户需要定义一个模型的结构、损失函数和优化器,并对数据进行预处理和加载。然后,用户可以使用训练数据对模型进行训练,并使用测试数据进行验证和评估。PyTorch提供了许多工具和接口,可以方便地进行这些操作,并且具有高效的计算性能。
我的工作
为了减少第三方库的依赖,我使用了 Pytorch
作为深度学习框架,并使用 matplotlib
作为可视化工具。
本次实验需要导入以下依赖库:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from torchvision import datasets
from torchvision.transforms import Compose
import matplotlib.pyplot as plt
设定超参数
这里设定了一些超参数,用于控制训练过程。
train_batch_size: int = 256 # 训练集的batch_size
test_batch_size: int = 100 # 测试集的batch_size
learning_rate: float = 0.06 # 学习率
epoch: int = 100 # 迭代次数
random_seed: int = 2 # 随机种子
设置推理设备
如果有可用的GPU,就使用GPU,可以加快训练速度。
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
if torch.cuda.is_available():
torch.backends.cudnn.benchmark = True # 自动寻找最优算法
device = torch.device('cuda')
torch.cuda.empty_cache() # 释放显存
torch.cuda.manual_seed_all(random_seed) # 为gpu提供随机数
else:
device = torch.device('cpu')
torch.manual_seed(random_seed) # 为cpu提供随机数
print(f'Use device:{device}')
定义 Compose 对象
Compose 对象可以将多个转换组合在一起,例如将 ToTensor
和 Normalize
组合在一起,可以将图像转换为张量并进行标准化,方便后续调用。
# 构建 compose 类的实例,包含转张量和标准化
train_transform: Compose = Compose([
transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)), # 对照片进行随机平移
transforms.RandomRotation((-10, 10)), # 随机旋转
transforms.ToTensor(), # 转张量
transforms.Normalize((0.1307,), (0.3081,)) # 标准化 (均值, 标准差)
])
# 标准化 (均值, 标准差)
test_transform: Compose = Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])