[딥러닝]KAN Network에 대하여

Date:     Updated:

카테고리:

MLP vs KAN

흔히 Multi-layer Perceptron(MLP)를 딥러닝의 근간이라고 한다. MLP는 가중치 행렬 \(W\)를 사용한 ‘선형 변환’과 ‘비선형 활성 함수’를 이용해 임의의 비선형 함수를 근사한다. 이는 ‘보편적 근사 정리(Universal Approximation Theorem, UAT)’로 설명된다. UAT는 인공 신경망(Artificial Neural Network, ANN)이 특정 조건을 만족할 때, 임의의 연속 함수를 원하는 정확도로 근사할 수 있다는 이론이다.

다층 퍼셉트론(MLP)과 같은 신경망 구조는 충분히 많은 노드와 적절한 활성화 함수를 사용하면, 어떠한 함수도 학습할 수 있다. 다시 말해, 임의의 비선형 함수를 주어진 오차 내에서 몇 개의 은닉층으로 구성된 완전 연결 네트워크(fully-connected network, FCN)로 근사할 수 있음을 의미한다.

반면, KANKolmogorov-Arnold Representation Theorem를 기반으로 모델을 전개한다. KAN은 거창해 보이지만, 사실 ‘splineMLP의 조합에 불과하다’고 할 수 있다. 즉, spline 모델과 MLP를 결합한 것에 지나지 않는다.

KAN Network

Kolmogorov-Arnold Representation Theorem(KAT)

먼저 KAT를 정의하기 위해서는 두 가지 가정이 필요하다.

1)(Multivariate) Smooth function: \(f : [0, 1]^n \rightarrow \mathbb{R}\)
2)\(\phi_{q, p}: [0, 1] \rightarrow \mathbb{R}, \;\; \Phi: \mathbb{R} \rightarrow \mathbb{R}\)

$$f(x) = f(x_1, \cdots, x_n) = \displaystyle\sum_{q=1}^{2n+1} \Phi_q \left ( \displaystyle\sum_{p=1}^n \phi_{q, p} (x_p) \right )$$

Kolmogorov-Arnold Representation Theroem(KAT)는 위와 같이 주어진다. 이 정리는 \(n\)차원의 다변량 함수를 \(n\)개의 단변량 함수 \(\phi\)들의 합으로 표현할 수 있다는 것을 의미한다. 머신러닝 관점에서 고차원 함수를 학습하려고 할 때, \(n\)개의 단변량 함수를 학습하는 것으로 귀결되며, 이는 차원의 저주(curse of dimensionality)를 피할 수 있을 것으로 기대할 수 있다. 수식적으로만 보면, KAT는 SVM의 커널 트릭과 의미상 유사하다. KAT는 복잡한 다변수 함수를 단순한 함수들의 조합으로 분해해 고차원 문제를 해결할 수 있다는 점에서 공통점을 가지고 있다.

Simple KAN

1

위 표는 MLP와 KAN을 비교한 것이다. MLP를 그래프 구조로 표현하면, 왼쪽의 (a)와 같이 각 노드는 활성 함수에 대응되고, 각 엣지에는 가중치 행렬 \(W\)가 대응된다. 반면, 아래에서 설명할 KAN의 경우 각 노드는 함수값에 대응되고, 각 엣지에는 함수 \(\phi\)가 대응된다. 또한, MLP는 가중치를 학습하는 반면, KAN은 활성함수를 학습한다는 차이점이 있다.

입력 데이터 \(\mathbf{x_0} = \{x_{0,1}, x_{0,2}\}\)가 2차원으로 주어진다고 가정하면 다음과 같은 수식으로 표현 가능하다.

$$x_{2,1} = \displaystyle\sum_{q=1}^5 \Phi_q(\phi_{q,1}(x_{0,1}) \; + \; \phi_{q,2}(x_{0,2}))$$

이 경우 KAN의 computation graph는 다음과 같다.

1

우선, 그래프의 각 노드(검은색 정점)는 단변량 함수값 혹은 데이터에 대응된다. 입력 데이터의 각 성분 \(x_{0,1}, x_{0,2}\)는 계산 그래프의 첫 번째 노드들로 주어진다. 이후 식 \((1)\)에서 확인할 수 있듯이 각 성분에는 \(\phi_{q,1}: q=1,\ldots,5\)가 대응되기 때문에, \(x_{0,1}, x_{0,2}\)에서 각각 5개의 엣지가 나가는 것을 확인할 수 있다 (그래프의 \(\phi_{0,1,1}\ldots\phi_{0,2,5}\)).

이후, 이들이 합해질 때에는 (식에서 덧셈 부분) \(q\) 값이 동일한 함수값들끼리 더해지기 때문에 \(\phi_{0,1,1} + \phi_{0,2,1} \ldots\) 끼리 각각 더해진다. 이를 통해 \(x_{1,1}, \ldots, x_{1,5}\)의 5개 성분을 갖는 은닉층이 구성된다.

마지막으로, \(\sum_{q}\Phi_{q}\) 부분의 연산이 이루어지기 위해 각각의 은닉층 성분에 함수값을 더하고, 이것이 출력 \(x_{2,1}\)로 구성된다. (위 그림에서는 표기의 일관성을 위해 \(\phi, \Phi\)를 구분하지 않고 \(\phi\)로 나타낸 것 같다.)

$$x_{2,1} = \displaystyle\sum_{q=1}^5 \Phi_q(\phi_{q,1}(x_{0,1}) \; + \; \phi_{q,2}(x_{0,2}))$$

이를 통해, 이 수식이 결론적으로 2-layer KAN에 대응된다는 것을 알 수 있다. 또한, 각각의 함수 \(\phi_{q, p}\), \(\Phi\)들을 학습이 가능해야 하므로 spline을 이용하여, 각 함수들의 계수(coefficient)를 학습한다. KAN에서는 B-spline을 이용한다.

1) KAN Layer

앞서 살펴본 KAN의 구조는 단일한 KRT에 대응되는 것으로, \(2n + 1\)개의 너비를 갖는 은닉층(hidden layer)이 사용되었다. 이는 MLP에 비하면 매우 얕고 단순한 구조이다. 따라서 MLP에 대응하기 위해 KAN도 여러 개의 은닉층을 쌓아야 한다. 입력 벡터가 \(N\)차원이고 출력 벡터가 \(M\)차원인 KAN Layer는 다음과 같이 1차원 함수 \(\phi_{q, p}\)로 구성된 행렬로 정의할 수 있다.

$$\Phi = \{\phi_{q, p}\}, p=1,2,\dots, N, \;\;\;q=1,2,\dots, M$$

앞서 살펴본 예시는 \(2 \rightarrow 5 \rightarrow 1\)차원으로 순차적으로 변환되므로, 2개의 KAN layer로 구성된다고 할 수 있다. 이는 논문에서 [n, 2n+1, 1]-KAN으로 표현된다. 이를 바탕으로 KAN layer를 순차적으로 쌓아 Deep KAN을 구성할 수 있다.


2) Activation Function

KAN은 비선형 활성화 함수 (non-linear activation function) 자체를 학습한다. 그리고 이 함수들을 학습 가능하게 하기 위해 B-spline을 이용한다. KAN에서는 이를 위해 다음과 같이 설정한다.

S1) Residual activation function

$$\phi(x) = w(b(x) + \text{Spline}(x))$$
$$b(x) = \text{SiLU}(x) = \frac{x}{1 + e^{-x}} \;\;\; \text{Spline}(x) = \displaystyle\sum_ic_iB_i(x)$$

이 식에서 \(b(x)\)는 기저 함수(basis function)이다. 일종의 residual connection 역할을 수행한다. 논문에서는 이 basis function으로 SiLU 함수를 사용한다. 그리고 spline의 경우 B-Spline의 선형 결합으로 주어진다. \(w\)는 활성 함수의 출력값의 스케일을 조절하기 위해 사용되는 factor이다.

S2) Initialization

$$\text{Initialize }c_{i} \overset{\mathrm{iid}}{\sim} N(0,\sigma^{2})$$

각 스케일 \(w\)는 Xavier Initialization으로 초기화하며, 각 활성 함수 \(\phi\)는 \(\text{Spline}(x) \approx 0\)이 되도록 초기화한다. 이는 각 계수 \(c_i\)들을 \(N(0, \sigma^2)\) 분포에서 샘플링하는 것으로 수행하며 KAN에서는 \(\sigma = 0.1\)로 설정하였다.

S3)Spline Grid Update
각 spline의 격자점(grid)을 업데이트하여, 활성함수의 함수값이 학습 중 고정된 범위를 넘어서는 것을 방지한다. 이러한 방식으로 KAN은 총 \(N^{2}L(G+k) \sim N^{2}LG\)의 파라미터 수를 가지게 된다. 일반적으로 MLP는 \(N^2L\)개의 파라미터수를 갖는 것과 비교하였을 때 KAN의 파라미터 수가 더 많아 보이지만, MLP에서 요구하는 \(N\)의 크기와 KAN에서 요구하는 N의 크기는 다르다. 결론적으로 KAN에서 더 작은 \(N\)값을 요구하기 때문에 파라미터 수의 차이는 작다.

결론적으로 활성 함수의 형태는 결국 Spline을 이루는 \(c_i\)에의해 결정된다. 기저 함수 \(B(x)\)는 고정된 함수이기 때문에 각각의 \(B(x)\)들을 결합하는 강도를 의미하는 \(c_i\)들이 어떤 값을 갖느냐에 따라 최종 함수의 모습이 결정되기 떄문이다. KAN이 학습하는 가중치는 결론적으로 \(c_i\)다.


3) Approximation for Deep KAN

4개의 변수를 가지는 다변량 함수 \(f(x)\)를 다음과 같이 정의할 때, 이는 KAN에서는 3개의 레이어 [4,2,1,1] KAN으로 표현할 수 있다.

$$f(x_{1},x_{2},x_{3},x_{4}) = \exp(\sin(x_{1}^{2}+x_{2}^{2})+\sin(x_{3}^{2}+ x_{4}^ {2}))$$

이 경우 각 레이어는 \(x^2, \sin(x), \exp(x)\)로 계산된다. 하지만, 2-layer KAN의 경우는 위와 같이 활성함수를 간단하게 \(x^2\) , \(\sin(x)\) 혹은 \(\exp(x)\)로 표현하는 것은 불가능하다. 이 사실을 바탕으로 KAN의 레이어가 깊어질수록 각 레이어에 대응되는 활성 함수의 표현이 간단해 진다는 것을 알 수 있다.

이를 수학적으로 표현하기 위한 근사식(Approximation Theory) 또한 존재한다. 입력 데이터 \(\mathbf{x}\in \mathbb{R}^{n}\)에 대해 함수 \(f(\mathbf{x})\)가 다음과 같이 KAN으로 표현된다고 하자.

$$f = (\Phi_{L-1} \circ \cdots \circ \Phi_{0}) \mathbf{x} $$

그러면 상수 \(C\)가 존재하여 다음의 근사 경계(approximation bound)가 성립한다.

$$\left\Vert f - (\Phi_{L-1}^{G} \circ \cdots \circ \Phi_{0}^{G}) \mathbf{x}\right\Vert_{C^{m}} \le C G^{-k-1+m}$$

여기서 \(G\)는 B-spline에서 격자 크기(구간의 수), \(k\)는 smoothness order를 각각 의미하며, norm은 다음과 같이 정의한다

$$\left\Vert g\right\Vert_{C^{m}} = \max_{\vert\beta\vert\le m} \sup_{x\in[0,1]^{n}} \left\vert D^{\beta}g(x)\right\vert$$

1

KAT의 핵심은 근사 경계 (approximation bound)가 입력 차원 \(N\)에 의존하지 않는다는 점이다. 즉, spline의 형태만 적절히 조절하면, 임의의 함수를 잘 근사할 수 있게 된다. 논문에서는 \(k=3\), 즉 cubic spline을 사용하는 것이 적절하다고 제안한 반면, 격자 크기 \(G\)의 경우 너무 세밀하게 설정하면 테스트 손실이 발산하는 overfitting 문제가 발생할 수 있음을 보였다. 만약 학습 데이터가 \(n\)개 있고, 전체 파라미터 수가 \(mG\)개로 계산된다면, \(G = \dfrac{n}{m}\)을 넘지 않도록 설정하는 것이 적절하다고 한다 (interpolation threshold).


Interpretability of KAN

KAN은 MLP와는 달리 해석 가능성 (interpretability)의 능력을 가지고 있다고 서술한다. 그러나 지금까지의 KAN의 내부 구성만으로는 해석 가능성에 대한 타당성을 입증할 수 없다. 따라서 논문에서는 타당성을 입증하기 위해 추가적인 과정을 제시한다.

1

이 그림은 KAN의 해석 가능하게 만드는 과정을 보여준다. 간단한 예시로 2D input을 갖는 \(f(x,y)=\exp(\sin(\pi x)+ {y^2})\)로부터 생성된 데이터들이 있다고 가정한다.

함수 $f$를 사전에 알고 있다면 이는 [2,1,1]-KAN으로 표현된다는 것을 알 수 있지만, 당면한 상황은 이를 모르는 상황이다. 이 경우 충분히 큰 KAN을 먼저 가정한 후, sparsity regularization을 바탕으로 모델을 학습하여 불필요한 노드와 엣지를 제거하고 (prune), 이를 원래 함수로 나타내는 것이다.

S1) Sparsification
먼저 첫 번째 과정은 Sparsification을 포함한 학습이다. 이는 이후에 과정인 Pruning을 하기 위함이다. Pruning을 위해서는 가중치 판단이 가능해야 한다. 가중치 판단이란 학습할 때 최대한 중요한 파라미터만 남기고, 나머지는 작은 값을 갖도록 하는 것이다. 이를 위해 L1 RegularizationEntropy Regularization을 손실함수에 추가한다.

  • Regularization → L1 norm
$$\left\vert \Phi\right\vert_{1}:=\sum_{i=1}^{n_\mathrm{in}}\sum_{j=1}^{n_\mathrm{out}}\left\vert \phi_{i,j}\right\vert_{1}$$
  • Entropy Regularization term
$$S(\Phi) := \sum_{i=1}^{n_\mathrm{in}}\sum_{j=1}^{n_\mathrm{out}}\frac{\left\vert \phi_{i,j}\right\vert_{1}}{\left\vert \Phi\right\vert_{1}}\log \left(\frac{\left\vert \phi_{i,j}\right\vert_{1}}{\left\vert \Phi\right\vert_{1}}\right)$$
  • Total Loss
$$l_\mathrm{total} = l_\mathrm{pred} + \lambda \left(\mu_{1}\sum_{l=0}^{L-1}\left\vert \Phi_{l}\right\vert_{1}+ \mu_{2}\sum_{l=0}^{L-1}S(\Phi_{l})\right).$$

S2) Pruning 두 번째 과정은 Pruning이다. 이때 Pruning은 학습된 가중치의 값을 기준으로 수행된다. 앞선 단계에서 이를 위해 희소화를 도입했으며, Pruning 단계에서는 한 개의 가지를 제외하고 모두 제거한다. 위 그림을 보면 Pruning이 완료된 후 입력과 출력이 한 개의 라인으로 연결된 모습을 볼 수 있다. 즉, 입력값 \(x\)로부터 출력값 \(y\)가 나오는 과정이 곱하기와 더하기만으로 표현될 수 있는 상태가 된 것이다.

S3) Expression

1

마지막으로 이렇게 얻은 활성함수들을 우리가 이해할 수 있는 함수로 표현해 주는 것이다. 예를 들어, \(x^2, \sin(x), \exp(x)\)등으로 표현하는 것이다. 하지만, KAN에서 제안한 해석 가능성은 너무나 활성 함수가 쉽게 표현 가능한 함수들이라는 너무도 단순한 가정을 하고 있기 때문에 이 가정에 대해서는 추후 연구를 지켜볼 필요가 있다.

KAN의 장단점

  • 장점
    • 향상된 표현력
    • 해석 가능성 제안(블랙박스 문제를 어느정도 해결)
    • 효율적인 학습
  • 단점
    • 복잡한 구현. 각 노드별로 다른 활성 함수를 사용해야하기 때문
    • Sparsification, Pruning, Symbolification등 다양한 방법을 통해 해석 가능성을 제안하지만 이는 컴퓨팅 자원의 높은 계산 비용으로 이어짐.
    • 실험적 검증의 부재

Reference

[1] Blog: KAN : Kolmogorov-Arnold Network
[2] Blog: [논문 리뷰] KAN: Kolmogorov–Arnold Networks
[3] Paper: KAN: Kolmogorov-Arnold Networks
[4] Github: Awesome KAN(Kolmogorov-Arnold Network)

DeepLearning 카테고리 내 다른 글 보러가기

댓글 남기기