From 367707343b118564f6a86b651d37bc79e8a3d9a1 Mon Sep 17 00:00:00 2001 From: Vincent Roulet Date: Fri, 5 Jun 2026 11:39:08 -0700 Subject: [PATCH] Add grad_clip to static_argnames in training algorithms PiperOrigin-RevId: 927403292 --- init2winit/trainer_lib/test_trainer.py | 6 +++++- init2winit/trainer_lib/training_algorithm.py | 1 + 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/init2winit/trainer_lib/test_trainer.py b/init2winit/trainer_lib/test_trainer.py index d0be571e..8700edf3 100644 --- a/init2winit/trainer_lib/test_trainer.py +++ b/init2winit/trainer_lib/test_trainer.py @@ -652,7 +652,8 @@ def test_text_model_trainer(self): self.assertEqual(set(df.columns.values), set(get_column_names())) - def test_trainer(self): + @parameterized.parameters(None, 1.0) + def test_trainer(self, grad_clip): """Test training for two epochs on MNIST with a small model.""" rng = jax.random.PRNGKey(0) @@ -678,6 +679,9 @@ def test_trainer(self): 'valid_size': 96, 'test_size': 80, } + if grad_clip is not None: + hparam_overrides['opt_hparams.grad_clip'] = grad_clip + input_pipeline_hps = config_dict.ConfigDict(dict( num_tf_data_prefetches=-1, num_device_prefetches=0, diff --git a/init2winit/trainer_lib/training_algorithm.py b/init2winit/trainer_lib/training_algorithm.py index f509b1f0..4f07383b 100644 --- a/init2winit/trainer_lib/training_algorithm.py +++ b/init2winit/trainer_lib/training_algorithm.py @@ -902,6 +902,7 @@ def update_params( static_argnames=( 'training_cost_fn', 'optimizer_update_fn', + 'grad_clip', ), donate_argnums=(0, 1, 2), )