diff --git a/configs/product_keys/baseline.yaml b/configs/product_keys/baseline.yaml new file mode 100644 index 00000000..27cc1919 --- /dev/null +++ b/configs/product_keys/baseline.yaml @@ -0,0 +1,169 @@ +# @package _global_ +defaults: + - /_cluster/helios@_here_ + - /_model/tiny@_here_ + - /_trainer/llama@_here_ + - /_dataset/c4@_here_ + - /_checkpoints/none@_here_ + - /_misc/default@_here_ + - _self_ + +common: + sequence_length: 1024 + batch_size: 64 + dmodel: 1024 + dff: 2724 + datt: ${common.dmodel} + n_blocks: 16 + q_heads: 16 + kv_heads: 16 + # vocab_size: 50304 + vocab_size: 128256 + +# trainer: +# _target_: src.product_keys.trainer.MaskedLMTrainer +# masking_percentage: 0.2 +# mask_token_id: 50257 +# unmaskable_special_tokens: [50256, 50257] # <|endoftext|> +# gradient_accumulation_steps: 2 +# n_steps: 77050 +# # ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] +# learning_rate: 1e-3 +# train_dataloader: +# dataset: +# tokenize_fn: +# _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn +# eval_dataloader: +# dataset: +# tokenize_fn: +# _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + +trainer: + n_steps: 77050 + ^learning_rate: [5e-4, 1e-3, 2e-3] + # learning_rate: 1e-3 + +infrastructure: + metric_logger: + type: wandb + wandb_entity: ideas_cv + name: baseline_causal + project_name: tml-bgw + tags: + - baseline + - clm + - "seq_len=${common.sequence_length}" + - "n_layers=${common.n_blocks}" + - "dmodel=${common.dmodel}" + slurm: + gres: gpu:2 + time: "1-00:00:00" + job-name: ${infrastructure.metric_logger.name} + +model: + _target_: src.projected_compression.model.LLM + + embedding: + _target_: src.projected_compression.model.TransformerEmbedding + vocab_size: ${common.vocab_size} + dmodel: ${common.dmodel} + init_fn: + _target_: src.projected_compression.model.trunc_normal_ + _partial_: true + + encoder: + _target_: src.projected_compression.model.TransformerEncoder + n_blocks: ${common.n_blocks} + block_fn: + _target_: src.projected_compression.model.TransformerBlock + _partial_: true + norm_fn: + _target_: src.core.model.RMSNorm + _partial_: true + eps: 1e-5 + normalized_shape: ${common.dmodel} + + attention_fn: + _target_: src.projected_compression.model.RoPEAttention + _partial_: true + dmodel: ${common.dmodel} + q_heads: ${common.q_heads} + kv_heads: ${common.kv_heads} + seq_len: ${common.sequence_length} + causal: true + + q_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.datt} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + k_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${eval:'(${common.datt} // ${model.encoder.block_fn.attention_fn.q_heads}) * ${model.encoder.block_fn.attention_fn.kv_heads}'} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + v_proj_fn: ${model.encoder.block_fn.attention_fn.k_proj_fn} + + # o_proj_fn: ${model.encoder.block_fn.attention_fn.q_proj_fn} # TODO check have I done it right pls + o_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.datt} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + rope_base: 500000 + rope_scale_freqs: true + + ff_layer_fn: + _target_: src.projected_compression.model.ProjectedLlamaFeedForward + _partial_: true + ff_pre_act_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.dff} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + ff_post_act_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dff} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + gate_fn: ${model.encoder.block_fn.ff_layer_fn.ff_pre_act_fn} + + head: + _target_: src.projected_compression.model.TransformerHead + linear_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.vocab_size} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + norm_fn: + _target_: src.core.model.RMSNorm + _partial_: true + eps: 1e-5 + normalized_shape: ${common.dmodel} \ No newline at end of file diff --git a/configs/product_keys/pk_mlm.yaml b/configs/product_keys/pk_mlm.yaml index 0d636623..aaa3a416 100644 --- a/configs/product_keys/pk_mlm.yaml +++ b/configs/product_keys/pk_mlm.yaml @@ -1,6 +1,6 @@ # @package _global_ defaults: - - /_cluster/entropy@_here_ + - /_cluster/helios@_here_ - /_model/tiny@_here_ - /_trainer/llama@_here_ - /_dataset/c4@_here_ @@ -26,28 +26,31 @@ trainer: unmaskable_special_tokens: [50256, 50257] # <|endoftext|> gradient_accumulation_steps: 2 n_steps: 77050 - learning_rate: 5e-4 + ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] train_dataloader: - tokenize_fn: - _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn - - + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn eval_dataloader: - tokenize_fn: - _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn infrastructure: metric_logger: + type: wandb + wandb_entity: ideas_cv name: pk_mlm - project_name: pmtest/tml-bgw + project_name: tml-bgw tags: - nano - pk_mlm - "seq_len=${common.sequence_length}" - "n_layers=${common.n_blocks}" + - "dmodel=${common.dmodel}" slurm: - gres: gpu:1 + gres: gpu:2 time: "1-00:00:00" job-name: ${infrastructure.metric_logger.name} @@ -61,6 +64,7 @@ model: q_heads: ${common.q_heads} kv_heads: ${common.kv_heads} seq_len: ${common.sequence_length} + causal: false q_proj_fn: _target_: src.projected_compression.model.Linear diff --git a/configs/product_keys/pkm.yaml b/configs/product_keys/pkm.yaml new file mode 100644 index 00000000..e4b80c01 --- /dev/null +++ b/configs/product_keys/pkm.yaml @@ -0,0 +1,207 @@ +# @package _global_ +defaults: + - /_cluster/entropy@_here_ + - /_model/tiny@_here_ + - /_trainer/llama@_here_ + - /_dataset/c4@_here_ + - /_checkpoints/none@_here_ + - /_misc/default@_here_ + - _self_ + +common: + sequence_length: 1024 + batch_size: 64 + dmodel: 1024 + dff: 2724 + datt: ${common.dmodel} + n_blocks: 16 + q_heads: 16 + kv_heads: 16 + # vocab_size: 50304 # GPT-2 vocab + vocab_size: 128256 + + # pkm_n_sub_keys: 128 # 128^2 = 16,384 memory slots + pkm_n_sub_keys: 256 # 256^2 = 65,536 memory slots + # pkm_n_sub_keys: 384 # 384^2 = 147,456 memory slots + # pkm_n_sub_keys: 512 # 512^2 = 262,144 memory slots + # pkm_n_sub_keys: 768 # 768^2 = 589,824 memory slots + # pkm_n_sub_keys: 1024 # 1024^2 = 1,048,576 memory slots + pkm_k: 32 + # pkm_query_dim: 512 + pkm_query_dim: 1024 + pkm_n_heads: 4 + pkm_indices: [7, 14] + +# trainer: +# _target_: src.product_keys.trainer.MaskedLMTrainer +# masking_percentage: 0.2 +# mask_token_id: 50257 +# unmaskable_special_tokens: [50256, 50257] # <|endoftext|> +# gradient_accumulation_steps: 2 +# n_steps: 77050 +# learning_rate: 1e-3 +# ^learning_rate: [5e-4, 1e-3, 2e-3] +# optimizer_param_groups: +# - regex: ".*ff_layer.layer.values.*" +# lr: "${eval:'1.0 * ${trainer.learning_rate}'}" +# indices_filter: +# layer_path: "encoder.blocks" +# indices: ${common.pkm_indices} +# train_dataloader: +# dataset: +# tokenize_fn: +# _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn +# eval_dataloader: +# dataset: +# tokenize_fn: +# _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + +trainer: + n_steps: 77050 + learning_rate: 1e-3 + ^learning_rate: [5e-4, 1e-3, 2e-3] + optimizer_param_groups: + - regex: ".*ff_layer.layer.values.*" + lr: "${eval:'1.0 * ${trainer.learning_rate}'}" + indices_filter: + layer_path: "encoder.blocks" + indices: ${common.pkm_indices} + +infrastructure: + metric_logger: + type: wandb + wandb_entity: ideas_cv + name: pkm_causal_256_query_dim_1024 + project_name: tml-bgw + tags: + - pkm + - clm + - "pkm_n_sub_keys=${common.pkm_n_sub_keys}" + - "pkm_k=${common.pkm_k}" + - "pkm_indices=${model.encoder.block_fn.pkm_indices}" + - "seq_len=${common.sequence_length}" + - "n_layers=${common.n_blocks}" + - "dmodel=${common.dmodel}" + slurm: + gres: gpu:2 + time: "1-00:00:00" + job-name: ${infrastructure.metric_logger.name} + +model: + _target_: src.projected_compression.model.LLM + + embedding: + _target_: src.projected_compression.model.TransformerEmbedding + vocab_size: ${common.vocab_size} + dmodel: ${common.dmodel} + init_fn: + _target_: src.projected_compression.model.trunc_normal_ + _partial_: true + + encoder: + _target_: src.projected_compression.model.TransformerEncoder + n_blocks: ${common.n_blocks} + block_fn: + _target_: src.product_keys.model.HybridTransformerBlock + _partial_: true + pkm_indices: ${common.pkm_indices} + + norm_fn: + _target_: src.core.model.RMSNorm + _partial_: true + eps: 1e-5 + normalized_shape: ${common.dmodel} + + attention_fn: + _target_: src.projected_compression.model.RoPEAttention + _partial_: true + dmodel: ${common.dmodel} + q_heads: ${common.q_heads} + kv_heads: ${common.kv_heads} + seq_len: ${common.sequence_length} + causal: true + + q_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.datt} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + k_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${eval:'(${common.datt} // ${model.encoder.block_fn.attention_fn.q_heads}) * ${model.encoder.block_fn.attention_fn.kv_heads}'} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + v_proj_fn: ${model.encoder.block_fn.attention_fn.k_proj_fn} + + # o_proj_fn: ${model.encoder.block_fn.attention_fn.q_proj_fn} # TODO check have I done it right pls + o_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.datt} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + rope_base: 500000 + rope_scale_freqs: true + + ff_layer_fn: + _target_: src.projected_compression.model.ProjectedLlamaFeedForward + _partial_: true + ff_pre_act_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.dff} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + ff_post_act_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dff} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + gate_fn: ${model.encoder.block_fn.ff_layer_fn.ff_pre_act_fn} + + pkm_layer_fn: + _target_: src.product_keys.model.ProductKeysMemory + _partial_: true + d_model: ${common.dmodel} + query_dim: ${common.pkm_query_dim} + n_sub_keys: ${common.pkm_n_sub_keys} + k_neighbors: ${common.pkm_k} + n_heads: ${common.pkm_n_heads} + + head: + _target_: src.projected_compression.model.TransformerHead + linear_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.vocab_size} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + norm_fn: + _target_: src.core.model.RMSNorm + _partial_: true + eps: 1e-5 + normalized_shape: ${common.dmodel} diff --git a/configs/product_keys/top_k_attention.yaml b/configs/product_keys/top_k_attention.yaml index 99ad5758..16f52d51 100644 --- a/configs/product_keys/top_k_attention.yaml +++ b/configs/product_keys/top_k_attention.yaml @@ -1,6 +1,6 @@ # @package _global_ defaults: - - /_cluster/entropy@_here_ + - /_cluster/helios@_here_ - /_model/tiny@_here_ - /_trainer/llama@_here_ - /_dataset/c4@_here_ @@ -9,37 +9,47 @@ defaults: - _self_ common: - sequence_length: 2048 - batch_size: 32 - dmodel: 768 - dff: 2042 + sequence_length: 1024 + batch_size: 64 + dmodel: 1024 + dff: 2724 datt: ${common.dmodel} - n_blocks: 12 - q_heads: 12 - kv_heads: 12 - vocab_size: 128256 + n_blocks: 16 + q_heads: 16 + kv_heads: 16 + vocab_size: 50304 trainer: + _target_: src.product_keys.trainer.MaskedLMTrainer + masking_percentage: 0.2 + mask_token_id: 50257 + unmaskable_special_tokens: [50256, 50257] # <|endoftext|> gradient_accumulation_steps: 2 - n_steps: 56000 - learning_rate: 1e-3 - - checkpoint: - save: - type: huggingface - path: checkpoint + n_steps: 77050 + ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] + train_dataloader: + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + eval_dataloader: + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn infrastructure: metric_logger: + type: wandb + wandb_entity: ideas_cv name: top_k_attention - project_name: pmtest/tml-bgw + project_name: tml-bgw tags: - nano - top_k_attention - - "lr=${trainer.learning_rate}" - "seq_len=${common.sequence_length}" + - "n_layers=${common.n_blocks}" + - "dmodel=${common.dmodel}" slurm: - gres: gpu:1 + gres: gpu:2 time: "1-00:00:00" job-name: ${infrastructure.metric_logger.name} @@ -57,6 +67,7 @@ model: q_heads: ${common.q_heads} kv_heads: ${common.kv_heads} seq_len: ${common.sequence_length} + causal: false q_proj_fn: _target_: src.projected_compression.model.Linear diff --git a/main.py b/main.py index be7213a9..065c97b0 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import re import os import hydra import yaml @@ -197,9 +198,86 @@ def get_model_optimizer_scheduler(cfg, model, learning_rate): logger.info("Initialization failed, exiting...") return None, None, None model = setup_distributed_training(model, cfg.trainer.distributed) + + optimizer_groups = [] + # If optimizer_param_groups is defined in config, use generic regex-based grouping + if ( + hasattr(cfg.trainer, "optimizer_param_groups") + and cfg.trainer.optimizer_param_groups + ): + assigned_param_ids = set() + logger.info( + f"Model parameters: {[name for name, _ in model.named_parameters()]}" + ) + + for group_cfg in cfg.trainer.optimizer_param_groups: + group_regex = group_cfg.regex + group_lr = group_cfg.get("lr", learning_rate) + group_params = [] + group_matches = [] + + # Determine filter criteria if provided + target_indices = None + layer_path_prefix = "" + if "indices_filter" in group_cfg and group_cfg.indices_filter: + raw_indices = group_cfg.indices_filter.get("indices", []) + if OmegaConf.is_list(raw_indices) or isinstance( + raw_indices, (list, tuple) + ): + target_indices = set(raw_indices) + layer_path_prefix = group_cfg.indices_filter.get("layer_path", "") + + for name, param in model.named_parameters(): + if id(param) in assigned_param_ids: + continue + + if re.search(group_regex, name): + # Apply indices filter logic if configured + if target_indices is not None and layer_path_prefix: + # Assumption: The layer index is the number immediately following the prefix in the parameter name + # Structure: {layer_path_prefix}.{INDEX}.{...} + prefix_esc = re.escape(layer_path_prefix) + # We use search to find prefix.INDEX. anywhere in the name + match_idx = re.search(rf"{prefix_esc}\.(\d+)\.", name) + + if not match_idx: + continue + + layer_idx = int(match_idx.group(1)) + if layer_idx not in target_indices: + continue + + group_params.append(param) + assigned_param_ids.add(id(param)) + group_matches.append(name) + + if group_params: + logger.info( + f"Optimizer group regex='{group_regex}' lr={group_lr} matched {len(group_params)} params: {group_matches}" + ) + optimizer_groups.append({"params": group_params, "lr": group_lr}) + else: + logger.warning( + f"Optimizer group regex='{group_regex}' matched no parameters." + ) + + # Add remaining parameters to the default group + default_params = [] + for name, param in model.named_parameters(): + if id(param) not in assigned_param_ids: + default_params.append(param) + + if default_params: + optimizer_groups.append({"params": default_params, "lr": learning_rate}) + logger.info( + f"Default optimizer group lr={learning_rate} contains {len(default_params)} remaining params." + ) + + else: + optimizer_groups = [{"params": model.parameters(), "lr": learning_rate}] + optimizer = torch.optim.AdamW( - model.parameters(), - lr=learning_rate, + optimizer_groups, weight_decay=cfg.trainer.weight_decay, ) scheduler = instantiate(cfg.trainer.scheduler)( @@ -328,7 +406,10 @@ def run(cfg: OmegaConf, metric_logger=None): if model is not None: logger.info(f"Model initialized") - trainer = instantiate(cfg.trainer) + # exclude optimizer_param_groups from trainer kwargs to avoid error + trainer_args = OmegaConf.to_container(cfg.trainer, resolve=True) + trainer_args.pop("optimizer_param_groups", None) + trainer = instantiate(trainer_args) if "distillation" in cfg: if cfg.distillation.load.type == "huggingface": diff --git a/src/core/schedulers.py b/src/core/schedulers.py index d50f2609..75ab85b0 100644 --- a/src/core/schedulers.py +++ b/src/core/schedulers.py @@ -1,3 +1,4 @@ +import math import torch from torch.optim.lr_scheduler import SequentialLR, LinearLR, ConstantLR @@ -61,27 +62,36 @@ def load_state_dict(self, loaded_state): def get_cosine_scheduler_with_warmup( optimizer, warmup_steps: int, n_steps: int, final_lr_fraction: float ): - assert ( - len(optimizer.param_groups) == 1 - ), "Cosine scheduler only supports one param group" - optimizer_lr = optimizer.param_groups[0][ - "lr" - ] # param_groups changes when applying scheduler warmup = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps, ) - after_warmup_steps = n_steps - warmup_steps - 1 + # Cosine scheduler phase starts after warmup and 1 constant step + cosine_start_step = warmup_steps + 1 + T_max = n_steps - cosine_start_step + + def cosine_lambda(step): + # Calculate progress t within the cosine phase + if step < cosine_start_step: + return 1.0 + t = step - cosine_start_step + if t >= T_max: + return final_lr_fraction + # Decay from 1.0 to final_lr_fraction + return final_lr_fraction + 0.5 * (1 - final_lr_fraction) * ( + 1 + math.cos(math.pi * t / T_max) + ) + constant_scheduler = torch.optim.lr_scheduler.ConstantLR( - optimizer, factor=1.0 - ) # TODO this is only because of a bug in llm-random - cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, - T_max=after_warmup_steps, - eta_min=final_lr_fraction * optimizer_lr, + optimizer, factor=1.0, total_iters=1 + ) + # Use LambdaLR to allow different base/min LRs per parameter group (proportional decay) + cosine_scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lr_lambda=cosine_lambda ) + training_scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup, constant_scheduler, cosine_scheduler], diff --git a/src/product_keys/model.py b/src/product_keys/model.py index 56a8432c..cd373f6e 100644 --- a/src/product_keys/model.py +++ b/src/product_keys/model.py @@ -6,7 +6,7 @@ from torch.nn.init import trunc_normal_ import logging -from src.core.model import AttentionMechanism, RoPE +from src.core.model import AttentionMechanism, Residual, RoPE logger = logging.getLogger(__name__) @@ -20,6 +20,44 @@ def deterministic_weight_init(fan_in, scale): return partial(trunc_normal_, mean=0.0, std=std, a=low, b=high, generator=generator) +class HybridTransformerBlock(nn.Module): + def __init__( + self, + block_id: int, + norm_fn, + attention_fn, + ff_layer_fn, + pkm_layer_fn, + pkm_indices: list[int], + ): + super().__init__() + self.log_name = f"block[{block_id}]" + + self.attention_layer = Residual( + norm=norm_fn(), + layer=attention_fn(), + log_name=f"{self.log_name}/residual_attention", + ) + + if block_id in pkm_indices: + selected_layer = pkm_layer_fn() + layer_type_name = "pkm" + else: + selected_layer = ff_layer_fn() + layer_type_name = "feedforward" + + self.ff_layer = Residual( + norm=norm_fn(), + layer=selected_layer, + log_name=f"{self.log_name}/residual_{layer_type_name}", + ) + + def forward(self, x): + x = self.attention_layer(x) + x = self.ff_layer(x) + return x + + class RoPETopKAttention(nn.Module): def __init__( self, @@ -35,6 +73,7 @@ def __init__( rope_scale_freqs: bool, top_k: int, top_k_before_softmax: bool = True, + causal: bool = True, ): super().__init__() self.q_proj = q_proj_fn() @@ -51,6 +90,8 @@ def __init__( self.top_k = top_k self.top_k_before_softmax = top_k_before_softmax + self.causal = causal + self.rope = RoPE( dhead=self.dhead, length=seq_len, @@ -85,17 +126,19 @@ def forward(self, x): # standard attention if seq_len is smaller or equal top_k if seq_len <= self.top_k: attention_output = self.attention_mechanism( - query=q, key=k, value=v, causal=True + query=q, key=k, value=v, causal=self.causal ) return self.o_proj( attention_output.transpose(1, 2).contiguous().flatten(-2) ) attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.dhead) - causal_mask = torch.triu( - torch.ones(seq_len, seq_len, device=attention_scores.device), diagonal=1 - ).bool() - attention_scores = attention_scores.masked_fill(causal_mask, float("-inf")) + + if self.causal: + causal_mask = torch.triu( + torch.ones(seq_len, seq_len, device=attention_scores.device), diagonal=1 + ).bool() + attention_scores = attention_scores.masked_fill(causal_mask, float("-inf")) if self.top_k_before_softmax: attention_scores = self.__apply_topk_mask( @@ -223,6 +266,8 @@ def forward(self, x): # Calculate similarity between full Q and the reconstructed candidates # q needs unsqueeze to broadcast: (B, H, S, 1, D) @ (B, H, S, K*K, D).T + # TODO + # ! we've calculated both halves separately, we can reuse that scores_final = (q.unsqueeze(-2) * candidates).sum(dim=-1) # Select top K closest combinations @@ -260,3 +305,123 @@ def forward(self, x): attn_output = attn_output.squeeze(-2) return self.o_proj(attn_output.transpose(1, 2).contiguous().flatten(-2)) + + +class ProductKeysMemory(nn.Module): + def __init__( + self, + d_model: int, + query_dim: int, + n_sub_keys: int, + k_neighbors: int, + n_heads: int = 4, + ): + super().__init__() + self.n_heads = n_heads + self.k = k_neighbors + self.n_sub_keys = n_sub_keys + self.query_dim = query_dim + + # Query Network + # Projects input to query space. BatchNorm is crucial for PKM stability/convergence. + self.query_proj = nn.Linear(d_model, n_heads * query_dim) + self.query_bn = nn.BatchNorm1d(n_heads * query_dim) + + # Sub-Keys (Codebooks) + # Two separate sets of keys for the product quantization + self.c1 = nn.Parameter(torch.empty(n_heads, n_sub_keys, query_dim // 2)) + self.c2 = nn.Parameter(torch.empty(n_heads, n_sub_keys, query_dim // 2)) + nn.init.normal_(self.c1, mean=0, std=d_model**-0.5) + nn.init.normal_(self.c2, mean=0, std=d_model**-0.5) + + # Memory Values + # The actual values retrieved. Size is (n_sub_keys^2, d_model) + self.values = nn.Embedding(n_sub_keys * n_sub_keys, d_model) + nn.init.normal_(self.values.weight, mean=0, std=d_model**-0.5) + + def _get_knn(self, queries, codebooks): + """ + Calculates dot product scores and retrieves top-k indices and values. + """ + # queries: (batch, head, sub_dim) + # codebooks: (head, n_sub_keys, sub_dim) + + # Calculate similarity (dot product) + scores = torch.einsum("bhd,hnd->bhn", queries, codebooks) + + # Select top-k + top_scores, top_indices = torch.topk(scores, k=self.k, dim=-1, largest=True) + return top_scores, top_indices + + def forward(self, x): + bs, seq_len, d_model = x.shape + + # 1. Query Projection + x_flat = x.view(-1, d_model) + q = self.query_proj(x_flat) + q = self.query_bn(q) + q = q.view(bs * seq_len, self.n_heads, self.query_dim) + + # Split query into two halves for product quantization + q1, q2 = torch.chunk(q, 2, dim=-1) + + # 2. Retrieve Top-K candidates for each half + scores1, idx1 = self._get_knn(q1, self.c1) + scores2, idx2 = self._get_knn(q2, self.c2) + + # 3. Cartesian Product of Scores + # Sum every score from the first half with every score from the second half + # (BS*Seq, H, K, 1) + (BS*Seq, H, 1, K) -> (BS*Seq, H, K, K) + all_scores = scores1.unsqueeze(3) + scores2.unsqueeze(2) + + # Flatten the KxK grid to K^2 to find the global top-k + all_scores_flat = all_scores.view(bs * seq_len, self.n_heads, -1) + + # Select the best combinations (global top-k) + global_scores, global_top_indices = torch.topk(all_scores_flat, self.k, dim=-1) + + # 4. Index Mapping + # Map the flattened indices back to the original codebook indices + idx1_pos = global_top_indices // self.k + idx2_pos = global_top_indices % self.k + + # Gather the actual sub-key indices + real_idx1 = torch.gather(idx1, 2, idx1_pos) + real_idx2 = torch.gather(idx2, 2, idx2_pos) + + # Calculate the global memory index: i * N_keys + j + memory_indices = real_idx1 * self.n_sub_keys + real_idx2 + + # 5. Read from Memory + attn_weights = F.softmax(global_scores, dim=-1) # (BS*Seq, H, K) + + # Flatten indices and weights to the format expected by embedding_bag + # The "bag" dimension is (BS * Seq * Heads), with K elements in each bag + flat_indices = memory_indices.view(-1, self.k) + flat_weights = attn_weights.view(-1, self.k) + + # Fused Lookup + Weighted Sum + # We avoid creating the massive (BS*Seq, H, K, d_model) tensor by using embedding_bag + is_bfloat16 = flat_weights.dtype == torch.bfloat16 + if is_bfloat16: + flat_weights_fp32 = flat_weights.to(torch.float32) + values_weight_fp32 = self.values.weight.to(torch.float32) + out_flat = F.embedding_bag( + input=flat_indices, + weight=values_weight_fp32, + per_sample_weights=flat_weights_fp32, + mode="sum", + ) + out_flat = out_flat.to(torch.bfloat16) + else: + out_flat = F.embedding_bag( + input=flat_indices, + weight=self.values.weight, + per_sample_weights=flat_weights, + mode="sum", + ) + + # 6. Aggregation + # Restore the correct dimensions: (BS*Seq, H, d_model) -> (BS, Seq, H, d_model) + out_flat = out_flat.view(bs, seq_len, self.n_heads, d_model) + return out_flat.sum(dim=2) # Output: (BS, Seq, d_model) diff --git a/src/projected_compression/model.py b/src/projected_compression/model.py index 4b1371db..527a4f04 100644 --- a/src/projected_compression/model.py +++ b/src/projected_compression/model.py @@ -239,6 +239,7 @@ def __init__( low_freq_factor=1, high_freq_factor=4, original_max_position_embeddings=8192, + causal=True, ): super().__init__() self.q_proj = q_proj_fn() @@ -251,6 +252,7 @@ def __init__( self.kv_heads = kv_heads self.dhead = self.q_proj.weight.shape[0] // self.q_heads self.dmodel = dmodel + self.causal = causal self.rope = RoPE( dhead=self.dhead, @@ -279,7 +281,7 @@ def forward(self, x): k = repeat_kv(k, self.q_heads // self.kv_heads) v = repeat_kv(v, self.q_heads // self.kv_heads) attention_output = self.attention_mechanism( - query=q, key=k, value=v, causal=True + query=q, key=k, value=v, causal=self.causal ) output = self.o_proj(attention_output.transpose(1, 2).contiguous().flatten(-2))