[Pytorch]다차원 Tensor의 쌓기, cat & stack
카테고리: Pytorch
Pytorch에서 다차원 텐서 곱
1. torch.cat VS torch.stack
import torch
# (2, 3) 사이즈 2차원 텐서 2개 생성
a = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])
b = torch.tensor([[7, 8, 9],
                  [10, 11, 12]])
1) cat함수
cat함수는 원하는 dimension 방향으로 텐서를 나란히 쌓아준다. 따라서 torch.cat()을 선언할 때 괄호 안에 dimenstion을 설정해 줘야 한다. 
예를 들어, 크기가 (x,b,c)인 텐서와 (y,b,c)인 텐서가 있다. 이 때 dim = 1인 b와 dim = 2인 c는 동일해야 한다. 이 경우 dim = 0 방향으로 concatenation을 진행할 수 있다.
torch.cat([(x,b,c), (y,b,c)], dim=0)을 선언해주면 결론적으로 크기가 (x+y, b, c)인 텐서가 출력된다.
torch.cat([a, b], dim = 0)
'''
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]]), size = (2+2, 3) = (4, 3)
'''
torch.cat([a, b], dim = 1)
'''
tensor([[ 1,  2,  3,  7,  8,  9],
        [ 4,  5,  6, 10, 11, 12]]), size = (2, 3+3) = (2, 6)
'''
 
2) Stack함수
cat함수는 원하는 dimension방향으로 텐서를 나란히 쌓아준다. 반면 stack함수는 텐서를 새로운 차원에 차곡차곡쌓아준다. 예를 들어, (x,y,z)사이즈를 가지는 텐서에 dim = 2에 텐서 3개를 쌓는다면 (x,y,3,z)가 된다.(같은 사이즈의 텐서끼리만 쌓을 수 있다.)
import torch
a = torch.tensor([[1, 2, 3],
                  [4, 5, 6]])
b = torch.tensor([[7, 8, 9],
                  [10, 11, 12]])
print(torch.stack([a, b], dim = 0))
'''
tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],
        [[ 7,  8,  9],
         [10, 11, 12]]]), size = (2, 2, 3)'''
print(torch.stack([a, b], dim = 1))
'''
tensor([[[ 1,  2,  3],
         [ 7,  8,  9]],
        [[ 4,  5,  6],
         [10, 11, 12]]]), size = (2, 2, 3)'''
print(torch.stack([a, b], dim = 2))
'''
tensor([[[ 1,  7],
         [ 2,  8],
         [ 3,  9]],
        [[ 4, 10],
         [ 5, 11],
         [ 6, 12]]]), size = (2, 3, 2)'''
 
 
      
    
댓글 남기기