[논문리뷰]DyPRAG: Dynamic Parametric Retrieval Augmented Generation for Test-time Knowledge Enhancement
카테고리: NR
Yuqiao Tan, Shizhu He, Huanxuan Liao, Jun Zhao, and Kang Liu. “Dynamic Parametric Retrieval Augmented Generation for Test-time Knowledge Enhancement.” In arXiv preprint arXiv:2503.23895 [cs.CL], March.
Problem Statement
- Standard RAG는 검색된 문서들을 입력 컨텍스트에 추가하는 방식인 in-context injection을 사용하는데, 이는 문서의 수와 길이가 증가함에 따라 추론 비용이 크게 증가한다는 문제가 있다.
- 외부 컨텍스트가 LLM의 내부 지식과 충돌할 때 RAG 환각 (= knowledge conflict)이 발생하여, 정확한 문서가 제공되더라도 LLM이 잘못된 답변을 생성하는 문제가 있다.
- 기존의 Parametric RAG는 문서를 LLM 파라미터에 직접 통합하는 방식(parameter injection)으로 추론 비용을 줄이지만, 각 문서마다 별도의 훈련과 저장이 필요하여 높은 훈련 비용과 저장 비용이 발생하며, 일반화 능력이 제한적이다.
Methodology
DyPRAG는 문서를 파라미터로 동적으로 변환하는 프레임워크이다. 전체 프로세스는 오프라인 단계와 온라인 단계로 구성된다.
Stage1: Doc-Param Pairs Collections
Document Augmentation
- 입력: 원본 문서 \(d_i\)
- 출력: 증강된 문서 집합 \(D_i = {(d_i^k, q_i^j, a_i^j)}\)
이 단계에서는 각 문서의 정보를 LLM이 더 잘 기억하고 조작할 수 있도록 증강한다. LLM을 사용하여 원본 문서 \(d_i\)를 \(n\)개의 다양한 언어적 변형으로 다시 작성하여 \({d_i^1, d_i^2, ..., d_i^n}\)을 생성한다. 또한 각 원본 문서에 대해 \(m\)개의 질문-정답 쌍을 생성한다. 이렇게 생성된 증강 문서들은 원본 문서의 사실적 내용을 보존하면서도 다양한 언어적 변형을 포함하게 된다. \(n\), \(m\)은 하이퍼파라미터이다.
Document Parameterization
- 입력: 증강된 문서 집합 \(D_i\)
- 출력: Fine-tuned LoRA 파라미터 \(p_i\)
증강된 문서들을 LoRA (Low-Rank Adaptation) 방식을 통해 파라미터로 인코딩한다. 각 증강 문서 트리플 \((d_i^k, q_i^j, a_i^j)\)을 \([d_i^k ⊕ q_i^j ⊕ a_i^j]\) 형태로 연결하여 훈련 샘플을 만든다.
증강된 문서들을 LoRA (Low-Rank Adaptation) 방식을 통해 파라미터로 인코딩한다. 위의 식을 목적 함수로 하고, \(\Delta \Theta\)는 학습 가능한 low-rank matrix이고, \(\Theta\)는 LLM의 고정된 파라미터이다. LoRA는 feed-forward network (FFN)에만 적용되며, 이를 통해 문서의 지식을 LLM의 파라미터에 직접 인코딩한다. 이 과정을 parameter injection이라고 부르며, 이를 통해 각 문서 \(d_i\)는 파라미터 표현 \(p_i = F(d_i)\)로 변환된다. \(F\)는 underlying mapping function이다.
Stage2: DyPRAG Training
Document Encoding
- 입력: 문서 \(d_i\)
- 출력: 문서 임베딩 \(s_i ∈ \mathbb{R}^h\)
Parameter Translator를 학습하기 위해 먼저 원본 LLM M을 사용하여 문서를 인코딩한다. 문서 \(d_i\)를 LLM에 입력하고, vocabulary space로 변환하기 전의 마지막 토큰 위치에서 last hidden state \(s_i\)를 추출한다. 여기서 \(h\)는 LLM의 hidden dimension을 나타낸다. 이 임베딩은 문서의 의미적 정보를 압축하여 담고 있다.
Parameter Translation
- 입력: 문서 임베딩 \(s_i\)와 레이어 인덱스 \(idx^l\)
- 출력: Dynamic LoRA 파라미터 \((B^l, A^l)\)
Parameter Translator \(F'_{\phi}\)는 문서 임베딩을 LoRA 파라미터로 변환하는 경량 하이퍼네트워크이다. 여러 개의 선형 레이어로 구성되며, 베이스 파라미터 \(\phi\)로 매개변수화된다. FFN의 각 모듈(up-proj, down-proj, gate-proj)에 대해 별도의 translator를 구성한다. 각 레이어 l에서 이 프로세스가 적용되므로, 문서 임베딩 \(s_i\)와 레이어 인덱스를 연결하여 입력으로 사용한다.
표준 LoRA는 가중치를 \(W + BA\)로 분해하는데, 여기서 \(B \in \mathbb{R}^{h \times r}\), \(A \in \mathbb{R}^{h \times k}\)이다. \(r\)과 \(k\)는 각각 LoRA rank와 FFN 레이어의 중간 차원이다. Parameter Translator는 이 B와 A행렬을 생성한다. 예를 들어, FFN의 up-project 모듈에서 B 행렬은 아래와 같다.
이 때, \(W_{down}^{l, B} \in \mathbb{R}^{p \times (h+1)}\), \(W_{up}^{l, B} \in \mathbb{R}^{hr \times p}\)이고, \(p\)는 중간 차원으로 하이퍼 파라미터이다(tunable intermediate dimension).
Training Objectives
- 입력: Doc-Param 쌍 \((d_i, p_i)\)의 집합 \(K\)
- 출력: 학습된 Parameter Translator \(F'_{\phi}\)
세 가지 손실 함수를 조합하여 \(F'_{\phi}\)를 학습한다. \(\mathcal{L_{\text{pred}}}\)는 증강된 데이터셋을 사용한 표준 언어 모델링 손실로, 생성된 파라미터가 올바른 답변을 생성하도록 한다. \(\mathcal{L_{\text{mse}}}\)는 생성된 파라미터와 타겟 파라미터 간의 평균 제곱 오차로, 파라미터 수준에서의 정렬을 보장한다. \(\mathcal{L_{\text{kl}}}\)은 두 모델의 단어 확률 분포 간의 KL divergence로, 생성된 모델이 타겟 모델의 동작을 모방하도록 한다. 이 세 손실을 가중합하여 최종 alignment loss L_align을 구성한다.
Stage3: DyPRAG Inference
Document Retrieval and Encoding
- 입력: 테스트 질문 \(q^t\)
- 출력: 검색된 문서의 임베딩 \({s_i^t}\)
추론 시에는 먼저 retrieval module \(R\)을 사용하여 질문과 가장 관련성 높은 \(c\)개의 문서를 검색한다. 각 검색된 문서 \(d_i^t\)에 대해 Stage 2와 동일한 방식으로 임베딩 \(s_i^t\)를 추출한다.
Dynamic Parameter Generation
- 입력: 문서 임베딩 \(s_i^t\)
- 출력: Dynamic LoRA adapter \(p_i^{(t,')}\)
학습된 Parameter Translator \(F'_{\phi}\)에 각 문서 임베딩을 입력하여 dynamic LoRA adapter를 생성한다. 이 adapter는 해당 문서의 관련 정보를 parameter modality로 인코딩한다. 생성된 LoRA 파라미터들을 병합하여 추론에 사용함으로써, 긴 컨텍스트 없이도 낮은 추론 비용을 달성한다.
Hybrid Knowledge Injection (DyPRAG-Combine)
- 입력: 검색된 문서들과 질문
- 출력: 통합된 지식을 활용한 최종 답변
DyPRAG-Combine은 parameter injection과 in-context injection을 결합한 방식이다. 먼저 검색된 문서들을 Parameter Translator를 통해 dynamic LoRA로 변환하여 LLM의 파라미터에 주입한다. 이를 통해 LLM은 문서와 관련된 parametric knowledge를 내재화한다. 그 다음 동일한 문서들을 전통적인 RAG 방식처럼 입력 컨텍스트에도 추가한다. 이 접근법은 parametric knowledge와 contextual knowledge의 상호보완적 융합을 가능하게 하며, 특히 knowledge conflict를 완화하는 데 효과적이다. 실험 결과 이 방식이 가장 우수한 성능을 보였으며, RAG hallucination 문제를 효과적으로 해결했다.
Experiments
Experiment Settings
실험은 네 가지 QA 데이터셋에서 수행되었다.
- Multi-hop QA: 2WikiMultihopQA (2WQA), HotpotQA (HQA)
- KGQA: PopQA (PQA), ComplexWebQuestions (CWQ).
평가 지표로는 F1 score (%)를 사용했다. 베이스 모델로는 Qwen2.5-1.5B-Instruct, LLaMA-3.2-1B-Instruct, LLaMA3-8B-Instruct를 선택했다. Doc-Param pairs 수집을 위해 각 하위 데이터셋에서 중복되지 않는 200개의 추가 질문을 사용했으며, 검색 문서 수 \(c\)는 3으로 설정했다. 이를 통해 총 4,800개의 alignment set \(K\)를 구성했다. Parameter Translator의 intermediate size \(p\)는 32로 설정했다.
Main Results (QA Performance)
DyPRAG는 다양한 모델과 데이터셋에서 일관되게 우수한 성능을 보였다. LLaMA3.2-1B 모델에서 DyPRAG는 평균 27.57%의 F1 점수를 달성하여 PRAG (26.51%)를 1.06%, standard RAG (26.99%)를 0.58%, vanilla model (22.39%)을 5.18% 능가했다. 특히 주목할 만한 성과는 2WQA의 bridge 하위 작업에서 나타났는데, DyPRAG가 48.15%를 달성하여 RAG보다 21.37%, PRAG보다 23.81% 높은 성능을 보였다. 이는 DyPRAG가 다양한 데이터셋에서 학습할 때 더 유용한 정보를 학습한다는 것을 보여준다.
DyPRAG-Combine은 모든 모델에서 최고 성능을 달성했다. LLaMA3.2-1B에서 평균 31.80%로 PRAG-Combine (29.94%)을 1.86% 능가했으며, Qwen2.5-1.5B에서는 27.60%로 PRAG-Combine (27.05%)을 0.55%, LLaMA3-8B에서는 43.69%로 PRAG-Combine (42.61%)을 1.08% 능가했다. 특히 LLaMA3-8B의 2WQA total에서 50.24%라는 높은 성능을 달성했는데, 이는 vanilla model (33.02%)보다 17.22% 향상된 결과이다.
Out-of-Distribution Performance
DyPRAG의 일반화 능력을 평가하기 위해 StrategyQA (SQA)와 IIRC 데이터셋에서 out-of-distribution (OOD) 성능을 측정했다. Vanilla model은 관련 지식 부족으로 인해 매우 낮은 성능을 보였다 (예: Qwen2.5-1.5B에서 SQA 1.00%, IIRC 8.78%). 반면 DyPRAG는 parametric knowledge를 성공적으로 주입하여 SQA에서 비교 가능한 성능을 달성했다.
DyPRAG-Combine은 모든 시나리오에서 최고 성능을 보였다. LLaMA3.2-1B에서 SQA 50.33%를 달성하여 standard RAG (27.67%)보다 22.66% 향상되었고, IIRC에서는 41.91%로 standard RAG (40.38%)보다 향상되었다. LLaMA3-8B에서는 IIRC에서 57.90%를 달성하여 standard RAG (43.27%)보다 13.63% 높은 성능을 보였다. 이러한 결과는 DyPRAG가 PRAG와 달리 추가적인 오프라인 훈련 없이도 OOD 시나리오를 처리할 수 있는 뛰어난 일반화 능력을 가지고 있음을 보여준다.
Ablation Study
Effect of Intermediate Dimension \(p\)
Parameter Translator의 중간 차원 p가 성능과 비용에 미치는 영향을 분석했다. 놀랍게도 p=2에서도 DyPRAG는 32.66%의 F1 점수로 standard RAG (28.32%)와 PRAG (30.82%)를 능가했다. p=4에서 33.26%로 최고 성능을 달성했으며, 이후 p가 증가해도 성능 향상은 미미했다. 저장 비용 측면에서 p=2일 때 7.71MB로 PRAG (672MB)의 0.011%에 불과했다. 추론 시간은 0.625초 (encode: 0.13s, translate: 0.060s 포함)로 standard RAG (1.20초)보다 빠르고 PRAG (0.56초)와 비슷한 수준이었다.
Effect of Training Dataset Size
Doc-Param pairs의 크기를 480에서 4,800까지 변경하며 실험했다. DyPRAG는 단 480개의 훈련 예제로도 강력한 성능을 달성했다. LLaMA3.2-1B에서 2WQA는 480개로 0.25, 4,800개로 0.30의 F1 점수를 보였으며, 대부분의 데이터셋에서 안정적인 성능을 유지했다. 이는 Parameter Translator가 최소한의 데이터로도 문서와 파라미터 간의 underlying mapping을 효과적으로 학습할 수 있음을 보여준다.
Effect of Alignment Loss
세 가지 손실 함수 구성 요소의 기여도를 분석했다. \(\mathcal{L}_{kl}\)을 제거했을 때 성능이 평균 25.28%에서 18.68%로 크게 하락했는데, 이는 타겟 모델의 출력 분포와 정렬하는 것이 효과적인 전략임을 보여준다. \(\mathcal{L}_{mse}\)만 제거했을 때는 23.38%로 상대적으로 작은 하락을 보였지만, \(F'_\phi\)가 생성하는 \(p'\) 값이 훈련된 \(p\)에 최대한 가깝도록 하는 것이 여전히 유익함을 확인했다. \(\mathcal{L}_{pred}\)만 사용했을 때도 (w/o \(\mathcal{L}_{kl}. \mathcal{L}_{mse}\)) 22.54%로 안정적인 성능을 유지하여, 표준 언어 모델링 손실이 전체 정렬 과정에서 중심적인 역할을 한다는 것을 보여준다.
Analysis
Knowledge Conflict and Fusion
DyPRAG-Combine이 contextual knowledge와 parametric knowledge 간의 충돌을 어떻게 해결하는지 분석했다. Table 4의 사례에서 vanilla LLM은 잘못된 parametric knowledge를 가지고 있었고, standard RAG와 DyPRAG는 모두 hallucination으로 인해 틀린 답을 생성했다. 반면 DyPRAG-Combine은 contextual knowledge와 변환된 parametric knowledge를 효과적으로 통합하여 정답을 생성했다. RAGTruth 벤치마크에서 GPT-4o를 사용한 평가 결과 (Figure 3), DyPRAG-Combine이 standard RAG보다 knowledge internalization에서 크게 우수했다. Qwen-1.5B에서 89% win rate, LLaMA-1B에서 85% win rate, LLaMA-8B에서 47% win rate를 달성했다.
Response Length Comparison
DyPRAG-Combine은 응답 길이를 크게 감소시켰다. Qwen2.5-1.5B 모델에서 2WQA의 경우 RAG (245.40 tokens)에서 DyPRAG-Combine (196.23 tokens)로 20% 감소했고, CWQ에서는 RAG (212.14 tokens)에서 DyPRAG-Combine (22.59 tokens)로 약 90% 감소했다. 이는 DyPRAG-Combine이 더 적은 토큰으로 정답을 생성할 수 있어 추론 비용을 낮추고 중복 정보를 피할 수 있음을 보여준다.
Comparison with Effective RAG Methods
FLARE와 DRAGIN과 같은 효과적인 RAG 방법들과 비교했다. LLaMA3-8B에서 2WQA total 성능은 standard RAG 34.55%, DRAGIN 35.69%, FLARE 34.62%인 반면, DyPRAG는 37.25%를 달성했다. DyPRAG-Combine은 45.17%로 모든 방법을 크게 능가했다. 이는 DyPRAG가 test-time knowledge enhancement에서 다른 효율적인 RAG 방법들보다 우수함을 보여준다.
댓글 남기기