[Deep Learning]KAN Network

Updated:     Updated:

Categories:

MLP vs KAN

Multi-Layer Perceptron(MLP) is the foundation of modern deep learning, and it approximates arbitrary nonlinear functions by using ‘Linear Transformation’ with a weight matrix \(W\) and ‘Non-linear Avtivation Functions’ such as ReLU.

This can be explained by the ‘Universal Approximation Theorem (UAT).’ The UAT states that when an artificial neural network (ANN) satisfies certain conditions, it can approximate any continuous function to the desired accuracy.

A neural network structure such as a MLP can learn any function if it has enough nodes and an appropriate activation funtion. In other words, it means that any nonlinear function can be approximated within a given error margin using a fully-connected network (FCN) with a few hidden layers.

On the other hand, KAN is developed based on the Kolmogorov-Arnold Representation Theorem. While KAN may sound complex, it is essentially nothing more than a combination of splines and MLP. In other words, it’s simply a fusion of spline models and MLP.

KAN Network

Kolmogorov-Arnold Representation Theorem(KAT)

First, to define KAT, two assumptions are needed.

1)(Multivariate) Smooth function: \(f : [0, 1]^n \rightarrow \mathbb{R}\)
2)\(\phi_{q, p}: [0, 1] \rightarrow \mathbb{R}, \;\; \Phi: \mathbb{R} \rightarrow \mathbb{R}\)

$$f(x) = f(x_1, \cdots, x_n) = \displaystyle\sum_{q=1}^{2n+1} \Phi_q \left ( \displaystyle\sum_{p=1}^n \phi_{q, p} (x_p) \right )$$

Kolomogorov-Arnold representation theorem is presented as above. This theorem means that a multivariate function of \(n\)-dimensions can be expressed as the sum of a univariate funtions, \(\phi\). From a machine learning perspective, when trying to learn high-dimensional functions, this reduces to learning \(n\) univariate functions, which can be expected to avoid the curse of dimensionality.

Mathematically speaking, KAT is conceptually similar to the kernel trick in Support Vector Machines (SVM). KAT shares a common point with the SVM kernel trick in that it can solve high-dimensional problems by decomposing complex multivariate functions into combinations of simpler functions.

Simple KAN

1

The table above compares MLP and KAN. When representing an MLP as a graph structure, each node corresponds to an activation function, and each edge corresponds to a weight matrix W, as shown on the left in (a). In contrast, for KAN, which will be explained below, each node corresponds to a function value, and each edge corresponds to a function \(\phi\). Moreover, unlike MLP, which learns weights, KAN learns functions.

Assuming the input data \(\mathbf{x_0} = \{ x_{0,1}, x_{0,2} \}\) is given in two dimensions, it can be represented by the following equation.

$$x_{2,1} = \displaystyle\sum_{q=1}^5 \Phi_q(\phi_{q,1}(x_{0,1}) \; + \; \phi_{q,2}(x_{0,2}))$$

In this case, the computation graph of KAN is as follows.

1

First, each node in the graph (black vertices) corresponds to univariate function values or data. Each component of the input data, \(x_{0,1}, x_{0,2}\), is given to the first nodes of the computational graph. As can be seen in equation \((1)\), each component is associated with \(\phi_{q,1}: q=1,\ldots,5\), which means that from \(x_{0,1}, x_{0,2}\), five edges extend (represented as \(\phi_{0,1,1}\ldots\phi_{0,2,5}\) in the graph). When they are summed (as in the addition part of the equation), the values with the same \(q\) are added together, resulting in \(\phi_{0,1,1} + \phi_{0,2,1} \ldots\) being summed respectively.

From this, a hidden layer with five components, \(x_{1,1}, \ldots, x_{1,5}\), is formed. Finally, to perform the operation \(\sum_{q}\Phi_{q}\), function values from each hidden layer component are added, which constitutes the output \(x_{2,1}\). (In the figure above, \(\phi\) and \(\Phi\) seem to be represented as \(\phi\) without distinction for the sake of consistency in notation.)

Furthermore, each function \(\phi_{q, p}\) and \(\Phi\) must be learnable, so spline is used to learn the coefficients of each function. In KAN, B-spline is utilized.

1) KAN Layer

The KAN structure examined above corresponds to a single KRT, using a hidden layer with a width of \(2n + 1\). This is a very shallow and simple structure compared to MLP. Therefore, to match the capacity of an MLP, KAN also needs to stack multiple hidden layers.

$$\Phi = \{\phi_{q, p}\}, p=1,2,\dots, N, \;\;\;q=1,2,\dots, M$$

The example examined above is sequentially transformed from dimensions \(2 \rightarrow 5 \rightarrow 1\), so it can be said to consist of two KAN layers. In the paper, this is represented as [n, 2n+1, 1]-KAN. Based on this, KAN layers can be stacked sequentially to form a Deep KAN.


2) Activation Function

KAN learns the non-linear activation functions themselves. To make these functions learnable, KAN uses B-splines. In KAN, the setup is as follows.

S1) Residual activation function

$$\phi(x) = w(b(x) + \text{Spline}(x))$$
$$b(x) = \text{SiLU}(x) = \frac{x}{1 + e^{-x}} \;\;\; \text{Spline}(x) = \displaystyle\sum_ic_iB_i(x)$$

In this equation, \(b(x)\) is the basis function. It acts as a kind of residual connection. In the paper, the SiLU function is used as the basis function andfor the spline, it is given as a linear combination of B-splines. The \(w\) is the scale factor for controlling the output value of the activation function.

S2) Initialization

$$\text{Initialize }c_{i} \overset{\mathrm{iid}}{\sim} N(0,\sigma^{2})$$

Each scale \(w\) is initialized using Xavier initialization, and each activation function \(\phi\) is initialized such that \(\text{Spline}(x) \approx 0\). This is done by sampling each coefficient \(c_i\) from the distribution \(N(0, \sigma^2)\), with \(\sigma = 0.1\) set in KAN.

S3)Spline Grid Update
Each spline grid point is updated to prevent the activation function’s values from exceeding a fixed range during training. In this way, KAN has a total of \(N^{2}L(G+k) \sim N^{2}LG\) parameters. While KAN may seem to have more parameters compared to an MLP with \(N^2L\) parameters, the \(N\) required for MLP is different from the \(N\) required for KAN. Ultimately, since KAN requires a smaller \(N\), the difference in parameter count is reduced.

In conclusion, the shape of the activation function is ultimately determined by the \(c_i\) that make up the spline. Since the basis function \(B(x)\) is a fixed function, the final form of the function depends on the values of \(c_i\), which determine the strength of the combination of each \(B(x)\). Therefore, the weights learned by KAN are ultimately \(c_i\).


3) Approximation for Deep KAN

When defining a multivariate function \(f(x)\) with 4 variables as follows, it can be represented by a 3-layer [4, 2, 1, 1] KAN.

$$f(x_{1},x_{2},x_{3},x_{4})=\exp(\sin(x_{1}^{2}+x_{2}^{2})+\sin(x_{3}^{2}+x_{4}^{2}))$$

In this case, each layer is computed as \(x^2, \sin(x), \exp(x)\). However, in the case of a 2-layer KAN, it is impossible to express the activation functions as simply as \(x^2\), \(\sin(x)\), or \(\exp(x)\). Based on this, we can see that as the layers of KAN increase in depth, the representation of the activation function corresponding to each layer becomes simpler.

(KAT) Given input data \(\mathbf{x}\in \mathbb{R}^{n}\), let the function \(f(\mathbf{x})\) be represented by KAN as follows.

$$f = (\Phi_{L-1} \circ \cdots \circ \Phi_{0}) \mathbf{x}$$

Then, there exists a constant \(C\) such that the following approximation bound holds.

$$\left\Vert f - (\Phi_{L-1}^{G} \circ \cdots \circ \Phi_{0}^{G}) \mathbf{x}\right\Vert_{C^{m}} \le C G^{-k-1+m} \tag{KAT}$$

Here, \(G\) represents the grid size (number of intervals) in the B-spline, and \(k\) represents the smoothness order. The norm is defined as follows.

$$\left\Vert g\right\Vert_{C^{m}} = \max_{\vert\beta\vert\le m} \sup_{x\in[0,1]^{n}} \left\vert D^{\beta}g(x)\right\vert$$

1

The core of KAT is that the approximation bound does not depend on the input dimension \(N\). In other words, by adjusting the assumed form of the spline, any function can be approximated well. The paper suggests that using \(k=3\), i.e., a cubic spline, is appropriate, while setting the grid size \(G\) too finely can lead to overfitting issues where the test loss diverges. If there are \(n\) training data points and the total number of parameters is calculated as \(mG\), it is recommended to set \(G \leq \dfrac{n}{m}\) to avoid this problem (interpolation threshold).


Interpretability of KAN

KAN is described as having the ability of interpretability, unlike MLP. However, the internal structure of KAN alone is insufficient to prove the validity of its interpretability. Therefore, the paper presents an additional process to prove this validity.

1

This figure shows the process of making KAN interpretable. As a simple example, suppose there are data generated from \(f(x, y) = \exp(\sin(\pi x) + y^2)\) with a 2D input.

If the function \(f\) is known beforehand, it can be represented by a [2, 1, 1]-KAN, but in the current situation, this is unknown. In this case, a sufficiently large KAN is initially assumed, and then the model is trained based on sparsity regularization to prune unnecessary nodes and edges, representing the original function.

S1) sparsification
The first step is the training process, which includes sparsification. This is a preparation step for the subsequent pruning process. Pruning requires the ability to assess weights. Weight assessment means retaining only the most important parameters during training while making the others small. To achieve this, L1 Regularization and Entropy Regularization are added to the loss function.

  • Regularization → L1 norm
$$\left\vert \Phi \right\vert_{1} := \sum_{i=1}^{n_\mathrm{in}} \sum_{j=1}^{n_\mathrm{out}} \left\vert \phi_{i,j} \right\vert_{1}$$
  • Entropy Regularization term
$$S(\Phi) := \sum_{i=1}^{n_\mathrm{in}} \sum_{j=1}^{n_\mathrm{out}} \frac{\left\vert \phi_{i,j} \right\vert_{1}}{\left\vert \Phi \right\vert_{1}} \log \left( \frac{\left\vert \phi_{i,j} \right\vert_{1}}{\left\vert \Phi \right\vert_{1}} \right)$$
  • Total Loss
$$l_\mathrm{total} = l_\mathrm{pred} + \lambda \left( \mu_{1} \sum_{l=0}^{L-1} \left\vert \Phi_{l} \right\vert_{1} + \mu_{2} \sum_{l=0}^{L-1} S(\Phi_{l}) \right)$$

S2) Pruning
The second step is Pruning. Pruning is performed based on the value of the learned weights. In the previous step, sparsification was introduced for this purpose, and in the pruning step, all branches except one are removed. In the figure above, once pruning is complete, you can see that the input and output are connected by a single line. In other words, the process of obtaining the output value \(y\) from the input value \(x\) is now represented by only multiplication and addition.

S3) Expression

1

Finally, the obtained activation functions are expressed in a form that we can understand, such as \(x^2, \sin(x), \exp(x)\).

However, the interpretability proposed by KAN is based on a rather simplistic assumption that the activation functions can be easily expressed in such simple forms, so further research is needed to evaluate this assumption.

Pros ans Cons of KAN Network

  • Pros
    • Enhanced expressiveness
    • Proposal of interpretability (partially solving the black-box problem)
    • Efficient training
  • Cons
    • Complex implementation: different activation functions must be used for each node
    • Although interpretability is proposed through methods such as sparsification, pruning, and symbolification, this leads to high computational costs
    • Lack of experimental validation

Reference

[1] Blog: KAN : Kolmogorov-Arnold Network
[2] Blog: [논문 리뷰] KAN: Kolmogorov–Arnold Networks
[3] Paper: KAN: Kolmogorov-Arnold Networks
[4] Github: Awesome KAN(Kolmogorov-Arnold Network)

DeepLearning 카테고리 내 다른 글 보러가기

첫 번째 글입니다 가장 최근 글입니다

Comments