diff --git a/tunix/models/gemma/model.py b/tunix/models/gemma/model.py index 4f8385f14..d58962221 100644 --- a/tunix/models/gemma/model.py +++ b/tunix/models/gemma/model.py @@ -543,39 +543,37 @@ class FeedForward(nnx.Module): def __init__( self, - features: int, - hidden_dim: int, + config: ModelConfig, *, rngs: nnx.Rngs, - shd_config: ShardingConfig = ShardingConfig.get_default_sharding(), ): - self.shd_config = shd_config + self.config = config kernel_init_fn = nnx.initializers.zeros_init() self.gate_proj = nnx.Linear( - in_features=features, - out_features=hidden_dim, + in_features=config.embed_dim, + out_features=config.hidden_dim, use_bias=False, rngs=rngs, kernel_init=nnx.with_partitioning( - kernel_init_fn, shd_config.ffw_weight_df + kernel_init_fn, config.shd_config.ffw_weight_df ), ) self.up_proj = nnx.Linear( - in_features=features, - out_features=hidden_dim, + in_features=config.embed_dim, + out_features=config.hidden_dim, use_bias=False, rngs=rngs, kernel_init=nnx.with_partitioning( - kernel_init_fn, shd_config.ffw_weight_df + kernel_init_fn, config.shd_config.ffw_weight_df ), ) self.down_proj = nnx.Linear( - in_features=hidden_dim, - out_features=features, + in_features=config.hidden_dim, + out_features=config.embed_dim, use_bias=False, rngs=rngs, kernel_init=nnx.with_partitioning( - kernel_init_fn, shd_config.ffw_weight_fd + kernel_init_fn, config.shd_config.ffw_weight_fd ), ) @@ -586,7 +584,7 @@ def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: ff1 = self.up_proj(x) activations = gate_value * ff1 - activations = shard(activations, self.shd_config.act_btf) + activations = shard(activations, self.config.shd_config.act_btf) outputs = self.down_proj(activations) return outputs @@ -597,48 +595,43 @@ class Block(nnx.Module): def __init__( self, - num_heads: int, - num_kv_heads: int, - embed_dim: int, - head_dim: int, - hidden_dim: int, - use_post_attn_norm: bool, - use_post_ffw_norm: bool, - attn_type: AttentionType, + config: ModelConfig, *, + attn_type: AttentionType, rngs: nnx.Rngs, - attn_logits_soft_cap: float | None, - sliding_window_size: int | None = None, - shd_config: ShardingConfig = ShardingConfig.get_default_sharding(), - remat_config: RematConfig = RematConfig.BLOCK, ): + self.config = config self.pre_attention_norm = RMSNorm( - embed_dim, rngs=rngs, shd_config=shd_config + config.embed_dim, rngs=rngs, shd_config=config.shd_config ) self.attn = Attention( - num_heads=num_heads, - num_kv_heads=num_kv_heads, - features=embed_dim, - head_dim=head_dim, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + features=config.embed_dim, + head_dim=config.head_dim, attn_type=attn_type, - attn_logits_soft_cap=attn_logits_soft_cap, - sliding_window_size=sliding_window_size, + attn_logits_soft_cap=config.attn_logits_soft_cap, + sliding_window_size=config.sliding_window_size, rngs=rngs, - shd_config=shd_config, - remat_config=remat_config, + shd_config=config.shd_config, + remat_config=config.remat_config, ) - if use_post_attn_norm: - self.post_attn_norm = RMSNorm(embed_dim, rngs=rngs, shd_config=shd_config) + if config.use_post_attn_norm: + self.post_attn_norm = RMSNorm( + config.embed_dim, rngs=rngs, shd_config=config.shd_config + ) - self.pre_ffw_norm = RMSNorm(embed_dim, rngs=rngs, shd_config=shd_config) + self.pre_ffw_norm = RMSNorm( + config.embed_dim, rngs=rngs, shd_config=config.shd_config + ) self.mlp = FeedForward( - features=embed_dim, - hidden_dim=hidden_dim, + config=config, rngs=rngs, - shd_config=shd_config, ) - if use_post_ffw_norm: - self.post_ffw_norm = RMSNorm(embed_dim, rngs=rngs, shd_config=shd_config) + if config.use_post_ffw_norm: + self.post_ffw_norm = RMSNorm( + config.embed_dim, rngs=rngs, shd_config=config.shd_config + ) def __call__( self, @@ -655,7 +648,7 @@ def __call__( attn_mask, ) - if self.use_post_attn_norm: + if self.config.use_post_attn_norm: attn_output = self.post_attn_norm(attn_output) attn_output += x @@ -663,20 +656,12 @@ def __call__( outputs = self.pre_ffw_norm(attn_output) outputs = self.mlp(outputs) - if self.use_post_ffw_norm: + if self.config.use_post_ffw_norm: outputs = self.post_ffw_norm(outputs) outputs += attn_output return cache, outputs - @property - def use_post_attn_norm(self): - return hasattr(self, 'post_attn_norm') and self.post_attn_norm is not None - - @property - def use_post_ffw_norm(self): - return hasattr(self, 'post_ffw_norm') and self.post_ffw_norm is not None - class RMSNorm(nnx.Module): """RMSNorm layer.""" @@ -864,19 +849,9 @@ def __init__( ) self.layers = compat.ModuleList([ Block( - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - embed_dim=config.embed_dim, - head_dim=config.head_dim, - hidden_dim=config.hidden_dim, - sliding_window_size=config.sliding_window_size, - use_post_attn_norm=config.use_post_attn_norm, - use_post_ffw_norm=config.use_post_ffw_norm, - attn_logits_soft_cap=config.attn_logits_soft_cap, + config=config, attn_type=attn_type, rngs=rngs, - shd_config=config.shd_config, - remat_config=config.remat_config, ) for _, attn_type in zip( range(config.num_layers), config.attention_types diff --git a/tunix/models/gemma3/model.py b/tunix/models/gemma3/model.py index 8d54b0b90..6da22b02c 100644 --- a/tunix/models/gemma3/model.py +++ b/tunix/models/gemma3/model.py @@ -733,128 +733,121 @@ class FeedForward(nnx.Module): def __init__( self, - features: int, - hidden_dim: int, + config: ModelConfig, *, rngs: nnx.Rngs, - shd_config: ShardingConfig = ShardingConfig.get_default_sharding(), - param_dtype: jnp.dtype = jnp.bfloat16, ): - self.shd_config = shd_config + self.config = config kernel_init_fn = nnx.initializers.zeros_init() self.gate_proj = nnx.Linear( - in_features=features, - out_features=hidden_dim, + in_features=config.embed_dim, + out_features=config.hidden_dim, use_bias=False, rngs=rngs, - param_dtype=param_dtype, + param_dtype=config.param_dtype, kernel_init=nnx.with_partitioning( - kernel_init_fn, shd_config.ffw_weight_df + kernel_init_fn, config.shd_config.ffw_weight_df ), ) self.up_proj = nnx.Linear( - in_features=features, - out_features=hidden_dim, + in_features=config.embed_dim, + out_features=config.hidden_dim, use_bias=False, rngs=rngs, - param_dtype=param_dtype, + param_dtype=config.param_dtype, kernel_init=nnx.with_partitioning( - kernel_init_fn, shd_config.ffw_weight_df + kernel_init_fn, config.shd_config.ffw_weight_df ), ) self.down_proj = nnx.Linear( - in_features=hidden_dim, - out_features=features, + in_features=config.hidden_dim, + out_features=config.embed_dim, use_bias=False, rngs=rngs, - param_dtype=param_dtype, + param_dtype=config.param_dtype, kernel_init=nnx.with_partitioning( - kernel_init_fn, shd_config.ffw_weight_fd + kernel_init_fn, config.shd_config.ffw_weight_fd ), ) - @jax.named_scope('feed_forward') - def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: - with jax.named_scope('gate_proj'): - ff_gate = self.gate_proj(x) + def block( + self, + x: jaxtyping.Array, + ) -> jaxtyping.Array: + ff_gate = self.gate_proj(x) gate_value = nnx.gelu(ff_gate) - - with jax.named_scope('up_proj'): - ff1 = self.up_proj(x) + ff1 = self.up_proj(x) activations = gate_value * ff1 - activations = sharding_utils.shard(activations, self.shd_config.act_btf) - - with jax.named_scope('down_proj'): - outputs = self.down_proj(activations) + activations = sharding_utils.shard( + activations, self.config.shd_config.act_btf + ) + outputs = self.down_proj(activations) return outputs + @jax.named_scope('feed_forward') + def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: + if self.config.remat_config == RematConfig.BLOCK: + return nnx.remat(self.block.__func__)(self, x) + else: + return self.block(x) + -class Block(nnx.Module): +class DecoderLayer(nnx.Module): """Transformer block.""" def __init__( self, - *, - num_heads: int, - num_kv_heads: int, - embed_dim: int, - head_dim: int, - hidden_dim: int, + config: ModelConfig, attn_type: AttentionType, + *, rngs: nnx.Rngs, - sliding_window_size: int | None, - rope_base_frequency: int, - rope_scale_factor: float, - query_pre_attn_norm: QueryPreAttentionNormalisation, - shd_config: ShardingConfig = ShardingConfig.get_default_sharding(), - remat_config: RematConfig = RematConfig.NONE, - param_dtype: jnp.dtype = jnp.bfloat16, ): self.pre_attention_norm = RMSNorm( - embed_dim, + config.embed_dim, rngs=rngs, - sharding=shd_config.rms_norm_weight, - param_dtype=param_dtype, + sharding=config.shd_config.rms_norm_weight, + param_dtype=config.param_dtype, ) self.attn = Attention( - num_heads=num_heads, - num_kv_heads=num_kv_heads, - features=embed_dim, - head_dim=head_dim, + num_heads=config.num_heads, + num_kv_heads=config.num_kv_heads, + features=config.embed_dim, + head_dim=config.head_dim, attn_type=attn_type, - sliding_window_size=sliding_window_size, - rope_base_frequency=rope_base_frequency, - rope_scale_factor=rope_scale_factor, - query_pre_attn_norm=query_pre_attn_norm, + sliding_window_size=config.sliding_window_size, + rope_base_frequency=config.local_base_frequency + if attn_type == AttentionType.LOCAL_SLIDING + else config.global_base_frequency, + rope_scale_factor=config.local_scale_factor + if attn_type == AttentionType.LOCAL_SLIDING + else config.global_scale_factor, + query_pre_attn_norm=config.query_pre_attn_norm, rngs=rngs, - shd_config=shd_config, - remat_config=remat_config, - param_dtype=param_dtype, + shd_config=config.shd_config, + remat_config=config.remat_config, + param_dtype=config.param_dtype, ) self.post_attention_norm = RMSNorm( - embed_dim, + config.embed_dim, rngs=rngs, - sharding=shd_config.rms_norm_weight, - param_dtype=param_dtype, + sharding=config.shd_config.rms_norm_weight, + param_dtype=config.param_dtype, ) self.pre_ffw_norm = RMSNorm( - embed_dim, + config.embed_dim, rngs=rngs, - sharding=shd_config.rms_norm_weight, - param_dtype=param_dtype, + sharding=config.shd_config.rms_norm_weight, + param_dtype=config.param_dtype, ) self.mlp = FeedForward( - features=embed_dim, - hidden_dim=hidden_dim, + config=config, rngs=rngs, - shd_config=shd_config, - param_dtype=param_dtype, ) self.post_ffw_norm = RMSNorm( - embed_dim, + config.embed_dim, rngs=rngs, - sharding=shd_config.rms_norm_weight, - param_dtype=param_dtype, + sharding=config.shd_config.rms_norm_weight, + param_dtype=config.param_dtype, ) def __call__( @@ -938,25 +931,10 @@ def __init__(self, config: ModelConfig, *, rngs: nnx.Rngs): param_dtype=config.param_dtype, ) self.layers = compat.ModuleList([ - Block( - num_heads=config.num_heads, - num_kv_heads=config.num_kv_heads, - embed_dim=config.embed_dim, - head_dim=config.head_dim, - hidden_dim=config.hidden_dim, - sliding_window_size=config.sliding_window_size, + DecoderLayer( + config=config, attn_type=attn_type, - rope_base_frequency=config.local_base_frequency - if attn_type == AttentionType.LOCAL_SLIDING - else config.global_base_frequency, - rope_scale_factor=config.local_scale_factor - if attn_type == AttentionType.LOCAL_SLIDING - else config.global_scale_factor, - query_pre_attn_norm=config.query_pre_attn_norm, rngs=rngs, - shd_config=config.shd_config, - remat_config=config.remat_config, - param_dtype=config.param_dtype, ) for _, attn_type in zip( range(config.num_layers), itertools.cycle(GEMMA3_ATTENTION_PATTERN) diff --git a/tunix/models/llama3/model.py b/tunix/models/llama3/model.py index 7b2358bc9..ec912506b 100644 --- a/tunix/models/llama3/model.py +++ b/tunix/models/llama3/model.py @@ -443,6 +443,7 @@ def __init__( *, rngs: nnx.Rngs, ): + self.config = config self.shd_config = config.shd_config kernel_init_fn = nnx.initializers.zeros_init() self.gate_proj = nnx.Linear( @@ -473,13 +474,22 @@ def __init__( ), ) - @jax.named_scope('feed_forward') - def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: + def block( + self, + x: jaxtyping.Array, + ) -> jaxtyping.Array: activations = nnx.silu(self.gate_proj(x)) * self.up_proj(x) activations = shard(activations, self.shd_config.act_btf) outputs = self.down_proj(activations) return outputs + @jax.named_scope('feed_forward') + def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: + if self.config.remat_config == RematConfig.BLOCK: + return nnx.remat(self.block.__func__)(self, x) + else: + return self.block(x) + class DecoderLayer(nnx.Module): """DecoderLayer.""" diff --git a/tunix/models/qwen2/model.py b/tunix/models/qwen2/model.py index 94a7f066d..f9b79929b 100644 --- a/tunix/models/qwen2/model.py +++ b/tunix/models/qwen2/model.py @@ -514,6 +514,7 @@ def __init__( *, rngs: nnx.Rngs, ): + self.config = config self.shd_config = config.shd_config kernel_init_fn = nnx.initializers.zeros_init() self.gate_proj = nnx.Linear( @@ -544,13 +545,22 @@ def __init__( ), ) - @jax.named_scope('feed_forward') - def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: + def block( + self, + x: jaxtyping.Array, + ) -> jaxtyping.Array: activations = nnx.silu(self.gate_proj(x)) * self.up_proj(x) activations = shard(activations, self.shd_config.act_btf) outputs = self.down_proj(activations) return outputs + @jax.named_scope('feed_forward') + def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: + if self.config.remat_config == RematConfig.BLOCK: + return nnx.remat(self.block.__func__)(self, x) + else: + return self.block(x) + class DecoderLayer(nnx.Module): """DecoderLayer.""" diff --git a/tunix/models/qwen3/model.py b/tunix/models/qwen3/model.py index 135c9093e..d2e8309a4 100644 --- a/tunix/models/qwen3/model.py +++ b/tunix/models/qwen3/model.py @@ -576,6 +576,7 @@ def __init__( *, rngs: nnx.Rngs, ): + self.config = config self.shd_config = config.shd_config kernel_init_fn = nnx.initializers.zeros_init() self.gate_proj = nnx.Linear( @@ -606,13 +607,22 @@ def __init__( ), ) - @jax.named_scope('feed_forward') - def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: + def block( + self, + x: jaxtyping.Array, + ) -> jaxtyping.Array: activations = nnx.silu(self.gate_proj(x)) * self.up_proj(x) activations = shard(activations, self.shd_config.act_btf) outputs = self.down_proj(activations) return outputs + @jax.named_scope('feed_forward') + def __call__(self, x: jaxtyping.ArrayLike) -> jaxtyping.Array: + if self.config.remat_config == RematConfig.BLOCK: + return nnx.remat(self.block.__func__)(self, x) + else: + return self.block(x) + class DecoderLayer(nnx.Module): """DecoderLayer."""