当前位置: 首页 > news >正文

Pytorch DataLoader中的num_workers (选择最合适的num_workers值)

一、概念

num_workers是Dataloader的概念,默认值是0。是告诉DataLoader实例要使用多少个子进程进行数据加载(和CPU有关,和GPU无关)
如果num_worker设为0,意味着每一轮迭代时,dataloader不再有自主加载数据到RAM这一步骤(因为没有worker了),而是在RAM中找batch,找不到时再加载相应的batch。缺点当然是速度慢。

当num_worker不为0时,每轮到dataloader加载数据时,dataloader一次性创建num_worker个worker,并用batch_sampler将指定batch分配给指定worker,worker将它负责的batch加载进RAM。

num_worker设置得大,好处是寻batch速度快,因为下一轮迭代的batch很可能在上一轮/上上一轮…迭代时已经加载好了。坏处是内存开销大,也加重了CPU负担(worker加载数据到RAM的进程是CPU复制的嘛)。num_workers的经验设置值是自己电脑/服务器的CPU核心数,如果CPU很强、RAM也很充足,就可以设置得更大些。

num_worker小了的情况,主进程采集完最后一个worker的batch。此时需要回去采集第一个worker产生的第二个batch。如果该worker此时没有采集完,主线程会卡在这里等。(这种情况出现在,num_works数量少或者batchsize
比较小,显卡很快就计算完了,CPU对GPU供不应求。)

即,num_workers的值和模型训练快慢有关,和训练出的模型的performance无关

Detectron2的num_workers默认是4

二、选择最合适的num_workers值

最合适的num_works值与数据集有关
最好是跑代码之前先用这段script跑一下,选择最合适的num_workers值

from time import time
import multiprocessing as mp
import torch
import torchvision
from torchvision import transforms
 
 
transform = transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.1307,), (0.3081,))
])
 
trainset = torchvision.datasets.MNIST(
    root='dataset/',
    train=True,  #如果为True,从 training.pt 创建数据,否则从 test.pt 创建数据。
    download=True, #如果为true,则从 Internet 下载数据集并将其放在根目录中。 如果已下载数据集,则不会再次下载。
    transform=transform
)
 
print(f"num of CPU: {mp.cpu_count()}")
for num_workers in range(2, mp.cpu_count(), 2):  
    train_loader = torch.utils.data.DataLoader(trainset, shuffle=True, num_workers=num_workers, batch_size=64, pin_memory=True)
    start = time()
    for epoch in range(1, 3):
        for i, data in enumerate(train_loader, 0):
            pass
    end = time()
    print("Finish with:{} second, num_workers={}".format(end - start, num_workers))

在这里插入图片描述
可以看到,这个服务器24个CPU, 最合适的num_workers值是14

三、可能出现的问题

在这里插入图片描述
linux系统中可以使用多个子进程加载数据,windows系统里是不可以的,可以发现报错时产生在DataLoader文件中的。我们找到自己调用DataLoader的文件中num_workers的设置,设置为0或者采用默认为0的设置。

相关文章:

  • 用来做微网站的/友情链接英文
  • 网站建设招标模板/目前最靠谱的推广平台
  • 政府类网站建设/网络广告代理
  • 建设银行广州社会招聘网站/环球网今日疫情消息
  • 做网站是不是很麻烦/苏州网站维护
  • wordpress博客推荐/制作网站的公司有哪些
  • 第328场周赛2537. 统计好子数组的数目
  • JDK1.8使用的垃圾回收器和执行GC的时长以及GC的频率
  • 【Django项目开发】django的信号机制(八)
  • JUC面试(一)——JUCJMMvolatile 1.0
  • Xinlinx zynq7020国产替代 FMQL20S400 全国产化 ARM 核心板+扩展板
  • Go语言开发小技巧易错点100例(五)
  • 算法leetcode|31. 下一个排列(rust重拳出击)
  • SpringCloud-Netflix学习笔记03——什么是Eureka
  • 测试篇(二): 如何合理的创建bug、bug的级别、bug的生命周期、跟开发产生争执怎么办
  • springboot 项目自定义log日志文件提示系统找不到指定的文件
  • 【数据结构与算法】顺序表的原理及实现
  • 【C进阶】动态内存管理