From 879089828a69ed83d215e4cf1491ca2f211842c2 Mon Sep 17 00:00:00 2001 From: Mateusz Borowski Date: Wed, 17 Dec 2025 10:53:09 +0100 Subject: [PATCH 1/9] Add initial PK implementation --- src/product_keys/model.py | 139 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 139 insertions(+) diff --git a/src/product_keys/model.py b/src/product_keys/model.py index 56a8432c..75efaabf 100644 --- a/src/product_keys/model.py +++ b/src/product_keys/model.py @@ -223,6 +223,8 @@ def forward(self, x): # Calculate similarity between full Q and the reconstructed candidates # q needs unsqueeze to broadcast: (B, H, S, 1, D) @ (B, H, S, K*K, D).T + # TODO + # ! we've calculated both halves separately, we can reuse that scores_final = (q.unsqueeze(-2) * candidates).sum(dim=-1) # Select top K closest combinations @@ -260,3 +262,140 @@ def forward(self, x): attn_output = attn_output.squeeze(-2) return self.o_proj(attn_output.transpose(1, 2).contiguous().flatten(-2)) + + +class ProductKeysMemory(nn.Module): + def __init__( + self, input_dim: int, query_dim: int, n_sub_keys: int, k_neighbors: int, n_heads:int=4 + ): + super().__init__() + self.n_heads = n_heads + self.k = k_neighbors + self.n_sub_keys = n_sub_keys + self.query_dim = query_dim + self.sub_key_dim = query_dim // 2 + self.memory_dim = input_dim # todo remove this param, rename input_dim to smth like dmodel + + # Query Network (Learnable) + self.query_proj = nn.Linear(input_dim, n_heads * query_dim) + self.query_bn = nn.BatchNorm1d(n_heads * query_dim) + + # Sub-Keys (Learnable, Separate for each head) + self.c1 = nn.Parameter(torch.randn(n_heads, n_sub_keys, self.sub_key_dim)) + self.c2 = nn.Parameter(torch.randn(n_heads, n_sub_keys, self.sub_key_dim)) + + # Memory Values (Learnable, Shared across heads) + self.values = nn.Embedding(n_sub_keys * n_sub_keys, self.memory_dim) + nn.init.normal_(self.values.weight, mean=0, std=input_dim**-0.5) + + def _get_knn_indices_gpu(self, queries, codebooks): + # queries: (batch, head, dim) + # codebooks: (head, num_sub_keys, dim) + scores = torch.einsum("bhd,hnd->bhn", queries, codebooks) + + _, indices = torch.topk(scores, k=self.k, dim=-1, largest=True) + + return indices + + def _gather_keys(self, codebook, indices): + # codebook: (n_heads, n_keys, dim) + # indices: (batch, n_heads, k) + # Output: (batch, n_heads, k, dim) + + # Expand codebook to batch size + # (1, n_heads, n_keys, dim) -> (batch, n_heads, n_keys, dim) + cb_exp = codebook.unsqueeze(0).expand(indices.size(0), -1, -1, -1) + + # Expand indices to dim + # (batch, n_heads, k, 1) -> (batch, n_heads, k, dim) + idx_exp = indices.unsqueeze(-1).expand(-1, -1, -1, codebook.size(-1)) + + return torch.gather(cb_exp, 2, idx_exp) + + def forward(self, x): + bs, seq_len, input_dim = x.shape + + # Flatten batch and sequence for processing + x_flat = x.view(-1, input_dim) # (batch*seq, input_dim) + q = self.query_proj(x_flat) + q = self.query_bn(q) + q = q.view(-1, self.n_heads, self.query_dim) + + q1, q2 = torch.chunk(q, 2, dim=-1) # Each: (batch, n_heads, sub_dim) + + # Retrieve Sub-Key Indices + idx1 = self._get_knn_indices_gpu(q1, self.c1) + idx2 = self._get_knn_indices_gpu(q2, self.c2) + + k1_selected = self._gather_keys(self.c1, idx1) + k2_selected = self._gather_keys(self.c2, idx2) + + # Compute Dot Products (Scores) + # q1: (batch, n_heads, dim) -> unsqueeze -> (batch, n_heads, 1, dim) + # Todo we've calculated both halves separately, we can reuse that + scores1 = (q1.unsqueeze(2) * k1_selected).sum(dim=-1) # (batch, n_heads, k) + scores2 = (q2.unsqueeze(2) * k2_selected).sum(dim=-1) # (batch, n_heads, k) + + # Cartesian Product Sum + # (batch, n_heads, k, 1) + (batch, n_heads, 1, k) -> (batch, n_heads, k, k) + all_scores = scores1.unsqueeze(3) + scores2.unsqueeze(2) + + # Flatten (k, k) to k^2 to find global top-k + all_scores_flat = all_scores.view( + all_scores.size(0), self.n_heads, -1 + ) # (batch, n_heads, k*k) + + # Select global top-k from the k^2 candidates + top_scores, top_indices_flat = torch.topk(all_scores_flat, self.k, dim=-1) + + # 4. Map back to Global Memory Indices + # We need to find which (i, j) pair in the k*k grid corresponded to the top scores + + # Create grid of local indices + # idx1: (batch, n_heads, k) + idx1_grid = ( + idx1.unsqueeze(3) + .expand(-1, -1, -1, self.k) + .reshape(idx1.size(0), self.n_heads, -1) + ) + idx2_grid = ( + idx2.unsqueeze(2) + .expand(-1, -1, self.k, -1) + .reshape(idx2.size(0), self.n_heads, -1) + ) + + # Gather the actual codebook indices corresponding to the winners + best_idx1 = torch.gather(idx1_grid, 2, top_indices_flat) + best_idx2 = torch.gather(idx2_grid, 2, top_indices_flat) + + # Calculate global memory index: i * |C| + j + global_indices = best_idx1 * self.n_sub_keys + best_idx2 + + # 5. Read from Value Memory (Weighted Sum) + # Softmax over the top-k scores + attn_weights = F.softmax(top_scores, dim=-1) # (batch, n_heads, k) + + # Fetch Values + # self.values.weight: (total_keys, val_dim) + # global_indices: (batch, n_heads, k) + # Output: (batch, n_heads, k, val_dim) + + # Since embedding weight is 2D, we flatten indices to lookup + flat_indices = global_indices.view(-1) + values_selected = F.embedding(flat_indices, self.values.weight) + values_selected = values_selected.view(*global_indices.shape, self.memory_dim) + + # Weighted Sum + # (batch, n_heads, k, val_dim) * (batch, n_heads, k, 1) -> sum over k + head_outputs = (values_selected * attn_weights.unsqueeze(-1)).sum( + dim=2 + ) # (batch, n_heads, val_dim) + + # 6. Multi-Head Aggregation + # "The memory simply sums the output m_i(x) of each head" [cite: 210] + output = head_outputs.sum(dim=1) # (batch, val_dim) + + # Restore sequence dimension + output = output.view(bs, seq_len, self.memory_dim) + + return output From 07466099180a78c8742dc6ec65731c959362b16e Mon Sep 17 00:00:00 2001 From: Wojciech Weremczuk Date: Wed, 14 Jan 2026 09:43:15 +0100 Subject: [PATCH 2/9] PK feed-forward improvements --- configs/product_keys/pkm.yaml | 109 +++++++++++++++++++++ src/product_keys/model.py | 176 ++++++++++++++-------------------- 2 files changed, 179 insertions(+), 106 deletions(-) create mode 100644 configs/product_keys/pkm.yaml diff --git a/configs/product_keys/pkm.yaml b/configs/product_keys/pkm.yaml new file mode 100644 index 00000000..3927d546 --- /dev/null +++ b/configs/product_keys/pkm.yaml @@ -0,0 +1,109 @@ +# @package _global_ +defaults: + - /_cluster/entropy@_here_ + - /_model/tiny@_here_ + - /_trainer/llama@_here_ + - /_dataset/c4@_here_ + - /_checkpoints/none@_here_ + - /_misc/default@_here_ + - _self_ + +common: + sequence_length: 1024 + batch_size: 64 + dmodel: 768 + datt: ${common.dmodel} + n_blocks: 12 + q_heads: 12 + kv_heads: 12 + vocab_size: 50304 # GPT-2 vocab + + pkm_n_sub_keys: 128 # 128^2 = 16,384 memory slots + pkm_k: 32 + pkm_query_dim: 32 + pkm_n_heads: 4 + +trainer: + _target_: src.product_keys.trainer.MaskedLMTrainer + masking_percentage: 0.20 + mask_token_id: 50257 + unmaskable_special_tokens: [50256, 50257] + gradient_accumulation_steps: 2 + n_steps: 50000 + learning_rate: 3e-4 + train_dataloader: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + eval_dataloader: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + +infrastructure: + metric_logger: + name: pkm_test_run + project_name: pmtest/tml-bgw + tags: + - pkm_memory + - mlm_training + - "dmodel=${common.dmodel}" + slurm: + gres: gpu:1 + time: "0-04:00:00" + job-name: ${infrastructure.metric_logger.name} + +model: + encoder: + block_fn: + ff_layer_fn: + _target_: src.product_keys.model.ProductKeysMemory + _partial_: true + d_model: ${common.dmodel} + query_dim: ${common.pkm_query_dim} + n_sub_keys: ${common.pkm_n_sub_keys} + k_neighbors: ${common.pkm_k} + n_heads: ${common.pkm_n_heads} + + attention_fn: + _target_: src.product_keys.model.RoPETopKAttention + _partial_: true + dmodel: ${common.dmodel} + q_heads: ${common.q_heads} + kv_heads: ${common.kv_heads} + seq_len: ${common.sequence_length} + + q_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.datt} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + k_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.datt} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + v_proj_fn: ${model.encoder.block_fn.attention_fn.k_proj_fn} + + o_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.datt} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + rope_base: 10000 + rope_scale_freqs: true + top_k: 32 + top_k_before_softmax: true diff --git a/src/product_keys/model.py b/src/product_keys/model.py index 75efaabf..292de55a 100644 --- a/src/product_keys/model.py +++ b/src/product_keys/model.py @@ -266,136 +266,100 @@ def forward(self, x): class ProductKeysMemory(nn.Module): def __init__( - self, input_dim: int, query_dim: int, n_sub_keys: int, k_neighbors: int, n_heads:int=4 + self, + d_model: int, + query_dim: int, + n_sub_keys: int, + k_neighbors: int, + n_heads: int = 4, + **kwargs, # To ignore unused args ): super().__init__() self.n_heads = n_heads self.k = k_neighbors self.n_sub_keys = n_sub_keys self.query_dim = query_dim - self.sub_key_dim = query_dim // 2 - self.memory_dim = input_dim # todo remove this param, rename input_dim to smth like dmodel - # Query Network (Learnable) - self.query_proj = nn.Linear(input_dim, n_heads * query_dim) + # Query Network + # Projects input to query space. BatchNorm is crucial for PKM stability/convergence. + self.query_proj = nn.Linear(d_model, n_heads * query_dim) self.query_bn = nn.BatchNorm1d(n_heads * query_dim) - # Sub-Keys (Learnable, Separate for each head) - self.c1 = nn.Parameter(torch.randn(n_heads, n_sub_keys, self.sub_key_dim)) - self.c2 = nn.Parameter(torch.randn(n_heads, n_sub_keys, self.sub_key_dim)) - - # Memory Values (Learnable, Shared across heads) - self.values = nn.Embedding(n_sub_keys * n_sub_keys, self.memory_dim) - nn.init.normal_(self.values.weight, mean=0, std=input_dim**-0.5) - - def _get_knn_indices_gpu(self, queries, codebooks): - # queries: (batch, head, dim) - # codebooks: (head, num_sub_keys, dim) + # Sub-Keys (Codebooks) + # Two separate sets of keys for the product quantization + self.c1 = nn.Parameter(torch.randn(n_heads, n_sub_keys, query_dim // 2)) + self.c2 = nn.Parameter(torch.randn(n_heads, n_sub_keys, query_dim // 2)) + + # Memory Values + # The actual values retrieved. Size is (n_sub_keys^2, d_model) + self.values = nn.Embedding(n_sub_keys * n_sub_keys, d_model) + nn.init.normal_(self.values.weight, mean=0, std=d_model**-0.5) + + def _get_knn(self, queries, codebooks): + """ + Calculates dot product scores and retrieves top-k indices and values. + """ + # queries: (batch, head, sub_dim) + # codebooks: (head, n_sub_keys, sub_dim) + + # Calculate similarity (dot product) scores = torch.einsum("bhd,hnd->bhn", queries, codebooks) - - _, indices = torch.topk(scores, k=self.k, dim=-1, largest=True) - - return indices - - def _gather_keys(self, codebook, indices): - # codebook: (n_heads, n_keys, dim) - # indices: (batch, n_heads, k) - # Output: (batch, n_heads, k, dim) - - # Expand codebook to batch size - # (1, n_heads, n_keys, dim) -> (batch, n_heads, n_keys, dim) - cb_exp = codebook.unsqueeze(0).expand(indices.size(0), -1, -1, -1) - - # Expand indices to dim - # (batch, n_heads, k, 1) -> (batch, n_heads, k, dim) - idx_exp = indices.unsqueeze(-1).expand(-1, -1, -1, codebook.size(-1)) - - return torch.gather(cb_exp, 2, idx_exp) + + # Select top-k + top_scores, top_indices = torch.topk(scores, k=self.k, dim=-1, largest=True) + return top_scores, top_indices def forward(self, x): - bs, seq_len, input_dim = x.shape + bs, seq_len, d_model = x.shape - # Flatten batch and sequence for processing - x_flat = x.view(-1, input_dim) # (batch*seq, input_dim) + # 1. Query Projection + x_flat = x.view(-1, d_model) q = self.query_proj(x_flat) q = self.query_bn(q) - q = q.view(-1, self.n_heads, self.query_dim) - - q1, q2 = torch.chunk(q, 2, dim=-1) # Each: (batch, n_heads, sub_dim) + q = q.view(bs * seq_len, self.n_heads, self.query_dim) - # Retrieve Sub-Key Indices - idx1 = self._get_knn_indices_gpu(q1, self.c1) - idx2 = self._get_knn_indices_gpu(q2, self.c2) + # Split query into two halves for product quantization + q1, q2 = torch.chunk(q, 2, dim=-1) - k1_selected = self._gather_keys(self.c1, idx1) - k2_selected = self._gather_keys(self.c2, idx2) + # 2. Retrieve Top-K candidates for each half + scores1, idx1 = self._get_knn(q1, self.c1) + scores2, idx2 = self._get_knn(q2, self.c2) - # Compute Dot Products (Scores) - # q1: (batch, n_heads, dim) -> unsqueeze -> (batch, n_heads, 1, dim) - # Todo we've calculated both halves separately, we can reuse that - scores1 = (q1.unsqueeze(2) * k1_selected).sum(dim=-1) # (batch, n_heads, k) - scores2 = (q2.unsqueeze(2) * k2_selected).sum(dim=-1) # (batch, n_heads, k) - - # Cartesian Product Sum - # (batch, n_heads, k, 1) + (batch, n_heads, 1, k) -> (batch, n_heads, k, k) + # 3. Cartesian Product of Scores + # Sum every score from the first half with every score from the second half + # (BS, H, K, 1) + (BS, H, 1, K) -> (BS, H, K, K) all_scores = scores1.unsqueeze(3) + scores2.unsqueeze(2) - # Flatten (k, k) to k^2 to find global top-k - all_scores_flat = all_scores.view( - all_scores.size(0), self.n_heads, -1 - ) # (batch, n_heads, k*k) - - # Select global top-k from the k^2 candidates - top_scores, top_indices_flat = torch.topk(all_scores_flat, self.k, dim=-1) - - # 4. Map back to Global Memory Indices - # We need to find which (i, j) pair in the k*k grid corresponded to the top scores - - # Create grid of local indices - # idx1: (batch, n_heads, k) - idx1_grid = ( - idx1.unsqueeze(3) - .expand(-1, -1, -1, self.k) - .reshape(idx1.size(0), self.n_heads, -1) - ) - idx2_grid = ( - idx2.unsqueeze(2) - .expand(-1, -1, self.k, -1) - .reshape(idx2.size(0), self.n_heads, -1) - ) - - # Gather the actual codebook indices corresponding to the winners - best_idx1 = torch.gather(idx1_grid, 2, top_indices_flat) - best_idx2 = torch.gather(idx2_grid, 2, top_indices_flat) - - # Calculate global memory index: i * |C| + j - global_indices = best_idx1 * self.n_sub_keys + best_idx2 + # Flatten the KxK grid to K^2 to find the global top-k + all_scores_flat = all_scores.view(bs * seq_len, self.n_heads, -1) + + # Select the best combinations (global top-k) + global_scores, global_top_indices = torch.topk(all_scores_flat, self.k, dim=-1) - # 5. Read from Value Memory (Weighted Sum) - # Softmax over the top-k scores - attn_weights = F.softmax(top_scores, dim=-1) # (batch, n_heads, k) + # 4. Index Mapping + # Map the flattened indices back to the original codebook indices + idx1_pos = global_top_indices // self.k + idx2_pos = global_top_indices % self.k - # Fetch Values - # self.values.weight: (total_keys, val_dim) - # global_indices: (batch, n_heads, k) - # Output: (batch, n_heads, k, val_dim) + # Gather the actual sub-key indices + real_idx1 = torch.gather(idx1, 2, idx1_pos) + real_idx2 = torch.gather(idx2, 2, idx2_pos) - # Since embedding weight is 2D, we flatten indices to lookup - flat_indices = global_indices.view(-1) - values_selected = F.embedding(flat_indices, self.values.weight) - values_selected = values_selected.view(*global_indices.shape, self.memory_dim) + # Calculate the global memory index: i * N_keys + j + memory_indices = real_idx1 * self.n_sub_keys + real_idx2 - # Weighted Sum - # (batch, n_heads, k, val_dim) * (batch, n_heads, k, 1) -> sum over k - head_outputs = (values_selected * attn_weights.unsqueeze(-1)).sum( - dim=2 - ) # (batch, n_heads, val_dim) + # 5. Read from Memory + attn_weights = F.softmax(global_scores, dim=-1) # (BS, H, K) - # 6. Multi-Head Aggregation - # "The memory simply sums the output m_i(x) of each head" [cite: 210] - output = head_outputs.sum(dim=1) # (batch, val_dim) + flat_indices = memory_indices.view(-1) + values_selected = self.values(flat_indices) + values_selected = values_selected.view(bs * seq_len, self.n_heads, self.k, d_model) - # Restore sequence dimension - output = output.view(bs, seq_len, self.memory_dim) + # Weighted sum of retrieved values + out_heads = (values_selected * attn_weights.unsqueeze(-1)).sum(dim=2) - return output + # 6. Aggregation + # Sum outputs across all heads + output = out_heads.sum(dim=1) # (BS, d_model) + + return output.view(bs, seq_len, d_model) From 378179d5e4a7d1a36e971d552f28605698432f79 Mon Sep 17 00:00:00 2001 From: Wojciech Weremczuk Date: Wed, 4 Feb 2026 17:23:29 +0100 Subject: [PATCH 3/9] Product keys in FF. Adjusting configs to run comparisons --- configs/product_keys/baseline.yaml | 163 +++++++++++++++++++ configs/product_keys/pk_mlm.yaml | 24 +-- configs/product_keys/pkm.yaml | 186 +++++++++++++++------- configs/product_keys/top_k_attention.yaml | 49 +++--- src/core/model.py | 4 +- src/product_keys/model.py | 55 ++++++- src/projected_compression/model.py | 4 +- 7 files changed, 391 insertions(+), 94 deletions(-) create mode 100644 configs/product_keys/baseline.yaml diff --git a/configs/product_keys/baseline.yaml b/configs/product_keys/baseline.yaml new file mode 100644 index 00000000..397185b3 --- /dev/null +++ b/configs/product_keys/baseline.yaml @@ -0,0 +1,163 @@ +# @package _global_ +defaults: + - /_cluster/helios@_here_ + - /_model/tiny@_here_ + - /_trainer/llama@_here_ + - /_dataset/c4@_here_ + - /_checkpoints/none@_here_ + - /_misc/default@_here_ + - _self_ + +common: + sequence_length: 1024 + batch_size: 64 + dmodel: 1024 + dff: 2724 + datt: ${common.dmodel} + n_blocks: 16 + q_heads: 16 + kv_heads: 16 + vocab_size: 50304 + +trainer: + _target_: src.product_keys.trainer.MaskedLMTrainer + masking_percentage: 0.2 + mask_token_id: 50257 + unmaskable_special_tokens: [50256, 50257] # <|endoftext|> + gradient_accumulation_steps: 2 + n_steps: 77050 + ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] + train_dataloader: + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + eval_dataloader: + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + + +infrastructure: + metric_logger: + type: wandb + wandb_entity: ideas_cv + name: baseline + project_name: tml-bgw + tags: + - nano + - baseline + - "seq_len=${common.sequence_length}" + - "n_layers=${common.n_blocks}" + - "dmodel=${common.dmodel}" + slurm: + gres: gpu:2 + time: "1-00:00:00" + job-name: ${infrastructure.metric_logger.name} + +model: + _target_: src.projected_compression.model.LLM + + embedding: + _target_: src.projected_compression.model.TransformerEmbedding + vocab_size: ${common.vocab_size} + dmodel: ${common.dmodel} + init_fn: + _target_: src.projected_compression.model.trunc_normal_ + _partial_: true + + encoder: + _target_: src.projected_compression.model.TransformerEncoder + n_blocks: ${common.n_blocks} + block_fn: + _target_: src.projected_compression.model.TransformerBlock + _partial_: true + norm_fn: + _target_: src.core.model.RMSNorm + _partial_: true + eps: 1e-5 + normalized_shape: ${common.dmodel} + + attention_fn: + _target_: src.projected_compression.model.RoPEAttention + _partial_: true + dmodel: ${common.dmodel} + q_heads: ${common.q_heads} + kv_heads: ${common.kv_heads} + seq_len: ${common.sequence_length} + causal: false + + q_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.datt} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + k_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${eval:'(${common.datt} // ${model.encoder.block_fn.attention_fn.q_heads}) * ${model.encoder.block_fn.attention_fn.kv_heads}'} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + v_proj_fn: ${model.encoder.block_fn.attention_fn.k_proj_fn} + + # o_proj_fn: ${model.encoder.block_fn.attention_fn.q_proj_fn} # TODO check have I done it right pls + o_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.datt} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + rope_base: 500000 + rope_scale_freqs: true + + ff_layer_fn: + _target_: src.projected_compression.model.ProjectedLlamaFeedForward + _partial_: true + ff_pre_act_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.dff} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + ff_post_act_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dff} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + gate_fn: ${model.encoder.block_fn.ff_layer_fn.ff_pre_act_fn} + + head: + _target_: src.projected_compression.model.TransformerHead + linear_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.vocab_size} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + norm_fn: + _target_: src.core.model.RMSNorm + _partial_: true + eps: 1e-5 + normalized_shape: ${common.dmodel} \ No newline at end of file diff --git a/configs/product_keys/pk_mlm.yaml b/configs/product_keys/pk_mlm.yaml index 0d636623..aaa3a416 100644 --- a/configs/product_keys/pk_mlm.yaml +++ b/configs/product_keys/pk_mlm.yaml @@ -1,6 +1,6 @@ # @package _global_ defaults: - - /_cluster/entropy@_here_ + - /_cluster/helios@_here_ - /_model/tiny@_here_ - /_trainer/llama@_here_ - /_dataset/c4@_here_ @@ -26,28 +26,31 @@ trainer: unmaskable_special_tokens: [50256, 50257] # <|endoftext|> gradient_accumulation_steps: 2 n_steps: 77050 - learning_rate: 5e-4 + ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] train_dataloader: - tokenize_fn: - _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn - - + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn eval_dataloader: - tokenize_fn: - _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn infrastructure: metric_logger: + type: wandb + wandb_entity: ideas_cv name: pk_mlm - project_name: pmtest/tml-bgw + project_name: tml-bgw tags: - nano - pk_mlm - "seq_len=${common.sequence_length}" - "n_layers=${common.n_blocks}" + - "dmodel=${common.dmodel}" slurm: - gres: gpu:1 + gres: gpu:2 time: "1-00:00:00" job-name: ${infrastructure.metric_logger.name} @@ -61,6 +64,7 @@ model: q_heads: ${common.q_heads} kv_heads: ${common.kv_heads} seq_len: ${common.sequence_length} + causal: false q_proj_fn: _target_: src.projected_compression.model.Linear diff --git a/configs/product_keys/pkm.yaml b/configs/product_keys/pkm.yaml index 3927d546..519ef158 100644 --- a/configs/product_keys/pkm.yaml +++ b/configs/product_keys/pkm.yaml @@ -1,6 +1,6 @@ # @package _global_ defaults: - - /_cluster/entropy@_here_ + - /_cluster/helios@_here_ - /_model/tiny@_here_ - /_trainer/llama@_here_ - /_dataset/c4@_here_ @@ -11,99 +11,171 @@ defaults: common: sequence_length: 1024 batch_size: 64 - dmodel: 768 + dmodel: 1024 + dff: 2724 datt: ${common.dmodel} - n_blocks: 12 - q_heads: 12 - kv_heads: 12 + n_blocks: 16 + q_heads: 16 + kv_heads: 16 vocab_size: 50304 # GPT-2 vocab - pkm_n_sub_keys: 128 # 128^2 = 16,384 memory slots + # pkm_n_sub_keys: 256 # 256^2 = 65,536 memory slots + pkm_n_sub_keys: 512 # 512^2 = 262,144 memory slots pkm_k: 32 - pkm_query_dim: 32 + pkm_query_dim: 512 pkm_n_heads: 4 trainer: _target_: src.product_keys.trainer.MaskedLMTrainer - masking_percentage: 0.20 + masking_percentage: 0.2 mask_token_id: 50257 - unmaskable_special_tokens: [50256, 50257] + unmaskable_special_tokens: [50256, 50257] # <|endoftext|> gradient_accumulation_steps: 2 - n_steps: 50000 - learning_rate: 3e-4 + n_steps: 77050 + ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] train_dataloader: - tokenize_fn: - _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn eval_dataloader: - tokenize_fn: - _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn infrastructure: metric_logger: - name: pkm_test_run - project_name: pmtest/tml-bgw + type: wandb + wandb_entity: ideas_cv + name: pkm + project_name: tml-bgw tags: - - pkm_memory - - mlm_training + - nano + - pkm + - "seq_len=${common.sequence_length}" + - "n_layers=${common.n_blocks}" - "dmodel=${common.dmodel}" + - "pkm_k=${common.pkm_k}" + - "pkm_n_sub_keys=${common.pkm_n_sub_keys}" slurm: - gres: gpu:1 - time: "0-04:00:00" + gres: gpu:2 + time: "1-00:00:00" job-name: ${infrastructure.metric_logger.name} model: + _target_: src.projected_compression.model.LLM + + embedding: + _target_: src.projected_compression.model.TransformerEmbedding + vocab_size: ${common.vocab_size} + dmodel: ${common.dmodel} + init_fn: + _target_: src.projected_compression.model.trunc_normal_ + _partial_: true + encoder: + _target_: src.projected_compression.model.TransformerEncoder + n_blocks: ${common.n_blocks} block_fn: - ff_layer_fn: - _target_: src.product_keys.model.ProductKeysMemory + _target_: src.product_keys.model.HybridTransformerBlock + _partial_: true + pkm_indices: [7, 14] # Layers 8 and 15 use PKM + + norm_fn: + _target_: src.core.model.RMSNorm _partial_: true - d_model: ${common.dmodel} - query_dim: ${common.pkm_query_dim} - n_sub_keys: ${common.pkm_n_sub_keys} - k_neighbors: ${common.pkm_k} - n_heads: ${common.pkm_n_heads} + eps: 1e-5 + normalized_shape: ${common.dmodel} attention_fn: - _target_: src.product_keys.model.RoPETopKAttention + _target_: src.projected_compression.model.RoPEAttention _partial_: true dmodel: ${common.dmodel} q_heads: ${common.q_heads} kv_heads: ${common.kv_heads} seq_len: ${common.sequence_length} - + causal: false + q_proj_fn: _target_: src.projected_compression.model.Linear _partial_: true in_features: ${common.dmodel} out_features: ${common.datt} partial_init_fn: - _target_: src.projected_compression.model.llm_random_weight_init - _partial_: true - scale: 1 - + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + k_proj_fn: - _target_: src.projected_compression.model.Linear - _partial_: true - in_features: ${common.dmodel} - out_features: ${common.datt} - partial_init_fn: - _target_: src.projected_compression.model.llm_random_weight_init - _partial_: true - scale: 1 - + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${eval:'(${common.datt} // ${model.encoder.block_fn.attention_fn.q_heads}) * ${model.encoder.block_fn.attention_fn.kv_heads}'} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + v_proj_fn: ${model.encoder.block_fn.attention_fn.k_proj_fn} - - o_proj_fn: - _target_: src.projected_compression.model.Linear - _partial_: true - in_features: ${common.datt} - out_features: ${common.dmodel} - partial_init_fn: - _target_: src.projected_compression.model.llm_random_weight_init - _partial_: true - scale: 1 - - rope_base: 10000 + + # o_proj_fn: ${model.encoder.block_fn.attention_fn.q_proj_fn} # TODO check have I done it right pls + o_proj_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.datt} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + + rope_base: 500000 rope_scale_freqs: true - top_k: 32 - top_k_before_softmax: true + + ff_layer_fn: + _target_: src.projected_compression.model.ProjectedLlamaFeedForward + _partial_: true + ff_pre_act_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.dff} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + ff_post_act_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dff} + out_features: ${common.dmodel} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + gate_fn: ${model.encoder.block_fn.ff_layer_fn.ff_pre_act_fn} + + pkm_layer_fn: + _target_: src.product_keys.model.ProductKeysMemory + _partial_: true + d_model: ${common.dmodel} + query_dim: ${common.pkm_query_dim} + n_sub_keys: ${common.pkm_n_sub_keys} + k_neighbors: ${common.pkm_k} + n_heads: ${common.pkm_n_heads} + + head: + _target_: src.projected_compression.model.TransformerHead + linear_fn: + _target_: src.projected_compression.model.Linear + _partial_: true + in_features: ${common.dmodel} + out_features: ${common.vocab_size} + partial_init_fn: + _target_: src.projected_compression.model.llm_random_weight_init + _partial_: true + scale: 1 + norm_fn: + _target_: src.core.model.RMSNorm + _partial_: true + eps: 1e-5 + normalized_shape: ${common.dmodel} diff --git a/configs/product_keys/top_k_attention.yaml b/configs/product_keys/top_k_attention.yaml index 99ad5758..16f52d51 100644 --- a/configs/product_keys/top_k_attention.yaml +++ b/configs/product_keys/top_k_attention.yaml @@ -1,6 +1,6 @@ # @package _global_ defaults: - - /_cluster/entropy@_here_ + - /_cluster/helios@_here_ - /_model/tiny@_here_ - /_trainer/llama@_here_ - /_dataset/c4@_here_ @@ -9,37 +9,47 @@ defaults: - _self_ common: - sequence_length: 2048 - batch_size: 32 - dmodel: 768 - dff: 2042 + sequence_length: 1024 + batch_size: 64 + dmodel: 1024 + dff: 2724 datt: ${common.dmodel} - n_blocks: 12 - q_heads: 12 - kv_heads: 12 - vocab_size: 128256 + n_blocks: 16 + q_heads: 16 + kv_heads: 16 + vocab_size: 50304 trainer: + _target_: src.product_keys.trainer.MaskedLMTrainer + masking_percentage: 0.2 + mask_token_id: 50257 + unmaskable_special_tokens: [50256, 50257] # <|endoftext|> gradient_accumulation_steps: 2 - n_steps: 56000 - learning_rate: 1e-3 - - checkpoint: - save: - type: huggingface - path: checkpoint + n_steps: 77050 + ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] + train_dataloader: + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + eval_dataloader: + dataset: + tokenize_fn: + _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn infrastructure: metric_logger: + type: wandb + wandb_entity: ideas_cv name: top_k_attention - project_name: pmtest/tml-bgw + project_name: tml-bgw tags: - nano - top_k_attention - - "lr=${trainer.learning_rate}" - "seq_len=${common.sequence_length}" + - "n_layers=${common.n_blocks}" + - "dmodel=${common.dmodel}" slurm: - gres: gpu:1 + gres: gpu:2 time: "1-00:00:00" job-name: ${infrastructure.metric_logger.name} @@ -57,6 +67,7 @@ model: q_heads: ${common.q_heads} kv_heads: ${common.kv_heads} seq_len: ${common.sequence_length} + causal: false q_proj_fn: _target_: src.projected_compression.model.Linear diff --git a/src/core/model.py b/src/core/model.py index 0d20118e..9401b128 100644 --- a/src/core/model.py +++ b/src/core/model.py @@ -380,6 +380,7 @@ def __init__( seq_len, rope_base, rope_scale_freqs: bool, + causal=True, ): super().__init__() self.q_proj = q_proj_fn() @@ -392,6 +393,7 @@ def __init__( self.kv_heads = kv_heads self.dhead = self.q_proj.weight.shape[0] // self.q_heads self.dmodel = dmodel + self.causal = causal self.rope = RoPE( dhead=self.dhead, @@ -418,7 +420,7 @@ def forward(self, x): k = repeat_kv(k, self.q_heads // self.kv_heads) v = repeat_kv(v, self.q_heads // self.kv_heads) attention_output = self.attention_mechanism( - query=q, key=k, value=v, causal=True + query=q, key=k, value=v, causal=self.causal ) output = self.o_proj(attention_output.transpose(1, 2).contiguous().flatten(-2)) diff --git a/src/product_keys/model.py b/src/product_keys/model.py index 292de55a..a460ee7a 100644 --- a/src/product_keys/model.py +++ b/src/product_keys/model.py @@ -6,7 +6,7 @@ from torch.nn.init import trunc_normal_ import logging -from src.core.model import AttentionMechanism, RoPE +from src.core.model import AttentionMechanism, Residual, RoPE logger = logging.getLogger(__name__) @@ -20,6 +20,44 @@ def deterministic_weight_init(fan_in, scale): return partial(trunc_normal_, mean=0.0, std=std, a=low, b=high, generator=generator) +class HybridTransformerBlock(nn.Module): + def __init__( + self, + block_id: int, + norm_fn, + attention_fn, + ff_layer_fn, + pkm_layer_fn, + pkm_indices: list[int], + ): + super().__init__() + self.log_name = f"block[{block_id}]" + + self.attention_layer = Residual( + norm=norm_fn(), + layer=attention_fn(), + log_name=f"{self.log_name}/residual_attention", + ) + + if block_id in pkm_indices: + selected_layer = pkm_layer_fn() + layer_type_name = "pkm" + else: + selected_layer = ff_layer_fn() + layer_type_name = "feedforward" + + self.ff_layer = Residual( + norm=norm_fn(), + layer=selected_layer, + log_name=f"{self.log_name}/residual_{layer_type_name}", + ) + + def forward(self, x): + x = self.attention_layer(x) + x = self.ff_layer(x) + return x + + class RoPETopKAttention(nn.Module): def __init__( self, @@ -35,6 +73,7 @@ def __init__( rope_scale_freqs: bool, top_k: int, top_k_before_softmax: bool = True, + causal: bool = True, ): super().__init__() self.q_proj = q_proj_fn() @@ -50,6 +89,8 @@ def __init__( self.top_k = top_k self.top_k_before_softmax = top_k_before_softmax + + self.causal = causal self.rope = RoPE( dhead=self.dhead, @@ -85,17 +126,19 @@ def forward(self, x): # standard attention if seq_len is smaller or equal top_k if seq_len <= self.top_k: attention_output = self.attention_mechanism( - query=q, key=k, value=v, causal=True + query=q, key=k, value=v, causal=self.causal ) return self.o_proj( attention_output.transpose(1, 2).contiguous().flatten(-2) ) attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.dhead) - causal_mask = torch.triu( - torch.ones(seq_len, seq_len, device=attention_scores.device), diagonal=1 - ).bool() - attention_scores = attention_scores.masked_fill(causal_mask, float("-inf")) + + if self.causal: + causal_mask = torch.triu( + torch.ones(seq_len, seq_len, device=attention_scores.device), diagonal=1 + ).bool() + attention_scores = attention_scores.masked_fill(causal_mask, float("-inf")) if self.top_k_before_softmax: attention_scores = self.__apply_topk_mask( diff --git a/src/projected_compression/model.py b/src/projected_compression/model.py index 4b1371db..527a4f04 100644 --- a/src/projected_compression/model.py +++ b/src/projected_compression/model.py @@ -239,6 +239,7 @@ def __init__( low_freq_factor=1, high_freq_factor=4, original_max_position_embeddings=8192, + causal=True, ): super().__init__() self.q_proj = q_proj_fn() @@ -251,6 +252,7 @@ def __init__( self.kv_heads = kv_heads self.dhead = self.q_proj.weight.shape[0] // self.q_heads self.dmodel = dmodel + self.causal = causal self.rope = RoPE( dhead=self.dhead, @@ -279,7 +281,7 @@ def forward(self, x): k = repeat_kv(k, self.q_heads // self.kv_heads) v = repeat_kv(v, self.q_heads // self.kv_heads) attention_output = self.attention_mechanism( - query=q, key=k, value=v, causal=True + query=q, key=k, value=v, causal=self.causal ) output = self.o_proj(attention_output.transpose(1, 2).contiguous().flatten(-2)) From 1f93892b34eae647b190a717de96190404d1b8ef Mon Sep 17 00:00:00 2001 From: Wojciech Weremczuk Date: Wed, 18 Feb 2026 10:39:22 +0100 Subject: [PATCH 4/9] Fix --- src/core/model.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/core/model.py b/src/core/model.py index 9401b128..0d20118e 100644 --- a/src/core/model.py +++ b/src/core/model.py @@ -380,7 +380,6 @@ def __init__( seq_len, rope_base, rope_scale_freqs: bool, - causal=True, ): super().__init__() self.q_proj = q_proj_fn() @@ -393,7 +392,6 @@ def __init__( self.kv_heads = kv_heads self.dhead = self.q_proj.weight.shape[0] // self.q_heads self.dmodel = dmodel - self.causal = causal self.rope = RoPE( dhead=self.dhead, @@ -420,7 +418,7 @@ def forward(self, x): k = repeat_kv(k, self.q_heads // self.kv_heads) v = repeat_kv(v, self.q_heads // self.kv_heads) attention_output = self.attention_mechanism( - query=q, key=k, value=v, causal=self.causal + query=q, key=k, value=v, causal=True ) output = self.o_proj(attention_output.transpose(1, 2).contiguous().flatten(-2)) From a75519ceff87282b554fce54fe2f0e6c19d3f20b Mon Sep 17 00:00:00 2001 From: Wojciech Weremczuk Date: Wed, 18 Feb 2026 10:47:21 +0100 Subject: [PATCH 5/9] Fixed formatting --- src/product_keys/model.py | 34 ++++++++++++++++++---------------- 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/src/product_keys/model.py b/src/product_keys/model.py index a460ee7a..da02ac85 100644 --- a/src/product_keys/model.py +++ b/src/product_keys/model.py @@ -38,7 +38,7 @@ def __init__( layer=attention_fn(), log_name=f"{self.log_name}/residual_attention", ) - + if block_id in pkm_indices: selected_layer = pkm_layer_fn() layer_type_name = "pkm" @@ -89,7 +89,7 @@ def __init__( self.top_k = top_k self.top_k_before_softmax = top_k_before_softmax - + self.causal = causal self.rope = RoPE( @@ -133,7 +133,7 @@ def forward(self, x): ) attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.dhead) - + if self.causal: causal_mask = torch.triu( torch.ones(seq_len, seq_len, device=attention_scores.device), diagonal=1 @@ -309,11 +309,11 @@ def forward(self, x): class ProductKeysMemory(nn.Module): def __init__( - self, - d_model: int, - query_dim: int, - n_sub_keys: int, - k_neighbors: int, + self, + d_model: int, + query_dim: int, + n_sub_keys: int, + k_neighbors: int, n_heads: int = 4, **kwargs, # To ignore unused args ): @@ -344,10 +344,10 @@ def _get_knn(self, queries, codebooks): """ # queries: (batch, head, sub_dim) # codebooks: (head, n_sub_keys, sub_dim) - + # Calculate similarity (dot product) scores = torch.einsum("bhd,hnd->bhn", queries, codebooks) - + # Select top-k top_scores, top_indices = torch.topk(scores, k=self.k, dim=-1, largest=True) return top_scores, top_indices @@ -375,7 +375,7 @@ def forward(self, x): # Flatten the KxK grid to K^2 to find the global top-k all_scores_flat = all_scores.view(bs * seq_len, self.n_heads, -1) - + # Select the best combinations (global top-k) global_scores, global_top_indices = torch.topk(all_scores_flat, self.k, dim=-1) @@ -392,17 +392,19 @@ def forward(self, x): memory_indices = real_idx1 * self.n_sub_keys + real_idx2 # 5. Read from Memory - attn_weights = F.softmax(global_scores, dim=-1) # (BS, H, K) + attn_weights = F.softmax(global_scores, dim=-1) # (BS, H, K) flat_indices = memory_indices.view(-1) - values_selected = self.values(flat_indices) - values_selected = values_selected.view(bs * seq_len, self.n_heads, self.k, d_model) + values_selected = self.values(flat_indices) + values_selected = values_selected.view( + bs * seq_len, self.n_heads, self.k, d_model + ) # Weighted sum of retrieved values out_heads = (values_selected * attn_weights.unsqueeze(-1)).sum(dim=2) # 6. Aggregation # Sum outputs across all heads - output = out_heads.sum(dim=1) # (BS, d_model) - + output = out_heads.sum(dim=1) # (BS, d_model) + return output.view(bs, seq_len, d_model) From 8e810768267b2dc83ab9706b4b3eaa61afa80dc4 Mon Sep 17 00:00:00 2001 From: Wojciech Weremczuk Date: Mon, 23 Feb 2026 22:48:32 +0100 Subject: [PATCH 6/9] Review fixes, added optimizer_param_groups for lr --- configs/product_keys/pkm.yaml | 8 +++++- main.py | 45 ++++++++++++++++++++++++++++++++-- src/product_keys/model.py | 46 +++++++++++++++++++++++------------ 3 files changed, 81 insertions(+), 18 deletions(-) diff --git a/configs/product_keys/pkm.yaml b/configs/product_keys/pkm.yaml index 519ef158..747d1fde 100644 --- a/configs/product_keys/pkm.yaml +++ b/configs/product_keys/pkm.yaml @@ -21,6 +21,7 @@ common: # pkm_n_sub_keys: 256 # 256^2 = 65,536 memory slots pkm_n_sub_keys: 512 # 512^2 = 262,144 memory slots + # pkm_n_sub_keys: 1024 # 1024^2 = 1,048,576 memory slots pkm_k: 32 pkm_query_dim: 512 pkm_n_heads: 4 @@ -33,6 +34,10 @@ trainer: gradient_accumulation_steps: 2 n_steps: 77050 ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] + # optimizer_param_groups: + # - regex: ".*pkm_layer_fn.*values.*" + # # ^lr: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] # Sweep different learning rates for PKM values + # lr: 1e-3 train_dataloader: dataset: tokenize_fn: @@ -78,7 +83,8 @@ model: block_fn: _target_: src.product_keys.model.HybridTransformerBlock _partial_: true - pkm_indices: [7, 14] # Layers 8 and 15 use PKM + # pkm_indices: [7, 14] # Layers 8 and 15 use PKM + ^pkm_indices: [[7, 14], [3, 7, 11, 15], [1, 3, 5, 7, 9, 11, 13, 15]] # Sweeping different numbers of PKM layers norm_fn: _target_: src.core.model.RMSNorm diff --git a/main.py b/main.py index be7213a9..a0040e7e 100644 --- a/main.py +++ b/main.py @@ -1,3 +1,4 @@ +import re import os import hydra import yaml @@ -197,9 +198,49 @@ def get_model_optimizer_scheduler(cfg, model, learning_rate): logger.info("Initialization failed, exiting...") return None, None, None model = setup_distributed_training(model, cfg.trainer.distributed) + + optimizer_groups = [] + # If optimizer_param_groups is defined in config, use generic regex-based grouping + if hasattr(cfg.trainer, "optimizer_param_groups") and cfg.trainer.optimizer_param_groups: + assigned_param_ids = set() + + for group_cfg in cfg.trainer.optimizer_param_groups: + group_regex = group_cfg.regex + group_lr = group_cfg.get("lr", learning_rate) + group_params = [] + group_matches = [] + + for name, param in model.named_parameters(): + if id(param) in assigned_param_ids: + continue + + if re.search(group_regex, name): + group_params.append(param) + assigned_param_ids.add(id(param)) + group_matches.append(name) + + if group_params: + logger.info(f"Optimizer group regex='{group_regex}' lr={group_lr} matched {len(group_params)} params: {group_matches}") + optimizer_groups.append({"params": group_params, "lr": group_lr}) + else: + logger.warning(f"Optimizer group regex='{group_regex}' matched no parameters.") + + # Add remaining parameters to the default group + default_params = [] + for name, param in model.named_parameters(): + if id(param) not in assigned_param_ids: + default_params.append(param) + + if default_params: + optimizer_groups.append({"params": default_params, "lr": learning_rate}) + logger.info(f"Default optimizer group lr={learning_rate} contains {len(default_params)} remaining params.") + + else: + # Fallback to simple default group (or previous hardcoded logic if we wanted to keep it, but user asked to generalize) + optimizer_groups = [{"params": model.parameters(), "lr": learning_rate}] + optimizer = torch.optim.AdamW( - model.parameters(), - lr=learning_rate, + optimizer_groups, weight_decay=cfg.trainer.weight_decay, ) scheduler = instantiate(cfg.trainer.scheduler)( diff --git a/src/product_keys/model.py b/src/product_keys/model.py index da02ac85..6cf9e0fc 100644 --- a/src/product_keys/model.py +++ b/src/product_keys/model.py @@ -370,7 +370,7 @@ def forward(self, x): # 3. Cartesian Product of Scores # Sum every score from the first half with every score from the second half - # (BS, H, K, 1) + (BS, H, 1, K) -> (BS, H, K, K) + # (BS*Seq, H, K, 1) + (BS*Seq, H, 1, K) -> (BS*Seq, H, K, K) all_scores = scores1.unsqueeze(3) + scores2.unsqueeze(2) # Flatten the KxK grid to K^2 to find the global top-k @@ -392,19 +392,35 @@ def forward(self, x): memory_indices = real_idx1 * self.n_sub_keys + real_idx2 # 5. Read from Memory - attn_weights = F.softmax(global_scores, dim=-1) # (BS, H, K) - - flat_indices = memory_indices.view(-1) - values_selected = self.values(flat_indices) - values_selected = values_selected.view( - bs * seq_len, self.n_heads, self.k, d_model - ) - - # Weighted sum of retrieved values - out_heads = (values_selected * attn_weights.unsqueeze(-1)).sum(dim=2) + attn_weights = F.softmax(global_scores, dim=-1) # (BS*Seq, H, K) + + # Flatten indices and weights to the format expected by embedding_bag + # The "bag" dimension is (BS * Seq * Heads), with K elements in each bag + flat_indices = memory_indices.view(-1, self.k) + flat_weights = attn_weights.view(-1, self.k) + + # Fused Lookup + Weighted Sum + # We avoid creating the massive (BS*Seq, H, K, d_model) tensor by using embedding_bag + is_bfloat16 = flat_weights.dtype == torch.bfloat16 + if is_bfloat16: + flat_weights_fp32 = flat_weights.to(torch.float32) + values_weight_fp32 = self.values.weight.to(torch.float32) + out_flat = F.embedding_bag( + input=flat_indices, + weight=values_weight_fp32, + per_sample_weights=flat_weights_fp32, + mode='sum' + ) + out_flat = out_flat.to(torch.bfloat16) + else: + out_flat = F.embedding_bag( + input=flat_indices, + weight=self.values.weight, + per_sample_weights=flat_weights, + mode='sum' + ) # 6. Aggregation - # Sum outputs across all heads - output = out_heads.sum(dim=1) # (BS, d_model) - - return output.view(bs, seq_len, d_model) + # Restore the correct dimensions: (BS*Seq, H, d_model) -> (BS, Seq, H, d_model) + out_flat = out_flat.view(bs, seq_len, self.n_heads, d_model) + return out_flat.sum(dim=2) # Output: (BS, Seq, d_model) From fe79d21d3dab74f3c211c8ff2b2100f713b0f276 Mon Sep 17 00:00:00 2001 From: Wojciech Weremczuk Date: Tue, 24 Feb 2026 21:32:25 +0100 Subject: [PATCH 7/9] Fixes and formatting --- configs/product_keys/baseline.yaml | 5 +++-- configs/product_keys/pkm.yaml | 32 ++++++++++++++++++------------ main.py | 25 +++++++++++++++-------- src/product_keys/model.py | 10 ++++++---- 4 files changed, 45 insertions(+), 27 deletions(-) diff --git a/configs/product_keys/baseline.yaml b/configs/product_keys/baseline.yaml index 397185b3..769435cb 100644 --- a/configs/product_keys/baseline.yaml +++ b/configs/product_keys/baseline.yaml @@ -26,7 +26,8 @@ trainer: unmaskable_special_tokens: [50256, 50257] # <|endoftext|> gradient_accumulation_steps: 2 n_steps: 77050 - ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] + # ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] + learning_rate: 1e-3 train_dataloader: dataset: tokenize_fn: @@ -44,8 +45,8 @@ infrastructure: name: baseline project_name: tml-bgw tags: - - nano - baseline + - mlm - "seq_len=${common.sequence_length}" - "n_layers=${common.n_blocks}" - "dmodel=${common.dmodel}" diff --git a/configs/product_keys/pkm.yaml b/configs/product_keys/pkm.yaml index 747d1fde..bd3d4e37 100644 --- a/configs/product_keys/pkm.yaml +++ b/configs/product_keys/pkm.yaml @@ -19,9 +19,12 @@ common: kv_heads: 16 vocab_size: 50304 # GPT-2 vocab + # pkm_n_sub_keys: 128 # 128^2 = 16,384 memory slots # pkm_n_sub_keys: 256 # 256^2 = 65,536 memory slots - pkm_n_sub_keys: 512 # 512^2 = 262,144 memory slots - # pkm_n_sub_keys: 1024 # 1024^2 = 1,048,576 memory slots + # pkm_n_sub_keys: 384 # 384^2 = 147,456 memory slots + # pkm_n_sub_keys: 512 # 512^2 = 262,144 memory slots + # pkm_n_sub_keys: 768 # 768^2 = 589,824 memory slots + pkm_n_sub_keys: 1024 # 1024^2 = 1,048,576 memory slots pkm_k: 32 pkm_query_dim: 512 pkm_n_heads: 4 @@ -33,11 +36,11 @@ trainer: unmaskable_special_tokens: [50256, 50257] # <|endoftext|> gradient_accumulation_steps: 2 n_steps: 77050 - ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] - # optimizer_param_groups: - # - regex: ".*pkm_layer_fn.*values.*" - # # ^lr: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] # Sweep different learning rates for PKM values - # lr: 1e-3 + # ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] + learning_rate: 1e-3 + optimizer_param_groups: + - regex: ".*pkm_layer_fn.*values.*" + lr: 5e-3 train_dataloader: dataset: tokenize_fn: @@ -51,16 +54,18 @@ infrastructure: metric_logger: type: wandb wandb_entity: ideas_cv - name: pkm + name: "pkm_${common.pkm_n_sub_keys}_memory_lr_5e-3" project_name: tml-bgw tags: - - nano - pkm + - mlm + - "pkm_memory_lr=${trainer.optimizer_param_groups.0.lr}" + - "pkm_n_sub_keys=${common.pkm_n_sub_keys}" + - "pkm_k=${common.pkm_k}" + - "pkm_indices=${model.encoder.block_fn.pkm_indices}" - "seq_len=${common.sequence_length}" - "n_layers=${common.n_blocks}" - "dmodel=${common.dmodel}" - - "pkm_k=${common.pkm_k}" - - "pkm_n_sub_keys=${common.pkm_n_sub_keys}" slurm: gres: gpu:2 time: "1-00:00:00" @@ -83,8 +88,9 @@ model: block_fn: _target_: src.product_keys.model.HybridTransformerBlock _partial_: true - # pkm_indices: [7, 14] # Layers 8 and 15 use PKM - ^pkm_indices: [[7, 14], [3, 7, 11, 15], [1, 3, 5, 7, 9, 11, 13, 15]] # Sweeping different numbers of PKM layers + pkm_indices: [7, 14] # Layers 8 and 15 use PKM + # pkm_indices: [3, 7, 10, 14] + # pkm_indices: [2, 4, 7, 9, 12, 14] norm_fn: _target_: src.core.model.RMSNorm diff --git a/main.py b/main.py index a0040e7e..c76390c7 100644 --- a/main.py +++ b/main.py @@ -201,9 +201,12 @@ def get_model_optimizer_scheduler(cfg, model, learning_rate): optimizer_groups = [] # If optimizer_param_groups is defined in config, use generic regex-based grouping - if hasattr(cfg.trainer, "optimizer_param_groups") and cfg.trainer.optimizer_param_groups: + if ( + hasattr(cfg.trainer, "optimizer_param_groups") + and cfg.trainer.optimizer_param_groups + ): assigned_param_ids = set() - + for group_cfg in cfg.trainer.optimizer_param_groups: group_regex = group_cfg.regex group_lr = group_cfg.get("lr", learning_rate) @@ -213,27 +216,33 @@ def get_model_optimizer_scheduler(cfg, model, learning_rate): for name, param in model.named_parameters(): if id(param) in assigned_param_ids: continue - + if re.search(group_regex, name): group_params.append(param) assigned_param_ids.add(id(param)) group_matches.append(name) - + if group_params: - logger.info(f"Optimizer group regex='{group_regex}' lr={group_lr} matched {len(group_params)} params: {group_matches}") + logger.info( + f"Optimizer group regex='{group_regex}' lr={group_lr} matched {len(group_params)} params: {group_matches}" + ) optimizer_groups.append({"params": group_params, "lr": group_lr}) else: - logger.warning(f"Optimizer group regex='{group_regex}' matched no parameters.") + logger.warning( + f"Optimizer group regex='{group_regex}' matched no parameters." + ) # Add remaining parameters to the default group default_params = [] for name, param in model.named_parameters(): if id(param) not in assigned_param_ids: default_params.append(param) - + if default_params: optimizer_groups.append({"params": default_params, "lr": learning_rate}) - logger.info(f"Default optimizer group lr={learning_rate} contains {len(default_params)} remaining params.") + logger.info( + f"Default optimizer group lr={learning_rate} contains {len(default_params)} remaining params." + ) else: # Fallback to simple default group (or previous hardcoded logic if we wanted to keep it, but user asked to generalize) diff --git a/src/product_keys/model.py b/src/product_keys/model.py index 6cf9e0fc..e6715964 100644 --- a/src/product_keys/model.py +++ b/src/product_keys/model.py @@ -330,8 +330,10 @@ def __init__( # Sub-Keys (Codebooks) # Two separate sets of keys for the product quantization - self.c1 = nn.Parameter(torch.randn(n_heads, n_sub_keys, query_dim // 2)) - self.c2 = nn.Parameter(torch.randn(n_heads, n_sub_keys, query_dim // 2)) + self.c1 = nn.Parameter(torch.empty(n_heads, n_sub_keys, query_dim // 2)) + self.c2 = nn.Parameter(torch.empty(n_heads, n_sub_keys, query_dim // 2)) + nn.init.normal_(self.c1, mean=0, std=d_model**-0.5) + nn.init.normal_(self.c2, mean=0, std=d_model**-0.5) # Memory Values # The actual values retrieved. Size is (n_sub_keys^2, d_model) @@ -409,7 +411,7 @@ def forward(self, x): input=flat_indices, weight=values_weight_fp32, per_sample_weights=flat_weights_fp32, - mode='sum' + mode="sum", ) out_flat = out_flat.to(torch.bfloat16) else: @@ -417,7 +419,7 @@ def forward(self, x): input=flat_indices, weight=self.values.weight, per_sample_weights=flat_weights, - mode='sum' + mode="sum", ) # 6. Aggregation From c4cc2d9e96c95f0fc4ec4ace0ae4880b74466602 Mon Sep 17 00:00:00 2001 From: Wojciech Weremczuk Date: Mon, 2 Mar 2026 18:35:55 +0100 Subject: [PATCH 8/9] Optimizer param groups --- configs/product_keys/pkm.yaml | 23 +++++++++++----------- main.py | 35 ++++++++++++++++++++++++++++++++-- src/core/schedulers.py | 36 ++++++++++++++++++++++------------- src/product_keys/model.py | 1 - 4 files changed, 68 insertions(+), 27 deletions(-) diff --git a/configs/product_keys/pkm.yaml b/configs/product_keys/pkm.yaml index bd3d4e37..bb87afab 100644 --- a/configs/product_keys/pkm.yaml +++ b/configs/product_keys/pkm.yaml @@ -21,13 +21,14 @@ common: # pkm_n_sub_keys: 128 # 128^2 = 16,384 memory slots # pkm_n_sub_keys: 256 # 256^2 = 65,536 memory slots - # pkm_n_sub_keys: 384 # 384^2 = 147,456 memory slots - # pkm_n_sub_keys: 512 # 512^2 = 262,144 memory slots + # pkm_n_sub_keys: 384 # 384^2 = 147,456 memory slots + pkm_n_sub_keys: 512 # 512^2 = 262,144 memory slots # pkm_n_sub_keys: 768 # 768^2 = 589,824 memory slots - pkm_n_sub_keys: 1024 # 1024^2 = 1,048,576 memory slots + # pkm_n_sub_keys: 1024 # 1024^2 = 1,048,576 memory slots pkm_k: 32 pkm_query_dim: 512 pkm_n_heads: 4 + pkm_indices: [7, 14] trainer: _target_: src.product_keys.trainer.MaskedLMTrainer @@ -36,11 +37,14 @@ trainer: unmaskable_special_tokens: [50256, 50257] # <|endoftext|> gradient_accumulation_steps: 2 n_steps: 77050 - # ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] learning_rate: 1e-3 + ^learning_rate: [5e-4, 1e-3, 2e-3] optimizer_param_groups: - - regex: ".*pkm_layer_fn.*values.*" - lr: 5e-3 + - regex: ".*ff_layer.layer.values.*" + lr: "${eval:'1.0 * ${trainer.learning_rate}'}" + indices_filter: + layer_path: "encoder.blocks" + indices: ${common.pkm_indices} train_dataloader: dataset: tokenize_fn: @@ -54,12 +58,11 @@ infrastructure: metric_logger: type: wandb wandb_entity: ideas_cv - name: "pkm_${common.pkm_n_sub_keys}_memory_lr_5e-3" + name: "pkm" project_name: tml-bgw tags: - pkm - mlm - - "pkm_memory_lr=${trainer.optimizer_param_groups.0.lr}" - "pkm_n_sub_keys=${common.pkm_n_sub_keys}" - "pkm_k=${common.pkm_k}" - "pkm_indices=${model.encoder.block_fn.pkm_indices}" @@ -88,9 +91,7 @@ model: block_fn: _target_: src.product_keys.model.HybridTransformerBlock _partial_: true - pkm_indices: [7, 14] # Layers 8 and 15 use PKM - # pkm_indices: [3, 7, 10, 14] - # pkm_indices: [2, 4, 7, 9, 12, 14] + pkm_indices: ${common.pkm_indices} norm_fn: _target_: src.core.model.RMSNorm diff --git a/main.py b/main.py index c76390c7..065c97b0 100644 --- a/main.py +++ b/main.py @@ -206,6 +206,9 @@ def get_model_optimizer_scheduler(cfg, model, learning_rate): and cfg.trainer.optimizer_param_groups ): assigned_param_ids = set() + logger.info( + f"Model parameters: {[name for name, _ in model.named_parameters()]}" + ) for group_cfg in cfg.trainer.optimizer_param_groups: group_regex = group_cfg.regex @@ -213,11 +216,37 @@ def get_model_optimizer_scheduler(cfg, model, learning_rate): group_params = [] group_matches = [] + # Determine filter criteria if provided + target_indices = None + layer_path_prefix = "" + if "indices_filter" in group_cfg and group_cfg.indices_filter: + raw_indices = group_cfg.indices_filter.get("indices", []) + if OmegaConf.is_list(raw_indices) or isinstance( + raw_indices, (list, tuple) + ): + target_indices = set(raw_indices) + layer_path_prefix = group_cfg.indices_filter.get("layer_path", "") + for name, param in model.named_parameters(): if id(param) in assigned_param_ids: continue if re.search(group_regex, name): + # Apply indices filter logic if configured + if target_indices is not None and layer_path_prefix: + # Assumption: The layer index is the number immediately following the prefix in the parameter name + # Structure: {layer_path_prefix}.{INDEX}.{...} + prefix_esc = re.escape(layer_path_prefix) + # We use search to find prefix.INDEX. anywhere in the name + match_idx = re.search(rf"{prefix_esc}\.(\d+)\.", name) + + if not match_idx: + continue + + layer_idx = int(match_idx.group(1)) + if layer_idx not in target_indices: + continue + group_params.append(param) assigned_param_ids.add(id(param)) group_matches.append(name) @@ -245,7 +274,6 @@ def get_model_optimizer_scheduler(cfg, model, learning_rate): ) else: - # Fallback to simple default group (or previous hardcoded logic if we wanted to keep it, but user asked to generalize) optimizer_groups = [{"params": model.parameters(), "lr": learning_rate}] optimizer = torch.optim.AdamW( @@ -378,7 +406,10 @@ def run(cfg: OmegaConf, metric_logger=None): if model is not None: logger.info(f"Model initialized") - trainer = instantiate(cfg.trainer) + # exclude optimizer_param_groups from trainer kwargs to avoid error + trainer_args = OmegaConf.to_container(cfg.trainer, resolve=True) + trainer_args.pop("optimizer_param_groups", None) + trainer = instantiate(trainer_args) if "distillation" in cfg: if cfg.distillation.load.type == "huggingface": diff --git a/src/core/schedulers.py b/src/core/schedulers.py index d50f2609..75ab85b0 100644 --- a/src/core/schedulers.py +++ b/src/core/schedulers.py @@ -1,3 +1,4 @@ +import math import torch from torch.optim.lr_scheduler import SequentialLR, LinearLR, ConstantLR @@ -61,27 +62,36 @@ def load_state_dict(self, loaded_state): def get_cosine_scheduler_with_warmup( optimizer, warmup_steps: int, n_steps: int, final_lr_fraction: float ): - assert ( - len(optimizer.param_groups) == 1 - ), "Cosine scheduler only supports one param group" - optimizer_lr = optimizer.param_groups[0][ - "lr" - ] # param_groups changes when applying scheduler warmup = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_steps, ) - after_warmup_steps = n_steps - warmup_steps - 1 + # Cosine scheduler phase starts after warmup and 1 constant step + cosine_start_step = warmup_steps + 1 + T_max = n_steps - cosine_start_step + + def cosine_lambda(step): + # Calculate progress t within the cosine phase + if step < cosine_start_step: + return 1.0 + t = step - cosine_start_step + if t >= T_max: + return final_lr_fraction + # Decay from 1.0 to final_lr_fraction + return final_lr_fraction + 0.5 * (1 - final_lr_fraction) * ( + 1 + math.cos(math.pi * t / T_max) + ) + constant_scheduler = torch.optim.lr_scheduler.ConstantLR( - optimizer, factor=1.0 - ) # TODO this is only because of a bug in llm-random - cosine_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( - optimizer, - T_max=after_warmup_steps, - eta_min=final_lr_fraction * optimizer_lr, + optimizer, factor=1.0, total_iters=1 + ) + # Use LambdaLR to allow different base/min LRs per parameter group (proportional decay) + cosine_scheduler = torch.optim.lr_scheduler.LambdaLR( + optimizer, lr_lambda=cosine_lambda ) + training_scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[warmup, constant_scheduler, cosine_scheduler], diff --git a/src/product_keys/model.py b/src/product_keys/model.py index e6715964..cd373f6e 100644 --- a/src/product_keys/model.py +++ b/src/product_keys/model.py @@ -315,7 +315,6 @@ def __init__( n_sub_keys: int, k_neighbors: int, n_heads: int = 4, - **kwargs, # To ignore unused args ): super().__init__() self.n_heads = n_heads From b2e625c5373270109042e6f54f5676c8e026fa51 Mon Sep 17 00:00:00 2001 From: Wojciech Weremczuk Date: Tue, 21 Apr 2026 22:09:34 +0200 Subject: [PATCH 9/9] Test PK in FF on CLM task --- configs/product_keys/baseline.yaml | 45 +++++++++++++----------- configs/product_keys/pkm.yaml | 55 ++++++++++++++++++------------ 2 files changed, 59 insertions(+), 41 deletions(-) diff --git a/configs/product_keys/baseline.yaml b/configs/product_keys/baseline.yaml index 769435cb..27cc1919 100644 --- a/configs/product_keys/baseline.yaml +++ b/configs/product_keys/baseline.yaml @@ -17,36 +17,41 @@ common: n_blocks: 16 q_heads: 16 kv_heads: 16 - vocab_size: 50304 + # vocab_size: 50304 + vocab_size: 128256 + +# trainer: +# _target_: src.product_keys.trainer.MaskedLMTrainer +# masking_percentage: 0.2 +# mask_token_id: 50257 +# unmaskable_special_tokens: [50256, 50257] # <|endoftext|> +# gradient_accumulation_steps: 2 +# n_steps: 77050 +# # ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] +# learning_rate: 1e-3 +# train_dataloader: +# dataset: +# tokenize_fn: +# _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn +# eval_dataloader: +# dataset: +# tokenize_fn: +# _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn trainer: - _target_: src.product_keys.trainer.MaskedLMTrainer - masking_percentage: 0.2 - mask_token_id: 50257 - unmaskable_special_tokens: [50256, 50257] # <|endoftext|> - gradient_accumulation_steps: 2 n_steps: 77050 - # ^learning_rate: [1e-4, 2e-4, 5e-4, 1e-3, 2e-3, 5e-3] - learning_rate: 1e-3 - train_dataloader: - dataset: - tokenize_fn: - _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn - eval_dataloader: - dataset: - tokenize_fn: - _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn - + ^learning_rate: [5e-4, 1e-3, 2e-3] + # learning_rate: 1e-3 infrastructure: metric_logger: type: wandb wandb_entity: ideas_cv - name: baseline + name: baseline_causal project_name: tml-bgw tags: - baseline - - mlm + - clm - "seq_len=${common.sequence_length}" - "n_layers=${common.n_blocks}" - "dmodel=${common.dmodel}" @@ -85,7 +90,7 @@ model: q_heads: ${common.q_heads} kv_heads: ${common.kv_heads} seq_len: ${common.sequence_length} - causal: false + causal: true q_proj_fn: _target_: src.projected_compression.model.Linear diff --git a/configs/product_keys/pkm.yaml b/configs/product_keys/pkm.yaml index bb87afab..e4b80c01 100644 --- a/configs/product_keys/pkm.yaml +++ b/configs/product_keys/pkm.yaml @@ -1,6 +1,6 @@ # @package _global_ defaults: - - /_cluster/helios@_here_ + - /_cluster/entropy@_here_ - /_model/tiny@_here_ - /_trainer/llama@_here_ - /_dataset/c4@_here_ @@ -17,25 +17,46 @@ common: n_blocks: 16 q_heads: 16 kv_heads: 16 - vocab_size: 50304 # GPT-2 vocab + # vocab_size: 50304 # GPT-2 vocab + vocab_size: 128256 # pkm_n_sub_keys: 128 # 128^2 = 16,384 memory slots - # pkm_n_sub_keys: 256 # 256^2 = 65,536 memory slots + pkm_n_sub_keys: 256 # 256^2 = 65,536 memory slots # pkm_n_sub_keys: 384 # 384^2 = 147,456 memory slots - pkm_n_sub_keys: 512 # 512^2 = 262,144 memory slots + # pkm_n_sub_keys: 512 # 512^2 = 262,144 memory slots # pkm_n_sub_keys: 768 # 768^2 = 589,824 memory slots # pkm_n_sub_keys: 1024 # 1024^2 = 1,048,576 memory slots pkm_k: 32 - pkm_query_dim: 512 + # pkm_query_dim: 512 + pkm_query_dim: 1024 pkm_n_heads: 4 pkm_indices: [7, 14] +# trainer: +# _target_: src.product_keys.trainer.MaskedLMTrainer +# masking_percentage: 0.2 +# mask_token_id: 50257 +# unmaskable_special_tokens: [50256, 50257] # <|endoftext|> +# gradient_accumulation_steps: 2 +# n_steps: 77050 +# learning_rate: 1e-3 +# ^learning_rate: [5e-4, 1e-3, 2e-3] +# optimizer_param_groups: +# - regex: ".*ff_layer.layer.values.*" +# lr: "${eval:'1.0 * ${trainer.learning_rate}'}" +# indices_filter: +# layer_path: "encoder.blocks" +# indices: ${common.pkm_indices} +# train_dataloader: +# dataset: +# tokenize_fn: +# _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn +# eval_dataloader: +# dataset: +# tokenize_fn: +# _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn + trainer: - _target_: src.product_keys.trainer.MaskedLMTrainer - masking_percentage: 0.2 - mask_token_id: 50257 - unmaskable_special_tokens: [50256, 50257] # <|endoftext|> - gradient_accumulation_steps: 2 n_steps: 77050 learning_rate: 1e-3 ^learning_rate: [5e-4, 1e-3, 2e-3] @@ -45,24 +66,16 @@ trainer: indices_filter: layer_path: "encoder.blocks" indices: ${common.pkm_indices} - train_dataloader: - dataset: - tokenize_fn: - _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn - eval_dataloader: - dataset: - tokenize_fn: - _target_: src.product_keys.datasets.gpt2_mask_tokenize_fn infrastructure: metric_logger: type: wandb wandb_entity: ideas_cv - name: "pkm" + name: pkm_causal_256_query_dim_1024 project_name: tml-bgw tags: - pkm - - mlm + - clm - "pkm_n_sub_keys=${common.pkm_n_sub_keys}" - "pkm_k=${common.pkm_k}" - "pkm_indices=${model.encoder.block_fn.pkm_indices}" @@ -106,7 +119,7 @@ model: q_heads: ${common.q_heads} kv_heads: ${common.kv_heads} seq_len: ${common.sequence_length} - causal: false + causal: true q_proj_fn: _target_: src.projected_compression.model.Linear