[Pytorch]다차원 Tensor의 곱(Matmul)
카테고리: Pytorch
Pytorch 텐서의 연산
1. Broadcasting
크기가 서로 다른 텐서끼리 연산할 때 자동으로 크기를 맞춰주는 기능이다.
1) 덧셈 & 뺄셈
두 행렬을 서로 더하거나 뺄 때 크기각 같아야 한다. Pytorch에서는 크기가 다르더라도 브로드캐스팅(Broadcasting) 통해 크기가 서로 다른 두 텐서의 연산을 가능케 해준다.
먼저 크기가 같은 경우이다.
t1 = torch.FloatTensor([[10, 100],
[1000, 10000]])
t2 = torch.FloatTensor([[1, 2],
[3, 4]]) # 둘다 크기가 (1, 2)인 텐서
print(t1 + t2)
# tensor([[ 11., 102.],
# [ 1003., 10004.]])
크기가 다른 경우는 Broadcasting을 한다.
t1 = torch.FloatTensor([[1, 4]]) # 크기가 (1, 2)인 텐서
t2 = torch.FloatTensor([[3], [4]]) # 크기가 (2, 1)인 텐서
print(t1 + t2) # tensor([[4., 7.],
# [5., 8.]])
print((t1 + t2).size()) # torch.Size([2, 2])
2) 텐서 곱(matmul)
- 규칙!!
- 두 Tensor가 모두 1차원이면 dot product이고 Scalar를 리턴
- 두 Tensor가 모두 2차원이면 matrix-matrix product를 리턴
- 첫 번째 인수가 1차원, 두 번째 인수 2차원인 Tensor를 곱하면 행렬 곱셈을 위해 첫 번째 인수(1차원) Tesntor이 첫 번째 차원에 추가되고, 곱셈 후 첫 번째 차원이 제거됨
- 첫 번째 인수가 2차원, 두 번재 인수가 1차원 Tensor를 곱하면 matrix-vector product가 리턴
- 두 인수 모두 3차원 이상의 차원을 가진 경우 Batch matrix 곱셈이 리턴.
- 첫 번째 인수가 1차원이라면 배치 행렬 곱셈을 위해 1이 첫 번째 차원에 추가되고, 곱셈 후 첫 번째 차원이 제거됨
- 두 번째 인수가 1차원이라면 배치 행렬 곱셈을 위해 1이 두 번째 차원에 추가되고, 곱셈 후 두 번째 차원이 제거됨
- Non-matrix 차원은 브로드캐스팅된다.(따라서 broadcastable 해야 함.)
- 예를 들어, input: (\(j \times 1 \times n \times n\)) 와 (\(k \times n \times n\)) 일 때 출력은 (\(j \times k \times n \times n\)) 이다.
- 브로드캐스팅 로직은 브로드캐스팅이 가능한지를 판단할 때, 행렬 차원이 아니라 배치 차원만을 고려한다.
- Batch dimension(배치 차원)이란 Tensor의 첫 번째 차원이다.
- 예를 들어 input: (\(j \times 1 \times n \times m\)) 와 (\(k \times m \times p\))일 때 출력은 (\(j \times k \times n \times p\)) 이다.
batch dimension이 Tensor의 첫 번째 dimension인거 중요!!
1. 첫 번째 argument가 1차원 텐서
tensor1 = torch.randn(10)
tensor2 = torch.ones(4, 10, 5)
out = torch.matmul(tensor1, tensor2).size()
print(out) # torch.Size([4, 5])
두 번째 인수가 N차원 텐서의 두 번째 dimension이 일치해야 계산이 가능하다. 두 번째 인수는 3차원 텐서인데 첫 번째 dimension은 batch 축이므로, 그 다음 dimension(두 번째 차원)이 벡터와 일치해야 계산이 가능하다.
2. 두 번째 argument가 1차원 텐서
tensor1 = torch.randn(10, 3, 4)
tensor2 = torch.randn(4)
out = torch.matmul(tensor1, tensor2).size()
print(out) # torch.Size([10, 3])
이번에는, 마지막 dimenstion이 일치해야 계산이 가능하다. 그 외에는 모두 error가 발생한다.
일반적인 rule
batch size dimension을 제외하고는 브로드캐스팅 할 수 있는 차원은 브로드캐스팅을 적용
import torch
t1 = torch.randn(3)
t2 = torch.randn(3)
out1 = torch.matmul(t1, t2) # size: 1
print(out1.size()) # torch.Size([])
#-------------------------------------------------#
t3 = torch.randn(3, 4)
t4 = torch.randn(4)
out2 = torch.matmul(t3, t4) # size: (3, )
print(out2.size()) # torch.Size([3])
#-------------------------------------------------#
t5 = torch.randn(10, 3, 4)
t6 = torch.randn(4)
out3 = torch.matmul(t5, t6) # size: (10, 3)
print(out3.size()) # torch.Size([10, 3])
#-------------------------------------------------#
t7 = torch.randn(10, 3, 4)
t8 = torch.randn(10, 4, 5)
out4 = torch.matmul(t7, t8) # size: (10, 3, 5)
print(out4.size()) # torch.Size([10, 3, 5])
#-------------------------------------------------#
x1 = torch.randn(10, 3, 4)
x2 = torch.randn(4, 5)
out5 = torch.matmul(x1, x2) # size: (10, 3, 5)
print(out5.size()) # torch.Size([10, 3, 5])
참고로 Torch.mm은 연산 자체는 matmul과 비슷하지만, 브로드캐스팅이 불가능하다.
Reference
Torch.mm과 Torch.matmul
Pytorch manual
PyTorch 다차원 텐서 곱(matmul)
댓글 남기기