python broadcast

Posted by Meng Cao on 2019-08-07

广播(broadcasting)

指的是不同形状的数组之间的算术运算的执行方式。

广播的原则:

两个数组从后向前对比,每个维度数组兼容,当:

  1. 数值相等
  2. 或其中一个数值为1
    如果较短的数组维度遍历完毕,则可以广播,否则抛出ValueError: frames are not aligned异常

当比较的任何一个维度为1时,则使用另一个。换句话说,大小为1的维被拉伸或“复制”以匹配另一维。

Example

可以构成广播的例子:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
A      (2d array):  5 x 4
B (1d array): 1
Result (2d array): 5 x 4

A (2d array): 5 x 4
B (1d array): 4
Result (2d array): 5 x 4

A (3d array): 15 x 3 x 5
B (3d array): 15 x 1 x 5
Result (3d array): 15 x 3 x 5

A (3d array): 15 x 3 x 5
B (2d array): 3 x 5
Result (3d array): 15 x 3 x 5

A (3d array): 15 x 3 x 5
B (2d array): 3 x 1
Result (3d array): 15 x 3 x 5

以下是不能broadcast的形状示例:

1
2
3
4
5
A      (1d array):  3
B (1d array): 4 # trailing dimensions do not match

A (2d array): 2 x 1
B (3d array): 8 x 4 x 3 # second from last dimensions mismatched

所以经常会有这样的需求:在数组的某一维度上增加一个1,例如维度(4,3)转换成(4,1,3), 我们可以使用np.newaxis属性以及全切片来插入新轴:

1
2
3
arr = np.zeros((4, 4))
arr_3d = arr[:, np.newaxis, :]
arr_3d.shape # [4,1,4]