[논문리뷰]D2LLM:Decomposed and Distilled Large Language Models for Semantic Search(ACL, 2024)

Date:     Updated:

카테고리:

Zihan Liao, Hang Yu, Jianguo Li, Jun Wang, and Wei Zhang. 2024. D2LLM:Decomposed and Distilled Large Language Models for Semantic Search. In Proceedings of the 62nd Annual Meeting of the Association for Computational Linguistics (Volume 1: Long Papers), Lun-Wei Ku, Andre Martins, and Vivek Srikumar(Eds.). Association for Computational Linguistics, Bangkok, Thailand, 14798–14814.https://doi.org/10.18653/v1/2024.acl-long.791

Problem Statement

[BERT Style Bi-Encoder]: 쿼리와 문서에 대해 독립적으로 벡터를 생성하여 효율성이 뛰어나지만, 미묘한 의미 차이를 포착하지 못해 정확도가 떨어진다. 또한, 방대한 양의 데이터에 대한 복잡한 다단계 학습 과정이 필요하다는 한계가 존재한다.

[GPT Style Cross-Encoder]: 쿼리와 문단을 하나의 입력으로 결합하여 처리함으로써 정확도가 높다. 특히, 방대한 양의 지식을 사전 학습하여 도메인별 사전 학습이 필요 없고, 새로운 도메인에도 강건한 zero-shot 학습 능력을 보인다. 하지만, 문서의 벡터를 미리 계산할 수 없어 새로운 쿼리-문서 쌍마다 계산을 다시 수행해야하므로 실시간 처리에는 비효율적이다.



Methodology

Overview

1

D2LLM은 Cross-Encoder의 강점인 정확도와 Bi-Encoder의 효율성을 결합한 시맨틱 검색 모델이다. 먼저 LLM 기반의 Cross-Encoder를 교사 모델로 설정한다. 교사 모델은 쿼리와 문서를 하나의 입력으로 결합하고, 특별히 설계된 프롬프트(symmetric, asymmetric)를 활용하여 정교한 의미적 관계를 파악한다. 그런 다음, 효율성을 극대화하기 위해 교사 모델을 Bi-EncoderInteraction Emulation Module (IEM)을 포함하는 학생 모델(Student Model)로 분해한다. Bi-Encoder는 PMA(Pooling by Multihead Attention)를 사용하여 쿼리와 문단에 대한 개별 벡터 임베딩을 효율적으로 생성한다. 이와 동시에, IEM은 이 두 임베딩을 결합하여 교사 모델의 복잡한 상호 작용 방식을 모방한다.

학생 모델은 지식 증류(Knowledge Distillation) 과정을 통해 학습된다. 이 과정에는 교사 모델의 점수를 활용하여 관련 샘플에 가중치를 부여하는 Contrastive Imitation, 긍정 및 부정 샘플 간의 미묘한 순위 차이를 맞추는 Rank Imitation, 그리고 교사 모델의 임베딩 관계 패턴을 모방하는 Feature Imitation이 있다. 이처럼 D2LLM은 교사 모델의 지식을 흡수하며 효율성과 정확성 사이의 균형을 효과적으로 달성한다.

Teacher Model (LLM)

1

  • Input: Query-Passage Pair (\(\mathbf{X}_i, \mathbf{X}_j\)) + Prompt(\(\mathbf{P}\))
  • Output: 분류 토큰 임베딩 \(y_{ij}^{\mathcal{T}}\)

먼저 LLM의 zero-shot 학습 능력을 활용하기 위해, 프롬프트 엔지니어링을 적용해 query-passage 쌍을 분석하도록 LLM을 유도하는 특정 프롬프트를 설계한다. 이 프롬프트 \(\mathbf P \in (\mathbf P^{\text{sym}},\mathbf P^{\text{asym}})\)는 대칭(symmetric)과 비대칭(asymmetric) 검색을 위한 두 가지 프롬프트로 구성된다.

  • 대칭 검색(Symmetric Search)
    • 정의: Query와 Passage가 동일한 역할/형태를 가지며, 상호 교환이 가능할 때
    • 예시:
      • “What are the symptoms of the flu?” ↔ “What are the flu symptoms?”
      • 두 문장은 서로 같은 의미를 담고 있으며, 단순히 문장 유사도 측정(semantic similarity)이 목적임
    • 특징: Query와 Passage가 모두 질문(question) 혹은 짧은 진술문(statement)일 수 있음
    • Task 예시: NLI(Natural Language Inference), STS(Semantic Textual Similarity)
  • Asymmetric Search
    • 정의: Query와 Passage가 서로 다른 역할을 가지며, 교환 불가능할 때
    • 예시:
      • Query: “What are the symptoms of the flu?”
      • Passage: “The flu typically causes fever, cough, sore throat, runny nose, muscle aches, and fatigue.”
      • Query는 질문, Passage는 답변/설명 문서로 역할이 다름
    • 특징: Query는 짧고 정보 요청 중심, Passage는 길고 정보 제공 중심
    • Task 예시: 정보 검색(IR), QA Retrieval

LLM에 쿼리-문서 쌍과 프롬프트를 입력시키고, 프롬프트의 마지막 토큰의 hidden state embedding을 classification token embedding으로 사용한다.

$$y_{ij}^{\mathcal{T}} = \text{LLM}(\mathbf{X}_i, \mathbf{X}_j, \mathbf{P})$$

이 마지막 토큰의 히든 스테이트 임베딩 \(y_{ij}^{\mathcal T}\)가 분류 토큰으로 기능하여 쿼리-문서 쌍이 관련 있는지 여부를 표시하며, 프롬프트는 검색 유형(대칭, 비대칭)에 맞도록 적응을 돕는다.

“yes” 혹은 “no”의 확률을 계산하기 위해 LLM의 마지막 레이어의 projection weight matrix \(W^{\mathcal T} \in \mathbb R^{\vert V \vert \times d}\)를 이용한다. 이 가중치 행렬은 모든 vocabulary집합에 대한 임베딩 행렬로, 타겟으로 하는 것은 “yes”, “no” 두 개이기 때문에 실제로 사용되는 것은 \(W^{\mathcal T}[\text{“yes”, “no”}] \in \mathbb R^{2 \times d}\) 이다.

$$z_{ij}^{\mathcal T} = W^{\mathcal T}[\text{“yes”, “no”}]y_{ij}^{\mathcal T}$$
$$s_{ij}^{\mathcal T} = \text{SoftMax}(z_{ij}^{\mathcal T})$$

분류를 위해 얻었던 분류 토큰 임베딩과 가중치 행렬을 곱해줌으로써 최종적으로 “yes”, “no”에 대한 각각의 score를 얻을 수 있다. 최종적으로 score에 softmax를 취해줌으로써 “yes”, “no”에 대한 각각의 확률값을 구할 수 있다. 스코어와 확률값은 student 모델 학습 시 supervision으로 사용된다.

Student Model

1

  • Teacher와 달리 Student는 쿼리와 문서를 독립적으로 인코딩함.
  • Cross-encoder를 사용하는 Teacher와 달리 Bi-encoder 구조를 사용함.
  • 단, 구조만 쿼리와 문서를 각각 임베딩한다는 의미에서 Bi-encoder 구조인 것이지, 실제로 BERT 계열의 encoder를 사용하는 것이 아니라, LLM으로 임베딩을 추출하는 것.

Student Model은 Bi-encoder 구조 + PMA + IEM 조합으로 구성되어 있다. 먼저 쿼리와 문서의 텍스트가 bi-encoder에 각각 입력되면 각각에 대해 토큰 단위 임베딩을 얻을 수 있다. (쿼리를 기준으로 설명)

$$y_i^{\text{agg}} = \text{PMA}_q (\mathbf Y_i) = \text{LN}(h +FFN(h))$$
$$h = \text{LN}(\text{MHA}(q, \mathbf Y_i, \mathbf Y_i) + q)$$

길이가 L인 쿼리 \(\mathbf X_i =[x_i(1), \cdots, x_i(L)]\)을 LLM에 입력시키면 토큰 단위의 임베딩 집합 \(\mathbf Y_i = [y_i(1), \cdots, y_i(L)]\)을 얻을 수 있다. PMA (Pooling by Multihead Attention)모듈은 \(\mathbf Y_i\)를 입력받아 토큰 단위 임베딩들을 aggregation한다. 이를 통해 LLM의 원래 훈력 목적(다음 토큰을 예측)과 충돌하지 않으면서도, 문장 전체의 의미를 유연하게 반영하는 임베딩을 생성할 수 있다. 위의 수식에서 MHA는 멀티 헤드 어텐션을 의미하고, LN은 Layer Norm, \(q\)는 PMA의 쿼리를 의미한다. \(q\)는 learnable vector로써, 앵커로서 토큰과의 유사도에 따라 가중합을 학습적으로 정한다.

$$y_{ij}^{\mathcal S} = f_2(f_1([y_i^{\text{agg}}, y_j^{\text{agg}}]))$$

IEM (Interation Emulation Module)은 PMA를 통해 생성된 쿼리와 문서의 개별 임베딩 벡터를 입력 받는다. 이 임베딩들을 병합한 후 선형 레이어에 입력하여 쿼리와 문서 간의 관계를 모델링한다. 기존의 Bi-Encoder는 쿼리와 문서 임베딩 간의 코사인 유사도를 사용하여 관련성을 판단하는 반면, IEM은 MLP를 통해 이러한 관계를 더 유연하고 정교하게 포착한다. 특히, IEM은 대칭형 검색(e.g., NLI, STS)과 비대칭형 검색(e.g., IR)의 미묘한 차이를 처리하기 위해 두 개의 전용 브랜치를 가지고 있다.

요약하자면, IEM은 쿼리와 문서 임베딩을 독립적으로 생성하여 효율성을 유지하면서도, Cross-Encoder처럼 두 입력 간의 복잡한 상호작용을 에뮬레이션함으로써 정확도를 높이는 역할을 수행한다. 이는 Bi-Encoder의 효율성과 Cross-Encoder의 정확도라는 두 가지 장점을 모두 결합하려는 시도이다.

Training (Knowledge Distillation)

Contrastive Imitation Loss

대조 학습을 위해서는 negative 문서를 정의해야하고, 논문에서는 좀 더 어려운 “하드 네거티브”를 사용하기 위해 BM25를 기반으로 \(k\)개의 문서를 검색해 negative로 활용한다. Contrastive Imitation Loss는 InfoNCE를 기반으로 하며 수식은 다음과 같다.

$$\mathcal L^{CL} = -\frac{1}{|\mathbb D^+|}\sum_{j\in \mathbb D^+}\log\frac{\exp (s_{ij}^{\mathcal T}z_{ij}^{\mathcal S}/\tau)}{\sum_{k \in \mathbb D^-}\exp((1-s_{ik}^{\mathcal T})z_{ik}^{\mathcal S}/\tau)}$$

이 때, \(s_{ij}^{\mathcal T}\)는 쿼리-문서 쌍에 대한 교사 모델의 “yes” 확률값이고, \(z_{ij}^{\mathcal S}\)는 학생 모델의 해당 쌍에 대한 logit이다.

[긍정 샘플 가중치 부여] 이 손실을 통해 교사 모델이 중요하다고 판단한 긍정 샘플에 더 높은 가중치를 부여할 수 있다. \(s_{ij}^{\mathcal T}\) 값은 교사 모델이 쿼리-문서 쌍이 관련 있다고 판단할수록 커지기 때문에, 학생 모델은 해당 샘플을 더 잘 모방하도록 훈련된다.

[잠재적 정답 샘플에 대한 강건성] 부정 샘플 중 정답 문서와 일부 내용이 일치하거나 유사하지만 정답이 아닌 하드 네거티브가 존재한다. 하지만, 이는 다시 말해 잠재적 긍정 샘플이라고 말할 수 있다. 일반적인 대조학습을 위한 손실 함수는 이러한 하드 네거티브를 잘못된 부정으로 간주하여 학습에 혼란을 주지만, CL 손실의 경우 \(s_{ik}^{\mathcal T}\)점수를 부여하면, 손실 함수의 분모에 있는 (\(1-s_{ij}^{\mathcal T}\)) 값이 작아져 손실이 낮아지기 때문에 학생 모델이 이 샘플을 완전히 관련없는(irrelevant)한 부정 문서로 분류하지 않도록 유도한다.

Rank Imitation Loss

Rank Imiation 손실은 샘플 간의 미묘한 순위 차이를 학생 모델이 모방하도록 유도하며, 이는 CL손실이 다루기 어려운 부분이다. 이 손실함수를 통해 정답 문서와 하드 네거티브 문서사이의 미묘한 뉘앙스 차이를 포착해 학생 모델이 교사 모델의 랭킹을 따라가도록 학습시키는 것이다.

$$\mathcal L^{RI}_{PH} = 1 -\text{corr}(\mathbf{z}_i^{\mathcal T}, \mathbf{z}_i^{\mathcal S}), \quad \mathbf{z}_i^{\mathcal T} = [z_{ij}^{\mathcal T}] \; \text{for} \; j \in\mathbb D^+ \cup \mathbb D_H^-$$

먼저 정답 및 하드 네거티브 샘플의 순위를 모방하기위해 피어슨 상관관계(Pearson Correlation)를 이용한다. 정답 샘플과 하드 네커티브 샘플을 대상으로 교사와 학생 모델의 로짓 간의 피어슨 상관계수를 최대화하여, 학생 모델은 교사 모델이 판단한 중요한 쌍들의 순위 매김을 배우게 된다.

$$\mathcal L_{HI}^{RI} = -\frac{1}{|\mathbb D_{H}^-||\mathbb D_{I}^-|}\sum_{j \in \mathbb D_H^-}\sum_{k \in \mathbb D_I^-}\lambda_{jk}\log\big( \sigma(z_{ij}^{\mathcal S} - z_{ik}^{\mathcal S}) \big)$$

또한, 하드 네거티브와 비교적 쉬운 이지 네거티프의 순위를 모방하기 위한 손실도 정의한다. 즉, 이 손실의 목적은 하드 네거티브(\(\mathbb D_H^-\))와 이지 네거티브(\(\mathbb D_I^-\))을 명확하게 구분하기 위함이다. \(\lambda_{jk}\)는 하드 네거티브 \(j\)와 이지 네거티브 \(k\)사이의 NDCG (normalized discounted cumulative gain)에 기반한 순위 비교 지표이다. 이 손실은 교사가 하드 네거티브에 대해 이지 네거티브보다 더 높은 점수를 부여했을 때만 학생 모델의 순위 정렬을 유도한다.

Feature Imitation Loss

$$\mathcal L^{FI} = \vert\vert \mathbf{r}_i^{\mathcal T} - \mathbf{r}_i^{\mathcal S}\vert\vert$$
$$r_{ijk}^{\mathcal T}= \text{sim}(\mathbf(y_{ij}^{\mathcal T}, y_{ik}^{\mathcal T})), \quad \forall j,k \in \mathbb D^+ \cup \mathbb D^-_{H}$$

Feature Imiataion 손실은 교사 모델의 풍부한 내부 표현(feature)을 활용하요, 교사 모델과 학생 모델의 유사도 행렬간의 차이를 최소화하는 것을 목표로 한다 (유사도 분포를 매칭). \(r_{ijk}^{\mathcal T}\)는 교사 모델의 분류 토큰 임베딩 간의 코사인 유사도를 기반으로 한 유사도 행렬이고, \(r_{ijk}^{\mathcal S}\)는 학생 모델의 유사도 행렬이다.

$$\mathcal L = \mathcal L^{CI} +\alpha\mathcal L^{RI}_{PH} + \beta\mathcal L^{RI}_{HI} + \gamma\mathcal L^{FI}$$

최종 손실 함수를 기반으로 학생 모델의 PMA, IEM, 선형 레이어, LLM(for Bi-Encoder)가 훈련된다. LLM은 peft와 LoRA를 활용하여 학습한다.



Experiments

Main Results

1

자연어 추론(NLI) 과제에서 D2LLM은 소규모 데이터(0.3M)로 학습했음에도 불구하고, 동일 조건의 다른 베이스라인보다 월등히 높은 성능을 기록했다. 예를 들어 OCNLI와 CMNLI에서 D2LLM은 각각 ACC 0.7889 / 0.8014를 기록했으며, 이는 LLaRRA나 Udever보다 10% 이상 향상된 결과이다. 특히 BGE가 100M 대규모 데이터로 학습해 얻은 0.7266(OCNLI)보다도 더 높은 점수를 보였으며, 결국 소규모 데이터로도 기존 강력한 모델을 능가하는 효율성을 입증했다

Ablation Study

1

  • -CI+CL: Contrastive Imitation Loss를 InfoNCE로 교체
  • -RI_PH: Rank Imitation Loss 중 positive/hard negative 사이의 순위 모방 손실 제거
  • -RI_HI: Rank Imitation Loss 중 easy/hard negative 사이의 순위 모방 손실 제거
  • -FI: Feature Imitation Loss 제거
  • -PMA+mean: PMA를 Mean Pooling으로 대체
  • -PMA+[EOS]: PMA를 [EOS] 토큰 사용으로 대체
  • -IEM+cos: IEM을 cosine 유사도로 대체
  • D2LLM-1.8B: 교사 모델을 7B에서 1.8B로 축소

Table 3의 Ablation Study 결과를 보면, 제안된 각 손실 함수와 모듈이 성능에 기여하는 정도가 뚜렷하게 드러난다. Contrastive Imitation(CI)를 표준 contrastive loss로 대체했을 때 성능이 3.59% 감소하였으며, 특히 Rank Imitation에서 양성과 하드 네거티브의 순위를 맞추는 RI_PH를 제거했을 때는 6.57%라는 가장 큰 성능 저하가 나타났다. 또한 하드와 쉬운 네거티브를 구분하는 RI_HI를 제거하면 4.92%의 성능 감소가 뒤따랐고, Feature Imitation을 제거했을 때도 2.17%의 하락이 발생하였다. Pooling 모듈을 단순 mean pooling이나 [EOS] 기반으로 바꿨을 때도 각각 1.71%, 1.51% 감소하였으며, Interaction Emulation Module(IEM)을 cosine similarity로 대체했을 경우에는 3.83%의 성능 저하가 관찰되었다. 특히, 모델의 파라미터를 바꾼 D2LLM-1.8B를 제외하고는 -RI_PH 실험에서 성능 하락폭이 6.57% 가장 컸다. 이는 학생 모델이 교사 모델의 핵심적인 순위 매김을 모방하는 것이 매우 중요함을 말해준다.

문장 의미 유사성(STS) 실험

1

문장 의미 유사성(STS) 과제에서는 D2LLM이 대부분의 경우(14개 중 10개) 다른 베이스라인을 능가했지만, 일부 데이터셋에서는 BGE-ft가 앞서기도 했다. 또한 원래 교사 모델(LLM-ce)은 STS에 특화된 파인튜닝을 하지 않아 BGE보다 낮은 성능을 보였다. 그러나 LLM-ce를 소규모(0.3M)로 STS에 맞춰 파인튜닝한 LLM-ce-ft를 교사로 삼아 D2LLM-ft를 학습하자, 기존 D2LLM 대비 1.69% 향상되었고, 모든 경쟁 모델보다 평균 최소 17.42% 이상 우위를 달성했다. 이는 소규모 데이터 기반의 적응(finetuning)만으로도 교사-학생 모두 성능을 크게 개선할 수 있음을 보여준다.

Runtime Analysis

1

Figure 2는 각 모델의 런타임 효율성을 분석한 결과입니다. 이 분석은 쿼리 벡터화, 관련성 점수 계산, 문단 순위 매기기 시간으로 나뉜다.

  • Cross-encoder (LLM-ce): LLM-ce는 각 문단을 쿼리와 연결하여 개별적으로 모델을 통과시켜야 하므로, 관련성 점수 계산에 매우 긴 시간이 소요된다.
  • Bi-encoders (BGE, RocketQAv2 등): 이 모델들은 문단 임베딩 벡터를 미리 계산하여 데이터베이스에 저장하므로, 쿼리 벡터화와 코사인 유사도 계산에만 시간을 사용한다.
  • D2LLM: D2LLM은 Bi-encoder와 유사하게 쿼리 벡터화 시간을 가지지만, 유사도 계산에 MLP를 사용하기 때문에 코사인 유사도만을 사용하는 Bi-encoder보다 관련성 점수 계산 시간이 약간 더 길 수 있다.

결과적으로 D2LLM은 Cross-encoder에 비해 압도적으로 효율적이며, Bi-encoder에 비해 약간의 시간만 추가로 소요된다. 이는 D2LLM이 Cross-encoder의 정확성과 Bi-encoder의 효율성 사이의 균형을 효과적으로 달성했음을 보여준다.

Retrieval Performance

1

정보검색(IR) 과제에서는 D2LLM이 다양한 도메인(T2Retrieval, DuRetrieval, mMARCO 등)에서 대부분 Udever, LLaRRA, RocketQAv2보다 우수한 성능을 기록했다. 특히 CovidRetrieval과 MedicalRetrieval 같은 특수 도메인에서는 8% 이상 성능 개선을 보였다. 다만, 일부 세부 지표에서는 BGE(100M 학습)보다 낮은 결과를 보였으며, 이는 대규모 데이터의 이점 때문으로 해석된다. 하지만 같은 0.9M 데이터 조건에서 비교했을 때는 D2LLM이 가장 안정적인 성능을 보였다.



NR 카테고리 내 다른 글 보러가기

댓글 남기기