diff --git a/init2winit/model_lib/mdlm_rope_nanodo.py b/init2winit/model_lib/mdlm_rope_nanodo.py index f45fdd31..b42f25ab 100644 --- a/init2winit/model_lib/mdlm_rope_nanodo.py +++ b/init2winit/model_lib/mdlm_rope_nanodo.py @@ -91,18 +91,7 @@ def setup(self): ) self.blocks = [rope_nanodo.TBlock(cfg) for _ in range(cfg.N)] - if cfg.normalization == 'layernorm': - self.out_ln = nn.LayerNorm( - dtype=cfg.dtype, param_dtype=cfg.param_dtype, use_bias=False - ) - elif cfg.normalization == 'rmsnorm': - self.out_ln = nn.RMSNorm( - dtype=cfg.dtype, - param_dtype=cfg.param_dtype, - epsilon=cfg.rmsnorm_epsilon, - ) - else: - raise ValueError(f'Unknown normalization: {cfg.normalization}') + self.out_ln = cfg.make_norm() if cfg.tie_embeddings: self.output_proj = None diff --git a/init2winit/model_lib/rope_nanodo.py b/init2winit/model_lib/rope_nanodo.py index ffda5592..1046ac4f 100644 --- a/init2winit/model_lib/rope_nanodo.py +++ b/init2winit/model_lib/rope_nanodo.py @@ -26,15 +26,12 @@ from flax import linen as nn from init2winit import utils from init2winit.model_lib import base_model -from init2winit.model_lib import model_utils import jax import jax.numpy as jnp from ml_collections.config_dict import config_dict partial = functools.partial -ParameterType = model_utils.ParameterType -NamedSharding = jax.sharding.NamedSharding -P = jax.sharding.PartitionSpec + DEFAULT_HPARAMS = config_dict.ConfigDict( dict( @@ -48,6 +45,8 @@ mlp_activation='glu', qk_norm=True, tie_embeddings=True, + use_residual_scaling=False, + initializer='xavier', ) ) @@ -62,10 +61,6 @@ class DoConfig: V: int # vocab size F: int # FF inner dimension L: int # sequence length - kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() - embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal', out_axis=0 - ) dtype: jnp.dtype = jnp.bfloat16 param_dtype: jnp.dtype = jnp.float32 rmsnorm_epsilon: float = 1e-6 @@ -76,6 +71,50 @@ class DoConfig: qk_norm: bool = True is_causal: bool = True eps: float = 1e-6 + use_residual_scaling: bool = False + initializer: str = 'xavier' # 'xavier' or 'std0.02' + + # Derived initializers (set in __post_init__) + attention_init: nn.initializers.Initializer = dataclasses.field(init=False) + linear_init: nn.initializers.Initializer = dataclasses.field(init=False) + embed_init: nn.initializers.Initializer = dataclasses.field(init=False) + residual_init: nn.initializers.Initializer = dataclasses.field(init=False) + + def __post_init__(self): + if self.initializer == 'xavier': + self.attention_init = nn.initializers.xavier_uniform() + self.linear_init = nn.initializers.xavier_uniform() + self.embed_init = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0 + ) + elif self.initializer == 'std0.02': + self.attention_init = nn.initializers.normal(stddev=0.02) + self.linear_init = nn.initializers.normal(stddev=0.02) + self.embed_init = nn.initializers.normal(stddev=0.02) + else: + raise ValueError(f'Unknown initializer: {self.initializer}') + + if self.use_residual_scaling and self.initializer == 'std0.02': + self.residual_init = nn.initializers.normal( + stddev=0.02 / jnp.sqrt(2 * self.N) + ) + else: + self.residual_init = self.linear_init + + def make_norm(self): + """Returns a normalization layer based on config.""" + if self.normalization == 'layernorm': + return nn.LayerNorm( + dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False + ) + elif self.normalization == 'rmsnorm': + return nn.RMSNorm( + dtype=self.dtype, + param_dtype=self.param_dtype, + epsilon=self.rmsnorm_epsilon, + ) + else: + raise ValueError(f'Unknown normalization: {self.normalization}') class Mlp(nn.Module): @@ -86,10 +125,9 @@ class Mlp(nn.Module): @nn.compact def __call__(self, x_BxLxD: jax.Array): cfg = self.cfg - # Use Xavier uniform initialization explicitly linear = partial( nn.Dense, - kernel_init=cfg.kernel_init, + kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype, param_dtype=cfg.param_dtype, @@ -112,7 +150,15 @@ def __call__(self, x_BxLxD: jax.Array): else: raise ValueError(f'Unknown activation: {cfg.mlp_activation}') x_BxLxF = mlp_activation(x_BxLxF) - x_BxLxD = linear(cfg.D)(x_BxLxF) + x_BxLxD = nn.Dense( + cfg.D, + kernel_init=cfg.residual_init + if cfg.use_residual_scaling + else cfg.linear_init, + use_bias=False, + dtype=cfg.dtype, + param_dtype=cfg.param_dtype, + )(x_BxLxF) return x_BxLxD @@ -178,7 +224,7 @@ def setup(self): nn.DenseGeneral, axis=-1, features=(cfg.H, self.Dh), - kernel_init=cfg.kernel_init, + kernel_init=cfg.attention_init, use_bias=False, dtype=cfg.dtype, param_dtype=cfg.param_dtype, @@ -191,7 +237,9 @@ def setup(self): features=cfg.D, name='attn_out_proj', # axis=(-2, -1), # - kernel_init=cfg.kernel_init, + kernel_init=cfg.residual_init + if cfg.use_residual_scaling + else cfg.linear_init, use_bias=False, dtype=cfg.dtype, param_dtype=cfg.param_dtype, @@ -246,24 +294,13 @@ class TBlock(nn.Module): def __call__(self, in_BxLxD: jax.Array): cfg = self.docfg - # "pre-layernorm" - if cfg.normalization == 'layernorm': - x_BxLxD = nn.LayerNorm( - dtype=cfg.dtype, param_dtype=cfg.param_dtype, use_bias=False - )(in_BxLxD) - elif cfg.normalization == 'rmsnorm': - x_BxLxD = nn.RMSNorm( - dtype=cfg.dtype, - param_dtype=cfg.param_dtype, - epsilon=cfg.rmsnorm_epsilon, - )(in_BxLxD) - else: - raise ValueError(f'Unknown normalization: {cfg.normalization}') + x_BxLxD = cfg.make_norm()(in_BxLxD) x_BxLxD = Attention(cfg)(x_BxLxD) x_BxLxD += in_BxLxD - z_BxLxD = Mlp(cfg)(x_BxLxD) + z_BxLxD = cfg.make_norm()(x_BxLxD) + z_BxLxD = Mlp(cfg)(z_BxLxD) return x_BxLxD + z_BxLxD @@ -282,27 +319,8 @@ def setup(self): dtype=cfg.dtype, param_dtype=cfg.param_dtype, ) - self.pos_embed = nn.Embed( - num_embeddings=cfg.L, - features=cfg.D, - embedding_init=cfg.embed_init, - dtype=cfg.dtype, - param_dtype=cfg.param_dtype, - ) - self.blocks = [TBlock(cfg) for _ in range(cfg.N)] - if cfg.normalization == 'layernorm': - self.out_ln = nn.LayerNorm( - dtype=cfg.dtype, param_dtype=cfg.param_dtype, use_bias=False - ) - elif cfg.normalization == 'rmsnorm': - self.out_ln = nn.RMSNorm( - dtype=cfg.dtype, - param_dtype=cfg.param_dtype, - epsilon=cfg.rmsnorm_epsilon, - ) - else: - raise ValueError(f'Unknown normalization: {cfg.normalization}') + self.out_ln = cfg.make_norm() # Output projection - tied to input embeddings if configured if cfg.tie_embeddings: @@ -344,6 +362,8 @@ def build_flax_module(self): normalization=self.hps['normalization'], qk_norm=self.hps['qk_norm'], tie_embeddings=self.hps['tie_embeddings'], + use_residual_scaling=self.hps['use_residual_scaling'], + initializer=self.hps['initializer'], ) return TransformerDo(config)