当前位置: 首页 > news >正文

【Numpy基础知识】Broadcasting广播

Numpy广播

来源:Numpy官网:https://numpy.org/doc/stable/user/basics.html

在这里插入图片描述

广播描述了 NumPy 在算术运算期间如何处理具有不同形状的数组。根据某些约束,较小的数组将“广播”到较大的阵列,以便它们具有兼容的形状。

导包

import numpy as np

NumPy 操作通常是逐个元素地对数组对完成的。在最简单的情况下,两个数组必须具有完全相同的形状,如以下示例所示:

a = np.array([1.0, 2.0, 3.0])
print(a)
[1. 2. 3.]
b = np.array([2.0, 2.0, 2.0])
print(b)
[2. 2. 2.]
print(a * b)
[2. 4. 6.]

当数组的形状满足某些约束时,NumPy 的广播规则放宽了此约束。最简单的广播示例发生在操作中组合数组和标量值时:

a = np.array([1.0, 2.0, 3.0])
b = 2.0
print(a)
[1. 2. 3.]
print(a * b)
[2. 4. 6.]

结果等效于前面的示例,其中 b 是一个数组。我们可以认为标量 b 在算术运算期间被拉伸成一个与 a 形状相同的数组。

在这里插入图片描述

在最简单的广播示例中,标量 b 被拉伸为与 a 形状相同的数组a,因此这些形状与逐个元素的乘法兼容。

【1】一般广播规则

在两个数组上操作时,NumPy 会逐个元素比较它们的形状。它从尾随(即最右边)维度开始,然后向左工作。两个维度兼容以下情况

  • 它们是相等的,或者
  • 其中之一是 1

如果不满足这些条件,则会引发 ValueError: operands could not be broadcast together异常,指示数组具有不兼容的形状。生成的数组的大小是沿输入的每个轴的大小不是 1。

数组不需要具有相同数量的维度。例如,如果您有一个 256x256x3 的 RGB 值数组,并且您希望将图像中的每种颜色按不同的值缩放,则可以将图像乘以具有 3 个值的一维数组。根据广播规则排列这些数组的尾轴的大小,表明它们是兼容的:

Image (3d array): 256 x 256 x 3
Scale (1d array): 3
Result (3d array): 256 x 256 x 3

当比较的维度中的任何一个是一个时,将使用另一个维度。换句话说,尺寸为 1 的尺寸被拉伸或“复制”以匹配另一个尺寸。

在以下示例中,A 和 B 数组的轴长度为 1,在广播操作期间扩展为更大的尺寸:

A (4d array): 8 x 1 x 6 x 1
B (3d array): 7 x 1 x 5
Result (4d array): 8 x 7 x 6 x 5

【2】可广播数组

一些要广播的示例

A (2d array): 5 x 4
B (1d array): 1
Result (2d array): 5 x 4

A (2d array): 5 x 4
B (1d array): 4
Result (2d array): 5 x 4

A (3d array): 15 x 3 x 5
B (3d array): 15 x 1 x 5
Result (3d array): 15 x 3 x 5

A (3d array): 15 x 3 x 5
B (2d array): 3 x 5
Result (3d array): 15 x 3 x 5

A (3d array): 15 x 3 x 5
B (2d array): 3 x 1
Result (3d array): 15 x 3 x 5

下面是不广播的形状数组示例

A (1d array): 3
B (1d array): 4

A (2d array): 2 x 1
B (3d array): 8 x 4 x 3

将一维数组添加到二维数组时的广播示例:

a = np.array([[0.0, 0.0, 0.0],
              [10.0, 10.0, 10.0],
              [20.0, 20.0, 20.0],
              [30.0, 30.0, 30.0]])
print(a)
[[ 0.  0.  0.]
 [10. 10. 10.]
 [20. 20. 20.]
 [30. 30. 30.]]
b = np.array([1.0, 2.0, 3.0])
print(b)
[1. 2. 3.]
print(a + b)
[[ 1.  2.  3.]
 [11. 12. 13.]
 [21. 22. 23.]
 [31. 32. 33.]]

可以看到,结果将 b 添加到 a 的每一行。

在这里插入图片描述

如果b 的形状不匹配

在这里插入图片描述

当数组的尾随维度不相等时,广播将失败,因为无法将第一个数组的行中的值与第二个数组的元素对齐以进行逐个元素的添加。

广播提供了一种获取两个阵列的外部乘积(或任何其他外部操作)的便捷方法。以下示例显示了两个一维数组的外部加法操作:

a = np.array([0.0, 10.0, 20.0, 30.0])
b = np.array([1.0, 2.0, 3.0])
print(a)
[ 0. 10. 20. 30.]
print(b)
[1. 2. 3.]
print(a[:, np.newaxis] + b)
[[ 1.  2.  3.]
 [11. 12. 13.]
 [21. 22. 23.]
 [31. 32. 33.]]

在这里插入图片描述

在某些情况下,广播会拉伸两个数组以形成大于任一初始数组的输出数组。

在上面那个栗子里面,newaxis 索引运算符将一个新轴插入a 中,使其成为二维 4x1 数组。将 4x1 数组与形状为 (3,) 的 b 组合,将生成一个 4x3 数组。

【3】一个实际的例子:矢量量化

广播经常出现在现实世界的问题中。一个典型的例子出现在信息论、分类和其他相关领域使用的矢量量化 (VQ) 算法中。

在下面显示的非常简单的二维情况下,observation值描述了要分类的运动员的体重和身高。codes代表不同级别的运动员。1 找到最近的点需要计算观测值与每个代码之间的距离。最短的距离提供最佳匹配。在此示例中,codes[0] 是最接近的类,表示运动员可能是篮球运动员。

from numpy import array, argmin, sqrt, sum

observation = array([111.0, 188.0])
print(observation)
[111. 188.]
codes = array([[102.0, 203.0],
               [132.0, 193.0],
               [45.0, 155.0],
               [57.0, 173.0]])
print(codes)
[[102. 203.]
 [132. 193.]
 [ 45. 155.]
 [ 57. 173.]]
diff = codes - observation
print(diff)
[[ -9.  15.]
 [ 21.   5.]
 [-66. -33.]
 [-54. -15.]]
dist = sqrt(sum(diff ** 2, axis=-1))
print(dist)
[17.49285568 21.58703314 73.79024326 56.04462508]
argmin(dist)
0

在此示例中,observation数组被拉伸以匹配codes数组的形状:

Observation (1d array): 2
Codes (2d array): 4 x 2
Diff (2d array): 4 x 2

在这里插入图片描述

矢量量化的基本操作计算要分类的对象、深色方块和多个已知代码(灰色圆圈)之间的距离。在这个简单的情况下,代码表示各个类。更复杂的情况每个类使用多个代码。

通常,将大量observations值(可能从数据库中读取)与一组codes进行比较。请可以考虑以下方案:

Observation (2d array): 10 x 3
Codes (2d array): 5 x 3
Diff (3d array): 5 x 10 x 3

三维数组 diff 是广播的结果,而不是计算的必要条件。大型数据集将生成一个计算效率低下的大型中间数组。相反,如果使用围绕上述二维示例中代码的 Python 循环单独计算每个观察值,则使用更小的数组。

广播是一种强大的工具,用于编写简短且通常直观的代码,在 C 语言中非常有效地进行计算。但是,在某些情况下,广播会为特定算法使用不必要的大量内存。在这些情况下,最好用 Python 编写算法的外部循环。这也可能产生更具可读性的代码,因为随着广播中维度数量的增加,使用广播的算法往往变得更加难以解释。

相关文章:

  • 华为云CDN助力企业用户体验全面优化,让企业“惠”加速
  • Java项目:springboot+vue电影院会员管理系统
  • Android实现一维二维码扫描生成功能(一)-zxing导入现有项目
  • 企业经常会问到的软件测试面试题及答案,一定要好好记住
  • 转互联网好难,如何避免无效转行?
  • 试卷的安全方案
  • 真香啊,这招可以轻松抓取某音短视频数据(附 Python 代码)
  • ETHERCAT从站设计与FOC伺服马达电流环控制
  • nginx 解决跨域问题——(CORS)
  • Freemodbus启动流程分析
  • Java项目:springboot网上点餐系统
  • 全国职业院校技能大赛中职组网络安全竞赛试题 —XSS漏洞(笔记文档)
  • 4、常用类和对象
  • 机器学习模型-BUPA liver disorders-探索饮酒与肝炎关系
  • 【架构师(第五十三篇)】 性能优化之 HTTP 缓存
  • 博德宝闪耀回归,九牧国际化提速
  • 20岁电竞选手自学编程转行程序员,轻松拿下大厂offer
  • 转行IT,你需要了解的真实项目研发流程是怎样的?
  • 优优聚:美团成立机器人研究院!
  • 自己个人拥有一个可以支付功能的网站?当然可以了!保姆级演示!