pytorch tensor 操作:合并、分割、维度变换
查看维度:
1 | a=torch.randn(3,4) |
对 tensor 进行 reshape: tensor.view
定义:将 tensor 中的元素,按照顺序逐个选取,凑成 (shape) 的大小
1 | a = torch.randn(3, 4) |
把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其他维度的tensor。比如说是不管你原先的数据是[[[1,2,3],[4,5,6]]] 还是 [1,2,3,4,5,6],因为它们排成一维向量都是 6 个元素,所以只要 view 后面的参数一致,得到的结果都是一样的。
1 | a=torch.Tensor([[[1,2,3],[4,5,6]]]) |
tensor 交换维度:tensor.permute
定义:将tensor的维度换位。
1 | a = torch.randn(2, 3, 4) # torch.Size([2, 3, 4]) |
对 tensor 维度进行压缩:tensor.squeeze
定义:对数据的维度进行压缩,去掉维数为 1 的的维度
1 | a=torch.randn(1, 2, 1, 3, 4) |
对 tensor 维度进行扩充:tensor.unsqueeze
定义:给指定位置加上维数为 1 的维度
1 | a=torch.randn(2, 3, 4) |
tensor 维度扩张:tensor.expand
定义:对 tensor 的维度进行扩张。
如果某个维度参数是 -1,代表这个维度不改变。
tensor 可以被 expand 到更大的维度,新的维度的只是前面的值的重复。新的维度参数不能为 -1。
expand 一个 tensor 并不会分配新的内存,而只是生成一个已存在的 tensor 的 view。
1 | x = torch.tensor([[1], [2], [3]]) |
经过实践发现,使用 expand 可以增加新的一个维度,但是只能在第 0 维增加一个维度,增加的维度大小可以大于 1,比如原始 t = tensor(X,Y),可以 t.expand(Z,X,Y),不能在其他维度上增加;expand 拓展某个已经存在的维度的时候,如果原始 t = tensor(X,Y),则必须要求 X 或者 Y 中至少有 1 个维度为 1,且只能 expand 维度为 1 的那一维。
将 tensor 的指定维度合并为一个维度:torch.flatten
1 | torch.flatten(input, start_dim=0, end_dim=-1) |
start_dim: flatten 的起始维度。
end_dim: flatten 的结束维度。
1 | a=torch.randn(2, 3, 4) |
1 | a: tensor([[[-0.2499, -0.0171, 1.1206, 0.5814], |
将两个 tensor 拼接起来:torch.cat
定义:把2个 tensor 按照特定的维度连接起来。
要求:除被拼接的维度外,其他维度必须相同
1 | import torch |
将两个 tensor 堆叠起来:torch.stack
定义:增加一个新的维度,来表示拼接后的2个 tensor。
直观些理解的话,咱们不妨把一个 2 维的 tensor 理解成一张长方形的纸张,cat
相当于是把两张纸缝合在一起,形成一张更大的纸,而stack
相当于是把两张纸上下堆叠在一起。
要求:两个tensor拼接前的形状完全一致
1 | a=torch.randn(3,4) |
将 tensor 进行分割:torch.split
定义:根据长度去拆分 tensor
1 | a=torch.randn(3,4) |
将 tensor 均等分割:torch.chunk
定义:均等分的 split,但是当维度长度不能被等分份数整除时,虽然不会报错,但可能结果与预期的不一样,建议只在可以被整除的情况下运用。
1 | a=torch.randn(4,6) |