From f4ed0e5686427dd46121789f021ec779eb974c1d Mon Sep 17 00:00:00 2001 From: Ahmed Khaled Date: Mon, 18 May 2026 14:31:35 -0700 Subject: [PATCH] internal PiperOrigin-RevId: 917429077 --- hessian/model_debugger.py | 141 ++-- hessian/model_debugger_callback.py | 31 +- hessian/precondition.py | 42 +- hessian/test_model_debugger.py | 100 ++- hessian/test_precondition.py | 80 +- init2winit/base_callback.py | 17 +- init2winit/callbacks.py | 1 - init2winit/checkpoint.py | 56 +- init2winit/dataset_lib/autoaugment.py | 198 +++-- .../dataset_lib/criteo_terabyte_dataset.py | 128 ++- init2winit/dataset_lib/data_selectors.py | 14 +- init2winit/dataset_lib/data_utils.py | 59 +- init2winit/dataset_lib/datasets.py | 2 + init2winit/dataset_lib/fake_dataset.py | 23 +- init2winit/dataset_lib/fastmri_dataset.py | 120 +-- init2winit/dataset_lib/image_preprocessing.py | 34 +- init2winit/dataset_lib/imagenet_dataset.py | 5 +- .../dataset_lib/imagenet_preprocessing.py | 123 +-- init2winit/dataset_lib/librispeech.py | 76 +- .../dataset_lib/librispeech_input_pipeline.py | 55 +- .../dataset_lib/lm1b_input_pipeline_v2.py | 125 +-- init2winit/dataset_lib/lm1b_v2.py | 52 +- .../dataset_lib/mlperf_imagenet_dataset.py | 71 +- .../dataset_lib/mlperf_input_pipeline.py | 52 +- init2winit/dataset_lib/mt_pipeline.py | 290 ++++--- init2winit/dataset_lib/mt_pipeline_test.py | 142 ++-- init2winit/dataset_lib/mt_tokenizer.py | 83 +- init2winit/dataset_lib/nanodo_c4.py | 24 +- .../dataset_lib/nanodo_data_loader_shared.py | 3 +- init2winit/dataset_lib/nanodo_fineweb_edu.py | 19 +- init2winit/dataset_lib/nqm_noise.py | 6 +- init2winit/dataset_lib/ogbg_molpcba.py | 55 +- init2winit/dataset_lib/pg19.py | 208 +++-- init2winit/dataset_lib/protein_vocab.py | 97 ++- init2winit/dataset_lib/proteins.py | 150 ++-- .../dataset_lib/small_image_datasets.py | 396 +++++---- init2winit/dataset_lib/test_data_utils.py | 9 +- init2winit/dataset_lib/test_datasets.py | 128 +-- .../test_fineweb_edu_10b_input_pipeline.py | 1 + init2winit/dataset_lib/test_ogbg_molpcba.py | 80 +- init2winit/dataset_lib/test_ogbg_static.py | 1 + .../dataset_lib/test_small_image_datasets.py | 36 +- .../dataset_lib/test_wikitext_tokenizer.py | 1 + init2winit/dataset_lib/translate_wmt.py | 95 ++- init2winit/dataset_lib/wikitext103.py | 17 +- .../dataset_lib/wikitext103_input_pipeline.py | 72 +- init2winit/dataset_lib/wikitext2.py | 33 +- .../dataset_lib/wikitext2_input_pipeline.py | 53 +- init2winit/gradient_statistics_callback.py | 26 +- init2winit/hyperparameters.py | 20 +- init2winit/init_lib/initializers.py | 2 + init2winit/init_lib/meta_init.py | 132 +-- init2winit/init_lib/sparse_init.py | 40 +- init2winit/init_lib/test_initializers.py | 35 +- init2winit/main.py | 21 +- init2winit/model_lib/adabelief_densenet.py | 38 +- init2winit/model_lib/adabelief_resnet.py | 52 +- init2winit/model_lib/adabelief_vgg.py | 84 +- init2winit/model_lib/attention.py | 335 ++++---- init2winit/model_lib/autoencoder.py | 11 +- init2winit/model_lib/base_model.py | 55 +- init2winit/model_lib/binarize_layers.py | 486 ++++++----- init2winit/model_lib/conformer.py | 351 ++++---- .../model_lib/convolutional_autoencoder.py | 26 +- init2winit/model_lib/deepspeech.py | 267 +++--- init2winit/model_lib/dlrm.py | 90 +- init2winit/model_lib/fully_connected.py | 15 +- init2winit/model_lib/gnn.py | 27 +- .../model_lib/librispeech_preprocessor.py | 381 ++++++--- .../model_lib/local_attention_transformer.py | 780 ++++++++++-------- init2winit/model_lib/lstm.py | 127 +-- init2winit/model_lib/lstm_lm.py | 27 +- init2winit/model_lib/max_pooling_cnn.py | 69 +- .../model_lib/metrics_minimized_registry.py | 11 +- init2winit/model_lib/mlperf_resnet.py | 108 ++- init2winit/model_lib/model_utils.py | 80 +- init2winit/model_lib/models.py | 1 - init2winit/model_lib/nanodo.py | 10 +- init2winit/model_lib/normalization.py | 139 ++-- init2winit/model_lib/nqm.py | 79 +- init2winit/model_lib/partition_tree.py | 3 +- init2winit/model_lib/resnet.py | 93 ++- init2winit/model_lib/rope_nanodo.py | 7 +- init2winit/model_lib/simple_cnn.py | 33 +- init2winit/model_lib/spectrum_augmenter.py | 49 +- .../test_local_attention_transformer.py | 41 +- init2winit/model_lib/test_losses.py | 269 +++--- init2winit/model_lib/test_metrics.py | 383 ++++++--- init2winit/model_lib/test_models.py | 215 +++-- init2winit/model_lib/test_normalization.py | 72 +- init2winit/model_lib/transformer_lm.py | 184 +++-- init2winit/model_lib/transformer_stu_lm.py | 194 +++-- .../model_lib/transformer_stu_tensordot_lm.py | 198 +++-- init2winit/model_lib/unet.py | 54 +- init2winit/model_lib/vit.py | 238 ++++-- init2winit/model_lib/wide_resnet.py | 48 +- init2winit/model_lib/xformer_translate.py | 613 ++++++++------ .../model_lib/xformer_translate_binary.py | 527 ++++++------ .../xformer_translate_mlc_variant.py | 596 +++++++------ init2winit/mt_eval/decode.py | 366 ++++---- init2winit/mt_eval/eval_utils.py | 17 +- init2winit/mt_eval/inference.py | 122 +-- init2winit/mt_eval/mt_callback.py | 87 +- init2winit/optimizer_lib/factor_sam.py | 18 +- .../optimizer_lib/gradient_accumulator.py | 39 +- .../optimizer_lib/kitchen_sink/__init__.py | 1 - .../optimizer_lib/kitchen_sink/_src/alias.py | 51 +- .../kitchen_sink/_src/combine.py | 27 +- .../optimizer_lib/kitchen_sink/_src/core.py | 9 +- .../optimizer_lib/kitchen_sink/_src/mask.py | 8 +- .../kitchen_sink/_src/preconditioner.py | 59 +- .../kitchen_sink/_src/test_core.py | 131 ++- .../kitchen_sink/_src/test_mask.py | 65 +- .../kitchen_sink/_src/test_preconditioner.py | 57 +- .../kitchen_sink/_src/test_transform.py | 224 ++--- .../kitchen_sink/_src/transform.py | 90 +- .../optimizer_lib/kitchen_sink/_src/utils.py | 25 +- .../linalg/low_rank_root_update.py | 18 +- .../linalg/low_rank_root_update_test.py | 24 +- .../linalg/paterson_stockmeyer.py | 8 +- .../optimizer_lib/linalg/pth_inv_root_rmn.py | 17 +- .../linalg/pth_inv_root_rmn_coefficients.py | 35 +- .../linalg/pth_inv_root_rmn_test.py | 48 +- .../optimizer_lib/linalg/root_selector.py | 13 +- .../optimizer_lib/online_newton_step.py | 93 ++- init2winit/optimizer_lib/optimizers.py | 104 +-- init2winit/optimizer_lib/pax_adafactor.py | 102 ++- init2winit/optimizer_lib/samuel.py | 9 +- init2winit/optimizer_lib/search_subspace.py | 95 ++- .../sharpness_aware_minimization.py | 22 +- .../test_gradient_accumulator.py | 97 ++- init2winit/optimizer_lib/test_optimizers.py | 6 +- .../optimizer_lib/test_search_subspace.py | 41 +- init2winit/optimizer_lib/test_utils.py | 3 +- init2winit/optimizer_lib/utils.py | 6 +- init2winit/test_checkpoint.py | 82 +- init2winit/test_schedules.py | 391 ++++++--- init2winit/test_training_metrics_grabber.py | 351 ++++---- init2winit/test_utils.py | 10 +- init2winit/tools/inspect_dataset.py | 10 +- init2winit/trainer_lib/base_trainer.py | 74 +- .../trainer_lib/test_mdlm_integration.py | 1 + init2winit/trainer_lib/test_trainer.py | 433 ++++++---- init2winit/trainer_lib/trainer_utils.py | 5 +- init2winit/trainer_lib/trainers.py | 1 - init2winit/trainer_lib/training_algorithm.py | 1 - init2winit/training_metrics_grabber.py | 202 +++-- init2winit/utils.py | 68 +- 148 files changed, 8933 insertions(+), 6179 deletions(-) diff --git a/hessian/model_debugger.py b/hessian/model_debugger.py index 1560b0f3..02a5b344 100644 --- a/hessian/model_debugger.py +++ b/hessian/model_debugger.py @@ -34,7 +34,7 @@ def qvalue(array): - return jnp.linalg.norm(array.reshape(-1))**2 / array.size + return jnp.linalg.norm(array.reshape(-1)) ** 2 / array.size def cvalue(activations): @@ -55,10 +55,9 @@ def tag_qcvalue(module, activations, name): module.sow('qcvalues', name, qc_value) -def tag_residual_activations(module, - identity_path, - other_path, - name='residual'): +def tag_residual_activations( + module, identity_path, other_path, name='residual' +): """Used in inspecting the forward pass and residual networks. Residual connections involve adding x + F(x) for some function F. This @@ -79,7 +78,8 @@ def tag_residual_activations(module, 'qvalues', name + 'q', # avoid scope collision with the cvalue jnp.array((res_values_q, add_values_q)), - reduce_fn=lambda x, y: y) + reduce_fn=lambda x, y: y, + ) res_values_c = cvalue(identity_path) add_values_c = cvalue(other_path) @@ -87,7 +87,8 @@ def tag_residual_activations(module, 'cvalues', name + 'c', jnp.array((res_values_c, add_values_c)), - reduce_fn=lambda x, y: y) + reduce_fn=lambda x, y: y, + ) def pmap_then_unreplicate(leaf_fn, axis_name='batch'): @@ -106,10 +107,12 @@ def pmap_then_unreplicate(leaf_fn, axis_name='batch'): the values on device 0. """ p_leaf_fn = jax.pmap(leaf_fn, axis_name=axis_name) + def tree_fn(*args): vals = p_leaf_fn(*args) vals_on_host = jax.tree.map(lambda x: x[0], vals) return vals_on_host + return tree_fn @@ -138,9 +141,9 @@ def append_pytree_leaves(full_pytree, to_append): return jax.tree.map(lambda x, y: array_append(y, x), to_append, full_pytree) -def create_forward_pass_stats_fn(apply_fn, - capture_activation_norms=False, - sown_collection_names=None): +def create_forward_pass_stats_fn( + apply_fn, capture_activation_norms=False, sown_collection_names=None +): """Creates a function which grabs intermediate values from forward pass. If capture_activation_norms=True then we run the forward pass with @@ -168,6 +171,7 @@ def create_forward_pass_stats_fn(apply_fn, mutables = ['intermediates', 'batch_stats'] if sown_collection_names is not None: mutables.extend(sown_collection_names) + def get_forward_pass_statistics(params, batch, rng): _, forward_pass_statistics = apply_fn( params=params, @@ -176,7 +180,8 @@ def get_forward_pass_statistics(params, batch, rng): rngs={'dropout': rng}, capture_intermediates=capture_activation_norms, mutable=mutables, - train=True) + train=True, + ) forward_pass_statistics = flax.core.unfreeze(forward_pass_statistics) # We run in train mode but throw away the updated batch stats because we @@ -186,14 +191,17 @@ def get_forward_pass_statistics(params, batch, rng): if 'intermediates' in forward_pass_statistics: # This calculation corresponds to the average q-value across the batch. forward_pass_statistics['intermediate_qvalue'] = jax.tree.map( - qvalue, forward_pass_statistics['intermediates']) + qvalue, forward_pass_statistics['intermediates'] + ) forward_pass_statistics['intermediate_cvalue'] = jax.tree.map( - cvalue, forward_pass_statistics['intermediates']) + cvalue, forward_pass_statistics['intermediates'] + ) # Don't want to store the full activations. forward_pass_statistics.pop('intermediates') return forward_pass_statistics + return get_forward_pass_statistics @@ -215,7 +223,9 @@ def remove_leaf_tuples(tree): def unflatten(d): return flax.core.freeze( flax.traverse_util.unflatten_dict( - {tuple(k.split('/')): v for k, v in d.items()})) + {tuple(k.split('/')): v for k, v in d.items()} + ) + ) # Not a JAX Array Type - just an empty holder of static information: @@ -282,15 +292,17 @@ def build_skip_flags(paths_to_skip): class ModelDebugger: """Debugging tool for internal layers of a model.""" - def __init__(self, - forward_pass=None, - grad_fn=None, - use_pmap=True, - save_every=1, - metrics_logger=None, - pytree_metrics_logger=None, - skip_flags=None, - skip_groups=None): + def __init__( + self, + forward_pass=None, + grad_fn=None, + use_pmap=True, + save_every=1, + metrics_logger=None, + pytree_metrics_logger=None, + skip_flags=None, + skip_groups=None, + ): """Used to inspect a models forward and backward pass. The following keys are required in config - @@ -315,8 +327,10 @@ def __init__(self, attention layer jacobians contribute to the vanishing gradient problem"? """ if metrics_logger and (metrics_logger._json_path is None): - raise ValueError('To use the ModelDebugger with a metrics_logger, a json' - ' path must be specified when building metrics_logger') + raise ValueError( + 'To use the ModelDebugger with a metrics_logger, a json' + ' path must be specified when building metrics_logger' + ) self._save_every = save_every self._metrics_logger = metrics_logger self._pytree_metrics_logger = pytree_metrics_logger @@ -349,31 +363,36 @@ def __init__(self, ): self._stored_metrics = pytree_metrics_logger.load_latest_pytree() - def _grab_statistics(self, - step, - batch=None, - rng=None, - params=None, - grad=None, - update=None, - grad_norms_sql2=None, - update_norms_sql2=None, - param_norms_sql2=None): + def _grab_statistics( + self, + step, + batch=None, + rng=None, + params=None, + grad=None, + update=None, + grad_norms_sql2=None, + update_norms_sql2=None, + param_norms_sql2=None, + ): """Computes layerwise gradient and parameter norm statistics.""" metrics_dict = {'step': step} if grad is None and grad_norms_sql2 is None and self.grad_fn is not None: grad = self.grad_fn(params, batch, rng) - for tup in zip([params, grad, update], - [param_norms_sql2, grad_norms_sql2, update_norms_sql2], - ['param', 'grad', 'update']): + for tup in zip( + [params, grad, update], + [param_norms_sql2, grad_norms_sql2, update_norms_sql2], + ['param', 'grad', 'update'], + ): variable_tree, norm_tree_sql2, key = tup if variable_tree and not norm_tree_sql2: norm_tree_sql2 = self._tree_norm_fn_sql2(variable_tree) if norm_tree_sql2: metrics_dict['{}_norms_sql2'.format(key)] = norm_tree_sql2 metrics_dict['global_{}_norm_sql2'.format(key)] = sum( - jax.tree.leaves(norm_tree_sql2)) + jax.tree.leaves(norm_tree_sql2) + ) return metrics_dict @@ -451,18 +470,20 @@ def run_skip_analysis(self, params, batch, rng): grad_dict['unmodified_gradient'] = self._tree_norm_fn_sql2(new_g) return {'skip_analysis': grad_dict} - def full_eval(self, - step, - params=None, - grad=None, - update=None, - fwd_pass_summaries=None, - param_norms_sql2=None, - grad_norms_sql2=None, - update_norms_sql2=None, - extra_scalar_metrics=None, - batch=None, - rng=None): + def full_eval( + self, + step, + params=None, + grad=None, + update=None, + fwd_pass_summaries=None, + param_norms_sql2=None, + grad_norms_sql2=None, + update_norms_sql2=None, + extra_scalar_metrics=None, + batch=None, + rng=None, + ): """Computes statistics of the forward and backward pass and save to disk. Currently what is written to disk is a pytree, with a dict at the top level @@ -511,8 +532,13 @@ def full_eval(self, """ all_metrics = {'step': step} if any([ - params, grad, update, param_norms_sql2, grad_norms_sql2, - update_norms_sql2, self.grad_fn + params, + grad, + update, + param_norms_sql2, + grad_norms_sql2, + update_norms_sql2, + self.grad_fn, ]): metrics_dict = self._grab_statistics( step=step, @@ -523,7 +549,8 @@ def full_eval(self, update=update, grad_norms_sql2=grad_norms_sql2, update_norms_sql2=update_norms_sql2, - param_norms_sql2=param_norms_sql2) + param_norms_sql2=param_norms_sql2, + ) all_metrics.update(metrics_dict) if extra_scalar_metrics: @@ -535,7 +562,8 @@ def full_eval(self, if self.forward_pass and not fwd_pass_summaries: if batch is None: raise ValueError( - 'Must supply a batch when computing forward pass stats.') + 'Must supply a batch when computing forward pass stats.' + ) fwd_pass_summaries = self.forward_pass(params, batch, rng) @@ -554,7 +582,8 @@ def full_eval(self, self._stored_metrics = append_pytree_leaves( self._stored_metrics, - remove_leaf_tuples(flax.core.unfreeze(all_metrics))) + remove_leaf_tuples(flax.core.unfreeze(all_metrics)), + ) if self._metrics_logger and jax.process_index() == 0: self._maybe_save_metrics(step) diff --git a/hessian/model_debugger_callback.py b/hessian/model_debugger_callback.py index 176f3f0c..2b90d934 100644 --- a/hessian/model_debugger_callback.py +++ b/hessian/model_debugger_callback.py @@ -30,12 +30,9 @@ } -def get_grad(params, - batch, - rng, - batch_stats=None, - module_flags=None, - training_cost=None): +def get_grad( + params, batch, rng, batch_stats=None, module_flags=None, training_cost=None +): """Single step of the training loop. Args: @@ -57,13 +54,11 @@ def get_grad(params, kwargs = {'module_flags': module_flags} else: kwargs = {} + def opt_cost(params): return training_cost( - params, - batch, - batch_stats=batch_stats, - dropout_rng=rng, - **kwargs) + params, batch, batch_stats=batch_stats, dropout_rng=rng, **kwargs + ) grad_fn = jax.value_and_grad(opt_cost, has_aux=True) _, grad = grad_fn(params) @@ -104,7 +99,8 @@ def __init__( get_act_stats_fn = model_debugger.create_forward_pass_stats_fn( model.apply_on_batch, capture_activation_norms=True, - sown_collection_names=callback_config.get('sown_collection_names')) + sown_collection_names=callback_config.get('sown_collection_names'), + ) batch_stats = jax.tree.map(lambda x: x[:][0], batch_stats) grad_fn = functools.partial( get_grad, @@ -117,7 +113,8 @@ def __init__( metrics_logger=logger, grad_fn=grad_fn, skip_flags=callback_config.get('skip_flags'), - skip_groups=callback_config.get('skip_groups')) + skip_groups=callback_config.get('skip_groups'), + ) # pmap functions for the training loop # in_axes = (params = 0, batch_stats = 0, batch = 0, step = None, # lr = None, rng = None, local_device_index = 0, training_metrics_grabber=0, @@ -150,14 +147,16 @@ def run_eval(self, params, batch_stats, optimizer_state, global_step): """ del optimizer_state del batch_stats - p_norms = jax.tree.map(lambda x: jnp.linalg.norm(x[0].reshape(-1))**2, - params) + p_norms = jax.tree.map( + lambda x: jnp.linalg.norm(x[0].reshape(-1)) ** 2, params + ) self.debugger.full_eval( step=global_step, params=params, param_norms_sql2=p_norms, batch=self.batch, - rng=self.batch_rng) + rng=self.batch_rng, + ) return {} diff --git a/hessian/precondition.py b/hessian/precondition.py index 77eefd11..72aec334 100644 --- a/hessian/precondition.py +++ b/hessian/precondition.py @@ -55,11 +55,13 @@ def _make_inv_power_preconditioner(diag, eps, eps_root=0.0, power=0.5): # KS transforms for which preconditioning is implemented. -SUPPORTED_KS_TRANSFORMS = ['scale_by_adam', - 'scale_by_nadam', - 'scale_by_amsgrad', - 'precondition_by_rms', - 'precondition_by_rss'] +SUPPORTED_KS_TRANSFORMS = [ + 'scale_by_adam', + 'scale_by_nadam', + 'scale_by_amsgrad', + 'precondition_by_rms', + 'precondition_by_rss', +] def _make_ks_preconditioner(element, state, hps): @@ -75,31 +77,37 @@ def _make_ks_preconditioner(element, state, hps): """ err_msg = 'all KS hps must be set in order to compute preconditioned Hessian' if element == 'scale_by_adam' or element == 'scale_by_nadam': - if not ('b2' in hps and 'debias' in hps - and 'eps' in hps and 'eps_root' in hps): + if not ( + 'b2' in hps and 'debias' in hps and 'eps' in hps and 'eps_root' in hps + ): raise ValueError(err_msg) nu = _maybe_bias_correct(state.nu, hps['b2'], state.count, hps['debias']) - return _make_inv_power_preconditioner(nu, hps['eps'], hps['eps_root'], - hps.get('power', 0.5)) + return _make_inv_power_preconditioner( + nu, hps['eps'], hps['eps_root'], hps.get('power', 0.5) + ) elif element == 'scale_by_amsgrad': if not ('eps' in hps and 'eps_root' in hps): raise ValueError(err_msg) return _make_inv_power_preconditioner(state.nu, hps['eps'], hps['eps_root']) elif element == 'precondition_by_rms': - if not ('decay' in hps and 'debias' in hps - and 'eps' in hps and 'eps_root' in hps): + if not ( + 'decay' in hps + and 'debias' in hps + and 'eps' in hps + and 'eps_root' in hps + ): raise ValueError(err_msg) nu = _maybe_bias_correct(state.nu, hps['decay'], state.count, hps['debias']) return _make_inv_power_preconditioner(nu, hps['eps'], hps['eps_root']) elif element == 'precondition_by_rss': if 'eps' not in hps: raise ValueError(err_msg) - return _make_inv_power_preconditioner(state.sum_of_squares, - hps['eps'], 0.0) + return _make_inv_power_preconditioner(state.sum_of_squares, hps['eps'], 0.0) -def make_diag_preconditioner(optimizer, opt_hparams, - optimizer_state, precondition_config): +def make_diag_preconditioner( + optimizer, opt_hparams, optimizer_state, precondition_config +): """Construct a diagonal preconditioner. Given an optimizer and its state, return that optimizer's preconditioner. @@ -119,6 +127,7 @@ def make_diag_preconditioner(optimizer, opt_hparams, opt_hparams: (dict) The opt_hparams dict from the init2winit config. optimizer_state: (pytree) The unreplicated optimizer state. precondition_config: (ConfigDict) Configs for the preconditioner. + Returns: (pytree) diagonal preconditioner """ @@ -138,7 +147,8 @@ def make_diag_preconditioner(optimizer, opt_hparams, # a single KS transform chain, and a single preconditioner in that chain. if optimizer == 'kitchen_sink' and '0' in opt_hparams: precondition_steps = [ - step for step in opt_hparams.keys() + step + for step in opt_hparams.keys() if opt_hparams[step]['element'] in SUPPORTED_KS_TRANSFORMS ] if len(precondition_steps) != 1: diff --git a/hessian/test_model_debugger.py b/hessian/test_model_debugger.py index 9e92a005..ce74e911 100644 --- a/hessian/test_model_debugger.py +++ b/hessian/test_model_debugger.py @@ -123,7 +123,8 @@ def test_model_debugger(self): debugger = model_debugger.ModelDebugger(use_pmap=False) rep_rng = flax.jax_utils.replicate(rng) metrics = debugger.full_eval( - 10, params=variables['params'], grad=variables['params'], rng=rep_rng) + 10, params=variables['params'], grad=variables['params'], rng=rep_rng + ) expected_keys = [ 'step', 'global_param_norm_sql2', @@ -150,23 +151,33 @@ def apply_on_batch(params, batch_stats, batch, **apply_kwargs): get_act_stats_fn = model_debugger.create_forward_pass_stats_fn( apply_on_batch, capture_activation_norms=True, - sown_collection_names=['qvalues']) + sown_collection_names=['qvalues'], + ) debugger = model_debugger.ModelDebugger( - forward_pass=get_act_stats_fn, use_pmap=True) + forward_pass=get_act_stats_fn, use_pmap=True + ) metrics = debugger.full_eval( - step=0, params=rep_params, batch=xs, rng=rep_rng) + step=0, params=rep_params, batch=xs, rng=rep_rng + ) expected_output = np.dot(xs[0], params['Dense_0']['kernel']) - expected_q_value = np.linalg.norm(expected_output)**2 / expected_output.size + expected_q_value = ( + np.linalg.norm(expected_output) ** 2 / expected_output.size + ) expected_c_value = model_debugger.cvalue(expected_output) - expected_output_norm = np.linalg.norm( - expected_output)**2 / expected_output.size - expected_input_norm = float(np.linalg.norm(xs))**2 / xs.size + expected_output_norm = ( + np.linalg.norm(expected_output) ** 2 / expected_output.size + ) + expected_input_norm = float(np.linalg.norm(xs)) ** 2 / xs.size expected_keys = [ - 'qvalues', 'intermediate_qvalue', 'intermediate_cvalue', 'step', - 'param_norms_sql2', 'global_param_norm_sql2' + 'qvalues', + 'intermediate_qvalue', + 'intermediate_cvalue', + 'step', + 'param_norms_sql2', + 'global_param_norm_sql2', ] self.assertEqual(set(expected_keys), set(metrics.keys())) @@ -174,19 +185,21 @@ def apply_on_batch(params, batch_stats, batch, **apply_kwargs): self.assertAlmostEqual( float(expected_q_value), float(metrics['intermediate_qvalue']['__call__'][0]), - places=5) + places=5, + ) self.assertAlmostEqual( float(expected_c_value), float(metrics['intermediate_cvalue']['__call__'][0]), - places=5) + places=5, + ) self.assertAlmostEqual( - expected_input_norm, - float(metrics['qvalues']['residualq'][0]), - places=5) + expected_input_norm, float(metrics['qvalues']['residualq'][0]), places=5 + ) self.assertAlmostEqual( expected_output_norm, float(metrics['qvalues']['residualq'][1]), - places=5) + places=5, + ) def test_skip_analysis(self): """Test that we can selectively turn off layers in the backward pass.""" @@ -208,8 +221,8 @@ class B(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense( - 1, kernel_init=nn.initializers.constant(3), use_bias=False)( - x) + 1, kernel_init=nn.initializers.constant(3), use_bias=False + )(x) y = C()(x) return y @@ -219,8 +232,8 @@ class C(nn.Module): @nn.compact def __call__(self, x): x = nn.Dense( - 1, kernel_init=nn.initializers.constant(2), use_bias=False)( - x) + 1, kernel_init=nn.initializers.constant(2), use_bias=False + )(x) return x init_rng = jax.random.PRNGKey(0) @@ -244,15 +257,21 @@ def fake_loss_fn(ps, x, rng, module_flags=None): use_pmap=True, grad_fn=grad_fn, skip_flags=['B_0/C_0'], - skip_groups=['test_group']) + skip_groups=['test_group'], + ) rep_params = flax.jax_utils.replicate(ps) rep_rng = flax.jax_utils.replicate(rng) rep_x = flax.jax_utils.replicate(x) metrics = debugger.full_eval( - step=10, params=rep_params, rng=rep_rng, batch=rep_x) + step=10, params=rep_params, rng=rep_rng, batch=rep_x + ) expected_keys = [ - 'step', 'global_param_norm_sql2', 'param_norms_sql2', 'grad_norms_sql2', - 'global_grad_norm_sql2', 'skip_analysis' + 'step', + 'global_param_norm_sql2', + 'param_norms_sql2', + 'grad_norms_sql2', + 'global_grad_norm_sql2', + 'skip_analysis', ] self.assertEqual(set(expected_keys), set(metrics.keys())) skip_dict = metrics['skip_analysis'] @@ -286,8 +305,7 @@ def test_model_debugger_restore(self): rep_variables = set_up_cnn() pytree_path = os.path.join(self.test_dir, 'metrics') - pytree_metrics_logger = utils.PytreeMetricLogger( - pytree_path=pytree_path) + pytree_metrics_logger = utils.PytreeMetricLogger(pytree_path=pytree_path) metrics_logger = utils.MetricLogger(self.test_dir) # Fake grad_fn for testing. @@ -299,7 +317,8 @@ def grad_fn(params, batch, rng): use_pmap=False, grad_fn=grad_fn, metrics_logger=metrics_logger, - pytree_metrics_logger=pytree_metrics_logger) + pytree_metrics_logger=pytree_metrics_logger, + ) # eval twice to test the concat extra_metrics = {'train_loss': 1.0} @@ -308,12 +327,14 @@ def grad_fn(params, batch, rng): 10, params=rep_variables['params'], grad=rep_variables['params'], - extra_scalar_metrics=extra_metrics) + extra_scalar_metrics=extra_metrics, + ) metrics = debugger.full_eval( 20, params=rep_variables['params'], grad=None, # use internal gradient comp - extra_scalar_metrics=extra_metrics2) + extra_scalar_metrics=extra_metrics2, + ) expected_keys = [ 'step', 'global_param_norm_sql2', @@ -331,28 +352,35 @@ def grad_fn(params, batch, rng): self.assertEqual(metrics['global_grad_norm_sql2'].shape, expected_shape) # Test stored metrics is concatenated. expected_shape = (2,) - self.assertEqual(loaded_metrics['global_grad_norm_sql2'].shape, - expected_shape) + self.assertEqual( + loaded_metrics['global_grad_norm_sql2'].shape, expected_shape + ) # check param norms were saved correctly self.assertEqual( - loaded_metrics['param_norms_sql2']['Conv_0']['kernel'].shape, (2,)) + loaded_metrics['param_norms_sql2']['Conv_0']['kernel'].shape, (2,) + ) self.assertEqual(loaded_metrics['train_loss'][0], 1.0) # Test restore of prior metrics. new_debugger = model_debugger.ModelDebugger( use_pmap=True, metrics_logger=metrics_logger, - pytree_metrics_logger=pytree_metrics_logger) + pytree_metrics_logger=pytree_metrics_logger, + ) _ = new_debugger.full_eval( 30, params=rep_variables['params'], grad=rep_variables['params'], - extra_scalar_metrics=extra_metrics2) + extra_scalar_metrics=extra_metrics2, + ) pytree_metrics_logger.wait_until_pytree_checkpoint_finished() self.assertEqual( - new_debugger.stored_metrics['param_norms_sql2']['Conv_0'] - ['kernel'].shape, (3,)) + new_debugger.stored_metrics['param_norms_sql2']['Conv_0'][ + 'kernel' + ].shape, + (3,), + ) if __name__ == '__main__': diff --git a/hessian/test_precondition.py b/hessian/test_precondition.py index 6cefd5a1..1ffaa9b4 100644 --- a/hessian/test_precondition.py +++ b/hessian/test_precondition.py @@ -27,8 +27,7 @@ import optax -def _calculate_adam_preconditioner(gradients, beta2, epsilon, - bias_correct): +def _calculate_adam_preconditioner(gradients, beta2, epsilon, bias_correct): """Compute the Adam preconditioner after several steps of training. Args: @@ -43,9 +42,11 @@ def _calculate_adam_preconditioner(gradients, beta2, epsilon, nu = jax.tree.map(lambda x: 0.0, gradients[0]) for gradient in gradients: gradient_sq = jax.tree.map(jnp.square, gradient) - nu = jax.tree.map(lambda nu, g: beta2*nu + (1 - beta2)*g, nu, gradient_sq) + nu = jax.tree.map( + lambda nu, g: beta2 * nu + (1 - beta2) * g, nu, gradient_sq + ) if bias_correct: - nu = jax.tree.map(lambda nu: nu / (1 - beta2**(len(gradients))), nu) + nu = jax.tree.map(lambda nu: nu / (1 - beta2 ** (len(gradients))), nu) return jax.tree.map(lambda nu: jnp.sqrt(nu) + epsilon, nu) @@ -61,11 +62,9 @@ def test_adam(self): beta2 = 0.999 epsilon = 1e-7 - opt_hparams = FrozenConfigDict({ - 'beta1': beta1, - 'beta2': beta2, - 'epsilon': epsilon - }) + opt_hparams = FrozenConfigDict( + {'beta1': beta1, 'beta2': beta2, 'epsilon': epsilon} + ) hparams = FrozenConfigDict({ 'optimizer': 'adam', 'opt_hparams': opt_hparams, @@ -77,8 +76,10 @@ def test_adam(self): init_fn, update_fn = optimizers.get_optimizer(hparams) params = {'foo': 1.0, 'bar': {'baz': 3.0}} - gradients = [{'foo': 0.5, 'bar': {'baz': 0.1}}, - {'foo': 0.2, 'bar': {'baz': 0.6}}] + gradients = [ + {'foo': 0.5, 'bar': {'baz': 0.1}}, + {'foo': 0.2, 'bar': {'baz': 0.6}}, + ] optimizer_state = init_fn(params) optimizer_state.hyperparams['learning_rate'] = lr @@ -89,21 +90,29 @@ def test_adam(self): # yes bias correction expected_preconditioner = _calculate_adam_preconditioner( - gradients, beta2, epsilon, bias_correct=True) + gradients, beta2, epsilon, bias_correct=True + ) preconditioner = make_diag_preconditioner( - 'adam', opt_hparams, optimizer_state, - FrozenConfigDict(dict(bias_correction=True))) + 'adam', + opt_hparams, + optimizer_state, + FrozenConfigDict(dict(bias_correction=True)), + ) self.assertTrue(pytree_allclose(expected_preconditioner, preconditioner)) # no bias correction expected_preconditioner = _calculate_adam_preconditioner( - gradients, beta2, epsilon, bias_correct=False) + gradients, beta2, epsilon, bias_correct=False + ) preconditioner = make_diag_preconditioner( - 'adam', opt_hparams, optimizer_state, - FrozenConfigDict(dict(bias_correction=False))) + 'adam', + opt_hparams, + optimizer_state, + FrozenConfigDict(dict(bias_correction=False)), + ) self.assertTrue(pytree_allclose(expected_preconditioner, preconditioner)) @@ -116,10 +125,7 @@ def test_adam_ks(self): optimizer = 'kitchen_sink' opt_hparams = FrozenConfigDict({ - '0': { - 'element': 'sanitize_values', - 'hps': {} - }, + '0': {'element': 'sanitize_values', 'hps': {}}, '1': { 'element': 'scale_by_adam', 'hps': { @@ -127,15 +133,10 @@ def test_adam_ks(self): 'b2': beta2, 'eps': epsilon, 'eps_root': 0.0, - 'debias': True - } - }, - '2': { - 'element': 'clip_updates', - 'hps': { - 'clip_threshold': 1000.0 - } + 'debias': True, + }, }, + '2': {'element': 'clip_updates', 'hps': {'clip_threshold': 1000.0}}, }) hparams = FrozenConfigDict({ @@ -149,27 +150,28 @@ def test_adam_ks(self): init_fn, update_fn = optimizers.get_optimizer(hparams) params = {'foo': 1.0, 'bar': {'baz': 3.0}} - gradients = [{'foo': 0.5, 'bar': {'baz': 0.1}}, - {'foo': 0.2, 'bar': {'baz': 0.6}}] + gradients = [ + {'foo': 0.5, 'bar': {'baz': 0.1}}, + {'foo': 0.2, 'bar': {'baz': 0.6}}, + ] optimizer_state = init_fn(params) optimizer_state.hyperparams['learning_rate'] = lr for gradient in gradients: - updates, optimizer_state = update_fn(gradient, - optimizer_state, - params) + updates, optimizer_state = update_fn(gradient, optimizer_state, params) params = optax.apply_updates(params, updates) expected_preconditioner = _calculate_adam_preconditioner( - gradients, beta2, epsilon, bias_correct=True) + gradients, beta2, epsilon, bias_correct=True + ) preconditioner = make_diag_preconditioner( - optimizer, opt_hparams, optimizer_state, - FrozenConfigDict(dict())) + optimizer, opt_hparams, optimizer_state, FrozenConfigDict(dict()) + ) + + self.assertTrue(pytree_allclose(expected_preconditioner, preconditioner)) - self.assertTrue(pytree_allclose( - expected_preconditioner, preconditioner)) if __name__ == '__main__': absltest.main() diff --git a/init2winit/base_callback.py b/init2winit/base_callback.py index 835c0f01..c4ef8317 100644 --- a/init2winit/base_callback.py +++ b/init2winit/base_callback.py @@ -37,9 +37,20 @@ class BaseCallBack: """Base callback to specify the required API.""" - def __init__(self, model, params, batch_stats, optimizer_state, - optimizer_update_fn, dataset, hps, callback_config, train_dir, - rng, mesh): + def __init__( + self, + model, + params, + batch_stats, + optimizer_state, + optimizer_update_fn, + dataset, + hps, + callback_config, + train_dir, + rng, + mesh, + ): """Defines the API for callback construction.""" pass diff --git a/init2winit/callbacks.py b/init2winit/callbacks.py index 42940e00..98d6db9a 100644 --- a/init2winit/callbacks.py +++ b/init2winit/callbacks.py @@ -19,7 +19,6 @@ from init2winit.hessian import model_debugger_callback from init2winit.mt_eval import mt_callback - _ALL_CALLBACKS = { 'mt': mt_callback.MTEvaluationCallback, 'model_debugger': model_debugger_callback.ModelDebugCallback, diff --git a/init2winit/checkpoint.py b/init2winit/checkpoint.py index 49ed335f..aa6364af 100644 --- a/init2winit/checkpoint.py +++ b/init2winit/checkpoint.py @@ -18,9 +18,11 @@ This is useful for training neural networks with stax, where model parameters are nested numpy arrays. """ + from absl import flags from absl import logging import jax + # pylint: disable=g-importing-member from jax.experimental.multihost_utils import process_allgather import orbax.checkpoint as ocp @@ -49,7 +51,8 @@ def maybe_restore_checkpoint( unreplicated_batch_stats, unreplicated_training_metrics_state, orbax_checkpoint_manager=None, - orbax_checkpoint_manager_external=None): + orbax_checkpoint_manager_external=None, +): """Optionally restores from a checkpoint. The checkpoint logic is as follows: if `orbax_checkpoint_manager` contains @@ -96,7 +99,7 @@ def maybe_restore_checkpoint( # train_dir does not exist or if it exists and contains no checkpoints. # Note that we could likely change the below line to: # found_checkpoint = latest_ckpt != unreplicated_checkpoint_state - found_checkpoint = (latest_ckpt['global_step'] != uninitialized_global_step) + found_checkpoint = latest_ckpt['global_step'] != uninitialized_global_step # If there's a latest checkpoint in the train_dir, restore from that. if found_checkpoint: @@ -123,7 +126,8 @@ def maybe_restore_checkpoint( 0, # global_step 0, # sum_train_cost 0, # preemption_count - False) # is_restored + False, + ) # is_restored else: # Else, don't restore from any checkpoint. return ( unreplicated_optimizer_state, @@ -133,7 +137,8 @@ def maybe_restore_checkpoint( 0, # global_step 0, # sum_train_cost 0, # preemption_count - False) # is_restored + False, + ) # is_restored return ( ckpt_to_return['optimizer_state'], @@ -143,7 +148,8 @@ def maybe_restore_checkpoint( ckpt_to_return['global_step'], # global_step ckpt_to_return['sum_train_cost'], ckpt_to_return['preemption_count'], # preemption_count - is_restored) # is_restored + is_restored, + ) # is_restored def unreplicate_and_save_checkpoint( @@ -154,38 +160,39 @@ def unreplicate_and_save_checkpoint( global_step, preemption_count, sum_train_cost, - orbax_checkpoint_manager): + orbax_checkpoint_manager, +): """Saves pytree, step, preemption_count, and sum_train_cost to train_dir.""" logging.info('Saving checkpoint to ckpt_%d', global_step) # jax.device_get doesn't work if jax.Array lives on multiple hosts. # So we first all_gather it to the host and then call jax.device_get if jax.process_count() > 1: unreplicated_optimizer_state = jax.device_get( - process_allgather(optimizer_state, tiled=True)) + process_allgather(optimizer_state, tiled=True) + ) unreplicated_params = jax.device_get(process_allgather(params, tiled=True)) else: unreplicated_optimizer_state = jax.device_get(optimizer_state) unreplicated_params = jax.device_get(params) unreplicated_batch_stats = jax.device_get(batch_stats) - unreplicated_training_metrics_state = jax.device_get( - training_metrics_state) + unreplicated_training_metrics_state = jax.device_get(training_metrics_state) unreplicated_sum_train_cost = jax.device_get(sum_train_cost) - state = dict(global_step=global_step, - preemption_count=preemption_count, - sum_train_cost=unreplicated_sum_train_cost, - optimizer_state=unreplicated_optimizer_state, - params=unreplicated_params, - batch_stats=unreplicated_batch_stats, - training_metrics_grabber=unreplicated_training_metrics_state) - save_checkpoint(global_step, - state, - orbax_checkpoint_manager=orbax_checkpoint_manager) + state = dict( + global_step=global_step, + preemption_count=preemption_count, + sum_train_cost=unreplicated_sum_train_cost, + optimizer_state=unreplicated_optimizer_state, + params=unreplicated_params, + batch_stats=unreplicated_batch_stats, + training_metrics_grabber=unreplicated_training_metrics_state, + ) + save_checkpoint( + global_step, state, orbax_checkpoint_manager=orbax_checkpoint_manager + ) logging.info('Done saving checkpoint.') -def save_checkpoint(step, - state, - orbax_checkpoint_manager): +def save_checkpoint(step, state, orbax_checkpoint_manager): """Saves checkpoint to train_dir. A list of checkpoints will be stored in train_dir/step. @@ -229,9 +236,10 @@ def load_latest_checkpoint(target=None, orbax_checkpoint_manager=None): """Loads the most recent checkpoint listed in train_dir. Args: - target: used for checkpointing, a pytree whose structure will be used - to structure the restored checkpoint data. + target: used for checkpointing, a pytree whose structure will be used to + structure the restored checkpoint data. orbax_checkpoint_manager: An orbax.CheckpointManager instance. + Returns: The state restored from the checkpoint. If using Flax checkpointing and target=None, this will return a unstructured dictionary containing the diff --git a/init2winit/dataset_lib/autoaugment.py b/init2winit/dataset_lib/autoaugment.py index 47bd45af..b4407c42 100644 --- a/init2winit/dataset_lib/autoaugment.py +++ b/init2winit/dataset_lib/autoaugment.py @@ -34,20 +34,22 @@ AutoAugment Reference: https://arxiv.org/abs/1805.09501 RandAugment Reference: https://arxiv.org/abs/1909.13719 """ + from __future__ import absolute_import from __future__ import division from __future__ import print_function import inspect import math + import tensorflow.compat.v1 as tf + from tensorflow.contrib import image as contrib_image from tensorflow.contrib import training as contrib_training - # This signifies the max integer that the controller RNN could predict for the # augmentation scheme. -_MAX_LEVEL = 10. +_MAX_LEVEL = 10.0 def policy_v0(): @@ -149,11 +151,10 @@ def cutout(image, pad_size, replace=0): Args: image: An image Tensor of type uint8. - pad_size: Specifies how big the zero mask that will be generated is that - is applied to the image. The mask will be of size - (2*pad_size x 2*pad_size). - replace: What pixel value to fill in the image in the area that has - the cutout mask applied to it. + pad_size: Specifies how big the zero mask that will be generated is that is + applied to the image. The mask will be of size (2*pad_size x 2*pad_size). + replace: What pixel value to fill in the image in the area that has the + cutout mask applied to it. Returns: An image Tensor that is of type uint8. @@ -163,30 +164,31 @@ def cutout(image, pad_size, replace=0): # Sample the center location in the image where the zero mask will be applied. cutout_center_height = tf.random_uniform( - shape=[], minval=0, maxval=image_height, - dtype=tf.int32) + shape=[], minval=0, maxval=image_height, dtype=tf.int32 + ) cutout_center_width = tf.random_uniform( - shape=[], minval=0, maxval=image_width, - dtype=tf.int32) + shape=[], minval=0, maxval=image_width, dtype=tf.int32 + ) lower_pad = tf.maximum(0, cutout_center_height - pad_size) upper_pad = tf.maximum(0, image_height - cutout_center_height - pad_size) left_pad = tf.maximum(0, cutout_center_width - pad_size) right_pad = tf.maximum(0, image_width - cutout_center_width - pad_size) - cutout_shape = [image_height - (lower_pad + upper_pad), - image_width - (left_pad + right_pad)] + cutout_shape = [ + image_height - (lower_pad + upper_pad), + image_width - (left_pad + right_pad), + ] padding_dims = [[lower_pad, upper_pad], [left_pad, right_pad]] mask = tf.pad( - tf.zeros(cutout_shape, dtype=image.dtype), - padding_dims, constant_values=1) + tf.zeros(cutout_shape, dtype=image.dtype), padding_dims, constant_values=1 + ) mask = tf.expand_dims(mask, -1) mask = tf.tile(mask, [1, 1, 3]) image = tf.where( - tf.equal(mask, 0), - tf.ones_like(image, dtype=image.dtype) * replace, - image) + tf.equal(mask, 0), tf.ones_like(image, dtype=image.dtype) * replace, image + ) return image @@ -250,8 +252,8 @@ def rotate(image, degrees, replace): degrees: Float, a scalar angle in degrees to rotate all images by. If degrees is positive the image will be rotated clockwise otherwise it will be rotated counterclockwise. - replace: A one or three value 1D tensor to fill empty pixels caused by - the rotate operation. + replace: A one or three value 1D tensor to fill empty pixels caused by the + rotate operation. Returns: The rotated version of image. @@ -286,7 +288,8 @@ def shear_x(image, level, replace): # [1 level # 0 1]. image = contrib_image.transform( - wrap(image), [1., level, 0., 0., 1., 0., 0., 0.]) + wrap(image), [1.0, level, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0] + ) return unwrap(image, replace) @@ -297,7 +300,8 @@ def shear_y(image, level, replace): # [1 0 # level 1]. image = contrib_image.transform( - wrap(image), [1., 0., 0., level, 1., 0., 0., 0.]) + wrap(image), [1.0, 0.0, 0.0, level, 1.0, 0.0, 0.0, 0.0] + ) return unwrap(image, replace) @@ -347,9 +351,14 @@ def sharpness(image, factor): # Make image 4D for conv operation. image = tf.expand_dims(image, 0) # SMOOTH PIL Kernel. - kernel = tf.constant( - [[1, 1, 1], [1, 5, 1], [1, 1, 1]], dtype=tf.float32, - shape=[3, 3, 1, 1]) / 13. + kernel = ( + tf.constant( + [[1, 1, 1], [1, 5, 1], [1, 1, 1]], + dtype=tf.float32, + shape=[3, 3, 1, 1], + ) + / 13.0 + ) # Tile across channel dimension. kernel = tf.tile(kernel, [1, 1, 3, 1]) strides = [1, 1, 1, 1] @@ -357,7 +366,8 @@ def sharpness(image, factor): # Some augmentation that uses depth-wise conv will cause crashing when # training on GPU. degenerate = tf.nn.depthwise_conv2d( - image, kernel, strides, padding='VALID', rate=[1, 1]) + image, kernel, strides, padding='VALID', rate=[1, 1] + ) degenerate = tf.clip_by_value(degenerate, 0.0, 255.0) degenerate = tf.squeeze(tf.cast(degenerate, tf.uint8), [0]) @@ -374,6 +384,7 @@ def sharpness(image, factor): def equalize(image): """Implements Equalize function from PIL using TF ops.""" + def scale_channel(im, c): """Scale the data in the channel to implement equalize.""" im = tf.cast(im[:, :, c], tf.int32) @@ -397,9 +408,11 @@ def build_lut(histo, step): # If step is zero, return the original image. Otherwise, build # lut from the full histogram and step and then index from it. - result = tf.cond(tf.equal(step, 0), - lambda: im, - lambda: tf.gather(build_lut(histo, step), im)) + result = tf.cond( + tf.equal(step, 0), + lambda: im, + lambda: tf.gather(build_lut(histo, step), im), + ) return tf.cast(result, tf.uint8) @@ -457,7 +470,8 @@ def unwrap(image, replace): flattened_image = tf.where( tf.equal(alpha_channel, 0), tf.ones_like(flattened_image, dtype=image.dtype) * replace, - flattened_image) + flattened_image, + ) image = tf.reshape(flattened_image, image_shape) image = tf.slice(image, [0, 0, 0], [image_shape[0], image_shape[1], 3]) @@ -492,7 +506,7 @@ def _randomly_negate_tensor(tensor): def _rotate_level_to_arg(level): - level = (level/_MAX_LEVEL) * 30. + level = (level / _MAX_LEVEL) * 30.0 level = _randomly_negate_tensor(level) return (level,) @@ -502,23 +516,23 @@ def _shrink_level_to_arg(level): if level == 0: return (1.0,) # if level is zero, do not shrink the image # Maximum shrinking ratio is 2.9. - level = 2. / (_MAX_LEVEL / level) + 0.9 + level = 2.0 / (_MAX_LEVEL / level) + 0.9 return (level,) def _enhance_level_to_arg(level): - return ((level/_MAX_LEVEL) * 1.8 + 0.1,) + return ((level / _MAX_LEVEL) * 1.8 + 0.1,) def _shear_level_to_arg(level): - level = (level/_MAX_LEVEL) * 0.3 + level = (level / _MAX_LEVEL) * 0.3 # Flip level to negative with 50% chance. level = _randomly_negate_tensor(level) return (level,) def _translate_level_to_arg(level, translate_const): - level = (level/_MAX_LEVEL) * float(translate_const) + level = (level / _MAX_LEVEL) * float(translate_const) # Flip level to negative with 50% chance. level = _randomly_negate_tensor(level) return (level,) @@ -530,21 +544,25 @@ def level_to_arg(hparams): 'Equalize': lambda level: (), 'Invert': lambda level: (), 'Rotate': _rotate_level_to_arg, - 'Posterize': lambda level: (int((level/_MAX_LEVEL) * 4),), - 'Solarize': lambda level: (int((level/_MAX_LEVEL) * 256),), - 'SolarizeAdd': lambda level: (int((level/_MAX_LEVEL) * 110),), + 'Posterize': lambda level: (int((level / _MAX_LEVEL) * 4),), + 'Solarize': lambda level: (int((level / _MAX_LEVEL) * 256),), + 'SolarizeAdd': lambda level: (int((level / _MAX_LEVEL) * 110),), 'Color': _enhance_level_to_arg, 'Contrast': _enhance_level_to_arg, 'Brightness': _enhance_level_to_arg, 'Sharpness': _enhance_level_to_arg, 'ShearX': _shear_level_to_arg, 'ShearY': _shear_level_to_arg, - 'Cutout': lambda level: (int((level/_MAX_LEVEL) * hparams.cutout_const),), + 'Cutout': lambda level: ( + int((level / _MAX_LEVEL) * hparams.cutout_const), + ), # pylint:disable=g-long-lambda 'TranslateX': lambda level: _translate_level_to_arg( - level, hparams.translate_const), + level, hparams.translate_const + ), 'TranslateY': lambda level: _translate_level_to_arg( - level, hparams.translate_const), + level, hparams.translate_const + ), # pylint:enable=g-long-lambda } @@ -589,11 +607,11 @@ def _apply_func_with_prob(func, image, args, prob): # Apply the function with probability `prob`. should_apply_op = tf.cast( - tf.floor(tf.random_uniform([], dtype=tf.float32) + prob), tf.bool) + tf.floor(tf.random_uniform([], dtype=tf.float32) + prob), tf.bool + ) augmented_image = tf.cond( - should_apply_op, - lambda: func(image, *args), - lambda: image) + should_apply_op, lambda: func(image, *args), lambda: image + ) return augmented_image @@ -602,16 +620,16 @@ def select_and_apply_random_policy(policies, image): policy_to_select = tf.random_uniform([], maxval=len(policies), dtype=tf.int32) # Note that using tf.case instead of tf.conds would result in significantly # larger graphs and would even break export for some larger policies. - for (i, policy) in enumerate(policies): + for i, policy in enumerate(policies): image = tf.cond( tf.equal(i, policy_to_select), lambda selected_policy=policy: selected_policy(image), - lambda: image) + lambda: image, + ) return image -def build_and_apply_nas_policy(policies, image, - augmentation_hparams): +def build_and_apply_nas_policy(policies, image, augmentation_hparams): """Build a policy from the given policies passed in and apply to image. Args: @@ -643,19 +661,20 @@ def build_and_apply_nas_policy(policies, image, policy_info = list(policy_info) + [replace_value, augmentation_hparams] tf_policy.append(_parse_policy_info(*policy_info)) + # Now build the tf policy that will apply the augmentation procedue # on image. def make_final_policy(tf_policy_): def final_policy(image_): for func, prob, args in tf_policy_: - image_ = _apply_func_with_prob( - func, image_, args, prob) + image_ = _apply_func_with_prob(func, image_, args, prob) return image_ + return final_policy + tf_policies.append(make_final_policy(tf_policy)) - augmented_image = select_and_apply_random_policy( - tf_policies, image) + augmented_image = select_and_apply_random_policy(tf_policies, image) return augmented_image @@ -667,25 +686,25 @@ def distort_image_with_autoaugment(image, augmentation_name): Args: image: `Tensor` of shape [height, width, 3] representing an image. augmentation_name: The name of the AutoAugment policy to use. The available - options are `v0` and `test`. `v0` is the policy used for - all of the results in the paper and was found to achieve the best results - on the COCO dataset. `v1`, `v2` and `v3` are additional good policies - found on the COCO dataset that have slight variation in what operations - were used during the search procedure along with how many operations are - applied in parallel to a single image (2 vs 3). + options are `v0` and `test`. `v0` is the policy used for all of the + results in the paper and was found to achieve the best results on the COCO + dataset. `v1`, `v2` and `v3` are additional good policies found on the + COCO dataset that have slight variation in what operations were used + during the search procedure along with how many operations are applied in + parallel to a single image (2 vs 3). Returns: A tuple containing the augmented versions of `image`. """ - available_policies = {'v0': policy_v0, - 'test': policy_vtest} + available_policies = {'v0': policy_v0, 'test': policy_vtest} if augmentation_name not in available_policies: raise ValueError('Invalid augmentation_name: {}'.format(augmentation_name)) policy = available_policies[augmentation_name]() # Hparams that will be used for AutoAugment. augmentation_hparams = contrib_training.HParams( - cutout_const=100, translate_const=250) + cutout_const=100, translate_const=250 + ) return build_and_apply_nas_policy(policy, image, augmentation_hparams) @@ -701,8 +720,8 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key): sequentially to an image. Represented as (N) in the paper. Usually best values will be in the range [1, 3]. magnitude: Integer, shared magnitude across all augmentation operations. - Represented as (M) in the paper. Usually best values are in the range - [5, 30]. + Represented as (M) in the paper. Usually best values are in the range [5, + 30]. key: an rng key from tf.random.experimental.stateless_fold_in. Returns: @@ -711,34 +730,49 @@ def distort_image_with_randaugment(image, num_layers, magnitude, key): replace_value = [128] * 3 tf.logging.info('Using RandAug.') augmentation_hparams = contrib_training.HParams( - cutout_const=40, translate_const=100) + cutout_const=40, translate_const=100 + ) available_ops = [ - 'AutoContrast', 'Equalize', 'Invert', 'Rotate', 'Posterize', - 'Solarize', 'Color', 'Contrast', 'Brightness', 'Sharpness', - 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Cutout', 'SolarizeAdd'] + 'AutoContrast', + 'Equalize', + 'Invert', + 'Rotate', + 'Posterize', + 'Solarize', + 'Color', + 'Contrast', + 'Brightness', + 'Sharpness', + 'ShearX', + 'ShearY', + 'TranslateX', + 'TranslateY', + 'Cutout', + 'SolarizeAdd', + ] for layer_num in range(num_layers): key = tf.random.experimental.stateless_fold_in(key, layer_num) - op_to_select = tf.random.stateless_uniform([], - seed=key, - maxval=len(available_ops), - dtype=tf.int32) + op_to_select = tf.random.stateless_uniform( + [], seed=key, maxval=len(available_ops), dtype=tf.int32 + ) random_magnitude = float(magnitude) with tf.name_scope('randaug_layer_{}'.format(layer_num)): - for (i, op_name) in enumerate(available_ops): + for i, op_name in enumerate(available_ops): key = tf.random.experimental.stateless_fold_in(key, i) - prob = tf.random.stateless_uniform([], - seed=key, - minval=0.2, - maxval=0.8, - dtype=tf.float32) - func, _, args = _parse_policy_info(op_name, prob, random_magnitude, - replace_value, augmentation_hparams) + prob = tf.random.stateless_uniform( + [], seed=key, minval=0.2, maxval=0.8, dtype=tf.float32 + ) + func, _, args = _parse_policy_info( + op_name, prob, random_magnitude, replace_value, augmentation_hparams + ) image = tf.cond( tf.equal(i, op_to_select), # pylint:disable=g-long-lambda lambda selected_func=func, selected_args=args: selected_func( - image, *selected_args), + image, *selected_args + ), # pylint:enable=g-long-lambda - lambda: image) + lambda: image, + ) return image diff --git a/init2winit/dataset_lib/criteo_terabyte_dataset.py b/init2winit/dataset_lib/criteo_terabyte_dataset.py index eea9316d..7133cb92 100644 --- a/init2winit/dataset_lib/criteo_terabyte_dataset.py +++ b/init2winit/dataset_lib/criteo_terabyte_dataset.py @@ -55,9 +55,32 @@ # Raw vocab sizes from # https://cloud.google.com/tpu/docs/tutorials/dlrm-dcn-2.x#run-model. _VOCAB_SIZES = [ - 39884406, 39043, 17289, 7420, 20263, 3, 7120, 1543, 63, 38532951, 2953546, - 403346, 10, 2208, 11938, 155, 4, 976, 14, 39979771, 25641295, 39664984, - 585935, 12972, 108, 36 + 39884406, + 39043, + 17289, + 7420, + 20263, + 3, + 7120, + 1543, + 63, + 38532951, + 2953546, + 403346, + 10, + 2208, + 11938, + 155, + 4, + 976, + 14, + 39979771, + 25641295, + 39664984, + 585935, + 12972, + 108, + 36, ] @@ -69,7 +92,8 @@ def _parse_example_fn(num_dense_features, example): categorical_defaults = [['00000000'] for _ in range(len(_VOCAB_SIZES))] record_defaults = label_defaults + int_defaults + categorical_defaults fields = tf.io.decode_csv( - example, record_defaults, field_delim='\t', na_value='-1') + example, record_defaults, field_delim='\t', na_value='-1' + ) num_labels = 1 features = {} @@ -87,9 +111,11 @@ def _parse_example_fn(num_dense_features, example): # We append the column index to the string to make the same id in different # columns unique. cat_features.append( - tf.strings.to_hash_bucket_fast(field + str(idx), _VOCAB_SIZES[idx])) + tf.strings.to_hash_bucket_fast(field + str(idx), _VOCAB_SIZES[idx]) + ) cat_features = tf.cast( - tf.stack(cat_features, axis=1), dtype=int_features.dtype) + tf.stack(cat_features, axis=1), dtype=int_features.dtype + ) features['inputs'] = tf.concat([int_features, cat_features], axis=1) return features @@ -100,7 +126,8 @@ def criteo_tsv_reader( file_path, num_dense_features, batch_size, - num_batches_to_prefetch): + num_batches_to_prefetch, +): """Input reader fn for pre-processed Criteo data. Raw Criteo data is assumed to be preprocessed in the following way: @@ -110,13 +137,14 @@ def criteo_tsv_reader( 4. Categorical data is bucketized and are hence tf.int32. Args: - split: a text string indicating which split, one of - {'train', 'eval_train', 'validation', 'test'}. + split: a text string indicating which split, one of {'train', 'eval_train', + 'validation', 'test'}. shuffle_rng: jax.random.PRNGKey used for shuffling, only used in training. file_path: filepath to the criteo dataset. num_dense_features: number of dense features. batch_size: per-host batch size. num_batches_to_prefetch: number of batches to prefetch. + Returns: A tf.data.Dataset object. """ @@ -128,16 +156,20 @@ def criteo_tsv_reader( is_training = split == 'train' if is_training: file_shuffle_seed, data_shuffle_seed = jax.random.split(shuffle_rng, 2) - file_shuffle_seed = data_utils.convert_jax_to_tf_random_seed(file_shuffle_seed) - data_shuffle_seed = data_utils.convert_jax_to_tf_random_seed(data_shuffle_seed) + file_shuffle_seed = data_utils.convert_jax_to_tf_random_seed( + file_shuffle_seed + ) + data_shuffle_seed = data_utils.convert_jax_to_tf_random_seed( + data_shuffle_seed + ) file_shuffle_seed = multihost_utils.broadcast_one_to_all( - file_shuffle_seed, - is_source=jax.process_index() == 0 + file_shuffle_seed, is_source=jax.process_index() == 0 ) ds = tf.data.Dataset.list_files( - file_path, shuffle=is_training, seed=file_shuffle_seed) + file_path, shuffle=is_training, seed=file_shuffle_seed + ) index = jax.process_index() num_hosts = jax.process_count() ds = ds.shard(num_hosts, index) @@ -148,7 +180,8 @@ def criteo_tsv_reader( cycle_length=64, block_length=batch_size // 8, num_parallel_calls=64, - deterministic=False) + deterministic=False, + ) if is_training: ds = ds.shuffle(buffer_size=524_288 * 100, seed=data_shuffle_seed) ds = ds.batch(batch_size, drop_remainder=is_training) @@ -253,13 +286,13 @@ def _eval_numpy_iterator( Caps the number of batches to the split size divided across hosts. If the source runs out before `num_batches`, yields zero-filled batches so every host sees the same count. - + Args: num_batches (int): number of batches to process. per_host_eval_batch_size (int): batch size per host. tf_dataset (tfds.Dataset): source tensorflow dataset. split_size (int): total number of examples in the eval split. - + Yields: Padded numpy batches. """ @@ -271,9 +304,13 @@ def _eval_numpy_iterator( # treat them all as being the same batch size. num_hosts = jax.process_count() num_batches_in_split = math.ceil( - split_size / (per_host_eval_batch_size * num_hosts)) - if (num_batches is None or num_batches < 0 or - num_batches > num_batches_in_split): + split_size / (per_host_eval_batch_size * num_hosts) + ) + if ( + num_batches is None + or num_batches < 0 + or num_batches > num_batches_in_split + ): logging.info('Setting num_batches to %d.', num_batches_in_split) num_batches = num_batches_in_split @@ -285,28 +322,33 @@ def _eval_numpy_iterator( except StopIteration: if zeros_batch is None: zeros_batch = jax.tree.map( - lambda x: np.zeros_like(x, dtype=x.dtype), batch) + lambda x: np.zeros_like(x, dtype=x.dtype), batch + ) yield zeros_batch continue batch = data_utils.maybe_pad_batch( - batch, desired_batch_size=per_host_eval_batch_size) + batch, desired_batch_size=per_host_eval_batch_size + ) yield batch -def get_criteo1tb(shuffle_rng, - batch_size, - eval_batch_size, - hps): +def get_criteo1tb(shuffle_rng, batch_size, eval_batch_size, hps): """Get the Criteo 1TB train and eval iterators.""" process_count = jax.process_count() if batch_size % process_count != 0: - raise ValueError('process_count={} must divide batch_size={}.'.format( - process_count, batch_size)) + raise ValueError( + 'process_count={} must divide batch_size={}.'.format( + process_count, batch_size + ) + ) if eval_batch_size is None: eval_batch_size = batch_size if eval_batch_size % process_count != 0: - raise ValueError('process_count={} must divide eval_batch_size={}.'.format( - process_count, eval_batch_size)) + raise ValueError( + 'process_count={} must divide eval_batch_size={}.'.format( + process_count, eval_batch_size + ) + ) per_host_eval_batch_size = eval_batch_size // process_count per_host_batch_size = batch_size // process_count @@ -446,7 +488,8 @@ def _get_criteo1tb_tsv( """Load Criteo 1TB from raw TSV files (legacy path).""" train_file_path = os.path.join(RAW_CRITEO1TB_FILE_PATH, 'train/*/*') validation_file_path = os.path.join( - RAW_CRITEO1TB_FILE_PATH, 'val_set_second_half_of_day23_not_used/*') + RAW_CRITEO1TB_FILE_PATH, 'val_set_second_half_of_day23_not_used/*' + ) test_file_path = os.path.join(RAW_CRITEO1TB_FILE_PATH, 'eval/day_23/*') train_dataset = criteo_tsv_reader( @@ -455,7 +498,8 @@ def _get_criteo1tb_tsv( file_path=train_file_path, num_dense_features=hps.num_dense_features, batch_size=per_host_batch_size, - num_batches_to_prefetch=num_batches_to_prefetch) + num_batches_to_prefetch=num_batches_to_prefetch, + ) data_utils.log_rss('train dataset created') if num_device_prefetches > 0: train_iterator_fn = lambda: data_utils.prefetch_iterator( @@ -472,7 +516,8 @@ def _get_criteo1tb_tsv( file_path=train_file_path, num_dense_features=hps.num_dense_features, batch_size=per_host_eval_batch_size, - num_batches_to_prefetch=num_batches_to_prefetch) + num_batches_to_prefetch=num_batches_to_prefetch, + ) eval_train_iterator_fn = functools.partial( _eval_numpy_iterator, per_host_eval_batch_size=per_host_eval_batch_size, @@ -486,7 +531,8 @@ def _get_criteo1tb_tsv( file_path=validation_file_path, num_dense_features=hps.num_dense_features, batch_size=per_host_eval_batch_size, - num_batches_to_prefetch=num_batches_to_prefetch) + num_batches_to_prefetch=num_batches_to_prefetch, + ) validation_iterator_fn = functools.partial( _eval_numpy_iterator, per_host_eval_batch_size=per_host_eval_batch_size, @@ -500,7 +546,8 @@ def _get_criteo1tb_tsv( file_path=test_file_path, num_dense_features=hps.num_dense_features, batch_size=per_host_eval_batch_size, - num_batches_to_prefetch=num_batches_to_prefetch) + num_batches_to_prefetch=num_batches_to_prefetch, + ) test_iterator_fn = functools.partial( _eval_numpy_iterator, per_host_eval_batch_size=per_host_eval_batch_size, @@ -524,13 +571,14 @@ def _get_criteo1tb_tsv( train_iterator_fn, eval_train_iterator_fn, validation_iterator_fn, - test_iterator_fn) + test_iterator_fn, + ) def get_fake_batch(hps): return { - 'inputs': - np.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype), - 'targets': - np.zeros((hps.batch_size,), dtype=hps.model_dtype), + 'inputs': np.zeros( + (hps.batch_size, *hps.input_shape), dtype=hps.model_dtype + ), + 'targets': np.zeros((hps.batch_size,), dtype=hps.model_dtype), } diff --git a/init2winit/dataset_lib/data_selectors.py b/init2winit/dataset_lib/data_selectors.py index d64b0e0e..6e54030a 100644 --- a/init2winit/dataset_lib/data_selectors.py +++ b/init2winit/dataset_lib/data_selectors.py @@ -23,7 +23,8 @@ def noop( batch_stats, hps, global_step, - constant_base_rng): + constant_base_rng, +): """An example no-op data selector that just yields the next batch. Args: @@ -35,8 +36,8 @@ def noop( global_step: the current global step. constant_base_rng: the RNG used for the experiment. IMPORTANT NOTE: this will be constant for all calls to this function, in order to get a unique - RNG each time we need to do - `rng = jax.random.fold_in(constant_base_rng, global_step)`. + RNG each time we need to do `rng = jax.random.fold_in(constant_base_rng, + global_step)`. Yields: A batch of data. @@ -57,7 +58,8 @@ def data_echoing( batch_stats, hps, global_step, - constant_base_rng): + constant_base_rng, +): """An example data echoing selector. Args: @@ -69,8 +71,8 @@ def data_echoing( global_step: the current global step. constant_base_rng: the RNG used for the experiment. IMPORTANT NOTE: this will be constant for all calls to this function, in order to get a unique - RNG each time we need to do - `rng = jax.random.fold_in(constant_base_rng, global_step)`. + RNG each time we need to do `rng = jax.random.fold_in(constant_base_rng, + global_step)`. Yields: A batch of data. diff --git a/init2winit/dataset_lib/data_utils.py b/init2winit/dataset_lib/data_utils.py index 5f0ec34d..1c644000 100644 --- a/init2winit/dataset_lib/data_utils.py +++ b/init2winit/dataset_lib/data_utils.py @@ -30,13 +30,15 @@ import jraph import numpy as np - -Dataset = collections.namedtuple('Dataset', [ - 'train_iterator_fn', - 'eval_train_epoch', - 'valid_epoch', - 'test_epoch', -]) +Dataset = collections.namedtuple( + 'Dataset', + [ + 'train_iterator_fn', + 'eval_train_epoch', + 'valid_epoch', + 'test_epoch', + ], +) def log_rss(msg: str): @@ -45,8 +47,9 @@ def log_rss(msg: str): logging.info('%s — RSS: %.1f MB', msg, rss_mb) -def prefetch_iterator(source_iter: Iterator[jax.typing.ArrayLike], - num_prefetch: int) -> Iterator[jax.typing.ArrayLike]: +def prefetch_iterator( + source_iter: Iterator[jax.typing.ArrayLike], num_prefetch: int +) -> Iterator[jax.typing.ArrayLike]: """Wraps the given iterator with prefetching. Args: @@ -121,14 +124,16 @@ def iterator_as_numpy(iterator): yield jax.tree.map(lambda y: y._numpy(), x) # pylint: disable=protected-access -def image_iterator(data, - rescale, - output_shape, - is_one_hot, - autoencoder, - shuffle_rng=None, - augment_fn=None, - include_example_keys=False): +def image_iterator( + data, + rescale, + output_shape, + is_one_hot, + autoencoder, + shuffle_rng=None, + augment_fn=None, + include_example_keys=False, +): """Preprocesses the batch data arrays in the data generator. Rescales inputs. One hot encode targets if is_one_hot is true. @@ -166,11 +171,13 @@ def image_iterator(data, yield {'inputs': inputs, 'targets': targets} -def maybe_pad_batch(batch, - desired_batch_size, - data_format=None, - mask_key=None, - padding_value=0.0): +def maybe_pad_batch( + batch, + desired_batch_size, + data_format=None, + mask_key=None, + padding_value=0.0, +): """Zero pad the batch on the right to desired_batch_size. All keys in the batch dictionary will have their corresponding arrays padded. @@ -187,9 +194,9 @@ def maybe_pad_batch(batch, dimension to pad. If not provided then it is assumed the first dimension is the batch dimension. mask_key: Typically used for text datasets, it's either 'inputs' (for - encoder only models like language models) or 'targets' - (for encoder-decoder models like seq2seq tasks) to decide weights for - padded sequence. For Image datasets, this will be (most likely) unused. + encoder only models like language models) or 'targets' (for + encoder-decoder models like seq2seq tasks) to decide weights for padded + sequence. For Image datasets, this will be (most likely) unused. padding_value: value to be used as padding. Returns: @@ -247,7 +254,7 @@ def make_global_array(local_data, mesh): """Util to combine per-host batches into a global batch array. Args: - local_data: local data batch on host. + local_data: local data batch on host. mesh: mesh specification to shard the data. Returns: diff --git a/init2winit/dataset_lib/datasets.py b/init2winit/dataset_lib/datasets.py index e9a22ec3..b051f38b 100644 --- a/init2winit/dataset_lib/datasets.py +++ b/init2winit/dataset_lib/datasets.py @@ -25,6 +25,7 @@ from init2winit.dataset_lib import imagenet_dataset from init2winit.dataset_lib import librispeech from init2winit.dataset_lib import lm1b_v2 + # We get TF v2 eager execution error if we import fineweb_edu_10b # and fineweb_edu_10b_mdlm before lm1b_v2 from init2winit.dataset_lib import fineweb_edu_10b # pylint: disable=g-bad-import-order @@ -34,6 +35,7 @@ from init2winit.dataset_lib import nanodo_fineweb_edu from init2winit.dataset_lib import nqm_noise from init2winit.dataset_lib import ogbg_molpcba + from init2winit.dataset_lib import proteins from init2winit.dataset_lib import small_image_datasets from init2winit.dataset_lib import translate_wmt diff --git a/init2winit/dataset_lib/fake_dataset.py b/init2winit/dataset_lib/fake_dataset.py index 332b6cb8..4fa18349 100644 --- a/init2winit/dataset_lib/fake_dataset.py +++ b/init2winit/dataset_lib/fake_dataset.py @@ -14,6 +14,7 @@ # limitations under the License. """Fake image input pipeline. Returns the same batch of ones over and over.""" + import copy from init2winit.dataset_lib import data_utils @@ -22,7 +23,6 @@ from ml_collections.config_dict import config_dict import numpy as np - TRAIN_IMAGES = 1281167 EVAL_IMAGES = 50000 @@ -31,11 +31,14 @@ IMAGE_SIZE = 224 -DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - input_shape=(224, 224, 3), - output_shape=(NUM_CLASSES,), - train_size=TRAIN_IMAGES, - valid_size=EVAL_IMAGES)) +DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + input_shape=(224, 224, 3), + output_shape=(NUM_CLASSES,), + train_size=TRAIN_IMAGES, + valid_size=EVAL_IMAGES, + ) +) METADATA = { 'apply_one_hot_in_loss': False, @@ -49,8 +52,7 @@ def get_fake_batch(hps): num_classes = hps.output_shape[0] train_input_shape = (batch_size, *input_shape) images = jnp.ones(train_input_shape, dtype=jnp.float32) - labels = jax.nn.one_hot( - np.zeros((batch_size,)), num_classes, dtype=jnp.int32) + labels = jax.nn.one_hot(np.zeros((batch_size,)), num_classes, dtype=jnp.int32) batch = { 'inputs': images, 'targets': labels, @@ -89,6 +91,7 @@ def eval_train_epoch(*args, **kwargs): del kwargs return yield # This yield is needed to make this a valid (null) iterator. + # pylint: enable=unreachable # pylint: disable=unreachable @@ -97,7 +100,9 @@ def test_epoch(*args, **kwargs): del kwargs return yield # This yield is needed to make this a valid (null) iterator. + # pylint: enable=unreachable return data_utils.Dataset( - train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch) + train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch + ) diff --git a/init2winit/dataset_lib/fastmri_dataset.py b/init2winit/dataset_lib/fastmri_dataset.py index 3dab35ee..1e912800 100644 --- a/init2winit/dataset_lib/fastmri_dataset.py +++ b/init2winit/dataset_lib/fastmri_dataset.py @@ -31,25 +31,27 @@ gfile = tf.io.gfile listdir = tf.io.gfile.listdir -DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - input_shape=(320, 320), - output_shape=(320, 320), - data_dir='', - train_dir='knee_singlecoil_train', - train_size=34742, - num_train_h5_files=973, - val_dir='knee_singlecoil_val', - valid_size=3554, - num_valid_h5_files=100, - # NOTE(dsuo): ground truth is not publicly available for the test set - # `knee_singlecoil_test_v2`, so we split the validation set into roughly - # two, 100 files for validation, 99 for test. This amounts to 3,554 slices - # for validation, 3,581 for test. - test_dir='knee_singlecoil_val', - test_size=3581, - num_test_h5_files=99, - eval_seed=0, -)) +DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + input_shape=(320, 320), + output_shape=(320, 320), + data_dir='', + train_dir='knee_singlecoil_train', + train_size=34742, + num_train_h5_files=973, + val_dir='knee_singlecoil_val', + valid_size=3554, + num_valid_h5_files=100, + # NOTE(dsuo): ground truth is not publicly available for the test + # set `knee_singlecoil_test_v2`, so we split the validation set into + # roughly two, 100 files for validation, 99 for test. This amounts + # to 3,554 slices for validation, 3,581 for test. + test_dir='knee_singlecoil_val', + test_size=3581, + num_test_h5_files=99, + eval_seed=0, + ) +) METADATA = { @@ -57,8 +59,9 @@ } -def _process_example(kspace, kspace_shape, target, target_shape, volume_max, - seed): +def _process_example( + kspace, kspace_shape, target, target_shape, volume_max, seed +): """Generate a single example (slice from mri image). Args: @@ -83,14 +86,17 @@ def _process_example(kspace, kspace_shape, target, target_shape, volume_max, acceleration = tf.convert_to_tensor(4.0, dtype=tf.float32) num_low_frequencies = tf.cast( - num_cols_float * center_fraction, dtype=tf.int32) + num_cols_float * center_fraction, dtype=tf.int32 + ) # calculate_center_mask mask = tf.zeros(num_cols, dtype=tf.float32) pad = (num_cols - num_low_frequencies + 1) // 2 mask = tf.tensor_scatter_nd_update( - mask, tf.reshape(tf.range(pad, pad + num_low_frequencies), (-1, 1)), - tf.ones(num_low_frequencies)) + mask, + tf.reshape(tf.range(pad, pad + num_low_frequencies), (-1, 1)), + tf.ones(num_low_frequencies), + ) # reshape_mask center_mask = tf.reshape(mask, (1, num_cols)) @@ -102,8 +108,8 @@ def _process_example(kspace, kspace_shape, target, target_shape, volume_max, ) mask = tf.cast( - tf.random.stateless_uniform((num_cols,), seed) < prob, - dtype=tf.float32) + tf.random.stateless_uniform((num_cols,), seed) < prob, dtype=tf.float32 + ) acceleration_mask = tf.reshape(mask, (1, num_cols)) mask = tf.math.maximum(center_mask, acceleration_mask) @@ -118,8 +124,10 @@ def _process_example(kspace, kspace_shape, target, target_shape, volume_max, image = tf.signal.fftshift(shifted_image, axes=(0, 1)) scaling_norm = tf.cast( tf.math.sqrt( - tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), 'float32')), - kspace.dtype) + tf.cast(tf.math.reduce_prod(tf.shape(kspace)[-2:]), 'float32') + ), + kspace.dtype, + ) image = image * scaling_norm image = tf.stack((tf.math.real(image), tf.math.imag(image)), axis=-1) @@ -132,7 +140,7 @@ def _process_example(kspace, kspace_shape, target, target_shape, volume_max, image = image[..., w_from:w_to, h_from:h_to, :] # complex_abs - abs_image = tf.math.sqrt(tf.math.reduce_sum(image ** 2, axis=-1)) + abs_image = tf.math.sqrt(tf.math.reduce_sum(image**2, axis=-1)) # normalize_instance mean = tf.math.reduce_mean(abs_image) @@ -151,14 +159,17 @@ def _process_example(kspace, kspace_shape, target, target_shape, volume_max, 'targets': target, 'mean': mean, 'std': std, - 'volume_max': volume_max + 'volume_max': volume_max, } def _h5_to_examples(path): """Yield MRI slices from an hdf5 file containing a single MRI volume.""" - tf.print('fastmri_dataset._h5_to_examples call:', path, - datetime.datetime.now().strftime('%H:%M:%S:%f')) + tf.print( + 'fastmri_dataset._h5_to_examples call:', + path, + datetime.datetime.now().strftime('%H:%M:%S:%f'), + ) with gfile.GFile(path, 'rb') as gf: with h5py.File(gf, 'r') as hf: # NOTE(dsuo): logic taken from reference code @@ -166,7 +177,8 @@ def _h5_to_examples(path): for i in range(hf['kspace'].shape[0]): yield hf['kspace'][i], hf['kspace'][i].shape, hf['reconstruction_esc'][ - i], hf['reconstruction_esc'][i].shape, volume_max + i + ], hf['reconstruction_esc'][i].shape, volume_max def _create_generator(filename): @@ -178,7 +190,8 @@ def _create_generator(filename): tf.TensorSpec(shape=(), dtype=tf.float32), ) return tf.data.Dataset.from_generator( - _h5_to_examples, args=(filename,), output_signature=signature) + _h5_to_examples, args=(filename,), output_signature=signature + ) def load_split(per_host_batch_size, split, hps, shuffle_rng=None): @@ -192,6 +205,7 @@ def load_split(per_host_batch_size, split, hps, shuffle_rng=None): hps: The hparams the experiment is run with. Required fields are num_train_h5_files, num_test_h5_files, and num_valid_h5_files. shuffle_rng: The RNG used to shuffle the split. + Returns: A `tf.data.Dataset`. """ @@ -293,20 +307,23 @@ def load_split(per_host_batch_size, split, hps, shuffle_rng=None): _create_generator, cycle_length=32, block_length=64, - num_parallel_calls=hps.num_tf_data_map_parallel_calls) + num_parallel_calls=hps.num_tf_data_map_parallel_calls, + ) ds = ds.cache() def process_example(example_index, example): if split == 'train': process_rng = tf.random.experimental.stateless_fold_in( - shuffle_rng, example_index) + shuffle_rng, example_index + ) else: # NOTE(dsuo): we use fixed randomness for eval. process_rng = tf.cast(jax.random.PRNGKey(hps.eval_seed), tf.int64) return _process_example(*example, process_rng) ds = ds.enumerate().map( - process_example, num_parallel_calls=hps.num_tf_data_map_parallel_calls) + process_example, num_parallel_calls=hps.num_tf_data_map_parallel_calls + ) if split == 'train': ds = ds.shuffle( @@ -344,8 +361,9 @@ def get_fastmri(shuffle_rng, batch_size, eval_batch_size, hps): train_ds = tfds.as_numpy(train_ds) # NOTE(dsuo): fastMRI has fixed randomness for eval. - eval_train_ds = load_split(per_host_eval_batch_size, 'eval_train', hps, - shuffle_rng) + eval_train_ds = load_split( + per_host_eval_batch_size, 'eval_train', hps, shuffle_rng + ) eval_train_ds = tfds.as_numpy(eval_train_ds) eval_ds = load_split(per_host_eval_batch_size, 'val', hps, shuffle_rng) @@ -369,20 +387,20 @@ def test_epoch(num_batches=None): for batch in itertools.islice(test_ds, num_batches): yield data_utils.maybe_pad_batch(batch, per_host_eval_batch_size) - return data_utils.Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, - test_epoch) + return data_utils.Dataset( + train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch + ) def get_fake_batch(hps): return { - 'inputs': - np.ones((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype), - 'targets': - np.ones((hps.batch_size, *hps.output_shape), dtype=hps.model_dtype), - 'mean': - np.ones((hps.batch_size,), dtype=hps.model_dtype), - 'std': - np.ones((hps.batch_size,), dtype=hps.model_dtype), - 'volume_max': - np.ones((hps.batch_size,), dtype=hps.model_dtype), + 'inputs': np.ones( + (hps.batch_size, *hps.input_shape), dtype=hps.model_dtype + ), + 'targets': np.ones( + (hps.batch_size, *hps.output_shape), dtype=hps.model_dtype + ), + 'mean': np.ones((hps.batch_size,), dtype=hps.model_dtype), + 'std': np.ones((hps.batch_size,), dtype=hps.model_dtype), + 'volume_max': np.ones((hps.batch_size,), dtype=hps.model_dtype), } diff --git a/init2winit/dataset_lib/image_preprocessing.py b/init2winit/dataset_lib/image_preprocessing.py index 7c674fd3..f36a6013 100644 --- a/init2winit/dataset_lib/image_preprocessing.py +++ b/init2winit/dataset_lib/image_preprocessing.py @@ -26,8 +26,9 @@ def _crop(key, image, hps): """Randomly shifts the window viewing the image.""" pixpad = (hps.crop_num_pixels, hps.crop_num_pixels) zero = (0, 0) - padded_image = jnp.pad(image, (pixpad, pixpad, zero), - mode='constant', constant_values=0.0) + padded_image = jnp.pad( + image, (pixpad, pixpad, zero), mode='constant', constant_values=0.0 + ) corner = random.randint(key, (2,), 0, 2 * hps.crop_num_pixels) corner = jnp.concatenate((corner, jnp.zeros((1,), jnp.int32))) cropped_image = lax.dynamic_slice(padded_image, corner, image.shape) @@ -108,8 +109,10 @@ def mixup(key, alpha, images, labels): mixed_labels = weight * labels + (1.0 - weight) * labels[::-1] weight = jnp.reshape(weight, (1,) * 0 + (batch_size,) + (1,) * (3)) - reverse = tuple(slice(images.shape[i]) if d != 'N' else slice(-1, None, -1) - for i, d in enumerate(image_format)) + reverse = tuple( + slice(images.shape[i]) if d != 'N' else slice(-1, None, -1) + for i, d in enumerate(image_format) + ) mixed_images = weight * images + (1.0 - weight) * images[reverse] @@ -123,12 +126,11 @@ def augment_cifar10(key, images, labels, hps): key: Rng key. images: A batch of images with shape [batch, height, width, channels]. labels: A batch of labels with shape [batch, ...] - hps: HParams object. hps.alpha parameterizes - the beta distribution sampling the mixup probabilities. - hps.crop_num_pixels determines the max amount of pixels for which the - viewing window will be shifted. hps.flip_probability determines the - probability of applying a random flip. hps.use_mixup determines whether - or not mixup is applied. + hps: HParams object. hps.alpha parameterizes the beta distribution sampling + the mixup probabilities. hps.crop_num_pixels determines the max amount of + pixels for which the viewing window will be shifted. hps.flip_probability + determines the probability of applying a random flip. hps.use_mixup + determines whether or not mixup is applied. Returns: A tuple containing the augmented images and labels. @@ -141,9 +143,15 @@ def augment_cifar10(key, images, labels, hps): # Random flip flip_mask = random.uniform(flip_rng, (1,) * (-1) + (batch_size,) + (1,) * (3)) images = jnp.where( - flip_mask < hps.flip_probability, images[tuple( - slice(images.shape[i]) if d != 'W' else slice(-1, None, -1) - for i, d in enumerate('NHWC'))], images) + flip_mask < hps.flip_probability, + images[ + tuple( + slice(images.shape[i]) if d != 'W' else slice(-1, None, -1) + for i, d in enumerate('NHWC') + ) + ], + images, + ) images = crop(crop_rng, images, hps) diff --git a/init2winit/dataset_lib/imagenet_dataset.py b/init2winit/dataset_lib/imagenet_dataset.py index 49c4852b..96b6c5e0 100644 --- a/init2winit/dataset_lib/imagenet_dataset.py +++ b/init2winit/dataset_lib/imagenet_dataset.py @@ -257,7 +257,8 @@ def load_split( dtype=tf.float32, image_size=224, shuffle_rng=None, - tfds_dataset_name='imagenet2012:5.*.*'): # pyformat: disable + tfds_dataset_name='imagenet2012:5.*.*', +): """Creates a split from the ImageNet dataset using TensorFlow Datasets. The dataset returned by this function will repeat forever if split == 'train', @@ -373,7 +374,7 @@ def mixup_batch(batch_index, batch): per_batch_mixup_rng = tf.random.experimental.stateless_fold_in( mixup_rng, batch_index ) - (inputs, targets) = image_preprocessing.mixup_tf( + inputs, targets = image_preprocessing.mixup_tf( per_batch_mixup_rng, batch['inputs'], batch['targets'], diff --git a/init2winit/dataset_lib/imagenet_preprocessing.py b/init2winit/dataset_lib/imagenet_preprocessing.py index c5bfdd17..84e98f80 100644 --- a/init2winit/dataset_lib/imagenet_preprocessing.py +++ b/init2winit/dataset_lib/imagenet_preprocessing.py @@ -26,13 +26,15 @@ NUM_CLASSES = 1000 -def distorted_bounding_box_crop(image_bytes, - bbox, - rng_seed, - min_object_covered=0.1, - aspect_ratio_range=(0.75, 1.33), - area_range=(0.05, 1.0), - max_attempts=10): +def distorted_bounding_box_crop( + image_bytes, + bbox, + rng_seed, + min_object_covered=0.1, + aspect_ratio_range=(0.75, 1.33), + area_range=(0.05, 1.0), + max_attempts=10, +): """Generates cropped_image using one of the bboxes randomly distorted. See `tf.image.sample_distorted_bounding_box` for more documentation. @@ -58,15 +60,18 @@ def distorted_bounding_box_crop(image_bytes, cropped image `Tensor` """ shape = tf.image.extract_jpeg_shape(image_bytes) - sample_distorted_bounding_box = tf.image.stateless_sample_distorted_bounding_box( - shape, - seed=rng_seed, - bounding_boxes=bbox, - min_object_covered=min_object_covered, - aspect_ratio_range=aspect_ratio_range, - area_range=area_range, - max_attempts=max_attempts, - use_image_if_no_bounding_boxes=True) + sample_distorted_bounding_box = ( + tf.image.stateless_sample_distorted_bounding_box( + shape, + seed=rng_seed, + bounding_boxes=bbox, + min_object_covered=min_object_covered, + aspect_ratio_range=aspect_ratio_range, + area_range=area_range, + max_attempts=max_attempts, + use_image_if_no_bounding_boxes=True, + ) + ) bbox_begin, bbox_size, _ = sample_distorted_bounding_box # Crop the image to the specified bounding box. @@ -79,8 +84,9 @@ def distorted_bounding_box_crop(image_bytes, def _resize(image, image_size): - return tf.image.resize([image], [image_size, image_size], - method=tf.image.ResizeMethod.BICUBIC)[0] + return tf.image.resize( + [image], [image_size, image_size], method=tf.image.ResizeMethod.BICUBIC + )[0] def _at_least_x_are_equal(a, b, x): @@ -90,8 +96,13 @@ def _at_least_x_are_equal(a, b, x): return tf.greater_equal(tf.reduce_sum(match), x) -def _decode_and_random_crop(image_bytes, image_size, rng_seed, area_range, - use_center_crop_if_random_failed): +def _decode_and_random_crop( + image_bytes, + image_size, + rng_seed, + area_range, + use_center_crop_if_random_failed, +): """Make a random crop of image_size.""" bbox = tf.constant([0.0, 0.0, 1.0, 1.0], dtype=tf.float32, shape=[1, 1, 4]) image = distorted_bounding_box_crop( @@ -99,15 +110,18 @@ def _decode_and_random_crop(image_bytes, image_size, rng_seed, area_range, bbox, rng_seed=rng_seed, min_object_covered=0.1, - aspect_ratio_range=(3. / 4, 4. / 3.), + aspect_ratio_range=(3.0 / 4, 4.0 / 3.0), area_range=area_range, - max_attempts=10) + max_attempts=10, + ) original_shape = tf.image.extract_jpeg_shape(image_bytes) bad = _at_least_x_are_equal(original_shape, tf.shape(image), 3) - image = tf.cond(bad and use_center_crop_if_random_failed, - lambda: _decode_and_center_crop(image_bytes, image_size), - lambda: _resize(image, image_size)) + image = tf.cond( + bad and use_center_crop_if_random_failed, + lambda: _decode_and_center_crop(image_bytes, image_size), + lambda: _resize(image, image_size), + ) return image @@ -143,11 +157,9 @@ def _resize_for_inception(image, size, method='bilinear'): return tf.cast(image, dtype) -def _decode_and_inception_crop(image_data, - size, - area_min=5, - area_max=100, - method='bilinear'): +def _decode_and_inception_crop( + image_data, size, area_min=5, area_max=100, method='bilinear' +): """Decode jpeg and add inception crop. Args: @@ -166,7 +178,8 @@ def _decode_and_inception_crop(image_data, tf.zeros([0, 0, 4], tf.float32), area_range=(area_min / 100, area_max / 100), min_object_covered=0, # Don't enforce a minimum area. - use_image_if_no_bounding_boxes=True) + use_image_if_no_bounding_boxes=True, + ) # Crop the image to the specified bounding box. offset_y, offset_x, _ = tf.unstack(begin) @@ -185,14 +198,20 @@ def _decode_and_center_crop(image_bytes, image_size): image_width = shape[1] padded_center_crop_size = tf.cast( - ((image_size / (image_size + CROP_PADDING)) * - tf.cast(tf.minimum(image_height, image_width), tf.float32)), tf.int32) + ( + (image_size / (image_size + CROP_PADDING)) + * tf.cast(tf.minimum(image_height, image_width), tf.float32) + ), + tf.int32, + ) offset_height = ((image_height - padded_center_crop_size) + 1) // 2 offset_width = ((image_width - padded_center_crop_size) + 1) // 2 crop_window = tf.stack([ - offset_height, offset_width, padded_center_crop_size, - padded_center_crop_size + offset_height, + offset_width, + padded_center_crop_size, + padded_center_crop_size, ]) image = tf.image.decode_and_crop_jpeg(image_bytes, crop_window, channels=3) image = _resize(image, image_size) @@ -206,17 +225,19 @@ def normalize_image(image): return image -def preprocess_for_train(image_bytes, - rng_seed, - dtype=tf.float32, - image_size=224, - crop='random', - random_crop_area_range=(0.08, 1.0), - use_center_crop_if_random_failed=False, - random_flip=True, - use_randaug=False, - randaug_magnitude=0, - randaug_num_layers=0): +def preprocess_for_train( + image_bytes, + rng_seed, + dtype=tf.float32, + image_size=224, + crop='random', + random_crop_area_range=(0.08, 1.0), + use_center_crop_if_random_failed=False, + random_flip=True, + use_randaug=False, + randaug_magnitude=0, + randaug_num_layers=0, +): """Preprocesses the given image for training. Args: @@ -245,7 +266,8 @@ def preprocess_for_train(image_bytes, image_size, crop_rng, random_crop_area_range, - use_center_crop_if_random_failed=use_center_crop_if_random_failed) + use_center_crop_if_random_failed=use_center_crop_if_random_failed, + ) elif crop == 'inception': image = _decode_and_inception_crop(image_bytes, image_size) elif crop == 'center': @@ -261,10 +283,9 @@ def preprocess_for_train(image_bytes, # NOTE(dsuo): autoaugment code expects uint8 image; not sure why we use # float32[0, 255], but just making sure pipeline runs. image = tf.cast(tf.clip_by_value(image, 0, 255), tf.uint8) - image = autoaugment.distort_image_with_randaugment(image, - randaug_num_layers, - randaug_magnitude, - randaug_rng) + image = autoaugment.distort_image_with_randaugment( + image, randaug_num_layers, randaug_magnitude, randaug_rng + ) image = tf.cast(image, tf.float32) image = normalize_image(image) diff --git a/init2winit/dataset_lib/librispeech.py b/init2winit/dataset_lib/librispeech.py index d79f9e3b..096cfd00 100644 --- a/init2winit/dataset_lib/librispeech.py +++ b/init2winit/dataset_lib/librispeech.py @@ -39,7 +39,9 @@ output_shape=(-1, VOCAB_SIZE), train_size=281241, tokenizer_vocab_path='', - tokenizer_type='SPM')) + tokenizer_type='SPM', + ) +) METADATA = {'apply_one_hot_in_loss': False} @@ -53,37 +55,53 @@ def get_librispeech(shuffle_rng, batch_size, eval_batch_size=None, hps=None): """Wrapper to conform to the general dataset API.""" process_count = jax.process_count() if batch_size % process_count != 0: - raise ValueError('process_count={} must divide batch_size={}.'.format( - process_count, batch_size)) + raise ValueError( + 'process_count={} must divide batch_size={}.'.format( + process_count, batch_size + ) + ) per_host_batch_size = batch_size // process_count if eval_batch_size is None: eval_batch_size = batch_size if eval_batch_size % process_count != 0: - raise ValueError('process_count={} must divide eval_batch_size={}.'.format( - process_count, eval_batch_size)) + raise ValueError( + 'process_count={} must divide eval_batch_size={}.'.format( + process_count, eval_batch_size + ) + ) per_host_eval_batch_size = eval_batch_size // process_count - return _get_librispeech(hps, per_host_batch_size, per_host_eval_batch_size, - shuffle_rng) + return _get_librispeech( + hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng + ) -def _get_librispeech(hps, per_host_batch_size, per_host_eval_batch_size, - shuffle_rng): +def _get_librispeech( + hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng +): """Data generators for lm1b.""" n_devices = jax.local_device_count() if per_host_batch_size % n_devices != 0: - raise ValueError('n_devices={} must divide per_host_batch_size={}.'.format( - n_devices, per_host_batch_size)) + raise ValueError( + 'n_devices={} must divide per_host_batch_size={}.'.format( + n_devices, per_host_batch_size + ) + ) if per_host_eval_batch_size % n_devices != 0: raise ValueError( 'n_devices={} must divide per_host_eval_batch_size={}.'.format( - n_devices, per_host_eval_batch_size)) + n_devices, per_host_eval_batch_size + ) + ) - train_ds, eval_ds, test_ds = librispeech_input_pipeline.get_librispeech_datasets( - hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng) + train_ds, eval_ds, test_ds = ( + librispeech_input_pipeline.get_librispeech_datasets( + hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng + ) + ) def train_iterator_fn(): for batch in iter(train_ds): @@ -99,14 +117,16 @@ def valid_epoch(num_batches=None): for batch in itertools.islice(valid_iter, num_batches): batch = _batch_to_dict(batch) yield data_utils.maybe_pad_batch( - batch, desired_batch_size=per_host_eval_batch_size, padding_value=1.0) + batch, desired_batch_size=per_host_eval_batch_size, padding_value=1.0 + ) def test_epoch(num_batches=None): test_iter = iter(test_ds) for batch in itertools.islice(test_iter, num_batches): batch = _batch_to_dict(batch) yield data_utils.maybe_pad_batch( - batch, desired_batch_size=per_host_eval_batch_size, padding_value=1.0) + batch, desired_batch_size=per_host_eval_batch_size, padding_value=1.0 + ) # pylint: enable=unreachable return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch) @@ -114,16 +134,16 @@ def test_epoch(num_batches=None): def get_fake_batch(hps): return { - 'inputs': - np.ones((hps.batch_size, hps.max_input_length), - dtype=hps.model_dtype), - 'input_paddings': - np.ones((hps.batch_size, hps.max_input_length), - dtype=hps.model_dtype), - 'targets': - np.ones((hps.batch_size, hps.max_target_length), - dtype=hps.model_dtype), - 'target_paddings': - np.ones((hps.batch_size, hps.max_target_length), - dtype=hps.model_dtype), + 'inputs': np.ones( + (hps.batch_size, hps.max_input_length), dtype=hps.model_dtype + ), + 'input_paddings': np.ones( + (hps.batch_size, hps.max_input_length), dtype=hps.model_dtype + ), + 'targets': np.ones( + (hps.batch_size, hps.max_target_length), dtype=hps.model_dtype + ), + 'target_paddings': np.ones( + (hps.batch_size, hps.max_target_length), dtype=hps.model_dtype + ), } diff --git a/init2winit/dataset_lib/librispeech_input_pipeline.py b/init2winit/dataset_lib/librispeech_input_pipeline.py index d9d9cdca..d477bfe6 100644 --- a/init2winit/dataset_lib/librispeech_input_pipeline.py +++ b/init2winit/dataset_lib/librispeech_input_pipeline.py @@ -24,7 +24,6 @@ import tensorflow as tf import tensorflow_datasets as tfds - # pylint: disable=g-import-not-at-top try: from init2winit.dataset_lib import spm_tokenizer @@ -39,7 +38,7 @@ ALLOWED_TOKENIZERS = ['SPM', 'RAW'] -class CharacterTokenizer(): +class CharacterTokenizer: """Tokenizer that uses raw character level vocab.""" def __init__(self): @@ -49,7 +48,7 @@ def __init__(self): '': 2, '_': 3, ' ': 4, - '\'': 5, + "'": 5, 'A': 6, 'B': 7, 'C': 8, @@ -132,12 +131,12 @@ def _preprocess_output(features, tokenizer, hps): Args: features: input tf data features - tokenizer: tokenizer to be used to tokenize the input - feature's output transcription. + tokenizer: tokenizer to be used to tokenize the input feature's output + transcription. hps: hyperparameters for the dataset pipeline set upstream, this is used to - extract flags controlling which tokenizer is used. - Uses sentence piece tokenizer if hps.use_spm_tokenizer = True. - Uses simple character level tokenizer if hps.use_character_tokenizer=True. + extract flags controlling which tokenizer is used. Uses sentence piece + tokenizer if hps.use_spm_tokenizer = True. Uses simple character level + tokenizer if hps.use_character_tokenizer=True. Returns: outputs tf data features with tokenized transcripts. @@ -145,19 +144,23 @@ def _preprocess_output(features, tokenizer, hps): if hps.tokenizer_type == 'SPM': features['targets'] = tokenizer.tokenize(features['targets']) features['target_paddings'] = tf.zeros_like( - features['targets'], dtype=tf.float32) + features['targets'], dtype=tf.float32 + ) elif hps.tokenizer_type == 'RAW': features['targets'] = tf.py_function( - func=tokenizer.tokenize, inp=[features['targets']], Tout=tf.int32) + func=tokenizer.tokenize, inp=[features['targets']], Tout=tf.int32 + ) features['target_paddings'] = tf.zeros_like( - features['targets'], dtype=tf.float32) + features['targets'], dtype=tf.float32 + ) return features def _make_input_paddings(features): features['input_paddings'] = tf.zeros_like( - features['inputs'], dtype=tf.float32) + features['inputs'], dtype=tf.float32 + ) return features @@ -173,8 +176,9 @@ def _normalize_feature_names(features): return features -def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder, - split: str, shuffle_seed=None) -> tf.data.Dataset: +def get_raw_dataset( + dataset_builder: tfds.core.DatasetBuilder, split: str, shuffle_seed=None +) -> tf.data.Dataset: """Loads the raw dataset and normalizes feature keys. Args: @@ -192,7 +196,8 @@ def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder, ds = dataset_builder.as_dataset( split=per_host_split, shuffle_files=(shuffle_seed is not None), - read_config=tfds.ReadConfig(shuffle_seed=shuffle_seed)) + read_config=tfds.ReadConfig(shuffle_seed=shuffle_seed), + ) ds = ds.map(_normalize_feature_names, num_parallel_calls=AUTOTUNE) return ds @@ -217,7 +222,8 @@ def preprocess_data( raise ValueError( 'Passed in tokenizer_type value does not correspond to currently ' 'supported tokenizers, make sure one of SPM or RAW is set' - ' as tokenizer_type flag.') + ' as tokenizer_type flag.' + ) if hps.tokenizer_type == 'SPM': tokenizer = spm_tokenizer.load_tokenizer(hps.tokenizer_vocab_path) @@ -226,7 +232,8 @@ def preprocess_data( dataset = dataset.map( functools.partial(_preprocess_output, tokenizer=tokenizer, hps=hps), - num_parallel_calls=10) + num_parallel_calls=10, + ) dataset = dataset.map(_make_input_paddings, num_parallel_calls=10) @@ -234,7 +241,8 @@ def preprocess_data( # note that audio filtering is post frequency domain conversion. if max_input_length > 0 and max_target_length > 0: inputs_length_filter, targets_length_filter = length_filter( - max_input_length, max_target_length) + max_input_length, max_target_length + ) dataset = dataset.filter(inputs_length_filter) dataset = dataset.filter(targets_length_filter) @@ -255,8 +263,9 @@ def preprocess_data( 'inputs': 0, 'targets': 0, 'input_paddings': 1.0, - 'target_paddings': 1.0 - }) + 'target_paddings': 1.0, + }, + ) if prefetch_size: dataset = dataset.prefetch(prefetch_size) @@ -285,7 +294,8 @@ def get_librispeech_datasets( train=True, batch_size=per_host_batch_size, hps=hps, - shuffle_seed=utils.convert_jax_to_tf_random_seed(data_shuffle_seed)) + shuffle_seed=utils.convert_jax_to_tf_random_seed(data_shuffle_seed), + ) eval_ds = preprocess_data( eval_data, @@ -300,6 +310,7 @@ def get_librispeech_datasets( train=False, batch_size=per_host_eval_batch_size, hps=hps, - drop_remainder=False) + drop_remainder=False, + ) return train_ds, eval_ds, test_ds diff --git a/init2winit/dataset_lib/lm1b_input_pipeline_v2.py b/init2winit/dataset_lib/lm1b_input_pipeline_v2.py index 8494ab8f..11740cbb 100644 --- a/init2winit/dataset_lib/lm1b_input_pipeline_v2.py +++ b/init2winit/dataset_lib/lm1b_input_pipeline_v2.py @@ -48,8 +48,9 @@ def __call__(self, features: Features) -> Features: return features -def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder, - split: str) -> tf.data.Dataset: +def get_raw_dataset( + dataset_builder: tfds.core.DatasetBuilder, split: str +) -> tf.data.Dataset: """Loads a raw text dataset and normalizes feature keys. Args: @@ -62,17 +63,20 @@ def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder, 'targets'. """ per_host_split = deterministic_data.get_read_instruction_for_host( - split, dataset_info=dataset_builder.info, drop_remainder=False) + split, dataset_info=dataset_builder.info, drop_remainder=False + ) ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False) ds = ds.map( - NormalizeFeatureNamesOp(dataset_builder.info), - num_parallel_calls=AUTOTUNE) + NormalizeFeatureNamesOp(dataset_builder.info), num_parallel_calls=AUTOTUNE + ) return ds -def pack_dataset(dataset: tf.data.Dataset, - key2length: Union[int, Dict[str, int]], - keys: Optional[List[str]] = None) -> tf.data.Dataset: +def pack_dataset( + dataset: tf.data.Dataset, + key2length: Union[int, Dict[str, int]], + keys: Optional[List[str]] = None, +) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. Adapted from the mesh-tf implementation. @@ -117,8 +121,10 @@ def pack_dataset(dataset: tf.data.Dataset, keys = list(shapes.keys()) for k in keys: if k not in shapes: - raise ValueError('Key %s not found in dataset. Available keys are %s' % - (k, shapes.keys())) + raise ValueError( + 'Key %s not found in dataset. Available keys are %s' + % (k, shapes.keys()) + ) if not shapes[k].is_compatible_with(tf.TensorShape([None])): raise ValueError('Tensors to be packed must be one-dimensional.') # make sure that the length dictionary contains all keys as well as the @@ -131,13 +137,15 @@ def pack_dataset(dataset: tf.data.Dataset, # trim to length dataset = dataset.map( - lambda x: {k: x[k][:key2length[k]] for k in keys}, - num_parallel_calls=AUTOTUNE) + lambda x: {k: x[k][: key2length[k]] for k in keys}, + num_parallel_calls=AUTOTUNE, + ) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) dataset = dataset.padded_batch( - batch_size, padded_shapes={k: [-1] for k in keys}) + batch_size, padded_shapes={k: [-1] for k in keys} + ) dataset = _pack_with_tf_ops(dataset, keys, key2length) # Set the Tensor shapes correctly since they get lost in the process. @@ -147,8 +155,9 @@ def my_fn(x): return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) -def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], - key2length: Dict[str, int]) -> tf.data.Dataset: +def _pack_with_tf_ops( + dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int] +) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. Helper for pack_dataset() Uses tf.while_loop. @@ -173,7 +182,8 @@ def write_packed_example(partial, outputs): for k in keys_etc: new_outputs[k] = outputs[k].write( outputs[k].size(), - tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) + tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), + ) return new_partial, new_outputs def map_fn(x): @@ -193,9 +203,11 @@ def map_fn(x): outputs = {} for k in keys: outputs[k] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) outputs[k + '_position'] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) def body_fn(i, partial, outputs): """Body function for while_loop. @@ -212,13 +224,15 @@ def body_fn(i, partial, outputs): one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) - val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] + val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( can_append, tf.less_equal( - tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) + tf.size(partial[k]) + tf.size(one_example[k]), key2length[k] + ), + ) def false_fn(): return write_packed_example(partial, outputs) @@ -229,12 +243,12 @@ def true_fn(): partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: - new_seq = one_example[k][:key2length[k]] + new_seq = one_example[k][: key2length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + '_position'] = tf.concat( - [partial[k + '_position'], - tf.range(new_seq_len)], 0) + [partial[k + '_position'], tf.range(new_seq_len)], 0 + ) partial = new_partial return i + 1, partial, outputs @@ -248,14 +262,14 @@ def true_fn(): {k: tf.TensorShape([None]) for k in keys_etc}, {k: tf.TensorShape(None) for k in keys_etc}, ), - maximum_iterations=dynamic_batch_size) + maximum_iterations=dynamic_batch_size, + ) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: - packed[k + '_segmentation'] = ( - tf.cumsum( - tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) * - tf.cast(tf.not_equal(packed[k], 0), tf.int32)) + packed[k + '_segmentation'] = tf.cumsum( + tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1 + ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32) return packed dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) @@ -265,16 +279,18 @@ def true_fn(): # ----------------------------------------------------------------------------- # Main dataset prep routines. # ----------------------------------------------------------------------------- -def preprocess_data(dataset, - train=True, - num_epochs=1, - pack_examples=False, - shuffle_buffer_size=1024, - max_length=512, - batch_size=256, - drop_remainder=True, - prefetch_size=AUTOTUNE, - shuffle_seed=None): +def preprocess_data( + dataset, + train=True, + num_epochs=1, + pack_examples=False, + shuffle_buffer_size=1024, + max_length=512, + batch_size=256, + drop_remainder=True, + prefetch_size=AUTOTUNE, + shuffle_seed=None, +): """Shuffle and batch/pack the given dataset.""" def length_filter(max_len): @@ -301,15 +317,10 @@ def filter_fn(x): else: # simple (static-shape) padded batching dataset = dataset.padded_batch( batch_size, - padded_shapes={ - 'inputs': max_length, - 'targets': max_length - }, - padding_values={ - 'inputs': 0, - 'targets': 0 - }, - drop_remainder=drop_remainder) + padded_shapes={'inputs': max_length, 'targets': max_length}, + padding_values={'inputs': 0, 'targets': 0}, + drop_remainder=drop_remainder, + ) repeated_dataset = dataset.repeat(num_epochs) @@ -319,8 +330,9 @@ def filter_fn(x): return repeated_dataset, dataset -def get_lm1b_datasets(hps, per_host_batch_size, per_host_eval_batch_size, - shuffle_rng): +def get_lm1b_datasets( + hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng +): """Load and return dataset of batched examples for use during training.""" if hps.vocab_path is None: vocab_path = os.path.expanduser('~/lm2b_sentencepiece_model') @@ -340,11 +352,14 @@ def get_lm1b_datasets(hps, per_host_batch_size, per_host_eval_batch_size, train_data, vocab_path=vocab_path, vocab_size=hps.vocab_size, - max_corpus_chars=hps.max_corpus_chars) + max_corpus_chars=hps.max_corpus_chars, + ) train_data = train_data.map( - spm_tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + spm_tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE + ) eval_data = eval_data.map( - spm_tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + spm_tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE + ) train_ds, eval_train_ds = preprocess_data( train_data, @@ -353,7 +368,8 @@ def get_lm1b_datasets(hps, per_host_batch_size, per_host_eval_batch_size, pack_examples=hps.pack_examples, batch_size=per_host_batch_size, max_length=hps.max_target_length, - shuffle_seed=data_utils.convert_jax_to_tf_random_seed(shuffle_rng)) + shuffle_seed=data_utils.convert_jax_to_tf_random_seed(shuffle_rng), + ) eval_ds, _ = preprocess_data( eval_data, @@ -361,6 +377,7 @@ def get_lm1b_datasets(hps, per_host_batch_size, per_host_eval_batch_size, pack_examples=hps.pack_examples, batch_size=per_host_eval_batch_size, max_length=hps.max_eval_target_length, - drop_remainder=False) + drop_remainder=False, + ) return train_ds, eval_train_ds, eval_ds diff --git a/init2winit/dataset_lib/lm1b_v2.py b/init2winit/dataset_lib/lm1b_v2.py index b2c718c9..32212118 100644 --- a/init2winit/dataset_lib/lm1b_v2.py +++ b/init2winit/dataset_lib/lm1b_v2.py @@ -37,7 +37,9 @@ vocab_path=None, train_size=30301028, pack_examples=False, - max_corpus_chars=10**7)) + max_corpus_chars=10**7, + ) +) METADATA = { 'apply_one_hot_in_loss': True, @@ -56,25 +58,30 @@ def get_lm1b(shuffle_rng, batch_size, eval_batch_size=None, hps=None): """Wrapper to conform to the general dataset API.""" process_count = jax.process_count() if batch_size % process_count != 0: - raise ValueError('process_count={} must divide batch_size={}.'.format( - process_count, batch_size)) + raise ValueError( + 'process_count={} must divide batch_size={}.'.format( + process_count, batch_size + ) + ) per_host_batch_size = batch_size // process_count if eval_batch_size is None: eval_batch_size = batch_size if eval_batch_size % process_count != 0: - raise ValueError('process_count={} must divide eval_batch_size={}.'.format( - process_count, eval_batch_size)) + raise ValueError( + 'process_count={} must divide eval_batch_size={}.'.format( + process_count, eval_batch_size + ) + ) per_host_eval_batch_size = eval_batch_size // process_count - return _get_lm1b(hps, per_host_batch_size, per_host_eval_batch_size, - shuffle_rng) + return _get_lm1b( + hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng + ) -def maybe_pad_batch(batch, - desired_batch_size, - mask_key='targets'): +def maybe_pad_batch(batch, desired_batch_size, mask_key='targets'): """Zero pad the batch on the right to desired_batch_size. All keys in the batch dictionary will have their corresponding arrays padded. @@ -88,9 +95,9 @@ def maybe_pad_batch(batch, desired_batch_size: All arrays in the dict will be padded to have first dimension equal to desired_batch_size. mask_key: Typically used for text datasets, it's either 'inputs' (for - encoder only models like language models) or 'targets' - (for encoder-decoder models like seq2seq tasks) to decide weights for - padded sequence. For Image datasets, this will be (most likely) unused. + encoder only models like language models) or 'targets' (for + encoder-decoder models like seq2seq tasks) to decide weights for padded + sequence. For Image datasets, this will be (most likely) unused. Returns: A dictionary mapping the same keys to the padded batches. Additionally we @@ -103,8 +110,9 @@ def maybe_pad_batch(batch, raise ValueError(f'Incorrect mask key {mask_key}.') if 'weights' in batch: - batch['weights'] = np.multiply(batch['weights'], - np.where(batch[mask_key] > 0, 1, 0)) + batch['weights'] = np.multiply( + batch['weights'], np.where(batch[mask_key] > 0, 1, 0) + ) else: batch['weights'] = np.where(batch[mask_key] > 0, 1, 0) @@ -130,16 +138,22 @@ def _get_lm1b(hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng): """Data generators for lm1b.""" n_devices = jax.local_device_count() if per_host_batch_size % n_devices != 0: - raise ValueError('n_devices={} must divide per_host_batch_size={}.'.format( - n_devices, per_host_batch_size)) + raise ValueError( + 'n_devices={} must divide per_host_batch_size={}.'.format( + n_devices, per_host_batch_size + ) + ) if per_host_eval_batch_size % n_devices != 0: raise ValueError( 'n_devices={} must divide per_host_eval_batch_size={}.'.format( - n_devices, per_host_eval_batch_size)) + n_devices, per_host_eval_batch_size + ) + ) train_ds, eval_train_ds, eval_ds = lm1b_input_pipeline_v2.get_lm1b_datasets( - hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng) + hps, per_host_batch_size, per_host_eval_batch_size, shuffle_rng + ) def train_iterator_fn(): for batch in iter(train_ds): diff --git a/init2winit/dataset_lib/mlperf_imagenet_dataset.py b/init2winit/dataset_lib/mlperf_imagenet_dataset.py index 59b85276..ec1751f0 100644 --- a/init2winit/dataset_lib/mlperf_imagenet_dataset.py +++ b/init2winit/dataset_lib/mlperf_imagenet_dataset.py @@ -25,24 +25,23 @@ import numpy as np import tensorflow.compat.v2 as tf - -DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - input_shape=(224, 224, 3), - output_shape=(1000,), - train_size=1281167, - valid_size=50000, - test_size=10000, # ImageNet-v2. - use_imagenetv2_test=True)) +DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + input_shape=(224, 224, 3), + output_shape=(1000,), + train_size=1281167, + valid_size=50000, + test_size=10000, # ImageNet-v2. + use_imagenetv2_test=True, + ) +) METADATA = { 'apply_one_hot_in_loss': False, } -def get_mlperf_imagenet(rng, - batch_size, - eval_batch_size, - hps=None): +def get_mlperf_imagenet(rng, batch_size, eval_batch_size, hps=None): """Data generators for imagenet. Args: @@ -58,13 +57,15 @@ def get_mlperf_imagenet(rng, if batch_size % jax.device_count() != 0: raise ValueError( 'Require batch_size % jax.device_count(), received ' - 'batch_size={}, device_count={}.'.format( - batch_size, jax.device_count())) + 'batch_size={}, device_count={}.'.format(batch_size, jax.device_count()) + ) if eval_batch_size % jax.device_count() != 0: raise ValueError( 'Require eval_batch_size % jax.device_count(), received ' 'eval_batch_size={}, device_count={}.'.format( - eval_batch_size, jax.device_count())) + eval_batch_size, jax.device_count() + ) + ) host_batch_size = batch_size // jax.process_count() eval_host_batch_size = eval_batch_size // jax.process_count() @@ -78,21 +79,24 @@ def get_mlperf_imagenet(rng, dtype=input_dtype, split='train', rng=rng, - shuffle_size=shuffle_buffer_size) + shuffle_size=shuffle_buffer_size, + ) eval_train_ds = mlperf_input_pipeline.load_split( host_batch_size, dtype=input_dtype, split='eval_train', rng=rng, - shuffle_size=shuffle_buffer_size) + shuffle_size=shuffle_buffer_size, + ) eval_ds = mlperf_input_pipeline.load_split( eval_host_batch_size, dtype=input_dtype, split='validation', rng=rng, - shuffle_size=shuffle_buffer_size) + shuffle_size=shuffle_buffer_size, + ) # We do not have TFRecords of ImageNet-v2 in the same format as the # train/validation splits above, so we reuse the same test split from the @@ -104,7 +108,8 @@ def get_mlperf_imagenet(rng, 'test', hps=hps, image_size=224, - tfds_dataset_name='imagenet_v2/matched-frequency') + tfds_dataset_name='imagenet_v2/matched-frequency', + ) # We cannot use tfds.as_numpy because this calls tensor.numpy() which does an # additional copy of the tensor, instead we call tensor._numpy() below. @@ -116,7 +121,8 @@ def eval_train_epoch(num_batches=None): num_batches = 0 eval_train_iter = iter(eval_train_ds) np_iter = data_utils.iterator_as_numpy( - itertools.islice(eval_train_iter, num_batches)) + itertools.islice(eval_train_iter, num_batches) + ) for batch in np_iter: yield data_utils.maybe_pad_batch(batch, eval_host_batch_size) @@ -125,7 +131,8 @@ def valid_epoch(num_batches=None): num_batches = max_eval_steps valid_iter = iter(eval_ds) np_iter = data_utils.iterator_as_numpy( - itertools.islice(valid_iter, num_batches)) + itertools.islice(valid_iter, num_batches) + ) for batch in np_iter: yield data_utils.maybe_pad_batch(batch, eval_host_batch_size) @@ -133,7 +140,8 @@ def test_epoch(num_batches=None): if test_ds: test_iter = iter(test_ds) np_iter = data_utils.iterator_as_numpy( - itertools.islice(test_iter, num_batches)) + itertools.islice(test_iter, num_batches) + ) for batch in np_iter: yield data_utils.maybe_pad_batch(batch, eval_host_batch_size) else: @@ -143,18 +151,17 @@ def test_epoch(num_batches=None): # pylint: enable=unreachable return data_utils.Dataset( - train_iterator_fn, - eval_train_epoch, - valid_epoch, - test_epoch) + train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch + ) def get_fake_batch(hps): return { - 'inputs': - np.ones((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype), - 'targets': - np.ones((hps.batch_size, *hps.output_shape), dtype=hps.model_dtype), - 'weights': - np.ones((hps.batch_size,), dtype=hps.model_dtype), + 'inputs': np.ones( + (hps.batch_size, *hps.input_shape), dtype=hps.model_dtype + ), + 'targets': np.ones( + (hps.batch_size, *hps.output_shape), dtype=hps.model_dtype + ), + 'weights': np.ones((hps.batch_size,), dtype=hps.model_dtype), } diff --git a/init2winit/dataset_lib/mlperf_input_pipeline.py b/init2winit/dataset_lib/mlperf_input_pipeline.py index ca4c3c13..0a3e2e5b 100644 --- a/init2winit/dataset_lib/mlperf_input_pipeline.py +++ b/init2winit/dataset_lib/mlperf_input_pipeline.py @@ -19,7 +19,6 @@ from init2winit.dataset_lib import data_utils from init2winit.dataset_lib import imagenet_preprocessing - import jax import tensorflow as tf @@ -29,13 +28,15 @@ STDDEV_RGB = [0.229 * 255, 0.224 * 255, 0.225 * 255] -def load_split(batch_size, - split, - dtype, - rng, - data_dir=None, - image_size=224, - shuffle_size=16384): +def load_split( + batch_size, + split, + dtype, + rng, + data_dir=None, + image_size=224, + shuffle_size=16384, +): """Returns the input_fn. Args: @@ -49,6 +50,7 @@ def load_split(batch_size, image_size: the size to resize the images to using `tf.image.resize(..., method='bicubic')`. shuffle_size: the size of the shuffle buffer used in `dataset.shuffler()`. + Returns: a tf.data.Dataset that is batched and preprocessed, and optionally shuffled and repeated, for ImageNet based off the MLPerf codebase. Note that for evaluation, the final partial batches are not yet padded to be the same @@ -76,19 +78,23 @@ def dataset_parser(*value): if len(value) > 1: [example_index, value] = value per_example_rng = tf.random.experimental.stateless_fold_in( - tf.cast(preprocess_rng, tf.int64), example_index) + tf.cast(preprocess_rng, tf.int64), example_index + ) elif split == 'train': raise ValueError( 'Must enumerate() over tf.data.Dataset when training in order to get ' - 'a per-example index to fold into a per-example seed.') + 'a per-example index to fold into a per-example seed.' + ) else: value = value[0] per_example_rng = None parsed = tf.io.parse_single_example( - value, { + value, + { 'image/encoded': tf.io.FixedLenFeature((), tf.string, ''), - 'image/class/label': tf.io.FixedLenFeature([], tf.int64, 0) - }) + 'image/class/label': tf.io.FixedLenFeature([], tf.int64, 0), + }, + ) image_bytes = tf.reshape(parsed['image/encoded'], []) label = tf.cast(tf.reshape(parsed['image/class/label'], []), tf.int32) - 1 @@ -101,10 +107,12 @@ def dataset_parser(*value): crop='random', random_crop_area_range=(0.05, 1.0), use_center_crop_if_random_failed=False, - random_flip=True) + random_flip=True, + ) else: image = imagenet_preprocessing.preprocess_for_eval( - image_bytes, dtype, image_size) + image_bytes, dtype, image_size + ) return { 'inputs': image, @@ -115,18 +123,22 @@ def dataset_parser(*value): num_hosts = jax.process_count() use_training_files = split in ['train', 'eval_train'] file_pattern = os.path.join( - data_dir, 'train-*' if use_training_files else 'validation-*') + data_dir, 'train-*' if use_training_files else 'validation-*' + ) dataset = tf.data.Dataset.list_files(file_pattern, shuffle=False) dataset = dataset.shard(num_hosts, index) concurrent_files = min(10, 1024 // num_hosts) - dataset = dataset.interleave(tf.data.TFRecordDataset, concurrent_files, 1, - concurrent_files) + dataset = dataset.interleave( + tf.data.TFRecordDataset, concurrent_files, 1, concurrent_files + ) if split == 'train': dataset = dataset.cache() # cache compressed JPEGs instead dataset = dataset.shuffle( - shuffle_size, reshuffle_each_iteration=True, - seed=data_utils.convert_jax_to_tf_random_seed(shuffle_rng)).repeat() + shuffle_size, + reshuffle_each_iteration=True, + seed=data_utils.convert_jax_to_tf_random_seed(shuffle_rng), + ).repeat() dataset = dataset.enumerate().map(dataset_parser, 64) else: dataset = dataset.map(dataset_parser, 64) diff --git a/init2winit/dataset_lib/mt_pipeline.py b/init2winit/dataset_lib/mt_pipeline.py index bfc37c2d..e061d85f 100644 --- a/init2winit/dataset_lib/mt_pipeline.py +++ b/init2winit/dataset_lib/mt_pipeline.py @@ -26,13 +26,13 @@ import tensorflow as tf import tensorflow_datasets as tfds - AUTOTUNE = tf.data.AUTOTUNE Features = Dict[str, tf.Tensor] -def get_user_defined_symbols(ds_info: tfds.core.DatasetInfo, - reverse_translation: bool): +def get_user_defined_symbols( + ds_info: tfds.core.DatasetInfo, reverse_translation: bool +): """Get language id token.""" if reverse_translation: token, _ = ds_info.supervised_keys @@ -58,8 +58,7 @@ def __call__(self, features: Features) -> Features: class TaskTokenOp: """Adds '2xx' task token to 'inputs'.""" - def __init__(self, ds_info: tfds.core.DatasetInfo, - reverse_translation: bool): + def __init__(self, ds_info: tfds.core.DatasetInfo, reverse_translation: bool): self.token_lang = get_user_defined_symbols(ds_info, reverse_translation) def __call__(self, features): @@ -79,9 +78,7 @@ def __call__(self, features): return features -def maybe_pad_batch(batch, - desired_batch_size, - mask_key='targets'): +def maybe_pad_batch(batch, desired_batch_size, mask_key='targets'): """Zero pad the batch on the right to desired_batch_size. All keys in the batch dictionary will have their corresponding arrays padded. @@ -95,9 +92,9 @@ def maybe_pad_batch(batch, desired_batch_size: All arrays in the dict will be padded to have first dimension equal to desired_batch_size. mask_key: Typically used for text datasets, it's either 'inputs' (for - encoder only models like language models) or 'targets' - (for encoder-decoder models like seq2seq tasks) to decide weights for - padded sequence. For Image datasets, this will be (most likely) unused. + encoder only models like language models) or 'targets' (for + encoder-decoder models like seq2seq tasks) to decide weights for padded + sequence. For Image datasets, this will be (most likely) unused. Returns: A dictionary mapping the same keys to the padded batches. Additionally we @@ -110,8 +107,9 @@ def maybe_pad_batch(batch, raise ValueError(f'Incorrect mask key {mask_key}.') if 'weights' in batch: - batch['weights'] = np.multiply(batch['weights'], - np.where(batch[mask_key] > 0, 1, 0)) + batch['weights'] = np.multiply( + batch['weights'], np.where(batch[mask_key] > 0, 1, 0) + ) else: batch['weights'] = np.where(batch[mask_key] > 0, 1, 0) @@ -133,11 +131,13 @@ def zero_pad(ar, pad_axis): return padded_batch -def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder, - split: str, - *, - reverse_translation: bool = False, - add_language_token: bool = False) -> tf.data.Dataset: +def get_raw_dataset( + dataset_builder: tfds.core.DatasetBuilder, + split: str, + *, + reverse_translation: bool = False, + add_language_token: bool = False, +) -> tf.data.Dataset: """Loads a raw WMT dataset and normalizes feature keys. Args: @@ -154,23 +154,30 @@ def get_raw_dataset(dataset_builder: tfds.core.DatasetBuilder, """ num_examples = dataset_builder.info.splits[split].num_examples per_host_split = deterministic_data.get_read_instruction_for_host( - split, num_examples, drop_remainder=False) + split, num_examples, drop_remainder=False + ) ds = dataset_builder.as_dataset(split=per_host_split, shuffle_files=False) ds = ds.map( NormalizeFeatureNamesOp( - dataset_builder.info, reverse_translation=reverse_translation), - num_parallel_calls=AUTOTUNE) + dataset_builder.info, reverse_translation=reverse_translation + ), + num_parallel_calls=AUTOTUNE, + ) if add_language_token: ds = ds.map( - TaskTokenOp(dataset_builder.info, - reverse_translation=reverse_translation), - num_parallel_calls=AUTOTUNE) + TaskTokenOp( + dataset_builder.info, reverse_translation=reverse_translation + ), + num_parallel_calls=AUTOTUNE, + ) return ds -def pack_dataset(dataset: tf.data.Dataset, - key2length: Union[int, Dict[str, int]], - keys: Optional[List[str]] = None) -> tf.data.Dataset: +def pack_dataset( + dataset: tf.data.Dataset, + key2length: Union[int, Dict[str, int]], + keys: Optional[List[str]] = None, +) -> tf.data.Dataset: """Creates a 'packed' version of a dataset on-the-fly. Adapted from the mesh-tf implementation. @@ -215,8 +222,10 @@ def pack_dataset(dataset: tf.data.Dataset, keys = list(shapes.keys()) for k in keys: if k not in shapes: - raise ValueError('Key %s not found in dataset. Available keys are %s' % - (k, shapes.keys())) + raise ValueError( + 'Key %s not found in dataset. Available keys are %s' + % (k, shapes.keys()) + ) if not shapes[k].is_compatible_with(tf.TensorShape([None])): raise ValueError('Tensors to be packed must be one-dimensional.') # make sure that the length dictionary contains all keys as well as the @@ -229,13 +238,15 @@ def pack_dataset(dataset: tf.data.Dataset, # trim to length dataset = dataset.map( - lambda x: {k: x[k][:key2length[k]] for k in keys}, - num_parallel_calls=AUTOTUNE) + lambda x: {k: x[k][: key2length[k]] for k in keys}, + num_parallel_calls=AUTOTUNE, + ) # Setting batch_size=length ensures that the concatenated sequences (if they # have length >=1) are sufficient to fill at least one packed example. batch_size = max(key2length.values()) dataset = dataset.padded_batch( - batch_size, padded_shapes={k: [-1] for k in keys}) + batch_size, padded_shapes={k: [-1] for k in keys} + ) dataset = _pack_with_tf_ops(dataset, keys, key2length) # Set the Tensor shapes correctly since they get lost in the process. @@ -245,8 +256,9 @@ def my_fn(x): return dataset.map(my_fn, num_parallel_calls=AUTOTUNE) -def _pack_with_tf_ops(dataset: tf.data.Dataset, keys: List[str], - key2length: Dict[str, int]) -> tf.data.Dataset: +def _pack_with_tf_ops( + dataset: tf.data.Dataset, keys: List[str], key2length: Dict[str, int] +) -> tf.data.Dataset: """Helper-function for packing a dataset which has already been batched. Helper for pack_dataset() Uses tf.while_loop. @@ -271,7 +283,8 @@ def write_packed_example(partial, outputs): for k in keys_etc: new_outputs[k] = outputs[k].write( outputs[k].size(), - tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]])) + tf.pad(partial[k], [[0, key2length[k] - tf.size(partial[k])]]), + ) return new_partial, new_outputs def map_fn(x): @@ -291,9 +304,11 @@ def map_fn(x): outputs = {} for k in keys: outputs[k] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) outputs[k + '_position'] = tf.TensorArray( - tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]]) + tf.int32, size=0, dynamic_size=True, element_shape=[key2length[k]] + ) def body_fn(i, partial, outputs): """Body function for while_loop. @@ -310,13 +325,15 @@ def body_fn(i, partial, outputs): one_example = {} for k in keys: val = tf.cast(x[k][i], tf.int32) - val = val[:tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] + val = val[: tf.reduce_sum(tf.cast(tf.not_equal(val, 0), tf.int32))] one_example[k] = val for k in keys: can_append = tf.logical_and( can_append, tf.less_equal( - tf.size(partial[k]) + tf.size(one_example[k]), key2length[k])) + tf.size(partial[k]) + tf.size(one_example[k]), key2length[k] + ), + ) def false_fn(): return write_packed_example(partial, outputs) @@ -327,12 +344,12 @@ def true_fn(): partial, outputs = tf.cond(can_append, true_fn, false_fn) new_partial = {} for k in keys: - new_seq = one_example[k][:key2length[k]] + new_seq = one_example[k][: key2length[k]] new_seq_len = tf.size(new_seq) new_partial[k] = tf.concat([partial[k], new_seq], 0) new_partial[k + '_position'] = tf.concat( - [partial[k + '_position'], - tf.range(new_seq_len)], 0) + [partial[k + '_position'], tf.range(new_seq_len)], 0 + ) partial = new_partial return i + 1, partial, outputs @@ -346,14 +363,14 @@ def true_fn(): {k: tf.TensorShape([None]) for k in keys_etc}, {k: tf.TensorShape(None) for k in keys_etc}, ), - maximum_iterations=dynamic_batch_size) + maximum_iterations=dynamic_batch_size, + ) _, outputs = write_packed_example(partial, outputs) packed = {k: outputs[k].stack() for k in keys_etc} for k in keys: - packed[k + '_segmentation'] = ( - tf.cumsum( - tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1) * - tf.cast(tf.not_equal(packed[k], 0), tf.int32)) + packed[k + '_segmentation'] = tf.cumsum( + tf.cast(tf.equal(packed[k + '_position'], 0), tf.int32), axis=1 + ) * tf.cast(tf.not_equal(packed[k], 0), tf.int32) return packed dataset = dataset.map(map_fn, num_parallel_calls=AUTOTUNE) @@ -363,50 +380,64 @@ def true_fn(): # ----------------------------------------------------------------------------- # Main dataset prep routines. # ----------------------------------------------------------------------------- -def get_sampled_dataset(ds_builders, - split: str, - rates: List[int], - reverse_translation: bool, - add_language_token: bool, - loss_weights: List[float], - is_training: bool, - sample_seed: int, - shuffle_seed: int, - shuffle_buffer_size: int = 1024): +def get_sampled_dataset( + ds_builders, + split: str, + rates: List[int], + reverse_translation: bool, + add_language_token: bool, + loss_weights: List[float], + is_training: bool, + sample_seed: int, + shuffle_seed: int, + shuffle_buffer_size: int = 1024, +): """Create a sampled training dataset.""" raw_data = [] for builder in ds_builders: - raw_data.append(get_raw_dataset( - builder, split, - reverse_translation=reverse_translation, - add_language_token=add_language_token)) + raw_data.append( + get_raw_dataset( + builder, + split, + reverse_translation=reverse_translation, + add_language_token=add_language_token, + ) + ) if loss_weights is not None: - raw_data = [ds.map(ImportanceSamplingOp(weight), - num_parallel_calls=AUTOTUNE) for ds, weight - in zip(raw_data, loss_weights)] + raw_data = [ + ds.map(ImportanceSamplingOp(weight), num_parallel_calls=AUTOTUNE) + for ds, weight in zip(raw_data, loss_weights) + ] - def _shuffle_repeat(dataset, shuffle_seed: int, - shuffle_buffer_size): + def _shuffle_repeat(dataset, shuffle_seed: int, shuffle_buffer_size): dataset = dataset.shuffle(shuffle_buffer_size, seed=shuffle_seed) dataset = dataset.repeat() return dataset if is_training: - raw_data = [_shuffle_repeat(data, shuffle_seed=shuffle_seed, - shuffle_buffer_size=shuffle_buffer_size) - for data in raw_data] + raw_data = [ + _shuffle_repeat( + data, + shuffle_seed=shuffle_seed, + shuffle_buffer_size=shuffle_buffer_size, + ) + for data in raw_data + ] sampled_raw_data = tf.data.experimental.sample_from_datasets( - raw_data, rates, seed=sample_seed) + raw_data, rates, seed=sample_seed + ) return sampled_raw_data -def preprocess_wmt_data(dataset, - pack_examples: bool = True, - max_length: int = 100, - batch_size: int = 256, - prefetch_size: int = AUTOTUNE): +def preprocess_wmt_data( + dataset, + pack_examples: bool = True, + max_length: int = 100, + batch_size: int = 256, + prefetch_size: int = AUTOTUNE, +): """Shuffle and batch/pack the given dataset.""" def length_filter(max_len): @@ -431,26 +462,18 @@ def filter_fn(x): padded_shapes={ 'inputs': max_length, 'targets': max_length, - 'weights': (1,) - }, - padding_values={ - 'inputs': 0, - 'targets': 0, - 'weights': 0.0 + 'weights': (1,), }, - drop_remainder=False) + padding_values={'inputs': 0, 'targets': 0, 'weights': 0.0}, + drop_remainder=False, + ) else: dataset = dataset.padded_batch( batch_size, - padded_shapes={ - 'inputs': max_length, - 'targets': max_length - }, - padding_values={ - 'inputs': 0, - 'targets': 0 - }, - drop_remainder=False) + padded_shapes={'inputs': max_length, 'targets': max_length}, + padding_values={'inputs': 0, 'targets': 0}, + drop_remainder=False, + ) if prefetch_size: dataset = dataset.prefetch(prefetch_size) @@ -458,21 +481,27 @@ def filter_fn(x): return dataset -def get_wmt_datasets(config: config_dict.ConfigDict, - *, - shuffle_seed: int, - sample_seed: int, - n_devices: int, - per_host_batch_size: int, - per_host_eval_batch_size: int, - vocab_path: Optional[str] = None): +def get_wmt_datasets( + config: config_dict.ConfigDict, + *, + shuffle_seed: int, + sample_seed: int, + n_devices: int, + per_host_batch_size: int, + per_host_eval_batch_size: int, + vocab_path: Optional[str] = None, +): """Load and return dataset of batched examples for use during training.""" if per_host_batch_size % n_devices: - raise ValueError("Batch size %d isn't divided evenly by n_devices %d" % - (per_host_batch_size, n_devices)) + raise ValueError( + "Batch size %d isn't divided evenly by n_devices %d" + % (per_host_batch_size, n_devices) + ) if per_host_eval_batch_size % n_devices: - raise ValueError("Eval Batch size %d isn't divided evenly by n_devices %d" % - (per_host_eval_batch_size, n_devices)) + raise ValueError( + "Eval Batch size %d isn't divided evenly by n_devices %d" + % (per_host_eval_batch_size, n_devices) + ) if vocab_path is None: vocab_path = os.path.expanduser('~/wmt_sentencepiece_model') @@ -482,8 +511,9 @@ def get_wmt_datasets(config: config_dict.ConfigDict, else: dataset_keys = config.tfds_dataset_keys dataset_rates = config.rates - train_ds_builders = [tfds.builder(tfds_dataset_key) for tfds_dataset_key in - dataset_keys] + train_ds_builders = [ + tfds.builder(tfds_dataset_key) for tfds_dataset_key in dataset_keys + ] eval_ds_builder = tfds.builder(config.tfds_eval_dataset_key) if config.tfds_predict_dataset_key: @@ -493,33 +523,42 @@ def get_wmt_datasets(config: config_dict.ConfigDict, # TODO(dxin): give each task its own reverse translation bool. sampled_train_data = get_sampled_dataset( - train_ds_builders, config.train_split, dataset_rates, - config.reverse_translation, config.add_language_token, + train_ds_builders, + config.train_split, + dataset_rates, + config.reverse_translation, + config.add_language_token, loss_weights=config.loss_weights, - is_training=True, sample_seed=sample_seed, - shuffle_seed=shuffle_seed) + is_training=True, + sample_seed=sample_seed, + shuffle_seed=shuffle_seed, + ) eval_data = get_raw_dataset( eval_ds_builder, config.eval_split, reverse_translation=config.reverse_translation, - add_language_token=config.add_language_token) + add_language_token=config.add_language_token, + ) predict_data = get_raw_dataset( predict_ds_builder, config.predict_split, reverse_translation=config.reverse_translation, - add_language_token=config.add_language_token) + add_language_token=config.add_language_token, + ) # Tokenize data. user_defined_symbols = [] if config.add_language_token: user_defined_symbols = [ get_user_defined_symbols(ds_builders.info, True) - for ds_builders in train_ds_builders] - user_defined_symbols.extend( - [get_user_defined_symbols(ds_builders.info, False) - for ds_builders in train_ds_builders]) + for ds_builders in train_ds_builders + ] + user_defined_symbols.extend([ + get_user_defined_symbols(ds_builders.info, False) + for ds_builders in train_ds_builders + ]) user_defined_symbols = list(set(user_defined_symbols)) # TODO(dxin): Check in vocab file eventually. @@ -531,30 +570,37 @@ def get_wmt_datasets(config: config_dict.ConfigDict, character_coverage=config.character_coverage, byte_fallback=config.byte_fallback, split_digits=config.split_digits, - user_defined_symbols=user_defined_symbols) + user_defined_symbols=user_defined_symbols, + ) sampled_train_data = sampled_train_data.map( - tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE + ) eval_data = eval_data.map( - tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE + ) predict_data = predict_data.map( - tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE + ) sampled_train_ds = preprocess_wmt_data( sampled_train_data, pack_examples=config.pack_examples, batch_size=per_host_batch_size, - max_length=config.max_target_length) + max_length=config.max_target_length, + ) eval_ds = preprocess_wmt_data( eval_data, pack_examples=False, batch_size=per_host_eval_batch_size, - max_length=config.max_eval_target_length) + max_length=config.max_eval_target_length, + ) predict_ds = preprocess_wmt_data( predict_data, pack_examples=False, batch_size=per_host_eval_batch_size, - max_length=config.max_predict_length) + max_length=config.max_predict_length, + ) return sampled_train_ds, eval_ds, predict_ds diff --git a/init2winit/dataset_lib/mt_pipeline_test.py b/init2winit/dataset_lib/mt_pipeline_test.py index 0eb42fbf..c3c05e87 100644 --- a/init2winit/dataset_lib/mt_pipeline_test.py +++ b/init2winit/dataset_lib/mt_pipeline_test.py @@ -38,8 +38,9 @@ class MTPipelineTest(absltest.TestCase): - def _get_datasets(self, shuffle_seed=None, sample_seed=None, - pack_examples=False): + def _get_datasets( + self, shuffle_seed=None, sample_seed=None, pack_examples=False + ): config = translate_wmt.DEFAULT_HPARAMS config.vocab_size = 32 config.max_corpus_chars = 1000 @@ -48,7 +49,8 @@ def _get_datasets(self, shuffle_seed=None, sample_seed=None, config.max_predict_length = _PREDICT_TARGET_LENGTH config.pack_examples = pack_examples config.tfds_dataset_keys = [ - 'wmt15_translate/de-en', 'wmt15_translate/ru-en' + 'wmt15_translate/de-en', + 'wmt15_translate/ru-en', ] config.tfds_eval_dataset_key = 'wmt14_translate/de-en' @@ -67,7 +69,8 @@ def _get_datasets(self, shuffle_seed=None, sample_seed=None, n_devices=2, per_host_batch_size=4, per_host_eval_batch_size=4, - vocab_path=vocab_path) + vocab_path=vocab_path, + ) return train_ds, eval_ds, predict_ds def test_train_ds_golden(self): @@ -75,10 +78,13 @@ def test_train_ds_golden(self): train_ds, _, _ = self._get_datasets(pack_examples=False) expected_shape = [4, _TARGET_LENGTH] # 4 batch_size. for batch in train_ds.take(3): - self.assertEqual({k: v.shape.as_list() for k, v in batch.items()}, { - 'inputs': expected_shape, - 'targets': expected_shape, - }) + self.assertEqual( + {k: v.shape.as_list() for k, v in batch.items()}, + { + 'inputs': expected_shape, + 'targets': expected_shape, + }, + ) def test_train_ds_packed(self): # packed train dataset @@ -87,62 +93,80 @@ def test_train_ds_packed(self): # For training we pack multiple short examples in one example. # *_position and *_segmentation indicate the boundaries. for batch in train_ds.take(3): - self.assertEqual({k: v.shape.as_list() for k, v in batch.items()}, { - 'inputs': expected_shape, - 'inputs_position': expected_shape, - 'inputs_segmentation': expected_shape, - 'targets': expected_shape, - 'targets_position': expected_shape, - 'targets_segmentation': expected_shape, - }) + self.assertEqual( + {k: v.shape.as_list() for k, v in batch.items()}, + { + 'inputs': expected_shape, + 'inputs_position': expected_shape, + 'inputs_segmentation': expected_shape, + 'targets': expected_shape, + 'targets_position': expected_shape, + 'targets_segmentation': expected_shape, + }, + ) def test_train_ds_determinism(self): # unpacked train dataset - train_ds, _, _ = self._get_datasets(shuffle_seed=_SHUFFLE_SEED, - sample_seed=_SAMPLE_SEED) - train_ds_copy, _, _ = self._get_datasets(shuffle_seed=_SHUFFLE_SEED, - sample_seed=_SAMPLE_SEED) + train_ds, _, _ = self._get_datasets( + shuffle_seed=_SHUFFLE_SEED, sample_seed=_SAMPLE_SEED + ) + train_ds_copy, _, _ = self._get_datasets( + shuffle_seed=_SHUFFLE_SEED, sample_seed=_SAMPLE_SEED + ) batch_idx_to_test = 1 train_ds_batch = next( - itertools.islice(train_ds, batch_idx_to_test, batch_idx_to_test + 1)) + itertools.islice(train_ds, batch_idx_to_test, batch_idx_to_test + 1) + ) train_ds_copy_batch = next( - itertools.islice(train_ds_copy, batch_idx_to_test, - batch_idx_to_test + 1)) + itertools.islice( + train_ds_copy, batch_idx_to_test, batch_idx_to_test + 1 + ) + ) self.assertTrue( - jnp.array_equal(train_ds_batch['inputs'], - train_ds_copy_batch['inputs'])) + jnp.array_equal(train_ds_batch['inputs'], train_ds_copy_batch['inputs']) + ) self.assertTrue( - jnp.array_equal(train_ds_batch['targets'], - train_ds_copy_batch['targets'])) + jnp.array_equal( + train_ds_batch['targets'], train_ds_copy_batch['targets'] + ) + ) def test_eval_ds(self): _, eval_ds, _ = self._get_datasets() expected_shape = [4, _EVAL_TARGET_LENGTH] # 4 batch_size. for batch in eval_ds.take(3): - self.assertEqual({k: v.shape.as_list() for k, v in batch.items()}, { - 'inputs': expected_shape, - 'targets': expected_shape, - }) + self.assertEqual( + {k: v.shape.as_list() for k, v in batch.items()}, + { + 'inputs': expected_shape, + 'targets': expected_shape, + }, + ) def test_predict_ds(self): _, _, predict_ds = self._get_datasets() expected_shape = [4, _PREDICT_TARGET_LENGTH] # 4 batch_size. for batch in predict_ds.take(3): - self.assertEqual({k: v.shape.as_list() for k, v in batch.items()}, { - 'inputs': expected_shape, - 'targets': expected_shape, - }) + self.assertEqual( + {k: v.shape.as_list() for k, v in batch.items()}, + { + 'inputs': expected_shape, + 'targets': expected_shape, + }, + ) class MTSamplePipelineTest(absltest.TestCase): - def _get_sampled_datasets(self, rates, num_examples: int, shuffle_seed=None, - sample_seed=None): + def _get_sampled_datasets( + self, rates, num_examples: int, shuffle_seed=None, sample_seed=None + ): # Test sampling ratio. config = translate_wmt.DEFAULT_HPARAMS config.rates = rates config.tfds_dataset_keys = [ - 'wmt15_translate/de-en', 'wmt15_translate/ru-en' + 'wmt15_translate/de-en', + 'wmt15_translate/ru-en', ] config.tfds_eval_dataset_key = 'wmt14_translate/de-en' @@ -150,22 +174,30 @@ def _get_sampled_datasets(self, rates, num_examples: int, shuffle_seed=None, if config.tfds_dataset_key: train_ds_builders = [tfds.builder(config.tfds_dataset_key)] else: - train_ds_builders = [tfds.builder(tfds_dataset_key) for - tfds_dataset_key in config.tfds_dataset_keys] + train_ds_builders = [ + tfds.builder(tfds_dataset_key) + for tfds_dataset_key in config.tfds_dataset_keys + ] sampled_train_data = mt_pipeline.get_sampled_dataset( - train_ds_builders, config.train_split, config.rates, + train_ds_builders, + config.train_split, + config.rates, reverse_translation=True, add_language_token=True, loss_weights=config.loss_weights, is_training=True, sample_seed=sample_seed, - shuffle_seed=shuffle_seed) + shuffle_seed=shuffle_seed, + ) # Get the language task ids. - lang_ids = [mt_pipeline.get_user_defined_symbols(train_ds_builder.info, - reverse_translation=True) - for train_ds_builder in train_ds_builders] + lang_ids = [ + mt_pipeline.get_user_defined_symbols( + train_ds_builder.info, reverse_translation=True + ) + for train_ds_builder in train_ds_builders + ] rates = {lang_id: rate for lang_id, rate in zip(lang_ids, config.rates)} return sampled_train_data, rates, lang_ids @@ -173,32 +205,38 @@ def _get_sampled_datasets(self, rates, num_examples: int, shuffle_seed=None, def test_sample_batch_ds_ratio(self): # Test sampling ratios. sampled_train_data, rates, lang_ids = self._get_sampled_datasets( - [0.5, 0.5], 20000, _SHUFFLE_SEED, _SAMPLE_SEED) + [0.5, 0.5], 20000, _SHUFFLE_SEED, _SAMPLE_SEED + ) counts = {lang_id: 0 for lang_id in lang_ids} # Establish the counts. batch_size = 10000 for batch in sampled_train_data.batch(batch_size).take(1): for input_data in batch['inputs']: - counts[input_data.numpy().decode('utf-8')[0:len(lang_ids[0])]] += 1 + counts[input_data.numpy().decode('utf-8')[0 : len(lang_ids[0])]] += 1 # Check if counts are in line w/ the sampling ratio. - self.assertTrue(np.allclose( - [rates[lang_id] for lang_id in lang_ids], - [counts[lang_id]/batch_size for lang_id in lang_ids], atol=1e-2)) + self.assertTrue( + np.allclose( + [rates[lang_id] for lang_id in lang_ids], + [counts[lang_id] / batch_size for lang_id in lang_ids], + atol=1e-2, + ) + ) def test_sample_batch_ds_repeat(self): # Test that the sampled dataset can repeat datasets if the data in one # stream is exhausted before the other. sampled_train_data, _, lang_ids = self._get_sampled_datasets( - [0.9, 0.1], 300, _SHUFFLE_SEED, _SAMPLE_SEED) + [0.9, 0.1], 300, _SHUFFLE_SEED, _SAMPLE_SEED + ) # Take enough batches to exhaust the dataset with the larger sampling rate. batch_size = 100 counts = {lang_id: 0 for lang_id in lang_ids} for batch in sampled_train_data.batch(batch_size).take(20): for input_data in batch['inputs']: - counts[input_data.numpy().decode('utf-8')[0:len(lang_ids[0])]] += 1 + counts[input_data.numpy().decode('utf-8')[0 : len(lang_ids[0])]] += 1 self.assertTrue(all([counts[lang_id] > 0 for lang_id in lang_ids])) diff --git a/init2winit/dataset_lib/mt_tokenizer.py b/init2winit/dataset_lib/mt_tokenizer.py index 3efaa257..95fb3870 100644 --- a/init2winit/dataset_lib/mt_tokenizer.py +++ b/init2winit/dataset_lib/mt_tokenizer.py @@ -19,7 +19,7 @@ import os import tempfile import time -from typing import Any, Dict, Iterable, Tuple, List +from typing import Any, Dict, Iterable, Tuple from absl import logging import jax @@ -40,7 +40,7 @@ def _dump_chars_to_textfile( dataset: tf.data.Dataset, maxchars: int = int(1e7), - data_keys=('inputs', 'targets') + data_keys=('inputs', 'targets'), ) -> Tuple[str, int]: """Write part of a TFDS sentence dataset to lines in a text file. @@ -55,7 +55,8 @@ def _dump_chars_to_textfile( char_count = 0 ds_iter = dataset.as_numpy_iterator() with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/ds_chars') as outfp: + delete=False, prefix='/tmp/ds_chars' + ) as outfp: while char_count < maxchars: example = next(ds_iter) for k in data_keys: @@ -65,17 +66,19 @@ def _dump_chars_to_textfile( return outfp.name, char_count -def _train_sentencepiece(dataset: tf.data.Dataset, - *, - vocab_size: int, - maxchars: int = int(1e7), - model_path: str, - model_type: str = 'unigram', - character_coverage: float = 1.0, - byte_fallback: bool = False, - split_digits: bool = False, - data_keys: Tuple[str, str] = ('inputs', 'targets'), - user_defined_symbols: List[str] = []): +def _train_sentencepiece( + dataset: tf.data.Dataset, + *, + vocab_size: int, + maxchars: int = int(1e7), + model_path: str, + model_type: str = 'unigram', + character_coverage: float = 1.0, + byte_fallback: bool = False, + split_digits: bool = False, + data_keys: Tuple[str, str] = ('inputs', 'targets'), + user_defined_symbols: Tuple[str, ...] = (), +): """Train SentencePiece tokenizer from subset of tf dataset. Args: @@ -97,19 +100,23 @@ def _train_sentencepiece(dataset: tf.data.Dataset, """ abs_model_path = os.path.abspath(os.path.expanduser(model_path)) fname, _ = _dump_chars_to_textfile( - dataset, maxchars=maxchars, data_keys=data_keys) + dataset, maxchars=maxchars, data_keys=data_keys + ) with tempfile.NamedTemporaryFile( - delete=False, prefix='/tmp/sp_tmp') as model_fp: + delete=False, prefix='/tmp/sp_tmp' + ) as model_fp: pass # we just want a prefix'd tmp-filename user_defined_symbols = ['▁' + symbol for symbol in user_defined_symbols] user_defined_symbols = ','.join(user_defined_symbols) argstr = ' '.join([ - f'--input={fname}', f'--vocab_size={vocab_size}', + f'--input={fname}', + f'--vocab_size={vocab_size}', f'--character_coverage={character_coverage}', - f'--model_prefix={model_fp.name}', f'--model_type={model_type}', + f'--model_prefix={model_fp.name}', + f'--model_type={model_type}', f'--user_defined_symbols={user_defined_symbols}', f'--byte_fallback={byte_fallback}', - f'--split_digits={split_digits}' + f'--split_digits={split_digits}', ]) spm.SentencePieceTrainer.Train(argstr) if jax.process_index() == 0: @@ -126,28 +133,33 @@ def _train_sentencepiece(dataset: tf.data.Dataset, return abs_model_path -def _load_sentencepiece_tokenizer(model_path: str, - add_bos: bool = False, - add_eos: bool = True, - reverse: bool = False): +def _load_sentencepiece_tokenizer( + model_path: str, + add_bos: bool = False, + add_eos: bool = True, + reverse: bool = False, +): """Load a tf-text SentencePiece tokenizer from given model filepath.""" with gfile.GFile(model_path, 'rb') as model_fp: sp_model = model_fp.read() sp_tokenizer = tftxt.SentencepieceTokenizer( - model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse) + model=sp_model, add_bos=add_bos, add_eos=add_eos, reverse=reverse + ) return sp_tokenizer -def load_or_train_tokenizer(dataset: tf.data.Dataset, - *, - vocab_path: str, - vocab_size: int, - max_corpus_chars: int, - character_coverage: float = 1.0, - byte_fallback: bool = False, - split_digits: bool = False, - data_keys: Tuple[str, str] = ('inputs', 'targets'), - user_defined_symbols: List[str] = []): +def load_or_train_tokenizer( + dataset: tf.data.Dataset, + *, + vocab_path: str, + vocab_size: int, + max_corpus_chars: int, + character_coverage: float = 1.0, + byte_fallback: bool = False, + split_digits: bool = False, + data_keys: Tuple[str, str] = ('inputs', 'targets'), + user_defined_symbols: Tuple[str, ...] = (), +): """Loads the tokenizer at `vocab_path` or trains a one from `dataset`.""" try: return _load_sentencepiece_tokenizer(vocab_path) @@ -162,7 +174,8 @@ def load_or_train_tokenizer(dataset: tf.data.Dataset, byte_fallback=byte_fallback, split_digits=split_digits, data_keys=data_keys, - user_defined_symbols=user_defined_symbols) + user_defined_symbols=user_defined_symbols, + ) return _load_sentencepiece_tokenizer(vocab_path) diff --git a/init2winit/dataset_lib/nanodo_c4.py b/init2winit/dataset_lib/nanodo_c4.py index 170a74b2..d2d46ebd 100644 --- a/init2winit/dataset_lib/nanodo_c4.py +++ b/init2winit/dataset_lib/nanodo_c4.py @@ -109,14 +109,12 @@ def py_batched_tfds( operations=pygrain_ops, sampler=index_sampler, worker_count=worker_count, - worker_buffer_size=worker_buffer_size + worker_buffer_size=worker_buffer_size, ) return batched_dataloader -def get_dataset( - shuffle_rng, batch_size, eval_batch_size=None, hps=None -): +def get_dataset(shuffle_rng, batch_size, eval_batch_size=None, hps=None): """Data generators for Nanodo.""" shuffle_seed = data_utils.convert_jax_to_tf_random_seed(shuffle_rng) @@ -153,22 +151,14 @@ def get_dataset( def train_iterator_fn(): for example in iter(train_ds): inputs, targets, weights = data_loader.get_in_out(example) - batch = { - 'inputs': inputs, - 'targets': targets, - 'weights': weights - } + batch = {'inputs': inputs, 'targets': targets, 'weights': weights} yield batch def eval_train_epoch(num_batches=None): eval_train_iter = iter(eval_ds) for example in itertools.islice(eval_train_iter, num_batches): inputs, targets, weights = data_loader.get_in_out(example) - batch = { - 'inputs': inputs, - 'targets': targets, - 'weights': weights - } + batch = {'inputs': inputs, 'targets': targets, 'weights': weights} yield batch @@ -176,11 +166,7 @@ def valid_epoch(num_batches=None): valid_iter = iter(eval_ds) for example in itertools.islice(valid_iter, num_batches): inputs, targets, weights = data_loader.get_in_out(example) - batch = { - 'inputs': inputs, - 'targets': targets, - 'weights': weights - } + batch = {'inputs': inputs, 'targets': targets, 'weights': weights} yield batch # pylint: disable=unreachable diff --git a/init2winit/dataset_lib/nanodo_data_loader_shared.py b/init2winit/dataset_lib/nanodo_data_loader_shared.py index d3b057da..1ada5370 100644 --- a/init2winit/dataset_lib/nanodo_data_loader_shared.py +++ b/init2winit/dataset_lib/nanodo_data_loader_shared.py @@ -27,12 +27,10 @@ import grain.python as grain import jax import jax.numpy as jnp - import numpy as np import sentencepiece as spm - PAD_ID = 0 @@ -81,6 +79,7 @@ def EncodeAsIds(self, text: Union[bytes, str]) -> list[int]: def DecodeIds(self, ids: Iterable[int]) -> str: return bytes(ids).decode('utf-8') + # pylint: enable=invalid-name diff --git a/init2winit/dataset_lib/nanodo_fineweb_edu.py b/init2winit/dataset_lib/nanodo_fineweb_edu.py index 5876c5df..6cfe825f 100644 --- a/init2winit/dataset_lib/nanodo_fineweb_edu.py +++ b/init2winit/dataset_lib/nanodo_fineweb_edu.py @@ -28,7 +28,6 @@ import numpy as np import tensorflow_datasets as tfds - VOCAB_SIZE = 100864 EVAL_SPLIT = 'cc_main_2024_10' @@ -173,22 +172,14 @@ def get_dataset(shuffle_rng, batch_size, eval_batch_size=None, hps=None): def train_iterator_fn(): for example in iter(train_ds): inputs, targets, weights = data_loader.get_in_out(example) - batch = { - 'inputs': inputs, - 'targets': targets, - 'weights': weights - } + batch = {'inputs': inputs, 'targets': targets, 'weights': weights} yield batch def eval_train_epoch(num_batches=None): eval_train_iter = iter(eval_ds) for example in itertools.islice(eval_train_iter, num_batches): inputs, targets, weights = data_loader.get_in_out(example) - batch = { - 'inputs': inputs, - 'targets': targets, - 'weights': weights - } + batch = {'inputs': inputs, 'targets': targets, 'weights': weights} yield batch @@ -196,11 +187,7 @@ def valid_epoch(num_batches=None): valid_iter = iter(eval_ds) for example in itertools.islice(valid_iter, num_batches): inputs, targets, weights = data_loader.get_in_out(example) - batch = { - 'inputs': inputs, - 'targets': targets, - 'weights': weights - } + batch = {'inputs': inputs, 'targets': targets, 'weights': weights} yield batch # pylint: disable=unreachable diff --git a/init2winit/dataset_lib/nqm_noise.py b/init2winit/dataset_lib/nqm_noise.py index e5642da3..67ae25a4 100644 --- a/init2winit/dataset_lib/nqm_noise.py +++ b/init2winit/dataset_lib/nqm_noise.py @@ -20,7 +20,6 @@ from ml_collections.config_dict import config_dict import numpy as np - NQM_HPARAMS = config_dict.ConfigDict( dict( train_size=1e10, @@ -28,7 +27,8 @@ test_size=0, input_shape=(100,), # This determines the dimension. output_shape=(1,), - )) + ) +) NQM_METADATA = { 'apply_one_hot_in_loss': False, } @@ -49,6 +49,7 @@ def get_nqm_noise(shuffle_rng, batch_size, eval_batch_size, hps=None): eval_batch_size: Not used. hps: Hparams object. We only refer to hps.input_shape to determine the dimension of the noise. + Returns: train_epoch, eval_train_epoch, valid_epoch, test_epoch: three generators. Only train_epoch is used. @@ -83,6 +84,7 @@ def valid_epoch(*args, **kwargs): del kwargs return yield # This yield is needed to make this a valid (null) iterator. + # pylint: enable=unreachable # pylint: disable=unreachable diff --git a/init2winit/dataset_lib/ogbg_molpcba.py b/init2winit/dataset_lib/ogbg_molpcba.py index 3b7c9b80..9da4e870 100644 --- a/init2winit/dataset_lib/ogbg_molpcba.py +++ b/init2winit/dataset_lib/ogbg_molpcba.py @@ -98,10 +98,9 @@ def __len__(self): return len(self.data) -def _load_dataset(split, - should_shuffle=False, - shuffle_seed=None, - shuffle_buffer_size=None): +def _load_dataset( + split, should_shuffle=False, shuffle_seed=None, shuffle_buffer_size=None +): """Loads a dataset split from TFDS.""" if should_shuffle: assert shuffle_seed is not None and shuffle_buffer_size is not None @@ -113,20 +112,23 @@ def _load_dataset(split, dataset_shuffle_seed = None read_config = tfds.ReadConfig( - add_tfds_id=True, shuffle_seed=file_shuffle_seed) + add_tfds_id=True, shuffle_seed=file_shuffle_seed + ) dataset = tfds.load( 'ogbg_molpcba', split=split, shuffle_files=should_shuffle, - read_config=read_config) + read_config=read_config, + ) logging.info('Loading in memory dataset...') dataset = list(tfds.as_numpy(dataset)) return _InMemoryDataset(dataset, should_shuffle, dataset_shuffle_seed) -def _to_jraph(example, add_bidirectional_edges, add_virtual_node, - add_self_loops): +def _to_jraph( + example, add_bidirectional_edges, add_virtual_node, add_self_loops +): """Converts an example graph to jraph.GraphsTuple.""" if hasattr(example['edge_feat'], '_numpy'): example = data_utils.tf_to_numpy(example) @@ -152,16 +154,19 @@ def _to_jraph(example, add_bidirectional_edges, add_virtual_node, new_senders = np.concatenate([new_senders, np.arange(num_nodes)]) new_receivers = np.concatenate([new_receivers, np.arange(num_nodes)]) edge_feat = np.concatenate( - [edge_feat, np.zeros((num_nodes, edge_feat.shape[-1]))]) + [edge_feat, np.zeros((num_nodes, edge_feat.shape[-1]))] + ) num_edges += num_nodes if add_virtual_node: node_feat = np.concatenate([node_feat, np.zeros_like(node_feat[0, None])]) new_senders = np.concatenate([new_senders, np.arange(num_nodes)]) new_receivers = np.concatenate( - [new_receivers, np.full((num_nodes,), num_nodes)]) + [new_receivers, np.full((num_nodes,), num_nodes)] + ) edge_feat = np.concatenate( - [edge_feat, np.zeros((num_nodes, edge_feat.shape[-1]))]) + [edge_feat, np.zeros((num_nodes, edge_feat.shape[-1]))] + ) num_edges += num_nodes num_nodes += 1 @@ -174,7 +179,8 @@ def _to_jraph(example, add_bidirectional_edges, add_virtual_node, receivers=new_receivers, # Keep the labels with the graph for batching. They will be removed # in the processed batch. - globals=np.expand_dims(labels, axis=0)) + globals=np.expand_dims(labels, axis=0), + ) def _get_weights_by_nan_and_padding(labels, padding_mask): @@ -550,11 +556,13 @@ def _get_dynamic_batch_iterator( _to_jraph, add_bidirectional_edges=add_bidirectional_edges, add_virtual_node=add_virtual_node, - add_self_loops=add_self_loops) + add_self_loops=add_self_loops, + ) jraph_iter = map(to_jraph_partial, dataset_iter) - batched_iter = jraph.dynamically_batch(jraph_iter, max_n_nodes + 1, - max_n_edges, max_n_graphs + 1) + batched_iter = jraph.dynamically_batch( + jraph_iter, max_n_nodes + 1, max_n_edges, max_n_graphs + 1 + ) count = 0 graphs_shards = [] @@ -569,7 +577,8 @@ def _get_dynamic_batch_iterator( graph = batched_graph._replace(globals={}) replaced_labels, weights = _get_weights_by_nan_and_padding( - labels, jraph.get_graph_padding_mask(graph)) + labels, jraph.get_graph_padding_mask(graph) + ) graphs_shards.append(graph) labels_shards.append(replaced_labels) @@ -582,7 +591,7 @@ def _get_dynamic_batch_iterator( # It is possible we may be leaking host memory with the np call. 'inputs': jraph.batch_np(graphs_shards), 'targets': np.vstack(labels_shards), - 'weights': np.vstack(weights_shards) + 'weights': np.vstack(weights_shards), } count = 0 @@ -613,6 +622,7 @@ def get_ogbg_molpcba(shuffle_rng, batch_size, eval_batch_size, hps=None): shuffle_buffer_size = 2**15 shuffle_rng_train, shuffle_rng_eval_train = jax.random.split(shuffle_rng) + def _log_mem(label): rss_mb = resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 logging.info('[ogbg] %s — RSS: %.1f MB', label, rss_mb) @@ -623,7 +633,8 @@ def _log_mem(label): 'train', should_shuffle=True, shuffle_seed=shuffle_rng_train, - shuffle_buffer_size=shuffle_buffer_size) + shuffle_buffer_size=shuffle_buffer_size, + ) _log_mem('After loading train split') eval_train_size = min(hps.valid_size, len(train_ds)) # Use a random subset of the training data for eval_train. @@ -720,8 +731,9 @@ def dataset_iterator(): num_nodes = int(rng.normal(loc=num_nodes_mean, scale=num_nodes_std)) # NOTE(dsuo): we want at least as many edges as we have nodes. - num_edges = max(num_nodes, - int(rng.normal(loc=num_edges_mean, scale=num_edges_std))) + num_edges = max( + num_nodes, int(rng.normal(loc=num_edges_mean, scale=num_edges_std)) + ) # NOTE(dsuo): create an edge between pair of consecutive nodes to have # a well-formed molecule. @@ -732,7 +744,8 @@ def dataset_iterator(): # NOTE(dsuo): create random edges for any remaining. if num_edges > num_nodes: edge_index[num_nodes:num_edges, :] = rng.choice( - num_nodes, (num_edges - num_nodes, 2)) + num_nodes, (num_edges - num_nodes, 2) + ) yield { 'edge_feat': tf.ones((num_edges, 3), dtype=tf.float32), diff --git a/init2winit/dataset_lib/pg19.py b/init2winit/dataset_lib/pg19.py index db2afec5..240d1f42 100644 --- a/init2winit/dataset_lib/pg19.py +++ b/init2winit/dataset_lib/pg19.py @@ -54,18 +54,21 @@ VOCAB_SIZE = 35561 -DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - max_target_length=8192, - shuffle_size=512, - train_size=302013, - input_shape=(16,), - output_shape=(VOCAB_SIZE,), - vocab_path=None, - data_dir=None, - vocab_size=VOCAB_SIZE, - max_corpus_chars=10**7, - eod_id=1, - eval_split='test')) +DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + max_target_length=8192, + shuffle_size=512, + train_size=302013, + input_shape=(16,), + output_shape=(VOCAB_SIZE,), + vocab_path=None, + data_dir=None, + vocab_size=VOCAB_SIZE, + max_corpus_chars=10**7, + eod_id=1, + eval_split='test', + ) +) METADATA = { 'apply_one_hot_in_loss': True, @@ -110,9 +113,9 @@ def map_line_length(tensor: tf.Tensor) -> Feature: return (tf.cast(len(tensor), tf.int64), tensor) -def scan_func(state: tf.Tensor, - element: tf.Tensor, - max_target_length: int = 8192): +def scan_func( + state: tf.Tensor, element: tf.Tensor, max_target_length: int = 8192 +): """Maps (old_state, input_element) to (new_state, output_element). Args: @@ -143,7 +146,8 @@ def init_func(key: tf.Tensor) -> tf.TensorArray: """Maps a nested structure of tensors to a scalar tf.int64 tensor.""" del key # Not used by Reducer return tf.TensorArray( - tf.int64, size=0, dynamic_size=True, clear_after_read=False) + tf.int64, size=0, dynamic_size=True, clear_after_read=False + ) def reduce_func(state: tf.Tensor, element: Feature) -> tf.TensorArray: @@ -171,8 +175,9 @@ def preprocess_example(example: Feature) -> Dict[tf.Tensor, tf.Tensor]: return map_line_length(example) -def generate_features(dataset: tf.data.Dataset, - hps: config_dict.ConfigDict) -> Feature: +def generate_features( + dataset: tf.data.Dataset, hps: config_dict.ConfigDict +) -> Feature: """Preprocesses a dataset before serliazing features and saving TFRecords. Args: @@ -184,14 +189,20 @@ def generate_features(dataset: tf.data.Dataset, """ dataset = dataset.map(preprocess_example, num_parallel_calls=AUTOTUNE) dataset = dataset.scan( - (tf.convert_to_tensor( - 0, dtype=tf.int64), tf.convert_to_tensor(0, dtype=tf.int64)), - functools.partial(scan_func, max_target_length=hps.max_target_length)) + ( + tf.convert_to_tensor(0, dtype=tf.int64), + tf.convert_to_tensor(0, dtype=tf.int64), + ), + functools.partial(scan_func, max_target_length=hps.max_target_length), + ) reducer = tf.data.experimental.Reducer( - init_func, reduce_func, - functools.partial(finalize_func, eod_id=hps.eod_id)) + init_func, + reduce_func, + functools.partial(finalize_func, eod_id=hps.eod_id), + ) dataset = dataset.apply( - tf.data.experimental.group_by_reducer(key_func, reducer)) + tf.data.experimental.group_by_reducer(key_func, reducer) + ) return dataset @@ -233,8 +244,13 @@ def create_example(tensor: tf.Tensor) -> tf.train.Example: return tf.train.Example(features=tf.train.Features(feature=feature)) -def write_pg19_tfrecords(data_dir: str, split: str, vocab_path: str, - hps: config_dict.ConfigDict, dataset_builder): +def write_pg19_tfrecords( + data_dir: str, + split: str, + vocab_path: str, + hps: config_dict.ConfigDict, + dataset_builder, +): """Writes preprocessed PG-19 data as TF Records split into shards. Args: @@ -256,14 +272,16 @@ def write_pg19_tfrecords(data_dir: str, split: str, vocab_path: str, train_data, vocab_path=vocab_path, vocab_size=hps.vocab_size, - max_corpus_chars=hps.max_corpus_chars) + max_corpus_chars=hps.max_corpus_chars, + ) if split == 'dev': split = 'validation' split_dataset = generate_dataset(dataset_builder, split) split_dataset = split_dataset.map( - spm_tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE) + spm_tokenizer.TokenizeOp(sp_tokenizer), num_parallel_calls=AUTOTUNE + ) split_dataset = generate_features(split_dataset, hps) num_shards = TFRECORDS_SHARDS[split] @@ -280,8 +298,9 @@ def write_pg19_tfrecords(data_dir: str, split: str, vocab_path: str, writer.write(feature.SerializeToString()) -def decode_and_preprocess_example(encoded_example: Feature, - hps: config_dict.ConfigDict) -> Feature: +def decode_and_preprocess_example( + encoded_example: Feature, hps: config_dict.ConfigDict +) -> Feature: """Decodes a serialized data in an encoded_example. Args: @@ -292,12 +311,14 @@ def decode_and_preprocess_example(encoded_example: Feature, output: a feature dictionary with decoded data. """ example = tf.io.parse_example( - encoded_example, { - 'targets': - tf.io.FixedLenSequenceFeature( - shape=[], dtype=tf.int64, allow_missing=True) - }) - return example['targets'][:hps.max_target_length] + encoded_example, + { + 'targets': tf.io.FixedLenSequenceFeature( + shape=[], dtype=tf.int64, allow_missing=True + ) + }, + ) + return example['targets'][: hps.max_target_length] def output_preprocess(tensor: tf.Tensor) -> Tuple[Feature, Feature]: @@ -314,16 +335,18 @@ def output_preprocess(tensor: tf.Tensor) -> Tuple[Feature, Feature]: return add_inputs_and_targets(tensor) -def get_dataset(data_dir: str, - split: str, - vocab_path: str, - per_host_batch_size: int, - hps: config_dict.ConfigDict, - shuffle: bool, - shuffle_rng: jax.random.PRNGKey, - process_count: int, - repeat: bool = False, - drop_remainder: bool = False) -> tf.data.Dataset: +def get_dataset( + data_dir: str, + split: str, + vocab_path: str, + per_host_batch_size: int, + hps: config_dict.ConfigDict, + shuffle: bool, + shuffle_rng: jax.random.PRNGKey, + process_count: int, + repeat: bool = False, + drop_remainder: bool = False, +) -> tf.data.Dataset: """Loads and decodes PG-19 TFRecords. Args: @@ -346,11 +369,17 @@ def get_dataset(data_dir: str, split = 'dev' data_files = tf.io.matching_files(os.path.join(data_dir, (split + '*'))) if not exists(data_dir) and tf.size(data_files) == 0: - logging.info('There is no directory like %s or' - ' there is no TFRecords for %s split', data_dir, split) - logging.info('Generating TFRecords for the %s split.' - ' If it is a train split then it might' - ' take 5-6 hours to complete.', split) + logging.info( + 'There is no directory like %s or there is no TFRecords for %s split', + data_dir, + split, + ) + logging.info( + 'Generating TFRecords for the %s split.' + ' If it is a train split then it might' + ' take 5-6 hours to complete.', + split, + ) pg19_builder = tfds.builder('pg19') write_pg19_tfrecords( @@ -358,7 +387,8 @@ def get_dataset(data_dir: str, split=split, vocab_path=vocab_path, hps=hps, - dataset_builder=pg19_builder) + dataset_builder=pg19_builder, + ) data_files = tf.io.matching_files(os.path.join(data_dir, (split + '*'))) if split == 'train': @@ -368,7 +398,8 @@ def get_dataset(data_dir: str, dataset = tf.data.TFRecordDataset(data_files, buffer_size=8 * 1024 * 1024) dataset = dataset.map( lambda x: decode_and_preprocess_example(x, hps=hps), - num_parallel_calls=AUTOTUNE) + num_parallel_calls=AUTOTUNE, + ) # In T2T they shuffle twice with a buffer sizes of 1024 and 512 respectively # We shuffle only once with a buffer size 512 since it does not seem to affect # results @@ -379,7 +410,8 @@ def get_dataset(data_dir: str, dataset = dataset.padded_batch( batch_size=per_host_batch_size, padded_shapes=hps.max_target_length, - drop_remainder=drop_remainder) + drop_remainder=drop_remainder, + ) dataset = dataset.map(output_preprocess, num_parallel_calls=AUTOTUNE) dataset = dataset.prefetch(AUTOTUNE) return dataset @@ -390,7 +422,7 @@ def get_pg19_datasets( per_host_batch_size: int, per_host_eval_batch_size: int, shuffle_rng: jax.random.PRNGKey, - process_count: int + process_count: int, ) -> Tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]: """Preprocesses a dataset. @@ -429,7 +461,8 @@ def get_pg19_datasets( shuffle=True, shuffle_rng=convert_jax_to_tf_random_seed(shuffle_rng), drop_remainder=True, - process_count=process_count) + process_count=process_count, + ) eval_ds = get_dataset( data_dir=data_dir, split='validation', @@ -438,7 +471,8 @@ def get_pg19_datasets( hps=hps, shuffle=False, shuffle_rng=None, - process_count=process_count) + process_count=process_count, + ) test_ds = get_dataset( data_dir=data_dir, split='test', @@ -447,15 +481,18 @@ def get_pg19_datasets( hps=hps, shuffle=False, shuffle_rng=None, - process_count=process_count) + process_count=process_count, + ) return train_ds, eval_ds, test_ds -def get_pg19(shuffle_rng: jax.random.PRNGKey = None, - batch_size: int = 8, - eval_batch_size: Optional[int] = None, - hps: config_dict.ConfigDict = None): +def get_pg19( + shuffle_rng: jax.random.PRNGKey = None, + batch_size: int = 8, + eval_batch_size: Optional[int] = None, + hps: config_dict.ConfigDict = None, +): """PG-19 data generator. Args: @@ -469,32 +506,46 @@ def get_pg19(shuffle_rng: jax.random.PRNGKey = None, """ process_count = jax.process_count() if batch_size % process_count != 0: - raise ValueError('process_count={} must divide batch_size={}.'.format( - process_count, batch_size)) + raise ValueError( + 'process_count={} must divide batch_size={}.'.format( + process_count, batch_size + ) + ) per_host_batch_size = batch_size // process_count if eval_batch_size is None: eval_batch_size = batch_size if eval_batch_size % process_count != 0: - raise ValueError('process_count={} must divide eval_batch_size={}.'.format( - process_count, eval_batch_size)) + raise ValueError( + 'process_count={} must divide eval_batch_size={}.'.format( + process_count, eval_batch_size + ) + ) per_host_eval_batch_size = eval_batch_size // process_count n_devices = jax.local_device_count() if per_host_batch_size % 1 != 0: - raise ValueError('n_devices={} must divide per_host_batch_size={}.'.format( - n_devices, per_host_batch_size)) + raise ValueError( + 'n_devices={} must divide per_host_batch_size={}.'.format( + n_devices, per_host_batch_size + ) + ) if per_host_eval_batch_size % 1 != 0: raise ValueError( 'n_devices={} must divide per_host_eval_batch_size={}.'.format( - n_devices, per_host_eval_batch_size)) - - train_ds, eval_ds, test_ds = get_pg19_datasets(hps, per_host_batch_size, - per_host_eval_batch_size, - shuffle_rng, - process_count) + n_devices, per_host_eval_batch_size + ) + ) + + train_ds, eval_ds, test_ds = get_pg19_datasets( + hps, + per_host_batch_size, + per_host_eval_batch_size, + shuffle_rng, + process_count, + ) def train_iterator_fn(): """Iterates over the train dataset and yields Numpy batches.""" @@ -513,7 +564,8 @@ def valid_epoch(num_batches: int = None): for batch in itertools.islice(iter(test_ds), num_batches): batch = tf_to_numpy(batch) yield maybe_pad_batch( - batch, desired_batch_size=per_host_eval_batch_size, padding_value=0) + batch, desired_batch_size=per_host_eval_batch_size, padding_value=0 + ) # pylint: disable=unreachable def test_epoch(*args, **kwargs): @@ -529,13 +581,15 @@ def valid_epoch(num_batches: int = None): for batch in itertools.islice(iter(eval_ds), num_batches): batch = tf_to_numpy(batch) yield maybe_pad_batch( - batch, desired_batch_size=per_host_eval_batch_size, padding_value=0) + batch, desired_batch_size=per_host_eval_batch_size, padding_value=0 + ) def test_epoch(num_batches: int = None): """Iterates over the test dataset and yields Numpy batches.""" for batch in itertools.islice(iter(test_ds), num_batches): batch = tf_to_numpy(batch) yield maybe_pad_batch( - batch, desired_batch_size=per_host_eval_batch_size, padding_value=0) + batch, desired_batch_size=per_host_eval_batch_size, padding_value=0 + ) return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch) diff --git a/init2winit/dataset_lib/protein_vocab.py b/init2winit/dataset_lib/protein_vocab.py index 3a2de57a..4e10c641 100644 --- a/init2winit/dataset_lib/protein_vocab.py +++ b/init2winit/dataset_lib/protein_vocab.py @@ -14,6 +14,7 @@ # limitations under the License. """Classes specifying different protein domains.""" + import collections from collections import abc import copy @@ -40,18 +41,20 @@ class Vocabulary(object): """Basic vocabulary used to represent output tokens for domains.""" - def __init__(self, - tokens, - name=None, - include_bos=False, - include_eos=False, - include_pad=False, - include_mask=False, - bos_token=BOS_TOKEN, - eos_token=EOS_TOKEN, - pad_token=PAD_TOKEN, - mask_token=MASK_TOKEN, - disallow_sep_token=True): + def __init__( + self, + tokens, + name=None, + include_bos=False, + include_eos=False, + include_pad=False, + include_mask=False, + bos_token=BOS_TOKEN, + eos_token=EOS_TOKEN, + pad_token=PAD_TOKEN, + mask_token=MASK_TOKEN, + disallow_sep_token=True, + ): """A token vocabulary. Args: @@ -61,22 +64,22 @@ def __init__(self, name: An optional name for the vocab. include_bos: Whether to append `bos_token` to `tokens` that marks the beginning of a sequence. - include_eos: Whether to append `eos_token` to `tokens` that marks the - end of a sequence. + include_eos: Whether to append `eos_token` to `tokens` that marks the end + of a sequence. include_pad: Whether to append `pad_token` to `tokens` to marks past end of sequence. include_mask: Whether to append `mask_token` to `tokens` to mark masked positions. - bos_token: A special token than marks the beginning of sequence. - Ignored if `include_bos == False`. - eos_token: A special token than marks the end of sequence. - Ignored if `include_eos == False`. - pad_token: A special token than marks past the end of sequence. - Ignored if `include_pad == False`. + bos_token: A special token than marks the beginning of sequence. Ignored + if `include_bos == False`. + eos_token: A special token than marks the end of sequence. Ignored if + `include_eos == False`. + pad_token: A special token than marks past the end of sequence. Ignored if + `include_pad == False`. mask_token: A special token than marks MASKED positions for e.g. BERT. Ignored if `include_mask == False`. disallow_sep_token: If True, disallow `|` appearing in the vocabulary, - which is used as separator token when serializing to csv. + which is used as separator token when serializing to csv. """ if not isinstance(tokens, abc.Iterable): tokens = range(tokens) @@ -95,7 +98,8 @@ def __init__(self, special_tokens = sorted(set(tokens) & set([SEP_TOKEN])) if special_tokens: raise ValueError( - f'tokens contains reserved special tokens: {special_tokens}!') + f'tokens contains reserved special tokens: {special_tokens}!' + ) self._name = name self._set_tokens(tokens) @@ -114,9 +118,11 @@ def _set_tokens(self, tokens): self._tokens = tokens self._token_ids = list(range(len(tokens))) self._id_to_token = collections.OrderedDict( - zip(self._token_ids, self._tokens)) + zip(self._token_ids, self._tokens) + ) self._token_to_id = collections.OrderedDict( - zip(self._tokens, self._token_ids)) + zip(self._tokens, self._token_ids) + ) def __setstate__(self, state): """Create vocab from dict version.""" @@ -139,7 +145,8 @@ def __getstate__(self): eos_token=self._eos_token, pad_token=self._pad_token, mask_token=self._mask_token, - )) + ) + ) def as_dict(self): """Serialize vocabulary to dict.""" @@ -178,26 +185,32 @@ def token_ids(self): @property def bos(self): """Returns the index of the BOS token or None if unspecified.""" - return (None if self._bos_token is None else - self._token_to_id[self._bos_token]) + return ( + None if self._bos_token is None else self._token_to_id[self._bos_token] + ) @property def eos(self): """Returns the index of the EOS token or None if unspecified.""" - return (None if self._eos_token is None else - self._token_to_id[self._eos_token]) + return ( + None if self._eos_token is None else self._token_to_id[self._eos_token] + ) @property def mask(self): """Returns the index of the MASK token or None if unspecified.""" - return (None if self._mask_token is None else - self._token_to_id[self._mask_token]) + return ( + None + if self._mask_token is None + else self._token_to_id[self._mask_token] + ) @property def pad(self): """Returns the index of the PAD token or None if unspecified.""" - return (None - if self._pad_token is None else self._token_to_id[self._pad_token]) + return ( + None if self._pad_token is None else self._token_to_id[self._pad_token] + ) def is_valid(self, value): """Tests if a value is a valid token id and returns a bool.""" @@ -243,12 +256,14 @@ def decode(self, values, stop_at_eos=False, as_str=False): return ''.join(tokens) if as_str else tokens -def make_protein_vocab(include_anomalous_amino_acids=True, - include_bos=True, - include_eos=True, - include_pad=True, - include_mask=True, - include_align_tokens=False): +def make_protein_vocab( + include_anomalous_amino_acids=True, + include_bos=True, + include_eos=True, + include_pad=True, + include_mask=True, + include_align_tokens=False, +): """Returns a vocabulary for proteins.""" tokens = list(_AA_TOKENS) if include_anomalous_amino_acids: @@ -261,5 +276,5 @@ def make_protein_vocab(include_anomalous_amino_acids=True, include_bos=include_bos, include_eos=include_eos, include_pad=include_pad, - include_mask=include_mask) - + include_mask=include_mask, + ) diff --git a/init2winit/dataset_lib/proteins.py b/init2winit/dataset_lib/proteins.py index 347f0b38..436333e7 100644 --- a/init2winit/dataset_lib/proteins.py +++ b/init2winit/dataset_lib/proteins.py @@ -40,7 +40,8 @@ output_shape=(31,), train_size=27000000, data_name='uniref50/unaligned_encoded', - )) + ) +) METADATA = { 'apply_one_hot_in_loss': True, @@ -65,18 +66,26 @@ def _crop(x, max_length, sample): start = tf.random.uniform( (), dtype=tf.int32, - maxval=tf.maximum(1, - tf.shape(x)[0] - max_length + 1)) + maxval=tf.maximum(1, tf.shape(x)[0] - max_length + 1), + ) else: start = 0 - x = x[start:(start + max_length)] + x = x[start : (start + max_length)] return x -def preprocess_masked(inputs, random_tokens, mask_token, pad_token, mask_rate, - mask_token_proportion, random_token_proportion, mode, - rng): +def preprocess_masked( + inputs, + random_tokens, + mask_token, + pad_token, + mask_rate, + mask_token_proportion, + random_token_proportion, + mode, + rng, +): """Preprocess inputs for masked language modeling. Args: @@ -99,8 +108,9 @@ def preprocess_masked(inputs, random_tokens, mask_token, pad_token, mask_rate, """ total = random_token_proportion + mask_token_proportion if total < 0 or total > 1: - raise ValueError('Sum of random proportion and mask proportion must be' - ' in [0, 1] range.') + raise ValueError( + 'Sum of random proportion and mask proportion must be in [0, 1] range.' + ) targets = inputs if mode == Mode.PREDICT: @@ -126,7 +136,8 @@ def preprocess_masked(inputs, random_tokens, mask_token, pad_token, mask_rate, # Generate full array of random tokens. rng, subrng = jax.random.split(rng) random_ids = jax.random.randint( - subrng, inputs.shape, minval=0, maxval=len(random_tokens)) + subrng, inputs.shape, minval=0, maxval=len(random_tokens) + ) fullrandom = random_tokens[random_ids] # Full array of MASK tokens @@ -137,13 +148,16 @@ def preprocess_masked(inputs, random_tokens, mask_token, pad_token, mask_rate, # Remaining probability mass stays original values after MASK and RANDOM. # MASK tokens. - masked_inputs = jnp.where(rand < mask_token_proportion, fullmask, - inputs) + masked_inputs = jnp.where(rand < mask_token_proportion, fullmask, inputs) # Random tokens. masked_inputs = jnp.where( - jnp.logical_and(rand >= mask_token_proportion, - rand < mask_token_proportion + random_token_proportion), - fullrandom, masked_inputs) + jnp.logical_and( + rand >= mask_token_proportion, + rand < mask_token_proportion + random_token_proportion, + ), + fullrandom, + masked_inputs, + ) # Only replace positions where `should_mask` masked_inputs = jnp.where(should_mask, masked_inputs, inputs) @@ -152,14 +166,16 @@ def preprocess_masked(inputs, random_tokens, mask_token, pad_token, mask_rate, return masked_inputs, targets, weights -class BertMasker(): +class BertMasker: """Construct BERT masker given a vocab.""" - def __init__(self, - vocab, - mask_rate=0.15, - mask_token_proportion=0.1, - random_token_proportion=0.8): + def __init__( + self, + vocab, + mask_rate=0.15, + mask_token_proportion=0.1, + random_token_proportion=0.8, + ): self._vocab = vocab if vocab.mask is None: raise ValueError('Vocabulary must specify a MASK token.') @@ -182,7 +198,8 @@ def __call__(self, inputs, mode, rng): pad_token=self._vocab.pad, mask_rate=self._mask_rate, mask_token_proportion=self._mask_token_proportion, - random_token_proportion=self._random_token_proportion) + random_token_proportion=self._random_token_proportion, + ) return inputs, targets, weights @@ -194,20 +211,23 @@ def shift_right(x, bos_token): x, pad_widths, mode='constant', - constant_values=tf.constant(bos_token, dtype=x.dtype)) + constant_values=tf.constant(bos_token, dtype=x.dtype), + ) return padded[:, :-1] -def load_protein_tfds(name, - split, - field='sequence', - num_epochs=1, - shuffle_buffer=2**15, - batch_size=32, - max_length=None, - sample_length=True, - data_dir=None, - drop_remainder=True): +def load_protein_tfds( + name, + split, + field='sequence', + num_epochs=1, + shuffle_buffer=2**15, + batch_size=32, + max_length=None, + sample_length=True, + data_dir=None, + drop_remainder=True, +): """Load protein tfds by name. If split is `train`, shuffle data. @@ -235,7 +255,8 @@ def load_protein_tfds(name, split=split, with_info=False, data_dir=data_dir, - shuffle_files=shuffle) + shuffle_files=shuffle, + ) # Construct vocab from stored metadata. # TODO(ddohan): Regenerate dataset with new vocab. @@ -248,9 +269,9 @@ def _get_field(example): # pylint: disable=protected-access if max_length: ds = ds.map( - functools.partial( - _crop, max_length=max_length, sample=sample_length), - num_parallel_calls=tf.data.experimental.AUTOTUNE) + functools.partial(_crop, max_length=max_length, sample=sample_length), + num_parallel_calls=tf.data.experimental.AUTOTUNE, + ) # pylint: disable=protected-access if shuffle: @@ -263,16 +284,14 @@ def _get_field(example): batch_size, padded_shapes=max_length, padding_values=np.array(vocab.pad, dtype=np.int64), - drop_remainder=drop_remainder) + drop_remainder=drop_remainder, + ) ds = ds.prefetch(tf.data.experimental.AUTOTUNE) return ds, vocab -def load_dataset(data_name, - batch_size, - eval_batch_size, - length=512): +def load_dataset(data_name, batch_size, eval_batch_size, length=512): """Load protein dataset. Args: @@ -286,9 +305,11 @@ def load_dataset(data_name, """ logging.info('Loading data_name: %s', data_name) train_ds, vocab = load_protein_tfds( - data_name, 'train', max_length=length, batch_size=batch_size) + data_name, 'train', max_length=length, batch_size=batch_size + ) valid_ds, vocab = load_protein_tfds( - data_name, 'validation', max_length=length, batch_size=eval_batch_size) + data_name, 'validation', max_length=length, batch_size=eval_batch_size + ) return train_ds, valid_ds, vocab @@ -296,11 +317,7 @@ def load_dataset(data_name, def _batch_to_dict(batch, masker, mode, rng): batch = data_utils.tf_to_numpy(batch) inputs, targets, weights = masker(batch, mode, rng) - return { - 'inputs': inputs, - 'targets': targets, - 'weights': weights - } + return {'inputs': inputs, 'targets': targets, 'weights': weights} def get_uniref(shuffle_rng, batch_size, eval_batch_size, hps=None): @@ -308,38 +325,39 @@ def get_uniref(shuffle_rng, batch_size, eval_batch_size, hps=None): per_host_batch_size = batch_size // jax.process_count() per_host_eval_batch_size = eval_batch_size // jax.process_count() return _get_uniref( - per_host_batch_size, - per_host_eval_batch_size, - hps, - shuffle_rng) + per_host_batch_size, per_host_eval_batch_size, hps, shuffle_rng + ) -def _get_uniref( - per_host_batch_size, - per_host_eval_batch_size, - hps, - data_rng): +def _get_uniref(per_host_batch_size, per_host_eval_batch_size, hps, data_rng): """Data generators for Uniref50 clustered protein dataset.""" # TODO(gilmer) Currently uniref drops the last partial batch on eval. logging.warning( - 'Currently the Protein dataset drops the last partial batch on eval') + 'Currently the Protein dataset drops the last partial batch on eval' + ) if jax.process_count() > 1: raise NotImplementedError('Proteins does not support multihost training') n_devices = jax.local_device_count() if per_host_batch_size % n_devices != 0: - raise ValueError('n_devices={} must divide per_host_batch_size={}.'.format( - n_devices, per_host_batch_size)) + raise ValueError( + 'n_devices={} must divide per_host_batch_size={}.'.format( + n_devices, per_host_batch_size + ) + ) if per_host_eval_batch_size % n_devices != 0: raise ValueError( 'n_devices={} must divide per_host_eval_batch_size={}.'.format( - n_devices, per_host_eval_batch_size)) + n_devices, per_host_eval_batch_size + ) + ) train_ds, eval_ds, vocab = load_dataset( hps.data_name, batch_size=per_host_batch_size, eval_batch_size=per_host_eval_batch_size, - length=hps.max_target_length) + length=hps.max_target_length, + ) masker = BertMasker(vocab=vocab) @@ -351,14 +369,16 @@ def train_iterator_fn(): def eval_train_epoch(num_batches=None): eval_train_iter = iter(train_ds) for batch_index, batch in enumerate( - itertools.islice(eval_train_iter, num_batches)): + itertools.islice(eval_train_iter, num_batches) + ): batch_rng = jax.random.fold_in(data_rng, batch_index) yield _batch_to_dict(batch, masker, 'eval', batch_rng) def valid_epoch(num_batches=None): valid_iter = iter(eval_ds) for batch_index, batch in enumerate( - itertools.islice(valid_iter, num_batches)): + itertools.islice(valid_iter, num_batches) + ): batch_rng = jax.random.fold_in(data_rng, batch_index) yield _batch_to_dict(batch, masker, 'eval', batch_rng) diff --git a/init2winit/dataset_lib/small_image_datasets.py b/init2winit/dataset_lib/small_image_datasets.py index 34202128..55f524a0 100644 --- a/init2winit/dataset_lib/small_image_datasets.py +++ b/init2winit/dataset_lib/small_image_datasets.py @@ -29,99 +29,121 @@ import tensorflow.compat.v2 as tf import tensorflow_datasets as tfds - -MNIST_HPARAMS = config_dict.ConfigDict(dict( - train_size=50000, - valid_size=10000, - test_size=10000, - input_shape=(28, 28, 1), - output_shape=(10,))) +MNIST_HPARAMS = config_dict.ConfigDict( + dict( + train_size=50000, + valid_size=10000, + test_size=10000, + input_shape=(28, 28, 1), + output_shape=(10,), + ) +) MNIST_METADATA = { 'apply_one_hot_in_loss': False, } -MNIST_AUTOENCODER_HPARAMS = config_dict.ConfigDict(dict( - train_size=50000, - valid_size=10000, - test_size=10000, - input_shape=(28, 28, 1), - output_shape=(784,))) +MNIST_AUTOENCODER_HPARAMS = config_dict.ConfigDict( + dict( + train_size=50000, + valid_size=10000, + test_size=10000, + input_shape=(28, 28, 1), + output_shape=(784,), + ) +) MNIST_AUTOENCODER_METADATA = { 'apply_one_hot_in_loss': False, } -FASHION_MNIST_HPARAMS = config_dict.ConfigDict(dict( - train_size=45000, - valid_size=5000, - test_size=10000, - input_shape=(28, 28, 1), - output_shape=(10,))) +FASHION_MNIST_HPARAMS = config_dict.ConfigDict( + dict( + train_size=45000, + valid_size=5000, + test_size=10000, + input_shape=(28, 28, 1), + output_shape=(10,), + ) +) FASHION_MNIST_METADATA = { 'apply_one_hot_in_loss': False, } -CIFAR10_DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - flip_probability=0.5, - alpha=1.0, - crop_num_pixels=4, - use_mixup=True, - train_size=45000, - valid_size=5000, - test_size=10000, - input_shape=(32, 32, 3), - output_shape=(10,))) +CIFAR10_DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + flip_probability=0.5, + alpha=1.0, + crop_num_pixels=4, + use_mixup=True, + train_size=45000, + valid_size=5000, + test_size=10000, + input_shape=(32, 32, 3), + output_shape=(10,), + ) +) CIFAR10_METADATA = { 'apply_one_hot_in_loss': False, } -CIFAR100_DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - flip_probability=0.5, - alpha=1.0, - crop_num_pixels=4, - use_mixup=True, - train_size=45000, - valid_size=5000, - test_size=10000, - input_shape=(32, 32, 3), - output_shape=(100,))) +CIFAR100_DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + flip_probability=0.5, + alpha=1.0, + crop_num_pixels=4, + use_mixup=True, + train_size=45000, + valid_size=5000, + test_size=10000, + input_shape=(32, 32, 3), + output_shape=(100,), + ) +) CIFAR100_METADATA = { 'apply_one_hot_in_loss': False, } -SVHN_NO_EXTRA_DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - flip_probability=0.5, - alpha=1.0, - crop_num_pixels=4, - use_mixup=False, - train_size=73257 - 7000, - valid_size=7000, - test_size=26032, - input_shape=(32, 32, 3), - output_shape=(10,))) +SVHN_NO_EXTRA_DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + flip_probability=0.5, + alpha=1.0, + crop_num_pixels=4, + use_mixup=False, + train_size=73257 - 7000, + valid_size=7000, + test_size=26032, + input_shape=(32, 32, 3), + output_shape=(10,), + ) +) SVHN_NO_EXTRA_METADATA = { 'apply_one_hot_in_loss': False, } -def _eval_batches(images, - labels, - per_host_batch_size, - num_batches=None, - valid_example_keys=None): +def _eval_batches( + images, + labels, + per_host_batch_size, + num_batches=None, + valid_example_keys=None, +): """Produce a stream of batches for a single evaluation epoch.""" for idx in itertools.islice( - range(0, images.shape[0], per_host_batch_size), num_batches): - inputs = jnp.array(images[idx:idx + per_host_batch_size]) - targets = jnp.array(labels[idx:idx + per_host_batch_size]) + range(0, images.shape[0], per_host_batch_size), num_batches + ): + inputs = jnp.array(images[idx : idx + per_host_batch_size]) + targets = jnp.array(labels[idx : idx + per_host_batch_size]) data_dict = { 'inputs': inputs, 'targets': targets, 'weights': jnp.ones(inputs.shape[0], dtype=inputs.dtype), } if valid_example_keys is not None and len(valid_example_keys) == len( - images): - data_dict['example_key'] = valid_example_keys[idx:idx + - per_host_batch_size] + images + ): + data_dict['example_key'] = valid_example_keys[ + idx : idx + per_host_batch_size + ] data_dict = data_utils.maybe_pad_batch(data_dict, per_host_batch_size) yield data_dict @@ -129,7 +151,7 @@ def _eval_batches(images, def _shard_by_host_id(array): split_size = len(array) // jax.process_count() start = split_size * jax.process_index() - return array[start: start+split_size] + return array[start : start + split_size] def _prepare_small_image_datasets( @@ -146,18 +168,21 @@ def _prepare_small_image_datasets( augment_fn, is_one_hot=True, autoencoder=False, - include_example_keys=False): + include_example_keys=False, +): """Prepare Dataset using tf.data.Datasets of the different splits.""" if autoencoder and is_one_hot: raise ValueError( - 'One hot encoding cannot be applied to autoencoder datasets.') + 'One hot encoding cannot be applied to autoencoder datasets.' + ) eval_image_iterator = functools.partial( data_utils.image_iterator, rescale=rescale, output_shape=output_shape, is_one_hot=is_one_hot, - autoencoder=autoencoder) + autoencoder=autoencoder, + ) # Setup the eval_train split as a copy of the training data, in the form of # the first `num_train_batches` batches of the data as an np.array. @@ -165,24 +190,33 @@ def _prepare_small_image_datasets( num_train_batches = train_size // per_host_batch_size eval_train_data = list( - itertools.islice(eval_train_iterator, 0, num_train_batches)) + itertools.islice(eval_train_iterator, 0, num_train_batches) + ) eval_train_inputs = jnp.array([batch['inputs'] for batch in eval_train_data]) - eval_train_inputs_shape = (num_train_batches * per_host_batch_size, - *input_shape) + eval_train_inputs_shape = ( + num_train_batches * per_host_batch_size, + *input_shape, + ) eval_train_inputs = np.reshape(eval_train_inputs, eval_train_inputs_shape) eval_train_targets = jnp.array( - [batch['targets'] for batch in eval_train_data]) - eval_train_output_shape = (num_train_batches * per_host_batch_size, - *output_shape) + [batch['targets'] for batch in eval_train_data] + ) + eval_train_output_shape = ( + num_train_batches * per_host_batch_size, + *output_shape, + ) eval_train_targets = np.reshape(eval_train_targets, eval_train_output_shape) valid_inputs = jnp.array([]) valid_targets = jnp.array([]) valid_example_keys = jnp.array([]) if data_valid: - valid_data = next(eval_image_iterator( - data_valid, include_example_keys=include_example_keys)) + valid_data = next( + eval_image_iterator( + data_valid, include_example_keys=include_example_keys + ) + ) valid_inputs = valid_data['inputs'] valid_targets = valid_data['targets'] if include_example_keys: @@ -212,24 +246,25 @@ def _prepare_small_image_datasets( is_one_hot=is_one_hot, autoencoder=autoencoder, shuffle_rng=shuffle_rng, - augment_fn=augment_fn) + augment_fn=augment_fn, + ) eval_train_epoch = functools.partial( _eval_batches, eval_train_inputs, eval_train_targets, - per_host_eval_batch_size) + per_host_eval_batch_size, + ) valid_epoch = functools.partial( _eval_batches, valid_inputs, valid_targets, per_host_eval_batch_size, - valid_example_keys=valid_example_keys) + valid_example_keys=valid_example_keys, + ) test_epoch = functools.partial( - _eval_batches, - test_inputs, - test_targets, - per_host_eval_batch_size) + _eval_batches, test_inputs, test_targets, per_host_eval_batch_size + ) return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch) @@ -247,7 +282,8 @@ def _process_small_tfds_image_ds( shuffle_rng, augment_fn=image_preprocessing.identity_augment, is_one_hot=True, - autoencoder=False): + autoencoder=False, +): """Helper wrapper around tfds which converts the data into the init2winit API. Returns three data generators, train, validation and test. The API of @@ -282,8 +318,8 @@ def _process_small_tfds_image_ds( input_shape: (tuple) Used to check that the data is of the correct shape. output_shape: (tuple) Shape of network output. shuffle_rng: jax.random.PRNGKey - augment_fn: Function with API (rng, images, labels) -> images, labels. - This function will be applied to every training batch. + augment_fn: Function with API (rng, images, labels) -> images, labels. This + function will be applied to every training batch. is_one_hot: (bool) If true, targets are one hot encoded. autoencoder: (bool) If true, targets are set to input images. @@ -293,37 +329,57 @@ def _process_small_tfds_image_ds( augment_fn = jax.jit(augment_fn) train_split = tfds.core.ReadInstruction('train', to=train_size, unit='abs') - data_train = tfds.load(name=dataset_name, split=train_split, - as_dataset_kwargs={'shuffle_files': False}) + data_train = tfds.load( + name=dataset_name, + split=train_split, + as_dataset_kwargs={'shuffle_files': False}, + ) data_train = data_train.cache() # Ensure a different shuffle of the training data on each host. shuffle_rng = jax.random.fold_in(shuffle_rng, jax.process_index()) data_train = data_train.shuffle( train_size, reshuffle_each_iteration=True, - seed=data_utils.convert_jax_to_tf_random_seed(shuffle_rng)) + seed=data_utils.convert_jax_to_tf_random_seed(shuffle_rng), + ) data_train = data_train.repeat() data_train = data_train.batch(per_host_batch_size) if valid_size > 0: valid_split = tfds.core.ReadInstruction( - 'train', from_=-valid_size, unit='abs') - data_valid = tfds.load(name=dataset_name, split=valid_split, - as_dataset_kwargs={'shuffle_files': False}) + 'train', from_=-valid_size, unit='abs' + ) + data_valid = tfds.load( + name=dataset_name, + split=valid_split, + as_dataset_kwargs={'shuffle_files': False}, + ) data_valid = data_valid.batch(valid_size) else: data_valid = None - data_test = tfds.load(name=dataset_name, split='test', - as_dataset_kwargs={'shuffle_files': False}) + data_test = tfds.load( + name=dataset_name, + split='test', + as_dataset_kwargs={'shuffle_files': False}, + ) data_test = data_test.batch(test_size) - return _prepare_small_image_datasets(data_train, data_valid, data_test, - per_host_batch_size, - per_host_eval_batch_size, train_size, - rescale, input_shape, output_shape, - shuffle_rng, augment_fn, - is_one_hot, autoencoder) + return _prepare_small_image_datasets( + data_train, + data_valid, + data_test, + per_host_batch_size, + per_host_eval_batch_size, + train_size, + rescale, + input_shape, + output_shape, + shuffle_rng, + augment_fn, + is_one_hot, + autoencoder, + ) def _load_deterministic_with_custom_validation( @@ -338,7 +394,8 @@ def _load_deterministic_with_custom_validation( output_shape, shuffle_rng, augment_fn=image_preprocessing.identity_augment, - include_example_keys=False): + include_example_keys=False, +): """Load a small image dataset with a deterministic validation set. This allows users to deterministically load a validation set that is sampled @@ -375,12 +432,15 @@ def _load_deterministic_with_custom_validation( train_ds = tfds.load( dataset_name, split=tfds.core.ReadInstruction( - 'train', from_=0, to=train_size, unit='abs')) + 'train', from_=0, to=train_size, unit='abs' + ), + ) data_train = train_ds.shuffle( train_size, reshuffle_each_iteration=True, - seed=int(jax.random.randint(shuffle_rng, (), 0, 1000))) + seed=int(jax.random.randint(shuffle_rng, (), 0, 1000)), + ) data_train = data_train.repeat() data_train = data_train.batch(per_host_batch_size) @@ -389,7 +449,9 @@ def _load_deterministic_with_custom_validation( dataset_name, read_config=read_config, split=tfds.core.ReadInstruction( - 'train', from_=train_size, to=train_size + valid_size, unit='abs')) + 'train', from_=train_size, to=train_size + valid_size, unit='abs' + ), + ) data_valid = valid_ds.batch(valid_size) else: valid_ds = tf.data.Dataset.from_tensor_slices([]) @@ -398,7 +460,8 @@ def _load_deterministic_with_custom_validation( data_test = tfds.load( name=dataset_name, split='test', - as_dataset_kwargs={'shuffle_files': False}) + as_dataset_kwargs={'shuffle_files': False}, + ) data_test = data_test.batch(test_size) return _prepare_small_image_datasets( @@ -413,7 +476,8 @@ def _load_deterministic_with_custom_validation( output_shape, shuffle_rng, augment_fn, - include_example_keys=include_example_keys) + include_example_keys=include_example_keys, + ) def get_mnist(shuffle_rng, batch_size, eval_batch_size, hps=None): @@ -439,19 +503,25 @@ def get_mnist(shuffle_rng, batch_size, eval_batch_size, hps=None): size yielded from valid_epoch() and test_epoch(). hps: Hparams object. hps.train_size, hps.valid_size, and hps.test_size will specify the sizes of the various data splits. + Returns: train_epoch, valid_epoch, test_epoch: three generators. """ per_host_batch_size = batch_size // jax.process_count() per_host_eval_batch_size = eval_batch_size // jax.process_count() rescale = lambda x: x / 255.0 - return _process_small_tfds_image_ds('mnist', - per_host_batch_size, - per_host_eval_batch_size, - hps.train_size, hps.valid_size, - hps.test_size, rescale, - hps.input_shape, - hps.output_shape, shuffle_rng) + return _process_small_tfds_image_ds( + 'mnist', + per_host_batch_size, + per_host_eval_batch_size, + hps.train_size, + hps.valid_size, + hps.test_size, + rescale, + hps.input_shape, + hps.output_shape, + shuffle_rng, + ) def get_mnist_autoencoder(shuffle_rng, batch_size, eval_batch_size, hps=None): @@ -484,19 +554,23 @@ def get_mnist_autoencoder(shuffle_rng, batch_size, eval_batch_size, hps=None): per_host_batch_size = batch_size // jax.process_count() per_host_eval_batch_size = eval_batch_size // jax.process_count() rescale = lambda x: x / 255.0 - return _process_small_tfds_image_ds('mnist', - per_host_batch_size, - per_host_eval_batch_size, hps.train_size, - hps.valid_size, hps.test_size, rescale, - hps.input_shape, hps.output_shape, - shuffle_rng, - is_one_hot=False, autoencoder=True) - - -def get_fashion_mnist(shuffle_rng, - batch_size, - eval_batch_size, - hps=None): + return _process_small_tfds_image_ds( + 'mnist', + per_host_batch_size, + per_host_eval_batch_size, + hps.train_size, + hps.valid_size, + hps.test_size, + rescale, + hps.input_shape, + hps.output_shape, + shuffle_rng, + is_one_hot=False, + autoencoder=True, + ) + + +def get_fashion_mnist(shuffle_rng, batch_size, eval_batch_size, hps=None): """Returns generators for the Fashion MNIST train, validation, and test set. Returns three data generators, train, validation and test. The API of @@ -519,25 +593,28 @@ def get_fashion_mnist(shuffle_rng, size yielded from valid_epoch() and test_epoch(). hps: Hparams object. hps.train_size, hps.valid_size, and hps.test_size will specify the sizes of the various data splits. + Returns: train_epoch, valid_epoch, test_epoch: three generators. """ per_host_batch_size = batch_size // jax.process_count() per_host_eval_batch_size = eval_batch_size // jax.process_count() rescale = lambda x: x / 255.0 - return _process_small_tfds_image_ds('fashion_mnist', - per_host_batch_size, - per_host_eval_batch_size, - hps.train_size, hps.valid_size, - hps.test_size, rescale, - hps.input_shape, - hps.output_shape, shuffle_rng) - - -def get_cifar10(shuffle_rng, - batch_size, - eval_batch_size, - hps=None): + return _process_small_tfds_image_ds( + 'fashion_mnist', + per_host_batch_size, + per_host_eval_batch_size, + hps.train_size, + hps.valid_size, + hps.test_size, + rescale, + hps.input_shape, + hps.output_shape, + shuffle_rng, + ) + + +def get_cifar10(shuffle_rng, batch_size, eval_batch_size, hps=None): """Returns generators for the CIFAR10 train, validation, and test set. Returns three data generators, train, validation and test. The API of @@ -578,8 +655,9 @@ def get_cifar10(shuffle_rng, augment_fn = functools.partial(image_preprocessing.augment_cifar10, hps=hps) if hps.train_size + hps.valid_size > 50000: - raise ValueError('The sum of train_size and valid_size should not exceed ' - '50k.') + raise ValueError( + 'The sum of train_size and valid_size should not exceed 50k.' + ) if hps.test_size > 10000: raise ValueError('test_size should not exceed 10k') @@ -595,13 +673,11 @@ def get_cifar10(shuffle_rng, hps.output_shape, shuffle_rng, augment_fn, - include_example_keys=hps.get('include_example_keys', False)) + include_example_keys=hps.get('include_example_keys', False), + ) -def get_cifar100(shuffle_rng, - batch_size, - eval_batch_size, - hps=None): +def get_cifar100(shuffle_rng, batch_size, eval_batch_size, hps=None): """Returns generators for the CIFAR100 train, validation, and test set. Returns three data generators, train, validation and test. The API of @@ -625,32 +701,36 @@ def get_cifar100(shuffle_rng, hps: Hparams object. hps.train_size, hps.valid_size, and hps.test_size will specify the sizes of the various data splits. See image_preprocessing for hparams that control the data augmentation. + Returns: train_epoch, valid_epoch, test_epoch: three generators. """ per_host_batch_size = batch_size // jax.process_count() per_host_eval_batch_size = eval_batch_size // jax.process_count() - mean = jnp.array([129.3041658, 124.06996185, 112.4340492])[None, None, - None, :] + mean = jnp.array([129.3041658, 124.06996185, 112.4340492])[ + None, None, None, : + ] std = jnp.array([68.17024395, 65.3918073, 70.41836985])[None, None, None, :] rescale = lambda x: (x - mean) / std augment_fn = functools.partial(image_preprocessing.augment_cifar10, hps=hps) - return _process_small_tfds_image_ds('cifar100', - per_host_batch_size, - per_host_eval_batch_size, - hps.train_size, hps.valid_size, - hps.test_size, rescale, - hps.input_shape, - hps.output_shape, shuffle_rng, augment_fn) + return _process_small_tfds_image_ds( + 'cifar100', + per_host_batch_size, + per_host_eval_batch_size, + hps.train_size, + hps.valid_size, + hps.test_size, + rescale, + hps.input_shape, + hps.output_shape, + shuffle_rng, + augment_fn, + ) # TODO(znado): add version that also has the "extra" training examples. -def get_svhn_no_extra( - shuffle_rng, - batch_size, - eval_batch_size, - hps=None): +def get_svhn_no_extra(shuffle_rng, batch_size, eval_batch_size, hps=None): """Returns generators for the SVHN train, validation, and test set. Note that we do not include the "extra" 531131 examples which are sometimes @@ -681,6 +761,7 @@ def get_svhn_no_extra( hps: Hparams object. hps.train_size, hps.valid_size, and hps.test_size will specify the sizes of the various data splits. See image_preprocessing for hparams that control the data augmentation. + Returns: train_epoch, valid_epoch, test_epoch: three generators. """ @@ -699,4 +780,5 @@ def get_svhn_no_extra( hps.input_shape, hps.output_shape, shuffle_rng, - augment_fn=augment_fn) + augment_fn=augment_fn, + ) diff --git a/init2winit/dataset_lib/test_data_utils.py b/init2winit/dataset_lib/test_data_utils.py index fdbce54e..960ed9c7 100644 --- a/init2winit/dataset_lib/test_data_utils.py +++ b/init2winit/dataset_lib/test_data_utils.py @@ -70,7 +70,8 @@ def test_padding(self, image_format, batch_axis, input_shape): """Test that the shape is the expected padded shape.""" batch = {'inputs': np.ones(input_shape)} padded_batch = data_utils.maybe_pad_batch( - batch, desired_batch_size, image_format) + batch, desired_batch_size, image_format + ) expected_shapes = list(input_shape) expected_shapes[batch_axis] = desired_batch_size self.assertEqual(padded_batch['inputs'].shape, tuple(expected_shapes)) @@ -92,7 +93,8 @@ def test_padding_seq2seq(self): expected_targets_shape = (desired_batch_size, target_len_max) expected_weights_shape = (desired_batch_size, target_len_max) padded_batch = data_utils.maybe_pad_batch( - batch, desired_batch_size, data_format=None, mask_key='targets') + batch, desired_batch_size, data_format=None, mask_key='targets' + ) self.assertEqual(padded_batch['inputs'].shape, expected_inputs_shape) self.assertEqual(padded_batch['targets'].shape, expected_targets_shape) self.assertEqual(padded_batch['weights'].shape, expected_weights_shape) @@ -104,7 +106,8 @@ def test_padding_seq2seq(self): # # pad at sequence_len axis expected_weights_array[:, target_len_true:] = 0 self.assertTrue( - np.array_equal(padded_batch['weights'], expected_weights_array)) + np.array_equal(padded_batch['weights'], expected_weights_array) + ) class CachedIteratorFactoryTest(absltest.TestCase): diff --git a/init2winit/dataset_lib/test_datasets.py b/init2winit/dataset_lib/test_datasets.py index 1e9866a7..9baedca8 100644 --- a/init2winit/dataset_lib/test_datasets.py +++ b/init2winit/dataset_lib/test_datasets.py @@ -27,7 +27,6 @@ import numpy as np import tensorflow_datasets as tfds - FLAGS = flags.FLAGS @@ -57,91 +56,124 @@ def test_determinism(self, ds): shuffle_rng=jax.random.PRNGKey(0), batch_size=batch_size, eval_batch_size=eval_batch_size, - hps=hps) + hps=hps, + ) dataset_copy = dataset_builder( shuffle_rng=jax.random.PRNGKey(0), batch_size=batch_size, eval_batch_size=eval_batch_size, - hps=hps) + hps=hps, + ) batch_idx_to_test = 1 - saved_batch = next(itertools.islice( - dataset.train_iterator_fn(), batch_idx_to_test, - batch_idx_to_test + 1)) - saved_batch_same_epoch = next(itertools.islice( - dataset_copy.train_iterator_fn(), batch_idx_to_test, - batch_idx_to_test + 1)) - saved_batch_diff_epoch = next(itertools.islice( - dataset.train_iterator_fn(), batch_idx_to_test + 3, - batch_idx_to_test + 4)) - - saved_batch_eval = next(itertools.islice( - dataset.valid_epoch(), batch_idx_to_test, - batch_idx_to_test + 1)) + saved_batch = next( + itertools.islice( + dataset.train_iterator_fn(), + batch_idx_to_test, + batch_idx_to_test + 1, + ) + ) + saved_batch_same_epoch = next( + itertools.islice( + dataset_copy.train_iterator_fn(), + batch_idx_to_test, + batch_idx_to_test + 1, + ) + ) + saved_batch_diff_epoch = next( + itertools.islice( + dataset.train_iterator_fn(), + batch_idx_to_test + 3, + batch_idx_to_test + 4, + ) + ) + + saved_batch_eval = next( + itertools.islice( + dataset.valid_epoch(), batch_idx_to_test, batch_idx_to_test + 1 + ) + ) saved_batch_eval_same_epoch = next( - itertools.islice(dataset_copy.valid_epoch(), batch_idx_to_test, - batch_idx_to_test + 1)) + itertools.islice( + dataset_copy.valid_epoch(), batch_idx_to_test, batch_idx_to_test + 1 + ) + ) self.assertTrue( - jnp.array_equal(saved_batch['inputs'], - saved_batch_same_epoch['inputs'])) + jnp.array_equal(saved_batch['inputs'], saved_batch_same_epoch['inputs']) + ) self.assertTrue( - jnp.array_equal(saved_batch['targets'], - saved_batch_same_epoch['targets'])) + jnp.array_equal( + saved_batch['targets'], saved_batch_same_epoch['targets'] + ) + ) self.assertFalse( - jnp.array_equal(saved_batch['inputs'], - saved_batch_diff_epoch['inputs'])) + jnp.array_equal(saved_batch['inputs'], saved_batch_diff_epoch['inputs']) + ) self.assertFalse( - jnp.array_equal(saved_batch['targets'], - saved_batch_diff_epoch['targets'])) + jnp.array_equal( + saved_batch['targets'], saved_batch_diff_epoch['targets'] + ) + ) self.assertTrue( - jnp.array_equal(saved_batch_eval['inputs'], - saved_batch_eval_same_epoch['inputs'])) + jnp.array_equal( + saved_batch_eval['inputs'], saved_batch_eval_same_epoch['inputs'] + ) + ) # Check shapes - expected_shape = jnp.array([ - batch_size, hps.input_shape[0], hps.input_shape[1], hps.input_shape[2] - ]) + expected_shape = jnp.array( + [batch_size, hps.input_shape[0], hps.input_shape[1], hps.input_shape[2]] + ) expected_shape_eval = jnp.array([ - eval_batch_size, hps.input_shape[0], - hps.input_shape[1], hps.input_shape[2], + eval_batch_size, + hps.input_shape[0], + hps.input_shape[1], + hps.input_shape[2], ]) self.assertTrue( - jnp.array_equal(saved_batch['inputs'].shape, expected_shape)) + jnp.array_equal(saved_batch['inputs'].shape, expected_shape) + ) self.assertTrue( - jnp.array_equal(saved_batch_eval['inputs'].shape, expected_shape_eval)) + jnp.array_equal(saved_batch_eval['inputs'].shape, expected_shape_eval) + ) expected_target_shape = jnp.array( - [batch_size, get_dataset_hparams(ds)['output_shape'][-1]]) - self.assertTrue(jnp.array_equal(saved_batch['targets'].shape, - expected_target_shape)) + [batch_size, get_dataset_hparams(ds)['output_shape'][-1]] + ) + self.assertTrue( + jnp.array_equal(saved_batch['targets'].shape, expected_target_shape) + ) # Check that the training gen drops the last partial batch. drop_partial_batches = list( - itertools.islice(dataset.train_iterator_fn(), 0, 2)) + itertools.islice(dataset.train_iterator_fn(), 0, 2) + ) # Check that the validation set correctly pads the final partial batch. no_drop_partial_batches = list(dataset.test_epoch(num_batches=3)) self.assertLen(drop_partial_batches, 2) self.assertLen(no_drop_partial_batches, 3) expected_shape = jnp.array([ - 80 % batch_size, hps.input_shape[0], - hps.input_shape[1], hps.input_shape[2], + 80 % batch_size, + hps.input_shape[0], + hps.input_shape[1], + hps.input_shape[2], ]) self.assertTrue( - jnp.array_equal(no_drop_partial_batches[2]['inputs'].shape, - expected_shape)) + jnp.array_equal( + no_drop_partial_batches[2]['inputs'].shape, expected_shape + ) + ) # We expect the partial batch to have 40 % 16 = 8 non padded inputs. self.assertEqual(no_drop_partial_batches[2]['weights'].sum(), 8) # Test number of batches num_batches = 1 - num_generated = len( - [ - b for b in itertools.islice( - dataset.train_iterator_fn(), 0, num_batches) - ]) + num_generated = len([ + b for b in itertools.islice(dataset.train_iterator_fn(), 0, num_batches) + ]) self.assertEqual(num_batches, num_generated) diff --git a/init2winit/dataset_lib/test_fineweb_edu_10b_input_pipeline.py b/init2winit/dataset_lib/test_fineweb_edu_10b_input_pipeline.py index 734df008..5be94067 100644 --- a/init2winit/dataset_lib/test_fineweb_edu_10b_input_pipeline.py +++ b/init2winit/dataset_lib/test_fineweb_edu_10b_input_pipeline.py @@ -46,5 +46,6 @@ def test_batch_with_padding(self): self.assertTrue(np.array_equal(padded_batch, padded_batch_expected)) + if __name__ == '__main__': absltest.main() diff --git a/init2winit/dataset_lib/test_ogbg_molpcba.py b/init2winit/dataset_lib/test_ogbg_molpcba.py index 4e0ca9ce..1f5e0788 100644 --- a/init2winit/dataset_lib/test_ogbg_molpcba.py +++ b/init2winit/dataset_lib/test_ogbg_molpcba.py @@ -42,21 +42,18 @@ def _make_graph(num_nodes, num_edges, labels): return { - 'num_edges': - np.array([num_edges]), - 'num_nodes': - np.array([num_nodes]), - 'edge_index': - np.array( - list( - itertools.islice( - itertools.combinations(range(num_nodes), 2), num_edges))), - 'edge_feat': - np.ones((num_edges, 3)).astype('float32'), - 'node_feat': - np.ones((num_nodes, 9)).astype('float32'), - 'labels': - labels + 'num_edges': np.array([num_edges]), + 'num_nodes': np.array([num_nodes]), + 'edge_index': np.array( + list( + itertools.islice( + itertools.combinations(range(num_nodes), 2), num_edges + ) + ) + ), + 'edge_feat': np.ones((num_edges, 3)).astype('float32'), + 'node_feat': np.ones((num_nodes, 9)).astype('float32'), + 'labels': labels, } @@ -67,7 +64,8 @@ def _as_dataset(*args, **kwargs): def get_iter(): return ( _make_graph(num_nodes, num_edges, labels) - for num_nodes, num_edges, labels in zip(NUMS_NODES, NUMS_EDGES, LABELS)) + for num_nodes, num_edges, labels in zip(NUMS_NODES, NUMS_EDGES, LABELS) + ) return tf.data.Dataset.from_generator( get_iter, @@ -78,7 +76,8 @@ def get_iter(): 'node_feat': tf.TensorSpec(shape=(None, 9), dtype=np.float32), 'num_edges': tf.TensorSpec(shape=(1,), dtype=np.int64), 'num_nodes': tf.TensorSpec(shape=(1,), dtype=np.int64), - }) + }, + ) def _get_dataset(shuffle_seed, additional_hps=None): @@ -104,7 +103,8 @@ def _get_dataset(shuffle_seed, additional_hps=None): shuffle_rng=shuffle_seed, batch_size=batch_size, eval_batch_size=eval_batch_size, - hps=hps) + hps=hps, + ) return dataset @@ -126,12 +126,14 @@ def test_get_batch_pads_correctly(self): # The graphs are padded to the right size self.assertEqual(inputs.n_node.shape[0], BATCH_SIZE + 1) self.assertEqual( - np.sum(inputs.n_node), BATCH_SIZE * NODES_SIZE_MULTIPLIER + 1) + np.sum(inputs.n_node), BATCH_SIZE * NODES_SIZE_MULTIPLIER + 1 + ) self.assertEqual(np.sum(inputs.n_edge), BATCH_SIZE * EDGES_SIZE_MULTIPLIER) # Weights are zero at NaN labels and in padded examples - self.assertNDArrayNear(batch['weights'], - np.array([[1, 1], [0, 1], [0, 0]]), 1e-3) + self.assertNDArrayNear( + batch['weights'], np.array([[1, 1], [0, 1], [0, 0]]), 1e-3 + ) self.assertFalse(np.any(np.isnan(batch['targets']))) def test_train_shuffle_is_deterministic(self): @@ -156,30 +158,32 @@ def test_add_virtual_node(self): num_nodes = np.array(NUMS_NODES[0]) num_edges = np.array(NUMS_EDGES[0]) + self.assertNDArrayNear(inputs.n_node[0], np.array(num_nodes + 1), 1e-3) self.assertNDArrayNear( - inputs.n_node[0], np.array(num_nodes + 1), 1e-3) + inputs.n_edge[0], np.array(num_edges + num_nodes), 1e-3 + ) self.assertNDArrayNear( - inputs.n_edge[0], np.array(num_edges + num_nodes), 1e-3) + inputs.edges[num_edges : num_edges + num_nodes], + np.zeros_like(inputs.edges[num_edges : num_edges + num_nodes]), + 1e-3, + ) self.assertNDArrayNear( - inputs.edges[num_edges:num_edges + num_nodes], - np.zeros_like(inputs.edges[num_edges:num_edges + num_nodes]), 1e-3) - self.assertNDArrayNear(inputs.nodes[num_nodes], - np.zeros_like(inputs.nodes[num_nodes]), 1e-3) + inputs.nodes[num_nodes], np.zeros_like(inputs.nodes[num_nodes]), 1e-3 + ) def test_add_bidirectional_edges(self): """Tests that adding bidirectional edges works correctly.""" dataset = _get_dataset( - jax.random.PRNGKey(0), {'add_bidirectional_edges': True}) + jax.random.PRNGKey(0), {'add_bidirectional_edges': True} + ) batch = next(dataset.valid_epoch()) inputs = batch['inputs'] num_nodes = np.array(NUMS_NODES[0]) num_edges = np.array(NUMS_EDGES[0]) - self.assertNDArrayNear( - inputs.n_node[0], np.array(num_nodes), 1e-3) - self.assertNDArrayNear( - inputs.n_edge[0], np.array(num_edges * 2), 1e-3) + self.assertNDArrayNear(inputs.n_node[0], np.array(num_nodes), 1e-3) + self.assertNDArrayNear(inputs.n_edge[0], np.array(num_edges * 2), 1e-3) def test_add_self_loops(self): """Tests that adding self loops works correctly.""" @@ -190,13 +194,15 @@ def test_add_self_loops(self): num_nodes = np.array(NUMS_NODES[0]) num_edges = np.array(NUMS_EDGES[0]) + self.assertNDArrayNear(inputs.n_node[0], np.array(num_nodes), 1e-3) self.assertNDArrayNear( - inputs.n_node[0], np.array(num_nodes), 1e-3) - self.assertNDArrayNear( - inputs.n_edge[0], np.array(num_edges + num_nodes), 1e-3) + inputs.n_edge[0], np.array(num_edges + num_nodes), 1e-3 + ) self.assertNDArrayNear( - inputs.edges[num_edges:num_edges + num_nodes], - np.zeros_like(inputs.edges[num_edges:num_edges + num_nodes]), 1e-3) + inputs.edges[num_edges : num_edges + num_nodes], + np.zeros_like(inputs.edges[num_edges : num_edges + num_nodes]), + 1e-3, + ) if __name__ == '__main__': diff --git a/init2winit/dataset_lib/test_ogbg_static.py b/init2winit/dataset_lib/test_ogbg_static.py index 42871597..336bb1c2 100644 --- a/init2winit/dataset_lib/test_ogbg_static.py +++ b/init2winit/dataset_lib/test_ogbg_static.py @@ -404,5 +404,6 @@ def test_all_augmentations(self): msg=f'weights mismatch with all augmentations in batch {i}', ) + if __name__ == '__main__': tf.test.main() diff --git a/init2winit/dataset_lib/test_small_image_datasets.py b/init2winit/dataset_lib/test_small_image_datasets.py index 74b9a2ca..53e1e0a7 100644 --- a/init2winit/dataset_lib/test_small_image_datasets.py +++ b/init2winit/dataset_lib/test_small_image_datasets.py @@ -29,7 +29,9 @@ class SmallImageDatasetsTest(absltest.TestCase): def test_cifar10(self): """Test example generation in CIFAR10 is reproducible.""" dataset = small_image_datasets.get_cifar10( - random.PRNGKey(0), 1, 1, + random.PRNGKey(0), + 1, + 1, config_dict.ConfigDict( dict( flip_probability=0.5, @@ -41,24 +43,30 @@ def test_cifar10(self): test_size=10000, include_example_keys=True, input_shape=(32, 32, 3), - output_shape=(10,)))) + output_shape=(10,), + ) + ), + ) examples = itertools.islice(dataset.valid_epoch(), 10) example_keys = [ example['example_key'][0].decode('utf-8') for example in examples ] - self.assertEqual(example_keys, [ - 'cifar10-train.array_record-00000-of-00001__45000', - 'cifar10-train.array_record-00000-of-00001__45001', - 'cifar10-train.array_record-00000-of-00001__45002', - 'cifar10-train.array_record-00000-of-00001__45003', - 'cifar10-train.array_record-00000-of-00001__45004', - 'cifar10-train.array_record-00000-of-00001__45005', - 'cifar10-train.array_record-00000-of-00001__45006', - 'cifar10-train.array_record-00000-of-00001__45007', - 'cifar10-train.array_record-00000-of-00001__45008', - 'cifar10-train.array_record-00000-of-00001__45009', - ]) + self.assertEqual( + example_keys, + [ + 'cifar10-train.array_record-00000-of-00001__45000', + 'cifar10-train.array_record-00000-of-00001__45001', + 'cifar10-train.array_record-00000-of-00001__45002', + 'cifar10-train.array_record-00000-of-00001__45003', + 'cifar10-train.array_record-00000-of-00001__45004', + 'cifar10-train.array_record-00000-of-00001__45005', + 'cifar10-train.array_record-00000-of-00001__45006', + 'cifar10-train.array_record-00000-of-00001__45007', + 'cifar10-train.array_record-00000-of-00001__45008', + 'cifar10-train.array_record-00000-of-00001__45009', + ], + ) if __name__ == '__main__': diff --git a/init2winit/dataset_lib/test_wikitext_tokenizer.py b/init2winit/dataset_lib/test_wikitext_tokenizer.py index f289c49e..640696eb 100644 --- a/init2winit/dataset_lib/test_wikitext_tokenizer.py +++ b/init2winit/dataset_lib/test_wikitext_tokenizer.py @@ -52,5 +52,6 @@ def test_tokenizer_vocab_size(self): self.assertEqual(num_unique_tokens, num_unique_words + 2) + if __name__ == '__main__': absltest.main() diff --git a/init2winit/dataset_lib/translate_wmt.py b/init2winit/dataset_lib/translate_wmt.py index f6e7331d..ceaeb8e6 100644 --- a/init2winit/dataset_lib/translate_wmt.py +++ b/init2winit/dataset_lib/translate_wmt.py @@ -70,7 +70,8 @@ # max_eval_target_length, # max_predict_length) # input_shape = [(max_len,), (max_len,)] - )) + ) +) METADATA = { 'apply_one_hot_in_loss': True, @@ -83,10 +84,9 @@ def get_translate_wmt(shuffle_rng, batch_size, eval_batch_size=None, hps=None): per_host_batch_size = batch_size // jax.process_count() per_host_eval_batch_size = eval_batch_size // jax.process_count() - return _get_translate_wmt(per_host_batch_size, - per_host_eval_batch_size, - hps, - shuffle_rng) + return _get_translate_wmt( + per_host_batch_size, per_host_eval_batch_size, hps, shuffle_rng + ) def validate_hparams(hps): @@ -96,26 +96,32 @@ def validate_hparams(hps): if hps.tfds_dataset_keys: # trains multilingual model assert len(hps.tfds_dataset_keys) >= 1 else: - raise ValueError('Either set tfds_dataset_key to train bilingual model' + - 'or set tfds_dataset_keys to train multilingual model') + raise ValueError( + 'Either set tfds_dataset_key to train bilingual model' + + 'or set tfds_dataset_keys to train multilingual model' + ) if hps.rates: assert len(hps.tfds_dataset_keys) == len(hps.rates) -def _get_translate_wmt(per_host_batch_size, - per_host_eval_batch_size, - hps, - shuffle_rng): +def _get_translate_wmt( + per_host_batch_size, per_host_eval_batch_size, hps, shuffle_rng +): """Data generators for wmt translate task.""" n_devices = jax.local_device_count() if per_host_batch_size % n_devices != 0: - raise ValueError('n_devices={} must divide per_host_batch_size={}.'.format( - n_devices, per_host_batch_size)) + raise ValueError( + 'n_devices={} must divide per_host_batch_size={}.'.format( + n_devices, per_host_batch_size + ) + ) if per_host_eval_batch_size % n_devices != 0: raise ValueError( 'n_devices={} must divide per_host_eval_batch_size={}.'.format( - n_devices, per_host_eval_batch_size)) + n_devices, per_host_eval_batch_size + ) + ) validate_hparams(hps) @@ -128,22 +134,21 @@ def _get_translate_wmt(per_host_batch_size, n_devices=jax.local_device_count(), per_host_batch_size=per_host_batch_size, per_host_eval_batch_size=per_host_eval_batch_size, - vocab_path=vocab_path) + vocab_path=vocab_path, + ) def train_iterator_fn(): for batch in iter(train_ds): yield mt_pipeline.maybe_pad_batch( - data_utils.tf_to_numpy(batch), - per_host_batch_size, - mask_key='targets') + data_utils.tf_to_numpy(batch), per_host_batch_size, mask_key='targets' + ) def eval_train_epoch(num_batches=None): eval_train_iter = iter(train_ds) for batch in itertools.islice(eval_train_iter, num_batches): yield mt_pipeline.maybe_pad_batch( - data_utils.tf_to_numpy(batch), - per_host_batch_size, - mask_key='targets') + data_utils.tf_to_numpy(batch), per_host_batch_size, mask_key='targets' + ) def valid_epoch(num_batches=None): valid_iter = iter(eval_ds) @@ -151,7 +156,8 @@ def valid_epoch(num_batches=None): yield mt_pipeline.maybe_pad_batch( data_utils.tf_to_numpy(batch), per_host_eval_batch_size, - mask_key='targets') + mask_key='targets', + ) def test_epoch(num_batches=None): predict_iter = iter(predict_ds) @@ -159,7 +165,8 @@ def test_epoch(num_batches=None): yield mt_pipeline.maybe_pad_batch( data_utils.tf_to_numpy(batch), per_host_eval_batch_size, - mask_key='targets') + mask_key='targets', + ) return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch) @@ -167,31 +174,31 @@ def test_epoch(num_batches=None): def get_fake_batch(hps): """Build fake batch for translate_wmt.""" batch = { - 'inputs': - np.ones((hps.batch_size, hps.max_target_length), - dtype=np.int32), - 'targets': - np.ones((hps.batch_size, hps.max_target_length), - dtype=np.int32), - 'weights': - np.ones((hps.batch_size, hps.max_target_length), - dtype=np.int64), + 'inputs': np.ones( + (hps.batch_size, hps.max_target_length), dtype=np.int32 + ), + 'targets': np.ones( + (hps.batch_size, hps.max_target_length), dtype=np.int32 + ), + 'weights': np.ones( + (hps.batch_size, hps.max_target_length), dtype=np.int64 + ), } if hps.pack_examples: batch.update({ - 'inputs_position': - np.ones((hps.batch_size, hps.max_target_length), - dtype=np.int32), - 'inputs_segmentation': - np.ones((hps.batch_size, hps.max_target_length), - dtype=np.int32), - 'targets_position': - np.ones((hps.batch_size, hps.max_target_length), - dtype=np.int32), - 'targets_segmentation': - np.ones((hps.batch_size, hps.max_target_length), - dtype=np.int32), + 'inputs_position': np.ones( + (hps.batch_size, hps.max_target_length), dtype=np.int32 + ), + 'inputs_segmentation': np.ones( + (hps.batch_size, hps.max_target_length), dtype=np.int32 + ), + 'targets_position': np.ones( + (hps.batch_size, hps.max_target_length), dtype=np.int32 + ), + 'targets_segmentation': np.ones( + (hps.batch_size, hps.max_target_length), dtype=np.int32 + ), }) return batch diff --git a/init2winit/dataset_lib/wikitext103.py b/init2winit/dataset_lib/wikitext103.py index 440826e8..88f1aba4 100644 --- a/init2winit/dataset_lib/wikitext103.py +++ b/init2winit/dataset_lib/wikitext103.py @@ -41,7 +41,8 @@ tokenizer='word', tokenizer_vocab_path=None, vocab_size=input_pipeline.WORD_VOCAB_SIZE, - )) + ) +) METADATA = { @@ -71,7 +72,8 @@ def get_wikitext103( batch_size: int, eval_batch_size: int = None, hps: config_dict.ConfigDict = None, - pad_id: int = PAD_ID) -> Dataset: + pad_id: int = PAD_ID, +) -> Dataset: """Returns Wikitext-103 Dataset. Args: @@ -93,7 +95,9 @@ def get_wikitext103( if batch_size % process_count != 0: raise ValueError( 'process_count={} must divide batch_size={}.'.format( - process_count, batch_size)) + process_count, batch_size + ) + ) if eval_batch_size is None: eval_batch_size = batch_size @@ -101,7 +105,9 @@ def get_wikitext103( if eval_batch_size % process_count != 0: raise ValueError( 'process_count={} must divide batch_size={}.'.format( - process_count, batch_size)) + process_count, batch_size + ) + ) train_dataset, eval_train_dataset, valid_dataset, test_dataset = ( input_pipeline.get_wikitext103_dataset( @@ -129,5 +135,4 @@ def test_epoch(num_batches=None): for batch in itertools.islice(iter(test_dataset), num_batches): yield add_weights_to_batch(data_utils.tf_to_numpy(batch), pad_id) - return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, - test_epoch) + return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch) diff --git a/init2winit/dataset_lib/wikitext103_input_pipeline.py b/init2winit/dataset_lib/wikitext103_input_pipeline.py index 7e7f1167..fd642172 100644 --- a/init2winit/dataset_lib/wikitext103_input_pipeline.py +++ b/init2winit/dataset_lib/wikitext103_input_pipeline.py @@ -37,12 +37,13 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE -def get_trained_tokenizer(train_dataset: Union[tf.data.Dataset, str], - tokenizer: str, - vocab_path: str = SPM_TOKENIZER_VOCAB_PATH, - vocab_size: int = SPM_TOKENIZER_VOCAB_SIZE, - max_corpus_chars: int = MAX_CORPUS_CHARS, - ) -> tf.data.Dataset: +def get_trained_tokenizer( + train_dataset: Union[tf.data.Dataset, str], + tokenizer: str, + vocab_path: str = SPM_TOKENIZER_VOCAB_PATH, + vocab_size: int = SPM_TOKENIZER_VOCAB_SIZE, + max_corpus_chars: int = MAX_CORPUS_CHARS, +) -> tf.data.Dataset: """Returns a tokenizer trained on the train dataset. Args: @@ -71,11 +72,12 @@ def get_trained_tokenizer(train_dataset: Union[tf.data.Dataset, str], return tokenizer -def batch_with_padding(dataset: tf.data.Dataset, - batch_size, - padded_shapes=None, - padding_id=PAD_ID, - ): +def batch_with_padding( + dataset: tf.data.Dataset, + batch_size, + padded_shapes=None, + padding_id=PAD_ID, +): """Batches a tf.data.Dataset and adds padding if len(dataset) is not divisible by the batch size. Args: @@ -85,14 +87,14 @@ def batch_with_padding(dataset: tf.data.Dataset, padding_id: value for padding, for elements in new batch Returns: - """ batched_dataset = dataset.batch(batch_size, drop_remainder=False) # tf.data.Dataset.padded.batch pads elements in the batch so we call it # again with batch_size=1 to pad each element in original batch. padded_batched_dataset = batched_dataset.padded_batch( - 1, padded_shapes=padded_shapes, padding_values=padding_id) + 1, padded_shapes=padded_shapes, padding_values=padding_id + ) # Remove extra dimension resulting from the batch_size=1. padded_batched_dataset = padded_batched_dataset.unbatch() @@ -129,10 +131,12 @@ def get_wikitext103_dataset( test_text_dataset = tf.data.TextLineDataset(test_path) # Tokenize data - tokenizer = get_trained_tokenizer(train_text_dataset, - hps.tokenizer, - hps.tokenizer_vocab_path, - hps.vocab_size) + tokenizer = get_trained_tokenizer( + train_text_dataset, + hps.tokenizer, + hps.tokenizer_vocab_path, + hps.vocab_size, + ) train_dataset_tokenized = train_text_dataset.map(tokenizer.tokenize) valid_dataset_tokenized = valid_text_dataset.map(tokenizer.tokenize) test_dataset_tokenized = test_text_dataset.map(tokenizer.tokenize) @@ -173,24 +177,29 @@ def get_wikitext103_dataset( # Split the sequences into inputs and targets. train_dataset_sequences = train_dataset_sequences.map( - lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE) + lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE + ) eval_train_dataset_sequences = eval_train_sequences.map( - lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE) + lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE + ) valid_dataset_sequences = valid_dataset_sequences.map( - lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE) + lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE + ) test_dataset_sequences = test_dataset_sequences.map( - lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE) + lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE + ) # Shuffle the train sequences. train_dataset_sequences = train_dataset_sequences.shuffle( - SHUFFLE_BUFFER_SIZE, seed=shuffle_seed) + SHUFFLE_BUFFER_SIZE, seed=shuffle_seed + ) # Perform batching for training, validation and testing. # Make training data repeat indefinitely. train_dataset_sequences = train_dataset_sequences.repeat() train_dataset = train_dataset_sequences.batch( - train_batch_size, - drop_remainder=False).prefetch(tf.data.experimental.AUTOTUNE) + train_batch_size, drop_remainder=False + ).prefetch(tf.data.experimental.AUTOTUNE) # Use padded batches for eval_train, validation and test_datasets since the # sequences do not repeat indefintely. @@ -199,23 +208,26 @@ def get_wikitext103_dataset( train_batch_size, padded_shapes={ 'inputs': (train_batch_size, None), - 'targets': (train_batch_size, None) - }).prefetch(tf.data.experimental.AUTOTUNE) + 'targets': (train_batch_size, None), + }, + ).prefetch(tf.data.experimental.AUTOTUNE) valid_dataset = batch_with_padding( valid_dataset_sequences, valid_batch_size, padded_shapes={ 'inputs': (valid_batch_size, None), - 'targets': (valid_batch_size, None) - }).prefetch(tf.data.experimental.AUTOTUNE) + 'targets': (valid_batch_size, None), + }, + ).prefetch(tf.data.experimental.AUTOTUNE) test_dataset = batch_with_padding( test_dataset_sequences, test_batch_size, padded_shapes={ 'inputs': (test_batch_size, None), - 'targets': (test_batch_size, None) - }).prefetch(tf.data.experimental.AUTOTUNE) + 'targets': (test_batch_size, None), + }, + ).prefetch(tf.data.experimental.AUTOTUNE) return train_dataset, eval_train_dataset, valid_dataset, test_dataset diff --git a/init2winit/dataset_lib/wikitext2.py b/init2winit/dataset_lib/wikitext2.py index 2eb71b04..0201dc70 100644 --- a/init2winit/dataset_lib/wikitext2.py +++ b/init2winit/dataset_lib/wikitext2.py @@ -37,8 +37,9 @@ vocab_size=VOCAB_SIZE, # TODO(kasimbeg) : add vocab path after seperating out tokenizer # vocab_path=None, - train_size=59676 # Number of sequences. - )) + train_size=59676, # Number of sequences. + ) +) METADATA = { 'apply_one_hot_in_loss': True, @@ -65,7 +66,8 @@ def get_wikitext2( data_rng, batch_size: int, eval_batch_size: int = None, - hps: config_dict.ConfigDict = None,) -> Dataset: + hps: config_dict.ConfigDict = None, +) -> Dataset: """Returns Wikitext-2 Dataset. Args: @@ -86,22 +88,28 @@ def get_wikitext2( if batch_size % process_count != 0: raise ValueError( 'process_count={} must divide batch_size={}.'.format( - process_count, batch_size)) + process_count, batch_size + ) + ) if eval_batch_size % process_count != 0: raise ValueError( 'process_count={} must divide batch_size={}.'.format( - process_count, batch_size)) + process_count, batch_size + ) + ) if eval_batch_size is None: eval_batch_size = batch_size - train_dataset, eval_train_dataset, valid_dataset, test_dataset = input_pipeline.get_wikitext2_dataset( - hps, - train_batch_size=batch_size, - valid_batch_size=eval_batch_size, - test_batch_size=eval_batch_size, - shuffle_seed=data_utils.convert_jax_to_tf_random_seed(data_rng), + train_dataset, eval_train_dataset, valid_dataset, test_dataset = ( + input_pipeline.get_wikitext2_dataset( + hps, + train_batch_size=batch_size, + valid_batch_size=eval_batch_size, + test_batch_size=eval_batch_size, + shuffle_seed=data_utils.convert_jax_to_tf_random_seed(data_rng), + ) ) def train_iterator_fn(): @@ -120,5 +128,4 @@ def test_epoch(num_batches=None): for batch in itertools.islice(iter(test_dataset), num_batches): yield add_weights_to_batch(data_utils.tf_to_numpy(batch)) - return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, - test_epoch) + return Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch) diff --git a/init2winit/dataset_lib/wikitext2_input_pipeline.py b/init2winit/dataset_lib/wikitext2_input_pipeline.py index 79cab2ef..dfc8950f 100644 --- a/init2winit/dataset_lib/wikitext2_input_pipeline.py +++ b/init2winit/dataset_lib/wikitext2_input_pipeline.py @@ -32,17 +32,20 @@ AUTOTUNE = tf.data.experimental.AUTOTUNE -def get_trained_tokenizer(train_dataset: tf.data.Dataset,) -> tf.data.Dataset: +def get_trained_tokenizer( + train_dataset: tf.data.Dataset, +) -> tf.data.Dataset: tokenizer = wikitext_tokenizer.Tokenizer() tokenizer.train(train_dataset) return tokenizer -def batch_with_padding(dataset: tf.data.Dataset, - batch_size, - padded_shapes=None, - padding_id=PAD_ID, - ): +def batch_with_padding( + dataset: tf.data.Dataset, + batch_size, + padded_shapes=None, + padding_id=PAD_ID, +): """Batches a tf.data.Dataset and adds padding if len(dataset) not divisible by the batch size. Args: @@ -52,14 +55,14 @@ def batch_with_padding(dataset: tf.data.Dataset, padding_id: value for padding, for elements in new batch Returns: - """ batched_dataset = dataset.batch(batch_size, drop_remainder=False) # tf.data.Dataset.padded.batch pads elements in the batch so we call it # again with batch_size=1 to pad each element in original batch. padded_batched_dataset = batched_dataset.padded_batch( - 1, padded_shapes=padded_shapes, padding_values=padding_id) + 1, padded_shapes=padded_shapes, padding_values=padding_id + ) # Remove extra dimension resulting from the batch_size=1. padded_batched_dataset = padded_batched_dataset.unbatch() @@ -68,8 +71,11 @@ def batch_with_padding(dataset: tf.data.Dataset, def get_wikitext2_dataset( - hps: config_dict.ConfigDict, train_batch_size: int, valid_batch_size: int, - test_batch_size: int, shuffle_seed: int + hps: config_dict.ConfigDict, + train_batch_size: int, + valid_batch_size: int, + test_batch_size: int, + shuffle_seed: int, ) -> tuple[tf.data.Dataset, tf.data.Dataset, tf.data.Dataset, tf.data.Dataset]: """Returns wikitext-2 dataset. @@ -123,13 +129,17 @@ def get_wikitext2_dataset( # Split the sequences into inputs and targets. train_dataset_sequences = train_dataset_sequences.map( - lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE) + lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE + ) eval_train_dataset_sequences = eval_train_dataset_sequences.map( - lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE) + lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE + ) valid_dataset_sequences = valid_dataset_sequences.map( - lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE) + lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE + ) test_dataset_sequences = test_dataset_sequences.map( - lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE) + lambda x: {'inputs': x, 'targets': x}, num_parallel_calls=AUTOTUNE + ) # Shuffle the train sequences. train_dataset_sequences = train_dataset_sequences.shuffle( @@ -149,23 +159,26 @@ def get_wikitext2_dataset( train_batch_size, padded_shapes={ 'inputs': (train_batch_size, None), - 'targets': (train_batch_size, None) - }).prefetch(tf.data.experimental.AUTOTUNE) + 'targets': (train_batch_size, None), + }, + ).prefetch(tf.data.experimental.AUTOTUNE) valid_dataset = batch_with_padding( valid_dataset_sequences, valid_batch_size, padded_shapes={ 'inputs': (valid_batch_size, None), - 'targets': (valid_batch_size, None) - }).prefetch(tf.data.experimental.AUTOTUNE) + 'targets': (valid_batch_size, None), + }, + ).prefetch(tf.data.experimental.AUTOTUNE) test_dataset = batch_with_padding( test_dataset_sequences, test_batch_size, padded_shapes={ 'inputs': (test_batch_size, None), - 'targets': (test_batch_size, None) - }).prefetch(tf.data.experimental.AUTOTUNE) + 'targets': (test_batch_size, None), + }, + ).prefetch(tf.data.experimental.AUTOTUNE) return train_dataset, eval_train_dataset, valid_dataset, test_dataset diff --git a/init2winit/gradient_statistics_callback.py b/init2winit/gradient_statistics_callback.py index 044dc4be..8f728c76 100644 --- a/init2winit/gradient_statistics_callback.py +++ b/init2winit/gradient_statistics_callback.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Callback for computing gradient statistics given set of params. -""" +"""Callback for computing gradient statistics given set of params.""" import functools import itertools @@ -69,9 +68,7 @@ def __init__( self.num_updates = 0 self.orbax_checkpoint_manager = ocp.CheckpointManager( self.save_path, - options=ocp.CheckpointManagerOptions( - max_to_keep=1, create=True - ), + options=ocp.CheckpointManagerOptions(max_to_keep=1, create=True), ) def update(params, batch, batch_stats, dropout_rng): @@ -88,9 +85,7 @@ def opt_cost(params): return grad - params_sharding = jax.tree_util.tree_map( - lambda x: x.sharding, params - ) + params_sharding = jax.tree_util.tree_map(lambda x: x.sharding, params) batch_stats_sharding = nn.get_sharding(batch_stats, self.mesh) self.jitted_update = jax.jit( @@ -98,16 +93,16 @@ def opt_cost(params): in_shardings=( params_sharding, jax.sharding.NamedSharding( - self.mesh, jax.sharding.PartitionSpec('devices')), + self.mesh, jax.sharding.PartitionSpec('devices') + ), batch_stats_sharding, - None + None, ), - out_shardings=(params_sharding) + out_shardings=(params_sharding), ) def run_eval(self, params, batch_stats, optimizer_state, global_step): - """Computes gradient statistics from mini batches over full training data. - """ + """Computes gradient statistics from mini batches over full training data.""" del optimizer_state train_iter = itertools.islice( self.dataset.train_iterator_fn(), self.num_batches_in_training_epoch @@ -149,12 +144,13 @@ def run_eval(self, params, batch_stats, optimizer_state, global_step): state = dict( grad_std=jax.device_get(grad_std), grad_mean=jax.device_get(grad_mean), - step=global_step + step=global_step, ) checkpoint.save_checkpoint( step=global_step, state=state, - orbax_checkpoint_manager=self.orbax_checkpoint_manager) + orbax_checkpoint_manager=self.orbax_checkpoint_manager, + ) return {} diff --git a/init2winit/hyperparameters.py b/init2winit/hyperparameters.py index b7c1b4fb..401eba97 100644 --- a/init2winit/hyperparameters.py +++ b/init2winit/hyperparameters.py @@ -14,6 +14,7 @@ # limitations under the License. """Hyperparameter management logic.""" + import json from typing import Dict @@ -30,18 +31,16 @@ def expand_key(hparams, key_pieces, index, value): """Util to safely expand dotted keys in a dictionary. Args: - hparams: the hparams dictionary containing dotted keys. - key_pieces: - list containing pieces of dotted key. e.g: ['a', 'b', 'c'] for 'a.b.c' - index: - current index being read within key_pieces. - value: - value to be inserted for dotted key. + hparams: the hparams dictionary containing dotted keys. + key_pieces: list containing pieces of dotted key. e.g: ['a', 'b', 'c'] for + 'a.b.c' + index: current index being read within key_pieces. + value: value to be inserted for dotted key. Raises: ValueError: 1) if any prefix of dotted key is not a dictionary - 2) if dotted key overrides a constant value already set in dictionary. + 2) if dotted key overrides a constant value already set in dictionary. """ curr_p = key_pieces[index] @@ -57,8 +56,8 @@ def expand_key(hparams, key_pieces, index, value): raise ValueError( 'prefix = {} already exists with value = {}'.format( '.'.join(key_pieces[: index + 1]), hparams[curr_p] - ) ) + ) else: if curr_p not in hparams: hparams[curr_p] = {} @@ -219,7 +218,8 @@ def build_hparams( logging.warning('Unrecognized top-level hparams: %s', new_keys) if any(k not in allowed_unrecognized_hparams for k in new_keys): raise ValueError( - f'Unrecognized top-level hparams not in allowlist: {new_keys}') + f'Unrecognized top-level hparams not in allowlist: {new_keys}' + ) with merged.unlocked(): merged.update(overrides_dict) else: diff --git a/init2winit/init_lib/initializers.py b/init2winit/init_lib/initializers.py index feb277f7..33ae6054 100644 --- a/init2winit/init_lib/initializers.py +++ b/init2winit/init_lib/initializers.py @@ -42,6 +42,8 @@ def noop( ): """No-op init.""" return params + + # pylint: enable=unused-argument DEFAULT_HPARAMS = config_dict.ConfigDict() diff --git a/init2winit/init_lib/meta_init.py b/init2winit/init_lib/meta_init.py index 400dfb13..638c1fbf 100644 --- a/init2winit/init_lib/meta_init.py +++ b/init2winit/init_lib/meta_init.py @@ -37,20 +37,22 @@ import numpy as np import optax - # Small hparams for quicker tests. -DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - meta_learning_rate=.1, - meta_steps=50, - meta_batch_size=8, - epsilon=1e-5, - meta_momentum=0.5, -)) +DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + meta_learning_rate=0.1, + meta_steps=50, + meta_batch_size=8, + epsilon=1e-5, + meta_momentum=0.5, + ) +) def _count_params(tree): - return jax.tree_util.tree_reduce(operator.add, - jax.tree.map(lambda x: x.size, tree)) + return jax.tree_util.tree_reduce( + operator.add, jax.tree.map(lambda x: x.size, tree) + ) def scale_params(params, scalars): @@ -93,11 +95,13 @@ def meta_loss(params_to_loss, scalars, normalized_params, epsilon): nparams = _count_params(g) def meta_term(g, hgp): - ratio = (g-hgp) / (g + epsilon * jax.lax.stop_gradient(2*(g >= 0) - 1)) + ratio = (g - hgp) / (g + epsilon * jax.lax.stop_gradient(2 * (g >= 0) - 1)) return jnp.sum(jnp.abs(ratio - 1)) - return jax.tree_util.tree_reduce(operator.add, jax.tree.map( - meta_term, g, hgp)) / nparams + return ( + jax.tree_util.tree_reduce(operator.add, jax.tree.map(meta_term, g, hgp)) + / nparams + ) def normalize(node): @@ -115,16 +119,18 @@ def _get_non_bias_params(params): return bias_and_scalar_keys -def meta_optimize_scales(loss_fn, - fprop, - normalized_params, - norms, - hps, - input_shape, - output_shape, - rng_key, - metrics_logger=None, - log_every=10): +def meta_optimize_scales( + loss_fn, + fprop, + normalized_params, + norms, + hps, + input_shape, + output_shape, + rng_key, + metrics_logger=None, + log_every=10, +): """Implements MetaInit initializer. Args: @@ -148,8 +154,11 @@ def meta_optimize_scales(loss_fn, """ num_outputs = output_shape[-1] if hps.meta_batch_size % jax.device_count() != 0: - raise ValueError('meta_bs: {}, n_devices: {}'.format( - hps.meta_batch_size, jax.device_count())) + raise ValueError( + 'meta_bs: {}, n_devices: {}'.format( + hps.meta_batch_size, jax.device_count() + ) + ) def get_batch(rng_key): """Return a fake batch of data.""" @@ -160,10 +169,15 @@ def get_batch(rng_key): input_key, target_key = jax.random.split(rng_key) inputs = jax.random.normal(input_key, meta_input_shape) - targets = jax.random.randint(target_key, ( - jax.local_device_count(), - hps.meta_batch_size // jax.device_count(), - ), 0, num_outputs) + targets = jax.random.randint( + target_key, + ( + jax.local_device_count(), + hps.meta_batch_size // jax.device_count(), + ), + 0, + num_outputs, + ) targets = jnp.eye(num_outputs)[targets] return (inputs, targets) @@ -174,13 +188,14 @@ def get_batch(rng_key): for key in non_bias_and_scalar_keys: logging.info(key) traversal = flax.traverse_util.ModelParamTraversal( - lambda path, _: path in non_bias_and_scalar_keys) + lambda path, _: path in non_bias_and_scalar_keys + ) # Non-bias, non-scalar norms. meta_params = traversal.update(lambda x: x, norms) meta_opt_init_fn, meta_opt_update_fn = optax.sgd( - learning_rate=hps.meta_learning_rate, - momentum=hps.meta_momentum) + learning_rate=hps.meta_learning_rate, momentum=hps.meta_momentum + ) meta_optimizer_state = meta_opt_init_fn(meta_params) meta_optimizer_state = jax_utils.replicate(meta_optimizer_state) meta_params = jax_utils.replicate(meta_params) @@ -189,9 +204,11 @@ def get_batch(rng_key): @functools.partial(jax.pmap, axis_name='batch') def update(meta_params, optimizer_state, inputs, targets): """Update step.""" + def params_to_loss(params): loss_value, loss_weight = loss_fn( - fprop({'params': params}, inputs, train=True), targets) + fprop({'params': params}, inputs, train=True), targets + ) return loss_value / loss_weight def _meta_loss(params): @@ -202,7 +219,8 @@ def _meta_loss(params): grads = model_utils.cross_device_avg(grads) grads = jax.tree.map(jnp.sign, grads) meta_updates, new_meta_optimizer_state = meta_opt_update_fn( - grads, optimizer_state, params=meta_params) + grads, optimizer_state, params=meta_params + ) new_meta_params = optax.apply_updates(meta_params, meta_updates) return new_meta_params, new_meta_optimizer_state, loss @@ -213,18 +231,19 @@ def _meta_loss(params): inputs, targets = get_batch(batch_rng) meta_params, meta_optimizer_state, loss_value = update( - meta_params, meta_optimizer_state, inputs, targets) + meta_params, meta_optimizer_state, inputs, targets + ) training_curve.append(loss_value) - if (jax.process_index() == 0 and - (i % log_every == 0 or (i + 1) == hps.meta_steps)): + if jax.process_index() == 0 and ( + i % log_every == 0 or (i + 1) == hps.meta_steps + ): end = time.perf_counter() - logging.info('Cumulative time (seconds): %d', end-start) + logging.info('Cumulative time (seconds): %d', end - start) logging.info('meta_init step %d, loss: %f', i, float(loss_value[0])) if metrics_logger is not None: - metrics_logger.append_scalar_metrics({ - 'global_step': i, - 'meta_loss': float(loss_value[0]) - }) + metrics_logger.append_scalar_metrics( + {'global_step': i, 'meta_loss': float(loss_value[0])} + ) # Create a new model with the learned init. learned_norms = jax_utils.unreplicate(meta_params) @@ -234,21 +253,24 @@ def _meta_loss(params): def _log_shape_and_norms(pytree, metrics_logger, key): shape_and_norms = jax.tree.map( lambda x: (str(x.shape), str(np.linalg.norm(x.reshape(-1)))), - unfreeze(pytree)) + unfreeze(pytree), + ) logging.info(json.dumps(shape_and_norms, sort_keys=True, indent=4)) if metrics_logger is not None: metrics_logger.append_json_object({'key': key, 'value': shape_and_norms}) -def meta_init(loss_fn, - flax_module, - params, - hps, - input_shape, - output_shape, - rng_key, - metrics_logger=None, - log_every=10): +def meta_init( + loss_fn, + flax_module, + params, + hps, + input_shape, + output_shape, + rng_key, + metrics_logger=None, + log_every=10, +): """Implements MetaInit initializer. Args: @@ -272,8 +294,7 @@ def meta_init(loss_fn, _log_shape_and_norms(params, metrics_logger, key='init_norms') # First grab the norms of all weights and rescale params to have norm 1. logging.info('Running meta init') - norms = jax.tree.map(lambda node: jnp.linalg.norm(node.reshape(-1)), - params) + norms = jax.tree.map(lambda node: jnp.linalg.norm(node.reshape(-1)), params) normalized_params = jax.tree.map(normalize, params) @@ -287,7 +308,8 @@ def meta_init(loss_fn, output_shape, rng_key, metrics_logger=metrics_logger, - log_every=log_every) + log_every=log_every, + ) new_params = scale_params(normalized_params, learned_norms) if jax.process_index() == 0: diff --git a/init2winit/init_lib/sparse_init.py b/init2winit/init_lib/sparse_init.py index 2a9caef3..0e3ba902 100644 --- a/init2winit/init_lib/sparse_init.py +++ b/init2winit/init_lib/sparse_init.py @@ -26,18 +26,24 @@ from ml_collections.config_dict import config_dict import numpy as np -DEFAULT_HPARAMS = config_dict.ConfigDict(dict(non_zero_connection_weights=15,)) +DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + non_zero_connection_weights=15, + ) +) -def sparse_init(loss_fn, - flax_module, - params, - hps, - input_shape, - output_shape, - rng_key, - metrics_logger=None, - log_every=10): +def sparse_init( + loss_fn, + flax_module, + params, + hps, + input_shape, + output_shape, + rng_key, + metrics_logger=None, + log_every=10, +): """Implements SparseInit initializer. Args: @@ -73,8 +79,11 @@ def sparse_init(loss_fn, for k in range(num_units_in): if num_units_out > hps.non_zero_connection_weights: non_zero_units_out = jax.random.choice( - rng_keys_in[k], num_units_out, (hps.non_zero_connection_weights,), - replace=False) + rng_keys_in[k], + num_units_out, + (hps.non_zero_connection_weights,), + replace=False, + ) mask[k, non_zero_units_out] = False else: mask[k, :] = False @@ -84,8 +93,11 @@ def sparse_init(loss_fn, for k in range(num_units_out): if num_units_in > hps.non_zero_connection_weights: non_zero_units_in = jax.random.choice( - rng_keys_out[k], num_units_in, (hps.non_zero_connection_weights,), - replace=False) + rng_keys_out[k], + num_units_in, + (hps.non_zero_connection_weights,), + replace=False, + ) mask[non_zero_units_in, k] = False else: mask[:, k] = False diff --git a/init2winit/init_lib/test_initializers.py b/init2winit/init_lib/test_initializers.py index ef836fb6..27eaef20 100644 --- a/init2winit/init_lib/test_initializers.py +++ b/init2winit/init_lib/test_initializers.py @@ -16,6 +16,7 @@ r"""Tests for initializers.py. """ + import copy import functools @@ -31,7 +32,6 @@ import jax.numpy as jnp import jax.tree_util - BATCH_SIZE = 10 OUTPUT_SHAPE = (5,) MODEL_TO_INPUT_SHAPE = { @@ -41,8 +41,10 @@ } -TEST_PARAMETRIZATION = (('test_{}'.format(init), init) - for init in initializers._ALL_INITIALIZERS.keys()) # pylint: disable=protected-access +TEST_PARAMETRIZATION = ( + ('test_{}'.format(init), init) + for init in initializers._ALL_INITIALIZERS.keys() # pylint: disable=protected-access +) def _load_model(model_name): @@ -88,25 +90,27 @@ def test_initializers(self, init): hps=init_hps, input_shape=input_shape[1:], output_shape=OUTPUT_SHAPE, - rng_key=init_rng) + rng_key=init_rng, + ) # Check new params are still valid params outputs = flax_module.apply( - {'params': new_params}, jnp.ones(input_shape), train=True) + {'params': new_params}, jnp.ones(input_shape), train=True + ) utils.log_pytree_shape_and_statistics(new_params) self.assertEqual(outputs.shape, (input_shape[0], OUTPUT_SHAPE[-1])) - @parameterized.named_parameters(('test_{}'.format(model_name), model_name) - for model_name in MODEL_TO_INPUT_SHAPE.keys()) + @parameterized.named_parameters( + ('test_{}'.format(model_name), model_name) + for model_name in MODEL_TO_INPUT_SHAPE.keys() + ) def test_meta_loss(self, model_name): """Test that meta_init does not update the bias scalars.""" rng = jax.random.PRNGKey(0) flax_module, params, input_shape, _ = _load_model(model_name) - norms = jax.tree.map(lambda node: jnp.linalg.norm(node.reshape(-1)), - params) - normalized_params = jax.tree.map(meta_init.normalize, - params) + norms = jax.tree.map(lambda node: jnp.linalg.norm(node.reshape(-1)), params) + normalized_params = jax.tree.map(meta_init.normalize, params) loss_name = 'cross_entropy' loss_fn = losses.get_loss_fn(loss_name) learned_norms, _ = meta_init.meta_optimize_scales( @@ -117,7 +121,8 @@ def test_meta_loss(self, model_name): hps=meta_init.DEFAULT_HPARAMS, input_shape=input_shape[1:], output_shape=OUTPUT_SHAPE, - rng_key=rng) + rng_key=rng, + ) # Check that all learned bias scales are 0, the meta loss should be # independent of these terms. @@ -145,14 +150,16 @@ def test_sparse_init(self): hps=init_hps, input_shape=input_shape[1:], output_shape=OUTPUT_SHAPE, - rng_key=rng) + rng_key=rng, + ) # Check new params are sparse for key in new_params: num_units_in, num_units_out = new_params[key]['kernel'].shape self.assertLess( jnp.count_nonzero(new_params[key]['kernel']), - (num_units_in + num_units_out) * non_zero_connection_weights) + (num_units_in + num_units_out) * non_zero_connection_weights, + ) self.assertEqual(jnp.count_nonzero(new_params[key]['bias']), 0) diff --git a/init2winit/main.py b/init2winit/main.py index c327e9c6..bfd2f54b 100644 --- a/init2winit/main.py +++ b/init2winit/main.py @@ -72,9 +72,12 @@ # Special, has no direct equivalent in config. -flags.DEFINE_string('experiment_dir', None, - 'Path to save weights and other results. Each trial ' - 'directory will have path experiment_dir/worker_id/.') +flags.DEFINE_string( + 'experiment_dir', + None, + 'Path to save weights and other results. Each trial ' + 'directory will have path experiment_dir/worker_id/.', +) FLAGS = flags.FLAGS @@ -152,11 +155,13 @@ def _run( dataset_builder = datasets.get_dataset(dataset_name) data_selector = datasets.get_data_selector(data_selector_name) dataset_meta_data = datasets.get_dataset_meta_data(dataset_name) - input_pipeline_hps = config_dict.ConfigDict(dict( - num_tf_data_prefetches=num_tf_data_prefetches, - num_device_prefetches=num_device_prefetches, - num_tf_data_map_parallel_calls=num_tf_data_map_parallel_calls, - )) + input_pipeline_hps = config_dict.ConfigDict( + dict( + num_tf_data_prefetches=num_tf_data_prefetches, + num_device_prefetches=num_device_prefetches, + num_tf_data_map_parallel_calls=num_tf_data_map_parallel_calls, + ) + ) training_algorithm_class = ( training_algorithms_registry.get_training_algorithm( training_algorithm_name diff --git a/init2winit/model_lib/adabelief_densenet.py b/init2winit/model_lib/adabelief_densenet.py index 9b06ac5e..96b60f0b 100644 --- a/init2winit/model_lib/adabelief_densenet.py +++ b/init2winit/model_lib/adabelief_densenet.py @@ -25,7 +25,6 @@ The original DenseNet paper can be found here: https://arxiv.org/abs/1608.06993?source=post_page--------------------------- - """ import functools @@ -37,7 +36,6 @@ import jax.numpy as jnp from ml_collections.config_dict import config_dict - DEFAULT_HPARAMS = config_dict.ConfigDict( dict( num_layers=121, # Must be one of [121, 169, 201, 161] @@ -65,6 +63,7 @@ class BottleneckBlock(nn.Module): preceded by 1x1 convoluational operation (and the correpsonding batch normalization and ReLU). """ + growth_rate: int dtype: model_utils.Dtype = jnp.float32 normalizer: str = 'batch_norm' @@ -84,7 +83,8 @@ def __call__(self, x, train): features=self.growth_rate, kernel_size=(3, 3), padding=((1, 1), (1, 1)), - name='conv2')(y) + name='conv2', + )(y) # Concatenate the output and input along the features dimension. y = jnp.concatenate([y, x], axis=3) @@ -97,6 +97,7 @@ class TransitionBlock(nn.Module): Downsampling is achieved by a 1x1 convoluationl layer (with the associated batch norm and ReLU) and a 2x2 average pooling layer. """ + num_features: int use_kernel_size_as_stride_in_pooling: bool dtype: model_utils.Dtype = jnp.float32 @@ -113,7 +114,8 @@ def __call__(self, x, train): y = nn.avg_pool( y, window_shape=(2, 2), - strides=(2, 2) if self.use_kernel_size_as_stride_in_pooling else (1, 1)) + strides=(2, 2) if self.use_kernel_size_as_stride_in_pooling else (1, 1), + ) return y @@ -123,6 +125,7 @@ class DenseNet(nn.Module): The network consists of an inital convolutaional layer, four dense blocks connected by transition blocks, a pooling layer and a classification layer. """ + num_layers: int num_outputs: int growth_rate: int @@ -153,25 +156,25 @@ def update_num_features(num_features, num_blocks, growth_rate, reduction): features=num_features, kernel_size=(3, 3), padding=((1, 1), (1, 1)), - name='conv1')(x) + name='conv1', + )(x) # Internal dense and transtion blocks num_blocks = _block_size_options[self.num_layers] block = functools.partial( - BottleneckBlock, - dtype=self.dtype, - normalizer=self.normalizer) + BottleneckBlock, dtype=self.dtype, normalizer=self.normalizer + ) for i in range(3): y = dense_layers(y, block, num_blocks[i], self.growth_rate) - num_features = update_num_features(num_features, num_blocks[i], - self.growth_rate, self.reduction) + num_features = update_num_features( + num_features, num_blocks[i], self.growth_rate, self.reduction + ) y = TransitionBlock( num_features, dtype=self.dtype, normalizer=self.normalizer, - use_kernel_size_as_stride_in_pooling=self - .use_kernel_size_as_stride_in_pooling)( - y, train=train) + use_kernel_size_as_stride_in_pooling=self.use_kernel_size_as_stride_in_pooling, + )(y, train=train) # Final dense block y = dense_layers(y, block, num_blocks[3], self.growth_rate) @@ -183,13 +186,15 @@ def update_num_features(num_features, num_blocks, growth_rate, reduction): y = nn.avg_pool( y, window_shape=(4, 4), - strides=(4, 4) if self.use_kernel_size_as_stride_in_pooling else (1, 1)) + strides=(4, 4) if self.use_kernel_size_as_stride_in_pooling else (1, 1), + ) # Classification layer y = jnp.reshape(y, (y.shape[0], -1)) if self.normalize_classifier_input: maybe_normalize = model_utils.get_normalizer( - self.normalize_classifier_input, train) + self.normalize_classifier_input, train + ) y = maybe_normalize()(y) y = y * self.classification_scale_factor @@ -218,8 +223,7 @@ def build_flax_module(self): num_outputs=self.hps['output_shape'][-1], growth_rate=self.hps.growth_rate, reduction=self.hps.reduction, - use_kernel_size_as_stride_in_pooling=self.hps - .use_kernel_size_as_stride_in_pooling, + use_kernel_size_as_stride_in_pooling=self.hps.use_kernel_size_as_stride_in_pooling, dtype=self.hps.model_dtype, normalizer=self.hps.normalizer, normalize_classifier_input=self.hps.normalize_classifier_input, diff --git a/init2winit/model_lib/adabelief_resnet.py b/init2winit/model_lib/adabelief_resnet.py index 6954e771..b027b611 100644 --- a/init2winit/model_lib/adabelief_resnet.py +++ b/init2winit/model_lib/adabelief_resnet.py @@ -30,7 +30,6 @@ https://arxiv.org/abs/2010.07468 https://github.com/juntang-zhuang/Adabelief-Optimizer/blob/update_0.1.0/PyTorch_Experiments/classification_cifar10/models/resnet.py - """ import functools @@ -44,7 +43,6 @@ import jax.numpy as jnp from ml_collections.config_dict import config_dict - DEFAULT_HPARAMS = config_dict.ConfigDict( dict( num_filters=16, @@ -55,11 +53,13 @@ model_dtype='float32', virtual_batch_size=None, data_format='NHWC', - )) + ) +) class BasicResidualBlock(nn.Module): """Basic ResNet block.""" + filters: int strides: Tuple[int, int] = (1, 1) dtype: model_utils.Dtype = jnp.float32 @@ -81,16 +81,18 @@ def __call__(self, x, train): batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, - data_format=self.data_format) + data_format=self.data_format, + ) conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) residual = x if needs_projection: residual = conv( - self.filters, (1, 1), self.strides, 'VALID', - name='proj_conv')(residual) + self.filters, (1, 1), self.strides, 'VALID', name='proj_conv' + )(residual) residual = batch_norm(name='proj_bn')( - residual, use_running_average=not train) + residual, use_running_average=not train + ) y = conv(self.filters, (3, 3), self.strides, 'SAME', name='conv1')(x) y = batch_norm(name='bn1')(y, use_running_average=not train) @@ -103,6 +105,7 @@ def __call__(self, x, train): class BottleneckResidualBlock(nn.Module): """Bottleneck ResNet block.""" + filters: int strides: Tuple[int, int] = (1, 1) dtype: model_utils.Dtype = jnp.float32 @@ -124,15 +127,18 @@ def __call__(self, x, train): batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, - data_format=self.data_format) + data_format=self.data_format, + ) conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) residual = x if needs_projection: - residual = conv( - self.filters * 4, (1, 1), self.strides, name='proj_conv')(residual) + residual = conv(self.filters * 4, (1, 1), self.strides, name='proj_conv')( + residual + ) residual = batch_norm(name='proj_bn')( - residual, use_running_average=not train) + residual, use_running_average=not train + ) y = conv(self.filters, (1, 1), name='conv1')(x) y = batch_norm(name='bn1')(y, use_running_average=not train) @@ -143,13 +149,15 @@ def __call__(self, x, train): y = conv(self.filters * 4, (1, 1), name='conv3')(y) y = batch_norm(name='bn3', scale_init=nn.initializers.zeros)( - y, use_running_average=not train) + y, use_running_average=not train + ) y = nn.relu(residual + y) return y class ResNet(nn.Module): """Adabelief ResNetV1.""" + num_outputs: int num_filters: int = 64 num_layers: int = 50 @@ -167,11 +175,14 @@ def __call__(self, x, train): raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[self.num_layers] x = nn.Conv( - self.num_filters, (3, 3), (1, 1), + self.num_filters, + (3, 3), + (1, 1), 'SAME', use_bias=False, dtype=self.dtype, - name='init_conv')(x) + name='init_conv', + )(x) x = normalization.VirtualBatchNorm( momentum=self.batch_norm_momentum, epsilon=self.batch_norm_epsilon, @@ -180,7 +191,8 @@ def __call__(self, x, train): batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, - data_format=self.data_format)(x, use_running_average=not train) + data_format=self.data_format, + )(x, use_running_average=not train) x = nn.relu(x) residual_block = block_type_options[self.num_layers] for i, block_size in enumerate(block_sizes): @@ -195,7 +207,8 @@ def __call__(self, x, train): batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, - data_format=self.data_format)(x, train=train) + data_format=self.data_format, + )(x, train=train) x = jnp.mean(x, axis=(1, 2)) x = nn.Dense(self.num_outputs, dtype=self.dtype)(x) return x @@ -210,7 +223,7 @@ def __call__(self, x, train): 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], - 200: [3, 24, 36, 3] + 200: [3, 24, 36, 3], } block_type_options = { @@ -220,7 +233,7 @@ def __call__(self, x, train): 50: BottleneckResidualBlock, 101: BottleneckResidualBlock, 152: BottleneckResidualBlock, - 200: BottleneckResidualBlock + 200: BottleneckResidualBlock, } @@ -239,7 +252,8 @@ def build_flax_module(self): batch_size=self.hps.batch_size, virtual_batch_size=self.hps.virtual_batch_size, total_batch_size=self.hps.total_accumulated_batch_size, - data_format=self.hps.data_format) + data_format=self.hps.data_format, + ) def get_fake_inputs(self, hps): """Helper method solely for the purpose of initialzing the model.""" diff --git a/init2winit/model_lib/adabelief_vgg.py b/init2winit/model_lib/adabelief_vgg.py index 95422ea1..a02e5367 100644 --- a/init2winit/model_lib/adabelief_vgg.py +++ b/init2winit/model_lib/adabelief_vgg.py @@ -35,7 +35,6 @@ import jax.numpy as jnp from ml_collections.config_dict import config_dict - DEFAULT_HPARAMS = config_dict.ConfigDict( dict( num_layers=11, # Must be one of [11, 13, 16, 19] @@ -76,6 +75,7 @@ def features(x, num_layers, normalizer, dtype, train): class VGG(nn.Module): """Adabelief VGG.""" + num_layers: int num_outputs: int normalizer: str = 'none' @@ -86,7 +86,8 @@ def __call__(self, x, train): x = features(x, self.num_layers, self.normalizer, self.dtype, train) x = jnp.reshape(x, (x.shape[0], -1)) x = classifier( - x, self.num_outputs, dropout_rate=0.5, deterministic=not train) + x, self.num_outputs, dropout_rate=0.5, deterministic=not train + ) return x @@ -96,32 +97,95 @@ def __call__(self, x, train): # letter M indicates a max pooling layer. _layer_size_options = { 1: [ - 8, 'M', 16, 'M', 32, 32, 'M', 64, 64, 'M', 64, 64, 'M' - ], # used for testing only. + 8, + 'M', + 16, + 'M', + 32, + 32, + 'M', + 64, + 64, + 'M', + 64, + 64, + 'M', + ], # used for testing only. 11: [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'], 13: [ - 64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M' + 64, + 64, + 'M', + 128, + 128, + 'M', + 256, + 256, + 'M', + 512, + 512, + 'M', + 512, + 512, + 'M', ], 16: [ - 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, - 512, 512, 'M' + 64, + 64, + 'M', + 128, + 128, + 'M', + 256, + 256, + 256, + 'M', + 512, + 512, + 512, + 'M', + 512, + 512, + 512, + 'M', ], 19: [ - 64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, - 'M', 512, 512, 512, 512, 'M' + 64, + 64, + 'M', + 128, + 128, + 'M', + 256, + 256, + 256, + 256, + 'M', + 512, + 512, + 512, + 512, + 'M', + 512, + 512, + 512, + 512, + 'M', ], } # pylint: disable=[missing-class-docstring] class AdaBeliefVGGModel(base_model.BaseModel): + def build_flax_module(self): """Adabelief VGG.""" return VGG( num_layers=self.hps.num_layers, num_outputs=self.hps['output_shape'][-1], dtype=self.hps.model_dtype, - normalizer=self.hps.normalizer) + normalizer=self.hps.normalizer, + ) def get_fake_inputs(self, hps): """Helper method solely for the purpose of initialzing the model.""" diff --git a/init2winit/model_lib/attention.py b/init2winit/model_lib/attention.py index e57dc466..e3e0fcf0 100644 --- a/init2winit/model_lib/attention.py +++ b/init2winit/model_lib/attention.py @@ -46,7 +46,7 @@ def tag_attention_logits(logits): def tag_attention_probs(softmax_probs): this_module = nn.module._context.module_stack[-1] # pylint: disable=protected-access - num_bad = (softmax_probs > .9999).sum() + num_bad = (softmax_probs > 0.9999).sum() this_module.sow('attn_prob9999', 'attn_prob9999', num_bad) @@ -102,10 +102,8 @@ def _dot_product_attention( dtype = query.dtype assert query.ndim == key.ndim, 'q, k must have same rank.' - assert query.shape[:-3] == key.shape[:-3], ( - 'q, k batch dims must match.') - assert query.shape[-2] == key.shape[-2], ( - 'q, k num_heads must match.') + assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.' + assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.' assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' # calculate attention matrix @@ -113,8 +111,10 @@ def _dot_product_attention( query = query / jnp.sqrt(depth).astype(dtype) # attn weight shape is (batch..., num_heads, q_length, kv_length) # Note: Very important that attn_temp applied BEFORE the masking operation. - attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key, - precision=precision) * attn_temp + attn_weights = ( + jnp.einsum('...qhd,...khd->...hqk', query, key, precision=precision) + * attn_temp + ) # apply attention bias: masking, dropout, proximity bias, etc. if bias is not None: @@ -130,7 +130,7 @@ def _dot_product_attention( tag_attention_probs(attn_weights) # apply attention dropout - if not deterministic and dropout_rate > 0.: + if not deterministic and dropout_rate > 0.0: keep_prob = 1.0 - dropout_rate if broadcast_dropout: # dropout is broadcast across the batch + head dimensions @@ -138,27 +138,29 @@ def _dot_product_attention( keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) - multiplier = (keep.astype(dtype) / - jnp.asarray(keep_prob, dtype=dtype)) + multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype) attn_weights = attn_weights * multiplier # return weighted sum over values for each query position - return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value, - precision=precision) - - -def dot_product_attention(query: Array, - key: Array, - value: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, - broadcast_dropout: bool = True, - dropout_rng: Optional[PRNGKey] = None, - dropout_rate: float = 0., - deterministic: bool = False, - dtype: Optional[Dtype] = None, - precision: PrecisionLike = None, - attn_temp: float = 1.0): + return jnp.einsum( + '...hqk,...khd->...qhd', attn_weights, value, precision=precision + ) + + +def dot_product_attention( + query: Array, + key: Array, + value: Array, + bias: Optional[Array] = None, + mask: Optional[Array] = None, + broadcast_dropout: bool = True, + dropout_rng: Optional[PRNGKey] = None, + dropout_rate: float = 0.0, + deterministic: bool = False, + dtype: Optional[Dtype] = None, + precision: PrecisionLike = None, + attn_temp: float = 1.0, +): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on @@ -168,21 +170,19 @@ def dot_product_attention(query: Array, Note: query, key, value needn't have any batch dimensions. Args: - query: queries for calculating attention with shape of - `[batch..., q_length, num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of - `[batch..., kv_length, num_heads, qk_depth_per_head]`. - value: values to be used in attention with shape of - `[batch..., kv_length, num_heads, v_depth_per_head]`. + query: queries for calculating attention with shape of `[batch..., q_length, + num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of `[batch..., kv_length, + num_heads, qk_depth_per_head]`. + value: values to be used in attention with shape of `[batch..., kv_length, + num_heads, v_depth_per_head]`. bias: bias for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks, padding masks, - proximity bias, etc. + shape `[batch..., num_heads, q_length, kv_length]`. This can be used for + incorporating causal masks, padding masks, proximity bias, etc. mask: mask for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks. - Attention weights are masked out if their corresponding mask value - is `False`. + shape `[batch..., num_heads, q_length, kv_length]`. This can be used for + incorporating causal masks. Attention weights are masked out if their + corresponding mask value is `False`. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate @@ -190,8 +190,8 @@ def dot_product_attention(query: Array, dtype: the dtype of the computation (default: infer from inputs) precision: numerical precision of the computation see `jax.lax.Precision` for details. - attn_temp: Attention logits will be normalized by C / sqrt{d} where C is - the attn_temp. + attn_temp: Attention logits will be normalized by C / sqrt{d} where C is the + attn_temp. Returns: Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. @@ -227,38 +227,37 @@ def dot_product_attention(query: Array, class MultiHeadDotProductAttention(Module): """Multi-head dot-product attention. - Attributes: - num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - dtype: the dtype of the computation - (default: infer from inputs and params) - param_dtype: the dtype passed to parameter initializers (default: float32) - qkv_features: dimension of the key, query, and value. - out_features: dimension of the last projection - broadcast_dropout: bool: use a broadcasted dropout along batch dims. - dropout_rate: dropout rate - deterministic: if false, the attention weight is masked randomly - using dropout, whereas if true, the attention weights - are deterministic. - precision: numerical precision of the computation see `jax.lax.Precision` - for details. - kernel_init: initializer for the kernel of the Dense layers. - bias_init: initializer for the bias of the Dense layers. - use_bias: bool: whether pointwise QKVO dense transforms use bias. - attention_fn: dot_product_attention or compatible function. Accepts - query, key, value, and returns output of shape - `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` - decode: whether to prepare and use an autoregressive cache. - normalize_attention: Apply LayerNorm to query and key before computing - dot_product_attention. + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + dtype: the dtype of the computation (default: infer from inputs and params) + param_dtype: the dtype passed to parameter initializers (default: float32) + qkv_features: dimension of the key, query, and value. + out_features: dimension of the last projection + broadcast_dropout: bool: use a broadcasted dropout along batch dims. + dropout_rate: dropout rate + deterministic: if false, the attention weight is masked randomly using + dropout, whereas if true, the attention weights are deterministic. + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer for the kernel of the Dense layers. + bias_init: initializer for the bias of the Dense layers. + use_bias: bool: whether pointwise QKVO dense transforms use bias. + attention_fn: dot_product_attention or compatible function. Accepts query, + key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,, + num_heads, value_channels]`` + decode: whether to prepare and use an autoregressive cache. + normalize_attention: Apply LayerNorm to query and key before computing + dot_product_attention. """ + num_heads: int dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32 qkv_features: Optional[int] = None out_features: Optional[int] = None broadcast_dropout: bool = True - dropout_rate: float = 0. + dropout_rate: float = 0.0 deterministic: Optional[bool] = None precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init @@ -270,52 +269,55 @@ class MultiHeadDotProductAttention(Module): attn_temp: float = 1.0 @compact - def __call__(self, - inputs_q: Array, - inputs_kv: Array, - mask: Optional[Array] = None, - deterministic: Optional[bool] = None): + def __call__( + self, + inputs_q: Array, + inputs_kv: Array, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None, + ): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: - inputs_q: input queries of shape - `[batch_sizes..., length, features]`. - inputs_kv: key/values of shape - `[batch_sizes..., length, features]`. - mask: attention mask of shape - `[batch_sizes..., num_heads, query_length, key/value_length]`. - Attention weights are masked out if their corresponding mask value - is `False`. - deterministic: if false, the attention weight is masked randomly - using dropout, whereas if true, the attention weights - are deterministic. + inputs_q: input queries of shape `[batch_sizes..., length, features]`. + inputs_kv: key/values of shape `[batch_sizes..., length, features]`. + mask: attention mask of shape `[batch_sizes..., num_heads, query_length, + key/value_length]`. Attention weights are masked out if their + corresponding mask value is `False`. + deterministic: if false, the attention weight is masked randomly using + dropout, whereas if true, the attention weights are deterministic. Returns: output of shape `[batch_sizes..., length, features]`. """ features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] - assert qkv_features % self.num_heads == 0, ( - 'Memory dimension must be divisible by number of heads.') + assert ( + qkv_features % self.num_heads == 0 + ), 'Memory dimension must be divisible by number of heads.' head_dim = qkv_features // self.num_heads - dense = functools.partial(DenseGeneral, - axis=-1, - dtype=self.dtype, - param_dtype=self.param_dtype, - features=(self.num_heads, head_dim), - kernel_init=self.kernel_init, - bias_init=self.bias_init, - use_bias=self.use_bias, - precision=self.precision) + dense = functools.partial( + DenseGeneral, + axis=-1, + dtype=self.dtype, + param_dtype=self.param_dtype, + features=(self.num_heads, head_dim), + kernel_init=self.kernel_init, + bias_init=self.bias_init, + use_bias=self.use_bias, + precision=self.precision, + ) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] - query, key, value = (dense(name='query')(inputs_q), - dense(name='key')(inputs_kv), - dense(name='value')(inputs_kv)) + query, key, value = ( + dense(name='query')(inputs_q), + dense(name='key')(inputs_kv), + dense(name='value')(inputs_kv), + ) if self.normalize_attention: query = nn.LayerNorm(name='query_ln')(query) key = nn.LayerNorm(name='key_ln')(key) @@ -325,21 +327,27 @@ def __call__(self, if self.decode: # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable('cache', 'cached_key') - cached_key = self.variable('cache', 'cached_key', - jnp.zeros, key.shape, key.dtype) - cached_value = self.variable('cache', 'cached_value', - jnp.zeros, value.shape, value.dtype) - cache_index = self.variable('cache', 'cache_index', - lambda: jnp.array(0, dtype=jnp.int32)) + cached_key = self.variable( + 'cache', 'cached_key', jnp.zeros, key.shape, key.dtype + ) + cached_value = self.variable( + 'cache', 'cached_value', jnp.zeros, value.shape, value.dtype + ) + cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32) + ) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = ( - cached_key.value.shape) + cached_key.value.shape + ) # shape check of cached keys against query input expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) if expected_shape != query.shape: - raise ValueError('Autoregressive cache shape error, ' - 'expected query shape %s instead got %s.' % - (expected_shape, query.shape)) + raise ValueError( + 'Autoregressive cache shape error, ' + 'expected query shape %s instead got %s.' + % (expected_shape, query.shape) + ) # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0,) * len(batch_dims) + (cur_index, 0, 0) @@ -354,13 +362,19 @@ def __call__(self, # not the remaining zero elements. mask = combine_masks( mask, - jnp.broadcast_to(jnp.arange(max_length) <= cur_index, - tuple(batch_dims) + (1, 1, max_length))) + jnp.broadcast_to( + jnp.arange(max_length) <= cur_index, + tuple(batch_dims) + (1, 1, max_length), + ), + ) dropout_rng = None - if self.dropout_rate > 0.: # Require `deterministic` only if using dropout. - m_deterministic = merge_param('deterministic', self.deterministic, - deterministic) + if ( + self.dropout_rate > 0.0 + ): # Require `deterministic` only if using dropout. + m_deterministic = merge_param( + 'deterministic', self.deterministic, deterministic + ) if not m_deterministic: dropout_rng = self.make_rng('dropout') else: @@ -378,17 +392,20 @@ def __call__(self, deterministic=m_deterministic, dtype=self.dtype, precision=self.precision, - attn_temp=self.attn_temp) # pytype: disable=wrong-keyword-args + attn_temp=self.attn_temp, + ) # pytype: disable=wrong-keyword-args # back to the original inputs dimensions - out = DenseGeneral(features=features, - axis=(-2, -1), - kernel_init=self.kernel_init, - bias_init=self.bias_init, - use_bias=self.use_bias, - dtype=self.dtype, - param_dtype=self.param_dtype, - precision=self.precision, - name='out')(x) + out = DenseGeneral( + features=features, + axis=(-2, -1), + kernel_init=self.kernel_init, + bias_init=self.bias_init, + use_bias=self.use_bias, + dtype=self.dtype, + param_dtype=self.param_dtype, + precision=self.precision, + name='out', + )(x) return out @@ -396,39 +413,43 @@ class SelfAttention(MultiHeadDotProductAttention): """Self-attention special case of multi-head dot-product attention.""" @compact - def __call__(self, inputs_q: Array, mask: Optional[Array] = None, - deterministic: Optional[bool] = None): + def __call__( + self, + inputs_q: Array, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None, + ): """Applies multi-head dot product self-attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: - inputs_q: input queries of shape - `[batch_sizes..., length, features]`. - mask: attention mask of shape - `[batch_sizes..., num_heads, query_length, key/value_length]`. - Attention weights are masked out if their corresponding mask value - is `False`. - deterministic: if false, the attention weight is masked randomly - using dropout, whereas if true, the attention weights - are deterministic. + inputs_q: input queries of shape `[batch_sizes..., length, features]`. + mask: attention mask of shape `[batch_sizes..., num_heads, query_length, + key/value_length]`. Attention weights are masked out if their + corresponding mask value is `False`. + deterministic: if false, the attention weight is masked randomly using + dropout, whereas if true, the attention weights are deterministic. Returns: output of shape `[batch_sizes..., length, features]`. """ - return super().__call__(inputs_q, inputs_q, mask, - deterministic=deterministic) + return super().__call__( + inputs_q, inputs_q, mask, deterministic=deterministic + ) # mask-making utility functions -def make_attention_mask(query_input: Array, - key_input: Array, - pairwise_fn: Callable[..., Any] = jnp.multiply, - extra_batch_dims: int = 0, - dtype: Dtype = jnp.float32): +def make_attention_mask( + query_input: Array, + key_input: Array, + pairwise_fn: Callable[..., Any] = jnp.multiply, + extra_batch_dims: int = 0, + dtype: Dtype = jnp.float32, +): """Mask-making helper for attention weights. In case of 1d inputs (i.e., `[batch..., len_q]`, `[batch..., len_kv]`, the @@ -439,23 +460,24 @@ def make_attention_mask(query_input: Array, query_input: a batched, flat input of query_length size key_input: a batched, flat input of key_length size pairwise_fn: broadcasting elementwise comparison function - extra_batch_dims: number of extra batch dims to add singleton - axes for, none by default + extra_batch_dims: number of extra batch dims to add singleton axes for, none + by default dtype: mask return dtype Returns: A `[batch..., 1, len_q, len_kv]` shaped mask for 1d attention. """ - mask = pairwise_fn(jnp.expand_dims(query_input, axis=-1), - jnp.expand_dims(key_input, axis=-2)) + mask = pairwise_fn( + jnp.expand_dims(query_input, axis=-1), jnp.expand_dims(key_input, axis=-2) + ) mask = jnp.expand_dims(mask, axis=-3) mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims))) return mask.astype(dtype) -def make_causal_mask(x: Array, - extra_batch_dims: int = 0, - dtype: Dtype = jnp.float32) -> Array: +def make_causal_mask( + x: Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32 +) -> Array: """Make a causal mask for self-attention. In case of 1d inputs (i.e., `[batch..., len]`, the self-attention weights @@ -464,20 +486,24 @@ def make_causal_mask(x: Array, Args: x: input array of shape `[batch..., len]` - extra_batch_dims: number of batch dims to add singleton axes for, - none by default + extra_batch_dims: number of batch dims to add singleton axes for, none by + default dtype: mask return dtype Returns: A `[batch..., 1, len, len]` shaped causal mask for 1d attention. """ idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape) - return make_attention_mask(idxs, idxs, jnp.greater_equal, - extra_batch_dims=extra_batch_dims, dtype=dtype) + return make_attention_mask( + idxs, + idxs, + jnp.greater_equal, + extra_batch_dims=extra_batch_dims, + dtype=dtype, + ) -def combine_masks(*masks: Optional[Array], - dtype: Dtype = jnp.float32) -> Array: +def combine_masks(*masks: Optional[Array], dtype: Dtype = jnp.float32) -> Array: """Combine attention masks. Args: @@ -490,8 +516,9 @@ def combine_masks(*masks: Optional[Array], masks_list = [m for m in masks if m is not None] if not masks_list: return None - assert all(map(lambda x: x.ndim == masks_list[0].ndim, masks_list)), ( - f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}') + assert all( + map(lambda x: x.ndim == masks_list[0].ndim, masks_list) + ), f'masks must have same rank: {tuple(map(lambda x: x.ndim, masks_list))}' mask, *other_masks = masks_list for other_mask in other_masks: mask = jnp.logical_and(mask, other_mask) diff --git a/init2winit/model_lib/autoencoder.py b/init2winit/model_lib/autoencoder.py index 9e6c7bb0..29a2a072 100644 --- a/init2winit/model_lib/autoencoder.py +++ b/init2winit/model_lib/autoencoder.py @@ -38,7 +38,8 @@ activation_function=['relu', 'relu', 'relu', 'relu', 'relu'], kernel_scales=[1.0] * 6, model_dtype='float32', - )) + ) +) class AutoEncoderModel(base_model.BaseModel): @@ -46,19 +47,19 @@ class AutoEncoderModel(base_model.BaseModel): def build_flax_module(self): kernel_inits = [ - initializers.normal(scale) - for scale in self.hps.kernel_scales + initializers.normal(scale) for scale in self.hps.kernel_scales ] return FullyConnected( num_outputs=self.hps['output_shape'][-1], hid_sizes=self.hps.hid_sizes, activation_function=self.hps.activation_function, - kernel_inits=kernel_inits) + kernel_inits=kernel_inits, + ) def get_fake_inputs(self, hps): """Helper method solely for the purpose of initialzing the model.""" dummy_inputs = [ jnp.zeros((hps.batch_size, *hps.input_shape), dtype=hps.model_dtype) - ] + ] return dummy_inputs diff --git a/init2winit/model_lib/base_model.py b/init2winit/model_lib/base_model.py index 66538a52..1249e91b 100644 --- a/init2winit/model_lib/base_model.py +++ b/init2winit/model_lib/base_model.py @@ -34,8 +34,14 @@ P = jax.sharding.PartitionSpec -def _evaluate_batch(flax_module, params, batch_stats, batch, metrics_bundle, - apply_one_hot_in_loss): +def _evaluate_batch( + flax_module, + params, + batch_stats, + batch, + metrics_bundle, + apply_one_hot_in_loss, +): """Evaluates metrics on the given batch. We use the CLU metrics library to evaluate the metrics, and we require that @@ -46,8 +52,8 @@ def _evaluate_batch(flax_module, params, batch_stats, batch, metrics_bundle, flax_module: the Flax linen.nn.Module. params: A dict of trainable model parameters. Passed as {'params': params} into flax_module.apply(). - batch_stats: A dict of non-trainable model state. Passed as - {'batch_stats': batch_stats} into flax_module.apply(). + batch_stats: A dict of non-trainable model state. Passed as {'batch_stats': + batch_stats} into flax_module.apply(). batch: A dictionary with keys 'inputs', 'targets', 'weights'. metrics_bundle: A group of metrics to use for evaluation. apply_one_hot_in_loss: Indicates whether or not the targets are one hot @@ -59,7 +65,8 @@ def _evaluate_batch(flax_module, params, batch_stats, batch, metrics_bundle, """ variables = {'params': params, 'batch_stats': batch_stats} logits = flax_module.apply( - variables, batch['inputs'], mutable=False, train=False) + variables, batch['inputs'], mutable=False, train=False + ) targets = batch['targets'] if apply_one_hot_in_loss: @@ -74,7 +81,8 @@ def _evaluate_batch(flax_module, params, batch_stats, batch, metrics_bundle, # We don't use CLU's `mask` argument here, we handle it ourselves through # `weights`. return metrics_bundle.single_from_model_output( - logits=logits, targets=targets, weights=weights) + logits=logits, targets=targets, weights=weights + ) class BaseModel(object): @@ -193,7 +201,8 @@ def initialize(self, initializer, hps, rng, metrics_logger): fake_input_batch = fake_inputs elif isinstance(hps.input_shape, list): # Typical case for seq2seq models fake_input_batch = [ - np.zeros((2, *x), model_dtype) for x in hps.input_shape] + np.zeros((2, *x), model_dtype) for x in hps.input_shape + ] else: # Typical case for classification models fake_input_batch = [np.zeros((2, *hps.input_shape), model_dtype)] params_rng, init_rng, dropout_rng = jax.random.split(rng, num=3) @@ -220,8 +229,9 @@ def initialize(self, initializer, hps, rng, metrics_logger): functools.partial(self.flax_module.init, train=False), **jit_kwargs ) - init_dict = model_init_fn({'params': params_rng, 'dropout': dropout_rng}, - *fake_input_batch) + init_dict = model_init_fn( + {'params': params_rng, 'dropout': dropout_rng}, *fake_input_batch + ) logging.info( 'Flax module init call took %f seconds.', @@ -332,7 +342,8 @@ def get_sharding(self, params, mesh): sharding_overrides = self.get_sharding_overrides(mesh) overriden_shardings = self._apply_sharding_overrides( - params, mesh, default_shardings, sharding_overrides) + params, mesh, default_shardings, sharding_overrides + ) return overriden_shardings def evaluate_batch(self, params, batch_stats, batch): @@ -343,7 +354,8 @@ def evaluate_batch(self, params, batch_stats, batch): batch_stats, batch, self.metrics_bundle, - self.dataset_meta_data['apply_one_hot_in_loss']) + self.dataset_meta_data['apply_one_hot_in_loss'], + ) def apply_on_batch(self, params, batch_stats, batch, **apply_kwargs): """Wrapper around flax_module.apply. @@ -375,11 +387,14 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): if dropout_rng is not None: apply_kwargs['rngs'] = {'dropout': dropout_rng} - logits, new_batch_stats = self.apply_on_batch(params, batch_stats, batch, - **apply_kwargs) + logits, new_batch_stats = self.apply_on_batch( + params, batch_stats, batch, **apply_kwargs + ) weights = batch.get('weights') - return self.training_objective_fn(params, logits, batch['targets'], - weights), new_batch_stats + return ( + self.training_objective_fn(params, logits, batch['targets'], weights), + new_batch_stats, + ) def training_objective_fn(self, params, logits, targets, weights): """Returns the training objective (loss + regularizer) on a batch of logits. @@ -399,9 +414,7 @@ def training_objective_fn(self, params, logits, targets, weights): # is multi-class cross-entropy. label_smoothing = self.hps.get('label_smoothing') if self._loss_name == 'cross_entropy' and label_smoothing is not None: - targets = model_utils.apply_label_smoothing( - targets, label_smoothing - ) + targets = model_utils.apply_label_smoothing(targets, label_smoothing) objective_numerator, objective_denominator = self.loss_fn( logits, targets, weights @@ -427,7 +440,8 @@ def param_shapes(self): if self._param_shapes is None: raise ValueError( 'This should not happen, model.initialize() should be called ' - 'before model.param_shapes!') + 'before model.param_shapes!' + ) return self._param_shapes @property @@ -436,7 +450,8 @@ def params_types(self): if self._param_types is None: raise ValueError( 'This should not happen, model.initialize() should be called ' - 'before model.param_types!') + 'before model.param_types!' + ) return self._param_types def is_output_params(self, param_key: str) -> bool: diff --git a/init2winit/model_lib/binarize_layers.py b/init2winit/model_lib/binarize_layers.py index b953473c..6dd1f4ca 100644 --- a/init2winit/model_lib/binarize_layers.py +++ b/init2winit/model_lib/binarize_layers.py @@ -20,10 +20,9 @@ third_party/py/flax/linen/linear.py and third_party/py/flax/linen/attention.py. """ - import dataclasses import functools -from typing import Any, Callable, Iterable, Optional, Sequence, Tuple, Union, TYPE_CHECKING +from typing import Any, Callable, Iterable, Optional, Sequence, TYPE_CHECKING, Tuple, Union from flax import struct as flax_struct from flax.linen.attention import combine_masks from flax.linen.dtypes import promote_dtype @@ -40,14 +39,20 @@ from ml_collections import config_dict import numpy as np - PRNGKey = Any Shape = Tuple[int, ...] Dtype = Any Array = Any -dataclass = flax_struct.dataclass if not TYPE_CHECKING else dataclasses.dataclass -PrecisionLike = Union[None, str, lax.Precision, Tuple[str, str], - Tuple[lax.Precision, lax.Precision]] +dataclass = ( + flax_struct.dataclass if not TYPE_CHECKING else dataclasses.dataclass +) +PrecisionLike = Union[ + None, + str, + lax.Precision, + Tuple[str, str], + Tuple[lax.Precision, lax.Precision], +] default_kernel_init = lecun_normal() @@ -114,16 +119,17 @@ def floor_with_gradient(x: jax.Array) -> jax.Array: """Floor with Straight-Through-Estimator gradient.""" return jnp.floor(x) + add_straight_through_estimator(floor_with_gradient) class BinarizeOps: - """Binarization operator. - """ + """Binarization operator.""" @dataclasses.dataclass class HParams: """Hyperparameters for binarization.""" + # A fixed quantization bound within which inputs have gradients bound: Optional[Union[float, jnp.ndarray]] # A small value subtracted from the clipping bound to prevent from overflow @@ -131,19 +137,20 @@ class HParams: # Axis along which to automatically get bound values scale_axis: Optional[Union[Iterable[int], str]] - def __init__(self, - bound: jnp.ndarray, - epsilon: float, # default value 2**(-7) - dtype: Any): + def __init__( + self, + bound: jnp.ndarray, + epsilon: float, # default value 2**(-7) + dtype: Any, + ): self.bound = bound self.epsilon = epsilon self.dtype = dtype @classmethod - def create_ops(cls, - x: jnp.ndarray, - hparams: HParams, - dtype: Any = jnp.bfloat16): + def create_ops( + cls, x: jnp.ndarray, hparams: HParams, dtype: Any = jnp.bfloat16 + ): """Create BinarizationOps for symmetric inputs clipped to [-bounds, bounds]. Args: @@ -165,10 +172,7 @@ def create_ops(cls, bound = jnp.asarray(bound, dtype) bound += jnp.finfo(jnp.float32).eps # avoid dividing by zero bound = jax.lax.stop_gradient(bound) # no gradietns on the bound - return cls( - bound=bound, - dtype=dtype, - epsilon=hparams.epsilon) + return cls(bound=bound, dtype=dtype, epsilon=hparams.epsilon) def binarize(self, x: jnp.ndarray) -> jnp.ndarray: """The binarization operation. @@ -206,6 +210,7 @@ class BiDense(Module): weight_bin_hparams: hyperparameters for binarizing the kernels. inputs_bin_hparams: hyperparameters for binarizing the inputs. """ + features: int use_bias: bool = True dtype: Optional[Dtype] = None @@ -226,13 +231,16 @@ def __call__(self, inputs: Array) -> Array: Returns: The transformed input. """ - kernel = self.param('kernel', - self.kernel_init, - (jnp.shape(inputs)[-1], self.features), - self.param_dtype) + kernel = self.param( + 'kernel', + self.kernel_init, + (jnp.shape(inputs)[-1], self.features), + self.param_dtype, + ) if self.use_bias: - bias = self.param('bias', self.bias_init, (self.features,), - self.param_dtype) + bias = self.param( + 'bias', self.bias_init, (self.features,), self.param_dtype + ) else: bias = None inputs, kernel, bias = promote_dtype(inputs, kernel, bias, dtype=self.dtype) @@ -241,7 +249,8 @@ def __call__(self, inputs: Array) -> Array: if self.weight_bin_hparams is not None: whps = self.weight_bin_hparams assert whps.scale_axis in [ - 'auto', None + 'auto', + None, ], f'scale_axis can only be auto or None. Got {whps.scale_axis}.' if whps.scale_axis == 'auto': # automatically select one bound per out_feature @@ -251,12 +260,15 @@ def __call__(self, inputs: Array) -> Array: weight_ops = BinarizeOps.create_ops(kernel, whps, dtype=self.dtype) w_bound_shape = weight_ops.bound.shape assert w_bound_shape == (1, self.features) or w_bound_shape == ( - 1, 1), f'kernel bound shape {w_bound_shape} is not realistic' + 1, + 1, + ), f'kernel bound shape {w_bound_shape} is not realistic' kernel = weight_ops.binarize(kernel) if self.inputs_bin_hparams is not None: ahps = self.inputs_bin_hparams assert ahps.scale_axis in [ - 'auto', None + 'auto', + None, ], f'scale_axis can only be auto or None. Got {ahps.scale_axis}.' if ahps.scale_axis == 'auto': ahps = ahps.to_dict() @@ -265,13 +277,22 @@ def __call__(self, inputs: Array) -> Array: inputs_ops = BinarizeOps.create_ops(inputs, ahps, dtype=self.dtype) a_bound_shape = inputs_ops.bound.shape assert a_bound_shape == ( - inputs.shape[0], inputs.shape[1], 1) or a_bound_shape == ( - 1, 1, 1), f'inputs bound shape {a_bound_shape} is not realistic' + inputs.shape[0], + inputs.shape[1], + 1, + ) or a_bound_shape == ( + 1, + 1, + 1, + ), f'inputs bound shape {a_bound_shape} is not realistic' inputs = inputs_ops.binarize(inputs) # ---------------------- - y = lax.dot_general(inputs, kernel, - (((inputs.ndim - 1,), (0,)), ((), ())), - precision=self.precision) + y = lax.dot_general( + inputs, + kernel, + (((inputs.ndim - 1,), (0,)), ((), ())), + precision=self.precision, + ) if bias is not None: y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,)) return y @@ -307,6 +328,7 @@ class BiDenseGeneral(Module): weight_bin_hparams: hyperparameters for binarizing the kernels. inputs_bin_hparams: hyperparameters for binarizing the inputs. """ + features: Union[int, Sequence[int]] axis: Union[int, Sequence[int]] = -1 batch_dims: Sequence[int] = () @@ -335,8 +357,11 @@ def __call__(self, inputs: Array) -> Array: if batch_dims: max_dim = np.max(batch_dims) if set(batch_dims) != set(range(max_dim + 1)): - raise ValueError('batch_dims %s must be consecutive leading ' - 'dimensions starting from 0.' % str(batch_dims)) + raise ValueError( + 'batch_dims %s must be consecutive leading ' + 'dimensions starting from 0.' + % str(batch_dims) + ) ndim = inputs.ndim n_batch_dims = len(batch_dims) @@ -346,34 +371,51 @@ def __call__(self, inputs: Array) -> Array: def kernel_init_wrap(rng, shape, dtype=jnp.float32): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) - flat_shape = (np.prod(shape[n_batch_dims:n_axis + n_batch_dims]), - np.prod(shape[-n_features:]),) - kernel = jnp.concatenate([self.kernel_init(rng, flat_shape, dtype) - for _ in range(size_batch_dims)], axis=0) + flat_shape = ( + np.prod(shape[n_batch_dims : n_axis + n_batch_dims]), + np.prod(shape[-n_features:]), + ) + kernel = jnp.concatenate( + [ + self.kernel_init(rng, flat_shape, dtype) + for _ in range(size_batch_dims) + ], + axis=0, + ) return jnp.reshape(kernel, shape) batch_shape = tuple(inputs.shape[ax] for ax in batch_dims) # batch and non-contracting dims of input with 1s for batch dims. expanded_batch_shape = tuple( inputs.shape[ax] if ax in batch_dims else 1 - for ax in range(inputs.ndim) if ax not in axis) + for ax in range(inputs.ndim) + if ax not in axis + ) kernel_shape = tuple(inputs.shape[ax] for ax in axis) + features - kernel = self.param('kernel', kernel_init_wrap, batch_shape + kernel_shape, - self.param_dtype) + kernel = self.param( + 'kernel', kernel_init_wrap, batch_shape + kernel_shape, self.param_dtype + ) batch_ind = tuple(range(n_batch_dims)) contract_ind = tuple(range(n_batch_dims, n_axis + n_batch_dims)) if self.use_bias: + def bias_init_wrap(rng, shape, dtype=jnp.float32): size_batch_dims = np.prod(shape[:n_batch_dims], dtype=np.int32) flat_shape = (np.prod(shape[-n_features:]),) - bias = jnp.concatenate([self.bias_init(rng, flat_shape, dtype) - for _ in range(size_batch_dims)], axis=0) + bias = jnp.concatenate( + [ + self.bias_init(rng, flat_shape, dtype) + for _ in range(size_batch_dims) + ], + axis=0, + ) return jnp.reshape(bias, shape) - bias = self.param('bias', bias_init_wrap, batch_shape + features, - self.param_dtype) + bias = self.param( + 'bias', bias_init_wrap, batch_shape + features, self.param_dtype + ) else: bias = None @@ -384,7 +426,8 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32): if self.weight_bin_hparams is not None: whps = self.weight_bin_hparams assert whps.scale_axis in [ - 'auto', None + 'auto', + None, ], f'scale_axis can only be auto or None. Got {whps.scale_axis}.' if whps.scale_axis == 'auto': # automatically select one bound per out_feature @@ -393,14 +436,16 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32): whps = config_dict.ConfigDict(whps) weight_ops = BinarizeOps.create_ops(kernel, whps, dtype=self.dtype) w_bound_shape = weight_ops.bound.shape - assert w_bound_shape == (1,) * n_axis + features or w_bound_shape == ( - 1, - ) * kernel.ndim, f'kernel bound shape {w_bound_shape} is not realistic' + assert ( + w_bound_shape == (1,) * n_axis + features + or w_bound_shape == (1,) * kernel.ndim + ), f'kernel bound shape {w_bound_shape} is not realistic' kernel = weight_ops.binarize(kernel) if self.inputs_bin_hparams is not None: ahps = self.inputs_bin_hparams assert ahps.scale_axis in [ - 'auto', None + 'auto', + None, ], f'scale_axis can only be auto or None. Got {ahps.scale_axis}.' if ahps.scale_axis == 'auto': ahps = ahps.to_dict() @@ -408,17 +453,21 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32): ahps = config_dict.ConfigDict(ahps) inputs_ops = BinarizeOps.create_ops(inputs, ahps, dtype=self.dtype) a_bound_shape = inputs_ops.bound.shape - assert a_bound_shape == tuple([ - inputs.shape[i] for i in range(ndim - n_axis) - ]) + (1,) * n_axis or a_bound_shape == ( - 1,) * ndim, f'inputs bound shape {a_bound_shape} is not realistic' + assert ( + a_bound_shape + == tuple([inputs.shape[i] for i in range(ndim - n_axis)]) + + (1,) * n_axis + or a_bound_shape == (1,) * ndim + ), f'inputs bound shape {a_bound_shape} is not realistic' inputs = inputs_ops.binarize(inputs) # ---------------------- - out = lax.dot_general(inputs, - kernel, - ((axis, contract_ind), (batch_dims, batch_ind)), - precision=self.precision) + out = lax.dot_general( + inputs, + kernel, + ((axis, contract_ind), (batch_dims, batch_ind)), + precision=self.precision, + ) # dot_general output has shape [batch_dims/group_dims] + [feature_dims] if self.use_bias: # expand bias shape to broadcast bias over batch dims. @@ -427,16 +476,18 @@ def bias_init_wrap(rng, shape, dtype=jnp.float32): return out -def dot_product_attention_weights(query: Array, - key: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, - broadcast_dropout: bool = True, - dropout_rng: Optional[PRNGKey] = None, - dropout_rate: float = 0., - deterministic: bool = False, - dtype: Optional[Dtype] = None, - precision: PrecisionLike = None): +def dot_product_attention_weights( + query: Array, + key: Array, + bias: Optional[Array] = None, + mask: Optional[Array] = None, + broadcast_dropout: bool = True, + dropout_rng: Optional[PRNGKey] = None, + dropout_rate: float = 0.0, + deterministic: bool = False, + dtype: Optional[Dtype] = None, + precision: PrecisionLike = None, +): """Computes dot-product attention weights given query and key. Used by :func:`dot_product_attention`, which is what you'll most likely use. @@ -444,19 +495,17 @@ def dot_product_attention_weights(query: Array, you can directly call this function and call einsum yourself. Args: - query: queries for calculating attention with shape of - `[batch..., q_length, num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of - `[batch..., kv_length, num_heads, qk_depth_per_head]`. + query: queries for calculating attention with shape of `[batch..., q_length, + num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of `[batch..., kv_length, + num_heads, qk_depth_per_head]`. bias: bias for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks, padding masks, - proximity bias, etc. + shape `[batch..., num_heads, q_length, kv_length]`. This can be used for + incorporating causal masks, padding masks, proximity bias, etc. mask: mask for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks. - Attention weights are masked out if their corresponding mask value - is `False`. + shape `[batch..., num_heads, q_length, kv_length]`. This can be used for + incorporating causal masks. Attention weights are masked out if their + corresponding mask value is `False`. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate @@ -472,18 +521,17 @@ def dot_product_attention_weights(query: Array, dtype = query.dtype assert query.ndim == key.ndim, 'q, k must have same rank.' - assert query.shape[:-3] == key.shape[:-3], ( - 'q, k batch dims must match.') - assert query.shape[-2] == key.shape[-2], ( - 'q, k num_heads must match.') + assert query.shape[:-3] == key.shape[:-3], 'q, k batch dims must match.' + assert query.shape[-2] == key.shape[-2], 'q, k num_heads must match.' assert query.shape[-1] == key.shape[-1], 'q, k depths must match.' # calculate attention matrix depth = query.shape[-1] query = query / jnp.sqrt(depth).astype(dtype) # attn weight shape is (batch..., num_heads, q_length, kv_length) - attn_weights = jnp.einsum('...qhd,...khd->...hqk', query, key, - precision=precision) + attn_weights = jnp.einsum( + '...qhd,...khd->...hqk', query, key, precision=precision + ) # apply attention bias: masking, dropout, proximity bias, etc. if bias is not None: @@ -497,7 +545,7 @@ def dot_product_attention_weights(query: Array, attn_weights = jax.nn.softmax(attn_weights).astype(dtype) # apply attention dropout - if not deterministic and dropout_rate > 0.: + if not deterministic and dropout_rate > 0.0: keep_prob = 1.0 - dropout_rate if broadcast_dropout: # dropout is broadcast across the batch + head dimensions @@ -505,24 +553,25 @@ def dot_product_attention_weights(query: Array, keep = random.bernoulli(dropout_rng, keep_prob, dropout_shape) else: keep = random.bernoulli(dropout_rng, keep_prob, attn_weights.shape) - multiplier = (keep.astype(dtype) / - jnp.asarray(keep_prob, dtype=dtype)) + multiplier = keep.astype(dtype) / jnp.asarray(keep_prob, dtype=dtype) attn_weights = attn_weights * multiplier return attn_weights -def dot_product_attention(query: Array, - key: Array, - value: Array, - bias: Optional[Array] = None, - mask: Optional[Array] = None, - broadcast_dropout: bool = True, - dropout_rng: Optional[PRNGKey] = None, - dropout_rate: float = 0., - deterministic: bool = False, - dtype: Optional[Dtype] = None, - precision: PrecisionLike = None): +def dot_product_attention( + query: Array, + key: Array, + value: Array, + bias: Optional[Array] = None, + mask: Optional[Array] = None, + broadcast_dropout: bool = True, + dropout_rng: Optional[PRNGKey] = None, + dropout_rate: float = 0.0, + deterministic: bool = False, + dtype: Optional[Dtype] = None, + precision: PrecisionLike = None, +): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on @@ -532,21 +581,19 @@ def dot_product_attention(query: Array, Note: query, key, value needn't have any batch dimensions. Args: - query: queries for calculating attention with shape of - `[batch..., q_length, num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of - `[batch..., kv_length, num_heads, qk_depth_per_head]`. - value: values to be used in attention with shape of - `[batch..., kv_length, num_heads, v_depth_per_head]`. + query: queries for calculating attention with shape of `[batch..., q_length, + num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of `[batch..., kv_length, + num_heads, qk_depth_per_head]`. + value: values to be used in attention with shape of `[batch..., kv_length, + num_heads, v_depth_per_head]`. bias: bias for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks, padding masks, - proximity bias, etc. + shape `[batch..., num_heads, q_length, kv_length]`. This can be used for + incorporating causal masks, padding masks, proximity bias, etc. mask: mask for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks. - Attention weights are masked out if their corresponding mask value - is `False`. + shape `[batch..., num_heads, q_length, kv_length]`. This can be used for + incorporating causal masks. Attention weights are masked out if their + corresponding mask value is `False`. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate @@ -561,57 +608,68 @@ def dot_product_attention(query: Array, query, key, value = promote_dtype(query, key, value, dtype=dtype) dtype = query.dtype assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' - assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( - 'q, k, v batch dims must match.') - assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( - 'q, k, v num_heads must match.') + assert ( + query.shape[:-3] == key.shape[:-3] == value.shape[:-3] + ), 'q, k, v batch dims must match.' + assert ( + query.shape[-2] == key.shape[-2] == value.shape[-2] + ), 'q, k, v num_heads must match.' assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' # compute attention weights attn_weights = dot_product_attention_weights( - query, key, bias, mask, broadcast_dropout, dropout_rng, dropout_rate, - deterministic, dtype, precision) + query, + key, + bias, + mask, + broadcast_dropout, + dropout_rng, + dropout_rate, + deterministic, + dtype, + precision, + ) # return weighted sum over values for each query position - return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value, - precision=precision) + return jnp.einsum( + '...hqk,...khd->...qhd', attn_weights, value, precision=precision + ) class MultiHeadDotProductAttention(Module): """Multi-head dot-product attention. - Attributes: - num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) - should be divisible by the number of heads. - dtype: the dtype of the computation - (default: infer from inputs and params) - param_dtype: the dtype passed to parameter initializers (default: float32) - qkv_features: dimension of the key, query, and value. - out_features: dimension of the last projection - broadcast_dropout: bool: use a broadcasted dropout along batch dims. - dropout_rate: dropout rate - deterministic: if false, the attention weight is masked randomly - using dropout, whereas if true, the attention weights - are deterministic. - precision: numerical precision of the computation see `jax.lax.Precision` - for details. - kernel_init: initializer for the kernel of the Dense layers. - bias_init: initializer for the bias of the Dense layers. - use_bias: bool: whether pointwise QKVO dense transforms use bias. - attention_fn: dot_product_attention or compatible function. Accepts - query, key, value, and returns output of shape - `[bs, dim1, dim2, ..., dimN,, num_heads, value_channels]`` - decode: whether to prepare and use an autoregressive cache. - binarize_hparams: hyperparameters for binarization. - dynamic_context: contains flags that decide if performing quantization. + Attributes: + num_heads: number of attention heads. Features (i.e. inputs_q.shape[-1]) + should be divisible by the number of heads. + dtype: the dtype of the computation (default: infer from inputs and params) + param_dtype: the dtype passed to parameter initializers (default: float32) + qkv_features: dimension of the key, query, and value. + out_features: dimension of the last projection + broadcast_dropout: bool: use a broadcasted dropout along batch dims. + dropout_rate: dropout rate + deterministic: if false, the attention weight is masked randomly using + dropout, whereas if true, the attention weights are deterministic. + precision: numerical precision of the computation see `jax.lax.Precision` + for details. + kernel_init: initializer for the kernel of the Dense layers. + bias_init: initializer for the bias of the Dense layers. + use_bias: bool: whether pointwise QKVO dense transforms use bias. + attention_fn: dot_product_attention or compatible function. Accepts query, + key, value, and returns output of shape `[bs, dim1, dim2, ..., dimN,, + num_heads, value_channels]`` + decode: whether to prepare and use an autoregressive cache. + binarize_hparams: hyperparameters for binarization. + dynamic_context: contains flags that decide if performing quantization. """ + num_heads: int dtype: Optional[Dtype] = None param_dtype: Dtype = jnp.float32 qkv_features: Optional[int] = None out_features: Optional[int] = None broadcast_dropout: bool = True - dropout_rate: float = 0. + dropout_rate: float = 0.0 deterministic: Optional[bool] = None precision: PrecisionLike = None kernel_init: Callable[[PRNGKey, Shape, Dtype], Array] = default_kernel_init @@ -627,28 +685,26 @@ class MultiHeadDotProductAttention(Module): ) @compact - def __call__(self, - inputs_q: Array, - inputs_kv: Array, - mask: Optional[Array] = None, - deterministic: Optional[bool] = None): + def __call__( + self, + inputs_q: Array, + inputs_kv: Array, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None, + ): """Applies multi-head dot product attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: - inputs_q: input queries of shape - `[batch_sizes..., length, features]`. - inputs_kv: key/values of shape - `[batch_sizes..., length, features]`. - mask: attention mask of shape - `[batch_sizes..., num_heads, query_length, key/value_length]`. - Attention weights are masked out if their corresponding mask value - is `False`. - deterministic: if false, the attention weight is masked randomly - using dropout, whereas if true, the attention weights - are deterministic. + inputs_q: input queries of shape `[batch_sizes..., length, features]`. + inputs_kv: key/values of shape `[batch_sizes..., length, features]`. + mask: attention mask of shape `[batch_sizes..., num_heads, query_length, + key/value_length]`. Attention weights are masked out if their + corresponding mask value is `False`. + deterministic: if false, the attention weight is masked randomly using + dropout, whereas if true, the attention weights are deterministic. Returns: output of shape `[batch_sizes..., length, features]`. @@ -663,8 +719,9 @@ def __call__(self, features = self.out_features or inputs_q.shape[-1] qkv_features = self.qkv_features or inputs_q.shape[-1] - assert qkv_features % self.num_heads == 0, ( - 'Memory dimension must be divisible by number of heads.') + assert ( + qkv_features % self.num_heads == 0 + ), 'Memory dimension must be divisible by number of heads.' head_dim = qkv_features // self.num_heads dense = functools.partial( @@ -678,20 +735,26 @@ def __call__(self, use_bias=self.use_bias, precision=self.precision, weight_bin_hparams=whps, - inputs_bin_hparams=kqv_ahps) + inputs_bin_hparams=kqv_ahps, + ) # project inputs_q to multi-headed q/k/v # dimensions are then [batch..., length, n_heads, n_features_per_head] - query, key, value = (dense(name='query')(inputs_q), - dense(name='key')(inputs_kv), - dense(name='value')(inputs_kv)) + query, key, value = ( + dense(name='query')(inputs_q), + dense(name='key')(inputs_kv), + dense(name='value')(inputs_kv), + ) # add layernorms after K,Q,V transformation to support binarization query_normalize = model_utils.get_normalizer( - 'layer_norm', not deterministic, dtype=self.dtype) + 'layer_norm', not deterministic, dtype=self.dtype + ) key_normalize = model_utils.get_normalizer( - 'layer_norm', not deterministic, dtype=self.dtype) + 'layer_norm', not deterministic, dtype=self.dtype + ) value_normalize = model_utils.get_normalizer( - 'layer_norm', not deterministic, dtype=self.dtype) + 'layer_norm', not deterministic, dtype=self.dtype + ) query = query_normalize()(query) key = key_normalize()(key) value = value_normalize()(value) @@ -701,21 +764,27 @@ def __call__(self, if self.decode: # detect if we're initializing by absence of existing cache data. is_initialized = self.has_variable('cache', 'cached_key') - cached_key = self.variable('cache', 'cached_key', - jnp.zeros, key.shape, key.dtype) - cached_value = self.variable('cache', 'cached_value', - jnp.zeros, value.shape, value.dtype) - cache_index = self.variable('cache', 'cache_index', - lambda: jnp.array(0, dtype=jnp.int32)) + cached_key = self.variable( + 'cache', 'cached_key', jnp.zeros, key.shape, key.dtype + ) + cached_value = self.variable( + 'cache', 'cached_value', jnp.zeros, value.shape, value.dtype + ) + cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.int32) + ) if is_initialized: *batch_dims, max_length, num_heads, depth_per_head = ( - cached_key.value.shape) + cached_key.value.shape + ) # shape check of cached keys against query input expected_shape = tuple(batch_dims) + (1, num_heads, depth_per_head) if expected_shape != query.shape: - raise ValueError('Autoregressive cache shape error, ' - 'expected query shape %s instead got %s.' % - (expected_shape, query.shape)) + raise ValueError( + 'Autoregressive cache shape error, ' + 'expected query shape %s instead got %s.' + % (expected_shape, query.shape) + ) # update key, value caches with our new 1d spatial slices cur_index = cache_index.value indices = (0,) * len(batch_dims) + (cur_index, 0, 0) @@ -730,13 +799,19 @@ def __call__(self, # not the remaining zero elements. mask = combine_masks( mask, - jnp.broadcast_to(jnp.arange(max_length) <= cur_index, - tuple(batch_dims) + (1, 1, max_length))) + jnp.broadcast_to( + jnp.arange(max_length) <= cur_index, + tuple(batch_dims) + (1, 1, max_length), + ), + ) dropout_rng = None - if self.dropout_rate > 0.: # Require `deterministic` only if using dropout. - m_deterministic = merge_param('deterministic', self.deterministic, - deterministic) + if ( + self.dropout_rate > 0.0 + ): # Require `deterministic` only if using dropout. + m_deterministic = merge_param( + 'deterministic', self.deterministic, deterministic + ) if not m_deterministic: dropout_rng = self.make_rng('dropout') else: @@ -753,7 +828,8 @@ def __call__(self, broadcast_dropout=self.broadcast_dropout, deterministic=m_deterministic, dtype=self.dtype, - precision=self.precision) # pytype: disable=wrong-keyword-args + precision=self.precision, + ) # pytype: disable=wrong-keyword-args # back to the original inputs dimensions out = BiDenseGeneral( # use the binarized version of DenseGeneral features=features, @@ -766,11 +842,13 @@ def __call__(self, precision=self.precision, weight_bin_hparams=whps, inputs_bin_hparams=out_ahps, - name='out')(x) + name='out', + )(x) # add a layernorm followed by a local shortcut to support binarization out_normalize = model_utils.get_normalizer( - 'layer_norm', not deterministic, dtype=self.dtype) + 'layer_norm', not deterministic, dtype=self.dtype + ) out = out_normalize()(out) out = out + jnp.reshape(x, out.shape) return out @@ -780,26 +858,28 @@ class SelfAttention(MultiHeadDotProductAttention): """Self-attention special case of multi-head dot-product attention.""" @compact - def __call__(self, inputs_q: Array, mask: Optional[Array] = None, - deterministic: Optional[bool] = None): + def __call__( + self, + inputs_q: Array, + mask: Optional[Array] = None, + deterministic: Optional[bool] = None, + ): """Applies multi-head dot product self-attention on the input data. Projects the inputs into multi-headed query, key, and value vectors, applies dot-product attention and project the results to an output vector. Args: - inputs_q: input queries of shape - `[batch_sizes..., length, features]`. - mask: attention mask of shape - `[batch_sizes..., num_heads, query_length, key/value_length]`. - Attention weights are masked out if their corresponding mask value - is `False`. - deterministic: if false, the attention weight is masked randomly - using dropout, whereas if true, the attention weights - are deterministic. + inputs_q: input queries of shape `[batch_sizes..., length, features]`. + mask: attention mask of shape `[batch_sizes..., num_heads, query_length, + key/value_length]`. Attention weights are masked out if their + corresponding mask value is `False`. + deterministic: if false, the attention weight is masked randomly using + dropout, whereas if true, the attention weights are deterministic. Returns: output of shape `[batch_sizes..., length, features]`. """ - return super().__call__(inputs_q, inputs_q, mask, - deterministic=deterministic) + return super().__call__( + inputs_q, inputs_q, mask, deterministic=deterministic + ) diff --git a/init2winit/model_lib/conformer.py b/init2winit/model_lib/conformer.py index 146ef857..05dc0d53 100644 --- a/init2winit/model_lib/conformer.py +++ b/init2winit/model_lib/conformer.py @@ -68,7 +68,8 @@ enable_conformer_post_layer_norm=True, use_lingvo_attention=False, attn_temperature=1.0, - )) + ) +) DEFAULT_HPARAMS = config_dict.ConfigDict( @@ -93,12 +94,15 @@ enable_decoder_pre_layer_norm=True, enable_conformer_post_layer_norm=True, use_lingvo_attention=False, - attn_temperature=1.0)) + attn_temperature=1.0, + ) +) @struct.dataclass class ConformerConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 0 dtype: Any = jnp.float32 encoder_dim: int = 0 @@ -143,6 +147,7 @@ class LayerNorm(nn.Module): zeros, this differs from default flax implementation of multiplying by scale and initializing to ones. """ + dim: int = 0 epsilon: float = 1e-6 @@ -156,7 +161,7 @@ def __call__(self, inputs): var = jnp.mean(jnp.square(inputs - mean), axis=[-1], keepdims=True) normed_inputs = (inputs - mean) * jax.lax.rsqrt(var + self.epsilon) - normed_inputs *= (1 + self.scale) + normed_inputs *= 1 + self.scale normed_inputs += self.bias return normed_inputs @@ -169,6 +174,7 @@ class Subsample(nn.Module): encoder_dim: model dimension of conformer. input_dropout_rate: dropout rate for inputs. """ + encoder_dim: int = 0 input_dropout_rate: float = 0.0 @@ -178,29 +184,32 @@ def __call__(self, inputs, input_paddings, train): outputs = jnp.expand_dims(inputs, axis=-1) outputs, output_paddings = Conv2dSubsampling( - input_channels=1, output_channels=self.encoder_dim)(outputs, - output_paddings) + input_channels=1, output_channels=self.encoder_dim + )(outputs, output_paddings) outputs, output_paddings = Conv2dSubsampling( - input_channels=self.encoder_dim, - output_channels=self.encoder_dim)(outputs, output_paddings) + input_channels=self.encoder_dim, output_channels=self.encoder_dim + )(outputs, output_paddings) batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape outputs = jnp.reshape( - outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)) + outputs, (batch_size, subsampled_lengths, subsampled_dims * channels) + ) outputs = nn.Dense( self.encoder_dim, use_bias=True, - kernel_init=nn.initializers.xavier_uniform())(outputs) + kernel_init=nn.initializers.xavier_uniform(), + )(outputs) outputs = outputs + AddPositionalEmbedding(embedding_dim=self.encoder_dim)( - seq_length=outputs.shape[1]) + seq_length=outputs.shape[1] + ) - outputs = nn.Dropout( - rate=self.input_dropout_rate, deterministic=not train)( - outputs) + outputs = nn.Dropout(rate=self.input_dropout_rate, deterministic=not train)( + outputs + ) return outputs, output_paddings @@ -212,6 +221,7 @@ class Conv2dSubsampling(nn.Module): 2) Also performs strided convolution over input_paddings to return the correct paddings for downstream layers. """ + input_channels: int = 0 output_channels: int = 0 filter_stride: List[int] = (2, 2) @@ -219,10 +229,12 @@ class Conv2dSubsampling(nn.Module): def setup(self): self.filter_shape = (3, 3, self.input_channels, self.output_channels) - self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), - self.filter_shape) - self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), - self.output_channels) + self.kernel = self.param( + 'kernel', nn.initializers.xavier_uniform(), self.filter_shape + ) + self.bias = self.param( + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels + ) @nn.compact def __call__(self, inputs, paddings): @@ -235,7 +247,8 @@ def __call__(self, inputs, paddings): padding=self.padding, rhs_dilation=(1, 1), dimension_numbers=('NHWC', 'HWIO', 'NHWC'), - feature_group_count=feature_group_count) + feature_group_count=feature_group_count, + ) outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,)) outputs = nn.relu(outputs) @@ -250,19 +263,21 @@ def __call__(self, inputs, paddings): rhs=jnp.ones([1, 1, 1]), window_strides=self.filter_stride[:1], padding=[(0, pad_len)], - dimension_numbers=('NHC', 'HIO', 'NHC')) + dimension_numbers=('NHC', 'HIO', 'NHC'), + ) out_padding = jnp.squeeze(out_padding, axis=-1) # Mask outputs by correct paddings to ensure padded elements in inputs map # to padded value in outputs. - outputs = outputs * (1.0 - - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)) + outputs = outputs * ( + 1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1) + ) return outputs, out_padding class FeedForwardModule(nn.Module): - """Feedforward block of conformer layer. - """ + """Feedforward block of conformer layer.""" + config: ConformerConfig @nn.compact @@ -274,23 +289,25 @@ def __call__(self, inputs, padding_mask=None, train=False): inputs = nn.Dense( config.encoder_dim * config.feed_forward_expansion_factor, use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - inputs) + kernel_init=nn.initializers.xavier_uniform(), + )(inputs) inputs = model_utils.ACTIVATIONS[self.config.activation_function](inputs) inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)( - inputs, deterministic=not train) + inputs, deterministic=not train + ) inputs = inputs * padding_mask inputs = nn.Dense( config.encoder_dim, use_bias=True, - kernel_init=nn.initializers.xavier_uniform())( - inputs) + kernel_init=nn.initializers.xavier_uniform(), + )(inputs) inputs = inputs * padding_mask inputs = nn.Dropout(rate=config.feed_forward_residual_dropout_rate)( - inputs, deterministic=not train) + inputs, deterministic=not train + ) return inputs @@ -302,6 +319,7 @@ class AddPositionalEmbedding(nn.Module): max_len: maximum possible length for the input posemb_init: positional embedding initializer """ + min_timescale: int = 1 max_timescale: int = 10_000 embedding_dim: int = 512 @@ -310,22 +328,24 @@ class AddPositionalEmbedding(nn.Module): def __call__(self, seq_length): position = jnp.arange(seq_length, dtype=jnp.float32)[jnp.newaxis, :] num_timescales = self.embedding_dim // 2 - log_timescale_increment = ( - math.log(float(self.max_timescale) / float(self.min_timescale)) / - jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1)) + log_timescale_increment = math.log( + float(self.max_timescale) / float(self.min_timescale) + ) / jnp.maximum(jnp.asarray(num_timescales, dtype=jnp.float32) - 1, 1) inv_timescales = self.min_timescale * jnp.exp( - jnp.arange(num_timescales, dtype=jnp.float32) * - -log_timescale_increment) + jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment + ) scaled_time = ( - position[:, :, jnp.newaxis] * - inv_timescales[jnp.newaxis, jnp.newaxis, :]) + position[:, :, jnp.newaxis] + * inv_timescales[jnp.newaxis, jnp.newaxis, :] + ) signal = jnp.concatenate( - [jnp.sin(scaled_time), jnp.cos(scaled_time)], - axis=2).astype(jnp.float32) + [jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=2 + ).astype(jnp.float32) # Force usage of `np` rather than `jnp` to compute static values at trace # time. - signal = jnp.pad(signal, - [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]]) + signal = jnp.pad( + signal, [[0, 0], [0, 0], [0, np.mod(self.embedding_dim, 2)]] + ) return signal @@ -333,6 +353,7 @@ def __call__(self, seq_length): # https://github.com/tensorflow/lingvo/blob/7de4ca8fff3cb28c2ecb21bbd7b02a964ce727f7/lingvo/jax/layers/attentions.py#L201 class QueryScaler(nn.Module): """A layer to scale individual dims of the query attention matrix.""" + dim: int = 0 def setup(self): @@ -342,8 +363,10 @@ def setup(self): def __call__(self, inputs): inputs_shape = inputs.shape if inputs_shape[-1] != self.dim: - raise ValueError('QueryScaler expects inputs to have' - ' same last dimension as scaling param.') + raise ValueError( + 'QueryScaler expects inputs to have' + ' same last dimension as scaling param.' + ) # 1.0/jax.nn.softplus(0.0) = 1.442695041. Hard code this number so that we # can avoid unnecessary XLA op fusion mess on TPU. @@ -358,18 +381,20 @@ def __call__(self, inputs): # Modifying flax linen default dot product attention function to add # query scaling, reference to original function here : # https://github.com/google/flax/blob/a9af38085a7a49b571cf37d375060fd683e74972/flax/linen/attention.py#L121 -def dot_product_attention(query, - key, - value, - bias=None, - mask=None, - broadcast_dropout=True, - dropout_rng=None, - dropout_rate=0., - deterministic=False, - dtype=jnp.float32, - precision=None, - temperature=1.0): +def dot_product_attention( + query, + key, + value, + bias=None, + mask=None, + broadcast_dropout=True, + dropout_rng=None, + dropout_rate=0.0, + deterministic=False, + dtype=jnp.float32, + precision=None, + temperature=1.0, +): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on @@ -380,21 +405,19 @@ def dot_product_attention(query, Note: query, key, value needn't have any batch dimensions. Args: - query: queries for calculating attention with shape of - `[batch..., q_length, num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of - `[batch..., kv_length, num_heads, qk_depth_per_head]`. - value: values to be used in attention with shape of - `[batch..., kv_length, num_heads, v_depth_per_head]`. + query: queries for calculating attention with shape of `[batch..., q_length, + num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of `[batch..., kv_length, + num_heads, qk_depth_per_head]`. + value: values to be used in attention with shape of `[batch..., kv_length, + num_heads, v_depth_per_head]`. bias: bias for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks, padding masks, - proximity bias, etc. + shape `[batch..., num_heads, q_length, kv_length]`. This can be used for + incorporating causal masks, padding masks, proximity bias, etc. mask: mask for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks. - Attention weights are masked out if their corresponding mask value - is `False`. + shape `[batch..., num_heads, q_length, kv_length]`. This can be used for + incorporating causal masks. Attention weights are masked out if their + corresponding mask value is `False`. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate @@ -408,23 +431,36 @@ def dot_product_attention(query, Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. """ assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' - assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( - 'q, k, v batch dims must match.') - assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( - 'q, k, v num_heads must match.') + assert ( + query.shape[:-3] == key.shape[:-3] == value.shape[:-3] + ), 'q, k, v batch dims must match.' + assert ( + query.shape[-2] == key.shape[-2] == value.shape[-2] + ), 'q, k, v num_heads must match.' assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' # compute attention weights query = QueryScaler(dim=query.shape[-1])(query) - attn_weights = nn.dot_product_attention_weights(query, key, bias, mask, - broadcast_dropout, - dropout_rng, dropout_rate, - deterministic, dtype, - precision) + attn_weights = nn.dot_product_attention_weights( + query, + key, + bias, + mask, + broadcast_dropout, + dropout_rng, + dropout_rate, + deterministic, + dtype, + precision, + ) # return weighted sum over values for each query position - return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value, - precision=precision) * temperature + return ( + jnp.einsum( + '...hqk,...khd->...qhd', attn_weights, value, precision=precision + ) + * temperature + ) class MultiHeadedSelfAttention(nn.Module): @@ -436,6 +472,7 @@ class MultiHeadedSelfAttention(nn.Module): Note: this attention implementation uses a learned scale parameter to scale query matrix before passing it to flax attention module. """ + config: ConformerConfig = None def setup(self): @@ -444,7 +481,8 @@ def setup(self): num_heads=self.config.num_attention_heads, hidden_dim=self.config.encoder_dim, input_dim=self.config.encoder_dim, - dim_per_head=dim_per_head) + dim_per_head=dim_per_head, + ) def _get_large_negative_number(self, dtype): if jnp.issubdtype(dtype, jnp.inexact): @@ -467,7 +505,8 @@ def __call__(self, inputs, paddings, train): config = self.config mask_paddings = 1 - paddings attention_mask = nn.make_attention_mask( - mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32) + mask_paddings > 0, mask_paddings > 0, dtype=jnp.float32 + ) inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -478,10 +517,12 @@ def __call__(self, inputs, paddings, train): query_vec=inputs, key_vec=inputs, value_vec=inputs, - atten_mask=atten_mask)[0] + atten_mask=atten_mask, + )[0] else: attn_fn = functools.partial( - dot_product_attention, temperature=config.attn_temperature) + dot_product_attention, temperature=config.attn_temperature + ) result = nn.SelfAttention( num_heads=config.num_attention_heads, qkv_features=config.encoder_dim, @@ -493,11 +534,12 @@ def __call__(self, inputs, paddings, train): broadcast_dropout=False, attention_fn=attn_fn, dropout_rate=config.attention_dropout_rate, - deterministic=not train)(inputs, mask=attention_mask) + deterministic=not train, + )(inputs, mask=attention_mask) result = nn.Dropout( - rate=config.attention_residual_dropout_rate, deterministic=not train)( - result) + rate=config.attention_residual_dropout_rate, deterministic=not train + )(result) return result @@ -514,16 +556,19 @@ class BatchNorm(nn.Module): and the corresponding defaults for momentum and epsilon have been copied over from lingvo. """ + config: ConformerConfig def setup(self): dim = self.config.encoder_dim dtype = self.config.dtype - self.ra_mean = self.variable('batch_stats', 'mean', - lambda s: jnp.zeros(s, dtype), dim) - self.ra_var = self.variable('batch_stats', 'var', - lambda s: jnp.ones(s, dtype), dim) + self.ra_mean = self.variable( + 'batch_stats', 'mean', lambda s: jnp.zeros(s, dtype), dim + ) + self.ra_var = self.variable( + 'batch_stats', 'var', lambda s: jnp.ones(s, dtype), dim + ) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @@ -541,7 +586,8 @@ def __call__(self, inputs, input_paddings, train): mask = 1.0 - padding sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=False) count_v = jnp.sum( - jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=False) + jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=False + ) count_v = jnp.maximum(count_v, 1.0) mean = sum_v / count_v @@ -549,14 +595,13 @@ def __call__(self, inputs, input_paddings, train): sum_vv = jnp.sum( (inputs - mean) * (inputs - mean) * mask, axis=reduce_over_dims, - keepdims=False) + keepdims=False, + ) var = sum_vv / count_v - self.ra_mean.value = momentum * self.ra_mean.value + ( - 1 - momentum) * mean - self.ra_var.value = momentum * self.ra_var.value + ( - 1 - momentum) * var + self.ra_mean.value = momentum * self.ra_mean.value + (1 - momentum) * mean + self.ra_var.value = momentum * self.ra_var.value + (1 - momentum) * var else: mean = self.ra_mean.value var = self.ra_var.value @@ -589,6 +634,7 @@ class ConvolutionBlock(nn.Module): | output """ + config: ConformerConfig @nn.compact @@ -599,12 +645,14 @@ def __call__(self, inputs, input_paddings, train): input_gated1 = nn.Dense( config.encoder_dim, kernel_init=nn.initializers.xavier_uniform(), - use_bias=True)(inputs) + use_bias=True, + )(inputs) input_gated2 = nn.Dense( config.encoder_dim, kernel_init=nn.initializers.xavier_uniform(), - use_bias=True)(inputs) + use_bias=True, + )(inputs) inputs = input_gated1 * jax.nn.sigmoid(input_gated2) inputs = inputs * (1 - jnp.expand_dims(input_paddings, -1)) @@ -616,17 +664,19 @@ def __call__(self, inputs, input_paddings, train): padding='SAME', feature_group_count=config.encoder_dim, use_bias=False, - kernel_init=nn.initializers.xavier_uniform())(inputs) + kernel_init=nn.initializers.xavier_uniform(), + )(inputs) inputs = BatchNorm(config)(inputs, input_paddings, train) inputs = model_utils.ACTIVATIONS[self.config.activation_function](inputs) inputs = nn.Dense( - config.encoder_dim, - kernel_init=nn.initializers.xavier_uniform())(inputs) + config.encoder_dim, kernel_init=nn.initializers.xavier_uniform() + )(inputs) inputs = nn.Dropout( - rate=config.conv_residual_dropout_rate, deterministic=not train)(inputs) + rate=config.conv_residual_dropout_rate, deterministic=not train + )(inputs) return inputs @@ -641,8 +691,8 @@ class ConformerBlock(nn.Module): x = x + 0.5 * FeedForward(x) y = layer_norm(x) - """ + config: ConformerConfig @nn.compact @@ -651,15 +701,18 @@ def __call__(self, inputs, input_paddings, train): padding_mask = jnp.expand_dims(1 - input_paddings, -1) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( - inputs, padding_mask, train) + inputs, padding_mask, train + ) inputs = inputs + MultiHeadedSelfAttention(config=self.config)( - inputs, input_paddings, train) + inputs, input_paddings, train + ) inputs = inputs + ConvolutionBlock(config)(inputs, input_paddings, train) inputs = inputs + 0.5 * FeedForwardModule(config=self.config)( - inputs, padding_mask, train) + inputs, padding_mask, train + ) if config.enable_conformer_post_layer_norm: inputs = LayerNorm(dim=config.encoder_dim)(inputs) @@ -674,6 +727,7 @@ class ConformerEncoderDecoder(nn.Module): for each time step. The output is then fed into a CTC loss which eliminates the need for alignment with targets. """ + config: ConformerConfig def setup(self): @@ -685,8 +739,8 @@ def setup(self): time_mask_max_frames=config.time_mask_max_frames, time_mask_max_ratio=config.time_mask_max_ratio, time_masks_per_frame=config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=config - .use_dynamic_time_mask_max_frames) + use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames, + ) @nn.compact def __call__(self, inputs, input_paddings, train): @@ -700,8 +754,8 @@ def __call__(self, inputs, input_paddings, train): outputs, output_paddings = preprocessor.MelFilterbankFrontend( preprocessing_config, per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)( - outputs, output_paddings) + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR, + )(outputs, output_paddings) # Ablate random parts of input along temporal and frequency dimension # following the specaug procedure in https://arxiv.org/abs/1904.08779. @@ -711,8 +765,8 @@ def __call__(self, inputs, input_paddings, train): # Subsample input by a factor of 4 by performing strided convolutions. outputs, output_paddings = Subsample( encoder_dim=config.encoder_dim, - input_dropout_rate=config.input_dropout_rate)(outputs, output_paddings, - train) + input_dropout_rate=config.input_dropout_rate, + )(outputs, output_paddings, train) # Run the conformer encoder layers. for _ in range(config.num_encoder_layers): @@ -724,7 +778,8 @@ def __call__(self, inputs, input_paddings, train): outputs = nn.Dense( config.vocab_size, use_bias=True, - kernel_init=nn.initializers.xavier_uniform())(outputs) + kernel_init=nn.initializers.xavier_uniform(), + )(outputs) return outputs, output_paddings @@ -751,10 +806,13 @@ def collapse_and_remove_blanks(self, labels, seq_length, blank_id: int = 0): labels = (labels * blank_mask).astype(labels.dtype) # Mask labels that don't equal previous label. - label_mask = jnp.concatenate([ - jnp.ones_like(labels[:, :1], dtype=jnp.int32), - jnp.not_equal(labels[:, 1:], labels[:, :-1]) - ], axis=1) + label_mask = jnp.concatenate( + [ + jnp.ones_like(labels[:, :1], dtype=jnp.int32), + jnp.not_equal(labels[:, 1:], labels[:, :-1]), + ], + axis=1, + ) # Filter labels that aren't in the original sequence. maxlen = labels.shape[1] @@ -790,8 +848,10 @@ def collapse_and_remove_blanks(self, labels, seq_length, blank_id: int = 0): # Reshape back to square batch. batch_size = labels.shape[0] new_shape = [batch_size, new_maxlen] - return (jnp.reshape(flat, new_shape).astype(labels.dtype), - new_seq_len.astype(seq_length.dtype)) + return ( + jnp.reshape(flat, new_shape).astype(labels.dtype), + new_seq_len.astype(seq_length.dtype), + ) def greedy_decode(self, logits, logit_paddings): per_frame_max = jnp.argmax(logits, axis=-1) @@ -804,22 +864,21 @@ def evaluate_batch(self, params, batch_stats, batch): """Evaluates cross_entopy on the given batch.""" logits, logit_paddings = self.flax_module.apply( - { - 'params': params, - 'batch_stats': batch_stats - }, + {'params': params, 'batch_stats': batch_stats}, batch['inputs'], batch['input_paddings'], train=False, - mutable=False) + mutable=False, + ) labels = batch['targets'] label_paddings = batch['target_paddings'] - (objective_numerator, objective_denominator) = self.loss_fn( - logits, logit_paddings, labels, label_paddings) + objective_numerator, objective_denominator = self.loss_fn( + logits, logit_paddings, labels, label_paddings + ) - normalized_loss = (objective_numerator / (objective_denominator)) + normalized_loss = objective_numerator / (objective_denominator) hyps, hyp_paddings = self.greedy_decode(logits, logit_paddings) return self.metrics_bundle.single_from_model_output( @@ -827,7 +886,8 @@ def evaluate_batch(self, params, batch_stats, batch): hyps=hyps, hyp_paddings=hyp_paddings, targets=labels, - target_paddings=label_paddings) + target_paddings=label_paddings, + ) def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): """Return CTC loss.""" @@ -835,24 +895,23 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): # For more information on flax.linen.Module.apply, see the docs at # https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply. (outputs, output_paddings), new_batch_stats = self.flax_module.apply( - { - 'params': params, - 'batch_stats': batch_stats - }, + {'params': params, 'batch_stats': batch_stats}, batch['inputs'], batch['input_paddings'], rngs={'dropout': dropout_rng}, mutable=['batch_stats'], - train=True) + train=True, + ) labels = batch['targets'] label_paddings = batch['target_paddings'] - (objective_numerator, objective_denominator) = self.loss_fn( - outputs, output_paddings, labels, label_paddings) + objective_numerator, objective_denominator = self.loss_fn( + outputs, output_paddings, labels, label_paddings + ) # epsilon added to handle empty batch case if we encounter one. - objective_value = (objective_numerator / (objective_denominator + 1e-9)) + objective_value = objective_numerator / (objective_denominator + 1e-9) return objective_value, new_batch_stats def apply_on_batch(self, params, batch_stats, batch, **apply_kwargs): @@ -863,10 +922,8 @@ def apply_on_batch(self, params, batch_stats, batch, **apply_kwargs): variables = {'params': params} return self.flax_module.apply( - variables, - batch['inputs'], - batch['input_paddings'], - **apply_kwargs) + variables, batch['inputs'], batch['input_paddings'], **apply_kwargs + ) def build_flax_module(self): config = ConformerConfig( @@ -881,18 +938,16 @@ def build_flax_module(self): time_mask_max_frames=self.hps.time_mask_max_frames, time_mask_max_ratio=self.hps.time_mask_max_ratio, time_masks_per_frame=self.hps.time_masks_per_frame, - use_dynamic_time_mask_max_frames=self.hps - .use_dynamic_time_mask_max_frames, + use_dynamic_time_mask_max_frames=self.hps.use_dynamic_time_mask_max_frames, use_specaug=self.hps.use_specaug, attention_residual_dropout_rate=self.hps.residual_dropout_rate, feed_forward_residual_dropout_rate=self.hps.residual_dropout_rate, input_dropout_rate=self.hps.input_dropout_rate, - enable_conformer_post_layer_norm=self.hps - .enable_conformer_post_layer_norm, + enable_conformer_post_layer_norm=self.hps.enable_conformer_post_layer_norm, enable_decoder_pre_layer_norm=self.hps.enable_decoder_pre_layer_norm, use_lingvo_attention=self.hps.use_lingvo_attention, activation_function=self.hps.activation_function, - ) + ) module = ConformerEncoderDecoder(config) return module @@ -933,19 +988,17 @@ def build_flax_module(self): time_mask_max_frames=self.hps.time_mask_max_frames, time_mask_max_ratio=self.hps.time_mask_max_ratio, time_masks_per_frame=self.hps.time_masks_per_frame, - use_dynamic_time_mask_max_frames=self.hps - .use_dynamic_time_mask_max_frames, + use_dynamic_time_mask_max_frames=self.hps.use_dynamic_time_mask_max_frames, use_specaug=self.hps.use_specaug, attention_residual_dropout_rate=self.hps.dropout_rate, feed_forward_residual_dropout_rate=self.hps.dropout_rate, input_dropout_rate=aux_dropout_rate, - enable_conformer_post_layer_norm=self.hps - .enable_conformer_post_layer_norm, + enable_conformer_post_layer_norm=self.hps.enable_conformer_post_layer_norm, enable_decoder_pre_layer_norm=self.hps.enable_decoder_pre_layer_norm, use_lingvo_attention=self.hps.use_lingvo_attention, attn_temperature=self.hps.attn_temperature, activation_function=self.hps.activation_function, - ) + ) module = ConformerEncoderDecoder(config) return module diff --git a/init2winit/model_lib/convolutional_autoencoder.py b/init2winit/model_lib/convolutional_autoencoder.py index 607b9bf5..18d89588 100644 --- a/init2winit/model_lib/convolutional_autoencoder.py +++ b/init2winit/model_lib/convolutional_autoencoder.py @@ -28,7 +28,6 @@ from jax import numpy as jnp from ml_collections.config_dict import config_dict - # small test hparams from # https://blog.keras.io/building-autoencoders-in-keras.html DEFAULT_HPARAMS = config_dict.ConfigDict( @@ -51,7 +50,8 @@ }, activation_function='relu', model_dtype='float32', - )) + ) +) class ConvAutoEncoder(nn.Module): @@ -61,6 +61,7 @@ class ConvAutoEncoder(nn.Module): [batch_size_per_device, *input_shape] where input_shape may be of arbitrary rank. The model flatten the input before applying a dense layer. """ + output_shape: Sequence[int] encoder: Dict[str, Any] decoder: Dict[str, Any] @@ -79,7 +80,8 @@ def __call__(self, x, train): ] if len(set(len(self.encoder[k]) for k in encoder_keys)) > 1: raise ValueError( - 'The elements in encoder dict do not have the same length.') + 'The elements in encoder dict do not have the same length.' + ) decoder_keys = [ 'filter_sizes', @@ -90,19 +92,23 @@ def __call__(self, x, train): ] if len(set(len(self.decoder[k]) for k in decoder_keys)) > 1: raise ValueError( - 'The elements in decoder dict do not have the same length.') + 'The elements in decoder dict do not have the same length.' + ) # encoder for i in range(len(self.encoder['filter_sizes'])): x = nn.Conv( self.encoder['filter_sizes'][i], self.encoder['kernel_sizes'][i], - padding=self.encoder['kernel_paddings'][i])(x) + padding=self.encoder['kernel_paddings'][i], + )(x) x = model_utils.ACTIVATIONS[self.encoder['activations'][i]](x) x = nn.max_pool( - x, self.encoder['window_sizes'][i], + x, + self.encoder['window_sizes'][i], strides=self.encoder['strides'][i], - padding=self.encoder['window_paddings'][i]) + padding=self.encoder['window_paddings'][i], + ) # decoder for i in range(len(self.decoder['filter_sizes'])): @@ -110,7 +116,8 @@ def __call__(self, x, train): self.decoder['filter_sizes'][i], self.decoder['kernel_sizes'][i], self.decoder['window_sizes'][i], - padding=self.decoder['paddings'][i])(x) + padding=self.decoder['paddings'][i], + )(x) x = model_utils.ACTIVATIONS[self.decoder['activations'][i]](x) return x @@ -122,7 +129,8 @@ def build_flax_module(self): return ConvAutoEncoder( output_shape=self.hps.output_shape, encoder=self.hps.encoder, - decoder=self.hps.decoder) + decoder=self.hps.decoder, + ) def get_fake_inputs(self, hps): """Helper method solely for the purpose of initialzing the model.""" diff --git a/init2winit/model_lib/deepspeech.py b/init2winit/model_lib/deepspeech.py index 89b33b8b..a80fe5b5 100644 --- a/init2winit/model_lib/deepspeech.py +++ b/init2winit/model_lib/deepspeech.py @@ -37,7 +37,6 @@ import jax.numpy as jnp from ml_collections.config_dict import config_dict - Array = jnp.ndarray StateType = Union[Array, Tuple[Array, ...]] PRNGKey = Any @@ -68,7 +67,9 @@ bidirectional=True, enable_subsampling_batchnorm=False, enable_synced_batchnorm=False, - layernorm_everywhere=False)) + layernorm_everywhere=False, + ) +) DEFAULT_HPARAMS = config_dict.ConfigDict( @@ -94,12 +95,15 @@ bidirectional=True, enable_subsampling_batchnorm=False, enable_synced_batchnorm=False, - layernorm_everywhere=False)) + layernorm_everywhere=False, + ) +) @struct.dataclass class DeepspeechConfig: """Global hyperparameters used to minimize obnoxious kwarg plumbing.""" + vocab_size: int = 0 dtype: Any = jnp.float32 encoder_dim: int = 0 @@ -136,6 +140,7 @@ class Subsample(nn.Module): encoder_dim: model dimension of conformer. input_dropout_rate: dropout rate for inputs. """ + config: DeepspeechConfig @nn.compact @@ -164,21 +169,24 @@ def __call__(self, inputs, output_paddings, train): output_channels=config.encoder_dim, enable_batchnorm=config.enable_subsampling_batchnorm, enable_synced_batchnorm=config.enable_synced_batchnorm, - activation=config.activation + activation=config.activation, )(outputs, output_paddings, train) batch_size, subsampled_lengths, subsampled_dims, channels = outputs.shape outputs = jnp.reshape( - outputs, (batch_size, subsampled_lengths, subsampled_dims * channels)) + outputs, (batch_size, subsampled_lengths, subsampled_dims * channels) + ) outputs = nn.Dense( config.encoder_dim, use_bias=True, - kernel_init=nn.initializers.xavier_uniform())(outputs) + kernel_init=nn.initializers.xavier_uniform(), + )(outputs) outputs = nn.Dropout( - rate=config.input_dropout_rate, deterministic=not train)(outputs) + rate=config.input_dropout_rate, deterministic=not train + )(outputs) return outputs, output_paddings @@ -190,6 +198,7 @@ class Conv2dSubsampling(nn.Module): 2) Also performs strided convolution over input_paddings to return the correct paddings for downstream layers. """ + input_channels: int = 0 output_channels: int = 0 filter_stride: List[int] = (2, 2) @@ -204,10 +213,12 @@ class Conv2dSubsampling(nn.Module): def setup(self): self.filter_shape = (3, 3, self.input_channels, self.output_channels) - self.kernel = self.param('kernel', nn.initializers.xavier_uniform(), - self.filter_shape) - self.bias = self.param('bias', lambda rng, s: jnp.zeros(s, jnp.float32), - self.output_channels) + self.kernel = self.param( + 'kernel', nn.initializers.xavier_uniform(), self.filter_shape + ) + self.bias = self.param( + 'bias', lambda rng, s: jnp.zeros(s, jnp.float32), self.output_channels + ) @nn.compact def __call__(self, inputs, paddings, train): @@ -220,15 +231,19 @@ def __call__(self, inputs, paddings, train): padding=self.padding, rhs_dilation=(1, 1), dimension_numbers=('NHWC', 'HWIO', 'NHWC'), - feature_group_count=feature_group_count) + feature_group_count=feature_group_count, + ) outputs += jnp.reshape(self.bias, (1,) * (outputs.ndim - 1) + (-1,)) if self.enable_batchnorm: - outputs = BatchNorm(self.encoder_dim, self.dtype, - self.batch_norm_momentum, self.batch_norm_epsilon, - self.enable_synced_batchnorm)( - outputs, input_paddings=None, train=train) + outputs = BatchNorm( + self.encoder_dim, + self.dtype, + self.batch_norm_momentum, + self.batch_norm_epsilon, + self.enable_synced_batchnorm, + )(outputs, input_paddings=None, train=train) if self.activation in model_utils.ACTIVATIONS: outputs = model_utils.ACTIVATIONS[self.activation](outputs) @@ -245,19 +260,22 @@ def __call__(self, inputs, paddings, train): rhs=jnp.ones([1, 1, 1]), window_strides=self.filter_stride[:1], padding=[(0, pad_len)], - dimension_numbers=('NHC', 'HIO', 'NHC')) + dimension_numbers=('NHC', 'HIO', 'NHC'), + ) out_padding = jnp.squeeze(out_padding, axis=-1) # Mask outputs by correct paddings to ensure padded elements in inputs map # to padded value in outputs. - outputs = outputs * (1.0 - - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1)) + outputs = outputs * ( + 1.0 - jnp.expand_dims(jnp.expand_dims(out_padding, -1), -1) + ) return outputs, out_padding class FeedForwardModule(nn.Module): """Feedforward block of conformer layer.""" + config: DeepspeechConfig @nn.compact @@ -273,12 +291,14 @@ def __call__(self, inputs, input_paddings=None, train=False): config.dtype, config.batch_norm_momentum, config.batch_norm_epsilon, - config.enable_synced_batchnorm)(inputs, input_paddings, train) + config.enable_synced_batchnorm, + )(inputs, input_paddings, train) inputs = nn.Dense( config.encoder_dim, use_bias=True, - kernel_init=nn.initializers.xavier_uniform())(inputs) + kernel_init=nn.initializers.xavier_uniform(), + )(inputs) if config.activation in model_utils.ACTIVATIONS: inputs = model_utils.ACTIVATIONS[config.activation](inputs) @@ -287,7 +307,8 @@ def __call__(self, inputs, input_paddings=None, train=False): inputs *= padding_mask inputs = nn.Dropout(rate=config.feed_forward_dropout_rate)( - inputs, deterministic=not train) + inputs, deterministic=not train + ) return inputs @@ -302,6 +323,7 @@ class LayerNorm(nn.Module): zeros, this differs from default flax implementation of multiplying by scale and initializing to ones. """ + dim: int = 0 epsilon: float = 1e-6 @@ -333,6 +355,7 @@ class BatchNorm(nn.Module): and the corresponding defaults for momentum and epsilon have been copied over from lingvo. """ + encoder_dim: int = 0 dtype: Any = jnp.float32 batch_norm_momentum: float = 0.999 @@ -343,10 +366,12 @@ def setup(self): dim = self.encoder_dim dtype = self.dtype - self.ra_mean = self.variable('batch_stats', 'mean', - lambda s: jnp.zeros(s, dtype), dim) - self.ra_var = self.variable('batch_stats', 'var', - lambda s: jnp.ones(s, dtype), dim) + self.ra_mean = self.variable( + 'batch_stats', 'mean', lambda s: jnp.zeros(s, dtype), dim + ) + self.ra_var = self.variable( + 'batch_stats', 'var', lambda s: jnp.ones(s, dtype), dim + ) self.gamma = self.param('scale', nn.initializers.zeros, dim, dtype) self.beta = self.param('bias', nn.initializers.zeros, dim, dtype) @@ -375,7 +400,8 @@ def __call__(self, inputs, input_paddings=None, train=False): mask = 1.0 - padding sum_v = jnp.sum(inputs * mask, axis=reduce_over_dims, keepdims=False) count_v = jnp.sum( - jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=False) + jnp.ones_like(inputs) * mask, axis=reduce_over_dims, keepdims=False + ) count_v = jnp.maximum(count_v, 1.0) mean = sum_v / count_v @@ -383,7 +409,8 @@ def __call__(self, inputs, input_paddings=None, train=False): sum_vv = jnp.sum( (inputs - mean) * (inputs - mean) * mask, axis=reduce_over_dims, - keepdims=False) + keepdims=False, + ) var = sum_vv / count_v @@ -402,6 +429,7 @@ def __call__(self, inputs, input_paddings=None, train=False): + @jax.vmap def flip_sequences(inputs: Array, lengths: Array) -> Array: """Flips a sequence of inputs along the time dimension. @@ -451,6 +479,7 @@ class GenericRNNSequenceEncoder(nn.Module): greater than zero, you must use an RNN cell that implements `RecurrentDropoutCell` such as RecurrentDropoutOptimizedLSTMCell. """ + hidden_size: int cell_type: Type[nn.RNNCellBase] cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() @@ -464,9 +493,15 @@ def setup(self): variable_broadcast='params', in_axes=(1, flax.core.axes_scan.broadcast, flax.core.axes_scan.broadcast), out_axes=1, - split_rngs={'params': False}) - def unroll_cell(self, cell_state: StateType, inputs: Array, - recurrent_dropout_mask: Optional[Array], deterministic: bool): + split_rngs={'params': False}, + ) + def unroll_cell( + self, + cell_state: StateType, + inputs: Array, + recurrent_dropout_mask: Optional[Array], + deterministic: bool, + ): """Unrolls a recurrent cell over an input sequence. Args: @@ -487,12 +522,14 @@ def unroll_cell(self, cell_state: StateType, inputs: Array, new_cell_state, output = self.cell(cell_state, inputs) return new_cell_state, (new_cell_state, output) - def __call__(self, - inputs: Array, - lengths: Array, - initial_state: StateType, - reverse: bool = False, - deterministic: bool = False): + def __call__( + self, + inputs: Array, + lengths: Array, + initial_state: StateType, + reverse: bool = False, + deterministic: bool = False, + ): """Unrolls the RNN cell over the inputs. Arguments: @@ -516,11 +553,12 @@ def __call__(self, inputs = flip_sequences(inputs, lengths) recurrent_dropout_mask = None - _, (cell_states, outputs) = self.unroll_cell(initial_state, inputs, - recurrent_dropout_mask, - deterministic) + _, (cell_states, outputs) = self.unroll_cell( + initial_state, inputs, recurrent_dropout_mask, deterministic + ) final_state = jax.tree.map( - lambda x: x[jnp.arange(inputs.shape[0]), lengths - 1], cell_states) + lambda x: x[jnp.arange(inputs.shape[0]), lengths - 1], cell_states + ) if reverse: outputs = flip_sequences(outputs, lengths) @@ -550,11 +588,12 @@ class GenericRNN(nn.Module): concatenate the outputs from the two directions. cell_kwargs: Optional keyword arguments to instantiate the cell with. """ + cell_type: Type[nn.RNNCellBase] hidden_size: int num_layers: int = 1 - dropout_rate: float = 0. - recurrent_dropout_rate: float = 0. + dropout_rate: float = 0.0 + recurrent_dropout_rate: float = 0.0 bidirectional: bool = False cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() @@ -564,7 +603,7 @@ def __call__( inputs: Array, lengths: Array, initial_states: Optional[Sequence[StateType]] = None, - deterministic: bool = False + deterministic: bool = False, ) -> Tuple[Array, Sequence[StateType]]: """Processes the input sequence using the recurrent cell. @@ -573,11 +612,11 @@ def __call__( lengths: The lengths of each sequence in the batch. [batch_size] initial_states: The initial states for the cells. You must provide `num_layers` initial states (when using bidirectional, `num_layers * - 2`). - These must be ordered in the following way: (layer_0_forward, - layer_0_backward, layer_1_forward, layer_1_backward, ...). If None, - all initial states will be initialized with zeros. + 2`). These must be ordered in the following way: (layer_0_forward, + layer_0_backward, layer_1_forward, layer_1_backward, ...). If None, all + initial states will be initialized with zeros. deterministic: Disables dropout between layers when set to True. + Returns: The sequence of all outputs for the final layer, and a list of final states for each cell and direction. Directions are alternated (first @@ -616,11 +655,13 @@ def __call__( cell_kwargs=self.cell_kwargs, hidden_size=self.hidden_size, recurrent_dropout_rate=self.recurrent_dropout_rate, - name=f'{self.name}SequenceEncoder_{cell_idx}')( - inputs, - lengths, - initial_state=initial_states[cell_idx], - deterministic=deterministic) + name=f'{self.name}SequenceEncoder_{cell_idx}', + )( + inputs, + lengths, + initial_state=initial_states[cell_idx], + deterministic=deterministic, + ) final_states.append(final_state) cell_idx += 1 @@ -631,12 +672,14 @@ def __call__( cell_kwargs=self.cell_kwargs, hidden_size=self.hidden_size, recurrent_dropout_rate=self.recurrent_dropout_rate, - name=f'{self.name}SequenceEncoder_{cell_idx}')( - inputs, - lengths, - initial_state=initial_states[cell_idx], - reverse=True, - deterministic=deterministic) + name=f'{self.name}SequenceEncoder_{cell_idx}', + )( + inputs, + lengths, + initial_state=initial_states[cell_idx], + reverse=True, + deterministic=deterministic, + ) outputs = jnp.concatenate([outputs, backward_outputs], axis=-1) final_states.append(backward_final_state) cell_idx += 1 @@ -665,10 +708,11 @@ class LSTM(nn.Module): best for hidden sizes up to 2048. cell_kwargs: Optional keyword arguments to instantiate the cell with. """ + hidden_size: int num_layers: int = 1 - dropout_rate: float = 0. - recurrent_dropout_rate: float = 0. + dropout_rate: float = 0.0 + recurrent_dropout_rate: float = 0.0 bidirectional: bool = False cell_type: Any = nn.OptimizedLSTMCell cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() @@ -679,7 +723,8 @@ def __call__( inputs: Array, lengths: Array, initial_states: Optional[Sequence[StateType]] = None, - deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]: + deterministic: bool = False, + ) -> Tuple[Array, Sequence[StateType]]: """Processes an input sequence with an LSTM cell. Example usage: @@ -695,8 +740,8 @@ def __call__( initial_states: The initial states for the cells. You must provide `num_layers` initial states (when using bidirectional, `num_layers * 2`). These must be ordered in the following way: (layer_0_forward, - layer_0_backward, layer_1_forward, layer_1_backward, ...). If None, - all initial states will be initialized with zeros. + layer_0_backward, layer_1_forward, layer_1_backward, ...). If None, all + initial states will be initialized with zeros. deterministic: Disables dropout between layers when set to True. Returns: @@ -712,11 +757,13 @@ def __call__( recurrent_dropout_rate=self.recurrent_dropout_rate, bidirectional=self.bidirectional, cell_kwargs=self.cell_kwargs, - name='LSTM')( - inputs, - lengths, - initial_states=initial_states, - deterministic=deterministic) + name='LSTM', + )( + inputs, + lengths, + initial_states=initial_states, + deterministic=deterministic, + ) class BatchRNN(nn.Module): @@ -724,6 +771,7 @@ class BatchRNN(nn.Module): High level overview: """ + config: DeepspeechConfig @nn.compact @@ -763,6 +811,7 @@ class DeepSpeechEncoderDecoder(nn.Module): for each time step. The output is then fed into a CTC loss which eliminates the need for alignment with targets. """ + config: DeepspeechConfig def setup(self): @@ -774,7 +823,7 @@ def setup(self): time_mask_max_frames=config.time_mask_max_frames, time_mask_max_ratio=config.time_mask_max_ratio, time_masks_per_frame=config.time_masks_per_frame, - use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames + use_dynamic_time_mask_max_frames=config.use_dynamic_time_mask_max_frames, ) @nn.compact @@ -789,8 +838,8 @@ def __call__(self, inputs, input_paddings, train): outputs, output_paddings = preprocessor.MelFilterbankFrontend( preprocessing_config, per_bin_mean=preprocessor.LIBRISPEECH_MEAN_VECTOR, - per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR)(outputs, - output_paddings) + per_bin_stddev=preprocessor.LIBRISPEECH_STD_VECTOR, + )(outputs, output_paddings) # Ablate random parts of input along temporal and frequency dimension # following the specaug procedure in https://arxiv.org/abs/1904.08779. @@ -805,13 +854,9 @@ def __call__(self, inputs, input_paddings, train): # Run the lstm layers. for _ in range(config.num_lstm_layers): if config.enable_residual_connections: - outputs = outputs + BatchRNN(config)( - outputs, output_paddings, train - ) + outputs = outputs + BatchRNN(config)(outputs, output_paddings, train) else: - outputs = BatchRNN(config)( - outputs, output_paddings, train - ) + outputs = BatchRNN(config)(outputs, output_paddings, train) for _ in range(config.num_ffn_layers): if config.enable_residual_connections: @@ -830,7 +875,8 @@ def __call__(self, inputs, input_paddings, train): outputs = nn.Dense( config.vocab_size, use_bias=True, - kernel_init=nn.initializers.xavier_uniform())(outputs) + kernel_init=nn.initializers.xavier_uniform(), + )(outputs) return outputs, output_paddings @@ -857,11 +903,13 @@ def collapse_and_remove_blanks(self, labels, seq_length, blank_id: int = 0): labels = (labels * blank_mask).astype(labels.dtype) # Mask labels that don't equal previous label. - label_mask = jnp.concatenate([ - jnp.ones_like(labels[:, :1], dtype=jnp.int32), - jnp.not_equal(labels[:, 1:], labels[:, :-1]) - ], - axis=1) + label_mask = jnp.concatenate( + [ + jnp.ones_like(labels[:, :1], dtype=jnp.int32), + jnp.not_equal(labels[:, 1:], labels[:, :-1]), + ], + axis=1, + ) # Filter labels that aren't in the original sequence. maxlen = labels.shape[1] @@ -897,8 +945,10 @@ def collapse_and_remove_blanks(self, labels, seq_length, blank_id: int = 0): # Reshape back to square batch. batch_size = labels.shape[0] new_shape = [batch_size, new_maxlen] - return (jnp.reshape(flat, new_shape).astype(labels.dtype), - new_seq_len.astype(seq_length.dtype)) + return ( + jnp.reshape(flat, new_shape).astype(labels.dtype), + new_seq_len.astype(seq_length.dtype), + ) def greedy_decode(self, logits, logit_paddings): per_frame_max = jnp.argmax(logits, axis=-1) @@ -921,11 +971,12 @@ def evaluate_batch(self, params, batch_stats, batch): labels = batch['targets'] label_paddings = batch['target_paddings'] - (objective_numerator, objective_denominator) = self.loss_fn( - logits, logit_paddings, labels, label_paddings) + objective_numerator, objective_denominator = self.loss_fn( + logits, logit_paddings, labels, label_paddings + ) # epsilon added to handle empty batch case if we encounter one. - normalized_loss = (objective_numerator / (objective_denominator + 1e-9)) + normalized_loss = objective_numerator / (objective_denominator + 1e-9) hyps, hyp_paddings = self.greedy_decode(logits, logit_paddings) return self.metrics_bundle.single_from_model_output( @@ -933,18 +984,17 @@ def evaluate_batch(self, params, batch_stats, batch): hyps=hyps, hyp_paddings=hyp_paddings, targets=labels, - target_paddings=label_paddings) + target_paddings=label_paddings, + ) def apply_on_batch(self, params, batch_stats, batch, **apply_kwargs): """Wrapper around flax_module.apply.""" return self.flax_module.apply( - { - 'params': params, - 'batch_stats': batch_stats - }, + {'params': params, 'batch_stats': batch_stats}, batch['inputs'], batch['input_paddings'], - **apply_kwargs) + **apply_kwargs, + ) def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): """Return CTC loss.""" @@ -952,23 +1002,22 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): # For more information on flax.linen.Module.apply, see the docs at # https://flax.readthedocs.io/en/latest/flax.linen.html#flax.linen.Module.apply. (outputs, output_paddings), new_batch_stats = self.flax_module.apply( - { - 'params': params, - 'batch_stats': batch_stats - }, + {'params': params, 'batch_stats': batch_stats}, batch['inputs'], batch['input_paddings'], rngs={'dropout': dropout_rng}, mutable=['batch_stats'], - train=True) + train=True, + ) labels = batch['targets'] label_paddings = batch['target_paddings'] - (objective_numerator, objective_denominator) = self.loss_fn( - outputs, output_paddings, labels, label_paddings) + objective_numerator, objective_denominator = self.loss_fn( + outputs, output_paddings, labels, label_paddings + ) - objective_value = (objective_numerator / (objective_denominator)) + objective_value = objective_numerator / (objective_denominator) return objective_value, new_batch_stats def build_flax_module(self): @@ -983,8 +1032,7 @@ def build_flax_module(self): time_mask_max_frames=self.hps.time_mask_max_frames, time_mask_max_ratio=self.hps.time_mask_max_ratio, time_masks_per_frame=self.hps.time_masks_per_frame, - use_dynamic_time_mask_max_frames=self.hps - .use_dynamic_time_mask_max_frames, + use_dynamic_time_mask_max_frames=self.hps.use_dynamic_time_mask_max_frames, use_specaug=self.hps.use_specaug, input_dropout_rate=self.hps.input_dropout_rate, feed_forward_dropout_rate=self.hps.feed_forward_dropout_rate, @@ -994,7 +1042,8 @@ def build_flax_module(self): enable_subsampling_batchnorm=self.hps.enable_subsampling_batchnorm, enable_synced_batchnorm=self.hps.enable_synced_batchnorm, activation=self.hps.activation, - layernorm_everywhere=self.hps.layernorm_everywhere) + layernorm_everywhere=self.hps.layernorm_everywhere, + ) module = DeepSpeechEncoderDecoder(config) return module @@ -1034,8 +1083,7 @@ def build_flax_module(self): time_mask_max_frames=self.hps.time_mask_max_frames, time_mask_max_ratio=self.hps.time_mask_max_ratio, time_masks_per_frame=self.hps.time_masks_per_frame, - use_dynamic_time_mask_max_frames=self.hps - .use_dynamic_time_mask_max_frames, + use_dynamic_time_mask_max_frames=self.hps.use_dynamic_time_mask_max_frames, use_specaug=self.hps.use_specaug, input_dropout_rate=aux_dropout_rate, feed_forward_dropout_rate=self.hps.dropout_rate, @@ -1045,7 +1093,8 @@ def build_flax_module(self): enable_subsampling_batchnorm=self.hps.enable_subsampling_batchnorm, enable_synced_batchnorm=self.hps.enable_synced_batchnorm, activation=self.hps.activation, - layernorm_everywhere=self.hps.layernorm_everywhere) + layernorm_everywhere=self.hps.layernorm_everywhere, + ) module = DeepSpeechEncoderDecoder(config) return module diff --git a/init2winit/model_lib/dlrm.py b/init2winit/model_lib/dlrm.py index 786b37f5..888d9f3d 100644 --- a/init2winit/model_lib/dlrm.py +++ b/init2winit/model_lib/dlrm.py @@ -20,7 +20,6 @@ import flax.linen as nn from init2winit.model_lib import base_model from init2winit.model_lib import model_utils - import jax from jax import nn as jnn import jax.numpy as jnp @@ -44,7 +43,8 @@ dropout_rate=0.0, normalizer='none', # dropout will exist only if there are at least two top mlp layers - )) + ) +) DEFAULT_RESNET_HPARAMS = DEFAULT_HPARAMS.copy_and_resolve_references() DEFAULT_RESNET_HPARAMS.mlp_top_dims = [128, 128, 1] @@ -106,7 +106,8 @@ class DLRM(nn.Module): def get_embeddings(self, embedding_table_block, indices_block): """Get embeddings for a block of indices.""" embedding_table_block = jax.lax.all_gather( - embedding_table_block, 'devices', axis=1, tiled=True) + embedding_table_block, 'devices', axis=1, tiled=True + ) embeddings = jnp.take(embedding_table_block, indices_block, axis=0) return embeddings @@ -117,9 +118,13 @@ def __call__(self, x, train): mesh=model_utils.get_default_mesh(), in_specs=( P(None, 'devices'), - P('devices',), + P( + 'devices', + ), + ), + out_specs=P( + 'devices', ), - out_specs=P('devices',), ) bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1) cat_features = jnp.asarray(cat_features, dtype=jnp.int32) @@ -138,8 +143,9 @@ def __call__(self, x, train): bot_mlp_input = normalizer_layer()(bot_mlp_input) bot_mlp_output = bot_mlp_input batch_size = bot_mlp_output.shape[0] - feature_stack = jnp.reshape(bot_mlp_output, - [batch_size, -1, self.embed_dim]) + feature_stack = jnp.reshape( + bot_mlp_output, [batch_size, -1, self.embed_dim] + ) base_init_fn = jnn.initializers.uniform(scale=1.0) if self.embedding_init_multiplier is None: @@ -148,21 +154,23 @@ def __call__(self, x, train): embedding_init_multiplier = self.embedding_init_multiplier # Embedding table init and lookup for a single unified table. idx_lookup = cat_features % self.vocab_size + def scaled_init(key, shape, dtype=jnp.float_): return base_init_fn(key, shape, dtype) * embedding_init_multiplier embedding_table = self.param( - 'embedding_table', - scaled_init, - [self.vocab_size, self.embed_dim]) + 'embedding_table', scaled_init, [self.vocab_size, self.embed_dim] + ) embed_features = shmapped_get_embeddings(embedding_table, idx_lookup) embed_features = normalizer_layer()(embed_features) feature_stack = jnp.concatenate([feature_stack, embed_features], axis=1) dot_interact_output = dot_interact( - concat_features=feature_stack, keep_diags=self.keep_diags) - top_mlp_input = jnp.concatenate([bot_mlp_output, dot_interact_output], - axis=-1) + concat_features=feature_stack, keep_diags=self.keep_diags + ) + top_mlp_input = jnp.concatenate( + [bot_mlp_output, dot_interact_output], axis=-1 + ) mlp_input_dim = top_mlp_input.shape[1] mlp_top_dims = self.mlp_top_dims num_layers_top = len(mlp_top_dims) @@ -171,17 +179,19 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input = nn.Dense( fan_out, kernel_init=jnn.initializers.normal( - stddev=(2.0 / (fan_in + fan_out))**0.5), + stddev=(2.0 / (fan_in + fan_out)) ** 0.5 + ), bias_init=jnn.initializers.normal( - stddev=(1.0 / mlp_top_dims[layer_idx])**0.5))( - top_mlp_input) + stddev=(1.0 / mlp_top_dims[layer_idx]) ** 0.5 + ), + )(top_mlp_input) if layer_idx < (num_layers_top - 1): top_mlp_input = activation_fn(top_mlp_input) top_mlp_input = normalizer_layer()(top_mlp_input) if self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2: top_mlp_input = nn.Dropout( - rate=self.dropout_rate, deterministic=not train)( - top_mlp_input) + rate=self.dropout_rate, deterministic=not train + )(top_mlp_input) logits = top_mlp_input return logits @@ -210,7 +220,8 @@ def build_flax_module(self): embed_dim=self.hps.embed_dim, keep_diags=self.hps.keep_diags, dropout_rate=self.hps.dropout_rate, - normalizer=self.hps.normalizer) + normalizer=self.hps.normalizer, + ) def get_fake_inputs(self, hps): """Helper method solely for purpose of initalizing the model.""" @@ -267,7 +278,8 @@ def __call__(self, x, train): mlp_bottom_dims[0], kernel_init=jnn.initializers.glorot_uniform(), bias_init=jnn.initializers.normal( - stddev=1.0 / mlp_bottom_dims[0]**0.5), + stddev=1.0 / mlp_bottom_dims[0] ** 0.5 + ), )(bot_mlp_input) bot_mlp_input = activation_fn(bot_mlp_input) @@ -286,18 +298,19 @@ def __call__(self, x, train): embedding_init_multiplier = self.embedding_init_multiplier # Embedding table init and lookup for a single unified table. idx_lookup = jnp.reshape(cat_features, [-1]) % self.vocab_size + def scaled_init(key, shape, dtype=jnp.float_): return base_init_fn(key, shape, dtype) * embedding_init_multiplier embedding_table = self.param( - 'embedding_table', - scaled_init, - [self.vocab_size, self.embed_dim]) + 'embedding_table', scaled_init, [self.vocab_size, self.embed_dim] + ) embed_features = embedding_table[idx_lookup] batch_size = bot_mlp_input.shape[0] embed_features = jnp.reshape( - embed_features, (batch_size, 26 * self.embed_dim)) + embed_features, (batch_size, 26 * self.embed_dim) + ) top_mlp_input = jnp.concatenate([bot_mlp_input, embed_features], axis=1) mlp_input_dim = top_mlp_input.shape[1] mlp_top_dims = self.mlp_top_dims @@ -305,33 +318,37 @@ def scaled_init(key, shape, dtype=jnp.float_): top_mlp_input = nn.Dense( mlp_top_dims[0], kernel_init=jnn.initializers.normal( - stddev=(2.0 / (mlp_input_dim + mlp_top_dims[0]))**0.5), + stddev=(2.0 / (mlp_input_dim + mlp_top_dims[0])) ** 0.5 + ), bias_init=jnn.initializers.normal( - stddev=(1.0 / mlp_top_dims[0])**0.5))( - top_mlp_input) + stddev=(1.0 / mlp_top_dims[0]) ** 0.5 + ), + )(top_mlp_input) top_mlp_input = activation_fn(top_mlp_input) for layer_idx, fan_out in list(enumerate(mlp_top_dims))[1:-1]: fan_in = mlp_top_dims[layer_idx - 1] x = nn.Dense( fan_out, kernel_init=jnn.initializers.normal( - stddev=(2.0 / (fan_in + fan_out))**0.5), + stddev=(2.0 / (fan_in + fan_out)) ** 0.5 + ), bias_init=jnn.initializers.normal( - stddev=(1.0 / mlp_top_dims[layer_idx])**0.5))( - top_mlp_input) + stddev=(1.0 / mlp_top_dims[layer_idx]) ** 0.5 + ), + )(top_mlp_input) x = activation_fn(x) if self.dropout_rate > 0.0 and layer_idx == num_layers_top - 2: - x = nn.Dropout( - rate=self.dropout_rate, deterministic=not train)(x) + x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) top_mlp_input += x # In the DLRM model the last layer width is always 1. We can hardcode that # below. logits = nn.Dense( 1, kernel_init=jnn.initializers.normal( - stddev=(2.0 / (mlp_top_dims[-2] + 1))**0.5), - bias_init=jnn.initializers.normal( - stddev=1.0))(top_mlp_input) + stddev=(2.0 / (mlp_top_dims[-2] + 1)) ** 0.5 + ), + bias_init=jnn.initializers.normal(stddev=1.0), + )(top_mlp_input) return logits @@ -349,7 +366,8 @@ def build_flax_module(self): num_dense_features=self.hps.num_dense_features, embed_dim=self.hps.embed_dim, keep_diags=self.hps.keep_diags, - dropout_rate=self.hps.dropout_rate) + dropout_rate=self.hps.dropout_rate, + ) def get_fake_inputs(self, hps): """Helper method solely for purpose of initalizing the model.""" diff --git a/init2winit/model_lib/fully_connected.py b/init2winit/model_lib/fully_connected.py index 4bd60ef9..f8902f85 100644 --- a/init2winit/model_lib/fully_connected.py +++ b/init2winit/model_lib/fully_connected.py @@ -14,6 +14,7 @@ # limitations under the License. """Simple fully connected feedforward neural network classifier.""" + import copy from typing import Any, Tuple @@ -24,7 +25,6 @@ import jax.numpy as jnp from ml_collections.config_dict import config_dict - # small hparams used for unit tests DEFAULT_HPARAMS = config_dict.ConfigDict( dict( @@ -43,6 +43,7 @@ class FullyConnected(nn.Module): [batch_size_per_device, *input_shape] where input_shape may be of arbitrary rank. The model flatten the input before applying a dense layer. """ + num_outputs: int hid_sizes: Tuple[int] activation_function: Any @@ -56,20 +57,23 @@ def __call__(self, x, train): if len(self.activation_function) != len(self.hid_sizes): raise ValueError( 'The number of activation functions must be equal to the number ' - 'of hidden layers') + 'of hidden layers' + ) activation_function = copy.deepcopy(self.activation_function) else: activation_function = [self.activation_function] * len(self.hid_sizes) x = jnp.reshape(x, (x.shape[0], -1)) for i, (num_hid, init) in enumerate( - zip(self.hid_sizes, self.kernel_inits[:-1])): + zip(self.hid_sizes, self.kernel_inits[:-1]) + ): x = nn.Dense(num_hid, kernel_init=init, bias_init=self.bias_init)(x) x = model_utils.ACTIVATIONS[activation_function[i]](x) x = nn.Dense( self.num_outputs, kernel_init=self.kernel_inits[-1], - bias_init=self.bias_init)(x) + bias_init=self.bias_init, + )(x) return x @@ -86,7 +90,8 @@ def build_flax_module(self): num_outputs=self.hps['output_shape'][-1], hid_sizes=tuple(self.hps.hid_sizes), activation_function=self.hps.activation_function, - kernel_inits=tuple(kernel_inits)) + kernel_inits=tuple(kernel_inits), + ) def get_fake_inputs(self, hps): """Helper method solely for the purpose of initialzing the model.""" diff --git a/init2winit/model_lib/gnn.py b/init2winit/model_lib/gnn.py index 9f7da0b7..685b69c0 100644 --- a/init2winit/model_lib/gnn.py +++ b/init2winit/model_lib/gnn.py @@ -25,6 +25,7 @@ closest to the GIN+Virtual Node (Xu et al., 2018) model, and it reaches equivalent performance on the ogbg-molpcba dataset. """ + from typing import Tuple from flax import linen as nn @@ -44,7 +45,8 @@ num_message_passing_steps=5, normalizer='layer_norm', dropout_rate=0.1, - )) + ) +) def _make_embed(latent_dim): @@ -77,6 +79,7 @@ class GNN(nn.Module): The model assumes the input data is a jraph.GraphsTuple without global variables. The final prediction will be encoded in the globals. """ + num_outputs: int latent_dim: int hidden_dims: Tuple[int] @@ -92,11 +95,13 @@ def __call__(self, graph, train): activation = model_utils.ACTIVATIONS[self.activation_function] graph = graph._replace( - globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs])) + globals=jnp.zeros([graph.n_node.shape[0], self.num_outputs]) + ) embedder = jraph.GraphMapFeatures( embed_node_fn=_make_embed(self.latent_dim), - embed_edge_fn=_make_embed(self.latent_dim)) + embed_edge_fn=_make_embed(self.latent_dim), + ) graph = embedder(graph) for _ in range(self.num_message_passing_steps): @@ -105,17 +110,21 @@ def __call__(self, graph, train): self.hidden_dims, maybe_normalize_fn=maybe_normalize_fn, dropout=dropout, - activation=activation), + activation=activation, + ), update_node_fn=_make_mlp( self.hidden_dims, maybe_normalize_fn=maybe_normalize_fn, dropout=dropout, - activation=activation), + activation=activation, + ), update_global_fn=_make_mlp( self.hidden_dims, maybe_normalize_fn=maybe_normalize_fn, dropout=dropout, - activation=activation)) + activation=activation, + ), + ) graph = net(graph) @@ -139,7 +148,8 @@ def get_fake_inputs(self, hps): edges=jnp.ones((1,) + hps.input_edge_shape), globals=jnp.zeros((1,) + hps.output_shape), senders=jnp.asarray([0]), - receivers=jnp.asarray([0])) + receivers=jnp.asarray([0]), + ) # We need to wrap the GraphsTuple in a list so that it can be passed as # *inputs to the model init function. return [graph] @@ -152,4 +162,5 @@ def build_flax_module(self): normalizer=self.hps['normalizer'], dropout_rate=self.hps['dropout_rate'], num_message_passing_steps=self.hps['num_message_passing_steps'], - activation_function=self.hps['activation_function'],) + activation_function=self.hps['activation_function'], + ) diff --git a/init2winit/model_lib/librispeech_preprocessor.py b/init2winit/model_lib/librispeech_preprocessor.py index 6ff1ce21..f79422dc 100644 --- a/init2winit/model_lib/librispeech_preprocessor.py +++ b/init2winit/model_lib/librispeech_preprocessor.py @@ -36,68 +36,175 @@ _MEL_HIGH_FREQUENCY_Q = 1127.0 LIBRISPEECH_MEAN_VECTOR = [ - -7.6047816276550293, -7.1206226348876953, -6.8864245414733887, - -6.8705768585205078, -6.9667720794677734, -7.1084094047546387, - -6.9528026580810547, -6.783994197845459, -6.6195521354675293, - -6.4876265525817871, -6.4120659828186035, -6.394047737121582, - -6.4244871139526367, -6.3993711471557617, -6.5158271789550781, - -6.7137999534606934, -6.8476877212524414, -6.9885001182556152, - -6.9221386909484863, -7.146148681640625, -7.2040400505065918, - -7.0537552833557129, -7.3140382766723633, -7.1223249435424805, - -7.30251407623291, -7.1212143898010254, -7.2425732612609863, - -7.1730537414550781, -7.0979413986206055, -7.088747501373291, - -6.9849910736083984, -6.8787732124328613, -6.7602753639221191, - -6.6300945281982422, -6.5145769119262695, -6.4245057106018066, - -6.356513500213623, -6.31787633895874, -6.2660770416259766, - -6.2468328475952148, -6.2821526527404785, -6.1908388137817383, - -6.2484354972839355, -6.1472640037536621, -6.0924725532531738, - -6.0171003341674805, -5.9250402450561523, -5.8535833358764648, - -5.8209109306335449, -5.8118929862976074, -5.80783748626709, - -5.7714629173278809, -5.7453732490539551, -5.7705655097961426, - -5.7765641212463379, -5.7831673622131348, -5.7954087257385254, - -5.7994823455810547, -5.8023476600646973, -5.8047118186950684, - -5.8168182373046875, -5.8844799995422363, -5.9727106094360352, - -6.0444660186767578, -6.1284866333007812, -6.2257585525512695, - -6.3157496452331543, -6.39061164855957, -6.4928598403930664, - -6.5498456954956055, -6.6054320335388184, -6.6508378982543945, - -6.66917610168457, -6.6726889610290527, -6.684234619140625, - -6.6974577903747559, -6.75471830368042, -6.7949142456054688, - -6.8634209632873535, -6.94186544418335 + -7.6047816276550293, + -7.1206226348876953, + -6.8864245414733887, + -6.8705768585205078, + -6.9667720794677734, + -7.1084094047546387, + -6.9528026580810547, + -6.783994197845459, + -6.6195521354675293, + -6.4876265525817871, + -6.4120659828186035, + -6.394047737121582, + -6.4244871139526367, + -6.3993711471557617, + -6.5158271789550781, + -6.7137999534606934, + -6.8476877212524414, + -6.9885001182556152, + -6.9221386909484863, + -7.146148681640625, + -7.2040400505065918, + -7.0537552833557129, + -7.3140382766723633, + -7.1223249435424805, + -7.30251407623291, + -7.1212143898010254, + -7.2425732612609863, + -7.1730537414550781, + -7.0979413986206055, + -7.088747501373291, + -6.9849910736083984, + -6.8787732124328613, + -6.7602753639221191, + -6.6300945281982422, + -6.5145769119262695, + -6.4245057106018066, + -6.356513500213623, + -6.31787633895874, + -6.2660770416259766, + -6.2468328475952148, + -6.2821526527404785, + -6.1908388137817383, + -6.2484354972839355, + -6.1472640037536621, + -6.0924725532531738, + -6.0171003341674805, + -5.9250402450561523, + -5.8535833358764648, + -5.8209109306335449, + -5.8118929862976074, + -5.80783748626709, + -5.7714629173278809, + -5.7453732490539551, + -5.7705655097961426, + -5.7765641212463379, + -5.7831673622131348, + -5.7954087257385254, + -5.7994823455810547, + -5.8023476600646973, + -5.8047118186950684, + -5.8168182373046875, + -5.8844799995422363, + -5.9727106094360352, + -6.0444660186767578, + -6.1284866333007812, + -6.2257585525512695, + -6.3157496452331543, + -6.39061164855957, + -6.4928598403930664, + -6.5498456954956055, + -6.6054320335388184, + -6.6508378982543945, + -6.66917610168457, + -6.6726889610290527, + -6.684234619140625, + -6.6974577903747559, + -6.75471830368042, + -6.7949142456054688, + -6.8634209632873535, + -6.94186544418335, ] LIBRISPEECH_STD_VECTOR = [ - 3.4353282451629639, 3.5962932109832764, 3.7012472152709961, - 3.7369205951690674, 3.7535104751586914, 3.693629264831543, - 3.6922497749328613, 3.7641522884368896, 3.8419716358184814, - 3.8999848365783691, 3.9294240474700928, 3.9317409992218018, - 3.9139585494995117, 3.9031598567962646, 3.8691999912261963, - 3.8155081272125244, 3.7644970417022705, 3.7099106311798096, - 3.6965086460113525, 3.6003766059875488, 3.5493226051330566, - 3.5465121269226074, 3.45003604888916, 3.4712812900543213, - 3.4084610939025879, 3.4408135414123535, 3.4104881286621094, - 3.4217638969421387, 3.4312851428985596, 3.4199209213256836, - 3.4305806159973145, 3.4382665157318115, 3.4580366611480713, - 3.4817991256713867, 3.4958710670471191, 3.5036792755126953, - 3.5047574043273926, 3.4988734722137451, 3.493056058883667, - 3.4822943210601807, 3.459430456161499, 3.4612770080566406, - 3.4559063911437988, 3.4755423069000244, 3.4971549510955811, - 3.5326557159423828, 3.5705199241638184, 3.5920312404632568, - 3.596907377243042, 3.5913500785827637, 3.5865931510925293, - 3.5826809406280518, 3.5837743282318115, 3.5895791053771973, - 3.5819313526153564, 3.5837869644165039, 3.5861184597015381, - 3.5889589786529541, 3.592214822769165, 3.5939455032348633, - 3.5856630802154541, 3.5884113311767578, 3.5921022891998291, - 3.5870490074157715, 3.5806570053100586, 3.5731067657470703, - 3.5617532730102539, 3.54980731010437, 3.5527374744415283, - 3.5475366115570068, 3.5387849807739258, 3.5256178379058838, - 3.5031836032867432, 3.4922726154327393, 3.4879646301269531, - 3.4725594520568848, 3.4558389186859131, 3.4351828098297119, - 3.4284293651580811, 3.4299170970916748 + 3.4353282451629639, + 3.5962932109832764, + 3.7012472152709961, + 3.7369205951690674, + 3.7535104751586914, + 3.693629264831543, + 3.6922497749328613, + 3.7641522884368896, + 3.8419716358184814, + 3.8999848365783691, + 3.9294240474700928, + 3.9317409992218018, + 3.9139585494995117, + 3.9031598567962646, + 3.8691999912261963, + 3.8155081272125244, + 3.7644970417022705, + 3.7099106311798096, + 3.6965086460113525, + 3.6003766059875488, + 3.5493226051330566, + 3.5465121269226074, + 3.45003604888916, + 3.4712812900543213, + 3.4084610939025879, + 3.4408135414123535, + 3.4104881286621094, + 3.4217638969421387, + 3.4312851428985596, + 3.4199209213256836, + 3.4305806159973145, + 3.4382665157318115, + 3.4580366611480713, + 3.4817991256713867, + 3.4958710670471191, + 3.5036792755126953, + 3.5047574043273926, + 3.4988734722137451, + 3.493056058883667, + 3.4822943210601807, + 3.459430456161499, + 3.4612770080566406, + 3.4559063911437988, + 3.4755423069000244, + 3.4971549510955811, + 3.5326557159423828, + 3.5705199241638184, + 3.5920312404632568, + 3.596907377243042, + 3.5913500785827637, + 3.5865931510925293, + 3.5826809406280518, + 3.5837743282318115, + 3.5895791053771973, + 3.5819313526153564, + 3.5837869644165039, + 3.5861184597015381, + 3.5889589786529541, + 3.592214822769165, + 3.5939455032348633, + 3.5856630802154541, + 3.5884113311767578, + 3.5921022891998291, + 3.5870490074157715, + 3.5806570053100586, + 3.5731067657470703, + 3.5617532730102539, + 3.54980731010437, + 3.5527374744415283, + 3.5475366115570068, + 3.5387849807739258, + 3.5256178379058838, + 3.5031836032867432, + 3.4922726154327393, + 3.4879646301269531, + 3.4725594520568848, + 3.4558389186859131, + 3.4351828098297119, + 3.4284293651580811, + 3.4299170970916748, ] @struct.dataclass class LibrispeechPreprocessingConfig: """Config to hold all preprocessing options for librispeech dataset.""" + sample_rate: float = 16000.0 frame_size_ms: float = 25.0 frame_step_ms: float = 10.0 @@ -117,8 +224,9 @@ class LibrispeechPreprocessingConfig: def _hertz_to_mel(frequencies_hertz): """Convert hertz to mel.""" - return _MEL_HIGH_FREQUENCY_Q * jnp.log(1.0 + (frequencies_hertz / - _MEL_BREAK_FREQUENCY_HERTZ)) + return _MEL_HIGH_FREQUENCY_Q * jnp.log( + 1.0 + (frequencies_hertz / _MEL_BREAK_FREQUENCY_HERTZ) + ) def _pad_end_length(num_timesteps, frame_step, frame_size): @@ -130,11 +238,13 @@ def _pad_end_length(num_timesteps, frame_step, frame_size): return padded_length - num_timesteps -def frame(x, - frame_length: int, - frame_step: int, - pad_end: bool = False, - pad_value: Union[int, float] = 0.0): +def frame( + x, + frame_length: int, + frame_step: int, + pad_end: bool = False, + pad_value: Union[int, float] = 0.0, +): """Slides a window and extract values. This function extracts `x[:, n:n+frame_length, :]` with sliding `n` with @@ -160,24 +270,31 @@ def frame(x, if pad_end: num_extends = _pad_end_length(num_timesteps, frame_step, frame_length) x = jnp.pad( - x, ((0, 0), (0, num_extends), (0, 0)), + x, + ((0, 0), (0, num_extends), (0, 0)), 'constant', - constant_values=pad_value) + constant_values=pad_value, + ) flat_y = jax.lax.conv_general_dilated_patches( - x, (frame_length,), (frame_step,), + x, + (frame_length,), + (frame_step,), 'VALID', - dimension_numbers=('NTC', 'OIT', 'NTC')) + dimension_numbers=('NTC', 'OIT', 'NTC'), + ) ret = flat_y.reshape(flat_y.shape[:-1] + (num_channels, frame_length)) return ret.transpose((0, 1, 3, 2)) -def linear_to_mel_weight_matrix(num_mel_bins: int = 20, - num_spectrogram_bins: int = 129, - sample_rate: Union[int, float] = 8000, - lower_edge_hertz: Union[int, float] = 125.0, - upper_edge_hertz: Union[int, float] = 3800.0, - dtype: Any = jnp.float32): +def linear_to_mel_weight_matrix( + num_mel_bins: int = 20, + num_spectrogram_bins: int = 129, + sample_rate: Union[int, float] = 8000, + lower_edge_hertz: Union[int, float] = 125.0, + upper_edge_hertz: Union[int, float] = 3800.0, + dtype: Any = jnp.float32, +): r"""Jax-port of `tf.signal.linear_to_mel_weight_matrix`. Args: @@ -209,23 +326,29 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, if num_mel_bins <= 0: raise ValueError('num_mel_bins must be positive. Got: %s' % num_mel_bins) if lower_edge_hertz < 0.0: - raise ValueError('lower_edge_hertz must be non-negative. Got: %s' % - lower_edge_hertz) + raise ValueError( + 'lower_edge_hertz must be non-negative. Got: %s' % lower_edge_hertz + ) if lower_edge_hertz >= upper_edge_hertz: - raise ValueError('lower_edge_hertz %.1f >= upper_edge_hertz %.1f' % - (lower_edge_hertz, upper_edge_hertz)) + raise ValueError( + 'lower_edge_hertz %.1f >= upper_edge_hertz %.1f' + % (lower_edge_hertz, upper_edge_hertz) + ) if sample_rate <= 0.0: raise ValueError('sample_rate must be positive. Got: %s' % sample_rate) if upper_edge_hertz > sample_rate / 2: - raise ValueError('upper_edge_hertz must not be larger than the Nyquist ' - 'frequency (sample_rate / 2). Got %s for sample_rate: %s' % - (upper_edge_hertz, sample_rate)) + raise ValueError( + 'upper_edge_hertz must not be larger than the Nyquist ' + 'frequency (sample_rate / 2). Got %s for sample_rate: %s' + % (upper_edge_hertz, sample_rate) + ) # HTK excludes the spectrogram DC bin. bands_to_zero = 1 nyquist_hertz = sample_rate / 2.0 linear_frequencies = jnp.linspace( - 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype)[bands_to_zero:] + 0.0, nyquist_hertz, num_spectrogram_bins, dtype=dtype + )[bands_to_zero:] spectrogram_bins_mel = _hertz_to_mel(linear_frequencies)[:, jnp.newaxis] # Compute num_mel_bins triples of (lower_edge, center, upper_edge). The @@ -236,7 +359,8 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, _hertz_to_mel(lower_edge_hertz), _hertz_to_mel(upper_edge_hertz), num_mel_bins + 2, - dtype=dtype) + dtype=dtype, + ) # Split the triples up and reshape them into [1, num_mel_bins] tensors. lower_edge_mel = edges[:-2][jnp.newaxis, :] @@ -246,9 +370,11 @@ def linear_to_mel_weight_matrix(num_mel_bins: int = 20, # Calculate lower and upper slopes for every spectrogram bin. # Line segments are linear in the mel domain, not Hertz. lower_slopes = (spectrogram_bins_mel - lower_edge_mel) / ( - center_mel - lower_edge_mel) + center_mel - lower_edge_mel + ) upper_slopes = (upper_edge_mel - spectrogram_bins_mel) / ( - upper_edge_mel - center_mel) + upper_edge_mel - center_mel + ) # Intersect the line segments with each other and zero. mel_weights_matrix = jnp.maximum(0.0, jnp.minimum(lower_slopes, upper_slopes)) @@ -276,22 +402,25 @@ def _hanning_greco(win_support, frame_size, dtype): if frame_size < win_support: raise ValueError( 'Provided frame_size = {} is lower than win_support = {}'.format( - frame_size, win_support)) + frame_size, win_support + ) + ) arg = jnp.pi * 2.0 / (win_support) - hann = 0.5 - (0.5 * jnp.cos(arg * - (jnp.arange(win_support, dtype=dtype) + 0.5))) + hann = 0.5 - ( + 0.5 * jnp.cos(arg * (jnp.arange(win_support, dtype=dtype) + 0.5)) + ) zero_size = frame_size - win_support return jnp.pad(hann, [(0, zero_size)]) def _next_pow_of_two(x: Union[int, float]) -> int: - return int(2**np.ceil(np.log2(x))) + return int(2 ** np.ceil(np.log2(x))) class SpectrogramFrontend(nn.Module): - """Layer to convert input audio signals from time domain to frequency domain. - """ + """Layer to convert input audio signals from time domain to frequency domain.""" + config: LibrispeechPreprocessingConfig = None input_scale_factor: float = 1.0 output_log: bool = False @@ -299,8 +428,9 @@ class SpectrogramFrontend(nn.Module): def setup(self) -> None: p = self.config self._frame_step = int(round(p.sample_rate * p.frame_step_ms / 1000.0)) - self._frame_size = int(round( - p.sample_rate * p.frame_size_ms / 1000.0)) + 1 # +1 for the preemph + self._frame_size = ( + int(round(p.sample_rate * p.frame_size_ms / 1000.0)) + 1 + ) # +1 for the preemph # TF-version has maximum of 512, but it's not always necessary self.fft_size = _next_pow_of_two(self._frame_size) @@ -330,22 +460,30 @@ def f(frame_size, dtype): def _apply_preemphasis(self, framed_signal): p = self.config if p.preemph_htk_flavor: - return jnp.concatenate([ - framed_signal[:, :, :1, :] * (1. - p.preemph), - (framed_signal[:, :, 1:-1, :] - - p.preemph * framed_signal[:, :, :-2, :]) - ], axis=2) + return jnp.concatenate( + [ + framed_signal[:, :, :1, :] * (1.0 - p.preemph), + ( + framed_signal[:, :, 1:-1, :] + - p.preemph * framed_signal[:, :, :-2, :] + ), + ], + axis=2, + ) else: - return (framed_signal[:, :, 1:, :] - - p.preemph * framed_signal[:, :, :-1, :]) + return ( + framed_signal[:, :, 1:, :] - p.preemph * framed_signal[:, :, :-1, :] + ) def fprop_paddings(self, input_paddings): p = self.config if p.pad_end: - num_extends = _pad_end_length(input_paddings.shape[1], self._frame_step, - self._frame_size) + num_extends = _pad_end_length( + input_paddings.shape[1], self._frame_step, self._frame_size + ) input_paddings = jnp.pad( - input_paddings, ((0, 0), (0, num_extends)), constant_values=1.0) + input_paddings, ((0, 0), (0, num_extends)), constant_values=1.0 + ) return jax.lax.reduce_window( input_paddings, @@ -353,7 +491,8 @@ def fprop_paddings(self, input_paddings): computation=jax.lax.min, window_dimensions=[1, self._frame_size], window_strides=[1, self._frame_step], - padding='valid') + padding='valid', + ) def next_prng_key(self, name='dropout'): return self.make_rng(name) @@ -376,7 +515,8 @@ def __call__(self, inputs, input_paddings): pcm_audio_chunk = inputs.astype(jnp.float32) * self.input_scale_factor framed_signal = frame( - pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end) + pcm_audio_chunk, self._frame_size, self._frame_step, pad_end=p.pad_end + ) if p.preemph != 0.0: preemphasized = self._apply_preemphasis(framed_signal) @@ -384,8 +524,10 @@ def __call__(self, inputs, input_paddings): preemphasized = framed_signal[..., :-1, :] if p.noise_scale > 0.0: - noise_signal = jax.random.normal(self.next_prng_key(), - preemphasized.shape) * p.noise_scale + noise_signal = ( + jax.random.normal(self.next_prng_key(), preemphasized.shape) + * p.noise_scale + ) else: noise_signal = jnp.zeros(preemphasized.shape) @@ -408,8 +550,8 @@ def __call__(self, inputs, input_paddings): class MelFilterbankFrontend(nn.Module): - """Layer to compute log mel spectograms from input audio signals. - """ + """Layer to compute log mel spectograms from input audio signals.""" + config: LibrispeechPreprocessingConfig = None use_divide_stream: bool = True per_bin_mean: Optional[float] = None @@ -418,9 +560,10 @@ class MelFilterbankFrontend(nn.Module): def setup(self): p = self.config - input_scale_factor = 2 ** -15 if self.use_divide_stream else 1.0 + input_scale_factor = 2**-15 if self.use_divide_stream else 1.0 self.stft = SpectrogramFrontend( - p, input_scale_factor=input_scale_factor, output_log=False) + p, input_scale_factor=input_scale_factor, output_log=False + ) if self.per_bin_mean is None: per_bin_mean = [0.0] * p.num_bins @@ -432,11 +575,12 @@ def setup(self): else: per_bin_stddev = self.per_bin_stddev - self._normalizer_mean = jnp.array(per_bin_mean)[jnp.newaxis, jnp.newaxis, :, - jnp.newaxis] - self._normalizer_stddev = jnp.array(per_bin_stddev)[jnp.newaxis, - jnp.newaxis, :, - jnp.newaxis] + self._normalizer_mean = jnp.array(per_bin_mean)[ + jnp.newaxis, jnp.newaxis, :, jnp.newaxis + ] + self._normalizer_stddev = jnp.array(per_bin_stddev)[ + jnp.newaxis, jnp.newaxis, :, jnp.newaxis + ] @nn.compact def __call__(self, inputs, input_paddings): @@ -449,14 +593,17 @@ def __call__(self, inputs, input_paddings): num_spectrogram_bins=spect.shape[2], sample_rate=p.sample_rate, lower_edge_hertz=p.lower_edge_hertz, - upper_edge_hertz=p.upper_edge_hertz) + upper_edge_hertz=p.upper_edge_hertz, + ) mel_spectrogram = jnp.einsum('fn,btfc->btnc', mel_weights, spect) logmel_spectrogram = jnp.log(jnp.maximum(mel_spectrogram, p.output_floor)) normalized_logmel_spectrogram = ( - (logmel_spectrogram - self._normalizer_mean) / self._normalizer_stddev) + logmel_spectrogram - self._normalizer_mean + ) / self._normalizer_stddev - normalized_logmel_spectrogram = jnp.squeeze(normalized_logmel_spectrogram, - -1) + normalized_logmel_spectrogram = jnp.squeeze( + normalized_logmel_spectrogram, -1 + ) return normalized_logmel_spectrogram, spect_paddings diff --git a/init2winit/model_lib/local_attention_transformer.py b/init2winit/model_lib/local_attention_transformer.py index b489cd01..2b6c0b03 100644 --- a/init2winit/model_lib/local_attention_transformer.py +++ b/init2winit/model_lib/local_attention_transformer.py @@ -22,6 +22,7 @@ developed by: aurkor@google.com and msaffar@google.com. """ + import itertools import math from typing import Any, Dict, List, Sequence, Tuple, Union @@ -34,13 +35,11 @@ import numpy as np INITIALIZERS = { - 'variance_scaling': - nn.initializers.variance_scaling( - 0.2, mode='fan_avg', distribution='uniform'), - 'glorot_uniform': - nn.initializers.glorot_uniform(), - 'uniform': - nn.initializers.uniform() + 'variance_scaling': nn.initializers.variance_scaling( + 0.2, mode='fan_avg', distribution='uniform' + ), + 'glorot_uniform': nn.initializers.glorot_uniform(), + 'uniform': nn.initializers.uniform(), } DEFAULT_HPARAMS = config_dict.ConfigDict( @@ -90,7 +89,8 @@ def shift_right(x, axis=1): pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) + x, pad_widths, mode='constant', constant_values=x.dtype.type(0) + ) return jax.lax.dynamic_slice_in_dim(padded, 0, padded.shape[axis] - 1, axis) @@ -102,6 +102,7 @@ class FeedForward(nn.Module): feedforward_dropout: dropout rate in the Dropout layers. kernel_init: kernel initializer in the Dense layers. """ + feedforward_depths: Sequence[int] = None feedforward_dropout: float = 0.0 kernel_init: Any = nn.initializers.glorot_uniform() @@ -123,26 +124,29 @@ def __call__(self, input_x: Tensor, train: bool = False) -> Tensor: features=self.feedforward_depths[0], kernel_init=self.kernel_init, use_bias=True, - name='conv1')(input_x) + name='conv1', + )(input_x) x = nn.relu(x) - x = nn.Dropout( - rate=self.feedforward_dropout, deterministic=not train)(x) + x = nn.Dropout(rate=self.feedforward_dropout, deterministic=not train)(x) output = nn.Dense( features=self.feedforward_depths[1], kernel_init=self.kernel_init, use_bias=True, - name='conv2')(x) + name='conv2', + )(x) return output # TODO(krasowiak): Potentially replace lines 128-1226 # with modified Flax SelfAttention -def decode_step_to_index(decode_step: int, - array_shape: Tuple[int], - query_shape: Tuple[int] = (256,)) -> Tuple[int]: +def decode_step_to_index( + decode_step: int, + array_shape: Tuple[int, ...], + query_shape: Tuple[int, ...] = (256,), +) -> Tuple[int, ...]: """Maps decode step to n-d index according to blocked raster scan order. Args: @@ -155,8 +159,10 @@ def decode_step_to_index(decode_step: int, `decode_step` w.r.t. blocked raster scan order. """ if len(query_shape) != len(array_shape): - raise ValueError(f'Query ({query_shape}) and array ({array_shape})' - ' shapes not the same length.') + raise ValueError( + f'Query ({query_shape}) and array ({array_shape})' + ' shapes not the same length.' + ) blocks_per_dimension = [t // q for t, q in zip(array_shape, query_shape)] items_in_block = np.prod(query_shape, dtype=jnp.int32) @@ -182,7 +188,7 @@ def decode_step_to_index(decode_step: int, def get_item_at_decode_step( input_array: Tensor, decode_step: int = None, - query_shape: Tuple[int] = (256,) + query_shape: Tuple[int, ...] = (256,), ) -> Tensor: """Extracts a single item from an n-d array at `decode_step` position. @@ -200,10 +206,12 @@ def get_item_at_decode_step( index = decode_step_to_index( decode_step=decode_step, array_shape=x_shape[1:-1], - query_shape=query_shape) + query_shape=query_shape, + ) index = [i.tolist() for i in index] - output = input_array[:x_shape[0], index[0]:index[0] + len(index), - 0:x_shape[-1]] + output = input_array[ + : x_shape[0], index[0] : index[0] + len(index), 0 : x_shape[-1] + ] return output @@ -227,7 +235,8 @@ def ones_matrix_band_part( num_cols: int, max_backward: int, max_forward: int, - output_shape: Tuple[int] = None) -> Tensor: + output_shape: Tuple[int, ...] = None, +) -> Tensor: """Prepares a matrix band part of 1s. Args: @@ -259,10 +268,8 @@ def ones_matrix_band_part( def attention_bias_local( - length: int, - max_backward: int, - max_forward: int, - scale_factor: float = -1e9) -> Tensor: + length: int, max_backward: int, max_forward: int, scale_factor: float = -1e9 +) -> Tensor: """Creates a bias array to be added to attention logits. A position may attend to positions at most max_distance from it, @@ -284,14 +291,15 @@ def attention_bias_local( num_cols=length, max_backward=max_backward, max_forward=max_forward, - output_shape=[1, 1, length, length]) + output_shape=[1, 1, length, length], + ) output = scale_factor * (1.0 - output) return output -def attention_bias_lower_triangle(length: int, - bias_cache: Dict[str, - Tensor] = None) -> Tensor: +def attention_bias_lower_triangle( + length: int, bias_cache: Dict[str, Tensor] = None +) -> Tensor: """Creates an bias tensor to be added to attention logits. Args: @@ -313,10 +321,10 @@ def attention_bias_lower_triangle(length: int, def causal_attention_bias_nd( - query_shape: Tuple[int] = (256,), - memory_flange: Tuple[int] = (256,), + query_shape: Tuple[int, ...] = (256,), + memory_flange: Tuple[int, ...] = (256,), decode_step: int = None, - bias_cache: Dict[str, Tensor] = None + bias_cache: Dict[str, Tensor] = None, ) -> Tensor: """Creates a causal attention bias for local nd attention. @@ -336,39 +344,47 @@ def causal_attention_bias_nd( return bias_cache[cache_key] if all([m % q != 0 for q, m in zip(query_shape, memory_flange)]): - raise ValueError(f'Query ({query_shape}) and memory ({memory_flange})' - ' modulo not equal to 0.') + raise ValueError( + f'Query ({query_shape}) and memory ({memory_flange})' + ' modulo not equal to 0.' + ) blocks_per_memory_flange = [ m // q for q, m in zip(query_shape, memory_flange) ] - prev_blocks = np.prod( - [2 * b + 1 for b in blocks_per_memory_flange], - dtype=jnp.int32) // 2 + prev_blocks = ( + np.prod([2 * b + 1 for b in blocks_per_memory_flange], dtype=jnp.int32) + // 2 + ) all_blocks = np.prod( - [blocks_per_memory_flange[0] + 1] + - [2 * b + 1 for b in blocks_per_memory_flange[1:]], - dtype=jnp.int32) + [blocks_per_memory_flange[0] + 1] + + [2 * b + 1 for b in blocks_per_memory_flange[1:]], + dtype=jnp.int32, + ) future_blocks = all_blocks - prev_blocks - 1 items_in_block = np.prod(query_shape, dtype=jnp.int32) items_in_query = items_in_block if decode_step is None else 1 prev_blocks_attn = jnp.zeros( - [1, 1, items_in_query, prev_blocks * items_in_block]) + [1, 1, items_in_query, prev_blocks * items_in_block] + ) if decode_step is None: - center_block_attn = attention_bias_lower_triangle(length=items_in_block, - bias_cache=bias_cache) + center_block_attn = attention_bias_lower_triangle( + length=items_in_block, bias_cache=bias_cache + ) else: step_in_block = decode_step % items_in_block cond = jnp.less_equal( - jnp.arange(stop=items_in_block, dtype=jnp.int32), - step_in_block).reshape(shape=[1, 1, items_in_query, items_in_block]) + jnp.arange(stop=items_in_block, dtype=jnp.int32), step_in_block + ).reshape(shape=[1, 1, items_in_query, items_in_block]) x = jnp.zeros([1, 1, items_in_query, items_in_block]) y = -1e9 * jnp.ones([1, 1, items_in_query, items_in_block]) center_block_attn = jnp.where(cond, x, y) future_blocks_attn = -1e9 * jnp.ones( - [1, 1, items_in_query, future_blocks * items_in_block]) + [1, 1, items_in_query, future_blocks * items_in_block] + ) output = jnp.concatenate( - [prev_blocks_attn, center_block_attn, future_blocks_attn], axis=3) + [prev_blocks_attn, center_block_attn, future_blocks_attn], axis=3 + ) if decode_step is None: if bias_cache is None: @@ -390,19 +406,22 @@ def maybe_tile(input_x: Tensor, input_y: Tensor) -> Tensor: x_shape = input_x.shape y_shape = input_y.shape if len(x_shape) != len(y_shape): - raise ValueError(f'Query ({x_shape}) and array ({y_shape})' - ' shapes not the same length.') + raise ValueError( + f'Query ({x_shape}) and array ({y_shape}) shapes not the same length.' + ) x_tile = [1] y_tile = [1] for x_dim, y_dim in zip(x_shape[1:-1], y_shape[1:-1]): try: if x_dim % y_dim != 0: - raise ValueError(f'X_dim ({x_dim}) and y_dim ({y_dim})' - ' modulos not equal to 0.') + raise ValueError( + f'X_dim ({x_dim}) and y_dim ({y_dim}) modulos not equal to 0.' + ) except ValueError as maybe_tile_error: if y_dim % x_dim != 0: - raise ValueError(f'X_dim ({x_dim}) and y_dim ({y_dim})' - ' modulos not equal to 0.') from maybe_tile_error + raise ValueError( + f'X_dim ({x_dim}) and y_dim ({y_dim}) modulos not equal to 0.' + ) from maybe_tile_error if x_dim == y_dim: x_tile.append(1) y_tile.append(1) @@ -420,12 +439,12 @@ def maybe_tile(input_x: Tensor, input_y: Tensor) -> Tensor: def local_attention_bias_nd( v_array: Tensor, - query_shape: Tuple[int] = (256,), - memory_flange: Tuple[int] = (256,), + query_shape: Tuple[int, ...] = (256,), + memory_flange: Tuple[int, ...] = (256,), masked: bool = True, cache_padding_bias: bool = False, decode_step: int = None, - bias_cache: Dict[str, Tensor] = None + bias_cache: Dict[str, Tensor] = None, ) -> Tensor: """Creates an attention bias for local n-d attention. @@ -444,14 +463,14 @@ def local_attention_bias_nd( items_in_memory] if cache_padding_bias is True. """ cache_names = ['_'.join(map(str, i)) for i in [query_shape, memory_flange]] - cache_key = 'local_attention_bias_{}_{}_{}_{}'.format(cache_names[0], - cache_names[1], masked, - cache_padding_bias) + cache_key = 'local_attention_bias_{}_{}_{}_{}'.format( + cache_names[0], cache_names[1], masked, cache_padding_bias + ) if bias_cache and cache_key in bias_cache and decode_step is None: return bias_cache[cache_key] if cache_padding_bias: - array = embedding_to_padding(embedding=v_array[:1, :, :, :]) * -1e9, + array = (embedding_to_padding(embedding=v_array[:1, :, :, :]) * -1e9,) padding_attn_bias = jnp.expand_dims(array, axis=-2) else: array = embedding_to_padding(embedding=v_array) * -1e9 @@ -462,9 +481,11 @@ def local_attention_bias_nd( query_shape=query_shape, memory_flange=memory_flange, decode_step=decode_step, - bias_cache=bias_cache) - causal_attn_bias, padding_attn_bias = maybe_tile(input_x=causal_attn_bias, - input_y=padding_attn_bias) + bias_cache=bias_cache, + ) + causal_attn_bias, padding_attn_bias = maybe_tile( + input_x=causal_attn_bias, input_y=padding_attn_bias + ) output = jnp.minimum(causal_attn_bias, padding_attn_bias) else: output = padding_attn_bias @@ -476,7 +497,7 @@ def local_attention_bias_nd( return output -def pad_to_multiple_nd(input_x: Tensor, block_shape: Tuple[int]) -> Tensor: +def pad_to_multiple_nd(input_x: Tensor, block_shape: Tuple[int, ...]) -> Tensor: """Ensures the input is a multiple of a provided shape. Args: @@ -489,15 +510,14 @@ def pad_to_multiple_nd(input_x: Tensor, block_shape: Tuple[int]) -> Tensor: """ shape = input_x.shape paddings = [-l % b for l, b in zip(shape[1:-1], block_shape)] - output = jnp.pad( - input_x, [(0, 0)] + [(0, p) for p in paddings] + [(0, 0)]) + output = jnp.pad(input_x, [(0, 0)] + [(0, p) for p in paddings] + [(0, 0)]) return output def select_block_for_decode_step( input_x: Tensor, decode_step: int = None, - query_shape: Tuple[int] = (256,) + query_shape: Tuple[int, ...] = (256,), ) -> Tensor: """Selects one block from the input array that contains position `decode_step`. @@ -514,11 +534,15 @@ def select_block_for_decode_step( blocked_x_shape = input_x.shape x_shape = [b * q for b, q in zip(blocked_x_shape[1:-2], query_shape)] index = decode_step_to_index( - decode_step=decode_step, array_shape=query_shape, query_shape=x_shape) + decode_step=decode_step, array_shape=query_shape, query_shape=x_shape + ) blocked_index = [i // q for i, q in zip(index, query_shape)] - output = input_x[:blocked_x_shape[0], - blocked_index[0]:blocked_index[0] + len(blocked_index), - 0:blocked_x_shape[-2], 0:blocked_x_shape[-1]] + output = input_x[ + : blocked_x_shape[0], + blocked_index[0] : blocked_index[0] + len(blocked_index), + 0 : blocked_x_shape[-2], + 0 : blocked_x_shape[-1], + ] return output @@ -536,28 +560,39 @@ def break_into_blocks_nd(input_x: Tensor, block_shape: Tuple[int]) -> Tensor: """ x_shape = list(input_x.shape) if all([l % b != 0 for l, b in zip(x_shape[1:], block_shape)]): - raise ValueError(f'X_shape[1:] ({x_shape[1:]}) and block ({block_shape})' - ' modulo not equal to 0.') + raise ValueError( + f'X_shape[1:] ({x_shape[1:]}) and block ({block_shape})' + ' modulo not equal to 0.' + ) blocks_per_dimension = [l // b for l, b in zip(x_shape[1:], block_shape)] reshape_to = list( - itertools.chain.from_iterable(zip(blocks_per_dimension, block_shape))) + itertools.chain.from_iterable(zip(blocks_per_dimension, block_shape)) + ) input_x = jnp.reshape(input_x, [-1] + reshape_to + x_shape[-1:]) block_dimensions_index = [2 * (i + 1) for i in range(len(block_shape))] - axes = [0] + [i - 1 for i in block_dimensions_index - ] + block_dimensions_index + [2 * len(block_shape) + 1] + axes = ( + [0] + + [i - 1 for i in block_dimensions_index] + + block_dimensions_index + + [2 * len(block_shape) + 1] + ) input_x = jnp.transpose(input_x, axes) - axes = [-1] + blocks_per_dimension + [ - np.prod(block_shape, dtype=jnp.int32) - ] + x_shape[-1:] + axes = ( + [-1] + + blocks_per_dimension + + [np.prod(block_shape, dtype=jnp.int32)] + + x_shape[-1:] + ) output = jnp.reshape(input_x, axes) return output def break_into_memory_blocks_nd( input_x: Tensor, - query_shape: Tuple[int] = (256,), - memory_flange: Tuple[int] = (256,), - masked: bool = True) -> Tensor: + query_shape: Tuple[int, ...] = (256,), + memory_flange: Tuple[int, ...] = (256,), + masked: bool = True, +) -> Tensor: """Breaks an input array into memory blocks around query blocks. Args: @@ -572,8 +607,10 @@ def break_into_memory_blocks_nd( which is equal to q[i] + 2m[i] or q[i] + m[i] if masked attention and i = 1 """ if all([m % q != 0 for q, m in zip(query_shape, memory_flange)]): - raise ValueError(f'Query ({query_shape}) and memory ({memory_flange})' - ' modulo not equal to 0.') + raise ValueError( + f'Query ({query_shape}) and memory ({memory_flange})' + ' modulo not equal to 0.' + ) original_x_shape = input_x.shape blocks_in_memory_flange = [m // b for b, m in zip(query_shape, memory_flange)] num_query_blocks = [ @@ -583,12 +620,14 @@ def break_into_memory_blocks_nd( if masked: input_x = jnp.pad( input_x, - [[0, 0], [memory_flange[0], 0]] + - [[p, p] for p in memory_flange[1:]] + [[0, 0]]) + [[0, 0], [memory_flange[0], 0]] + + [[p, p] for p in memory_flange[1:]] + + [[0, 0]], + ) else: input_x = jnp.pad( - input_x, - [[0, 0]] + [[p, p] for p in memory_flange] + [[0, 0]]) + input_x, [[0, 0]] + [[p, p] for p in memory_flange] + [[0, 0]] + ) query_blocks = break_into_blocks_nd(input_x=input_x, block_shape=query_shape) start_indices_per_dimension = [] @@ -602,9 +641,9 @@ def break_into_memory_blocks_nd( slices = [] for start_indices in itertools.product(*start_indices_per_dimension): - s = query_blocks[0:, - start_indices[0]:start_indices[0] + num_query_blocks[0], - 0:, 0:] + s = query_blocks[ + 0:, start_indices[0] : start_indices[0] + num_query_blocks[0], 0:, 0: + ] slices.append(s) output = jnp.concatenate(slices, axis=-2) return output @@ -627,8 +666,9 @@ def flatten_blocks_nd(input_x: Tensor) -> Tensor: return output -def unflatten_blocks_nd(input_x: Tensor, - blocks_per_dimension: List[int]) -> Tensor: +def unflatten_blocks_nd( + input_x: Tensor, blocks_per_dimension: List[int] +) -> Tensor: """Converts a flattened array to a blocked array. Args: @@ -639,25 +679,24 @@ def unflatten_blocks_nd(input_x: Tensor, output: an array of shape [batch, d1, d2, ..., dn, items_in_block, depth]. """ x_shape = list(input_x.shape) - assert x_shape[1] == np.prod( - blocks_per_dimension, dtype=jnp.int32) - output = jnp.reshape( - input_x, [-1] + blocks_per_dimension + x_shape[-2:]) + assert x_shape[1] == np.prod(blocks_per_dimension, dtype=jnp.int32) + output = jnp.reshape(input_x, [-1] + blocks_per_dimension + x_shape[-2:]) return output def break_bias_into_blocks( input_bias: Tensor, local_num_heads: int = 8, - memory_query_shape: Tuple[int] = (256,), - memory_flange: Tuple[int] = (256,), + memory_query_shape: Tuple[int, ...] = (256,), + memory_flange: Tuple[int, ...] = (256,), masked: bool = True, - decode_step: int = None) -> Tensor: + decode_step: int = None, +) -> Tensor: """Breaks bias into a blocked array. Args: input_bias: a bias array of shape of shape [batch * heads, num_blocks, - items_in_query, items_in_memory]. + items_in_query, items_in_memory]. local_num_heads: a number of local attention heads. memory_query_shape: a tuple with a memory query shape. memory_flange: a tuple with a memory flange shape. @@ -676,20 +715,20 @@ def break_bias_into_blocks( input_x=x, query_shape=memory_query_shape, memory_flange=memory_flange, - masked=masked) + masked=masked, + ) if decode_step is not None: x = select_block_for_decode_step( - input_x=x, decode_step=decode_step, query_shape=memory_query_shape) + input_x=x, decode_step=decode_step, query_shape=memory_query_shape + ) x = flatten_blocks_nd(input_x=x) output = jnp.squeeze(x, axis=-1) return output -def cast_like( - input_x: Tensor, - input_y: Tensor) -> Tensor: +def cast_like(input_x: Tensor, input_y: Tensor) -> Tensor: """Cast the same dtype on the first array as on the second if necessary. Args: @@ -711,8 +750,9 @@ def generate_relative_positions_matrix( length_q: int, length_k: int, max_relative_position: int = 513, - query_shape: Tuple[int] = (256,), - decode_step: int = None) -> Tensor: + query_shape: Tuple[int, ...] = (256,), + decode_step: int = None, +) -> Tensor: """Generates matrix of relative positions. Args: @@ -737,11 +777,13 @@ def generate_relative_positions_matrix( else: block_len = np.prod(query_shape) positive_positions = block_len - decode_step % block_len - distance_mat = jnp.expand_dims( - jnp.arange(-length_k, 0, 1), - axis=0) + positive_positions + distance_mat = ( + jnp.expand_dims(jnp.arange(-length_k, 0, 1), axis=0) + + positive_positions + ) distance_mat_clipped = jnp.clip( - distance_mat, -max_relative_position, max_relative_position) + distance_mat, -max_relative_position, max_relative_position + ) output = distance_mat_clipped + max_relative_position return output @@ -755,17 +797,20 @@ class RelativePositionEmbeddings(nn.Module): query_shape: a tuple with a query shape. embedding_init: embeddings initializer. """ + embed_layer_name: str max_relative_position: int = 513 depth: int = 129 - query_shape: Tuple[int] = (256,) + query_shape: Tuple[int, ...] = (256,) decode_step: int = None embedding_init: Any = nn.initializers.glorot_uniform() def setup(self): self.embedding = self.param( - self.embed_layer_name, self.embedding_init, - (self.max_relative_position * 2 + 1, self.depth)) + self.embed_layer_name, + self.embedding_init, + (self.max_relative_position * 2 + 1, self.depth), + ) def __call__(self, length_q: int, length_k: int) -> Tensor: """Applies the RelativePositionEmbeddings module. @@ -782,16 +827,15 @@ def __call__(self, length_q: int, length_k: int) -> Tensor: length_k=length_k, max_relative_position=self.max_relative_position, query_shape=self.query_shape, - decode_step=self.decode_step) - output = jnp.take( - self.embedding, indices=relative_positions_matrix, axis=0) + decode_step=self.decode_step, + ) + output = jnp.take(self.embedding, indices=relative_positions_matrix, axis=0) return output -def relative_attention_inner(input_x: Tensor, - input_y: Tensor, - input_z: Tensor, - transpose: bool) -> Tensor: +def relative_attention_inner( + input_x: Tensor, input_y: Tensor, input_z: Tensor, transpose: bool +) -> Tensor: """Calculates position-aware inner dot-product attention. Args: @@ -807,20 +851,18 @@ def relative_attention_inner(input_x: Tensor, """ if transpose: xy_matmul = jnp.einsum( - 'bhxd,bhyd->bhxy', - input_x, - input_y, - precision=jax.lax.Precision.HIGHEST) + 'bhxd,bhyd->bhxy', input_x, input_y, precision=jax.lax.Precision.HIGHEST + ) x_tz_matmul_r_t = jnp.einsum( - 'bhxd,xyd->bhxy', input_x, input_z, precision=jax.lax.Precision.HIGHEST) + 'bhxd,xyd->bhxy', input_x, input_z, precision=jax.lax.Precision.HIGHEST + ) else: xy_matmul = jnp.einsum( - 'bhxd,bhdy->bhxy', - input_x, - input_y, - precision=jax.lax.Precision.HIGHEST) + 'bhxd,bhdy->bhxy', input_x, input_y, precision=jax.lax.Precision.HIGHEST + ) x_tz_matmul_r_t = jnp.einsum( - 'bhxd,xdy->bhxy', input_x, input_z, precision=jax.lax.Precision.HIGHEST) + 'bhxd,xdy->bhxy', input_x, input_z, precision=jax.lax.Precision.HIGHEST + ) output = xy_matmul + x_tz_matmul_r_t return output @@ -838,19 +880,22 @@ class RelativeDotProductAttention(nn.Module): query_shape: a tuple with a query shape. embedding_init: embeddings initializer. """ + max_relative_position: int = 513 - query_shape: Tuple[int] = (256,) + query_shape: Tuple[int, ...] = (256,) attention_dropout: float = 0.0 decode_step: int = None embedding_init: Any = nn.initializers.glorot_uniform() @nn.compact def __call__( - self, q_array: Tensor, + self, + q_array: Tensor, k_array: Tensor, v_array: Tensor, bias_array: Tensor, - train: bool = False) -> Tensor: + train: bool = False, + ) -> Tensor: """Applies the RelativeDotProductAttention module. Args: @@ -866,9 +911,10 @@ def __call__( weights """ if not self.max_relative_position: - raise ValueError('Max relative position (%s) should be > 0 when using ' - 'relative self attention.' % - (self.max_relative_position)) + raise ValueError( + 'Max relative position (%s) should be > 0 when using ' + 'relative self attention.' % (self.max_relative_position) + ) depth = k_array.shape[3] length_k = k_array.shape[2] length_q = q_array.shape[2] @@ -879,32 +925,32 @@ def __call__( depth=depth, query_shape=self.query_shape, decode_step=self.decode_step, - embedding_init=self.embedding_init)( - length_q=length_q, length_k=length_k) + embedding_init=self.embedding_init, + )(length_q=length_q, length_k=length_k) relations_values = RelativePositionEmbeddings( embed_layer_name='relative_positions_values', max_relative_position=self.max_relative_position, depth=depth, query_shape=self.query_shape, decode_step=self.decode_step, - embedding_init=self.embedding_init)( - length_q=length_q, length_k=length_k) + embedding_init=self.embedding_init, + )(length_q=length_q, length_k=length_k) logits = relative_attention_inner( - input_x=q_array, - input_y=k_array, - input_z=relations_keys, - transpose=True) + input_x=q_array, input_y=k_array, input_z=relations_keys, transpose=True + ) if bias_array is not None: logits += bias_array weights = nn.softmax(logits) if self.attention_dropout: weights = nn.Dropout( - rate=self.attention_dropout, deterministic=not train)(weights) + rate=self.attention_dropout, deterministic=not train + )(weights) output = relative_attention_inner( input_x=weights, input_y=v_array, input_z=relations_values, - transpose=False) + transpose=False, + ) return output, weights @@ -919,19 +965,23 @@ class DotProductAttention(nn.Module): query_shape: a tuple with a query shape. embedding_init: embeddings initializer. """ + local_relative: bool = True max_relative_position: int = 513 attention_dropout: float = 0.0 decode_step: int = None - query_shape: Tuple[int] = (256,) + query_shape: Tuple[int, ...] = (256,) embedding_init: Any = nn.initializers.glorot_uniform() @nn.compact def __call__( - self, q_array: Tensor, + self, + q_array: Tensor, k_array: Tensor, v_array: Tensor, - bias_array: Tensor, train: bool = False) -> Tensor: + bias_array: Tensor, + train: bool = False, + ) -> Tensor: """Applies the DotProductAttention module. Args: @@ -953,11 +1003,13 @@ def __call__( decode_step=self.decode_step, query_shape=self.query_shape, embedding_init=self.embedding_init, - )(q_array=q_array, - k_array=k_array, - v_array=v_array, - bias_array=bias_array, - train=train) + )( + q_array=q_array, + k_array=k_array, + v_array=v_array, + bias_array=bias_array, + train=train, + ) logits = jnp.matmul(q_array, jnp.transpose(k_array, axes=(0, 1, 3, 2))) if bias_array is not None: bias_array = cast_like(input_x=bias_array, input_y=logits) @@ -965,7 +1017,8 @@ def __call__( weights = nn.softmax(logits) if self.attention_dropout: weights = nn.Dropout( - rate=self.attention_dropout, deterministic=not train)(weights) + rate=self.attention_dropout, deterministic=not train + )(weights) output = jnp.matmul(weights, v_array) return output, weights @@ -989,7 +1042,7 @@ def combine_heads_nd(input_x: Tensor) -> Tensor: return output -def put_back_blocks_nd(input_x: Tensor, block_shape: Tuple[int]) -> Tensor: +def put_back_blocks_nd(input_x: Tensor, block_shape: Tuple[int, ...]) -> Tensor: """Restructures an input array from blocks to normal ordering. Args: @@ -1004,21 +1057,30 @@ def put_back_blocks_nd(input_x: Tensor, block_shape: Tuple[int]) -> Tensor: if isinstance(x_shape[-2], int): if x_shape[-2] != np.prod(block_shape): - raise ValueError(f'X_shape[-2] ({x_shape[-2]}) and block ({block_shape})' - ' are not equal.') + raise ValueError( + f'X_shape[-2] ({x_shape[-2]}) and block ({block_shape})' + ' are not equal.' + ) - x = jnp.reshape( - input_x, x_shape[:-2] + list(block_shape) + x_shape[-1:]) + x = jnp.reshape(input_x, x_shape[:-2] + list(block_shape) + x_shape[-1:]) block_dimension_index = list(range(1, len(block_shape) + 1)) block_shape_index = [b + len(block_shape) for b in block_dimension_index] interleaved_dimensions = list( itertools.chain.from_iterable( - zip(block_dimension_index, block_shape_index))) - x = jnp.transpose(x, - [0] + interleaved_dimensions + [2 * len(block_shape) + 1]) - axes = [-1] + [ - x_shape[2 * i + 1] * x_shape[2 * i + 2] for i in range(len(block_shape)) - ] + x_shape[-1:] + zip(block_dimension_index, block_shape_index) + ) + ) + x = jnp.transpose( + x, [0] + interleaved_dimensions + [2 * len(block_shape) + 1] + ) + axes = ( + [-1] + + [ + x_shape[2 * i + 1] * x_shape[2 * i + 2] + for i in range(len(block_shape)) + ] + + x_shape[-1:] + ) output = jnp.reshape(x, axes) return output @@ -1045,9 +1107,10 @@ class LocalAttention(nn.Module): kernel_init: kernel initializer in the Dense layers. embedding_init: embeddings initializer. """ - query_shape: Tuple[int] = (256,) - memory_query_shape: Tuple[int] = (512,) - memory_flange: Tuple[int] = (256,) + + query_shape: Tuple[int, ...] = (256,) + memory_query_shape: Tuple[int, ...] = (512,) + memory_flange: Tuple[int, ...] = (256,) local_num_heads: int = 8 local_relative: bool = True memory_antecedent: bool = None @@ -1064,11 +1127,13 @@ class LocalAttention(nn.Module): embedding_init: Any = nn.initializers.glorot_uniform() @nn.compact - def __call__(self, - q_array: Tensor, - k_array: Tensor, - v_array: Tensor, - train: bool = False) -> Tensor: + def __call__( + self, + q_array: Tensor, + k_array: Tensor, + v_array: Tensor, + train: bool = False, + ) -> Tensor: """Applies the LocalAttention module. Args: @@ -1085,23 +1150,22 @@ def __call__(self, if all([m % b != 0 for m, b in zip(self.memory_flange, self.query_shape)]): raise ValueError( f'Query ({self.query_shape}) and memory ({self.memory_flange})' - ' modulo not equal to 0.') + ' modulo not equal to 0.' + ) if self.decode_step is not None: - q_array = jnp.reshape( - q_array, [-1] + list(q_array.shape[2:])) + q_array = jnp.reshape(q_array, [-1] + list(q_array.shape[2:])) latest_q = get_item_at_decode_step( input_array=q_array, decode_step=self.decode_step, - query_shape=self.query_shape) + query_shape=self.query_shape, + ) q_array = jnp.reshape( - q_array, - [-1, self.local_num_heads] + - list(q_array.shape[1:])) + q_array, [-1, self.local_num_heads] + list(q_array.shape[1:]) + ) latest_q = jnp.reshape( - latest_q, - [-1, self.local_num_heads] + - list(latest_q.shape[1:])) + latest_q, [-1, self.local_num_heads] + list(latest_q.shape[1:]) + ) q_shape = list(latest_q.shape) else: q_shape = list(q_array.shape) @@ -1123,38 +1187,45 @@ def __call__(self, if self.decode_step is None: q_array = pad_to_multiple_nd( - input_x=q_array, block_shape=self.query_shape) + input_x=q_array, block_shape=self.query_shape + ) k_array = pad_to_multiple_nd(input_x=k_array, block_shape=mem_query_shape) v_array = pad_to_multiple_nd(input_x=v_array, block_shape=mem_query_shape) if self.decode_step is None: q_array = break_into_blocks_nd( - input_x=q_array, block_shape=self.query_shape) + input_x=q_array, block_shape=self.query_shape + ) else: q_array = jnp.reshape( - q_array, [-1] + [1] * (len(q_shape) - 3) + [q_shape[-1]]) + q_array, [-1] + [1] * (len(q_shape) - 3) + [q_shape[-1]] + ) k_array = break_into_memory_blocks_nd( input_x=k_array, query_shape=mem_query_shape, memory_flange=self.memory_flange, - masked=self.masked) + masked=self.masked, + ) v_array = break_into_memory_blocks_nd( input_x=v_array, query_shape=mem_query_shape, memory_flange=self.memory_flange, - masked=self.masked) + masked=self.masked, + ) blocks_per_dim = list(q_array.shape[1:-2]) if self.decode_step is not None: k_array = select_block_for_decode_step( input_x=k_array, decode_step=self.decode_step, - query_shape=mem_query_shape) + query_shape=mem_query_shape, + ) v_array = select_block_for_decode_step( input_x=v_array, decode_step=self.decode_step, - query_shape=mem_query_shape) + query_shape=mem_query_shape, + ) q_array = flatten_blocks_nd(input_x=q_array) k_array = flatten_blocks_nd(input_x=k_array) @@ -1172,7 +1243,8 @@ def __call__(self, masked=self.masked, cache_padding_bias=self.cache_padding_bias, decode_step=self.decode_step, - bias_cache=bias_cache) + bias_cache=bias_cache, + ) if self.padding_bias is not None: padding_bias = break_bias_into_blocks(input_bias=self.padding_bias) @@ -1191,40 +1263,40 @@ def __call__(self, attention_dropout=self.attention_dropout, decode_step=self.decode_step, query_shape=self.query_shape, - embedding_init=self.embedding_init)( - q_array=q_array, - k_array=k_array, - v_array=v_array, - bias_array=attn_bias, - train=train) + embedding_init=self.embedding_init, + )( + q_array=q_array, + k_array=k_array, + v_array=v_array, + bias_array=attn_bias, + train=train, + ) output = unflatten_blocks_nd( - input_x=output, blocks_per_dimension=blocks_per_dim) + input_x=output, blocks_per_dimension=blocks_per_dim + ) output = jnp.reshape( - output, - [q_shape[0], self.local_num_heads] + - list(output.shape[1:])) + output, [q_shape[0], self.local_num_heads] + list(output.shape[1:]) + ) outputs.append(output) output = jnp.concatenate(outputs, axis=1) output_shape = list(output.shape) output = jnp.reshape( - output, [output_shape[0], self.local_num_heads, -1, output_shape[-1]]) + output, [output_shape[0], self.local_num_heads, -1, output_shape[-1]] + ) output = nn.Dense( output_shape[-1], kernel_init=self.kernel_init, use_bias=False, - name='dense')( - output) + name='dense', + )(output) output = jnp.reshape(output, output_shape) if self.decode_step is None: - output = jnp.reshape( - output, [-1] + list(output.shape[2:])) - output = put_back_blocks_nd( - input_x=output, block_shape=self.query_shape) - output = jnp.reshape( - output, q_shape[:2] + list(output.shape[1:])) - output = output[0:, 0:, 0:q_shape[2:-1][0], 0:] + output = jnp.reshape(output, [-1] + list(output.shape[2:])) + output = put_back_blocks_nd(input_x=output, block_shape=self.query_shape) + output = jnp.reshape(output, q_shape[:2] + list(output.shape[1:])) + output = output[0:, 0:, 0 : q_shape[2:-1][0], 0:] return output @@ -1246,11 +1318,15 @@ def split_heads_nd(input_x: Tensor, num_heads: int = 8) -> Tensor: m = x_shape[-1] if isinstance(m, int) and isinstance(num_heads, int): if m % num_heads != 0: - raise ValueError(f'X_shape[-1] ({m}) and memory ({num_heads})' - ' modulo not equal to 0.') + raise ValueError( + f'X_shape[-1] ({m}) and memory ({num_heads}) modulo not equal to 0.' + ) x = jnp.reshape(input_x, x_shape[:-1] + [num_heads, m // num_heads]) - axes = [0, num_dimensions + 1] + list(range( - 1, num_dimensions + 1)) + [num_dimensions + 2] + axes = ( + [0, num_dimensions + 1] + + list(range(1, num_dimensions + 1)) + + [num_dimensions + 2] + ) output = jnp.transpose(x, axes=axes) return output @@ -1258,8 +1334,9 @@ def split_heads_nd(input_x: Tensor, num_heads: int = 8) -> Tensor: def put_item_in_decode_step( input_x: Tensor, decode_step: int = None, - query_shape: Tuple[int] = (256,), - replacement: Any = 1.0) -> Tensor: + query_shape: Tuple[int, ...] = (256,), + replacement: Any = 1.0, +) -> Tensor: """Puts a single item into an an array at the `decode_step` position. Args: @@ -1278,13 +1355,11 @@ def put_item_in_decode_step( index = decode_step_to_index( decode_step=decode_step, array_shape=query_shape, - query_shape=x_shape[2:-1]) + query_shape=x_shape[2:-1], + ) flattened_x = jnp.reshape( - input_x, - [ - -1, x_shape[1], - np.prod(x_shape[2:-1]), x_shape[-1] - ]) + input_x, [-1, x_shape[1], np.prod(x_shape[2:-1]), x_shape[-1]] + ) flattened_x = jnp.transpose(flattened_x, axes=[2, 0, 1, 3]) flattened_index = 0 @@ -1325,15 +1400,16 @@ class MultiHeadAttention(nn.Module): kernel_init: kernel initializer in the Dense layers. embedding_init: embeddings initializer. """ + hidden_size: int = 1032 - memory_query_shape: Tuple[int] = (512,) + memory_query_shape: Tuple[int, ...] = (512,) cache: Dict[str, Tensor] = None bias_cache: Dict[str, Tensor] = None memory_antecedent: Tensor = None total_key_depth: int = 1032 total_value_depth: int = 1032 - query_shape: Tuple[int] = (256,) - memory_flange: Tuple[int] = (256,) + query_shape: Tuple[int, ...] = (256,) + memory_flange: Tuple[int, ...] = (256,) local_num_heads: int = 8 local_relative: bool = True masked: bool = True @@ -1360,18 +1436,23 @@ def __call__(self, query_antecedent: Tensor, train: bool = False) -> Tensor: [batch, 1, ..., 1, output_depth] if decode_step is set. """ if self.total_key_depth % self.local_num_heads != 0: - raise ValueError('Key depth (%d) must be divisible by the number of ' - 'attention heads (%d).' % - (self.total_key_depth, self.local_num_heads)) + raise ValueError( + 'Key depth (%d) must be divisible by the number of ' + 'attention heads (%d).' % (self.total_key_depth, self.local_num_heads) + ) if self.total_value_depth % self.local_num_heads != 0: - raise ValueError('Value depth (%d) must be divisible by the number of ' - 'attention heads (%d).' % - (self.total_value_depth, self.local_num_heads)) + raise ValueError( + 'Value depth (%d) must be divisible by the number of ' + 'attention heads (%d).' + % (self.total_value_depth, self.local_num_heads) + ) if self.share_qk: if self.memory_antecedent: - raise ValueError(f'Memory ({self.mmemory_antecedent}) must be None ' - 'if share_qk is True.') + raise ValueError( + f'Memory ({self.mmemory_antecedent}) must be None ' + 'if share_qk is True.' + ) if self.cache is None: cache = {} @@ -1390,50 +1471,57 @@ def __call__(self, query_antecedent: Tensor, train: bool = False) -> Tensor: latest_antecedent = get_item_at_decode_step( input_array=query_antecedent, decode_step=self.decode_step, - query_shape=self.query_shape) + query_shape=self.query_shape, + ) latest_q = nn.Dense( self.total_key_depth, kernel_init=self.kernel_init, use_bias=False, - name='latest_q')( - latest_antecedent) + name='latest_q', + )(latest_antecedent) latest_k = nn.Dense( self.total_key_depth, kernel_init=self.kernel_init, use_bias=False, - name='latest_k')( - latest_antecedent) + name='latest_k', + )(latest_antecedent) latest_v = nn.Dense( self.total_value_depth, kernel_init=self.kernel_init, use_bias=False, - name='latest_v')( - latest_antecedent) + name='latest_v', + )(latest_antecedent) latest_q = split_heads_nd( - input_x=latest_q, num_heads=self.local_num_heads) + input_x=latest_q, num_heads=self.local_num_heads + ) key_depth_per_head = self.total_key_depth // self.local_num_heads latest_k = split_heads_nd( - input_x=latest_k, num_heads=self.local_num_heads) + input_x=latest_k, num_heads=self.local_num_heads + ) latest_v = split_heads_nd( - input_x=latest_v, num_heads=self.local_num_heads) + input_x=latest_v, num_heads=self.local_num_heads + ) q_array = cache['q'] k_array = cache['k'] v_array = cache['v'] - q_array = put_item_in_decode_step(q_array, latest_q, self.decode_step, - self.query_shape) + q_array = put_item_in_decode_step( + q_array, latest_q, self.decode_step, self.query_shape + ) if self.memory_antecedent is None: k_array = put_item_in_decode_step( input_x=k_array, replacement=latest_k, decode_step=self.decode_step, - query_shape=self.query_shape) + query_shape=self.query_shape, + ) v_array = put_item_in_decode_step( input_x=v_array, replacement=latest_v, decode_step=self.decode_step, - query_shape=self.query_shape) + query_shape=self.query_shape, + ) cache['q'] = q_array cache['k'] = k_array @@ -1444,24 +1532,24 @@ def __call__(self, query_antecedent: Tensor, train: bool = False) -> Tensor: self.total_key_depth, kernel_init=self.kernel_init, use_bias=False, - name='q')( - query_antecedent) + name='q', + )(query_antecedent) k_array = nn.Dense( self.total_key_depth, kernel_init=self.kernel_init, use_bias=False, - name='k')( - query_antecedent) + name='k', + )(query_antecedent) v_array = nn.Dense( self.total_value_depth, kernel_init=self.kernel_init, use_bias=False, - name='v')( - query_antecedent) + name='v', + )(query_antecedent) q_array = split_heads_nd(input_x=q_array, num_heads=self.local_num_heads) key_depth_per_head = self.total_key_depth // self.local_num_heads - q_array *= key_depth_per_head ** -0.5 + q_array *= key_depth_per_head**-0.5 k_array = split_heads_nd(input_x=k_array, num_heads=self.local_num_heads) v_array = split_heads_nd(input_x=v_array, num_heads=self.local_num_heads) @@ -1492,8 +1580,8 @@ def __call__(self, query_antecedent: Tensor, train: bool = False) -> Tensor: token_bias=self.token_bias, padding_bias=self.padding_bias, kernel_init=self.kernel_init, - embedding_init=self.embedding_init)( - q_array=q_array, k_array=k_array, v_array=v_array, train=train) + embedding_init=self.embedding_init, + )(q_array=q_array, k_array=k_array, v_array=v_array, train=train) output = combine_heads_nd(input_x=output) @@ -1501,18 +1589,20 @@ def __call__(self, query_antecedent: Tensor, train: bool = False) -> Tensor: self.hidden_size, kernel_init=self.kernel_init, use_bias=False, - name='output_transform')( - output) + name='output_transform', + )(output) return output # TODO(krasowiak): Potentially replace lines 1506-1563 # with modified AddPositionEmbs from init2winit.model_lib.transformer_lm -def get_timing_signal_1d(length: int, - channels: int, - min_timescale: float = 1.0, - max_timescale: float = 1.0e4, - start_index: int = 0) -> Tensor: +def get_timing_signal_1d( + length: int, + channels: int, + min_timescale: float = 1.0, + max_timescale: float = 1.0e4, + start_index: int = 0, +) -> Tensor: """Positional encoding helper function. Args: @@ -1527,20 +1617,17 @@ def get_timing_signal_1d(length: int, """ position = jnp.arange(length + start_index, dtype=jnp.float32) num_timescales = channels // 2 - log_timescale_increment = ( - math.log(float(max_timescale) / float(min_timescale)) / - jnp.maximum(num_timescales - 1, 1)) + log_timescale_increment = math.log( + float(max_timescale) / float(min_timescale) + ) / jnp.maximum(num_timescales - 1, 1) inv_timescales = min_timescale * jnp.exp( - jnp.arange(num_timescales, dtype=jnp.float32) * - -log_timescale_increment) - scaled_time = jnp.expand_dims( - position, axis=1) * jnp.expand_dims( - inv_timescales, axis=0) - signal = jnp.concatenate( - [jnp.sin(scaled_time), - jnp.cos(scaled_time)], axis=1) - signal = jnp.pad( - signal, [[0, 0], [0, jnp.mod(channels, 2)]]) + jnp.arange(num_timescales, dtype=jnp.float32) * -log_timescale_increment + ) + scaled_time = jnp.expand_dims(position, axis=1) * jnp.expand_dims( + inv_timescales, axis=0 + ) + signal = jnp.concatenate([jnp.sin(scaled_time), jnp.cos(scaled_time)], axis=1) + signal = jnp.pad(signal, [[0, 0], [0, jnp.mod(channels, 2)]]) output = jnp.reshape(signal, [1, length, channels]) return output @@ -1562,7 +1649,8 @@ class ProcessInput(nn.Module): preprocessing_init: initializer for the first preprocessing step. embedding_init: embeddings initializer. """ - query_shape: Tuple[int] = (256,) + + query_shape: Tuple[int, ...] = (256,) preprocess_dropout_a: float = 0.0 vocab_size: int = 98302 embedding_dims: int = 1032 @@ -1576,9 +1664,11 @@ class ProcessInput(nn.Module): embed_layer_name: str = 'embeddings' def setup(self): - self.embedding = self.param(self.embed_layer_name, - self.embedding_init, - (self.vocab_size, self.embedding_dims)) + self.embedding = self.param( + self.embed_layer_name, + self.embedding_init, + (self.vocab_size, self.embedding_dims), + ) @nn.compact def __call__(self, input_x: Tensor, train: bool = False) -> Tensor: @@ -1601,10 +1691,14 @@ def __call__(self, input_x: Tensor, train: bool = False) -> Tensor: raise ValueError(f'Query length ({self.query_shape}) is not equal 1.') if self.preprocess_dropout_a: - dropout_array = self.param('dropout_array', self.preprocessing_init, - shape) - input_x = jnp.where(dropout_array < self.preprocess_dropout_a, - jnp.zeros_like(input_x), input_x) + dropout_array = self.param( + 'dropout_array', self.preprocessing_init, shape + ) + input_x = jnp.where( + dropout_array < self.preprocess_dropout_a, + jnp.zeros_like(input_x), + input_x, + ) output = jnp.expand_dims(input_x, axis=-1) output = shift_right(output) @@ -1613,24 +1707,26 @@ def __call__(self, input_x: Tensor, train: bool = False) -> Tensor: if self.preprocess_dropout_b: output = nn.Dropout( - rate=self.preprocess_dropout_b, deterministic=not train)( - output) + rate=self.preprocess_dropout_b, deterministic=not train + )(output) output = nn.Dense( self.hidden_size, kernel_init=self.kernel_init, use_bias=True, - name='emb_dense')( - output) + name='emb_dense', + )(output) if self.add_timing_signal: output += get_timing_signal_1d( - length=self.max_target_length, channels=self.hidden_size) + length=self.max_target_length, channels=self.hidden_size + ) return output def process_partial_targets_decoding( - targets: Tensor, query_shape: Tuple[int] = (256,)) -> Tensor: + targets: Tensor, query_shape: Tuple[int, ...] = (256,) +) -> Tensor: """Preprocesses tokenized input sequences in the decoding process. Args: @@ -1644,13 +1740,11 @@ def process_partial_targets_decoding( seq_length = targets_shape[1] blocks_per_dim = [seq_length // q for q in query_shape] targets = jnp.reshape( - targets, - [ - targets_shape[0], -1, - np.prod(query_shape), 1 - ]) + targets, [targets_shape[0], -1, np.prod(query_shape), 1] + ) targets = unflatten_blocks_nd( - input_x=targets, blocks_per_dimension=blocks_per_dim) + input_x=targets, blocks_per_dimension=blocks_per_dim + ) targets = put_back_blocks_nd(input_x=targets, block_shape=query_shape) outputs = jnp.reshape(targets, [-1, seq_length]) return outputs @@ -1663,14 +1757,14 @@ class LayerPostProcess(nn.Module): post_attention_epsilon: an epsilon value for LayerNorm. post_attention_dropout: a dropout rate in the postprocessing process. """ + post_attention_epsilon: float = 1e-6 post_attention_dropout: float = 0.1 @nn.compact - def __call__(self, - input_x: Tensor, - input_y: Tensor, - train: bool = False) -> Tensor: + def __call__( + self, input_x: Tensor, input_y: Tensor, train: bool = False + ) -> Tensor: """Applies the LayerPostProcess module. Args: @@ -1684,13 +1778,15 @@ def __call__(self, output: a postprocessed array of shape [batch, max_target_length, hidden_size]. """ - y = nn.Dropout( - rate=self.post_attention_dropout, deterministic=not train)(input_y) + y = nn.Dropout(rate=self.post_attention_dropout, deterministic=not train)( + input_y + ) output = y + input_x output = nn.LayerNorm( - epsilon=self.post_attention_epsilon, name='layer_norm')(output) + epsilon=self.post_attention_epsilon, name='layer_norm' + )(output) return output @@ -1728,6 +1824,7 @@ class DecoderBlock(nn.Module): kernel_init: kernel initializer in the Dense layers. embedding_init: embeddings initializer. """ + layer: int = None decoder_dropout_a: float = 0.1 post_attention_epsilon: float = 1e-6 @@ -1738,9 +1835,9 @@ class DecoderBlock(nn.Module): memory_antecedent: Tensor = None total_key_depth: int = 1032 total_value_depth: int = 1032 - query_shape: Tuple[int] = (256,) - memory_query_shape: Tuple[int] = (512,) - memory_flange: int = (256,) + query_shape: Tuple[int, ...] = (256,) + memory_query_shape: Tuple[int, ...] = (512,) + memory_flange: Tuple[int, ...] = (256,) local_num_heads: int = 8 local_relative: bool = True masked: bool = True @@ -1768,8 +1865,9 @@ def __call__(self, input_x: Tensor, train: bool = False) -> Tensor: Returns: output: an array of shape [batch, max_target_length, hidden_size]. """ - x = nn.Dropout( - rate=self.decoder_dropout_a, deterministic=not train)(input_x) + x = nn.Dropout(rate=self.decoder_dropout_a, deterministic=not train)( + input_x + ) if self.cache is None: cache = {} @@ -1809,29 +1907,31 @@ def __call__(self, input_x: Tensor, train: bool = False) -> Tensor: token_bias=self.token_bias, padding_bias=self.padding_bias, kernel_init=self.kernel_init, - embedding_init=self.embedding_init)(query_antecedent=x, train=train) + embedding_init=self.embedding_init, + )(query_antecedent=x, train=train) x = LayerPostProcess( post_attention_epsilon=self.post_attention_epsilon, - post_attention_dropout=self.post_attention_dropout)( - input_x=x, input_y=y, train=train) + post_attention_dropout=self.post_attention_dropout, + )(input_x=x, input_y=y, train=train) y = FeedForward( feedforward_depths=self.feedforward_depths, feedforward_dropout=self.feedforward_dropout, - kernel_init=self.kernel_init)( - input_x=x, train=train) + kernel_init=self.kernel_init, + )(input_x=x, train=train) output = LayerPostProcess( post_attention_epsilon=self.post_attention_epsilon, - post_attention_dropout=self.post_attention_dropout)( - input_x=x, input_y=y, train=train) + post_attention_dropout=self.post_attention_dropout, + )(input_x=x, input_y=y, train=train) if self.decode_step is not None: output = get_item_at_decode_step( input_array=output, decode_step=self.decode_step, - query_shape=self.query_shape) + query_shape=self.query_shape, + ) return output @@ -1847,11 +1947,11 @@ class LocalAttentionTransformerArchitecture(nn.Module): query_shape: a tuple with a query shape. max_target_length: the maximum allowed length of the tokenized sequence. preprocess_dropout_a: a dropout rate in the first preprocessing Dropout - layer. + layer. embedding_dims: an embedding dimension. vocab_size: vocabulary size. preprocess_dropout_b: a dropout rate in the second preprocessing Dropout - layer. + layer. hidden_size: a hidden size depth. add_timing_signal: whether to add positional encoding. num_decoder_layers: number of decoder blocks/layers to use in the @@ -1880,12 +1980,13 @@ class LocalAttentionTransformerArchitecture(nn.Module): feedforward_depths: number of neurons in the 1st and 2nd Dense layers. dtype: data types of the final logits. """ + kernel_init: Any = nn.initializers.glorot_uniform() preprocessing_init: Any = nn.initializers.uniform() embedding_init: Any = nn.initializers.glorot_uniform() decode: bool = False decode_step: int = None - query_shape: Tuple[int] = (256,) + query_shape: Tuple[int, ...] = (256,) max_target_length: int = 8192 preprocess_dropout_a: float = 0.0 embedding_dims: int = 1032 @@ -1901,8 +2002,8 @@ class LocalAttentionTransformerArchitecture(nn.Module): bias_cache: Dict[str, Tensor] = None local_num_heads: int = 8 cache: Dict[str, Tensor] = None - memory_query_shape: Tuple[int] = (512,) - memory_flange: Tuple[int] = (256,) + memory_query_shape: Tuple[int, ...] = (512,) + memory_flange: Tuple[int, ...] = (256,) cache_padding_bias: bool = False max_relative_position: int = 513 attention_dropout: float = 0.0 @@ -1918,9 +2019,7 @@ class LocalAttentionTransformerArchitecture(nn.Module): dtype: str = 'float32' @nn.compact - def __call__(self, - input_x: Tensor, - train: bool = False) -> Tensor: + def __call__(self, input_x: Tensor, train: bool = False) -> Tensor: """Applies the TransformerLocalAttentionArchitecture module. Args: @@ -1941,7 +2040,8 @@ def __call__(self, if self.decode: x = process_partial_targets_decoding( - targets=x, query_shape=self.query_shape) + targets=x, query_shape=self.query_shape + ) x = ProcessInput( preprocess_dropout_a=self.preprocess_dropout_a, @@ -1953,8 +2053,8 @@ def __call__(self, max_target_length=self.max_target_length, kernel_init=self.kernel_init, preprocessing_init=self.preprocessing_init, - embedding_init=self.embedding_init)( - input_x=x, train=train) + embedding_init=self.embedding_init, + )(input_x=x, train=train) for layer in range(self.num_decoder_layers): y = DecoderBlock( @@ -1984,14 +2084,15 @@ def __call__(self, feedforward_depths=self.feedforward_depths, feedforward_dropout=self.feedforward_dropout, kernel_init=self.kernel_init, - embedding_init=self.embedding_init)( - input_x=x, train=train) + embedding_init=self.embedding_init, + )(input_x=x, train=train) output = nn.Dense( self.vocab_size, kernel_init=self.kernel_init, use_bias=True, - name='final_dense_2')(y) + name='final_dense_2', + )(y) output = output.astype(self.dtype) return output @@ -2004,7 +2105,8 @@ def evaluate_batch(self, params, batch_stats, batch): """Returns evaulation metrics on the given batch.""" variables = {'params': params, 'batch_stats': batch_stats} logits = self.flax_module.apply( - variables, batch['inputs'], mutable=False, train=False) + variables, batch['inputs'], mutable=False, train=False + ) targets = batch['targets'] # Class 0 is reserved for padding weights = jnp.not_equal(targets, 0).astype(jnp.float32) @@ -2031,7 +2133,8 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): apply_kwargs['rngs'] = {'dropout': dropout_rng} logits, new_batch_stats = self.flax_module.apply( - all_variables, batch['inputs'], **apply_kwargs) + all_variables, batch['inputs'], **apply_kwargs + ) targets = batch['targets'] # Class 0 is reserved for padding weights = jnp.not_equal(targets, 0).astype(jnp.float32) @@ -2081,7 +2184,8 @@ def build_flax_module(self): post_attention_dropout=self.hps.post_attention_dropout, feedforward_dropout=self.hps.feedforward_dropout, feedforward_depths=self.hps.feedforward_depths, - dtype=self.hps.model_dtype) + dtype=self.hps.model_dtype, + ) def get_fake_inputs(self, hps): """Helper method solely for the purpose of initialzing the model.""" diff --git a/init2winit/model_lib/lstm.py b/init2winit/model_lib/lstm.py index 96225b60..a4b95ffc 100644 --- a/init2winit/model_lib/lstm.py +++ b/init2winit/model_lib/lstm.py @@ -16,6 +16,7 @@ """Module with flax LSTM class. """ + import abc import functools from typing import (Any, Mapping, Optional, Sequence, Tuple, Type, Union) @@ -64,10 +65,11 @@ def flip_sequences(inputs: Array, lengths: Array) -> Array: return inputs[idxs] -def sample_recurrent_dropout_mask(rng: Any, rate: float, batch_size: int, - hidden_size: int) -> Optional[Array]: +def sample_recurrent_dropout_mask( + rng: Any, rate: float, batch_size: int, hidden_size: int +) -> Optional[Array]: """Samples a recurrent dropout mask.""" - if rate == 0.: + if rate == 0.0: return None mask = random.bernoulli(rng, p=1 - rate, shape=(batch_size, hidden_size)) # Scale recurrent dropout mask to control for magnitude at test time. @@ -77,28 +79,34 @@ def sample_recurrent_dropout_mask(rng: Any, rate: float, batch_size: int, class RecurrentDropoutCell(abc.ABC): """Interface for cells that know how to apply recurrent dropout.""" - def __call__(self, - cell_state: StateType, - inputs: Array, - recurrent_dropout_mask: Optional[Array], - deterministic: bool = False): + def __call__( + self, + cell_state: StateType, + inputs: Array, + recurrent_dropout_mask: Optional[Array], + deterministic: bool = False, + ): pass - def get_recurrent_dropout_mask(self, rate: float, batch_size: int, - hidden_size: int): + def get_recurrent_dropout_mask( + self, rate: float, batch_size: int, hidden_size: int + ): pass -class RecurrentDropoutOptimizedLSTMCell(nn.OptimizedLSTMCell, - RecurrentDropoutCell): +class RecurrentDropoutOptimizedLSTMCell( + nn.OptimizedLSTMCell, RecurrentDropoutCell +): """An optimized LSTM cell that applies recurrent dropout on h (and not c).""" @nn.compact - def __call__(self, # pytype: disable=signature-mismatch # jax-ndarray - cell_state: Tuple[Array, Array], - inputs: Array, - recurrent_dropout_mask: Optional[Array] = None, - deterministic: bool = False): + def __call__( + self, # pytype: disable=signature-mismatch # jax-ndarray + cell_state: Tuple[Array, Array], + inputs: Array, + recurrent_dropout_mask: Optional[Array] = None, + deterministic: bool = False, + ): """Applies recurrent dropout on h in the state and performs one step.""" if not deterministic and recurrent_dropout_mask is not None: c, h = cell_state @@ -106,8 +114,9 @@ def __call__(self, # pytype: disable=signature-mismatch # jax-ndarray return super().__call__(cell_state, inputs) # pylint: disable=no-value-for-parameter - def get_recurrent_dropout_mask(self, rate: float, batch_size: int, - hidden_size: int): + def get_recurrent_dropout_mask( + self, rate: float, batch_size: int, hidden_size: int + ): """Returns a recurrent dropout mask for this cell.""" rng = self.make_rng('dropout') return sample_recurrent_dropout_mask(rng, rate, batch_size, hidden_size) @@ -128,6 +137,7 @@ class GenericRNNSequenceEncoder(nn.Module): greater than zero, you must use an RNN cell that implements `RecurrentDropoutCell` such as RecurrentDropoutOptimizedLSTMCell. """ + hidden_size: int cell_type: Type[nn.RNNCellBase] cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() @@ -141,9 +151,15 @@ def setup(self): variable_broadcast=('params', 'params_axes'), in_axes=(1, flax.core.axes_scan.broadcast, flax.core.axes_scan.broadcast), out_axes=1, - split_rngs={'params': False}) - def unroll_cell(self, cell_state: StateType, inputs: Array, - recurrent_dropout_mask: Optional[Array], deterministic: bool): + split_rngs={'params': False}, + ) + def unroll_cell( + self, + cell_state: StateType, + inputs: Array, + recurrent_dropout_mask: Optional[Array], + deterministic: bool, + ): """Unrolls a recurrent cell over an input sequence. Args: @@ -162,19 +178,22 @@ def unroll_cell(self, cell_state: StateType, inputs: Array, # This returns both the state and the output, so we can slice out the # correct final states later. if isinstance(self.cell, RecurrentDropoutCell): - new_cell_state, output = self.cell(cell_state, inputs, - recurrent_dropout_mask, deterministic) + new_cell_state, output = self.cell( + cell_state, inputs, recurrent_dropout_mask, deterministic + ) else: new_cell_state, output = self.cell(cell_state, inputs) return new_cell_state, (new_cell_state, output) - def __call__(self, - inputs: Array, - lengths: Array, - initial_state: StateType, - reverse: bool = False, - deterministic: bool = False): + def __call__( + self, + inputs: Array, + lengths: Array, + initial_state: StateType, + reverse: bool = False, + deterministic: bool = False, + ): """Unrolls the RNN cell over the inputs. Arguments: @@ -198,26 +217,29 @@ def __call__(self, inputs = flip_sequences(inputs, lengths) # Sample a recurrent dropout mask if recurrent dropout is requested. - if self.recurrent_dropout_rate > 0. and not deterministic: + if self.recurrent_dropout_rate > 0.0 and not deterministic: if not isinstance(self.cell, RecurrentDropoutCell): - raise ValueError( - ('The provided cell does not support recurrent dropout, but ' - f'recurrent_dropout_rate is set to {self.recurrent_dropout_rate}. ' - 'Please provide a cell that implements `RecurrentDropoutCell`.')) + raise ValueError(( + 'The provided cell does not support recurrent dropout, but ' + f'recurrent_dropout_rate is set to {self.recurrent_dropout_rate}. ' + 'Please provide a cell that implements `RecurrentDropoutCell`.' + )) recurrent_dropout_mask = self.cell.get_recurrent_dropout_mask( rate=self.recurrent_dropout_rate, batch_size=inputs.shape[0], - hidden_size=self.hidden_size) + hidden_size=self.hidden_size, + ) else: recurrent_dropout_mask = None - _, (cell_states, outputs) = self.unroll_cell(initial_state, inputs, - recurrent_dropout_mask, - deterministic) + _, (cell_states, outputs) = self.unroll_cell( + initial_state, inputs, recurrent_dropout_mask, deterministic + ) final_state = jax.tree.map( - lambda x: x[jnp.arange(inputs.shape[0]), lengths - 1], cell_states) + lambda x: x[jnp.arange(inputs.shape[0]), lengths - 1], cell_states + ) if reverse: outputs = flip_sequences(outputs, lengths) @@ -246,10 +268,11 @@ class GenericRNN(nn.Module): residual_connections: Add residual connection between layers. cell_kwargs: Optional keyword arguments to instantiate the cell with. """ + cell_type: Type[nn.RNNCellBase] hidden_sizes: Sequence[int] - dropout_rate: float = 0. - recurrent_dropout_rate: float = 0. + dropout_rate: float = 0.0 + recurrent_dropout_rate: float = 0.0 bidirectional: bool = False residual_connections: bool = False cell_kwargs: Mapping[str, Any] = flax.core.FrozenDict() @@ -260,7 +283,8 @@ def __call__( inputs: Array, lengths: Array, initial_states: Optional[Sequence[StateType]] = None, - deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]: + deterministic: bool = False, + ) -> Tuple[Array, Sequence[StateType]]: """Processes the input sequence using the recurrent cell. Args: @@ -353,15 +377,16 @@ def __call__( if self.residual_connections: assert inputs.shape == outputs.shape, ( f'For residual connections, inputs ({inputs.shape}) and ' - f'outputs ({outputs.shape}) must be the same shape.') + f'outputs ({outputs.shape}) must be the same shape.' + ) inputs += outputs else: inputs = outputs # Apply dropout between layers. - inputs = nn.Dropout( - rate=self.dropout_rate, deterministic=deterministic)( - inputs) + inputs = nn.Dropout(rate=self.dropout_rate, deterministic=deterministic)( + inputs + ) return outputs, final_states @@ -384,9 +409,10 @@ class LSTM(nn.Module): best for hidden sizes up to 2048. cell_kwargs: Optional keyword arguments to instantiate the cell with. """ + hidden_sizes: Sequence[int] - dropout_rate: float = 0. - recurrent_dropout_rate: float = 0. + dropout_rate: float = 0.0 + recurrent_dropout_rate: float = 0.0 bidirectional: bool = False residual_connections: bool = False cell_type: Any = RecurrentDropoutOptimizedLSTMCell @@ -398,7 +424,8 @@ def __call__( inputs: Array, lengths: Array, initial_states: Optional[Sequence[StateType]] = None, - deterministic: bool = False) -> Tuple[Array, Sequence[StateType]]: + deterministic: bool = False, + ) -> Tuple[Array, Sequence[StateType]]: """Processes an input sequence with an LSTM cell. Example usage: diff --git a/init2winit/model_lib/lstm_lm.py b/init2winit/model_lib/lstm_lm.py index 5bc2b554..695e0246 100644 --- a/init2winit/model_lib/lstm_lm.py +++ b/init2winit/model_lib/lstm_lm.py @@ -18,6 +18,7 @@ Inspired by https://github.com/pytorch/examples/blob/main/word_language_model/model.py """ + from typing import Any, Mapping, Tuple, Union import flax @@ -55,7 +56,8 @@ def shift_right(x, axis=1): pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(MASK_TOKEN)) + x, pad_widths, mode='constant', constant_values=x.dtype.type(MASK_TOKEN) + ) return lax.dynamic_slice_in_dim(padded, 0, padded.shape[axis] - 1, axis) @@ -73,8 +75,8 @@ class LSTMLM(nn.Module): bidirectional: Process the sequence left-to-right and right-to-left and contatenate the outputs of the two directions. residual_connections: Add residual connection between layers. - tie_embeddings: If true share weights of the input embedding layer - with the output embeddings. + tie_embeddings: If true share weights of the input embedding layer with the + output embeddings. projection_layer: If true add projection layer after LSTM from LSTM output size (hidden_sizes[-1]) to emb_dim. Useful if tie_embeddings is True and hidden_sizes[-1] is not equal to emb_dim. @@ -83,11 +85,12 @@ class LSTMLM(nn.Module): using `flax.linen.LSTMCell` instead. cell_kwargs: Optional keyword arguments to instatiate the cell with. """ + emb_dim: int vocab_size: int hidden_sizes: list[int] - dropout_rate: float = 0. - recurrent_dropout_rate: float = 0. + dropout_rate: float = 0.0 + recurrent_dropout_rate: float = 0.0 bidirectional: bool = False residual_connections: bool = False tie_embeddings: bool = False @@ -148,8 +151,9 @@ def __call__( deterministic=(not train), ) # Apply dropout on outputs - output = nn.Dropout( - rate=self.dropout_rate, deterministic=(not train))(output) + output = nn.Dropout(rate=self.dropout_rate, deterministic=(not train))( + output + ) # Optionally apply projection layer if self.projection_layer: @@ -174,8 +178,8 @@ def evaluate_batch(self, params, batch_stats, batch): metric_fn(logits, targets, weights), including the argument names. Args: - params: A dict of trainable model parameters. Passed as - {'params': params} into flax_module.apply(). + params: A dict of trainable model parameters. Passed as {'params': params} + into flax_module.apply(). batch_stats: A dict of non-trainable model state. Passed as {'batch_stats': batch_stats} into flax_module.apply(). batch: A dictionary with keys 'inputs', 'targets', 'weights'. @@ -218,6 +222,7 @@ def build_flax_module(self): ) def get_fake_inputs(self, hps): - dummy_inputs = jnp.ones((hps.batch_size, hps.sequence_length), - dtype='int32') + dummy_inputs = jnp.ones( + (hps.batch_size, hps.sequence_length), dtype='int32' + ) return [dummy_inputs] diff --git a/init2winit/model_lib/max_pooling_cnn.py b/init2winit/model_lib/max_pooling_cnn.py index d73d0a64..c851195b 100644 --- a/init2winit/model_lib/max_pooling_cnn.py +++ b/init2winit/model_lib/max_pooling_cnn.py @@ -18,6 +18,7 @@ This model can be used to implement the 3c3d architecture from: https://github.com/fsschneider/DeepOBS/blob/master/deepobs/tensorflow/testproblems/_3c3d.py """ + from typing import Any, Sequence from flax import linen as nn @@ -25,23 +26,23 @@ from init2winit.model_lib import model_utils from jax.nn import initializers import jax.numpy as jnp - from ml_collections.config_dict import config_dict - # small hparams used for unit tests -DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - num_filters=[64, 96, 128], - kernel_sizes=[5, 3, 3], - kernel_paddings=['VALID', 'VALID', 'SAME'], - window_sizes=[3, 3, 3], - window_paddings=['SAME', 'SAME', 'SAME'], - strides=[2, 2, 2], - num_dense_units=[512, 256], - activation_fn='relu', - normalizer='none', - model_dtype='float32', -)) +DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + num_filters=[64, 96, 128], + kernel_sizes=[5, 3, 3], + kernel_paddings=['VALID', 'VALID', 'SAME'], + window_sizes=[3, 3, 3], + window_paddings=['SAME', 'SAME', 'SAME'], + strides=[2, 2, 2], + num_dense_units=[512, 256], + activation_fn='relu', + normalizer='none', + model_dtype='float32', + ) +) class MaxPoolingCNN(nn.Module): @@ -49,6 +50,7 @@ class MaxPoolingCNN(nn.Module): The model assumes the input shape is [batch, H, W, C]. """ + num_outputs: int num_filters: Sequence[int] kernel_sizes: Sequence[int] @@ -66,31 +68,47 @@ class MaxPoolingCNN(nn.Module): def __call__(self, x, train): maybe_normalize = model_utils.get_normalizer(self.normalizer, train) iterator = zip( - self.num_filters, self.kernel_sizes, self.kernel_paddings, - self.window_sizes, self.window_paddings, self.strides) - for num_filters, kernel_size, kernel_padding, window_size, window_padding, stride in iterator: + self.num_filters, + self.kernel_sizes, + self.kernel_paddings, + self.window_sizes, + self.window_paddings, + self.strides, + ) + for ( + num_filters, + kernel_size, + kernel_padding, + window_size, + window_padding, + stride, + ) in iterator: x = nn.Conv( - num_filters, (kernel_size, kernel_size), (1, 1), + num_filters, + (kernel_size, kernel_size), + (1, 1), padding=kernel_padding, kernel_init=self.kernel_init, - bias_init=self.bias_init)(x) + bias_init=self.bias_init, + )(x) x = model_utils.ACTIVATIONS[self.activation_fn](x) x = maybe_normalize()(x) x = nn.max_pool( x, window_shape=(window_size, window_size), strides=(stride, stride), - padding=window_padding) + padding=window_padding, + ) x = jnp.reshape(x, (x.shape[0], -1)) for num_units in self.num_dense_units: x = nn.Dense( - num_units, kernel_init=self.kernel_init, bias_init=self.bias_init)(x) + num_units, kernel_init=self.kernel_init, bias_init=self.bias_init + )(x) x = model_utils.ACTIVATIONS[self.activation_fn](x) x = maybe_normalize()(x) x = nn.Dense( - self.num_outputs, - kernel_init=self.kernel_init, - bias_init=self.bias_init)(x) + self.num_outputs, kernel_init=self.kernel_init, bias_init=self.bias_init + )(x) return x @@ -109,7 +127,8 @@ def build_flax_module(self): strides=self.hps.strides, num_dense_units=self.hps.num_dense_units, activation_fn=self.hps.activation_fn, - normalizer=self.hps.normalizer) + normalizer=self.hps.normalizer, + ) def get_fake_inputs(self, hps): """Helper method solely for the purpose of initialzing the model.""" diff --git a/init2winit/model_lib/metrics_minimized_registry.py b/init2winit/model_lib/metrics_minimized_registry.py index cc15b5ed..9351862c 100644 --- a/init2winit/model_lib/metrics_minimized_registry.py +++ b/init2winit/model_lib/metrics_minimized_registry.py @@ -18,10 +18,10 @@ This file is useful when writing configs that perform tuning studies. In general, users should call is_minimized on eval metric column names. """ + import collections import itertools - MIN_EVAL_METRICS = [ 'ce_loss', 'error_rate', @@ -40,7 +40,8 @@ def generate_eval_cols(metrics: collections.abc.Iterable[str]) -> list[str]: MINIMIZE_REGISTRY = {k: True for k in generate_eval_cols(MIN_EVAL_METRICS)} MINIMIZE_REGISTRY.update( - {k: False for k in generate_eval_cols(MAX_EVAL_METRICS)}) + {k: False for k in generate_eval_cols(MAX_EVAL_METRICS)} +) MINIMIZE_REGISTRY['train_cost'] = True MINIMIZE_REGISTRY['callback/wmt14_translate/de-en/valid/bleu_score'] = False MINIMIZE_REGISTRY['callback/wmt14_translate/de-en/test/bleu_score'] = False @@ -55,5 +56,7 @@ def is_minimized(col_name: str) -> bool: if col in col_name: return MINIMIZE_REGISTRY[col] - raise ValueError(f'Column {col_name} not found in `MINIMIZE_REGISTRY` as ' - 'either a column name or a substring of a column name.') + raise ValueError( + f'Column {col_name} not found in `MINIMIZE_REGISTRY` as ' + 'either a column name or a substring of a column name.' + ) diff --git a/init2winit/model_lib/mlperf_resnet.py b/init2winit/model_lib/mlperf_resnet.py index 5bbf85f3..ce070e99 100644 --- a/init2winit/model_lib/mlperf_resnet.py +++ b/init2winit/model_lib/mlperf_resnet.py @@ -14,8 +14,8 @@ # limitations under the License. """Flax implementation of the MLPerf ResNet V1.5 model.""" -import functools +import functools from typing import Any, Optional, Tuple from flax import linen as nn @@ -26,43 +26,48 @@ import jax.numpy as jnp from ml_collections.config_dict import config_dict - -FAKE_MODEL_DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - num_filters=16, - num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] - model_dtype='float32', - virtual_batch_size=64, - data_format='NHWC', - activation_function='relu', - dropout_rate=0.0, -)) +FAKE_MODEL_DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + num_filters=16, + num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] + model_dtype='float32', + virtual_batch_size=64, + data_format='NHWC', + activation_function='relu', + dropout_rate=0.0, + ) +) # Used for the mlperf version of Resnet. -MLPERF_DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - num_filters=16, - # We set default to 18 for faster unit tests. - num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] - bn_output_scale=0.0, - batch_norm_momentum=0.9, - batch_norm_epsilon=1e-5, - model_dtype='float32', - virtual_batch_size=64, - data_format='NHWC', - activation_function='relu', - dropout_rate=0.0, -)) +MLPERF_DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + num_filters=16, + # We set default to 18 for faster unit tests. + num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] + bn_output_scale=0.0, + batch_norm_momentum=0.9, + batch_norm_epsilon=1e-5, + model_dtype='float32', + virtual_batch_size=64, + data_format='NHWC', + activation_function='relu', + dropout_rate=0.0, + ) +) def _constant_init(factor): def init_fn(key, shape, dtype=jnp.float32): del key return jnp.ones(shape, dtype) * factor + return init_fn class ResidualBlock(nn.Module): """Bottleneck ResNet block.""" + filters: int strides: Tuple[int, int] = (1, 1) axis_name: Optional[str] = None @@ -90,14 +95,17 @@ def __call__(self, x, train): batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, - data_format=self.data_format) + data_format=self.data_format, + ) conv = functools.partial(nn.Conv, use_bias=False, dtype=self.dtype) residual = x if needs_projection: - residual = conv( - self.filters * 4, (1, 1), self.strides, name='proj_conv')(residual) + residual = conv(self.filters * 4, (1, 1), self.strides, name='proj_conv')( + residual + ) residual = batch_norm(name='proj_bn')( - residual, use_running_average=not train) + residual, use_running_average=not train + ) y = conv(self.filters, (1, 1), name='conv1')(x) y = batch_norm(name='bn1')(y, use_running_average=not train) activation_fn = model_utils.ACTIVATIONS[self.activation_function] @@ -106,15 +114,16 @@ def __call__(self, x, train): y = batch_norm(name='bn2')(y, use_running_average=not train) y = activation_fn(y) y = conv(self.filters * 4, (1, 1), name='conv3')(y) - y = batch_norm( - name='bn3', scale_init=_constant_init(self.bn_output_scale))( - y, use_running_average=not train) + y = batch_norm(name='bn3', scale_init=_constant_init(self.bn_output_scale))( + y, use_running_average=not train + ) y = activation_fn(residual + y) return y class ResNet(nn.Module): """ResNetV1.""" + num_classes: int num_filters: int = 64 num_layers: int = 50 @@ -137,8 +146,14 @@ def __call__(self, x, train): raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[self.num_layers] conv = functools.partial(nn.Conv, padding=[(3, 3), (3, 3)]) - x = conv(self.num_filters, kernel_size=(7, 7), strides=(2, 2), - use_bias=False, dtype=self.dtype, name='conv0')(x) + x = conv( + self.num_filters, + kernel_size=(7, 7), + strides=(2, 2), + use_bias=False, + dtype=self.dtype, + name='conv0', + )(x) x = normalization.VirtualBatchNorm( momentum=self.batch_norm_momentum, epsilon=self.batch_norm_epsilon, @@ -149,14 +164,15 @@ def __call__(self, x, train): batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, - data_format=self.data_format)(x, use_running_average=not train) + data_format=self.data_format, + )(x, use_running_average=not train) x = model_utils.ACTIVATIONS[self.activation_function](x) # MLperf-required x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') for i, block_size in enumerate(block_sizes): for j in range(block_size): strides = (2, 2) if i > 0 and j == 0 else (1, 1) x = ResidualBlock( - self.num_filters * 2 ** i, + self.num_filters * 2**i, strides=strides, axis_name=self.axis_name, axis_index_groups=self.axis_index_groups, @@ -169,14 +185,16 @@ def __call__(self, x, train): total_batch_size=self.total_batch_size, data_format=self.data_format, activation_function=self.activation_function, - )(x, train=train) + )(x, train=train) x = jnp.mean(x, axis=(1, 2)) if self.dropout_rate > 0.0: x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) - x = nn.Dense(self.num_classes, kernel_init=nn.initializers.normal(), - dtype=self.dtype)(x) + x = nn.Dense( + self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype + )(x) return x + # a dictionary mapping the number of layers in a resnet to the number of blocks # in each stage of the model. _block_size_options = { @@ -185,12 +203,13 @@ def __call__(self, x, train): 50: [3, 4, 6, 3], 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], - 200: [3, 24, 36, 3] + 200: [3, 24, 36, 3], } class FakeResNet(nn.Module): """Minimal NN (for debugging) with the same signature as a ResNet.""" + num_classes: int axis_name: Optional[str] = None axis_index_groups: Optional[Any] = None @@ -205,10 +224,12 @@ def __call__(self, x, train): name='init_bn', axis_name=self.axis_name, axis_index_groups=self.axis_index_groups, - dtype=self.dtype)(x) + dtype=self.dtype, + )(x) x = jnp.mean(x, axis=(1, 2)) - x = nn.Dense(self.num_classes, kernel_init=nn.initializers.normal(), - dtype=self.dtype)(x) + x = nn.Dense( + self.num_classes, kernel_init=nn.initializers.normal(), dtype=self.dtype + )(x) return x @@ -229,7 +250,8 @@ def build_flax_module(self): total_batch_size=self.hps.total_accumulated_batch_size, data_format=self.hps.data_format, activation_function=self.hps.activation_function, - dropout_rate=self.hps.dropout_rate) + dropout_rate=self.hps.dropout_rate, + ) def get_fake_inputs(self, hps): """Helper method solely for purpose of initializing the model.""" diff --git a/init2winit/model_lib/model_utils.py b/init2winit/model_lib/model_utils.py index 95a40a35..417e9d42 100644 --- a/init2winit/model_lib/model_utils.py +++ b/init2winit/model_lib/model_utils.py @@ -14,9 +14,9 @@ # limitations under the License. """Common code used by different models.""" + import enum import functools - from typing import Any, Callable, Dict, Iterable from absl import logging @@ -50,15 +50,17 @@ lecun_normal = functools.partial( initializers.variance_scaling, mode='fan_in', - distribution='truncated_normal') + distribution='truncated_normal', +) # This trick is used in fairseq's multihead attention. # https://github.com/facebookresearch/fairseq/blob/0.12.2-release/fairseq/modules/multihead_attention.py#L171 # pylint: disable=line-too-long xavier_uniform_over_sqrt2 = functools.partial( initializers.variance_scaling, - scale=1./2, + scale=1.0 / 2, mode='fan_avg', - distribution='uniform') + distribution='uniform', +) INITIALIZERS = { 'delta_orthogonal': initializers.delta_orthogonal, @@ -71,6 +73,7 @@ class ScalarMultiply(nn.Module): """Layer which multiplies by a single scalar.""" + scale_init: Any = initializers.ones @nn.compact @@ -78,13 +81,15 @@ def __call__(self, x): return x * self.param('scale', self.scale_init, ()) -def get_normalizer(normalizer, - train, - batch_size=None, - virtual_batch_size=None, - total_batch_size=None, - dtype=jnp.float32, - data_format='NHWC'): +def get_normalizer( + normalizer, + train, + batch_size=None, + virtual_batch_size=None, + total_batch_size=None, + dtype=jnp.float32, + data_format='NHWC', +): """Maps a string to the given normalizer function. We return a function that returns the normalization module, deferring the @@ -105,8 +110,8 @@ def get_normalizer(normalizer, Args: normalizer: One of ['batch_norm', 'virtual_batch_norm', 'layer_norm', 'none']. - train: Boolean indiciating if we are running in train or inference mode - for batch norm. + train: Boolean indiciating if we are running in train or inference mode for + batch norm. batch_size: only used for virtual batch norm, the batch size. virtual_batch_size: only used for virtual batch norm, the virtual batch size. @@ -128,7 +133,8 @@ def get_normalizer(normalizer, use_running_average=not train, momentum=0.9, epsilon=1e-5, - dtype=dtype) + dtype=dtype, + ) elif normalizer == 'virtual_batch_norm': return functools.partial( normalization.VirtualBatchNorm, @@ -139,18 +145,23 @@ def get_normalizer(normalizer, virtual_batch_size=virtual_batch_size, data_format=data_format, total_batch_size=total_batch_size, - dtype=dtype) + dtype=dtype, + ) elif normalizer in ['layer_norm', 'pre_layer_norm', 'post_layer_norm']: return functools.partial(nn.LayerNorm, dtype=dtype) elif normalizer == 'none': + def identity_wrapper(*args, **kwargs): del args del kwargs + def identity(x, *args, **kwargs): del args del kwargs return x + return identity + return identity_wrapper else: raise ValueError('Unknown normalizer: {}'.format(normalizer)) @@ -224,8 +235,8 @@ def l2_regularization(params, l2_decay_rank_threshold): Args: params: Pytree containing parameters. l2_decay_rank_threshold: The calculation will only include parameters with - param.ndim >= l2_decay_rank_threshold. Set to 2 to ignore all bias and - batch_norm params in the model. + param.ndim >= l2_decay_rank_threshold. Set to 2 to ignore all bias and + batch_norm params in the model. Returns: weight_l2: the squared l2 norm of all params matching the threshold. @@ -257,13 +268,15 @@ def flatten_dict(nested_dict, sep='/'): For example, if the dictionary is {'outer1': {'inner1': 1, 'inner2': 2}}. This will return {'/outer1/inner1': 1, '/outer1/inner2': 2}. With sep='/' this - will match how flax.traverse_util.ParamTreeTraversal flattens the keys, to allow for + will match how flax.traverse_util.ParamTreeTraversal flattens the keys, to + allow for easy filtering when using traversals. Requires the nested dictionaries contain no cycles. Args: nested_dict: A nested dictionary. sep: The separator to use when concatenating the dictionary strings. + Returns: The flattened dictionary. """ @@ -310,6 +323,7 @@ def rescale_layers(params, layer_rescale_factors): # Define this so that if using pytree iteration utilities, can iterate over the # model shapes pytree without iterating over the shape tuples. class ShapeTuple: + def __init__(self, shape_tuple): self.shape_tuple = shape_tuple @@ -326,6 +340,7 @@ def param_shapes(params): class ParameterType(enum.Enum): """Different types of neural network parameters.""" + WEIGHT = 0 BIAS = 1 CONV_WEIGHT = 2 @@ -361,7 +376,8 @@ def param_types(shapes, parent_name: str = '') -> Dict[str, ParameterType]: name = name.lower() if isinstance(value, dict) or isinstance(value, FrozenDict): param_types_dict[original_name] = param_types( - value, parent_name=parent_name + '/' + name) + value, parent_name=parent_name + '/' + name + ) else: if 'batchnorm' in parent_name or 'bn' in parent_name: if name == 'scale': @@ -370,7 +386,8 @@ def param_types(shapes, parent_name: str = '') -> Dict[str, ParameterType]: param_types_dict[original_name] = ParameterType.BATCH_NORM_BIAS else: raise ValueError( - f'Unrecognized batch norm parameter: {parent_name}/{name}.') + f'Unrecognized batch norm parameter: {parent_name}/{name}.' + ) elif ( 'layernorm' in parent_name or 'layer_norm_' in parent_name @@ -383,7 +400,8 @@ def param_types(shapes, parent_name: str = '') -> Dict[str, ParameterType]: param_types_dict[original_name] = ParameterType.LAYER_NORM_BIAS else: raise ValueError( - f'Unrecognized layer norm parameter: {parent_name}/{name}.') + f'Unrecognized layer norm parameter: {parent_name}/{name}.' + ) elif 'rmsnorm' in parent_name: if name == 'scale': param_types_dict[original_name] = ParameterType.RMSNORM_SCALE @@ -391,13 +409,15 @@ def param_types(shapes, parent_name: str = '') -> Dict[str, ParameterType]: param_types_dict[original_name] = ParameterType.RMSNORM_BIAS else: raise ValueError( - f'Unrecognized rms norm parameter: {parent_name}/{name}.') + f'Unrecognized rms norm parameter: {parent_name}/{name}.' + ) elif 'queryscaler' in parent_name: if name == 'scale': param_types_dict[original_name] = ParameterType.ATTENTION_SCALE else: raise ValueError( - f'Unrecognized query scaler parameter: {parent_name}/{name}.') + f'Unrecognized query scaler parameter: {parent_name}/{name}.' + ) elif 'conv' in parent_name: if 'bias' in name: param_types_dict[original_name] = ParameterType.BIAS @@ -406,8 +426,9 @@ def param_types(shapes, parent_name: str = '') -> Dict[str, ParameterType]: # Note that this is exact equality, not contained in, because # flax.linen.Embed names the embedding parameter "embedding" # https://github.com/google/flax/blob/main/flax/linen/linear.py#L604. - elif ('embedding' in name or - ('embedding' in parent_name and name == 'kernel')): + elif 'embedding' in name or ( + 'embedding' in parent_name and name == 'kernel' + ): param_types_dict[original_name] = ParameterType.EMBEDDING elif ( 'attention' in parent_name @@ -430,7 +451,8 @@ def param_types(shapes, parent_name: str = '') -> Dict[str, ParameterType]: param_types_dict[original_name] = ParameterType.ATTENTION_QKV else: raise ValueError( - f'Unrecognized attention parameter: {parent_name}/{name}.') + f'Unrecognized attention parameter: {parent_name}/{name}.' + ) elif 'lstm' in parent_name: if name == 'kernel': param_types_dict[original_name] = ParameterType.LSTM_WEIGHT @@ -438,7 +460,8 @@ def param_types(shapes, parent_name: str = '') -> Dict[str, ParameterType]: param_types_dict[original_name] = ParameterType.LSTM_BIAS else: raise ValueError( - f'Unrecognized attention parameter: {parent_name}/{name}.') + f'Unrecognized attention parameter: {parent_name}/{name}.' + ) elif 'subsample' in parent_name and 'dense' in parent_name: if name == 'kernel': param_types_dict[original_name] = ParameterType.SUBSAMPLE_WEIGHT @@ -451,8 +474,7 @@ def param_types(shapes, parent_name: str = '') -> Dict[str, ParameterType]: elif 'x' in name: param_types_dict[original_name] = ParameterType.NQM_PARAM else: - raise ValueError( - f'Unrecognized parameter: {parent_name}/{name}.') + raise ValueError(f'Unrecognized parameter: {parent_name}/{name}.') return param_types_dict diff --git a/init2winit/model_lib/models.py b/init2winit/model_lib/models.py index 51ea7df4..d8864dc9 100644 --- a/init2winit/model_lib/models.py +++ b/init2winit/model_lib/models.py @@ -45,7 +45,6 @@ from init2winit.model_lib import xformer_translate_binary from init2winit.model_lib import xformer_translate_mlc_variant - _ALL_MODELS = { 'fully_connected': ( fully_connected.FullyConnectedModel, diff --git a/init2winit/model_lib/nanodo.py b/init2winit/model_lib/nanodo.py index dac53a55..ef252b83 100644 --- a/init2winit/model_lib/nanodo.py +++ b/init2winit/model_lib/nanodo.py @@ -59,12 +59,14 @@ class DoConfig: L: int # sequence length kernel_init: nn.initializers.Initializer = nn.initializers.xavier_uniform() embed_init: nn.initializers.Initializer = nn.initializers.variance_scaling( - 1.0, 'fan_in', 'normal', out_axis=0) + 1.0, 'fan_in', 'normal', out_axis=0 + ) dtype: jnp.dtype = jnp.float32 class TransformerDo(nn.Module): """Transformer decoder-only.""" + docfg: DoConfig def setup(self): @@ -97,14 +99,14 @@ def __call__(self, y_BxL: jax.Array, train: bool): class Mlp(nn.Module): """Multilayer perceptron.""" + cfg: DoConfig @nn.compact def __call__(self, x_BxLxD: jax.Array): cfg = self.cfg linear = functools.partial( - nn.Dense, kernel_init=cfg.kernel_init, use_bias=False, - dtype=cfg.dtype + nn.Dense, kernel_init=cfg.kernel_init, use_bias=False, dtype=cfg.dtype ) x_BxLxF = linear(cfg.F)(x_BxLxD) x_BxLxF = jax.nn.gelu(x_BxLxF) @@ -114,6 +116,7 @@ def __call__(self, x_BxLxD: jax.Array): class TBlock(nn.Module): """Transformer Block.""" + docfg: DoConfig @nn.compact @@ -133,6 +136,7 @@ def __call__(self, in_BxLxD: jax.Array): class CausalAttn(nn.Module): """Causal attention layer.""" + cfg: DoConfig @nn.compact diff --git a/init2winit/model_lib/normalization.py b/init2winit/model_lib/normalization.py index d57f67f8..087ade9c 100644 --- a/init2winit/model_lib/normalization.py +++ b/init2winit/model_lib/normalization.py @@ -14,10 +14,10 @@ # limitations under the License. """Virtual Batch Normalization Flax module.""" + from typing import Any, Callable, Iterable, Optional from flax import linen as nn - from jax import lax from jax.nn import initializers import jax.numpy as jnp @@ -28,17 +28,16 @@ def _absolute_dims(rank, dims): def _get_batch_axis( - data_format, - x, - virtual_batch_size, - use_running_average, - axis_index_groups): + data_format, x, virtual_batch_size, use_running_average, axis_index_groups +): """Get the batch axis of input x, and check for a valid virtual batch size.""" if data_format: if 'N' not in data_format: raise ValueError( 'Could not locate batch axis "N" in `data_format={}`.'.format( - data_format)) + data_format + ) + ) batch_axis = data_format.index('N') else: batch_axis = 0 @@ -48,15 +47,19 @@ def _get_batch_axis( if data_format is None: raise ValueError( 'Must provide `data_format` when providing `virtual_batch_size` ' - 'to a VirtualBatchNorm layer.') + 'to a VirtualBatchNorm layer.' + ) if axis_index_groups is not None: raise ValueError( 'Only one of `virtual_batch_size` or `axis_index_groups` can ' - 'be provided to a VirtualBatchNorm layer.') + 'be provided to a VirtualBatchNorm layer.' + ) if virtual_batch_size < 1: raise ValueError( 'Must have a `virtual_batch_size` > 1, received {}.'.format( - virtual_batch_size)) + virtual_batch_size + ) + ) if x.shape[batch_axis] % virtual_batch_size != 0: raise ValueError( '`virtual_batch_size={}` must evenly divide ' @@ -64,7 +67,9 @@ def _get_batch_axis( 'batch size < hps.virtual_batch_size; either decrease the number of ' 'cores, increase the total batch size, or decrease ' 'hps.virtual_batch_size.'.format( - virtual_batch_size, x.shape[batch_axis])) + virtual_batch_size, x.shape[batch_axis] + ) + ) return batch_axis @@ -86,14 +91,14 @@ class VirtualBatchNorm(nn.Module): Attributes: x: the input to be normalized. axis: the feature or non-batch axis of the input. - momentum: decay rate for the exponential moving average of - the batch statistics. + momentum: decay rate for the exponential moving average of the batch + statistics. epsilon: a small float added to variance to avoid dividing by zero. dtype: the dtype of the computation (default: float32). use_bias: if True, bias (beta) is added. - use_scale: if True, multiply by scale (gamma). - When the next layer is linear (also e.g. nn.relu), this can be disabled - since the scaling will be done by the next layer. + use_scale: if True, multiply by scale (gamma). When the next layer is linear + (also e.g. nn.relu), this can be disabled since the scaling will be done + by the next layer. bias_init: initializer for bias, by default, zero. scale_init: initializer for scale, by default, one. axis_name: the axis name used to combine batch statistics from multiple @@ -106,12 +111,12 @@ class VirtualBatchNorm(nn.Module): batch_size: the batch size used for each forward pass. We must explicitly pass this instead of relying on `x.shape` because we may initialize the model with a batch of zeros with a different batch size. - virtual_batch_size: the size of the virtual batches to construct on - each device, which will be used to normalize sub-batches of each - per-device batch. Must evenly divide the per-device batch size (as - determined by `x`), and cannot be combined with `axis_index_groups`. - Passing the default value of None will replicate the existing nn.BatchNorm - behavior without virtual batches. + virtual_batch_size: the size of the virtual batches to construct on each + device, which will be used to normalize sub-batches of each per-device + batch. Must evenly divide the per-device batch size (as determined by + `x`), and cannot be combined with `axis_index_groups`. Passing the default + value of None will replicate the existing nn.BatchNorm behavior without + virtual batches. total_batch_size: only necessary when using gradient accumulation, the total batch size used to calculate accumulated gradients. This is required here because we need to store `total_batch_size // virtual_batch_size` EMAs @@ -119,6 +124,7 @@ class VirtualBatchNorm(nn.Module): data_format: only used when `virtual_batch_size` is set, to determine the batch axis. """ + use_running_average: Optional[bool] = None axis: int = -1 momentum: float = 0.99 @@ -155,7 +161,8 @@ def __call__(self, x, use_running_average: Optional[bool] = None): Normalized inputs (the same shape as inputs). """ use_running_average = nn.module.merge_param( - 'use_running_average', self.use_running_average, use_running_average) + 'use_running_average', self.use_running_average, use_running_average + ) virtual_batch_size = self.virtual_batch_size batch_axis = _get_batch_axis( @@ -163,7 +170,8 @@ def __call__(self, x, use_running_average: Optional[bool] = None): x, virtual_batch_size, use_running_average, - self.axis_index_groups) + self.axis_index_groups, + ) if virtual_batch_size is None: virtual_batch_size = self.batch_size @@ -196,30 +204,47 @@ def __call__(self, x, use_running_average: Optional[bool] = None): # `virtual_batch_size, which should only happen if we are initializing # with dummy variables (typically of batch size 2). min(x.shape[batch_axis], virtual_batch_size), - *x.shape[batch_axis + 1:]) + *x.shape[batch_axis + 1 :], + ) x = jnp.reshape(x, sub_batched_shape) - ra_mean = self.variable('batch_stats', 'batch_norm_running_mean', - lambda s: jnp.zeros(s, jnp.float32), - feature_shape) - ra_var = self.variable('batch_stats', 'batch_norm_running_var', - lambda s: jnp.ones(s, jnp.float32), - feature_shape) + ra_mean = self.variable( + 'batch_stats', + 'batch_norm_running_mean', + lambda s: jnp.zeros(s, jnp.float32), + feature_shape, + ) + ra_var = self.variable( + 'batch_stats', + 'batch_norm_running_var', + lambda s: jnp.ones(s, jnp.float32), + feature_shape, + ) # If using gradient accumulation, use these to accumulate the activations # for the current batch before folding them into the running average. mean_accumulator = self.variable( - 'batch_stats', 'batch_norm_mean_accumulator', - lambda s: jnp.zeros(s, jnp.float32), feature_shape) + 'batch_stats', + 'batch_norm_mean_accumulator', + lambda s: jnp.zeros(s, jnp.float32), + feature_shape, + ) mean2_accumulator = self.variable( - 'batch_stats', 'batch_norm_mean2_accumulator', - lambda s: jnp.zeros(s, jnp.float32), feature_shape) + 'batch_stats', + 'batch_norm_mean2_accumulator', + lambda s: jnp.zeros(s, jnp.float32), + feature_shape, + ) # A counter that is used to determine which accumulation pass we are # currently in. This will increment from 0 until we have accumulated # gradients calculated on `self.total_batch_size` examples. This should only # ever be saved on disk as 0 because we only checkpoint after accumulating # enough examples to make an update. - grad_accum_counter = self.variable('batch_stats', 'grad_accum_counter', - lambda s: jnp.zeros(s, jnp.int32), []) + grad_accum_counter = self.variable( + 'batch_stats', + 'grad_accum_counter', + lambda s: jnp.zeros(s, jnp.int32), + [], + ) # See NOTE above on initialization behavior. initializing = self.is_mutable_collection('params') @@ -237,15 +262,17 @@ def __call__(self, x, use_running_average: Optional[bool] = None): else: # Shape (num_sub_batches, x.shape[axis]). mean = jnp.mean(x, axis=reduction_axis, keepdims=False) - mean2 = jnp.mean( - lax.square(x), axis=reduction_axis, keepdims=False) + mean2 = jnp.mean(lax.square(x), axis=reduction_axis, keepdims=False) if self.axis_name is not None and not initializing: concatenated_mean = jnp.concatenate([mean, mean2]) mean, mean2 = jnp.split( lax.pmean( concatenated_mean, axis_name=self.axis_name, - axis_index_groups=self.axis_index_groups), 2) + axis_index_groups=self.axis_index_groups, + ), + 2, + ) var = mean2 - lax.square(mean) if not initializing: @@ -257,19 +284,22 @@ def __call__(self, x, use_running_average: Optional[bool] = None): # running averages. should_update_ra = grad_accum_counter_inc // passes_per_total_batch ra_mean_update = ( - should_update_ra * mean_accumulator.value / grad_accum_counter_inc) + should_update_ra * mean_accumulator.value / grad_accum_counter_inc + ) ra_mean.value = ( - (1 - should_update_ra * (1 - self.momentum)) * ra_mean.value + - (1 - self.momentum) * ra_mean_update) + 1 - should_update_ra * (1 - self.momentum) + ) * ra_mean.value + (1 - self.momentum) * ra_mean_update ra_var_update = should_update_ra * ( - mean2_accumulator.value / grad_accum_counter_inc - - lax.square(mean_accumulator.value / grad_accum_counter_inc)) + mean2_accumulator.value / grad_accum_counter_inc + - lax.square(mean_accumulator.value / grad_accum_counter_inc) + ) ra_var.value = ( - (1 - should_update_ra * (1 - self.momentum)) * ra_var.value + - (1 - self.momentum) * ra_var_update) + 1 - should_update_ra * (1 - self.momentum) + ) * ra_var.value + (1 - self.momentum) * ra_var_update grad_accum_counter.value = ( - grad_accum_counter_inc % passes_per_total_batch) + grad_accum_counter_inc % passes_per_total_batch + ) # Reset the activation accumulators every `passes_per_total_batch` steps # (np.sign == 0 if grad_accum_counter == 0). mean_accumulator.value *= jnp.sign(grad_accum_counter.value) @@ -277,15 +307,16 @@ def __call__(self, x, use_running_average: Optional[bool] = None): y = x - mean.reshape((num_sub_batches, *feature_shape)) mul = lax.rsqrt( - var.reshape((num_sub_batches, *feature_shape)) + self.epsilon) + var.reshape((num_sub_batches, *feature_shape)) + self.epsilon + ) if self.use_scale: mul = mul * self.param( - 'scale', self.scale_init, reduced_feature_shape).reshape( - (1, *feature_shape)) + 'scale', self.scale_init, reduced_feature_shape + ).reshape((1, *feature_shape)) y = y * mul if self.use_bias: - y = y + self.param( - 'bias', self.bias_init, reduced_feature_shape).reshape( - (1, *feature_shape)) + y = y + self.param('bias', self.bias_init, reduced_feature_shape).reshape( + (1, *feature_shape) + ) y = jnp.reshape(y, input_shape) return jnp.asarray(y, self.dtype) diff --git a/init2winit/model_lib/nqm.py b/init2winit/model_lib/nqm.py index baa8a350..d43ba606 100644 --- a/init2winit/model_lib/nqm.py +++ b/init2winit/model_lib/nqm.py @@ -16,6 +16,7 @@ r"""NQM Model. """ + from flax import linen as nn from init2winit.model_lib import base_model from init2winit.model_lib import model_utils @@ -26,13 +27,15 @@ from scipy.stats import ortho_group # small hparams used for unit tests -DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - # Note the dimension is set by input_shape. - hessian_decay_power=1, - noise_decay_power=1, - nqm_mode='diagH_diagC', - model_dtype='float32', -)) +DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + # Note the dimension is set by input_shape. + hessian_decay_power=1, + noise_decay_power=1, + nqm_mode='diagH_diagC', + model_dtype='float32', + ) +) class NQMLoss(nn.Module): @@ -46,12 +49,13 @@ class NQMLoss(nn.Module): Attrs: hessian: dxd psd matrix representing the loss hessian. - noise_scaling: dxd matrix used to scale noise_input. This is a matrix N - s.t. N^T N = C, where C will be the covariance of the scaled noise. + noise_scaling: dxd matrix used to scale noise_input. This is a matrix N s.t. + N^T N = C, where C will be the covariance of the scaled noise. train: Ignored, we add this to conform to the i2w model API. noise_input = jnp.asarray(noise_input) x = self.param('x', (noise_input.shape[-1],), initializers.ones) """ + hessian: model_utils.Array noise_scaling: model_utils.Array train: bool = True @@ -74,17 +78,17 @@ def __call__(self, noise_input, train): # NQM loss = 1/2 x^T hessian x + x.T noise_scaling noise_input # this gives grad loss = hessian x + eps, where eps ~ N(0, C) return jnp.dot(jnp.dot(x.T, self.hessian), x) / 2 + jnp.mean( - jnp.dot(jnp.dot(noise_input, self.noise_scaling), x)) + jnp.dot(jnp.dot(noise_input, self.noise_scaling), x) + ) def quadratic_form(u, sigma): return np.dot(np.dot(u.T, sigma), u) -def _get_nqm_matrices(dim, - hessian_decay_power=1.0, - noise_decay_power=1.0, - mode='diagH_noC'): +def _get_nqm_matrices( + dim, hessian_decay_power=1.0, noise_decay_power=1.0, mode='diagH_noC' +): """Returns a hessian and a noise scaling matrix used in the NQM. The corresponding loss will be equal to 1/2 x^T H x + eps sqrtC x. @@ -114,7 +118,7 @@ def _get_nqm_matrices(dim, Args: dim: dimension of the matrices to generate. hessian_decay_power: Hessian eigenvalues will be of the form 1 / i^power. - noise_decay_power : Noise eigenvalues will be of the form 1 / i^power. + noise_decay_power : Noise eigenvalues will be of the form 1 / i^power. mode: One of the modes listed above. Returns: @@ -122,10 +126,12 @@ def _get_nqm_matrices(dim, sqrtC^T sqrtC = C. """ hessian_eigs = np.array( - [1.0 / np.power(i, hessian_decay_power) for i in range(1, dim + 1)]) + [1.0 / np.power(i, hessian_decay_power) for i in range(1, dim + 1)] + ) hessian_eigs = np.diag(hessian_eigs) noise_scaling_eigs = np.array( - [1.0 / np.power(i, noise_decay_power / 2.0) for i in range(1, dim + 1)]) + [1.0 / np.power(i, noise_decay_power / 2.0) for i in range(1, dim + 1)] + ) noise_scaling_eigs = np.diag(noise_scaling_eigs) @@ -137,19 +143,25 @@ def _get_nqm_matrices(dim, ortho_matrix = ortho_group.rvs(dim=dim) # H = U^T Sigma U if mode == 'H_noC': - return (quadratic_form(ortho_matrix, hessian_eigs), - np.zeros_like(noise_scaling_eigs)) + return ( + quadratic_form(ortho_matrix, hessian_eigs), + np.zeros_like(noise_scaling_eigs), + ) elif mode == 'H_codiagC': # noise matrix = noise_scaling_eigs U - return (quadratic_form(ortho_matrix, hessian_eigs), - np.dot(noise_scaling_eigs, ortho_matrix)) + return ( + quadratic_form(ortho_matrix, hessian_eigs), + np.dot(noise_scaling_eigs, ortho_matrix), + ) elif mode == 'H_offdiagC': # Sample a new rotation matrix for the noise c_ortho_matrix = ortho_group.rvs(dim=dim) - return (quadratic_form(ortho_matrix, hessian_eigs), - np.dot(noise_scaling_eigs, c_ortho_matrix)) + return ( + quadratic_form(ortho_matrix, hessian_eigs), + np.dot(noise_scaling_eigs, c_ortho_matrix), + ) elif mode == 'diagH_offdiagC': # Sample a new rotation matrix for the noise @@ -186,13 +198,11 @@ def __init__( def evaluate_batch(self, params, batch_stats, batch): """Evals the NQM loss.""" logits = self.flax_module.apply( - {'params': params}, batch['inputs'], train=False) + {'params': params}, batch['inputs'], train=False + ) # Trainer eval assumes eval function sums, not averages. loss = logits * batch['inputs'].shape[0] - metrics = { - 'loss': loss, - 'num_examples': batch['inputs'].shape[0] - } + metrics = {'loss': loss, 'num_examples': batch['inputs'].shape[0]} return metrics def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): @@ -213,14 +223,17 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): """ del dropout_rng average_loss = self.flax_module.apply( - {'params': params}, batch['inputs'], train=True) + {'params': params}, batch['inputs'], train=True + ) return average_loss, batch_stats def build_flax_module(self): - hessian, noise_scaling = _get_nqm_matrices(self.hps.input_shape[0], - self.hps.hessian_decay_power, - self.hps.noise_decay_power, - self.hps.nqm_mode) + hessian, noise_scaling = _get_nqm_matrices( + self.hps.input_shape[0], + self.hps.hessian_decay_power, + self.hps.noise_decay_power, + self.hps.nqm_mode, + ) return NQMLoss(hessian=hessian, noise_scaling=noise_scaling) def get_fake_inputs(self, hps): diff --git a/init2winit/model_lib/partition_tree.py b/init2winit/model_lib/partition_tree.py index 80b05ee0..ef76b5e4 100644 --- a/init2winit/model_lib/partition_tree.py +++ b/init2winit/model_lib/partition_tree.py @@ -42,6 +42,7 @@ def create_partition_flat_params_fn(key_map): partition_flat_params, a function which returns a partitioned param dictionary. """ + def partition_flat_params(flat_params): subparam_groups = {} for tup in flat_params: @@ -51,6 +52,7 @@ def partition_flat_params(flat_params): subparam_groups[mapped_key][tup] = flat_params[tup] return subparam_groups + return partition_flat_params @@ -76,4 +78,3 @@ def get_test_group(params): def get_skip_analysis_fn(name): return skip_analysis_registry[name] - diff --git a/init2winit/model_lib/resnet.py b/init2winit/model_lib/resnet.py index 4198259f..a5d96c92 100644 --- a/init2winit/model_lib/resnet.py +++ b/init2winit/model_lib/resnet.py @@ -14,6 +14,7 @@ # limitations under the License. """Flax implementation of ResNet V1.""" + import functools from typing import Optional, Tuple @@ -25,23 +26,24 @@ import jax.numpy as jnp from ml_collections.config_dict import config_dict - -DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - num_filters=16, - num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] - batch_norm_momentum=0.9, - batch_norm_epsilon=1e-5, - # Make this a string to avoid having to import jnp into the configs. - model_dtype='float32', - virtual_batch_size=64, - data_format='NHWC', - block_type='post_activation', # either pre_activation or post_activation - bn_relu_conv=True, # only used for block_type='pre_activation' - use_bn=True, - dropout_rate=0.0, - activation_function='relu', - extra_norm_on_residual=False, -)) +DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + num_filters=16, + num_layers=18, # Must be one of [18, 34, 50, 101, 152, 200] + batch_norm_momentum=0.9, + batch_norm_epsilon=1e-5, + # Make this a string to avoid having to import jnp into the configs. + model_dtype='float32', + virtual_batch_size=64, + data_format='NHWC', + block_type='post_activation', # pre_activation or post_activation + bn_relu_conv=True, # only used for block_type='pre_activation' + use_bn=True, + dropout_rate=0.0, + activation_function='relu', + extra_norm_on_residual=False, + ) +) class PreActResidualBlock(nn.Module): @@ -52,6 +54,7 @@ class PreActResidualBlock(nn.Module): non-linearity after the residual connection, the preactivation block applies it before the residual connection. """ + filters: int strides: Tuple[int] = (1, 1) dtype: model_utils.Dtype = jnp.float32 @@ -69,6 +72,7 @@ class PreActResidualBlock(nn.Module): @nn.compact def __call__(self, x, train): needs_projection = x.shape[-1] != self.filters * 4 or self.strides != (1, 1) + def maybe_normalize(name): if self.use_bn: return normalization.VirtualBatchNorm( @@ -79,7 +83,8 @@ def maybe_normalize(name): virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, data_format=self.data_format, - name=name) + name=name, + ) else: return lambda x, **kwargs: x @@ -87,10 +92,12 @@ def maybe_normalize(name): residual = x if needs_projection: - residual = conv( - self.filters * 4, (1, 1), self.strides, name='proj_conv')(residual) + residual = conv(self.filters * 4, (1, 1), self.strides, name='proj_conv')( + residual + ) residual = maybe_normalize(name='proj_bn')( - residual, use_running_average=not train) + residual, use_running_average=not train + ) def _bn_nonlin(y, name): if self.bn_relu_conv: @@ -110,12 +117,14 @@ def _bn_nonlin(y, name): if self.extra_norm_on_residual: return maybe_normalize(name='extra')( - y + residual, use_running_average=not train) + y + residual, use_running_average=not train + ) return y + residual class ResidualBlock(nn.Module): """Bottleneck ResNet block.""" + filters: int strides: Tuple[int] = (1, 1) dtype: model_utils.Dtype = jnp.float32 @@ -133,6 +142,7 @@ class ResidualBlock(nn.Module): @nn.compact def __call__(self, x, train): needs_projection = x.shape[-1] != self.filters * 4 or self.strides != (1, 1) + def maybe_normalize(name, scale_init=nn.initializers.ones): if self.use_bn: return normalization.VirtualBatchNorm( @@ -144,7 +154,8 @@ def maybe_normalize(name, scale_init=nn.initializers.ones): total_batch_size=self.total_batch_size, data_format=self.data_format, scale_init=scale_init, - name=name) + name=name, + ) else: return lambda x, **kwargs: x @@ -152,10 +163,12 @@ def maybe_normalize(name, scale_init=nn.initializers.ones): residual = x if needs_projection: - residual = conv( - self.filters * 4, (1, 1), self.strides, name='proj_conv')(residual) + residual = conv(self.filters * 4, (1, 1), self.strides, name='proj_conv')( + residual + ) residual = maybe_normalize(name='proj_bn')( - residual, use_running_average=not train) + residual, use_running_average=not train + ) y = conv(self.filters, (1, 1), name='conv1')(x) y = maybe_normalize(name='bn1')(y, use_running_average=not train) @@ -166,7 +179,8 @@ def maybe_normalize(name, scale_init=nn.initializers.ones): y = conv(self.filters * 4, (1, 1), name='conv3')(y) y = maybe_normalize(name='bn3', scale_init=nn.initializers.zeros)( - y, use_running_average=not train) + y, use_running_average=not train + ) y = model_utils.ACTIVATIONS[self.activation_function](residual + y) if self.extra_norm_on_residual: @@ -176,6 +190,7 @@ def maybe_normalize(name, scale_init=nn.initializers.ones): class ResNet(nn.Module): """ResNetV1.""" + num_outputs: int num_filters: int = 64 num_layers: int = 50 @@ -199,10 +214,14 @@ def __call__(self, x, train): raise ValueError('Please provide a valid number of layers') block_sizes = _block_size_options[self.num_layers] - x = nn.Conv(self.num_filters, (7, 7), (2, 2), - use_bias=False, - dtype=self.dtype, - name='init_conv')(x) + x = nn.Conv( + self.num_filters, + (7, 7), + (2, 2), + use_bias=False, + dtype=self.dtype, + name='init_conv', + )(x) if self.use_bn: x = normalization.VirtualBatchNorm( momentum=self.batch_norm_momentum, @@ -212,7 +231,8 @@ def __call__(self, x, train): batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, total_batch_size=self.total_batch_size, - data_format=self.data_format)(x, use_running_average=not train) + data_format=self.data_format, + )(x, use_running_average=not train) x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME') if self.block_type == 'post_activation': residual_block = ResidualBlock @@ -226,7 +246,7 @@ def __call__(self, x, train): strides = (2, 2) if i > 0 and j == 0 else (1, 1) index += 1 x = residual_block( - self.num_filters * 2 ** i, + self.num_filters * 2**i, strides=strides, dtype=self.dtype, batch_norm_momentum=self.batch_norm_momentum, @@ -239,7 +259,7 @@ def __call__(self, x, train): use_bn=self.use_bn, activation_function=self.activation_function, extra_norm_on_residual=self.extra_norm_on_residual, - )(x, train=train) + )(x, train=train) x = jnp.mean(x, axis=(1, 2)) if self.dropout_rate > 0.0: x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) @@ -257,7 +277,7 @@ def __call__(self, x, train): 101: [3, 4, 23, 3], 152: [3, 8, 36, 3], 200: [3, 24, 36, 3], - 1000: [3, 24 * 5, 36 * 5, 3] + 1000: [3, 24 * 5, 36 * 5, 3], } @@ -282,7 +302,8 @@ def build_flax_module(self): use_bn=self.hps.use_bn, dropout_rate=self.hps.dropout_rate, activation_function=self.hps.activation_function, - extra_norm_on_residual=self.hps.extra_norm_on_residual) + extra_norm_on_residual=self.hps.extra_norm_on_residual, + ) def get_fake_inputs(self, hps): """Helper method solely for purpose of initializing the model.""" diff --git a/init2winit/model_lib/rope_nanodo.py b/init2winit/model_lib/rope_nanodo.py index 803cfff9..933f8eaa 100644 --- a/init2winit/model_lib/rope_nanodo.py +++ b/init2winit/model_lib/rope_nanodo.py @@ -99,9 +99,10 @@ def __call__(self, x_BxLxD: jax.Array): # parameters instead of 2 * hidden_dim * D parameters. hidden_dim = cfg.F * 2 / 3 # Round up to the nearest multiple of cfg.multiple_of - hidden_dim = int(cfg.multiple_of * ( - (hidden_dim + cfg.multiple_of - 1) // cfg.multiple_of - )) + hidden_dim = int( + cfg.multiple_of + * ((hidden_dim + cfg.multiple_of - 1) // cfg.multiple_of) + ) # Double the hidden dimension for GLU x_BxLxF = linear(2 * hidden_dim)(x_BxLxD) else: diff --git a/init2winit/model_lib/simple_cnn.py b/init2winit/model_lib/simple_cnn.py index e1d35b53..292e3a8c 100644 --- a/init2winit/model_lib/simple_cnn.py +++ b/init2winit/model_lib/simple_cnn.py @@ -14,6 +14,7 @@ # limitations under the License. """Simple convnet classifier.""" + from typing import Sequence from flax import linen as nn @@ -21,17 +22,17 @@ from init2winit.model_lib import model_utils from jax.nn import initializers import jax.numpy as jnp - from ml_collections.config_dict import config_dict - # small hparams used for unit tests -DEFAULT_HPARAMS = config_dict.ConfigDict(dict( - num_filters=[20, 10], - kernel_sizes=[3, 3], - activation_function='relu', - model_dtype='float32', -)) +DEFAULT_HPARAMS = config_dict.ConfigDict( + dict( + num_filters=[20, 10], + kernel_sizes=[3, 3], + activation_function='relu', + model_dtype='float32', + ) +) class SimpleCNN(nn.Module): @@ -39,6 +40,7 @@ class SimpleCNN(nn.Module): The model assumes the input shape is [batch, H, W, C]. """ + num_outputs: int num_filters: Sequence[int] kernel_sizes: Sequence[int] @@ -50,15 +52,17 @@ class SimpleCNN(nn.Module): def __call__(self, x, train): for num_filters, kernel_size in zip(self.num_filters, self.kernel_sizes): x = nn.Conv( - num_filters, (kernel_size, kernel_size), (1, 1), + num_filters, + (kernel_size, kernel_size), + (1, 1), kernel_init=self.kernel_init, - bias_init=self.bias_init)(x) + bias_init=self.bias_init, + )(x) x = model_utils.ACTIVATIONS[self.activation_function](x) x = jnp.reshape(x, (x.shape[0], -1)) x = nn.Dense( - self.num_outputs, - kernel_init=self.kernel_init, - bias_init=self.bias_init)(x) + self.num_outputs, kernel_init=self.kernel_init, bias_init=self.bias_init + )(x) return x @@ -71,7 +75,8 @@ def build_flax_module(self): num_outputs=self.hps['output_shape'][-1], num_filters=self.hps.num_filters, kernel_sizes=self.hps.kernel_sizes, - activation_function=self.hps.activation_function) + activation_function=self.hps.activation_function, + ) def get_fake_inputs(self, hps): """Helper method solely for the purpose of initialzing the model.""" diff --git a/init2winit/model_lib/spectrum_augmenter.py b/init2winit/model_lib/spectrum_augmenter.py index 69f20277..b00ef4da 100644 --- a/init2winit/model_lib/spectrum_augmenter.py +++ b/init2winit/model_lib/spectrum_augmenter.py @@ -31,6 +31,7 @@ class SpecAug(nn.Module): This is an essential component in speech recognition models that helps achieve better word error rates. """ + freq_mask_count: int = 1 freq_mask_max_bins: int = 15 time_mask_count: int = 1 @@ -42,14 +43,16 @@ class SpecAug(nn.Module): def next_prng_key(self, name='dropout'): return self.make_rng(name) - def _get_mask(self, - batch_size, - choose_range, - mask_size, - max_length=None, - masks_per_frame=0.0, - multiplicity=1, - max_ratio=1.0): + def _get_mask( + self, + batch_size, + choose_range, + mask_size, + max_length=None, + masks_per_frame=0.0, + multiplicity=1, + max_ratio=1.0, + ): # Sample lengths for multiple masks. if max_length and max_length > 0: max_length = jnp.tile(max_length, (batch_size,)) @@ -59,9 +62,11 @@ def _get_mask(self, key=self.next_prng_key(), shape=(batch_size, multiplicity), minval=0.0, - maxval=1.0) - masked_frame_size = jnp.einsum('b,bm->bm', max_length, - masked_portion).astype(jnp.int32) + maxval=1.0, + ) + masked_frame_size = jnp.einsum( + 'b,bm->bm', max_length, masked_portion + ).astype(jnp.int32) # Make sure the sampled length was smaller than max_ratio * length_bound. # Note that sampling in this way was biased # (shorter sequence may over-masked.) @@ -71,7 +76,8 @@ def _get_mask(self, # Choose starting point. random_start = jax.random.uniform( - key=self.next_prng_key(), shape=(batch_size, multiplicity), maxval=1.0) + key=self.next_prng_key(), shape=(batch_size, multiplicity), maxval=1.0 + ) start_with_in_valid_range = random_start * (choose_range - length + 1) start = start_with_in_valid_range.astype(jnp.int32) end = start + length - 1 @@ -92,10 +98,12 @@ def _get_mask(self, if masks_per_frame > 0: multiplicity_weights = jnp.tile( jnp.expand_dims(jnp.arange(multiplicity, dtype=jnp.int32), 0), - [batch_size, 1]) + [batch_size, 1], + ) multiplicity_tensor = masks_per_frame * choose_range - multiplicity_weights = (multiplicity_weights < - multiplicity_tensor).astype(jnp.int32) + multiplicity_weights = ( + multiplicity_weights < multiplicity_tensor + ).astype(jnp.int32) pre_mask = jnp.einsum('bmt,bm->bt', pre_mask, multiplicity_weights) else: pre_mask = jnp.einsum('bmt->bt', pre_mask) @@ -111,8 +119,9 @@ def _time_mask(self, inputs, length): max_ratio = self.time_mask_max_ratio # If maximum mask length is zero, do nothing. - if ((time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames) or - max_ratio <= 0.0): + if ( + time_mask_max_frames == 0 and not use_dynamic_time_mask_max_frames + ) or max_ratio <= 0.0: return inputs if multiplicity == 0: return inputs @@ -130,7 +139,8 @@ def _time_mask(self, inputs, length): max_length=time_mask_max_frames, masks_per_frame=self.time_masks_per_frame, multiplicity=multiplicity, - max_ratio=max_ratio) + max_ratio=max_ratio, + ) outputs = jnp.einsum('bxy,bx->bxy', inputs, block_arrays) return outputs @@ -155,7 +165,8 @@ def _frequency_mask(self, inputs): max_length=freq_mask_max_bins, masks_per_frame=0.0, multiplicity=multiplicity, - max_ratio=1.0) + max_ratio=1.0, + ) outputs = jnp.einsum('bxy,by->bxy', inputs, block_arrays) return outputs diff --git a/init2winit/model_lib/test_local_attention_transformer.py b/init2winit/model_lib/test_local_attention_transformer.py index aa8c2949..2b3675ca 100644 --- a/init2winit/model_lib/test_local_attention_transformer.py +++ b/init2winit/model_lib/test_local_attention_transformer.py @@ -13,8 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Tests for local atention transformer. -""" +"""Tests for local atention transformer.""" from absl.testing import absltest from absl.testing import parameterized @@ -29,7 +28,6 @@ class LocalAttentionTransformerTests(parameterized.TestCase): Reference shapes and data types come from Colab notebooks, where TF and JAX/FLAX support functions where compared. - """ def test_decode_step_to_index(self): @@ -37,7 +35,8 @@ def test_decode_step_to_index(self): decode_step = 10 array_shape = (8192,) output = local_attention_transformer.decode_step_to_index( - decode_step, array_shape) + decode_step, array_shape + ) self.assertEqual(output, (10,)) @@ -66,7 +65,8 @@ def test_ones_matrix_band_part(self): max_forward = 10 output_shape = [1, 1, 2, 2] output = local_attention_transformer.ones_matrix_band_part( - num_rows, num_cols, max_backward, max_forward, output_shape) + num_rows, num_cols, max_backward, max_forward, output_shape + ) self.assertEqual(output.shape, (1, 1, 2, 2)) self.assertEqual(output.dtype, np.float32) @@ -77,7 +77,8 @@ def test_attention_bias_local(self): max_backward = 10 max_forward = 10 output = local_attention_transformer.attention_bias_local( - length, max_backward, max_forward) + length, max_backward, max_forward + ) self.assertEqual(output.shape, (1, 1, 2, 2)) self.assertEqual(output.dtype, np.float32) @@ -130,7 +131,8 @@ def test_select_block_for_decode_step(self): x = np.array(np.random.rand(1, 10, 256, 16), dtype=np.float32) decode_step = 10 output = local_attention_transformer.select_block_for_decode_step( - x, decode_step) + x, decode_step + ) self.assertEqual(output.shape, (1, 1, 256, 16)) self.assertEqual(output.dtype, np.float32) @@ -165,7 +167,8 @@ def test_unflatten_blocks_nd(self): x = np.array(np.random.rand(1, 2, 3, 4), dtype=np.float32) blocks_per_dimension = [2] output = local_attention_transformer.unflatten_blocks_nd( - x, blocks_per_dimension) + x, blocks_per_dimension + ) self.assertEqual(output.shape, (1, 2, 3, 4)) self.assertEqual(output.dtype, np.float32) @@ -192,7 +195,8 @@ def test_generate_relative_positions_matrix(self): length_q = 10 length_k = 10 output = local_attention_transformer.generate_relative_positions_matrix( - length_q, length_k) + length_q, length_k + ) self.assertEqual(output.shape, (10, 10)) self.assertEqual(output.dtype, np.int32) @@ -203,7 +207,8 @@ def test_generate_relative_positions_embeddings(self): length_k = 10 rng_key = jax.random.PRNGKey(0) model = local_attention_transformer.RelativePositionEmbeddings( - embed_layer_name='unit_test') + embed_layer_name='unit_test' + ) params = model.init(rng_key, length_q, length_k) output = model.apply(params, length_q, length_k) @@ -217,7 +222,8 @@ def test_relative_attention_inner(self): z = np.array(np.random.rand(1, 4, 3), dtype=np.float32) transpose = False output = local_attention_transformer.relative_attention_inner( - x, y, z, transpose) + x, y, z, transpose + ) self.assertEqual(output.shape, (1, 1, 1, 3)) self.assertEqual(output.dtype, np.float32) @@ -297,7 +303,8 @@ def test_put_item_in_decode_step(self): """Tests support function put_item_in_decode_step.""" x = jnp.array(np.random.rand(3, 8, 1, 4), dtype=np.float32) output = local_attention_transformer.put_item_in_decode_step( - input_x=x, decode_step=1) + input_x=x, decode_step=1 + ) self.assertEqual(output.shape, (3, 8, 1, 4)) self.assertEqual(output.dtype, np.float32) @@ -337,7 +344,8 @@ def test_process_partial_targets_decoding(self): """Tests support function process_partial_targets_decoding.""" x = np.array(np.random.rand(1, 256), dtype=np.float32) output = local_attention_transformer.process_partial_targets_decoding( - targets=x) + targets=x + ) self.assertEqual(output.shape, (1, 256)) self.assertEqual(output.dtype, np.float32) @@ -349,7 +357,8 @@ def test_feedforward(self): key1 = jax.random.PRNGKey(0) model = local_attention_transformer.FeedForward( - feedforward_depths=feedforward_depths) + feedforward_depths=feedforward_depths + ) params = model.init(key1, x) output = model.apply(params, x) @@ -374,7 +383,8 @@ def test_decoder_block(self): x = jnp.array(np.random.rand(1, 2, 1032), dtype=jnp.float32) rng_key = jax.random.PRNGKey(0) model = local_attention_transformer.DecoderBlock( - feedforward_depths=feedforward_depths) + feedforward_depths=feedforward_depths + ) params = model.init(rng_key, x) imp_output = model.apply(params, x) @@ -384,4 +394,3 @@ def test_decoder_block(self): if __name__ == '__main__': absltest.main() - diff --git a/init2winit/model_lib/test_losses.py b/init2winit/model_lib/test_losses.py index e32573d5..7fe6eec7 100644 --- a/init2winit/model_lib/test_losses.py +++ b/init2winit/model_lib/test_losses.py @@ -16,6 +16,7 @@ """Tests for losses.py. """ + import functools import types @@ -51,107 +52,139 @@ 'rescaled_loss_m': 10.0, }) -CLASSIFICATION_TEST_DATA = [{ - 'logits': - np.array([[5, 3, 4, -3, 7], [2, 5, -5, 5, 6], [-6, -5, 8, -6, 4], - [15, 8, -6, 4, 2], [-7, 5, -6, 9, 0]]), - 'one_hot_targets': - np.array([[1, 0, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 0], - [0, 0, 0, 0, 1], [0, 1, 0, 0, 0]]), - 'weights': - None, - 'hps': HPS_1, - 'cross_entropy': - 8.956906, - 'bi_tempered_cross_entropy': - 1.9120569, - 'rescaled_mean_squared_error': - 37.56, -}, { - 'logits': - np.array([[4, 2, 0, -4, 5], [14, 2, -5, 10, 12], [20, -3, 7, -9, 6], - [5, 7, -1, 2, -8], [4, -7, 9, 0, 2]]), - 'one_hot_targets': - np.array([[0, 1, 0, 0, 0], [0, 0, 0, 1, 0], [1, 0, 0, 0, 0], - [0, 0, 0, 0, 1], [0, 0, 1, 0, 0]]), - 'weights': - np.array([2, 7, 0, 3, 0]), - 'hps': HPS_2, - 'cross_entropy': - 6.7589717, - 'bi_tempered_cross_entropy': - 1.7393580, - 'rescaled_mean_squared_error': - 140.56666, -}] +CLASSIFICATION_TEST_DATA = [ + { + 'logits': np.array([ + [5, 3, 4, -3, 7], + [2, 5, -5, 5, 6], + [-6, -5, 8, -6, 4], + [15, 8, -6, 4, 2], + [-7, 5, -6, 9, 0], + ]), + 'one_hot_targets': np.array([ + [1, 0, 0, 0, 0], + [0, 0, 1, 0, 0], + [0, 0, 0, 1, 0], + [0, 0, 0, 0, 1], + [0, 1, 0, 0, 0], + ]), + 'weights': None, + 'hps': HPS_1, + 'cross_entropy': 8.956906, + 'bi_tempered_cross_entropy': 1.9120569, + 'rescaled_mean_squared_error': 37.56, + }, + { + 'logits': np.array([ + [4, 2, 0, -4, 5], + [14, 2, -5, 10, 12], + [20, -3, 7, -9, 6], + [5, 7, -1, 2, -8], + [4, -7, 9, 0, 2], + ]), + 'one_hot_targets': np.array([ + [0, 1, 0, 0, 0], + [0, 0, 0, 1, 0], + [1, 0, 0, 0, 0], + [0, 0, 0, 0, 1], + [0, 0, 1, 0, 0], + ]), + 'weights': np.array([2, 7, 0, 3, 0]), + 'hps': HPS_2, + 'cross_entropy': 6.7589717, + 'bi_tempered_cross_entropy': 1.7393580, + 'rescaled_mean_squared_error': 140.56666, + }, +] -RECONSTRUCTION_TEST_DATA = [{ - 'logits': - np.array([[4, -5, 8, -10], [-5, 7, 4, 11], [12, 5, 5, -9], - [7, -11, -4, 8]]).astype(float), - 'targets': - np.array([[0.05, 0.02, 0.96, 0.02], [0.05, 0.001, 0.5, 0.4], - [0.68, 0.92, 0.12, 0.22], [0.34, 0.44, 0.29, 0.2]]), - 'weights': - None, - 'hps': HPS_1, - 'sigmoid_binary_cross_entropy': - 11.996754, - 'bi_tempered_sigmoid_binary_cross_entropy': - 1.9910424, - 'sigmoid_mean_squared_error': - 1.180348, -}, { - 'logits': - np.array([[[4, -5], [8, -10]], [[-5, 7], [4, 11]], [[12, 5], [5, -9]], - [[7, -11], [-4, 8]]]).astype(float), - 'targets': - np.array([[[0.05, 0.02], [0.96, 0.02]], [[0.05, 0.001], [0.5, 0.4]], - [[0.68, 0.92], [0.12, 0.22]], [[0.34, 0.44], [0.29, 0.2]]]), - 'weights': - None, - 'hps': HPS_1, - 'sigmoid_binary_cross_entropy': - 11.996754, - 'bi_tempered_sigmoid_binary_cross_entropy': - 1.9910425, - 'sigmoid_mean_squared_error': - 1.180348, -}, { - 'logits': - np.array([[4, -5, 8, -10], [-5, 7, 4, 11], [12, 5, 5, -9], - [7, -11, -4, 8]]).astype(float), - 'targets': - np.array([[0.05, 0.02, 0.96, 0.02], [0.05, 0.001, 0.5, 0.4], - [0.68, 0.92, 0.12, 0.22], [0.34, 0.44, 0.29, 0.2]]), - 'weights': - np.array([0, 4, 0, 2]), - 'hps': HPS_1, - 'sigmoid_binary_cross_entropy': - 16.259, - 'bi_tempered_sigmoid_binary_cross_entropy': - 2.6654808, - 'sigmoid_mean_squared_error': - 1.5073959, -}] +RECONSTRUCTION_TEST_DATA = [ + { + 'logits': ( + np.array([ + [4, -5, 8, -10], + [-5, 7, 4, 11], + [12, 5, 5, -9], + [7, -11, -4, 8], + ]).astype(float) + ), + 'targets': np.array([ + [0.05, 0.02, 0.96, 0.02], + [0.05, 0.001, 0.5, 0.4], + [0.68, 0.92, 0.12, 0.22], + [0.34, 0.44, 0.29, 0.2], + ]), + 'weights': None, + 'hps': HPS_1, + 'sigmoid_binary_cross_entropy': 11.996754, + 'bi_tempered_sigmoid_binary_cross_entropy': 1.9910424, + 'sigmoid_mean_squared_error': 1.180348, + }, + { + 'logits': ( + np.array([ + [[4, -5], [8, -10]], + [[-5, 7], [4, 11]], + [[12, 5], [5, -9]], + [[7, -11], [-4, 8]], + ]).astype(float) + ), + 'targets': np.array([ + [[0.05, 0.02], [0.96, 0.02]], + [[0.05, 0.001], [0.5, 0.4]], + [[0.68, 0.92], [0.12, 0.22]], + [[0.34, 0.44], [0.29, 0.2]], + ]), + 'weights': None, + 'hps': HPS_1, + 'sigmoid_binary_cross_entropy': 11.996754, + 'bi_tempered_sigmoid_binary_cross_entropy': 1.9910425, + 'sigmoid_mean_squared_error': 1.180348, + }, + { + 'logits': ( + np.array([ + [4, -5, 8, -10], + [-5, 7, 4, 11], + [12, 5, 5, -9], + [7, -11, -4, 8], + ]).astype(float) + ), + 'targets': np.array([ + [0.05, 0.02, 0.96, 0.02], + [0.05, 0.001, 0.5, 0.4], + [0.68, 0.92, 0.12, 0.22], + [0.34, 0.44, 0.29, 0.2], + ]), + 'weights': np.array([0, 4, 0, 2]), + 'hps': HPS_1, + 'sigmoid_binary_cross_entropy': 16.259, + 'bi_tempered_sigmoid_binary_cross_entropy': 2.6654808, + 'sigmoid_mean_squared_error': 1.5073959, + }, +] -CROSS_ENTROPY_TEST_DATA = [{ - 'logits': - np.array([[4, 7], [-2, 5], [8, 6], [-10, -4], [3, -5]]).astype(float), - 'targets': - np.array([[1, 0], [0, 1], [1, 0], [1, 0], [0, 1]]), - 'weights': - None, - 'hps': HPS_1, -}, { - 'logits': - np.array([[4, 7], [-2, 5], [8, 6], [-10, -4], [3, -5]]).astype(float), - 'targets': - np.array([[1, 0], [0, 1], [1, 0], [1, 0], [0, 1]]), - 'weights': - np.array([2, 0, 0, 6, 1]), - 'hps': HPS_1, -}] +CROSS_ENTROPY_TEST_DATA = [ + { + 'logits': ( + np.array([[4, 7], [-2, 5], [8, 6], [-10, -4], [3, -5]]).astype( + float + ) + ), + 'targets': np.array([[1, 0], [0, 1], [1, 0], [1, 0], [0, 1]]), + 'weights': None, + 'hps': HPS_1, + }, + { + 'logits': ( + np.array([[4, 7], [-2, 5], [8, 6], [-10, -4], [3, -5]]).astype( + float + ) + ), + 'targets': np.array([[1, 0], [0, 1], [1, 0], [1, 0], [0, 1]]), + 'weights': np.array([2, 0, 0, 6, 1]), + 'hps': HPS_1, + }, +] CLASSIFICATION_KEYS = [ (loss_name, loss_name) for loss_name in CLASSIFICATION_LOSSES @@ -205,7 +238,8 @@ def test_classification_losses(self, loss_name): self.assertAlmostEqual( loss_fn(data['logits'], data['one_hot_targets'], data['weights']), data[loss_name], - places=5) + places=5, + ) @parameterized.named_parameters(*RECONSTRUCTION_KEYS) def test_regression_losses(self, loss_name): @@ -215,14 +249,18 @@ def test_regression_losses(self, loss_name): self.assertAlmostEqual( loss_fn(data['logits'], data['targets'], data['weights']), data[loss_name], - places=6) + places=6, + ) def test_cross_entropy_loss_fn(self): + """Tests equivalence of binary and multi-class cross entropy.""" for data in CROSS_ENTROPY_TEST_DATA: for binary_loss_name, loss_name in [ ('sigmoid_binary_cross_entropy', 'cross_entropy'), - ('bi_tempered_sigmoid_binary_cross_entropy', - 'bi_tempered_cross_entropy') + ( + 'bi_tempered_sigmoid_binary_cross_entropy', + 'bi_tempered_cross_entropy', + ), ]: sigmoid_binary_ce_fn = losses.get_loss_fn(binary_loss_name, data['hps']) sigmoid_binary_ce_fn = wrap_loss(self, sigmoid_binary_ce_fn) @@ -230,20 +268,23 @@ def test_cross_entropy_loss_fn(self): ce_fn = wrap_loss(self, ce_fn) self.assertAlmostEqual( sigmoid_binary_ce_fn( - np.array([[logits[0] - logits[1]] for logits in data['logits'] - ]), + np.array( + [[logits[0] - logits[1]] for logits in data['logits']] + ), np.array([[targets[0]] for targets in data['targets']]), - data['weights']), + data['weights'], + ), ce_fn(data['logits'], data['targets'], data['weights']), - places=5) + places=5, + ) def test_sigmoid_cross_entropy_per_label_weights(self): """Tests whether per label weights mask the correct entries.""" for binary_loss_name in [ 'sigmoid_binary_cross_entropy', - 'bi_tempered_sigmoid_binary_cross_entropy']: - sigmoid_binary_ce_fn = losses.get_loss_fn( - binary_loss_name, HPS_1) + 'bi_tempered_sigmoid_binary_cross_entropy', + ]: + sigmoid_binary_ce_fn = losses.get_loss_fn(binary_loss_name, HPS_1) sigmoid_binary_ce_fn = wrap_loss(self, sigmoid_binary_ce_fn) logits = np.arange(15).reshape(3, 5) targets = np.arange(15, 30).reshape(3, 5) @@ -260,8 +301,11 @@ def test_sigmoid_cross_entropy_per_label_weights(self): # per-label case. self.assertAlmostEqual( sigmoid_binary_ce_fn(logits, targets, per_label_weights), - sigmoid_binary_ce_fn(logits[:, :4], targets[:, :4], - per_example_weights) / 4) + sigmoid_binary_ce_fn( + logits[:, :4], targets[:, :4], per_example_weights + ) + / 4, + ) # optax ctc loss blank token has id = 0 by default @parameterized.named_parameters( @@ -302,5 +346,6 @@ def test_weighted_mean_absolute_error(self, logits, targets, result): self.assertAlmostEqual(loss_value, jax.numpy.array([result])) + if __name__ == '__main__': absltest.main() diff --git a/init2winit/model_lib/test_metrics.py b/init2winit/model_lib/test_metrics.py index 009d0381..b42c790a 100644 --- a/init2winit/model_lib/test_metrics.py +++ b/init2winit/model_lib/test_metrics.py @@ -34,21 +34,32 @@ class MetricsTest(parameterized.TestCase): @parameterized.named_parameters( dict( testcase_name='basic', - targets=np.array([[1., 0.], [0., 1.]]), + targets=np.array([[1.0, 0.0], [0.0, 1.0]]), logits=np.array([[0.5, 0.5], [0.5, 0.5]]), - weights=np.array([[1., 1.], [1., 1.]]), - result=0.5), + weights=np.array([[1.0, 1.0], [1.0, 1.0]]), + result=0.5, + ), dict( testcase_name='weights', - targets=np.array([[1., 0.,], [0., 1.], [0., 1.]]), + targets=np.array([ + [ + 1.0, + 0.0, + ], + [0.0, 1.0], + [0.0, 1.0], + ]), logits=np.array([[0.5, 0.5], [0.5, 0.5], [0.5, 0.7]]), - weights=np.array([[1., 1.], [0., 1.], [1., 0.]]), - result=0.5)) + weights=np.array([[1.0, 1.0], [0.0, 1.0], [1.0, 0.0]]), + result=0.5, + ), + ) def test_map(self, logits, targets, weights, result): """Tests the mean average precision computation.""" average_precision = metrics.OGBGMeanAveragePrecision.from_model_output( - logits=logits, targets=targets, weights=weights).compute() + logits=logits, targets=targets, weights=weights + ).compute() self.assertAlmostEqual(average_precision, result) def test_structural_similarity(self): @@ -56,89 +67,251 @@ def test_structural_similarity(self): NOTE(dsuo): we test with the defaults used in the FastMRI workload. """ - im1 = np.array( - [[ - 0.73833251, 0.89810601, 0.59434839, 0.112503, 0.40403852, - 0.05790091, 0.81124133, 0.68376994, 0.58584383, 0.0930026 + im1 = np.array([ + [ + 0.73833251, + 0.89810601, + 0.59434839, + 0.112503, + 0.40403852, + 0.05790091, + 0.81124133, + 0.68376994, + 0.58584383, + 0.0930026, + ], + [ + 0.42542345, + 0.56915387, + 0.17478424, + 0.1589461, + 0.74287562, + 0.47219216, + 0.52649117, + 0.50070807, + 0.67500359, + 0.37819205, + ], + [ + 0.68373459, + 0.64230722, + 0.97335725, + 0.24565012, + 0.48942928, + 0.50963254, + 0.37571989, + 0.38366919, + 0.36945232, + 0.09163938, + ], + [ + 0.15319016, + 0.36161473, + 0.61484123, + 0.17523618, + 0.73859486, + 0.16115077, + 0.01884255, + 0.98497526, + 0.02232614, + 0.71922009, + ], + [ + 0.5019574, + 0.80491521, + 0.65586547, + 0.03463707, + 0.47130842, + 0.63220364, + 0.76905247, + 0.50815002, + 0.92499088, + 0.20647629, + ], + [ + 0.77917087, + 0.07486334, + 0.36286554, + 0.27815142, + 0.706411, + 0.41677936, + 0.23959606, + 0.51440788, + 0.75697984, + 0.80130235, + ], + [ + 0.55735541, + 0.62024555, + 0.87929081, + 0.8054033, + 0.12014468, + 0.22865017, + 0.23542662, + 0.71521724, + 0.40843243, + 0.25604842, + ], + [ + 0.26829756, + 0.19476007, + 0.38425566, + 0.38231672, + 0.12902957, + 0.60572083, + 0.65571312, + 0.6134444, + 0.13835472, + 0.06307113, + ], + [ + 0.75307163, + 0.44877311, + 0.31321154, + 0.03057511, + 0.86964725, + 0.89864268, + 0.53593918, + 0.87913734, + 0.72179573, + 0.03976076, + ], + [ + 0.55826683, + 0.18015783, + 0.90069321, + 0.55805617, + 0.71336459, + 0.67812158, + 0.27842012, + 0.85890605, + 0.22515742, + 0.60356573, + ], + ]) + + im2 = np.array([ + [ + 0.38785895, + 0.69382307, + 0.24389369, + 0.89767306, + 0.42301789, + 0.16313277, + 0.80090617, + 0.16567136, + 0.71543147, + 0.30399568, + ], + [ + 0.35219695, + 0.57636231, + 0.5339162, + 0.51421423, + 0.55444482, + 0.3299572, + 0.08871051, + 0.90975499, + 0.26302511, + 0.08448494, + ], + [ + 0.88601557, + 0.41470639, + 0.68370194, + 0.64813528, + 0.86429226, + 0.69276718, + 0.66361842, + 0.4851298, + 0.74617258, + 0.28851107, + ], + [ + 0.84669316, + 0.99759206, + 0.79429959, + 0.19977481, + 0.4833177, + 0.57696104, + 0.13978823, + 0.63513837, + 0.73423608, + 0.83064902, + ], + [ + 0.45382891, + 0.41018542, + 0.86997271, + 0.39990761, + 0.32097822, + 0.52282046, + 0.05960004, + 0.95429451, + 0.03181412, + 0.80956527, + ], + [ + 0.90649511, + 0.61557879, + 0.59897015, + 0.94188484, + 0.90297625, + 0.76986281, + 0.2392755, + 0.33402192, + 0.36923513, + 0.54177217, + ], + [ + 0.44340055, + 0.58440755, + 0.45363187, + 0.74527457, + 0.23761691, + 0.74693863, + 0.11449182, + 0.48795747, + 0.94897711, + 0.01631275, + ], + [ + 0.30310764, + 0.07203944, + 0.11931363, + 0.48873794, + 0.18900569, + 0.00643777, + 0.30659393, + 0.41300417, + 0.69529398, + 0.24826242, + ], + [ + 0.98076131, + 0.6875125, + 0.54545994, + 0.16997529, + 0.0698003, + 0.59835326, + 0.58198102, + 0.8785474, + 0.69644425, + 0.73404286, ], - [ - 0.42542345, 0.56915387, 0.17478424, 0.1589461, 0.74287562, - 0.47219216, 0.52649117, 0.50070807, 0.67500359, 0.37819205 - ], - [ - 0.68373459, 0.64230722, 0.97335725, 0.24565012, 0.48942928, - 0.50963254, 0.37571989, 0.38366919, 0.36945232, 0.09163938 - ], - [ - 0.15319016, 0.36161473, 0.61484123, 0.17523618, 0.73859486, - 0.16115077, 0.01884255, 0.98497526, 0.02232614, 0.71922009 - ], - [ - 0.5019574, 0.80491521, 0.65586547, 0.03463707, 0.47130842, - 0.63220364, 0.76905247, 0.50815002, 0.92499088, 0.20647629 - ], - [ - 0.77917087, 0.07486334, 0.36286554, 0.27815142, 0.706411, - 0.41677936, 0.23959606, 0.51440788, 0.75697984, 0.80130235 - ], - [ - 0.55735541, 0.62024555, 0.87929081, 0.8054033, 0.12014468, - 0.22865017, 0.23542662, 0.71521724, 0.40843243, 0.25604842 - ], - [ - 0.26829756, 0.19476007, 0.38425566, 0.38231672, 0.12902957, - 0.60572083, 0.65571312, 0.6134444, 0.13835472, 0.06307113 - ], - [ - 0.75307163, 0.44877311, 0.31321154, 0.03057511, 0.86964725, - 0.89864268, 0.53593918, 0.87913734, 0.72179573, 0.03976076 - ], - [ - 0.55826683, 0.18015783, 0.90069321, 0.55805617, 0.71336459, - 0.67812158, 0.27842012, 0.85890605, 0.22515742, 0.60356573 - ]]) - - im2 = np.array( - [[ - 0.38785895, 0.69382307, 0.24389369, 0.89767306, 0.42301789, - 0.16313277, 0.80090617, 0.16567136, 0.71543147, 0.30399568 + [ + 0.1310947, + 0.91694649, + 0.32005394, + 0.98112882, + 0.4818337, + 0.26479291, + 0.97803938, + 0.03502056, + 0.72615619, + 0.72047081, ], - [ - 0.35219695, 0.57636231, 0.5339162, 0.51421423, 0.55444482, - 0.3299572, 0.08871051, 0.90975499, 0.26302511, 0.08448494 - ], - [ - 0.88601557, 0.41470639, 0.68370194, 0.64813528, 0.86429226, - 0.69276718, 0.66361842, 0.4851298, 0.74617258, 0.28851107 - ], - [ - 0.84669316, 0.99759206, 0.79429959, 0.19977481, 0.4833177, - 0.57696104, 0.13978823, 0.63513837, 0.73423608, 0.83064902 - ], - [ - 0.45382891, 0.41018542, 0.86997271, 0.39990761, 0.32097822, - 0.52282046, 0.05960004, 0.95429451, 0.03181412, 0.80956527 - ], - [ - 0.90649511, 0.61557879, 0.59897015, 0.94188484, 0.90297625, - 0.76986281, 0.2392755, 0.33402192, 0.36923513, 0.54177217 - ], - [ - 0.44340055, 0.58440755, 0.45363187, 0.74527457, 0.23761691, - 0.74693863, 0.11449182, 0.48795747, 0.94897711, 0.01631275 - ], - [ - 0.30310764, 0.07203944, 0.11931363, 0.48873794, 0.18900569, - 0.00643777, 0.30659393, 0.41300417, 0.69529398, 0.24826242 - ], - [ - 0.98076131, 0.6875125, 0.54545994, 0.16997529, 0.0698003, - 0.59835326, 0.58198102, 0.8785474, 0.69644425, 0.73404286 - ], - [ - 0.1310947, 0.91694649, 0.32005394, 0.98112882, 0.4818337, - 0.26479291, 0.97803938, 0.03502056, 0.72615619, 0.72047081 - ]]) + ]) jim1 = jnp.array(im1) jim2 = jnp.array(im2) @@ -154,7 +327,7 @@ def test_wer_spm(self): source_sentence = "Let's start praying this test passes!" decoded_sentence = source_sentence - class MockSPMTokenizer(): + class MockSPMTokenizer: def tokenize(self, s): return s @@ -166,25 +339,29 @@ def detokenize(self, s): tokenizer_type = 'SPM' with mock.patch.object( - tokenizer, 'tokenize', return_value=[1, 2, 3], autospec=True): + tokenizer, 'tokenize', return_value=[1, 2, 3], autospec=True + ): with mock.patch.object( - tokenizer, - 'detokenize', - return_value=[source_sentence], - autospec=True): - source_tokens = jnp.array([tokenizer.tokenize(source_sentence)], - dtype=jnp.int32) + tokenizer, 'detokenize', return_value=[source_sentence], autospec=True + ): + source_tokens = jnp.array( + [tokenizer.tokenize(source_sentence)], dtype=jnp.int32 + ) source_paddings = jnp.array([[0.0, 0.0, 0.0]], dtype=jnp.float32) - decoded_tokens = jnp.array([tokenizer.tokenize(decoded_sentence)], - dtype=jnp.int32) + decoded_tokens = jnp.array( + [tokenizer.tokenize(decoded_sentence)], dtype=jnp.int32 + ) decoded_paddings = jnp.array([[0.0, 0.0, 0.0]], dtype=jnp.float32) - word_errors, num_words = metrics.compute_wer(decoded_tokens, - decoded_paddings, - source_tokens, - source_paddings, tokenizer, - tokenizer_type) + word_errors, num_words = metrics.compute_wer( + decoded_tokens, + decoded_paddings, + source_tokens, + source_paddings, + tokenizer, + tokenizer_type, + ) self.assertEqual(word_errors, 0) self.assertEqual(num_words, 6.0) diff --git a/init2winit/model_lib/test_models.py b/init2winit/model_lib/test_models.py index 0a19e153..f84d6f99 100644 --- a/init2winit/model_lib/test_models.py +++ b/init2winit/model_lib/test_models.py @@ -38,7 +38,6 @@ from ml_collections.config_dict import config_dict import numpy as np - HIDDEN_SIZES = (50, 50) INPUT_SHAPE = { @@ -348,8 +347,14 @@ autoencoder_models = ['autoencoder', 'convolutional_autoencoder'] text_models = ['transformer', 'performer', 'lstm'] classification_models = [ - 'fully_connected', 'simple_cnn', 'max_pooling_cnn', 'wide_resnet', 'resnet', - 'adabelief_densenet', 'adabelief_vgg', 'fake_resnet' + 'fully_connected', + 'simple_cnn', + 'max_pooling_cnn', + 'wide_resnet', + 'resnet', + 'adabelief_densenet', + 'adabelief_vgg', + 'fake_resnet', ] binary_classification_models = ['dlrm', 'dlrm_resnet'] # TODO(kasimbeg) generative_models = ['unet'] # TODO(kasimbeg) @@ -370,10 +375,14 @@ ] # pylint: disable=g-complex-comprehension model_init_keys = [ - ('test_model_{}_dtype_{}'.format( + ( + 'test_model_{}_dtype_{}'.format( + model_str, + dtype, + ), model_str, dtype, - ), model_str, dtype) + ) for model_str, dtype in itertools.product(all_models, dtypes) if model_str not in skipped_models ] @@ -384,9 +393,11 @@ binary_classification_keys = [ ('test_{}'.format(m), m) for m in binary_classification_models ] -text_keys = [('test_{}_{}'.format(m, d), m, d) - for m, d in itertools.product(text_models, dtypes) - if d != 'bfloat16' or m == 'transformer'] +text_keys = [ + ('test_{}_{}'.format(m, d), m, d) + for m, d in itertools.product(text_models, dtypes) + if d != 'bfloat16' or m == 'transformer' +] dtype_keys = [('test_{}'.format(t), t) for t in dtypes] remat_scan_keys = [('test_no_remat_scan', None), ('test_remat_scan', (2, 2))] dtype_and_remat_scan_keys = [ @@ -418,7 +429,8 @@ def _get_fake_inputs_for_initialization(model, hps): return fake_inputs else: raise NotImplementedError( - 'Method get_fake_inputs not implemented for model.') + 'Method get_fake_inputs not implemented for model.' + ) return fake_inputs @@ -437,12 +449,10 @@ def _initialize_model(model_str, model_dtype): model = model_cls( hps, - dataset_meta_data={ - 'shift_inputs': True, - 'causal': True - }, + dataset_meta_data={'shift_inputs': True, 'causal': True}, loss_name=LOSS_NAME[model_str], - metrics_name=METRICS_NAME[model_str]) + metrics_name=METRICS_NAME[model_str], + ) rng = jax.random.PRNGKey(0) initializer = initializers.get_initializer('noop') init_dict = model.initialize(initializer, hps, rng, metrics_logger=None) @@ -476,7 +486,8 @@ def test_classification_models(self, model_str): model = model_cls(hps, {}, loss, metrics) xs = jnp.array(np.random.normal(size=INPUT_SHAPE['classification'])) model_init_fn = jax.jit( - functools.partial(model.flax_module.init, train=False)) + functools.partial(model.flax_module.init, train=False) + ) init_dict = model_init_fn({'params': params_rng}, xs) params = init_dict['params'] batch_stats = init_dict.get('batch_stats', {}) @@ -487,9 +498,12 @@ def test_classification_models(self, model_str): xs, mutable=['batch_stats'], rngs={'dropout': dropout_rng}, - train=True) - self.assertEqual(outputs.shape, (INPUT_SHAPE['classification'][0], - OUTPUT_SHAPE['classification'][-1])) + train=True, + ) + self.assertEqual( + outputs.shape, + (INPUT_SHAPE['classification'][0], OUTPUT_SHAPE['classification'][-1]), + ) # If it's a batch norm model check the batch stats changed. if batch_stats: @@ -499,10 +513,12 @@ def test_classification_models(self, model_str): # Test batch_norm in inference mode. outputs = model.flax_module.apply( - {'params': params, 'batch_stats': batch_stats}, xs, train=False) + {'params': params, 'batch_stats': batch_stats}, xs, train=False + ) self.assertEqual( outputs.shape, - (INPUT_SHAPE['classification'][0], OUTPUT_SHAPE['classification'][-1])) + (INPUT_SHAPE['classification'][0], OUTPUT_SHAPE['classification'][-1]), + ) @parameterized.named_parameters(*text_keys) def test_text_models(self, model_str, dtype_str): @@ -548,41 +564,45 @@ def test_text_models(self, model_str, dtype_str): rng = jax.random.PRNGKey(0) loss = 'cross_entropy' metrics = 'classification_metrics' - model = model_cls(small_hps, { - 'max_len': 64, - 'shift_inputs': True, - 'causal': True - }, loss, metrics) + model = model_cls( + small_hps, + {'max_len': 64, 'shift_inputs': True, 'causal': True}, + loss, + metrics, + ) xs = jnp.array( - np.random.randint(size=text_input_shape, low=1, high=vocab_size)) + np.random.randint(size=text_input_shape, low=1, high=vocab_size) + ) dropout_rng, params_rng = jax.random.split(rng) model_init_fn = jax.jit( - functools.partial(model.flax_module.init, train=False)) + functools.partial(model.flax_module.init, train=False) + ) init_dict = model_init_fn({'params': params_rng}, xs) params = init_dict['params'] batch_stats = init_dict.get('batch_stats', {}) param_type_matches_model_type = jax.tree_util.tree_map( - lambda x: x.dtype == small_hps.model_dtype, params) + lambda x: x.dtype == small_hps.model_dtype, params + ) self.assertTrue( - jax.tree_util.tree_reduce(lambda x, y: x and y, - param_type_matches_model_type)) + jax.tree_util.tree_reduce( + lambda x, y: x and y, param_type_matches_model_type + ) + ) # Check that the forward pass works with mutated batch_stats. # Due to a bug in flax, this jit is required, otherwise the model errors. @jax.jit def forward_pass(params, xs, dropout_rng): outputs, new_batch_stats = model.flax_module.apply( - { - 'params': params, - 'batch_stats': batch_stats - }, + {'params': params, 'batch_stats': batch_stats}, xs, mutable=['batch_stats'], capture_intermediates=True, rngs={'dropout': dropout_rng}, - train=True) + train=True, + ) return outputs, new_batch_stats outputs, new_batch_stats = forward_pass(params, xs, dropout_rng) @@ -654,10 +674,7 @@ def test_translate_model(self, dtype_str, remat_scan_lengths): 'dropout_rate': 0.1, 'attention_dropout_rate': 0.1, 'momentum': 0.9, - 'lr_hparams': { - 'base_lr': 0.005, - 'schedule': 'constant' - }, + 'lr_hparams': {'base_lr': 0.005, 'schedule': 'constant'}, 'vocab_size': vocab_size, 'output_shape': (vocab_size,), 'model_dtype': dtype_str, @@ -675,27 +692,30 @@ def test_translate_model(self, dtype_str, remat_scan_lengths): rng = jax.random.PRNGKey(0) loss = 'cross_entropy' metrics = 'classification_metrics' - model = model_cls(small_hps, { - 'shift_outputs': True, - 'causal': True - }, loss, metrics) + model = model_cls( + small_hps, {'shift_outputs': True, 'causal': True}, loss, metrics + ) xs = jnp.array( - np.random.randint(size=text_src_input_shape, low=1, - high=vocab_size)) + np.random.randint(size=text_src_input_shape, low=1, high=vocab_size) + ) ys = jnp.array( - np.random.randint(size=text_tgt_input_shape, low=1, - high=vocab_size)) + np.random.randint(size=text_tgt_input_shape, low=1, high=vocab_size) + ) dropout_rng, params_rng = jax.random.split(rng) model_init_fn = jax.jit( - functools.partial(model.flax_module.init, train=False)) + functools.partial(model.flax_module.init, train=False) + ) init_dict = model_init_fn({'params': params_rng}, xs, ys) params = init_dict['params'] param_type_matches_model_type = jax.tree_util.tree_map( - lambda x: x.dtype == small_hps.model_dtype, params) + lambda x: x.dtype == small_hps.model_dtype, params + ) self.assertTrue( - jax.tree_util.tree_reduce(lambda x, y: x and y, - param_type_matches_model_type)) + jax.tree_util.tree_reduce( + lambda x, y: x and y, param_type_matches_model_type + ) + ) # Test forward pass. @jax.jit @@ -706,7 +726,8 @@ def forward_pass(params, xs, ys, dropout_rng): ys, rngs={'dropout': dropout_rng}, capture_intermediates=True, - train=True) + train=True, + ) return outputs, intermediates logits, intermediates = forward_pass(params, xs, ys, dropout_rng) @@ -714,13 +735,17 @@ def forward_pass(params, xs, ys, dropout_rng): # TODO(ankugarg): Add tests for individual encoder/decoder (inference mode). self.assertEqual( logits.shape, - (text_tgt_input_shape[0], text_tgt_input_shape[1], vocab_size)) + (text_tgt_input_shape[0], text_tgt_input_shape[1], vocab_size), + ) intermediates_type_matches_model_type = jax.tree_util.tree_map( - lambda x: x.dtype == small_hps.model_dtype, intermediates) + lambda x: x.dtype == small_hps.model_dtype, intermediates + ) self.assertTrue( - jax.tree_util.tree_reduce(lambda x, y: x and y, - intermediates_type_matches_model_type)) + jax.tree_util.tree_reduce( + lambda x, y: x and y, intermediates_type_matches_model_type + ) + ) def test_nqm(self): """Test the noisy quadratic model.""" @@ -735,7 +760,8 @@ def test_nqm(self): noise_decay_power=1.0, nqm_mode='diagH_diagC', model_dtype='float32', - )) + ) + ) model_cls = models.get_model('nqm') params_rng = jax.random.PRNGKey(0) @@ -743,7 +769,8 @@ def test_nqm(self): noise_eps = jnp.array(np.random.normal(size=(batch_size, dim))) xs = np.zeros((batch_size, dim)) model_init_fn = jax.jit( - functools.partial(model.flax_module.init, train=False)) + functools.partial(model.flax_module.init, train=False) + ) params = model_init_fn({'params': params_rng}, xs)['params'] model_x = params['x'] @@ -756,12 +783,14 @@ def loss(params, inputs): np.array([ 1.0 / np.power(i, model_hps.hessian_decay_power) for i in range(1, dim + 1) - ])) + ]) + ) noise_matrix = np.diag( np.array([ 1.0 / np.power(i, model_hps.noise_decay_power / 2.0) for i in range(1, dim + 1) - ])) + ]) + ) noise = jnp.dot(noise_eps, noise_matrix) mean_noise = np.mean(noise, axis=0) @@ -788,7 +817,8 @@ def test_autoencoder_model(self, model_str): model = model_cls(hps, {}, loss, metrics) xs = jnp.array(np.random.normal(size=INPUT_SHAPE[model_str])) model_init_fn = jax.jit( - functools.partial(model.flax_module.init, train=False)) + functools.partial(model.flax_module.init, train=False) + ) init_dict = model_init_fn({'params': params_rng}, xs) params = init_dict['params'] batch_stats = init_dict.get('batch_stats', {}) @@ -798,10 +828,12 @@ def test_autoencoder_model(self, model_str): {'params': params, 'batch_stats': batch_stats}, xs, mutable=['batch_stats'], - train=True) + train=True, + ) self.assertEqual( outputs.shape, - tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str]))) + tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str])), + ) # If it's a batch norm model check the batch stats changed. if batch_stats: @@ -811,10 +843,12 @@ def test_autoencoder_model(self, model_str): # Test batch_norm in inference mode. outputs = model.flax_module.apply( - {'params': params, 'batch_stats': batch_stats}, xs, train=False) + {'params': params, 'batch_stats': batch_stats}, xs, train=False + ) self.assertEqual( outputs.shape, - tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str]))) + tuple([INPUT_SHAPE[model_str][0]] + list(OUTPUT_SHAPE[model_str])), + ) def test_graph_model(self): """Test forward pass of the GNN model.""" @@ -823,11 +857,13 @@ def test_graph_model(self): output_shape = (5,) model_str = 'gnn' model_hps = models.get_model_hparams(model_str) - model_hps.update({'output_shape': output_shape, - 'latent_dim': 10, - 'hidden_dims': (10,), - 'batch_size': 5, - 'normalizer': 'batch_norm'}) + model_hps.update({ + 'output_shape': output_shape, + 'latent_dim': 10, + 'hidden_dims': (10,), + 'batch_size': 5, + 'normalizer': 'batch_norm', + }) model_cls = models.get_model(model_str) rng = jax.random.PRNGKey(0) dropout_rng, params_rng = jax.random.split(rng) @@ -841,14 +877,17 @@ def test_graph_model(self): inputs = jraph.get_fully_connected_graph( n_node_per_graph=node_per_graph, n_graph=num_graphs, - node_features=np.ones((num_graphs * node_per_graph,) + - node_input_shape), + node_features=np.ones( + (num_graphs * node_per_graph,) + node_input_shape + ), ) inputs = inputs._replace( - edges=np.ones((num_graphs * edge_per_graph,) + edge_input_shape)) + edges=np.ones((num_graphs * edge_per_graph,) + edge_input_shape) + ) padded_inputs = jraph.pad_with_graphs(inputs, 20, 50, 7) model_init_fn = jax.jit( - functools.partial(model.flax_module.init, train=False)) + functools.partial(model.flax_module.init, train=False) + ) init_dict = model_init_fn({'params': params_rng}, padded_inputs) params = init_dict['params'] batch_stats = init_dict['batch_stats'] @@ -859,7 +898,8 @@ def test_graph_model(self): padded_inputs, mutable=['batch_stats'], rngs={'dropout': dropout_rng}, - train=True) + train=True, + ) self.assertEqual(outputs.shape, (7,) + output_shape) def test_local_attention_transformer(self): @@ -879,19 +919,18 @@ def test_local_attention_transformer(self): model = model_cls(model_hps, {}, loss, metrics) inputs = jnp.array(np.random.randint(size=(1, 16), low=1, high=8)) model_init_fn = jax.jit( - functools.partial(model.flax_module.init, train=False)) + functools.partial(model.flax_module.init, train=False) + ) init_dict = model_init_fn({'params': params_rng}, inputs) params = init_dict['params'] batch_stats = init_dict.get('batch_stats', {}) outputs, _ = model.flax_module.apply( - { - 'params': params, - 'batch_stats': batch_stats - }, + {'params': params, 'batch_stats': batch_stats}, inputs, mutable=['batch_stats'], rngs={'dropout': dropout_rng}, - train=True) + train=True, + ) self.assertEqual(outputs.shape, (1, 16, 8)) @parameterized.named_parameters(*lstm_keys) @@ -1016,15 +1055,16 @@ def test_vit_model_apply_sharding_overrides_failure(self): # pylint: disable=protected-access try: _ = model._apply_sharding_overrides( - params, host_mesh, nn.get_sharding(params, host_mesh), bad_overrides) + params, host_mesh, nn.get_sharding(params, host_mesh), bad_overrides + ) except ValueError as e: error = e self.assertEqual( str(error), 'Param shape (1536,) is not compatible with sharding ' - 'NamedSharding(mesh=Mesh(\'devices\': 8, axis_types=(Auto,)), ' - 'spec=P(None, \'devices\'), memory_kind=device)', + "NamedSharding(mesh=Mesh('devices': 8, axis_types=(Auto,)), " + "spec=P(None, 'devices'), memory_kind=device)", ) good_overrides = { @@ -1056,7 +1096,12 @@ def test_vit_params_shapes(self): }, 'conv_patch_extract': { 'bias': model_utils.ShapeTuple((384,)), - 'kernel': model_utils.ShapeTuple((16, 16, 3, 384,)), + 'kernel': model_utils.ShapeTuple(( + 16, + 16, + 3, + 384, + )), }, 'head': { 'bias': model_utils.ShapeTuple((1000,)), diff --git a/init2winit/model_lib/test_normalization.py b/init2winit/model_lib/test_normalization.py index 75f0d33f..3cd880c4 100644 --- a/init2winit/model_lib/test_normalization.py +++ b/init2winit/model_lib/test_normalization.py @@ -17,10 +17,8 @@ import functools from absl.testing import absltest - from flax import linen as nn from init2winit.model_lib import normalization - import jax import jax.numpy as jnp import numpy as np @@ -31,7 +29,8 @@ def _init(flax_module, rng, input_shape): model_init_fn = jax.jit( - functools.partial(flax_module.init, use_running_average=False)) + functools.partial(flax_module.init, use_running_average=False) + ) xs = np.zeros(input_shape) init_dict = model_init_fn({'params': rng}, xs) params = init_dict['params'] @@ -61,33 +60,38 @@ def test_batch_norm(self): bn_params, bn_state = _init(bn_flax_module, rng, input_shape) vbn_flax_module = normalization.VirtualBatchNorm( - momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC') + momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC' + ) vbn_params, vbn_state = _init(vbn_flax_module, rng, input_shape) _, bn_state = bn_flax_module.apply( {'params': bn_params, 'batch_stats': bn_state}, x, mutable=['batch_stats'], - use_running_average=False) + use_running_average=False, + ) bn_state = bn_state['batch_stats'] bn_y, bn_state = bn_flax_module.apply( {'params': bn_params, 'batch_stats': bn_state}, x, mutable=['batch_stats'], - use_running_average=False) + use_running_average=False, + ) bn_state = bn_state['batch_stats'] _, vbn_state = vbn_flax_module.apply( {'params': vbn_params, 'batch_stats': vbn_state}, x, mutable=['batch_stats'], - use_running_average=False) + use_running_average=False, + ) vbn_state = vbn_state['batch_stats'] vbn_y, vbn_state = vbn_flax_module.apply( {'params': vbn_params, 'batch_stats': vbn_state}, x, mutable=['batch_stats'], - use_running_average=False) + use_running_average=False, + ) vbn_state = vbn_state['batch_stats'] # Test that the layer forward passes are the same. @@ -97,11 +101,13 @@ def test_batch_norm(self): np.testing.assert_allclose( bn_state['mean'], np.squeeze(vbn_state['batch_norm_running_mean']), - atol=1e-4) + atol=1e-4, + ) np.testing.assert_allclose( bn_state['var'], np.squeeze(vbn_state['batch_norm_running_var']), - atol=1e-4) + atol=1e-4, + ) def test_forward_pass(self): """Test that two calls are the same as one with twice the batch size.""" @@ -123,57 +129,66 @@ def test_forward_pass(self): x_both = jnp.concatenate((x1, x2)) expected_bn_y = jnp.concatenate( - (jnp.ones(half_input_shape) * -1.0, jnp.ones(half_input_shape))) + (jnp.ones(half_input_shape) * -1.0, jnp.ones(half_input_shape)) + ) bn_flax_module = nn.BatchNorm(momentum=0.9) bn_params, bn_state = _init(bn_flax_module, rng, input_shape) vbn_flax_module = normalization.VirtualBatchNorm( - momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC') + momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC' + ) vbn_params, vbn_state = _init(vbn_flax_module, rng, input_shape) bn_y1, _ = bn_flax_module.apply( {'params': bn_params, 'batch_stats': bn_state}, x1, mutable=['batch_stats'], - use_running_average=False) + use_running_average=False, + ) bn_y2, _ = bn_flax_module.apply( {'params': bn_params, 'batch_stats': bn_state}, x2, mutable=['batch_stats'], - use_running_average=False) + use_running_average=False, + ) bn_y_both, _ = bn_flax_module.apply( {'params': bn_params, 'batch_stats': bn_state}, x_both, mutable=['batch_stats'], - use_running_average=False) + use_running_average=False, + ) vbn_y_both, vbn_state = vbn_flax_module.apply( {'params': vbn_params, 'batch_stats': vbn_state}, x_both, mutable=['batch_stats'], - use_running_average=False) + use_running_average=False, + ) vbn_state = vbn_state['batch_stats'] # Test that the layer forward passes behave as expected. np.testing.assert_allclose(bn_y1, expected_bn_y, atol=1e-4) np.testing.assert_allclose(bn_y2, expected_bn_y, atol=1e-4) np.testing.assert_allclose( - vbn_y_both, jnp.concatenate((bn_y1, bn_y2)), atol=1e-4) + vbn_y_both, jnp.concatenate((bn_y1, bn_y2)), atol=1e-4 + ) # Test that the virtual batch norm and nn.BatchNorm layers do not perform # the same calculation on the concatenated batch. # There is no negative of `np.testing.assert_allclose` so we test that the # diff is greater than zero. np.testing.assert_array_less( - -jnp.abs(vbn_y_both - bn_y_both), jnp.zeros_like(vbn_y_both)) + -jnp.abs(vbn_y_both - bn_y_both), jnp.zeros_like(vbn_y_both) + ) _, vbn_state = vbn_flax_module.apply( {'params': vbn_params, 'batch_stats': vbn_state}, x_both, mutable=['batch_stats'], - use_running_average=False) + use_running_average=False, + ) vbn_state = vbn_state['batch_stats'] # The mean running average stats at 0.0, and the variance starts at 1.0. So @@ -184,16 +199,18 @@ def test_forward_pass(self): expected_mean_ema_x2 = 0.19 * jnp.mean(x2) * jnp.ones((feature_size,)) expected_mean_ema_both = (expected_mean_ema_x1 + expected_mean_ema_x2) / 2.0 expected_var_ema_both = ( - (0.19 * jnp.std(jnp.concatenate((x1, x2))) ** 2.0 + 0.81) * - jnp.ones((feature_size,))) + 0.19 * jnp.std(jnp.concatenate((x1, x2))) ** 2.0 + 0.81 + ) * jnp.ones((feature_size,)) np.testing.assert_allclose( np.squeeze(vbn_state['batch_norm_running_mean']), expected_mean_ema_both, - atol=1e-4) + atol=1e-4, + ) np.testing.assert_allclose( np.squeeze(vbn_state['batch_norm_running_var']), expected_var_ema_both, - atol=1e-4) + atol=1e-4, + ) def test_different_eval_batch_size(self): """Test virtual BN can use a different batch size for evals.""" @@ -204,20 +221,23 @@ def test_different_eval_batch_size(self): x = 2.0 * jnp.ones(input_shape) vbn_flax_module = normalization.VirtualBatchNorm( - momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC') + momentum=0.9, virtual_batch_size=batch_size, data_format='NHWC' + ) vbn_params, vbn_state = _init(vbn_flax_module, rng, input_shape) _, vbn_state = vbn_flax_module.apply( {'params': vbn_params, 'batch_stats': vbn_state}, x, mutable=['batch_stats'], - use_running_average=False) + use_running_average=False, + ) vbn_state = vbn_state['batch_stats'] vbn_flax_module.apply( {'params': vbn_params, 'batch_stats': vbn_state}, jnp.ones((13, 3, 3, feature_size)), - use_running_average=True) + use_running_average=True, + ) if __name__ == '__main__': diff --git a/init2winit/model_lib/transformer_lm.py b/init2winit/model_lib/transformer_lm.py index f2fb4b3d..89104a69 100644 --- a/init2winit/model_lib/transformer_lm.py +++ b/init2winit/model_lib/transformer_lm.py @@ -18,6 +18,7 @@ Adapted from https://github.com/google/flax/blob/b60f7f45b90f8fc42a88b1639c9cc88a40b298d3/examples/lm1b/models.py """ + from typing import Any, Optional from flax import linen as nn @@ -48,7 +49,8 @@ model_dtype='float32', decode=False, normalize_attention=False, - )) + ) +) def shift_right(x, axis=1): @@ -56,7 +58,8 @@ def shift_right(x, axis=1): pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) + x, pad_widths, mode='constant', constant_values=x.dtype.type(0) + ) return lax.dynamic_slice_in_dim(padded, 0, padded.shape[axis] - 1, axis) @@ -66,7 +69,7 @@ def shift_inputs(x, segment_ids=None, axis=1): # For packed targets, the first shifted token of a new sequence is made # 0, rather than being the EOS token for the last sequence. if segment_ids is not None: - shifted *= (segment_ids == shift_right(segment_ids, axis=axis)) + shifted *= segment_ids == shift_right(segment_ids, axis=axis) return shifted @@ -87,7 +90,8 @@ def init(key, shape, dtype=np.float32): pe = np.zeros((max_len, d_feature), dtype=dtype) position = np.arange(0, max_len)[:, np.newaxis] div_term = np.exp( - np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) + np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature) + ) pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] @@ -104,6 +108,7 @@ class AddPositionEmbs(nn.Module): posemb_init: positional embedding initializer decode: whether to run in single-position autoregressive mode. """ + max_len: int = 2048 posemb_init: model_utils.Initializer = nn.initializers.normal(stddev=1.0) decode: bool = False @@ -125,24 +130,28 @@ def __call__(self, inputs, inputs_positions=None, dtype=np.float32): output: `(bs, timesteps, in_dim)` """ # inputs.shape is (batch_size, seq_len, emb_dim) - assert inputs.ndim == 3, ('Number of dimensions should be 3,' - ' but it is: %d' % inputs.ndim) + assert inputs.ndim == 3, ( + 'Number of dimensions should be 3, but it is: %d' % inputs.ndim + ) length = inputs.shape[1] pos_emb_shape = (1, self.max_len, inputs.shape[-1]) if self.posemb_init is None: # Use a fixed (non-learned) sinusoidal position embedding. - pos_embedding = sinusoidal_init(max_len=self.max_len)(None, pos_emb_shape, - dtype) + pos_embedding = sinusoidal_init(max_len=self.max_len)( + None, pos_emb_shape, dtype + ) else: - pos_embedding = self.param('pos_embedding', self.posemb_init, - pos_emb_shape, dtype) + pos_embedding = self.param( + 'pos_embedding', self.posemb_init, pos_emb_shape, dtype + ) pe = pos_embedding[:, :length, :] # We use a cache position index for tracking decoding position. if self.decode: is_initialized = self.has_variable('cache', 'cache_index') - cache_index = self.variable('cache', 'cache_index', - lambda: jnp.array(0, dtype=jnp.uint32)) + cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32) + ) if is_initialized: i = cache_index.value cache_index.value = i + 1 @@ -158,6 +167,7 @@ def __call__(self, inputs, inputs_positions=None, dtype=np.float32): class MlpBlock(nn.Module): """Transformer MLP block.""" + mlp_dim: int out_dim: Optional[int] = None dropout_rate: float = 0.1 @@ -174,8 +184,8 @@ def __call__(self, inputs, train): kernel_init=self.kernel_init, bias_init=self.bias_init, dtype=self.dtype, - param_dtype=self.dtype)( - inputs) + param_dtype=self.dtype, + )(inputs) x = nn.gelu(x) x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) output = nn.Dense( @@ -183,8 +193,8 @@ def __call__(self, inputs, train): kernel_init=self.kernel_init, bias_init=self.bias_init, dtype=self.dtype, - param_dtype=self.dtype)( - x) + param_dtype=self.dtype, + )(x) output = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(output) return output @@ -192,16 +202,17 @@ def __call__(self, inputs, train): class Transformer1DBlock(nn.Module): """Transformer layer (https://openreview.net/forum?id=H1e5GJBtDr). - qkv_dim: dimension of the query/key/value - mlp_dim: dimension of the mlp on top of attention block - num_heads: number of heads - dropout_rate: dropout rate - attention_dropout_rate: dropout rate for attention weights - normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', - 'pre_layer_norm', 'none' - attention_fn: Attention function to use. If None, defaults to - nn.dot_product_attention. + qkv_dim: dimension of the query/key/value + mlp_dim: dimension of the mlp on top of attention block + num_heads: number of heads + dropout_rate: dropout rate + attention_dropout_rate: dropout rate for attention weights + normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', + 'pre_layer_norm', 'none' + attention_fn: Attention function to use. If None, defaults to + nn.dot_product_attention. """ + qkv_dim: int mlp_dim: int num_heads: int @@ -214,13 +225,15 @@ class Transformer1DBlock(nn.Module): normalize_attention: bool = False @nn.compact - def __call__(self, - inputs, - train, - decoder_mask=None, - encoder_decoder_mask=None, - inputs_positions=None, - inputs_segmentation=None): + def __call__( + self, + inputs, + train, + decoder_mask=None, + encoder_decoder_mask=None, + inputs_positions=None, + inputs_segmentation=None, + ): """Applies Transformer1DBlock module. Args: @@ -233,23 +246,29 @@ def __call__(self, Returns: output after transformer block. - """ # Attention block. assert inputs.ndim == 3 if self.normalizer in [ - 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none' + 'batch_norm', + 'layer_norm', + 'pre_layer_norm', + 'none', ]: maybe_pre_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) else: raise ValueError('Unsupported normalizer: {}'.format(self.normalizer)) @@ -272,7 +291,8 @@ def __call__(self, attention_fn=attention_fn, dropout_rate=self.attention_dropout_rate, normalize_attention=self.normalize_attention, - deterministic=not train)(x, decoder_mask) + deterministic=not train, + )(x, decoder_mask) x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) x = x + inputs x = maybe_post_normalize(param_dtype=self.dtype)(x) @@ -280,8 +300,8 @@ def __call__(self, # MLP block. y = maybe_pre_normalize(param_dtype=self.dtype)(x) y = MlpBlock( - mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, dtype=self.dtype)( - y, train=train) + mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, dtype=self.dtype + )(y, train=train) res = x + y return maybe_post_normalize(param_dtype=self.dtype)(res) @@ -290,25 +310,26 @@ def __call__(self, class TransformerLM(nn.Module): """Transformer Model for language modeling. - vocab_size: size of the vocabulary - emb_dim: dimension of embedding - num_heads: number of heads - num_layers: number of layers - qkv_dim: dimension of the query/key/value - mlp_dim: dimension of the mlp on top of attention block - max_len: maximum length. - train: bool: if model is training. - causal: Whether to apply causal masking. - shift: bool: if we right-shift input - this is only disabled for - fast, looped single-token autoregressive decoding. - dropout_rate: dropout rate - attention_dropout_rate: dropout rate for attention weights - normalizer: One of 'batch_norm', 'layer_norm', 'none' - attention_fn: Attention function to use. If None, defaults to - nn.dot_product_attention. - decode: whether to run in single-position autoregressive mode. - pad_token: Indicates which input tokens are padded. + vocab_size: size of the vocabulary + emb_dim: dimension of embedding + num_heads: number of heads + num_layers: number of layers + qkv_dim: dimension of the query/key/value + mlp_dim: dimension of the mlp on top of attention block + max_len: maximum length. + train: bool: if model is training. + causal: Whether to apply causal masking. + shift: bool: if we right-shift input - this is only disabled for + fast, looped single-token autoregressive decoding. + dropout_rate: dropout rate + attention_dropout_rate: dropout rate for attention weights + normalizer: One of 'batch_norm', 'layer_norm', 'none' + attention_fn: Attention function to use. If None, defaults to + nn.dot_product_attention. + decode: whether to run in single-position autoregressive mode. + pad_token: Indicates which input tokens are padded. """ + vocab_size: int shared_embedding: Any = None logits_via_embedding: bool = False @@ -331,11 +352,9 @@ class TransformerLM(nn.Module): pad_token: int = 0 @nn.compact - def __call__(self, - inputs, - train, - inputs_positions=None, - inputs_segmentation=None): + def __call__( + self, inputs, train, inputs_positions=None, inputs_segmentation=None + ): """Applies Transformer model on the inputs. Args: @@ -356,13 +375,16 @@ def __call__(self, else: decoder_mask = nn.combine_masks( nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype), - nn.make_causal_mask(inputs, dtype=dtype)) + nn.make_causal_mask(inputs, dtype=dtype), + ) if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask( - inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype)) + inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype + ), + ) y = inputs.astype('int32') if not self.decode: @@ -380,7 +402,8 @@ def __call__(self, features=self.emb_dim, dtype=dtype, param_dtype=dtype, - embedding_init=nn.initializers.normal(stddev=1.0)) + embedding_init=nn.initializers.normal(stddev=1.0), + ) else: output_embed = self.shared_embedding @@ -390,8 +413,8 @@ def __call__(self, max_len=self.max_len, posemb_init=sinusoidal_init(max_len=self.max_len), decode=self.decode, - name='posembed_output')( - y, inputs_positions=inputs_positions, dtype=dtype) + name='posembed_output', + )(y, inputs_positions=inputs_positions, dtype=dtype) y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) y = y.astype(dtype) @@ -406,16 +429,19 @@ def __call__(self, normalize_attention=self.normalize_attention, attention_fn=self.attention_fn, normalizer=self.normalizer, - dtype=dtype)( - inputs=y, - train=train, - decoder_mask=decoder_mask, - encoder_decoder_mask=None, - inputs_positions=None, - inputs_segmentation=None,) + dtype=dtype, + )( + inputs=y, + train=train, + decoder_mask=decoder_mask, + encoder_decoder_mask=None, + inputs_positions=None, + inputs_segmentation=None, + ) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=dtype) + self.normalizer, train, dtype=dtype + ) y = maybe_normalize(param_dtype=dtype)(y) if self.logits_via_embedding: @@ -430,8 +456,8 @@ def __call__(self, bias_init=nn.initializers.normal(stddev=1e-6), dtype=dtype, param_dtype=dtype, - name='logits_dense')( - y) + name='logits_dense', + )(y) return logits.astype(dtype) diff --git a/init2winit/model_lib/transformer_stu_lm.py b/init2winit/model_lib/transformer_stu_lm.py index 2a5d2fcd..9a2695e1 100644 --- a/init2winit/model_lib/transformer_stu_lm.py +++ b/init2winit/model_lib/transformer_stu_lm.py @@ -18,6 +18,7 @@ Adapted from https://github.com/google/flax/blob/b60f7f45b90f8fc42a88b1639c9cc88a40b298d3/examples/lm1b/models.py """ + import functools import logging import pickle @@ -57,7 +58,8 @@ add_stu_norm=True, do_stu_proj=True, zero_init_input=False, - )) + ) +) def shift_right(x, axis=1): @@ -65,7 +67,8 @@ def shift_right(x, axis=1): pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) + x, pad_widths, mode='constant', constant_values=x.dtype.type(0) + ) return lax.dynamic_slice_in_dim(padded, 0, padded.shape[axis] - 1, axis) @@ -75,7 +78,7 @@ def shift_inputs(x, segment_ids=None, axis=1): # For packed targets, the first shifted token of a new sequence is made # 0, rather than being the EOS token for the last sequence. if segment_ids is not None: - shifted *= (segment_ids == shift_right(segment_ids, axis=axis)) + shifted *= segment_ids == shift_right(segment_ids, axis=axis) return shifted @@ -96,7 +99,8 @@ def init(key, shape, dtype=np.float32): pe = np.zeros((max_len, d_feature), dtype=dtype) position = np.arange(0, max_len)[:, np.newaxis] div_term = np.exp( - np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) + np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature) + ) pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] @@ -113,6 +117,7 @@ class AddPositionEmbs(nn.Module): posemb_init: positional embedding initializer decode: whether to run in single-position autoregressive mode. """ + max_len: int = 2048 posemb_init: model_utils.Initializer = nn.initializers.normal(stddev=1.0) decode: bool = False @@ -134,24 +139,28 @@ def __call__(self, inputs, inputs_positions=None, dtype=np.float32): output: `(bs, timesteps, in_dim)` """ # inputs.shape is (batch_size, seq_len, emb_dim) - assert inputs.ndim == 3, ('Number of dimensions should be 3,' - ' but it is: %d' % inputs.ndim) + assert inputs.ndim == 3, ( + 'Number of dimensions should be 3, but it is: %d' % inputs.ndim + ) length = inputs.shape[1] pos_emb_shape = (1, self.max_len, inputs.shape[-1]) if self.posemb_init is None: # Use a fixed (non-learned) sinusoidal position embedding. - pos_embedding = sinusoidal_init(max_len=self.max_len)(None, pos_emb_shape, - dtype) + pos_embedding = sinusoidal_init(max_len=self.max_len)( + None, pos_emb_shape, dtype + ) else: - pos_embedding = self.param('pos_embedding', self.posemb_init, - pos_emb_shape, dtype) + pos_embedding = self.param( + 'pos_embedding', self.posemb_init, pos_emb_shape, dtype + ) pe = pos_embedding[:, :length, :] # We use a cache position index for tracking decoding position. if self.decode: is_initialized = self.has_variable('cache', 'cache_index') - cache_index = self.variable('cache', 'cache_index', - lambda: jnp.array(0, dtype=jnp.uint32)) + cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32) + ) if is_initialized: i = cache_index.value cache_index.value = i + 1 @@ -167,6 +176,7 @@ def __call__(self, inputs, inputs_positions=None, dtype=np.float32): class MlpBlock(nn.Module): """Transformer MLP block.""" + mlp_dim: int out_dim: Optional[int] = None dropout_rate: float = 0.1 @@ -183,8 +193,8 @@ def __call__(self, inputs, train): kernel_init=self.kernel_init, bias_init=self.bias_init, dtype=self.dtype, - param_dtype=self.dtype)( - inputs) + param_dtype=self.dtype, + )(inputs) x = nn.gelu(x) x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) output = nn.Dense( @@ -192,8 +202,8 @@ def __call__(self, inputs, train): kernel_init=self.kernel_init, bias_init=self.bias_init, dtype=self.dtype, - param_dtype=self.dtype)( - x) + param_dtype=self.dtype, + )(x) output = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(output) return output @@ -201,16 +211,17 @@ def __call__(self, inputs, train): class TransformerSTUHybridBlock(nn.Module): """Transformer layer (https://openreview.net/forum?id=H1e5GJBtDr). - qkv_dim: dimension of the query/key/value - mlp_dim: dimension of the mlp on top of attention block - num_heads: number of heads - dropout_rate: dropout rate - attention_dropout_rate: dropout rate for attention weights - normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', - 'pre_layer_norm', 'none' - attention_fn: Attention function to use. If None, defaults to - nn.dot_product_attention. + qkv_dim: dimension of the query/key/value + mlp_dim: dimension of the mlp on top of attention block + num_heads: number of heads + dropout_rate: dropout rate + attention_dropout_rate: dropout rate for attention weights + normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', + 'pre_layer_norm', 'none' + attention_fn: Attention function to use. If None, defaults to + nn.dot_product_attention. """ + qkv_dim: int mlp_dim: int num_heads: int @@ -228,13 +239,15 @@ class TransformerSTUHybridBlock(nn.Module): zero_init_input: bool = False @nn.compact - def __call__(self, - inputs, - train, - decoder_mask=None, - encoder_decoder_mask=None, - inputs_positions=None, - inputs_segmentation=None): + def __call__( + self, + inputs, + train, + decoder_mask=None, + encoder_decoder_mask=None, + inputs_positions=None, + inputs_segmentation=None, + ): """Applies Transformer1DBlock module. Args: @@ -247,23 +260,29 @@ def __call__(self, Returns: output after transformer block. - """ # Attention block. assert inputs.ndim == 3 if self.normalizer in [ - 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none' + 'batch_norm', + 'layer_norm', + 'pre_layer_norm', + 'none', ]: maybe_pre_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) else: raise ValueError('Unsupported normalizer: {}'.format(self.normalizer)) @@ -321,7 +340,8 @@ def __call__(self, attention_fn=attention_fn, dropout_rate=self.attention_dropout_rate, normalize_attention=self.normalize_attention, - deterministic=not train)(x, decoder_mask) + deterministic=not train, + )(x, decoder_mask) x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) x = x + stu_outs x = maybe_post_normalize(param_dtype=self.dtype)(x) @@ -329,8 +349,8 @@ def __call__(self, # MLP block. y = maybe_pre_normalize(param_dtype=self.dtype)(x) y = MlpBlock( - mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, dtype=self.dtype)( - y, train=train) + mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, dtype=self.dtype + )(y, train=train) res = x + y return maybe_post_normalize(param_dtype=self.dtype)(res) @@ -349,9 +369,7 @@ def conv_fft(v: jnp.ndarray, u: jnp.ndarray) -> jnp.ndarray: """ # Convolve two vectors of length l (x.shape[0]) and truncate to the l oldest # values. - tr_conv = lambda x, y: jax.scipy.signal.convolve(x, y)[ - : y.shape[0] - ] + tr_conv = lambda x, y: jax.scipy.signal.convolve(x, y)[: y.shape[0]] # Convolve each sequence of length l in v with each sequence in u. mvconv = jax.vmap(tr_conv, in_axes=(1, None), out_axes=1) @@ -393,25 +411,26 @@ def compute_xtilde( class TransformerLM(nn.Module): """Transformer Model for language modeling. - vocab_size: size of the vocabulary - emb_dim: dimension of embedding - num_heads: number of heads - num_layers: number of layers - qkv_dim: dimension of the query/key/value - mlp_dim: dimension of the mlp on top of attention block - max_len: maximum length. - train: bool: if model is training. - causal: Whether to apply causal masking. - shift: bool: if we right-shift input - this is only disabled for - fast, looped single-token autoregressive decoding. - dropout_rate: dropout rate - attention_dropout_rate: dropout rate for attention weights - normalizer: One of 'batch_norm', 'layer_norm', 'none' - attention_fn: Attention function to use. If None, defaults to - nn.dot_product_attention. - decode: whether to run in single-position autoregressive mode. - pad_token: Indicates which input tokens are padded. + vocab_size: size of the vocabulary + emb_dim: dimension of embedding + num_heads: number of heads + num_layers: number of layers + qkv_dim: dimension of the query/key/value + mlp_dim: dimension of the mlp on top of attention block + max_len: maximum length. + train: bool: if model is training. + causal: Whether to apply causal masking. + shift: bool: if we right-shift input - this is only disabled for + fast, looped single-token autoregressive decoding. + dropout_rate: dropout rate + attention_dropout_rate: dropout rate for attention weights + normalizer: One of 'batch_norm', 'layer_norm', 'none' + attention_fn: Attention function to use. If None, defaults to + nn.dot_product_attention. + decode: whether to run in single-position autoregressive mode. + pad_token: Indicates which input tokens are padded. """ + vocab_size: int input_len: int num_eigh: int @@ -440,7 +459,7 @@ class TransformerLM(nn.Module): zero_init_input: bool = False def setup(self): - path = EIGVEC_PATH+f'{self.input_len}.pkl' + path = EIGVEC_PATH + f'{self.input_len}.pkl' with gfile.Open(path, 'rb') as my_file: eig_vals, eig_vecs = pickle.load(my_file) @@ -449,16 +468,14 @@ def setup(self): logging.info('Eig vecs shape: %s', eig_vecs.shape) self.eigh = ( - eig_vals[-self.num_eigh:], - eig_vecs[:, -self.num_eigh:], + eig_vals[-self.num_eigh :], + eig_vecs[:, -self.num_eigh :], ) @nn.compact - def __call__(self, - inputs, - train, - inputs_positions=None, - inputs_segmentation=None): + def __call__( + self, inputs, train, inputs_positions=None, inputs_segmentation=None + ): """Applies Transformer model on the inputs. Args: @@ -481,13 +498,16 @@ def __call__(self, else: decoder_mask = nn.combine_masks( nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype), - nn.make_causal_mask(inputs, dtype=dtype)) + nn.make_causal_mask(inputs, dtype=dtype), + ) if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask( - inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype)) + inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype + ), + ) y = inputs.astype('int32') if not self.decode: @@ -505,7 +525,8 @@ def __call__(self, features=self.emb_dim, dtype=dtype, param_dtype=dtype, - embedding_init=nn.initializers.normal(stddev=1.0)) + embedding_init=nn.initializers.normal(stddev=1.0), + ) else: output_embed = self.shared_embedding @@ -515,8 +536,8 @@ def __call__(self, max_len=self.max_len, posemb_init=sinusoidal_init(max_len=self.max_len), decode=self.decode, - name='posembed_output')( - y, inputs_positions=inputs_positions, dtype=dtype) + name='posembed_output', + )(y, inputs_positions=inputs_positions, dtype=dtype) y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) y = y.astype(dtype) @@ -536,16 +557,19 @@ def __call__(self, make_stu_residual=self.make_stu_residual, add_stu_norm=self.add_stu_norm, do_stu_proj=self.do_stu_proj, - zero_init_input=self.zero_init_input)( - inputs=y, - train=train, - decoder_mask=decoder_mask, - encoder_decoder_mask=None, - inputs_positions=None, - inputs_segmentation=None,) + zero_init_input=self.zero_init_input, + )( + inputs=y, + train=train, + decoder_mask=decoder_mask, + encoder_decoder_mask=None, + inputs_positions=None, + inputs_segmentation=None, + ) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=dtype) + self.normalizer, train, dtype=dtype + ) y = maybe_normalize(param_dtype=dtype)(y) if self.logits_via_embedding: @@ -560,8 +584,8 @@ def __call__(self, bias_init=nn.initializers.normal(stddev=1e-6), dtype=dtype, param_dtype=dtype, - name='logits_dense')( - y) + name='logits_dense', + )(y) return logits.astype(dtype) diff --git a/init2winit/model_lib/transformer_stu_tensordot_lm.py b/init2winit/model_lib/transformer_stu_tensordot_lm.py index 351cde4d..cc4afcfc 100644 --- a/init2winit/model_lib/transformer_stu_tensordot_lm.py +++ b/init2winit/model_lib/transformer_stu_tensordot_lm.py @@ -18,6 +18,7 @@ Adapted from https://github.com/google/flax/blob/b60f7f45b90f8fc42a88b1639c9cc88a40b298d3/examples/lm1b/models.py """ + import functools import logging import pickle @@ -57,7 +58,8 @@ add_stu_norm=True, zero_init_input=False, add_negative_vecs=False, - )) + ) +) def shift_right(x, axis=1): @@ -65,7 +67,8 @@ def shift_right(x, axis=1): pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) + x, pad_widths, mode='constant', constant_values=x.dtype.type(0) + ) return lax.dynamic_slice_in_dim(padded, 0, padded.shape[axis] - 1, axis) @@ -75,7 +78,7 @@ def shift_inputs(x, segment_ids=None, axis=1): # For packed targets, the first shifted token of a new sequence is made # 0, rather than being the EOS token for the last sequence. if segment_ids is not None: - shifted *= (segment_ids == shift_right(segment_ids, axis=axis)) + shifted *= segment_ids == shift_right(segment_ids, axis=axis) return shifted @@ -96,7 +99,8 @@ def init(key, shape, dtype=np.float32): pe = np.zeros((max_len, d_feature), dtype=dtype) position = np.arange(0, max_len)[:, np.newaxis] div_term = np.exp( - np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature)) + np.arange(0, d_feature, 2) * -(np.log(10000.0) / d_feature) + ) pe[:, 0::2] = np.sin(position * div_term) pe[:, 1::2] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] @@ -113,6 +117,7 @@ class AddPositionEmbs(nn.Module): posemb_init: positional embedding initializer decode: whether to run in single-position autoregressive mode. """ + max_len: int = 2048 posemb_init: model_utils.Initializer = nn.initializers.normal(stddev=1.0) decode: bool = False @@ -134,24 +139,28 @@ def __call__(self, inputs, inputs_positions=None, dtype=np.float32): output: `(bs, timesteps, in_dim)` """ # inputs.shape is (batch_size, seq_len, emb_dim) - assert inputs.ndim == 3, ('Number of dimensions should be 3,' - ' but it is: %d' % inputs.ndim) + assert inputs.ndim == 3, ( + 'Number of dimensions should be 3, but it is: %d' % inputs.ndim + ) length = inputs.shape[1] pos_emb_shape = (1, self.max_len, inputs.shape[-1]) if self.posemb_init is None: # Use a fixed (non-learned) sinusoidal position embedding. - pos_embedding = sinusoidal_init(max_len=self.max_len)(None, pos_emb_shape, - dtype) + pos_embedding = sinusoidal_init(max_len=self.max_len)( + None, pos_emb_shape, dtype + ) else: - pos_embedding = self.param('pos_embedding', self.posemb_init, - pos_emb_shape, dtype) + pos_embedding = self.param( + 'pos_embedding', self.posemb_init, pos_emb_shape, dtype + ) pe = pos_embedding[:, :length, :] # We use a cache position index for tracking decoding position. if self.decode: is_initialized = self.has_variable('cache', 'cache_index') - cache_index = self.variable('cache', 'cache_index', - lambda: jnp.array(0, dtype=jnp.uint32)) + cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32) + ) if is_initialized: i = cache_index.value cache_index.value = i + 1 @@ -167,6 +176,7 @@ def __call__(self, inputs, inputs_positions=None, dtype=np.float32): class MlpBlock(nn.Module): """Transformer MLP block.""" + mlp_dim: int out_dim: Optional[int] = None dropout_rate: float = 0.1 @@ -183,8 +193,8 @@ def __call__(self, inputs, train): kernel_init=self.kernel_init, bias_init=self.bias_init, dtype=self.dtype, - param_dtype=self.dtype)( - inputs) + param_dtype=self.dtype, + )(inputs) x = nn.gelu(x) x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) output = nn.Dense( @@ -192,8 +202,8 @@ def __call__(self, inputs, train): kernel_init=self.kernel_init, bias_init=self.bias_init, dtype=self.dtype, - param_dtype=self.dtype)( - x) + param_dtype=self.dtype, + )(x) output = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(output) return output @@ -201,16 +211,17 @@ def __call__(self, inputs, train): class TransformerSTUHybridBlock(nn.Module): """Transformer layer (https://openreview.net/forum?id=H1e5GJBtDr). - qkv_dim: dimension of the query/key/value - mlp_dim: dimension of the mlp on top of attention block - num_heads: number of heads - dropout_rate: dropout rate - attention_dropout_rate: dropout rate for attention weights - normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', - 'pre_layer_norm', 'none' - attention_fn: Attention function to use. If None, defaults to - nn.dot_product_attention. + qkv_dim: dimension of the query/key/value + mlp_dim: dimension of the mlp on top of attention block + num_heads: number of heads + dropout_rate: dropout rate + attention_dropout_rate: dropout rate for attention weights + normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', + 'pre_layer_norm', 'none' + attention_fn: Attention function to use. If None, defaults to + nn.dot_product_attention. """ + qkv_dim: int mlp_dim: int num_heads: int @@ -227,13 +238,15 @@ class TransformerSTUHybridBlock(nn.Module): zero_init_input: bool = False @nn.compact - def __call__(self, - inputs, - train, - decoder_mask=None, - encoder_decoder_mask=None, - inputs_positions=None, - inputs_segmentation=None): + def __call__( + self, + inputs, + train, + decoder_mask=None, + encoder_decoder_mask=None, + inputs_positions=None, + inputs_segmentation=None, + ): """Applies Transformer1DBlock module. Args: @@ -246,23 +259,29 @@ def __call__(self, Returns: output after transformer block. - """ # Attention block. assert inputs.ndim == 3 if self.normalizer in [ - 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none' + 'batch_norm', + 'layer_norm', + 'pre_layer_norm', + 'none', ]: maybe_pre_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) else: raise ValueError('Unsupported normalizer: {}'.format(self.normalizer)) @@ -278,9 +297,7 @@ def __call__(self, # Convolve two vectors of length l and truncate to first input_len values. # TODO(dsuo): Need to investigate when input len is shorter than # self.cfg.sequence_length. - tr_conv = lambda x, y: jax.scipy.signal.convolve(x, y)[ - : y.shape[0] - ] + tr_conv = lambda x, y: jax.scipy.signal.convolve(x, y)[: y.shape[0]] # Compute d_in convolutions between [l, d_out] by [l, d_out]. Output shape # is [l, d_out]. @@ -352,7 +369,8 @@ def __call__(self, attention_fn=attention_fn, dropout_rate=self.attention_dropout_rate, normalize_attention=self.normalize_attention, - deterministic=not train)(x, decoder_mask) + deterministic=not train, + )(x, decoder_mask) x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) x = x + stu_outs x = maybe_post_normalize(param_dtype=self.dtype)(x) @@ -360,8 +378,8 @@ def __call__(self, # MLP block. y = maybe_pre_normalize(param_dtype=self.dtype)(x) y = MlpBlock( - mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, dtype=self.dtype)( - y, train=train) + mlp_dim=self.mlp_dim, dropout_rate=self.dropout_rate, dtype=self.dtype + )(y, train=train) res = x + y return maybe_post_normalize(param_dtype=self.dtype)(res) @@ -380,9 +398,7 @@ def conv_fft(v: jnp.ndarray, u: jnp.ndarray) -> jnp.ndarray: """ # Convolve two vectors of length l (x.shape[0]) and truncate to the l oldest # values. - tr_conv = lambda x, y: jax.scipy.signal.convolve(x, y)[ - : y.shape[0] - ] + tr_conv = lambda x, y: jax.scipy.signal.convolve(x, y)[: y.shape[0]] # Convolve each sequence of length l in v with each sequence in u. mvconv = jax.vmap(tr_conv, in_axes=(1, None), out_axes=1) @@ -433,25 +449,26 @@ def alternate_sign(inputs: np.ndarray): class TransformerLM(nn.Module): """Transformer Model for language modeling. - vocab_size: size of the vocabulary - emb_dim: dimension of embedding - num_heads: number of heads - num_layers: number of layers - qkv_dim: dimension of the query/key/value - mlp_dim: dimension of the mlp on top of attention block - max_len: maximum length. - train: bool: if model is training. - causal: Whether to apply causal masking. - shift: bool: if we right-shift input - this is only disabled for - fast, looped single-token autoregressive decoding. - dropout_rate: dropout rate - attention_dropout_rate: dropout rate for attention weights - normalizer: One of 'batch_norm', 'layer_norm', 'none' - attention_fn: Attention function to use. If None, defaults to - nn.dot_product_attention. - decode: whether to run in single-position autoregressive mode. - pad_token: Indicates which input tokens are padded. + vocab_size: size of the vocabulary + emb_dim: dimension of embedding + num_heads: number of heads + num_layers: number of layers + qkv_dim: dimension of the query/key/value + mlp_dim: dimension of the mlp on top of attention block + max_len: maximum length. + train: bool: if model is training. + causal: Whether to apply causal masking. + shift: bool: if we right-shift input - this is only disabled for + fast, looped single-token autoregressive decoding. + dropout_rate: dropout rate + attention_dropout_rate: dropout rate for attention weights + normalizer: One of 'batch_norm', 'layer_norm', 'none' + attention_fn: Attention function to use. If None, defaults to + nn.dot_product_attention. + decode: whether to run in single-position autoregressive mode. + pad_token: Indicates which input tokens are padded. """ + vocab_size: int input_len: int num_eigh: int @@ -480,7 +497,7 @@ class TransformerLM(nn.Module): add_negative_vecs: bool = False def setup(self): - path = EIGVEC_PATH+f'{self.input_len}.pkl' + path = EIGVEC_PATH + f'{self.input_len}.pkl' with gfile.Open(path, 'rb') as my_file: eig_vals, eig_vecs = pickle.load(my_file) @@ -488,8 +505,8 @@ def setup(self): logging.info('Eig vals shape: %s', eig_vals.shape) logging.info('Eig vecs shape: %s', eig_vecs.shape) - eig_vals = eig_vals[-self.num_eigh:] - eig_vecs_pos = eig_vecs[:, -self.num_eigh:] + eig_vals = eig_vals[-self.num_eigh :] + eig_vecs_pos = eig_vecs[:, -self.num_eigh :] if self.add_negative_vecs: eig_vecs_neg = alternate_sign(eig_vecs_pos) @@ -506,11 +523,9 @@ def setup(self): self.eigh = expanded_eigh @nn.compact - def __call__(self, - inputs, - train, - inputs_positions=None, - inputs_segmentation=None): + def __call__( + self, inputs, train, inputs_positions=None, inputs_segmentation=None + ): """Applies Transformer model on the inputs. Args: @@ -533,13 +548,16 @@ def __call__(self, else: decoder_mask = nn.combine_masks( nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype), - nn.make_causal_mask(inputs, dtype=dtype)) + nn.make_causal_mask(inputs, dtype=dtype), + ) if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, nn.make_attention_mask( - inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype)) + inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype + ), + ) y = inputs.astype('int32') if not self.decode: @@ -557,7 +575,8 @@ def __call__(self, features=self.emb_dim, dtype=dtype, param_dtype=dtype, - embedding_init=nn.initializers.normal(stddev=1.0)) + embedding_init=nn.initializers.normal(stddev=1.0), + ) else: output_embed = self.shared_embedding @@ -567,8 +586,8 @@ def __call__(self, max_len=self.max_len, posemb_init=sinusoidal_init(max_len=self.max_len), decode=self.decode, - name='posembed_output')( - y, inputs_positions=inputs_positions, dtype=dtype) + name='posembed_output', + )(y, inputs_positions=inputs_positions, dtype=dtype) y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) y = y.astype(dtype) @@ -587,16 +606,19 @@ def __call__(self, eigh=self.eigh, make_stu_residual=self.make_stu_residual, add_stu_norm=self.add_stu_norm, - zero_init_input=self.zero_init_input)( - inputs=y, - train=train, - decoder_mask=decoder_mask, - encoder_decoder_mask=None, - inputs_positions=None, - inputs_segmentation=None,) + zero_init_input=self.zero_init_input, + )( + inputs=y, + train=train, + decoder_mask=decoder_mask, + encoder_decoder_mask=None, + inputs_positions=None, + inputs_segmentation=None, + ) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=dtype) + self.normalizer, train, dtype=dtype + ) y = maybe_normalize(param_dtype=dtype)(y) if self.logits_via_embedding: @@ -611,8 +633,8 @@ def __call__(self, bias_init=nn.initializers.normal(stddev=1e-6), dtype=dtype, param_dtype=dtype, - name='logits_dense')( - y) + name='logits_dense', + )(y) return logits.astype(dtype) diff --git a/init2winit/model_lib/unet.py b/init2winit/model_lib/unet.py index b12b4982..8ff400dd 100644 --- a/init2winit/model_lib/unet.py +++ b/init2winit/model_lib/unet.py @@ -32,7 +32,6 @@ import jax.numpy as jnp from ml_collections import config_dict - DEFAULT_HPARAMS = config_dict.ConfigDict( dict( out_chans=1, @@ -42,7 +41,8 @@ activation='leaky_relu', model_dtype='float32', normalizer='unet_instance_norm', - )) + ) +) def _compute_stats(x, axes): @@ -53,7 +53,7 @@ def _compute_stats(x, axes): mean2 = jnp.mean(jnp.square(x), axes) # mean2 - _abs_sq(mean) is not guaranteed to be non-negative due # to floating point round-off errors. - var = jnp.maximum(0., mean2 - jnp.square(mean)) + var = jnp.maximum(0.0, mean2 - jnp.square(mean)) return mean, var @@ -77,16 +77,17 @@ def _simple_instance_norm2d(x, axes, epsilon=1e-5): class UNet(nn.Module): """Jax / Flax implementation of a U-Net model. - O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks - for biomedical image segmentation. In International Conference on Medical - image computing and computer-assisted intervention, pages 234–241. - Springer, 2015. + O. Ronneberger, P. Fischer, and Thomas Brox. U-net: Convolutional networks + for biomedical image segmentation. In International Conference on Medical + image computing and computer-assisted intervention, pages 234–241. + Springer, 2015. - out_chans: Number of channels in the output to the U-Net model. - chans: Number of output channels of the first convolution layer. - num_pool_layers: Number of down-sampling and up-sampling layers. - drop_prob: Dropout probability. + out_chans: Number of channels in the output to the U-Net model. + chans: Number of output channels of the first convolution layer. + num_pool_layers: Number of down-sampling and up-sampling layers. + drop_prob: Dropout probability. """ + out_chans: int chans: int = 32 num_pool_layers: int = 4 @@ -166,6 +167,7 @@ class ConvBlock(nn.Module): out_chans: Number of channels in the output. drop_prob: Dropout probability. """ + out_chans: int drop_prob: float activation: str @@ -188,8 +190,8 @@ def __call__(self, x, train=True): features=self.out_chans, kernel_size=(3, 3), strides=(1, 1), - use_bias=False)( - x) + use_bias=False, + )(x) # InstanceNorm2d was implemented with no learnable params in reference code # so this is a simple normalization along channels if self.normalizer == 'unet_instance_norm': @@ -208,14 +210,14 @@ def __call__(self, x, train=True): # Ref code uses dropout2d which applies the same mask for the entire channel # Replicated by using broadcast dims to have the same filter on HW x = nn.Dropout( - self.drop_prob, broadcast_dims=(1, 2), deterministic=not train)( - x) + self.drop_prob, broadcast_dims=(1, 2), deterministic=not train + )(x) x = nn.Conv( features=self.out_chans, kernel_size=(3, 3), strides=(1, 1), - use_bias=False)( - x) + use_bias=False, + )(x) # InstanceNorm2d was implemented with no learnable params in reference code # so this is a simple normalization along channels if self.normalizer == 'unet_instance_norm': @@ -232,8 +234,8 @@ def __call__(self, x, train=True): else: raise ValueError('Unsupported activation: {}'.format(self.activation)) x = nn.Dropout( - self.drop_prob, broadcast_dims=(1, 2), deterministic=not train)( - x) + self.drop_prob, broadcast_dims=(1, 2), deterministic=not train + )(x) return x @@ -243,6 +245,7 @@ class TransposeConvBlock(nn.Module): out_chans: Number of channels in the output. """ + out_chans: int activation: str @@ -257,8 +260,8 @@ def __call__(self, x): jnp.array: Output tensor of shape `(N, H*2, W*2, out_chans)`. """ x = nn.ConvTranspose( - self.out_chans, kernel_size=(2, 2), strides=(2, 2), use_bias=False)( - x) + self.out_chans, kernel_size=(2, 2), strides=(2, 2), use_bias=False + )(x) x = _simple_instance_norm2d(x, (1, 2)) if self.activation == 'leaky_relu': x = jax.nn.leaky_relu(x, negative_slope=0.2) @@ -277,7 +280,8 @@ def evaluate_batch(self, params, batch_stats, batch): """Evaluates metrics under self.metrics_name on the given_batch.""" variables = {'params': params, 'batch_stats': batch_stats} logits = self.flax_module.apply( - variables, batch['inputs'], mutable=False, train=False) + variables, batch['inputs'], mutable=False, train=False + ) targets = batch['targets'] # map the dict values (which are functions) to function(targets, logits) @@ -295,7 +299,8 @@ def evaluate_batch(self, params, batch_stats, batch): mean=batch.get('mean'), std=batch.get('std'), volume_max=batch.get('volume_max'), - axis_name='batch') + axis_name='batch', + ) def build_flax_module(self): """Unet implementation.""" @@ -305,7 +310,8 @@ def build_flax_module(self): num_pool_layers=self.hps.num_pool_layers, drop_prob=self.hps.dropout_rate, activation=self.hps.activation, - normalizer=self.hps.normalizer) + normalizer=self.hps.normalizer, + ) def get_fake_inputs(self, hps): """Helper method solely for the purpose of initializing the model.""" diff --git a/init2winit/model_lib/vit.py b/init2winit/model_lib/vit.py index 275f9dcc..a12b1773 100644 --- a/init2winit/model_lib/vit.py +++ b/init2winit/model_lib/vit.py @@ -28,7 +28,6 @@ from ml_collections.config_dict import config_dict import numpy as np - # NOTE(dsuo): could be useful to have a `base_config` for models as well. DEFAULT_HPARAMS = config_dict.ConfigDict( dict( @@ -53,17 +52,18 @@ layer_norm_struct=None, attn_temperature=1.0, use_glu=False, - )) + ) +) -def posemb_sincos_2d(h, w, width, temperature=10_000., dtype=jnp.float32): +def posemb_sincos_2d(h, w, width, temperature=10_000.0, dtype=jnp.float32): """Follows the MoCo v3 logic.""" y, x = jnp.mgrid[:h, :w] if width % 4 != 0: raise ValueError('Width must be mult of 4 for sincos posemb.') omega = jnp.arange(width // 4) / (width // 4 - 1) - omega = 1. / (temperature ** omega) + omega = 1.0 / (temperature**omega) y = jnp.einsum('m,d->md', y.flatten(), omega) x = jnp.einsum('m,d->md', x.flatten(), omega) pe = jnp.concatenate([jnp.sin(x), jnp.cos(x), jnp.sin(y), jnp.cos(y)], axis=1) @@ -72,26 +72,32 @@ def posemb_sincos_2d(h, w, width, temperature=10_000., dtype=jnp.float32): def get_posemb(self, emb_type, seqshape, width, name, dtype=jnp.float32): if emb_type == 'learn': - return self.param(name, nn.initializers.normal(stddev=1 / np.sqrt(width)), - (1, np.prod(seqshape), width), dtype) + return self.param( + name, + nn.initializers.normal(stddev=1 / np.sqrt(width)), + (1, np.prod(seqshape), width), + dtype, + ) elif emb_type == 'sincos2d': return posemb_sincos_2d(*seqshape, width, dtype=dtype) else: raise ValueError(f'Unknown posemb type: {emb_type}') -def dot_product_attention(query, - key, - value, - bias=None, - mask=None, - broadcast_dropout=True, - dropout_rng=None, - dropout_rate=0., - deterministic=False, - dtype=jnp.float32, - precision=None, - temperature=1.0): +def dot_product_attention( + query, + key, + value, + bias=None, + mask=None, + broadcast_dropout=True, + dropout_rng=None, + dropout_rate=0.0, + deterministic=False, + dtype=jnp.float32, + precision=None, + temperature=1.0, +): """Computes dot-product attention given query, key, and value. This is the core function for applying attention based on @@ -101,21 +107,19 @@ def dot_product_attention(query, Note: query, key, value needn't have any batch dimensions. Args: - query: queries for calculating attention with shape of - `[batch..., q_length, num_heads, qk_depth_per_head]`. - key: keys for calculating attention with shape of - `[batch..., kv_length, num_heads, qk_depth_per_head]`. - value: values to be used in attention with shape of - `[batch..., kv_length, num_heads, v_depth_per_head]`. + query: queries for calculating attention with shape of `[batch..., q_length, + num_heads, qk_depth_per_head]`. + key: keys for calculating attention with shape of `[batch..., kv_length, + num_heads, qk_depth_per_head]`. + value: values to be used in attention with shape of `[batch..., kv_length, + num_heads, v_depth_per_head]`. bias: bias for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks, padding masks, - proximity bias, etc. + shape `[batch..., num_heads, q_length, kv_length]`. This can be used for + incorporating causal masks, padding masks, proximity bias, etc. mask: mask for the attention weights. This should be broadcastable to the - shape `[batch..., num_heads, q_length, kv_length]`. - This can be used for incorporating causal masks. - Attention weights are masked out if their corresponding mask value - is `False`. + shape `[batch..., num_heads, q_length, kv_length]`. This can be used for + incorporating causal masks. Attention weights are masked out if their + corresponding mask value is `False`. broadcast_dropout: bool: use a broadcasted dropout along batch dims. dropout_rng: JAX PRNGKey: to be used for dropout dropout_rate: dropout rate @@ -129,24 +133,40 @@ def dot_product_attention(query, Output of shape `[batch..., q_length, num_heads, v_depth_per_head]`. """ assert key.ndim == query.ndim == value.ndim, 'q, k, v must have same rank.' - assert query.shape[:-3] == key.shape[:-3] == value.shape[:-3], ( - 'q, k, v batch dims must match.') - assert query.shape[-2] == key.shape[-2] == value.shape[-2], ( - 'q, k, v num_heads must match.') + assert ( + query.shape[:-3] == key.shape[:-3] == value.shape[:-3] + ), 'q, k, v batch dims must match.' + assert ( + query.shape[-2] == key.shape[-2] == value.shape[-2] + ), 'q, k, v num_heads must match.' assert key.shape[-3] == value.shape[-3], 'k, v lengths must match.' # compute attention weights attn_weights = nn.dot_product_attention_weights( - query, key, bias, mask, broadcast_dropout, dropout_rng, dropout_rate, - deterministic, dtype, precision) + query, + key, + bias, + mask, + broadcast_dropout, + dropout_rng, + dropout_rate, + deterministic, + dtype, + precision, + ) # return weighted sum over values for each query position - return jnp.einsum('...hqk,...khd->...qhd', attn_weights, value, - precision=precision) * temperature + return ( + jnp.einsum( + '...hqk,...khd->...qhd', attn_weights, value, precision=precision + ) + * temperature + ) class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim dropout: float = 0.0 activation: str = 'gelu' @@ -168,9 +188,7 @@ def __call__(self, x, train=True): raise ValueError('Unsupported activation: {}'.format(self.activation)) if self.use_glu: - y = nn.Dense( - self.mlp_dim, - **inits)(x) + y = nn.Dense(self.mlp_dim, **inits)(x) x = x * y x = nn.Dropout(rate=self.dropout)(x, train) @@ -180,6 +198,7 @@ def __call__(self, x, train=True): class Encoder1DBlock(nn.Module): """Single transformer encoder block (MHSA + MLP).""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 dropout: float = 0.0 @@ -224,13 +243,14 @@ def __call__(self, x, train=True): in_axis=-2, out_axis=-1, batch_axis=(), - dtype=jnp.float_), + dtype=jnp.float_, + ), deterministic=train, attention_fn=attn_fn, name='MultiHeadDotProductAttention_1', )(y) y = nn.Dropout(rate=self.dropout)(y, train) - x = 2 * (self.residual_alpha * x + (1-self.residual_alpha) * y) + x = 2 * (self.residual_alpha * x + (1 - self.residual_alpha) * y) x = maybe_post_normalize()(x) out['+sa'] = x @@ -240,10 +260,10 @@ def __call__(self, x, train=True): dropout=self.dropout, name='MlpBlock_3', activation=self.activation, - use_glu=self.use_glu + use_glu=self.use_glu, )(y, train) y = nn.Dropout(rate=self.dropout)(y, train) - x = 2*(self.residual_alpha*x+(1-self.residual_alpha)*y) + x = 2 * (self.residual_alpha * x + (1 - self.residual_alpha) * y) x = maybe_post_normalize()(x) if self.resnet_style_residual: activation_fn = model_utils.ACTIVATIONS[self.activation] @@ -255,6 +275,7 @@ def __call__(self, x, train=True): class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation.""" + depth: int mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 @@ -286,9 +307,10 @@ def __call__(self, x, train=True): residual_alpha=self.residual_alpha, scale_attention_init=self.scale_attention_init, attn_temperature=self.attn_temperature, - use_glu=self.use_glu) + use_glu=self.use_glu, + ) else: - assert(len(self.layer_norm_struct)) == self.depth + assert (len(self.layer_norm_struct)) == self.depth block = Encoder1DBlock( name=f'encoderblock_{lyr}', mlp_dim=self.mlp_dim, @@ -300,7 +322,8 @@ def __call__(self, x, train=True): residual_alpha=self.residual_alpha, scale_attention_init=self.scale_attention_init, attn_temperature=self.attn_temperature, - use_glu=self.use_glu) + use_glu=self.use_glu, + ) x, out[f'block{lyr:02d}'] = block(x, train) out['pre_ln'] = x # Alias for last block, but without the number in it. @@ -312,32 +335,39 @@ def __call__(self, x, train=True): class MAPHead(nn.Module): """Multihead Attention Pooling.""" + mlp_dim: Optional[int] = None # Defaults to 4x input dim num_heads: int = 12 normalizer: str = 'pre_layer_norm' + @nn.compact def __call__(self, x): # TODO(lbeyer): condition on GAP(x) n, _, d = x.shape - probe = self.param('probe', nn.initializers.xavier_uniform(), - (1, 1, d), x.dtype) + probe = self.param( + 'probe', nn.initializers.xavier_uniform(), (1, 1, d), x.dtype + ) probe = jnp.tile(probe, [n, 1, 1]) x = nn.MultiHeadDotProductAttention( - num_heads=self.num_heads, - kernel_init=nn.initializers.xavier_uniform())(probe, x) + num_heads=self.num_heads, kernel_init=nn.initializers.xavier_uniform() + )(probe, x) # TODO(lbeyer): dropout on head? if self.normalizer == 'pre_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( - self.normalizer, train=True, dtype=None) + self.normalizer, train=True, dtype=None + ) maybe_post_normalize = model_utils.get_normalizer( - 'none', train=True, dtype=None) + 'none', train=True, dtype=None + ) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( - 'none', train=True, dtype=None) + 'none', train=True, dtype=None + ) maybe_post_normalize = model_utils.get_normalizer( - self.normalizer, train=True, dtype=None) + self.normalizer, train=True, dtype=None + ) else: raise ValueError('Unsupported normalizer: {}'.format(self.normalizer)) y = maybe_pre_normalize()(x) @@ -369,21 +399,27 @@ class ViT(nn.Module): layer_norm_struct: Sequence[str] = None attn_temperature: float = 1.0 use_glu: bool = False + @nn.compact def __call__(self, x, *, train=False): out = {} # Patch extraction x = out['stem'] = nn.Conv( - self.width, self.patch_size, strides=self.patch_size, - padding='VALID', name='conv_patch_extract')(x) + self.width, + self.patch_size, + strides=self.patch_size, + padding='VALID', + name='conv_patch_extract', + )(x) n, h, w, c = x.shape x = jnp.reshape(x, [n, h * w, c]) # Add posemb before adding extra token. x = out['with_posemb'] = x + get_posemb( - self, self.posemb, (h, w), c, 'pos_embedding', x.dtype) + self, self.posemb, (h, w), c, 'pos_embedding', x.dtype + ) if self.pool_type == 'tok': cls = self.param('cls', nn.initializers.zeros, (1, 1, c), x.dtype) @@ -405,16 +441,16 @@ def __call__(self, x, *, train=False): layer_norm_struct=self.layer_norm_struct, attn_temperature=self.attn_temperature, use_glu=self.use_glu, - name='Transformer')( - x, train=not train) + name='Transformer', + )(x, train=not train) encoded = out['encoded'] = x if self.pool_type == 'map': x = out['head_input'] = MAPHead( num_heads=self.num_heads, mlp_dim=self.mlp_dim, - normalizer=self.normalizer)( - x) + normalizer=self.normalizer, + )(x) elif self.pool_type == 'gap': x = out['head_input'] = jnp.mean(x, axis=1) elif self.pool_type == '0': @@ -455,11 +491,25 @@ def build_flax_module(self): """Vision transformer.""" keys = [ - 'num_classes', 'rep_size', 'pool_type', 'posemb', 'width', 'depth', - 'mlp_dim', 'num_heads', 'patch_size', 'dropout_rate', 'head_zeroinit', - 'normalizer', 'activation', 'resnet_style_residual', 'residual_alpha', - 'scale_attention_init', 'layer_norm_struct', 'attn_temperature', - 'use_glu' + 'num_classes', + 'rep_size', + 'pool_type', + 'posemb', + 'width', + 'depth', + 'mlp_dim', + 'num_heads', + 'patch_size', + 'dropout_rate', + 'head_zeroinit', + 'normalizer', + 'activation', + 'resnet_style_residual', + 'residual_alpha', + 'scale_attention_init', + 'layer_norm_struct', + 'attn_temperature', + 'use_glu', ] args = {k: self.hps[k] for k in keys} @@ -492,10 +542,50 @@ def decode_variant(variant): return { # pylint:disable=line-too-long # Reference: Table 2 of https://arxiv.org/abs/2106.04560. - 'width': {'Ti': 192, 'S': 384, 'V': 384, 'M': 512, 'B': 768, 'L': 1024, 'H': 1280, 'g': 1408, 'G': 1664}[v], - 'depth': {'Ti': 12, 'S': 12, 'V': 12, 'M': 12, 'B': 12, 'L': 24, 'H': 32, 'g': 40, 'G': 48}[v], - 'mlp_dim': {'Ti': 768, 'S': 1536, 'V': 1152, 'M': 2048, 'B': 3072, 'L': 4096, 'H': 5120, 'g': 6144, 'G': 8192}[v], - 'num_heads': {'Ti': 3, 'S': 6, 'V': 6, 'M': 8, 'B': 12, 'L': 16, 'H': 16, 'g': 16, 'G': 16}[v], + 'width': { + 'Ti': 192, + 'S': 384, + 'V': 384, + 'M': 512, + 'B': 768, + 'L': 1024, + 'H': 1280, + 'g': 1408, + 'G': 1664, + }[v], + 'depth': { + 'Ti': 12, + 'S': 12, + 'V': 12, + 'M': 12, + 'B': 12, + 'L': 24, + 'H': 32, + 'g': 40, + 'G': 48, + }[v], + 'mlp_dim': { + 'Ti': 768, + 'S': 1536, + 'V': 1152, + 'M': 2048, + 'B': 3072, + 'L': 4096, + 'H': 5120, + 'g': 6144, + 'G': 8192, + }[v], + 'num_heads': { + 'Ti': 3, + 'S': 6, + 'V': 6, + 'M': 8, + 'B': 12, + 'L': 16, + 'H': 16, + 'g': 16, + 'G': 16, + }[v], # pylint:enable=line-too-long - 'patch_size': (int(patch), int(patch)) + 'patch_size': (int(patch), int(patch)), } diff --git a/init2winit/model_lib/wide_resnet.py b/init2winit/model_lib/wide_resnet.py index f9379b1f..136537ac 100644 --- a/init2winit/model_lib/wide_resnet.py +++ b/init2winit/model_lib/wide_resnet.py @@ -14,6 +14,7 @@ # limitations under the License. """Wide Resnet Model.""" + from typing import List, Optional, Tuple from flax import linen as nn @@ -21,10 +22,8 @@ from init2winit.model_lib import model_utils from jax.nn import initializers import jax.numpy as jnp - from ml_collections.config_dict import config_dict - # small hparams used for unit tests DEFAULT_HPARAMS = config_dict.ConfigDict( dict( @@ -39,12 +38,14 @@ virtual_batch_size=None, model_dtype='float32', activation_function='relu', - group_strides=[(1, 1), (2, 2), (2, 2)]) + group_strides=[(1, 1), (2, 2), (2, 2)], + ) ) class WideResnetBlock(nn.Module): """Defines a single WideResnetBlock.""" + channels: int strides: List[Tuple[int]] conv_kernel_init: model_utils.Initializer = initializers.lecun_normal() @@ -62,7 +63,8 @@ def __call__(self, x, train): train, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, - total_batch_size=self.total_batch_size) + total_batch_size=self.total_batch_size, + ) y = maybe_normalize(name='bn1')(x) y = model_utils.ACTIVATIONS[self.activation_function](y) @@ -74,7 +76,8 @@ def __call__(self, x, train): self.strides, padding='SAME', kernel_init=self.conv_kernel_init, - use_bias=False)(y) + use_bias=False, + )(y) y = nn.Conv( self.channels, @@ -83,7 +86,8 @@ def __call__(self, x, train): padding='SAME', name='conv1', kernel_init=self.conv_kernel_init, - use_bias=False)(y) + use_bias=False, + )(y) y = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(y) y = maybe_normalize(name='bn2')(y) y = model_utils.ACTIVATIONS[self.activation_function](y) @@ -93,7 +97,8 @@ def __call__(self, x, train): padding='SAME', name='conv2', kernel_init=self.conv_kernel_init, - use_bias=False)(y) + use_bias=False, + )(y) if self.normalizer == 'none': y = model_utils.ScalarMultiply()(y) @@ -103,6 +108,7 @@ def __call__(self, x, train): class WideResnetGroup(nn.Module): """Defines a WideResnetGroup.""" + blocks_per_group: int channels: int strides: Tuple[int, int] = (1, 1) @@ -126,12 +132,14 @@ def __call__(self, x, train): activation_function=self.activation_function, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, - total_batch_size=self.total_batch_size)(x, train=train) + total_batch_size=self.total_batch_size, + )(x, train=train) return x class WideResnet(nn.Module): """Defines the WideResnet Model.""" + blocks_per_group: int channel_multiplier: int group_strides: List[Tuple[int]] @@ -153,7 +161,8 @@ def __call__(self, x, train): padding='SAME', name='init_conv', kernel_init=self.conv_kernel_init, - use_bias=False)(x) + use_bias=False, + )(x) x = WideResnetGroup( self.blocks_per_group, 16 * self.channel_multiplier, @@ -164,7 +173,8 @@ def __call__(self, x, train): activation_function=self.activation_function, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, - total_batch_size=self.total_batch_size)(x, train=train) + total_batch_size=self.total_batch_size, + )(x, train=train) x = WideResnetGroup( self.blocks_per_group, 32 * self.channel_multiplier, @@ -175,7 +185,8 @@ def __call__(self, x, train): activation_function=self.activation_function, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, - total_batch_size=self.total_batch_size)(x, train=train) + total_batch_size=self.total_batch_size, + )(x, train=train) x = WideResnetGroup( self.blocks_per_group, 64 * self.channel_multiplier, @@ -186,13 +197,15 @@ def __call__(self, x, train): activation_function=self.activation_function, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, - total_batch_size=self.total_batch_size)(x, train=train) + total_batch_size=self.total_batch_size, + )(x, train=train) maybe_normalize = model_utils.get_normalizer( self.normalizer, train, batch_size=self.batch_size, virtual_batch_size=self.virtual_batch_size, - total_batch_size=self.total_batch_size) + total_batch_size=self.total_batch_size, + ) x = maybe_normalize()(x) x = model_utils.ACTIVATIONS[self.activation_function](x) x = nn.avg_pool(x, (8, 8)) @@ -212,15 +225,18 @@ def build_flax_module(self): group_strides=self.hps.group_strides, num_outputs=self.hps['output_shape'][-1], conv_kernel_init=model_utils.INITIALIZERS[self.hps.conv_kernel_init]( - self.hps.conv_kernel_scale), + self.hps.conv_kernel_scale + ), dense_kernel_init=model_utils.INITIALIZERS[self.hps.dense_kernel_init]( - self.hps.dense_kernel_scale), + self.hps.dense_kernel_scale + ), dropout_rate=self.hps.dropout_rate, normalizer=self.hps.normalizer, activation_function=self.hps.activation_function, batch_size=self.hps.batch_size, virtual_batch_size=self.hps.virtual_batch_size, - total_batch_size=self.hps.total_accumulated_batch_size) + total_batch_size=self.hps.total_accumulated_batch_size, + ) def get_fake_inputs(self, hps): """Helper method solely for the purpose of initialzing the model.""" diff --git a/init2winit/model_lib/xformer_translate.py b/init2winit/model_lib/xformer_translate.py index ff7e5984..445ff61b 100644 --- a/init2winit/model_lib/xformer_translate.py +++ b/init2winit/model_lib/xformer_translate.py @@ -23,10 +23,10 @@ Transformer, using remat_scan with configuration (3, 3) results in a 30% time increase in the backward pass. """ + import functools from typing import Any, Callable, Optional, Sequence from absl import logging - from flax import linen as nn from init2winit import utils from init2winit.model_lib import attention @@ -39,7 +39,6 @@ from ml_collections.config_dict import config_dict import numpy as np - MLCOMMONS_DEFAULT_HPARAMS = config_dict.ConfigDict( dict( share_embeddings=False, @@ -62,7 +61,8 @@ dec_cross_attn_kernel_init='xavier_uniform', decode=False, normalize_attention=False, - )) + ) +) DEFAULT_HPARAMS = config_dict.ConfigDict( @@ -86,7 +86,8 @@ dec_cross_attn_kernel_init='xavier_uniform', decode=False, normalize_attention=False, - )) + ) +) class Scannable(nn.Module): @@ -97,6 +98,7 @@ class Scannable(nn.Module): input to a layer is of the form x, *others where x is changed by the layer and *others are extra arguments static throughout the layers. """ + build_fn: Callable[[], nn.Module] train: False @@ -107,14 +109,14 @@ def __call__(self, x): """Applies the Module to inputs. Args: - x: the inputs to the module. It is assumed to be a tuple of pytrees. - The first element of the tuple is mapped by self.block into an output - of the same structure (e.g. the Decoder activations fed to the - Encoder-Decoder Multi-Head-Attention). - The other elements are static arguments used by self.block that would - stay the same if we apply multiple self.block's one after the other - (e.g. the encoder output used by the Encoder-Decoder + x: the inputs to the module. It is assumed to be a tuple of pytrees. The + first element of the tuple is mapped by self.block into an output of the + same structure (e.g. the Decoder activations fed to the Encoder-Decoder + Multi-Head-Attention). The other elements are static arguments used by + self.block that would stay the same if we apply multiple self.block's + one after the other (e.g. the encoder output used by the Encoder-Decoder Multi-Head-Attention). + Returns: self.block(x[0], *x[1:]), *x[1:]. """ @@ -136,7 +138,8 @@ def shift_right(x, axis=1): pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) + x, pad_widths, mode='constant', constant_values=x.dtype.type(0) + ) return padded[:, :-1] @@ -160,8 +163,8 @@ def init(key, shape, dtype=np.float32): position = np.arange(0, max_len)[:, np.newaxis] scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1) div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor) - pe[:, :d_feature // 2] = np.sin(position * div_term) - pe[:, d_feature // 2:2 * (d_feature // 2)] = np.cos(position * div_term) + pe[:, : d_feature // 2] = np.sin(position * div_term) + pe[:, d_feature // 2 : 2 * (d_feature // 2)] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] return jnp.array(pe) @@ -176,16 +179,15 @@ class AddPositionEmbs(nn.Module): (non-learned) sinusoidal embedding table. decode: whether to use an autoregressive cache. """ + max_len: int = 512 posemb_init: Optional[model_utils.Initializer] = None decode: bool = False @nn.compact - def __call__(self, - inputs, - inputs_position=None, - train=True, - dtype=np.float32): + def __call__( + self, inputs, inputs_position=None, train=True, dtype=np.float32 + ): """Applies AddPositionEmbs module. By default this layer uses a fixed sinusoidal embedding table. If a @@ -203,22 +205,26 @@ def __call__(self, """ del train # inputs.shape is (batch_size, seq_len, emb_dim) - assert inputs.ndim == 3, ('Number of dimensions should be 3,' - ' but it is: %d' % inputs.ndim) + assert inputs.ndim == 3, ( + 'Number of dimensions should be 3, but it is: %d' % inputs.ndim + ) length = inputs.shape[1] pos_emb_shape = (1, self.max_len, inputs.shape[-1]) if self.posemb_init is None: # Use a fixed (non-learned) sinusoidal position embedding. - pos_embedding = sinusoidal_init(max_len=self.max_len)(None, pos_emb_shape, - dtype) + pos_embedding = sinusoidal_init(max_len=self.max_len)( + None, pos_emb_shape, dtype + ) else: pos_embedding = self.param( - 'pos_embedding', self.posemb_init, pos_emb_shape, dtype) + 'pos_embedding', self.posemb_init, pos_emb_shape, dtype + ) pe = pos_embedding[:, :length, :] if self.decode: is_initialized = self.has_variable('cache', 'cache_index') - cache_index = self.variable('cache', 'cache_index', - lambda: jnp.array(0, dtype=jnp.uint32)) + cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32) + ) if is_initialized: i = cache_index.value cache_index.value = i + 1 @@ -234,6 +240,7 @@ def __call__(self, class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" + mlp_dim: int dtype: model_utils.Dtype = jnp.float32 out_dim: Optional[int] = None @@ -250,8 +257,8 @@ def __call__(self, inputs, train): dtype=self.dtype, param_dtype=self.dtype, kernel_init=self.kernel_init, - bias_init=self.bias_init)( - inputs) + bias_init=self.bias_init, + )(inputs) x = nn.relu(x) x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) output = nn.Dense( @@ -259,10 +266,9 @@ def __call__(self, inputs, train): dtype=self.dtype, param_dtype=self.dtype, kernel_init=self.kernel_init, - bias_init=self.bias_init)( - x) - output = nn.Dropout( - rate=self.dropout_rate, deterministic=not train)(output) + bias_init=self.bias_init, + )(x) + output = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(output) return output @@ -279,10 +285,11 @@ class Encoder1DBlock(nn.Module): normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' normalize_attention: Apply LayerNorm to query and key before computing - dot_product_attention. - enc_self_attn_kernel_init_fn: initializer for encoder's - self attention matrices. + dot_product_attention. + enc_self_attn_kernel_init_fn: initializer for encoder's self attention + matrices. """ + qkv_dim: int mlp_dim: int num_heads: int @@ -294,10 +301,7 @@ class Encoder1DBlock(nn.Module): enc_self_attn_kernel_init_fn: model_utils.Initializer = initializers.xavier_uniform() # pylint: disable=line-too-long @nn.compact - def __call__(self, - inputs, - encoder_mask=None, - train=True): + def __call__(self, inputs, encoder_mask=None, train=True): """Applies Encoder1DBlock module. Args: @@ -311,16 +315,24 @@ def __call__(self, # Attention block. assert inputs.ndim == 3 if self.normalizer in [ - 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none']: + 'batch_norm', + 'layer_norm', + 'pre_layer_norm', + 'none', + ]: maybe_pre_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) else: raise ValueError('Unsupported normalizer: {}'.format(self.normalizer)) @@ -336,8 +348,8 @@ def __call__(self, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate, normalize_attention=self.normalize_attention, - name='EncoderSelfAttention')( - x, mask=encoder_mask, deterministic=not train) + name='EncoderSelfAttention', + )(x, mask=encoder_mask, deterministic=not train) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) x = x + inputs @@ -349,7 +361,8 @@ def __call__(self, mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate, - name='MLPBlock')(y, train=train) + name='MLPBlock', + )(y, train=train) res = x + y return maybe_post_normalize(param_dtype=self.dtype)(res) @@ -368,13 +381,14 @@ class EncoderDecoder1DBlock(nn.Module): normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' normalize_attention: Apply LayerNorm to query and key before computing - dot_product_attention. - dec_self_attn_kernel_init_fn: initializer for decoder's - self attention matrices. - dec_cross_attn_kernel_init_fn: initializer for decoder's - cross attention matrices. + dot_product_attention. + dec_self_attn_kernel_init_fn: initializer for decoder's self attention + matrices. + dec_cross_attn_kernel_init_fn: initializer for decoder's cross attention + matrices. decode: whether to use an autoregressive cache. """ + qkv_dim: int mlp_dim: int num_heads: int @@ -388,12 +402,14 @@ class EncoderDecoder1DBlock(nn.Module): decode: bool = False @nn.compact - def __call__(self, - targets, - encoded, - decoder_mask=None, - encoder_decoder_mask=None, - train=True): + def __call__( + self, + targets, + encoded, + decoder_mask=None, + encoder_decoder_mask=None, + train=True, + ): """Applies EncoderDecoder1DBlock module. Args: @@ -409,16 +425,24 @@ def __call__(self, # Decoder block. assert targets.ndim == 3 if self.normalizer in [ - 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none']: + 'batch_norm', + 'layer_norm', + 'pre_layer_norm', + 'none', + ]: maybe_pre_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) else: raise ValueError('Unsupported normalizer: {}'.format(self.normalizer)) @@ -435,8 +459,8 @@ def __call__(self, dropout_rate=self.attention_dropout_rate, decode=self.decode, name='DecoderSelfAttention', - normalize_attention=self.normalize_attention)( - x, decoder_mask, deterministic=not train) + normalize_attention=self.normalize_attention, + )(x, decoder_mask, deterministic=not train) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) x = x + targets @@ -453,11 +477,10 @@ def __call__(self, use_bias=False, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate, - normalize_attention=self.normalize_attention)( - y, encoded, encoder_decoder_mask, deterministic=not train) + normalize_attention=self.normalize_attention, + )(y, encoded, encoder_decoder_mask, deterministic=not train) - y = nn.Dropout(rate=self.dropout_rate)( - y, deterministic=not train) + y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) y = y + x y = maybe_post_normalize(param_dtype=self.dtype)(y) @@ -467,7 +490,8 @@ def __call__(self, mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate, - name='MLPBlock')(z, train=train) + name='MLPBlock', + )(z, train=train) res = y + z return maybe_post_normalize(param_dtype=self.dtype)(res) @@ -476,27 +500,28 @@ def __call__(self, class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation. - vocab_size: size of the vocabulary - shared_embedding: a shared embedding layer to use. - dtype: the jnp.dtype for the model parameters. - emb_dim: dimension of embedding - num_heads: number of heads - enc_num_layers: number of layers. It is ignored if enc_remat_scan_lengths - is not None. - qkv_dim: dimension of the query/key/value - mlp_dim: dimension of the mlp on top of attention block - max_len: maximum length. - dropout_rate: dropout rate - normalizer: One of 'batch_norm', 'layer_norm', 'none' - normalize_attention: Apply LayerNorm to query and key before computing - dot_product_attention. - attention_dropout_rate: dropout rate for attention weights - enc_self_attn_kernel_init_fn: initializer for encoder's - self attention matrices. - enc_remat_scan_lengths: if not None, it is the sequence of lengths to use - for remat_scan. See flax.linen.remat_scan; in this case this - defines the total number of layers, not enc_num_layers. + vocab_size: size of the vocabulary + shared_embedding: a shared embedding layer to use. + dtype: the jnp.dtype for the model parameters. + emb_dim: dimension of embedding + num_heads: number of heads + enc_num_layers: number of layers. It is ignored if enc_remat_scan_lengths + is not None. + qkv_dim: dimension of the query/key/value + mlp_dim: dimension of the mlp on top of attention block + max_len: maximum length. + dropout_rate: dropout rate + normalizer: One of 'batch_norm', 'layer_norm', 'none' + normalize_attention: Apply LayerNorm to query and key before computing + dot_product_attention. + attention_dropout_rate: dropout rate for attention weights + enc_self_attn_kernel_init_fn: initializer for encoder's + self attention matrices. + enc_remat_scan_lengths: if not None, it is the sequence of lengths to use + for remat_scan. See flax.linen.remat_scan; in this case this + defines the total number of layers, not enc_num_layers. """ + vocab_size: int shared_embedding: Any = None dtype: jnp.dtype = jnp.float32 @@ -514,11 +539,9 @@ class Encoder(nn.Module): enc_remat_scan_lengths: Optional[Sequence[int]] = None @nn.compact - def __call__(self, - inputs, - inputs_position=None, - encoder_mask=None, - train=True): + def __call__( + self, inputs, inputs_position=None, encoder_mask=None, train=True + ): """Applies Transformer model on the inputs. Args: @@ -540,14 +563,15 @@ def __call__(self, param_dtype=self.dtype, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), - name='input_vocab_embeddings') + name='input_vocab_embeddings', + ) else: input_embed = self.shared_embedding x = inputs.astype('int32') x = input_embed(x) x = AddPositionEmbs( - max_len=self.max_len, decode=False, name='posembed_input')( - x, inputs_position=inputs_position, train=train, dtype=self.dtype) + max_len=self.max_len, decode=False, name='posembed_input' + )(x, inputs_position=inputs_position, train=train, dtype=self.dtype) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) # Input encoder. @@ -561,22 +585,26 @@ def __call__(self, attention_dropout_rate=self.attention_dropout_rate, normalizer=self.normalizer, normalize_attention=self.normalize_attention, - enc_self_attn_kernel_init_fn=self.enc_self_attn_kernel_init_fn) + enc_self_attn_kernel_init_fn=self.enc_self_attn_kernel_init_fn, + ) if self.enc_remat_scan_lengths is None: for lyr in range(self.enc_num_layers): x = build_fn(name=f'encoderblock_{lyr}')( - x, encoder_mask=encoder_mask, train=train) + x, encoder_mask=encoder_mask, train=train + ) else: - logging.info('Using Remat Scan, ignoring enc_num_layers; ' - 'number of layers=%d', np.prod(self.enc_remat_scan_lengths)) - enc_stack = nn.remat_scan( - Scannable, lengths=self.enc_remat_scan_lengths)(build_fn=build_fn, - train=train, - name='EncoderStack') + logging.info( + 'Using Remat Scan, ignoring enc_num_layers; number of layers=%d', + np.prod(self.enc_remat_scan_lengths), + ) + enc_stack = nn.remat_scan(Scannable, lengths=self.enc_remat_scan_lengths)( + build_fn=build_fn, train=train, name='EncoderStack' + ) x = enc_stack((x, encoder_mask))[0] if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) x = maybe_normalize(param_dtype=self.dtype)(x) return x @@ -584,33 +612,34 @@ def __call__(self, class Decoder(nn.Module): """Transformer Model Decoder for sequence to sequence translation. - output_vocab_size: size of the vocabulary. - shared_embedding: a shared embedding layer to use. - logits_via_embedding: bool: whether final logit transform shares embedding - weights. - dtype: the jnp.dtype for the model parameters. - emb_dim: dimension of embedding. - num_heads: number of heads. - dec_num_layers: number of layers. It is ignored if dec_remat_scan_lengths - is not None. - qkv_dim: dimension of the query/key/value. - mlp_dim: dimension of the mlp on top of attention block. - max_len: maximum length. - decode: whether to use an autoregressive cache. - dropout_rate: dropout rate. - normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', - 'pre_layer_norm', 'none' - normalize_attention: Apply LayerNorm to query and key before computing - dot_product_attention. - attention_dropout_rate: dropout rate for attention weights. - dec_self_attn_kernel_init_fn: initializer for decoder's - self attention matrices. - dec_cross_attn_kernel_init_fn: initializer for decoder's - cross attention matrices. - dec_remat_scan_lengths: if not None, it is the sequence of lengths to use - for remat_scan. See flax.linen.remat_scan; in this case this - defines the total number of layers, not dec_num_layers. + output_vocab_size: size of the vocabulary. + shared_embedding: a shared embedding layer to use. + logits_via_embedding: bool: whether final logit transform shares embedding + weights. + dtype: the jnp.dtype for the model parameters. + emb_dim: dimension of embedding. + num_heads: number of heads. + dec_num_layers: number of layers. It is ignored if dec_remat_scan_lengths + is not None. + qkv_dim: dimension of the query/key/value. + mlp_dim: dimension of the mlp on top of attention block. + max_len: maximum length. + decode: whether to use an autoregressive cache. + dropout_rate: dropout rate. + normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', + 'pre_layer_norm', 'none' + normalize_attention: Apply LayerNorm to query and key before computing + dot_product_attention. + attention_dropout_rate: dropout rate for attention weights. + dec_self_attn_kernel_init_fn: initializer for decoder's + self attention matrices. + dec_cross_attn_kernel_init_fn: initializer for decoder's + cross attention matrices. + dec_remat_scan_lengths: if not None, it is the sequence of lengths to use + for remat_scan. See flax.linen.remat_scan; in this case this + defines the total number of layers, not dec_num_layers. """ + output_vocab_size: int shared_embedding: Any = None logits_via_embedding: bool = False @@ -631,13 +660,15 @@ class Decoder(nn.Module): dec_remat_scan_lengths: Optional[Sequence[int]] = None @nn.compact - def __call__(self, - encoded, - targets, - targets_position=None, - decoder_mask=None, - encoder_decoder_mask=None, - train=True): + def __call__( + self, + encoded, + targets, + targets_position=None, + decoder_mask=None, + encoder_decoder_mask=None, + train=True, + ): """Applies Transformer model on the inputs. Args: @@ -646,7 +677,6 @@ def __call__(self, targets_position: input subsequence positions for packed examples. decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. - train: whether it is training. Returns: @@ -663,7 +693,8 @@ def __call__(self, dtype=self.dtype, param_dtype=self.dtype, embedding_init=nn.initializers.normal(stddev=1.0), - name='output_vocab_embeddings') + name='output_vocab_embeddings', + ) else: output_embed = self.shared_embedding @@ -672,11 +703,8 @@ def __call__(self, y = shift_right(y) y = output_embed(y) y = AddPositionEmbs( - max_len=self.max_len, decode=self.decode, name='posembed_output')( - y, - inputs_position=targets_position, - train=train, - dtype=self.dtype) + max_len=self.max_len, decode=self.decode, name='posembed_output' + )(y, inputs_position=targets_position, train=train, dtype=self.dtype) y = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(y) # Target-Input Decoder @@ -692,7 +720,8 @@ def __call__(self, normalize_attention=self.normalize_attention, dec_self_attn_kernel_init_fn=self.dec_self_attn_kernel_init_fn, dec_cross_attn_kernel_init_fn=self.dec_cross_attn_kernel_init_fn, - decode=self.decode) + decode=self.decode, + ) if self.dec_remat_scan_lengths is None: for lyr in range(self.dec_num_layers): @@ -701,20 +730,23 @@ def __call__(self, encoded, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, - train=train) + train=train, + ) else: - logging.info('Using Remat Scan, ignoring enc_num_layers; ' - 'number of layers=%d', np.prod(self.dec_remat_scan_lengths)) - dec_stack = nn.remat_scan( - Scannable, lengths=self.dec_remat_scan_lengths)(build_fn=build_fn, - train=train, - name='DecoderStack') + logging.info( + 'Using Remat Scan, ignoring enc_num_layers; number of layers=%d', + np.prod(self.dec_remat_scan_lengths), + ) + dec_stack = nn.remat_scan(Scannable, lengths=self.dec_remat_scan_lengths)( + build_fn=build_fn, train=train, name='DecoderStack' + ) if decoder_mask is not None: decoder_mask = decoder_mask.astype(self.dtype) y = dec_stack((y, encoded, decoder_mask, encoder_decoder_mask))[0] if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) y = maybe_normalize(param_dtype=self.dtype)(y) # Decoded Logits @@ -731,8 +763,8 @@ def __call__(self, param_dtype=self.dtype, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), - name='logitdense')( - y) + name='logitdense', + )(y) return logits @@ -745,37 +777,38 @@ def __call__(self, class Transformer(nn.Module): """Transformer Model for sequence to sequence translation. - vocab_size: size of the input vocabulary. - output_vocab_size: size of the output vocabulary. If None, the output - vocabulary size is assumed to be the same as vocab_size. - share_embeddings: bool: share embedding layer for inputs and targets. - logits_via_embedding: bool: whether final logit transform shares embedding - weights. - dtype: the jnp.dtype for the model parameters. - emb_dim: dimension of embedding. - num_heads: number of heads. - enc_num_layers: number of encoder layers. - enc_remat_scan_lengths: Optional sequence of lengths to use with - flax.linen.remat_scan. - dec_num_layers: number of decoder layers. - dec_remat_scan_lengths: Optional sequence of lengths to use with - flax.linen.remat_scan. - qkv_dim: dimension of the query/key/value. - mlp_dim: dimension of the mlp on top of attention block. - max_len: maximum length. - dropout_rate: dropout rate. - attention_dropout_rate: dropout rate for attention weights. - normalizer: One of 'batch_norm', 'layer_norm', 'none' - normalize_attention: Apply LayerNorm to query and key before computing - dot_product_attention. - enc_self_attn_kernel_init_fn: initializer for encoder's - self attention matrices. - dec_self_attn_kernel_init_fn: initializer for decoder's - self attention matrices. - dec_cross_attn_kernel_init_fn: initializer for decoder's - cross attention matrices. - decode: whether to use an autoregressive cache. + vocab_size: size of the input vocabulary. + output_vocab_size: size of the output vocabulary. If None, the output + vocabulary size is assumed to be the same as vocab_size. + share_embeddings: bool: share embedding layer for inputs and targets. + logits_via_embedding: bool: whether final logit transform shares embedding + weights. + dtype: the jnp.dtype for the model parameters. + emb_dim: dimension of embedding. + num_heads: number of heads. + enc_num_layers: number of encoder layers. + enc_remat_scan_lengths: Optional sequence of lengths to use with + flax.linen.remat_scan. + dec_num_layers: number of decoder layers. + dec_remat_scan_lengths: Optional sequence of lengths to use with + flax.linen.remat_scan. + qkv_dim: dimension of the query/key/value. + mlp_dim: dimension of the mlp on top of attention block. + max_len: maximum length. + dropout_rate: dropout rate. + attention_dropout_rate: dropout rate for attention weights. + normalizer: One of 'batch_norm', 'layer_norm', 'none' + normalize_attention: Apply LayerNorm to query and key before computing + dot_product_attention. + enc_self_attn_kernel_init_fn: initializer for encoder's + self attention matrices. + dec_self_attn_kernel_init_fn: initializer for decoder's + self attention matrices. + dec_cross_attn_kernel_init_fn: initializer for decoder's + cross attention matrices. + decode: whether to use an autoregressive cache. """ + vocab_size: Optional[int] = None output_vocab_size: Optional[int] = None share_embeddings: bool = False @@ -801,26 +834,32 @@ class Transformer(nn.Module): def setup(self): if self.enc_num_layers and self.enc_remat_scan_lengths: - raise ValueError(f'Only one of enc_num_layers ({self.enc_num_layers})' - 'and enc_remat_scan_lengths' - f'({self.enc_remat_scan_lengths}) can be set.') + raise ValueError( + f'Only one of enc_num_layers ({self.enc_num_layers})' + 'and enc_remat_scan_lengths' + f'({self.enc_remat_scan_lengths}) can be set.' + ) if self.dec_num_layers and self.dec_remat_scan_lengths: - raise ValueError(f'Only one of dec_num_layers ({self.dec_num_layers})' - 'and dec_remat_scan_lengths' - f'({self.dec_remat_scan_lengths}) can be set.') + raise ValueError( + f'Only one of dec_num_layers ({self.dec_num_layers})' + 'and dec_remat_scan_lengths' + f'({self.dec_remat_scan_lengths}) can be set.' + ) if self.share_embeddings: if self.output_vocab_size is not None: - assert self.output_vocab_size == self.vocab_size, ( - "can't share embedding with different vocab sizes.") + assert ( + self.output_vocab_size == self.vocab_size + ), "can't share embedding with different vocab sizes." self.shared_embedding = nn.Embed( num_embeddings=self.vocab_size, features=self.emb_dim, dtype=self.dtype, param_dtype=self.dtype, embedding_init=nn.initializers.normal(stddev=1.0), - name='VocabEmbeddings') + name='VocabEmbeddings', + ) else: self.shared_embedding = None @@ -840,7 +879,8 @@ def setup(self): normalizer=self.normalizer, enc_self_attn_kernel_init_fn=self.enc_self_attn_kernel_init_fn, enc_remat_scan_lengths=self.enc_remat_scan_lengths, - name='encoder') + name='encoder', + ) self.decoder = Decoder( output_vocab_size=self.output_vocab_size, shared_embedding=self.shared_embedding, @@ -860,17 +900,20 @@ def setup(self): dec_cross_attn_kernel_init_fn=self.dec_self_attn_kernel_init_fn, decode=self.should_decode, dec_remat_scan_lengths=self.dec_remat_scan_lengths, - name='decoder') + name='decoder', + ) @nn.compact - def __call__(self, - inputs, - targets, - inputs_position=None, - targets_position=None, - inputs_segmentation=None, - targets_segmentation=None, - train=False): + def __call__( + self, + inputs, + targets, + inputs_position=None, + targets_position=None, + inputs_segmentation=None, + targets_segmentation=None, + train=False, + ): """Applies Transformer model on the inputs. Args: @@ -885,18 +928,22 @@ def __call__(self, Returns: Output: [batch_size, target_sequence_length, qkv_dim] """ - encoded = self.encode(inputs, - inputs_position=inputs_position, - inputs_segmentation=inputs_segmentation, - train=train) - - logits = self.decode(encoded, - inputs, # only used for masks - targets, - targets_position=targets_position, - inputs_segmentation=inputs_segmentation, - targets_segmentation=targets_segmentation, - train=train) + encoded = self.encode( + inputs, + inputs_position=inputs_position, + inputs_segmentation=inputs_segmentation, + train=train, + ) + + logits = self.decode( + encoded, + inputs, # only used for masks + targets, + targets_position=targets_position, + inputs_segmentation=inputs_segmentation, + targets_segmentation=targets_segmentation, + train=train, + ) return logits # The following two methods allow us to run the trained Transformer in @@ -905,39 +952,39 @@ def __call__(self, # cache object for iteratively storing keys and values during the decoding # process. - def encode(self, - inputs, - inputs_position=None, - inputs_segmentation=None, - train=False): + def encode( + self, inputs, inputs_position=None, inputs_segmentation=None, train=False + ): # Make padding attention mask. dtype = self.dtype - encoder_mask = nn.make_attention_mask( - inputs > 0, inputs > 0, dtype=dtype) + encoder_mask = nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: encoder_mask = nn.combine_masks( encoder_mask, - nn.make_attention_mask(inputs_segmentation, - inputs_segmentation, - jnp.equal, - dtype=dtype)) + nn.make_attention_mask( + inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype + ), + ) encoded = self.encoder( inputs, inputs_position=inputs_position, encoder_mask=encoder_mask, - train=train) + train=train, + ) return encoded - def decode(self, - encoded, - inputs, - targets, - targets_position=None, - inputs_segmentation=None, - targets_segmentation=None, - train=False): + def decode( + self, + encoded, + inputs, + targets, + targets_position=None, + inputs_segmentation=None, + targets_segmentation=None, + train=False, + ): # Make padding attention masks. dtype = self.dtype if self.should_decode: @@ -945,28 +992,31 @@ def decode(self, # used. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( - jnp.ones_like(targets) > 0, inputs > 0, dtype=dtype) + jnp.ones_like(targets) > 0, inputs > 0, dtype=dtype + ) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=dtype), - nn.make_causal_mask(targets, dtype=dtype)) + nn.make_causal_mask(targets, dtype=dtype), + ) encoder_decoder_mask = nn.make_attention_mask( - targets > 0, inputs > 0, dtype=dtype) + targets > 0, inputs > 0, dtype=dtype + ) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, - nn.make_attention_mask(targets_segmentation, - targets_segmentation, - jnp.equal, - dtype=dtype)) + nn.make_attention_mask( + targets_segmentation, targets_segmentation, jnp.equal, dtype=dtype + ), + ) encoder_decoder_mask = nn.combine_masks( encoder_decoder_mask, - nn.make_attention_mask(targets_segmentation, - inputs_segmentation, - jnp.equal, - dtype=dtype)) + nn.make_attention_mask( + targets_segmentation, inputs_segmentation, jnp.equal, dtype=dtype + ), + ) logits = self.decoder( encoded, @@ -974,7 +1024,8 @@ def decode(self, targets_position=targets_position, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, - train=train) + train=train, + ) return logits @@ -994,14 +1045,12 @@ def evaluate_batch(self, params, batch_stats, batch): # Add log-perplexity metric. return self.metrics_bundle.single_from_model_output( - logits=logits, targets=targets, weights=weights, axis_name='batch') - - def apply_on_batch(self, - params, - batch_stats, - batch, - train=True, - **apply_kwargs): + logits=logits, targets=targets, weights=weights, axis_name='batch' + ) + + def apply_on_batch( + self, params, batch_stats, batch, train=True, **apply_kwargs + ): """Wrapper around flax_module.apply.""" variables = {'params': params} if batch_stats is not None: @@ -1035,7 +1084,8 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): batch, mutable=['batch_stats'], rngs={'dropout': dropout_rng}, - train=True) + train=True, + ) weights = batch.get('weights') targets = batch['targets'] @@ -1045,30 +1095,37 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): # Optionally apply label smoothing. if self.hps.get('label_smoothing') is not None: targets = model_utils.apply_label_smoothing( - targets, self.hps.get('label_smoothing')) - (total_loss, total_weight) = self.loss_fn( - logits, targets, weights) + targets, self.hps.get('label_smoothing') + ) + total_loss, total_weight = self.loss_fn(logits, targets, weights) # (total_loss, total_weight) = lax.psum( # (total_loss, total_weight), axis_name='batch') - total_loss = (total_loss / total_weight) + total_loss = total_loss / total_weight if self.hps.get('l2_decay_factor'): l2_loss = model_utils.l2_regularization( - params, self.hps.l2_decay_rank_threshold) + params, self.hps.l2_decay_rank_threshold + ) total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss return total_loss, (new_batch_stats) def build_flax_module(self): - max_len = max(self.hps.max_target_length, self.hps.max_eval_target_length, - self.hps.max_predict_length) + max_len = max( + self.hps.max_target_length, + self.hps.max_eval_target_length, + self.hps.max_predict_length, + ) enc_self_attn_kernel_init_fn = model_utils.INITIALIZERS[ - self.hps.enc_self_attn_kernel_init]() + self.hps.enc_self_attn_kernel_init + ]() dec_self_attn_kernel_init_fn = model_utils.INITIALIZERS[ - self.hps.dec_self_attn_kernel_init]() + self.hps.dec_self_attn_kernel_init + ]() dec_cross_attn_kernel_init_fn = model_utils.INITIALIZERS[ - self.hps.dec_cross_attn_kernel_init]() + self.hps.dec_cross_attn_kernel_init + ]() dtype = utils.dtype_from_str(self.hps.model_dtype) return Transformer( @@ -1113,14 +1170,20 @@ class MLCommonsTransformerTranslate(TransformerTranslate): """ def build_flax_module(self): - max_len = max(self.hps.max_target_length, self.hps.max_eval_target_length, - self.hps.max_predict_length) + max_len = max( + self.hps.max_target_length, + self.hps.max_eval_target_length, + self.hps.max_predict_length, + ) enc_self_attn_kernel_init_fn = model_utils.INITIALIZERS[ - self.hps.enc_self_attn_kernel_init]() + self.hps.enc_self_attn_kernel_init + ]() dec_self_attn_kernel_init_fn = model_utils.INITIALIZERS[ - self.hps.dec_self_attn_kernel_init]() + self.hps.dec_self_attn_kernel_init + ]() dec_cross_attn_kernel_init_fn = model_utils.INITIALIZERS[ - self.hps.dec_cross_attn_kernel_init]() + self.hps.dec_cross_attn_kernel_init + ]() dtype = utils.dtype_from_str(self.hps.model_dtype) aux_dropout_rate = ( self.hps.dropout_rate diff --git a/init2winit/model_lib/xformer_translate_binary.py b/init2winit/model_lib/xformer_translate_binary.py index ace8d5e4..b77fe64d 100644 --- a/init2winit/model_lib/xformer_translate_binary.py +++ b/init2winit/model_lib/xformer_translate_binary.py @@ -17,6 +17,7 @@ Adapted from third_party/py/language/google/generation/tsukuyomi/models.py """ + import dataclasses from typing import Any, Optional @@ -96,7 +97,8 @@ def shift_right(x, axis=1): pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) + x, pad_widths, mode='constant', constant_values=x.dtype.type(0) + ) return padded[:, :-1] @@ -120,8 +122,8 @@ def init(key, shape, dtype=np.float32): position = np.arange(0, max_len)[:, np.newaxis] scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1) div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor) - pe[:, :d_feature // 2] = np.sin(position * div_term) - pe[:, d_feature // 2:2 * (d_feature // 2)] = np.cos(position * div_term) + pe[:, : d_feature // 2] = np.sin(position * div_term) + pe[:, d_feature // 2 : 2 * (d_feature // 2)] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] return jnp.array(pe) @@ -136,15 +138,13 @@ class AddPositionEmbs(nn.Module): (non-learned) sinusoidal embedding table. decode: whether to use an autoregressive cache. """ + max_len: int = 512 posemb_init: Optional[model_utils.Initializer] = None decode: bool = False @nn.compact - def __call__(self, - inputs, - inputs_positions=None, - train=True): + def __call__(self, inputs, inputs_positions=None, train=True): """Applies AddPositionEmbs module. By default this layer uses a fixed sinusoidal embedding table. If a @@ -161,22 +161,26 @@ def __call__(self, """ del train # inputs.shape is (batch_size, seq_len, emb_dim) - assert inputs.ndim == 3, ('Number of dimensions should be 3,' - ' but it is: %d' % inputs.ndim) + assert inputs.ndim == 3, ( + 'Number of dimensions should be 3, but it is: %d' % inputs.ndim + ) length = inputs.shape[1] pos_emb_shape = (1, self.max_len, inputs.shape[-1]) if self.posemb_init is None: # Use a fixed (non-learned) sinusoidal position embedding. pos_embedding = sinusoidal_init(max_len=self.max_len)( - None, pos_emb_shape, None) + None, pos_emb_shape, None + ) else: pos_embedding = self.param( - 'pos_embedding', pos_emb_shape, self.posemb_init) + 'pos_embedding', pos_emb_shape, self.posemb_init + ) pe = pos_embedding[:, :length, :] if self.decode: is_initialized = self.has_variable('cache', 'cache_index') - cache_index = self.variable('cache', 'cache_index', - lambda: jnp.array(0, dtype=jnp.uint32)) + cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32) + ) if is_initialized: i = cache_index.value cache_index.value = i + 1 @@ -192,6 +196,7 @@ def __call__(self, class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" + mlp_dim: int dtype: model_utils.Dtype = jnp.float32 out_dim: Optional[int] = None @@ -219,11 +224,13 @@ def __call__(self, inputs, train): kernel_init=self.kernel_init, bias_init=self.bias_init, weight_bin_hparams=whps, - inputs_bin_hparams=ahps)(inputs) + inputs_bin_hparams=ahps, + )(inputs) x = nn.relu(x) # add layer norm 1 for the sake of binarizing input activations of dense2 dense1_normalize = model_utils.get_normalizer( - 'layer_norm', train, dtype=self.dtype) + 'layer_norm', train, dtype=self.dtype + ) x = dense1_normalize()(x) x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) output = binarize_layers.BiDense( @@ -232,13 +239,14 @@ def __call__(self, inputs, train): kernel_init=self.kernel_init, bias_init=self.bias_init, weight_bin_hparams=whps, - inputs_bin_hparams=ahps)(x) + inputs_bin_hparams=ahps, + )(x) # Add layer norm 2 to adjust the output magnitude after binarization dense2_normalize = model_utils.get_normalizer( - 'layer_norm', train, dtype=self.dtype) + 'layer_norm', train, dtype=self.dtype + ) output = dense2_normalize()(output) - output = nn.Dropout( - rate=self.dropout_rate, deterministic=not train)(output) + output = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(output) return output @@ -254,9 +262,10 @@ class Encoder1DBlock(nn.Module): attention_dropout_rate: Dropout rate for attention weights. normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' - enc_self_attn_kernel_init_fn: initializer for encoder's - self attention matrices. + enc_self_attn_kernel_init_fn: initializer for encoder's self attention + matrices. """ + qkv_dim: int mlp_dim: int num_heads: int @@ -273,10 +282,7 @@ class Encoder1DBlock(nn.Module): ) @nn.compact - def __call__(self, - inputs, - encoder_mask=None, - train=True): + def __call__(self, inputs, encoder_mask=None, train=True): """Applies Encoder1DBlock module. Args: @@ -290,16 +296,24 @@ def __call__(self, # Attention block. assert inputs.ndim == 3 if self.normalizer in [ - 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none']: + 'batch_norm', + 'layer_norm', + 'pre_layer_norm', + 'none', + ]: maybe_pre_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) else: raise ValueError('Unsupported normalizer: {}'.format(self.normalizer)) @@ -315,8 +329,8 @@ def __call__(self, dropout_rate=self.attention_dropout_rate, binarize_hparams=self.binarize_hparams, dynamic_context=self.dynamic_context, - name='EncoderSelfAttention')( - x, mask=encoder_mask, deterministic=not train) + name='EncoderSelfAttention', + )(x, mask=encoder_mask, deterministic=not train) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) x = x + inputs @@ -330,7 +344,8 @@ def __call__(self, dropout_rate=self.dropout_rate, binarize_hparams=self.binarize_hparams, dynamic_context=self.dynamic_context, - name='MLPBlock')(y, train=train) + name='MLPBlock', + )(y, train=train) res = x + y return maybe_post_normalize()(res) @@ -348,12 +363,13 @@ class EncoderDecoder1DBlock(nn.Module): attention_dropout_rate: Dropout rate for attention weights normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' - dec_self_attn_kernel_init_fn: initializer for decoder's - self attention matrices. - dec_cross_attn_kernel_init_fn: initializer for decoder's - cross attention matrices. + dec_self_attn_kernel_init_fn: initializer for decoder's self attention + matrices. + dec_cross_attn_kernel_init_fn: initializer for decoder's cross attention + matrices. decode: whether to use an autoregressive cache. """ + qkv_dim: int mlp_dim: int num_heads: int @@ -372,12 +388,14 @@ class EncoderDecoder1DBlock(nn.Module): ) @nn.compact - def __call__(self, - targets, - encoded, - decoder_mask=None, - encoder_decoder_mask=None, - train=True): + def __call__( + self, + targets, + encoded, + decoder_mask=None, + encoder_decoder_mask=None, + train=True, + ): """Applies EncoderDecoder1DBlock module. Args: @@ -393,16 +411,24 @@ def __call__(self, # Decoder block. assert targets.ndim == 3 if self.normalizer in [ - 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none']: + 'batch_norm', + 'layer_norm', + 'pre_layer_norm', + 'none', + ]: maybe_pre_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) else: raise ValueError('Unsupported normalizer: {}'.format(self.normalizer)) @@ -419,7 +445,8 @@ def __call__(self, decode=self.decode, binarize_hparams=self.binarize_hparams, dynamic_context=self.dynamic_context, - name='DecoderSelfAttention')(x, decoder_mask, deterministic=not train) + name='DecoderSelfAttention', + )(x, decoder_mask, deterministic=not train) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) x = x + targets @@ -436,11 +463,10 @@ def __call__(self, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate, binarize_hparams=self.binarize_hparams, - dynamic_context=self.dynamic_context)( - y, encoded, encoder_decoder_mask, deterministic=not train) + dynamic_context=self.dynamic_context, + )(y, encoded, encoder_decoder_mask, deterministic=not train) - y = nn.Dropout(rate=self.dropout_rate)( - y, deterministic=not train) + y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) y = y + x y = maybe_post_normalize()(y) @@ -452,7 +478,8 @@ def __call__(self, dropout_rate=self.dropout_rate, binarize_hparams=self.binarize_hparams, dynamic_context=self.dynamic_context, - name='MLPBlock')(z, train=train) + name='MLPBlock', + )(z, train=train) res = y + z return maybe_post_normalize()(res) @@ -461,21 +488,22 @@ def __call__(self, class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation. - vocab_size: size of the vocabulary - shared_embedding: a shared embedding layer to use. - use_bfloat16: bool: whether use bfloat16. - emb_dim: dimension of embedding - num_heads: number of heads - enc_num_layers: number of layers - qkv_dim: dimension of the query/key/value - mlp_dim: dimension of the mlp on top of attention block - max_len: maximum length. - dropout_rate: dropout rate - normalizer: One of 'batch_norm', 'layer_norm', 'none' - attention_dropout_rate: dropout rate for attention weights - enc_self_attn_kernel_init_fn: initializer for encoder's - self attention matrices. + vocab_size: size of the vocabulary + shared_embedding: a shared embedding layer to use. + use_bfloat16: bool: whether use bfloat16. + emb_dim: dimension of embedding + num_heads: number of heads + enc_num_layers: number of layers + qkv_dim: dimension of the query/key/value + mlp_dim: dimension of the mlp on top of attention block + max_len: maximum length. + dropout_rate: dropout rate + normalizer: One of 'batch_norm', 'layer_norm', 'none' + attention_dropout_rate: dropout rate for attention weights + enc_self_attn_kernel_init_fn: initializer for encoder's + self attention matrices. """ + vocab_size: int shared_embedding: Any = None use_bfloat16: bool = False @@ -497,11 +525,9 @@ class Encoder(nn.Module): ) @nn.compact - def __call__(self, - inputs, - inputs_positions=None, - encoder_mask=None, - train=True): + def __call__( + self, inputs, inputs_positions=None, encoder_mask=None, train=True + ): """Applies Transformer model on the inputs. Args: @@ -521,16 +547,15 @@ def __call__(self, num_embeddings=self.vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), - name='input_vocab_embeddings') + name='input_vocab_embeddings', + ) else: input_embed = self.shared_embedding x = inputs.astype('int32') x = input_embed(x) x = AddPositionEmbs( - max_len=self.max_len, - decode=False, - name='posembed_input')( - x, inputs_positions=inputs_positions, train=train) + max_len=self.max_len, decode=False, name='posembed_input' + )(x, inputs_positions=inputs_positions, train=train) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) if self.use_bfloat16: @@ -552,13 +577,12 @@ def __call__(self, enc_self_attn_kernel_init_fn=self.enc_self_attn_kernel_init_fn, binarize_hparams=self.binarize_hparams, dynamic_context=self.dynamic_context, - name=f'encoderblock_{lyr}')( - x, - encoder_mask=encoder_mask, - train=train) + name=f'encoderblock_{lyr}', + )(x, encoder_mask=encoder_mask, train=train) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=dtype) + self.normalizer, train, dtype=dtype + ) x = maybe_normalize()(x) return x @@ -566,27 +590,28 @@ def __call__(self, class Decoder(nn.Module): """Transformer Model Decoder for sequence to sequence translation. - output_vocab_size: size of the vocabulary. - shared_embedding: a shared embedding layer to use. - logits_via_embedding: bool: whether final logit transform shares embedding - weights. - use_bfloat16: bool: whether use bfloat16. - emb_dim: dimension of embedding. - num_heads: number of heads. - dec_num_layers: number of layers. - qkv_dim: dimension of the query/key/value. - mlp_dim: dimension of the mlp on top of attention block. - max_len: maximum length. - decode: whether to use an autoregressive cache. - dropout_rate: dropout rate. - normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', - 'pre_layer_norm', 'none' - attention_dropout_rate: dropout rate for attention weights. - dec_self_attn_kernel_init_fn: initializer for decoder's - self attention matrices. - dec_cross_attn_kernel_init_fn: initializer for decoder's - cross attention matrices. + output_vocab_size: size of the vocabulary. + shared_embedding: a shared embedding layer to use. + logits_via_embedding: bool: whether final logit transform shares embedding + weights. + use_bfloat16: bool: whether use bfloat16. + emb_dim: dimension of embedding. + num_heads: number of heads. + dec_num_layers: number of layers. + qkv_dim: dimension of the query/key/value. + mlp_dim: dimension of the mlp on top of attention block. + max_len: maximum length. + decode: whether to use an autoregressive cache. + dropout_rate: dropout rate. + normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', + 'pre_layer_norm', 'none' + attention_dropout_rate: dropout rate for attention weights. + dec_self_attn_kernel_init_fn: initializer for decoder's + self attention matrices. + dec_cross_attn_kernel_init_fn: initializer for decoder's + cross attention matrices. """ + output_vocab_size: int shared_embedding: Any = None logits_via_embedding: bool = False @@ -611,13 +636,15 @@ class Decoder(nn.Module): ) @nn.compact - def __call__(self, - encoded, - targets, - targets_positions=None, - decoder_mask=None, - encoder_decoder_mask=None, - train=True): + def __call__( + self, + encoded, + targets, + targets_positions=None, + decoder_mask=None, + encoder_decoder_mask=None, + train=True, + ): """Applies Transformer model on the inputs. Args: @@ -626,7 +653,6 @@ def __call__(self, targets_positions: input subsequence positions for packed examples. decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. - train: whether it is training. Returns: @@ -642,7 +668,8 @@ def __call__(self, num_embeddings=self.output_vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), - name='output_vocab_embeddings') + name='output_vocab_embeddings', + ) else: output_embed = self.shared_embedding @@ -651,10 +678,8 @@ def __call__(self, y = shift_right(y) y = output_embed(y) y = AddPositionEmbs( - max_len=self.max_len, - decode=self.decode, - name='posembed_output')( - y, inputs_positions=targets_positions, train=train) + max_len=self.max_len, decode=self.decode, name='posembed_output' + )(y, inputs_positions=targets_positions, train=train) y = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(y) if self.use_bfloat16: @@ -675,15 +700,18 @@ def __call__(self, decode=self.decode, binarize_hparams=self.binarize_hparams, dynamic_context=self.dynamic_context, - name=f'encoderdecoderblock_{lyr}')( - y, - encoded, - decoder_mask=decoder_mask, - encoder_decoder_mask=encoder_decoder_mask, - train=train) + name=f'encoderdecoderblock_{lyr}', + )( + y, + encoded, + decoder_mask=decoder_mask, + encoder_decoder_mask=encoder_decoder_mask, + train=train, + ) if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=dtype) + self.normalizer, train, dtype=dtype + ) y = maybe_normalize()(y) # Decoded Logits @@ -699,7 +727,8 @@ def __call__(self, dtype=dtype, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), - name='logitdense')(y) + name='logitdense', + )(y) return logits @@ -712,31 +741,32 @@ def __call__(self, class Transformer(nn.Module): """Transformer Model for sequence to sequence translation. - vocab_size: size of the input vocabulary. - output_vocab_size: size of the output vocabulary. If None, the output - vocabulary size is assumed to be the same as vocab_size. - share_embeddings: bool: share embedding layer for inputs and targets. - logits_via_embedding: bool: whether final logit transform shares embedding - weights. - use_bfloat16: bool: whether use bfloat16. - emb_dim: dimension of embedding. - num_heads: number of heads. - enc_num_layers: number of encoder layers. - dec_num_layers: number of decoder layers. - qkv_dim: dimension of the query/key/value. - mlp_dim: dimension of the mlp on top of attention block. - max_len: maximum length. - dropout_rate: dropout rate. - attention_dropout_rate: dropout rate for attention weights. - normalizer: One of 'batch_norm', 'layer_norm', 'none' - enc_self_attn_kernel_init_fn: initializer for encoder's - self attention matrices. - dec_self_attn_kernel_init_fn: initializer for decoder's - self attention matrices. - dec_cross_attn_kernel_init_fn: initializer for decoder's - cross attention matrices. - decode: whether to use an autoregressive cache. + vocab_size: size of the input vocabulary. + output_vocab_size: size of the output vocabulary. If None, the output + vocabulary size is assumed to be the same as vocab_size. + share_embeddings: bool: share embedding layer for inputs and targets. + logits_via_embedding: bool: whether final logit transform shares embedding + weights. + use_bfloat16: bool: whether use bfloat16. + emb_dim: dimension of embedding. + num_heads: number of heads. + enc_num_layers: number of encoder layers. + dec_num_layers: number of decoder layers. + qkv_dim: dimension of the query/key/value. + mlp_dim: dimension of the mlp on top of attention block. + max_len: maximum length. + dropout_rate: dropout rate. + attention_dropout_rate: dropout rate for attention weights. + normalizer: One of 'batch_norm', 'layer_norm', 'none' + enc_self_attn_kernel_init_fn: initializer for encoder's + self attention matrices. + dec_self_attn_kernel_init_fn: initializer for decoder's + self attention matrices. + dec_cross_attn_kernel_init_fn: initializer for decoder's + cross attention matrices. + decode: whether to use an autoregressive cache. """ + vocab_size: Optional[int] = None output_vocab_size: Optional[int] = None share_embeddings: bool = False @@ -766,13 +796,15 @@ class Transformer(nn.Module): def setup(self): if self.share_embeddings: if self.output_vocab_size is not None: - assert self.output_vocab_size == self.vocab_size, ( - "can't share embedding with different vocab sizes.") + assert ( + self.output_vocab_size == self.vocab_size + ), "can't share embedding with different vocab sizes." self.shared_embedding = nn.Embed( num_embeddings=self.vocab_size, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), - name='VocabEmbeddings') + name='VocabEmbeddings', + ) else: self.shared_embedding = None @@ -792,7 +824,8 @@ def setup(self): enc_self_attn_kernel_init_fn=self.enc_self_attn_kernel_init_fn, binarize_hparams=self.binarize_hparams, dynamic_context=self.dynamic_context, - name='encoder') + name='encoder', + ) self.decoder = Decoder( output_vocab_size=self.output_vocab_size, shared_embedding=self.shared_embedding, @@ -812,17 +845,20 @@ def setup(self): decode=self.should_decode, binarize_hparams=self.binarize_hparams, dynamic_context=self.dynamic_context, - name='decoder') + name='decoder', + ) @nn.compact - def __call__(self, - inputs, - targets, - inputs_positions=None, - targets_positions=None, - inputs_segmentation=None, - targets_segmentation=None, - train=False): + def __call__( + self, + inputs, + targets, + inputs_positions=None, + targets_positions=None, + inputs_segmentation=None, + targets_segmentation=None, + train=False, + ): """Applies Transformer model on the inputs. Args: @@ -837,18 +873,22 @@ def __call__(self, Returns: Output: [batch_size, target_sequence_length, qkv_dim] """ - encoded = self.encode(inputs, - inputs_positions=inputs_positions, - inputs_segmentation=inputs_segmentation, - train=train) - - logits = self.decode(encoded, - inputs, # only used for masks - targets, - targets_positions=targets_positions, - inputs_segmentation=inputs_segmentation, - targets_segmentation=targets_segmentation, - train=train) + encoded = self.encode( + inputs, + inputs_positions=inputs_positions, + inputs_segmentation=inputs_segmentation, + train=train, + ) + + logits = self.decode( + encoded, + inputs, # only used for masks + targets, + targets_positions=targets_positions, + inputs_segmentation=inputs_segmentation, + targets_segmentation=targets_segmentation, + train=train, + ) return logits.astype(jnp.float32) if self.use_bfloat16 else logits # The following two methods allow us to run the trained Transformer in @@ -857,39 +897,39 @@ def __call__(self, # cache object for iteratively storing keys and values during the decoding # process. - def encode(self, - inputs, - inputs_positions=None, - inputs_segmentation=None, - train=False): + def encode( + self, inputs, inputs_positions=None, inputs_segmentation=None, train=False + ): # Make padding attention mask. dtype = jnp.bfloat16 if self.use_bfloat16 else jnp.float32 - encoder_mask = nn.make_attention_mask( - inputs > 0, inputs > 0, dtype=dtype) + encoder_mask = nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: encoder_mask = nn.combine_masks( encoder_mask, - nn.make_attention_mask(inputs_segmentation, - inputs_segmentation, - jnp.equal, - dtype=dtype)) + nn.make_attention_mask( + inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype + ), + ) encoded = self.encoder( inputs, inputs_positions=inputs_positions, encoder_mask=encoder_mask, - train=train) + train=train, + ) return encoded - def decode(self, - encoded, - inputs, - targets, - targets_positions=None, - inputs_segmentation=None, - targets_segmentation=None, - train=False): + def decode( + self, + encoded, + inputs, + targets, + targets_positions=None, + inputs_segmentation=None, + targets_segmentation=None, + train=False, + ): # Make padding attention masks. dtype = jnp.bfloat16 if self.use_bfloat16 else jnp.float32 if self.should_decode: @@ -897,28 +937,31 @@ def decode(self, # used. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( - jnp.ones_like(targets) > 0, inputs > 0, dtype=dtype) + jnp.ones_like(targets) > 0, inputs > 0, dtype=dtype + ) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=dtype), - nn.make_causal_mask(targets, dtype=dtype)) + nn.make_causal_mask(targets, dtype=dtype), + ) encoder_decoder_mask = nn.make_attention_mask( - targets > 0, inputs > 0, dtype=dtype) + targets > 0, inputs > 0, dtype=dtype + ) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, - nn.make_attention_mask(targets_segmentation, - targets_segmentation, - jnp.equal, - dtype=dtype)) + nn.make_attention_mask( + targets_segmentation, targets_segmentation, jnp.equal, dtype=dtype + ), + ) encoder_decoder_mask = nn.combine_masks( encoder_decoder_mask, - nn.make_attention_mask(targets_segmentation, - inputs_segmentation, - jnp.equal, - dtype=dtype)) + nn.make_attention_mask( + targets_segmentation, inputs_segmentation, jnp.equal, dtype=dtype + ), + ) logits = self.decoder( encoded, @@ -926,7 +969,8 @@ def decode(self, targets_positions=targets_positions, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, - train=train) + train=train, + ) return logits @@ -939,6 +983,7 @@ def __init__(self, hps, dataset_meta_data, loss_name, metrics_name): # TODO(ankugarg): Initialize cache for fast auto-regressive decoding here. # Also, initilaize tokenizer here to de-tokenize predicted logits # from beach search to target language sequence. + # pylint: disable=useless-super-delegation def evaluate_batch(self, params, batch_stats, batch, dynamic_context): @@ -946,7 +991,8 @@ def evaluate_batch(self, params, batch_stats, batch, dynamic_context): # replace the dynamic_context attribute across all layers in the model self.flax_module = dataclasses.replace( - self.flax_module, dynamic_context=dynamic_context) + self.flax_module, dynamic_context=dynamic_context + ) # TODO(ankugarg): Augment with other metrics like log-perplexity. logits = self.flax_module.apply( {'params': params, 'batch_stats': batch_stats}, @@ -956,7 +1002,8 @@ def evaluate_batch(self, params, batch_stats, batch, dynamic_context): targets_positions=batch.get('targets_positions'), inputs_segmentation=batch.get('inputs_segmentation'), targets_segmentation=batch.get('targets_segmentation'), - train=False) + train=False, + ) weights = batch.get('weights') targets = batch['targets'] @@ -983,22 +1030,26 @@ def apply_on_batch(self, params, batch_stats, batch, **apply_kwargs): batch.get('targets_positions'), batch.get('inputs_segmentation'), batch.get('targets_segmentation'), - **apply_kwargs) - - def training_cost(self, - params, - batch, - batch_stats=None, - dropout_rng=None, - dynamic_context=DynamicContext(), - teacher_params=None, - teacher_batch_stats=None, - teacher_model=None): + **apply_kwargs, + ) + + def training_cost( + self, + params, + batch, + batch_stats=None, + dropout_rng=None, + dynamic_context=DynamicContext(), + teacher_params=None, + teacher_batch_stats=None, + teacher_model=None, + ): """Return cross entropy loss with (optional) L2 penalty on the weights.""" # replace the dynamic_context attribute across all layers in the model self.flax_module = dataclasses.replace( - self.flax_module, dynamic_context=dynamic_context) + self.flax_module, dynamic_context=dynamic_context + ) # inputs/targets positions and segmentations are required when we have # packed examples. logits, new_batch_stats = self.apply_on_batch( @@ -1007,7 +1058,8 @@ def training_cost(self, batch, mutable=['batch_stats'], rngs={'dropout': dropout_rng}, - train=True) + train=True, + ) weights = batch.get('weights') targets = batch['targets'] @@ -1017,7 +1069,8 @@ def training_cost(self, # Optionally apply label smoothing. if self.hps.get('label_smoothing') is not None: targets = model_utils.apply_label_smoothing( - targets, self.hps.get('label_smoothing')) + targets, self.hps.get('label_smoothing') + ) if teacher_model is not None: # label smoothing is overwritten (always disabled) during distillation targets = teacher_model.flax_module.apply( @@ -1028,31 +1081,39 @@ def training_cost(self, batch.get('targets_positions'), batch.get('inputs_segmentation'), batch.get('targets_segmentation'), - train=False) + train=False, + ) targets = lax.stop_gradient(softmax(targets)) - (total_loss, total_weight) = self.loss_fn( - logits, targets, weights) + total_loss, total_weight = self.loss_fn(logits, targets, weights) - (total_loss, total_weight) = lax.psum( - (total_loss, total_weight), axis_name='batch') + total_loss, total_weight = lax.psum( + (total_loss, total_weight), axis_name='batch' + ) - total_loss = (total_loss / total_weight) + total_loss = total_loss / total_weight if self.hps.get('l2_decay_factor'): l2_loss = model_utils.l2_regularization( - params, self.hps.l2_decay_rank_threshold) + params, self.hps.l2_decay_rank_threshold + ) total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss return total_loss, (new_batch_stats) def build_flax_module(self): - max_len = max(self.hps.max_target_length, self.hps.max_eval_target_length, - self.hps.max_predict_length) + max_len = max( + self.hps.max_target_length, + self.hps.max_eval_target_length, + self.hps.max_predict_length, + ) enc_self_attn_kernel_init_fn = model_utils.INITIALIZERS[ - self.hps.enc_self_attn_kernel_init]() + self.hps.enc_self_attn_kernel_init + ]() dec_self_attn_kernel_init_fn = model_utils.INITIALIZERS[ - self.hps.dec_self_attn_kernel_init]() + self.hps.dec_self_attn_kernel_init + ]() dec_cross_attn_kernel_init_fn = model_utils.INITIALIZERS[ - self.hps.dec_cross_attn_kernel_init]() + self.hps.dec_cross_attn_kernel_init + ]() use_bfloat16 = self.hps.model_dtype == 'bfloat16' return Transformer( diff --git a/init2winit/model_lib/xformer_translate_mlc_variant.py b/init2winit/model_lib/xformer_translate_mlc_variant.py index 6cea66c1..631c4d8f 100644 --- a/init2winit/model_lib/xformer_translate_mlc_variant.py +++ b/init2winit/model_lib/xformer_translate_mlc_variant.py @@ -24,10 +24,10 @@ 3. hps.residual_scale: Insted of residual being x + F(x), we set it to a * x + (1 - a) * F(x). """ + import functools from typing import Any, Callable, Optional, Sequence from absl import logging - from flax import linen as nn from init2winit import utils from init2winit.model_lib import attention @@ -67,7 +67,8 @@ ffn_activation='relu', residual_scale=1.0, attn_temp=1.0, - )) + ) +) class Scannable(nn.Module): @@ -78,6 +79,7 @@ class Scannable(nn.Module): input to a layer is of the form x, *others where x is changed by the layer and *others are extra arguments static throughout the layers. """ + build_fn: Callable[[], nn.Module] train: False @@ -88,14 +90,14 @@ def __call__(self, x): """Applies the Module to inputs. Args: - x: the inputs to the module. It is assumed to be a tuple of pytrees. - The first element of the tuple is mapped by self.block into an output - of the same structure (e.g. the Decoder activations fed to the - Encoder-Decoder Multi-Head-Attention). - The other elements are static arguments used by self.block that would - stay the same if we apply multiple self.block's one after the other - (e.g. the encoder output used by the Encoder-Decoder + x: the inputs to the module. It is assumed to be a tuple of pytrees. The + first element of the tuple is mapped by self.block into an output of the + same structure (e.g. the Decoder activations fed to the Encoder-Decoder + Multi-Head-Attention). The other elements are static arguments used by + self.block that would stay the same if we apply multiple self.block's + one after the other (e.g. the encoder output used by the Encoder-Decoder Multi-Head-Attention). + Returns: self.block(x[0], *x[1:]), *x[1:]. """ @@ -117,7 +119,8 @@ def shift_right(x, axis=1): pad_widths = [(0, 0)] * len(x.shape) pad_widths[axis] = (1, 0) padded = jnp.pad( - x, pad_widths, mode='constant', constant_values=x.dtype.type(0)) + x, pad_widths, mode='constant', constant_values=x.dtype.type(0) + ) return padded[:, :-1] @@ -141,8 +144,8 @@ def init(key, shape, dtype=np.float32): position = np.arange(0, max_len)[:, np.newaxis] scale_factor = -np.log(max_scale / min_scale) / (d_feature // 2 - 1) div_term = min_scale * np.exp(np.arange(0, d_feature // 2) * scale_factor) - pe[:, :d_feature // 2] = np.sin(position * div_term) - pe[:, d_feature // 2:2 * (d_feature // 2)] = np.cos(position * div_term) + pe[:, : d_feature // 2] = np.sin(position * div_term) + pe[:, d_feature // 2 : 2 * (d_feature // 2)] = np.cos(position * div_term) pe = pe[np.newaxis, :, :] # [1, max_len, d_feature] return jnp.array(pe) @@ -157,16 +160,15 @@ class AddPositionEmbs(nn.Module): (non-learned) sinusoidal embedding table. decode: whether to use an autoregressive cache. """ + max_len: int = 512 posemb_init: Optional[model_utils.Initializer] = None decode: bool = False @nn.compact - def __call__(self, - inputs, - inputs_position=None, - train=True, - dtype=np.float32): + def __call__( + self, inputs, inputs_position=None, train=True, dtype=np.float32 + ): """Applies AddPositionEmbs module. By default this layer uses a fixed sinusoidal embedding table. If a @@ -184,22 +186,26 @@ def __call__(self, """ del train # inputs.shape is (batch_size, seq_len, emb_dim) - assert inputs.ndim == 3, ('Number of dimensions should be 3,' - ' but it is: %d' % inputs.ndim) + assert inputs.ndim == 3, ( + 'Number of dimensions should be 3, but it is: %d' % inputs.ndim + ) length = inputs.shape[1] pos_emb_shape = (1, self.max_len, inputs.shape[-1]) if self.posemb_init is None: # Use a fixed (non-learned) sinusoidal position embedding. - pos_embedding = sinusoidal_init(max_len=self.max_len)(None, pos_emb_shape, - dtype) + pos_embedding = sinusoidal_init(max_len=self.max_len)( + None, pos_emb_shape, dtype + ) else: pos_embedding = self.param( - 'pos_embedding', self.posemb_init, pos_emb_shape, dtype) + 'pos_embedding', self.posemb_init, pos_emb_shape, dtype + ) pe = pos_embedding[:, :length, :] if self.decode: is_initialized = self.has_variable('cache', 'cache_index') - cache_index = self.variable('cache', 'cache_index', - lambda: jnp.array(0, dtype=jnp.uint32)) + cache_index = self.variable( + 'cache', 'cache_index', lambda: jnp.array(0, dtype=jnp.uint32) + ) if is_initialized: i = cache_index.value cache_index.value = i + 1 @@ -215,6 +221,7 @@ def __call__(self, class MlpBlock(nn.Module): """Transformer MLP / feed-forward block.""" + mlp_dim: int dtype: model_utils.Dtype = jnp.float32 out_dim: Optional[int] = None @@ -233,8 +240,8 @@ def __call__(self, inputs, train): dtype=self.dtype, param_dtype=self.dtype, kernel_init=self.kernel_init, - bias_init=self.bias_init)( - inputs) + bias_init=self.bias_init, + )(inputs) activation_fn = model_utils.ACTIVATIONS[self.ffn_activation] x = activation_fn(x) @@ -244,8 +251,8 @@ def __call__(self, inputs, train): dtype=self.dtype, param_dtype=self.dtype, kernel_init=self.kernel_init, - bias_init=self.bias_init)( - inputs) + bias_init=self.bias_init, + )(inputs) x = x * y x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x) @@ -254,10 +261,9 @@ def __call__(self, inputs, train): dtype=self.dtype, param_dtype=self.dtype, kernel_init=self.kernel_init, - bias_init=self.bias_init)( - x) - output = nn.Dropout( - rate=self.dropout_rate, deterministic=not train)(output) + bias_init=self.bias_init, + )(x) + output = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(output) return output @@ -274,10 +280,11 @@ class Encoder1DBlock(nn.Module): normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' normalize_attention: Apply LayerNorm to query and key before computing - dot_product_attention. - enc_self_attn_kernel_init_fn: initializer for encoder's - self attention matrices. + dot_product_attention. + enc_self_attn_kernel_init_fn: initializer for encoder's self attention + matrices. """ + qkv_dim: int mlp_dim: int num_heads: int @@ -293,10 +300,7 @@ class Encoder1DBlock(nn.Module): attn_temp: float = 1.0 @nn.compact - def __call__(self, - inputs, - encoder_mask=None, - train=True): + def __call__(self, inputs, encoder_mask=None, train=True): """Applies Encoder1DBlock module. Args: @@ -310,16 +314,24 @@ def __call__(self, # Attention block. assert inputs.ndim == 3 if self.normalizer in [ - 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none']: + 'batch_norm', + 'layer_norm', + 'pre_layer_norm', + 'none', + ]: maybe_pre_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) else: raise ValueError('Unsupported normalizer: {}'.format(self.normalizer)) @@ -336,8 +348,8 @@ def __call__(self, dropout_rate=self.attention_dropout_rate, normalize_attention=self.normalize_attention, name='EncoderSelfAttention', - attn_temp=self.attn_temp)( - x, mask=encoder_mask, deterministic=not train) + attn_temp=self.attn_temp, + )(x, mask=encoder_mask, deterministic=not train) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) x = x * self.residual_scale + inputs @@ -352,7 +364,7 @@ def __call__(self, name='MLPBlock', ffn_activation=self.ffn_activation, glu=self.glu, - )(y, train=train) + )(y, train=train) res = x + y * self.residual_scale return maybe_post_normalize(param_dtype=self.dtype)(res) @@ -371,13 +383,14 @@ class EncoderDecoder1DBlock(nn.Module): normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', 'pre_layer_norm', 'none' normalize_attention: Apply LayerNorm to query and key before computing - dot_product_attention. - dec_self_attn_kernel_init_fn: initializer for decoder's - self attention matrices. - dec_cross_attn_kernel_init_fn: initializer for decoder's - cross attention matrices. + dot_product_attention. + dec_self_attn_kernel_init_fn: initializer for decoder's self attention + matrices. + dec_cross_attn_kernel_init_fn: initializer for decoder's cross attention + matrices. decode: whether to use an autoregressive cache. """ + qkv_dim: int mlp_dim: int num_heads: int @@ -395,12 +408,14 @@ class EncoderDecoder1DBlock(nn.Module): attn_temp: float = 1.0 @nn.compact - def __call__(self, - targets, - encoded, - decoder_mask=None, - encoder_decoder_mask=None, - train=True): + def __call__( + self, + targets, + encoded, + decoder_mask=None, + encoder_decoder_mask=None, + train=True, + ): """Applies EncoderDecoder1DBlock module. Args: @@ -416,16 +431,24 @@ def __call__(self, # Decoder block. assert targets.ndim == 3 if self.normalizer in [ - 'batch_norm', 'layer_norm', 'pre_layer_norm', 'none']: + 'batch_norm', + 'layer_norm', + 'pre_layer_norm', + 'none', + ]: maybe_pre_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) elif self.normalizer == 'post_layer_norm': maybe_pre_normalize = model_utils.get_normalizer( - 'none', train, dtype=self.dtype) + 'none', train, dtype=self.dtype + ) maybe_post_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) else: raise ValueError('Unsupported normalizer: {}'.format(self.normalizer)) @@ -443,8 +466,8 @@ def __call__(self, decode=self.decode, name='DecoderSelfAttention', normalize_attention=self.normalize_attention, - attn_temp=self.attn_temp)( - x, decoder_mask, deterministic=not train) + attn_temp=self.attn_temp, + )(x, decoder_mask, deterministic=not train) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) x = x * self.residual_scale + targets @@ -462,11 +485,10 @@ def __call__(self, broadcast_dropout=False, dropout_rate=self.attention_dropout_rate, normalize_attention=self.normalize_attention, - attn_temp=self.attn_temp)( - y, encoded, encoder_decoder_mask, deterministic=not train) + attn_temp=self.attn_temp, + )(y, encoded, encoder_decoder_mask, deterministic=not train) - y = nn.Dropout(rate=self.dropout_rate)( - y, deterministic=not train) + y = nn.Dropout(rate=self.dropout_rate)(y, deterministic=not train) y = y * self.residual_scale + x y = maybe_post_normalize(param_dtype=self.dtype)(y) @@ -479,7 +501,7 @@ def __call__(self, name='MLPBlock', ffn_activation=self.ffn_activation, glu=self.glu, - )(z, train=train) + )(z, train=train) res = y + z * self.residual_scale return maybe_post_normalize(param_dtype=self.dtype)(res) @@ -488,27 +510,28 @@ def __call__(self, class Encoder(nn.Module): """Transformer Model Encoder for sequence to sequence translation. - vocab_size: size of the vocabulary - shared_embedding: a shared embedding layer to use. - dtype: the jnp.dtype for the model parameters. - emb_dim: dimension of embedding - num_heads: number of heads - enc_num_layers: number of layers. It is ignored if enc_remat_scan_lengths - is not None. - qkv_dim: dimension of the query/key/value - mlp_dim: dimension of the mlp on top of attention block - max_len: maximum length. - dropout_rate: dropout rate - normalizer: One of 'batch_norm', 'layer_norm', 'none' - normalize_attention: Apply LayerNorm to query and key before computing - dot_product_attention. - attention_dropout_rate: dropout rate for attention weights - enc_self_attn_kernel_init_fn: initializer for encoder's - self attention matrices. - enc_remat_scan_lengths: if not None, it is the sequence of lengths to use - for remat_scan. See flax.linen.remat_scan; in this case this - defines the total number of layers, not enc_num_layers. + vocab_size: size of the vocabulary + shared_embedding: a shared embedding layer to use. + dtype: the jnp.dtype for the model parameters. + emb_dim: dimension of embedding + num_heads: number of heads + enc_num_layers: number of layers. It is ignored if enc_remat_scan_lengths + is not None. + qkv_dim: dimension of the query/key/value + mlp_dim: dimension of the mlp on top of attention block + max_len: maximum length. + dropout_rate: dropout rate + normalizer: One of 'batch_norm', 'layer_norm', 'none' + normalize_attention: Apply LayerNorm to query and key before computing + dot_product_attention. + attention_dropout_rate: dropout rate for attention weights + enc_self_attn_kernel_init_fn: initializer for encoder's + self attention matrices. + enc_remat_scan_lengths: if not None, it is the sequence of lengths to use + for remat_scan. See flax.linen.remat_scan; in this case this + defines the total number of layers, not enc_num_layers. """ + vocab_size: int shared_embedding: Any = None dtype: jnp.dtype = jnp.float32 @@ -530,11 +553,9 @@ class Encoder(nn.Module): attn_temp: float = 1.0 @nn.compact - def __call__(self, - inputs, - inputs_position=None, - encoder_mask=None, - train=True): + def __call__( + self, inputs, inputs_position=None, encoder_mask=None, train=True + ): """Applies Transformer model on the inputs. Args: @@ -556,14 +577,15 @@ def __call__(self, param_dtype=self.dtype, features=self.emb_dim, embedding_init=nn.initializers.normal(stddev=1.0), - name='input_vocab_embeddings') + name='input_vocab_embeddings', + ) else: input_embed = self.shared_embedding x = inputs.astype('int32') x = input_embed(x) x = AddPositionEmbs( - max_len=self.max_len, decode=False, name='posembed_input')( - x, inputs_position=inputs_position, train=train, dtype=self.dtype) + max_len=self.max_len, decode=False, name='posembed_input' + )(x, inputs_position=inputs_position, train=train, dtype=self.dtype) x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train) # Input encoder. @@ -582,22 +604,25 @@ def __call__(self, glu=self.glu, residual_scale=self.residual_scale, attn_temp=self.attn_temp, - ) + ) if self.enc_remat_scan_lengths is None: for lyr in range(self.enc_num_layers): x = build_fn(name=f'encoderblock_{lyr}')( - x, encoder_mask=encoder_mask, train=train) + x, encoder_mask=encoder_mask, train=train + ) else: - logging.info('Using Remat Scan, ignoring enc_num_layers; ' - 'number of layers=%d', np.prod(self.enc_remat_scan_lengths)) - enc_stack = nn.remat_scan( - Scannable, lengths=self.enc_remat_scan_lengths)(build_fn=build_fn, - train=train, - name='EncoderStack') + logging.info( + 'Using Remat Scan, ignoring enc_num_layers; number of layers=%d', + np.prod(self.enc_remat_scan_lengths), + ) + enc_stack = nn.remat_scan(Scannable, lengths=self.enc_remat_scan_lengths)( + build_fn=build_fn, train=train, name='EncoderStack' + ) x = enc_stack((x, encoder_mask))[0] if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) x = maybe_normalize(param_dtype=self.dtype)(x) return x @@ -605,33 +630,34 @@ def __call__(self, class Decoder(nn.Module): """Transformer Model Decoder for sequence to sequence translation. - output_vocab_size: size of the vocabulary. - shared_embedding: a shared embedding layer to use. - logits_via_embedding: bool: whether final logit transform shares embedding - weights. - dtype: the jnp.dtype for the model parameters. - emb_dim: dimension of embedding. - num_heads: number of heads. - dec_num_layers: number of layers. It is ignored if dec_remat_scan_lengths - is not None. - qkv_dim: dimension of the query/key/value. - mlp_dim: dimension of the mlp on top of attention block. - max_len: maximum length. - decode: whether to use an autoregressive cache. - dropout_rate: dropout rate. - normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', - 'pre_layer_norm', 'none' - normalize_attention: Apply LayerNorm to query and key before computing - dot_product_attention. - attention_dropout_rate: dropout rate for attention weights. - dec_self_attn_kernel_init_fn: initializer for decoder's - self attention matrices. - dec_cross_attn_kernel_init_fn: initializer for decoder's - cross attention matrices. - dec_remat_scan_lengths: if not None, it is the sequence of lengths to use - for remat_scan. See flax.linen.remat_scan; in this case this - defines the total number of layers, not dec_num_layers. + output_vocab_size: size of the vocabulary. + shared_embedding: a shared embedding layer to use. + logits_via_embedding: bool: whether final logit transform shares embedding + weights. + dtype: the jnp.dtype for the model parameters. + emb_dim: dimension of embedding. + num_heads: number of heads. + dec_num_layers: number of layers. It is ignored if dec_remat_scan_lengths + is not None. + qkv_dim: dimension of the query/key/value. + mlp_dim: dimension of the mlp on top of attention block. + max_len: maximum length. + decode: whether to use an autoregressive cache. + dropout_rate: dropout rate. + normalizer: One of 'batch_norm', 'layer_norm', 'post_layer_norm', + 'pre_layer_norm', 'none' + normalize_attention: Apply LayerNorm to query and key before computing + dot_product_attention. + attention_dropout_rate: dropout rate for attention weights. + dec_self_attn_kernel_init_fn: initializer for decoder's + self attention matrices. + dec_cross_attn_kernel_init_fn: initializer for decoder's + cross attention matrices. + dec_remat_scan_lengths: if not None, it is the sequence of lengths to use + for remat_scan. See flax.linen.remat_scan; in this case this + defines the total number of layers, not dec_num_layers. """ + output_vocab_size: int shared_embedding: Any = None logits_via_embedding: bool = False @@ -656,13 +682,15 @@ class Decoder(nn.Module): ffn_activation: str = 'relu' @nn.compact - def __call__(self, - encoded, - targets, - targets_position=None, - decoder_mask=None, - encoder_decoder_mask=None, - train=True): + def __call__( + self, + encoded, + targets, + targets_position=None, + decoder_mask=None, + encoder_decoder_mask=None, + train=True, + ): """Applies Transformer model on the inputs. Args: @@ -671,7 +699,6 @@ def __call__(self, targets_position: input subsequence positions for packed examples. decoder_mask: decoder self-attention mask. encoder_decoder_mask: encoder-decoder attention mask. - train: whether it is training. Returns: @@ -688,7 +715,8 @@ def __call__(self, dtype=self.dtype, param_dtype=self.dtype, embedding_init=nn.initializers.normal(stddev=1.0), - name='output_vocab_embeddings') + name='output_vocab_embeddings', + ) else: output_embed = self.shared_embedding @@ -697,11 +725,8 @@ def __call__(self, y = shift_right(y) y = output_embed(y) y = AddPositionEmbs( - max_len=self.max_len, decode=self.decode, name='posembed_output')( - y, - inputs_position=targets_position, - train=train, - dtype=self.dtype) + max_len=self.max_len, decode=self.decode, name='posembed_output' + )(y, inputs_position=targets_position, train=train, dtype=self.dtype) y = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(y) # Target-Input Decoder @@ -721,7 +746,7 @@ def __call__(self, glu=self.glu, residual_scale=self.residual_scale, attn_temp=self.attn_temp, - ) + ) if self.dec_remat_scan_lengths is None: for lyr in range(self.dec_num_layers): @@ -730,20 +755,23 @@ def __call__(self, encoded, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, - train=train) + train=train, + ) else: - logging.info('Using Remat Scan, ignoring enc_num_layers; ' - 'number of layers=%d', np.prod(self.dec_remat_scan_lengths)) - dec_stack = nn.remat_scan( - Scannable, lengths=self.dec_remat_scan_lengths)(build_fn=build_fn, - train=train, - name='DecoderStack') + logging.info( + 'Using Remat Scan, ignoring enc_num_layers; number of layers=%d', + np.prod(self.dec_remat_scan_lengths), + ) + dec_stack = nn.remat_scan(Scannable, lengths=self.dec_remat_scan_lengths)( + build_fn=build_fn, train=train, name='DecoderStack' + ) if decoder_mask is not None: decoder_mask = decoder_mask.astype(self.dtype) y = dec_stack((y, encoded, decoder_mask, encoder_decoder_mask))[0] if self.normalizer in ['batch_norm', 'layer_norm', 'pre_layer_norm']: maybe_normalize = model_utils.get_normalizer( - self.normalizer, train, dtype=self.dtype) + self.normalizer, train, dtype=self.dtype + ) y = maybe_normalize(param_dtype=self.dtype)(y) # Decoded Logits @@ -760,8 +788,8 @@ def __call__(self, param_dtype=self.dtype, kernel_init=nn.initializers.xavier_uniform(), bias_init=nn.initializers.normal(stddev=1e-6), - name='logitdense')( - y) + name='logitdense', + )(y) return logits @@ -774,37 +802,38 @@ def __call__(self, class Transformer(nn.Module): """Transformer Model for sequence to sequence translation. - vocab_size: size of the input vocabulary. - output_vocab_size: size of the output vocabulary. If None, the output - vocabulary size is assumed to be the same as vocab_size. - share_embeddings: bool: share embedding layer for inputs and targets. - logits_via_embedding: bool: whether final logit transform shares embedding - weights. - dtype: the jnp.dtype for the model parameters. - emb_dim: dimension of embedding. - num_heads: number of heads. - enc_num_layers: number of encoder layers. - enc_remat_scan_lengths: Optional sequence of lengths to use with - flax.linen.remat_scan. - dec_num_layers: number of decoder layers. - dec_remat_scan_lengths: Optional sequence of lengths to use with - flax.linen.remat_scan. - qkv_dim: dimension of the query/key/value. - mlp_dim: dimension of the mlp on top of attention block. - max_len: maximum length. - dropout_rate: dropout rate. - attention_dropout_rate: dropout rate for attention weights. - normalizer: One of 'batch_norm', 'layer_norm', 'none' - normalize_attention: Apply LayerNorm to query and key before computing - dot_product_attention. - enc_self_attn_kernel_init_fn: initializer for encoder's - self attention matrices. - dec_self_attn_kernel_init_fn: initializer for decoder's - self attention matrices. - dec_cross_attn_kernel_init_fn: initializer for decoder's - cross attention matrices. - decode: whether to use an autoregressive cache. + vocab_size: size of the input vocabulary. + output_vocab_size: size of the output vocabulary. If None, the output + vocabulary size is assumed to be the same as vocab_size. + share_embeddings: bool: share embedding layer for inputs and targets. + logits_via_embedding: bool: whether final logit transform shares embedding + weights. + dtype: the jnp.dtype for the model parameters. + emb_dim: dimension of embedding. + num_heads: number of heads. + enc_num_layers: number of encoder layers. + enc_remat_scan_lengths: Optional sequence of lengths to use with + flax.linen.remat_scan. + dec_num_layers: number of decoder layers. + dec_remat_scan_lengths: Optional sequence of lengths to use with + flax.linen.remat_scan. + qkv_dim: dimension of the query/key/value. + mlp_dim: dimension of the mlp on top of attention block. + max_len: maximum length. + dropout_rate: dropout rate. + attention_dropout_rate: dropout rate for attention weights. + normalizer: One of 'batch_norm', 'layer_norm', 'none' + normalize_attention: Apply LayerNorm to query and key before computing + dot_product_attention. + enc_self_attn_kernel_init_fn: initializer for encoder's + self attention matrices. + dec_self_attn_kernel_init_fn: initializer for decoder's + self attention matrices. + dec_cross_attn_kernel_init_fn: initializer for decoder's + cross attention matrices. + decode: whether to use an autoregressive cache. """ + vocab_size: Optional[int] = None output_vocab_size: Optional[int] = None share_embeddings: bool = False @@ -834,26 +863,32 @@ class Transformer(nn.Module): def setup(self): if self.enc_num_layers and self.enc_remat_scan_lengths: - raise ValueError(f'Only one of enc_num_layers ({self.enc_num_layers})' - 'and enc_remat_scan_lengths' - f'({self.enc_remat_scan_lengths}) can be set.') + raise ValueError( + f'Only one of enc_num_layers ({self.enc_num_layers})' + 'and enc_remat_scan_lengths' + f'({self.enc_remat_scan_lengths}) can be set.' + ) if self.dec_num_layers and self.dec_remat_scan_lengths: - raise ValueError(f'Only one of dec_num_layers ({self.dec_num_layers})' - 'and dec_remat_scan_lengths' - f'({self.dec_remat_scan_lengths}) can be set.') + raise ValueError( + f'Only one of dec_num_layers ({self.dec_num_layers})' + 'and dec_remat_scan_lengths' + f'({self.dec_remat_scan_lengths}) can be set.' + ) if self.share_embeddings: if self.output_vocab_size is not None: - assert self.output_vocab_size == self.vocab_size, ( - "can't share embedding with different vocab sizes.") + assert ( + self.output_vocab_size == self.vocab_size + ), "can't share embedding with different vocab sizes." self.shared_embedding = nn.Embed( num_embeddings=self.vocab_size, features=self.emb_dim, dtype=self.dtype, param_dtype=self.dtype, embedding_init=nn.initializers.normal(stddev=1.0), - name='VocabEmbeddings') + name='VocabEmbeddings', + ) else: self.shared_embedding = None @@ -878,7 +913,7 @@ def setup(self): glu=self.glu, residual_scale=self.residual_scale, attn_temp=self.attn_temp, - ) + ) self.decoder = Decoder( output_vocab_size=self.output_vocab_size, shared_embedding=self.shared_embedding, @@ -903,17 +938,19 @@ def setup(self): glu=self.glu, residual_scale=self.residual_scale, attn_temp=self.attn_temp, - ) + ) @nn.compact - def __call__(self, - inputs, - targets, - inputs_position=None, - targets_position=None, - inputs_segmentation=None, - targets_segmentation=None, - train=False): + def __call__( + self, + inputs, + targets, + inputs_position=None, + targets_position=None, + inputs_segmentation=None, + targets_segmentation=None, + train=False, + ): """Applies Transformer model on the inputs. Args: @@ -928,18 +965,22 @@ def __call__(self, Returns: Output: [batch_size, target_sequence_length, qkv_dim] """ - encoded = self.encode(inputs, - inputs_position=inputs_position, - inputs_segmentation=inputs_segmentation, - train=train) - - logits = self.decode(encoded, - inputs, # only used for masks - targets, - targets_position=targets_position, - inputs_segmentation=inputs_segmentation, - targets_segmentation=targets_segmentation, - train=train) + encoded = self.encode( + inputs, + inputs_position=inputs_position, + inputs_segmentation=inputs_segmentation, + train=train, + ) + + logits = self.decode( + encoded, + inputs, # only used for masks + targets, + targets_position=targets_position, + inputs_segmentation=inputs_segmentation, + targets_segmentation=targets_segmentation, + train=train, + ) return logits # The following two methods allow us to run the trained Transformer in @@ -948,39 +989,39 @@ def __call__(self, # cache object for iteratively storing keys and values during the decoding # process. - def encode(self, - inputs, - inputs_position=None, - inputs_segmentation=None, - train=False): + def encode( + self, inputs, inputs_position=None, inputs_segmentation=None, train=False + ): # Make padding attention mask. dtype = self.dtype - encoder_mask = nn.make_attention_mask( - inputs > 0, inputs > 0, dtype=dtype) + encoder_mask = nn.make_attention_mask(inputs > 0, inputs > 0, dtype=dtype) # Add segmentation block-diagonal attention mask if using segmented data. if inputs_segmentation is not None: encoder_mask = nn.combine_masks( encoder_mask, - nn.make_attention_mask(inputs_segmentation, - inputs_segmentation, - jnp.equal, - dtype=dtype)) + nn.make_attention_mask( + inputs_segmentation, inputs_segmentation, jnp.equal, dtype=dtype + ), + ) encoded = self.encoder( inputs, inputs_position=inputs_position, encoder_mask=encoder_mask, - train=train) + train=train, + ) return encoded - def decode(self, - encoded, - inputs, - targets, - targets_position=None, - inputs_segmentation=None, - targets_segmentation=None, - train=False): + def decode( + self, + encoded, + inputs, + targets, + targets_position=None, + inputs_segmentation=None, + targets_segmentation=None, + train=False, + ): # Make padding attention masks. dtype = self.dtype if self.should_decode: @@ -988,28 +1029,31 @@ def decode(self, # used. decoder_mask = None encoder_decoder_mask = nn.make_attention_mask( - jnp.ones_like(targets) > 0, inputs > 0, dtype=dtype) + jnp.ones_like(targets) > 0, inputs > 0, dtype=dtype + ) else: decoder_mask = nn.combine_masks( nn.make_attention_mask(targets > 0, targets > 0, dtype=dtype), - nn.make_causal_mask(targets, dtype=dtype)) + nn.make_causal_mask(targets, dtype=dtype), + ) encoder_decoder_mask = nn.make_attention_mask( - targets > 0, inputs > 0, dtype=dtype) + targets > 0, inputs > 0, dtype=dtype + ) # Add segmentation block-diagonal attention masks if using segmented data. if inputs_segmentation is not None: decoder_mask = nn.combine_masks( decoder_mask, - nn.make_attention_mask(targets_segmentation, - targets_segmentation, - jnp.equal, - dtype=dtype)) + nn.make_attention_mask( + targets_segmentation, targets_segmentation, jnp.equal, dtype=dtype + ), + ) encoder_decoder_mask = nn.combine_masks( encoder_decoder_mask, - nn.make_attention_mask(targets_segmentation, - inputs_segmentation, - jnp.equal, - dtype=dtype)) + nn.make_attention_mask( + targets_segmentation, inputs_segmentation, jnp.equal, dtype=dtype + ), + ) logits = self.decoder( encoded, @@ -1017,7 +1061,8 @@ def decode(self, targets_position=targets_position, decoder_mask=decoder_mask, encoder_decoder_mask=encoder_decoder_mask, - train=train) + train=train, + ) return logits @@ -1037,14 +1082,12 @@ def evaluate_batch(self, params, batch_stats, batch): # Add log-perplexity metric. return self.metrics_bundle.single_from_model_output( - logits=logits, targets=targets, weights=weights) - - def apply_on_batch(self, - params, - batch_stats, - batch, - train=True, - **apply_kwargs): + logits=logits, targets=targets, weights=weights + ) + + def apply_on_batch( + self, params, batch_stats, batch, train=True, **apply_kwargs + ): """Wrapper around flax_module.apply.""" variables = {'params': params} if batch_stats is not None: @@ -1078,7 +1121,8 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): batch, mutable=['batch_stats'], rngs={'dropout': dropout_rng}, - train=True) + train=True, + ) weights = batch.get('weights') targets = batch['targets'] @@ -1088,30 +1132,38 @@ def training_cost(self, params, batch, batch_stats=None, dropout_rng=None): # Optionally apply label smoothing. if self.hps.get('label_smoothing') is not None: targets = model_utils.apply_label_smoothing( - targets, self.hps.get('label_smoothing')) - (total_loss, total_weight) = self.loss_fn( - logits, targets, weights) + targets, self.hps.get('label_smoothing') + ) + total_loss, total_weight = self.loss_fn(logits, targets, weights) - (total_loss, total_weight) = lax.psum( - (total_loss, total_weight), axis_name='batch') + total_loss, total_weight = lax.psum( + (total_loss, total_weight), axis_name='batch' + ) - total_loss = (total_loss / total_weight) + total_loss = total_loss / total_weight if self.hps.get('l2_decay_factor'): l2_loss = model_utils.l2_regularization( - params, self.hps.l2_decay_rank_threshold) + params, self.hps.l2_decay_rank_threshold + ) total_loss += 0.5 * self.hps.l2_decay_factor * l2_loss return total_loss, (new_batch_stats) def build_flax_module(self): - max_len = max(self.hps.max_target_length, self.hps.max_eval_target_length, - self.hps.max_predict_length) + max_len = max( + self.hps.max_target_length, + self.hps.max_eval_target_length, + self.hps.max_predict_length, + ) enc_self_attn_kernel_init_fn = model_utils.INITIALIZERS[ - self.hps.enc_self_attn_kernel_init]() + self.hps.enc_self_attn_kernel_init + ]() dec_self_attn_kernel_init_fn = model_utils.INITIALIZERS[ - self.hps.dec_self_attn_kernel_init]() + self.hps.dec_self_attn_kernel_init + ]() dec_cross_attn_kernel_init_fn = model_utils.INITIALIZERS[ - self.hps.dec_cross_attn_kernel_init]() + self.hps.dec_cross_attn_kernel_init + ]() dtype = utils.dtype_from_str(self.hps.model_dtype) aux_dropout_rate = ( self.hps.dropout_rate diff --git a/init2winit/mt_eval/decode.py b/init2winit/mt_eval/decode.py index 34531674..4a6254d9 100644 --- a/init2winit/mt_eval/decode.py +++ b/init2winit/mt_eval/decode.py @@ -102,7 +102,7 @@ def unflatten_beam_dim(x, batch_size, beam_size, offset: int = 0): return x assert batch_size * beam_size == x.shape[offset] xshape = list(x.shape) - newshape = xshape[:offset] + [batch_size, beam_size] + xshape[offset + 1:] + newshape = xshape[:offset] + [batch_size, beam_size] + xshape[offset + 1 :] return x.reshape(newshape) @@ -111,8 +111,9 @@ def flat_batch_beam_expand(x, beam_size, offset: int = 0): return flatten_beam_dim(add_beam_dim(x, beam_size, offset), offset) -def gather_beams(nested, beam_indices, batch_size, new_beam_size, - offset: int = 0): +def gather_beams( + nested, beam_indices, batch_size, new_beam_size, offset: int = 0 +): """Gathers the beam slices indexed by beam_indices into new beam array. Args: @@ -128,8 +129,10 @@ def gather_beams(nested, beam_indices, batch_size, new_beam_size, """ batch_indices = jnp.reshape( jnp.arange(batch_size * new_beam_size) // new_beam_size, - (batch_size, new_beam_size)) + (batch_size, new_beam_size), + ) assert offset < 4, 'scan_over_layers_offset >= 4 is not supported' + def gather_fn(x): if is_scalar(x, offset): return x @@ -143,6 +146,7 @@ def gather_fn(x): return x[:, :, batch_indices, beam_indices] else: return x[:, :, :, batch_indices, beam_indices] + return jax.tree.map(gather_fn, nested) @@ -169,6 +173,7 @@ def gather_topk_beams(nested, score_or_log_prob, batch_size, new_beam_size): @flax.struct.dataclass class BeamState: """Holds beam search state data.""" + # The position of the decoding loop in the length dimension. cur_index: jax.Array # scalar int32: current decoded length index # The active sequence log probabilities and finished sequence scores. @@ -188,6 +193,7 @@ class BeamState: @flax.struct.dataclass class SamplingState: """Holds sampling state data.""" + # The position of the decoding loop in the length dimension. cur_index: jax.Array # scalar int32: current decoded length index # The active sequence probabilities and finished sequence scores. @@ -202,45 +208,43 @@ class SamplingState: cache: Any # Any pytree of arrays, e.g. flax attention Cache object -def beam_init(batch_size, - beam_size, - max_decode_len, - cache, - offset: int = 0): +def beam_init(batch_size, beam_size, max_decode_len, cache, offset: int = 0): """Initializes the beam search state data structure.""" cur_index0 = jnp.array(0) live_logprobs0 = jnp.tile( - jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), - [batch_size, 1]) + jnp.array([0.0] + [NEG_INF] * (beam_size - 1)), [batch_size, 1] + ) finished_scores0 = jnp.ones((batch_size, beam_size)) * NEG_INF - live_seqs0 = jnp.zeros( - (batch_size, beam_size, max_decode_len), jnp.int32) - finished_seqs0 = jnp.zeros( - (batch_size, beam_size, max_decode_len), jnp.int32) + live_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) + finished_seqs0 = jnp.zeros((batch_size, beam_size, max_decode_len), jnp.int32) finished_flags0 = jnp.zeros((batch_size, beam_size), jnp.bool_) # add beam dimension to attention cache pytree elements - beam_cache0 = jax.tree.map(lambda x: add_beam_dim(x, beam_size, offset), - cache) - return BeamState(cur_index=cur_index0, - live_logprobs=live_logprobs0, - finished_scores=finished_scores0, - live_seqs=live_seqs0, - finished_seqs=finished_seqs0, - finished_flags=finished_flags0, - cache=beam_cache0) - - -def sampling_init(batch_size: int, - sample_size: int, - max_decode_len: int, - cache): + beam_cache0 = jax.tree.map( + lambda x: add_beam_dim(x, beam_size, offset), cache + ) + return BeamState( + cur_index=cur_index0, + live_logprobs=live_logprobs0, + finished_scores=finished_scores0, + live_seqs=live_seqs0, + finished_seqs=finished_seqs0, + finished_flags=finished_flags0, + cache=beam_cache0, + ) + + +def sampling_init( + batch_size: int, sample_size: int, max_decode_len: int, cache +): """Initializes the sampling state data structure.""" cur_index0 = jnp.array(0) all_log_probs0 = jnp.tile( - jnp.array([0.0] + [NEG_INF] * (sample_size - 1)), [batch_size, 1]) + jnp.array([0.0] + [NEG_INF] * (sample_size - 1)), [batch_size, 1] + ) all_seqs0 = jnp.zeros((batch_size, sample_size, max_decode_len), jnp.int32) - finished_seqs0 = jnp.zeros((batch_size, sample_size, max_decode_len), - jnp.int32) + finished_seqs0 = jnp.zeros( + (batch_size, sample_size, max_decode_len), jnp.int32 + ) finished_flags0 = jnp.zeros((batch_size, sample_size), jnp.bool_) # add sample dimension to attention cache pytree elements sample_cache0 = jax.tree.map(lambda x: add_beam_dim(x, sample_size), cache) @@ -251,18 +255,21 @@ def sampling_init(batch_size: int, all_seqs=all_seqs0, finished_seqs=finished_seqs0, finished_flags=finished_flags0, - cache=sample_cache0) + cache=sample_cache0, + ) # Beam search routine: -def beam_search(inputs, - cache, - tokens_to_logits, - beam_size=4, - alpha=0.6, - eos_id=EOS_ID, - max_decode_len=None, - offset: int = 0): +def beam_search( + inputs, + cache, + tokens_to_logits, + beam_size=4, + alpha=0.6, + eos_id=EOS_ID, + max_decode_len=None, + offset: int = 0, +): """Beam search for transformer machine translation. Args: @@ -289,15 +296,14 @@ def beam_search(inputs, end_marker = jnp.array(eos_id) # initialize beam search state - beam_search_init_state = beam_init(batch_size, - beam_size, - max_decode_len, - cache, offset) + beam_search_init_state = beam_init( + batch_size, beam_size, max_decode_len, cache, offset + ) def beam_search_loop_cond_fn(state): """Beam search loop termination condition.""" # Have we reached max decoding length? - not_at_end = (state.cur_index < max_decode_len - 1) + not_at_end = state.cur_index < max_decode_len - 1 # Is no further progress in the beam search possible? # Get the best possible scores from alive sequences. @@ -305,10 +311,12 @@ def beam_search_loop_cond_fn(state): best_live_scores = state.live_logprobs[:, -1:] / min_brevity_penalty # Get the worst scores from finished sequences. worst_finished_scores = jnp.min( - state.finished_scores, axis=1, keepdims=True) + state.finished_scores, axis=1, keepdims=True + ) # Mask out scores from slots without any actual finished sequences. worst_finished_scores = jnp.where( - state.finished_flags, worst_finished_scores, NEG_INF) + state.finished_flags, worst_finished_scores, NEG_INF + ) # If no best possible live score is better than current worst finished # scores, the search cannot improve the finished set further. search_terminated = jnp.all(worst_finished_scores > best_live_scores) @@ -324,12 +332,15 @@ def beam_search_loop_body_fn(state): # dimension for feeding into the model. # --> [batch * beam, 1] flat_ids = flatten_beam_dim( - lax.dynamic_slice(state.live_seqs, (0, 0, state.cur_index), - (batch_size, beam_size, 1))) + lax.dynamic_slice( + state.live_seqs, (0, 0, state.cur_index), (batch_size, beam_size, 1) + ) + ) # Flatten beam dimension into batch to be compatible with model. # {[batch, beam, ...], ...} --> {[batch * beam, ...], ...} - flat_cache = jax.tree.map(functools.partial(flatten_beam_dim, - offset=offset), state.cache) + flat_cache = jax.tree.map( + functools.partial(flatten_beam_dim, offset=offset), state.cache + ) # Call fast-decoder model on current tokens to get next-position logits. # --> [batch * beam, vocab] @@ -350,8 +361,9 @@ def unflatten_beam_dim_in_cache(x): candidate_log_probs = jax.nn.log_softmax(logits) # Add new logprobs to existing prefix logprobs. # --> [batch, beam, vocab] - log_probs = ( - candidate_log_probs + jnp.expand_dims(state.live_logprobs, axis=2)) + log_probs = candidate_log_probs + jnp.expand_dims( + state.live_logprobs, axis=2 + ) # We'll need the vocab size, gather it from the log probability dimension. vocab_size = log_probs.shape[2] @@ -371,8 +383,9 @@ def unflatten_beam_dim_in_cache(x): topk_beam_indices = topk_indices // vocab_size # Gather 2*k top beams and beam-associated caches. # --> [batch, 2*beams, length], {[batch, 2*beams, ...], ...} - topk_seq = gather_beams(state.live_seqs, topk_beam_indices, batch_size, - beams_to_keep) + topk_seq = gather_beams( + state.live_seqs, topk_beam_indices, batch_size, beams_to_keep + ) # Append the most probable 2*K token IDs to the top 2*K sequences # Recover token id by modulo division and expand Id array for broadcasting. @@ -380,13 +393,14 @@ def unflatten_beam_dim_in_cache(x): topk_ids = jnp.expand_dims(topk_indices % vocab_size, axis=2) # Update sequences for the 2*K top-k new sequences. # --> [batch, 2*beams, length] - topk_seq = lax.dynamic_update_slice(topk_seq, topk_ids, - (0, 0, state.cur_index + 1)) + topk_seq = lax.dynamic_update_slice( + topk_seq, topk_ids, (0, 0, state.cur_index + 1) + ) # Update LIVE (in-progress) sequences: # Did any of these sequences reach an end marker? # --> [batch, 2*beams] - newly_finished = (topk_seq[:, :, state.cur_index + 1] == end_marker) + newly_finished = topk_seq[:, :, state.cur_index + 1] == end_marker # To prevent these newly finished sequences from being added to the LIVE # set of active beam search sequences, set their log probs to a very large # negative value. @@ -394,15 +408,17 @@ def unflatten_beam_dim_in_cache(x): # Gather top-k beams. _, new_topk_indices = lax.top_k(new_log_probs, k=beam_size) # --> [batch, beams, length], [batch, beams], {[batch, beams, ...], ...} - top_alive_seq, top_alive_log_probs = gather_beams([topk_seq, new_log_probs], - new_topk_indices, - batch_size, beam_size) + top_alive_seq, top_alive_log_probs = gather_beams( + [topk_seq, new_log_probs], new_topk_indices, batch_size, beam_size + ) - top_alive_indices = gather_beams(topk_beam_indices, new_topk_indices, - batch_size, beam_size) + top_alive_indices = gather_beams( + topk_beam_indices, new_topk_indices, batch_size, beam_size + ) # Apply offset to the cache - top_alive_cache = gather_beams(new_cache, top_alive_indices, batch_size, - beam_size, offset=offset) + top_alive_cache = gather_beams( + new_cache, top_alive_indices, batch_size, beam_size, offset=offset + ) # Update FINISHED (reached end of sentence) sequences: # Calculate new seq scores from log probabilities. @@ -415,16 +431,23 @@ def unflatten_beam_dim_in_cache(x): # new finished sequence scores to existing finished scores and select the # best from the new set of beams. finished_seqs = jnp.concatenate( # --> [batch, 3*beams, length] - [state.finished_seqs, topk_seq], - axis=1) + [state.finished_seqs, topk_seq], axis=1 + ) finished_scores = jnp.concatenate( # --> [batch, 3*beams] - [state.finished_scores, new_scores], axis=1) + [state.finished_scores, new_scores], axis=1 + ) finished_flags = jnp.concatenate( # --> [batch, 3*beams] - [state.finished_flags, newly_finished], axis=1) + [state.finished_flags, newly_finished], axis=1 + ) # --> [batch, beams, length], [batch, beams], [batch, beams] top_finished_seq, top_finished_scores, top_finished_flags = ( - gather_topk_beams([finished_seqs, finished_scores, finished_flags], - finished_scores, batch_size, beam_size)) + gather_topk_beams( + [finished_seqs, finished_scores, finished_flags], + finished_scores, + batch_size, + beam_size, + ) + ) return BeamState( cur_index=state.cur_index + 1, @@ -433,36 +456,45 @@ def unflatten_beam_dim_in_cache(x): live_seqs=top_alive_seq, finished_seqs=top_finished_seq, finished_flags=top_finished_flags, - cache=top_alive_cache) + cache=top_alive_cache, + ) # Run while loop and get final beam search state. - final_state = lax.while_loop(beam_search_loop_cond_fn, - beam_search_loop_body_fn, beam_search_init_state) + final_state = lax.while_loop( + beam_search_loop_cond_fn, beam_search_loop_body_fn, beam_search_init_state + ) # Account for the edge-case where there are no finished sequences for a # particular batch item. If so, return live sequences for that batch item. # --> [batch] none_finished = jnp.any(final_state.finished_flags, axis=1) # --> [batch, beams, length] - finished_seqs = jnp.where(none_finished[:, None, None], - final_state.finished_seqs, final_state.live_seqs) + finished_seqs = jnp.where( + none_finished[:, None, None], + final_state.finished_seqs, + final_state.live_seqs, + ) # --> [batch, beams] - finished_scores = jnp.where(none_finished[:, - None], final_state.finished_scores, - final_state.live_logprobs) + finished_scores = jnp.where( + none_finished[:, None], + final_state.finished_scores, + final_state.live_logprobs, + ) return finished_seqs, finished_scores -def sampling(inputs: jax.Array, - cache: Any, - tokens_to_logits: Callable[..., Tuple[jax.Array, Any]], - rng: Union[jnp.ndarray, np.ndarray, int], - sample_size: int, - eos_id: int, - max_decode_len: Optional[int], - temperature: Optional[int], - rescale_log_probs: Optional[int]): +def sampling( + inputs: jax.Array, + cache: Any, + tokens_to_logits: Callable[..., Tuple[jax.Array, Any]], + rng: Union[jnp.ndarray, np.ndarray, int], + sample_size: int, + eos_id: int, + max_decode_len: Optional[int], + temperature: Optional[int], + rescale_log_probs: Optional[int], +): """Sampling for transformer machine translation. Args: @@ -475,7 +507,7 @@ def sampling(inputs: jax.Array, eos_id: int: id of end-of-sentence token for target vocabulary. max_decode_len: int: maximum length of decoded translations. temperature: float: sampling temperature. temp ~ 0.0 approaches greedy - sampling. temp = 1.0 means no-op. temp >> 1 approaches uniform sampling. + sampling. temp = 1.0 means no-op. temp >> 1 approaches uniform sampling. rescale_log_probs: bool: whether to apply temperature, topp, and topk rescaling to the log probs which are returned. If True, the log_probs will include these transformations (for example, with topk=1, all log_probs @@ -494,13 +526,14 @@ def sampling(inputs: jax.Array, end_marker = jnp.array(eos_id) # initialize sampling search state - sampling_init_state = sampling_init(batch_size, sample_size, max_decode_len, - cache) + sampling_init_state = sampling_init( + batch_size, sample_size, max_decode_len, cache + ) def sampling_loop_cond_fn(state: SamplingState): """Sampling loop termination condition.""" # Have we reached max decoding length? - not_at_end = (state.cur_index < max_decode_len - 1) + not_at_end = state.cur_index < max_decode_len - 1 return not_at_end def sampling_loop_body_fn(state: SamplingState): @@ -511,8 +544,12 @@ def sampling_loop_body_fn(state: SamplingState): # --> [batch * sample, 1] # Using beam flattening function. flat_ids = flatten_beam_dim( - lax.dynamic_slice(state.all_seqs, (0, 0, state.cur_index), - (batch_size, sample_size, 1))) + lax.dynamic_slice( + state.all_seqs, + (0, 0, state.cur_index), + (batch_size, sample_size, 1), + ) + ) # Flatten sample dimension into batch to be compatible with model. # {[batch, sample, ...], ...} --> {[batch * sample, ...], ...} flat_cache = jax.tree.map(flatten_beam_dim, state.cache) @@ -524,13 +561,14 @@ def sampling_loop_body_fn(state: SamplingState): # Unflatten sample dimension in attention cache arrays # {[batch * sample, ...], ...} --> {[batch, sample, ...], ...} new_cache = jax.tree.map( - lambda x: unflatten_beam_dim(x, batch_size, sample_size), - new_flat_cache) + lambda x: unflatten_beam_dim(x, batch_size, sample_size), new_flat_cache + ) def sample_logits_with_nonzero_temperature(flat_logits_to_sample): # TODO(ankugarg): Implement top-p and top-k aka nucleus sampling here. scaled_logits = flat_logits_to_sample / jnp.maximum( - temperature, MIN_TEMPERATURE) + temperature, MIN_TEMPERATURE + ) # rngs = jax.random.split(rng, batch_size * sample_size + 1) sampled_ids = jax.random.categorical(rng, scaled_logits).astype(jnp.int32) if rescale_log_probs: @@ -540,8 +578,10 @@ def sample_logits_with_nonzero_temperature(flat_logits_to_sample): # [batch * sample size, vocab] -> [batch * sample_size] sampled_log_probs = jnp.squeeze( jnp.take_along_axis( - log_probs, jnp.expand_dims(sampled_ids, axis=1), axis=-1), - axis=-1) + log_probs, jnp.expand_dims(sampled_ids, axis=1), axis=-1 + ), + axis=-1, + ) return (sampled_ids, sampled_log_probs) def sample_logits_with_zero_temperature(flat_logits_to_sample): @@ -555,15 +595,18 @@ def sample_logits_with_zero_temperature(flat_logits_to_sample): # [batch * sample size, vocab] -> [batch * sample_size] sampled_log_probs = jnp.squeeze( jnp.take_along_axis( - log_probs, jnp.expand_dims(sampled_ids, axis=1), axis=-1), - axis=-1) + log_probs, jnp.expand_dims(sampled_ids, axis=1), axis=-1 + ), + axis=-1, + ) return (sampled_ids, sampled_log_probs) - (sampled_ids, - sampled_log_probs) = lax.cond(temperature > MIN_TEMPERATURE, - sample_logits_with_nonzero_temperature, - sample_logits_with_zero_temperature, - flat_logits) + sampled_ids, sampled_log_probs = lax.cond( + temperature > MIN_TEMPERATURE, + sample_logits_with_nonzero_temperature, + sample_logits_with_zero_temperature, + flat_logits, + ) # Reshape. sampled_ids = sampled_ids.reshape(batch_size, sample_size) @@ -573,27 +616,32 @@ def sample_logits_with_zero_temperature(flat_logits_to_sample): new_log_probs = sampled_log_probs + state.all_log_probs # Need to get the new sequences now # [batch, sample, length] - updated_seq = lax.dynamic_update_slice(state.all_seqs, sampled_ids, - (0, 0, state.cur_index + 1)) + updated_seq = lax.dynamic_update_slice( + state.all_seqs, sampled_ids, (0, 0, state.cur_index + 1) + ) # Did any of these sequences reach an end marker? --> [batch, sample] - newly_finished_flag = (updated_seq[:, :, state.cur_index + 1] == end_marker) + newly_finished_flag = updated_seq[:, :, state.cur_index + 1] == end_marker # Make sure not to add already finished sequence again to the list of # finished sequences. newly_finished_flag = jnp.where( - state.finished_flags, jnp.zeros_like(newly_finished_flag, jnp.bool_), - newly_finished_flag) + state.finished_flags, + jnp.zeros_like(newly_finished_flag, jnp.bool_), + newly_finished_flag, + ) # new set of finished sequences finished_seqs = state.finished_seqs - finished_seqs = jnp.where(newly_finished_flag[:, :, None], updated_seq, - finished_seqs) + finished_seqs = jnp.where( + newly_finished_flag[:, :, None], updated_seq, finished_seqs + ) # create final finished flags that combine sequences finished on this # iteration with sequences finished in the past. - finished_flags = jnp.where(newly_finished_flag, newly_finished_flag, - state.finished_flags) + finished_flags = jnp.where( + newly_finished_flag, newly_finished_flag, state.finished_flags + ) return SamplingState( cur_index=state.cur_index + 1, @@ -601,26 +649,33 @@ def sample_logits_with_zero_temperature(flat_logits_to_sample): all_seqs=updated_seq, finished_seqs=finished_seqs, finished_flags=finished_flags, - cache=new_cache) + cache=new_cache, + ) # Run while loop and get final sampling result. - final_state = lax.while_loop(sampling_loop_cond_fn, sampling_loop_body_fn, - sampling_init_state) - - finished_seqs = jnp.where(final_state.finished_flags[:, :, None], - final_state.finished_seqs, final_state.all_seqs) + final_state = lax.while_loop( + sampling_loop_cond_fn, sampling_loop_body_fn, sampling_init_state + ) + + finished_seqs = jnp.where( + final_state.finished_flags[:, :, None], + final_state.finished_seqs, + final_state.all_seqs, + ) # Ignore the first token in each sequence. return finished_seqs[:, :, 1:] -def decode_step(batch, - params, - cache, - max_decode_len, - flax_module, - eos_id=EOS_ID, - beam_size=4, - offset: int = 0): +def decode_step( + batch, + params, + cache, + max_decode_len, + flax_module, + eos_id=EOS_ID, + beam_size=4, + offset: int = 0, +): """Predict translation with fast decoding beam search on a batch.""" inputs = batch['inputs'] @@ -631,10 +686,8 @@ def decode_step(batch, # i.e. if we denote each batch element subtensor as el[n]: # [el0, el1, el2] --> beamsize=2 --> [el0,el0,el1,el1,el2,el2] encoded_inputs = flax_module.apply( - {'params': params}, - inputs, - train=False, - method=flax_module.encode) + {'params': params}, inputs, train=False, method=flax_module.encode + ) # Inputs don't need an offset in case of scan over layers. encoded_inputs = flat_batch_beam_expand(encoded_inputs, beam_size, offset=0) raw_inputs = flat_batch_beam_expand(inputs, beam_size, offset=0) @@ -651,7 +704,8 @@ def tokens_ids_to_logits(flat_ids, flat_cache): targets=flat_ids, train=False, mutable=['cache'], - method=flax_module.decode) + method=flax_module.decode, + ) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch * beam, 1, vocab] --> [batch * beam, vocab] @@ -668,7 +722,8 @@ def tokens_ids_to_logits(flat_ids, flat_cache): alpha=0.6, eos_id=eos_id, max_decode_len=max_decode_len, - offset=offset) + offset=offset, + ) # Beam search returns [n_batch, n_beam, n_length + 1] with beam dimension # sorted in increasing order of log-probability. @@ -676,17 +731,18 @@ def tokens_ids_to_logits(flat_ids, flat_cache): return beam_seqs[:, -1, 1:] -def sampling_step(batch: np.ndarray, - params: Mapping[str, Any], - cache: Mapping[str, Any], - max_decode_len: int, - rng: Union[jnp.ndarray, np.ndarray, int], - flax_module, - eos_id=EOS_ID, - sample_size: int = 20, - temperature: int = 1.0, - rescale_log_probs: int = 1, - ): +def sampling_step( + batch: np.ndarray, + params: Mapping[str, Any], + cache: Mapping[str, Any], + max_decode_len: int, + rng: Union[jnp.ndarray, np.ndarray, int], + flax_module, + eos_id=EOS_ID, + sample_size: int = 20, + temperature: int = 1.0, + rescale_log_probs: int = 1, +): """Performs one step of sampling.""" inputs = batch['inputs'] # Prepare transformer fast-decoder call for sampling: for sampling we @@ -695,10 +751,9 @@ def sampling_step(batch: np.ndarray, # rather than tiled. # i.e. if we denote each batch element subtensor as el[n]: # [el0, el1, el2] --> sample_size=2 --> [el0,el0,el1,el1,el2,el2] - encoded_inputs = flax_module.apply({'params': params}, - inputs, - train=False, - method=flax_module.encode) + encoded_inputs = flax_module.apply( + {'params': params}, inputs, train=False, method=flax_module.encode + ) encoded_inputs = flat_batch_beam_expand(encoded_inputs, sample_size) raw_inputs = flat_batch_beam_expand(inputs, sample_size) @@ -714,7 +769,8 @@ def tokens_ids_to_logits(flat_ids, flat_cache): targets=flat_ids, train=False, mutable=['cache'], - method=flax_module.decode) + method=flax_module.decode, + ) new_flat_cache = new_vars['cache'] # Remove singleton sequence-length dimension: # [batch_size * sample_size, 1, vocab] --> [batch_size * sample_size, vocab] @@ -732,8 +788,8 @@ def tokens_ids_to_logits(flat_ids, flat_cache): eos_id=eos_id, max_decode_len=max_decode_len, temperature=temperature, - rescale_log_probs=rescale_log_probs) + rescale_log_probs=rescale_log_probs, + ) # Sampling returns [batch_size, sampling_size, max_predict_length -1]. return sampling_seqs - diff --git a/init2winit/mt_eval/eval_utils.py b/init2winit/mt_eval/eval_utils.py index 427d221d..87f79836 100644 --- a/init2winit/mt_eval/eval_utils.py +++ b/init2winit/mt_eval/eval_utils.py @@ -29,9 +29,12 @@ def compute_bleu_from_predictions(predictions, references, language_code, name): """Computes BLEU score given predictions and references.""" - sacrebleu_tokenizer = 'zh' if language_code == 'zh' else sacrebleu.DEFAULT_TOKENIZER + sacrebleu_tokenizer = ( + 'zh' if language_code == 'zh' else sacrebleu.DEFAULT_TOKENIZER + ) bleu_score = sacrebleu.corpus_bleu( - predictions, [references], tokenize=sacrebleu_tokenizer).score + predictions, [references], tokenize=sacrebleu_tokenizer + ).score return {name: bleu_score} @@ -60,10 +63,8 @@ def save_evals(ckpt_dir, ckpt_step, eval_split, bleu_score): def _load_checkpoint(checkpoint_path, params): """Load model (and batch stats) from checkpoint.""" target = dict( - params=params, - global_step=-1, - preemption_count=0, - sum_train_cost=0.0) + params=params, global_step=-1, preemption_count=0, sum_train_cost=0.0 + ) ckpt = checkpoint.load_checkpoint( checkpoint_path, target=target, @@ -78,9 +79,7 @@ def average_checkpoints(checkpoint_paths, params): # Sum parameters of separate models together. params = _load_checkpoint(checkpoint_paths[0], params) for checkpoint_path in checkpoint_paths[1:]: - params_update = _load_checkpoint( - checkpoint_path, params - ) + params_update = _load_checkpoint(checkpoint_path, params) # TODO(dxin): Make this averaging process more numerically stable. params = jax.tree.map(lambda x, y: x + y, params, params_update) diff --git a/init2winit/mt_eval/inference.py b/init2winit/mt_eval/inference.py index aa148bc8..5a92544e 100644 --- a/init2winit/mt_eval/inference.py +++ b/init2winit/mt_eval/inference.py @@ -64,8 +64,9 @@ class InferenceManager(object): def __init__(self, *args, **kwargs): if kwargs['mode'] not in ['offline', 'online']: - raise ValueError('BLEU score computation only support online or ' - 'offline modes.') + raise ValueError( + 'BLEU score computation only support online or offline modes.' + ) self.mesh = kwargs['mesh'] self.finalize_batch_fn = kwargs['finalize_batch_fn'] if kwargs['mode'] == 'offline': @@ -73,14 +74,16 @@ def __init__(self, *args, **kwargs): else: self.init_online_evaluator(*args) - def init_offline_evaluator(self, - checkpoint_dir, - hps, - rng, - model_cls, - dataset, - dataset_meta_data, - mt_eval_config): + def init_offline_evaluator( + self, + checkpoint_dir, + hps, + rng, + model_cls, + dataset, + dataset_meta_data, + mt_eval_config, + ): """Utility for initializing offline BLEU evaluator.""" self.checkpoint_dir = checkpoint_dir self.hps = hps @@ -97,13 +100,9 @@ def init_offline_evaluator(self, self.encoder = self.load_tokenizer(hps.vocab_path) self.initialize_model(model_cls, dataset_meta_data, dropout_rng, params_rng) - def init_online_evaluator(self, - hps, - rng, - model_cls, - dataset, - dataset_metadata, - mt_eval_config): + def init_online_evaluator( + self, hps, rng, model_cls, dataset, dataset_metadata, mt_eval_config + ): """Utility for initializing online BLEU evaluator.""" self.hps = hps self.eos_id = decode.EOS_ID @@ -150,8 +149,9 @@ def load_tokenizer(self, vocab_path): return mt_tokenizer._load_sentencepiece_tokenizer(vocab_path) # pylint: enable=protected-access - def initialize_model(self, model_cls, dataset_meta_data, dropout_rng, - params_rng): + def initialize_model( + self, model_cls, dataset_meta_data, dropout_rng, params_rng + ): """Initialie model, especially cache for fast auto-regressive decoding.""" loss_name = 'cross_entropy' metrics_name = 'classification_metrics' @@ -162,9 +162,11 @@ def initialize_model(self, model_cls, dataset_meta_data, dropout_rng, model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) xs = [np.zeros((self.hps.batch_size, *x)) for x in self.hps.input_shape] model_init_fn = jax.jit( - functools.partial(model.flax_module.init, train=False)) + functools.partial(model.flax_module.init, train=False) + ) init_dict = model_init_fn( - {'params': params_rng, 'dropout': dropout_rng}, *xs) + {'params': params_rng, 'dropout': dropout_rng}, *xs + ) params = init_dict['params'] self.flax_module = model.flax_module self.params = params @@ -173,7 +175,9 @@ def initialize_model(self, model_cls, dataset_meta_data, dropout_rng, self.initialize_cache, max_length=self.max_length, params_rng=params_rng, - dropout_rng=dropout_rng)) + dropout_rng=dropout_rng, + ) + ) def initialize_cache(self, inputs, max_length, params_rng, dropout_rng): """Initialize a cache for a given input shape and max decode length.""" @@ -181,16 +185,16 @@ def initialize_cache(self, inputs, max_length, params_rng, dropout_rng): targets_shape = (inputs.shape[0], max_length) + inputs.shape[2:] model_init_fn = jax.jit( - functools.partial(self.flax_module.init, train=False)) + functools.partial(self.flax_module.init, train=False) + ) xs = [jnp.ones(inputs.shape), jnp.ones(targets_shape)] - init_dict = model_init_fn({ - 'params': params_rng, - 'dropout': dropout_rng - }, *xs) + init_dict = model_init_fn( + {'params': params_rng, 'dropout': dropout_rng}, *xs + ) return init_dict['cache'] def decode_tokens(self, toks): - valid_toks = toks[:np.argmax(toks == self.eos_id) + 1].astype(np.int32) + valid_toks = toks[: np.argmax(toks == self.eos_id) + 1].astype(np.int32) return self.encoder.detokenize(valid_toks).numpy().decode('utf-8') def current_batch_size(self, batch): @@ -208,7 +212,8 @@ def build_predictor(self): eos_id=self.eos_id, sample_size=self.mt_eval_config.get('sample_size'), temperature=self.mt_eval_config.get('temperature'), - rescale_log_probs=self.mt_eval_config.get('rescale_log_probs')) + rescale_log_probs=self.mt_eval_config.get('rescale_log_probs'), + ) else: decoder = functools.partial( decode.decode_step, @@ -216,7 +221,8 @@ def build_predictor(self): flax_module=self.flax_module, eos_id=self.eos_id, beam_size=self.mt_eval_config.get('beam_size'), - offset=self.mt_eval_config.get('scan_over_layers_offset', 0)) + offset=self.mt_eval_config.get('scan_over_layers_offset', 0), + ) self.predictor = jax.jit(decoder) def translate_and_calculate_bleu(self): @@ -228,16 +234,19 @@ def translate_and_calculate_bleu(self): ckpt_paths = eval_utils.get_checkpoints_in_range( checkpoint_dir=self.checkpoint_dir, lower_bound=step - self.ckpt_avg_window, - upper_bound=step) + upper_bound=step, + ) logging.info('Current checkpoints: %s', ckpt_paths) params = eval_utils.average_checkpoints( - checkpoint_paths=ckpt_paths, - params=self.params) + checkpoint_paths=ckpt_paths, params=self.params + ) _, params_replicated = utils.shard_pytree(params, self.mesh) decoding_output = self.translate_and_calculate_bleu_single_model( - params_replicated, self.eval_split) - logging.info('Sacre bleu score at step %d: %f', step, - decoding_output.bleu_score) + params_replicated, self.eval_split + ) + logging.info( + 'Sacre bleu score at step %d: %f', step, decoding_output.bleu_score + ) decoding_outputs.append(decoding_output) return decoding_outputs @@ -256,32 +265,34 @@ def translate_and_calculate_bleu_single_model(self, params, eval_split): weights = pred_batch['weights'] current_batch_size = int(weights[..., 0].sum()) if self.mt_eval_config.get('decoding_type') == 'beam_search': - self.process_beam_search_output(inputs, targets, predicted, - current_batch_size, decode_output) + self.process_beam_search_output( + inputs, targets, predicted, current_batch_size, decode_output + ) else: - self.process_sampling_output(inputs, targets, predicted, - current_batch_size, decode_output) + self.process_sampling_output( + inputs, targets, predicted, current_batch_size, decode_output + ) - logging.info('Predictions: %d References %d Sources %d.', - len(decode_output.translation_list), - len(decode_output.reference_list), - len(decode_output.source_list)) + logging.info( + 'Predictions: %d References %d Sources %d.', + len(decode_output.translation_list), + len(decode_output.reference_list), + len(decode_output.source_list), + ) if self.mt_eval_config.get('decoding_type') == 'beam_search': bleu_score = eval_utils.compute_bleu_from_predictions( decode_output.translation_list, decode_output.reference_list, self.mt_eval_config.get('tl_code'), - 'sacrebleu')['sacrebleu'] + 'sacrebleu', + )['sacrebleu'] decode_output.bleu_score = bleu_score decode_output.decoding_type = self.mt_eval_config.get('decoding_type') return decode_output - def process_beam_search_output(self, - inputs, - targets, - predicted, - batch_size, - decode_output): + def process_beam_search_output( + self, inputs, targets, predicted, batch_size, decode_output + ): """Process output if its beam search decoding.""" # Non-final dimensions are batch dimensions. @@ -296,12 +307,9 @@ def process_beam_search_output(self, decode_output.reference_list.append(curr_ref) decode_output.translation_list.append(curr_pred) - def process_sampling_output(self, - inputs, - targets, - predicted, - batch_size, - decode_output): + def process_sampling_output( + self, inputs, targets, predicted, batch_size, decode_output + ): """Process output if its sampling decoding.""" # Non-final dimensions are batch dimensions. diff --git a/init2winit/mt_eval/mt_callback.py b/init2winit/mt_eval/mt_callback.py index fd0cfdbe..2e7b1af3 100644 --- a/init2winit/mt_eval/mt_callback.py +++ b/init2winit/mt_eval/mt_callback.py @@ -50,12 +50,22 @@ from ml_collections.config_dict import config_dict import numpy as np - _REQUIRED_KEYS = [ - 'dataset_name', 'model_name', 'tfds_dataset_key', 'tfds_eval_dataset_key', - 'tfds_predict_dataset_key', 'reverse_translation', 'eval_batch_size', - 'eval_train_num_batches', 'eval_num_batches', 'eval_splits', - 'max_decode_length', 'tl_code', 'beam_size', 'decoding_type'] + 'dataset_name', + 'model_name', + 'tfds_dataset_key', + 'tfds_eval_dataset_key', + 'tfds_predict_dataset_key', + 'reverse_translation', + 'eval_batch_size', + 'eval_train_num_batches', + 'eval_num_batches', + 'eval_splits', + 'max_decode_length', + 'tl_code', + 'beam_size', + 'decoding_type', +] _SPLITS = ['train', 'valid', 'test'] @@ -115,16 +125,20 @@ def __init__( def _validate_callback_config(self): assert all(key in self.callback_config for key in _REQUIRED_KEYS), ( - 'callback config must contain these required keys:', _REQUIRED_KEYS) - assert ('vocab_path' not in self.callback_config), ( + 'callback config must contain these required keys:', + _REQUIRED_KEYS, + ) + assert 'vocab_path' not in self.callback_config, ( 'Eval must use same vocab file as used in training. No need to specify' ' vocab file. One from training config will be used.' ) assert all( split_name in set(_SPLITS) for split_name in self.callback_config['eval_splits'] - ), ('callback_config.eval_splits must contain only subset of these splits:', - _SPLITS) + ), ( + 'callback_config.eval_splits must contain only subset of these splits:', + _SPLITS, + ) def _get_dataset(self, hps, rng): """Sets ups dataset builders.""" @@ -134,19 +148,17 @@ def _get_dataset(self, hps, rng): dataset_builder = datasets.get_dataset(self.callback_config['dataset_name']) dataset_metadata = datasets.get_dataset_meta_data( - self.callback_config['dataset_name']) + self.callback_config['dataset_name'] + ) dataset = dataset_builder( rng, hparams.batch_size, eval_batch_size=self.callback_config['eval_batch_size'], - hps=hparams) + hps=hparams, + ) return dataset, dataset_metadata - def _evaluate(self, - params, - batch_stats, - batch_iter, - evaluate_batch_jitted): + def _evaluate(self, params, batch_stats, batch_iter, evaluate_batch_jitted): """Compute aggregated metrics on the given data iterator. This function is taken as is from trainer.py to avoid circular dependency. @@ -155,8 +167,7 @@ def _evaluate(self, Args: params: model params. batch_stats: A dict of batch_stats. - batch_iter: Generator which yields batches. Must support the API - for b in batch_iter: + batch_iter: Generator which yields batches. evaluate_batch_jitted: A function with API evaluate_batch_jitted(params, batch_stats, batch). Returns a dictionary mapping keys to the metric values across the sharded batch. @@ -199,11 +210,10 @@ def _merge_and_apply_prefix(self, d1, d2, prefix): """ d1 = d1.copy() for key in d2: - d1[prefix+key] = d2[key] + d1[prefix + key] = d2[key] return d1 - def run_eval( - self, params, batch_stats, optimizer_state, global_step): + def run_eval(self, params, batch_stats, optimizer_state, global_step): """Runs the MT models to evals specified by MT model. Args: @@ -225,13 +235,16 @@ def run_eval( for eval_split in self.callback_config['eval_splits']: if 'train' in eval_split: ds_splits_dict[eval_split] = self.dataset.eval_train_epoch( - self.callback_config['eval_train_num_batches']) + self.callback_config['eval_train_num_batches'] + ) elif 'valid' in eval_split: ds_splits_dict[eval_split] = self.dataset.valid_epoch( - self.callback_config['eval_num_batches']) + self.callback_config['eval_num_batches'] + ) else: ds_splits_dict[eval_split] = self.dataset.test_epoch( - self.callback_config['eval_num_batches']) + self.callback_config['eval_num_batches'] + ) metrics = {} @@ -240,19 +253,31 @@ def run_eval( try: decoding_output = ( self.inference_manager.translate_and_calculate_bleu_single_model( - params, split_name)) - split_metrics = self._evaluate(params, batch_stats, split_iter, - self.evaluate_batch_jitted) + params, split_name + ) + ) + split_metrics = self._evaluate( + params, batch_stats, split_iter, self.evaluate_batch_jitted + ) split_metrics['bleu_score'] = decoding_output.bleu_score metrics = self._merge_and_apply_prefix( - metrics, split_metrics, 'callback/' + - self.callback_config['tfds_dataset_key'] + '/' + split_name + '/') + metrics, + split_metrics, + 'callback/' + + self.callback_config['tfds_dataset_key'] + + '/' + + split_name + + '/', + ) except utils.TrainingDivergedError as err: # we don't want to stop training. del err - logging.info('Callback evaluation diverged for dataset %s at step:%d', - self.tfds_dataset_key, global_step) + logging.info( + 'Callback evaluation diverged for dataset %s at step:%d', + self.tfds_dataset_key, + global_step, + ) continue return metrics diff --git a/init2winit/optimizer_lib/factor_sam.py b/init2winit/optimizer_lib/factor_sam.py index a2d668fa..cfe5064b 100644 --- a/init2winit/optimizer_lib/factor_sam.py +++ b/init2winit/optimizer_lib/factor_sam.py @@ -37,7 +37,8 @@ def normalize_vector(y: jnp.ndarray) -> jnp.ndarray: y: A pytree of numpy ndarray, vector y in the equation above. """ gradient_norm = jnp.sqrt( - sum([jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)])) + sum([jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)]) + ) normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y) return normalized_gradient @@ -55,7 +56,7 @@ def sam_update( alpha=1.0, ): """SAM update function.""" - (grad_fn, params) = grad_fn_params_tuple + grad_fn, params = grad_fn_params_tuple updates = normalize_vector(updates) noised_params = jax.tree_util.tree_map( lambda p, u: p + rho * u, params, updates @@ -108,7 +109,7 @@ def init_fn(params): def update_fn(updates, state, grad_fn_params_tuple): # Updates here have been averaged across devices in Trainer before being # sent to the optimizer. - (_, params) = grad_fn_params_tuple + _, params = grad_fn_params_tuple # Update function in between SAM steps. intermediate_update_fn = clean_update @@ -128,9 +129,14 @@ def update_fn(updates, state, grad_fn_params_tuple): if grad_clip: updates_norm = jnp.sqrt(model_utils.l2_regularization(updates, 0)) scaled_updates = jax.tree.map( - lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, - lambda _: updates, None) + lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates + ) + updates = jax.lax.cond( + updates_norm > grad_clip, + lambda _: scaled_updates, + lambda _: updates, + None, + ) # TODO(thetish): Explore different order for base optimizer and SAM. For # example, in Adam preconditioning the SAM perturbation is helpful. return base_opt_update_fn(updates, state, params) # Apply base optimizer diff --git a/init2winit/optimizer_lib/gradient_accumulator.py b/init2winit/optimizer_lib/gradient_accumulator.py index 7dbe9005..509b2b69 100644 --- a/init2winit/optimizer_lib/gradient_accumulator.py +++ b/init2winit/optimizer_lib/gradient_accumulator.py @@ -42,6 +42,7 @@ Note that we only sync gradients when we are about to update the model, in order to avoid unnecessary cross replica communications. """ + from typing import NamedTuple, Optional from init2winit.optimizer_lib import utils as optimizer_utils @@ -52,6 +53,7 @@ class GradientAccumulatorState(NamedTuple): """State for the gradient accumulator.""" + base_state: NamedTuple # The state of the base optimizer. hyperparams: dict[str, jnp.ndarray] num_per_step_batches: jnp.ndarray # shape=(), dtype=jnp.int32. @@ -79,23 +81,29 @@ def accumulate_gradients( generate updates given the total gradient. base_opt_update_fn: The update function for the base optimizer used to generate updates given the total gradient. + Returns: An (init_fn, update_fn) tuple. """ - if (virtual_batch_size is not None and - virtual_batch_size > per_step_batch_size): + if ( + virtual_batch_size is not None + and virtual_batch_size > per_step_batch_size + ): raise ValueError( 'Gradient accumulation does not currently support using a virtual ' 'batch size ({}) that is larger than the per-step batch size ({}), as ' 'this would require multiple forward steps *up to each batch norm ' 'layer* in order to properly calculate the batch statistics necessary ' 'to simulate a larger batch size.'.format( - virtual_batch_size, per_step_batch_size)) + virtual_batch_size, per_step_batch_size + ) + ) if total_batch_size % per_step_batch_size != 0: raise ValueError( 'Need to step a per-step batch size ({}) that evenly divides the total ' - 'batch size ({}).'.format(per_step_batch_size, total_batch_size)) + 'batch size ({}).'.format(per_step_batch_size, total_batch_size) + ) steps_per_update = total_batch_size // per_step_batch_size @@ -105,7 +113,8 @@ def init_fn(params): base_state=base_state, hyperparams=base_state.hyperparams, num_per_step_batches=jnp.zeros([], jnp.int32), - accumulations=jax.tree.map(jnp.zeros_like, params)) + accumulations=jax.tree.map(jnp.zeros_like, params), + ) @optimizer_utils.no_cross_device_gradient_aggregation def update_fn(updates, state, params=None, **extra_args): @@ -119,15 +128,18 @@ def total_batch_update(total_gradients, params, state): # account any example weighting, which is rarely used for training # batches. total_gradients = jax.tree.map( - lambda x: x / steps_per_update, total_gradients) + lambda x: x / steps_per_update, total_gradients + ) updates, updated_base_state = base_opt_update_fn( - total_gradients, state.base_state, params=params, **extra_args) + total_gradients, state.base_state, params=params, **extra_args + ) reset_state = GradientAccumulatorState( base_state=updated_base_state, hyperparams=updated_base_state.hyperparams, num_per_step_batches=0, - accumulations=zeros_params) + accumulations=zeros_params, + ) return updates, reset_state def accumulation_continuation(updated_accumulations, _, state): @@ -135,18 +147,21 @@ def accumulation_continuation(updated_accumulations, _, state): base_state=state.base_state, hyperparams=state.base_state.hyperparams, num_per_step_batches=state.num_per_step_batches + 1, - accumulations=updated_accumulations) + accumulations=updated_accumulations, + ) return zeros_params, updated_state - updated_accumulations = jax.tree.map(lambda g, acc: g + acc, updates, - state.accumulations) + updated_accumulations = jax.tree.map( + lambda g, acc: g + acc, updates, state.accumulations + ) updates, state = jax.lax.cond( state.num_per_step_batches == steps_per_update - 1, total_batch_update, accumulation_continuation, updated_accumulations, params, - state) + state, + ) return updates, state return optax.GradientTransformation(init_fn, update_fn) diff --git a/init2winit/optimizer_lib/kitchen_sink/__init__.py b/init2winit/optimizer_lib/kitchen_sink/__init__.py index 92a44a83..ca87ebea 100644 --- a/init2winit/optimizer_lib/kitchen_sink/__init__.py +++ b/init2winit/optimizer_lib/kitchen_sink/__init__.py @@ -48,7 +48,6 @@ from init2winit.optimizer_lib.kitchen_sink._src.transform import ScaleByAMSGradState from init2winit.optimizer_lib.kitchen_sink._src.utils import unfreeze_wrapper - __version__ = '0.0.1' __all__ = ( diff --git a/init2winit/optimizer_lib/kitchen_sink/_src/alias.py b/init2winit/optimizer_lib/kitchen_sink/_src/alias.py index 347b2f32..4aa5abcb 100644 --- a/init2winit/optimizer_lib/kitchen_sink/_src/alias.py +++ b/init2winit/optimizer_lib/kitchen_sink/_src/alias.py @@ -14,6 +14,7 @@ # limitations under the License. """Aliases for optimizers not found in optax.""" + from typing import Any, Callable, Optional, Union from init2winit.optimizer_lib.kitchen_sink._src import transform import jax.numpy as jnp @@ -28,8 +29,9 @@ def nadamw( eps_root: float = 0.0, debias: bool = True, weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + weight_decay_mask: Optional[ + Union[Any, Callable[[optax.Params], Any]] + ] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. @@ -67,7 +69,8 @@ def nadamw( return optax.chain( transform.scale_by_nadam(b1, b2, eps, eps_root, debias), optax.add_decayed_weights(weight_decay, weight_decay_mask), - transform.scale_by_learning_rate(learning_rate)) + transform.scale_by_learning_rate(learning_rate), + ) def nadampw( @@ -79,8 +82,9 @@ def nadampw( debias: bool = True, power: float = 2.0, weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + weight_decay_mask: Optional[ + Union[Any, Callable[[optax.Params], Any]] + ] = None, ) -> optax.GradientTransformation: """Rescale updates according to the NAdam algorithm. @@ -120,7 +124,8 @@ def nadampw( return optax.chain( transform.scale_by_nadam(b1, b2, eps, eps_root, debias, power=power), optax.add_decayed_weights(weight_decay, weight_decay_mask), - transform.scale_by_learning_rate(learning_rate)) + transform.scale_by_learning_rate(learning_rate), + ) def adamw_generic( @@ -134,8 +139,9 @@ def adamw_generic( disable_preconditioning: bool = False, nesterov: bool = True, weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + weight_decay_mask: Optional[ + Union[Any, Callable[[optax.Params], Any]] + ] = None, disable_multiply_wd_by_base_lr: bool = False, base_lr_multiplier: float = 1.0, ) -> optax.GradientTransformation: @@ -215,25 +221,26 @@ def adapropw( use_nesterov: bool = True, quantized_dtype: str = 'float32', weight_decay: float = 0.0, - weight_decay_mask: Optional[Union[Any, Callable[[optax.Params], - Any]]] = None, + weight_decay_mask: Optional[ + Union[Any, Callable[[optax.Params], Any]] + ] = None, ) -> optax.GradientTransformation: """Rescale updates according to the AdaProp algorithm. Args: learning_rate: this is a fixed global scaling factor. b1: decay rate for the exponentially weighted average of grads. - b2: decay rate for the exponentially weighted average of absolute grads - is omitted because it is calculated from alpha and b1. + b2: decay rate for the exponentially weighted average of absolute grads is + omitted because it is calculated from alpha and b1. b3: decay rate for the exponentially weighted average of max grads. b4: decay rate for the exponentially weighted average of reward. eps: term added to the denominator to improve numerical stability. power: the power to use in the preconditioner (the value determines the power to which the absolute value of the grads are raised). use_nesterov: Whether to use Nesterov-style update. - quantized_dtype: type of the quantized input. Allowed options are - 'bfloat16' and 'float32'. If floating-point type is specified, - accumulators are stored as such type, instead of quantized integers. + quantized_dtype: type of the quantized input. Allowed options are 'bfloat16' + and 'float32'. If floating-point type is specified, accumulators are + stored as such type, instead of quantized integers. weight_decay: strength of the weight decay regularization. Note that this weight decay is multiplied with the learning rate. This is consistent with other frameworks such as PyTorch, but different from (Loshchilov et al, @@ -253,10 +260,16 @@ def adapropw( else: q_dtype = jnp.bfloat16 return optax.chain( - transform.scale_by_adaprop(b1=b1, b2=b2, b3=b3, b4=b4, - eps=eps, power=power, - use_nesterov=use_nesterov, - quantized_dtype=q_dtype), + transform.scale_by_adaprop( + b1=b1, + b2=b2, + b3=b3, + b4=b4, + eps=eps, + power=power, + use_nesterov=use_nesterov, + quantized_dtype=q_dtype, + ), optax.add_decayed_weights(weight_decay, weight_decay_mask), transform.scale_by_learning_rate(learning_rate, flip_sign=True), ) diff --git a/init2winit/optimizer_lib/kitchen_sink/_src/combine.py b/init2winit/optimizer_lib/kitchen_sink/_src/combine.py index 6164e4e6..e6e17f83 100644 --- a/init2winit/optimizer_lib/kitchen_sink/_src/combine.py +++ b/init2winit/optimizer_lib/kitchen_sink/_src/combine.py @@ -14,6 +14,7 @@ # limitations under the License. """Combine utilities.""" + import functools from typing import Any, NamedTuple from typing import Callable @@ -27,9 +28,13 @@ # TODO(dsuo): Add back grafting combinator. -def join(by: Union[str, Callable[[optax.GradientTransformation, ...], - optax.Updates]], *args, - **kwargs) -> Callable[..., optax.GradientTransformation]: +def join( + by: Union[ + str, Callable[[optax.GradientTransformation, ...], optax.Updates] + ], + *args, + **kwargs, +) -> Callable[..., optax.GradientTransformation]: """Join multiple chains.""" if by is None or by == 'chain': @@ -53,7 +58,7 @@ def init(params: optax.Params) -> optax.OptState: def update( updates: optax.Updates, state: optax.OptState, - params: Optional[optax.Params] = None + params: Optional[optax.Params] = None, ) -> Tuple[optax.Updates, optax.OptState]: combinator_state, args_state, kwargs_state = state @@ -89,14 +94,16 @@ def update( def _grafting_helper(chain, use_global_norm=False): norm = jax.tree.map(jnp.linalg.norm, chain) if use_global_norm: - global_norm = jax.tree_util.tree_reduce(lambda x, y: jnp.sqrt(x**2 + y**2), - norm) + global_norm = jax.tree_util.tree_reduce( + lambda x, y: jnp.sqrt(x**2 + y**2), norm + ) norm = jax.tree.map(lambda x: global_norm, norm) return norm class GraftingState(NamedTuple): """State for the Layered Adaptive RMS Preconditioner algorithm.""" + mag_norm: Any dir_norm: Any @@ -151,8 +158,12 @@ def init(params, *args, **kwargs): def update(state, *args, **kwargs): args = args + tuple(kwargs.values()) - return functools.reduce( - lambda x, y: jax.tree_multimap(lambda i, j: i + j, x, y), args), state + return ( + functools.reduce( + lambda x, y: jax.tree_multimap(lambda i, j: i + j, x, y), args + ), + state, + ) return init, update diff --git a/init2winit/optimizer_lib/kitchen_sink/_src/core.py b/init2winit/optimizer_lib/kitchen_sink/_src/core.py index 634a8ea1..47443acb 100644 --- a/init2winit/optimizer_lib/kitchen_sink/_src/core.py +++ b/init2winit/optimizer_lib/kitchen_sink/_src/core.py @@ -29,7 +29,6 @@ import ml_collections import optax - # TODO(dsuo): document config syntax. @@ -82,8 +81,9 @@ def _kitchen_sink_helper(config): return tx -def kitchen_sink(config: Dict[str, Any], - learning_rate: float = None) -> optax.GradientTransformation: +def kitchen_sink( + config: Dict[str, Any], learning_rate: float = None +) -> optax.GradientTransformation: """Runs a list of GradientTransforms in parallel and combines. Args: @@ -99,7 +99,8 @@ def kitchen_sink(config: Dict[str, Any], config = config.to_dict() elif not isinstance(config, dict): raise ValueError( - 'Kitchen Sink configuration needs to be a config dict or a python dict') + 'Kitchen Sink configuration needs to be a config dict or a python dict' + ) # Syntactic sugar. If we have an implied chain, make it explicitly a chain. if all([str(i) in config for i in range(len(config))]): diff --git a/init2winit/optimizer_lib/kitchen_sink/_src/mask.py b/init2winit/optimizer_lib/kitchen_sink/_src/mask.py index f89f352e..d14af197 100644 --- a/init2winit/optimizer_lib/kitchen_sink/_src/mask.py +++ b/init2winit/optimizer_lib/kitchen_sink/_src/mask.py @@ -14,6 +14,7 @@ # limitations under the License. """Mask utilities.""" + import flax @@ -31,14 +32,17 @@ def create_mask(fn): def mask(data): flattened_dict = flax.traverse_util.flatten_dict(data) return flax.traverse_util.unflatten_dict( - {k: fn(k, v) for k, v in flattened_dict.items()}) + {k: fn(k, v) for k, v in flattened_dict.items()} + ) return mask def create_weight_decay_mask(): return create_mask( - lambda p, _: 'bias' not in p and not p[-2].startswith('BatchNorm')) + lambda p, _: 'bias' not in p and not p[-2].startswith('BatchNorm') + ) + mask_registry = { 'bias_bn': create_weight_decay_mask(), diff --git a/init2winit/optimizer_lib/kitchen_sink/_src/preconditioner.py b/init2winit/optimizer_lib/kitchen_sink/_src/preconditioner.py index 5778f259..95de2723 100644 --- a/init2winit/optimizer_lib/kitchen_sink/_src/preconditioner.py +++ b/init2winit/optimizer_lib/kitchen_sink/_src/preconditioner.py @@ -87,7 +87,7 @@ def update(updates, state, params=None): 'updates': updates, 'variables': {}, 'moments': {}, - 'output': None + 'output': None, } new_state = [] for s, transform in zip(state, [variable_creator, accumulator, updater]): @@ -100,7 +100,8 @@ def update(updates, state, params=None): def nth_power( - power: Union[int, Tuple[int]] = 2) -> optax.GradientTransformation: + power: Union[int, Tuple[int, ...]] = 2, +) -> optax.GradientTransformation: """Create nth power(s) from gradients.""" if not hasattr(power, '__iter__'): @@ -117,7 +118,7 @@ def init(params: optax.Params) -> optax.OptState: def update( updates: optax.Updates, state: optax.OptState, - params: Optional[optax.Params] = None + params: Optional[optax.Params] = None, ) -> Tuple[optax.Updates, optax.OptState]: del params @@ -133,8 +134,9 @@ def update( return optax.GradientTransformation(init, update) -def ema_accumulator(decay: float = 0.999, - debias: bool = False) -> optax.GradientTransformation: +def ema_accumulator( + decay: float = 0.999, debias: bool = False +) -> optax.GradientTransformation: """Create accumulator that computes EMA on all updates.""" def init(params: optax.Params) -> optax.OptState: @@ -143,7 +145,7 @@ def init(params: optax.Params) -> optax.OptState: def update( updates: optax.Updates, state: optax.OptState, - params: Optional[optax.Params] = None + params: Optional[optax.Params] = None, ) -> Tuple[optax.Updates, optax.OptState]: del params @@ -153,8 +155,11 @@ def update( count = count + jnp.array(1, dtype=jnp.int32) beta = jnp.array(1, dtype=jnp.int32) - decay**count - updates['moments'] = moments if not debias else jax.tree.map( - lambda t: t / beta.astype(t.dtype), moments) + updates['moments'] = ( + moments + if not debias + else jax.tree.map(lambda t: t / beta.astype(t.dtype), moments) + ) return updates, (moments, count) @@ -164,14 +169,20 @@ def update( # TODO(dsuo): from namanagarwal@: revisit `initial_accumulator_value`. # `tensorflow` defaults to this value, but perhaps should consider # 0 or 1e-8 to match rms-type accumulators. -def yogi_accumulator(b2: float = 0.999, - initial_accumulator_value: float = 1e-6, - debias: bool = False) -> optax.GradientTransformation: +def yogi_accumulator( + b2: float = 0.999, + initial_accumulator_value: float = 1e-6, + debias: bool = False, +) -> optax.GradientTransformation: """Create yogi accumulator.""" def init(params: optax.Params) -> optax.OptState: - return (jax.tree.map(lambda p: jnp.full_like(p, initial_accumulator_value), - params), jnp.zeros([], dtype=jnp.int32)) + return ( + jax.tree.map( + lambda p: jnp.full_like(p, initial_accumulator_value), params + ), + jnp.zeros([], dtype=jnp.int32), + ) def update(updates, state, params=None): del params @@ -182,8 +193,11 @@ def update(updates, state, params=None): count = count + jnp.array(1, dtype=jnp.int32) beta = jnp.array(1, dtype=jnp.float32) - b2**count - updates['moments'] = moments if not debias else jax.tree.map( - lambda t: t / beta.astype(t.dtype), moments) + updates['moments'] = ( + moments + if not debias + else jax.tree.map(lambda t: t / beta.astype(t.dtype), moments) + ) return updates, (moments, count) @@ -206,15 +220,20 @@ def init(params: optax.Params) -> optax.OptState: def update( updates: optax.Updates, state: optax.OptState, - params: Optional[optax.Params] = None + params: Optional[optax.Params] = None, ) -> Tuple[optax.Updates, optax.OptState]: del params - grads = updates['updates'] if not use_accumulated_gradient else updates[ - 'moments']['1'] + grads = ( + updates['updates'] + if not use_accumulated_gradient + else updates['moments']['1'] + ) updates['output'] = jax.tree.map( - lambda u, v: u / (jnp.power(v + eps_root, exponent) + eps), grads, - updates['moments'][str(moment)]) + lambda u, v: u / (jnp.power(v + eps_root, exponent) + eps), + grads, + updates['moments'][str(moment)], + ) return updates, state diff --git a/init2winit/optimizer_lib/kitchen_sink/_src/test_core.py b/init2winit/optimizer_lib/kitchen_sink/_src/test_core.py index 74cd7f53..e95986b2 100644 --- a/init2winit/optimizer_lib/kitchen_sink/_src/test_core.py +++ b/init2winit/optimizer_lib/kitchen_sink/_src/test_core.py @@ -34,12 +34,8 @@ class KitchenSinkSuccessTest(absltest.TestCase): def setUp(self): super().setUp() self.basic_sink = kitchen_sink({ - '0': { - 'element': 'nesterov' - }, - '1': { - 'element': 'polyak_hb' - }, + '0': {'element': 'nesterov'}, + '1': {'element': 'polyak_hb'}, }) def test_construction(self): @@ -52,18 +48,15 @@ def test_dummy_step(self): ys = 1 optimizer = kitchen_sink({ - '0': { - 'element': 'nesterov' - }, - '1': { - 'element': 'polyak_hb' - }, + '0': {'element': 'nesterov'}, + '1': {'element': 'polyak_hb'}, }) params = {'w': jnp.ones((num_weights,))} opt_state = optimizer.init(flax.core.FrozenDict(params)) - compute_loss = lambda params, x, y: optax.l2_loss(params['w'].dot(x), - jnp.array(y)) + compute_loss = lambda params, x, y: optax.l2_loss( + params['w'].dot(x), jnp.array(y) + ) grads = jax.grad(compute_loss)(params, xs, ys) updates, opt_state = optimizer.update(grads, opt_state, params) @@ -76,19 +69,14 @@ class KitchenSinkFailTest(absltest.TestCase): """Test transform_chain exceptional behavior.""" def test_bad_hp(self): - self.assertRaises(TypeError, kitchen_sink, - {'0': { - 'element': 'nesterov', - 'hps': { - 'asdf': 1e-4 - } - }}) + self.assertRaises( + TypeError, + kitchen_sink, + {'0': {'element': 'nesterov', 'hps': {'asdf': 1e-4}}}, + ) def test_bad_transform_name(self): - self.assertRaises(ValueError, kitchen_sink, - {'0': { - 'element': 'rasputin' - }}) + self.assertRaises(ValueError, kitchen_sink, {'0': {'element': 'rasputin'}}) class KitchenSinkMaskTest(absltest.TestCase): @@ -99,7 +87,7 @@ def test_no_op(self): optimizer = kitchen_sink({ '0': { 'element': 'nesterov', - 'mask': lambda p: jax.tree.map(lambda x: x.ndim != 1, p) + 'mask': lambda p: jax.tree.map(lambda x: x.ndim != 1, p), } }) params = {'w': jnp.array([1, 2, 3])} @@ -113,7 +101,7 @@ def test_mask_dim_1(self): optimizer = kitchen_sink({ '0': { 'element': 'nesterov', - 'mask': lambda p: jax.tree.map(lambda x: x.ndim != 1, p) + 'mask': lambda p: jax.tree.map(lambda x: x.ndim != 1, p), } }) params = {'w': jnp.array([1, 2, 3]), 'b': jnp.ones((2, 2))} @@ -133,23 +121,18 @@ class FromHParamsTest(chex.TestCase): """Test construction from opt_hparams ml_collections.ConfigDict.""" def test_empty_element(self): - self.assertRaises(ValueError, kitchen_sink, - ml_collections.ConfigDict({'0': {}})) + self.assertRaises( + ValueError, kitchen_sink, ml_collections.ConfigDict({'0': {}}) + ) def test_empty_hps(self): - kitchen_sink( - ml_collections.ConfigDict({'0': { - 'element': 'nesterov' - }})) + kitchen_sink(ml_collections.ConfigDict({'0': {'element': 'nesterov'}})) def test_equal_structs(self): """Test that initial and updated states have the same tree structure.""" - tx = kitchen_sink( - ml_collections.ConfigDict({'0': { - 'element': 'nesterov' - }})) - params = {'a': 1.} - gradients = {'a': 2.} + tx = kitchen_sink(ml_collections.ConfigDict({'0': {'element': 'nesterov'}})) + params = {'a': 1.0} + gradients = {'a': 2.0} state = tx.init(params) _, new_state = tx.update(gradients, state) chex.assert_trees_all_equal_structs(state, new_state) @@ -157,33 +140,24 @@ def test_equal_structs(self): def test_empty_mask(self): kitchen_sink( ml_collections.ConfigDict( - {'0': { - 'element': 'nesterov', - 'hps': { - 'decay': 0.9 - } - }})) + {'0': {'element': 'nesterov', 'hps': {'decay': 0.9}}} + ) + ) def test_one_minus(self): """Test that we appropriately process / remove hps.""" tx = kitchen_sink( ml_collections.ConfigDict( - {'0': { - 'element': 'nesterov', - 'hps': { - 'decay': 0.9 - } - }})) + {'0': {'element': 'nesterov', 'hps': {'decay': 0.9}}} + ) + ) tx_one_minus = kitchen_sink( ml_collections.ConfigDict( - {'0': { - 'element': 'nesterov', - 'hps': { - 'one_minus_decay': 0.1 - } - }})) + {'0': {'element': 'nesterov', 'hps': {'one_minus_decay': 0.1}}} + ) + ) - params = {'a': 1.} + params = {'a': 1.0} state = tx.init(params) updates, state = tx.update(params, state, params) @@ -201,7 +175,7 @@ class GraftCombinatorTest(chex.TestCase): def test_1d(self): """Test grafting.""" - x = 10. + x = 10.0 lr = 0.01 g = kitchen_sink( @@ -209,20 +183,17 @@ def test_1d(self): 'join': { 'mag_chain': { 'element': 'sgd', - 'hps': { - 'learning_rate': -1. - }, + 'hps': {'learning_rate': -1.0}, }, 'dir_chain': { 'element': 'sgd', - 'hps': { - 'learning_rate': -2. - }, - } + 'hps': {'learning_rate': -2.0}, + }, }, 'by': 'grafting', }, - learning_rate=lr) + learning_rate=lr, + ) state = g.init(x) for _ in range(10): grad_fn = jax.value_and_grad(lambda x: x**2) @@ -253,36 +224,31 @@ def __call__(self, x): model = SimpleNN() params = model.init(key, x) - m = optax.sgd(learning_rate=-1., momentum=0.9) - d = optax.sgd(learning_rate=-2., momentum=0.9) + m = optax.sgd(learning_rate=-1.0, momentum=0.9) + d = optax.sgd(learning_rate=-2.0, momentum=0.9) g = kitchen_sink( { 'join': { 'mag_chain': { 'element': 'sgd', - 'hps': { - 'learning_rate': -1., - 'momentum': 0.9 - }, + 'hps': {'learning_rate': -1.0, 'momentum': 0.9}, }, 'dir_chain': { 'element': 'sgd', - 'hps': { - 'learning_rate': -2., - 'momentum': 0.9 - }, - } + 'hps': {'learning_rate': -2.0, 'momentum': 0.9}, + }, }, 'by': 'grafting', }, - learning_rate=lr) + learning_rate=lr, + ) s_m = m.init(params) s_d = d.init(params) s_g = g.init(params) def loss_fn(params): yhat = model.apply(params, x) - loss = jnp.sum((y - yhat)**2) + loss = jnp.sum((y - yhat) ** 2) return loss for _ in range(10): @@ -294,8 +260,9 @@ def loss_fn(params): u_m_n = jax.tree.map(jnp.linalg.norm, u_m) u_d_n = jax.tree.map(jnp.linalg.norm, u_d) - u_g2 = jax.tree.map(lambda m, d, dn: -lr * d / (dn + 1e-6) * m, u_m_n, - u_d, u_d_n) + u_g2 = jax.tree.map( + lambda m, d, dn: -lr * d / (dn + 1e-6) * m, u_m_n, u_d, u_d_n + ) chex.assert_trees_all_close(u_g, u_g2) diff --git a/init2winit/optimizer_lib/kitchen_sink/_src/test_mask.py b/init2winit/optimizer_lib/kitchen_sink/_src/test_mask.py index 310ed91f..68a5538a 100644 --- a/init2winit/optimizer_lib/kitchen_sink/_src/test_mask.py +++ b/init2winit/optimizer_lib/kitchen_sink/_src/test_mask.py @@ -43,8 +43,8 @@ def __call__(self, x): use_running_average=not self.train, momentum=0.9, epsilon=1e-5, - dtype=jnp.float32)( - x) + dtype=jnp.float32, + )(x) return x @@ -83,44 +83,16 @@ def test_simple(self): """Check that the correct tags are removed.""" mask = create_weight_decay_mask() data = { - 'bias': { - 'b': 4 - }, - 'bias': { - 'BatchNorm_0': 4, - 'bias': 5, - 'a': 0 - }, - 'BatchNorm_0': { - 'b': 4 - }, - 'a': { - 'b': { - 'BatchNorm_0': 0, - 'bias': 0 - }, - 'c': 0 - } + 'bias': {'b': 4}, + 'bias': {'BatchNorm_0': 4, 'bias': 5, 'a': 0}, + 'BatchNorm_0': {'b': 4}, + 'a': {'b': {'BatchNorm_0': 0, 'bias': 0}, 'c': 0}, } truth = { - 'bias': { - 'b': False - }, - 'bias': { - 'BatchNorm_0': False, - 'bias': False, - 'a': False - }, - 'BatchNorm_0': { - 'b': False - }, - 'a': { - 'b': { - 'BatchNorm_0': True, - 'bias': False - }, - 'c': True - } + 'bias': {'b': False}, + 'bias': {'BatchNorm_0': False, 'bias': False, 'a': False}, + 'BatchNorm_0': {'b': False}, + 'a': {'b': {'BatchNorm_0': True, 'bias': False}, 'c': True}, } chex.assert_equal(mask(data), truth) @@ -140,9 +112,9 @@ def test_batch(self): @self.variant def train_step(params, x, y): - y1, new_batch_stats = Foo( - filters=7, train=True).apply( - params, x, mutable=['batch_stats']) + y1, new_batch_stats = Foo(filters=7, train=True).apply( + params, x, mutable=['batch_stats'] + ) return jnp.abs(y - y1).sum(), new_batch_stats @@ -150,8 +122,9 @@ def train_step(params, x, y): grads, _ = jax.grad(train_step, has_aux=True)(foo_vars, x, y) updates, state = self.variant(tx.update)(dict(grads['params']), state) - chex.assert_trees_all_close(updates['BatchNorm_0'], - grads['params']['BatchNorm_0']) + chex.assert_trees_all_close( + updates['BatchNorm_0'], grads['params']['BatchNorm_0'] + ) @chex.variants(with_jit=True, without_jit=True) def test_bias(self): @@ -174,8 +147,10 @@ def loss(params, x, y): updates, state = self.variant(tx.update)(dict(grads), state) for i in range(3): - chex.assert_trees_all_close(grads['params'][f'layers_{i}']['bias'], - updates['params'][f'layers_{i}']['bias']) + chex.assert_trees_all_close( + grads['params'][f'layers_{i}']['bias'], + updates['params'][f'layers_{i}']['bias'], + ) if __name__ == '__main__': diff --git a/init2winit/optimizer_lib/kitchen_sink/_src/test_preconditioner.py b/init2winit/optimizer_lib/kitchen_sink/_src/test_preconditioner.py index 842e3f8a..cfbbdf57 100644 --- a/init2winit/optimizer_lib/kitchen_sink/_src/test_preconditioner.py +++ b/init2winit/optimizer_lib/kitchen_sink/_src/test_preconditioner.py @@ -75,11 +75,12 @@ def test_ema_accumulator(self, decay, debias): grads, _ = nth_grads.update(grads, None) updates, state = accumulator.update(grads, state) - actual = jax.tree.map(lambda g, t: (1 - decay) * g + decay * t, - grads['variables'], moments) + actual = jax.tree.map( + lambda g, t: (1 - decay) * g + decay * t, grads['variables'], moments + ) if debias: count += jnp.array(1, dtype=jnp.int32) - beta = jnp.array(1, dtype=jnp.int32) - decay ** count + beta = jnp.array(1, dtype=jnp.int32) - decay**count actual = jax.tree.map(lambda t: t / beta.astype(t.dtype), actual) # pylint: disable=cell-var-from-loop chex.assert_trees_all_close(updates['moments'], actual) @@ -99,18 +100,20 @@ def test_precondition_by_rms(self, decay, eps, eps_root, debias): """Test precondition_by_rms.""" actual_rms = transform.precondition_by_rms( - decay=decay, eps=eps, eps_root=eps_root, debias=debias) - - decon_rms = preconditioner.preconditioner(preconditioner.nth_power, - preconditioner.ema_accumulator, - preconditioner.rexp_updater, - {'power': 2}, { - 'decay': decay, - 'debias': debias - }, { - 'eps': eps, - 'eps_root': eps_root, - }) + decay=decay, eps=eps, eps_root=eps_root, debias=debias + ) + + decon_rms = preconditioner.preconditioner( + preconditioner.nth_power, + preconditioner.ema_accumulator, + preconditioner.rexp_updater, + {'power': 2}, + {'decay': decay, 'debias': debias}, + { + 'eps': eps, + 'eps_root': eps_root, + }, + ) params = jax.random.uniform(jax.random.PRNGKey(0), (10, 10)) @@ -139,8 +142,9 @@ def test_precondition_by_rms(self, decay, eps, eps_root, debias): initial_accumulator_value=[1e-8, 1e-6, 1e-1], debias=[True, False], ) - def test_precondition_by_yogi(self, b2, eps, eps_root, - initial_accumulator_value, debias): + def test_precondition_by_yogi( + self, b2, eps, eps_root, initial_accumulator_value, debias + ): """Test precondition_by_yogi.""" actual_yogi = transform.precondition_by_yogi( @@ -148,18 +152,24 @@ def test_precondition_by_yogi(self, b2, eps, eps_root, eps=eps, eps_root=eps_root, initial_accumulator_value=initial_accumulator_value, - debias=debias) + debias=debias, + ) decon_yogi = preconditioner.preconditioner( - preconditioner.nth_power, preconditioner.yogi_accumulator, - preconditioner.rexp_updater, {'power': 2}, { + preconditioner.nth_power, + preconditioner.yogi_accumulator, + preconditioner.rexp_updater, + {'power': 2}, + { 'b2': b2, 'initial_accumulator_value': initial_accumulator_value, - 'debias': debias - }, { + 'debias': debias, + }, + { 'eps': eps, 'eps_root': eps_root, - }) + }, + ) params = jax.random.uniform(jax.random.PRNGKey(0), (10, 10)) @@ -180,5 +190,6 @@ def test_precondition_by_yogi(self, b2, eps, eps_root, chex.assert_trees_all_close(actual_params, decon_params, atol=1e-4) + if __name__ == '__main__': absltest.main() diff --git a/init2winit/optimizer_lib/kitchen_sink/_src/test_transform.py b/init2winit/optimizer_lib/kitchen_sink/_src/test_transform.py index 796fc4b4..a4051531 100644 --- a/init2winit/optimizer_lib/kitchen_sink/_src/test_transform.py +++ b/init2winit/optimizer_lib/kitchen_sink/_src/test_transform.py @@ -34,8 +34,9 @@ def _optimizer_loop(optimizer, iterations=5): results = [] for _ in range(iterations): compute_loss = lambda params, x, y: optax.l2_loss(params['w'].dot(x), y) - grads = jax.grad(compute_loss)(params, jnp.array([5.0, 6.0]), - jnp.array(4.0)) + grads = jax.grad(compute_loss)( + params, jnp.array([5.0, 6.0]), jnp.array(4.0) + ) updates, opt_state = optimizer.update(grads, opt_state, params) params = optax.apply_updates(params, updates) results.append(params) @@ -46,7 +47,8 @@ def _optimizers_all_close(tx1, tx2, iterations=5, rtol=1e-5): chex.assert_trees_all_close( _optimizer_loop(tx1, iterations), _optimizer_loop(tx2, iterations), - rtol=rtol) + rtol=rtol, + ) class AdamTest(parameterized.TestCase): @@ -62,7 +64,8 @@ def test_adam(self, b1, b2, eps, eps_root): """Test adam. Thoroughly.""" tx1 = optax.scale_by_adam(b1=b1, b2=b2, eps=eps, eps_root=eps_root) tx2 = transform.scale_by_adam( - b1=b1, b2=b2, eps=eps, eps_root=eps_root, debias=True) + b1=b1, b2=b2, eps=eps, eps_root=eps_root, debias=True + ) _optimizers_all_close(tx1, tx2) @@ -74,30 +77,16 @@ def test_correctness(self): """Testing correctness via an independent flax.optim run.""" target_solution = [ - { - 'w': jnp.array([0.40500003, 0.286]) - }, - { - 'w': jnp.array([0.255515, 0.106618]) - }, - { - 'w': jnp.array([0.31884143, 0.18260972]) - }, - { - 'w': jnp.array([0.40163627, 0.28196353]) - }, - { - 'w': jnp.array([0.43924114, 0.32708937]) - }, + {'w': jnp.array([0.40500003, 0.286])}, + {'w': jnp.array([0.255515, 0.106618])}, + {'w': jnp.array([0.31884143, 0.18260972])}, + {'w': jnp.array([0.40163627, 0.28196353])}, + {'w': jnp.array([0.43924114, 0.32708937])}, ] optimizer = kitchen_sink( - {'0': { - 'element': 'nesterov', - 'hps': { - 'decay': 0.7 - } - }}, - learning_rate=0.01) + {'0': {'element': 'nesterov', 'hps': {'decay': 0.7}}}, + learning_rate=0.01, + ) results = _optimizer_loop(optimizer) for target, result in zip(target_solution, results): chex.assert_trees_all_close(target, result) @@ -110,30 +99,16 @@ def test_correctness(self): """Testing correctness via an independent flax.optim run.""" target_solution = [ - { - 'w': jnp.array([0.65, 0.58000004]) - }, - { - 'w': jnp.array([0.26849997, 0.12220004]) - }, - { - 'w': jnp.array([0.09766498, -0.08280197]) - }, - { - 'w': jnp.array([0.17850482, 0.01420582]) - }, - { - 'w': jnp.array([0.38620475, 0.2634457]) - }, + {'w': jnp.array([0.65, 0.58000004])}, + {'w': jnp.array([0.26849997, 0.12220004])}, + {'w': jnp.array([0.09766498, -0.08280197])}, + {'w': jnp.array([0.17850482, 0.01420582])}, + {'w': jnp.array([0.38620475, 0.2634457])}, ] optimizer = kitchen_sink( - {'0': { - 'element': 'polyak_hb', - 'hps': { - 'decay': 0.7 - } - }}, - learning_rate=0.01) + {'0': {'element': 'polyak_hb', 'hps': {'decay': 0.7}}}, + learning_rate=0.01, + ) results = _optimizer_loop(optimizer) for target, result in zip(target_solution, results): chex.assert_trees_all_close(target, result) @@ -154,9 +129,9 @@ def init_fn(params): def update_fn(updates, state, params=None): del params state['count'] += 1 - state['w'] = ((1 - decay) * updates['w'] + decay * state['w']) + state['w'] = (1 - decay) * updates['w'] + decay * state['w'] if debias: - update = {'w': state['w'] / (1 - decay**state['count'])} + update = {'w': state['w'] / (1 - decay ** state['count'])} else: update = {'w': state['w']} return update, state @@ -165,18 +140,16 @@ def update_fn(updates, state, params=None): decay = 0.7 learning_rate = 0.01 - true_ema = optax.chain(ema(decay), optax.scale(-1. * learning_rate)) + true_ema = optax.chain(ema(decay), optax.scale(-1.0 * learning_rate)) ks_ema = kitchen_sink( { '0': { 'element': 'first_moment_ema', - 'hps': { - 'decay': decay, - 'debias': True - } + 'hps': {'decay': decay, 'debias': True}, } }, - learning_rate=learning_rate) + learning_rate=learning_rate, + ) targets = _optimizer_loop(true_ema) results = _optimizer_loop(ks_ema) @@ -192,12 +165,7 @@ def test_debias_false(self): precondition_by_rms = kitchen_sink({ '0': { 'element': 'precondition_by_rms', - 'hps': { - 'decay': 0.9, - 'debias': False, - 'eps': 0, - 'eps_root': 1e-8 - } + 'hps': {'decay': 0.9, 'debias': False, 'eps': 0, 'eps_root': 1e-8}, } }) targets = _optimizer_loop(rms_prop) @@ -208,20 +176,17 @@ def test_debias_false(self): def test_debias_true(self): adam = kitchen_sink( - {'0': { - 'element': 'scale_by_adam', - 'hps': { - 'b1': 0.0 - }, - }}) + { + '0': { + 'element': 'scale_by_adam', + 'hps': {'b1': 0.0}, + } + } + ) precondition_by_rms = kitchen_sink( - {'0': { - 'element': 'precondition_by_rms', - 'hps': { - 'debias': True - } - }}) + {'0': {'element': 'precondition_by_rms', 'hps': {'debias': True}}} + ) targets = _optimizer_loop(adam) results = _optimizer_loop(precondition_by_rms) @@ -242,11 +207,12 @@ def test_with_0_momentum_yogi(self): 'b2': 0.9, 'debias': True, 'eps': 1e-8, - 'eps_root': 1e-6 - } + 'eps_root': 1e-6, + }, } }, - learning_rate=1.0) + learning_rate=1.0, + ) targets = _optimizer_loop(optax_yogi) results = _optimizer_loop(precondition_by_yogi) @@ -276,21 +242,27 @@ def init_fn(params): return State( nu=jax.tree.map(jnp.zeros_like, params), trace=jax.tree.map(jnp.zeros_like, params), - count=jnp.zeros([], jnp.int32)) + count=jnp.zeros([], jnp.int32), + ) def update_fn(updates, state, params=None): del params count = state.count + jnp.array(1, jnp.int32) nu = { - 'w': (1 - rms_decay) * (updates['w']**2) + rms_decay * state.nu['w'] + 'w': ( + (1 - rms_decay) * (updates['w'] ** 2) + + rms_decay * state.nu['w'] + ) } updates = {'w': updates['w'] / (jax.lax.sqrt(nu['w'] + eps_root) + eps)} updates = {'w': updates['w'] * jnp.sqrt((1 - rms_decay**count))} trace = { - 'w': (1 - moment_decay) * updates['w'] + - moment_decay * state.trace['w'] + 'w': ( + (1 - moment_decay) * updates['w'] + + moment_decay * state.trace['w'] + ) } updates = {'w': trace['w']} @@ -308,16 +280,13 @@ def update_fn(updates, state, params=None): 'decay': rms_decay, 'eps': eps, 'eps_root': eps_root, - 'debias': True - } + 'debias': True, + }, }, '1': { 'element': 'first_moment_ema', - 'hps': { - 'decay': moment_decay, - 'debias': True - } - } + 'hps': {'decay': moment_decay, 'debias': True}, + }, }) targets = _optimizer_loop(true_twisted_adam) @@ -344,19 +313,18 @@ def update_fn(updates, state, params=None): _, state = adam.update(updates, state, params) curr_nu = state.nu nu_hat = jax.tree.map(jnp.maximum, curr_nu, prev_nu) - updates = jax.tree.map(lambda m, v: m / (jnp.sqrt(v + 0.0) + 1e-8), - state.mu, nu_hat) + updates = jax.tree.map( + lambda m, v: m / (jnp.sqrt(v + 0.0) + 1e-8), state.mu, nu_hat + ) return updates, optax.ScaleByAdamState( - count=state.count, mu=state.mu, nu=nu_hat) + count=state.count, mu=state.mu, nu=nu_hat + ) return optax.GradientTransformation(init_fn, update_fn) true_amsgrad = amsgrad() - ks_amsgrad = kitchen_sink( - {'0': { - 'element': 'scale_by_amsgrad' - }}) + ks_amsgrad = kitchen_sink({'0': {'element': 'scale_by_amsgrad'}}) targets = _optimizer_loop(true_amsgrad) results = _optimizer_loop(ks_amsgrad) @@ -376,16 +344,17 @@ def test_adagrad(self): 'element': 'precondition_by_rss', 'hps': { 'initial_accumulator_value': 0.3, - } + }, }, '1': { 'element': 'first_moment_ema', 'hps': { 'decay': 0.0, - } - } + }, + }, }, - learning_rate=0.7) + learning_rate=0.7, + ) targets = _optimizer_loop(true_adagrad) results = _optimizer_loop(ks_adagrad) @@ -400,48 +369,26 @@ class EqEMAHBTest(chex.TestCase): def test_equivalence(self): hb = kitchen_sink( { - '0': { - 'element': 'precondition_by_rms', - 'hps': { - 'decay': 0.3 - } - }, - '1': { - 'element': 'polyak_hb', - 'hps': { - 'decay': 0.5 - } - }, + '0': {'element': 'precondition_by_rms', 'hps': {'decay': 0.3}}, + '1': {'element': 'polyak_hb', 'hps': {'decay': 0.5}}, '2': { 'element': 'add_decayed_weights', - 'hps': { - 'weight_decay': 0.1 - } - } + 'hps': {'weight_decay': 0.1}, + }, }, - learning_rate=1.0) + learning_rate=1.0, + ) ema = kitchen_sink( { - '0': { - 'element': 'precondition_by_rms', - 'hps': { - 'decay': 0.3 - } - }, - '1': { - 'element': 'first_moment_ema', - 'hps': { - 'decay': 0.5 - } - }, + '0': {'element': 'precondition_by_rms', 'hps': {'decay': 0.3}}, + '1': {'element': 'first_moment_ema', 'hps': {'decay': 0.5}}, '2': { 'element': 'add_decayed_weights', - 'hps': { - 'weight_decay': 0.05 - } - } + 'hps': {'weight_decay': 0.05}, + }, }, - learning_rate=2.0) + learning_rate=2.0, + ) targets = _optimizer_loop(hb) results = _optimizer_loop(ema) @@ -465,11 +412,12 @@ def test_output_modality_1(self): 'decays': decays, 'scales': scales, 'decay_distribution': decay_distribution, - 'eps_root': 0.0 - } + 'eps_root': 0.0, + }, }, }, - learning_rate=1.0) + learning_rate=1.0, + ) scales = jnp.array([0.9, 0.5, 1.0, 1.0]) betas = jnp.array([0.19, 0.75, 1.0, 1.0]) one_minus_betas = jnp.array([0.81, 0.25, 1.0, 1.0]) @@ -477,7 +425,7 @@ def test_output_modality_1(self): opt_state = ks_opt.init(params) # step 1 grads = {'w': 2 * jnp.ones((4,))} - true_nu = one_minus_betas * (grads['w']**2) + true_nu = one_minus_betas * (grads['w'] ** 2) true_updates = { 'w': -1.0 * jnp.array(scales) * grads['w'] / jnp.sqrt(true_nu) } @@ -486,7 +434,7 @@ def test_output_modality_1(self): params = optax.apply_updates(params, opt_updates) # step2 grads = {'w': jnp.ones((4,))} - true_nu = one_minus_betas * (grads['w']**2) + betas * true_nu + true_nu = one_minus_betas * (grads['w'] ** 2) + betas * true_nu true_updates = { 'w': -1.0 * jnp.array(scales) * grads['w'] / jnp.sqrt(true_nu) } diff --git a/init2winit/optimizer_lib/kitchen_sink/_src/transform.py b/init2winit/optimizer_lib/kitchen_sink/_src/transform.py index db147e14..9b7726b7 100644 --- a/init2winit/optimizer_lib/kitchen_sink/_src/transform.py +++ b/init2winit/optimizer_lib/kitchen_sink/_src/transform.py @@ -24,7 +24,6 @@ import jax.numpy as jnp import optax - # pylint:disable=invalid-name # pylint:disable=no-value-for-parameter @@ -66,8 +65,7 @@ def _update_preconditioner_moment(updates, moments, decay, order): assert order >= 1 and order <= 2 moment_func = lambda x: jnp.power(jnp.abs(x), order) return jax.tree.map( - lambda g, t: (1 - decay) * moment_func(g) + decay * t, - updates, moments + lambda g, t: (1 - decay) * moment_func(g) + decay * t, updates, moments ) @@ -192,7 +190,8 @@ def compute_params_ema_for_eval( def init_fn(params): return optax.EmaState( - count=jnp.array(0, dtype=jnp.int32), ema=jax.tree.map(jnp.copy, params)) + count=jnp.array(0, dtype=jnp.int32), ema=jax.tree.map(jnp.copy, params) + ) def update_fn(updates, state, params): if params is None: @@ -200,7 +199,7 @@ def update_fn(updates, state, params): if warmup: # https://github.com/tensorflow/tensorflow/blob/v2.9.1/tensorflow/python/training/moving_averages.py#L469 - ema_decay = jnp.minimum(decay, (1. + state.count) / (10. + state.count)) + ema_decay = jnp.minimum(decay, (1.0 + state.count) / (10.0 + state.count)) else: ema_decay = decay @@ -535,9 +534,7 @@ class BiasCorrectionState(NamedTuple): count: chex.Array # shape=(), dtype=jnp.int32. -def bias_correction( - decay: float = 0.9 -) -> optax.GradientTransformation: +def bias_correction(decay: float = 0.9) -> optax.GradientTransformation: """Compute the Adam style bias correction. Args: @@ -546,6 +543,7 @@ def bias_correction( Returns: An (init_fn, update_fn) tuple. """ + def init_fn(params): del params return BiasCorrectionState(count=jnp.zeros([], jnp.int32)) @@ -587,7 +585,7 @@ def init_fn(params): jnp.zeros_like, params ) # previous update with step-size/lr included return ScaleBy_Adaptive_GD_State( - r_squared=init_r_squared*jnp.ones([], jnp.float64), + r_squared=init_r_squared * jnp.ones([], jnp.float64), lambda_prev=jnp.zeros([], jnp.float64), lambda_sum=jnp.zeros([], jnp.float64), init_params=init_params, @@ -679,9 +677,7 @@ def update_fn(updates, state, params): state.prev_update, state.lambda_prev, ) - new_update_norm_squared = jax.tree.map( - lambda u: jnp.sum(u ** 2), new_updates - ) + new_update_norm_squared = jax.tree.map(lambda u: jnp.sum(u**2), new_updates) lambda_new = jax.tree.map( lambda l, g, r: 0.5 * (jnp.sqrt(l**2 + jnp.divide(g, r)) - l), state.lambda_sum, @@ -751,9 +747,7 @@ def update_fn(updates, state, params): state.prev_update, state.lambda_prev, ) - new_update_norm_squared = jax.tree.map( - jnp.square, new_updates - ) + new_update_norm_squared = jax.tree.map(jnp.square, new_updates) lambda_new = jax.tree.map( lambda l, g, r: 0.5 * (jnp.sqrt(jnp.square(l) + jnp.divide(g, r)) - l), state.lambda_sum, @@ -763,9 +757,7 @@ def update_fn(updates, state, params): lambda_sum_new = jax.tree.map( lambda l1, l2: l1 + l2, state.lambda_sum, lambda_new ) - new_updates_with_lr = jax.tree.map( - jnp.divide, new_updates, lambda_sum_new - ) + new_updates_with_lr = jax.tree.map(jnp.divide, new_updates, lambda_sum_new) negative_new_updates_with_lr = jax.tree.map( lambda u: -u, new_updates_with_lr ) @@ -821,13 +813,11 @@ def update_fn(updates, state, params): mu_sum_new = 0.5 * ( jnp.sqrt( state.mu_sum**2 - + jnp.divide((4*update_norm_squared), curr_r_squared) + + jnp.divide((4 * update_norm_squared), curr_r_squared) ) + state.mu_sum ) - new_updates_with_lr = jax.tree.map( - lambda u: u / mu_sum_new, updates - ) + new_updates_with_lr = jax.tree.map(lambda u: u / mu_sum_new, updates) return new_updates_with_lr, ScaleBy_Adaptive_GD_Simple_State( r_squared=curr_r_squared, mu_sum=mu_sum_new, @@ -870,18 +860,14 @@ def update_fn(updates, state, params): state.r_squared, curr_distance_norm_squared, ) - update_norm_squared = jax.tree.map( - lambda u: jnp.sum(u ** 2), updates - ) + update_norm_squared = jax.tree.map(lambda u: jnp.sum(u**2), updates) mu_sum_new = jax.tree.map( lambda l, g, r: 0.5 * (jnp.sqrt(l**2 + 4 * jnp.divide(g, r)) + l), state.mu_sum, update_norm_squared, curr_r_squared, ) - new_updates_with_lr = jax.tree.map( - lambda u, l: u / l, updates, mu_sum_new - ) + new_updates_with_lr = jax.tree.map(lambda u, l: u / l, updates, mu_sum_new) return new_updates_with_lr, ScaleBy_Adaptive_GD_Simple_State( r_squared=curr_r_squared, mu_sum=mu_sum_new, @@ -909,11 +895,9 @@ def init_fn(params): init_params = jax.tree.map(jnp.copy, params) # x0 return ScaleBy_Adaptive_GD_Simple_State( r_squared=jax.tree.map( - lambda x: init_r_squared*jnp.ones_like(x), params - ), - mu_sum=jax.tree.map( - lambda x: eps*jnp.ones_like(x), params + lambda x: init_r_squared * jnp.ones_like(x), params ), + mu_sum=jax.tree.map(lambda x: eps * jnp.ones_like(x), params), init_params=init_params, ) @@ -926,18 +910,15 @@ def update_fn(updates, state, params): state.r_squared, curr_distance_norm_squared, ) - update_norm_squared = jax.tree.map( - jnp.square, updates - ) + update_norm_squared = jax.tree.map(jnp.square, updates) mu_sum_new = jax.tree.map( - lambda l, g, r: 0.5*(jnp.sqrt(jnp.square(l) + 4*jnp.divide(g, r)) + l), + lambda l, g, r: 0.5 + * (jnp.sqrt(jnp.square(l) + 4 * jnp.divide(g, r)) + l), state.mu_sum, update_norm_squared, curr_r_squared, ) - new_updates_with_lr = jax.tree.map( - jnp.divide, updates, mu_sum_new - ) + new_updates_with_lr = jax.tree.map(jnp.divide, updates, mu_sum_new) return new_updates_with_lr, ScaleBy_Adaptive_GD_Simple_State( r_squared=curr_r_squared, mu_sum=mu_sum_new, @@ -1220,8 +1201,8 @@ def scale_by_adam_var_preserved( eps_root: term added to the denominator inside the square-root to improve numerical stability when backpropagating gradients through the rescaling. debias: whether to use moment bias correction. Note inspite of - implementation Adam style bias correction might not make sense here. - So it should not be used. + implementation Adam style bias correction might not make sense here. So it + should not be used. power: the power to use in the preconditioner (0.5 in default adam). Returns: @@ -1256,6 +1237,7 @@ def update_fn(updates, state, params=None): class ScaleByAdapropState(NamedTuple): """State for the AdaProp algorithm.""" + count: chex.Array # shape=(), dtype=jnp.int32. pp: optax.Updates mu: optax.Updates @@ -1277,8 +1259,8 @@ def scale_by_adaprop( Args: b1: decay rate for the exponentially weighted average of grads. - b2: decay rate for the exponentially weighted average of absolute grads - is omitted because it is calculated from alpha and b1. + b2: decay rate for the exponentially weighted average of absolute grads is + omitted because it is calculated from alpha and b1. b3: decay rate for the exponentially weighted average of max grads. b4: decay rate for the exponentially weighted average of reward. eps: term added to the denominator to improve numerical stability. @@ -1331,11 +1313,16 @@ def update_fn(updates, state, params): pp = _update_moment(params, state.pp, b4, 1) pp_hat = _bias_correction(pp, b4, new_count) param_change = jax.tree.map(lambda p, i: p - i, params, pp_hat) - g_max = jax.tree.map(lambda g, n: jnp.maximum(jnp.abs(g), raise_power(n)), - updates, nu_hat) + g_max = jax.tree.map( + lambda g, n: jnp.maximum(jnp.abs(g), raise_power(n)), updates, nu_hat + ) gain = jax.tree.map( - lambda r, p, g, x: jnp.maximum(b3*r - p*g/(x + eps), 0.0), - state.gain, param_change, updates, g_max) + lambda r, p, g, x: jnp.maximum(b3 * r - p * g / (x + eps), 0.0), + state.gain, + param_change, + updates, + g_max, + ) wealth = jax.tree.map(lambda g: 1.0 + g, gain) bet_factor = jax.tree.map( @@ -1343,8 +1330,7 @@ def update_fn(updates, state, params): mu_hat, nu_hat, ) - new_updates = jax.tree.map(lambda b, w: b * w, - bet_factor, wealth) + new_updates = jax.tree.map(lambda b, w: b * w, bet_factor, wealth) return new_updates, ScaleByAdapropState( count=new_count, pp=pp, @@ -1523,7 +1509,6 @@ def scale_by_nadam( power to which the absolute value of the grads are raised). use_nesterov: whether to use nesterov update. - Returns: An (init_fn, update_fn) tuple. """ @@ -1813,12 +1798,11 @@ def project_parameters_by_norm( Args: projection_radius: The norm of the projected parameters. order: Order of the norm used for projection. Default is None, i.e. 2. - axis: Axes along which the projection happens. - Default is None, i.e. the entire parameter is projected. + axis: Axes along which the projection happens. Default is None, i.e. the + entire parameter is projected. flip_update_sign: Returns: - """ m = -1 if flip_update_sign else 1 diff --git a/init2winit/optimizer_lib/kitchen_sink/_src/utils.py b/init2winit/optimizer_lib/kitchen_sink/_src/utils.py index 77503971..87b34e29 100644 --- a/init2winit/optimizer_lib/kitchen_sink/_src/utils.py +++ b/init2winit/optimizer_lib/kitchen_sink/_src/utils.py @@ -68,8 +68,10 @@ def wrapped_init_fn(params): def wrapped_update_fn(updates, state, params=None): new_updates, state = update_fn( - flax.core.unfreeze(updates), state, - None if params is None else flax.core.unfreeze(params)) + flax.core.unfreeze(updates), + state, + None if params is None else flax.core.unfreeze(params), + ) if isinstance(updates, flax.core.FrozenDict): new_updates = flax.core.freeze(new_updates) @@ -95,8 +97,11 @@ def is_scale_by_lr(x): return not isinstance(x, str) and x['element'] == 'scale_by_learning_rate' def contains_lr_as_param(x): - return not isinstance(x, str) and x.get( - 'hps', None) and 'learning_rate' in x['hps'] + return ( + not isinstance(x, str) + and x.get('hps', None) + and 'learning_rate' in x['hps'] + ) def update_leaf(x): if contains_lr_as_param(x): @@ -113,14 +118,14 @@ def update_leaf(x): '0': config, '1': { 'element': 'scale_by_learning_rate', - 'hps': { - 'learning_rate': learning_rate - } - } + 'hps': {'learning_rate': learning_rate}, + }, } } elif num_scaled == 1: return map_element(update_leaf, config) else: - logging.warning('Kitchen Sink configuration has more than one ' - 'scale_by_learning_rate. Please double check config') + logging.warning( + 'Kitchen Sink configuration has more than one ' + 'scale_by_learning_rate. Please double check config' + ) diff --git a/init2winit/optimizer_lib/linalg/low_rank_root_update.py b/init2winit/optimizer_lib/linalg/low_rank_root_update.py index ece83085..b2740319 100644 --- a/init2winit/optimizer_lib/linalg/low_rank_root_update.py +++ b/init2winit/optimizer_lib/linalg/low_rank_root_update.py @@ -49,6 +49,7 @@ def chol_inv(m): """ eye = jnp.identity(m.shape[-1], dtype=jnp.float32) l_m = jnp.linalg.cholesky(m) + def _get_inv(): d_m = lax.linalg.triangular_solve( l_m, eye, left_side=True, lower=True, transpose_a=False @@ -62,10 +63,8 @@ def _get_zeros(): def lyapunov_solver( - a: chex.Array, - c: chex.Array, - eps: float, - num_terms: int = 5): + a: chex.Array, c: chex.Array, eps: float, num_terms: int = 5 +): """Solves the Sylvester equation ax + xa = c using R. A. Smith's method. Args: @@ -76,7 +75,6 @@ def lyapunov_solver( Returns: solution to the Lyapunov equation - """ a_norm = norm_upper_bound(a) a_norm = jnp.where(a_norm == 0, 1, a_norm) @@ -103,7 +101,7 @@ def _loop_body(k, val): xx = xx + vv @ (xx @ vv.T) return (vv, xx) - (vv, xx) = lax.fori_loop(0, num_terms, _loop_body, (vv, xx)) + vv, xx = lax.fori_loop(0, num_terms, _loop_body, (vv, xx)) del vv return alph * xx @@ -129,14 +127,14 @@ def matrix_sqrt_update( Returns: Gram factor U of the update to the matrix sqrt. - """ + # given A^{1/2} and an update of the form ZZ^T, returns U such that # (A + ZZ^T)^{1/2} = A^{1/2} + UU^T and # actual update is stat_update @ stat_update.T def block_krylov_basis(a, q, k): pp = (q,) - for _ in range(k-1): + for _ in range(k - 1): pp = pp + (a @ pp[-1],) ks = jnp.hstack(pp) return jnp.linalg.qr(ks)[0] @@ -226,7 +224,5 @@ def _update_inv_sqrt(sqrt_a, isqrt_a, u): block_krylov_dim_multiplier, rng=rng, ) - new_sqrt_x, new_isqrt_x = _update_inv_sqrt( - sqrt_x, isqrt_x, u - ) + new_sqrt_x, new_isqrt_x = _update_inv_sqrt(sqrt_x, isqrt_x, u) return new_sqrt_x, new_isqrt_x diff --git a/init2winit/optimizer_lib/linalg/low_rank_root_update_test.py b/init2winit/optimizer_lib/linalg/low_rank_root_update_test.py index 500dd171..a39c3dfa 100644 --- a/init2winit/optimizer_lib/linalg/low_rank_root_update_test.py +++ b/init2winit/optimizer_lib/linalg/low_rank_root_update_test.py @@ -25,17 +25,19 @@ import scipy.stats -def _small_perturbation(n: int, gamma: float, - rng: np.random.RandomState) -> np.ndarray: +def _small_perturbation( + n: int, gamma: float, rng: np.random.RandomState +) -> np.ndarray: """Returns a vector of absolute values ofnormally distributed values with standard deviation gamma.""" - s = gamma*np.abs(rng.normal(size=n)) + s = gamma * np.abs(rng.normal(size=n)) return s -def _random_singular_values(n: int, gamma: float, - rng: np.random.RandomState) -> np.ndarray: +def _random_singular_values( + n: int, gamma: float, rng: np.random.RandomState +) -> np.ndarray: """Returns n random singular values in [γ, 1].""" - s = gamma**rng.random((n,)) # log of singular values uniformly distributed + s = gamma ** rng.random((n,)) # log of singular values uniformly distributed if n > 0: s[0] = gamma if n > 1: @@ -44,8 +46,8 @@ def _random_singular_values(n: int, gamma: float, def _random_svd( - n: int, gamma: float, - rng: np.random.RandomState) -> Tuple[np.ndarray, np.ndarray]: + n: int, gamma: float, rng: np.random.RandomState +) -> Tuple[np.ndarray, np.ndarray]: """Returns a random SVD decomposition with singular values in [γ, 1].""" # sample a uniformly random orthogonal matrix. v = scipy.stats.ortho_group.rvs(n, random_state=rng) @@ -84,9 +86,9 @@ def test_random_matrix(self, n, p): v = v.astype(np.float64) q = _small_perturbation(n, 1e-4, rng) q = np.diag(q.astype(np.float64)) - a_sqrt = (v * s**(1 / p)) @ v.T - a_isqrt = (v * s**(-1 / p)) @ v.T - exact = (v * (s + q**2)**(-1 / p)) @ v.T + a_sqrt = (v * s ** (1 / p)) @ v.T + a_isqrt = (v * s ** (-1 / p)) @ v.T + exact = (v * (s + q**2) ** (-1 / p)) @ v.T ans = _update_sqrt(a_sqrt, a_isqrt, q)[1] ans = np.array(ans).astype(np.float64) error = np.linalg.norm(ans - exact, 2) / np.linalg.norm(exact, 2) diff --git a/init2winit/optimizer_lib/linalg/paterson_stockmeyer.py b/init2winit/optimizer_lib/linalg/paterson_stockmeyer.py index 310f38d9..5a285e4d 100644 --- a/init2winit/optimizer_lib/linalg/paterson_stockmeyer.py +++ b/init2winit/optimizer_lib/linalg/paterson_stockmeyer.py @@ -14,6 +14,7 @@ # limitations under the License. """Paterson-Stockmeyer method for polynomial evaluation.""" + from typing import Any, Callable, List, Sequence, TypeVar import numpy as np @@ -31,8 +32,9 @@ def _powers(x: T, n: int, product: Callable[[T, T], T]) -> List[T]: return xp[1:] -def polynomial_no_constant(a: Sequence[Any], x: T, product: Callable[[T, T], - T]) -> T: +def polynomial_no_constant( + a: Sequence[Any], x: T, product: Callable[[T, T], T] +) -> T: """Paterson-Stockmeyer evaluation of a[0] x + a[1] x² + ... + a[n-1] xⁿ. A variant of the Paterson-Stockmeyer method for polynomial evaluation @@ -76,7 +78,7 @@ def polynomial_no_constant(a: Sequence[Any], x: T, product: Callable[[T, T], s = int(np.ceil(np.sqrt(n))) xp = _powers(x, s, product) inner = lambda alpha: sum([cj * xj for (cj, xj) in zip(alpha, xp)]) - inner_poly = lambda i: inner(a[s * i:min(n, s * (i + 1))]) + inner_poly = lambda i: inner(a[s * i : min(n, s * (i + 1))]) i = (n + s - 1) // s - 1 y = inner_poly(i) for i in reversed(range(i)): diff --git a/init2winit/optimizer_lib/linalg/pth_inv_root_rmn.py b/init2winit/optimizer_lib/linalg/pth_inv_root_rmn.py index 72049772..8c9e4505 100644 --- a/init2winit/optimizer_lib/linalg/pth_inv_root_rmn.py +++ b/init2winit/optimizer_lib/linalg/pth_inv_root_rmn.py @@ -75,7 +75,7 @@ def _scalar_inverse_root(x: chex.Array, n: int) -> chex.Array: elif n == 8: return lax.rsqrt(lax.sqrt(lax.sqrt(x))) else: - r = x**(1 / n) + r = x ** (1 / n) # One step of Newton's method to polish the root r = ((n - 1) / n) * r + (x / n) / _scalar_power(r, n - 1) return 1 / r @@ -125,8 +125,8 @@ def pth_inv_root_rmn( x: Input matrix must be SPD with eigenvalues >= float32 epsilon. p: Exponent. fast_root: If True, use a lower mixed degree approximation of x^{1/p} (in - the slower version, we use a mixed degree approximation of 2 and 3. - In the fast version, degree 2 and 3 is used). + the slower version, we use a mixed degree approximation of 2 and 3. In the + fast version, degree 2 and 3 is used). precision: Matrix multiplication precision to use (on TPU). See `jax.default_matmul_precision`. stable_iter: Whether to use the stable iteration for the inner loop. @@ -153,7 +153,7 @@ def pth_inv_root_rmn( ais, bis, cis = lax.cond( fast_root, lambda: pth_inv_root_rmn_coefficients.r12_schedule(p), - lambda: pth_inv_root_rmn_coefficients.r23_schedule(p) + lambda: pth_inv_root_rmn_coefficients.r23_schedule(p), ) max_k = len(cis) @@ -161,10 +161,11 @@ def pth_inv_root_rmn( def chol_inv(m): l_m = jnp.linalg.cholesky(m) e = jnp.finfo(jnp.float32).eps + def _avoid_nan_body(val): - (_, e, m) = val + _, e, m = val l = jnp.linalg.cholesky(m + jnp.diag(e * jnp.diag(m))) - e = 2*e + e = 2 * e return (l, e, m) l_m, _, _ = lax.while_loop( @@ -194,6 +195,7 @@ def general_iteration(k, val, p, symm_inv=chol_inv): if stable_iter: w_n = jnp.zeros((n, n), jnp.float32) y_inv = lax.cond(k == 0, lambda m: m, symm_inv, y) + def inside_iter(s, w_sum): a2 = lax.dynamic_index_in_dim(a, s, keepdims=False) b2 = lax.dynamic_index_in_dim(b, s, keepdims=False) @@ -205,6 +207,7 @@ def inside_iter(s, w_sum): y = c * (y + w_n) else: w = jnp.zeros((n, n), jnp.float32) + def inside_iter2(s, w_sum): a2 = lax.dynamic_index_in_dim(a, s, keepdims=False) b2 = lax.dynamic_index_in_dim(b, s, keepdims=False) @@ -247,4 +250,4 @@ def inside_iter2(s, w_sum): # inverse is computed as square of the inverse square root with jax.default_matmul_precision(_get_precision_string(precision)): return alpha * x, beta * val[1] @ val[1] - return 1/beta * val[0], beta * val[1] + return 1 / beta * val[0], beta * val[1] diff --git a/init2winit/optimizer_lib/linalg/pth_inv_root_rmn_coefficients.py b/init2winit/optimizer_lib/linalg/pth_inv_root_rmn_coefficients.py index cf3b5f3c..944debbd 100644 --- a/init2winit/optimizer_lib/linalg/pth_inv_root_rmn_coefficients.py +++ b/init2winit/optimizer_lib/linalg/pth_inv_root_rmn_coefficients.py @@ -15,11 +15,11 @@ """Coefficients for pth inverse root iteration. - The faster schedule allows for a faster convergence of the coupled - iteration. The faster schedule is a R_{1,1} approximation to x^{1/p}. For - matrices with condition number < 1e+7, this converges in 3 Cholesky steps. - The slower schedule is a mixed R_{2,2} andR_{3,3} approximation to x^{1/p}. - It takes a total of 5 serial Cholesky steps or 2 parallel Cholesky steps. +The faster schedule allows for a faster convergence of the coupled +iteration. The faster schedule is a R_{1,1} approximation to x^{1/p}. For +matrices with condition number < 1e+7, this converges in 3 Cholesky steps. +The slower schedule is a mixed R_{2,2} andR_{3,3} approximation to x^{1/p}. +It takes a total of 5 serial Cholesky steps or 2 parallel Cholesky steps. """ from jax import numpy as jnp @@ -176,7 +176,7 @@ def r11_schedule(p: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ]), "alpha": jnp.array([ 0.01760681686165901, - ]) + ]), }, "6": { "a": jnp.array([ @@ -208,7 +208,7 @@ def r11_schedule(p: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ]), "alpha": jnp.array([ 0.06767899452107008, - ]) + ]), }, "8": { "a": jnp.array([ @@ -240,7 +240,7 @@ def r11_schedule(p: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ]), "alpha": jnp.array([ 0.13269068114098673, - ]) + ]), }, } @@ -409,7 +409,7 @@ def r12_schedule(p: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ]), "alpha": jnp.array([ 0.01760681686165901, - ]) + ]), }, "6": { "a": jnp.array([ @@ -442,7 +442,7 @@ def r12_schedule(p: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ]), "alpha": jnp.array([ 0.06767899452107008, - ]) + ]), }, "8": { "a": jnp.array([ @@ -475,7 +475,7 @@ def r12_schedule(p: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ]), "alpha": jnp.array([ 0.13269068114098673, - ]) + ]), }, } @@ -546,7 +546,7 @@ def r23_schedule(p: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ]), "alpha": jnp.array([ 0.00031, - ]) + ]), }, "2": { "a": jnp.array([ @@ -579,7 +579,7 @@ def r23_schedule(p: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ]), "alpha": jnp.array([ 0.00031, - ]) + ]), }, "3": { "a": jnp.array([ @@ -612,7 +612,7 @@ def r23_schedule(p: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ]), "alpha": jnp.array([ 0.00458, - ]) + ]), }, "4": { "a": jnp.array([ @@ -645,7 +645,7 @@ def r23_schedule(p: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ]), "alpha": jnp.array([ 0.0178, - ]) + ]), }, "6": { "a": jnp.array([ @@ -678,7 +678,7 @@ def r23_schedule(p: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ]), "alpha": jnp.array([ 0.06817314583299908, - ]) + ]), }, "8": { "a": jnp.array([ @@ -711,7 +711,7 @@ def r23_schedule(p: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: ]), "alpha": jnp.array([ 0.13341664064126335, - ]) + ]), }, } @@ -723,4 +723,3 @@ def r23_schedule(p: int) -> tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: pth_inv_root_coeffs[indexstr]["b"], pth_inv_root_coeffs[indexstr]["c"], ) - diff --git a/init2winit/optimizer_lib/linalg/pth_inv_root_rmn_test.py b/init2winit/optimizer_lib/linalg/pth_inv_root_rmn_test.py index 16978e28..b5261150 100644 --- a/init2winit/optimizer_lib/linalg/pth_inv_root_rmn_test.py +++ b/init2winit/optimizer_lib/linalg/pth_inv_root_rmn_test.py @@ -27,10 +27,11 @@ import scipy.stats -def _random_singular_values(n: int, gamma: float, - rng: np.random.RandomState) -> np.ndarray: +def _random_singular_values( + n: int, gamma: float, rng: np.random.RandomState +) -> np.ndarray: """Returns n random singular values in [γ, 1].""" - s = gamma**rng.random((n,)) # log of singular values uniformly distributed + s = gamma ** rng.random((n,)) # log of singular values uniformly distributed if n > 0: s[0] = gamma if n > 1: @@ -39,8 +40,8 @@ def _random_singular_values(n: int, gamma: float, def _random_svd( - n: int, gamma: float, - rng: np.random.RandomState) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + n: int, gamma: float, rng: np.random.RandomState +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Returns a random SVD decomposition with singular values in [γ, 1].""" # sample a uniformly random orthogonal matrix. u = scipy.stats.ortho_group.rvs(n, random_state=rng) @@ -56,11 +57,15 @@ def _root(x, p): class PthInvRootTest(parameterized.TestCase): - @parameterized.named_parameters({ # pylint:disable=g-complex-comprehension - 'testcase_name': f'n={n}_p={p}', - 'n': n, # pylint: disable=undefined-variable - 'p': p, # pylint: disable=undefined-variable - } for n in [2, 31] for p in [1, 2, 3, 4, 6, 8]) + @parameterized.named_parameters( + { # pylint:disable=g-complex-comprehension + 'testcase_name': f'n={n}_p={p}', + 'n': n, # pylint: disable=undefined-variable + 'p': p, # pylint: disable=undefined-variable + } + for n in [2, 31] + for p in [1, 2, 3, 4, 6, 8] + ) def test_zero_matrix(self, n, p): a = jnp.zeros((n, n), dtype=jnp.float32) x = _root(a, p) @@ -81,12 +86,12 @@ def test_random_matrix(self, n, p, c): rng = np.random.RandomState(seed=42) for k in range(6): - sigma = 10**(-k - 1) # smallest singular value of test matrix + sigma = 10 ** (-k - 1) # smallest singular value of test matrix _, s, v = _random_svd(n, sigma, rng) - s = s.astype(np.float64) * (1e6 ** c) # c tests different matrix scalings + s = s.astype(np.float64) * (1e6**c) # c tests different matrix scalings v = v.astype(np.float64) a = jnp.array((v * s) @ v.T, jnp.float32) - exact = (v * s**(-1 / p)) @ v.T + exact = (v * s ** (-1 / p)) @ v.T x = _root(a, p) x = np.array(x).astype(np.float64) error = np.linalg.norm(x - exact, 2) / np.linalg.norm(exact, 2) @@ -108,12 +113,12 @@ def test_singular_matrix(self, n, p): rng = np.random.RandomState(seed=42) for k in range(6): - sigma = 10**(-k - 1) # smallest singular value of test matrix + sigma = 10 ** (-k - 1) # smallest singular value of test matrix _, s, v = _random_svd(n, sigma, rng) s = s.astype(np.float64) v = v.astype(np.float64) a = jnp.array((v * s) @ v.T, jnp.float32) - exact = (v * s**(-1 / p)) @ v.T + exact = (v * s ** (-1 / p)) @ v.T x = _root(a, p) x = np.array(x).astype(np.float64) error = np.linalg.norm(x - exact, 2) / np.linalg.norm(exact, 2) @@ -122,16 +127,19 @@ def test_singular_matrix(self, n, p): expected_error = 6 * kappa * np.finfo(np.float32).eps self.assertLessEqual(error, expected_error) - @parameterized.named_parameters({ # pylint:disable=g-complex-comprehension - 'testcase_name': '_p={}'.format(p), - 'p': p # pylint: disable=undefined-variable - } for p in [1, 2, 3, 4, 6, 8]) + @parameterized.named_parameters( + { # pylint:disable=g-complex-comprehension + 'testcase_name': '_p={}'.format(p), + 'p': p, # pylint: disable=undefined-variable + } + for p in [1, 2, 3, 4, 6, 8] + ) def test_random_diagonal_matrix(self, p): n = 16 eps = np.finfo(np.float32).eps rng = np.random.RandomState(seed=37) s = _random_singular_values(n, eps, rng) - exact = s.astype(np.float64)**(-1 / p) + exact = s.astype(np.float64) ** (-1 / p) x = _root(np.diag(s).astype(np.float32), p).astype(np.float64) # since the matrix is diagonal, the error should be small despite the # large condition number diff --git a/init2winit/optimizer_lib/linalg/root_selector.py b/init2winit/optimizer_lib/linalg/root_selector.py index 2a4dd170..025b422e 100644 --- a/init2winit/optimizer_lib/linalg/root_selector.py +++ b/init2winit/optimizer_lib/linalg/root_selector.py @@ -72,11 +72,13 @@ def root_selector( def _exact_root(): return f_er(x, p) - rank_array = np.zeros(np.where( - x.shape[-1] < 64, - x.shape[-1], - np.minimum(x.shape[-1] // 8, rank_estimate), - )) + rank_array = np.zeros( + np.where( + x.shape[-1] < 64, + x.shape[-1], + np.minimum(x.shape[-1] // 8, rank_estimate), + ) + ) f_ar = functools.partial( low_rank_root_update.low_rank_root_update, rank_array=rank_array, @@ -84,6 +86,7 @@ def _exact_root(): block_krylov_dim_multiplier=block_krylov_dim_multiplier, verbose=verbose, ) + def _approx_root(): return f_ar(sx, isx, up) diff --git a/init2winit/optimizer_lib/online_newton_step.py b/init2winit/optimizer_lib/online_newton_step.py index 2bd495a9..453d8b55 100644 --- a/init2winit/optimizer_lib/online_newton_step.py +++ b/init2winit/optimizer_lib/online_newton_step.py @@ -24,11 +24,13 @@ import optax -def diag_ons(learning_rate, - weight_decay: float = 0.0, - b1: float = 0.9, - b2: float = 0.999, - eps: float = 1e-8): +def diag_ons( + learning_rate, + weight_decay: float = 0.0, + b1: float = 0.9, + b2: float = 0.999, + eps: float = 1e-8, +): """The diagonal version of Online Newton Step with flexible updates. Args: @@ -46,26 +48,36 @@ def diag_ons(learning_rate, # Diag ONS without momentum and second moment decay return optax.chain( kitchen_sink.precondition_by_rss(eps=eps, power=1.0), - optax.add_decayed_weights(weight_decay), optax.scale(learning_rate)) + optax.add_decayed_weights(weight_decay), + optax.scale(learning_rate), + ) elif b1 == 1.0 and b2 != 1.0: # Diag ONS without momentum but with second moment decay return optax.chain( kitchen_sink.precondition_by_rms( - decay=b2, eps=eps, eps_root=0.0, power=1.0), - optax.add_decayed_weights(weight_decay), optax.scale(learning_rate)) + decay=b2, eps=eps, eps_root=0.0, power=1.0 + ), + optax.add_decayed_weights(weight_decay), + optax.scale(learning_rate), + ) elif b1 != 1.0 and b2 != 1.0: # Diag ONS with momentum and second moment decay return optax.chain( kitchen_sink.scale_by_adam(b1, b2, eps, eps_root=0.0, power=1.0), - optax.add_decayed_weights(weight_decay), optax.scale(learning_rate)) + optax.add_decayed_weights(weight_decay), + optax.scale(learning_rate), + ) -def last_layer_transformation(last_layer_optimizer, base_lr, - last_layer_base_lr, learning_rate): +def last_layer_transformation( + last_layer_optimizer, base_lr, last_layer_base_lr, learning_rate +): """Use an optimizer while scaling by a different learning rate.""" - return optax.chain(last_layer_optimizer, - optax.scale(learning_rate * last_layer_base_lr / base_lr)) + return optax.chain( + last_layer_optimizer, + optax.scale(learning_rate * last_layer_base_lr / base_lr), + ) def sherman_morrison(a_inv, u, alpha): @@ -78,6 +90,7 @@ def sherman_morrison(a_inv, u, alpha): class OnlineNewtonState(NamedTuple): """State holding the sum of gradient squares to date.""" + inv_hessian: optax.Updates @@ -87,7 +100,8 @@ def full_matrix_ons(alpha, initial_accumulator_value=0.1): def init_fn(params): raveled_params, _ = jax.flatten_util.ravel_pytree(params) initial_hessian = jnp.diag( - jnp.full_like(raveled_params, 1. / initial_accumulator_value)) + jnp.full_like(raveled_params, 1.0 / initial_accumulator_value) + ) return OnlineNewtonState(inv_hessian=initial_hessian) @@ -107,33 +121,42 @@ def online_newton_step(learning_rate, alpha, weight_decay): r"""An optimizer that does full matrix preconditioning.""" return optax.chain( - optax.add_decayed_weights(weight_decay), full_matrix_ons(alpha), - optax.sgd(learning_rate)) - - -def multiple_optimizer(last_layer_name, network_optimizer, last_layer_optimizer, - last_layer_base_lr, base_lr): + optax.add_decayed_weights(weight_decay), + full_matrix_ons(alpha), + optax.sgd(learning_rate), + ) + + +def multiple_optimizer( + last_layer_name, + network_optimizer, + last_layer_optimizer, + last_layer_base_lr, + base_lr, +): """Use a different optimizer for the last layer.""" def get_select_fn(layer_name): """Get a function that selects the specified layer as last layer.""" def select_layer(tree): - return {k: ('ll' if k == layer_name else 'net') for k, v in tree.items()} + return {k: 'll' if k == layer_name else 'net' for k, v in tree.items()} return select_layer - return kitchen_sink.unfreeze_wrapper(*optax.multi_transform( - { - 'net': - network_optimizer, - # Scale the learning rate of the last layer according to match - # last_layer_base_lr - 'll': - utils.static_inject_hyperparams(last_layer_transformation) - (last_layer_optimizer, - base_lr, - last_layer_base_lr, - learning_rate=0.0), - }, - get_select_fn(last_layer_name))) + return kitchen_sink.unfreeze_wrapper( + *optax.multi_transform( + { + 'net': network_optimizer, + # Scale the learning rate of the last layer according to match + # last_layer_base_lr + 'll': utils.static_inject_hyperparams(last_layer_transformation)( + last_layer_optimizer, + base_lr, + last_layer_base_lr, + learning_rate=0.0, + ), + }, + get_select_fn(last_layer_name), + ) + ) diff --git a/init2winit/optimizer_lib/optimizers.py b/init2winit/optimizer_lib/optimizers.py index 9db79ccc..13dfdb54 100644 --- a/init2winit/optimizer_lib/optimizers.py +++ b/init2winit/optimizer_lib/optimizers.py @@ -19,6 +19,7 @@ from absl import logging import flax + from init2winit.model_lib.model_utils import ParameterType # pylint: disable=g-importing-member from init2winit.optimizer_lib import gradient_accumulator from init2winit.optimizer_lib import kitchen_sink @@ -30,8 +31,8 @@ from init2winit.optimizer_lib import utils import jax import jax.numpy as jnp -import optax +import optax @@ -116,7 +117,9 @@ def sgd(learning_rate, weight_decay, momentum=None, nesterov=False): return optax.chain( optax.add_decayed_weights(weight_decay), optax.sgd( - learning_rate=learning_rate, momentum=momentum, nesterov=nesterov)) + learning_rate=learning_rate, momentum=momentum, nesterov=nesterov + ), + ) def get_optimizer(hps, model=None, batch_axis_name=None): @@ -226,8 +229,8 @@ def get_optimizer(hps, model=None, batch_axis_name=None): if ParameterType.DEFAULT.value not in param_type_to_optimizer_and_hparams: raise ValueError( - f'Fallback default optimizer not found in param_type_to_grad_tx.' - f' Please add a fallback optimizer to param_type_to_grad_tx =' + 'Fallback default optimizer not found in param_type_to_grad_tx.' + ' Please add a fallback optimizer to param_type_to_grad_tx =' f' {param_type_to_optimizer_and_hparams}' ) param_type_to_grad_tx = {} @@ -239,8 +242,8 @@ def get_optimizer(hps, model=None, batch_axis_name=None): opt_hparams.opt_hparams.update(hparams_to_merge) hps_copy.update(opt_hparams) logging.info('HPS_COPY %s', hps_copy) - param_type_to_grad_tx[param_type] = ( - optax.GradientTransformation(*get_optimizer(hps_copy, model)) + param_type_to_grad_tx[param_type] = optax.GradientTransformation( + *get_optimizer(hps_copy, model) ) param_to_type = model.params_types @@ -269,35 +272,43 @@ def get_optimizer(hps, model=None, batch_axis_name=None): hps_last_layer['l2_decay_factor'] = None network_optimizer = optax.GradientTransformation( - *get_optimizer(hps_network)) + *get_optimizer(hps_network) + ) last_layer_optimizer = optax.GradientTransformation( - *get_optimizer(hps_last_layer)) + *get_optimizer(hps_last_layer) + ) opt_init, opt_update = online_newton_step.multiple_optimizer( last_layer_name=hps.opt_hparams['last_layer_name'], network_optimizer=network_optimizer, last_layer_optimizer=last_layer_optimizer, last_layer_base_lr=hps.opt_hparams['last_layer_base_lr'], - base_lr=hps.lr_hparams['base_lr']) + base_lr=hps.lr_hparams['base_lr'], + ) elif hps.optimizer == 'online_newton_step': opt_init, opt_update = utils.static_inject_hyperparams( - online_newton_step.online_newton_step)( - learning_rate=0.0, # Manually injected on each train step. - alpha=hps.opt_hparams['alpha'], - weight_decay=weight_decay) + online_newton_step.online_newton_step + )( + learning_rate=0.0, # Manually injected on each train step. + alpha=hps.opt_hparams['alpha'], + weight_decay=weight_decay, + ) elif hps.optimizer == 'diag_ons': opt_init, opt_update = utils.static_inject_hyperparams( - online_newton_step.diag_ons)( - learning_rate=1.0, # Set to 1.0 to use as a last layer optimizer. - weight_decay=weight_decay, - b1=hps.opt_hparams['beta1'], - b2=hps.opt_hparams['beta2']) + online_newton_step.diag_ons + )( + learning_rate=1.0, # Set to 1.0 to use as a last layer optimizer. + weight_decay=weight_decay, + b1=hps.opt_hparams['beta1'], + b2=hps.opt_hparams['beta2'], + ) elif hps.optimizer == 'momentum' or hps.optimizer == 'nesterov': opt_init, opt_update = utils.static_inject_hyperparams(sgd)( learning_rate=0.0, # Manually injected on each train step. weight_decay=weight_decay, momentum=hps.opt_hparams['momentum'], - nesterov=(hps.optimizer == 'nesterov')) + nesterov=(hps.optimizer == 'nesterov'), + ) elif hps.optimizer == 'tearfree': sketch_size = hps.opt_hparams.get('sketchy_rank') if sketch_size is not None and sketch_size > 0: @@ -383,15 +394,17 @@ def get_optimizer(hps, model=None, batch_axis_name=None): b1=hps.opt_hparams['beta1'], b2=hps.opt_hparams['beta2'], eps=hps.opt_hparams['epsilon'], - weight_decay=weight_decay) + weight_decay=weight_decay, + ) elif hps.optimizer == 'adafactor': opt_init, opt_update = utils.static_inject_hyperparams(optax.adafactor)( learning_rate=0.0, min_dim_size_to_factor=hps.opt_hparams['min_dim_size_to_factor'], decay_rate=hps.opt_hparams['adafactor_decay_rate'], decay_offset=hps.opt_hparams['decay_offset'], - multiply_by_parameter_scale=hps - .opt_hparams['multiply_by_parameter_scale'], + multiply_by_parameter_scale=hps.opt_hparams[ + 'multiply_by_parameter_scale' + ], clipping_threshold=hps.opt_hparams['clipping_threshold'], momentum=hps.opt_hparams['momentum'], weight_decay_rate=weight_decay, @@ -440,18 +453,14 @@ def get_optimizer(hps, model=None, batch_axis_name=None): safeguard_warmup=hps.opt_hparams['safeguard_warmup'], ) elif hps.optimizer == 'cocob': - opt_init, opt_update = utils.static_inject_hyperparams( - optax.contrib.cocob - )( + opt_init, opt_update = utils.static_inject_hyperparams(optax.contrib.cocob)( learning_rate=0.0, weight_decay=weight_decay, alpha=hps.opt_hparams['alpha'], eps=hps.opt_hparams['eps'], ) elif hps.optimizer == 'momo': - opt_init, opt_update = utils.static_inject_hyperparams( - optax.contrib.momo - )( + opt_init, opt_update = utils.static_inject_hyperparams(optax.contrib.momo)( learning_rate=0.0, beta=hps.opt_hparams['beta'], lower_bound=hps.opt_hparams['lower_bound'], @@ -473,18 +482,14 @@ def get_optimizer(hps, model=None, batch_axis_name=None): ) optimizer_requires_value = True elif hps.optimizer == 'dog': - opt_init, opt_update = utils.static_inject_hyperparams( - optax.contrib.dog - )( + opt_init, opt_update = utils.static_inject_hyperparams(optax.contrib.dog)( learning_rate=0.0, reps_rel=hps.opt_hparams['reps_rel'], eps=hps.opt_hparams['eps'], weight_decay=hps.opt_hparams['weight_decay'], ) elif hps.optimizer == 'dowg': - opt_init, opt_update = utils.static_inject_hyperparams( - optax.contrib.dowg - )( + opt_init, opt_update = utils.static_inject_hyperparams(optax.contrib.dowg)( learning_rate=0.0, eps=hps.opt_hparams['eps'], weight_decay=hps.opt_hparams['weight_decay'], @@ -492,8 +497,8 @@ def get_optimizer(hps, model=None, batch_axis_name=None): elif hps.optimizer == 'kitchen_sink': opt_init, opt_update = utils.static_inject_hyperparams( - kitchen_sink.kitchen_sink)( - learning_rate=0.0, config=hps.opt_hparams) + kitchen_sink.kitchen_sink + )(learning_rate=0.0, config=hps.opt_hparams) elif hps.optimizer == 'samuel': opt_init, opt_update = samuel.from_hparams(hps.opt_hparams) @@ -511,7 +516,8 @@ def get_optimizer(hps, model=None, batch_axis_name=None): total_batch_size=hps.total_accumulated_batch_size, virtual_batch_size=virtual_batch_size, base_opt_init_fn=opt_init, - base_opt_update_fn=opt_update) + base_opt_update_fn=opt_update, + ) if hps.opt_hparams.get('use_sam', False): opt_init, opt_update = ( @@ -586,14 +592,16 @@ def _wrap_update_fn( """ del opt_name - def update_fn(grads, - optimizer_state, - params, - batch=None, - batch_stats=None, - cost_fn=None, - grad_fn=None, - value=None): + def update_fn( + grads, + optimizer_state, + params, + batch=None, + batch_stats=None, + cost_fn=None, + grad_fn=None, + value=None, + ): del batch, batch_stats if send_grad_fn and send_cost_fn: # Note that `value_and_grad` already returns the cost, so there is no need @@ -601,10 +609,12 @@ def update_fn(grads, raise ValueError('send_grad_fn and send_cost_fn must not both be True.') if send_grad_fn: return opt_update( - grads, optimizer_state, grad_fn_params_tuple=(grad_fn, params)) + grads, optimizer_state, grad_fn_params_tuple=(grad_fn, params) + ) elif send_cost_fn: return opt_update( - grads, optimizer_state, cost_fn_params_tuple=(cost_fn, params)) + grads, optimizer_state, cost_fn_params_tuple=(cost_fn, params) + ) elif send_value: return opt_update( grads, diff --git a/init2winit/optimizer_lib/pax_adafactor.py b/init2winit/optimizer_lib/pax_adafactor.py index 53671a87..7e474f20 100644 --- a/init2winit/optimizer_lib/pax_adafactor.py +++ b/init2winit/optimizer_lib/pax_adafactor.py @@ -23,7 +23,6 @@ """ import dataclasses - import functools import re from typing import Any, NamedTuple, Optional, Tuple, Union @@ -32,14 +31,14 @@ from jax import numpy as jnp import optax - JTensor = Any NestedJTensor = Any NestedHParams = Any -def to_quantized(fvalue: JTensor, - quantized_dtype: jnp.dtype) -> Tuple[JTensor, JTensor]: +def to_quantized( + fvalue: JTensor, quantized_dtype: jnp.dtype +) -> Tuple[JTensor, JTensor]: """Converts floating point values `fvalues` to quantized values. We use a very simple quantization scheme where the range is symmetric around @@ -83,14 +82,16 @@ def to_quantized(fvalue: JTensor, if fvalue.ndim < 1: raise ValueError( f'Input array {fvalue} must have a strictly positive number of ' - 'dimensions.') + 'dimensions.' + ) max_abs = jnp.max(jnp.abs(fvalue), axis=0) bucket_size = max_abs / num_buckets bs_expanded = bucket_size[jnp.newaxis, ...] # To avoid divide by 0.0 - bs_nonzero = jnp.where(bs_expanded > 0.0, bs_expanded, - jnp.ones_like(bs_expanded)) + bs_nonzero = jnp.where( + bs_expanded > 0.0, bs_expanded, jnp.ones_like(bs_expanded) + ) ratio = fvalue / bs_nonzero # We use rounding to remove bias. quantized = jnp.round(ratio) @@ -127,8 +128,8 @@ def adafactor_decay_rate_adam(beta2: float, step_counter: JTensor) -> JTensor: """ step = step_counter beta2 = jnp.array(beta2, dtype=jnp.float32) - t = step + 1. - return beta2 * (1. - jnp.power(beta2, t - 1.)) / (1. - jnp.power(beta2, t)) + t = step + 1.0 + return beta2 * (1.0 - jnp.power(beta2, t - 1.0)) / (1.0 - jnp.power(beta2, t)) def adafactor_decay_rate_pow(exponent: float, step_counter: JTensor) -> JTensor: @@ -144,7 +145,7 @@ def adafactor_decay_rate_pow(exponent: float, step_counter: JTensor) -> JTensor: """ step = step_counter exponent = jnp.array(exponent, dtype=jnp.float32) - return 1. - jnp.power((step + 1.), -exponent) + return 1.0 - jnp.power((step + 1.0), -exponent) def reduce_mean(array: JTensor) -> JTensor: @@ -186,6 +187,7 @@ def reduce_rms(array: JTensor) -> JTensor: @dataclasses.dataclass(frozen=True) class _ShardedAdafactorUpdateResult: """Structure containing per-variable info for Adafactor.""" + update: Optional[Any] m: Optional[Any] m_scale: Optional[Any] @@ -196,6 +198,7 @@ class _ShardedAdafactorUpdateResult: class ShardedAdafactorState(NamedTuple): """Overall state of the ShardedAdafactor optimizer.""" + count: JTensor m: Optional[NestedJTensor] m_scale: Optional[NestedJTensor] @@ -228,7 +231,8 @@ def __init__( multiply_by_parameter_scale: bool, epsilon2_param_scale_reg: float, maybe_inf_to_nan: bool, - nesterov: bool) -> None: + nesterov: bool, + ) -> None: """Constructor. See ShardedAdafactor() below.""" self._learning_rate = learning_rate @@ -320,7 +324,8 @@ def to_state(self, count, result_tree): m_scale=jax.tree.map(lambda o: o.m_scale, result_tree), vr=jax.tree.map(lambda o: o.vr, result_tree), vc=jax.tree.map(lambda o: o.vc, result_tree), - v=jax.tree.map(lambda o: o.v, result_tree)) + v=jax.tree.map(lambda o: o.v, result_tree), + ) def init(self, param): """Initializes the optimizer state for a given param.""" @@ -358,7 +363,8 @@ def init(self, param): m_scale=output_m_scale, vr=output_vr, vc=output_vc, - v=output_v) + v=output_v, + ) def inf_to_nan(self, array): """Converting Infinity values to the more sticky NaN.""" @@ -386,8 +392,9 @@ def parameter_scale(self, var): """ return jnp.maximum(reduce_rms(var), jnp.asarray(self._epsilon2, var.dtype)) - def compute_var_and_slot_update(self, count, grad, m, m_scale, vr, vc, v, - param, var_name=None): + def compute_var_and_slot_update( + self, count, grad, m, m_scale, vr, vc, v, param, var_name=None + ): """Computes the var and optimizer slots updates for a single variable.""" # We can probably skip this step grad = grad.astype(jnp.float32) @@ -424,7 +431,7 @@ def compute_var_and_slot_update(self, count, grad, m, m_scale, vr, vc, v, update_scale += grad_squared_mean * 1e-30 # END HACK - mixing_rate = 1. - decay_rate + mixing_rate = 1.0 - decay_rate shape = param.shape output_m = jnp.zeros((1,)) @@ -439,18 +446,23 @@ def compute_var_and_slot_update(self, count, grad, m, m_scale, vr, vc, v, # reduce_mean(). vr_axis, vc_axis = factored_second_moment_dims grad_squared_row_mean = self.inf_to_nan( - jnp.mean(grad_squared, axis=vr_axis)) + jnp.mean(grad_squared, axis=vr_axis) + ) grad_squared_col_mean = self.inf_to_nan( - jnp.mean(grad_squared, axis=vc_axis)) + jnp.mean(grad_squared, axis=vc_axis) + ) new_vr = decay_rate * vr + mixing_rate * grad_squared_row_mean new_vc = decay_rate * vc + mixing_rate * grad_squared_col_mean output_vr = new_vr output_vc = new_vc long_term_mean = jnp.mean(new_vr, axis=-1, keepdims=True) - r_factor = 1. / jnp.sqrt(new_vr / long_term_mean) - c_factor = 1. / jnp.sqrt(new_vc) - x = grad * jnp.expand_dims(r_factor, vr_axis) * jnp.expand_dims( - c_factor, vc_axis) + r_factor = 1.0 / jnp.sqrt(new_vr / long_term_mean) + c_factor = 1.0 / jnp.sqrt(new_vc) + x = ( + grad + * jnp.expand_dims(r_factor, vr_axis) + * jnp.expand_dims(c_factor, vc_axis) + ) else: # v with sharding annotation. new_v = decay_rate * v + mixing_rate * grad_squared @@ -458,7 +470,7 @@ def compute_var_and_slot_update(self, count, grad, m, m_scale, vr, vc, v, x = grad / jnp.sqrt(new_v) if self._clip_threshold is not None: - clipping_denom = jnp.maximum(1., reduce_rms(x) / self._clip_threshold) + clipping_denom = jnp.maximum(1.0, reduce_rms(x) / self._clip_threshold) clipping_denom = self.inf_to_nan(clipping_denom) x /= clipping_denom @@ -471,7 +483,7 @@ def compute_var_and_slot_update(self, count, grad, m, m_scale, vr, vc, v, m = to_float(m, m_scale) if self._nesterov: subtrahend_original = subtrahend - subtrahend = self._beta1 * m + (1. - self._beta1) * subtrahend + subtrahend = self._beta1 * m + (1.0 - self._beta1) * subtrahend subtrahend = self.inf_to_nan(subtrahend) if self._quantized_dtype == jnp.bfloat16: new_m = subtrahend.astype(jnp.bfloat16) @@ -518,7 +530,9 @@ def compute_var_and_slot_update(self, count, grad, m, m_scale, vr, vc, v, ratio = w_norm / g_norm ratio = jnp.where( jnp.greater(w_norm, 0), - jnp.where(jnp.greater(g_norm, 0), (w_norm / g_norm), 1.0), 1.0) + jnp.where(jnp.greater(g_norm, 0), (w_norm / g_norm), 1.0), + 1.0, + ) subtrahend *= ratio return _ShardedAdafactorUpdateResult( @@ -527,7 +541,8 @@ def compute_var_and_slot_update(self, count, grad, m, m_scale, vr, vc, v, m_scale=output_m_scale, vr=output_vr, vc=output_vc, - v=output_v) + v=output_v, + ) def sharded_adafactor( @@ -535,10 +550,10 @@ def sharded_adafactor( weight_decay: Optional[Union[float, dict[str, float]]] = None, layerwise_adaptation: bool = False, decay_method: str = '', - decay_adam: float = 0., - decay_pow: float = 0., - beta1: float = 0., - clip_threshold: Optional[float] = 1., + decay_adam: float = 0.0, + decay_pow: float = 0.0, + beta1: float = 0.0, + clip_threshold: Optional[float] = 1.0, factored: bool = True, epsilon1_grad_sq_reg: float = 1e-30, quantized_dtype: jnp.dtype = jnp.int8, @@ -628,7 +643,8 @@ def sharded_adafactor( assert learning_rate is not None assert decay_method == 'adam' or decay_method == 'pow', ( f'decay_method: {decay_method} not supported. Supported methods are ' - '"pow", or "adam".') + '"pow", or "adam".' + ) sharded_adafactor_helper = _ShardedAdafactorHelper( learning_rate=learning_rate, @@ -650,24 +666,36 @@ def sharded_adafactor( multiply_by_parameter_scale=multiply_by_parameter_scale, epsilon2_param_scale_reg=epsilon2_param_scale_reg, maybe_inf_to_nan=maybe_inf_to_nan, - nesterov=nesterov) + nesterov=nesterov, + ) def init_fn(params): """Initializes the optimizer's state.""" return sharded_adafactor_helper.to_state( jnp.zeros([], jnp.int32), - jax.tree.map(sharded_adafactor_helper.init, params)) + jax.tree.map(sharded_adafactor_helper.init, params), + ) def update_fn(updates, state, params=None): if params is None: raise ValueError( 'You are using a transformation that requires the current value of ' - 'parameters, but you are not passing `params` when calling `update`.') + 'parameters, but you are not passing `params` when calling `update`.' + ) compute_var_and_slot_update_fn = functools.partial( - sharded_adafactor_helper.compute_var_and_slot_update, state.count) - output = jax.tree.map(compute_var_and_slot_update_fn, updates, state.m, - state.m_scale, state.vr, state.vc, state.v, params) + sharded_adafactor_helper.compute_var_and_slot_update, state.count + ) + output = jax.tree.map( + compute_var_and_slot_update_fn, + updates, + state.m, + state.m_scale, + state.vr, + state.vc, + state.v, + params, + ) updates = jax.tree.map(lambda o: o.update, output) count_plus_one = state.count + jnp.array(1, jnp.int32) updated_states = sharded_adafactor_helper.to_state(count_plus_one, output) diff --git a/init2winit/optimizer_lib/samuel.py b/init2winit/optimizer_lib/samuel.py index 5b6a0f1f..eea30967 100644 --- a/init2winit/optimizer_lib/samuel.py +++ b/init2winit/optimizer_lib/samuel.py @@ -77,7 +77,8 @@ def samuel( if num_experts != jax.process_count(): raise ValueError( 'This implementation of SAMUEL requires the number of optimizers to be ' - 'equal to the number of hosts (one host per expert).') + 'equal to the number of hosts (one host per expert).' + ) optimizer = optimizers[jax.process_index()] hps = hps[jax.process_index()] @@ -111,7 +112,8 @@ def update_fn(updates, state, params): # jax.experimental.multihost_utils. # NOTE(dsuo): train_losses is of shape (jax.process_count(),). train_losses = jax.lax.all_gather(train_loss, 'batch').reshape( - jax.process_count(), jax.local_device_count())[:, 0] + jax.process_count(), jax.local_device_count() + )[:, 0] # Compute loss regret and update expert weights. loss_regret = train_losses.at[current_expert].get() - train_losses @@ -150,4 +152,5 @@ def from_hparams(opt_hparams): index += 1 return static_inject_hyperparams(samuel)( - optimizers=optimizers, hps=hps, **opt_hparams['args']) + optimizers=optimizers, hps=hps, **opt_hparams['args'] + ) diff --git a/init2winit/optimizer_lib/search_subspace.py b/init2winit/optimizer_lib/search_subspace.py index 58bc33a1..08d25022 100644 --- a/init2winit/optimizer_lib/search_subspace.py +++ b/init2winit/optimizer_lib/search_subspace.py @@ -17,6 +17,7 @@ TODO(dsuo): suport discrete hparams. """ + import copy import itertools import json @@ -26,6 +27,7 @@ import pandas as pd + def print_search_space(search_space): """Prints search space.""" for key, hp in search_space.items(): @@ -34,22 +36,41 @@ def print_search_space(search_space): print(f'\t- scale_type: {hp["scale_type"]}') -def get_top_k_random_sweep(trials, objective, min_objective, num_top_trials, - num_random_seeds, top_mode='final', **kwargs): +def get_top_k_random_sweep( + trials, + objective, + min_objective, + num_top_trials, + num_random_seeds, + top_mode='final', + **kwargs, +): """Generate num_random trials for top k experiments.""" del kwargs # Get the top k trials as ordered by objective if top_mode == 'final': - top_k_obj = trials[objective].apply(lambda x: x[-1]).sort_values( - ascending=min_objective).head(n=num_top_trials) + top_k_obj = ( + trials[objective] + .apply(lambda x: x[-1]) + .sort_values(ascending=min_objective) + .head(n=num_top_trials) + ) elif top_mode == 'best': if min_objective: - top_k_obj = trials[objective].apply(lambda x: x.min()).sort_values( - ascending=min_objective).head(n=num_top_trials) + top_k_obj = ( + trials[objective] + .apply(lambda x: x.min()) + .sort_values(ascending=min_objective) + .head(n=num_top_trials) + ) else: - top_k_obj = trials[objective].apply(lambda x: x.max()).sort_values( - ascending=min_objective).head(n=num_top_trials) + top_k_obj = ( + trials[objective] + .apply(lambda x: x.max()) + .sort_values(ascending=min_objective) + .head(n=num_top_trials) + ) top_k = trials.loc[top_k_obj.index] print(f'Generating top k trials using objective `{objective}`.') print(top_k_obj) @@ -61,20 +82,26 @@ def get_top_k_random_sweep(trials, objective, min_objective, num_top_trials, return [[hparam] * num_random_seeds for hparam in hparams], list(top_k_obj) -def find_best_cube(trials, - objective, - search_space, - k, - cube_sizes=None, - cube_strides=None, - min_objective=True, - **kwargs): +def find_best_cube( + trials, + objective, + search_space, + k, + cube_sizes=None, + cube_strides=None, + min_objective=True, + **kwargs, +): """Find best cube in original search space.""" del kwargs # Get the top k trials as ordered by objective - top_k_obj = trials[objective].apply(lambda x: x[-1]).sort_values( - ascending=min_objective).head(n=k) + top_k_obj = ( + trials[objective] + .apply(lambda x: x[-1]) + .sort_values(ascending=min_objective) + .head(n=k) + ) top_k_idx = top_k_obj.index for key in search_space.keys(): @@ -83,8 +110,9 @@ def find_best_cube(trials, else: trials[key] = trials[f'hps.{key}'] - top_k_df = pd.concat((trials.loc[top_k_idx][search_space.keys()], top_k_obj), - axis=1) + top_k_df = pd.concat( + (trials.loc[top_k_idx][search_space.keys()], top_k_obj), axis=1 + ) # Compute starting points of hyperparam cubes. cube_start_points = {} @@ -97,7 +125,7 @@ def find_best_cube(trials, if hp['scale_type'] == 'UNIT_LOG_SCALE': hp['mapped_range'] = [ np.floor(np.floor(np.log10(float(min_trial_value)))), - np.ceil(np.ceil(np.log10(float(max_trial_value)))) + np.ceil(np.ceil(np.log10(float(max_trial_value)))), ] elif hp['scale_type'] == 'UNIT_LINEAR_SCALE': # This requires some work @@ -107,7 +135,8 @@ def find_best_cube(trials, cube_start_points[key] = np.arange( hp['mapped_range'][0], hp['mapped_range'][1] - cube_sizes[key] + cube_strides[key], - cube_strides[key]) + cube_strides[key], + ) if len(cube_start_points[key]) == 0: # pylint: disable=g-explicit-length-test cube_start_points[key] = np.array([hp['mapped_range'][0]]) cube_end_points[key] = cube_start_points[key] + cube_sizes[key] @@ -128,8 +157,9 @@ def find_best_cube(trials, best_cube_top_trial_included = False # Find trials from top k in each cube. - for cube_start_point, cube_end_point in zip(cube_start_points, - cube_end_points): + for cube_start_point, cube_end_point in zip( + cube_start_points, cube_end_points + ): points_df = top_k_df top_trial_included = True for i, (start, end) in enumerate(zip(cube_start_point, cube_end_point)): @@ -140,9 +170,11 @@ def find_best_cube(trials, points_df = points_df[selectors] print(cube_start_point, cube_end_point, points_df[objective].mean()) # Check if we have found a better cube. - change_flag = points_df[objective].mean( - ) < best_cube_mean if min_objective else points_df[objective].mean( - ) > best_cube_mean + change_flag = ( + points_df[objective].mean() < best_cube_mean + if min_objective + else points_df[objective].mean() > best_cube_mean + ) # Record best cube. if change_flag: @@ -154,8 +186,8 @@ def find_best_cube(trials, new_search_space = copy.deepcopy(search_space) for val, (key, hp) in zip(best_cube, new_search_space.items()): if hp['scale_type'] == 'UNIT_LOG_SCALE': - hp['min_value'] = np.power(10., np.log10(val)) - hp['max_value'] = np.power(10., np.log10(val) + cube_sizes[key]) + hp['min_value'] = np.power(10.0, np.log10(val)) + hp['max_value'] = np.power(10.0, np.log10(val) + cube_sizes[key]) elif hp['scale_type'] == 'UNIT_LINEAR_SCALE': hp['min_value'] = val hp['max_value'] = val + cube_sizes[key] @@ -163,8 +195,9 @@ def find_best_cube(trials, del hp['mapped_range'] num_trials = len(best_cube_trials) - logging.info('Total number of trials included in the reported cube is %d', - num_trials) + logging.info( + 'Total number of trials included in the reported cube is %d', num_trials + ) if not best_cube_top_trial_included: logging.info('Warning the best trial was not included in the cube') diff --git a/init2winit/optimizer_lib/sharpness_aware_minimization.py b/init2winit/optimizer_lib/sharpness_aware_minimization.py index 2d16924d..1c4e4ca8 100644 --- a/init2winit/optimizer_lib/sharpness_aware_minimization.py +++ b/init2winit/optimizer_lib/sharpness_aware_minimization.py @@ -25,7 +25,6 @@ from typing import Optional from init2winit.model_lib import model_utils - import jax import jax.numpy as jnp import optax @@ -44,7 +43,8 @@ def dual_vector(y: jnp.ndarray) -> jnp.ndarray: y: A pytree of numpy ndarray, vector y in the equation above. """ gradient_norm = jnp.sqrt( - sum([jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)])) + sum([jnp.sum(jnp.square(e)) for e in jax.tree_util.tree_leaves(y)]) + ) normalized_gradient = jax.tree.map(lambda x: x / gradient_norm, y) return normalized_gradient @@ -79,7 +79,7 @@ def init_fn(params): return base_opt_init_fn(params) def update_fn(updates, state, grad_fn_params_tuple): - (grad_fn, params) = grad_fn_params_tuple + grad_fn, params = grad_fn_params_tuple # Updates here have been averaged across devices in Trainer before being # sent to the optimizer. We obtain gradients computed on the noised @@ -87,16 +87,22 @@ def update_fn(updates, state, grad_fn_params_tuple): # gradients and with the same 1e-6 epsilon that is used when clipping the # gradients. updates = dual_vector(updates) - noised_params = jax.tree_util.tree_map(lambda p, u: p + rho * u, params, - updates) + noised_params = jax.tree_util.tree_map( + lambda p, u: p + rho * u, params, updates + ) _, updates = grad_fn(noised_params) updates_norm = jnp.sqrt(model_utils.l2_regularization(updates, 0)) if grad_clip: scaled_updates = jax.tree.map( - lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates) - updates = jax.lax.cond(updates_norm > grad_clip, lambda _: scaled_updates, - lambda _: updates, None) + lambda x: x / (updates_norm + _GRAD_CLIP_EPS) * grad_clip, updates + ) + updates = jax.lax.cond( + updates_norm > grad_clip, + lambda _: scaled_updates, + lambda _: updates, + None, + ) updates, state = base_opt_update_fn(updates, state, params) return updates, state diff --git a/init2winit/optimizer_lib/test_gradient_accumulator.py b/init2winit/optimizer_lib/test_gradient_accumulator.py index 123b5837..658af37f 100644 --- a/init2winit/optimizer_lib/test_gradient_accumulator.py +++ b/init2winit/optimizer_lib/test_gradient_accumulator.py @@ -16,6 +16,7 @@ r"""Tests for gradient_accumulator.py. """ + import copy import functools import itertools @@ -52,22 +53,26 @@ def _init_model(model_cls, hps): model = model_cls(hps, dataset_metadata, loss_name, metrics_name) params_rng, dropout_rng = jax.random.split(key, num=2) model_init_fn = jax.jit( - functools.partial(model.flax_module.init, train=False)) + functools.partial(model.flax_module.init, train=False) + ) init_dict = model_init_fn( rngs={'params': params_rng, 'dropout': dropout_rng}, - x=np.zeros((2, *hps.input_shape))) + x=np.zeros((2, *hps.input_shape)), + ) params = init_dict['params'] batch_stats = init_dict.get('batch_stats', {}) return params, batch_stats, model.training_cost -def _optimize(num_steps, - params, - batch_stats, - training_cost, - train_iter, - opt_init, - opt_update): +def _optimize( + num_steps, + params, + batch_stats, + training_cost, + train_iter, + opt_init, + opt_update, +): """Update the Flax model for num_steps steps.""" opt_state = opt_init(params) @@ -76,7 +81,9 @@ def opt_cost(params, batch_stats, batch): params, batch=batch, batch_stats=batch_stats, - dropout_rng=jax.random.PRNGKey(2)) + dropout_rng=jax.random.PRNGKey(2), + ) + grad_fn = jax.value_and_grad(opt_cost, has_aux=True) for _ in range(num_steps): data_batch = next(train_iter) @@ -89,8 +96,7 @@ def opt_cost(params, batch_stats, batch): def _get_fake_text_dataset(batch_size, eval_num_batches): """Yields a single text batch repeatedly for train and test.""" - inputs = jnp.array( - np.random.randint(low=0, high=4, size=(batch_size, 32))) + inputs = jnp.array(np.random.randint(low=0, high=4, size=(batch_size, 32))) batch = { 'inputs': inputs, 'targets': inputs, @@ -122,10 +128,12 @@ def test_epoch(num_batches=None): meta_data = { 'apply_one_hot_in_loss': True, 'shift_inputs': True, - 'causal': True + 'causal': True, } - return (Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, - test_epoch), meta_data) + return ( + Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch), + meta_data, + ) class GradientAccumulatorTest(absltest.TestCase): @@ -143,13 +151,15 @@ def tearDown(self): def test_virtual_batch_size_error(self): with self.assertRaisesRegex( - ValueError, 'Gradient accumulation does not currently support using '): + ValueError, 'Gradient accumulation does not currently support using ' + ): gradient_accumulator.accumulate_gradients( per_step_batch_size=32, total_batch_size=96, virtual_batch_size=48, base_opt_init_fn=None, - base_opt_update_fn=None) + base_opt_update_fn=None, + ) def test_accumulation(self): """Test simple gradient accumulation.""" @@ -173,16 +183,19 @@ def test_accumulation(self): 'total_accumulated_batch_size': total_batch_size, }) grad_acc_params, grad_acc_batch_stats, grad_acc_training_cost = _init_model( - model_cls, hps) + model_cls, hps + ) total_dataset = dataset_builder( shuffle_rng=jax.random.PRNGKey(1), batch_size=total_batch_size, eval_batch_size=10, - hps=hps) + hps=hps, + ) # Ensure we see the same exact batches. train_iter = total_dataset.train_iterator_fn() train_iter = itertools.islice(train_iter, 0, num_steps) train_iter = itertools.cycle(train_iter) + def grad_acc_train_iter(): for _ in range(num_steps): total_batch = next(train_iter) @@ -197,13 +210,15 @@ def grad_acc_train_iter(): lrs = jnp.array([1.0, 0.1, 1e-2]) sgd_opt_init, sgd_opt_update = optax.sgd( - learning_rate=lambda t: lrs.at[t].get()) + learning_rate=lambda t: lrs.at[t].get() + ) opt_init, opt_update = gradient_accumulator.accumulate_gradients( per_step_batch_size=per_step_batch_size, total_batch_size=total_batch_size, virtual_batch_size=virtual_batch_size, base_opt_init_fn=sgd_opt_init, - base_opt_update_fn=sgd_opt_update) + base_opt_update_fn=sgd_opt_update, + ) grad_acc_params, grad_acc_batch_stats = _optimize( # Run for 3x the number of steps to see the same number of examples. num_steps=3 * num_steps, @@ -212,7 +227,8 @@ def grad_acc_train_iter(): training_cost=grad_acc_training_cost, train_iter=grad_acc_train_iter(), opt_init=opt_init, - opt_update=opt_update) + opt_update=opt_update, + ) # Compute the same updates, but without gradient accumulation. hps.update({ @@ -227,20 +243,22 @@ def grad_acc_train_iter(): training_cost=training_cost, train_iter=train_iter, opt_init=sgd_opt_init, - opt_update=sgd_opt_update) + opt_update=sgd_opt_update, + ) - diffs_params = jax.tree.map(lambda a, b: jnp.mean(jnp.abs(a - b)), - grad_acc_params, params) + diffs_params = jax.tree.map( + lambda a, b: jnp.mean(jnp.abs(a - b)), grad_acc_params, params + ) def batch_stats_reduce(a, b): if len(a.shape) > 0: # pylint: disable=g-explicit-length-test - return jnp.mean( - jnp.abs(jnp.mean(a, axis=0) - jnp.mean(b, axis=0))) + return jnp.mean(jnp.abs(jnp.mean(a, axis=0) - jnp.mean(b, axis=0))) # The gradient accumulator counters are scalars. return a - b - diffs_batch_stats = jax.tree.map(batch_stats_reduce, grad_acc_batch_stats, - batch_stats) + diffs_batch_stats = jax.tree.map( + batch_stats_reduce, grad_acc_batch_stats, batch_stats + ) # We sometimes get small floating point errors in the gradients, so we # cannot test for the values being exactly the same. acceptable_params_diff = 1e-4 @@ -257,11 +275,11 @@ def check_closeness(root_name, d, max_diff): not_close_dict[new_name] = dd return not_close_dict - not_close_params = check_closeness( - '', diffs_params, acceptable_params_diff) + not_close_params = check_closeness('', diffs_params, acceptable_params_diff) self.assertEmpty(not_close_params) not_close_batch_stats = check_closeness( - '', diffs_batch_stats, acceptable_batch_stats_diff) + '', diffs_batch_stats, acceptable_batch_stats_diff + ) # Note that for the variance variables in the batch stats collection, they # sometimes can start to diverge slightly over time (with a higher number of # training steps), likely due to numerical issues. @@ -300,10 +318,7 @@ def test_text_model(self): 'opt_hparams': { 'momentum': 0.9, }, - 'lr_hparams': { - 'base_lr': 0.005, - 'schedule': 'constant' - }, + 'lr_hparams': {'base_lr': 0.005, 'schedule': 'constant'}, # Training HParams. 'l2_decay_factor': 1e-4, 'l2_decay_rank_threshold': 2, @@ -315,7 +330,8 @@ def test_text_model(self): initializer = initializers.get_initializer('noop') eval_num_batches = 5 dataset, dataset_meta_data = _get_fake_text_dataset( - batch_size=hps.batch_size, eval_num_batches=eval_num_batches) + batch_size=hps.batch_size, eval_num_batches=eval_num_batches + ) eval_batch_size = hps.batch_size model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) @@ -340,10 +356,13 @@ def test_text_model(self): eval_frequency=eval_every, checkpoint_steps=checkpoint_steps, metrics_logger=metrics_logger, - init_logger=init_logger).train()) + init_logger=init_logger, + ).train() + ) with tf.io.gfile.GFile( - os.path.join(self.test_dir, 'measurements.csv')) as f: + os.path.join(self.test_dir, 'measurements.csv') + ) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] # Note that upgrading to Linen made this fail at 0.6. diff --git a/init2winit/optimizer_lib/test_optimizers.py b/init2winit/optimizer_lib/test_optimizers.py index e65f9720..9e5eed21 100644 --- a/init2winit/optimizer_lib/test_optimizers.py +++ b/init2winit/optimizer_lib/test_optimizers.py @@ -14,6 +14,7 @@ # limitations under the License. """Tests for optimizers.""" + import shutil import tempfile @@ -119,8 +120,8 @@ def test_generic_multi_optimizer_init(self): unreplicated_optimizer_state = opt_init_fn(unreplicated_params) self.assertIsInstance( - unreplicated_optimizer_state, - optax.transforms.PartitionState) + unreplicated_optimizer_state, optax.transforms.PartitionState + ) # unreplicated_optimizer_state should be a Dict mapping param type # to opt_state where only params mapping to that param_type have non-empty @@ -161,5 +162,6 @@ def tearDown(self): + if __name__ == '__main__': absltest.main() diff --git a/init2winit/optimizer_lib/test_search_subspace.py b/init2winit/optimizer_lib/test_search_subspace.py index 49b640b5..e7341cd4 100644 --- a/init2winit/optimizer_lib/test_search_subspace.py +++ b/init2winit/optimizer_lib/test_search_subspace.py @@ -24,7 +24,7 @@ 'hps.beta1': [2e-5, 4e-2, 6e-3, 1e-7, 3e-4], 'hps.beta2': [3e-1, 5e-4, 6e-7, 7e-1, 1e-2], 'hps.int': [1, 4, 5, 10, 2], - 'hps.float': [1., 4., 5., 10., 2.], + 'hps.float': [1.0, 4.0, 5.0, 10.0, 2.0], 'objective': [[0.12451], [2.2312], [0.123123], [0.5325], [0.6423]], }) @@ -43,7 +43,7 @@ def test_find_best_cube_log(self): 'min_value': 1e-8, 'max_value': 1e-1, 'type': 'DOUBLE', - 'scale_type': 'UNIT_LOG_SCALE' + 'scale_type': 'UNIT_LOG_SCALE', } } cube_sizes = {'beta1': 2} @@ -55,14 +55,17 @@ def test_find_best_cube_log(self): search_space, k=k, cube_sizes=cube_sizes, - cube_strides=cube_strides) + cube_strides=cube_strides, + ) testing.assert_equal(result['contains_best_trial'], False) testing.assert_equal(result['mean_trial_objective'], 0.12451) - testing.assert_array_equal(result['search_space']['beta1']['min_value'], - 1e-6) - testing.assert_array_equal(result['search_space']['beta1']['max_value'], - 1e-4) + testing.assert_array_equal( + result['search_space']['beta1']['min_value'], 1e-6 + ) + testing.assert_array_equal( + result['search_space']['beta1']['max_value'], 1e-4 + ) testing.assert_equal(len(result['trials']), 1) def test_find_best_cube_int(self): @@ -72,7 +75,7 @@ def test_find_best_cube_int(self): 'min_value': 1, 'max_value': 12, 'type': 'INTEGER', - 'scale_type': 'UNIT_LINEAR_SCALE' + 'scale_type': 'UNIT_LINEAR_SCALE', } } cube_sizes = {'int': 2} @@ -84,7 +87,8 @@ def test_find_best_cube_int(self): search_space, k=k, cube_sizes=cube_sizes, - cube_strides=cube_strides) + cube_strides=cube_strides, + ) testing.assert_equal(result['contains_best_trial'], True) testing.assert_equal(result['mean_trial_objective'], 0.123123) @@ -96,10 +100,10 @@ def test_find_best_cube_float(self): """Test find best cube of float params.""" search_space = { 'float': { - 'min_value': 1., - 'max_value': 12., + 'min_value': 1.0, + 'max_value': 12.0, 'type': 'DOUBLE', - 'scale_type': 'UNIT_LINEAR_SCALE' + 'scale_type': 'UNIT_LINEAR_SCALE', } } cube_sizes = {'float': 2} @@ -111,14 +115,17 @@ def test_find_best_cube_float(self): search_space, k=k, cube_sizes=cube_sizes, - cube_strides=cube_strides) + cube_strides=cube_strides, + ) testing.assert_equal(result['contains_best_trial'], True) testing.assert_equal(result['mean_trial_objective'], 0.123123) - testing.assert_array_equal(result['search_space']['float']['min_value'], - 5.0) - testing.assert_array_equal(result['search_space']['float']['max_value'], - 7.0) + testing.assert_array_equal( + result['search_space']['float']['min_value'], 5.0 + ) + testing.assert_array_equal( + result['search_space']['float']['max_value'], 7.0 + ) testing.assert_equal(len(result['trials']), 1) # TODO(namanagarwal): Get these tests working again. diff --git a/init2winit/optimizer_lib/test_utils.py b/init2winit/optimizer_lib/test_utils.py index 0e231840..a6a23995 100644 --- a/init2winit/optimizer_lib/test_utils.py +++ b/init2winit/optimizer_lib/test_utils.py @@ -24,7 +24,6 @@ from ml_collections.config_dict import ConfigDict import optax - # pylint:disable=duplicate-key @@ -139,7 +138,7 @@ def fun(x): # Test that we can access and set the new hparam state = new_opt.init(init_params) - state = optax.tree_utils.tree_set(state, foo=2.) + state = optax.tree_utils.tree_set(state, foo=2.0) foo = optax.tree_utils.tree_get(state, 'foo') self.assertEqual(foo, 2.0) diff --git a/init2winit/optimizer_lib/utils.py b/init2winit/optimizer_lib/utils.py index edbb6aa5..dc787739 100644 --- a/init2winit/optimizer_lib/utils.py +++ b/init2winit/optimizer_lib/utils.py @@ -219,7 +219,7 @@ def new_update_fn(updates, state, params=None, **extra_args): def append_hparam_name( base_opt: optax.GradientTransformationExtraArgs, hparam_name: str, - default_value: float = 0., + default_value: float = 0.0, ) -> optax.GradientTransformationExtraArgs: """Create artificicial hparam name to comply with pipeline. @@ -243,8 +243,8 @@ def append_hparam_name( state must be 'InjectHyperparamsState'-like such that a hyperparams attribute is present in its state. hparam_name: hyperparameter name to add. - default_value: default value for the new hyperparameter - (never used, simply there to fill the entry) + default_value: default value for the new hyperparameter (never used, simply + there to fill the entry) Returns: new_opt: new ``optax.GradientTransformationExtraArgs`` with new diff --git a/init2winit/test_checkpoint.py b/init2winit/test_checkpoint.py index 81a0eb6d..cb3ce3ce 100644 --- a/init2winit/test_checkpoint.py +++ b/init2winit/test_checkpoint.py @@ -33,7 +33,6 @@ import orbax.checkpoint as ocp from tensorflow.io import gfile - FLAGS = flags.FLAGS INPUT_SHAPE = [10, 28, 28, 1] @@ -57,7 +56,8 @@ def setUp(self): xs = jnp.array(np.random.normal(size=INPUT_SHAPE)) rng, params_rng = jax.random.split(rng) model_init_fn = jax.jit( - functools.partial(model.flax_module.init, train=False)) + functools.partial(model.flax_module.init, train=False) + ) init_dict = model_init_fn({'params': params_rng}, xs) self.params = init_dict['params'] @@ -71,8 +71,7 @@ def test_save_load_roundtrip(self): """Test that saving and loading produces the original state.""" orbax_checkpoint_manager = ocp.CheckpointManager( self.test_dir, - options=ocp.CheckpointManagerOptions( - max_to_keep=1, create=True), + options=ocp.CheckpointManagerOptions(max_to_keep=1, create=True), ) state = dict(params=self.params, global_step=5, completed_epochs=4) checkpoint.save_checkpoint( @@ -93,25 +92,21 @@ def test_delete_old_checkpoints(self): """Test that old checkpoints are deleted.""" orbax_checkpoint_manager = ocp.CheckpointManager( self.test_dir, - options=ocp.CheckpointManagerOptions( - max_to_keep=1, create=True - ), + options=ocp.CheckpointManagerOptions(max_to_keep=1, create=True), + ) + state1 = dict( + params=self.params, + global_step=5, + completed_epochs=4, ) - state1 = dict(params=self.params, - global_step=5, - completed_epochs=4,) checkpoint.save_checkpoint( - 0, - state1, - orbax_checkpoint_manager=orbax_checkpoint_manager) + 0, state1, orbax_checkpoint_manager=orbax_checkpoint_manager + ) - state2 = dict(params=self.params, - global_step=10, - completed_epochs=8) + state2 = dict(params=self.params, global_step=10, completed_epochs=8) checkpoint.save_checkpoint( - 1, - state2, - orbax_checkpoint_manager=orbax_checkpoint_manager) + 1, state2, orbax_checkpoint_manager=orbax_checkpoint_manager + ) orbax_checkpoint_manager.wait_until_finished() dir_contents = gfile.glob(os.path.join(self.test_dir, '*')) @@ -142,21 +137,22 @@ def test_all_variables_restored(self): orbax_checkpoint_manager = ocp.CheckpointManager( fresh_train_dir, - options=ocp.CheckpointManagerOptions( - max_to_keep=1, create=True - ), + options=ocp.CheckpointManagerOptions(max_to_keep=1, create=True), ) checkpoint.save_checkpoint( step=global_step, - state=dict(global_step=global_step, - preemption_count=preemption_count, - sum_train_cost=sum_train_cost, - optimizer_state=saved_optimizer_state, - params=saved_params, - batch_stats=saved_batch_stats, - training_metrics_grabber=saved_training_metrics), - orbax_checkpoint_manager=orbax_checkpoint_manager,) + state=dict( + global_step=global_step, + preemption_count=preemption_count, + sum_train_cost=sum_train_cost, + optimizer_state=saved_optimizer_state, + params=saved_params, + batch_stats=saved_batch_stats, + training_metrics_grabber=saved_training_metrics, + ), + orbax_checkpoint_manager=orbax_checkpoint_manager, + ) ( ret_state, @@ -175,12 +171,8 @@ def test_all_variables_restored(self): orbax_checkpoint_manager=orbax_checkpoint_manager, ) - assert pytree_equal( - ret_state, saved_optimizer_state - ) - assert pytree_equal( - ret_params, saved_params - ) + assert pytree_equal(ret_state, saved_optimizer_state) + assert pytree_equal(ret_params, saved_params) assert pytree_equal( ret_batch_stats, saved_batch_stats, @@ -199,13 +191,13 @@ def test_all_variables_restored(self): def test_maybe_restore_from_checkpoint_logic(self): """Test that the right checkpoint is returned. - 1. If there is no latest checkpoint in the train_dir, then the function - should returnthe passed-in params, batch_stats, etc. - 2. If there is a latest checkpoint in the train_dir, then the function - should return the latest checkpoint. - In the interest of conciseness, this test only checks the params, - not the batch_stats, optimizer_state, or training_metics. The below test - test_all_variables_restored() covers the other three. + 1. If there is no latest checkpoint in the train_dir, then the function + should returnthe passed-in params, batch_stats, etc. + 2. If there is a latest checkpoint in the train_dir, then the function + should return the latest checkpoint. + In the interest of conciseness, this test only checks the params, + not the batch_stats, optimizer_state, or training_metics. The below test + test_all_variables_restored() covers the other three. """ # mock parameters. initial_params = {'foo': 1.0} @@ -215,9 +207,7 @@ def test_maybe_restore_from_checkpoint_logic(self): orbax_checkpoint_manager = ocp.CheckpointManager( checkpoint_dir, - options=ocp.CheckpointManagerOptions( - max_to_keep=1, create=True - ), + options=ocp.CheckpointManagerOptions(max_to_keep=1, create=True), ) # two helper functions diff --git a/init2winit/test_schedules.py b/init2winit/test_schedules.py index 6a1574fb..fad0a114 100644 --- a/init2winit/test_schedules.py +++ b/init2winit/test_schedules.py @@ -27,15 +27,17 @@ class LearningRateTest(absltest.TestCase): def test_polynomial_decay(self): """Test polynomial schedule works correctly with decay_steps_factor.""" - hps = config_dict.ConfigDict(dict( - lr_hparams={ - 'schedule': 'polynomial', - 'power': 2.0, - 'base_lr': .1, - 'end_factor': .01, - 'decay_steps_factor': 0.5, - } - )) + hps = config_dict.ConfigDict( + dict( + lr_hparams={ + 'schedule': 'polynomial', + 'power': 2.0, + 'base_lr': 0.1, + 'end_factor': 0.01, + 'decay_steps_factor': 0.5, + } + ) + ) max_training_steps = 400 lr_fn = schedules.get_schedule_fn(hps.lr_hparams, max_training_steps) hps = hps.lr_hparams @@ -46,20 +48,23 @@ def test_polynomial_decay(self): step, decay_steps, hps['end_factor'] * hps['base_lr'], - power=hps['power'])().numpy() + power=hps['power'], + )().numpy() self.assertAlmostEqual(lr_fn(step), expected_learning_rate) def test_polynomial_decay_decay_steps(self): """Test polynomial schedule works correctly with decay_steps.""" - hps = config_dict.ConfigDict(dict( - lr_hparams={ - 'schedule': 'polynomial', - 'power': 2.0, - 'base_lr': .1, - 'end_factor': .01, - 'decay_steps': 200, - } - )) + hps = config_dict.ConfigDict( + dict( + lr_hparams={ + 'schedule': 'polynomial', + 'power': 2.0, + 'base_lr': 0.1, + 'end_factor': 0.01, + 'decay_steps': 200, + } + ) + ) max_training_steps = 400 lr_fn = schedules.get_schedule_fn(hps.lr_hparams, max_training_steps) hps = hps.lr_hparams @@ -70,7 +75,8 @@ def test_polynomial_decay_decay_steps(self): step, decay_steps, hps['end_factor'] * hps['base_lr'], - power=hps['power'])().numpy() + power=hps['power'], + )().numpy() self.assertAlmostEqual(lr_fn(step), expected_learning_rate) def test_compound_schedule_cosine(self): @@ -88,13 +94,15 @@ def test_compound_schedule_cosine(self): 1 + np.cos(0.9 * np.pi), 0.0, ] - hps = config_dict.ConfigDict(dict( - lr_hparams={ - 'schedule': 'compound', - 'factors': 'constant * cosine', - 'base_lr': 2, - } - )) + hps = config_dict.ConfigDict( + dict( + lr_hparams={ + 'schedule': 'compound', + 'factors': 'constant * cosine', + 'base_lr': 2, + } + ) + ) max_training_steps = 11 lr_fn = schedules.get_schedule_fn(hps.lr_hparams, max_training_steps) for step in range(max_training_steps): @@ -148,14 +156,16 @@ def test_compound_schedule_cosine_with_warmup(self): 4.0 * (1 + np.cos(0.9 * np.pi)), 0.0, ] - hps = config_dict.ConfigDict(dict( - lr_hparams={ - 'schedule': 'compound', - 'factors': 'constant * cosine * linear_warmup', - 'base_lr': 8, - 'warmup_steps': 4, - } - )) + hps = config_dict.ConfigDict( + dict( + lr_hparams={ + 'schedule': 'compound', + 'factors': 'constant * cosine * linear_warmup', + 'base_lr': 8, + 'warmup_steps': 4, + } + ) + ) max_training_steps = 4 + 11 lr_fn = schedules.get_schedule_fn(hps.lr_hparams, max_training_steps) for step in range(max_training_steps): @@ -178,8 +188,8 @@ def test_concatenate_out_of_bounds(self): # Although the schedule is only defined for sum(lengths), out of bounds # steps should be blithely passed into the final sub-schedule piece so we # test an extra 3 steps. - lrs = [lr_fn(t) for t in range(sum(lengths)+3)] - expected_lrs = [10.0] * 3 + [7.0] * (5+3) + lrs = [lr_fn(t) for t in range(sum(lengths) + 3)] + expected_lrs = [10.0] * 3 + [7.0] * (5 + 3) self.assertEqual(lrs, expected_lrs) def test_concatenate(self): @@ -201,7 +211,8 @@ def test_concatenate(self): 'warmup_steps': 5, }) lr_fn = schedules.concatenate( - lengths, lr_hparams_a, lr_hparams_b, lr_hparams_c) + lengths, lr_hparams_a, lr_hparams_b, lr_hparams_c + ) lrs = [lr_fn(t) for t in range(sum(lengths))] lf_fn_a = schedules.get_schedule_fn(lr_hparams_a, lengths[0]) lf_fn_b = schedules.get_schedule_fn(lr_hparams_b, lengths[1]) @@ -228,7 +239,8 @@ def test_schedule_stretching(self): lr_fn = schedules.get_schedule_fn(lr_hparams, max_training_steps) stretch_factor = 3 stretched_lr_fn = schedules.get_schedule_fn( - lr_hparams, max_training_steps, stretch_factor=stretch_factor) + lr_hparams, max_training_steps, stretch_factor=stretch_factor + ) lrs = [lr_fn(t) for t in range(max_training_steps)] stretched_lrs = [ stretched_lr_fn(t) for t in range(stretch_factor * max_training_steps) @@ -246,39 +258,171 @@ def test_schedule_stretching(self): def test_mlperf_schedule(self): """Test there are no changes to the MLPerf polynomial decay schedule.""" expected_lrs = [ - 0.0, 0.2, 0.4, 0.6, 0.8, 1.0, 1.2, 1.4, 1.6, 1.8, 2.0, 2.2, 2.4, 2.6, - 2.8, 3.0, 3.2, 3.4, 3.6, 3.8, 4.0, 4.2, 4.4, 4.6, 4.8, 5.0, 5.2, 5.4, - 5.6, 5.8, 6.0, 6.2, 6.4, 6.6, 6.8, 7.0, 7.2, 7.4, 7.6, 7.8, 8.0, 8.2, - 8.4, 8.6, 8.8, 9.0, 9.2, 9.4, 9.6, 9.8, 10.0, 9.802962, 9.607885, - 9.414769, 9.223614, 9.034419, 8.847184, 8.661909, 8.478596, 8.297242, - 8.117851, 7.940418, 7.764947, 7.5914364, 7.419886, 7.2502966, 7.082668, - 6.917, 6.7532916, 6.591545, 6.4317584, 6.273932, 6.1180663, 5.964162, - 5.812217, 5.662234, 5.5142093, 5.368148, 5.2240453, 5.0819044, 4.941723, - 4.803503, 4.6672425, 4.532944, 4.4006047, 4.2702274, 4.1418095, - 4.0153522, 3.8908558, 3.7683203, 3.647745, 3.5291305, 3.4124763, - 3.297783, 3.18505, 3.0742776, 2.965466, 2.858614, 2.753724, 2.6507936, - 2.5498245, 2.4508152, 2.353767, 2.2586792, 2.1655521, 2.0743854, - 1.9851794, 1.8979341, 1.8126491, 1.7293249, 1.6479613, 1.568558, - 1.4911155, 1.4156334, 1.342112, 1.2705511, 1.2009507, 1.133311, - 1.0676318, 1.0039133, 0.94215524, 0.8823574, 0.8245205, 0.7686443, - 0.71472853, 0.66277343, 0.61277884, 0.56474483, 0.5186714, 0.4745585, - 0.43240622, 0.39221448, 0.35398334, 0.31771275, 0.28340274, 0.2510533, - 0.22066444, 0.19223614, 0.16576843, 0.14126128, 0.11871469, 0.098128565, - 0.079503134, 0.06283828, 0.048134, 0.035390284, 0.02460714, 0.01578457, - 0.00892257, 0.004021142, + 0.0, + 0.2, + 0.4, + 0.6, + 0.8, + 1.0, + 1.2, + 1.4, + 1.6, + 1.8, + 2.0, + 2.2, + 2.4, + 2.6, + 2.8, + 3.0, + 3.2, + 3.4, + 3.6, + 3.8, + 4.0, + 4.2, + 4.4, + 4.6, + 4.8, + 5.0, + 5.2, + 5.4, + 5.6, + 5.8, + 6.0, + 6.2, + 6.4, + 6.6, + 6.8, + 7.0, + 7.2, + 7.4, + 7.6, + 7.8, + 8.0, + 8.2, + 8.4, + 8.6, + 8.8, + 9.0, + 9.2, + 9.4, + 9.6, + 9.8, + 10.0, + 9.802962, + 9.607885, + 9.414769, + 9.223614, + 9.034419, + 8.847184, + 8.661909, + 8.478596, + 8.297242, + 8.117851, + 7.940418, + 7.764947, + 7.5914364, + 7.419886, + 7.2502966, + 7.082668, + 6.917, + 6.7532916, + 6.591545, + 6.4317584, + 6.273932, + 6.1180663, + 5.964162, + 5.812217, + 5.662234, + 5.5142093, + 5.368148, + 5.2240453, + 5.0819044, + 4.941723, + 4.803503, + 4.6672425, + 4.532944, + 4.4006047, + 4.2702274, + 4.1418095, + 4.0153522, + 3.8908558, + 3.7683203, + 3.647745, + 3.5291305, + 3.4124763, + 3.297783, + 3.18505, + 3.0742776, + 2.965466, + 2.858614, + 2.753724, + 2.6507936, + 2.5498245, + 2.4508152, + 2.353767, + 2.2586792, + 2.1655521, + 2.0743854, + 1.9851794, + 1.8979341, + 1.8126491, + 1.7293249, + 1.6479613, + 1.568558, + 1.4911155, + 1.4156334, + 1.342112, + 1.2705511, + 1.2009507, + 1.133311, + 1.0676318, + 1.0039133, + 0.94215524, + 0.8823574, + 0.8245205, + 0.7686443, + 0.71472853, + 0.66277343, + 0.61277884, + 0.56474483, + 0.5186714, + 0.4745585, + 0.43240622, + 0.39221448, + 0.35398334, + 0.31771275, + 0.28340274, + 0.2510533, + 0.22066444, + 0.19223614, + 0.16576843, + 0.14126128, + 0.11871469, + 0.098128565, + 0.079503134, + 0.06283828, + 0.048134, + 0.035390284, + 0.02460714, + 0.01578457, + 0.00892257, + 0.004021142, ] - hps = config_dict.ConfigDict(dict( - lr_hparams={ - 'schedule': 'mlperf_polynomial', - 'base_lr': 10.0, - 'warmup_steps': 50, - 'decay_end': -1, - 'end_lr': 1e-4, - 'power': 2.0, - 'start_lr': 0.0, - 'warmup_power': 1.0, - } - )) + hps = config_dict.ConfigDict( + dict( + lr_hparams={ + 'schedule': 'mlperf_polynomial', + 'base_lr': 10.0, + 'warmup_steps': 50, + 'decay_end': -1, + 'end_lr': 1e-4, + 'power': 2.0, + 'start_lr': 0.0, + 'warmup_power': 1.0, + } + ) + ) max_training_steps = 50 lr_fn = schedules.get_schedule_fn(hps.lr_hparams, max_training_steps) for step in range(max_training_steps): @@ -292,15 +436,30 @@ def test_t2t_rsqrt_normalized_decay(self): 'schedule': 't2t_rsqrt_normalized_decay', 'base_lr': 0.01, 'defer_steps': 10, - })) + } + ) + ) expected_lrs = [ - 0.009999999776482582, 0.009999999776482582, 0.009999999776482582, - 0.009999999776482582, 0.009999999776482582, 0.009999999776482582, - 0.009999999776482582, 0.009999999776482582, 0.009999999776482582, - 0.009999999776482582, 0.009999999776482582, 0.009534625336527824, - 0.009128709323704243, 0.008770580403506756, 0.008451541885733604, - 0.00816496554762125, 0.007905693724751472, 0.0076696500182151794, - 0.007453559897840023, 0.007254762575030327 + 0.009999999776482582, + 0.009999999776482582, + 0.009999999776482582, + 0.009999999776482582, + 0.009999999776482582, + 0.009999999776482582, + 0.009999999776482582, + 0.009999999776482582, + 0.009999999776482582, + 0.009999999776482582, + 0.009999999776482582, + 0.009534625336527824, + 0.009128709323704243, + 0.008770580403506756, + 0.008451541885733604, + 0.00816496554762125, + 0.007905693724751472, + 0.0076696500182151794, + 0.007453559897840023, + 0.007254762575030327, ] max_training_steps = 20 @@ -310,40 +469,46 @@ def test_t2t_rsqrt_normalized_decay(self): def test_raises(self): """Test that an exception is raised with extra hparams.""" - good_hps = config_dict.ConfigDict(dict( - lr_hparams={ - 'schedule': 'mlperf_polynomial', - 'base_lr': .1, - 'warmup_steps': 200, - 'decay_end': -1, - 'end_lr': 1e-4, - 'power': 2.0, - 'start_lr': 0.0, - 'warmup_power': 1.0, - } - )) - bad_hps = config_dict.ConfigDict(dict( - lr_hparams={ - 'schedule': 'mlperf_polynomial', - 'warmup_steps': 200, - 'base_lr': .1, - 'decay_end': -1, - 'end_lr': 1e-4, - 'power': 2.0, - 'start_lr': 0.0, - 'initial_value': .1, - } - )) - bad_hps2 = config_dict.ConfigDict(dict( - lr_hparams={ - 'schedule': 'polynomial', - 'power': 2.0, - 'base_lr': .1, - 'end_factor': .01, - 'decay_steps': 200, - 'decay_steps_factor': 0.5 - } - )) + good_hps = config_dict.ConfigDict( + dict( + lr_hparams={ + 'schedule': 'mlperf_polynomial', + 'base_lr': 0.1, + 'warmup_steps': 200, + 'decay_end': -1, + 'end_lr': 1e-4, + 'power': 2.0, + 'start_lr': 0.0, + 'warmup_power': 1.0, + } + ) + ) + bad_hps = config_dict.ConfigDict( + dict( + lr_hparams={ + 'schedule': 'mlperf_polynomial', + 'warmup_steps': 200, + 'base_lr': 0.1, + 'decay_end': -1, + 'end_lr': 1e-4, + 'power': 2.0, + 'start_lr': 0.0, + 'initial_value': 0.1, + } + ) + ) + bad_hps2 = config_dict.ConfigDict( + dict( + lr_hparams={ + 'schedule': 'polynomial', + 'power': 2.0, + 'base_lr': 0.1, + 'end_factor': 0.01, + 'decay_steps': 200, + 'decay_steps_factor': 0.5, + } + ) + ) # This should pass. schedules.get_schedule_fn(good_hps.lr_hparams, 1) diff --git a/init2winit/test_training_metrics_grabber.py b/init2winit/test_training_metrics_grabber.py index 8cfdaf6f..9dcbecc7 100644 --- a/init2winit/test_training_metrics_grabber.py +++ b/init2winit/test_training_metrics_grabber.py @@ -32,7 +32,6 @@ from optax import InjectHyperparamsState from optax import ScaleByAdamState - FLAGS = flags.FLAGS @@ -41,29 +40,33 @@ class TrainingMetricsGrabberTest(absltest.TestCase): def setUp(self): super(TrainingMetricsGrabberTest, self).setUp() - self.mock_params0 = freeze({'foo': jnp.zeros(5), - 'bar': {'baz': jnp.ones(10)}}) - self.mock_batch_stats = freeze({'foo': jnp.zeros(5), - 'bar': {'baz': jnp.ones(10)}}) - self.mock_grad1 = freeze({'foo': -jnp.ones(5), - 'bar': {'baz': -jnp.ones(10)}}) - self.mock_grad2 = freeze({'foo': -2*jnp.ones(5), - 'bar': {'baz': -2*jnp.ones(10)}}) - - self.mock_nu0 = freeze({'foo': 2*jnp.ones(5), - 'bar': {'baz': 3*jnp.ones(10)}}) - self.mock_nu1 = freeze({'foo': 3*jnp.ones(5), - 'bar': {'baz': 4*jnp.ones(10)}}) + self.mock_params0 = freeze( + {'foo': jnp.zeros(5), 'bar': {'baz': jnp.ones(10)}} + ) + self.mock_batch_stats = freeze( + {'foo': jnp.zeros(5), 'bar': {'baz': jnp.ones(10)}} + ) + self.mock_grad1 = freeze( + {'foo': -jnp.ones(5), 'bar': {'baz': -jnp.ones(10)}} + ) + self.mock_grad2 = freeze( + {'foo': -2 * jnp.ones(5), 'bar': {'baz': -2 * jnp.ones(10)}} + ) + + self.mock_nu0 = freeze( + {'foo': 2 * jnp.ones(5), 'bar': {'baz': 3 * jnp.ones(10)}} + ) + self.mock_nu1 = freeze( + {'foo': 3 * jnp.ones(5), 'bar': {'baz': 4 * jnp.ones(10)}} + ) self.mock_optimizer_state0 = InjectHyperparamsState( - 0, - None, - (ScaleByAdamState(0, None, self.mock_nu0),)) + 0, None, (ScaleByAdamState(0, None, self.mock_nu0),) + ) self.mock_optimizer_state1 = InjectHyperparamsState( - 0, - None, - (ScaleByAdamState(1, None, self.mock_nu1),)) + 0, None, (ScaleByAdamState(1, None, self.mock_nu1),) + ) self.mock_cost0 = 1.0 self.mock_cost1 = 0.5 @@ -73,15 +76,16 @@ def setUp(self): self.num_train_steps = 5 # Simulate running GD with step size 1. - self.mock_params1 = jax.tree.map(lambda p, g: p - self.step_size * g, - self.mock_params0, - self.mock_grad1) - self.mock_params2 = jax.tree.map(lambda p, g: p - self.step_size * g, - self.mock_params1, - self.mock_grad2) + self.mock_params1 = jax.tree.map( + lambda p, g: p - self.step_size * g, self.mock_params0, self.mock_grad1 + ) + self.mock_params2 = jax.tree.map( + lambda p, g: p - self.step_size * g, self.mock_params1, self.mock_grad2 + ) - self.mock_zeros = freeze({'foo': jnp.zeros(5), - 'bar': {'baz': jnp.zeros(10)}}) + self.mock_zeros = freeze( + {'foo': jnp.zeros(5), 'bar': {'baz': jnp.zeros(10)}} + ) def test_init(self): """Test the training metrics initializer.""" @@ -89,30 +93,35 @@ def test_init(self): zeros_like_params = jax.tree.map(jnp.zeros_like, self.mock_params0) zeros_timeseries = jnp.zeros(self.num_train_steps) zeros_timeseries_like_params = jax.tree.map( - lambda x: jnp.zeros(self.num_train_steps), self.mock_params0) + lambda x: jnp.zeros(self.num_train_steps), self.mock_params0 + ) # Test init with everything disabled. - init_fn, _, _ = make_training_metrics(self.num_train_steps, - ConfigDict({})) + init_fn, _, _ = make_training_metrics(self.num_train_steps, ConfigDict({})) initial_metrics_state = init_fn(self.mock_params0, self.mock_batch_stats) self.assertTrue( - pytree_equal({'param_norm': zeros_timeseries}, - initial_metrics_state)) + pytree_equal({'param_norm': zeros_timeseries}, initial_metrics_state) + ) # Test init with enable_ema = True and enable_train_cost=True. - init_fn, _, _ = make_training_metrics(self.num_train_steps, - ConfigDict({}), - enable_ema=True, - enable_train_cost=True, - enable_param_norms=True, - enable_gradient_norm=True, - enable_all_gradient_norms=True, - enable_update_norm=True, - enable_update_norms=True, - enable_batch_stats_norm=True, - enable_all_batch_stats_norms=True) + init_fn, _, _ = make_training_metrics( + self.num_train_steps, + ConfigDict({}), + enable_ema=True, + enable_train_cost=True, + enable_param_norms=True, + enable_gradient_norm=True, + enable_all_gradient_norms=True, + enable_update_norm=True, + enable_update_norms=True, + enable_batch_stats_norm=True, + enable_all_batch_stats_norms=True, + ) initial_metrics_state = init_fn(self.mock_params0, self.mock_batch_stats) - self.assertTrue(pytree_equal(initial_metrics_state, { + self.assertTrue( + pytree_equal( + initial_metrics_state, + { 'train_cost': zeros_timeseries, 'param_norm': zeros_timeseries, 'grad_ema': zeros_like_params, @@ -125,7 +134,10 @@ def test_init(self): 'update_norm': zeros_timeseries, 'update_norms': zeros_timeseries_like_params, 'batch_stats_norm': zeros_timeseries, - 'all_batch_stats_norms': zeros_timeseries_like_params,},)) + 'all_batch_stats_norms': zeros_timeseries_like_params, + }, + ) + ) def test_train_cost(self): """Ensure that the train cost is logged correctly.""" @@ -222,12 +234,15 @@ def test_update_param_norm(self): expected_param_norm = jnp.zeros(5) expected_param_norm1 = jnp.zeros(5) - expected_param_norm = expected_param_norm.at[0].set(total_tree_norm_l2( - self.mock_params0)) - expected_param_norm1 = expected_param_norm1.at[0].set(total_tree_norm_l2( - self.mock_params0)) - expected_param_norm1 = expected_param_norm1.at[1].set(total_tree_norm_l2( - self.mock_params1)) + expected_param_norm = expected_param_norm.at[0].set( + total_tree_norm_l2(self.mock_params0) + ) + expected_param_norm1 = expected_param_norm1.at[0].set( + total_tree_norm_l2(self.mock_params0) + ) + expected_param_norm1 = expected_param_norm1.at[1].set( + total_tree_norm_l2(self.mock_params1) + ) self.assertTrue( pytree_equal( @@ -424,44 +439,59 @@ def test_optstate_sumsq(self): self.num_train_steps, ConfigDict({}), optstate_sumsq_fields=['nu'], - optstate_sum_fields=['nu']) + optstate_sum_fields=['nu'], + ) initial_metrics_state = init_fn(self.mock_params0, self.mock_batch_stats) - self.assertTrue(pytree_equal( - initial_metrics_state['optstate_sumsq'], { - 'nu': jnp.zeros(self.num_train_steps) - } - )) - self.assertTrue(pytree_equal( - initial_metrics_state['optstate_sum'], { - 'nu': jnp.zeros(self.num_train_steps) - } - )) - updated_metrics_state = update_fn(initial_metrics_state, - 0, - self.mock_cost0, - self.mock_grad1, - self.mock_params0, - self.mock_params1, - self.mock_optimizer_state0, - self.mock_batch_stats) - updated_metrics_state = update_fn(updated_metrics_state, - 1, - self.mock_cost1, - self.mock_grad2, - self.mock_params1, - self.mock_params2, - self.mock_optimizer_state1, - self.mock_batch_stats) - - self.assertEqual(updated_metrics_state['optstate_sumsq']['nu'][0], - total_tree_norm_sql2(self.mock_nu0)) - self.assertEqual(updated_metrics_state['optstate_sumsq']['nu'][1], - total_tree_norm_sql2(self.mock_nu1)) - - self.assertEqual(updated_metrics_state['optstate_sum']['nu'][0], - total_tree_sum(self.mock_nu0)) - self.assertEqual(updated_metrics_state['optstate_sum']['nu'][1], - total_tree_sum(self.mock_nu1)) + self.assertTrue( + pytree_equal( + initial_metrics_state['optstate_sumsq'], + {'nu': jnp.zeros(self.num_train_steps)}, + ) + ) + self.assertTrue( + pytree_equal( + initial_metrics_state['optstate_sum'], + {'nu': jnp.zeros(self.num_train_steps)}, + ) + ) + updated_metrics_state = update_fn( + initial_metrics_state, + 0, + self.mock_cost0, + self.mock_grad1, + self.mock_params0, + self.mock_params1, + self.mock_optimizer_state0, + self.mock_batch_stats, + ) + updated_metrics_state = update_fn( + updated_metrics_state, + 1, + self.mock_cost1, + self.mock_grad2, + self.mock_params1, + self.mock_params2, + self.mock_optimizer_state1, + self.mock_batch_stats, + ) + + self.assertEqual( + updated_metrics_state['optstate_sumsq']['nu'][0], + total_tree_norm_sql2(self.mock_nu0), + ) + self.assertEqual( + updated_metrics_state['optstate_sumsq']['nu'][1], + total_tree_norm_sql2(self.mock_nu1), + ) + + self.assertEqual( + updated_metrics_state['optstate_sum']['nu'][0], + total_tree_sum(self.mock_nu0), + ) + self.assertEqual( + updated_metrics_state['optstate_sum']['nu'][1], + total_tree_sum(self.mock_nu1), + ) def test_update_precondition(self): """Test that precondition metrics are computed correctly.""" @@ -470,100 +500,111 @@ def test_update_precondition(self): init_fn, update_fn, _ = make_training_metrics( self.num_train_steps, - ConfigDict({ - 'optimizer': optimizer, - 'opt_hparams': opt_hparams - }), + ConfigDict({'optimizer': optimizer, 'opt_hparams': opt_hparams}), enable_preconditioner_normsq=True, - enable_semip_grad_normsq=True) + enable_semip_grad_normsq=True, + ) initial_metrics_state = init_fn(self.mock_params0, self.mock_batch_stats) - updated_metrics_state = update_fn(initial_metrics_state, - 0, - self.mock_cost0, - self.mock_grad1, - self.mock_params0, - self.mock_params1, - self.mock_optimizer_state0, - self.mock_batch_stats) - updated_metrics_state = update_fn(updated_metrics_state, - 1, - self.mock_cost1, - self.mock_grad2, - self.mock_params1, - self.mock_params2, - self.mock_optimizer_state1, - self.mock_batch_stats) - - pre0 = make_diag_preconditioner(optimizer, opt_hparams, - self.mock_optimizer_state0, ConfigDict({})) - semip_grad_0 = jax.tree.map(lambda g, p: g / (p**0.5), self.mock_grad1, - pre0) - - pre1 = make_diag_preconditioner(optimizer, opt_hparams, - self.mock_optimizer_state1, ConfigDict({})) - semip_grad_1 = jax.tree.map(lambda g, p: g / (p**0.5), self.mock_grad2, - pre1) - - self.assertEqual(updated_metrics_state['preconditioner_normsq'][0], - total_tree_norm_sql2(pre0)) - self.assertEqual(updated_metrics_state['preconditioner_normsq'][1], - total_tree_norm_sql2(pre1)) - - self.assertEqual(updated_metrics_state['semip_grad_normsq'][0], - total_tree_norm_sql2(semip_grad_0)) - self.assertEqual(updated_metrics_state['semip_grad_normsq'][1], - total_tree_norm_sql2(semip_grad_1)) + updated_metrics_state = update_fn( + initial_metrics_state, + 0, + self.mock_cost0, + self.mock_grad1, + self.mock_params0, + self.mock_params1, + self.mock_optimizer_state0, + self.mock_batch_stats, + ) + updated_metrics_state = update_fn( + updated_metrics_state, + 1, + self.mock_cost1, + self.mock_grad2, + self.mock_params1, + self.mock_params2, + self.mock_optimizer_state1, + self.mock_batch_stats, + ) + + pre0 = make_diag_preconditioner( + optimizer, opt_hparams, self.mock_optimizer_state0, ConfigDict({}) + ) + semip_grad_0 = jax.tree.map( + lambda g, p: g / (p**0.5), self.mock_grad1, pre0 + ) + + pre1 = make_diag_preconditioner( + optimizer, opt_hparams, self.mock_optimizer_state1, ConfigDict({}) + ) + semip_grad_1 = jax.tree.map( + lambda g, p: g / (p**0.5), self.mock_grad2, pre1 + ) + + self.assertEqual( + updated_metrics_state['preconditioner_normsq'][0], + total_tree_norm_sql2(pre0), + ) + self.assertEqual( + updated_metrics_state['preconditioner_normsq'][1], + total_tree_norm_sql2(pre1), + ) + + self.assertEqual( + updated_metrics_state['semip_grad_normsq'][0], + total_tree_norm_sql2(semip_grad_0), + ) + self.assertEqual( + updated_metrics_state['semip_grad_normsq'][1], + total_tree_norm_sql2(semip_grad_1), + ) def test_summarize(self): """Test the training metrics summarizer.""" - _, _, summarize_fn = make_training_metrics(self.num_train_steps, - ConfigDict({}), - enable_train_cost=True, - enable_ema=True) + _, _, summarize_fn = make_training_metrics( + self.num_train_steps, + ConfigDict({}), + enable_train_cost=True, + enable_ema=True, + ) metrics_state = { 'train_cost': jnp.array([1.0, 0.5, 0.25, 0.0, 0.0]), - 'param_norm': { - 'foo': 7.0, - 'bar': {'baz': 2.0} - }, - 'grad_ema': { - 'foo': 1 * jnp.ones(5), - 'bar': {'baz': 2 * jnp.ones(10)} - }, + 'param_norm': {'foo': 7.0, 'bar': {'baz': 2.0}}, + 'grad_ema': {'foo': 1 * jnp.ones(5), 'bar': {'baz': 2 * jnp.ones(10)}}, 'grad_sq_ema': { 'foo': 2 * jnp.ones(5), - 'bar': {'baz': 6 * jnp.ones(10)} + 'bar': {'baz': 6 * jnp.ones(10)}, }, 'update_ema': { 'foo': 2 * jnp.ones(5), - 'bar': {'baz': 1 * jnp.ones(10)} + 'bar': {'baz': 1 * jnp.ones(10)}, }, 'update_sq_ema': { 'foo': 6 * jnp.ones(5), - 'bar': {'baz': 2 * jnp.ones(10)} + 'bar': {'baz': 2 * jnp.ones(10)}, }, } tree_summary = summarize_fn(metrics_state) self.assertTrue( pytree_equal( - tree_summary, { - 'param_norm': { - '/foo': 7.0, - '/bar/baz': 2.0 - }, + tree_summary, + { + 'param_norm': {'/foo': 7.0, '/bar/baz': 2.0}, 'grad_var': { '/foo': 5 * (2 - 1**2), - '/bar/baz': 10 * (6 - 2**2) + '/bar/baz': 10 * (6 - 2**2), }, 'update_var': { '/foo': 5 * (6 - 2**2), - '/bar/baz': 10 * (2 - 1**2) + '/bar/baz': 10 * (2 - 1**2), }, 'update_ratio': { '/foo': 5 * (6 - 2**2) / 7.0, - '/bar/baz': 10 * (2 - 1**2) / 2.0 - } - })) + '/bar/baz': 10 * (2 - 1**2) / 2.0, + }, + }, + ) + ) + if __name__ == '__main__': absltest.main() diff --git a/init2winit/test_utils.py b/init2winit/test_utils.py index 0351d2ae..18ddb329 100644 --- a/init2winit/test_utils.py +++ b/init2winit/test_utils.py @@ -86,10 +86,12 @@ def test_run_in_parallel_on_failing_fn(self): def test_array_append(self): """Test appending to an array.""" np.testing.assert_allclose( - utils.array_append(jnp.array([1, 2, 3]), 4), jnp.array([1, 2, 3, 4])) + utils.array_append(jnp.array([1, 2, 3]), 4), jnp.array([1, 2, 3, 4]) + ) np.testing.assert_allclose( utils.array_append(jnp.array([[1, 2], [3, 4]]), jnp.array([5, 6])), - jnp.array([[1, 2], [3, 4], [5, 6]])) + jnp.array([[1, 2], [3, 4], [5, 6]]), + ) def test_tree_norm_sq_l2(self): """Test computing the squared L2 norm of a pytree.""" @@ -99,9 +101,9 @@ def test_tree_norm_sq_l2(self): def test_tree_sum(self): """Test computing the sum of a pytree.""" - pytree = {'foo': 2*jnp.ones(10), 'baz': jnp.ones(20)} + pytree = {'foo': 2 * jnp.ones(10), 'baz': jnp.ones(20)} self.assertEqual(utils.total_tree_sum(pytree), 40) + if __name__ == '__main__': absltest.main() - diff --git a/init2winit/tools/inspect_dataset.py b/init2winit/tools/inspect_dataset.py index d9697503..609648d3 100644 --- a/init2winit/tools/inspect_dataset.py +++ b/init2winit/tools/inspect_dataset.py @@ -36,8 +36,9 @@ flags.DEFINE_string('dataset', None, 'Which dataset to inspect') flags.DEFINE_string('model', None, 'Which model to use') -flags.DEFINE_integer('batch_size', None, - 'Number of examples to retrieve in 1 batch') +flags.DEFINE_integer( + 'batch_size', None, 'Number of examples to retrieve in 1 batch' +) flags.DEFINE_integer('num_batches', None, 'Number of batches to retrieve') FLAGS = flags.FLAGS @@ -80,8 +81,9 @@ def main(unused_argv): rng = jax.random.PRNGKey(0) rng, data_rng = jax.random.split(rng) - dataset = datasets.get_dataset(FLAGS.dataset)(data_rng, batch_size, - batch_size, hps) + dataset = datasets.get_dataset(FLAGS.dataset)( + data_rng, batch_size, batch_size, hps + ) train_iter = dataset.train_iterator_fn() for i in range(num_batches): diff --git a/init2winit/trainer_lib/base_trainer.py b/init2winit/trainer_lib/base_trainer.py index 32d3d7a3..9ed78200 100644 --- a/init2winit/trainer_lib/base_trainer.py +++ b/init2winit/trainer_lib/base_trainer.py @@ -34,7 +34,6 @@ from ml_collections.config_dict import config_dict import orbax.checkpoint as ocp - CHECKPOINT_TTL = 'ttl=180d' @@ -243,15 +242,17 @@ def __init__( # 'params' and 'batch_stats' buffers as we don't re-assign those values in # eval, we do that only in train. self._evaluate_batch_jitted = jax.jit( - self._model.evaluate_batch, donate_argnums=(2,)) + self._model.evaluate_batch, donate_argnums=(2,) + ) # Creates a 1-d mesh with all devices available globally. self._mesh = model_utils.get_default_mesh() # Set training algorithm class. self._training_algorithm_class = training_algorithm_class - logging.info('Using training algorithm class: %s', - self._training_algorithm_class) + logging.info( + 'Using training algorithm class: %s', self._training_algorithm_class + ) if log_frequency is not None: self._log_frequency = log_frequency @@ -272,11 +273,13 @@ def log_model_info(self, unreplicated_params): utils.tabulate_model(self._model, self._hps) logging.info('train_size: %d,', self._hps.train_size) - def maybe_restore_from_checkpoint(self, - unreplicated_optimizer_state, - unreplicated_params, - unreplicated_batch_stats, - unreplicated_metrics_state): + def maybe_restore_from_checkpoint( + self, + unreplicated_optimizer_state, + unreplicated_params, + unreplicated_batch_stats, + unreplicated_metrics_state, + ): """Restores the training state from a checkpoint if one exists. Args: @@ -414,13 +417,19 @@ def _setup_eval_callbacks(self, callback_rng): return eval_callbacks def _run_eval_callbacks(self, report): + """Runs all registered evaluation callbacks and updates the report.""" for eval_callback in self._eval_callbacks: - callback_metrics = eval_callback.run_eval(self._params, self._batch_stats, - self._optimizer_state, - self._global_step) + callback_metrics = eval_callback.run_eval( + self._params, + self._batch_stats, + self._optimizer_state, + self._global_step, + ) if set(callback_metrics.keys()).intersection(set(report.keys())): - raise ValueError('There was a collision between the callback' - 'metrics and the standard eval metrics keys') + raise ValueError( + 'There was a collision between the callback' + 'metrics and the standard eval metrics keys' + ) report.update(callback_metrics) def _check_early_stopping(self, report): @@ -443,7 +452,8 @@ def _check_early_stopping(self, report): self._early_stopping_target_name, report[self._early_stopping_target_name], comparison_string, - self._early_stopping_target_value) + self._early_stopping_target_value, + ) return early_stopping_condition def _build_training_report(self): @@ -553,7 +563,8 @@ def _eval(self, start_step, start_time, eval_rng, save=True): self._sum_train_cost = 0.0 epoch = self._global_step * self._hps.batch_size // self._hps.train_size overall_steps_per_sec = self._get_step_frequency( - self._global_step, start_step, start_time) + self._global_step, start_step, start_time + ) report.update( epoch=epoch, preemption_count=self._preemption_count, @@ -574,7 +585,8 @@ def _eval(self, start_step, start_time, eval_rng, save=True): start_time, self._eval_frequency, self._eval_steps, - eval_time) + eval_time, + ) trainer_utils.log_epoch_report(report, self._metrics_logger) self._time_at_prev_eval_end = time.time() @@ -628,15 +640,15 @@ def setup_and_maybe_restore(self, init_rng, data_rng, callback_rng): start_time = time.time() if self._training_metrics_config is not None: - (metrics_init_fn, self._metrics_update_fn, - self._metrics_summary_fn) = make_training_metrics( - self._num_train_steps, self._hps, **self._training_metrics_config) + metrics_init_fn, self._metrics_update_fn, self._metrics_summary_fn = ( + make_training_metrics( + self._num_train_steps, self._hps, **self._training_metrics_config + ) + ) unreplicated_metrics_state = metrics_init_fn( unreplicated_params, unreplicated_batch_stats ) - logging.info( - 'Metrics initialized in %f seconds', time.time() - start_time - ) + logging.info('Metrics initialized in %f seconds', time.time() - start_time) start_time = time.time() ( unreplicated_optimizer_state, @@ -650,9 +662,7 @@ def setup_and_maybe_restore(self, init_rng, data_rng, callback_rng): unreplicated_metrics_state, ) - logging.info( - 'Checkpoint restored in %f seconds', time.time() - start_time - ) + logging.info('Checkpoint restored in %f seconds', time.time() - start_time) start_time = time.time() ( self._params, @@ -723,7 +733,8 @@ def train(self): batch_stats=self._batch_stats, hps=self._hps, global_step=self._global_step, - constant_base_rng=rng) + constant_base_rng=rng, + ) start_time = time.time() start_step = self._global_step @@ -835,8 +846,13 @@ def update(self, batch, rng, metrics_update_fn, metrics_state, training_cost): """ @abc.abstractmethod - def shard(self, unreplicated_params, unreplicated_optimizer_state, - unreplicated_batch_stats, unreplicated_metrics_state): + def shard( + self, + unreplicated_params, + unreplicated_optimizer_state, + unreplicated_batch_stats, + unreplicated_metrics_state, + ): """Shard the training state. Args: diff --git a/init2winit/trainer_lib/test_mdlm_integration.py b/init2winit/trainer_lib/test_mdlm_integration.py index a16c0d61..ecf9242d 100644 --- a/init2winit/trainer_lib/test_mdlm_integration.py +++ b/init2winit/trainer_lib/test_mdlm_integration.py @@ -260,5 +260,6 @@ def test_loss_decreases_on_pattern(self): eval_results['perplexity'], ) + if __name__ == '__main__': absltest.main() diff --git a/init2winit/trainer_lib/test_trainer.py b/init2winit/trainer_lib/test_trainer.py index d0be571e..3eb3ce28 100644 --- a/init2winit/trainer_lib/test_trainer.py +++ b/init2winit/trainer_lib/test_trainer.py @@ -48,7 +48,6 @@ import tensorflow.compat.v1 as tf # importing this is needed for tfds mocking. import tensorflow_datasets as tfds - FLAGS = flags.FLAGS _VOCAB_SIZE = 4 @@ -84,7 +83,7 @@ def get_column_names(): 'train_cost', 'grad_norm', 'update_norm', - 'train_steps_per_sec' + 'train_steps_per_sec', ] return column_names @@ -92,7 +91,8 @@ def get_column_names(): def _get_fake_text_dataset(batch_size, eval_num_batches): """Yields a single text batch repeatedly for train and test.""" inputs = jnp.array( - np.random.randint(low=0, high=_VOCAB_SIZE, size=(batch_size, _MAX_LEN))) + np.random.randint(low=0, high=_VOCAB_SIZE, size=(batch_size, _MAX_LEN)) + ) batch = { 'inputs': inputs, 'targets': inputs, @@ -124,10 +124,12 @@ def test_epoch(num_batches=None): meta_data = { 'apply_one_hot_in_loss': True, 'shift_inputs': True, - 'causal': True + 'causal': True, } - return (Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, - test_epoch), meta_data) + return ( + Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch), + meta_data, + ) def _get_fake_graph_dataset(batch_size, eval_num_batches, hps): @@ -141,10 +143,9 @@ def _get_batch(n_nodes_list): for n_nodes in n_nodes_list: n_edges = n_nodes**2 graph = jraph.get_fully_connected_graph( - n_nodes, 1, - np.ones((n_nodes, *hps.input_node_shape))) - graph = graph._replace( - edges=np.ones((n_edges, *hps.input_edge_shape))) + n_nodes, 1, np.ones((n_nodes, *hps.input_node_shape)) + ) + graph = graph._replace(edges=np.ones((n_edges, *hps.input_edge_shape))) labels = np.ones(hps.output_shape) * (1 if n_nodes in [4, 6] else 0) weights = np.ones(*hps.output_shape) graphs_list.append(graph) @@ -166,6 +167,7 @@ def _get_batch(n_nodes_list): [i for i in range(min_nodes, min_nodes + n_graphs_per_batch)] for _ in range(num_batches) ] + def train_iterator_fn(): for ns in itertools.cycle(n_nodes_list): yield _get_batch(ns) @@ -188,16 +190,19 @@ def test_epoch(num_batches=None): for ns in itertools.islice(itertools.cycle(n_nodes_list), num_batches): yield _get_batch(ns) - return (Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, - test_epoch), { - 'apply_one_hot_in_loss': False, - }) + return ( + Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch), + { + 'apply_one_hot_in_loss': False, + }, + ) def _get_fake_dlrm_dataset(batch_size, eval_num_batches, hps): """Yields a single text batch repeatedly for train and test.""" cat_features = np.random.randint( - low=0, high=hps.vocab_size, size=(batch_size, 26)) + low=0, high=hps.vocab_size, size=(batch_size, 26) + ) int_features = np.random.normal(size=(batch_size, hps.num_dense_features)) inputs = np.concatenate((int_features, cat_features), 1) targets = np.random.randint(low=0, high=2, size=(batch_size, 1)) @@ -232,8 +237,10 @@ def test_epoch(num_batches=None): meta_data = { 'apply_one_hot_in_loss': False, } - return (Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, - test_epoch), meta_data) + return ( + Dataset(train_iterator_fn, eval_train_epoch, valid_epoch, test_epoch), + meta_data, + ) class TrainerTest(parameterized.TestCase): @@ -269,7 +276,8 @@ def test_initialize_rescale(self): # First initialize with no rescale. params, _ = model.initialize( - initializer, hps, init_rng, metrics_logger=None) + initializer, hps, init_rng, metrics_logger=None + ) utils.log_pytree_shape_and_statistics(params) # Now rescale a layer by 100. @@ -279,7 +287,8 @@ def test_initialize_rescale(self): } rescaled_params, _ = model.initialize( - initializer, hps, init_rng, metrics_logger=None) + initializer, hps, init_rng, metrics_logger=None + ) # Check the right variable is rescaled v1 = params['Dense_1']['kernel'] @@ -296,7 +305,7 @@ def test_initialize_rescale(self): def test_classifaction_model_evaluate(self): """Test trainer evaluate end to end with classification model metrics.""" # Define a fake model that always outputs the same logits. - fake_batch_logits = np.tile([.5, .2, .7, 0.0], (4, 1)) + fake_batch_logits = np.tile([0.5, 0.2, 0.7, 0.0], (4, 1)) class FakeModel(nn.Module): @@ -305,7 +314,8 @@ def __call__(self, x, train): # Make a single linear layer with the identity as the init. identity_fn = lambda *_: np.eye(4) x = nn.Dense(features=4, use_bias=False, kernel_init=identity_fn)( - fake_batch_logits) + fake_batch_logits + ) return x key = jax.random.PRNGKey(0) @@ -313,9 +323,8 @@ def __call__(self, x, train): fake_flax_module = FakeModel() model_init_fn = jax.jit(fake_flax_module.init) init_dict = model_init_fn( - rngs={'params': params_rng, 'dropout': dropout_rng}, - x=None, - train=False) + rngs={'params': params_rng, 'dropout': dropout_rng}, x=None, train=False + ) mesh_shape = (jax.device_count(),) mesh = jax.sharding.Mesh( mesh_utils.create_device_mesh(mesh_shape, devices=jax.devices()), @@ -327,26 +336,10 @@ def __call__(self, x, train): # 4 evaluation batches of size 4. weights = np.ones((4)) fake_batches = [ - { - 'inputs': None, - 'targets': np.array([3, 2, 1, 0]), - 'weights': weights - }, - { - 'inputs': None, - 'targets': np.array([0, 3, 2, 0]), - 'weights': weights - }, - { - 'inputs': None, - 'targets': np.array([0, 0, 0, 0]), - 'weights': weights - }, - { - 'inputs': None, - 'targets': np.array([1, 1, 1, 1]), - 'weights': weights - }, + {'inputs': None, 'targets': np.array([3, 2, 1, 0]), 'weights': weights}, + {'inputs': None, 'targets': np.array([0, 3, 2, 0]), 'weights': weights}, + {'inputs': None, 'targets': np.array([0, 0, 0, 0]), 'weights': weights}, + {'inputs': None, 'targets': np.array([1, 1, 1, 1]), 'weights': weights}, ] def fake_batches_gen(): @@ -360,7 +353,8 @@ def fake_batches_gen(): batch_stats, batch, metrics.get_metrics('classification_metrics'), - True) + True, + ) evaluate_batch_jitted = jax.jit(eval_fn) # pylint: enable=protected-access @@ -377,13 +371,15 @@ def batch_ce_loss(logits, targets): loss = -np.sum(one_hot_targets * nn.log_softmax(logits), axis=-1) return loss - expected_error_rate = 14.0/16.0 # FakeModel always predicts class 2. + expected_error_rate = 14.0 / 16.0 # FakeModel always predicts class 2. expected_ce_loss = np.mean( - [batch_ce_loss(fake_batch_logits, b['targets']) for b in fake_batches]) + [batch_ce_loss(fake_batch_logits, b['targets']) for b in fake_batches] + ) self.assertEqual(expected_error_rate, evaluated_metrics['error_rate']) self.assertAlmostEqual( - expected_ce_loss, evaluated_metrics['ce_loss'], places=4) + expected_ce_loss, evaluated_metrics['ce_loss'], places=4 + ) self.assertEqual(16, evaluated_metrics['num_examples']) def test_graph_model_trainer(self): @@ -408,10 +404,7 @@ def test_graph_model_trainer(self): 'num_message_passing_steps': 1, 'normalizer': 'none', 'dropout_rate': 0.0, - 'lr_hparams': { - 'base_lr': 0.001, - 'schedule': 'constant' - }, + 'lr_hparams': {'base_lr': 0.001, 'schedule': 'constant'}, 'num_device_prefetches': 0, }) eval_num_batches = 5 @@ -419,7 +412,8 @@ def test_graph_model_trainer(self): loss_name = 'sigmoid_binary_cross_entropy' metrics_name = 'binary_classification_metrics_ogbg_map' dataset, dataset_meta_data = _get_fake_graph_dataset( - batch_size=hps.batch_size, eval_num_batches=eval_num_batches, hps=hps) + batch_size=hps.batch_size, eval_num_batches=eval_num_batches, hps=hps + ) model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) initializer = initializers.get_initializer('noop') @@ -448,8 +442,9 @@ def test_graph_model_trainer(self): ) _ = list(self.trainer.train()) - with tf.io.gfile.GFile(os.path.join(self.test_dir, - 'measurements.csv')) as f: + with tf.io.gfile.GFile( + os.path.join(self.test_dir, 'measurements.csv') + ) as f: df = pandas.read_csv(f) train_loss = df['train/ce_loss'].values self.assertLess(train_loss[-1], train_loss[0]) @@ -476,7 +471,8 @@ def test_dlrm_model_trainer(self): loss_name = 'sigmoid_binary_cross_entropy' metrics_name = 'binary_classification_metrics' dataset, dataset_meta_data = _get_fake_dlrm_dataset( - dataset_hps.batch_size, eval_num_batches, dataset_hps) + dataset_hps.batch_size, eval_num_batches, dataset_hps + ) hps = copy.copy(model_hps) hps.update({ 'train_size': 15, @@ -517,8 +513,9 @@ def test_dlrm_model_trainer(self): ) _ = list(self.trainer.train()) - with tf.io.gfile.GFile(os.path.join(self.test_dir, - 'measurements.csv')) as f: + with tf.io.gfile.GFile( + os.path.join(self.test_dir, 'measurements.csv') + ) as f: df = pandas.read_csv(f) train_loss = df['train/ce_loss'].values self.assertLess(train_loss[-1], train_loss[0]) @@ -554,10 +551,7 @@ def test_text_model_trainer(self): 'opt_hparams': { 'momentum': 0.9, }, - 'lr_hparams': { - 'base_lr': 0.005, - 'schedule': 'constant' - }, + 'lr_hparams': {'base_lr': 0.005, 'schedule': 'constant'}, # Training HParams. 'l2_decay_factor': 1e-4, 'l2_decay_rank_threshold': 2, @@ -571,7 +565,8 @@ def test_text_model_trainer(self): initializer = initializers.get_initializer('noop') eval_num_batches = 5 dataset, dataset_meta_data = _get_fake_text_dataset( - batch_size=hps.batch_size, eval_num_batches=eval_num_batches) + batch_size=hps.batch_size, eval_num_batches=eval_num_batches + ) eval_batch_size = hps.batch_size model = model_cls(hps, dataset_meta_data, loss_name, metrics_name) @@ -602,7 +597,8 @@ def test_text_model_trainer(self): _ = list(self.trainer.train()) with tf.io.gfile.GFile( - os.path.join(self.test_dir, 'measurements.csv')) as f: + os.path.join(self.test_dir, 'measurements.csv') + ) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] # Note that upgrading to Linen made this fail at 0.6. @@ -633,7 +629,8 @@ def test_text_model_trainer(self): ) _ = list(self.trainer.train()) with tf.io.gfile.GFile( - os.path.join(self.test_dir, 'measurements.csv')) as f: + os.path.join(self.test_dir, 'measurements.csv') + ) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] train_loss = df['train/ce_loss'].values[-1] @@ -645,7 +642,8 @@ def test_text_model_trainer(self): self.assertEqual( df['valid/num_examples'].values[-1], - eval_num_batches * eval_batch_size * _MAX_LEN) + eval_num_batches * eval_batch_size * _MAX_LEN, + ) # Check that the correct learning rate was saved in the measurements file. final_step = df['global_step'].values[-1] self.assertEqual(num_train_steps_reload, final_step) @@ -669,20 +667,19 @@ def test_trainer(self): initializer = initializers.get_initializer(initializer_name) dataset_builder = datasets.get_dataset(dataset_name) hparam_overrides = { - 'lr_hparams': { - 'base_lr': 0.1, - 'schedule': 'cosine' - }, + 'lr_hparams': {'base_lr': 0.1, 'schedule': 'cosine'}, 'batch_size': 8, 'train_size': 160, 'valid_size': 96, 'test_size': 80, } - input_pipeline_hps = config_dict.ConfigDict(dict( - num_tf_data_prefetches=-1, - num_device_prefetches=0, - num_tf_data_map_parallel_calls=-1, - )) + input_pipeline_hps = config_dict.ConfigDict( + dict( + num_tf_data_prefetches=-1, + num_device_prefetches=0, + num_tf_data_map_parallel_calls=-1, + ) + ) hps = hyperparameters.build_hparams( model_name, initializer_name, @@ -701,25 +698,34 @@ def as_dataset(self, *args, **kwargs): # pylint: disable=g-long-lambda,g-complex-comprehension return tf.data.Dataset.from_generator( - lambda: ({ - 'image': np.ones(shape=(28, 28, 1), dtype=np.uint8), - 'label': 9, - } for i in range(num_examples)), + lambda: ( + { + 'image': np.ones(shape=(28, 28, 1), dtype=np.uint8), + 'label': 9, + } + for i in range(num_examples) + ), output_types=self.info.features.dtype, output_shapes=self.info.features.shape, ) # This will override the tfds.load(mnist) call to return 100 fake samples. with tfds.testing.mock_data( - as_dataset_fn=as_dataset, num_examples=num_examples): + as_dataset_fn=as_dataset, num_examples=num_examples + ): dataset = dataset_builder( shuffle_rng=jax.random.PRNGKey(0), batch_size=hps.batch_size, eval_batch_size=eval_batch_size, - hps=hps) + hps=hps, + ) - model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name), - loss_name, metrics_name) + model = model_cls( + hps, + datasets.get_dataset_meta_data(dataset_name), + loss_name, + metrics_name, + ) num_train_steps = 40 eval_num_batches = 5 @@ -758,7 +764,8 @@ def as_dataset(self, *args, **kwargs): self.assertLen(epoch_reports, num_train_steps / eval_every) with tf.io.gfile.GFile( - os.path.join(self.test_dir, 'measurements.csv')) as f: + os.path.join(self.test_dir, 'measurements.csv') + ) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] self.assertEqual(df['preemption_count'].values[-1], 0) @@ -766,8 +773,9 @@ def as_dataset(self, *args, **kwargs): self.assertEqual(set(df.columns.values), set(get_column_names())) - model = model_cls(hps, {'apply_one_hot_in_loss': False}, loss_name, - metrics_name) + model = model_cls( + hps, {'apply_one_hot_in_loss': False}, loss_name, metrics_name + ) # Test reload from the checkpoint by increasing num_train_steps. num_train_steps_reload = 100 @@ -791,17 +799,21 @@ def as_dataset(self, *args, **kwargs): ) epoch_reports = list(self.trainer.train()) self.assertLen( - epoch_reports, (num_train_steps_reload - num_train_steps) / eval_every) + epoch_reports, (num_train_steps_reload - num_train_steps) / eval_every + ) with tf.io.gfile.GFile( - os.path.join(self.test_dir, 'measurements.csv')) as f: + os.path.join(self.test_dir, 'measurements.csv') + ) as f: df = pandas.read_csv(f) train_err = df['train/error_rate'].values[-1] train_loss = df['train/ce_loss'].values[-1] self.assertLess(train_err, 0.35) self.assertLess(train_loss, 0.1) - self.assertEqual(df['valid/num_examples'].values[-1], - eval_num_batches * eval_batch_size) + self.assertEqual( + df['valid/num_examples'].values[-1], + eval_num_batches * eval_batch_size, + ) self.assertEqual(df['preemption_count'].values[-1], 1) # Check that the correct learning rate was saved in the measurements file. final_learning_rate = df['learning_rate'].values[-1] @@ -811,10 +823,12 @@ def as_dataset(self, *args, **kwargs): # final_step will be one larger than the last step used to calculate the # lr_decay, hense we plug in (final_step - 1) to the decay formula. # Note that there is a small numerical different here with np vs jnp. - decay_factor = (1 + np.cos( - (final_step - 1) / num_train_steps_reload * np.pi)) * 0.5 - self.assertAlmostEqual(float(final_learning_rate), - hps.lr_hparams['base_lr'] * decay_factor) + decay_factor = ( + 1 + np.cos((final_step - 1) / num_train_steps_reload * np.pi) + ) * 0.5 + self.assertAlmostEqual( + float(final_learning_rate), hps.lr_hparams['base_lr'] * decay_factor + ) self.assertEqual(set(df.columns.values), set(get_column_names())) @@ -826,7 +840,8 @@ def as_dataset(self, *args, **kwargs): targets=np.array([[1, 0], [0, 1], [1, 0], [1, 0]]), weights=np.array([1, 1, 0, 0]), test_metric_names=['error_rate', 'num_examples'], - test_metric_vals=[0.5, 2]), + test_metric_vals=[0.5, 2], + ), dict( testcase_name='fractional_weights', metrics_name='classification_metrics', @@ -834,86 +849,185 @@ def as_dataset(self, *args, **kwargs): targets=np.array([[1, 0], [1, 0]]), weights=np.array([0.3, 0.7]), test_metric_names=['error_rate', 'num_examples'], - test_metric_vals=[0.3, 1]), + test_metric_vals=[0.3, 1], + ), dict( testcase_name='binary_classification_basic', metrics_name='binary_classification_metrics', - logits=np.array([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5], - [0.5, 0.5], [0.5, 0.5]]), + logits=np.array([ + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + ]), targets=np.array([[1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1]]), - weights=np.array([[1., 1.], [1., 1.], [1., 1.], [1., 1.], [1., 1.], - [1., 1.]]), + weights=np.array([ + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + ]), test_metric_names=['ce_loss', 'average_precision', 'auc_roc'], - test_metric_vals=[0.724077, 0.5, 0.5]), + test_metric_vals=[0.724077, 0.5, 0.5], + ), dict( testcase_name='binary_classification_no_weights', metrics_name='binary_classification_metrics', - logits=np.array([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5], - [0.5, 0.5], [0.5, 0.5]]), + logits=np.array([ + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + ]), targets=np.array([[1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1]]), weights=None, test_metric_names=['ce_loss', 'average_precision', 'auc_roc'], - test_metric_vals=[1.448154, 0.5, 0.5]), + test_metric_vals=[1.448154, 0.5, 0.5], + ), dict( testcase_name='binary_classification_zero_weights', metrics_name='binary_classification_metrics', - logits=np.array([[100, 0.5], [100, 0.5], [0.3, 0.7], [0.5, 0.15], - [0.05, 0.5], [0.9, 0.5]]), + logits=np.array([ + [100, 0.5], + [100, 0.5], + [0.3, 0.7], + [0.5, 0.15], + [0.05, 0.5], + [0.9, 0.5], + ]), targets=np.array([[1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1]]), - weights=np.array([[0., 1.], [0., 1.], [1., 1.], [1., 1.], [1., 1.], - [1., 1.]]), + weights=np.array([ + [0.0, 1.0], + [0.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + ]), test_metric_names=['ce_loss', 'average_precision', 'auc_roc'], - test_metric_vals=[0.8058497, 0.433333, 0.22222]), + test_metric_vals=[0.8058497, 0.433333, 0.22222], + ), dict( testcase_name='binary_classification_1d_weights', metrics_name='binary_classification_metrics', - logits=np.array([[100, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 100], - [0.5, 0.5], [0.5, 0.5]]), + logits=np.array([ + [100, 0.5], + [0.5, 0.5], + [0.5, 0.5], + [0.5, 100], + [0.5, 0.5], + [0.5, 0.5], + ]), targets=np.array([[1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1]]), - weights=np.array([0., 1., 1., 0., 1., 1.,]), + weights=np.array([ + 0.0, + 1.0, + 1.0, + 0.0, + 1.0, + 1.0, + ]), test_metric_names=['ce_loss', 'average_precision', 'auc_roc'], - test_metric_vals=[1.448154, 0.5, 0.5]), + test_metric_vals=[1.448154, 0.5, 0.5], + ), dict( testcase_name='binary_autoencoder_2d_weights', metrics_name='binary_autoencoder_metrics', - logits=np.array([[0.5, 100], [0.5, 100], [0.5, 0.5], [0.5, 0.5], - [0.5, 0.5], [0.5, 0.5]]), + logits=np.array([ + [0.5, 100], + [0.5, 100], + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + ]), targets=np.array([[1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1]]), - weights=np.array([[1., 0.], [1., 0.], [1., 1.], [1., 1.], [1., 1.], - [1., 1.]]), + weights=np.array([ + [1.0, 0.0], + [1.0, 0.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + [1.0, 1.0], + ]), test_metric_names=['sigmoid_mean_squared_error'], - test_metric_vals=[0.26499629]), + test_metric_vals=[0.26499629], + ), dict( testcase_name='binary_autoencoder_1d_weights', metrics_name='binary_autoencoder_metrics', - logits=np.array([[0.5, 100], [0.5, 0.5], [0.5, 100], [0.5, 0.5], - [0.5, 0.5], [0.5, 0.5]]), + logits=np.array([ + [0.5, 100], + [0.5, 0.5], + [0.5, 100], + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + ]), targets=np.array([[1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1]]), - weights=np.array([0., 1., 0., 1., 1., 1.,]), + weights=np.array([ + 0.0, + 1.0, + 0.0, + 1.0, + 1.0, + 1.0, + ]), test_metric_names=['sigmoid_mean_squared_error'], - test_metric_vals=[0.52999258]), + test_metric_vals=[0.52999258], + ), dict( testcase_name='binary_autoencoder_no_weights', metrics_name='binary_autoencoder_metrics', - logits=np.array([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5], - [0.5, 0.5], [0.5, 0.5]]), + logits=np.array([ + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + [0.5, 0.5], + ]), targets=np.array([[1, 0], [0, 1], [1, 0], [0, 1], [1, 0], [0, 1]]), weights=None, test_metric_names=['sigmoid_mean_squared_error'], - test_metric_vals=[0.5299926]), + test_metric_vals=[0.5299926], + ), dict( testcase_name='binary_classification_metrics_ogbg_map', metrics_name='binary_classification_metrics_ogbg_map', - logits=np.array([[100, 0.5], [0.75, 0.25], [0.15, -0.95], [0.5, 100], - [0.15, 0.5], [0.5, 0.05], [-7.0, 8.1], [-7.0, 8.1]]), + logits=np.array([ + [100, 0.5], + [0.75, 0.25], + [0.15, -0.95], + [0.5, 100], + [0.15, 0.5], + [0.5, 0.05], + [-7.0, 8.1], + [-7.0, 8.1], + ]), targets=np.array( - [[1, 0], [1, 0], [1, 0], [0, 1], [1, 0], [0, 1], [0, 1], [0, 1]]), - weights=np.array([0., 1., 1., 0., 1., 1., 1., 0.]), + [[1, 0], [1, 0], [1, 0], [0, 1], [1, 0], [0, 1], [0, 1], [0, 1]] + ), + weights=np.array([0.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0, 0.0]), test_metric_names=['average_precision', 'ce_loss'], - test_metric_vals=[0.791666, 1.0799019]), + test_metric_vals=[0.791666, 1.0799019], + ), ) - def test_evaluate(self, metrics_name, logits, targets, weights, - test_metric_names, test_metric_vals): + def test_evaluate( + self, + metrics_name, + logits, + targets, + weights, + test_metric_names, + test_metric_vals, + ): """Test metrics merging and evaluation including zero weights.""" def mock_evaluate_batch(params, batch_stats, batch): @@ -924,18 +1038,18 @@ def mock_evaluate_batch(params, batch_stats, batch): return metrics_bundle.single_from_model_output( logits=batch.get('logits'), targets=batch.get('targets'), - weights=batch.get('weights')) + weights=batch.get('weights'), + ) logits = np.split(logits, 2) targets = np.split(targets, 2) weights = np.split(weights, 2) if weights is not None else [None, None] # pylint: disable=g-complex-comprehension - batch_iter = [{ - 'logits': ls, - 'targets': ts, - 'weights': ws - } for ls, ts, ws in zip(logits, targets, weights)] + batch_iter = [ + {'logits': ls, 'targets': ts, 'weights': ws} + for ls, ts, ws in zip(logits, targets, weights) + ] mesh_shape = (jax.device_count(),) mesh = jax.sharding.Mesh( @@ -982,11 +1096,13 @@ def test_early_stopping(self, min_steps): 'valid_size': 96, 'test_size': 80, } - input_pipeline_hps = config_dict.ConfigDict(dict( - num_tf_data_prefetches=-1, - num_device_prefetches=0, - num_tf_data_map_parallel_calls=-1, - )) + input_pipeline_hps = config_dict.ConfigDict( + dict( + num_tf_data_prefetches=-1, + num_device_prefetches=0, + num_tf_data_map_parallel_calls=-1, + ) + ) hps = hyperparameters.build_hparams( model_name, initializer_name, @@ -1005,25 +1121,34 @@ def as_dataset(self, *args, **kwargs): # pylint: disable=g-long-lambda,g-complex-comprehension return tf.data.Dataset.from_generator( - lambda: ({ - 'image': np.ones(shape=(28, 28, 1), dtype=np.uint8), - 'label': 9, - } for i in range(num_examples)), + lambda: ( + { + 'image': np.ones(shape=(28, 28, 1), dtype=np.uint8), + 'label': 9, + } + for i in range(num_examples) + ), output_types=self.info.features.dtype, output_shapes=self.info.features.shape, ) # This will override the tfds.load(mnist) call to return 100 fake samples. with tfds.testing.mock_data( - as_dataset_fn=as_dataset, num_examples=num_examples): + as_dataset_fn=as_dataset, num_examples=num_examples + ): dataset = dataset_builder( shuffle_rng=jax.random.PRNGKey(0), batch_size=hps.batch_size, eval_batch_size=eval_batch_size, - hps=hps) + hps=hps, + ) - model = model_cls(hps, datasets.get_dataset_meta_data(dataset_name), - loss_name, metrics_name) + model = model_cls( + hps, + datasets.get_dataset_meta_data(dataset_name), + loss_name, + metrics_name, + ) num_train_steps = 40 early_stopping_target_name = 'test/ce_loss' @@ -1067,10 +1192,12 @@ def as_dataset(self, *args, **kwargs): if not min_steps: self.assertGreater( epoch_reports[-2][early_stopping_target_name], - early_stopping_target_value) + early_stopping_target_value, + ) self.assertLess( epoch_reports[-1][early_stopping_target_name], - early_stopping_target_value) + early_stopping_target_value, + ) if __name__ == '__main__': diff --git a/init2winit/trainer_lib/trainer_utils.py b/init2winit/trainer_lib/trainer_utils.py index 8237ad2c..d175d4b9 100644 --- a/init2winit/trainer_lib/trainer_utils.py +++ b/init2winit/trainer_lib/trainer_utils.py @@ -286,6 +286,7 @@ def eval_metrics( ) # Metrics are None if the dataset doesn't have that split if split_metrics is not None: - metrics = _merge_and_apply_prefix(metrics, split_metrics, - (split_name + '/')) + metrics = _merge_and_apply_prefix( + metrics, split_metrics, (split_name + '/') + ) return metrics diff --git a/init2winit/trainer_lib/trainers.py b/init2winit/trainer_lib/trainers.py index ae75177a..65983d13 100644 --- a/init2winit/trainer_lib/trainers.py +++ b/init2winit/trainer_lib/trainers.py @@ -17,7 +17,6 @@ from init2winit.trainer_lib import trainer - _ALL_TRAINERS = { 'standard': trainer.Trainer, } diff --git a/init2winit/trainer_lib/training_algorithm.py b/init2winit/trainer_lib/training_algorithm.py index f509b1f0..64f1246c 100644 --- a/init2winit/trainer_lib/training_algorithm.py +++ b/init2winit/trainer_lib/training_algorithm.py @@ -28,7 +28,6 @@ from ml_collections.config_dict import config_dict import optax - _GRAD_CLIP_EPS = 1e-6 diff --git a/init2winit/training_metrics_grabber.py b/init2winit/training_metrics_grabber.py index 8e0335e8..c9136d19 100644 --- a/init2winit/training_metrics_grabber.py +++ b/init2winit/training_metrics_grabber.py @@ -101,13 +101,13 @@ def make_training_metrics(num_train_steps, hps, **config_overrides): where each key is a field name in optstate_normsq_fields, and each value is a jnp array of length num_time_steps containing the time series of the normsq of this optstate field. - - optstate_sumsq_param_wise_fields (list of str): record the sum of squares - of each of these fields in the optimizer state parameter wise. - If this list is non-empty, the metrics state will have a field - "optstate_sumsq_param_wise" which is a dict where each key is a field + - optstate_sumsq_param_wise_fields (list of str): record the sum of squares + of each of these fields in the optimizer state parameter wise. + If this list is non-empty, the metrics state will have a field + "optstate_sumsq_param_wise" which is a dict where each key is a field name in optstate_sum_fields, and each value is a pytree of the same type - as params but with each leaf a jnp array of length num_time_steps - containing the time series of the sum of this optstate field for this + as params but with each leaf a jnp array of length num_time_steps + containing the time series of the sum of this optstate field for this parameter. - optstate_sum_fields (list of str): record the sum of each of these fields in the optimizer state. If this list is non-empty, the metrics state @@ -115,13 +115,13 @@ def make_training_metrics(num_train_steps, hps, **config_overrides): field name in optstate_sum_fields, and each value is a jnp array of length num_time_steps containing the time series of the sum of this optstate field. - - optstate_sum_param_wise_fields (list of str): record the sum of each of - these fields in the optimizer state parameter wise. If this list is - non-empty, the metrics state will have a field - "optstate_sumsq_param_wise" which is a dict where each key is a field + - optstate_sum_param_wise_fields (list of str): record the sum of each of + these fields in the optimizer state parameter wise. If this list is + non-empty, the metrics state will have a field + "optstate_sumsq_param_wise" which is a dict where each key is a field name in optstate_sum_fields, and each value is a pytree of the same type - as params but with each leaf a jnp array of length num_time_steps - containing the time series of the sum of this optstate field for this + as params but with each leaf a jnp array of length num_time_steps + containing the time series of the sum of this optstate field for this parameter. - enable_preconditioner_normsq (bool): if true, the metrics state will have a field "preconditioner_normsq" which is a jnp array of length @@ -178,22 +178,26 @@ def init_fn(params, batch_stats): metrics_state['train_cost'] = jnp.zeros(num_train_steps) if config['enable_param_norms']: metrics_state['param_norms'] = jax.tree.map( - lambda x: jnp.zeros(num_train_steps), params) + lambda x: jnp.zeros(num_train_steps), params + ) if config['enable_batch_stats_norm']: metrics_state['batch_stats_norm'] = jnp.zeros(num_train_steps) if config['enable_all_batch_stats_norms']: metrics_state['all_batch_stats_norms'] = jax.tree.map( - lambda x: jnp.zeros(num_train_steps), batch_stats) + lambda x: jnp.zeros(num_train_steps), batch_stats + ) if config['enable_gradient_norm']: metrics_state['gradient_norm'] = jnp.zeros(num_train_steps) if config['enable_all_gradient_norms']: metrics_state['all_gradient_norms'] = jax.tree.map( - lambda x: jnp.zeros(num_train_steps), params) + lambda x: jnp.zeros(num_train_steps), params + ) if config['enable_update_norm']: metrics_state['update_norm'] = jnp.zeros(num_train_steps) if config['enable_update_norms']: metrics_state['update_norms'] = jax.tree.map( - lambda x: jnp.zeros(num_train_steps), params) + lambda x: jnp.zeros(num_train_steps), params + ) if config['enable_ema']: metrics_state['grad_ema'] = jax.tree.map(jnp.zeros_like, params) metrics_state['grad_sq_ema'] = jax.tree.map(jnp.zeros_like, params) @@ -225,13 +229,23 @@ def init_fn(params, batch_stats): metrics_state['semip_grad_normsq'] = jnp.zeros(num_train_steps) if config['enable_grafting_norms']: metrics_state['mag_norms'] = jax.tree.map( - lambda x: jnp.zeros(num_train_steps), params) + lambda x: jnp.zeros(num_train_steps), params + ) metrics_state['dir_norms'] = jax.tree.map( - lambda x: jnp.zeros(num_train_steps), params) + lambda x: jnp.zeros(num_train_steps), params + ) return metrics_state - def update_fn(metrics_state, step, train_cost, grad, old_params, new_params, - optimizer_state, batch_stats): + def update_fn( + metrics_state, + step, + train_cost, + grad, + old_params, + new_params, + optimizer_state, + batch_stats, + ): """Update the training metrics state. Args: @@ -279,58 +293,83 @@ def update_fn(metrics_state, step, train_cost, grad, old_params, new_params, batch_stats, ) - def _update_fn(metrics_state, step, train_cost, grad, old_params, new_params, - optimizer_state, batch_stats): + def _update_fn( + metrics_state, + step, + train_cost, + grad, + old_params, + new_params, + optimizer_state, + batch_stats, + ): param_norm = jax.tree.map(_compute_leaf_norms, old_params) grad_norm = jax.tree.map(_compute_leaf_norms, grad) batch_stats_norm = jax.tree.map(_compute_leaf_norms, batch_stats) - if (config['enable_update_norm'] or config['enable_update_norms'] or - config['enable_ema']): + if ( + config['enable_update_norm'] + or config['enable_update_norms'] + or config['enable_ema'] + ): update = jax.tree.map(lambda x, y: x - y, old_params, new_params) else: update = None next_metrics_state = {} - next_metrics_state['param_norm'] = metrics_state['param_norm'].at[ - step].set(total_tree_norm_l2(param_norm)) + next_metrics_state['param_norm'] = ( + metrics_state['param_norm'].at[step].set(total_tree_norm_l2(param_norm)) + ) if config['enable_train_cost']: - next_metrics_state['train_cost'] = metrics_state['train_cost'].at[ - step].set(train_cost) + next_metrics_state['train_cost'] = ( + metrics_state['train_cost'].at[step].set(train_cost) + ) if config['enable_param_norms']: next_metrics_state['param_norms'] = _set_pytree_idx( - metrics_state['param_norms'], param_norm, step) + metrics_state['param_norms'], param_norm, step + ) if config['enable_batch_stats_norm']: - next_metrics_state['batch_stats_norm'] = metrics_state[ - 'batch_stats_norm'].at[step].set( - total_tree_norm_l2(batch_stats)) + next_metrics_state['batch_stats_norm'] = ( + metrics_state['batch_stats_norm'] + .at[step] + .set(total_tree_norm_l2(batch_stats)) + ) if config['enable_all_batch_stats_norms']: next_metrics_state['all_batch_stats_norms'] = _set_pytree_idx( - metrics_state['all_batch_stats_norms'], batch_stats_norm, step) + metrics_state['all_batch_stats_norms'], batch_stats_norm, step + ) if config['enable_gradient_norm']: - next_metrics_state['gradient_norm'] = metrics_state['gradient_norm'].at[ - step].set(total_tree_norm_l2(grad)) + next_metrics_state['gradient_norm'] = ( + metrics_state['gradient_norm'].at[step].set(total_tree_norm_l2(grad)) + ) if config['enable_all_gradient_norms']: next_metrics_state['all_gradient_norms'] = _set_pytree_idx( - metrics_state['all_gradient_norms'], grad_norm, step) + metrics_state['all_gradient_norms'], grad_norm, step + ) if config['enable_update_norm']: - next_metrics_state['update_norm'] = metrics_state['update_norm'].at[ - step].set(total_tree_norm_l2(update)) + next_metrics_state['update_norm'] = ( + metrics_state['update_norm'].at[step].set(total_tree_norm_l2(update)) + ) if config['enable_update_norms']: update_norm = jax.tree.map(_compute_leaf_norms, update) next_metrics_state['update_norms'] = _set_pytree_idx( - metrics_state['update_norms'], update_norm, step) + metrics_state['update_norms'], update_norm, step + ) if config['enable_ema']: beta = config['ema_beta'] grad_sq = jax.tree.map(jnp.square, grad) update_sq = jax.tree.map(jnp.square, update) next_metrics_state['grad_ema'] = _advance_ema( - metrics_state['grad_ema'], grad, beta) + metrics_state['grad_ema'], grad, beta + ) next_metrics_state['grad_sq_ema'] = _advance_ema( - metrics_state['grad_sq_ema'], grad_sq, beta) + metrics_state['grad_sq_ema'], grad_sq, beta + ) next_metrics_state['update_ema'] = _advance_ema( - metrics_state['update_ema'], update, beta) + metrics_state['update_ema'], update, beta + ) next_metrics_state['update_sq_ema'] = _advance_ema( - metrics_state['update_sq_ema'], update_sq, beta) + metrics_state['update_sq_ema'], update_sq, beta + ) if config['optstate_sumsq_fields']: next_metrics_state['optstate_sumsq'] = {} for field_name in config['optstate_sumsq_fields']: @@ -338,8 +377,11 @@ def _update_fn(metrics_state, step, train_cost, grad, old_params, new_params, if field is None: raise ValueError('optimizer state has no field {}'.format(field_name)) field_normsq = total_tree_norm_sql2(field) - next_metrics_state['optstate_sumsq'][field_name] = metrics_state[ - 'optstate_sumsq'][field_name].at[step].set(field_normsq) + next_metrics_state['optstate_sumsq'][field_name] = ( + metrics_state['optstate_sumsq'][field_name] + .at[step] + .set(field_normsq) + ) if config['optstate_sumsq_param_wise_fields']: next_metrics_state['optstate_sumsq_param_wise'] = {} for field_name in config['optstate_sumsq_param_wise_fields']: @@ -362,8 +404,9 @@ def _update_fn(metrics_state, step, train_cost, grad, old_params, new_params, if field is None: raise ValueError('optimizer state has no field {}'.format(field_name)) field_normsq = total_tree_sum(field) - next_metrics_state['optstate_sum'][field_name] = metrics_state[ - 'optstate_sum'][field_name].at[step].set(field_normsq) + next_metrics_state['optstate_sum'][field_name] = ( + metrics_state['optstate_sum'][field_name].at[step].set(field_normsq) + ) if config['optstate_sum_param_wise_fields']: next_metrics_state['optstate_sum_param_wise'] = {} for field_name in config['optstate_sum_param_wise_fields']: @@ -378,34 +421,46 @@ def _update_fn(metrics_state, step, train_cost, grad, old_params, new_params, step, ) ) - if (config['enable_preconditioner_normsq'] or - config['enable_semip_grad_normsq']): + if ( + config['enable_preconditioner_normsq'] + or config['enable_semip_grad_normsq'] + ): preconditioner = freeze( - make_diag_preconditioner(hps['optimizer'], hps['opt_hparams'], - optimizer_state, ConfigDict({}))) + make_diag_preconditioner( + hps['optimizer'], + hps['opt_hparams'], + optimizer_state, + ConfigDict({}), + ) + ) if config['enable_preconditioner_normsq']: normsq = total_tree_norm_sql2(preconditioner) - next_metrics_state['preconditioner_normsq'] = metrics_state[ - 'preconditioner_normsq'].at[step].set(normsq) + next_metrics_state['preconditioner_normsq'] = ( + metrics_state['preconditioner_normsq'].at[step].set(normsq) + ) if config['enable_semip_grad_normsq']: - semip_grad = jax.tree.map(lambda g, p: g / (p**0.5), - grad, preconditioner) + semip_grad = jax.tree.map( + lambda g, p: g / (p**0.5), grad, preconditioner + ) semip_grad_normsq = total_tree_norm_sql2(semip_grad) - next_metrics_state['semip_grad_normsq'] = metrics_state[ - 'semip_grad_normsq'].at[step].set(semip_grad_normsq) + next_metrics_state['semip_grad_normsq'] = ( + metrics_state['semip_grad_normsq'].at[step].set(semip_grad_normsq) + ) if config['enable_grafting_norms']: mag_norm = optimizer_utils.extract_field(optimizer_state, 'mag_norm') if mag_norm is None: raise ValueError('optimizer state has no field {}'.format('mag_norm')) mag_norm = freeze(mag_norm) next_metrics_state['mag_norms'] = _set_pytree_idx( - metrics_state['mag_norms'], mag_norm, step) + metrics_state['mag_norms'], mag_norm, step + ) dir_norm = optimizer_utils.extract_field(optimizer_state, 'dir_norm') if dir_norm is None: raise ValueError('optimizer state has no field {}'.format('dir_norm')) dir_norm = freeze(dir_norm) next_metrics_state['dir_norms'] = _set_pytree_idx( - metrics_state['dir_norms'], dir_norm, step) + metrics_state['dir_norms'], dir_norm, step + ) return next_metrics_state @@ -429,17 +484,19 @@ def summarize_fn(metrics_state): def compute_var(first_moment, second_moment): return (second_moment - first_moment**2).sum() - summary['grad_var'] = jax.tree.map(compute_var, - metrics_state['grad_ema'], - metrics_state['grad_sq_ema']) + summary['grad_var'] = jax.tree.map( + compute_var, metrics_state['grad_ema'], metrics_state['grad_sq_ema'] + ) - summary['update_var'] = jax.tree.map(compute_var, - metrics_state['update_ema'], - metrics_state['update_sq_ema']) + summary['update_var'] = jax.tree.map( + compute_var, + metrics_state['update_ema'], + metrics_state['update_sq_ema'], + ) - summary['update_ratio'] = jax.tree.map(operator.truediv, - summary['update_var'], - metrics_state['param_norm']) + summary['update_ratio'] = jax.tree.map( + operator.truediv, summary['update_var'], metrics_state['param_norm'] + ) # This dict will map from "summary key" to "flattened pytree of same shape # as params." @@ -457,9 +514,9 @@ def _map_values(f, dictionary): def _advance_ema(cur_ema, new_val, beta): """Advance an exponential moving average.""" - return jax.tree.map(lambda cur, new: beta * cur + (1 - beta) * new, - cur_ema, - new_val) + return jax.tree.map( + lambda cur, new: beta * cur + (1 - beta) * new, cur_ema, new_val + ) def _compute_leaf_norms(pytree): @@ -478,8 +535,9 @@ def _set_pytree_idx(pytree_of_arrs, new_pytree, idx): Returns: a pytree where we set the "idx" index of each leaf in pytree_of_arrs to the corresponding leaf in new_pytree. - """ + def set_arr(arr, new_value): return arr.at[idx].set(new_value) + return jax.tree.map(set_arr, pytree_of_arrs, new_pytree) diff --git a/init2winit/utils.py b/init2winit/utils.py index a82b7bd1..c9dad95a 100644 --- a/init2winit/utils.py +++ b/init2winit/utils.py @@ -14,6 +14,7 @@ # limitations under the License. """Utility functions for logging and recording training metrics.""" + import concurrent.futures import copy import functools @@ -35,7 +36,6 @@ import pandas as pd from tensorflow.io import gfile - exists = gfile.exists @@ -107,16 +107,16 @@ def set_up_loggers(train_dir, xm_work_unit=None): """Creates a logger for eval metrics as well as initialization metrics.""" csv_path = os.path.join(train_dir, 'measurements.csv') metrics_logger = MetricLogger( - csv_path=csv_path, - xm_work_unit=xm_work_unit, - events_dir=train_dir) + csv_path=csv_path, xm_work_unit=xm_work_unit, events_dir=train_dir + ) init_csv_path = os.path.join(train_dir, 'init_measurements.csv') init_json_path = os.path.join(train_dir, 'init_scalars.json') init_logger = MetricLogger( csv_path=init_csv_path, json_path=init_json_path, - xm_work_unit=xm_work_unit) + xm_work_unit=xm_work_unit, + ) return metrics_logger, init_logger @@ -135,13 +135,16 @@ def run_in_parallel(function, list_of_kwargs_to_function, num_workers): """ if num_workers < 1: raise ValueError( - 'Number of workers must be greater than 0. Was {}'.format(num_workers)) + 'Number of workers must be greater than 0. Was {}'.format(num_workers) + ) with concurrent.futures.ThreadPoolExecutor(num_workers) as executor: futures = [] logging.info( - 'Adding %d jobs to process pool to run in %d parallel ' - 'threads.', len(list_of_kwargs_to_function), num_workers) + 'Adding %d jobs to process pool to run in %d parallel threads.', + len(list_of_kwargs_to_function), + num_workers, + ) for kwargs in list_of_kwargs_to_function: f = executor.submit(function, **kwargs) @@ -171,7 +174,7 @@ def add_log_file(logfile): class PytreeMetricLogger(object): """Used to log pytree metrics during training. - + This class is used to log pytree metrics during training. It is similar to MetricLogger, but it is designed to log pytree metrics. """ @@ -192,7 +195,8 @@ def __init__( self._pytree_path, options=ocp.CheckpointManagerOptions( file_options=orbax_file_options, - max_to_keep=1, create=True, + max_to_keep=1, + create=True, ), ) @@ -240,11 +244,9 @@ class MetricLogger(object): the wrong time. """ - def __init__(self, - csv_path='', - json_path='', - events_dir=None, - **logger_kwargs): + def __init__( + self, csv_path='', json_path='', events_dir=None, **logger_kwargs + ): """Create a recorder for metrics, as CSV or JSON. @@ -264,7 +266,8 @@ def __init__(self, if len(logger_kwargs.keys()) > 1 or 'xm_work_unit' not in logger_kwargs: raise ValueError( 'The only logger_kwarg that should be passed to MetricLogger is ' - 'xm_work_unit.') + 'xm_work_unit.' + ) self._xm_work_unit = logger_kwargs['xm_work_unit'] else: self._xm_work_unit = None @@ -304,7 +307,8 @@ def _read_and_write_csv(): for name, value in metrics.items(): if name not in self._measurements: self._measurements[name] = self._xm_work_unit.get_measurement_series( - label=name) + label=name + ) try: self._measurements[name].create_measurement( objective_value=reduce_to_scalar(value), @@ -316,7 +320,8 @@ def _read_and_write_csv(): if self._tb_metric_writer: self._tb_metric_writer.write_scalars( - step=metrics['global_step'], scalars=metrics) + step=metrics['global_step'], scalars=metrics + ) # This gives a 1-2% slowdown in steps_per_sec on cifar-10 with batch # size 512. We could only flush at the end of training to optimize this. self._tb_metric_writer.flush() @@ -361,7 +366,8 @@ def log_pytree_shape_and_statistics(pytree, json_path=None): shape_dict = jax.tree.map(_summary_str, pytree) absl_logging.info(flax.core.pretty_repr(shape_dict)) total_params = jax.tree_util.tree_reduce( - operator.add, jax.tree.map(lambda x: x.size, pytree)) + operator.add, jax.tree.map(lambda x: x.size, pytree) + ) absl_logging.info('Total params: %d', total_params) @@ -372,18 +378,27 @@ def tabulate_model(model, hps): model: init2winit BaseModel hps: ml_collections.config_dict.config_dict.ConfigDict """ - tabulate_fn = nn.tabulate(model.flax_module, jax.random.PRNGKey(0), - console_kwargs={'force_terminal': False, - 'force_jupyter': False, - 'width': 240}, - ) + tabulate_fn = nn.tabulate( + model.flax_module, + jax.random.PRNGKey(0), + console_kwargs={ + 'force_terminal': False, + 'force_jupyter': False, + 'width': 240, + }, + ) fake_inputs_hps = copy.copy(hps) fake_inputs_hps.batch_size = 2 fake_inputs = model.get_fake_inputs(fake_inputs_hps) # Currently only two models implement the get_fake_batch. # Only attempt to log if we get a valid fake_input_batch. if fake_inputs: - absl_logging.info(tabulate_fn(*fake_inputs, train=False,)) + absl_logging.info( + tabulate_fn( + *fake_inputs, + train=False, + ) + ) def edit_distance(source, target): @@ -429,7 +444,8 @@ def edit_distance(source, target): distance[i][j] = 1 + min( distance[i][j - 1], # Insert distance[i - 1][j], # Remove - distance[i - 1][j - 1]) # Replace + distance[i - 1][j - 1], + ) # Replace return distance[num_source_words][num_target_words]