|
1 | 1 | --- |
2 | 2 | title: Auxiliary-loss Load Balancing in MoEs (1) |
3 | 3 | date: 2025-07-06 |
4 | | -tags: [cs, finance, ai] |
| 4 | +tags: [cs, ai, notes] |
5 | 5 | author: R |
6 | 6 | location: Above Illulissat, Greenland while on a plane from New York to Hong Kong |
7 | 7 | --- |
|
17 | 17 | - $K$ the number of experts selected per token, |
18 | 18 | - $T$ the total number of tokens in the batch, |
19 | 19 | - $\mathbf{u}_t\in\R^d$ the input for token $t$, |
20 | | -- $\mathrm{FFN}_i: \R^d\to\R^d$ the $i$-th expert network, |
| 20 | +- $\text{FFN}_i: \R^d\to\R^d$ the $i$-th expert network, |
21 | 21 | - $e_i\in\R^d$ the centroid (parameter) of expert $i$, and |
22 | | -- $G\colon\R\to\R_{>0}$ a positive gating function (e.g. $\exp$, $\mathrm{sigmoid}$, or $\mathrm{softmax}$). |
| 22 | +- $G\colon\R\to\R_{>0}$ a positive gating function (e.g. $\exp$, $\text{sigmoid}$, or $\text{softmax}$). |
23 | 23 |
|
24 | 24 | Compute for each token $t$ and expert $i$: |
25 | 25 | $$ |
|
30 | 30 | & (\text{if $s_{i,t}$ is among the top-$K$ scores}) |
31 | 31 | \\ |
32 | 32 | 0, & \text{otherwise} |
33 | | - \end{cases}\\ |
| 33 | + \end{cases} \\ |
34 | 34 | s_{i,t} &= G(\textbf{u}_t^\top e_i) |
35 | 35 | \end{align*} |
36 | 36 | $$ |
|
40 | 40 | \textbf{h}_t = \textbf{u}_t + \sum^N_{i=1} g_{i,t} \text{FFN}_i (\textbf{u}_t) \\ |
41 | 41 | $$ |
42 | 42 |
|
43 | | -So, here $G$ could be something that is $\R \to \R_{>0}$, some conventional ones could be a $\exp$, softmax or sigmoid (TBH I have to search these up to see what they actually are). In this paper they have use the latter 2. |
| 43 | +So, here $G$ could be something that is $\R \to \R_{>0}$, some conventional ones could be a $\exp$, softmax or sigmoid (TBH I have to search these two up to see what they are exactly). In this paper they have use the latter 2. |
44 | 44 |
|
45 | 45 | And there is the Expert consulted following the gating function. |
46 | 46 |
|
47 | | -### Problem: Inbalanced routing |
| 47 | +## Problem: Inbalanced routing |
48 | 48 | But one problem MoEs often experience is inbalanced routing (a few number of experts recieve the most token), thus *a risk of routing collapse (Shazeer et al., 2017), where the model consistently selects only a few experts, hindering sufficient training of the other experts*; or, a *computational bottleneck by load inbalance*. |
49 | 49 |
|
50 | 50 | I was wondering how it could cause a computational bottleneck, but then I realized the way I thought about it that it could easily scale through parallelism or some other ways is not easily achievable. Since there are different machines hosting each model, it depends more on the load given to a certain expert. |
51 | 51 |
|
52 | 52 | Plus, the training loop should undergo a substantial redesign for it to use the idle computational power to catch up. Even if I create replicas for the "hot" experts on more hosting devices, they need to be in sync, therefore creating a lot of cost by itself. Merging gradients across replicas requires collective operations every step, at that point it will just recreate the original problem trying to overcome if 1 of these slowed down... |
53 | 53 |
|
54 | | -#### Solution: Auxiliary-loss |
55 | | -To address this issue, there is auxiliary-loss encourage balanced load thus avoids inbalanced routing in training MoEs. To do this, it penalized the use of only a few number of agents. Its mostly within the process of the gating function. Defined as such: |
| 54 | +### Solution: Auxiliary-loss |
| 55 | +To address this issue, there is auxiliary-loss encourage balanced load thus avoids inbalanced routing in training MoEs. To do this, it penalized the use of only a few number of agents. Its mostly within the process of the gating function. |
56 | 56 |
|
| 57 | +**Key variables:** |
| 58 | +- $N$: number of experts in the MoE layer |
| 59 | +- $K$: number of experts selected per token (top-K) |
| 60 | +- $T$: total number of tokens in the batch |
| 61 | +- $\mathbb{1}$: indicator function (equals 1 if condition is true, 0 otherwise) |
| 62 | +- $\alpha$: balancing‐loss weight (manually set hyperparameter) |
57 | 63 |
|
| 64 | +Defined as such: |
58 | 65 |
|
59 | | -- **Normalized load** |
| 66 | +- **Normalized load** |
| 67 | + $f_i$:= the fraction of tokens routed to expert $i$: |
60 | 68 | $$ |
61 | | - f_i = \frac{N}{KT} \sum_{t=1}^T \mathbb{1} (i \in \mathrm{Topk} \mid \mathbf{u}_t ) |
62 | | - \quad (\text{fraction of tokens routed to expert }i ). |
| 69 | + f_i = \frac{N}{KT} \sum_{t=1}^T \mathbb{1} (i \in \text{Topk} \mid \mathbf{u}_t ) |
63 | 70 | $$ |
64 | | -- **Average gating weight** |
| 71 | + |
| 72 | +- **Average gating weight** |
| 73 | + $P_i$:= the mean score assigned by the gate to expert $i$: |
65 | 74 | $$ |
66 | | - P_i = \frac{1}{T}\sum_{t=1}^T s_{i,t} |
67 | | - \quad (\text{mean score assigned by the gate to expert }i ). |
| 75 | + P_i = \frac{1}{T} \sum_{t=1}^T s_{i,t} |
68 | 76 | $$ |
69 | | -` |
70 | | -##### Balance loss |
71 | 77 |
|
72 | 78 | Combine these into a single penalty term: |
| 79 | + $$ |
| 80 | + \mathcal{L}_{\text{balance}} = \alpha \sum_{i=1}^N f_i P_i. |
| 81 | + $$ |
| 82 | + |
| 83 | + |
| 84 | +**Regularization terms:** |
| 85 | +Introduce two small-weight penalties on the imbalance of $\{P_i\}$ and $\{f_i\}$: |
| 86 | +$$ |
| 87 | +\begin{align*} |
| 88 | + \mathcal{L}_{P} &= \lambda_{P} \text{CV}^2 (\{P_i\} )\\ |
| 89 | + \mathcal{L}_{f} &= \lambda_{f} \text{CV}^2 (\{f_i\} ) |
| 90 | +\end{align*} |
| 91 | +$$ |
| 92 | +where typically $\lambda_{P} \approx\lambda_{f} \sim10^{-2}$. |
| 93 | + |
| 94 | +> This is actually optional, for simpler just use $\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}} + \mathcal{L}_{\text{balance}}$. |
| 95 | +> I write it this way just to follow the [original MoE auxiliary-loss formulation paper (Shazeer et al. (2017))](https://arxiv.org/pdf/1701.06538). |
| 96 | +
|
| 97 | + |
| 98 | +**Imbalance metric: coefficient of variation squared** |
| 99 | +For any set of scalars $\{z_i\}_{i=1}^N$, define |
| 100 | +$$ |
| 101 | + \text{CV}^2(\{z_i\}) = |
| 102 | + \frac{\frac{1}{N} \sum_{i=1}^N z_i^2 - (\frac{1}{N} \sum_{i=1}^N z_i )^2} |
| 103 | + {(\frac{1}{N} \sum_{i=1}^N z_i )^2}, |
| 104 | +$$ |
| 105 | +which satisfies $\text{CV}^2=0$ exactly when all $z_i$ are equal. |
| 106 | + |
| 107 | +> Btw, this looks like the variance so much. |
| 108 | +> Write $\mu = \tfrac1N \sum_i z_i$ and $\nu = \tfrac1N \sum_i z_i^2$. Then |
| 109 | +> $$ |
| 110 | +> \text{CV}^2 = \frac{\nu - \mu^2}{\mu^2} = \frac{\text{Var}}{(\text{Mean})^2} |
| 111 | +> $$ |
| 112 | +> Its partial derivative w. one coordinate $z_k$ is |
| 113 | +> $$ |
| 114 | +> \frac{\partial \text{CV}^2}{\partial z_k} |
| 115 | +> = \frac{2}{N}\Bigl(\frac{z_k}{\mu^2} - \frac{\nu}{\mu^3}\Bigr). |
| 116 | +> $$ |
| 117 | +> > Details: |
| 118 | +> > $$ |
| 119 | +> > \begin{align*} |
| 120 | +> > \frac{\partial}{\partial z_k} (\tfrac{\nu - \mu^2}{\mu^2}) |
| 121 | +> > &= \frac{1}{\mu^2} \frac{\partial\nu}{\partial z_k} |
| 122 | +> > - \frac{\nu - \mu^2}{\mu^4} 2\mu \frac{\partial\mu}{\partial z_k}\\ |
| 123 | +> > &= \frac{1}{\mu^2} \frac{2z_k}{N} |
| 124 | +> > - \frac{\nu - \mu^2}{\mu^4} \frac{2\mu}{N}\\ |
| 125 | +> > &= \frac{2}{N}\Bigl(\frac{z_k}{\mu^2} - \frac{\nu}{\mu^3}\Bigr). |
| 126 | +> > \end{align*} |
| 127 | +> > $$ |
| 128 | +> |
| 129 | +> Because $\nu/\mu^3$ is the same constant for all $k$, this gradient pushes down any $z_k > \mu$ (overloaded expert) and pushes up any $z_k < \mu$ (underloaded expert). In other words, the derivative of a variance term normalized by $\mu^2$. |
| 130 | +
|
| 131 | + |
| 132 | +**Total training objective** |
| 133 | +Combine with the primary task loss $L_{\text{task}}$: |
73 | 134 | $$ |
74 | | - \mathcal{L}_{\mathrm{balance}} |
75 | | - = |
76 | | - \alpha \sum_{i=1}^N f_i P_i. |
| 135 | + \mathcal{L}_{\text{total}} |
| 136 | + = \mathcal{L}_{\text{task}} |
| 137 | + + \mathcal{L}_{P} |
| 138 | + + \mathcal{L}_{f}. |
77 | 139 | $$ |
78 | 140 |
|
79 | | -##### Why this encourages balanced routing |
| 141 | +#### Intuition |
| 142 | +- The penalty grows, as either $f_i$ or $P_i$ grows (since it's a product). Then the routing distribution is driven toward uniformity by the penalties. Backpropagation through the parameters plays a role in this process. |
| 143 | +- Minimizing $\text{CV}^2$ drives the variance of $\{\text{Imp}_i\}$ or $\{\text{Load}_i\}$ toward zero relative to their mean. (see derivation of $\partial \text{CV}^2/\partial z_k$ above) |
| 144 | +- Any expert $i$ with above-average usage raises its own $\text{Imp}_i$ or $\text{Load}_i$, increasing the penalty. |
80 | 145 |
|
81 | | -- $f_i$ captures how heavily expert $i$ is used, while $P_i$ captures its average gate score. |
82 | | -- If an expert is over-selected ($f_i$ large), the product $f_iP_i$ grows, increasing the penalty. |
83 | | -- Gradients then adjust the gating parameters to **decrease** routing to over-used experts and **increase** routing to under-used ones, driving the distribution toward uniformity. |
84 | 146 |
|
85 | | -##### Why this balances load |
| 147 | +#### Drawbacks |
| 148 | +The ICLR 2025 mentioned that auxiliary loss might introduce unwanted gradients, as the MoE models performs worse on some metrics. |
86 | 149 |
|
87 | | -- Minimizing $\mathrm{CV}^2$ drives the variance of $\{\mathrm{Imp}_i\}$ or $\{\mathrm{Load}_i\}$ toward zero *relative* to their mean. |
88 | | -- Any expert $i$ with above-average usage raises its own $\mathrm{Imp}_i$ or $\mathrm{Load}_i$, increasing the penalty. |
89 | | -- Backpropagation through the gating parameters encourages **reduced** routing to overloaded experts and **increased** routing to underutilized ones, leading to a more uniform expert selection distribution. |
| 150 | +However, I wasn't really convinced about this reasoning. Because the performance was not improved that significantly (I was expecting a larger gap) for the Validation Perplexity. The load balance one sound ok, and that's the main point of the paper, so it's good. |
90 | 151 |
|
91 | | -Basically, divide by the variance. |
| 152 | +The true drawback, in my opinion comes with the act of rebalancing through auxiliary-loss itself. |
| 153 | +- The idea of MoE is having a lot of highly specialized experts, auxiliary-loss fights any concentration of weight, even if that concentration was beneficial for modeling those tokens. |
| 154 | +- The balancing gradient for expert involves all experts' totals. So updating the logit for one expert now depends on every other expert’s load. It's obvious it can drown out more specialized signals. |
| 155 | +- Naturally, experts that are good at certain tokens are expected to get those; trying to make the router equalize loads regardless of quality can route a token to a weaker expert, simply because the "best" expert is already slightly busier. |
92 | 156 |
|
93 | 157 | (TBC) |
0 commit comments