Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 1 addition & 12 deletions init2winit/model_lib/mdlm_rope_nanodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,7 @@ def setup(self):
)

self.blocks = [rope_nanodo.TBlock(cfg) for _ in range(cfg.N)]
if cfg.normalization == 'layernorm':
self.out_ln = nn.LayerNorm(
dtype=cfg.dtype, param_dtype=cfg.param_dtype, use_bias=False
)
elif cfg.normalization == 'rmsnorm':
self.out_ln = nn.RMSNorm(
dtype=cfg.dtype,
param_dtype=cfg.param_dtype,
epsilon=cfg.rmsnorm_epsilon,
)
else:
raise ValueError(f'Unknown normalization: {cfg.normalization}')
self.out_ln = cfg.make_norm()

if cfg.tie_embeddings:
self.output_proj = None
Expand Down
114 changes: 67 additions & 47 deletions init2winit/model_lib/rope_nanodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,12 @@
from flax import linen as nn
from init2winit import utils
from init2winit.model_lib import base_model
from init2winit.model_lib import model_utils
import jax
import jax.numpy as jnp
from ml_collections.config_dict import config_dict

partial = functools.partial
ParameterType = model_utils.ParameterType
NamedSharding = jax.sharding.NamedSharding
P = jax.sharding.PartitionSpec


DEFAULT_HPARAMS = config_dict.ConfigDict(
dict(
Expand All @@ -48,6 +45,8 @@
mlp_activation='glu',
qk_norm=True,
tie_embeddings=True,
use_residual_scaling=False,
initializer='xavier',
)
)

Expand All @@ -62,10 +61,6 @@ class DoConfig:
V: int # vocab size
F: int # FF inner dimension
L: int # sequence length
kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform()
embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling(
1.0, 'fan_in', 'normal', out_axis=0
)
dtype: jnp.dtype = jnp.bfloat16
param_dtype: jnp.dtype = jnp.float32
rmsnorm_epsilon: float = 1e-6
Expand All @@ -76,6 +71,50 @@ class DoConfig:
qk_norm: bool = True
is_causal: bool = True
eps: float = 1e-6
use_residual_scaling: bool = False
initializer: str = 'xavier' # 'xavier' or 'std0.02'

# Derived initializers (set in __post_init__)
attention_init: nn.initializers.Initializer = dataclasses.field(init=False)
linear_init: nn.initializers.Initializer = dataclasses.field(init=False)
embed_init: nn.initializers.Initializer = dataclasses.field(init=False)
residual_init: nn.initializers.Initializer = dataclasses.field(init=False)

def __post_init__(self):
if self.initializer == 'xavier':
self.attention_init = nn.initializers.xavier_uniform()
self.linear_init = nn.initializers.xavier_uniform()
self.embed_init = nn.initializers.variance_scaling(
1.0, 'fan_in', 'normal', out_axis=0
)
elif self.initializer == 'std0.02':
self.attention_init = nn.initializers.normal(stddev=0.02)
self.linear_init = nn.initializers.normal(stddev=0.02)
self.embed_init = nn.initializers.normal(stddev=0.02)
else:
raise ValueError(f'Unknown initializer: {self.initializer}')

if self.use_residual_scaling and self.initializer == 'std0.02':
self.residual_init = nn.initializers.normal(
stddev=0.02 / jnp.sqrt(2 * self.N)
)
else:
self.residual_init = self.linear_init

def make_norm(self):
"""Returns a normalization layer based on config."""
if self.normalization == 'layernorm':
return nn.LayerNorm(
dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False
)
elif self.normalization == 'rmsnorm':
return nn.RMSNorm(
dtype=self.dtype,
param_dtype=self.param_dtype,
epsilon=self.rmsnorm_epsilon,
)
else:
raise ValueError(f'Unknown normalization: {self.normalization}')


class Mlp(nn.Module):
Expand All @@ -86,10 +125,9 @@ class Mlp(nn.Module):
@nn.compact
def __call__(self, x_BxLxD: jax.Array):
cfg = self.cfg
# Use Xavier uniform initialization explicitly
linear = partial(
nn.Dense,
kernel_init=cfg.kernel_init,
kernel_init=cfg.linear_init,
use_bias=False,
dtype=cfg.dtype,
param_dtype=cfg.param_dtype,
Expand All @@ -112,7 +150,15 @@ def __call__(self, x_BxLxD: jax.Array):
else:
raise ValueError(f'Unknown activation: {cfg.mlp_activation}')
x_BxLxF = mlp_activation(x_BxLxF)
x_BxLxD = linear(cfg.D)(x_BxLxF)
x_BxLxD = nn.Dense(
cfg.D,
kernel_init=cfg.residual_init
if cfg.use_residual_scaling
else cfg.linear_init,
use_bias=False,
dtype=cfg.dtype,
param_dtype=cfg.param_dtype,
)(x_BxLxF)
return x_BxLxD


Expand Down Expand Up @@ -178,7 +224,7 @@ def setup(self):
nn.DenseGeneral,
axis=-1,
features=(cfg.H, self.Dh),
kernel_init=cfg.kernel_init,
kernel_init=cfg.attention_init,
use_bias=False,
dtype=cfg.dtype,
param_dtype=cfg.param_dtype,
Expand All @@ -191,7 +237,9 @@ def setup(self):
features=cfg.D,
name='attn_out_proj',
# axis=(-2, -1), #
kernel_init=cfg.kernel_init,
kernel_init=cfg.residual_init
if cfg.use_residual_scaling
else cfg.linear_init,
use_bias=False,
dtype=cfg.dtype,
param_dtype=cfg.param_dtype,
Expand Down Expand Up @@ -246,24 +294,13 @@ class TBlock(nn.Module):
def __call__(self, in_BxLxD: jax.Array):
cfg = self.docfg

# "pre-layernorm"
if cfg.normalization == 'layernorm':
x_BxLxD = nn.LayerNorm(
dtype=cfg.dtype, param_dtype=cfg.param_dtype, use_bias=False
)(in_BxLxD)
elif cfg.normalization == 'rmsnorm':
x_BxLxD = nn.RMSNorm(
dtype=cfg.dtype,
param_dtype=cfg.param_dtype,
epsilon=cfg.rmsnorm_epsilon,
)(in_BxLxD)
else:
raise ValueError(f'Unknown normalization: {cfg.normalization}')
x_BxLxD = cfg.make_norm()(in_BxLxD)

x_BxLxD = Attention(cfg)(x_BxLxD)
x_BxLxD += in_BxLxD

z_BxLxD = Mlp(cfg)(x_BxLxD)
z_BxLxD = cfg.make_norm()(x_BxLxD)
z_BxLxD = Mlp(cfg)(z_BxLxD)

return x_BxLxD + z_BxLxD

Expand All @@ -282,27 +319,8 @@ def setup(self):
dtype=cfg.dtype,
param_dtype=cfg.param_dtype,
)
self.pos_embed = nn.Embed(
num_embeddings=cfg.L,
features=cfg.D,
embedding_init=cfg.embed_init,
dtype=cfg.dtype,
param_dtype=cfg.param_dtype,
)

self.blocks = [TBlock(cfg) for _ in range(cfg.N)]
if cfg.normalization == 'layernorm':
self.out_ln = nn.LayerNorm(
dtype=cfg.dtype, param_dtype=cfg.param_dtype, use_bias=False
)
elif cfg.normalization == 'rmsnorm':
self.out_ln = nn.RMSNorm(
dtype=cfg.dtype,
param_dtype=cfg.param_dtype,
epsilon=cfg.rmsnorm_epsilon,
)
else:
raise ValueError(f'Unknown normalization: {cfg.normalization}')
self.out_ln = cfg.make_norm()

# Output projection - tied to input embeddings if configured
if cfg.tie_embeddings:
Expand Down Expand Up @@ -344,6 +362,8 @@ def build_flax_module(self):
normalization=self.hps['normalization'],
qk_norm=self.hps['qk_norm'],
tie_embeddings=self.hps['tie_embeddings'],
use_residual_scaling=self.hps['use_residual_scaling'],
initializer=self.hps['initializer'],
)
return TransformerDo(config)

Expand Down