From 7982e2086ff151d9c3a5a63c3acc9e275d67ac7d Mon Sep 17 00:00:00 2001 From: Jakub Date: Thu, 26 Mar 2026 13:39:14 +0100 Subject: [PATCH 01/16] Add base config --- configs/jk_test.yaml | 45 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 45 insertions(+) create mode 100644 configs/jk_test.yaml diff --git a/configs/jk_test.yaml b/configs/jk_test.yaml new file mode 100644 index 00000000..2694c811 --- /dev/null +++ b/configs/jk_test.yaml @@ -0,0 +1,45 @@ +defaults: + - _cluster@_here_: entropy + - _model/jk_moe_check@_here_: tiny + - _trainer@_here_: llama + - _dataset@_here_: c4 + - _checkpoints@_here_: none + - _misc@_here_: default + - _eval@_here_: basic + - _self_ + +common: + sequence_length: 128 + batch_size: 128 + +trainer: + gradient_accumulation_steps: 1 + n_steps: 100 + learning_rate: 1e-3 + + checkpoint: + save: + type: nano + path: checkpoint + +infrastructure: + max_concurrent_jobs: 1 + + metric_logger: + type: wandb + wandb_entity: ideas_cv + project_name: llm-random-test + name: moe + tags: + - nano + - remote + - tiny + - moe + + slurm: + time: "00:10:00" + gres: gpu:1 + job-name: ${infrastructure.metric_logger.name} + +evaluator: + limit: 10 From 312d871adb9ae6eddf65cf9cab183772960fad26 Mon Sep 17 00:00:00 2001 From: Jakub Date: Thu, 26 Mar 2026 19:56:31 +0100 Subject: [PATCH 02/16] Initial impl --- configs/_cluster/entropy.yaml | 2 +- configs/_model/llama/small.yaml | 12 ++ configs/_model/llama/small_moe.yaml | 27 ++++ configs/jk_test.yaml | 13 +- src/core/model.py | 9 +- src/core/moe.py | 186 ++++++++++++++++++++++++++++ src/core/trainer.py | 16 ++- 7 files changed, 254 insertions(+), 11 deletions(-) create mode 100644 configs/_model/llama/small.yaml create mode 100644 configs/_model/llama/small_moe.yaml create mode 100644 src/core/moe.py diff --git a/configs/_cluster/entropy.yaml b/configs/_cluster/entropy.yaml index f341770e..3fdbf108 100644 --- a/configs/_cluster/entropy.yaml +++ b/configs/_cluster/entropy.yaml @@ -19,7 +19,7 @@ infrastructure: - 'export HYDRA_FULL_ERROR=1' # export pixi variables - - 'export PIXI_HOME=/storage_nvme_4/nano/pixi_new' + - 'export PIXI_HOME=/storage_nvme_4/nano/pixi_jk' - 'export PATH="$PIXI_HOME/bin:$PATH"' - 'export XDG_DATA_HOME="$PIXI_HOME/data"' - 'export XDG_CACHE_HOME="$PIXI_HOME/cache"' diff --git a/configs/_model/llama/small.yaml b/configs/_model/llama/small.yaml new file mode 100644 index 00000000..63da4707 --- /dev/null +++ b/configs/_model/llama/small.yaml @@ -0,0 +1,12 @@ +defaults: + - base_model + - _self_ + +common: + dmodel: 512 + dff: 2048 + dhead: 64 + sequence_length: 2048 + n_blocks: 2 + q_heads: 8 + kv_heads: 8 diff --git a/configs/_model/llama/small_moe.yaml b/configs/_model/llama/small_moe.yaml new file mode 100644 index 00000000..37dc18f7 --- /dev/null +++ b/configs/_model/llama/small_moe.yaml @@ -0,0 +1,27 @@ +defaults: + - base_model + - _self_ + +common: + dmodel: 512 + dff: 2048 + dhead: 64 + sequence_length: 2048 + n_blocks: 2 + q_heads: 8 + kv_heads: 8 + +model: + encoder: + block_fn: + ff_layer_fn: + _target_: src.core.moe.MoE + _partial_: true + dmodel: ${common.dmodel} + dff: ${common.dff} + num_experts: 4 + num_experts_per_tok: 2 + capacity_factor: 1.25 + moe_load_balancing_loss_factor: 0.01 + activation_function: swiglu + init_scale: 1.0 diff --git a/configs/jk_test.yaml b/configs/jk_test.yaml index 2694c811..7a42b0dd 100644 --- a/configs/jk_test.yaml +++ b/configs/jk_test.yaml @@ -1,6 +1,6 @@ defaults: - _cluster@_here_: entropy - - _model/jk_moe_check@_here_: tiny + - _model/llama@_here_: small_moe - _trainer@_here_: llama - _dataset@_here_: c4 - _checkpoints@_here_: none @@ -9,7 +9,7 @@ defaults: - _self_ common: - sequence_length: 128 + sequence_length: 1024 batch_size: 128 trainer: @@ -29,17 +29,16 @@ infrastructure: type: wandb wandb_entity: ideas_cv project_name: llm-random-test - name: moe + name: moe_fixed tags: - nano - remote - - tiny + - small - moe slurm: - time: "00:10:00" + time: "1-00:00:00" gres: gpu:1 job-name: ${infrastructure.metric_logger.name} -evaluator: - limit: 10 +evaluator: null diff --git a/src/core/model.py b/src/core/model.py index 0d20118e..4de083c8 100644 --- a/src/core/model.py +++ b/src/core/model.py @@ -188,8 +188,13 @@ def forward(self, x): class TransformerTower(nn.Module): def get_model_dimensions(self): # Works only for llama3 transforermer architecture - dmodel = self.blocks[0].ff_layer.layer.ff_pre_act.weight.shape[1] - dff = self.blocks[0].ff_layer.layer.ff_pre_act.weight.shape[0] + ff_layer = self.blocks[0].ff_layer.layer + if getattr(ff_layer, "is_moe", False): + dmodel = ff_layer.dmodel + dff = ff_layer.dff + else: + dmodel = ff_layer.ff_pre_act.weight.shape[1] + dff = ff_layer.ff_pre_act.weight.shape[0] datt = self.blocks[0].attention_layer.layer.q_proj.weight.shape[0] n_att_heads = self.blocks[0].attention_layer.layer.q_heads n_kvatt_heads = self.blocks[0].attention_layer.layer.kv_heads diff --git a/src/core/moe.py b/src/core/moe.py new file mode 100644 index 00000000..05f301fe --- /dev/null +++ b/src/core/moe.py @@ -0,0 +1,186 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ +import logging +import math + + +logger = logging.getLogger(__name__) + + +@torch.no_grad() +def _truncated_normal_(weight: torch.Tensor, fan_in: int, scale: float) -> None: + std = scale * (1 / fan_in) ** 0.5 + trunc_normal_(weight, mean=0.0, std=std, a=-2 * std, b=2 * std) + + +class MoE(nn.Module): + def __init__( + self, + dmodel: int, + dff: int, + num_experts: int, + num_experts_per_tok: int, + capacity_factor: float = 1.25, + moe_load_balancing_loss_factor: float = 0.0, + activation_function: str = "swiglu", + init_scale: float = 1.0, + **_ignored_kwargs, + ): + super().__init__() + + if activation_function != "swiglu": + raise ValueError( + f"MoE supports only swiglu, got {activation_function}." + ) + if num_experts_per_tok > num_experts: + raise ValueError( + f"num_experts_per_tok={num_experts_per_tok} must be <= num_experts={num_experts}." + ) + if capacity_factor <= 0: + raise ValueError( + f"capacity_factor must be > 0, got {capacity_factor}." + ) + + self.dmodel = dmodel + self.dff = dff + self.num_experts = num_experts + self.num_experts_per_tok = num_experts_per_tok + self.capacity_factor = capacity_factor + self.moe_load_balancing_loss_factor = moe_load_balancing_loss_factor + self.is_moe = True + self.aux_loss = None + + self.router_weight = nn.Parameter(torch.empty(num_experts, dmodel)) + self.ff_pre_act_weight = nn.Parameter( + torch.empty(num_experts, dff, dmodel) + ) + self.gate_weight = nn.Parameter( + torch.empty(num_experts, dff, dmodel) + ) + self.ff_post_act_weight = nn.Parameter( + torch.empty(num_experts, dmodel, dff) + ) + + _truncated_normal_(self.router_weight, dmodel, init_scale) + _truncated_normal_(self.ff_pre_act_weight, dmodel, init_scale) + _truncated_normal_(self.gate_weight, dmodel, init_scale) + _truncated_normal_(self.ff_post_act_weight, dff, init_scale) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + original_shape = x.shape + hidden_states = x.reshape(-1, self.dmodel) + num_tokens = hidden_states.size(0) + + # Router + router_logits = torch.einsum( + "th,eh->te", + hidden_states, + self.router_weight, + ) + router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) + router_weights, selected_experts = torch.topk( + router_probs, + k=self.num_experts_per_tok, + dim=-1, + ) + router_weights = router_weights / router_weights.sum( + dim=-1, keepdim=True + ).clamp_min(torch.finfo(router_weights.dtype).eps) + + # Keep only the highest-gated assignments per expert up to its capacity. + flat_tokens = torch.arange( + num_tokens, device=hidden_states.device, dtype=torch.long + ).repeat_interleave(self.num_experts_per_tok) + flat_experts = selected_experts.reshape(-1) + flat_weights = router_weights.reshape(-1) + total_assignments = flat_experts.numel() + capacity = max( + 1, + math.ceil( + self.capacity_factor * total_assignments / self.num_experts + ), + ) + weight_order = torch.argsort(flat_weights, descending=True, stable=True) + grouped_order = torch.argsort( + flat_experts[weight_order], stable=True + ) + sort_order = weight_order[grouped_order] + sorted_experts = flat_experts[sort_order] + sorted_tokens = flat_tokens[sort_order] + sorted_weights = flat_weights[sort_order] + expert_counts = sorted_experts.bincount(minlength=self.num_experts) + expert_offsets = expert_counts.cumsum(0) - expert_counts + slot_in_expert = ( + torch.arange(total_assignments, device=hidden_states.device) + - expert_offsets[sorted_experts] + ) + keep = slot_in_expert < capacity + kept_experts = sorted_experts[keep] + kept_tokens = sorted_tokens[keep] + kept_slots = slot_in_expert[keep] + kept_weights = sorted_weights[keep] + token_weight_sums = torch.zeros( + num_tokens, + dtype=kept_weights.dtype, + device=hidden_states.device, + ) + token_weight_sums = token_weight_sums.index_add( + 0, + kept_tokens, + kept_weights, + ) + kept_weights = kept_weights / token_weight_sums[kept_tokens].clamp_min( + torch.finfo(kept_weights.dtype).eps + ) + + # Dispatch the surviving tokens into expert-capacity slots and run the expert MLP batched per expert. + flat_capacity = self.num_experts * capacity + dispatch_index = kept_experts * capacity + kept_slots + expert_inputs = hidden_states.new_zeros(flat_capacity, self.dmodel) + expert_inputs.index_copy_(0, dispatch_index, hidden_states[kept_tokens]) + expert_inputs = expert_inputs.view( + self.num_experts, + capacity, + self.dmodel, + ) + ff_pre_act = torch.einsum( + "ech,edh->ecd", + expert_inputs, + self.ff_pre_act_weight, + ) + gate = torch.einsum( + "ech,edh->ecd", + expert_inputs, + self.gate_weight, + ) + expert_outputs = torch.einsum( + "ecd,ehd->ech", + ff_pre_act * F.silu(gate), + self.ff_post_act_weight, + ) + + # Gather only the kept expert outputs back to tokens and sum the top-k contributions. + token_updates = expert_outputs.view(flat_capacity, self.dmodel).index_select( + 0, dispatch_index + ) + token_updates = token_updates * kept_weights.to(hidden_states.dtype).unsqueeze(-1) + output = hidden_states.new_zeros(num_tokens, self.dmodel) + output = output.index_add(0, kept_tokens, token_updates) + output = output.reshape(original_shape) + + # Match the switch-style load-balancing term using pre-capacity routing statistics. + if self.training: + expert_frequency = flat_experts.bincount( + minlength=self.num_experts + ) + expert_frequency = expert_frequency.to(router_probs.dtype) + expert_frequency = expert_frequency / expert_frequency.sum().clamp_min(1) + self.aux_loss = self.num_experts * ( + router_probs.mean(dim=0) * expert_frequency + ).sum() + else: + self.aux_loss = None + + return output diff --git a/src/core/trainer.py b/src/core/trainer.py index 18fab093..ee95e8ec 100644 --- a/src/core/trainer.py +++ b/src/core/trainer.py @@ -168,7 +168,12 @@ def _hack_for_python_garbage_collection(input_ids, target_ids): target_ids.reshape(-1).long(), reduction="none", ) - loss = mask_loss.mean() / self.gradient_accumulation_steps + loss = mask_loss.mean() + if self.model.training: + loss = loss + self._calculate_moe_load_balancing_loss( + device=predicted_ids.device, + ) + loss = loss / self.gradient_accumulation_steps return loss losses = [] @@ -190,6 +195,15 @@ def _hack_for_python_garbage_collection(input_ids, target_ids): return avg_loss / float(os.environ["WORLD_SIZE"]) + def _calculate_moe_load_balancing_loss(self, device): + loss = torch.zeros((), device=device) + for module in self.model.modules(): + aux_loss = getattr(module, "aux_loss", None) + factor = getattr(module, "moe_load_balancing_loss_factor", 0.0) + if aux_loss is not None and factor: + loss = loss + aux_loss.to(device=device) * factor + return loss + def eval(self): self.model.eval() saved_step = self.step From add5718c67a33573388c9d7dd1ccbe6fa8d6fac1 Mon Sep 17 00:00:00 2001 From: Jakub Date: Thu, 26 Mar 2026 20:06:33 +0100 Subject: [PATCH 03/16] logging and readability --- configs/_model/llama/small_moe.yaml | 4 ++-- src/core/moe.py | 17 +++++++++-------- src/core/trainer.py | 28 +++++++++++++++++++++++++--- 3 files changed, 36 insertions(+), 13 deletions(-) diff --git a/configs/_model/llama/small_moe.yaml b/configs/_model/llama/small_moe.yaml index 37dc18f7..ce7db0f5 100644 --- a/configs/_model/llama/small_moe.yaml +++ b/configs/_model/llama/small_moe.yaml @@ -19,8 +19,8 @@ model: _partial_: true dmodel: ${common.dmodel} dff: ${common.dff} - num_experts: 4 - num_experts_per_tok: 2 + num_experts: 16 + num_experts_per_tok: 1 capacity_factor: 1.25 moe_load_balancing_loss_factor: 0.01 activation_function: swiglu diff --git a/src/core/moe.py b/src/core/moe.py index 05f301fe..bc55b643 100644 --- a/src/core/moe.py +++ b/src/core/moe.py @@ -80,21 +80,22 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.router_weight, ) router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) - router_weights, selected_experts = torch.topk( + # For each token, keep only the top-k experts and their routing probabilities + topk_probs, selected_experts = torch.topk( router_probs, k=self.num_experts_per_tok, dim=-1, ) - router_weights = router_weights / router_weights.sum( + topk_probs = topk_probs / topk_probs.sum( dim=-1, keepdim=True - ).clamp_min(torch.finfo(router_weights.dtype).eps) + ).clamp_min(torch.finfo(topk_probs.dtype).eps) - # Keep only the highest-gated assignments per expert up to its capacity. + # Keep only the highest-gated assignments per expert up to its capacity flat_tokens = torch.arange( num_tokens, device=hidden_states.device, dtype=torch.long ).repeat_interleave(self.num_experts_per_tok) flat_experts = selected_experts.reshape(-1) - flat_weights = router_weights.reshape(-1) + flat_weights = topk_probs.reshape(-1) total_assignments = flat_experts.numel() capacity = max( 1, @@ -135,7 +136,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: torch.finfo(kept_weights.dtype).eps ) - # Dispatch the surviving tokens into expert-capacity slots and run the expert MLP batched per expert. + # Dispatch the surviving tokens into expert-capacity slots and run the expert MLP batched per expert flat_capacity = self.num_experts * capacity dispatch_index = kept_experts * capacity + kept_slots expert_inputs = hidden_states.new_zeros(flat_capacity, self.dmodel) @@ -161,7 +162,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.ff_post_act_weight, ) - # Gather only the kept expert outputs back to tokens and sum the top-k contributions. + # Gather only the kept expert outputs back to tokens and sum the top-k contributions token_updates = expert_outputs.view(flat_capacity, self.dmodel).index_select( 0, dispatch_index ) @@ -170,7 +171,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: output = output.index_add(0, kept_tokens, token_updates) output = output.reshape(original_shape) - # Match the switch-style load-balancing term using pre-capacity routing statistics. + # Match the switch-style load-balancing term using pre-capacity routing statistics if self.training: expert_frequency = flat_experts.bincount( minlength=self.num_experts diff --git a/src/core/trainer.py b/src/core/trainer.py index ee95e8ec..8eff212a 100644 --- a/src/core/trainer.py +++ b/src/core/trainer.py @@ -49,6 +49,7 @@ def __attrs_post_init__(self): self.start_step = self.training_state["next_step"] self.device = next(self.model.parameters()).device self.loss_interval_100 = 0.0 + self._last_moe_load_balancing_loss = torch.zeros((), device=self.device) if self.eval_dataloader is not None and hasattr( self.eval_dataloader, "__iter__" @@ -169,30 +170,47 @@ def _hack_for_python_garbage_collection(input_ids, target_ids): reduction="none", ) loss = mask_loss.mean() + moe_load_balancing_loss = torch.zeros((), device=predicted_ids.device) if self.model.training: - loss = loss + self._calculate_moe_load_balancing_loss( + moe_load_balancing_loss = self._calculate_moe_load_balancing_loss( device=predicted_ids.device, ) + loss = loss + moe_load_balancing_loss loss = loss / self.gradient_accumulation_steps - return loss + moe_load_balancing_loss = ( + moe_load_balancing_loss / self.gradient_accumulation_steps + ) + return loss, moe_load_balancing_loss losses = [] + moe_load_balancing_losses = [] for batch_chunk in batch.chunk(self.gradient_accumulation_steps): input_ids, target_ids = self._preprocess_input(batch_chunk) input_ids = input_ids.to(self.device) if self.model.training: self._update_processed_tokens(input_ids) - loss = _hack_for_python_garbage_collection(input_ids, target_ids) + loss, moe_load_balancing_loss = _hack_for_python_garbage_collection( + input_ids, target_ids + ) if self.model.training: loss.backward() losses.append(loss.item()) + moe_load_balancing_losses.append(moe_load_balancing_loss.item()) # gloo backend supports only sum reduce operation, therfore we first divide by world size and then sum avg_loss = torch.tensor(losses, device=loss.device).sum() + avg_moe_load_balancing_loss = torch.tensor( + moe_load_balancing_losses, + device=loss.device, + ).sum() if dist.is_initialized(): dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(avg_moe_load_balancing_loss, op=dist.ReduceOp.SUM) + self._last_moe_load_balancing_loss = avg_moe_load_balancing_loss / float( + os.environ["WORLD_SIZE"] + ) return avg_loss / float(os.environ["WORLD_SIZE"]) def _calculate_moe_load_balancing_loss(self, device): @@ -243,6 +261,10 @@ def _update_processed_tokens(self, batch): def log_metrics(self, loss, grad_norm): self.metric_logger.set_tokens(self.processed_tokens) self.metric_logger.log("train/loss", loss.item()) + self.metric_logger.log( + "train/moe_load_balancing_loss", + self._last_moe_load_balancing_loss.item(), + ) self.metric_logger.log("train/lr", self.scheduler.get_last_lr()[0]) self.metric_logger.log("train/grad_norm", grad_norm.item()) From d968ae378ec39ae0b39a509123334fb6927bf20e Mon Sep 17 00:00:00 2001 From: Jakub Date: Thu, 26 Mar 2026 20:39:44 +0100 Subject: [PATCH 04/16] fix configs for test --- configs/_model/llama/small.yaml | 9 ++++----- configs/_model/llama/small_moe.yaml | 11 +++++------ 2 files changed, 9 insertions(+), 11 deletions(-) diff --git a/configs/_model/llama/small.yaml b/configs/_model/llama/small.yaml index 63da4707..39b22712 100644 --- a/configs/_model/llama/small.yaml +++ b/configs/_model/llama/small.yaml @@ -3,10 +3,9 @@ defaults: - _self_ common: - dmodel: 512 + dmodel: 768 dff: 2048 dhead: 64 - sequence_length: 2048 - n_blocks: 2 - q_heads: 8 - kv_heads: 8 + n_blocks: 12 + q_heads: 32 + kv_heads: 32 diff --git a/configs/_model/llama/small_moe.yaml b/configs/_model/llama/small_moe.yaml index ce7db0f5..7c607f7e 100644 --- a/configs/_model/llama/small_moe.yaml +++ b/configs/_model/llama/small_moe.yaml @@ -3,13 +3,12 @@ defaults: - _self_ common: - dmodel: 512 + dmodel: 768 dff: 2048 dhead: 64 - sequence_length: 2048 - n_blocks: 2 - q_heads: 8 - kv_heads: 8 + n_blocks: 12 + q_heads: 32 + kv_heads: 32 model: encoder: @@ -24,4 +23,4 @@ model: capacity_factor: 1.25 moe_load_balancing_loss_factor: 0.01 activation_function: swiglu - init_scale: 1.0 + init_scale: 0.02 From 9d30a8717c3c9f450b32187374c053bd69a04179 Mon Sep 17 00:00:00 2001 From: Jakub Date: Thu, 26 Mar 2026 21:10:53 +0100 Subject: [PATCH 05/16] format --- configs/jk_test.yaml | 8 ++++---- src/core/moe.py | 48 ++++++++++++++++---------------------------- 2 files changed, 21 insertions(+), 35 deletions(-) diff --git a/configs/jk_test.yaml b/configs/jk_test.yaml index 7a42b0dd..3fb351ca 100644 --- a/configs/jk_test.yaml +++ b/configs/jk_test.yaml @@ -1,6 +1,6 @@ defaults: - _cluster@_here_: entropy - - _model/llama@_here_: small_moe + - _model/llama@_here_: small - _trainer@_here_: llama - _dataset@_here_: c4 - _checkpoints@_here_: none @@ -10,11 +10,11 @@ defaults: common: sequence_length: 1024 - batch_size: 128 + batch_size: 32 trainer: gradient_accumulation_steps: 1 - n_steps: 100 + n_steps: 10000 learning_rate: 1e-3 checkpoint: @@ -29,7 +29,7 @@ infrastructure: type: wandb wandb_entity: ideas_cv project_name: llm-random-test - name: moe_fixed + name: test_dense_32 tags: - nano - remote diff --git a/src/core/moe.py b/src/core/moe.py index bc55b643..f1243b8c 100644 --- a/src/core/moe.py +++ b/src/core/moe.py @@ -31,17 +31,13 @@ def __init__( super().__init__() if activation_function != "swiglu": - raise ValueError( - f"MoE supports only swiglu, got {activation_function}." - ) + raise ValueError(f"MoE supports only swiglu, got {activation_function}.") if num_experts_per_tok > num_experts: raise ValueError( f"num_experts_per_tok={num_experts_per_tok} must be <= num_experts={num_experts}." ) if capacity_factor <= 0: - raise ValueError( - f"capacity_factor must be > 0, got {capacity_factor}." - ) + raise ValueError(f"capacity_factor must be > 0, got {capacity_factor}.") self.dmodel = dmodel self.dff = dff @@ -53,15 +49,9 @@ def __init__( self.aux_loss = None self.router_weight = nn.Parameter(torch.empty(num_experts, dmodel)) - self.ff_pre_act_weight = nn.Parameter( - torch.empty(num_experts, dff, dmodel) - ) - self.gate_weight = nn.Parameter( - torch.empty(num_experts, dff, dmodel) - ) - self.ff_post_act_weight = nn.Parameter( - torch.empty(num_experts, dmodel, dff) - ) + self.ff_pre_act_weight = nn.Parameter(torch.empty(num_experts, dff, dmodel)) + self.gate_weight = nn.Parameter(torch.empty(num_experts, dff, dmodel)) + self.ff_post_act_weight = nn.Parameter(torch.empty(num_experts, dmodel, dff)) _truncated_normal_(self.router_weight, dmodel, init_scale) _truncated_normal_(self.ff_pre_act_weight, dmodel, init_scale) @@ -86,9 +76,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: k=self.num_experts_per_tok, dim=-1, ) - topk_probs = topk_probs / topk_probs.sum( - dim=-1, keepdim=True - ).clamp_min(torch.finfo(topk_probs.dtype).eps) + topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True).clamp_min( + torch.finfo(topk_probs.dtype).eps + ) # Keep only the highest-gated assignments per expert up to its capacity flat_tokens = torch.arange( @@ -99,14 +89,10 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: total_assignments = flat_experts.numel() capacity = max( 1, - math.ceil( - self.capacity_factor * total_assignments / self.num_experts - ), + math.ceil(self.capacity_factor * total_assignments / self.num_experts), ) weight_order = torch.argsort(flat_weights, descending=True, stable=True) - grouped_order = torch.argsort( - flat_experts[weight_order], stable=True - ) + grouped_order = torch.argsort(flat_experts[weight_order], stable=True) sort_order = weight_order[grouped_order] sorted_experts = flat_experts[sort_order] sorted_tokens = flat_tokens[sort_order] @@ -166,21 +152,21 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: token_updates = expert_outputs.view(flat_capacity, self.dmodel).index_select( 0, dispatch_index ) - token_updates = token_updates * kept_weights.to(hidden_states.dtype).unsqueeze(-1) + token_updates = token_updates * kept_weights.to(hidden_states.dtype).unsqueeze( + -1 + ) output = hidden_states.new_zeros(num_tokens, self.dmodel) output = output.index_add(0, kept_tokens, token_updates) output = output.reshape(original_shape) # Match the switch-style load-balancing term using pre-capacity routing statistics if self.training: - expert_frequency = flat_experts.bincount( - minlength=self.num_experts - ) + expert_frequency = flat_experts.bincount(minlength=self.num_experts) expert_frequency = expert_frequency.to(router_probs.dtype) expert_frequency = expert_frequency / expert_frequency.sum().clamp_min(1) - self.aux_loss = self.num_experts * ( - router_probs.mean(dim=0) * expert_frequency - ).sum() + self.aux_loss = ( + self.num_experts * (router_probs.mean(dim=0) * expert_frequency).sum() + ) else: self.aux_loss = None From 58ecff8aca13b06a670dd017f46582b5d3d84f64 Mon Sep 17 00:00:00 2001 From: Jakub Date: Sat, 28 Mar 2026 10:07:51 +0100 Subject: [PATCH 06/16] fix bug in logging --- src/core/moe.py | 16 ---------------- src/core/trainer.py | 25 +++++++++++++++++++------ 2 files changed, 19 insertions(+), 22 deletions(-) diff --git a/src/core/moe.py b/src/core/moe.py index f1243b8c..8c1e76e0 100644 --- a/src/core/moe.py +++ b/src/core/moe.py @@ -76,9 +76,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: k=self.num_experts_per_tok, dim=-1, ) - topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True).clamp_min( - torch.finfo(topk_probs.dtype).eps - ) # Keep only the highest-gated assignments per expert up to its capacity flat_tokens = torch.arange( @@ -108,19 +105,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: kept_tokens = sorted_tokens[keep] kept_slots = slot_in_expert[keep] kept_weights = sorted_weights[keep] - token_weight_sums = torch.zeros( - num_tokens, - dtype=kept_weights.dtype, - device=hidden_states.device, - ) - token_weight_sums = token_weight_sums.index_add( - 0, - kept_tokens, - kept_weights, - ) - kept_weights = kept_weights / token_weight_sums[kept_tokens].clamp_min( - torch.finfo(kept_weights.dtype).eps - ) # Dispatch the surviving tokens into expert-capacity slots and run the expert MLP batched per expert flat_capacity = self.num_experts * capacity diff --git a/src/core/trainer.py b/src/core/trainer.py index 8eff212a..953d5a13 100644 --- a/src/core/trainer.py +++ b/src/core/trainer.py @@ -49,6 +49,7 @@ def __attrs_post_init__(self): self.start_step = self.training_state["next_step"] self.device = next(self.model.parameters()).device self.loss_interval_100 = 0.0 + self._last_reported_loss = torch.zeros((), device=self.device) self._last_moe_load_balancing_loss = torch.zeros((), device=self.device) if self.eval_dataloader is not None and hasattr( @@ -169,20 +170,25 @@ def _hack_for_python_garbage_collection(input_ids, target_ids): target_ids.reshape(-1).long(), reduction="none", ) - loss = mask_loss.mean() + # Keep the reported loss as pure CE so train/loss stays comparable to eval/loss; + # the MoE load-balancing term is optimized and logged separately. + reported_loss = mask_loss.mean() + loss = reported_loss moe_load_balancing_loss = torch.zeros((), device=predicted_ids.device) if self.model.training: moe_load_balancing_loss = self._calculate_moe_load_balancing_loss( device=predicted_ids.device, ) loss = loss + moe_load_balancing_loss + reported_loss = reported_loss / self.gradient_accumulation_steps loss = loss / self.gradient_accumulation_steps moe_load_balancing_loss = ( moe_load_balancing_loss / self.gradient_accumulation_steps ) - return loss, moe_load_balancing_loss + return loss, reported_loss, moe_load_balancing_loss losses = [] + reported_losses = [] moe_load_balancing_losses = [] for batch_chunk in batch.chunk(self.gradient_accumulation_steps): input_ids, target_ids = self._preprocess_input(batch_chunk) @@ -190,24 +196,31 @@ def _hack_for_python_garbage_collection(input_ids, target_ids): if self.model.training: self._update_processed_tokens(input_ids) - loss, moe_load_balancing_loss = _hack_for_python_garbage_collection( - input_ids, target_ids + loss, reported_loss, moe_load_balancing_loss = ( + _hack_for_python_garbage_collection(input_ids, target_ids) ) if self.model.training: loss.backward() losses.append(loss.item()) + reported_losses.append(reported_loss.item()) moe_load_balancing_losses.append(moe_load_balancing_loss.item()) # gloo backend supports only sum reduce operation, therfore we first divide by world size and then sum avg_loss = torch.tensor(losses, device=loss.device).sum() + avg_reported_loss = torch.tensor( + reported_losses, + device=loss.device, + ).sum() avg_moe_load_balancing_loss = torch.tensor( moe_load_balancing_losses, device=loss.device, ).sum() if dist.is_initialized(): dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(avg_reported_loss, op=dist.ReduceOp.SUM) dist.all_reduce(avg_moe_load_balancing_loss, op=dist.ReduceOp.SUM) + self._last_reported_loss = avg_reported_loss / float(os.environ["WORLD_SIZE"]) self._last_moe_load_balancing_loss = avg_moe_load_balancing_loss / float( os.environ["WORLD_SIZE"] ) @@ -260,7 +273,7 @@ def _update_processed_tokens(self, batch): def log_metrics(self, loss, grad_norm): self.metric_logger.set_tokens(self.processed_tokens) - self.metric_logger.log("train/loss", loss.item()) + self.metric_logger.log("train/loss", self._last_reported_loss.item()) self.metric_logger.log( "train/moe_load_balancing_loss", self._last_moe_load_balancing_loss.item(), @@ -268,7 +281,7 @@ def log_metrics(self, loss, grad_norm): self.metric_logger.log("train/lr", self.scheduler.get_last_lr()[0]) self.metric_logger.log("train/grad_norm", grad_norm.item()) - self.loss_averaged_100.log(self.metric_logger, loss.item()) + self.loss_averaged_100.log(self.metric_logger, self._last_reported_loss.item()) self.time_diff_averaged_100.log(self.metric_logger, time.time()) self.metric_logger.flush_accumulated_metrics() From 397fb10450f939abbcdb1e725aa85c6e128dc9f9 Mon Sep 17 00:00:00 2001 From: Jakub Date: Sat, 28 Mar 2026 10:09:25 +0100 Subject: [PATCH 07/16] update configs --- configs/_model/llama/small.yaml | 10 +++++----- configs/_model/llama/small_moe.yaml | 11 +++++------ configs/jk_test.yaml | 25 +++++++++++++++++-------- 3 files changed, 27 insertions(+), 19 deletions(-) diff --git a/configs/_model/llama/small.yaml b/configs/_model/llama/small.yaml index 39b22712..7c50ab66 100644 --- a/configs/_model/llama/small.yaml +++ b/configs/_model/llama/small.yaml @@ -3,9 +3,9 @@ defaults: - _self_ common: - dmodel: 768 - dff: 2048 + dmodel: 1024 + dff: 2816 dhead: 64 - n_blocks: 12 - q_heads: 32 - kv_heads: 32 + n_blocks: 16 + q_heads: 16 + kv_heads: 16 diff --git a/configs/_model/llama/small_moe.yaml b/configs/_model/llama/small_moe.yaml index 7c607f7e..e24c81fd 100644 --- a/configs/_model/llama/small_moe.yaml +++ b/configs/_model/llama/small_moe.yaml @@ -3,12 +3,12 @@ defaults: - _self_ common: - dmodel: 768 - dff: 2048 + dmodel: 1024 + dff: 2816 dhead: 64 - n_blocks: 12 - q_heads: 32 - kv_heads: 32 + n_blocks: 16 + q_heads: 16 + kv_heads: 16 model: encoder: @@ -23,4 +23,3 @@ model: capacity_factor: 1.25 moe_load_balancing_loss_factor: 0.01 activation_function: swiglu - init_scale: 0.02 diff --git a/configs/jk_test.yaml b/configs/jk_test.yaml index 3fb351ca..7b47024b 100644 --- a/configs/jk_test.yaml +++ b/configs/jk_test.yaml @@ -1,6 +1,6 @@ defaults: - _cluster@_here_: entropy - - _model/llama@_here_: small + - _model/llama@_here_: small_moe - _trainer@_here_: llama - _dataset@_here_: c4 - _checkpoints@_here_: none @@ -12,15 +12,24 @@ common: sequence_length: 1024 batch_size: 32 +model: + embedding: + vocab_size: 50257 + trainer: gradient_accumulation_steps: 1 - n_steps: 10000 - learning_rate: 1e-3 + n_steps: 40000 + learning_rate: 5e-4 + + train_dataloader: + dataset: + tokenize_fn: + _target_: src.core.datasets.gpt2_tokenize_fn - checkpoint: - save: - type: nano - path: checkpoint + eval_dataloader: + dataset: + tokenize_fn: + _target_: src.core.datasets.gpt2_tokenize_fn infrastructure: max_concurrent_jobs: 1 @@ -29,7 +38,7 @@ infrastructure: type: wandb wandb_entity: ideas_cv project_name: llm-random-test - name: test_dense_32 + name: test_long_fix2_moe_16 tags: - nano - remote From ad26d7cca59cbdf89e3e17d695fdb7204da5e50c Mon Sep 17 00:00:00 2001 From: Jakub Date: Sat, 28 Mar 2026 10:14:54 +0100 Subject: [PATCH 08/16] add z loss --- configs/_model/llama/small_moe.yaml | 1 + src/core/moe.py | 13 ++++++- src/core/trainer.py | 57 +++++++++++++++++++++++++---- 3 files changed, 62 insertions(+), 9 deletions(-) diff --git a/configs/_model/llama/small_moe.yaml b/configs/_model/llama/small_moe.yaml index e24c81fd..9374710a 100644 --- a/configs/_model/llama/small_moe.yaml +++ b/configs/_model/llama/small_moe.yaml @@ -22,4 +22,5 @@ model: num_experts_per_tok: 1 capacity_factor: 1.25 moe_load_balancing_loss_factor: 0.01 + moe_router_z_loss_factor: 0.001 activation_function: swiglu diff --git a/src/core/moe.py b/src/core/moe.py index 8c1e76e0..f28bd12b 100644 --- a/src/core/moe.py +++ b/src/core/moe.py @@ -24,6 +24,7 @@ def __init__( num_experts_per_tok: int, capacity_factor: float = 1.25, moe_load_balancing_loss_factor: float = 0.0, + moe_router_z_loss_factor: float = 0.0, activation_function: str = "swiglu", init_scale: float = 1.0, **_ignored_kwargs, @@ -45,8 +46,11 @@ def __init__( self.num_experts_per_tok = num_experts_per_tok self.capacity_factor = capacity_factor self.moe_load_balancing_loss_factor = moe_load_balancing_loss_factor + self.moe_router_z_loss_factor = moe_router_z_loss_factor self.is_moe = True self.aux_loss = None + self.moe_load_balancing_loss = None + self.router_z_loss = None self.router_weight = nn.Parameter(torch.empty(num_experts, dmodel)) self.ff_pre_act_weight = nn.Parameter(torch.empty(num_experts, dff, dmodel)) @@ -69,7 +73,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: hidden_states, self.router_weight, ) - router_probs = F.softmax(router_logits, dim=-1, dtype=torch.float32) + router_logits = router_logits.to(dtype=torch.float32) + router_probs = F.softmax(router_logits, dim=-1) # For each token, keep only the top-k experts and their routing probabilities topk_probs, selected_experts = torch.topk( router_probs, @@ -148,10 +153,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: expert_frequency = flat_experts.bincount(minlength=self.num_experts) expert_frequency = expert_frequency.to(router_probs.dtype) expert_frequency = expert_frequency / expert_frequency.sum().clamp_min(1) - self.aux_loss = ( + self.moe_load_balancing_loss = ( self.num_experts * (router_probs.mean(dim=0) * expert_frequency).sum() ) + self.aux_loss = self.moe_load_balancing_loss + self.router_z_loss = torch.logsumexp(router_logits, dim=-1).square().mean() else: self.aux_loss = None + self.moe_load_balancing_loss = None + self.router_z_loss = None return output diff --git a/src/core/trainer.py b/src/core/trainer.py index 953d5a13..3d46f30f 100644 --- a/src/core/trainer.py +++ b/src/core/trainer.py @@ -51,6 +51,7 @@ def __attrs_post_init__(self): self.loss_interval_100 = 0.0 self._last_reported_loss = torch.zeros((), device=self.device) self._last_moe_load_balancing_loss = torch.zeros((), device=self.device) + self._last_moe_router_z_loss = torch.zeros((), device=self.device) if self.eval_dataloader is not None and hasattr( self.eval_dataloader, "__iter__" @@ -171,32 +172,38 @@ def _hack_for_python_garbage_collection(input_ids, target_ids): reduction="none", ) # Keep the reported loss as pure CE so train/loss stays comparable to eval/loss; - # the MoE load-balancing term is optimized and logged separately. + # MoE auxiliary terms are optimized and logged separately. reported_loss = mask_loss.mean() loss = reported_loss moe_load_balancing_loss = torch.zeros((), device=predicted_ids.device) + moe_router_z_loss = torch.zeros((), device=predicted_ids.device) if self.model.training: moe_load_balancing_loss = self._calculate_moe_load_balancing_loss( device=predicted_ids.device, ) - loss = loss + moe_load_balancing_loss + moe_router_z_loss = self._calculate_moe_router_z_loss( + device=predicted_ids.device, + ) + loss = loss + moe_load_balancing_loss + moe_router_z_loss reported_loss = reported_loss / self.gradient_accumulation_steps loss = loss / self.gradient_accumulation_steps moe_load_balancing_loss = ( moe_load_balancing_loss / self.gradient_accumulation_steps ) - return loss, reported_loss, moe_load_balancing_loss + moe_router_z_loss = moe_router_z_loss / self.gradient_accumulation_steps + return loss, reported_loss, moe_load_balancing_loss, moe_router_z_loss losses = [] reported_losses = [] moe_load_balancing_losses = [] + moe_router_z_losses = [] for batch_chunk in batch.chunk(self.gradient_accumulation_steps): input_ids, target_ids = self._preprocess_input(batch_chunk) input_ids = input_ids.to(self.device) if self.model.training: self._update_processed_tokens(input_ids) - loss, reported_loss, moe_load_balancing_loss = ( + loss, reported_loss, moe_load_balancing_loss, moe_router_z_loss = ( _hack_for_python_garbage_collection(input_ids, target_ids) ) if self.model.training: @@ -204,6 +211,7 @@ def _hack_for_python_garbage_collection(input_ids, target_ids): losses.append(loss.item()) reported_losses.append(reported_loss.item()) moe_load_balancing_losses.append(moe_load_balancing_loss.item()) + moe_router_z_losses.append(moe_router_z_loss.item()) # gloo backend supports only sum reduce operation, therfore we first divide by world size and then sum avg_loss = torch.tensor(losses, device=loss.device).sum() @@ -215,26 +223,57 @@ def _hack_for_python_garbage_collection(input_ids, target_ids): moe_load_balancing_losses, device=loss.device, ).sum() + avg_moe_router_z_loss = torch.tensor( + moe_router_z_losses, + device=loss.device, + ).sum() if dist.is_initialized(): dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) dist.all_reduce(avg_reported_loss, op=dist.ReduceOp.SUM) dist.all_reduce(avg_moe_load_balancing_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(avg_moe_router_z_loss, op=dist.ReduceOp.SUM) self._last_reported_loss = avg_reported_loss / float(os.environ["WORLD_SIZE"]) self._last_moe_load_balancing_loss = avg_moe_load_balancing_loss / float( os.environ["WORLD_SIZE"] ) + self._last_moe_router_z_loss = avg_moe_router_z_loss / float( + os.environ["WORLD_SIZE"] + ) return avg_loss / float(os.environ["WORLD_SIZE"]) - def _calculate_moe_load_balancing_loss(self, device): + def _calculate_weighted_moe_loss( + self, + device, + loss_attr, + factor_attr, + fallback_loss_attr=None, + ): loss = torch.zeros((), device=device) for module in self.model.modules(): - aux_loss = getattr(module, "aux_loss", None) - factor = getattr(module, "moe_load_balancing_loss_factor", 0.0) + aux_loss = getattr(module, loss_attr, None) + if aux_loss is None and fallback_loss_attr is not None: + aux_loss = getattr(module, fallback_loss_attr, None) + factor = getattr(module, factor_attr, 0.0) if aux_loss is not None and factor: loss = loss + aux_loss.to(device=device) * factor return loss + def _calculate_moe_load_balancing_loss(self, device): + return self._calculate_weighted_moe_loss( + device=device, + loss_attr="moe_load_balancing_loss", + factor_attr="moe_load_balancing_loss_factor", + fallback_loss_attr="aux_loss", + ) + + def _calculate_moe_router_z_loss(self, device): + return self._calculate_weighted_moe_loss( + device=device, + loss_attr="router_z_loss", + factor_attr="moe_router_z_loss_factor", + ) + def eval(self): self.model.eval() saved_step = self.step @@ -278,6 +317,10 @@ def log_metrics(self, loss, grad_norm): "train/moe_load_balancing_loss", self._last_moe_load_balancing_loss.item(), ) + self.metric_logger.log( + "train/moe_router_z_loss", + self._last_moe_router_z_loss.item(), + ) self.metric_logger.log("train/lr", self.scheduler.get_last_lr()[0]) self.metric_logger.log("train/grad_norm", grad_norm.item()) From c49bc50f38d438cef41c2004212161607288241a Mon Sep 17 00:00:00 2001 From: Jakub Date: Thu, 2 Apr 2026 13:05:15 +0200 Subject: [PATCH 09/16] rename config --- configs/{jk_test.yaml => moe_example_run.yaml} | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) rename configs/{jk_test.yaml => moe_example_run.yaml} (88%) diff --git a/configs/jk_test.yaml b/configs/moe_example_run.yaml similarity index 88% rename from configs/jk_test.yaml rename to configs/moe_example_run.yaml index 7b47024b..d0e187b4 100644 --- a/configs/jk_test.yaml +++ b/configs/moe_example_run.yaml @@ -10,7 +10,7 @@ defaults: common: sequence_length: 1024 - batch_size: 32 + batch_size: 64 model: embedding: @@ -18,7 +18,7 @@ model: trainer: gradient_accumulation_steps: 1 - n_steps: 40000 + n_steps: 1000 learning_rate: 5e-4 train_dataloader: @@ -38,7 +38,7 @@ infrastructure: type: wandb wandb_entity: ideas_cv project_name: llm-random-test - name: test_long_fix2_moe_16 + name: moe_2gpu tags: - nano - remote @@ -46,8 +46,8 @@ infrastructure: - moe slurm: - time: "1-00:00:00" - gres: gpu:1 + time: "0-02:00:00" + gres: gpu:2 job-name: ${infrastructure.metric_logger.name} evaluator: null From 88f617a2933a64111b2a1632c1c1016541c4d80c Mon Sep 17 00:00:00 2001 From: Jakub Date: Thu, 2 Apr 2026 22:31:34 +0200 Subject: [PATCH 10/16] Revert entropy cluster config change --- configs/_cluster/entropy.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/configs/_cluster/entropy.yaml b/configs/_cluster/entropy.yaml index 3fdbf108..f341770e 100644 --- a/configs/_cluster/entropy.yaml +++ b/configs/_cluster/entropy.yaml @@ -19,7 +19,7 @@ infrastructure: - 'export HYDRA_FULL_ERROR=1' # export pixi variables - - 'export PIXI_HOME=/storage_nvme_4/nano/pixi_jk' + - 'export PIXI_HOME=/storage_nvme_4/nano/pixi_new' - 'export PATH="$PIXI_HOME/bin:$PATH"' - 'export XDG_DATA_HOME="$PIXI_HOME/data"' - 'export XDG_CACHE_HOME="$PIXI_HOME/cache"' From 4a0134fcab8eedf6abca26e282e92651d5945c7c Mon Sep 17 00:00:00 2001 From: Jakub Date: Thu, 2 Apr 2026 23:16:54 +0200 Subject: [PATCH 11/16] remove aux_loss --- src/core/moe.py | 3 --- src/core/trainer.py | 10 +++------- 2 files changed, 3 insertions(+), 10 deletions(-) diff --git a/src/core/moe.py b/src/core/moe.py index f28bd12b..8b57d7cf 100644 --- a/src/core/moe.py +++ b/src/core/moe.py @@ -48,7 +48,6 @@ def __init__( self.moe_load_balancing_loss_factor = moe_load_balancing_loss_factor self.moe_router_z_loss_factor = moe_router_z_loss_factor self.is_moe = True - self.aux_loss = None self.moe_load_balancing_loss = None self.router_z_loss = None @@ -156,10 +155,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.moe_load_balancing_loss = ( self.num_experts * (router_probs.mean(dim=0) * expert_frequency).sum() ) - self.aux_loss = self.moe_load_balancing_loss self.router_z_loss = torch.logsumexp(router_logits, dim=-1).square().mean() else: - self.aux_loss = None self.moe_load_balancing_loss = None self.router_z_loss = None diff --git a/src/core/trainer.py b/src/core/trainer.py index 3d46f30f..eb746e71 100644 --- a/src/core/trainer.py +++ b/src/core/trainer.py @@ -247,16 +247,13 @@ def _calculate_weighted_moe_loss( device, loss_attr, factor_attr, - fallback_loss_attr=None, ): loss = torch.zeros((), device=device) for module in self.model.modules(): - aux_loss = getattr(module, loss_attr, None) - if aux_loss is None and fallback_loss_attr is not None: - aux_loss = getattr(module, fallback_loss_attr, None) + module_loss = getattr(module, loss_attr, None) factor = getattr(module, factor_attr, 0.0) - if aux_loss is not None and factor: - loss = loss + aux_loss.to(device=device) * factor + if module_loss is not None and factor: + loss = loss + module_loss.to(device=device) * factor return loss def _calculate_moe_load_balancing_loss(self, device): @@ -264,7 +261,6 @@ def _calculate_moe_load_balancing_loss(self, device): device=device, loss_attr="moe_load_balancing_loss", factor_attr="moe_load_balancing_loss_factor", - fallback_loss_attr="aux_loss", ) def _calculate_moe_router_z_loss(self, device): From b63a6b7c94c4395e3ff90b504bab2c84199ed910 Mon Sep 17 00:00:00 2001 From: Jakub Date: Thu, 2 Apr 2026 23:29:51 +0200 Subject: [PATCH 12/16] fix logging and add moe normalization --- src/core/moe.py | 11 +++++++++++ src/core/trainer.py | 22 +++++++++++++--------- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/src/core/moe.py b/src/core/moe.py index 8b57d7cf..734d7fde 100644 --- a/src/core/moe.py +++ b/src/core/moe.py @@ -25,6 +25,7 @@ def __init__( capacity_factor: float = 1.25, moe_load_balancing_loss_factor: float = 0.0, moe_router_z_loss_factor: float = 0.0, + normalize_router_logits: bool = False, activation_function: str = "swiglu", init_scale: float = 1.0, **_ignored_kwargs, @@ -39,6 +40,10 @@ def __init__( ) if capacity_factor <= 0: raise ValueError(f"capacity_factor must be > 0, got {capacity_factor}.") + if normalize_router_logits and num_experts_per_tok == 1: + raise AssertionError( + "normalize_router_logits requires num_experts_per_tok > 1." + ) self.dmodel = dmodel self.dff = dff @@ -47,6 +52,7 @@ def __init__( self.capacity_factor = capacity_factor self.moe_load_balancing_loss_factor = moe_load_balancing_loss_factor self.moe_router_z_loss_factor = moe_router_z_loss_factor + self.normalize_router_logits = normalize_router_logits self.is_moe = True self.moe_load_balancing_loss = None self.router_z_loss = None @@ -109,6 +115,11 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: kept_tokens = sorted_tokens[keep] kept_slots = slot_in_expert[keep] kept_weights = sorted_weights[keep] + if self.normalize_router_logits and kept_weights.numel() > 0: + # Renormalize only the surviving expert weights so each token sums to 1 after capacity pruning. + token_weight_sums = kept_weights.new_zeros(num_tokens) + token_weight_sums.index_add_(0, kept_tokens, kept_weights) + kept_weights = kept_weights / token_weight_sums.index_select(0, kept_tokens) # Dispatch the surviving tokens into expert-capacity slots and run the expert MLP batched per expert flat_capacity = self.num_experts * capacity diff --git a/src/core/trainer.py b/src/core/trainer.py index eb746e71..d2ad72fe 100644 --- a/src/core/trainer.py +++ b/src/core/trainer.py @@ -48,6 +48,9 @@ def __attrs_post_init__(self): self.processed_tokens = self.training_state["processed_tokens"] self.start_step = self.training_state["next_step"] self.device = next(self.model.parameters()).device + self._has_moe_modules = any( + getattr(module, "is_moe", False) for module in self.model.modules() + ) self.loss_interval_100 = 0.0 self._last_reported_loss = torch.zeros((), device=self.device) self._last_moe_load_balancing_loss = torch.zeros((), device=self.device) @@ -177,7 +180,7 @@ def _hack_for_python_garbage_collection(input_ids, target_ids): loss = reported_loss moe_load_balancing_loss = torch.zeros((), device=predicted_ids.device) moe_router_z_loss = torch.zeros((), device=predicted_ids.device) - if self.model.training: + if self.model.training and self._has_moe_modules: moe_load_balancing_loss = self._calculate_moe_load_balancing_loss( device=predicted_ids.device, ) @@ -309,14 +312,15 @@ def _update_processed_tokens(self, batch): def log_metrics(self, loss, grad_norm): self.metric_logger.set_tokens(self.processed_tokens) self.metric_logger.log("train/loss", self._last_reported_loss.item()) - self.metric_logger.log( - "train/moe_load_balancing_loss", - self._last_moe_load_balancing_loss.item(), - ) - self.metric_logger.log( - "train/moe_router_z_loss", - self._last_moe_router_z_loss.item(), - ) + if self._has_moe_modules: + self.metric_logger.log( + "train/moe_load_balancing_loss", + self._last_moe_load_balancing_loss.item(), + ) + self.metric_logger.log( + "train/moe_router_z_loss", + self._last_moe_router_z_loss.item(), + ) self.metric_logger.log("train/lr", self.scheduler.get_last_lr()[0]) self.metric_logger.log("train/grad_norm", grad_norm.item()) From b5d722ca3bc7f7becf1647860daf47beec550f10 Mon Sep 17 00:00:00 2001 From: Jakub Date: Fri, 3 Apr 2026 14:46:05 +0200 Subject: [PATCH 13/16] refactor logging --- src/core/trainer.py | 87 +++++++++--------- src/core/trainer_distillation.py | 91 +++++++++---------- src/product_keys/trainer.py | 11 ++- src/projected_compression/trainer.py | 4 +- .../trainer_distillation.py | 4 +- 5 files changed, 96 insertions(+), 101 deletions(-) diff --git a/src/core/trainer.py b/src/core/trainer.py index d2ad72fe..dea62e7d 100644 --- a/src/core/trainer.py +++ b/src/core/trainer.py @@ -25,6 +25,15 @@ logger = logging.getLogger(__name__) +@define(frozen=True) +class LossMetrics: + total_loss: torch.Tensor + reported_loss: torch.Tensor + moe_load_balancing_loss: Optional[torch.Tensor] = None + moe_router_z_loss: Optional[torch.Tensor] = None + distill_loss: Optional[torch.Tensor] = None + + @define(slots=False) class Trainer: model: torch.nn.Module @@ -51,10 +60,6 @@ def __attrs_post_init__(self): self._has_moe_modules = any( getattr(module, "is_moe", False) for module in self.model.modules() ) - self.loss_interval_100 = 0.0 - self._last_reported_loss = torch.zeros((), device=self.device) - self._last_moe_load_balancing_loss = torch.zeros((), device=self.device) - self._last_moe_router_z_loss = torch.zeros((), device=self.device) if self.eval_dataloader is not None and hasattr( self.eval_dataloader, "__iter__" @@ -71,6 +76,8 @@ def __attrs_post_init__(self): next(self.eval_iterator) self.loss_averaged_100 = AveMetric(100, "100/train/loss") + if self._has_moe_modules: + self.total_loss_averaged_100 = AveMetric(100, "100/train/total_loss") self.time_diff_averaged_100 = AveDiffMetric(100, "100/time", time.time()) @property @@ -110,11 +117,11 @@ def train(self): self.metric_logger.set_step(step) self.metric_logger.set_tokens(self.processed_tokens) self.model.train() - loss = self.calculate_loss(batch) + loss_metrics = self.calculate_loss(batch) grad_norm = self.clip_gradient() - self.log_metrics(loss, grad_norm) + self.log_metrics(loss_metrics, grad_norm) self.optimizer.step() self.optimizer.zero_grad() @@ -160,7 +167,7 @@ def _preprocess_input(self, batch): # TODO test it return input_ids, target_ids - def calculate_loss(self, batch): + def calculate_loss(self, batch) -> LossMetrics: def _hack_for_python_garbage_collection(input_ids, target_ids): """we want to have no reference to model output while backpropagating to allow torch to free memory, so we wrap loss calculation in a function""" @@ -196,7 +203,7 @@ def _hack_for_python_garbage_collection(input_ids, target_ids): moe_router_z_loss = moe_router_z_loss / self.gradient_accumulation_steps return loss, reported_loss, moe_load_balancing_loss, moe_router_z_loss - losses = [] + total_losses = [] reported_losses = [] moe_load_balancing_losses = [] moe_router_z_losses = [] @@ -211,39 +218,29 @@ def _hack_for_python_garbage_collection(input_ids, target_ids): ) if self.model.training: loss.backward() - losses.append(loss.item()) - reported_losses.append(reported_loss.item()) - moe_load_balancing_losses.append(moe_load_balancing_loss.item()) - moe_router_z_losses.append(moe_router_z_loss.item()) + total_losses.append(loss.detach()) + reported_losses.append(reported_loss.detach()) + moe_load_balancing_losses.append(moe_load_balancing_loss.detach()) + moe_router_z_losses.append(moe_router_z_loss.detach()) # gloo backend supports only sum reduce operation, therfore we first divide by world size and then sum - avg_loss = torch.tensor(losses, device=loss.device).sum() - avg_reported_loss = torch.tensor( - reported_losses, - device=loss.device, - ).sum() - avg_moe_load_balancing_loss = torch.tensor( - moe_load_balancing_losses, - device=loss.device, - ).sum() - avg_moe_router_z_loss = torch.tensor( - moe_router_z_losses, - device=loss.device, - ).sum() + avg_total_loss = torch.stack(total_losses).sum() + avg_reported_loss = torch.stack(reported_losses).sum() + avg_moe_load_balancing_loss = torch.stack(moe_load_balancing_losses).sum() + avg_moe_router_z_loss = torch.stack(moe_router_z_losses).sum() if dist.is_initialized(): - dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(avg_total_loss, op=dist.ReduceOp.SUM) dist.all_reduce(avg_reported_loss, op=dist.ReduceOp.SUM) dist.all_reduce(avg_moe_load_balancing_loss, op=dist.ReduceOp.SUM) dist.all_reduce(avg_moe_router_z_loss, op=dist.ReduceOp.SUM) - self._last_reported_loss = avg_reported_loss / float(os.environ["WORLD_SIZE"]) - self._last_moe_load_balancing_loss = avg_moe_load_balancing_loss / float( - os.environ["WORLD_SIZE"] + world_size = float(os.environ["WORLD_SIZE"]) + return LossMetrics( + total_loss=avg_total_loss / world_size, + reported_loss=avg_reported_loss / world_size, + moe_load_balancing_loss=avg_moe_load_balancing_loss / world_size, + moe_router_z_loss=avg_moe_router_z_loss / world_size, ) - self._last_moe_router_z_loss = avg_moe_router_z_loss / float( - os.environ["WORLD_SIZE"] - ) - return avg_loss / float(os.environ["WORLD_SIZE"]) def _calculate_weighted_moe_loss( self, @@ -285,10 +282,10 @@ def eval(self): batch_fingerprint = create_batch_fingerprint(batch) eval_fingerprint.extend(batch_fingerprint) batch = batch.to(self.device) - loss = self.calculate_loss(batch) - losses.append(loss.item()) + loss_metrics = self.calculate_loss(batch) + losses.append(loss_metrics.reported_loss.detach()) self.metric_logger.flush_accumulated_metrics() - avg_loss = torch.tensor(losses).mean() + avg_loss = torch.stack(losses).mean() self.metric_logger.log("eval/loss", avg_loss.item()) if self._should_log_eval_input: @@ -309,22 +306,28 @@ def clip_gradient(self): def _update_processed_tokens(self, batch): self.processed_tokens += batch.numel() * int(os.environ["WORLD_SIZE"]) - def log_metrics(self, loss, grad_norm): + def log_metrics(self, loss_metrics: LossMetrics, grad_norm): self.metric_logger.set_tokens(self.processed_tokens) - self.metric_logger.log("train/loss", self._last_reported_loss.item()) + self.metric_logger.log("train/loss", loss_metrics.reported_loss.item()) if self._has_moe_modules: + self.metric_logger.log("train/total_loss", loss_metrics.total_loss.item()) self.metric_logger.log( "train/moe_load_balancing_loss", - self._last_moe_load_balancing_loss.item(), + loss_metrics.moe_load_balancing_loss.item(), ) self.metric_logger.log( "train/moe_router_z_loss", - self._last_moe_router_z_loss.item(), + loss_metrics.moe_router_z_loss.item(), ) self.metric_logger.log("train/lr", self.scheduler.get_last_lr()[0]) - self.metric_logger.log("train/grad_norm", grad_norm.item()) + if grad_norm is not None: + self.metric_logger.log("train/grad_norm", grad_norm.item()) - self.loss_averaged_100.log(self.metric_logger, self._last_reported_loss.item()) + self.loss_averaged_100.log(self.metric_logger, loss_metrics.reported_loss.item()) + if self._has_moe_modules: + self.total_loss_averaged_100.log( + self.metric_logger, loss_metrics.total_loss.item() + ) self.time_diff_averaged_100.log(self.metric_logger, time.time()) self.metric_logger.flush_accumulated_metrics() diff --git a/src/core/trainer_distillation.py b/src/core/trainer_distillation.py index 08c57f2d..e403d8c2 100644 --- a/src/core/trainer_distillation.py +++ b/src/core/trainer_distillation.py @@ -7,7 +7,7 @@ import torch.distributed as dist import logging -from src.core.trainer import Trainer +from src.core.trainer import LossMetrics, Trainer from src.core.metric_loggers import AveMetric from src.core.utils import create_batch_fingerprint @@ -83,7 +83,7 @@ def compute_distillation_loss(self, student_logits, teacher_logits): # Scale by temperature^2 to normalize return kl_loss * (self.distillation_temperature**2) - def calculate_loss(self, batch): + def calculate_loss(self, batch) -> LossMetrics: """Override to compute both CE loss and distillation loss""" def _compute_losses(input_ids, target_ids): @@ -122,7 +122,7 @@ def _compute_losses(input_ids, target_ids): return total_loss, ce_loss, distill_loss total_losses = [] - ce_losses = [] + reported_losses = [] distill_losses = [] for batch_chunk in batch.chunk(self.gradient_accumulation_steps): @@ -137,51 +137,48 @@ def _compute_losses(input_ids, target_ids): if self.model.training: total_loss.backward() - total_losses.append(total_loss.item()) - ce_losses.append(ce_loss.item()) - distill_losses.append(distill_loss.item()) + total_losses.append(total_loss.detach()) + reported_losses.append(ce_loss.detach()) + distill_losses.append(distill_loss.detach()) # Average and synchronize across devices - avg_total_loss = torch.tensor(total_losses, device=self.device).sum() - avg_ce_loss = torch.tensor(ce_losses, device=self.device).sum() - avg_distill_loss = torch.tensor(distill_losses, device=self.device).sum() + avg_total_loss = torch.stack(total_losses).sum() + avg_reported_loss = torch.stack(reported_losses).sum() + avg_distill_loss = torch.stack(distill_losses).sum() if dist.is_initialized(): dist.all_reduce(avg_total_loss, op=dist.ReduceOp.SUM) - dist.all_reduce(avg_ce_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(avg_reported_loss, op=dist.ReduceOp.SUM) dist.all_reduce(avg_distill_loss, op=dist.ReduceOp.SUM) world_size = float(os.environ["WORLD_SIZE"]) + return LossMetrics( + total_loss=avg_total_loss / world_size, + reported_loss=avg_reported_loss / world_size, + distill_loss=avg_distill_loss / world_size, + ) - # Store individual losses for logging - self._last_ce_loss = avg_ce_loss / world_size - self._last_distill_loss = avg_distill_loss / world_size - - return avg_total_loss / world_size - - def log_metrics(self, loss, grad_norm): + def log_metrics(self, loss_metrics: LossMetrics, grad_norm): """Override to add distillation-specific metrics""" self.metric_logger.set_tokens(self.processed_tokens) self.metric_logger.log("train/lr", self.scheduler.get_last_lr()[0]) - self.metric_logger.log("train/grad_norm", grad_norm.item()) + if grad_norm is not None: + self.metric_logger.log("train/grad_norm", grad_norm.item()) self.time_diff_averaged_100.log(self.metric_logger, time.time()) - # Add distillation-specific metrics - if hasattr(self, "_last_ce_loss"): - # `loss` is cross entropy loss on language modeling; `total_loss` is training loss! - self.metric_logger.log("train/loss", self._last_ce_loss.item()) - self.metric_logger.log("train/total_loss", loss.item()) - self.metric_logger.log("train/distill_loss", self._last_distill_loss.item()) - - self.loss_averaged_100.log(self.metric_logger, self._last_ce_loss.item()) - self.total_loss_averaged_100.log(self.metric_logger, loss.item()) - self.distill_loss_averaged_100.log( - self.metric_logger, self._last_distill_loss.item() - ) - else: - self.metric_logger.log("train/loss", loss.item()) - self.loss_averaged_100.log(self.metric_logger, loss.item()) + # `reported_loss` is cross entropy loss on language modeling; `total_loss` is training loss. + self.metric_logger.log("train/loss", loss_metrics.reported_loss.item()) + self.metric_logger.log("train/total_loss", loss_metrics.total_loss.item()) + self.metric_logger.log("train/distill_loss", loss_metrics.distill_loss.item()) + + self.loss_averaged_100.log(self.metric_logger, loss_metrics.reported_loss.item()) + self.total_loss_averaged_100.log( + self.metric_logger, loss_metrics.total_loss.item() + ) + self.distill_loss_averaged_100.log( + self.metric_logger, loss_metrics.distill_loss.item() + ) self.metric_logger.flush_accumulated_metrics() @@ -203,26 +200,20 @@ def eval(self): eval_fingerprint.extend(batch_fingerprint) batch = batch.to(self.device) - loss = self.calculate_loss(batch) - losses.append(loss.item()) - - if hasattr(self, "_last_ce_loss"): - ce_losses.append(self._last_ce_loss.item()) - distill_losses.append(self._last_distill_loss.item()) + loss_metrics = self.calculate_loss(batch) + losses.append(loss_metrics.total_loss.detach()) + ce_losses.append(loss_metrics.reported_loss.detach()) + distill_losses.append(loss_metrics.distill_loss.detach()) self.metric_logger.flush_accumulated_metrics() - avg_loss = torch.tensor(losses).mean() - - if ce_losses: - avg_ce_loss = torch.tensor(ce_losses).mean() - avg_distill_loss = torch.tensor(distill_losses).mean() - # `loss` is cross entropy loss on language modeling; `total_loss` is training loss! - self.metric_logger.log("eval/loss", avg_ce_loss.item()) - self.metric_logger.log("eval/distill_loss", avg_distill_loss.item()) - self.metric_logger.log("eval/total_loss", avg_loss.item()) - else: - self.metric_logger.log("eval/loss", avg_loss.item()) + avg_loss = torch.stack(losses).mean() + avg_ce_loss = torch.stack(ce_losses).mean() + avg_distill_loss = torch.stack(distill_losses).mean() + # `reported_loss` is cross entropy loss on language modeling; `total_loss` is training loss. + self.metric_logger.log("eval/loss", avg_ce_loss.item()) + self.metric_logger.log("eval/distill_loss", avg_distill_loss.item()) + self.metric_logger.log("eval/total_loss", avg_loss.item()) if self._should_log_eval_input: self.metric_logger.log("eval/batch", str(eval_fingerprint)) diff --git a/src/product_keys/trainer.py b/src/product_keys/trainer.py index 83b0e106..1a7e441f 100644 --- a/src/product_keys/trainer.py +++ b/src/product_keys/trainer.py @@ -5,7 +5,7 @@ import torch.distributed as dist from typing import List -from src.core.trainer import Trainer +from src.core.trainer import LossMetrics, Trainer @define(slots=False) @@ -45,7 +45,7 @@ def _preprocess_and_mask_input(self, batch): return input_ids, labels - def calculate_loss(self, batch): + def calculate_loss(self, batch) -> LossMetrics: """ Calculates the MLM loss. The loss is calculated only for the masked tokens. @@ -78,11 +78,12 @@ def _mlm_loss_calculation(input_ids, target_ids): loss = _mlm_loss_calculation(input_ids, target_ids) if self.model.training: loss.backward() - losses.append(loss.item()) + losses.append(loss.detach()) # gloo backend supports only sum reduce operation, therfore we first divide by world size and then sum - avg_loss = torch.tensor(losses, device=loss.device).sum() + avg_loss = torch.stack(losses).sum() if dist.is_initialized(): dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) - return avg_loss / float(os.environ["WORLD_SIZE"]) + avg_loss = avg_loss / float(os.environ["WORLD_SIZE"]) + return LossMetrics(total_loss=avg_loss, reported_loss=avg_loss) diff --git a/src/projected_compression/trainer.py b/src/projected_compression/trainer.py index 16e54842..3601b631 100644 --- a/src/projected_compression/trainer.py +++ b/src/projected_compression/trainer.py @@ -33,7 +33,7 @@ def train(self): self.model.train() self.model.prepare_compressed_weights() - loss = self.calculate_loss(batch) + loss_metrics = self.calculate_loss(batch) if self.only_compress_model_gradient_clipping: grad_norm = torch.nn.utils.clip_grad_norm_( @@ -56,7 +56,7 @@ def train(self): self.model.parameters(), self.gradient_clipping, grad_norm ) - self.log_metrics(loss, grad_norm) + self.log_metrics(loss_metrics, grad_norm) self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() diff --git a/src/projected_compression/trainer_distillation.py b/src/projected_compression/trainer_distillation.py index 2d62c448..00aaee7b 100644 --- a/src/projected_compression/trainer_distillation.py +++ b/src/projected_compression/trainer_distillation.py @@ -35,7 +35,7 @@ def train(self): self.model.train() self.model.prepare_compressed_weights() - loss = self.calculate_loss(batch) + loss_metrics = self.calculate_loss(batch) if self.only_compress_model_gradient_clipping: grad_norm = torch.nn.utils.clip_grad_norm_( @@ -58,7 +58,7 @@ def train(self): self.model.parameters(), self.gradient_clipping, grad_norm ) - self.log_metrics(loss, grad_norm) + self.log_metrics(loss_metrics, grad_norm) self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() From d299fca12a0123f836d3ede05d83eb6af2f0ffb9 Mon Sep 17 00:00:00 2001 From: Jakub Date: Fri, 3 Apr 2026 15:01:38 +0200 Subject: [PATCH 14/16] Reformat --- src/core/trainer.py | 4 +++- src/core/trainer_distillation.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/core/trainer.py b/src/core/trainer.py index dea62e7d..b8f6ee8d 100644 --- a/src/core/trainer.py +++ b/src/core/trainer.py @@ -323,7 +323,9 @@ def log_metrics(self, loss_metrics: LossMetrics, grad_norm): if grad_norm is not None: self.metric_logger.log("train/grad_norm", grad_norm.item()) - self.loss_averaged_100.log(self.metric_logger, loss_metrics.reported_loss.item()) + self.loss_averaged_100.log( + self.metric_logger, loss_metrics.reported_loss.item() + ) if self._has_moe_modules: self.total_loss_averaged_100.log( self.metric_logger, loss_metrics.total_loss.item() diff --git a/src/core/trainer_distillation.py b/src/core/trainer_distillation.py index e403d8c2..2a8ac5df 100644 --- a/src/core/trainer_distillation.py +++ b/src/core/trainer_distillation.py @@ -172,7 +172,9 @@ def log_metrics(self, loss_metrics: LossMetrics, grad_norm): self.metric_logger.log("train/total_loss", loss_metrics.total_loss.item()) self.metric_logger.log("train/distill_loss", loss_metrics.distill_loss.item()) - self.loss_averaged_100.log(self.metric_logger, loss_metrics.reported_loss.item()) + self.loss_averaged_100.log( + self.metric_logger, loss_metrics.reported_loss.item() + ) self.total_loss_averaged_100.log( self.metric_logger, loss_metrics.total_loss.item() ) From 488ec0b9455390b5896ccddff612f57b30e66c9d Mon Sep 17 00:00:00 2001 From: Jakub Date: Tue, 14 Apr 2026 13:41:41 +0200 Subject: [PATCH 15/16] modify configs --- configs/_model/llama/base_moe_model.yaml | 106 +++++++++++++++++++++++ configs/_model/llama/small_moe.yaml | 6 +- src/core/moe.py | 1 - 3 files changed, 107 insertions(+), 6 deletions(-) create mode 100644 configs/_model/llama/base_moe_model.yaml diff --git a/configs/_model/llama/base_moe_model.yaml b/configs/_model/llama/base_moe_model.yaml new file mode 100644 index 00000000..e0022941 --- /dev/null +++ b/configs/_model/llama/base_moe_model.yaml @@ -0,0 +1,106 @@ +common: + _target_: src.definitions.Common + dmodel: ??? + dff: ??? + dhead: ??? + sequence_length: ??? + n_blocks: ??? + q_heads: ??? + kv_heads: ??? + + +model: + _target_: src.projected_compression.model.LLM + + embedding: + _target_: src.projected_compression.model.TransformerEmbedding + vocab_size: 128256 + dmodel: ${common.dmodel} + init_fn: + _target_: src.projected_compression.model.trunc_normal_ + _partial_: true + + encoder: + _target_: src.projected_compression.model.TransformerEncoder + n_blocks: ${common.n_blocks} + block_fn: + _target_: src.projected_compression.model.TransformerBlock + _partial_: true + norm_fn: + _target_: src.core.model.RMSNorm + _partial_: true + eps: 1e-5 + normalized_shape: ${common.dmodel} + + attention_fn: + _target_: src.projected_compression.model.RoPEAttention + _partial_: true + dmodel: ${common.dmodel} + q_heads: ${common.q_heads} + kv_heads: ${common.kv_heads} + seq_len: ${common.sequence_length} + q_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${eval:'${common.dhead} * ${model.encoder.block_fn.attention_fn.q_heads}'} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + k_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${eval:'${common.dhead} * ${model.encoder.block_fn.attention_fn.kv_heads}'} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + v_proj_fn: ${model.encoder.block_fn.attention_fn.k_proj_fn} + + o_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${eval:'${common.dhead} * ${model.encoder.block_fn.attention_fn.q_heads}'} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + rope_base: 500000 + rope_scale_freqs: true + + ff_layer_fn: + _target_: src.core.moe.MoE + _partial_: true + dmodel: ${common.dmodel} + dff: ${common.dff} + num_experts: ??? + num_experts_per_tok: ??? + capacity_factor: 1.25 + moe_load_balancing_loss_factor: 0.0 + moe_router_z_loss_factor: 0.0 + normalize_router_logits: false + activation_function: swiglu + init_scale: 1.0 + + head: + _target_: src.projected_compression.model.TransformerHead + linear_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${model.embedding.vocab_size} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + norm_fn: + _target_: src.core.model.RMSNorm + _partial_: true + eps: 1e-5 + normalized_shape: ${common.dmodel} diff --git a/configs/_model/llama/small_moe.yaml b/configs/_model/llama/small_moe.yaml index 9374710a..22140255 100644 --- a/configs/_model/llama/small_moe.yaml +++ b/configs/_model/llama/small_moe.yaml @@ -1,5 +1,5 @@ defaults: - - base_model + - base_moe_model - _self_ common: @@ -14,10 +14,6 @@ model: encoder: block_fn: ff_layer_fn: - _target_: src.core.moe.MoE - _partial_: true - dmodel: ${common.dmodel} - dff: ${common.dff} num_experts: 16 num_experts_per_tok: 1 capacity_factor: 1.25 diff --git a/src/core/moe.py b/src/core/moe.py index 734d7fde..e129f69f 100644 --- a/src/core/moe.py +++ b/src/core/moe.py @@ -28,7 +28,6 @@ def __init__( normalize_router_logits: bool = False, activation_function: str = "swiglu", init_scale: float = 1.0, - **_ignored_kwargs, ): super().__init__() From 10c48c202a9ff948250c9ac6a76bc0a0538f76b2 Mon Sep 17 00:00:00 2001 From: Jakub Date: Tue, 14 Apr 2026 14:20:32 +0200 Subject: [PATCH 16/16] update configs --- configs/_model/llama/base_model.yaml | 28 +----- configs/_model/llama/base_moe_model.yaml | 106 ----------------------- configs/_model/llama/small_moe.yaml | 7 +- configs/ff_layer/dense.yaml | 21 +++++ configs/ff_layer/moe.yaml | 12 +++ src/core/moe.py | 20 ++--- 6 files changed, 50 insertions(+), 144 deletions(-) delete mode 100644 configs/_model/llama/base_moe_model.yaml create mode 100644 configs/ff_layer/dense.yaml create mode 100644 configs/ff_layer/moe.yaml diff --git a/configs/_model/llama/base_model.yaml b/configs/_model/llama/base_model.yaml index 30d8feaf..b18ce01c 100644 --- a/configs/_model/llama/base_model.yaml +++ b/configs/_model/llama/base_model.yaml @@ -1,3 +1,6 @@ +defaults: + - /ff_layer@model.encoder.block_fn.ff_layer_fn: dense + - _self_ common: _target_: src.definitions.Common @@ -75,29 +78,6 @@ model: rope_base: 500000 rope_scale_freqs: true - ff_layer_fn: - _target_: src.projected_compression.model.ProjectedLlamaFeedForward - _partial_: true - ff_pre_act_fn: - _target_: src.projected_compression.model.Linear - _partial_: true - in_features: ${common.dmodel} - out_features: ${common.dff} - partial_init_fn: - _target_: src.projected_compression.model.llm_random_weight_init - _partial_: true - scale: 1 - ff_post_act_fn: - _target_: src.projected_compression.model.Linear - _partial_: true - in_features: ${common.dff} - out_features: ${common.dmodel} - partial_init_fn: - _target_: src.projected_compression.model.llm_random_weight_init - _partial_: true - scale: 1 - gate_fn: ${model.encoder.block_fn.ff_layer_fn.ff_pre_act_fn} - head: _target_: src.projected_compression.model.TransformerHead linear_fn: @@ -113,4 +93,4 @@ model: _target_: src.core.model.RMSNorm _partial_: true eps: 1e-5 - normalized_shape: ${common.dmodel} \ No newline at end of file + normalized_shape: ${common.dmodel} diff --git a/configs/_model/llama/base_moe_model.yaml b/configs/_model/llama/base_moe_model.yaml deleted file mode 100644 index e0022941..00000000 --- a/configs/_model/llama/base_moe_model.yaml +++ /dev/null @@ -1,106 +0,0 @@ -common: - _target_: src.definitions.Common - dmodel: ??? - dff: ??? - dhead: ??? - sequence_length: ??? - n_blocks: ??? - q_heads: ??? - kv_heads: ??? - - -model: - _target_: src.projected_compression.model.LLM - - embedding: - _target_: src.projected_compression.model.TransformerEmbedding - vocab_size: 128256 - dmodel: ${common.dmodel} - init_fn: - _target_: src.projected_compression.model.trunc_normal_ - _partial_: true - - encoder: - _target_: src.projected_compression.model.TransformerEncoder - n_blocks: ${common.n_blocks} - block_fn: - _target_: src.projected_compression.model.TransformerBlock - _partial_: true - norm_fn: - _target_: src.core.model.RMSNorm - _partial_: true - eps: 1e-5 - normalized_shape: ${common.dmodel} - - attention_fn: - _target_: src.projected_compression.model.RoPEAttention - _partial_: true - dmodel: ${common.dmodel} - q_heads: ${common.q_heads} - kv_heads: ${common.kv_heads} - seq_len: ${common.sequence_length} - q_proj_fn: - _target_: src.projected_compression.model.Linear - _partial_: true - in_features: ${common.dmodel} - out_features: ${eval:'${common.dhead} * ${model.encoder.block_fn.attention_fn.q_heads}'} - partial_init_fn: - _target_: src.projected_compression.model.llm_random_weight_init - _partial_: true - scale: 1 - - k_proj_fn: - _target_: src.projected_compression.model.Linear - _partial_: true - in_features: ${common.dmodel} - out_features: ${eval:'${common.dhead} * ${model.encoder.block_fn.attention_fn.kv_heads}'} - partial_init_fn: - _target_: src.projected_compression.model.llm_random_weight_init - _partial_: true - scale: 1 - - v_proj_fn: ${model.encoder.block_fn.attention_fn.k_proj_fn} - - o_proj_fn: - _target_: src.projected_compression.model.Linear - _partial_: true - in_features: ${eval:'${common.dhead} * ${model.encoder.block_fn.attention_fn.q_heads}'} - out_features: ${common.dmodel} - partial_init_fn: - _target_: src.projected_compression.model.llm_random_weight_init - _partial_: true - scale: 1 - - rope_base: 500000 - rope_scale_freqs: true - - ff_layer_fn: - _target_: src.core.moe.MoE - _partial_: true - dmodel: ${common.dmodel} - dff: ${common.dff} - num_experts: ??? - num_experts_per_tok: ??? - capacity_factor: 1.25 - moe_load_balancing_loss_factor: 0.0 - moe_router_z_loss_factor: 0.0 - normalize_router_logits: false - activation_function: swiglu - init_scale: 1.0 - - head: - _target_: src.projected_compression.model.TransformerHead - linear_fn: - _target_: src.projected_compression.model.Linear - _partial_: true - in_features: ${common.dmodel} - out_features: ${model.embedding.vocab_size} - partial_init_fn: - _target_: src.projected_compression.model.llm_random_weight_init - _partial_: true - scale: 1 - norm_fn: - _target_: src.core.model.RMSNorm - _partial_: true - eps: 1e-5 - normalized_shape: ${common.dmodel} diff --git a/configs/_model/llama/small_moe.yaml b/configs/_model/llama/small_moe.yaml index 22140255..2a05cd99 100644 --- a/configs/_model/llama/small_moe.yaml +++ b/configs/_model/llama/small_moe.yaml @@ -1,5 +1,6 @@ defaults: - - base_moe_model + - base_model + - override /ff_layer@model.encoder.block_fn.ff_layer_fn: moe - _self_ common: @@ -15,8 +16,10 @@ model: block_fn: ff_layer_fn: num_experts: 16 - num_experts_per_tok: 1 + topk: 1 capacity_factor: 1.25 moe_load_balancing_loss_factor: 0.01 moe_router_z_loss_factor: 0.001 + normalize_router_logits: false activation_function: swiglu + init_scale: 1.0 diff --git a/configs/ff_layer/dense.yaml b/configs/ff_layer/dense.yaml new file mode 100644 index 00000000..8b7f3582 --- /dev/null +++ b/configs/ff_layer/dense.yaml @@ -0,0 +1,21 @@ +_target_: src.projected_compression.model.ProjectedLlamaFeedForward +_partial_: true +ff_pre_act_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.dff} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 +ff_post_act_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dff} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 +gate_fn: ${model.encoder.block_fn.ff_layer_fn.ff_pre_act_fn} diff --git a/configs/ff_layer/moe.yaml b/configs/ff_layer/moe.yaml new file mode 100644 index 00000000..964a55de --- /dev/null +++ b/configs/ff_layer/moe.yaml @@ -0,0 +1,12 @@ +_target_: src.core.moe.MoE +_partial_: true +dmodel: ${common.dmodel} +dff: ${common.dff} +num_experts: ??? +topk: ??? +capacity_factor: ??? +moe_load_balancing_loss_factor: ??? +moe_router_z_loss_factor: ??? +normalize_router_logits: ??? +activation_function: ??? +init_scale: ??? diff --git a/src/core/moe.py b/src/core/moe.py index e129f69f..4405c865 100644 --- a/src/core/moe.py +++ b/src/core/moe.py @@ -21,7 +21,7 @@ def __init__( dmodel: int, dff: int, num_experts: int, - num_experts_per_tok: int, + topk: int, capacity_factor: float = 1.25, moe_load_balancing_loss_factor: float = 0.0, moe_router_z_loss_factor: float = 0.0, @@ -33,21 +33,17 @@ def __init__( if activation_function != "swiglu": raise ValueError(f"MoE supports only swiglu, got {activation_function}.") - if num_experts_per_tok > num_experts: - raise ValueError( - f"num_experts_per_tok={num_experts_per_tok} must be <= num_experts={num_experts}." - ) + if topk > num_experts: + raise ValueError(f"topk={topk} must be <= num_experts={num_experts}.") if capacity_factor <= 0: raise ValueError(f"capacity_factor must be > 0, got {capacity_factor}.") - if normalize_router_logits and num_experts_per_tok == 1: - raise AssertionError( - "normalize_router_logits requires num_experts_per_tok > 1." - ) + if normalize_router_logits and topk == 1: + raise AssertionError("normalize_router_logits requires topk > 1.") self.dmodel = dmodel self.dff = dff self.num_experts = num_experts - self.num_experts_per_tok = num_experts_per_tok + self.topk = topk self.capacity_factor = capacity_factor self.moe_load_balancing_loss_factor = moe_load_balancing_loss_factor self.moe_router_z_loss_factor = moe_router_z_loss_factor @@ -82,14 +78,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: # For each token, keep only the top-k experts and their routing probabilities topk_probs, selected_experts = torch.topk( router_probs, - k=self.num_experts_per_tok, + k=self.topk, dim=-1, ) # Keep only the highest-gated assignments per expert up to its capacity flat_tokens = torch.arange( num_tokens, device=hidden_states.device, dtype=torch.long - ).repeat_interleave(self.num_experts_per_tok) + ).repeat_interleave(self.topk) flat_experts = selected_experts.reshape(-1) flat_weights = topk_probs.reshape(-1) total_assignments = flat_experts.numel()