当前位置: 首页 > news >正文

Pytorch自定义数据集模型训练流程

文章目录

  • Pytorch模型自定义数据集训练流程
    • 1、任务描述
    • 2、导入各种需要用到的包
    • 3、分割数据集
    • 4、将数据转成pytorch标准的DataLoader输入格式
    • 5、导入预训练模型,并修改分类层
    • 6、开始模型训练
    • 7、利用训好的模型做预测

Pytorch模型自定义数据集训练流程

我们以kaggle竞赛中的猫狗大战数据集为例搭建Pytorch自定义数据集模型训练的完整流程。

1、任务描述

Cats vs. Dogs(猫狗大战)数据集是Kaggle大数据竞赛某一年的一道赛题,利用给定的数据集,用算法实现猫和狗的识别。 其中包含了训练集和测试集,训练集中猫和狗的图片数量都是12500张且按顺序排序,测试集中猫和狗混合乱序图片一共12500张。
下载地址:https://www.kaggle.com/c/dogs-vs-cats/data

在这里插入图片描述

卷积神经网络(CNN)是一类包含卷积计算且具有深度结构的前馈神经网络,是深度学习的代表算法之一。卷积神经网络具有表征学习能力,能够按其阶层结构对输入信息进行平移不变分类,因此也被称为“平移不变人工神经网络”。
默认对图像分类各种算法已经熟悉,卷积、池化、批量归一化、全连接等各种结构具体细节这里不讨论,有不懂的需自行学习。

2、导入各种需要用到的包

import torch
import torchvision
from torchvision import datasets, transforms
import torch.utils.data
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset,DataLoader,Dataset
from torch.utils.tensorboard import SummaryWriter
import torch.nn.functional as F
from torch import nn
import numpy as np
import os
import shutil
from PIL import Image
import warnings
warnings.filterwarnings("ignore")

3、分割数据集

下载猫狗大战数据集,并解压。
解压完成后,通过以下代码实现数据集预处理(剔除不能正常打开的图片,打乱数据集等);然后对数据集进行分割,其中90%的数据集作为train训练,另外10%的数据集作为test测试。

# 分割数据集,将全部数据分成0.9的Train和0.1的Test
source_path = r"./kagglecatsanddogs_5340/PetImages/"
# 如果不存在文件夹要新建一个
if not os.path.exists(os.path.join(source_path, "train")):
    os.mkdir(os.path.join(source_path, "train"))
train_dir = os.path.join(source_path, "train")

if not os.path.exists(os.path.join(source_path, "test")):
    os.mkdir(os.path.join(source_path, "test"))
test_dir = os.path.join(source_path,"test")

## 将Cat和Dog文件夹全部移到train目录下,然后再从train目录下移动10%到test目录下
for category_dir in os.listdir(source_path):
    if category_dir not in ["train", "test"]:
        shutil.move(os.path.join(source_path,category_dir), os.path.join(source_path,"train"))
            
## 开始移动,移动前先剔除不能正常打开的图片
for dir in os.listdir(train_dir):
    category_dir_path = os.path.join(train_dir, dir)
    image_file_list = os.listdir(category_dir_path)   # 取出全部图片文件
    for file in image_file_list:
        try:
            Image.open(os.path.join(category_dir_path, file))
        except:
            os.remove(os.path.join(category_dir_path, file))
            image_file_list.remove(file)
    np.random.shuffle(image_file_list)
    test_num = int(0.1*len(image_file_list))
 
    #移动10%文件到对应目录
    if not os.path.exists(os.path.join(test_dir,dir)):
        os.mkdir(os.path.join(test_dir,dir))
    if len(os.listdir(os.path.join(test_dir,dir))) < test_num:  # 只有未移动过才需要移动,否则每运行一次都会移动一下
        for i in range(test_num):
            shutil.move(os.path.join(category_dir_path,image_file_list[i]), os.path.join(test_dir,dir,image_file_list[i]))

4、将数据转成pytorch标准的DataLoader输入格式

1、先对数据集进行预处理,包括resize成224*224的尺寸,因为vgg_net模型需要的输入尺寸为[N, 224, 224, 3];随机翻转,随机旋转等,另外对数据集做Normalize标准化,其中的mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.2]是从ImageNet数据集上的百万张图片中随机抽样计算得到的,以上这些内容主要是数据增强,增强模型的泛化性,有更好的预测效果。
2、然后将预处理好的数据转成pytorch标准的DataLoader输入格式,。

# 数据预处理
transform = transforms.Compose([
    transforms.RandomResizedCrop(224),# 对图像进行随机的crop以后再resize成固定大小
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456,
0.406],std=[0.229, 0.224, 0.2]),  # ImageNet全部图片的平均值和标准差
    transforms.RandomRotation(20), # 随机旋转角度
    transforms.RandomHorizontalFlip(p=0.5), # 随机水平翻转
])
 
# 读取数据
root = source_path
train_dataset = datasets.ImageFolder(root + '/train', transform)
test_dataset = datasets.ImageFolder(root + '/test', transform)
 
# 导入数据
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=16, shuffle=False)

5、导入预训练模型,并修改分类层

1、定义device,如果有GPU模型训练会自动用GPU训练,否则会使用CPU;使用GPU训练,只需在模型、数据、损失函数上使用cuda()就行。
2、这边默认对分类图像算法都熟悉,可以自己构建vgg16的完整网络,在猫狗数据集上重新训练。也可以下载预训练模型,由于原网络的分类输出是1000类别的,但是我们的图片只有两类,所以需要修改分类层,让模型能够适配我们的训练数据集。

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
vgg16 = torchvision.models.vgg16(pretrained=True).to(device)
print(vgg16)

inputs = torch.rand(1, 3, 224, 224)  # 拿一个随机tensor测试一下网络的输出是否满足预期
output = vgg16(inputs.to(device))
print("原始VGG网络的输出:",output.size())

# 构建新的全连接层
vgg16.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 100),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(100, 2)).to(device)
inputs = torch.rand(1, 3, 224, 224)
output = vgg16(inputs.to(device))
print("新构建的VGG网络的输出:",output.size())

6、开始模型训练

开始模型训练,我们这里只训练全连接分类层,将特征层的梯度requires_grad设置为False,特征层的参数将不参与训练。
训练过程中保存效果最好的网络模型,以防掉线,可以从断点开始继续训练,同时也可以用来做预测。
训练完成后,保存训练好的网络和参数,后面可以加载模型做预测。

writer = SummaryWriter("./logs/model")
loss_func = nn.CrossEntropyLoss().to(device)
learning_rate = 0.0001

#如果我们想只训练模型的全连接层
for param in vgg16.features.parameters():
    param.requires_grad = False
optimizer = torch.optim.Adam(vgg16.parameters(),lr=learning_rate)

##训练开始
total_train_step = 0
total_test_step = 0
min_acc = 100.0
for epoch in range(10):
    print("-----------train epoch {} start---------------".format(epoch))
    vgg16.train()
    for data in train_loader:
        optimizer.zero_grad()
        img, label = data
        output = vgg16(img.to(device))
        loss = loss_func(output, label.to(device))
        loss.backward()
        optimizer.step()
        total_train_step += 1
        
        if total_train_step % 10 == 0:
            print("steps: {}, train_loss: {}".format(total_train_step, loss.item()))
            writer.add_scalar("train_loss", loss.item(), total_train_step)


    ## 测试开始,看训练效果是否满足预期
    total_test_loss = 0
    total_acc = 0.0
    vgg16.eval()
    with torch.no_grad():
        for data in test_loader:
            optimizer.zero_grad()
            img, label = data
            output = vgg16(img.to(device))
            loss = loss_func(output, label.to(device))
            total_test_loss += loss
            accuary = torch.sum(output.argmax(1) == label.to(device))
            total_acc += accuary
    total_test_step += 1
    val_acc = total_acc.item() / len(test_dataset)
    
    ## 保存Acc最小的模型
    if val_acc < min_acc:
        min_acc = val_acc
        torch.save(vgg16.state_dict(), "./models/2classes_vgg16_weight_{}_{}.pth".format(epoch, round(val_acc,4)))
        torch.save(vgg16, "./models/2classes_vgg16_{}_{}.pth".format(epoch, round(val_acc,4)))

    print("测试loss: {}".format(total_test_loss.item()))
    print("测试Acc: {}".format(val_acc))
    writer.add_scalar("test_loss", total_test_loss.item(), total_test_step)
    writer.add_scalar("test_Acc", val_acc, total_test_step)

torch.save(vgg16.state_dict(), "./models/2classes_vgg16_latest_{}.pth".format(val_acc))

7、利用训好的模型做预测

拿出一张图片做预测,首先导入预训练模型,同样改掉分类层,然后导入预训练权重,预测图片类别,输出标签值和预测类别。

import matplotlib.pyplot as plt
img_path = r"./kagglecatsanddogs_5340/PetImages/test/Cat/1381.jpg"   # 拿出要预测的图片
image = Image.open(img_path).convert("RGB")
image.show()
    
vgg16_pred = torchvision.models.vgg16(pretrained=True)
vgg16_pred.classifier = torch.nn.Sequential(torch.nn.Linear(25088, 100),
                                       torch.nn.ReLU(),
                                       torch.nn.Dropout(p=0.5),
                                       torch.nn.Linear(100, 2))

transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize((224,224), interpolation=2),
    torchvision.transforms.ToTensor()
])
vgg16_pred.load_state_dict(torch.load("./models/2classes_vgg16_weight_15_0.9467513434294089.pth", map_location=torch.device('cpu')))
print(vgg16_pred)

image = transform(image)
print(image.size())
image = torch.reshape(image, [1,3,224,224])
vgg16_pred.eval()
with torch.no_grad():
    output = vgg16_pred(image)
# print("预测值为:",output)
print("预测标签为:",output.argmax(1).item())
print("预测动物为:",train_dataset.classes[output.argmax(1)])

相关文章:

  • 自适应产品网站模板/泰安网站制作推广
  • 现在还用dw做网站设计么/海外营销方案
  • 温州建站软件/重庆企业站seo
  • 宜昌视频网站建设/公众号seo排名软件
  • 网站建设费用的会计/吸引顾客的营销策略
  • 网站开发范本/广告文案经典范例200字
  • QEMU零知识学习3 —— QEMU配置
  • k8s之挂载本地磁盘到POD中
  • Spring国际化详解,Spring国家化实例及源码详解
  • 解决Windows Server远程断开后自动锁屏问题
  • 系分 - 案例分析 - 系统设计
  • 基于有向图的邻接矩阵计算其割点、割边、压缩图,并用networkx可视化绘制
  • 【进阶】Spring更简单的读取和存储对象
  • C++内存分配方法new与placement new使用方法详解
  • [ACTF2020 新生赛]BackupFile
  • 自动化测试 | 这些常用测试平台,你们公司在用的是哪些呢?
  • Android大厂面试100题,涵盖测试技术、环境搭建、人力资源
  • 【QT5 实现“上图下文”,带图标的按键样式-toolbutton-学习笔记-记录-基础样例】实现方式之一