[논문리뷰]Reformer: The Efficient Transformer

Date:     Updated:

카테고리:

Kitaev, Nikita, et al. “Reformer: The Efficient Transformer.” ArXiv:2001.04451 [Cs, Stat], 18 Feb. 2020, arxiv.org/abs/2001.04451.

Problem Statement

Drawbacks of Vanilla 트랜스포머

1. Attention 구조에 의한 메모리 문제

입력으로 길이가 \(L\)인 Sequence를 받는데, 시간 복잡도와 공간 복잡도는 \(O(L^2)\)이 된다. 트랜스포머에서 Attention 연산은 Dot-Product Attnetion이다. 수식으로 좀 더 구체화하면 다음과 같다.

\[Attention(Q, K, V) = softmax(\frac{QK^T}{\sqrt{d_k}})V\]
  • Query(\(Q\)): 영향을 받는 단어 (객체)
  • Key(\(K\)): 영향을 주는 단어 (주체)
  • Value(\(V\)): Key에 대응되는 값

시간 복잡도(Time Complexity)가 \(Q\)의 크기와 \(K\)의 크기에 곱에 비례한다. 이는 데이터의 길이가 10배 길어지면 연산 자체는 100배 더 많아지는 것과 같다.

2. N-stacked Residual Connection에 의한 메모리 문제
Vanilla 트랜스포머의 Encoder와 Decoder의 Layer수가 너무 많다는 문제점이 있다. N개의 Layer 층을 이루기 때문에 N배 많은 메모리를 필요로 한다. 마지막 Layer에서부터 시작 Layer까지 backpropagation을 하면서 미분값을 구하고, parameter를 업데이트 히는데, 층 수가 많아지면 많아질수록 연산에 더 많은 메모리가 요구된다. 이는 N개의 Layer가 쌓이면 그만큼 Residual Connection도 늘어나므로, 연산이 증가함을 의미한다.

3. Feed Forward Layer에 의한 메모리 문제
Feed Forward Layer(FFN)이 attention activation의 깊이보다 더 클 수 있다. 실제로 FFN이 각 Attention Layer의 출력에 모두 적용이 되어야 한다. 이 구조가 차지하는 메모리는 따라서 Sequence의 길이(\(L\))와 해당 layer의 입출력 차원의 곱에 비례한다. 문제는 보통 이 때 사용되는 입출력 차원은 대개 모델 임베딩의 차원에 비해 크다는 것이다.

\[FFN(x) = max(0, x \cdot W_1 + b_1) \; \cdot \; W_2 + b_2\]

트랜스포머에서도 입력 시퀀스는 512개의 토큰을 maximum으로 하지만, 실제로 입력 차원수는 2048이다. 따라서 데이터의 길이가 충분히 길면 이 FFN구조가 차지하는 메모리도 무시할 수 없게 된다. 즉, Sequence의 길이에 보통 4배 정도로 Layer의 차원수를 설정하고, Sequence가 길어지면 그에 따라 계산해야하는 파라미터수가 늘어나므로, 이것의 계산 복잡도가 시퀀스에 비례해 커지는 것을 말한다.



Related Work

1. Understanding about Memory

1

데이터를 load하고, 모델에 구조에 맞게 forward-propagation(순전파), back-propagation(역전파)의 과정을 거친 후 모델의 파라미터가 업데이트된다. 이 때, 그림에서와 같이 역전파 이전에는 연산의 중간 결과물인 \(b_1, b_2\)등을 저장하고 있는다. 이는 메모리 측면에서 단점으로 작용할 수 있으며, 트랜스포머에서 여실히 보여주게된다.

1

트랜스포에서 학습 시 Memory를 증가시키는 요인은 역전파하기 전까지 중간결과물(\(b_n\))을 저장해야하기 때문이다. 즉, batch size가 메모리 사용량에 지대한 영향을 미친다.

1

또한 트랜스포머에서 batch size이외에도 메모리 사용량을 증가시키는 요소는 여러가지가 더 있다. 바로 Layer 수에 영향을 받는 모델 깊이, Hidden size에 해당하는 모델 넓이, 그리고 입력 시퀀스의 길이에 해당하는 문장 길이 이다. 따라서 메모리의 총량은 (모델깊이 X 모델넓이 X 문장길이 X Batch)에 의해 결정된다.

1

이를 그림으로 표현하면 위와 같다. 이처럼, 학습 시 메모리 사용량을 줄이려면 (모델깊이 X 모델넓이 X 문장길이 X Batch)중 파라미터를 일부 조정하면서 떨어트릴 수 있다. 하지만, Reformer에서 말하는 것은 단순히 파라미터 조정을 통해 메모리 사용량을 줄이는 것이 아니다. Reformer가 보여주는 것은 바로 메모리 효율성이다. 그러면 하이퍼파라미터를 조절하지않고 어떻게 하면 효율적으로 만들 수 있는지가 관건이다. 그 답은 그리고 앞에서 나온 것과 같이 연산의 중간 산물과 관련이 있다.



Method

1. Locality-Sensitive Hashing (LSH)

1

Hashing은 해시 function(해시 함수) algorithm을 말하며 임의의 길이의 데이터를 고정된 길이의 데이터로 매핑하는 함수를 해시 함수이라고 한다. 데이터를 미리 Hashing해두면 해시값만으로 쉽게 데이터를 찾을 수 있다. 보통 해시값은 연결된 데이터와 전혀 관련이 없을 때가 만고, 그렇기 때문에 전체 데이터 분포에서 상대적 위치를 확인하거나 한 데이터와 가장 가까운 다른 데이터를 찾는 등 데이터에 대한 비교 분석을 할 때 반드시 실제 데이터 값을 비교하는 연산이 필요하다.

이 논문에서 Hashing을 사용함에 있어 핵심은, 가까운 거리에 위치한 데이터들은 가까운 해시값을 갖도록 구성하는 것이다. 이러면 비교 연산을 해시값에 대한 연산으로 근사할 수 있다. 이렇게 가까운 값들을 가까운 해시값을 가지도록 Hashing하는 것을 Locality-Sensitive Hashing(LSH)라 한다. LSH를 보기 전 기존 트랜스포머의 Attention을 먼저 알아야 한다.

1

Vanilla 트랜스포머에서는 세 가지 종류의 Attention이 존재한다. 바로 Encoder의 Self-Attention과 Decoder의 Masked Self-Attention, 그리고 Encoder의 최종 레이어의 출력과 디코더의 Masked Self-Attention을 거친 출력과 Cross Attention을 하는 Encoder-Decoder Attention이다. 또한 이 세 가지 Attention은 기본적으로 Dot-Product Attention이다.

1

이때 it의 Attention 중 5개 (‘The’, ‘animal’, ‘street’, ‘it’, ‘.’ ) 를 제외하면 제대로 그 영향력이 전달되지 않은 것을 볼 수 있다. 그래서 didn'tcross등과 같은 단어들은 연산 결과 Attention을 받지 않았기 때문에 실제로는 Sparse하게 된다. 이를 바꿔말하면, 우리는 Query(Q)에 대해 밀접한 연관을 가진 K에 대해서만 Attention을 하면 된다. 하지만 기존 트랜스포머는 이 부분을 비효율적으로 찾고, 모든 단어와 단어 사이 Attention을 계산하기 때문에 비효율 적인 것이다. Reformer에서는 LSH를 통해 이 문제를 해결하였다. 그러면 여기서 문제는 어떻게 LSH를 기계적으로 수행할 수 있을지이다.

1) Simple LSH

1

LSH는 고차원 데이터에 대해 nearest neighbor 알고리즘을 이용해 효율적으로 근사하는 방법이다. 두 점 \(p, \; q\)가 충분히 가깝다면 해시 함수을 거친 결과가 \(해시(p) == 해시(q)\)일 것이다. 임의의 직선 \(h1, \; h2, \; h3\)를 그어 영역에 따라 0과 1로 분류 한다. 각 포인트는 3개의 해시 값을 가지게 되고, 그 해시값이 버켓이라고 보면 된다. 이때 각 충분히 가까운 포인트는 거의 같은 버켓에 들어갔을음 볼 수있다.


2) Angular LSH

Reformer에서 실제로 사용한 방법은 Angular LSH이다. 이 방법은 방향 성분만을 활용하여 해시값을 생성한다. 전체적인 과정은 다음과 같다.

1

  1. 전체 데이터 포인트들의 벡터를 단위 구면에 맵핑한다. 이러면 전체 데이터 포인트를 오직 각도만 사용해서 기술할 수 있다. (\(r=1\)인 구면 좌표계)
  2. 이제 각 각도가 어느 사분면에 있는지 확인한다. 비슷한 데이터들은 같은 사분면에 있다. 따라서 사분면의 번호를 해시값으로 사용한다면, 비슷한 데이터들을 가깝게 구성할 수 있다.
  3. 이제 사상한 구면을 필요한 만큼 임의로 회전시킨다. 데이터가 가까우면 가까울수록 전체 해값을 공유할 가능성이 높아지고, 충분히 많은 해시값을 사용하면 데이터를 구별하는 변별력이 생긴다.

1

예를 들어, 2차원 임베딩 벡터가 \(X_1 = (3, 4)\), \(X_2 = (-12,5)\)가 있다고 가정하고, 이를 반지금 1인 구에 맵핑하면 각각의 점들은 \(X_1^{'} = (\frac{3}{5}, \frac{4}{5})\)와 \(X_2^{'} = ( - \frac{12}{13}, \frac{5}{13})\)이 된다.(그림에서는 이해를 위해 축을 회전) 다음으로 임의의 벡터 \(Y = (4,3)\)이 주어졌다고 가정하면 이 벡터는 원 위에서 \(Y^{'} = (\frac{4}{5}, \frac{3}{5})\)으로 맵핑된다. 똑같이 원을 돌리면서 해시값을 표현해보면 \(Y\)의 해시값은 (1, 4, 2)가 된다.

이 값은 \(X_1\)의 해시갑과 같고, \(X_2\)의 해시값과는 다르므로 데이터 간의 직접 비교 연산 없이 해시값이 일치하는지 보는 것만으로도 가까운 점들을 선별할 수 있다. 하지만 ‘이런 방식으로 데이터를 저장하면 방향 값은 살아있어도 크기에 대한 값은 손실되는 것이 아닌가?’ 라는 의문이 생기고 이에 대한 실험은 Ablation Study에서 기술하였다.


3) LSH Attention

1

기존의 트랜스포머에서 길이가 \(L\)인 시퀀스가 입력으로 들어오면 그에 따른 \((Q, K, V)\)값들은 \((batch, length, d_{model})\)의 차원수를 가진다. 그리고 \(QK^T\)는 \((batch, length, length)\)의 shape을 가지기 때문에 결론적으로 시간 복잡도와 공간 복잡도가 모두 \(O(L^2)\)이 된다. 이로 인해 Long-Sequence가 들어오면 학습에 드는 비용이 기하급수적으로 증가한다. 하지만, LSH등을 사용하면 가까운 임베딩끼리만 Attention을 진행할 수 있기 때문에 수식이 다음과 같이 바뀌며, 이는 메모리 효율적이다.

1

But it is important to note that the QKT matrix does not need to be fully materialized in memory. The attention can indeed be computed for each query qi separately, only calculating \(softmax(\frac{q_iK^T}{\sqrt{d_k}})\) once in memory, and then re-computing it on the backward pass when needed for gradients. This way of computing attention may be less efficient but it only uses memory proportional to length.

최종적으로 LSH attention의 softmax안에 들어가는 식은 다음과 같다. 여기서 \(\mathcal{P_i}\)는 query의 position이고, \(z\)함수는 partition function이다. 여기서 batch까지 고려하면 다시 식은 아래와 같이 바뀌고, 최종적인 attention식이 된다.

1

기존의 Attention에서는 입력으로 들어온 Sequence 내 모든 토큰들이 서로 서로 attention을 진행하기 때문에 히트맵에서 모든 부분에 영향력을 나타내는 정도가, 비록 그 어텐션 값이 작을지라도 표시된다. 반면, LSH Attention의 히트맵 같은 경우 같은 bucket안에 들어있는 단어들끼리만 어텐션을 진행하므로 다음과 같이 히트맵이 변한다.

1


4) Summary

1

1) 각 토큰의 (Key, Query), Value를 생성한다. 이 때, Key == Query가 되게하기 위해 같은 layer에서 생성한다. LSH를 적용하기 위해 Key, Query를 같은 공간에 projection하는 것이다. 2) Locality-Sensitive Hashing에 적용한다. 즉, 각 토큰에 Index를 부여한다. 3) 이후 Index가 같은 것 끼리 Sorting을 한다. 4) 시퀀스를 Chunking한다.

1

5) 여기서 조건을 만족하는 key에 attention을 적용한다.

  • Query와 Key 토큰이 같은 Index를 보유
  • Query 토큰은 같은 구역 또는 바로 앞 구역의 key 토큰에만 Attention을 한다.

1

2. Reverisble Transformer

1

Reverseible 트랜스포머는 기존에 연구된 The Reversible Residual Network (RevNet)의 아이디어를 적용한 것이다. 구조를 요약하면 Encoder, Decoder에서 attention layer와 feed-forward layer를 하나로 묶어버린 것이다. 이렇게하면, 각 layer안에 activation을 저장할 필요가 없다. 따라서 저장 공간을 줄어든다.

1

1

RevNet은 원래 이미지 처리에 사용되는 ResNet 구조에서 메모리를 효율적으로 사용하기 위해 고안되었다. 기존의 ResNet에서 사용되는 계산, 일반적인 Residual Connection의 연산은 위와 같다. 특징적인 것은, 하나의 입력에 대해 하나의 출력이 나온다는 구조이다. 하지만 이러한 방식은 \(x\)에서 \(y\)를 계산할 수는 있어도 \(y\)에서 \(x\)를 역으로 계산해낼 수는 없다. 따라서 들어가는 입력 \(x\)와 출력 \(y\)를 (\(x_1, x_2\)), (\(y_1, y_2\))쌍 형태로 바꾼다. 이를 도식화하면 다음과 같다.

1

여기서 \(F(\cdot)\)은 Attention Layer를, \(G(\cdot)\) Feed-Forward Layer이다. 이렇게 식을 나눠서 입출력을 pair로 쓰면 \(y_1\)과 \(y_2\)가 주어졌을 때, \(x_2 = y_2 - G(y_1)\)로 역산할 수 있고, \(x_1 = y_1 = F(x_2)\)로 역산할 수 있다. 이렇게 함으로써 임의의 시점의 출력값을 토대로 그 출력에 대한 입력값을 표현할 수 있다. 따라서, 중간 결과를 저장할 필요가 없이 Forward연산을 반복적으로 적용해 수치적 미분값을 얻을 수 있게 된다. 또한 두 단계의 Sequential한 연산을 하나의 과정으로 합친 것과 같아, activation function을 하나만 사용해도 된다. 즉, Activation function에 대한 파라미터를 공유하여 메모리 효율적이다.

1

최종적으로 Reverisble Transformer의 식을 표현하면 위와 같다. Reverisble Transformer에서는 activation을 각 layer마다 저장할 필요가 없기 때문에 \(n_l\)에 관한 부분이 수식에서 없어진다. 이로서 Reversible Network 적용 후 memory의 최대 사용량은 다음과 같다.

1

3. Chuncking

1

Reverisble Transformer를 통해서 \(n_l\) 에 관한 부분이 사라졌더라고해도, 여전히 많은 메모리를 요구한다. FFN은 특히 벡터의 차원수로 $$d_{ff} = 4K$ 또는 이보다 높은 차원수를 요구한다. 즉, Sequence가 길어질수록 FFN을 연산하는데도 많은 메모리가 소요된다. Chuncking에서 제안하는 방식은 이런 Computational cost를 줄이기 위해 하나의 FFN을 시퀀스에 위치와는 독립적으로 만들어 쪼갠 후 그걸 각각 연산하자는 것이다.

1

1

computations in feed-forward layers are completely independent across positions in a sequence, so the computation can be split into c chunks

1

이 Layer는 일반적으로 모든 위치에 대해 병렬로 연산을 수행하지만, 한 번에 하나의 Chunk에 대해 연산하여 메모리를 줄일 수 있다. 또한 Reverisble Transformer의 연산에더도 마찬가지로 Chunking을 할 수 있다. FFN외에도, 많은 단어를 요구하는(\(d_{model}\) word type보다 많은 단어를 요구) 모델의 경우 출력에서 log-likelihoo로 chuncking하고 한 번에 Sequence의 섹션에 대한 loss를 계산한다.

1

Final Model Structure

1) LSH Attention

앞서 각각의 특징을 살펴봤으므로 이제, 최종적인 모델 구조를 봐야한다. 논문에서는 각 데이터 포인트에 Locality Sensitive Hashing을 적용한다. 요점은 Hash값이 일치하는 데이터에 대해서만 Attention을 계산하는 것이다. 이를 위해 논문에서는 \(Query(Q) = Key(K)\)라는 가설을 세운다. 각 데이터 포인트를 Q로 projection하는 행렬과 K로 projection하는 행렬을 동일하게 설정한 것이다. 이 가설에서 Query와 Key는 본직적으로 같은 값은 값으로 설정한다.

\[Query(Q) == Key(K)\]

이러한 가설은 직관적으로 납득되지 않으나, 논문에서는 매우 큰 데이터셋에서는 이 가설을 사용해도 성능 저하가 일어나지 않음을 증명하였다. 이러한 가설을 세울 수 있는 근간은, 한 문장 내에서 중요한 단어는 몇 단어 안되고 나머지는 모두 불용어에 속하기 때문이다.

1

LSH Attention을 적용하는 절차는 첫 번째는 Q = K이므로 각 데이터 포인트에 대한 Q = K값을 일려로 된 벡터 형태로 나타낸 후, 각 데이터 포인트에 LSH를 적용하는 것이다. 그 이후 같은 Hash값을 가진 데이터 포인트끼리 버킷(Bucket)으로 묶는다. 각 해시 버킷에 임의로 순서를 매겨서 정렬한다.

1

각 버킷에서 높은 확률로 데이터 포인터들이 불균형하게 할당될 것이다. 따라서 데이터 포인터들을 고정된 크기의 구역으로 분할한다.(Chuncking) 이후 Attention Weight을 계산하는데, 다음 조건을 만족하는 쌍들에 대해서만 계산한다.

  1. 두 데이터 포인트가 같은 버킷에 있어야 한다.
  2. 두 데이터 포인트는 서로 같은 chunk에 있거나, attention 연산의 마지막 데이터 포인트는 시작 데이토 포인트가 있는 구역 바로 앞에 있어야 한다.

1

일반적인 트랜스포머는 각 데이터 포인트들이 스스로를 attend할 수 있는 구조이지만, 이 구조에서는 Q=K를 따르기 때문에 Self-Attention의 내적값이 다른 Attention 내적값보다 언제나 매우 크다. 따라서 조건을 만족하는 쌍이 자신으로만 구성되는 경우를 제외하면 스스로를 attend하는 것을 허용하지 않으며 이를 그림으로 표현하면 위와 같다.

같은 부분과 앞 부분에만 접근하므로 각 데이터 포인트가 최대로 attend하는 수는 각 chunking된 크기의 2배 만큼이다. 만약 이 시퀀스의 길이를 \(l\)이러고 하고, chunking 수를 \(c\)라 한다면, partition의 크기는 \(\frac{l}{c}\)이다. 또한 Attention의연산 수는 \(l \cdot (2 \frac{l}{c})^2\)에 비례한다. 만약 \(c\)가 충분히 크다면 \(l\)에 선형적으로 비례하는 구조로 간주할 수 있고, 원래 트랜스포머가 모든 쌍을 attend하기 때문에 \(l^2\)에 비례하는 복잡도를 가지는 것을 생각하면 상당히 개선된 것이다.


2) FFN with Reversible Network

Attention Layer와 Feed Forward Layer가 이루는 블록을 Residual Network로 간주할 수 있으며 따라서 Reversible Network에서와 같이 입출력을 둘로 나눌 수 있음을 보였다.

1

이 구조로 변형된 트랜스포머는 N층 블록의 결과물을 모두 저장할 필요 없이 한 층에 대해서만 메모리를 사용하여 연산을 수행한다. PyTorch에서 backward연산을 forward 함수로 구현하는 형태를 생각해보면 이해가 쉽다.


3) FFN with Chunking

1

FFN은 데이터 포인트의 위치와 무관하다. 각 구역으로 나뉜 부분 \(c\)에 FFN을 순차적으로 적용하면 하나의 부분에 대한 FFN 분량의 메모리만 필요하다. 논문의 공간 복잡도와 시간 복잡도를 비교한 표는 이 구조가 얼마나 효율적으로 트랜스포머구조를 간소화하는지를 보여준다. 표에서와 같이 시공간 복잡도의 이득 대부분은 LSH Attention에서 기인한다.


Experiment

1. Dataset

  • imagenet64: 이미지넷 데이터셋은 이미지에 등장한 객체의 종류를 분류하는 작업과 관련있으며 20,000개 이상의 종류 아래 14,000,000 개 이상의 이미지로 구성된다.
  • enwiki8-64K: enwiki8은 전체 영어 위키피디아 데이터를 압축하는 작업으로 본 논문에서 사용하는 enwiki8-64K는 각 부분이 64K 토큰으로 구성된 데이터이다.

1

1

실험은 트랜스포머로 입력 데이터를 인코딩한 뒤 다시 디코딩하는 압축 작업으로 이루어지며, 제안한 구조의 성능은 bit-per-dim으로 측정되었다. 데이터를 온전하게 표현하기 위한 인코딩 비트가 적으면 적을수록 압축이 효과적으로 되었음을 의미한다.

2. 가설 검증

1) Reformer 구조에서 Query는 Key와 같다고 주장해도 무방

첫 번째 가설은 ‘Reformer 구조에서 Query는 Key와 같다고 간주해도 무방하다.’이다. 앞서 모델 구조를 살펴볼 때 직관적으로는 납득하기 어려우나, 실험을 통해 증명하였다고 한 부분이다. 다시 말해 Q와 K를 공유하는 것의 효과(Effect of Sharing QK)를 알아보는 실험이다.

1

실험 결과, Query와 Key 행렬을 분리한 경우와 그렇지 않은 경우의 성능 차이는 미미했다. 오히려 enwiki8 데이터셋에서 Query와 Key를 공유했을 때 더 빠르게 수렴함을 확인할 수 있다. 즉, 위의 Query-Key 가설에서 비롯된 구조 변경은 성능에 영향을 주지 않는다.


2) Reformer의 Attention Block은 Reversible Layer 형태로 중첩

두 번째 가설은 ‘Reformer의 Attention Block은 Reversible Layer 형태로 중첩할 수 있다.’ 이다. 이번 실험에서는 실제로 Trasformer 모델에 Reversible Layer를 적용하면 어느 정도 성능에 기여할 수 있는지를 본다.(Effect of Reverisble Layer)

1

실험 결과, Reversible 형태로 구성한 구조와 그렇지 않은 구조 사이의 성능 차이는 매우 작다. 즉, Reversible 구조의 활용은 성능에 영향을 주지 않는다.


3) LSH Attention을 활용

마지막 가설은 ‘LSH Attention을 활용하면 기존 구조의 성능을 크게 저하하지 않으면서 입력 길이에 선형인 시간 복잡도를 보이게 개선할 수 있다.’이다. 이 가설이 이 논문의 핵심이며 가장 중요한 Contribution이다.

1

Imagenet64 데이터셋을 사용한 실험 결과에서 병렬적인 Hash를 많이 둘수록 Full Attention을 한 경우와 성능 차이가 줄어드는 것을 확인할 수 있다. 아울러 Hash를 8개 이상 두면 Full Attention할 때와 거의 비등한 절대 성능을 보인다. 저자들은 정확도와 계산량 사이의 경중에 따라 Hash값을 가감할 수 있다고 주장한다.

1

LSH Attention의 효과는 실제 Training Time을 측정할 때 더 부각된다. enwiki8에서 LSH Attention으로 계산하면 데이터 포인트의 길이가 길어져도 매 단계마다 소요시간이 거의 일정하나, Full Attention을 사용한 기존 모델에서는 선형으로 증가함을 알 수 있다. LSH Attention은 Layer를 많이 둘수록 더 높은 성능을 기록했지만, 개수가 12개를 넘어가면 성능 향상폭은 미미하였다.



Contribution

  1. Locality-Sensitive-Hashing (LSH)를 통해 서로 비슷한 임베딩이 비슷한 단어 쌍만 Attention을 진행한다.
    • 영향력이 높은 단어 쌍끼리만 가중치를 계산할 수 있다면 성능 저하 없이 복잡도가 대폭 줄어든 구조를 얻을 수 있다.
    • 영향력이 높은 단어 쌍은 임베딩 공간에서 서로 비슷한 pair이며, 이런 pair들은 LSH를 이용해서 빠르게 찾을 수 있다.
    • 결론적으로 attention computation이 \(O(L^2)\)에서 \(O(LlogL)\)로 감소한다.
  2. Chunking을 이용해 메모리를 줄인다.
    • Feed-Forward Layer (FFN)는 Attention Layer와 다르게 데이터 포인트의 위치에 무관하게 계산된다.
    • 따라서 데이터 포인트들을 묶어줄 수 있다면 계산하는 단위를 나눌 수 있고, 전체 데이터 포인트에 대한 FFN의 가중치를 한 번에 메모리에 올리지 않아도 된다.
  3. Reversible Layer를 이용해 메모리를 줄인다.
    • 트랜스포머는 Attention Layer와 FFN의 residual connection으로 되어 있다.
    • 본 논문에서는 기존 수식을 살짝 변형해 출려을 이용해서 입력을 다시 복원할 수 있는 형태로 기술했다.
    • 출력으로 입력을 복원할 수 있으므로, 각 중간 단계의 입출력을 저장할 필요 없이 바로 출력에서 미분값을 계산하여 훈련할 수 있다.

이로써 기존의 Naive 트랜스포머 모델이 Attention을 효과적으로 사용하지 못하므로 속도/크기/성능 면에서 최적화될 여지가 많다는 점을 보여주며 이를 효과적으로 해결하였다. 또 하나는 문장이나 문단 등의 짧은 단위에서는 기존 모델을 대체하기 어렵지만 문서나 책, 논문 등 매우 긴 문단 또는 책자 단위로 학습을 할 때 효과적으로 적용될 potential이 충분하다.



Reference

Blog: Reformer, The Efficient 트랜스포머
Blog: 꼼꼼하고 이해하기 쉬운 Reformer 리뷰
Reformer: The Efficient Transformer

NR 카테고리 내 다른 글 보러가기

댓글 남기기