Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ def __init__(
seq_len,
rope_base,
rope_scale_freqs: bool,
causal=True,
):
super().__init__()
self.q_proj = q_proj_fn()
Expand All @@ -393,7 +392,6 @@ def __init__(
self.kv_heads = kv_heads
self.dhead = self.q_proj.weight.shape[0] // self.q_heads
self.dmodel = dmodel
self.causal = causal

self.rope = RoPE(
dhead=self.dhead,
Expand All @@ -420,7 +418,7 @@ def forward(self, x):
k = repeat_kv(k, self.q_heads // self.kv_heads)
v = repeat_kv(v, self.q_heads // self.kv_heads)
attention_output = self.attention_mechanism(
query=q, key=k, value=v, causal=self.causal
query=q, key=k, value=v, causal=True
)

output = self.o_proj(attention_output.transpose(1, 2).contiguous().flatten(-2))
Expand Down
73 changes: 46 additions & 27 deletions src/product_keys/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(
layer=attention_fn(),
log_name=f"{self.log_name}/residual_attention",
)

if block_id in pkm_indices:
selected_layer = pkm_layer_fn()
layer_type_name = "pkm"
Expand Down Expand Up @@ -89,7 +89,7 @@ def __init__(

self.top_k = top_k
self.top_k_before_softmax = top_k_before_softmax

self.causal = causal

self.rope = RoPE(
Expand Down Expand Up @@ -133,7 +133,7 @@ def forward(self, x):
)

attention_scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.dhead)

if self.causal:
causal_mask = torch.triu(
torch.ones(seq_len, seq_len, device=attention_scores.device), diagonal=1
Expand Down Expand Up @@ -415,13 +415,12 @@ def forward(self, x):

class ProductKeysMemory(nn.Module):
def __init__(
self,
d_model: int,
query_dim: int,
n_sub_keys: int,
k_neighbors: int,
self,
d_model: int,
query_dim: int,
n_sub_keys: int,
k_neighbors: int,
n_heads: int = 4,
**kwargs, # To ignore unused args
):
super().__init__()
self.n_heads = n_heads
Expand All @@ -436,8 +435,10 @@ def __init__(

# Sub-Keys (Codebooks)
# Two separate sets of keys for the product quantization
self.c1 = nn.Parameter(torch.randn(n_heads, n_sub_keys, query_dim // 2))
self.c2 = nn.Parameter(torch.randn(n_heads, n_sub_keys, query_dim // 2))
self.c1 = nn.Parameter(torch.empty(n_heads, n_sub_keys, query_dim // 2))
self.c2 = nn.Parameter(torch.empty(n_heads, n_sub_keys, query_dim // 2))
nn.init.normal_(self.c1, mean=0, std=d_model**-0.5)
nn.init.normal_(self.c2, mean=0, std=d_model**-0.5)

# Memory Values
# The actual values retrieved. Size is (n_sub_keys^2, d_model)
Expand All @@ -450,10 +451,10 @@ def _get_knn(self, queries, codebooks):
"""
# queries: (batch, head, sub_dim)
# codebooks: (head, n_sub_keys, sub_dim)

# Calculate similarity (dot product)
scores = torch.einsum("bhd,hnd->bhn", queries, codebooks)

# Select top-k
top_scores, top_indices = torch.topk(scores, k=self.k, dim=-1, largest=True)
return top_scores, top_indices
Expand All @@ -476,12 +477,12 @@ def forward(self, x):

# 3. Cartesian Product of Scores
# Sum every score from the first half with every score from the second half
# (BS, H, K, 1) + (BS, H, 1, K) -> (BS, H, K, K)
# (BS*Seq, H, K, 1) + (BS*Seq, H, 1, K) -> (BS*Seq, H, K, K)
all_scores = scores1.unsqueeze(3) + scores2.unsqueeze(2)

# Flatten the KxK grid to K^2 to find the global top-k
all_scores_flat = all_scores.view(bs * seq_len, self.n_heads, -1)

# Select the best combinations (global top-k)
global_scores, global_top_indices = torch.topk(all_scores_flat, self.k, dim=-1)

Expand All @@ -498,17 +499,35 @@ def forward(self, x):
memory_indices = real_idx1 * self.n_sub_keys + real_idx2

# 5. Read from Memory
attn_weights = F.softmax(global_scores, dim=-1) # (BS, H, K)

flat_indices = memory_indices.view(-1)
values_selected = self.values(flat_indices)
values_selected = values_selected.view(bs * seq_len, self.n_heads, self.k, d_model)

# Weighted sum of retrieved values
out_heads = (values_selected * attn_weights.unsqueeze(-1)).sum(dim=2)
attn_weights = F.softmax(global_scores, dim=-1) # (BS*Seq, H, K)

# Flatten indices and weights to the format expected by embedding_bag
# The "bag" dimension is (BS * Seq * Heads), with K elements in each bag
flat_indices = memory_indices.view(-1, self.k)
flat_weights = attn_weights.view(-1, self.k)

# Fused Lookup + Weighted Sum
# We avoid creating the massive (BS*Seq, H, K, d_model) tensor by using embedding_bag
is_bfloat16 = flat_weights.dtype == torch.bfloat16
if is_bfloat16:
flat_weights_fp32 = flat_weights.to(torch.float32)
values_weight_fp32 = self.values.weight.to(torch.float32)
out_flat = F.embedding_bag(
input=flat_indices,
weight=values_weight_fp32,
per_sample_weights=flat_weights_fp32,
mode="sum",
)
out_flat = out_flat.to(torch.bfloat16)
else:
out_flat = F.embedding_bag(
input=flat_indices,
weight=self.values.weight,
per_sample_weights=flat_weights,
mode="sum",
)

# 6. Aggregation
# Sum outputs across all heads
output = out_heads.sum(dim=1) # (BS, d_model)

return output.view(bs, seq_len, d_model)
# Restore the correct dimensions: (BS*Seq, H, d_model) -> (BS, Seq, H, d_model)
out_flat = out_flat.view(bs, seq_len, self.n_heads, d_model)
return out_flat.sum(dim=2) # Output: (BS, Seq, d_model)