当前位置: 首页 > news >正文

搭建线性网络对MNIST数据集进行训练、测试,并且预测图片

目录

1. 介绍

2. 搭建网络(model)

3. 训练网络(train)

3.1 数据预处理

3.2 加载数据集

3.3 实例化网络  + 构建优化器

3.4 训练网络

3.5 保存网络参数

4. 预测图片(predict)

5. code


1. 介绍

这次要完成的任务是搭建一个线性的模型去对手写数字集MNIST进行训练,对我们的网络在测试集上进行测试防止过拟合,最后在网上随便找一张手写数字的图片进行预测

通过之前的学习,大概了解到,训练神经网络大概需要如下几个步骤

1. 准备数据集 Data sets

2. 搭建网络 Net + 优化器 optimizer

3. 训练网络 Train

4. 测试网络 Test

这次的流程也是类似,但是因为这次的数据集和网络都较大,所以采用模块化的方式去实现

准备工作如下所示,model 里面存放我们搭建的网络, train 里面存在训练过程 ,将训练后的网络保存到 Net.pth 中,最后将训练完成的网络在predict 中去预测图片 2.png

 

2. 搭建网络(model)

首先看搭建网络需要的库

其中nn是神经网络库,可以方便搭建神经网络

nn.functional 是函数库,例如relu 、 sigmoid 函数等等

网络的搭建部分较为简单, 因为MNIST数据集是图像数据,每个图像的大小是 28 * 28 的灰度图像

而我们搭建的网络是线性网络(非CNN网络),所以第一层的的维度是(784,512)

这里做下简单介绍:

例如我们将28 * 28 的图片压缩成一维的就是 784列 的行向量,因为输入是采用batch输入的,假设batch的size 是n。所以每次输入的维度是 n行,784列的一个(n,784)矩阵大小。而我们想要数据经过第一层网络后输出是512列(这里我们只改变单个图像的特征,而不是batch),所以输出就是(n,512)所以根据矩阵乘法(n,784)*(784,512)=(n,512)

这里只要满足矩阵乘法就行,例如y = w * x +b或者 y = x * wT + b都是为了满足矩阵乘法

然后根据同样的方法去构建第二层、第三层即可。因为最后要输出图片的数字(0-9),所以最后一层应该是10

前向传播就是按照信号流通的方法即可。首先需要将一副图片变成(1,784)的形式,然后经过第一层+激活函数(这里是relu)后流向第二层等等...

因为这是个多分类的问题,采用的是交叉熵函数,网络最后的输出应该经过softmax计算概率

但是因为pytorch里面的交叉熵损失函数已经包含softmax层了,所以最后之间将最后一层输出即可

 

3. 训练网络(train)

3.1 数据预处理

对数据进行预处理之前需要导入一些库

 数据的预处理如图

 ToTensor 的作用是改变图像的通道顺序,然后的图像的灰度值归一化的(0-1)之间

 Normalize 的作用是使用平均值和标准偏差规范化张量图像,第一个括号为mean,第一个括号为std 。因为要处理的为灰度图像是单通道的,所以只需要设置一个数字

3.2 加载数据集

torchvision 提供了很多数据集,这里我们下载MNIST数据集

root 指定下载的路径,如果路径中包含就不需要下载,train = True 为导入训练集 ,download 设置为True ,会下载到root中,如果路径中有就会之间读取,transform 数据预处理

Dataloader 是一个可以迭代的对象,它能将dataset返回的每一个数据样本拼接成一个batch,并且提供了多线程加速(num_workers)和打乱(shuffle)的功能

 这里依次读入train 和 test 集 ,设置batch_size 为64 

3.3 实例化网络  + 构建优化器

这部分较为简单,因为这个网络已经很大了,所以建议SGD梯度下降的时候加个momentum

3.4 训练网络

我们让epoch 训练五次

enumerate 会返回每一批的data 和 index ,所以我们定义的变量data里面包含两个value,i 是第i组数据

将第一个value 赋值给inputs(每一个样本的28*28) ,第二个value 赋值给 labels(0-9之间)

然后就是梯度清零->网络前向传播->计算损失->反向传播->梯度更新

 接下来就是打印loss,我们设置成每300次打印一次。因为每次batch会拿出64,这样6w张训练数据 / (64 * 300)= 3.125 ,所以一个epoch只会打印三次

 我们再在每次训练的时候计算网络的accuracy。

因为计算正确率的时候,不需要反向传播,所以这里不需要计算梯度with torch.no_grad():代表里面的代码不会计算梯度。

然后从testloader里面取出数据,经过网络计算网络的预测值。通过max函数计算输出的最大值

torch.max 函数会返回两个值(value,index),我们对网络的输出最大值不感兴趣,或者说我们对网络哪一个位置输出最大值感兴趣,因为网络输出的10维度正好对应我们的(0-9)。而网络流通的维度为(n,784)->(784,512)->(512,256)->(256,128)->(128,64)->(64,10),所以n个图像784(28*28)经过网络输出成了维度为10的特征,而10是列,所以dim设置成1,要横着取最大值

 然后总共的数量是一次拿的batch_size (本章代码是64),当预测和label一致的时候正确率+1,最后打印就行了

3.5 保存网络参数

 将网络的参数保存到Net.pth文件中

4. 预测图片(predict)

这里在网上找了一个手写数字的图片,用来测试我们的网络

 在对网络预测前有许多准备工作,比如灰度化、调整大小等等,

这里图像灰度化可以用opencv简单操作就行了,调整大小的话可以用transforms里面的Resize函数

然后加载网络训练好的参数,将图像预处理后就可以丢到网络中了

 这里也可以像train里面写成下面这样,但是因为我们已经知道了max会返回value和index,之间用下标[1]取第二个值就行了

_, predicted = torch.max(outputs, dim=1) # 取出每一行最大值,返回值为(值,索引)

反思:

网络存在的问题还有很多,比如明明test 上面的精度已经90多了,然后从网上随便找张图确很容易出错

虽然net的test精度很高,但是不能排除过拟合的问题,可能图片不是同分布的等等....

5. code

1. model 模块

import torch
import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):    # 定义网络
    def __init__(self):
        super(Net,self).__init__()     # 继承父类
        self.linear1 = nn.Linear(784,512)
        self.linear2 = nn.Linear(512,256)
        self.linear3 = nn.Linear(256,128)
        self.linear4 = nn.Linear(128,64)
        self.linear5 = nn.Linear(64,10)
    def forward(self,x):
        x = x.view(-1,784)
        x = F.relu(self.linear1(x))
        x = F.relu(self.linear2(x))
        x = F.relu(self.linear3(x))
        x = F.relu(self.linear4(x))
        x = self.linear5(x)
        return x

2. train 模块

import torch
import torchvision as tv    # 提供数据集
import torchvision.transforms as transforms   # 图像处理包
import torch.nn as nn
import torch.nn.functional as F
from torch import optim   # 优化器包
from model import Net    # 导入建立的网络

transform = transforms.Compose([transforms.ToTensor(),  # 改变通道顺序、归一化
                                transforms.Normalize((0.5,),(0.5,))])  # 归一化


batch_size = 64
# 训练集 6W张图片
trainset =tv.datasets.MNIST(root= './dataset/mnist',train=True,download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,shuffle = True,batch_size = batch_size)
# 测试集 1W张图片
testset =tv.datasets.MNIST(root= './dataset/mnist',train=False,download=True,transform=transform)
testloader = torch.utils.data.DataLoader(testset,shuffle = False,batch_size = batch_size)


net = Net()   # 实例化网络
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(),lr = 0.01,momentum=0.5)

for epoch in range(5):
    running_loss = 0.0
    for i,data in enumerate(trainloader,0):
        inputs,labels =data       # 取出对应的数据
        optimizer.zero_grad()       # 梯度清零

        outputs = net(inputs)       # 计算预测值
        loss = criterion(outputs,labels)        # 计算损失值
        loss.backward()         # 反向传播
        optimizer.step()        # 梯度更新

        running_loss += loss.item()    # 取出loss值

        if i % 300 ==299:
            correct = 0
            total = 0
            with torch.no_grad():   # 不会计算梯度
                for data in testloader: # 从测试集里面取出数据
                    images, labels = data
                    outputs = net(images)       # 计算预测值
                    _, predicted = torch.max(outputs, dim=1) # 取出每一行最大值,返回值为(值,索引)
                    total += labels.size(0)   # batch_size
                    correct += (predicted == labels).sum().item()

                print('[%d,%5d] train_loss: %.3f test_accuracy:%.3f'
                      % (epoch+1,i +1,running_loss/300,correct / total))
                running_loss = 0.0

print('Finished Training....')

save_path = './Net.pth'
torch.save(net.state_dict(),save_path)

打印结果为:

 

3. predict 模块

import torch
import torchvision.transforms as transforms
from PIL import Image
from model import Net

transforms = transforms.Compose([transforms.Resize((28,28)),   # 预处理
                                 transforms.ToTensor(),
                                 transforms.Normalize((0.5,),(0.5,))])
net = Net()
net.load_state_dict(torch.load('Net.pth'))   # 读取网络参数

im = Image.open('./2.png')
im = transforms(im)

with torch.no_grad():
    outputs  = net(im)
    predict = torch.max(outputs,dim = 1)[1]  # 取最大值的index
print(predict)

输出为:

 

 

相关文章:

  • 网站建设的几大原则/成都seo招聘
  • 连锁网站开发/搜索引擎的使用方法和技巧
  • 海口网站运营托管公司/导航网站怎么推广
  • wordpress会影响网速吗/长尾关键词举例
  • 西安网站建设项目/百度前三推广
  • 做网站 报价 需要了解/搜索引擎google
  • 【HTML+CSS】静态网页设计期末大作业——我的家乡无锡印象
  • @Cacheable和@CacheEvict的学习使用
  • 宝塔面板安装部署Vue项目,Vue项目从打包到上线
  • 《MLB棒球创造营》:走近棒球运动·迈阿密马林鱼队
  • 【设计模式】行为型模式-第 3 章第 3 讲【解释器模式】
  • 【云原生之Docker实战】使用Docker部署pdf2htmlEX文件转换工具
  • Observability:集群监控 (一) - Elastic Stack 8.x
  • 【maven】什么是坐标(依赖)继承与模块、web项目启动访问
  • Solidity 基础知识
  • CDH大数据平台 18Cloudera Manager Console之Sentry权限kafka测试(markdown新版)
  • 【ROS】如何在ROS中使用anaconda虚拟环境?
  • 什么是 PowerShell?