[논문리뷰]Learning to Compress Prompts with Gist Tokens (NeurIPS, 2023)

Date:     Updated:

카테고리:

Jesse Mu, Xiang Li, and Noah Goodman. 2023. Learning to Compress Prompts with Gist Tokens. In Advances in Neural Information Processing Systems, A. Oh, T. Naumann, A. Globerson, K. Saenko, M. Hardt, and S. Levine (Eds.), Vol. 36. Curran Associates, Inc., 19327–19352.

Problem Statement

[프롬프트 재사용의 비효율성] Transformer LM은 동일한 프롬프트를 수백·수천 번 재사용할 때마다 전체를 다시 self-attention으로 처리해야 하므로, 계산량과 메모리 사용량이 커진다.

[Finetuning/Distillation의 한계] 특정 프롬프트를 모델 내부로 통합하기 위해선 finetuning이나 distillation이 필요하지만, 이는 새로운 프롬프트마다 별도의 재학습과 weight 저장을 요구하여 확장성이 떨어진다.

[Soft/Prefix Tuning의 제약] Prefix-tuning이나 HyperTuning은 특정 task 데이터셋이 있어야 하고, unseen task에 zero-shot 일반화가 어렵다.

따라서, 이 논문은 프롬프트의 의미를 보존하면서도 압축된 “Gist tokens”로 대체하여 효율성을 확보하는 방법을 제안한다.



Methodology

Overview

1

이 논문은 Gisting이라는 방법론을 제안한다. 핵심은 프롬프트를 소수의 가상 토큰(Gist tokens) 위의 activation으로 압축하고, 이를 캐싱하여 재사용하는 것이다. 위의 그림은 Gist Masking의 핵심 아이디어를 시각적으로 보여주는 그림이다.LM이 프롬프트 \(t\)를 그대로 사용하는 대신, \(t\)를 Gist tokens로 압축하도록 강제하는 어텐션 마스크 수정 과정을 나타낸다.

  • Decoder-only (LLM 계열, Figure 2a): 원래 causal mask(하삼각 행렬)에서는 입력 \(x\)와 output \(y\)가 프롬프트 \(t\)까지 attend할 수 있다. 하지만 Gist Masking에서는 Gist tokens \(g_1, g_2, …\)만 프롬프트 \(t\)에 접근 가능</spam>하고, \(x\)와 \(y\)는 \(t\)에 접근하지 못한다. 이렇게 하면 프롬프트 정보는 반드시 Gist tokens을 통해서만 전달된다.
  • Encoder-Decoder (T5 계열, Figure 2b): Encoder 단계에서 입력 \(x\)가 프롬프트 \(t\)를 직접 참조하지 못하도록 차단하고, 동시에 프롬프트 \(t\)와 Gist tokens이 입력 \(x\)를 참조하지 못하게 한다. Decoder cross-attention에서도 프롬프트 \(t\)는 차단된다. 따라서 Encoder-Decoder 구조에서는 양방향 attention을 더 강하게 제약하여 프롬프트 정보가 Gist tokens에 집중되도록 유도한다.

Gisting Framework

  • 입력: Task 프롬프트 \(t\), Input context \(x\)
  • 출력: LM output \(y\)
  • 기존 모델: \(p_{LM}(y \mid t,x)\)
  • 제안 모델: \(p_G(y \mid G(t), x)\)

기존 LM은 프롬프트를 그대로 받아 출력 \(y\)을 예측하지만 (\(p_{LM}(y \mid t,x)\)), Gisting은 \(t\)를 압축한 \(G(t)\)를 사용하여 효율성을 확보한다 (\(p_G(y \mid G(t), x)\)). 즉, LM이 \(G(t)\)를 통해 동일한 instruction-following 동작을 수행할 수 있도록 설계하는 것이다. 이를 통해 LM은 프롬프트 길이와 상관없이 소수의 Gist tokens만 처리하면 되므로, 연산량과 메모리 사용량이 줄어든다.

A Context Distillation Perspective

  • 기존 Context Distillation Loss
$$L_{CD}(p_{t}^{CD}, t) = \mathbb{E}x \left[ D{KL}(p_{LM}(y|t,x) | p_t^{CD}(y|x)) \right]$$
  • Gisting의 Meta Distillation Loss
$$L_G(p_G, T) = \mathbb{E}{t \sim T, x} \left[ D{KL}(p_{LM}(y|t,x) | p_G(y|G(t),x)) \right]$$

즉, task 분포 \(T\) 전체에 대해 압축된 Gist token을 예측하도록 학습한다. 기존 distillation은 특정 프롬프트 \(t\)에 한정된 모델을 학습하지만, Gisting은 분포 전체 \(T\)에 대해 일반화 가능한 \(G(t)\)를 학습한다. 즉, unseen task \(t\)가 들어와도 별도의 학습 없이 Gist tokens만 생성하면 되므로, 학습 비용과 메모리 요구가 크게 줄어든다.

Learning Gisting by Masking

1

\(x\)프롬프트 \(t\)와 input \(x\)사이에 \(k\)개의 Gist token \([g_1, \cdots, g_k]\)를 삽입하고, 어텐션 마스크를 수정하여 \(x\)와 \(y\)가 \(t\)에 직접 attend하지 못하게 한다.

  • Decoder-only LM (LLaMA-7B): causal mask의 하삼각 행렬 일부를 차단 (프롬프트 → input attention 차단).
  • Encoder-Decoder LM (FLAN-T5-XXL):
    • Encoder에서 input \(x\)가 프롬프트 \(t\)에 attend 불가
    • Decoder cross-attention에서 \(t\)에 attend 불가
  • 출력: \(y\)가 프롬프트 \(t\) 대신 Gist tokens \(g\)를 통해 간접적으로 정보 획득.

핵심은 어텐션 마스크를 수정하여 프롬프트 \(t\)가 직접적으로 downstream input/output에 영향을 주지 못하게 막는 것이다. 이렇게 하면 LM은 프롬프트 정보를 Gist tokens에 반드시 압축해야 하며, 이후 \(x\)와 \(y\)는 Gist tokens을 통해서만 \(t\)의 의미를 간접적으로 활용한다. 이는 단순하지만 강력한 방식으로, 불과 수 줄의 코드 변경만으로 기존 LM 훈련 과정에 통합될 수 있다.



Experiments

Experiment Setup

Dataset

  • 실험에는 저자들이 만든 Alpaca+ 데이터셋을 사용하였다
  • Alpaca+는 Self-InstructStanford Alpaca를 합쳐 구성되었으며, 총 130,321개 예시를 포함한다.
    • 104,664개의 고유 task 프롬프트s \(t\)
    • 48,530개의 unique inputs \(x\)
    • 평균적으로 task당 0.64개의 input (즉, 약 59%의 task는 input이 없음)
  • Validation split은 세 가지로 나뉜다:
    • Seen: 학습 시 본 적 있는 프롬프트지만 input은 새로움 (1000개)
    • Unseen: 학습 데이터에 없는 프롬프트 (1000개)
    • Human: 사람이 직접 작성한 252개의 프롬프트 (OOD, Out-of-distribution 평가용)

Evaluation Metric

  • ROUGE-L: 텍스트 겹침 기반 지표, lexical overlap 측정
  • ChatGPT-3.5 평가: 두 모델 응답을 비교해 어느 쪽이 더 좋은지 투표 (win rate 측정)
  • Human 평가: Prolific annotators 3명이 모델 응답을 평가, Cohen’s κ로 annotator agreement 분석

Main Results 1. Sinle Gist Token Performance

1

  • Seen split:
    • LLaMA-7B Gist: ROUGE-L 57.8 (vs Positive Control 58.0), ChatGPT winrate 48.6%
    • FLAN-T5-XXL Gist: ROUGE-L 48.9 (vs Positive Control 50.6), ChatGPT winrate 50.8%
    • → 즉, 압축 후에도 성능 손실이 거의 없음.
  • Unseen split:
    • 성능 약간 하락 (LLaMA 49.7%, FLAN-T5 46.2% winrate)
    • 하지만 여전히 Positive Control과 비슷한 수준 유지.
  • Human split (OOD):
    • LLaMA: winrate 45.8% (약간 성능 저하)
    • FLAN-T5: winrate 42.5% (더 큰 성능 저하)
    • TF-IDF 기반 단순 키워드 압축은 거의 Negative Control 수준으로 성능 저조.

Table 1은 Gist 모델이 다양한 평가 split에서 Positive Control과 비교했을 때 얼마나 성능을 유지하는지를 보여준다. 특히 Seen split에서는 사실상 동일한 성능을 기록했고, Unseen split에서도 큰 성능 저하 없이 일반화가 가능함을 보였다. Human split에서는 성능 저하가 다소 두드러졌지만, 여전히 Negative Control이나 TF-IDF 기반 압축보다는 훨씬 우수했다. 이는 gist tokens이 실제로 프롬프트의 핵심 의미를 효과적으로 압축하고, 최소한의 정보 손실만으로도 원래 모델과 유사한 품질의 출력을 생성할 수 있음을 입증한다.

Main Results 2. Human Evaluation

1

  • Human Evaluation:
    • LLaMA gist: 평균 winrate 52.3% (Positive Control과 사실상 동등)
    • FLAN-T5 gist: 평균 winrate 40.6% (Positive Control보다 낮음)
  • Cohen’s \(\mathcal{K}\) (annotator agreement):
    • LLaMA: 0.24 (낮지만 일관성 존재)
    • FLAN-T5: 0.33
  • ChatGPT와 Human 평가 결과가 매우 유사 → ChatGPT 평가가 Human 평가의 대체 가능성을 보여줌.

Table 2는 Human annotator들이 직접 평가한 결과를 통해 Gist 모델의 품질을 검증한 것이다. LLaMA-7B는 Positive Control과 유사하거나 오히려 조금 더 나은 평가를 받았으며, FLAN-T5는 상대적으로 낮은 성능을 보였다. 그러나 전반적으로 ChatGPT 기반 자동 평가와 Human 평가 간 결과 일치도가 높아, ChatGPT가 저비용 대체 지표로 활용 가능함을 확인할 수 있다. 이 결과는 Gist 모델이 단순 자동화 지표뿐 아니라 실제 인간 평가에서도 경쟁력 있는 성능을 유지함을 보여준다.

Number of Gist Tokens

1

1

  • Gist token 개수 \(k \in {1, 2, 5, 10}\) 변화에 따른 성능 확인
  • 결과:
    • 성능은 \(k=1\)일 때와 \(k=5,10\)일 때 거의 동일.
    • LLaMA-7B의 경우 \(k=10\)에서는 성능 저하 (overfitting).
    • Compression Factor는 평균 프롬프트 길이 대비 \(k\) 토큰 비율로 계산됨. 예: Human split에서 26x 압축 가능.

Figure 3은 gist token 수를 늘렸을 때 성능 변화가 어떻게 나타나는지를 보여준다. 흥미롭게도 \(k=1\)일 때 이미 Positive Control과 동등한 수준의 성능을 기록하며, \(k\)를 더 늘리더라도 별다른 성능 향상이 없다. 오히려 token 수가 많아질수록 학습 분포에 과적합되는 경향이 나타나면서 성능이 하락한다. 이는 Gisting의 핵심 장점이자 효율성을 잘 보여주는 결과로, 불필요하게 많은 토큰을 쓸 필요 없이 최소한의 토큰만으로도 프롬프트 정보를 효과적으로 압축할 수 있음을 입증한다.

Gist Efficiency Improvements

1

  • 실험 조건: CUDA wall time(ms), GFLOPs를 prompt caching 전략별로 측정
  • LLaMA-7B:
    • No caching vs Gist caching → FLOPs 40% 감소, wall time 6.8% 감소.
    • Instruction caching vs Gist caching → FLOPs 이득은 0.11%에 불과, wall time 1% 단축.
  • FLAN-T5-XXL:
    • No caching vs Gist caching → FLOPs 40% 감소, wall time 4.2% 감소.
    • Instruction caching 불가능(encoder 구조 때문에).

Table 3은 실제 연산 효율성 측면에서 Gisting이 제공하는 이점을 구체적으로 수치로 보여준다. FLOPs는 크게 줄었지만 실제 실행 시간인 wall time의 단축 효과는 상대적으로 작게 나타난다. 이는 Transformer의 연산 병목이 단순 연산량이 아니라 메모리 접근에 의해 지배되기 때문이다. 그러나 중요한 점은 gist caching이 instruction caching보다 훨씬 효율적이며, 특히 encoder-decoder 모델처럼 instruction caching이 불가능한 경우에도 적용할 수 있다는 것이다. 따라서 Gisting은 단순한 성능 개선뿐 아니라 모델 구조 전반에서 활용 가능한 범용적 효율화 방법이라는 점에서 의의가 크다.



Conclusion

Contribution

  • Instruction finetuning 과정에서 attention mask만 수정하여 gist tokens을 학습하는 간단하면서도 비용 없는 방법을 제안함
  • Gist 모델이 unseen OOD prompts도 최대 26배까지 압축하면서도 출력 품질을 유지할 수 있음을 보임
  • 최대 40% FLOPs 감소, 4.2% wall time 단축, 그리고 기존 prompt caching 대비 26배 더 많은 프롬프트 캐싱 가능을 실험적으로 입증함
  • Gisting은 instruction finetuning의 변형일 뿐 아니라, LM의 meta-context distillation 방법으로 해석될 수 있음을 제시함

Limitations

  • Prompt를 압축하는 과정에서 원래 instruction의 nuance가 손실될 수 있으며, 세부 구문이나 특정 표현을 정확히 보존하지 못하는 경우 발생
  • Encoder-decoder 구조(T5)에서는 gist masking이 bidirectional encoder의 attention 흐름을 제한하기 때문에 OOD 성능이 Decoder-only 구조(LLaMA)에 비해 더 크게 저하됨
  • FLOPs는 크게 줄지만, 실제 wall time latency 개선은 제한적으로, 특히 Decoder-only LM에서는 효과가 작음
  • LM의 edge cases에 대한 동작이 본래도 잘 이해되지 않는데, gist 모델에서는 압축으로 인해 예측 불안정성이 추가될 수 있어 실제 배포 시 안전성 검증이 필요함

NR 카테고리 내 다른 글 보러가기

댓글 남기기