Skip to content

Commit 8f76f2d

Browse files
committed
Tag handelling and new post.
1 parent d7d9505 commit 8f76f2d

2 files changed

Lines changed: 100 additions & 30 deletions

File tree

posts/entries/011-MoE-1.md

Lines changed: 92 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
---
22
title: Auxiliary-loss Load Balancing in MoEs (1)
33
date: 2025-07-06
4-
tags: [cs, finance, ai]
4+
tags: [cs, ai, notes]
55
author: R
66
location: Above Illulissat, Greenland while on a plane from New York to Hong Kong
77
---
@@ -17,9 +17,9 @@ Let
1717
- $K$ the number of experts selected per token,
1818
- $T$ the total number of tokens in the batch,
1919
- $\mathbf{u}_t\in\R^d$ the input for token $t$,
20-
- $\mathrm{FFN}_i: \R^d\to\R^d$ the $i$-th expert network,
20+
- $\text{FFN}_i: \R^d\to\R^d$ the $i$-th expert network,
2121
- $e_i\in\R^d$ the centroid (parameter) of expert $i$, and
22-
- $G\colon\R\to\R_{>0}$ a positive gating function (e.g. $\exp$, $\mathrm{sigmoid}$, or $\mathrm{softmax}$).
22+
- $G\colon\R\to\R_{>0}$ a positive gating function (e.g. $\exp$, $\text{sigmoid}$, or $\text{softmax}$).
2323

2424
Compute for each token $t$ and expert $i$:
2525
$$
@@ -30,7 +30,7 @@ $$
3030
& (\text{if $s_{i,t}$ is among the top-$K$ scores})
3131
\\
3232
0, & \text{otherwise}
33-
\end{cases}\\
33+
\end{cases} \\
3434
s_{i,t} &= G(\textbf{u}_t^\top e_i)
3535
\end{align*}
3636
$$
@@ -40,54 +40,118 @@ $$
4040
\textbf{h}_t = \textbf{u}_t + \sum^N_{i=1} g_{i,t} \text{FFN}_i (\textbf{u}_t) \\
4141
$$
4242

43-
So, here $G$ could be something that is $\R \to \R_{>0}$, some conventional ones could be a $\exp$, softmax or sigmoid (TBH I have to search these up to see what they actually are). In this paper they have use the latter 2.
43+
So, here $G$ could be something that is $\R \to \R_{>0}$, some conventional ones could be a $\exp$, softmax or sigmoid (TBH I have to search these two up to see what they are exactly). In this paper they have use the latter 2.
4444

4545
And there is the Expert consulted following the gating function.
4646

47-
### Problem: Inbalanced routing
47+
## Problem: Inbalanced routing
4848
But one problem MoEs often experience is inbalanced routing (a few number of experts recieve the most token), thus *a risk of routing collapse (Shazeer et al., 2017), where the model consistently selects only a few experts, hindering sufficient training of the other experts*; or, a *computational bottleneck by load inbalance*.
4949

5050
I was wondering how it could cause a computational bottleneck, but then I realized the way I thought about it that it could easily scale through parallelism or some other ways is not easily achievable. Since there are different machines hosting each model, it depends more on the load given to a certain expert.
5151

5252
Plus, the training loop should undergo a substantial redesign for it to use the idle computational power to catch up. Even if I create replicas for the "hot" experts on more hosting devices, they need to be in sync, therefore creating a lot of cost by itself. Merging gradients across replicas requires collective operations every step, at that point it will just recreate the original problem trying to overcome if 1 of these slowed down...
5353

54-
#### Solution: Auxiliary-loss
55-
To address this issue, there is auxiliary-loss encourage balanced load thus avoids inbalanced routing in training MoEs. To do this, it penalized the use of only a few number of agents. Its mostly within the process of the gating function. Defined as such:
54+
### Solution: Auxiliary-loss
55+
To address this issue, there is auxiliary-loss encourage balanced load thus avoids inbalanced routing in training MoEs. To do this, it penalized the use of only a few number of agents. Its mostly within the process of the gating function.
5656

57+
**Key variables:**
58+
- $N$: number of experts in the MoE layer
59+
- $K$: number of experts selected per token (top-K)
60+
- $T$: total number of tokens in the batch
61+
- $\mathbb{1}$: indicator function (equals 1 if condition is true, 0 otherwise)
62+
- $\alpha$: balancing‐loss weight (manually set hyperparameter)
5763

64+
Defined as such:
5865

59-
- **Normalized load**
66+
- **Normalized load**
67+
$f_i$:= the fraction of tokens routed to expert $i$:
6068
$$
61-
f_i = \frac{N}{KT} \sum_{t=1}^T \mathbb{1} (i \in \mathrm{Topk} \mid \mathbf{u}_t )
62-
\quad (\text{fraction of tokens routed to expert }i ).
69+
f_i = \frac{N}{KT} \sum_{t=1}^T \mathbb{1} (i \in \text{Topk} \mid \mathbf{u}_t )
6370
$$
64-
- **Average gating weight**
71+
72+
- **Average gating weight**
73+
$P_i$:= the mean score assigned by the gate to expert $i$:
6574
$$
66-
P_i = \frac{1}{T}\sum_{t=1}^T s_{i,t}
67-
\quad (\text{mean score assigned by the gate to expert }i ).
75+
P_i = \frac{1}{T} \sum_{t=1}^T s_{i,t}
6876
$$
69-
`
70-
##### Balance loss
7177

7278
Combine these into a single penalty term:
79+
$$
80+
\mathcal{L}_{\text{balance}} = \alpha \sum_{i=1}^N f_i P_i.
81+
$$
82+
83+
84+
**Regularization terms:**
85+
Introduce two small-weight penalties on the imbalance of $\{P_i\}$ and $\{f_i\}$:
86+
$$
87+
\begin{align*}
88+
\mathcal{L}_{P} &= \lambda_{P} \text{CV}^2 (\{P_i\} )\\
89+
\mathcal{L}_{f} &= \lambda_{f} \text{CV}^2 (\{f_i\} )
90+
\end{align*}
91+
$$
92+
where typically $\lambda_{P} \approx\lambda_{f} \sim10^{-2}$.
93+
94+
> This is actually optional, for simpler just use $\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}} + \mathcal{L}_{\text{balance}}$.
95+
> I write it this way just to follow the [original MoE auxiliary-loss formulation paper (Shazeer et al. (2017))](https://arxiv.org/pdf/1701.06538).
96+
97+
98+
**Imbalance metric: coefficient of variation squared**
99+
For any set of scalars $\{z_i\}_{i=1}^N$, define
100+
$$
101+
\text{CV}^2(\{z_i\}) =
102+
\frac{\frac{1}{N} \sum_{i=1}^N z_i^2 - (\frac{1}{N} \sum_{i=1}^N z_i )^2}
103+
{(\frac{1}{N} \sum_{i=1}^N z_i )^2},
104+
$$
105+
which satisfies $\text{CV}^2=0$ exactly when all $z_i$ are equal.
106+
107+
> Btw, this looks like the variance so much.
108+
> Write $\mu = \tfrac1N \sum_i z_i$ and $\nu = \tfrac1N \sum_i z_i^2$. Then
109+
> $$
110+
> \text{CV}^2 = \frac{\nu - \mu^2}{\mu^2} = \frac{\text{Var}}{(\text{Mean})^2}
111+
> $$
112+
> Its partial derivative w. one coordinate $z_k$ is
113+
> $$
114+
> \frac{\partial \text{CV}^2}{\partial z_k}
115+
> = \frac{2}{N}\Bigl(\frac{z_k}{\mu^2} - \frac{\nu}{\mu^3}\Bigr).
116+
> $$
117+
> > Details:
118+
> > $$
119+
> > \begin{align*}
120+
> > \frac{\partial}{\partial z_k} (\tfrac{\nu - \mu^2}{\mu^2})
121+
> > &= \frac{1}{\mu^2} \frac{\partial\nu}{\partial z_k}
122+
> > - \frac{\nu - \mu^2}{\mu^4} 2\mu \frac{\partial\mu}{\partial z_k}\\
123+
> > &= \frac{1}{\mu^2} \frac{2z_k}{N}
124+
> > - \frac{\nu - \mu^2}{\mu^4} \frac{2\mu}{N}\\
125+
> > &= \frac{2}{N}\Bigl(\frac{z_k}{\mu^2} - \frac{\nu}{\mu^3}\Bigr).
126+
> > \end{align*}
127+
> > $$
128+
>
129+
> Because $\nu/\mu^3$ is the same constant for all $k$, this gradient pushes down any $z_k > \mu$ (overloaded expert) and pushes up any $z_k < \mu$ (underloaded expert). In other words, the derivative of a variance term normalized by $\mu^2$.
130+
131+
132+
**Total training objective**
133+
Combine with the primary task loss $L_{\text{task}}$:
73134
$$
74-
\mathcal{L}_{\mathrm{balance}}
75-
=
76-
\alpha \sum_{i=1}^N f_i P_i.
135+
\mathcal{L}_{\text{total}}
136+
= \mathcal{L}_{\text{task}}
137+
+ \mathcal{L}_{P}
138+
+ \mathcal{L}_{f}.
77139
$$
78140

79-
##### Why this encourages balanced routing
141+
#### Intuition
142+
- The penalty grows, as either $f_i$ or $P_i$ grows (since it's a product). Then the routing distribution is driven toward uniformity by the penalties. Backpropagation through the parameters plays a role in this process.
143+
- Minimizing $\text{CV}^2$ drives the variance of $\{\text{Imp}_i\}$ or $\{\text{Load}_i\}$ toward zero relative to their mean. (see derivation of $\partial \text{CV}^2/\partial z_k$ above)
144+
- Any expert $i$ with above-average usage raises its own $\text{Imp}_i$ or $\text{Load}_i$, increasing the penalty.
80145

81-
- $f_i$ captures how heavily expert $i$ is used, while $P_i$ captures its average gate score.
82-
- If an expert is over-selected ($f_i$ large), the product $f_iP_i$ grows, increasing the penalty.
83-
- Gradients then adjust the gating parameters to **decrease** routing to over-used experts and **increase** routing to under-used ones, driving the distribution toward uniformity.
84146

85-
##### Why this balances load
147+
#### Drawbacks
148+
The ICLR 2025 mentioned that auxiliary loss might introduce unwanted gradients, as the MoE models performs worse on some metrics.
86149

87-
- Minimizing $\mathrm{CV}^2$ drives the variance of $\{\mathrm{Imp}_i\}$ or $\{\mathrm{Load}_i\}$ toward zero *relative* to their mean.
88-
- Any expert $i$ with above-average usage raises its own $\mathrm{Imp}_i$ or $\mathrm{Load}_i$, increasing the penalty.
89-
- Backpropagation through the gating parameters encourages **reduced** routing to overloaded experts and **increased** routing to underutilized ones, leading to a more uniform expert selection distribution.
150+
However, I wasn't really convinced about this reasoning. Because the performance was not improved that significantly (I was expecting a larger gap) for the Validation Perplexity. The load balance one sound ok, and that's the main point of the paper, so it's good.
90151

91-
Basically, divide by the variance.
152+
The true drawback, in my opinion comes with the act of rebalancing through auxiliary-loss itself.
153+
- The idea of MoE is having a lot of highly specialized experts, auxiliary-loss fights any concentration of weight, even if that concentration was beneficial for modeling those tokens.
154+
- The balancing gradient for expert involves all experts' totals. So updating the logit for one expert now depends on every other expert’s load. It's obvious it can drown out more specialized signals.
155+
- Naturally, experts that are good at certain tokens are expected to get those; trying to make the router equalize loads regardless of quality can route a token to a weaker expert, simply because the "best" expert is already slightly busier.
92156

93157
(TBC)

posts/js/blog.js

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,14 @@ searchBox.addEventListener("input", e => {
9797
// Handle tag button clicks
9898
function handleTagClick(tag) {
9999
const params = new URLSearchParams(location.search);
100-
params.set("tag", tag);
101-
history.pushState({ tag }, "", `?${params.toString()}`);
100+
if (tag) {
101+
params.set("tag", tag);
102+
} else {
103+
params.delete("tag");
104+
}
105+
const newUrl = window.location.pathname +
106+
(params.toString() ? `?${params.toString()}` : "");
107+
history.pushState({ tag }, "", newUrl);
102108
applyTagFilter(tag);
103109
}
104110

0 commit comments

Comments
 (0)