[그래프 AI]GNN이 푸는 Task와 Pooling
카테고리: Graph
Prediction GNN
GNN Training Pipeline
GNN(Graph Neural Network)은 그래프 형태의 데이터를 처리하여 각 노드, 엣지, 또는 그래프 전체에 대한 예측을 수행하는 모델이다. GNN Pipeline을 보면, 입력 그래프가 GNN 모델로 들어가서 노드 임베딩을 생성하는 과정을 거친다. 이 노드 임베딩들은 Prediction Head로 전달되어 최종 예측값을 산출한다. 이 예측값은 평가 지표를 통해 평가되고, Loss 함수를 통해 학습 과정에서의 오차가 계산된다. 이 과정에서 중요한 부분은 노드 임베딩이 생성된 후 이를 기반으로 하는 예측 방식이다.
GNN을 사용하여 풀 수 있는 Task는 크게 세 가지 카테고리로 구분된다.
- Node-level prediction: 각 노드에 대한 예측을 수행하는 작업이다. 예를 들어, 노드가 속한 카테고리를 분류하는 Classification이나, 노드의 속성 값을 예측하는 Regression이 가능하다.
- Edge-level prediction: 두 노드 간의 관계, 즉 엣지의 유무를 예측하는 작업이다. 노드 임베딩의 조합을 통해 엣지를 예측할 수 있으며, 이때 다양한 방법이 사용된다.
- Graph-level prediction: 전체 그래프에 대한 예측을 수행하는 작업이다. 주로 그래프 자체의 속성을 예측하거나 분류할 때 사용된다.
노드 임베딩을 기반으로 각 Task는 다르게 정의된다. 특히 Node-level과 Edge-level에서 K-way prediction이 사용되며, 이는 노드나 엣지를 여러 카테고리로 분류하는 방식이다.
1. Node level prediction
GNN 학습 이후, 각 노드는 d차원의 노드 임베딩을 얻게 된다. 이 노드 임베딩은 해당 노드의 속성을 반영하는 벡터 표현이다. Node-level prediction에서는 이 노드 임베딩을 이용해 K-way prediction을 수행한다. 예를 들어, 다음과 같은 두 가지 예측을 할 수 있다:
- Classification: 주어진 노드가 여러 카테고리 중 하나에 속할 때, k개의 카테고리 중 하나로 노드를 분류하는 작업이다. 예를 들어, 각 노드를 특정 클래스에 배정하는 것이다.
- Regression: 노드에 대한 연속적인 값을 예측할 때, 예를 들어, k개의 타겟 값을 예측할 수 있다.
이를 수식으로 나타내면 위와 같다. 여기서 \(W^{(H)}\)는 노드 임베딩을 예측 값으로 매핑하는 파라미터이며, 이를 통해 Loss를 계산하여 학습이 가능하다.
2. Edge level prediction
Edge-level prediction은 두 노드 간의 관계, 즉 엣지의 존재 여부를 예측하는 작업이다. 이때, 두 노드 \(u\)와 \(v\)의 임베딩 \(h_u^{(L)}\)와 \(h_v^{(L)}\)를 이용하여 K-way prediction을 수행할 수 있다. 즉 \(\text{Head}_{edge}(h_u^{(L)}, h_v^{(L)})\) 와 같이 정의되며, 여기서 두 임베딩(\(h_u^{(L)}, h_v^{(L)}\))을 어떻게 결합할 것인지에 대해 여러 가지 방법으로 나뉜다. 대표적으로 Combination + Linear(선형 변환) 방식과 Dot product(내적)방식이 있다.
1) Concatenation + Linear 방식
- Concatenation: 두 노드 임베딩 \(h_u^{(L)}\)와 \(h_v^{(L)}\)를 이어 붙인다.
- Linear: 이어 붙인 임베딩을 Linear layer를 통해 k차원의 임베딩으로 매핑한다.(k-way prediction)
이를 수식으로 나타내면 다음과 같다.
여기서 Concatenation을 통해 얻은 임베딩을 선형 변환하여 엣지의 존재 여부를 예측할 수 있다.
2) Dot Product(내적) 방식
- 두 노드의 임베딩을 내적으로 결합하여 예측을 수행하는 방식이다. 이 방식은 주로 1-way prediction(즉, 엣지의 존재 유무 예측)에 사용된다.
- K-way prediction에 확장하려면, Multi-Head를 사용하는 방식으로 확장할 수 있다. 수식은 다음과 같다.
여기서 K-way prediction의 경우, Multi-head attention과 유사하게 각 Head에서 별도의 가중치를 학습하고 이를 내적하는 방식을 적용할 수 있다. 두 방식 모두 Edge-level prediction에서 중요한 역할을 하며, 각각의 장단점이 있다. Concatenation은 좀 더 일반적인 방식이며, Dot Product는 계산이 단순하여 효율적일 수 있다.
3. Graph level prediction
모든 노드 임베딩을 이용해 prediction을 한다. 여기서 Graph classification을 수행하는 Head는 GNN의 Aggregation function과 비슷하다. 하지만 GNN에서는 target 노드의 이웃들만 aggregation input으로 받는 반면, Graph-level prediction의 Head는 모든 노드의 임베딩을 입력으로 받는다. 두 함수 모두 정보를 Aggregation한다는 공통점이 있다.
Global Pooling
- \(\text{Head}_{\text{graph}}(\cdot)\)는 GNN layer의 \(\text{AGG}(\cdot)\)와 동일하다.
- Global mean/max/sum pooling을 이용한다.
- Global mean pooling: \(y_G = Mean\left( \{ h_v^{(L)} \in \mathbb{R}^d, \forall v \in G \} \right)\)
- Global max pooling: \(y_G = Max\left( \{ h_v^{(L)} \in \mathbb{R}^d, \forall v \in G \} \right)\)
- Global sum pooling: \(y_G = Sum\left( \{ h_v^{(L)} \in \mathbb{R}^d, \forall v \in G \} \right)\)
하지만 이러한 Pooling 방식은 크기가 작은 Small-Scale 그래프에서만 적용 가능하다. 크기가 큰 Large-Scale 그래프에서는 다른 방식을 사용한다. 계층적 mean pooling 혹은 Hierarchical Global Pooling이라는 방식을 사용하며, 모든 노드의 임베딩을 계층적으로 취합하는 방식이다. 예를 들어, ReLU(Sum())을 통해 취합하며, 첫 두 노드와 마지막 두 노드를 각각 취합한 후 층을 올려가며 반복적으로 진행한다.
크기가 큰 그래프에서 Global Pooling을 사용하지 않는 이유는 대표적으로 정보 손실(Lose information)을 유발할 수 있기 때문이다. GNN에서 Global Pooling을 하는 이유는 출력 값으로 나오는 임베딩의 크기를 고정시켜주기 위함이다. 하지만 이 과정에서 mean, max, sum 같은 연산을 취하는데, 이러한 방식은 노드 수가 많은 Large-Scale 그래프에서 문제가 발생할 수 있다.
- Ex
- G1 \(\to \{-5,0,2,3\} \to Sum \to 0\)
- G2 \(\to \{-50,0,20,30\} \to Sum \to 0\)
두 그래프의 구조는 다르지만, Pooling 값이 0으로 동일해질 수 있다. 이는 정보 손실을 유발하는 중요한 문제이다.
Hierarchical Pooling
Hierarchical Pooling은 Large-Scale 그래프에서 사용되는 방식으로, 모든 노드 임베딩 값을 계층적으로 Aggregation하여 예측을 진행한다. 이는 전통적인 Global Pooling 방식이 Large-Scale 그래프에서 발생하는 정보 손실 문제를 해결하기 위한 방법이다.
Hierarchical Pooling의 주요 아이디어는 노드 임베딩을 계층적으로 결합하여 최종 예측값을 얻는 것이다. 이를 위해 각 노드의 임베딩 값을 Sum, Mean, Max와 같은 연산을 통해 취합한 뒤, 그 결과를 ReLU 활성화 함수에 적용하여 정보 손실을 줄이면서 예측을 수행한다. Hierarchical Pooling의 단계별 과정은 다음과 같다.
[예시 1: 노드 임베딩 값이 \(\{-5, 0, 2, 3\}\)인 경우]
1) 첫 두 노드와 마지막 두 노드를 각각 Aggregate한다.
- 식: \(y_a = ReLU(Sum(\{-5, 0\})) = 0\)
- 식: \(y_b = ReLU(Sum(\{2, 3\})) = 5\)
2) 두 결과를 다시 Aggregate한다.
- 식: \(y_G1 = ReLU(Sum(\{y_a, y_b\})) = 5\)
[예시 2: 노드 임베딩 값이 \(\{-50, 0, 20, 30\}\)인 경우]
1) 첫 두 노드와 마지막 두 노드를 각각 Aggregate한다.
- 식: \(y_a = ReLU(Sum(\{-50, 0\})) = 0\)
- 식: \(y_b = ReLU(Sum(\{20, 30\})) = 50\)
2) 두 결과를 다시 Aggregate한다.
- 식: \(y_G2 = ReLU(Sum(\{y_a, y_b\})) = 50\)
이제 두 그래프 \(G1\)과 \(G2\)는 각기 다른 Hierarchical Pooling 결과를 가진다. 이 방식은 단순 Global Pooling에서 발생할 수 있는 정보 손실 문제를 해결하며, 보다 복잡한 구조에서도 그래프를 명확히 구분할 수 있다. 예를 들어,
- 식: \(G1 \to \{-5, 0, 2, 3\} \to y_G1 = 5\)
- 식: \(G2 \to \{-50, 0, 20, 30\} \to y_G2 = 50\)
이와 같이 Hierarchical Pooling을 통해 두 그래프의 차이를 명확하게 구분할 수 있다.
Hierarchical Pooling in Practice(DiffPool)
DiffPool의 아이디어는 간단핟하다. 먼저 밀집된 노드 임베딩끼리 Aggregation해서 clustering해서 하나의 clustering임베딩을 만들고 Graph pooling을 하고, 이 과정을 반복해서 하나의 임베딩으로 정보가 압축될 때까지 반복한다.
GNN A가 노드 임베딩을 병렬적으로 구하고, GNN B가 clustering을 병렬적으로 해준다. Graph Pooling을 hierarchical하게 할 수 있고, 성능이 많이 개선된 것을 볼 수 있다. 그래프가 커지면, 그래프를 분간하는 것 자체가 매우 어려운 문제가 된다. 따라서, Hierarchical한 방식 등 여러 가지 방법으로 그래프를 classification할 수 있는 방법들이 연구되고 있다.
Reference
[1] CS224W 강의
댓글 남기기