Pytorch 张量操作

前言

感觉每次看 pytorch 的代码,都会被其中的张量的维度给绕晕,尤其是遇到一些操作改变维度的时候,自己写的代码也是经常报维度的错误。正好最近学弟在看李沐老师的《动手学深度学习》的预备知识部分,为了解答一些问题,我也去重温了这一块知识,并对其中最重要也是最容易绕进去的张量降维与广播机制进行整理。

张量

通过一维张量表示向量

1
2
x = torch.arange(4)
x

输出

1
tensor([0, 1, 2, 3])

大量文献认为列向量是向量的默认方向,本文中也是如此

降维

这一部分非常容易搞混淆!!!建议多多画图理解,多多推演形成直觉。

默认情况下,调用求和函数会沿所有的轴降低张量的维度,使它变为一个标量。 我们还可以指定张量沿哪一个轴来通过求和降低维度。 以矩阵为例,为了通过求和所有行的元素来降维(轴 0),可以在调用函数时指定 axis = 0。 由于输入矩阵沿 0 轴降维以生成输出向量,因此输入轴 0 的维数在输出形状中消失。

1
2
x = torch.arange(6).reshape(2,3)
x
1
2
[[0,1,2],
[3,4,5]]
1
2
3
x.sum() # 15
x.sum(axis=0) # [3,5,7]
x.sum(axis=1) # [3,12]
1
2
3
4
5
6
7
x = torch.arange(24).reshape(2, 3, 4)
print(x)

print(x.sum())
print(x.sum(axis=0))
print(x.sum(axis=1))
print(x.sum(axis=2))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
tensor([[[ 0,  1,  2,  3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]],

[[12, 13, 14, 15],
[16, 17, 18, 19],
[20, 21, 22, 23]]])
tensor(276)
tensor([[12, 14, 16, 18],
[20, 22, 24, 26],
[28, 30, 32, 34]])
tensor([[12, 15, 18, 21],
[48, 51, 54, 57]])
tensor([[ 6, 22, 38],
[54, 70, 86]])

基础记忆口诀

  1. “沿谁求和谁消失”:指定 axis = n 时,张量会沿着该轴方向压缩(该轴的维度在输出中消失)。
  2. “不指定轴全坍缩”:不指定 axis 时,所有维度求和,最终输出标量。

注意,沿轴 0 求和实际上是指轴 0 的维度消失掉。 “沿某个轴求和 = 该轴的方向被压缩消失,求和操作是跨该轴的切片进行的”

  • 将 axis 对应的维度想象为” 被折叠的方向”。
  • 输出形状 = 原始形状去掉该轴(如 (a,b,c) 沿 axis = 1 → 去掉 b 得 (a,c))。

维度保持

保持维度:使用 keepdim = True 可保留被压缩的轴(如 (2,2) 沿 axis = 0 求和得 (1,2))。

把张量想象成洋葱,axis = n 就是剥掉第 n 层皮,把这一层的所有切片相加。

输入形状:(d0, d1, …, dn, …) axis = n 求和后 → 该维度消失,形状变为 (d0, d1, …, dn-1, dn + 1, …)

广播机制

广播 (Broadcasting) 是 PyTorch 中一种重要的机制,它允许在不同形状的张量之间进行逐元素操作,通过适当复制元素来扩展一个或两个数组,以便在转换之后,两个张量具有相同的形状。

基本规则

  1. 从尾部开始对齐: 比较两个张量的形状时,从最右边的维度开始向左比较。
  2. 维度兼容条件:
    • 两个维度相等
    • 其中一个维度为 1
    • 其中一个维度不存在 (即一个张量的维度比另一个少)
  3. 扩展操作: 在满足兼容条件的维度上,PyTorch 会自动将大小为 1 的维度复制扩展为另一个张量对应维度的大小。

维度

numpy 中的维度都是用元组来指定的,比如 np.zeros((2,3,2)) 的维度数量是三维的。np.zeros((3,)) 维度数量这是 1 维的,因为 (3) 不是元组它只能算 3 加个括号而已,加个逗号 (3,) 才是元组。

1
2
x = np.zeros((2,3,4)) # 维度数量为3,这个数组第一维的维度大小是2,第二维的维度大小是3,第三维的维度大小是4.
y = np.zeros((3,)) # 维度数量为1,第一维大小为3

相似写法比较:

1
2
3
4
x = np.zeros(2,3,4) # 错误!  TypeError: Cannot interpret '3' as a data type

y1 = np.zeros(3) # 正确,等价于y - NumPy专门为向量提供的简化形式
y2 = np.zeros(3,) # 正确,等价于y - 因为Python允许省略单元素元组的括号

np.zeros(2,3,4) 的错误原因: 这不是语法错误,而是参数传递错误 它会被解释为 np.zeros(2, dtype = 3, other_kwarg = 4),因为 NumPy 会把后面的参数当作关键字参数 实际报错会是 TypeError: data type not understood

就记住表示维度用元组,一维向量的表示可以简化比较特殊

广播示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch

# 示例1:标量与张量
a = torch.tensor([1, 2, 3]) # shape (3,)
b = 2 # shape ()
c = a + b # b被广播为[2, 2, 2]
print(c) # tensor([3, 4, 5])

# 示例2:不同形状的张量
x = torch.ones(2, 3) # shape (2, 3)
y = torch.tensor([1, 2, 3]) # shape (3,)
z = x + y # y被广播为[[1,2,3], [1,2,3]]
print(z)

# 示例3:更复杂的广播
m = torch.ones(4, 1, 3) # shape (4, 1, 3)
n = torch.ones(5, 3) # shape (5, 3)
o = m + n # 广播后形状为(4,5,3)
print(o.shape)

在示例 3 中,广播过程遵循了 PyTorch 的维度兼容条件,三点兼备,具体分析如下:

初始形状: m: (4, 1, 3)

n: (5, 3)

广播步骤:

  1. 从右向左对齐维度:
1
2
m: (4, 1, 3)
n: (5, 3)
  1. 逐维度比较:

    • 最右边维度(第 2 维): 3 == 3 → 兼容
    • 中间维度(第 1 维):
      • m 是 1,n 是 5 → 其中一个为 1 → 兼容(m 的这个维度会被扩展为 5)
    • 最左边维度(第 0 维):
      • m 有维度 4,n 没有这个维度 → 兼容(n 会被视为在额外维度上大小为 1,即 (1,5,3),然后扩展为 (4,5,3))
  2. 最终广播形状:

    • m 从 (4,1,3) 扩展为 (4,5,3)(复制第 1 维)
    • n 从 (5,3) 先被视为 (1,5,3),然后扩展为 (4,5,3)(复制第 0 维)

满足的兼容条件: 维度大小相等:最右边的 3 == 3。 其中一个维度为 1:m 的中间维度是 1,与 n 的 5 兼容。 其中一个张量缺少维度:n 比 m 少一个维度,在比较时自动对齐并在前面补 1。 因此,这个例子同时满足了广播的所有三种兼容条件。

正在加载今日诗词....
欢迎关注我的其它发布渠道