Skip to content
28 changes: 4 additions & 24 deletions configs/_model/llama/base_model.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
defaults:
- /ff_layer@model.encoder.block_fn.ff_layer_fn: dense
- _self_

common:
_target_: src.definitions.Common
Expand Down Expand Up @@ -75,29 +78,6 @@ model:
rope_base: 500000
rope_scale_freqs: true

ff_layer_fn:
_target_: src.projected_compression.model.ProjectedLlamaFeedForward
_partial_: true
ff_pre_act_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.dmodel}
out_features: ${common.dff}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1
ff_post_act_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.dff}
out_features: ${common.dmodel}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1
gate_fn: ${model.encoder.block_fn.ff_layer_fn.ff_pre_act_fn}

head:
_target_: src.projected_compression.model.TransformerHead
linear_fn:
Expand All @@ -113,4 +93,4 @@ model:
_target_: src.core.model.RMSNorm
_partial_: true
eps: 1e-5
normalized_shape: ${common.dmodel}
normalized_shape: ${common.dmodel}
11 changes: 11 additions & 0 deletions configs/_model/llama/small.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
defaults:
- base_model
- _self_

common:
dmodel: 1024
dff: 2816
dhead: 64
n_blocks: 16
q_heads: 16
kv_heads: 16
25 changes: 25 additions & 0 deletions configs/_model/llama/small_moe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
defaults:
- base_model
- override /ff_layer@model.encoder.block_fn.ff_layer_fn: moe
- _self_

common:
dmodel: 1024
dff: 2816
dhead: 64
n_blocks: 16
q_heads: 16
kv_heads: 16

model:
encoder:
block_fn:
ff_layer_fn:
num_experts: 16
topk: 1
capacity_factor: 1.25
moe_load_balancing_loss_factor: 0.01
moe_router_z_loss_factor: 0.001
normalize_router_logits: false
activation_function: swiglu
init_scale: 1.0
21 changes: 21 additions & 0 deletions configs/ff_layer/dense.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
_target_: src.projected_compression.model.ProjectedLlamaFeedForward
_partial_: true
ff_pre_act_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.dmodel}
out_features: ${common.dff}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1
ff_post_act_fn:
_target_: src.projected_compression.model.Linear
_partial_: true
in_features: ${common.dff}
out_features: ${common.dmodel}
partial_init_fn:
_target_: src.projected_compression.model.llm_random_weight_init
_partial_: true
scale: 1
gate_fn: ${model.encoder.block_fn.ff_layer_fn.ff_pre_act_fn}
12 changes: 12 additions & 0 deletions configs/ff_layer/moe.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
_target_: src.core.moe.MoE
_partial_: true
dmodel: ${common.dmodel}
dff: ${common.dff}
num_experts: ???
topk: ???
capacity_factor: ???
moe_load_balancing_loss_factor: ???
moe_router_z_loss_factor: ???
normalize_router_logits: ???
activation_function: ???
init_scale: ???
53 changes: 53 additions & 0 deletions configs/moe_example_run.yaml
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove / rename

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed

Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
defaults:
- _cluster@_here_: entropy
- _model/llama@_here_: small_moe
- _trainer@_here_: llama
- _dataset@_here_: c4
- _checkpoints@_here_: none
- _misc@_here_: default
- _eval@_here_: basic
- _self_

common:
sequence_length: 1024
batch_size: 64

model:
embedding:
vocab_size: 50257

trainer:
gradient_accumulation_steps: 1
n_steps: 1000
learning_rate: 5e-4

train_dataloader:
dataset:
tokenize_fn:
_target_: src.core.datasets.gpt2_tokenize_fn

eval_dataloader:
dataset:
tokenize_fn:
_target_: src.core.datasets.gpt2_tokenize_fn

infrastructure:
max_concurrent_jobs: 1

metric_logger:
type: wandb
wandb_entity: ideas_cv
project_name: llm-random-test
name: moe_2gpu
tags:
- nano
- remote
- small
- moe

slurm:
time: "0-02:00:00"
gres: gpu:2
job-name: ${infrastructure.metric_logger.name}

evaluator: null
9 changes: 7 additions & 2 deletions src/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,13 @@ def forward(self, x):
class TransformerTower(nn.Module):
def get_model_dimensions(self):
# Works only for llama3 transforermer architecture
dmodel = self.blocks[0].ff_layer.layer.ff_pre_act.weight.shape[1]
dff = self.blocks[0].ff_layer.layer.ff_pre_act.weight.shape[0]
ff_layer = self.blocks[0].ff_layer.layer
if getattr(ff_layer, "is_moe", False):
dmodel = ff_layer.dmodel
dff = ff_layer.dff
else:
dmodel = ff_layer.ff_pre_act.weight.shape[1]
dff = ff_layer.ff_pre_act.weight.shape[0]
datt = self.blocks[0].attention_layer.layer.q_proj.weight.shape[0]
n_att_heads = self.blocks[0].attention_layer.layer.q_heads
n_kvatt_heads = self.blocks[0].attention_layer.layer.kv_heads
Expand Down
169 changes: 169 additions & 0 deletions src/core/moe.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
import logging
import math


logger = logging.getLogger(__name__)


@torch.no_grad()
def _truncated_normal_(weight: torch.Tensor, fan_in: int, scale: float) -> None:
std = scale * (1 / fan_in) ** 0.5
trunc_normal_(weight, mean=0.0, std=std, a=-2 * std, b=2 * std)


class MoE(nn.Module):
def __init__(
self,
dmodel: int,
dff: int,
num_experts: int,
topk: int,
capacity_factor: float = 1.25,
moe_load_balancing_loss_factor: float = 0.0,
moe_router_z_loss_factor: float = 0.0,
normalize_router_logits: bool = False,
activation_function: str = "swiglu",
init_scale: float = 1.0,
):
super().__init__()

if activation_function != "swiglu":
raise ValueError(f"MoE supports only swiglu, got {activation_function}.")
if topk > num_experts:
raise ValueError(f"topk={topk} must be <= num_experts={num_experts}.")
if capacity_factor <= 0:
raise ValueError(f"capacity_factor must be > 0, got {capacity_factor}.")
if normalize_router_logits and topk == 1:
raise AssertionError("normalize_router_logits requires topk > 1.")

self.dmodel = dmodel
self.dff = dff
self.num_experts = num_experts
self.topk = topk
self.capacity_factor = capacity_factor
self.moe_load_balancing_loss_factor = moe_load_balancing_loss_factor
self.moe_router_z_loss_factor = moe_router_z_loss_factor
self.normalize_router_logits = normalize_router_logits
self.is_moe = True
self.moe_load_balancing_loss = None
self.router_z_loss = None

self.router_weight = nn.Parameter(torch.empty(num_experts, dmodel))
self.ff_pre_act_weight = nn.Parameter(torch.empty(num_experts, dff, dmodel))
self.gate_weight = nn.Parameter(torch.empty(num_experts, dff, dmodel))
self.ff_post_act_weight = nn.Parameter(torch.empty(num_experts, dmodel, dff))

_truncated_normal_(self.router_weight, dmodel, init_scale)
_truncated_normal_(self.ff_pre_act_weight, dmodel, init_scale)
_truncated_normal_(self.gate_weight, dmodel, init_scale)
_truncated_normal_(self.ff_post_act_weight, dff, init_scale)

def forward(self, x: torch.Tensor) -> torch.Tensor:
original_shape = x.shape
hidden_states = x.reshape(-1, self.dmodel)
num_tokens = hidden_states.size(0)

# Router
router_logits = torch.einsum(
"th,eh->te",
hidden_states,
self.router_weight,
)
router_logits = router_logits.to(dtype=torch.float32)
router_probs = F.softmax(router_logits, dim=-1)
# For each token, keep only the top-k experts and their routing probabilities
topk_probs, selected_experts = torch.topk(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

question: should the routing weights sum to 1, when num_experts_per_tok > 1?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added the option to normalize

router_probs,
k=self.topk,
dim=-1,
)

# Keep only the highest-gated assignments per expert up to its capacity
flat_tokens = torch.arange(
num_tokens, device=hidden_states.device, dtype=torch.long
).repeat_interleave(self.topk)
flat_experts = selected_experts.reshape(-1)
flat_weights = topk_probs.reshape(-1)
total_assignments = flat_experts.numel()
capacity = max(
1,
math.ceil(self.capacity_factor * total_assignments / self.num_experts),
)
weight_order = torch.argsort(flat_weights, descending=True, stable=True)
grouped_order = torch.argsort(flat_experts[weight_order], stable=True)
sort_order = weight_order[grouped_order]
sorted_experts = flat_experts[sort_order]
sorted_tokens = flat_tokens[sort_order]
sorted_weights = flat_weights[sort_order]
expert_counts = sorted_experts.bincount(minlength=self.num_experts)
expert_offsets = expert_counts.cumsum(0) - expert_counts
slot_in_expert = (
torch.arange(total_assignments, device=hidden_states.device)
- expert_offsets[sorted_experts]
)
keep = slot_in_expert < capacity
kept_experts = sorted_experts[keep]
kept_tokens = sorted_tokens[keep]
kept_slots = slot_in_expert[keep]
kept_weights = sorted_weights[keep]
if self.normalize_router_logits and kept_weights.numel() > 0:
# Renormalize only the surviving expert weights so each token sums to 1 after capacity pruning.
token_weight_sums = kept_weights.new_zeros(num_tokens)
token_weight_sums.index_add_(0, kept_tokens, kept_weights)
kept_weights = kept_weights / token_weight_sums.index_select(0, kept_tokens)

# Dispatch the surviving tokens into expert-capacity slots and run the expert MLP batched per expert
flat_capacity = self.num_experts * capacity
dispatch_index = kept_experts * capacity + kept_slots
expert_inputs = hidden_states.new_zeros(flat_capacity, self.dmodel)
expert_inputs.index_copy_(0, dispatch_index, hidden_states[kept_tokens])
expert_inputs = expert_inputs.view(
self.num_experts,
capacity,
self.dmodel,
)
ff_pre_act = torch.einsum(
"ech,edh->ecd",
expert_inputs,
self.ff_pre_act_weight,
)
gate = torch.einsum(
"ech,edh->ecd",
expert_inputs,
self.gate_weight,
)
expert_outputs = torch.einsum(
"ecd,ehd->ech",
ff_pre_act * F.silu(gate),
self.ff_post_act_weight,
)

# Gather only the kept expert outputs back to tokens and sum the top-k contributions
token_updates = expert_outputs.view(flat_capacity, self.dmodel).index_select(
0, dispatch_index
)
token_updates = token_updates * kept_weights.to(hidden_states.dtype).unsqueeze(
-1
)
output = hidden_states.new_zeros(num_tokens, self.dmodel)
output = output.index_add(0, kept_tokens, token_updates)
output = output.reshape(original_shape)

# Match the switch-style load-balancing term using pre-capacity routing statistics
if self.training:
expert_frequency = flat_experts.bincount(minlength=self.num_experts)
expert_frequency = expert_frequency.to(router_probs.dtype)
expert_frequency = expert_frequency / expert_frequency.sum().clamp_min(1)
self.moe_load_balancing_loss = (
self.num_experts * (router_probs.mean(dim=0) * expert_frequency).sum()
)
self.router_z_loss = torch.logsumexp(router_logits, dim=-1).square().mean()
else:
self.moe_load_balancing_loss = None
self.router_z_loss = None

return output
Loading
Loading