pytorch tensor 操作:合并、分割、维度变换

torch.Tensor — PyTorch 1.9.0 documentation

查看维度:

1
2
3
a=torch.randn(3,4)
a.size()
# torch.Size([3, 4])

对 tensor 进行 reshape: tensor.view

定义:将 tensor 中的元素,按照顺序逐个选取,凑成 (shape) 的大小

1
2
a = torch.randn(3, 4)
b = a.view(2,6)

把原先tensor中的数据按照行优先的顺序排成一个一维的数据(这里应该是因为要求地址是连续存储的),然后按照参数组合成其他维度的tensor。比如说是不管你原先的数据是[[[1,2,3],[4,5,6]]] 还是 [1,2,3,4,5,6],因为它们排成一维向量都是 6 个元素,所以只要 view 后面的参数一致,得到的结果都是一样的。

1
2
3
4
5
6
7
8
9
10
11
a=torch.Tensor([[[1,2,3],[4,5,6]]])
b=torch.Tensor([1,2,3,4,5,6])

print(a.view(1,6)) # tensor([[1., 2., 3., 4., 5., 6.]])
print(b.view(1,6)) # tensor([[1., 2., 3., 4., 5., 6.]])

a=torch.Tensor([[[1,2,3],[4,5,6]]])
print(a.view(3,2))
# tensor([[1., 2.],
# [3., 4.],
# [5., 6.]]) 相当于就是从 1,2,3,4,5,6 顺序的拿数组来填充需要的形状。

tensor 交换维度:tensor.permute

定义:将tensor的维度换位。

1
2
a = torch.randn(2, 3, 4) # torch.Size([2, 3, 4])
b = a.permute(2, 0, 1) # torch.Size([4, 2, 3])

对 tensor 维度进行压缩:tensor.squeeze

定义:对数据的维度进行压缩,去掉维数为 1 的的维度

1
2
3
4
a=torch.randn(1, 2, 1, 3, 4)

x = a.squeeze() # 去掉所有为 1 的维度:torch.Size([2, 3, 4])
y = a.squeeze(dim=2) # 去掉维度为 1 的 dim 维度:torch.Size([1, 2, 3, 4])

对 tensor 维度进行扩充:tensor.unsqueeze

定义:给指定位置加上维数为 1 的维度

1
2
3
a=torch.randn(2, 3, 4)

x = a.unsqueeze(dim=1) # torch.Size([2, 1, 3, 4])

tensor 维度扩张:tensor.expand

定义:对 tensor 的维度进行扩张。

如果某个维度参数是 -1,代表这个维度不改变。

tensor 可以被 expand 到更大的维度,新的维度的只是前面的值的重复。新的维度参数不能为 -1。

expand 一个 tensor 并不会分配新的内存,而只是生成一个已存在的 tensor 的 view。

1
2
3
4
5
6
7
8
9
10
11
x = torch.tensor([[1], [2], [3]])
x.size()
#torch.Size([3, 1])
x.expand(3, 4)
# tensor([[ 1, 1, 1, 1],
# [ 2, 2, 2, 2],
# [ 3, 3, 3, 3]])
x.expand(-1, 4) # -1 means not changing the size of that dimension
# tensor([[ 1, 1, 1, 1],
# [ 2, 2, 2, 2],
# [ 3, 3, 3, 3]])

经过实践发现,使用 expand 可以增加新的一个维度,但是只能在第 0 维增加一个维度,增加的维度大小可以大于 1,比如原始 t = tensor(X,Y),可以 t.expand(Z,X,Y),不能在其他维度上增加;expand 拓展某个已经存在的维度的时候,如果原始 t = tensor(X,Y),则必须要求 X 或者 Y 中至少有 1 个维度为 1,且只能 expand 维度为 1 的那一维。

将 tensor 的指定维度合并为一个维度:torch.flatten

1
torch.flatten(input, start_dim=0, end_dim=-1)

start_dim: flatten 的起始维度。

end_dim: flatten 的结束维度。

1
2
3
4
a=torch.randn(2, 3, 4)

x = torch.flatten(a, start_dim=1) # torch.Size([2, 12])
y = torch.flatten(a, start_dim=0, end_dim=1) # torch.Size([6, 4])
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
a: tensor([[[-0.2499, -0.0171,  1.1206,  0.5814],
[ 0.9040, -0.6853, 0.1916, -1.3254],
[ 0.3648, -1.7594, -0.2376, -0.0100]],

[[-0.4587, 0.4200, -0.3693, 1.5906],
[ 0.8472, -1.4564, 0.8263, -2.2202],
[ 0.9074, 2.2779, -0.2543, 0.2251]]])

x: tensor([[-0.2499, -0.0171, 1.1206, 0.5814, 0.9040, -0.6853, 0.1916, -1.3254,
0.3648, -1.7594, -0.2376, -0.0100],
[-0.4587, 0.4200, -0.3693, 1.5906, 0.8472, -1.4564, 0.8263, -2.2202,
0.9074, 2.2779, -0.2543, 0.2251]])

y: tensor([[-0.2499, -0.0171, 1.1206, 0.5814],
[ 0.9040, -0.6853, 0.1916, -1.3254],
[ 0.3648, -1.7594, -0.2376, -0.0100],
[-0.4587, 0.4200, -0.3693, 1.5906],
[ 0.8472, -1.4564, 0.8263, -2.2202],
[ 0.9074, 2.2779, -0.2543, 0.2251]])

将两个 tensor 拼接起来:torch.cat

定义:把2个 tensor 按照特定的维度连接起来。

要求:除被拼接的维度外,其他维度必须相同

1
2
3
4
5
6
7
import torch
a=torch.randn(3,4) #随机生成一个shape(3,4)的tensor
b=torch.randn(2,4) #随机生成一个shape(2,4)的tensor

torch.cat([a,b],dim=0)
#返回一个shape(5,4)的tensor
#把a和b拼接成一个shape(5,4)的tensor,

将两个 tensor 堆叠起来:torch.stack

定义:增加一个新的维度,来表示拼接后的2个 tensor。

直观些理解的话,咱们不妨把一个 2 维的 tensor 理解成一张长方形的纸张,cat相当于是把两张纸缝合在一起,形成一张更大的纸,而stack相当于是把两张纸上下堆叠在一起。
要求:两个tensor拼接前的形状完全一致

1
2
3
4
5
6
7
8
a=torch.randn(3,4)
b=torch.randn(3,4)

c=torch.stack([a,b],dim=0)
#返回一个shape(2,3,4)的tensor,新增的维度2分别指向a和b

d=torch.stack([a,b],dim=1)
#返回一个shape(3,2,4)的tensor,新增的维度2分别指向相应的a的第i行和b的第i行

将 tensor 进行分割:torch.split

定义:根据长度去拆分 tensor

1
2
3
4
5
6
7
a=torch.randn(3,4)

a.split([1,2],dim=0)
#把维度0按照长度[1,2]拆分,形成2个tensor,shape(1,4)和 shape(2,4)

a.split([2,2],dim=1)
#把维度1按照长度[2,2]拆分,形成2个tensor,shape(3,2)和shape(3,2)

将 tensor 均等分割:torch.chunk

定义:均等分的 split,但是当维度长度不能被等分份数整除时,虽然不会报错,但可能结果与预期的不一样,建议只在可以被整除的情况下运用。

1
2
3
4
5
6
a=torch.randn(4,6)

a.chunk(2,dim=0)
#返回一个shape(2,6)的tensor
a.chunk(2,dim=1)
#返回一个shape(4,3)的tensor

Pytorch:Tensor的合并与分割 - 简书 (jianshu.com)