Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 39 additions & 64 deletions tunix/models/gemma/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -543,39 +543,37 @@ class FeedForward(nnx.Module):

def __init__(
self,
features: int,
hidden_dim: int,
config: ModelConfig,
*,
rngs: nnx.Rngs,
shd_config: ShardingConfig = ShardingConfig.get_default_sharding(),
):
self.shd_config = shd_config
self.config = config
kernel_init_fn = nnx.initializers.zeros_init()
self.gate_proj = nnx.Linear(
in_features=features,
out_features=hidden_dim,
in_features=config.embed_dim,
out_features=config.hidden_dim,
use_bias=False,
rngs=rngs,
kernel_init=nnx.with_partitioning(
kernel_init_fn, shd_config.ffw_weight_df
kernel_init_fn, config.shd_config.ffw_weight_df
),
)
self.up_proj = nnx.Linear(
in_features=features,
out_features=hidden_dim,
in_features=config.embed_dim,
out_features=config.hidden_dim,
use_bias=False,
rngs=rngs,
kernel_init=nnx.with_partitioning(
kernel_init_fn, shd_config.ffw_weight_df
kernel_init_fn, config.shd_config.ffw_weight_df
),
)
self.down_proj = nnx.Linear(
in_features=hidden_dim,
out_features=features,
in_features=config.hidden_dim,
out_features=config.embed_dim,
use_bias=False,
rngs=rngs,
kernel_init=nnx.with_partitioning(
kernel_init_fn, shd_config.ffw_weight_fd
kernel_init_fn, config.shd_config.ffw_weight_fd
),
)

Expand All @@ -586,7 +584,7 @@ def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array:

ff1 = self.up_proj(x)
activations = gate_value * ff1
activations = shard(activations, self.shd_config.act_btf)
activations = shard(activations, self.config.shd_config.act_btf)

outputs = self.down_proj(activations)
return outputs
Expand All @@ -597,48 +595,43 @@ class Block(nnx.Module):

def __init__(
self,
num_heads: int,
num_kv_heads: int,
embed_dim: int,
head_dim: int,
hidden_dim: int,
use_post_attn_norm: bool,
use_post_ffw_norm: bool,
attn_type: AttentionType,
config: ModelConfig,
*,
attn_type: AttentionType,
rngs: nnx.Rngs,
attn_logits_soft_cap: float | None,
sliding_window_size: int | None = None,
shd_config: ShardingConfig = ShardingConfig.get_default_sharding(),
remat_config: RematConfig = RematConfig.BLOCK,
):
self.config = config
self.pre_attention_norm = RMSNorm(
embed_dim, rngs=rngs, shd_config=shd_config
config.embed_dim, rngs=rngs, shd_config=config.shd_config
)
self.attn = Attention(
num_heads=num_heads,
num_kv_heads=num_kv_heads,
features=embed_dim,
head_dim=head_dim,
num_heads=config.num_heads,
num_kv_heads=config.num_kv_heads,
features=config.embed_dim,
head_dim=config.head_dim,
attn_type=attn_type,
attn_logits_soft_cap=attn_logits_soft_cap,
sliding_window_size=sliding_window_size,
attn_logits_soft_cap=config.attn_logits_soft_cap,
sliding_window_size=config.sliding_window_size,
rngs=rngs,
shd_config=shd_config,
remat_config=remat_config,
shd_config=config.shd_config,
remat_config=config.remat_config,
)
if use_post_attn_norm:
self.post_attn_norm = RMSNorm(embed_dim, rngs=rngs, shd_config=shd_config)
if config.use_post_attn_norm:
self.post_attn_norm = RMSNorm(
config.embed_dim, rngs=rngs, shd_config=config.shd_config
)

self.pre_ffw_norm = RMSNorm(embed_dim, rngs=rngs, shd_config=shd_config)
self.pre_ffw_norm = RMSNorm(
config.embed_dim, rngs=rngs, shd_config=config.shd_config
)
self.mlp = FeedForward(
features=embed_dim,
hidden_dim=hidden_dim,
config=config,
rngs=rngs,
shd_config=shd_config,
)
if use_post_ffw_norm:
self.post_ffw_norm = RMSNorm(embed_dim, rngs=rngs, shd_config=shd_config)
if config.use_post_ffw_norm:
self.post_ffw_norm = RMSNorm(
config.embed_dim, rngs=rngs, shd_config=config.shd_config
)

def __call__(
self,
Expand All @@ -655,28 +648,20 @@ def __call__(
attn_mask,
)

if self.use_post_attn_norm:
if self.config.use_post_attn_norm:
attn_output = self.post_attn_norm(attn_output)

attn_output += x

outputs = self.pre_ffw_norm(attn_output)
outputs = self.mlp(outputs)

if self.use_post_ffw_norm:
if self.config.use_post_ffw_norm:
outputs = self.post_ffw_norm(outputs)

outputs += attn_output
return cache, outputs

@property
def use_post_attn_norm(self):
return hasattr(self, 'post_attn_norm') and self.post_attn_norm is not None

@property
def use_post_ffw_norm(self):
return hasattr(self, 'post_ffw_norm') and self.post_ffw_norm is not None


class RMSNorm(nnx.Module):
"""RMSNorm layer."""
Expand Down Expand Up @@ -864,19 +849,9 @@ def __init__(
)
self.layers = compat.ModuleList([
Block(
num_heads=config.num_heads,
num_kv_heads=config.num_kv_heads,
embed_dim=config.embed_dim,
head_dim=config.head_dim,
hidden_dim=config.hidden_dim,
sliding_window_size=config.sliding_window_size,
use_post_attn_norm=config.use_post_attn_norm,
use_post_ffw_norm=config.use_post_ffw_norm,
attn_logits_soft_cap=config.attn_logits_soft_cap,
config=config,
attn_type=attn_type,
rngs=rngs,
shd_config=config.shd_config,
remat_config=config.remat_config,
)
for _, attn_type in zip(
range(config.num_layers), config.attention_types
Expand Down
Loading