diff --git a/configs/_model/llama/base_model.yaml b/configs/_model/llama/base_model.yaml index 30d8feaf..b18ce01c 100644 --- a/configs/_model/llama/base_model.yaml +++ b/configs/_model/llama/base_model.yaml @@ -1,3 +1,6 @@ +defaults: + - /ff_layer@model.encoder.block_fn.ff_layer_fn: dense + - _self_ common: _target_: src.definitions.Common @@ -75,29 +78,6 @@ model: rope_base: 500000 rope_scale_freqs: true - ff_layer_fn: - _target_: src.projected_compression.model.ProjectedLlamaFeedForward - _partial_: true - ff_pre_act_fn: - _target_: src.projected_compression.model.Linear - _partial_: true - in_features: ${common.dmodel} - out_features: ${common.dff} - partial_init_fn: - _target_: src.projected_compression.model.llm_random_weight_init - _partial_: true - scale: 1 - ff_post_act_fn: - _target_: src.projected_compression.model.Linear - _partial_: true - in_features: ${common.dff} - out_features: ${common.dmodel} - partial_init_fn: - _target_: src.projected_compression.model.llm_random_weight_init - _partial_: true - scale: 1 - gate_fn: ${model.encoder.block_fn.ff_layer_fn.ff_pre_act_fn} - head: _target_: src.projected_compression.model.TransformerHead linear_fn: @@ -113,4 +93,4 @@ model: _target_: src.core.model.RMSNorm _partial_: true eps: 1e-5 - normalized_shape: ${common.dmodel} \ No newline at end of file + normalized_shape: ${common.dmodel} diff --git a/configs/_model/llama/small.yaml b/configs/_model/llama/small.yaml new file mode 100644 index 00000000..7c50ab66 --- /dev/null +++ b/configs/_model/llama/small.yaml @@ -0,0 +1,11 @@ +defaults: + - base_model + - _self_ + +common: + dmodel: 1024 + dff: 2816 + dhead: 64 + n_blocks: 16 + q_heads: 16 + kv_heads: 16 diff --git a/configs/_model/llama/small_moe.yaml b/configs/_model/llama/small_moe.yaml new file mode 100644 index 00000000..2a05cd99 --- /dev/null +++ b/configs/_model/llama/small_moe.yaml @@ -0,0 +1,25 @@ +defaults: + - base_model + - override /ff_layer@model.encoder.block_fn.ff_layer_fn: moe + - _self_ + +common: + dmodel: 1024 + dff: 2816 + dhead: 64 + n_blocks: 16 + q_heads: 16 + kv_heads: 16 + +model: + encoder: + block_fn: + ff_layer_fn: + num_experts: 16 + topk: 1 + capacity_factor: 1.25 + moe_load_balancing_loss_factor: 0.01 + moe_router_z_loss_factor: 0.001 + normalize_router_logits: false + activation_function: swiglu + init_scale: 1.0 diff --git a/configs/ff_layer/dense.yaml b/configs/ff_layer/dense.yaml new file mode 100644 index 00000000..8b7f3582 --- /dev/null +++ b/configs/ff_layer/dense.yaml @@ -0,0 +1,21 @@ +_target_: src.projected_compression.model.ProjectedLlamaFeedForward +_partial_: true +ff_pre_act_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.dff} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 +ff_post_act_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dff} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 +gate_fn: ${model.encoder.block_fn.ff_layer_fn.ff_pre_act_fn} diff --git a/configs/ff_layer/moe.yaml b/configs/ff_layer/moe.yaml new file mode 100644 index 00000000..964a55de --- /dev/null +++ b/configs/ff_layer/moe.yaml @@ -0,0 +1,12 @@ +_target_: src.core.moe.MoE +_partial_: true +dmodel: ${common.dmodel} +dff: ${common.dff} +num_experts: ??? +topk: ??? +capacity_factor: ??? +moe_load_balancing_loss_factor: ??? +moe_router_z_loss_factor: ??? +normalize_router_logits: ??? +activation_function: ??? +init_scale: ??? diff --git a/configs/moe_example_run.yaml b/configs/moe_example_run.yaml new file mode 100644 index 00000000..d0e187b4 --- /dev/null +++ b/configs/moe_example_run.yaml @@ -0,0 +1,53 @@ +defaults: + - _cluster@_here_: entropy + - _model/llama@_here_: small_moe + - _trainer@_here_: llama + - _dataset@_here_: c4 + - _checkpoints@_here_: none + - _misc@_here_: default + - _eval@_here_: basic + - _self_ + +common: + sequence_length: 1024 + batch_size: 64 + +model: + embedding: + vocab_size: 50257 + +trainer: + gradient_accumulation_steps: 1 + n_steps: 1000 + learning_rate: 5e-4 + + train_dataloader: + dataset: + tokenize_fn: + _target_: src.core.datasets.gpt2_tokenize_fn + + eval_dataloader: + dataset: + tokenize_fn: + _target_: src.core.datasets.gpt2_tokenize_fn + +infrastructure: + max_concurrent_jobs: 1 + + metric_logger: + type: wandb + wandb_entity: ideas_cv + project_name: llm-random-test + name: moe_2gpu + tags: + - nano + - remote + - small + - moe + + slurm: + time: "0-02:00:00" + gres: gpu:2 + job-name: ${infrastructure.metric_logger.name} + +evaluator: null diff --git a/src/core/model.py b/src/core/model.py index 0d20118e..4de083c8 100644 --- a/src/core/model.py +++ b/src/core/model.py @@ -188,8 +188,13 @@ def forward(self, x): class TransformerTower(nn.Module): def get_model_dimensions(self): # Works only for llama3 transforermer architecture - dmodel = self.blocks[0].ff_layer.layer.ff_pre_act.weight.shape[1] - dff = self.blocks[0].ff_layer.layer.ff_pre_act.weight.shape[0] + ff_layer = self.blocks[0].ff_layer.layer + if getattr(ff_layer, "is_moe", False): + dmodel = ff_layer.dmodel + dff = ff_layer.dff + else: + dmodel = ff_layer.ff_pre_act.weight.shape[1] + dff = ff_layer.ff_pre_act.weight.shape[0] datt = self.blocks[0].attention_layer.layer.q_proj.weight.shape[0] n_att_heads = self.blocks[0].attention_layer.layer.q_heads n_kvatt_heads = self.blocks[0].attention_layer.layer.kv_heads diff --git a/src/core/moe.py b/src/core/moe.py new file mode 100644 index 00000000..4405c865 --- /dev/null +++ b/src/core/moe.py @@ -0,0 +1,169 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn.init import trunc_normal_ +import logging +import math + + +logger = logging.getLogger(__name__) + + +@torch.no_grad() +def _truncated_normal_(weight: torch.Tensor, fan_in: int, scale: float) -> None: + std = scale * (1 / fan_in) ** 0.5 + trunc_normal_(weight, mean=0.0, std=std, a=-2 * std, b=2 * std) + + +class MoE(nn.Module): + def __init__( + self, + dmodel: int, + dff: int, + num_experts: int, + topk: int, + capacity_factor: float = 1.25, + moe_load_balancing_loss_factor: float = 0.0, + moe_router_z_loss_factor: float = 0.0, + normalize_router_logits: bool = False, + activation_function: str = "swiglu", + init_scale: float = 1.0, + ): + super().__init__() + + if activation_function != "swiglu": + raise ValueError(f"MoE supports only swiglu, got {activation_function}.") + if topk > num_experts: + raise ValueError(f"topk={topk} must be <= num_experts={num_experts}.") + if capacity_factor <= 0: + raise ValueError(f"capacity_factor must be > 0, got {capacity_factor}.") + if normalize_router_logits and topk == 1: + raise AssertionError("normalize_router_logits requires topk > 1.") + + self.dmodel = dmodel + self.dff = dff + self.num_experts = num_experts + self.topk = topk + self.capacity_factor = capacity_factor + self.moe_load_balancing_loss_factor = moe_load_balancing_loss_factor + self.moe_router_z_loss_factor = moe_router_z_loss_factor + self.normalize_router_logits = normalize_router_logits + self.is_moe = True + self.moe_load_balancing_loss = None + self.router_z_loss = None + + self.router_weight = nn.Parameter(torch.empty(num_experts, dmodel)) + self.ff_pre_act_weight = nn.Parameter(torch.empty(num_experts, dff, dmodel)) + self.gate_weight = nn.Parameter(torch.empty(num_experts, dff, dmodel)) + self.ff_post_act_weight = nn.Parameter(torch.empty(num_experts, dmodel, dff)) + + _truncated_normal_(self.router_weight, dmodel, init_scale) + _truncated_normal_(self.ff_pre_act_weight, dmodel, init_scale) + _truncated_normal_(self.gate_weight, dmodel, init_scale) + _truncated_normal_(self.ff_post_act_weight, dff, init_scale) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + original_shape = x.shape + hidden_states = x.reshape(-1, self.dmodel) + num_tokens = hidden_states.size(0) + + # Router + router_logits = torch.einsum( + "th,eh->te", + hidden_states, + self.router_weight, + ) + router_logits = router_logits.to(dtype=torch.float32) + router_probs = F.softmax(router_logits, dim=-1) + # For each token, keep only the top-k experts and their routing probabilities + topk_probs, selected_experts = torch.topk( + router_probs, + k=self.topk, + dim=-1, + ) + + # Keep only the highest-gated assignments per expert up to its capacity + flat_tokens = torch.arange( + num_tokens, device=hidden_states.device, dtype=torch.long + ).repeat_interleave(self.topk) + flat_experts = selected_experts.reshape(-1) + flat_weights = topk_probs.reshape(-1) + total_assignments = flat_experts.numel() + capacity = max( + 1, + math.ceil(self.capacity_factor * total_assignments / self.num_experts), + ) + weight_order = torch.argsort(flat_weights, descending=True, stable=True) + grouped_order = torch.argsort(flat_experts[weight_order], stable=True) + sort_order = weight_order[grouped_order] + sorted_experts = flat_experts[sort_order] + sorted_tokens = flat_tokens[sort_order] + sorted_weights = flat_weights[sort_order] + expert_counts = sorted_experts.bincount(minlength=self.num_experts) + expert_offsets = expert_counts.cumsum(0) - expert_counts + slot_in_expert = ( + torch.arange(total_assignments, device=hidden_states.device) + - expert_offsets[sorted_experts] + ) + keep = slot_in_expert < capacity + kept_experts = sorted_experts[keep] + kept_tokens = sorted_tokens[keep] + kept_slots = slot_in_expert[keep] + kept_weights = sorted_weights[keep] + if self.normalize_router_logits and kept_weights.numel() > 0: + # Renormalize only the surviving expert weights so each token sums to 1 after capacity pruning. + token_weight_sums = kept_weights.new_zeros(num_tokens) + token_weight_sums.index_add_(0, kept_tokens, kept_weights) + kept_weights = kept_weights / token_weight_sums.index_select(0, kept_tokens) + + # Dispatch the surviving tokens into expert-capacity slots and run the expert MLP batched per expert + flat_capacity = self.num_experts * capacity + dispatch_index = kept_experts * capacity + kept_slots + expert_inputs = hidden_states.new_zeros(flat_capacity, self.dmodel) + expert_inputs.index_copy_(0, dispatch_index, hidden_states[kept_tokens]) + expert_inputs = expert_inputs.view( + self.num_experts, + capacity, + self.dmodel, + ) + ff_pre_act = torch.einsum( + "ech,edh->ecd", + expert_inputs, + self.ff_pre_act_weight, + ) + gate = torch.einsum( + "ech,edh->ecd", + expert_inputs, + self.gate_weight, + ) + expert_outputs = torch.einsum( + "ecd,ehd->ech", + ff_pre_act * F.silu(gate), + self.ff_post_act_weight, + ) + + # Gather only the kept expert outputs back to tokens and sum the top-k contributions + token_updates = expert_outputs.view(flat_capacity, self.dmodel).index_select( + 0, dispatch_index + ) + token_updates = token_updates * kept_weights.to(hidden_states.dtype).unsqueeze( + -1 + ) + output = hidden_states.new_zeros(num_tokens, self.dmodel) + output = output.index_add(0, kept_tokens, token_updates) + output = output.reshape(original_shape) + + # Match the switch-style load-balancing term using pre-capacity routing statistics + if self.training: + expert_frequency = flat_experts.bincount(minlength=self.num_experts) + expert_frequency = expert_frequency.to(router_probs.dtype) + expert_frequency = expert_frequency / expert_frequency.sum().clamp_min(1) + self.moe_load_balancing_loss = ( + self.num_experts * (router_probs.mean(dim=0) * expert_frequency).sum() + ) + self.router_z_loss = torch.logsumexp(router_logits, dim=-1).square().mean() + else: + self.moe_load_balancing_loss = None + self.router_z_loss = None + + return output diff --git a/src/core/trainer.py b/src/core/trainer.py index 18fab093..b8f6ee8d 100644 --- a/src/core/trainer.py +++ b/src/core/trainer.py @@ -25,6 +25,15 @@ logger = logging.getLogger(__name__) +@define(frozen=True) +class LossMetrics: + total_loss: torch.Tensor + reported_loss: torch.Tensor + moe_load_balancing_loss: Optional[torch.Tensor] = None + moe_router_z_loss: Optional[torch.Tensor] = None + distill_loss: Optional[torch.Tensor] = None + + @define(slots=False) class Trainer: model: torch.nn.Module @@ -48,7 +57,9 @@ def __attrs_post_init__(self): self.processed_tokens = self.training_state["processed_tokens"] self.start_step = self.training_state["next_step"] self.device = next(self.model.parameters()).device - self.loss_interval_100 = 0.0 + self._has_moe_modules = any( + getattr(module, "is_moe", False) for module in self.model.modules() + ) if self.eval_dataloader is not None and hasattr( self.eval_dataloader, "__iter__" @@ -65,6 +76,8 @@ def __attrs_post_init__(self): next(self.eval_iterator) self.loss_averaged_100 = AveMetric(100, "100/train/loss") + if self._has_moe_modules: + self.total_loss_averaged_100 = AveMetric(100, "100/train/total_loss") self.time_diff_averaged_100 = AveDiffMetric(100, "100/time", time.time()) @property @@ -104,11 +117,11 @@ def train(self): self.metric_logger.set_step(step) self.metric_logger.set_tokens(self.processed_tokens) self.model.train() - loss = self.calculate_loss(batch) + loss_metrics = self.calculate_loss(batch) grad_norm = self.clip_gradient() - self.log_metrics(loss, grad_norm) + self.log_metrics(loss_metrics, grad_norm) self.optimizer.step() self.optimizer.zero_grad() @@ -154,7 +167,7 @@ def _preprocess_input(self, batch): # TODO test it return input_ids, target_ids - def calculate_loss(self, batch): + def calculate_loss(self, batch) -> LossMetrics: def _hack_for_python_garbage_collection(input_ids, target_ids): """we want to have no reference to model output while backpropagating to allow torch to free memory, so we wrap loss calculation in a function""" @@ -168,27 +181,94 @@ def _hack_for_python_garbage_collection(input_ids, target_ids): target_ids.reshape(-1).long(), reduction="none", ) - loss = mask_loss.mean() / self.gradient_accumulation_steps - return loss + # Keep the reported loss as pure CE so train/loss stays comparable to eval/loss; + # MoE auxiliary terms are optimized and logged separately. + reported_loss = mask_loss.mean() + loss = reported_loss + moe_load_balancing_loss = torch.zeros((), device=predicted_ids.device) + moe_router_z_loss = torch.zeros((), device=predicted_ids.device) + if self.model.training and self._has_moe_modules: + moe_load_balancing_loss = self._calculate_moe_load_balancing_loss( + device=predicted_ids.device, + ) + moe_router_z_loss = self._calculate_moe_router_z_loss( + device=predicted_ids.device, + ) + loss = loss + moe_load_balancing_loss + moe_router_z_loss + reported_loss = reported_loss / self.gradient_accumulation_steps + loss = loss / self.gradient_accumulation_steps + moe_load_balancing_loss = ( + moe_load_balancing_loss / self.gradient_accumulation_steps + ) + moe_router_z_loss = moe_router_z_loss / self.gradient_accumulation_steps + return loss, reported_loss, moe_load_balancing_loss, moe_router_z_loss - losses = [] + total_losses = [] + reported_losses = [] + moe_load_balancing_losses = [] + moe_router_z_losses = [] for batch_chunk in batch.chunk(self.gradient_accumulation_steps): input_ids, target_ids = self._preprocess_input(batch_chunk) input_ids = input_ids.to(self.device) if self.model.training: self._update_processed_tokens(input_ids) - loss = _hack_for_python_garbage_collection(input_ids, target_ids) + loss, reported_loss, moe_load_balancing_loss, moe_router_z_loss = ( + _hack_for_python_garbage_collection(input_ids, target_ids) + ) if self.model.training: loss.backward() - losses.append(loss.item()) + total_losses.append(loss.detach()) + reported_losses.append(reported_loss.detach()) + moe_load_balancing_losses.append(moe_load_balancing_loss.detach()) + moe_router_z_losses.append(moe_router_z_loss.detach()) # gloo backend supports only sum reduce operation, therfore we first divide by world size and then sum - avg_loss = torch.tensor(losses, device=loss.device).sum() + avg_total_loss = torch.stack(total_losses).sum() + avg_reported_loss = torch.stack(reported_losses).sum() + avg_moe_load_balancing_loss = torch.stack(moe_load_balancing_losses).sum() + avg_moe_router_z_loss = torch.stack(moe_router_z_losses).sum() if dist.is_initialized(): - dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(avg_total_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(avg_reported_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(avg_moe_load_balancing_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(avg_moe_router_z_loss, op=dist.ReduceOp.SUM) + + world_size = float(os.environ["WORLD_SIZE"]) + return LossMetrics( + total_loss=avg_total_loss / world_size, + reported_loss=avg_reported_loss / world_size, + moe_load_balancing_loss=avg_moe_load_balancing_loss / world_size, + moe_router_z_loss=avg_moe_router_z_loss / world_size, + ) + + def _calculate_weighted_moe_loss( + self, + device, + loss_attr, + factor_attr, + ): + loss = torch.zeros((), device=device) + for module in self.model.modules(): + module_loss = getattr(module, loss_attr, None) + factor = getattr(module, factor_attr, 0.0) + if module_loss is not None and factor: + loss = loss + module_loss.to(device=device) * factor + return loss + + def _calculate_moe_load_balancing_loss(self, device): + return self._calculate_weighted_moe_loss( + device=device, + loss_attr="moe_load_balancing_loss", + factor_attr="moe_load_balancing_loss_factor", + ) - return avg_loss / float(os.environ["WORLD_SIZE"]) + def _calculate_moe_router_z_loss(self, device): + return self._calculate_weighted_moe_loss( + device=device, + loss_attr="router_z_loss", + factor_attr="moe_router_z_loss_factor", + ) def eval(self): self.model.eval() @@ -202,10 +282,10 @@ def eval(self): batch_fingerprint = create_batch_fingerprint(batch) eval_fingerprint.extend(batch_fingerprint) batch = batch.to(self.device) - loss = self.calculate_loss(batch) - losses.append(loss.item()) + loss_metrics = self.calculate_loss(batch) + losses.append(loss_metrics.reported_loss.detach()) self.metric_logger.flush_accumulated_metrics() - avg_loss = torch.tensor(losses).mean() + avg_loss = torch.stack(losses).mean() self.metric_logger.log("eval/loss", avg_loss.item()) if self._should_log_eval_input: @@ -226,13 +306,30 @@ def clip_gradient(self): def _update_processed_tokens(self, batch): self.processed_tokens += batch.numel() * int(os.environ["WORLD_SIZE"]) - def log_metrics(self, loss, grad_norm): + def log_metrics(self, loss_metrics: LossMetrics, grad_norm): self.metric_logger.set_tokens(self.processed_tokens) - self.metric_logger.log("train/loss", loss.item()) + self.metric_logger.log("train/loss", loss_metrics.reported_loss.item()) + if self._has_moe_modules: + self.metric_logger.log("train/total_loss", loss_metrics.total_loss.item()) + self.metric_logger.log( + "train/moe_load_balancing_loss", + loss_metrics.moe_load_balancing_loss.item(), + ) + self.metric_logger.log( + "train/moe_router_z_loss", + loss_metrics.moe_router_z_loss.item(), + ) self.metric_logger.log("train/lr", self.scheduler.get_last_lr()[0]) - self.metric_logger.log("train/grad_norm", grad_norm.item()) + if grad_norm is not None: + self.metric_logger.log("train/grad_norm", grad_norm.item()) - self.loss_averaged_100.log(self.metric_logger, loss.item()) + self.loss_averaged_100.log( + self.metric_logger, loss_metrics.reported_loss.item() + ) + if self._has_moe_modules: + self.total_loss_averaged_100.log( + self.metric_logger, loss_metrics.total_loss.item() + ) self.time_diff_averaged_100.log(self.metric_logger, time.time()) self.metric_logger.flush_accumulated_metrics() diff --git a/src/core/trainer_distillation.py b/src/core/trainer_distillation.py index 08c57f2d..2a8ac5df 100644 --- a/src/core/trainer_distillation.py +++ b/src/core/trainer_distillation.py @@ -7,7 +7,7 @@ import torch.distributed as dist import logging -from src.core.trainer import Trainer +from src.core.trainer import LossMetrics, Trainer from src.core.metric_loggers import AveMetric from src.core.utils import create_batch_fingerprint @@ -83,7 +83,7 @@ def compute_distillation_loss(self, student_logits, teacher_logits): # Scale by temperature^2 to normalize return kl_loss * (self.distillation_temperature**2) - def calculate_loss(self, batch): + def calculate_loss(self, batch) -> LossMetrics: """Override to compute both CE loss and distillation loss""" def _compute_losses(input_ids, target_ids): @@ -122,7 +122,7 @@ def _compute_losses(input_ids, target_ids): return total_loss, ce_loss, distill_loss total_losses = [] - ce_losses = [] + reported_losses = [] distill_losses = [] for batch_chunk in batch.chunk(self.gradient_accumulation_steps): @@ -137,51 +137,50 @@ def _compute_losses(input_ids, target_ids): if self.model.training: total_loss.backward() - total_losses.append(total_loss.item()) - ce_losses.append(ce_loss.item()) - distill_losses.append(distill_loss.item()) + total_losses.append(total_loss.detach()) + reported_losses.append(ce_loss.detach()) + distill_losses.append(distill_loss.detach()) # Average and synchronize across devices - avg_total_loss = torch.tensor(total_losses, device=self.device).sum() - avg_ce_loss = torch.tensor(ce_losses, device=self.device).sum() - avg_distill_loss = torch.tensor(distill_losses, device=self.device).sum() + avg_total_loss = torch.stack(total_losses).sum() + avg_reported_loss = torch.stack(reported_losses).sum() + avg_distill_loss = torch.stack(distill_losses).sum() if dist.is_initialized(): dist.all_reduce(avg_total_loss, op=dist.ReduceOp.SUM) - dist.all_reduce(avg_ce_loss, op=dist.ReduceOp.SUM) + dist.all_reduce(avg_reported_loss, op=dist.ReduceOp.SUM) dist.all_reduce(avg_distill_loss, op=dist.ReduceOp.SUM) world_size = float(os.environ["WORLD_SIZE"]) + return LossMetrics( + total_loss=avg_total_loss / world_size, + reported_loss=avg_reported_loss / world_size, + distill_loss=avg_distill_loss / world_size, + ) - # Store individual losses for logging - self._last_ce_loss = avg_ce_loss / world_size - self._last_distill_loss = avg_distill_loss / world_size - - return avg_total_loss / world_size - - def log_metrics(self, loss, grad_norm): + def log_metrics(self, loss_metrics: LossMetrics, grad_norm): """Override to add distillation-specific metrics""" self.metric_logger.set_tokens(self.processed_tokens) self.metric_logger.log("train/lr", self.scheduler.get_last_lr()[0]) - self.metric_logger.log("train/grad_norm", grad_norm.item()) + if grad_norm is not None: + self.metric_logger.log("train/grad_norm", grad_norm.item()) self.time_diff_averaged_100.log(self.metric_logger, time.time()) - # Add distillation-specific metrics - if hasattr(self, "_last_ce_loss"): - # `loss` is cross entropy loss on language modeling; `total_loss` is training loss! - self.metric_logger.log("train/loss", self._last_ce_loss.item()) - self.metric_logger.log("train/total_loss", loss.item()) - self.metric_logger.log("train/distill_loss", self._last_distill_loss.item()) - - self.loss_averaged_100.log(self.metric_logger, self._last_ce_loss.item()) - self.total_loss_averaged_100.log(self.metric_logger, loss.item()) - self.distill_loss_averaged_100.log( - self.metric_logger, self._last_distill_loss.item() - ) - else: - self.metric_logger.log("train/loss", loss.item()) - self.loss_averaged_100.log(self.metric_logger, loss.item()) + # `reported_loss` is cross entropy loss on language modeling; `total_loss` is training loss. + self.metric_logger.log("train/loss", loss_metrics.reported_loss.item()) + self.metric_logger.log("train/total_loss", loss_metrics.total_loss.item()) + self.metric_logger.log("train/distill_loss", loss_metrics.distill_loss.item()) + + self.loss_averaged_100.log( + self.metric_logger, loss_metrics.reported_loss.item() + ) + self.total_loss_averaged_100.log( + self.metric_logger, loss_metrics.total_loss.item() + ) + self.distill_loss_averaged_100.log( + self.metric_logger, loss_metrics.distill_loss.item() + ) self.metric_logger.flush_accumulated_metrics() @@ -203,26 +202,20 @@ def eval(self): eval_fingerprint.extend(batch_fingerprint) batch = batch.to(self.device) - loss = self.calculate_loss(batch) - losses.append(loss.item()) - - if hasattr(self, "_last_ce_loss"): - ce_losses.append(self._last_ce_loss.item()) - distill_losses.append(self._last_distill_loss.item()) + loss_metrics = self.calculate_loss(batch) + losses.append(loss_metrics.total_loss.detach()) + ce_losses.append(loss_metrics.reported_loss.detach()) + distill_losses.append(loss_metrics.distill_loss.detach()) self.metric_logger.flush_accumulated_metrics() - avg_loss = torch.tensor(losses).mean() - - if ce_losses: - avg_ce_loss = torch.tensor(ce_losses).mean() - avg_distill_loss = torch.tensor(distill_losses).mean() - # `loss` is cross entropy loss on language modeling; `total_loss` is training loss! - self.metric_logger.log("eval/loss", avg_ce_loss.item()) - self.metric_logger.log("eval/distill_loss", avg_distill_loss.item()) - self.metric_logger.log("eval/total_loss", avg_loss.item()) - else: - self.metric_logger.log("eval/loss", avg_loss.item()) + avg_loss = torch.stack(losses).mean() + avg_ce_loss = torch.stack(ce_losses).mean() + avg_distill_loss = torch.stack(distill_losses).mean() + # `reported_loss` is cross entropy loss on language modeling; `total_loss` is training loss. + self.metric_logger.log("eval/loss", avg_ce_loss.item()) + self.metric_logger.log("eval/distill_loss", avg_distill_loss.item()) + self.metric_logger.log("eval/total_loss", avg_loss.item()) if self._should_log_eval_input: self.metric_logger.log("eval/batch", str(eval_fingerprint)) diff --git a/src/product_keys/trainer.py b/src/product_keys/trainer.py index 83b0e106..1a7e441f 100644 --- a/src/product_keys/trainer.py +++ b/src/product_keys/trainer.py @@ -5,7 +5,7 @@ import torch.distributed as dist from typing import List -from src.core.trainer import Trainer +from src.core.trainer import LossMetrics, Trainer @define(slots=False) @@ -45,7 +45,7 @@ def _preprocess_and_mask_input(self, batch): return input_ids, labels - def calculate_loss(self, batch): + def calculate_loss(self, batch) -> LossMetrics: """ Calculates the MLM loss. The loss is calculated only for the masked tokens. @@ -78,11 +78,12 @@ def _mlm_loss_calculation(input_ids, target_ids): loss = _mlm_loss_calculation(input_ids, target_ids) if self.model.training: loss.backward() - losses.append(loss.item()) + losses.append(loss.detach()) # gloo backend supports only sum reduce operation, therfore we first divide by world size and then sum - avg_loss = torch.tensor(losses, device=loss.device).sum() + avg_loss = torch.stack(losses).sum() if dist.is_initialized(): dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) - return avg_loss / float(os.environ["WORLD_SIZE"]) + avg_loss = avg_loss / float(os.environ["WORLD_SIZE"]) + return LossMetrics(total_loss=avg_loss, reported_loss=avg_loss) diff --git a/src/projected_compression/trainer.py b/src/projected_compression/trainer.py index 16e54842..3601b631 100644 --- a/src/projected_compression/trainer.py +++ b/src/projected_compression/trainer.py @@ -33,7 +33,7 @@ def train(self): self.model.train() self.model.prepare_compressed_weights() - loss = self.calculate_loss(batch) + loss_metrics = self.calculate_loss(batch) if self.only_compress_model_gradient_clipping: grad_norm = torch.nn.utils.clip_grad_norm_( @@ -56,7 +56,7 @@ def train(self): self.model.parameters(), self.gradient_clipping, grad_norm ) - self.log_metrics(loss, grad_norm) + self.log_metrics(loss_metrics, grad_norm) self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step() diff --git a/src/projected_compression/trainer_distillation.py b/src/projected_compression/trainer_distillation.py index 2d62c448..00aaee7b 100644 --- a/src/projected_compression/trainer_distillation.py +++ b/src/projected_compression/trainer_distillation.py @@ -35,7 +35,7 @@ def train(self): self.model.train() self.model.prepare_compressed_weights() - loss = self.calculate_loss(batch) + loss_metrics = self.calculate_loss(batch) if self.only_compress_model_gradient_clipping: grad_norm = torch.nn.utils.clip_grad_norm_( @@ -58,7 +58,7 @@ def train(self): self.model.parameters(), self.gradient_clipping, grad_norm ) - self.log_metrics(loss, grad_norm) + self.log_metrics(loss_metrics, grad_norm) self.optimizer.step() self.optimizer.zero_grad() self.scheduler.step()