[논문리뷰]HittER: Hierarchical transformers for knowledge graph embeddings

Date:     Updated:

카테고리:

Chen, S., Liu, X., Gao, J., Jiao, J., Zhang, R., & Ji, Y. (2021, October 6). HittER: Hierarchical Transformers for Knowledge Graph Embeddings. ArXiv.org. https://doi.org/10.48550/arXiv.2008.12813

Problem Statement

1

1. Knowledge Graph는 여전히 불완전(Incomplete)하고 Noisy하다.

Knowledge Graph는 Heterogeneous graph의 일종이다. 그리고 그래프 형식으로 정보를 저장할 때 사람이 직접 수동으로(manually) 넣어줘야하고, 이는 데이터 유실로 이어질 수 있다.

The first major problem is the incompleteness and noise in the knowledge graph.
This issue arises from the fact that much of the information in the graph is
manually inputted by humans so that it leads to potential loss of data.

2. 그래프 임베딩 모델(Graph Embedding Method, KGE)의 한계

그래프 임베딩은 기본적으로 하나의 벡터 공간에 그래의 구조 정보를 표현하는 것에 초점을 맞춘다. 예를 들어, head와 relation의 합을 tail이라 정의할 수 있다(TranE 모델). 그래프 임베딩 방식은 그래프의 구조 정보만을 활용한다. 한 노드를 중심으로 이웃 노드들의 정보를 취합(aggregation)한다. 다시 말해, 저차원의 벡터 공간에서 지리적인, 구조적인 특성을 이용하는 것이 KGE이다.

이렇게 할 경우 취합된 모든 정보를 벡터 공간에서 하나의 단일 벡터에 저장해야 하기떄문에 정보의 유실이 생길 수 있다(Seq2Seq 모델의 단점과 비슷: context vector에 정보를 압축). 정리하면 그래프 임베딩 방식(KGE, Knowledge Graph Embedding)의 경우 1. 오직 그래프의 구조 정보만을 활용하며 2. 하나의 단일 벡터에 정보를 압축해서 정보의 손실이 발생하는 한계점들이 존재한다.

2

위의 Sub-graph로 Link prediction을 진행한다고 가정해보자. Incomplete Triple의 head가 Sunnyvale이고 relation이 country일 때 tail을 찾는다고 가정해보자. 그래프 임베딩의 경우 Sunnyvale에 저장된 구조적인 정보만을 활용하여 추론을 진행한다. 다시 말해, Sunnyvale 엔티티를 기준으로 이웃의 정보를 모으는 것이다. 반면 이웃 노드들의 정보와 더불어 context 정보를 활용하면 state와 California에 대한 정보도 활용할 수 있다. 즉, 이웃 노드들의 정보를 활용하기위해 완전한 트리플을 사용하고, 여기서 California라는 context정보를 읽어내어 이용하여 좀 더 수월하고 정확학 추론이 가능하다.

Knowledge Graph에는 구조적인 정보뿐만 아니라 Context information도 존재한다. 이를 활용하기 위해 Graph Neural Network(GNN)이나 Attention-based 방식도 연구가 진행되었다. 하지만 Graph의 경우 Layer수를 늘리면 Oversmoothing현상이 발생하기 때문에 제약 조건이 존재한다. HittER 모델의 시작은 “어떻게 하면 이웃 노드의 정보(KGE)와 Context정보(Attention-based)를 모두 이용하면서 Deep한 모델을 만들수 있을까?” 라는 질물에서 출발한다.



Method

1. Overview

3

HittER는 “Deep Hierachical Transformer“모델로 엔티티와 릴레이션의 representation을 이웃들로부터 정보를 취합해 동시에 학습하는 모델이다. Attention을 하는 Transformer를 이용해 기본적으로 그래프의 context정보를 학습할 수 있으며 동시에 이웃 엔티티들의 정보또한 학습할 수 있게 만든 모델이다. HittER는 두 가지의 Transformer가 계층적으로 쌓여진 형태이다.

2 Model Architecture

2.1 Simple Context-Independent Transformer

4

먼저 간단하게 Notation을 살펴보면, 저자는 head, relation, tail을 각각 subject, predictate, object라고 표현을 했다. 또한 Incomplete Triple에서는 내가 알고 있는 엔티티를 Source entity라고 지칭하며 Knowledge Graph Completion을 통해 찾아야 하는 엔티티를 Target entity라고 지칭한다.

  • Triple: <\(head, relation, tail\)> = <\(source, predicate, object\)>

Simple Context-Independent Transformer는 Bottom block의 맨 앞단 Transformer 하나 만을 이용해서도 Link prediction을 할 수 있다. 즉, Bottom block의 점선으로 된 박스 부분만을 사용해 Link prediction을 진행한다. 이 Transformer의 입력은 [CLS]토큰의 임베딩과 source entity의 임베딩(\(e_{src}\)), predicate의 임베딩(\(r_p\))이 된다. 모델은 Link preiction에 필요한 scoring function으로 Transformer의 encoder방식(Multilayer + Bidirectional)을 사용하여 간단한 방법으로 수행할 수 있다.

[CLS] 토큰 임베딩과 \(e_{src}\), \(r_p\)를 각각 랜덤하게 초기화 한 후 세 개의 임베딩을 BERT와 마찬가지로 Transformer Encoder의 입력으로 집어넣는다. 그러면 Triple이 Plausible한지 아닌지를 판단해주는 결과값인 \(M_{e_{src}}\)가 출력된다. 같은 방식으로 이웃들 중 모든 후보 엔티티에 대하여 plausible score인 \(M_{e_{src}}\)를 구한 후 softmax 함수를 이용해 nomalization한다. 학습과정을 정리하자면 다음과 같다.

  1. True triple의 target entity와 \(M_{e_{src}}\)의 내적(Dot-product)를 구한다.
  2. 같은 방식으로 다른 모든 후보 엔티티들의 score를 계산하고 SoftMax 함수를 이용해 normalize한다.
  3. \(\mathcal{L_{LP}} = -log\,p(e_{tgt} \vert M_{e_{src}})\)를 사용하여 모델을 훈련한다.


2.2 Bottom Block: Entity Transformer

5

위의 Simple link prediction의 경우 BERT와 다를바가 없다. 다시 말해, Graph의 Context정보를 이용하지 않는다. 또한 스텝마다 하나의 트리플의 임베딩만 학습하므로 구조 정보또한 제대로 반영이 되지 않는다. 따라서 Source entity의 이웃 정보들을 활용해야한다. Bottom Block은 가능한 모든 엔티티-릴레이션 쌍으로부터 유용한 feature들을 모두 뽑아낸다.

6

먼저 Incomplete triple의 \(e_{src}\)와 \(r_p\)쌍을 첫 번째 transformer의 입력으로 넣는다. 다음으로 소스 엔티티의 이웃 엔티티들과 relation의 임베딩을 다음 transformer의 입력으로 넣는다. 이렇게 함으로써, 엔티티-릴레이션 쌍을 직접 top block으로 전파시키지 않고 두 개의 input을 하나로 변환함으로써 running time을 줄일 수 있다.


2.3 Top Block: Context Transformer

7

Top block에서는 bottom block의 출력값과 [GCLS]라는 special token의 임베딩 정보를 취합한다. Top block의 입력은 총 세 가지이다. 1) [GCLS] Special Token embedding : \(E_{[GCLS]}\) 2) Intermedjate source entity embedding : \(M_{e_{src}}\) 3) Other neighborhood entity embedding : \(M_{e_1}, M_{e_2}, \cdots\)

Loss function또한 바뀐다. 기본적으로 Bottom block과 마찬가지로 NLL의 형태를 띄지만 디테일이 다르다. Loss: \(\mathcal{L_{LP}} = -log \, p(e_{tgt} \vert T_{[GCLS]})\)


2.4 Balanced Contextualization: Masked Entity Prediction

위의 방식대로 context information을 주는 것은 종종 여러가지 문제를 유발한다.

1) Source entity가 link prediction을 하기에 충분한 정보를 가지고 있는 경우 추가적으로 들어가는 contextual information은 노이즈가 된다. 다시 말해, 이미 많은 정보를 내포한 Source entity는 다른 정보를 필요로 하지 않는다는 것이다. 2) 많은 context 정보가 Source entity로부터 온 정보를 downgrade시키거나 쓸데없는 상관관계 정보를 포함하기에 overfitting이 발생할 수 있다.

따라서 Contextual information과 Source entity information이 균형을 이루어야 한다. 논문에서는 두 문제를 해결하는 방법을 각각 제시한다.

Solution of Problem 1

8

첫 번째 문제를 해결하기 위해 모델을 훈련하는 동안 특정 확률로 입력 Source entity를 1)MASK 토큰으로 바꾸거나, 2)랜덤하게 선택된 엔티티로 바꾸거나, 3)바꾸지 않고 그대로 둔다. 이 특정 확률(Certain Probability)는 dataset마다 특화된 Hyperparameter이다. 이 과정을 통해 모델은 contextual representation을 학습할 수 있다.

Solution of Problem 2

9

두 번째 문제를 해결하기 위해서는 모델이 혼동을 일으키는 엔티티를 발견하도록 훈련시켜야 한다. 따라서 Source entity에 상응하는 출력 임베딩, \(T_{e_{src}}\)에 하나의 classification layer를 두어 Correct source entity인지 예측하도록 한다.

Loss function또한 정의를 다시해야 한다. Source entity가 맞는지 아닌지를 판단해야 하므로 Loss function은 \(\mathcal(L_{MEP})\)로 정의되며 학습을 위한 최종 Loss function은 \(\mathcal{L_{LP}} + \mathcal{L_{MEP}}\) 형태로 바뀐다.

첫 번째 Soulution은 항상 Beneficial하다. 하지만 두 번째 Soulution은 Source entity의 정보를 강조할 때만 필요하지, 양질의 contextual information이 있을 때에는 불필요하다. 따라서 dataset마다 다른 전략을 취해야 한다.

저자는 Masked Entity Preidiction에 두 가지 추가적인 전략을 제시한다.

10

1) Uniform neighborhood sampling
먼저 학습 샘플에 나타나는 이웃들 중 일부에 uniform neighborhood sampling을 하는 것이다. 이렇게 함으로써 일종의 data augmentor처럼 작동할 수 있으며 edge dropout regualarization과 같은 효과를 낼 수 있다.

2) Removing ground truth target entity from the source entity’s neighborhood during training 두 번째로는 학습 중 소스 엔티티의 이웃으로부터 나타나는 정답 타겟 엔티티를 삭제하는 것이다. 그렇지 않으면 train-test 불일치를 극악으로 만들어 낼 수 있다. 왜냐하면 학습 중 소스 엔티티의 이웃에 정답 타겟 엔티티가 있지만, 테스팅중에는 거의 찾기 힘들기 때문이다.



Experiment & Result

1. Dataset

11

Knowledge Graph Completion에서 자주 쓰이는 Banchmark dataset이 쓰였다.

12

FB15K-237과 WN18RR의 두 가지 dataset에 대한 link prediction 실험 결과이다. 실험에 사용한 HittER은 3 layers entity transformer 와 6 layers context transformer 로 구성되어 있다. 균등하게 샘플링된 neighbor entities의 maximum number는 dataset 각각 50개, 12개이다. 평가 지표로 MRR과 Hits@k를 사용하였다. 실험 결과는 HittER이 가장 좋은 성능을 보이고 있다.

3. Ablation Study

13

contextual information과 Balancing techniques 모두 사용하지 않은 경우와(None), contextual information은 사용하였으나 Balancing technique은 사용하지 않은 경우(Unbalanced) 그리고 두가지 다 사용한 경우(Balanced)를 비교한 실험을 진행하였고, 실험 결과 두가지 모두 사용한 경우의 결과가 가장 좋았다. None은 다시 말해 Simple Link prediction(Context-independent transformer)을 말한다. 또한 Unbalanced의 경우 Bottom block과 top block을 모두 사용하지만 MEP를 하지 않은 경우이고, Balanced는 MEP까지 한 경우를 말한다.이 실험의 결과로 알 수 있는 것은 Context information을 Unbalanced에서와 같이 직접적으로 모델에 가하는 것은 큰 의미가 없으며 반면 contextual information과 Source entity information의 균형을 맞춘, MEP를 푼 Balanced에서는 유의미한 MRR score를 얻을 수 있다는 것이다.

두 번째 실험은 Context information을 주었을때와 안주었을 때, 그리고 hop수에 따른 MRR score를 비교한 것이다. WN18RR dataset을 hop수에 따라 그룹화하여 분류하고 모델을 학습시킨 후 link prediction의 MRR을 측정한 것이다. 실험 결과, Long path를 가질 수록 정보를 취합하기 어려워 MRR 점수가 떨어지는 것을 알 수 있다. 또한 낮은 hop수에서는 context정보가 유의미한 성능 차이를 만들어내지만 Long tail이 될 수록 context가 성능에 미치는 영향력이 따라서 감소한다.

14

다음으로는 Question Answering에 관한 ablation study이다. 기존의 BERT는 구조 정보가 반영이되지 않는다. 반면 HittER는 이웃의 정보또한 bottom block을 통해 취합할 수 있다. KG-BERT만으로 학습을 시켰을때와, Attention시 HittER의 출력값을 가미해 학습시켰을 때의 성능을 비교한 것으로 HittER를 추가적으로 사용했을 때 Text정보와 Structure 정보 모두 활용하기에 더 좋은 성능이 나오는 것을 확인할 수 있다.

BERT의 각 레이에서 기존의 Self-attention module 이후에 cross attention module을 추가하였다. cross attention에서 Query는 BERT의 이전 레이어에서 온 값을 사용하고, KEY와 VALUE는 HittER layer의 output을 사용한다. 두 가지 QA datasets에 대하여 실험을 진행하였다. HittER의 INPUT으로 사용하기 위해 두 데이터셋에 있는 각각의 질문은 context entity와 inferred relation(between context entity and answer entity)으로 라벨링되었다.

FB15K-237 DATASET을 사전 학습한 HittER모델을 사용하였는데, QA datasets에 있는 대부분의 질문들이 FB15K-237의 knowledge와 관련이 없다. 따라서 논문에서는 context 와 answer entity가 FB15K-237과 QA datasets 모두에 존재하는 filtered setting 에서의 실험을 추가로 진행하였다. 실험 결과는 BERT만 사용했을 때 보다 HittER을 함께 사용했을 때 QA accuracy가 더 좋게 나오고 있다.



Contriubution

  1. A model that applies Transformer to a multi-relational knowledge graph is proposed.
    • 트랜스포머를 활용해 이웃의 구조 정보와 context정보를 모두 활용할 수 있는 모델을 제안
  2. The natural language processing model was applied to the KG and showed good performance on the benchmark dataset.



GR 카테고리 내 다른 글 보러가기

댓글 남기기