[Pytorch]다차원 Tensor의 쌓기, cat & stack
카테고리: Pytorch
Pytorch에서 다차원 텐서 곱
1. torch.cat VS torch.stack
import torch
# (2, 3) 사이즈 2차원 텐서 2개 생성
a = torch.tensor([[1, 2, 3],
[4, 5, 6]])
b = torch.tensor([[7, 8, 9],
[10, 11, 12]])
1) cat함수
cat함수는 원하는 dimension 방향으로 텐서를 나란히 쌓아준다. 따라서 torch.cat()
을 선언할 때 괄호 안에 dimenstion을 설정해 줘야 한다.
예를 들어, 크기가 (x,b,c)인 텐서와 (y,b,c)인 텐서가 있다. 이 때 dim = 1인 b와 dim = 2인 c는 동일해야 한다. 이 경우 dim = 0 방향으로 concatenation을 진행할 수 있다.
torch.cat([(x,b,c), (y,b,c)], dim=0)을 선언해주면 결론적으로 크기가 (x+y, b, c)인 텐서가 출력된다.
torch.cat([a, b], dim = 0)
'''
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]]), size = (2+2, 3) = (4, 3)
'''
torch.cat([a, b], dim = 1)
'''
tensor([[ 1, 2, 3, 7, 8, 9],
[ 4, 5, 6, 10, 11, 12]]), size = (2, 3+3) = (2, 6)
'''
2) Stack함수
cat함수는 원하는 dimension방향으로 텐서를 나란히 쌓아준다. 반면 stack함수는 텐서를 새로운 차원에 차곡차곡쌓아준다. 예를 들어, (x,y,z)사이즈를 가지는 텐서에 dim = 2에 텐서 3개를 쌓는다면 (x,y,3,z)가 된다.(같은 사이즈의 텐서끼리만 쌓을 수 있다.)
import torch
a = torch.tensor([[1, 2, 3],
[4, 5, 6]])
b = torch.tensor([[7, 8, 9],
[10, 11, 12]])
print(torch.stack([a, b], dim = 0))
'''
tensor([[[ 1, 2, 3],
[ 4, 5, 6]],
[[ 7, 8, 9],
[10, 11, 12]]]), size = (2, 2, 3)'''
print(torch.stack([a, b], dim = 1))
'''
tensor([[[ 1, 2, 3],
[ 7, 8, 9]],
[[ 4, 5, 6],
[10, 11, 12]]]), size = (2, 2, 3)'''
print(torch.stack([a, b], dim = 2))
'''
tensor([[[ 1, 7],
[ 2, 8],
[ 3, 9]],
[[ 4, 10],
[ 5, 11],
[ 6, 12]]]), size = (2, 3, 2)'''
댓글 남기기