当前位置: 首页 > news >正文

【6s965-fall2022】深度学习的效率指标

在这里插入图片描述

  • 两个核心指标是计算和内存(Computation and Memory)。
  • 需要考虑的三个维度是存储、延迟和能耗(Storage, Latency, and Energy)。

延迟 Latency

在这里插入图片描述
Latency = m a x ( T o p e r a t i o n , T m e m o r y ) max(T_{operation}, T_{memory}) max(Toperation,Tmemory)

能耗 Energy

在这里插入图片描述

  • 内存访问比计算更消耗能量。以下是能耗排名:

  • DRAM Access > SRAM Access > FP Mult > INT Mult > Register File > FP Add > INT Add

  • 因此,我们应该避免数据移动,因为数据移动越多,内存引用就会导致更多的能量消耗。

内存相关的指标

模型参数 Number of parameters (#Parameters)

  • Linear: c 0 × c i c_0 \times c_i c0×ci

  • Convolution: c o c i k h k w c_o c_i k_h k_w cocikhkw

  • Grouped convolution: c o c i k h k w / g c_o c_i k_h k_w / g cocikhkw/g

  • Depthwise convolution: c o k w k h c_o k_w k_h cokwkh

模型大小 Model size

  • M o d e l S i z e = N u m O f P a r a m e t e r s × B i t W i d t h = 模型参数 × 位宽 Model Size = NumOfParameters \times BitWidth=模型参数 \times 位宽 ModelSize=NumOfParameters×BitWidth=模型参数×位宽

例如,AlexNet有61M参数,因此它的模型大小将是244MB (FP32)和61MB (INT8)。

激活函数的个数 Number of Activations(#Activations)

  • 激活函数的个数是IoT推理中的内存瓶颈,而不是模型参数。
    在这里插入图片描述
  • 在训练过程中,内存瓶颈不是参数,而是激活函数的个数。
    在这里插入图片描述- MCUNet:从输入层到输出层,激活占的比例越来越小,权重占的比例越来越大,因为通道在增加。
    在这里插入图片描述

计算相关的指标

MACs: multiply-accumulate operations 乘法累加操作

  • 一次乘法累加(MAC)操作是 a = a + b × c a = a + b \times c a=a+b×c
  • 以下是一些常见的MACs的计算方式:
    • Matrix-vector multiplication (MV): m × n m\times n m×n

    • General matrix-matrix multiplication (GEMM): m × n × k m\times n\times k m×n×k

    • Linear layer: c o × c i c_o\times c_i co×ci

    • Convolution: c i × k w × k h × h o × w o × c o c_i\times k_w \times k_h \times h_o \times w_o \times c_o ci×kw×kh×ho×wo×co

    • Grouped convolution: c i × k w × k h × h o × w o × c o / g c_i\times k_w \times k_h \times h_o \times w_o \times c_o / g ci×kw×kh×ho×wo×co/g

    • Depthwise convolution: k w × k h × h o × w o × c o k_w \times k_h \times h_o \times w_o \times c_o kw×kh×ho×wo×co

FLOP: floating point operation

  • 1 M A C = 2 F L O P 1MAC = 2 FLOP 1MAC=2FLOP

    • 例如,AlexNet有724M mac,对应1.4G FLOP。
  • Floating point operation per second (FLOPS)

    • F L O P S = F L O P s e c o n d FLOPS = \frac{FLOP}{second} FLOPS=secondFLOP

用python 实现这些指标

from torchprofile import profile_macs

def get_model_macs(model, inputs) -> int:
    return profile_macs(model, inputs)

def get_num_parameters(model: nn.Module, count_nonzero_only=False) -> int:
    """
    calculate the total number of parameters of model
    :param count_nonzero_only: only count nonzero weights
    """
    num_counted_elements = 0
    for param in model.parameters():
        if count_nonzero_only:
            num_counted_elements += param.count_nonzero()
        else:
            num_counted_elements += param.numel()
    return num_counted_elements


def get_model_size(model: nn.Module, data_width=32, count_nonzero_only=False) -> int:
    """
    calculate the model size in bits
    :param data_width: #bits per element
    :param count_nonzero_only: only count nonzero weights
    """
    return get_num_parameters(model, count_nonzero_only) * data_width


Byte = 8
KiB = 1024 * Byte
MiB = 1024 * KiB
GiB = 1024 * MiB

相关文章:

  • 响应式网站简单模板/北京高端网站建设
  • 在线收费视频网站开发/亚马逊关键词
  • 做资讯类网站/找网站公司制作网站
  • python培训机构哪个好/兰州网站优化
  • 移动版网站建设的必要性/国外搜索引擎排名百鸣
  • 加强政府网站建设推进会/太原网站优化
  • 【markdown】语法 添加`emoji`表情
  • Java——全排列
  • 威纶通触摸屏配方功能的使用方法示例
  • IB地理学什么?适合什么人学习?
  • MS SQL Server 日志审核工具
  • ES6 课程概述⑦
  • 19/365 java 多线程
  • Java设计模式-观察者模式Observer
  • 开源项目介绍
  • 开源PPP软件PRIDE-PPPAR使用记录(二)解算网友发来的GNSS观测文件
  • 数据分析面试题--SQL面试题
  • gitlab-runner搭建CI/CD