Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 62 additions & 30 deletions init2winit/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,15 @@
This is useful for training neural networks with stax, where model parameters
are nested numpy arrays.
"""

from absl import flags
from absl import logging
from init2winit.dataset_lib import data_utils
import jax
# pylint: disable=g-importing-member
from jax.experimental.multihost_utils import process_allgather
from jax.experimental import multihost_utils
import orbax.checkpoint as ocp


FLAGS = flags.FLAGS


Expand All @@ -49,7 +51,8 @@ def maybe_restore_checkpoint(
unreplicated_batch_stats,
unreplicated_training_metrics_state,
orbax_checkpoint_manager=None,
orbax_checkpoint_manager_external=None):
orbax_checkpoint_manager_external=None,
):
"""Optionally restores from a checkpoint.

The checkpoint logic is as follows: if `orbax_checkpoint_manager` contains
Expand Down Expand Up @@ -77,9 +80,16 @@ def maybe_restore_checkpoint(
in train_dir.
"""
uninitialized_global_step = -1
# Unwrap CpuOffloaded leaves before passing to Orbax — it only accepts
# numpy/jax arrays as target leaves. The training algorithm's
# restore_optimizer_state() hook re-wraps them after restore.
unwrapped_optimizer_state = jax.tree.map(
lambda x: x.array if isinstance(x, data_utils.CpuOffloaded) else x,
unreplicated_optimizer_state,
)
unreplicated_checkpoint_state = dict(
params=unreplicated_params,
optimizer_state=unreplicated_optimizer_state,
optimizer_state=unwrapped_optimizer_state,
batch_stats=unreplicated_batch_stats,
training_metrics_grabber=unreplicated_training_metrics_state,
global_step=uninitialized_global_step,
Expand All @@ -96,7 +106,7 @@ def maybe_restore_checkpoint(
# train_dir does not exist or if it exists and contains no checkpoints.
# Note that we could likely change the below line to:
# found_checkpoint = latest_ckpt != unreplicated_checkpoint_state
found_checkpoint = (latest_ckpt['global_step'] != uninitialized_global_step)
found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step

# If there's a latest checkpoint in the train_dir, restore from that.
if found_checkpoint:
Expand All @@ -123,7 +133,8 @@ def maybe_restore_checkpoint(
0, # global_step
0, # sum_train_cost
0, # preemption_count
False) # is_restored
False,
) # is_restored
else: # Else, don't restore from any checkpoint.
return (
unreplicated_optimizer_state,
Expand All @@ -133,7 +144,8 @@ def maybe_restore_checkpoint(
0, # global_step
0, # sum_train_cost
0, # preemption_count
False) # is_restored
False,
) # is_restored

return (
ckpt_to_return['optimizer_state'],
Expand All @@ -143,7 +155,8 @@ def maybe_restore_checkpoint(
ckpt_to_return['global_step'], # global_step
ckpt_to_return['sum_train_cost'],
ckpt_to_return['preemption_count'], # preemption_count
is_restored) # is_restored
is_restored,
) # is_restored


def unreplicate_and_save_checkpoint(
Expand All @@ -154,38 +167,56 @@ def unreplicate_and_save_checkpoint(
global_step,
preemption_count,
sum_train_cost,
orbax_checkpoint_manager):
orbax_checkpoint_manager,
):
"""Saves pytree, step, preemption_count, and sum_train_cost to train_dir."""
logging.info('Saving checkpoint to ckpt_%d', global_step)
# jax.device_get doesn't work if jax.Array lives on multiple hosts.
# So we first all_gather it to the host and then call jax.device_get
if jax.process_count() > 1:
unreplicated_optimizer_state = jax.device_get(
process_allgather(optimizer_state, tiled=True))
unreplicated_params = jax.device_get(process_allgather(params, tiled=True))
unreplicated_optimizer_state = jax.tree.map(
lambda x: x
if isinstance(x, data_utils.CpuOffloaded)
else jax.device_get(multihost_utils.process_allgather(x, tiled=True)),
optimizer_state,
)
unreplicated_params = jax.device_get(
multihost_utils.process_allgather(params, tiled=True)
)
else:
unreplicated_optimizer_state = jax.device_get(optimizer_state)
unreplicated_optimizer_state = jax.tree.map(
lambda x: x
if isinstance(x, data_utils.CpuOffloaded)
else jax.device_get(x),
optimizer_state,
)
unreplicated_params = jax.device_get(params)
unreplicated_batch_stats = jax.device_get(batch_stats)
unreplicated_training_metrics_state = jax.device_get(
training_metrics_state)
unreplicated_training_metrics_state = jax.device_get(training_metrics_state)
unreplicated_sum_train_cost = jax.device_get(sum_train_cost)
state = dict(global_step=global_step,
preemption_count=preemption_count,
sum_train_cost=unreplicated_sum_train_cost,
optimizer_state=unreplicated_optimizer_state,
params=unreplicated_params,
batch_stats=unreplicated_batch_stats,
training_metrics_grabber=unreplicated_training_metrics_state)
save_checkpoint(global_step,
state,
orbax_checkpoint_manager=orbax_checkpoint_manager)
# Unwrap CpuOffloaded leaves to plain numpy arrays for Orbax serialization.
# CpuOffloaded is a runtime-only wrapper for sharding control; on disk the
# wrapped arrays are stored as regular numpy arrays.
unreplicated_optimizer_state = jax.tree.map(
lambda x: x.array if isinstance(x, data_utils.CpuOffloaded) else x,
unreplicated_optimizer_state,
)
state = dict(
global_step=global_step,
preemption_count=preemption_count,
sum_train_cost=unreplicated_sum_train_cost,
optimizer_state=unreplicated_optimizer_state,
params=unreplicated_params,
batch_stats=unreplicated_batch_stats,
training_metrics_grabber=unreplicated_training_metrics_state,
)
save_checkpoint(
global_step, state, orbax_checkpoint_manager=orbax_checkpoint_manager
)
logging.info('Done saving checkpoint.')


def save_checkpoint(step,
state,
orbax_checkpoint_manager):
def save_checkpoint(step, state, orbax_checkpoint_manager):
"""Saves checkpoint to train_dir.

A list of checkpoints will be stored in train_dir/step.
Expand Down Expand Up @@ -229,9 +260,10 @@ def load_latest_checkpoint(target=None, orbax_checkpoint_manager=None):
"""Loads the most recent checkpoint listed in train_dir.

Args:
target: used for checkpointing, a pytree whose structure will be used
to structure the restored checkpoint data.
target: used for checkpointing, a pytree whose structure will be used to
structure the restored checkpoint data.
orbax_checkpoint_manager: An orbax.CheckpointManager instance.

Returns:
The state restored from the checkpoint. If using Flax checkpointing and
target=None, this will return a unstructured dictionary containing the
Expand Down
98 changes: 69 additions & 29 deletions init2winit/dataset_lib/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,15 @@
import jraph
import numpy as np


Dataset = collections.namedtuple('Dataset', [
'train_iterator_fn',
'eval_train_epoch',
'valid_epoch',
'test_epoch',
])
Dataset = collections.namedtuple(
'Dataset',
[
'train_iterator_fn',
'eval_train_epoch',
'valid_epoch',
'test_epoch',
],
)


def log_rss(msg: str):
Expand All @@ -45,8 +47,9 @@ def log_rss(msg: str):
logging.info('%s — RSS: %.1f MB', msg, rss_mb)


def prefetch_iterator(source_iter: Iterator[jax.typing.ArrayLike],
num_prefetch: int) -> Iterator[jax.typing.ArrayLike]:
def prefetch_iterator(
source_iter: Iterator[jax.typing.ArrayLike], num_prefetch: int
) -> Iterator[jax.typing.ArrayLike]:
"""Wraps the given iterator with prefetching.

Args:
Expand Down Expand Up @@ -121,14 +124,16 @@ def iterator_as_numpy(iterator):
yield jax.tree.map(lambda y: y._numpy(), x) # pylint: disable=protected-access


def image_iterator(data,
rescale,
output_shape,
is_one_hot,
autoencoder,
shuffle_rng=None,
augment_fn=None,
include_example_keys=False):
def image_iterator(
data,
rescale,
output_shape,
is_one_hot,
autoencoder,
shuffle_rng=None,
augment_fn=None,
include_example_keys=False,
):
"""Preprocesses the batch data arrays in the data generator.

Rescales inputs. One hot encode targets if is_one_hot is true.
Expand Down Expand Up @@ -166,11 +171,13 @@ def image_iterator(data,
yield {'inputs': inputs, 'targets': targets}


def maybe_pad_batch(batch,
desired_batch_size,
data_format=None,
mask_key=None,
padding_value=0.0):
def maybe_pad_batch(
batch,
desired_batch_size,
data_format=None,
mask_key=None,
padding_value=0.0,
):
"""Zero pad the batch on the right to desired_batch_size.

All keys in the batch dictionary will have their corresponding arrays padded.
Expand All @@ -187,9 +194,9 @@ def maybe_pad_batch(batch,
dimension to pad. If not provided then it is assumed the first dimension
is the batch dimension.
mask_key: Typically used for text datasets, it's either 'inputs' (for
encoder only models like language models) or 'targets'
(for encoder-decoder models like seq2seq tasks) to decide weights for
padded sequence. For Image datasets, this will be (most likely) unused.
encoder only models like language models) or 'targets' (for
encoder-decoder models like seq2seq tasks) to decide weights for padded
sequence. For Image datasets, this will be (most likely) unused.
padding_value: value to be used as padding.

Returns:
Expand Down Expand Up @@ -247,7 +254,7 @@ def make_global_array(local_data, mesh):
"""Util to combine per-host batches into a global batch array.

Args:
local_data: local data batch on host.
local_data: local data batch on host.
mesh: mesh specification to shard the data.

Returns:
Expand All @@ -265,13 +272,46 @@ def make_global_array(local_data, mesh):
return global_array


class CpuOffloaded:
"""Marker wrapper for arrays that should remain on CPU.

Wraps a numpy array to signal to the trainer's sharding and checkpoint
code that this leaf should be skipped during JAX sharding operations
and device transfers. Used by optimizers that offload state to host
memory (e.g., single-worker DiLoCo's slow_params and nesterov_b).

The wrapped array is accessible via the `array` attribute.
"""

def __init__(self, array):
self.array = array

@property
def shape(self):
return self.array.shape

@property
def dtype(self):
return self.array.dtype

def __repr__(self):
return f'CpuOffloaded(shape={self.shape}, dtype={self.dtype})'


def shard_pytree(pytree, mesh, shardings=None):
"""Shards a pytree with the given shardings and mesh."""

if shardings is None:
shardings = nn.get_sharding(pytree, mesh)

def _maybe_shard(arr, sharding):
"""Shards the given array if the sharding is not None."""
if sharding is None:
return arr
return jax.make_array_from_process_local_data(sharding, arr, arr.shape)

pytree = jax.tree_util.tree_map(
lambda arr, sharding: jax.make_array_from_process_local_data(
sharding, arr, arr.shape
),
_maybe_shard,
pytree,
shardings,
)
Expand Down
13 changes: 1 addition & 12 deletions init2winit/model_lib/mdlm_rope_nanodo.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,18 +91,7 @@ def setup(self):
)

self.blocks = [rope_nanodo.TBlock(cfg) for _ in range(cfg.N)]
if cfg.normalization == 'layernorm':
self.out_ln = nn.LayerNorm(
dtype=cfg.dtype, param_dtype=cfg.param_dtype, use_bias=False
)
elif cfg.normalization == 'rmsnorm':
self.out_ln = nn.RMSNorm(
dtype=cfg.dtype,
param_dtype=cfg.param_dtype,
epsilon=cfg.rmsnorm_epsilon,
)
else:
raise ValueError(f'Unknown normalization: {cfg.normalization}')
self.out_ln = cfg.make_norm()

if cfg.tie_embeddings:
self.output_proj = None
Expand Down
Loading