当前位置: 首页 > news >正文

P6 PyTorch 常用数学运算

前言:

      这里主要介绍一下PyTorch 的常用数学运算

   

目录:

   1: add|sub 加减法

    2:   mul/div  乘/除运算

    3:   矩阵乘法

    4    2D矩阵转置

    5  其它常用数学运算

    6  clamp 梯度剪裁

一  加减法

      1.1 加法

   可以直接通过符号+ 或者 torch.add 

# -*- coding: utf-8 -*-
"""
Created on Tue Dec 20 20:24:10 2022

@author: cxf
"""
import torch

def add():
    
    a = torch.tensor([[1.0,2.0,3.0,4.0],
                     [1.0,2.0,3.0,4.0],
                     [1.0,2.0,3.0,4.0]])
    
    
    b = torch.tensor([3.0,2.0,1.0,-4.0])
    
    
    c = torch.add(a,b)
    
    print(c)
    
    
if __name__ == "__main__":
    add()
    

 a的shape [3,4]

 b 的shape[4] 先做broadcasting 插入一个维度后,做行复制

然后相加,结果为

1.2 减法

   

# -*- coding: utf-8 -*-
"""
Created on Tue Dec 20 20:24:10 2022

@author: cxf
"""
import torch

def sub():
    
    a = torch.tensor([[1.0,2.0,3.0,4.0],
                     [1.0,2.0,3.0,4.0],
                     [1.0,2.0,3.0,4.0]])
    
    
    b = torch.tensor([3.0,2.0,1.0,-4.0])
    
    
    c = torch.sub(a,b)
    
    d = a -b
    
    
    
    print(torch.all(torch.eq(c,d)))
    
    
if __name__ == "__main__":
    sub()
    

 一种用符号 - 或者 用api  torch.sub 

输出:


二   乘/除运算

    2.1 乘

            可以通过符号* 或者mul

# -*- coding: utf-8 -*-
"""
Created on Tue Dec 20 20:24:10 2022

@author: cxf
"""
import torch

def mul():
    
    a = torch.tensor([[1.0,2.0,3.0,4.0],
                     [1.0,2.0,3.0,4.0],
                     [1.0,2.0,3.0,4.0]])
    
    
    b = torch.tensor([1.0,0.5,1.0,0.25])
    
    
    c = torch.mul(a,b)
    
    d = a*b
    
    
    
    print(torch.all(torch.eq(c,d)))
    print(c)
    
    
if __name__ == "__main__":
    mul()

对应点位置相乘,输出如下

 2.2 除

   

def divide():
    
    a = torch.tensor([[1.0,2.0,3.0,4.0],
                     [1.0,2.0,3.0,4.0],
                     [1.0,2.0,3.0,4.0]])
    
    
    b = torch.tensor([1.0,2.0,3.0,4.0])
    
    
    c = torch.div(a,b)
    
    d = a/b
    
    
    
    print(torch.all(torch.eq(c,d)))
    print(c)

b 做broadcast(插入一个维度后 变成[1,4] 然后做行复制)     

输出:

# -*- coding: utf-8 -*-
"""
Created on Tue Dec 20 20:24:10 2022

@author: cxf
"""
import torch

def mat():
    
    a = torch.tensor([[1.0,2.0],
                      [1.0,2.0]])
    
    b = torch.tensor([1.0,0.5]) #[2]
    b = b.unsqueeze(1) #[2,1]
    c = torch.mm(a,b)
    
    print("\n 矩阵2D 相乘:\n ",c.numpy())

三  矩阵乘法

    3.1 mm  2D 向量相乘

           

"""
Created on Tue Dec 20 20:24:10 2022

@author: cxf
"""
import torch

def mat():
    
    a = torch.tensor([[1.0,2.0],
                      [1.0,2.0]])
    
    b = torch.tensor([1.0,0.5]) #[2]
    b = b.unsqueeze(1) #[2,1]
    c = torch.mm(a,b)
    
    print("\n 矩阵2D 相乘:\n ",c.numpy())
    
    
if __name__ == "__main__":
    mat()
    

输出:

 2.2 matmul

      这种是最常用的推荐方法,多多维度的张量,依然只取最后两维做2D mm,其它的维度保持不变

    c = torch.matmul(a, b)

    

2.3 @

    重载符号,实现matmul 功能

     c = a@b


四  2D 矩阵转置

   

def mat():
    
    a = torch.tensor([[1,1,1],
                      [2,2,2]])
    
    b = torch.tensor([[1,1,1],
                      [1,2,2]])
    
    c= a@b.t()
    
    print(c)

  当多维向量的时候,使用的transpose

  


五  常用的数学符号

     pow  #平方

     rsqrt  #平方根

    exp  指数运算

     log  对数运算

     floor  向下取整

     ceil   向上取整

     trunc  取整

    round  四舍五入


六   clamp  梯度剪裁

  在深度学习中,网络层次比较深的时候,有的时候为了防止梯度爆炸,或者梯度消失

需要做梯度剪裁,抑制过大过小的值。

# -*- coding: utf-8 -*-
"""
Created on Mon Dec 19 17:24:41 2022

@author: chengxf2
"""

import numpy 
import torch

def clamp():
    
    grad = torch.rand(3,3)*10
    
  
    
    print("\n 梯度 ",grad)
    
    new_grad1 = grad.clamp(5.0)
    
    new_grad2 = grad.clamp(0,5)
    
    print("\n 梯度1 ",new_grad1)
    print("\n 梯度2 ",new_grad2)
    
    
    
clamp()

输出:

        

相关文章:

  • 重庆市建设厅官方网站/以品牌推广为目的的广告网络平台
  • 北京做网站建设比较好的公司/郑州seo代理商
  • b2b网站建设排名/网店搜索引擎优化的方法
  • 潞城建设局网站/百度指数官方下载
  • 专业建设外贸网站制作/拉新奖励的app排行
  • 现在网站优化怎么做/百度小说风云榜今天
  • DFS——剪枝
  • 安科瑞红外测温方案助力滁州某新能源光伏产业工厂安全用电
  • 2022年安徽最新水利水电施工安全员模拟试题及答案
  • 资深车主才会告诉你的那些事,看完立省三万二
  • encode decode作用,以及为什么会出现乱码呢?如何准确检查字符长度?
  • XXE漏洞详解(一)——XML基础
  • Maven 项目模板
  • SpringBoot文件上传(官方案例)
  • 动态内存开辟+柔性数组
  • JS圣诞树
  • 网络安全渗透测试的八个步骤(一)
  • C语言—局部变量和全局变量