From 1edbae9cd3eb65558cb0bae402187508a5518cb7 Mon Sep 17 00:00:00 2001 From: Wojciech Weremczuk Date: Thu, 23 Apr 2026 11:30:13 +0200 Subject: [PATCH] Product keys in FF fixes --- src/core/model.py | 4 +-- src/product_keys/model.py | 73 ++++++++++++++++++++++++--------------- 2 files changed, 47 insertions(+), 30 deletions(-) diff --git a/src/core/model.py b/src/core/model.py index 9401b128..0d20118e 100644 --- a/src/core/model.py +++ b/src/core/model.py @@ -380,7 +380,6 @@ def __init__( seq_len, rope_base, rope_scale_freqs: bool, - causal=True, ): super().__init__() self.q_proj = q_proj_fn() @@ -393,7 +392,6 @@ def __init__( self.kv_heads = kv_heads self.dhead = self.q_proj.weight.shape[0] // self.q_heads self.dmodel = dmodel - self.causal = causal self.rope = RoPE( dhead=self.dhead, @@ -420,7 +418,7 @@ def forward(self, x): k = repeat_kv(k, self.q_heads // self.kv_heads) v = repeat_kv(v, self.q_heads // self.kv_heads) attention_output = self.attention_mechanism( - query=q, key=k, value=v, causal=self.causal + query=q, key=k, value=v, causal=True ) output = self.o_proj(attention_output.transpose(1, 2).contiguous().flatten(-2)) diff --git a/src/product_keys/model.py b/src/product_keys/model.py index 36a297b4..938f3213 100644 --- a/src/product_keys/model.py +++ b/src/product_keys/model.py @@ -38,7 +38,7 @@ def __init__( layer=attention_fn(), log_name=f"{self.log_name}/residual_attention", ) - + if block_id in pkm_indices: selected_layer = pkm_layer_fn() layer_type_name = "pkm" @@ -89,7 +89,7 @@ def __init__( self.top_k = top_k self.top_k_before_softmax = top_k_before_softmax - + self.causal = causal self.rope = RoPE( @@ -133,7 +133,7 @@ def forward(self, x): ) attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.dhead) - + if self.causal: causal_mask = torch.triu( torch.ones(seq_len, seq_len, device=attention_scores.device), diagonal=1 @@ -415,13 +415,12 @@ def forward(self, x): class ProductKeysMemory(nn.Module): def __init__( - self, - d_model: int, - query_dim: int, - n_sub_keys: int, - k_neighbors: int, + self, + d_model: int, + query_dim: int, + n_sub_keys: int, + k_neighbors: int, n_heads: int = 4, - **kwargs, # To ignore unused args ): super().__init__() self.n_heads = n_heads @@ -436,8 +435,10 @@ def __init__( # Sub-Keys (Codebooks) # Two separate sets of keys for the product quantization - self.c1 = nn.Parameter(torch.randn(n_heads, n_sub_keys, query_dim // 2)) - self.c2 = nn.Parameter(torch.randn(n_heads, n_sub_keys, query_dim // 2)) + self.c1 = nn.Parameter(torch.empty(n_heads, n_sub_keys, query_dim // 2)) + self.c2 = nn.Parameter(torch.empty(n_heads, n_sub_keys, query_dim // 2)) + nn.init.normal_(self.c1, mean=0, std=d_model**-0.5) + nn.init.normal_(self.c2, mean=0, std=d_model**-0.5) # Memory Values # The actual values retrieved. Size is (n_sub_keys^2, d_model) @@ -450,10 +451,10 @@ def _get_knn(self, queries, codebooks): """ # queries: (batch, head, sub_dim) # codebooks: (head, n_sub_keys, sub_dim) - + # Calculate similarity (dot product) scores = torch.einsum("bhd,hnd->bhn", queries, codebooks) - + # Select top-k top_scores, top_indices = torch.topk(scores, k=self.k, dim=-1, largest=True) return top_scores, top_indices @@ -476,12 +477,12 @@ def forward(self, x): # 3. Cartesian Product of Scores # Sum every score from the first half with every score from the second half - # (BS, H, K, 1) + (BS, H, 1, K) -> (BS, H, K, K) + # (BS*Seq, H, K, 1) + (BS*Seq, H, 1, K) -> (BS*Seq, H, K, K) all_scores = scores1.unsqueeze(3) + scores2.unsqueeze(2) # Flatten the KxK grid to K^2 to find the global top-k all_scores_flat = all_scores.view(bs * seq_len, self.n_heads, -1) - + # Select the best combinations (global top-k) global_scores, global_top_indices = torch.topk(all_scores_flat, self.k, dim=-1) @@ -498,17 +499,35 @@ def forward(self, x): memory_indices = real_idx1 * self.n_sub_keys + real_idx2 # 5. Read from Memory - attn_weights = F.softmax(global_scores, dim=-1) # (BS, H, K) - - flat_indices = memory_indices.view(-1) - values_selected = self.values(flat_indices) - values_selected = values_selected.view(bs * seq_len, self.n_heads, self.k, d_model) - - # Weighted sum of retrieved values - out_heads = (values_selected * attn_weights.unsqueeze(-1)).sum(dim=2) + attn_weights = F.softmax(global_scores, dim=-1) # (BS*Seq, H, K) + + # Flatten indices and weights to the format expected by embedding_bag + # The "bag" dimension is (BS * Seq * Heads), with K elements in each bag + flat_indices = memory_indices.view(-1, self.k) + flat_weights = attn_weights.view(-1, self.k) + + # Fused Lookup + Weighted Sum + # We avoid creating the massive (BS*Seq, H, K, d_model) tensor by using embedding_bag + is_bfloat16 = flat_weights.dtype == torch.bfloat16 + if is_bfloat16: + flat_weights_fp32 = flat_weights.to(torch.float32) + values_weight_fp32 = self.values.weight.to(torch.float32) + out_flat = F.embedding_bag( + input=flat_indices, + weight=values_weight_fp32, + per_sample_weights=flat_weights_fp32, + mode="sum", + ) + out_flat = out_flat.to(torch.bfloat16) + else: + out_flat = F.embedding_bag( + input=flat_indices, + weight=self.values.weight, + per_sample_weights=flat_weights, + mode="sum", + ) # 6. Aggregation - # Sum outputs across all heads - output = out_heads.sum(dim=1) # (BS, d_model) - - return output.view(bs, seq_len, d_model) + # Restore the correct dimensions: (BS*Seq, H, d_model) -> (BS, Seq, H, d_model) + out_flat = out_flat.view(bs, seq_len, self.n_heads, d_model) + return out_flat.sum(dim=2) # Output: (BS, Seq, d_model)