当前位置: 首页 > news >正文

(刘二大人)PyTorch深度学习实践-卷积网络(Advance)

1. 1x1的卷积核的作用

  1. 在width和height不变的基础上改变通道的数量
  2. 减少计算量

2. GoogLeNet中Inception Module的实现 

 2.1 Inception块的代码实现

import torch
import torch.nn.functional as F

class InceptinA(torch.nn.Module):
    def __init__(self,channels):
        super(InceptinA, self).__init__()
        self.branch_pool = torch.nn.Conv2d(channels,24,kernel_size=1)
        self.branch1x1 = torch.nn.Conv2d(channels,16,kernel_size=1)
        self.branch5x5_1 = torch.nn.Conv2d(channels,16,kernel_size=1)
        self.branch5x5_2 = torch.nn.Conv2d(16,24,kernel_size=5,padding=2)#使用了5x5的卷积核,为保证w和h不变,使用padding=2
        self.branch3x3_1 = torch.nn.Conv2d(channels,16,kernel_size=1)
        self.branch3x3_2 = torch.nn.Conv2d(16,24,kernel_size=3,padding=1)
        self.branch3x3_3 = torch.nn.Conv2d(24,24,kernel_size=3,padding=1)

    def forward(self,x):
        branch_pool = F.avg_pool2d(x,kernel_size=3,padding=1,stride=1) #本来默认stride就是1
        branch_pool = self.branch_pool(branch_pool)

        branch1x1 = self.branch1x1(x)

        branch5x5 = self.branch5x5_2(self.branch5x5_1(x))


        branch3x3 = self.branch3x3_3(self.branch3x3_2(self.branch3x3_1(x)))

        outputs = [branch_pool,branch1x1,branch5x5,branch3x3]
        return torch.cat(outputs,dim=1) #BxCxWxH,dim=1按照通道数进行拼接

2.2 使用模块构建卷积网络训练Minist数据集

2.3  整体代码实现

import torch
from Inception import InceptinA
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
from torch.utils.data import DataLoader
from torchvision import datasets,transforms

#追踪日志
writer = SummaryWriter(log_dir='../LEDR')

#准备数据集
trans = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.1307,),(0.3801,))])
train_set = datasets.MNIST(root='E:\learn_pytorch\LE',train=True,transform=trans,download=True)
test_set = datasets.MNIST(root='E:\learn_pytorch\LE',train=False,transform=trans,download=True)

#下载数据集
train_data = DataLoader(dataset=train_set,batch_size=64,shuffle=True)
test_data = DataLoader(dataset=test_set,batch_size=64,shuffle=False)

#构建模型
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv_1 = torch.nn.Conv2d(1,10,kernel_size=5)#输出变成 10x24x24
        self.conv_2 = torch.nn.Conv2d(88,20,kernel_size=5)# 输出变成 20x12x12
        self.mp = torch.nn.MaxPool2d(2)

        self.incept1 = InceptinA(channels=10)
        self.incept2 = InceptinA(channels=20)

        self.fc = torch.nn.Linear(1408,10)

    def forward(self,x):
        x = F.relu(self.mp(self.conv_1(x)))# 输出为 10x12x12
        x = self.incept1(x) #输出是88x12x12
        x = F.relu(self.mp(self.conv_2(x)))# 输出是 20x4x4
        x = self.incept2(x) #输出是 88x4x4
        x = x.view(-1,1408)
        x = self.fc(x)
        return x

#实例化模型
huihui = Net()

#定义损失函数和优化函数
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(params=huihui.parameters(),lr=0.01,momentum=0.5)

#开始训练
def train(epoch):
    run_loss = 0.0
    for batch_id , data in enumerate(train_data,0):
        inputs , targets = data
        outputs = huihui(inputs)
        loss = criterion(outputs, targets)

        #归零,反馈,更新
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        run_loss += loss.item()
        if batch_id % 300 == 299:
            print("[%d,%d] loss:%.3f" %(epoch+1,batch_id+1,run_loss/300))
            run_loss = 0.0

def test():
    total = 0
    correct = 0
    with torch.no_grad():
        for data in test_data:
            inputs , labels = data
            outputs = huihui(inputs)
            _,predict = torch.max(outputs,dim=1)
            total += labels.size(0)
            correct += (predict==labels).sum().item()
        writer.add_scalar("The Accuracy1",correct/total,epoch)
        print('[Accuracy] %d %%' % (100*correct/total))

if __name__ == '__main__':
    for epoch in range(10):
        train(epoch)
        test()

writer.close()



2.4 结果展示(正确率还是98%)

D:\Anaconda3\envs\pytorch\python.exe E:/learn_pytorch/LE/Inception_model.py
[1,300] loss:0.961
[1,600] loss:0.207
[1,900] loss:0.143
[Accuracy] 96 %
[2,300] loss:0.115
[2,600] loss:0.095
[2,900] loss:0.098
[Accuracy] 97 %
[3,300] loss:0.083
[3,600] loss:0.081
[3,900] loss:0.071
[Accuracy] 98 %
[4,300] loss:0.068
[4,600] loss:0.066
[4,900] loss:0.069
[Accuracy] 98 %
[5,300] loss:0.063
[5,600] loss:0.055
[5,900] loss:0.054
[Accuracy] 98 %
[6,300] loss:0.054
[6,600] loss:0.053
[6,900] loss:0.050
[Accuracy] 98 %
[7,300] loss:0.047
[7,600] loss:0.050
[7,900] loss:0.048
[Accuracy] 98 %
[8,300] loss:0.043
[8,600] loss:0.041
[8,900] loss:0.050
[Accuracy] 98 %
[9,300] loss:0.041
[9,600] loss:0.040
[9,900] loss:0.043
[Accuracy] 98 %
[10,300] loss:0.037
[10,600] loss:0.038
[10,900] loss:0.040
[Accuracy] 98 %

Process finished with exit code 0
 

 2.5 图像展示

 

相关文章:

  • 北京朝阳建站优化/线上线下整合营销方案
  • 做58同城的网站要多少钱/平台推广方式
  • 荣耀手机官方网站/网站建设与管理主要学什么
  • 网站建设工作函/网站推广的途径和方法
  • 网站设计主题中文/关键词自动优化工具
  • 网站建设推来客在哪里/百度网盘登录入口
  • mysql 理论知识
  • Decoder与Encoder重要组件
  • C · 初阶 | 循环语句
  • Vue-(7)
  • 『Material Design』CollapsingToolbarLayout可折叠标题栏
  • 详解MySQL数据类型
  • 【入门】上了大学,最好了解一点计算机视觉
  • MyBatis的二级缓存
  • 前端面试真题宝典(二)
  • 初级C语言之【函数】
  • 【FreeRTOS】基于STM32F407的Freertos实时操作系统移植
  • 毕业设计 基于CNN实现谣言检测 - python 深度学习 机器学习