Pytorch 张量操作
前言
感觉每次看 pytorch 的代码,都会被其中的张量的维度给绕晕,尤其是遇到一些操作改变维度的时候,自己写的代码也是经常报维度的错误。正好最近学弟在看李沐老师的《动手学深度学习》的预备知识部分,为了解答一些问题,我也去重温了这一块知识,并对其中最重要也是最容易绕进去的张量降维与广播机制进行整理。
张量
通过一维张量表示向量
1 | x = torch.arange(4) |
输出
1 | tensor([0, 1, 2, 3]) |
大量文献认为列向量是向量的默认方向,本文中也是如此
降维
这一部分非常容易搞混淆!!!建议多多画图理解,多多推演形成直觉。
默认情况下,调用求和函数会沿所有的轴降低张量的维度,使它变为一个标量。 我们还可以指定张量沿哪一个轴来通过求和降低维度。 以矩阵为例,为了通过求和所有行的元素来降维(轴 0),可以在调用函数时指定 axis = 0。 由于输入矩阵沿 0 轴降维以生成输出向量,因此输入轴 0 的维数在输出形状中消失。
1 | x = torch.arange(6).reshape(2,3) |
1 | [[0,1,2], |
1 | x.sum() # 15 |
1 | x = torch.arange(24).reshape(2, 3, 4) |
1 | tensor([[[ 0, 1, 2, 3], |
基础记忆口诀
- “沿谁求和谁消失”:指定 axis = n 时,张量会沿着该轴方向压缩(该轴的维度在输出中消失)。
- “不指定轴全坍缩”:不指定 axis 时,所有维度求和,最终输出标量。
注意,沿轴 0 求和实际上是指轴 0 的维度消失掉。 “沿某个轴求和 = 该轴的方向被压缩消失,求和操作是跨该轴的切片进行的”
- 将 axis 对应的维度想象为” 被折叠的方向”。
- 输出形状 = 原始形状去掉该轴(如 (a,b,c) 沿 axis = 1 → 去掉 b 得 (a,c))。
维度保持
保持维度:使用 keepdim = True 可保留被压缩的轴(如 (2,2) 沿 axis = 0 求和得 (1,2))。
把张量想象成洋葱,axis = n 就是剥掉第 n 层皮,把这一层的所有切片相加。
输入形状:(d0, d1, …, dn, …) axis = n 求和后 → 该维度消失,形状变为 (d0, d1, …, dn-1, dn + 1, …)
广播机制
广播 (Broadcasting) 是 PyTorch 中一种重要的机制,它允许在不同形状的张量之间进行逐元素操作,通过适当复制元素来扩展一个或两个数组,以便在转换之后,两个张量具有相同的形状。
基本规则
- 从尾部开始对齐: 比较两个张量的形状时,从最右边的维度开始向左比较。
- 维度兼容条件:
- 两个维度相等
- 其中一个维度为 1
- 其中一个维度不存在 (即一个张量的维度比另一个少)
- 扩展操作: 在满足兼容条件的维度上,PyTorch 会自动将大小为 1 的维度复制扩展为另一个张量对应维度的大小。
维度
numpy 中的维度都是用元组来指定的,比如 np.zeros((2,3,2)) 的维度数量是三维的。np.zeros((3,)) 维度数量这是 1 维的,因为 (3) 不是元组它只能算 3 加个括号而已,加个逗号 (3,) 才是元组。
1 | x = np.zeros((2,3,4)) # 维度数量为3,这个数组第一维的维度大小是2,第二维的维度大小是3,第三维的维度大小是4. |
相似写法比较:
1 | x = np.zeros(2,3,4) # 错误! TypeError: Cannot interpret '3' as a data type |
np.zeros(2,3,4) 的错误原因: 这不是语法错误,而是参数传递错误 它会被解释为 np.zeros(2, dtype = 3, other_kwarg = 4),因为 NumPy 会把后面的参数当作关键字参数 实际报错会是 TypeError: data type not understood
就记住表示维度用元组,一维向量的表示可以简化比较特殊
广播示例
1 | import torch |
在示例 3 中,广播过程遵循了 PyTorch 的维度兼容条件,三点兼备,具体分析如下:
初始形状: m: (4, 1, 3)
n: (5, 3)
广播步骤:
- 从右向左对齐维度:
1 | m: (4, 1, 3) |
逐维度比较:
- 最右边维度(第 2 维): 3 == 3 → 兼容
- 中间维度(第 1 维):
- m 是 1,n 是 5 → 其中一个为 1 → 兼容(m 的这个维度会被扩展为 5)
- 最左边维度(第 0 维):
- m 有维度 4,n 没有这个维度 → 兼容(n 会被视为在额外维度上大小为 1,即 (1,5,3),然后扩展为 (4,5,3))
最终广播形状:
- m 从 (4,1,3) 扩展为 (4,5,3)(复制第 1 维)
- n 从 (5,3) 先被视为 (1,5,3),然后扩展为 (4,5,3)(复制第 0 维)
满足的兼容条件: 维度大小相等:最右边的 3 == 3。 其中一个维度为 1:m 的中间维度是 1,与 n 的 5 兼容。 其中一个张量缺少维度:n 比 m 少一个维度,在比较时自动对齐并在前面补 1。 因此,这个例子同时满足了广播的所有三种兼容条件。