Involution Hell

Theory of MoE

Theory of MoE

基础公式定义

对于一个向量 ww,记 w2\|w\|_2w\|w\|_\infty 分别为它的 2\ell_2 范数和 \ell_\infty 范数。

给定正的常数 c1,c2c_1, c_2,我们定义:

  • x=Ω(y)x = \Omega(y),如果 x>c2yx > c_2 |y|
  • x=Θ(y)x = \Theta(y),如果 c_1 |y| < x < c_2 |y|
  • x=O(y)x = O(y),如果 x < c_1 |y|
  • x=o(y)x = o(y),如果 xy0\frac{x}{y} \to 0
  • O(y):上界,表示“不会比 y 增长得更快”。
  • Ω(y):下界,表示“至少和 y 一样快”。
  • Θ(y):上下界都在 y 的数量级内,表示“和 y 同阶”。
  • o(y):严格比 y 小得多,最终会趋近于 0。

重要假设

  1. 这个文章只想给出闭式遗忘公式,所以直接简化成线性模型。f(X)=Xw,wRdf(X)=X^⊤w,w∈R^d
  2. 这个文章只讨论task-wised的路由方法,数据生成的时候每份数据只加入了一个信号数据,其余都是正态分布噪声。目的也是为了简化模型,然后在实际工程应用中,token会被隐式的送到各个experts,而不采用人为设定的方式。

> ### 数据集生成规则

在每一轮训练 t[T]t \in [T],新的任务 ntn_t 到来时,数据集 Dt=(Xt,yt)\mathcal{D}_t = (X_t, y_t) 的生成步骤如下:

  1. 抽取任务真值向量
    • 从任务池 W={w1,,wN}\mathcal{W} = \{w_1, \dots, w_N\} 中均匀采样一个真值向量 wntw_{n_t},并设定 wntw_{n_t} 为当前任务的 ground truth。
  2. 生成缩放系数
    • 独立采样一个随机变量 βt(0,C)\beta_t \in (0, C),其中 C=O(1)C = \mathcal{O}(1)
  3. 构造输入特征矩阵 XtX_t
    • sts_t 个样本中生成:
      • 其中 一个样本定义为 βtvnt\beta_t v_{n_t},其中 vntv_{n_t} 是任务 ntn_t 的特征信号。
      • 其余 st1s_t - 1 个样本来自正态分布:N(0,σt2Id)\mathcal{N}(0, \sigma_t^2 I_d),其中 σt0\sigma_t \ge 0 是噪声水平。
  4. 生成输出标签 yty_t
    • 使用线性回归生成: yt=Xtwnty_t = X_t^\top w_{n_t}

最终得到
数据集 Dt=(Xt,yt)\mathcal{D}_t = (X_t, y_t),对应一个线性回归任务。

  1. 这个文章只采用Top-1的experts指定方式

公式理论讲解:

专家参数更新: 当router命中某个experts时,其他experts保持不变,只更新命中的experts,其更新公式为:

wt(mt)=wt1(mt)+Xt(XtXt)1(ytXtwt1(mt))w_t^{(m_t)} = w_{t-1}^{(m_t)} + X_t (X_t^\top X_t)^{-1}(y_t - X_t^\top w_{t-1}^{(m_t)})

专家参数更新公式的由来

目标:在第 tt 轮,专家 mtm_t 要拟合任务数据集 (Xt,yt)(X_t, y_t)
minw Xtwyt22\min_{w}\ \|X_t^\top w - y_t\|_2^2

问题:过参数化 (s_t < d) 时解不唯一,直接算最小二乘解会丢掉历史信息。
> 所以论文改成 约束优化

minw wwt1(mt)22s.t.  Xtw=yt\min_w \ \|w - w_{t-1}^{(m_t)}\|_2^2 \quad s.t.\ \ X_t^\top w = y_t

解法:用拉格朗日乘子或残差投影,可得更新:

wt(mt)=wt1(mt)+Xt(XtXt)1(ytXtwt1(mt))w_t^{(m_t)} = w_{t-1}^{(m_t)} + X_t (X_t^\top X_t)^{-1}\,(y_t - X_t^\top w_{t-1}^{(m_t)})

解释

  • (ytXtwt1)(y_t - X_t^\top w_{t-1}) = 残差 = 真实输出 - 旧预测
  • Xt(XtXt)1X_t (X_t^\top X_t)^{-1} = 把残差投影回参数空间的修正项
  • 整个式子 = 在旧参数附近做一次最小二乘修正

性质

  • 保证 Xtwt=ytX_t^\top w_t = y_t → 新参数能完美拟合当前任务
  • 同时尽量靠近 wt1w_{t-1} → 避免遗忘过大

辅助损失:(这里经常也被称作load balance)

Ltaux(Θt,Dt)=αMm[M]ft(m)Pt(m)L_t^{\text{aux}}(\Theta_t, \mathcal{D}_t) = \alpha \cdot M \cdot \sum_{m \in [M]} f_t^{(m)} \cdot P_t^{(m)}

辅助损失 (Auxiliary Loss)

参数解释

  • α\alpha:权重系数,控制辅助损失在总 loss 中的比重
  • MM:专家数量
  • ft(m)f_t^{(m)}:专家 mm 在前 tt 轮中被选择的频率(历史使用情况)
  • Pt(m)P_t^{(m)}:router 在第 tt 轮给专家 mm 的平均分配概率

作用

  • 惩罚历史上频繁被使用且当前仍高概率被选的专家
  • 鼓励 router 多利用未充分使用的专家
  • 实现 负载均衡,避免专家“过度/稀少”使用
  • 这里尾部项理解起来非常简单,当某个专家m历史使用的次数越多,并且当前轮数依然分配到了较大的logits的时候这个损失项就会变得极大,从而抑制router只会对几个专家的偏好性。进而避免路由坍塌。

局部性损失

Ltloc(Θt,Dt)=m[M]πm(Xt,Θt)wt(m)wt1(m)2L_t^{\text{loc}}(\Theta_t, \mathcal{D}_t) = \sum_{m \in [M]} \pi_m(X_t,\Theta_t)\, \|w_t^{(m)} - w_{t-1}^{(m)}\|_2

局部性损失 (Locality Loss)

参数解释

  • πm(Xt,Θt)\pi_m(X_t,\Theta_t):router 给专家 mm 的概率 (softmax 输出)
  • wt(m)w_t^{(m)}:专家 mm 在当前任务下的参数
  • wt1(m)w_{t-1}^{(m)}:专家 mm 在上一轮的参数

作用

  • 约束专家参数更新不能偏离历史太远
  • 让相似任务被路由到同一专家,从而减小 loss
  • 减少遗忘(新任务更新不会把旧知识完全覆盖)
  • 提高专家的 专精性:每个专家逐渐固定在某类任务上

训练误差(损失):

Lttr(wt(mt),Dt)=1stXtwt(mt)yt22L_t^{\text{tr}}(w_t^{(m_t)}, \mathcal{D}_t) = \frac{1}{s_t}\,\|X_t^\top w_t^{(m_t)} - y_t\|_2^2

训练损失 (Training Loss)

参数解释

  • sts_t:当前任务的数据样本数
  • XtX_t:特征矩阵
  • yty_t:输出标签向量
  • wt(mt)w_t^{(m_t)}:在第 tt 轮被选中的专家的参数

作用

  • 本质是最小二乘回归的均方误差 (MSE)
  • 让选中的专家拟合当前任务数据
  • 保证专家能捕捉任务的真实信号 (ground truth)

总损失:

Lttask=Lttr+Ltloc+LtauxL_t^{\text{task}} = L_t^{\text{tr}} + L_t^{\text{loc}} + L_t^{\text{aux}}

有了上述的总损失函数后,就可以在训练中,进行路由的参数更新了

路由更新公式:

θt+1(m)=θt(m)ηθ(m)Lttask(Θt,wt(mt),Dt),m[M]\theta_{t+1}^{(m)} = \theta_t^{(m)} - \eta \cdot \nabla_{\theta^{(m)}} L_t^{\text{task}}(\Theta_t, w_t^{(m_t)}, \mathcal{D}_t), \quad \forall m \in [M]

Tricks:

Early Termination

在持续学习 (CL) 的场景下,如果 gating network 一直持续更新,随着任务到达轮数的增加,不同专家的分配概率可能逐渐趋于一致,最终导致 专家分化消失错误路由。为了解决这一问题,需要引入 早停机制 (Early Termination)

  • 基本思想 在经过足够轮数的任务探索 (T1T_1 轮) 后,MoE 的专家分配应当逐渐收敛。此时继续训练 gating network 不再带来收益,反而会导致过拟合和任务边界模糊。因此,需要在合适时机 终止路由器参数 Θt\Theta_t 的更新,保持专家划分的稳定性。

  • 收敛判据 定义一个收敛标志 I(m)I^{(m)} 来衡量专家 mm 是否收敛:

    I(m)=hm(Xt,θt)hmt(Xt,θt)I^{(m)} = \big| h_m(X_t, \theta_t) - h_{m_t}(X_t, \theta_t) \big|

    其中,hm(Xt,θt)h_m(X_t,\theta_t) 表示专家 mm 在当前输入上的 gating 输出,hmt(Xt,θt)h_{m_t}(X_t,\theta_t) 表示被 router 实际选择的专家的输出。

    • 若该差距 大于阈值 Γ\Gamma,说明专家 mm 尚未收敛,需要继续更新 Θt\Theta_t
    • 若该差距 小于阈值 Γ\Gamma,则认为 gating network 已经收敛,停止对 Θt\Theta_t 的更新。
    • 如此,则可以避免 router 在已收敛后仍然更新,导致专家划分被破坏。也能确保不同专家能够稳定服务于各自的任务簇。结合 LlocL^{loc}LauxL^{aux} 的约束,早停机制使得系统能在 CL 环境下长期保持平衡和低遗忘。

局部性损失的多种可能性

  • 参数连续性 (Parameter Locality)
Lparamloc=m[M]πm(Xt,Θt)wt(m)wt1(m)2 L^{loc}_{param} = \sum_{m \in [M]} \pi_m(X_t,\Theta_t)\,\|w_t^{(m)} - w_{t-1}^{(m)}\|_2
- 在前章节使用的方法
- 保证同一专家在相邻任务上的参数差异不要太大。
    
  • 表示相似性 (Representation Locality)

    • 可以直接对专家输出的表示(hidden states)施加约束。

    • 比如:

Lreprloc=m[M]πm(Xt,Θt)fm(Xt)fm(Xt1)2L^{loc}_{repr} = \sum_{m \in [M]} \pi_m(X_t,\Theta_t)\,\|f_m(X_t) - f_m(X_{t-1})\|_2
- 让相似输入在同一专家上输出保持稳定。
    
  • 路由概率连续性 (Routing Locality)

    • 约束 router 的分配概率不要随任务跳跃太大。

    • 形式类似:

Lrouteloc=m[M]πm(Xt,Θt)πm(Xt1,Θt1)2L^{loc}_{route} = \sum_{m \in [M]} \|\pi_m(X_t,\Theta_t) - \pi_m(X_{t-1},\Theta_{t-1})\|_2
  • 语义/任务嵌入的相似性 (Task Embedding Locality)

    • 如果能为任务构建一个 task embedding(比如通过元学习或对比学习),可以定义:

      • 相似任务 → 路由到同一专家

      • 不相似任务 → 尽量区分