LeNet5实现-pytorch
完整代码:PyNet/pytorch/lenet5_test.py
加载数据
pytorch
提供模块torchvision
,用于数据的加载、预处理和批量化
torchvision.datasets
内置类MNIST
用于mnist
数据集下载和加载torchvision.transforms
对数据进行预处理torchvision.DataLoader
对数据进行批量化
1 | def load_mnist_data(batch_size=128, shuffle=False): |
网络定义
LeNet-5
模型定义参考卷积神经网络推导-单张图片矩阵计算
torch.nn
模块实现了网络层类,包括卷积层(Conv2d
)、最大池化层(MaxPool2d
)、全连接层(Linear
)和其他激活层
torch.nn
模块提供functional
类用于网络层类的实现
1 | class LeNet5(nn.Module): |
训练
训练参数如下
- 学习率
lr = 1e-3
- 批量大小
batch_size = 128
- 迭代次数
epochs = 500
训练结果
训练时间
CPU | GPU | 单次迭代时间 |
---|---|---|
8核 Intel(R) Core(TM) i7-7700HQ CPU @ 2.80GHz | GeForce 940MX | 约13秒 |
迭代500
次训练结果
训练集精度 | 测试集精度 |
---|---|
99.40% | 98.63% |