
每个专家网络都需要生成整个输出结果。这意味着专家网络的独立性强,权重更新不再需要考虑其他专家网络的影响。更重要的是,在这种损失函数的训练下,当一个专家网络的误差小于所有专家网络误差的加权平均值时,它的权重就会增大,而当它的误差大于此加权平均值时,它的权重就会减小。所以,使用这种损失函数训练出来的模型,各专家网络之间是竞争关系,而不是合作关系。正是这种“竞争上岗”的模式,形成了动态加载的效果。
每个专家独立计算损失,从而鼓励每个数据样本尽可能被一个专家处理--这种结构不仅提高了模型的效率,还使模型在推理时可以只激活部分专家,从而大幅减少了计算资源的消耗。如同唐僧师徒团队:唐僧名气大、面子大,遇到社交场合,就由唐僧去谈;孙空擅长降妖除,遇到妖怪就请孙悟空出战;沙和尚任劳任怨,脏活累活由沙和尚干;猪八戒好吃懒做,就在团队搞搞气氛。这就是模块解耦要达到的效果。
这里有2个公式,先看看怎么理解:
有一个开源的项目,有对应的代码:
import torch
import torch.nn as nn
import torch.nn.functional as F
class NaiveMoELayer(nn.Module):
def __init__(self, input_size, output_size, num_experts):
super().__init__()
self.num_experts = num_experts
self.experts = nn.ModuleList([
nn.Linear(input_size, output_size) for _ in range(num_experts)
])
self.gate = nn.Linear(input_size, num_experts)
def forward(self, x):
gating_weights = F.softmax(self.gate(x), dim=-1)
expert_outputs = []
for expert in self.experts:
expert_outputs.append(expert(x))
expert_outputs = torch.stack(expert_outputs, dim=1)
final_output = torch.einsum('bn, bno -> bo', gating_weights, expert_outputs)
return final_output, gating_weights
model = NaiveMoELayer(input_size=100, output_size=10, num_experts=4)
input_data = torch.randn(32, 100)
target = torch.randn(32, 10)
output, gates = model(input_data)
loss = nn.MSELoss()(output, target)
loss.backward()
那么再看第2个公式:

在pytorch中有对应的代码
计算加权 MSE 损失
import torch
import torch.nn as nn
def weighted_mse_loss(gating_weights, expert_outputs, targets):
"""
Args:
gating_weights (torch.Tensor): [batch_size, num_experts] 专家权重
expert_outputs (torch.Tensor): [batch_size, num_experts, output_dim] 各专家预测
targets (torch.Tensor): [batch_size, output_dim] 真实值
Returns:
torch.Tensor: 加权 MSE 损失
"""
mse_per_expert = (expert_outputs - targets.unsqueeze(1)) ** 2
mse_per_expert = mse_per_expert.mean(dim=-1)
weighted_mse = (gating_weights * mse_per_expert).sum(dim=-1)
return weighted_mse.mean()
在 MoE 模型中使用
class MoE(nn.Module):
def __init__(self, input_dim, output_dim, num_experts):
super().__init__()
self.experts = nn.ModuleList([nn.Linear(input_dim, output_dim) for _ in range(num_experts)])
self.gate = nn.Linear(input_dim, num_experts)
def forward(self, x, targets=None):
gating_weights = torch.softmax(self.gate(x), dim=-1)
expert_outputs = torch.stack([expert(x) for expert in self.experts], dim=1)
output = (gating_weights.unsqueeze(-1) * expert_outputs).sum(dim=1)
loss = None
if targets is not None:
loss = weighted_mse_loss(gating_weights, expert_outputs, targets)
return output, loss
公式的作用
- 专家选择优化
- 门控网络通过最小化 Ec 学习分配权重 pic,使误差小的专家获得更高权重。
- 稀疏 MoE 训练
- 结合 **Top-K Gating ** (如 Switch Transformer),仅激活部分专家,减少计算量。
- 多任务学习
- 如果 c 代表不同任务,该公式可用于 **任务感知的专家分配 ** 。