[논문리뷰]GritLM: Generative Representational Instruction Tuning (ICLR, 2025)
카테고리: NR
Niklas Muennighoff, Hongjin Su, Liang Wang, Nan Yang, Furu Wei, Tao Yu, Amanpreet Singh, and Douwe Kiela. 2025. Generative Representational Instruction Tuning. arXiv:2402.09906 [cs.CL] https://arxiv.org/abs/2402.09906
Problem Statement
이 논문의 주요 목표는 텍스트 기반 문제를 모두 포괄할 수 있는 단일 LLM을 구축하는 것이다. 기존 연구들은 생성(generation)과 임베딩(embedding) 중 하나에만 특화되어 있었으나, 현실적 응용에서는 두 능력이 동시에 요구된다. 예를 들어, RAG에서는 검색(retrieval)과 생성(generation)을 모두 수행해야 하며, 현재는 이를 위해 별도의 모델이 필요하다.
- [생성 기반 접근의 한계] LLM의 hidden state를 임베딩 표현으로 사용하는 것은 쉽지만, Generation이 목적인 LLM을 임베딩 태스크 (e.g., retrieval, clustering)에 맞추려면 Contrastive learning과 같은 별도의 학습이 필요하지만, 이 과정에서 파라미터가 바뀌어 “좋은 생성 분포를 유지하는 것”과 “좋은 임베딩을 만드는 것”이 충돌함. 그 결과 임베딩 성능은 올라가지만, 생성 성능은 떨어지는 문제가 발생.
- [임베딩·생성 모델 분리로 인한 성능 및 효율성 저하] 현재는 임베딩 전용 모델과 생성 전용 모델이 각각 존재하여 한쪽 작업에 특화될 뿐 다른 작업에는 약하다. 따라서 두 모델을 함께 사용해야 하고, 이는 성능 저하뿐 아니라 중복 연산으로 인해 latency와 비용이 커진다.
- [복잡한 시스템 구조의 비효율성] Retrieval-Augmented Generation(RAG) 같은 응용에서는 검색 모델과 생성 모델을 별도로 사용해야 하므로 쿼리와 문맥을 두 번씩 전달해야 하고, API 제공업체도 생성·임베딩 엔드포인트를 각각 운영해야 한다. 이로 인해 인프라 구조가 복잡해지고 운영 효율성이 떨어진다.
GRIT은 이를 해결하기 i) Dual Instruction Tuning, ii) Unifying Models, iii) Query/Doc Caching in RAG 세 가지 방법을 도입하였다. 위의 그림은 GritLM이 다른 baseline들 대비 생성과 임베딩 활용에서 모두 우수한 성능을 보여주었음을 나타낸다.
Methodology
Overview
Figure 2는 GRIT에서 하나의 모델이 생성과 임베딩이라는 서로 다른 두 태스크를 어떻게 구분하는지를 보여주는 구조도이다. 입력 포맷에는 특별한 토큰을 사용하여 태스크를 지정하는데, 생성 태스크의 경우 <|assistant|>
토큰을 포함시켜 모델이 Causal Attention과 LM Head를 통해 연속적인 텍스트를 출력하도록 한다. 반면 임베딩 태스크에서는 <|embed|>
토큰을 추가해 입력 전체를 Bidirectional Attention으로 인코딩한 뒤 Mean Pooling을 적용하여 고정 차원의 벡터를 얻는다. 이를 통해 동일한 언어 모델이 instruction에 따라 텍스트 생성 모드와 임베딩 표현 학습 모드를 전환할 수 있으며, 결국 GRIT은 이 단일 아키텍처 안에서 두 가지 기능을 모두 효과적으로 학습하도록 설계되었음을 Figure 2가 시각적으로 보여준다.
Figure 3는 GRIT의 전체 학습 프레임워크를 개괄적으로 보여준다. GRIT은 입력 instruction과 special token을 통해 생성 태스크와 임베딩 태스크를 구분하고, 각 태스크별로 다른 학습 경로를 갖는다. 생성 태스크는 Causal Attention 기반의 LM Head를 통해 텍스트 시퀀스를 예측하며, 임베딩 태스크는 Bidirectional Attention과 Mean Pooling을 통해 표현 벡터를 추출한다. 이 두 학습 과정은 각각 Language Modeling Loss와 Contrastive Representation Loss로 최적화되며, 최종적으로는 이 두 손실을 가중합한 GRIT Loss로 모델을 학습한다.
Embedding Task
- 입력:
<s><|user|>{instruction}<|embed|>{sample}</s>
- 출력: Mean Pooling을 통해 얻은 고차원 벡터 (4096차원)
모델은 입력된 context를 bidirectional attention을 사용하여 시퀀스 전체 맥락을 고려하고, 최종 hidden state에 mean pooling을 적용하여 임베딩 표현을 생성한다.
- \(q^{(i)}\): \(i\)번째 query
- \(d^{(i)}\): \(i\)번째 positive document
- \(d^{(j)}\): in-batch negatives
- \(f_\theta\): encoder 함수
- \(\sigma\): cosine similarity
- \(\tau\): temperature scaling
Embedding task 학습을 위해서는 contrastive loss를 사용하며, in-batch내의 다른 샘플들을 negative로 활용한다 (in-batch negative). 모델 파라미터 \(\theta\)로 매개변수화된 GritLM 모델 \(f\)에 대해, \(q\)는 query이고 \(d\)는 document이다. \(\sigma\) 는 pooling 연산 후 코사인 유사도를 나타내고, 명령어와 포맷 토큰은 무시하고 입력 샘플의 final hidden state만 평균화한다.
Generative Task
- 입력:
<s><∣user∣>instruction<∣assistant∣>response</s>
- 출력: 토큰 시퀀스 \({x^{(1)}, \dots, x^{(N)}}\)
모델은 입력에 대해 Causal Attention을 사용하고, LM Head(Language Modeling Head, 언어 모델링 헤드)를 사용하여 순차적으로 텍스트를 생성한다.
- \(x^{(i)}\): \(i\)번째 토큰
- \(f_{\theta,\eta}\): LLM backbone과 LM Head
생성 데이터에 대해서는 Next token prediction을 위한 language modeling objective을 사용하며, 모델이 다음 토큰을 예측하도록 학습한다. 모델 파라미터 \(\theta\)와 언어 모델링 헤드 \(\eta\)로 매개변수화된 GritLM 모델 \(f\)에 대해 \(i\)번째 토큰 \(x_i\)는 이전 토큰들의 누적 확률분포에 기반하여 다음 토큰을 예측하게 된다.
Training Objective
최종적으로 두 손실함수를 가중합하여 training objective로 활용한다.
여기서 \(\lambda_{\text{Rep}} > \lambda_{\text{Gen}}\)으로 설정하여, 임베딩 학습에 더 큰 비중을 둔다.
Inference
GRIT은 학습 단계에서 생성과 임베딩을 단일 모델로 통합했기 때문에, 추론 단계에서는 효율적인 RAG 실행을 가능하게 한다. 전통적인 RAG에서는 쿼리와 문서를 각각 별도의 모델로 인코딩해야 하므로, 질의 \(q\)와 문서 집합 \({d_1, d_2, \dots, d_n}\)이 주어졌을 때 총 네 번의 forward pass가 필요하다. 이로 인해 latency와 비용이 증가하며, 실제 응용에서 비효율성이 크다.
이를 해결하기 위해 GRIT은 Query Caching과 Doc Caching이라는 두 가지 최적화 기법을 도입한다. Query Caching은 동일한 질의 \(q\)에 대해 한 번만 임베딩 \(f_\theta(q)\)를 계산하고, 이후 반복적으로 재사용하는 방식이다. 반대로 Doc Caching은 문서 집합 \({d_i}\)를 미리 임베딩 벡터 \(f_\theta(d_i)\)로 변환해 캐시에 저장하고, 추론 시 즉시 불러오는 방식이다. 이 접근은 특히 긴 문서 처리에서 큰 장점을 가지며, 문서 인코딩의 중복 계산을 제거해 latency를 크게 줄인다.
Experiments
Main Result 1. Embedding Peformance
임베딩 성능은 MTEB 벤치마크에서 평가되었다. GritLM-7B는 평균 66.8점을 기록하여 오픈 모델 중 새로운 SOTA 성능을 달성하였다. 이는 기존의 대규모 언어 모델인 LLaMA-2-70B가 35.6점에 불과한 것과 비교할 때 매우 큰 격차를 보여주며, 임베딩 전용 모델인 E5-Mistral-7B의 66.6과 동등한 수준이다. 즉, GRIT은 임베딩 학습을 위해 설계된 모델과 비슷한 성능을 달성하면서도 동시에 생성 능력을 유지한다는 점에서 큰 의미를 가진다.
Main Result 2. Generation Performance
생성 성능은 MMLU, GSM8K, BBH, TyDi QA, HumanEval, AlpacaEval 등 다양한 벤치마크에서 측정되었다. GritLM-7B는 평균 55.5점을 기록하여 동급 크기의 오픈모델들보다 성능이 뛰어났으며, GritLM-8x7B는 평균 65.7점을 기록하여 공개된 생성 모델 중 가장 높은 성능을 달성하였다. 특히 LLaMA-2-70B의 46.4점보다도 월등히 높은 결과를 보임으로써, GRIT이 단일 모델임에도 불구하고 파라미터 효율성과 생성 능력 모두에서 강점을 지닌다는 사실이 입증되었다.
Ablation Study
Attention 메커니즘과 Pooling 방식의 영향을 분석한 결과, 임베딩 태스크에서는 Bidirectional Attention과 Mean Pooling을 조합했을 때 가장 우수한 성능(평균 64.0)을 보였다. 반대로 Causal Attention이나 Last Token Pooling을 사용한 경우 성능이 떨어졌다. 이는 임베딩 표현 학습에서 양방향 맥락 정보와 평균 기반 집계가 중요한 역할을 함을 보여준다.
Analysis 1. Reranking
GritLM은 Cross-Encoder와 Bi-Encoder 모두로 활용할 수 있음을 확인하였다. Retrieval 단계에서 Top-10 결과에 대해 reranking을 적용한 경우, 성능이 57.4에서 57.9로 소폭 향상되었다. 이는 GRIT이 단순한 임베딩 생성 모델에 그치지 않고, reranker로서도 적용 가능한 유연성을 가진다는 것을 보여준다.
Analysis 2. RAG Optimization (Query/Doc Caching)
Table 7은 GRIT을 활용한 Retrieval-Augmented Generation(RAG) 환경에서 캐싱 기법의 효과를 분석한 결과를 담고 있다. 전통적인 RAG 구조에서는 질의(query)와 문서(document)가 각각 모델에 두 번씩 전달되어 총 네 번의 forward pass가 필요하기 때문에 지연(latency)과 비용이 크게 증가한다. 이를 해결하기 위해 저자들은 Query Caching과 Doc Caching을 제안하였다. Query Caching은 질의 인코딩을 한 번만 수행한 뒤 재활용하는 방식이고, Doc Caching은 문서 인코딩을 캐시하여 재사용하는 방식이다.
실험 결과, 이러한 캐싱 기법을 적용하면 RAG의 추론 속도가 크게 향상되었다. CPU 환경에서는 최대 63%, GPU 환경에서는 최대 33%의 속도 개선이 이루어졌으며, 특히 4000 토큰 이상의 긴 문서를 처리할 때 latency 감소 효과가 두드러졌다. 중요한 점은 이러한 최적화에도 불구하고 RAG의 성능(Match score)이 유지되거나 일부 데이터셋에서는 오히려 약간 개선되었다는 사실이다. 즉, Query/Doc Caching은 단순히 효율성을 높이는 데 그치지 않고 실제 성능을 손상시키지 않으며, 대규모 문서 검색이 필요한 응용에서 실질적인 이점을 제공한다.
댓글 남기기