[NLP]What is the Mixture-of-Experts (MoE)?

Updated:     Updated:

Categories:

What is the MoE?

According to a recent LLM study, one of the key factors to enhance the model performance is the model size which is determined by the number of learnable parameters. When if the computing resources are limited, training a Large Language Model(LLM) for fewer steps is better than training a smaller Language Model(LM) for more steps.

Mixture of Experts (MoE) allows for pretraining with fewer computational resources, making it possible to train larger models or datsets while staying whithin the same resource constraints. In particular, MoE models can reach the same level of performance as dense models in a much shortet time during pretraining.

Concepts of MoE

1) Sparse MoE layers

  • In neural networks, the dense feed-forward network (FFN), where the weight graph is a complete graph, is a key component of the transformer.
  • Sparse MoE layers divide one FFN into \(N\) layers, each of which is referred to as an ‘expert’. Since the FFN is split, each expert is responsible for handling specific tokens.
1
\[1\] Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity

2) Gate network or router

  • The gating mechanism, often referred to as the router, is responsible for determining which expert should handle each token. For instance, in the image described, the router assigns the token “More” to the second expert, while directing the token “Parameter” to the first expert.
  • The main role of this router is to distribute the incoming tokens to the most appropriate experts. As we’ll cover in more detail later, it’s possible for a single token to be routed to multiple experts simultaneously. A key challenge in Mixture of Experts (MoE) models is figuring out the optimal strategy for assigning tokens to experts.
  • The router itself consists of learnable parameters that are trained alongside the rest of the network’s components during the pre-training phase.

What is ‘Sparsity’?

1

Sparsity can be described through the weight graph in deep learning models. Typically, transformers are dense models in the sense that, during backpropagation, all weight values from the previous layer influence the weight updates in the next layer.

Mathmatically speaking, ‘Sparsity’ utilizes the concept of conditional computation. In dense models, all parameters \(W_i\) are used for every input, resulting in calculation such as \(f(x) = W_i x\). However, in sparse models, only a subset of the parameters is activated depending on the input, and the operation becomes \(f(x) = W_{selected} x\). This approach allows the model to scale efiiciently without increasing computational cost.

Shazeer’s work applying Mixture of Experts (MoE) to translation tasks is a great example of this conditional computation. In MoE, only a subset of experts is activated for a given input, enabling the model to grow in size while minimizing computational overhead. Thanks to this method, anywhere from dozens to even thousands of experts can be employed in MoE layers.

1

First, assuming there are 8 experts, routing a token to one expert can be implemented quite simply in code. By combining a Linear layer + Softmax, the specific expert can be easily determined, and the token can be straightforwardly forwarded to the selected expert.

import torch.nn as nn

# inputs는 dim 차원을 가진 tensor
dim, num_experts = 512, 8
gate = nn.Linear(dim, num_experts)
nth_expert = F.softmax(gate(inputs), dim=-1).argmax()

However, this structure can lead to several issues, especially when gating across K experts, where the problem becomes even more pronounced. In a standard transformer model, all data contributes to parameter updates, maintaining a consistent batch size.

On the other hand, in MoE, data is allocated differently to each expert, causing the batch size for each expert to vary and often become smaller than the fixed batch size. For instance, if 10 tokens are provided, 5 may be assigned to one expert, while the remaining 5 are distributed across multiple experts. In such cases, the batch sizes could become (5, 2, 1, 1, 1), resulting in imbalanced batch sizes and inefficient use of resources.

Furthermore, as tokens are routed to multiple experts, some experts may become overloaded while others remain underutilized, leading to inefficient resource utilization and potentially degraded training performance.

To solve this imbalancing problem, the gate network (\(G\)) adopts the ‘Top-K Gating’ method. First, the trained gating network decides which part of the input should be allocated to specific experts (\(E\)).

$$y = \displaystyle\sum_{i=1}^n G(x)_iE_i(x)$$

In this method, all experts are executed for every input. However, when \(G\) becomes 0, not all experts need to process every operation, thus saving computational resources. The gating function, as mentioned earlier, is a Softmax function.

$$G_{\sigma}(x) = \text{Softmax}(x \cdot W_g)$$

The ‘Top-K Gating’ method literally maintains the top K values. In this process, Shazeer applied a ‘Noisy Top-K Gating’ approach, where noise is added.

  1. First, noise is added to the gating function.
$$H(x)_i = (w \cdot W_g)_i \;+ \; \text{StandardNormal()} \cdot \text{Softplus}((x \cdot W_{\text{noise}})_i)$$
  1. Then, the Top-K values are selected.
$$\text{KeepTopK}(v, k)_i = \begin{cases} v_i & \text{if } v_i \text{ is in the top } k \text{ elements of } v, \\ -\infty & \text{otherwise.} \end{cases}$$
  1. Finally, by applying the Softmax function, a new gating function is defined.
$$G(x) = \text{Softmax}(\text{KeepTopK}(H(x), k))$$

By using a low enough k (e.g. one or two), training and inference can be performed much faster than if many experts were activated. Routing to more than one expert is necessary for the gate to learn how to route to different experts, so at least two experts must be selected.

For example, if only one expert is selected, the gate will always route the input to the same expert. This means all input data will be directed to one specific expert, preventing the other experts from participating in the learning process. As a result, the model becomes biased toward certain patterns, fails to capture diversity, and does not fully utilize the learning capacity of all experts.

On the other hand, when at least two experts are selected, the gate can learn to route inputs to different experts. For instance, with Expert A and B, some inputs will go to A, while others will go to B, allowing both experts to participate in the learning process. By diversifying routing, the model can learn a wider range of data patterns and achieve more generalized performance.

In short, at least two experts must be selected for the gate to effectively utilize different experts and for each expert to learn according to distinct data patterns.

MoE Implementation (by Mistral-7B)

import dataclasses
from typing import List

import torch
import torch.nn.functional as F
from simple_parsing.helpers import Serializable
from torch import nn


@dataclasses.dataclass
class MoeArgs(Serializable):
    num_experts: int
    num_experts_per_tok: int


class MoeLayer(nn.Module):
    def __init__(self, experts: List[nn.Module], gate: nn.Module, moe_args: MoeArgs):
        super().__init__()
        assert len(experts) > 0
        self.experts = nn.ModuleList(experts)
        self.gate = gate
        self.args = moe_args

    def forward(self, inputs: torch.Tensor):
        # Step 1: Pass the input through the gate linear layer to route it to an expert
        gate_logits = self.gate(inputs)
        
        # Step 2: Select the Top-K experts from the gate logits
        weights, selected_experts = torch.topk(gate_logits, self.args.num_experts_per_tok)

        # Step 3: Compute weights for the Top-K experts (using softmax)
        weights = F.softmax(weights, dim=1, dtype=torch.float).to(inputs.dtype)
        results = torch.zeros_like(inputs)

        # N개의 experts 돌면서 순회
        for i, expert in enumerate(self.experts):
            # Step 4: Extract tokens corresponding to the i_th expert
            batch_idx, nth_expert = torch.where(selected_experts == i)
            
            # Step 5: Pass the tokens through the i_th expert
						# Step 6: Apply the expert weight to the output of the i_th expert           
            results[batch_idx] += weights[batch_idx, nth_expert, None] * expert(
                inputs[batch_idx]
            )
        return results

Reference

[1] Switch Transformers: Scaling to Trillion Parameter Models with Simple and Efficient Sparsity
[2] FasterMoE: modeling and optimizing training of large-scale dynamic pre-trained models
[3] Blog: Mixture of Experts Explained
[4] Blog: What is MoE?
[5] Github: Mistral

NLP 카테고리 내 다른 글 보러가기

첫 번째 글입니다 가장 최근 글입니다

Comments