From 231af3cbadd896400a6b8d32ab9c2d09a242a5df Mon Sep 17 00:00:00 2001 From: Ahmed Khaled Date: Tue, 2 Jun 2026 11:02:46 -0700 Subject: [PATCH] internal PiperOrigin-RevId: 925464038 --- init2winit/checkpoint.py | 92 ++++++++++----- init2winit/dataset_lib/data_utils.py | 98 +++++++++++----- init2winit/model_lib/mdlm_rope_nanodo.py | 13 +-- init2winit/model_lib/rope_nanodo.py | 114 +++++++++++-------- init2winit/trainer_lib/base_trainer.py | 7 ++ init2winit/trainer_lib/trainer.py | 14 ++- init2winit/trainer_lib/training_algorithm.py | 15 +++ 7 files changed, 232 insertions(+), 121 deletions(-) diff --git a/init2winit/checkpoint.py b/init2winit/checkpoint.py index 49ed335f..2affecc5 100644 --- a/init2winit/checkpoint.py +++ b/init2winit/checkpoint.py @@ -18,13 +18,15 @@ This is useful for training neural networks with stax, where model parameters are nested numpy arrays. """ + from absl import flags from absl import logging +from init2winit.dataset_lib import data_utils import jax -# pylint: disable=g-importing-member -from jax.experimental.multihost_utils import process_allgather +from jax.experimental import multihost_utils import orbax.checkpoint as ocp + FLAGS = flags.FLAGS @@ -49,7 +51,8 @@ def maybe_restore_checkpoint( unreplicated_batch_stats, unreplicated_training_metrics_state, orbax_checkpoint_manager=None, - orbax_checkpoint_manager_external=None): + orbax_checkpoint_manager_external=None, +): """Optionally restores from a checkpoint. The checkpoint logic is as follows: if `orbax_checkpoint_manager` contains @@ -77,9 +80,16 @@ def maybe_restore_checkpoint( in train_dir. """ uninitialized_global_step = -1 + # Unwrap CpuOffloaded leaves before passing to Orbax — it only accepts + # numpy/jax arrays as target leaves. The training algorithm's + # restore_optimizer_state() hook re-wraps them after restore. + unwrapped_optimizer_state = jax.tree.map( + lambda x: x.array if isinstance(x, data_utils.CpuOffloaded) else x, + unreplicated_optimizer_state, + ) unreplicated_checkpoint_state = dict( params=unreplicated_params, - optimizer_state=unreplicated_optimizer_state, + optimizer_state=unwrapped_optimizer_state, batch_stats=unreplicated_batch_stats, training_metrics_grabber=unreplicated_training_metrics_state, global_step=uninitialized_global_step, @@ -96,7 +106,7 @@ def maybe_restore_checkpoint( # train_dir does not exist or if it exists and contains no checkpoints. # Note that we could likely change the below line to: # found_checkpoint = latest_ckpt != unreplicated_checkpoint_state - found_checkpoint = (latest_ckpt['global_step'] != uninitialized_global_step) + found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step # If there's a latest checkpoint in the train_dir, restore from that. if found_checkpoint: @@ -123,7 +133,8 @@ def maybe_restore_checkpoint( 0, # global_step 0, # sum_train_cost 0, # preemption_count - False) # is_restored + False, + ) # is_restored else: # Else, don't restore from any checkpoint. return ( unreplicated_optimizer_state, @@ -133,7 +144,8 @@ def maybe_restore_checkpoint( 0, # global_step 0, # sum_train_cost 0, # preemption_count - False) # is_restored + False, + ) # is_restored return ( ckpt_to_return['optimizer_state'], @@ -143,7 +155,8 @@ def maybe_restore_checkpoint( ckpt_to_return['global_step'], # global_step ckpt_to_return['sum_train_cost'], ckpt_to_return['preemption_count'], # preemption_count - is_restored) # is_restored + is_restored, + ) # is_restored def unreplicate_and_save_checkpoint( @@ -154,38 +167,56 @@ def unreplicate_and_save_checkpoint( global_step, preemption_count, sum_train_cost, - orbax_checkpoint_manager): + orbax_checkpoint_manager, +): """Saves pytree, step, preemption_count, and sum_train_cost to train_dir.""" logging.info('Saving checkpoint to ckpt_%d', global_step) # jax.device_get doesn't work if jax.Array lives on multiple hosts. # So we first all_gather it to the host and then call jax.device_get if jax.process_count() > 1: - unreplicated_optimizer_state = jax.device_get( - process_allgather(optimizer_state, tiled=True)) - unreplicated_params = jax.device_get(process_allgather(params, tiled=True)) + unreplicated_optimizer_state = jax.tree.map( + lambda x: x + if isinstance(x, data_utils.CpuOffloaded) + else jax.device_get(multihost_utils.process_allgather(x, tiled=True)), + optimizer_state, + ) + unreplicated_params = jax.device_get( + multihost_utils.process_allgather(params, tiled=True) + ) else: - unreplicated_optimizer_state = jax.device_get(optimizer_state) + unreplicated_optimizer_state = jax.tree.map( + lambda x: x + if isinstance(x, data_utils.CpuOffloaded) + else jax.device_get(x), + optimizer_state, + ) unreplicated_params = jax.device_get(params) unreplicated_batch_stats = jax.device_get(batch_stats) - unreplicated_training_metrics_state = jax.device_get( - training_metrics_state) + unreplicated_training_metrics_state = jax.device_get(training_metrics_state) unreplicated_sum_train_cost = jax.device_get(sum_train_cost) - state = dict(global_step=global_step, - preemption_count=preemption_count, - sum_train_cost=unreplicated_sum_train_cost, - optimizer_state=unreplicated_optimizer_state, - params=unreplicated_params, - batch_stats=unreplicated_batch_stats, - training_metrics_grabber=unreplicated_training_metrics_state) - save_checkpoint(global_step, - state, - orbax_checkpoint_manager=orbax_checkpoint_manager) + # Unwrap CpuOffloaded leaves to plain numpy arrays for Orbax serialization. + # CpuOffloaded is a runtime-only wrapper for sharding control; on disk the + # wrapped arrays are stored as regular numpy arrays. + unreplicated_optimizer_state = jax.tree.map( + lambda x: x.array if isinstance(x, data_utils.CpuOffloaded) else x, + unreplicated_optimizer_state, + ) + state = dict( + global_step=global_step, + preemption_count=preemption_count, + sum_train_cost=unreplicated_sum_train_cost, + optimizer_state=unreplicated_optimizer_state, + params=unreplicated_params, + batch_stats=unreplicated_batch_stats, + training_metrics_grabber=unreplicated_training_metrics_state, + ) + save_checkpoint( + global_step, state, orbax_checkpoint_manager=orbax_checkpoint_manager + ) logging.info('Done saving checkpoint.') -def save_checkpoint(step, - state, - orbax_checkpoint_manager): +def save_checkpoint(step, state, orbax_checkpoint_manager): """Saves checkpoint to train_dir. A list of checkpoints will be stored in train_dir/step. @@ -229,9 +260,10 @@ def load_latest_checkpoint(target=None, orbax_checkpoint_manager=None): """Loads the most recent checkpoint listed in train_dir. Args: - target: used for checkpointing, a pytree whose structure will be used - to structure the restored checkpoint data. + target: used for checkpointing, a pytree whose structure will be used to + structure the restored checkpoint data. orbax_checkpoint_manager: An orbax.CheckpointManager instance. + Returns: The state restored from the checkpoint. If using Flax checkpointing and target=None, this will return a unstructured dictionary containing the diff --git a/init2winit/dataset_lib/data_utils.py b/init2winit/dataset_lib/data_utils.py index 5f0ec34d..26882eb0 100644 --- a/init2winit/dataset_lib/data_utils.py +++ b/init2winit/dataset_lib/data_utils.py @@ -30,13 +30,15 @@ import jraph import numpy as np - -Dataset = collections.namedtuple('Dataset', [ - 'train_iterator_fn', - 'eval_train_epoch', - 'valid_epoch', - 'test_epoch', -]) +Dataset = collections.namedtuple( + 'Dataset', + [ + 'train_iterator_fn', + 'eval_train_epoch', + 'valid_epoch', + 'test_epoch', + ], +) def log_rss(msg: str): @@ -45,8 +47,9 @@ def log_rss(msg: str): logging.info('%s — RSS: %.1f MB', msg, rss_mb) -def prefetch_iterator(source_iter: Iterator[jax.typing.ArrayLike], - num_prefetch: int) -> Iterator[jax.typing.ArrayLike]: +def prefetch_iterator( + source_iter: Iterator[jax.typing.ArrayLike], num_prefetch: int +) -> Iterator[jax.typing.ArrayLike]: """Wraps the given iterator with prefetching. Args: @@ -121,14 +124,16 @@ def iterator_as_numpy(iterator): yield jax.tree.map(lambda y: y._numpy(), x) # pylint: disable=protected-access -def image_iterator(data, - rescale, - output_shape, - is_one_hot, - autoencoder, - shuffle_rng=None, - augment_fn=None, - include_example_keys=False): +def image_iterator( + data, + rescale, + output_shape, + is_one_hot, + autoencoder, + shuffle_rng=None, + augment_fn=None, + include_example_keys=False, +): """Preprocesses the batch data arrays in the data generator. Rescales inputs. One hot encode targets if is_one_hot is true. @@ -166,11 +171,13 @@ def image_iterator(data, yield {'inputs': inputs, 'targets': targets} -def maybe_pad_batch(batch, - desired_batch_size, - data_format=None, - mask_key=None, - padding_value=0.0): +def maybe_pad_batch( + batch, + desired_batch_size, + data_format=None, + mask_key=None, + padding_value=0.0, +): """Zero pad the batch on the right to desired_batch_size. All keys in the batch dictionary will have their corresponding arrays padded. @@ -187,9 +194,9 @@ def maybe_pad_batch(batch, dimension to pad. If not provided then it is assumed the first dimension is the batch dimension. mask_key: Typically used for text datasets, it's either 'inputs' (for - encoder only models like language models) or 'targets' - (for encoder-decoder models like seq2seq tasks) to decide weights for - padded sequence. For Image datasets, this will be (most likely) unused. + encoder only models like language models) or 'targets' (for + encoder-decoder models like seq2seq tasks) to decide weights for padded + sequence. For Image datasets, this will be (most likely) unused. padding_value: value to be used as padding. Returns: @@ -247,7 +254,7 @@ def make_global_array(local_data, mesh): """Util to combine per-host batches into a global batch array. Args: - local_data: local data batch on host. + local_data: local data batch on host. mesh: mesh specification to shard the data. Returns: @@ -265,13 +272,46 @@ def make_global_array(local_data, mesh): return global_array +class CpuOffloaded: + """Marker wrapper for arrays that should remain on CPU. + + Wraps a numpy array to signal to the trainer's sharding and checkpoint + code that this leaf should be skipped during JAX sharding operations + and device transfers. Used by optimizers that offload state to host + memory (e.g., single-worker DiLoCo's slow_params and nesterov_b). + + The wrapped array is accessible via the `array` attribute. + """ + + def __init__(self, array): + self.array = array + + @property + def shape(self): + return self.array.shape + + @property + def dtype(self): + return self.array.dtype + + def __repr__(self): + return f'CpuOffloaded(shape={self.shape}, dtype={self.dtype})' + + def shard_pytree(pytree, mesh, shardings=None): + """Shards a pytree with the given shardings and mesh.""" + if shardings is None: shardings = nn.get_sharding(pytree, mesh) + + def _maybe_shard(arr, sharding): + """Shards the given array if the sharding is not None.""" + if sharding is None: + return arr + return jax.make_array_from_process_local_data(sharding, arr, arr.shape) + pytree = jax.tree_util.tree_map( - lambda arr, sharding: jax.make_array_from_process_local_data( - sharding, arr, arr.shape - ), + _maybe_shard, pytree, shardings, ) diff --git a/init2winit/model_lib/mdlm_rope_nanodo.py b/init2winit/model_lib/mdlm_rope_nanodo.py index f45fdd31..b42f25ab 100644 --- a/init2winit/model_lib/mdlm_rope_nanodo.py +++ b/init2winit/model_lib/mdlm_rope_nanodo.py @@ -91,18 +91,7 @@ def setup(self): ) self.blocks = [rope_nanodo.TBlock(cfg) for _ in range(cfg.N)] - if cfg.normalization == 'layernorm': - self.out_ln = nn.LayerNorm( - dtype=cfg.dtype, param_dtype=cfg.param_dtype, use_bias=False - ) - elif cfg.normalization == 'rmsnorm': - self.out_ln = nn.RMSNorm( - dtype=cfg.dtype, - param_dtype=cfg.param_dtype, - epsilon=cfg.rmsnorm_epsilon, - ) - else: - raise ValueError(f'Unknown normalization: {cfg.normalization}') + self.out_ln = cfg.make_norm() if cfg.tie_embeddings: self.output_proj = None diff --git a/init2winit/model_lib/rope_nanodo.py b/init2winit/model_lib/rope_nanodo.py index ffda5592..1046ac4f 100644 --- a/init2winit/model_lib/rope_nanodo.py +++ b/init2winit/model_lib/rope_nanodo.py @@ -26,15 +26,12 @@ from flax import linen as nn from init2winit import utils from init2winit.model_lib import base_model -from init2winit.model_lib import model_utils import jax import jax.numpy as jnp from ml_collections.config_dict import config_dict partial = functools.partial -ParameterType = model_utils.ParameterType -NamedSharding = jax.sharding.NamedSharding -P = jax.sharding.PartitionSpec + DEFAULT_HPARAMS = config_dict.ConfigDict( dict( @@ -48,6 +45,8 @@ mlp_activation='glu', qk_norm=True, tie_embeddings=True, + use_residual_scaling=False, + initializer='xavier', ) ) @@ -62,10 +61,6 @@ class DoConfig: V: int # vocab size F: int # FF inner dimension L: int # sequence length - kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() - embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal', out_axis=0 - ) dtype: jnp.dtype = jnp.bfloat16 param_dtype: jnp.dtype = jnp.float32 rmsnorm_epsilon: float = 1e-6 @@ -76,6 +71,50 @@ class DoConfig: qk_norm: bool = True is_causal: bool = True eps: float = 1e-6 + use_residual_scaling: bool = False + initializer: str = 'xavier' # 'xavier' or 'std0.02' + + # Derived initializers (set in __post_init__) + attention_init: nn.initializers.Initializer = dataclasses.field(init=False) + linear_init: nn.initializers.Initializer = dataclasses.field(init=False) + embed_init: nn.initializers.Initializer = dataclasses.field(init=False) + residual_init: nn.initializers.Initializer = dataclasses.field(init=False) + + def __post_init__(self): + if self.initializer == 'xavier': + self.attention_init = nn.initializers.xavier_uniform() + self.linear_init = nn.initializers.xavier_uniform() + self.embed_init = nn.initializers.variance_scaling( + 1.0, 'fan_in', 'normal', out_axis=0 + ) + elif self.initializer == 'std0.02': + self.attention_init = nn.initializers.normal(stddev=0.02) + self.linear_init = nn.initializers.normal(stddev=0.02) + self.embed_init = nn.initializers.normal(stddev=0.02) + else: + raise ValueError(f'Unknown initializer: {self.initializer}') + + if self.use_residual_scaling and self.initializer == 'std0.02': + self.residual_init = nn.initializers.normal( + stddev=0.02 / jnp.sqrt(2 * self.N) + ) + else: + self.residual_init = self.linear_init + + def make_norm(self): + """Returns a normalization layer based on config.""" + if self.normalization == 'layernorm': + return nn.LayerNorm( + dtype=self.dtype, param_dtype=self.param_dtype, use_bias=False + ) + elif self.normalization == 'rmsnorm': + return nn.RMSNorm( + dtype=self.dtype, + param_dtype=self.param_dtype, + epsilon=self.rmsnorm_epsilon, + ) + else: + raise ValueError(f'Unknown normalization: {self.normalization}') class Mlp(nn.Module): @@ -86,10 +125,9 @@ class Mlp(nn.Module): @nn.compact def __call__(self, x_BxLxD: jax.Array): cfg = self.cfg - # Use Xavier uniform initialization explicitly linear = partial( nn.Dense, - kernel_init=cfg.kernel_init, + kernel_init=cfg.linear_init, use_bias=False, dtype=cfg.dtype, param_dtype=cfg.param_dtype, @@ -112,7 +150,15 @@ def __call__(self, x_BxLxD: jax.Array): else: raise ValueError(f'Unknown activation: {cfg.mlp_activation}') x_BxLxF = mlp_activation(x_BxLxF) - x_BxLxD = linear(cfg.D)(x_BxLxF) + x_BxLxD = nn.Dense( + cfg.D, + kernel_init=cfg.residual_init + if cfg.use_residual_scaling + else cfg.linear_init, + use_bias=False, + dtype=cfg.dtype, + param_dtype=cfg.param_dtype, + )(x_BxLxF) return x_BxLxD @@ -178,7 +224,7 @@ def setup(self): nn.DenseGeneral, axis=-1, features=(cfg.H, self.Dh), - kernel_init=cfg.kernel_init, + kernel_init=cfg.attention_init, use_bias=False, dtype=cfg.dtype, param_dtype=cfg.param_dtype, @@ -191,7 +237,9 @@ def setup(self): features=cfg.D, name='attn_out_proj', # axis=(-2, -1), # - kernel_init=cfg.kernel_init, + kernel_init=cfg.residual_init + if cfg.use_residual_scaling + else cfg.linear_init, use_bias=False, dtype=cfg.dtype, param_dtype=cfg.param_dtype, @@ -246,24 +294,13 @@ class TBlock(nn.Module): def __call__(self, in_BxLxD: jax.Array): cfg = self.docfg - # "pre-layernorm" - if cfg.normalization == 'layernorm': - x_BxLxD = nn.LayerNorm( - dtype=cfg.dtype, param_dtype=cfg.param_dtype, use_bias=False - )(in_BxLxD) - elif cfg.normalization == 'rmsnorm': - x_BxLxD = nn.RMSNorm( - dtype=cfg.dtype, - param_dtype=cfg.param_dtype, - epsilon=cfg.rmsnorm_epsilon, - )(in_BxLxD) - else: - raise ValueError(f'Unknown normalization: {cfg.normalization}') + x_BxLxD = cfg.make_norm()(in_BxLxD) x_BxLxD = Attention(cfg)(x_BxLxD) x_BxLxD += in_BxLxD - z_BxLxD = Mlp(cfg)(x_BxLxD) + z_BxLxD = cfg.make_norm()(x_BxLxD) + z_BxLxD = Mlp(cfg)(z_BxLxD) return x_BxLxD + z_BxLxD @@ -282,27 +319,8 @@ def setup(self): dtype=cfg.dtype, param_dtype=cfg.param_dtype, ) - self.pos_embed = nn.Embed( - num_embeddings=cfg.L, - features=cfg.D, - embedding_init=cfg.embed_init, - dtype=cfg.dtype, - param_dtype=cfg.param_dtype, - ) - self.blocks = [TBlock(cfg) for _ in range(cfg.N)] - if cfg.normalization == 'layernorm': - self.out_ln = nn.LayerNorm( - dtype=cfg.dtype, param_dtype=cfg.param_dtype, use_bias=False - ) - elif cfg.normalization == 'rmsnorm': - self.out_ln = nn.RMSNorm( - dtype=cfg.dtype, - param_dtype=cfg.param_dtype, - epsilon=cfg.rmsnorm_epsilon, - ) - else: - raise ValueError(f'Unknown normalization: {cfg.normalization}') + self.out_ln = cfg.make_norm() # Output projection - tied to input embeddings if configured if cfg.tie_embeddings: @@ -344,6 +362,8 @@ def build_flax_module(self): normalization=self.hps['normalization'], qk_norm=self.hps['qk_norm'], tie_embeddings=self.hps['tie_embeddings'], + use_residual_scaling=self.hps['use_residual_scaling'], + initializer=self.hps['initializer'], ) return TransformerDo(config) diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index 32d3d7a3..7dbb4132 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -653,6 +653,13 @@ def setup_and_maybe_restore(self, init_rng, data_rng, callback_rng): logging.info( 'Checkpoint restored in %f seconds', time.time() - start_time ) + # Allow training algorithms to post-process restored optimizer state + # (e.g., re-wrap CpuOffloaded leaves stripped during serialization). + unreplicated_optimizer_state = ( + self.training_algorithm.restore_optimizer_state( + unreplicated_optimizer_state + ) + ) start_time = time.time() ( self._params, diff --git a/init2winit/trainer_lib/trainer.py b/init2winit/trainer_lib/trainer.py index c35c39b1..272b4fa4 100644 --- a/init2winit/trainer_lib/trainer.py +++ b/init2winit/trainer_lib/trainer.py @@ -105,10 +105,18 @@ def shard( _, params = data_utils.shard_pytree( unreplicated_params, self._mesh, params_sharding ) + + def _get_sharding(x): + """Returns the sharding for the given pytree node.""" + if isinstance(x, data_utils.CpuOffloaded): + return None + elif isinstance(x, jax.Array) and isinstance(x.sharding, NamedSharding): + return x.sharding + else: + return NamedSharding(self._mesh, P()) + optimizer_state_sharding = jax.tree_util.tree_map( - lambda x: x.sharding - if isinstance(x.sharding, NamedSharding) - else NamedSharding(self._mesh, P()), + _get_sharding, unreplicated_optimizer_state, ) diff --git a/init2winit/trainer_lib/training_algorithm.py b/init2winit/trainer_lib/training_algorithm.py index f509b1f0..522f4b73 100644 --- a/init2winit/trainer_lib/training_algorithm.py +++ b/init2winit/trainer_lib/training_algorithm.py @@ -197,6 +197,21 @@ def init_optimizer_state( Optimizer state: Pytree of optimizer state. """ + def restore_optimizer_state(self, optimizer_state): + """Post-processes optimizer state after checkpoint restore. + + Override this method in subclasses that use runtime wrappers (e.g., + CpuOffloaded) which are stripped during serialization and need to be + re-applied after deserialization. + + Args: + optimizer_state: The restored optimizer state pytree (plain numpy arrays). + + Returns: + The post-processed optimizer state, ready for sharding. + """ + return optimizer_state + # Per-optimizer default opt_hparams for OptaxTrainingAlgorithm. # These consolidate all the inline defaults from get_optimizer() in