From f81f79e72e944ecccb043f33ba1404c22483d10e Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Tue, 4 Mar 2025 13:36:10 +0100 Subject: [PATCH 01/17] add steps to args --- neural_lam/train_model.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/neural_lam/train_model.py b/neural_lam/train_model.py index e8b402d52..f1b050a5c 100644 --- a/neural_lam/train_model.py +++ b/neural_lam/train_model.py @@ -72,6 +72,12 @@ def main(input_args=None): default=200, help="upper epoch limit (default: 200)", ) + parser.add_argument( + "--steps", + type=int, + default=-1, + help="upper step limit (default: None)", + ) parser.add_argument( "--batch_size", type=int, default=4, help="batch size (default: 4)" ) @@ -308,6 +314,7 @@ def main(input_args=None): ) trainer = pl.Trainer( max_epochs=args.epochs, + max_steps=args.steps, deterministic=True, strategy="ddp", accelerator=device_name, From 9358f66f78936833555ce534f1aa95cd963b1268 Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Tue, 4 Mar 2025 15:30:42 +0100 Subject: [PATCH 02/17] first tests --- neural_lam/lr_scheduler.py | 82 +++++++++++++++++++++++++++++++++++ neural_lam/models/ar_model.py | 1 + tests/test_lr_scheduler.py | 25 +++++++++++ 3 files changed, 108 insertions(+) create mode 100644 neural_lam/lr_scheduler.py create mode 100644 tests/test_lr_scheduler.py diff --git a/neural_lam/lr_scheduler.py b/neural_lam/lr_scheduler.py new file mode 100644 index 000000000..d0da08728 --- /dev/null +++ b/neural_lam/lr_scheduler.py @@ -0,0 +1,82 @@ +# Standard library +import math + +# Third-party +import matplotlib.pyplot as plt +import torch + + +class WarmupCosineAnnealingLR(torch.optim.lr_scheduler.LRScheduler): + def __init__( + self, + optimizer, + total_steps, + warmup_steps=1000, + max_lr=0.001, + min_lr=0.00001, + ): + self.max_steps = max_steps + self.warmup_steps = warmup_steps + self.max_lr = max_lr + self.min_lr = min_lr + schedule = MattsSchedule( + total_steps=100, warmup_steps=55, min_lr=0, max_lr=1 + ) + super().__init__(optimizer) + + def get_lr(self): + self.base_lrs + lrs = [1 for group in self.optimizer.param_groups] + return lrs + + def warmup(self, step): + return step / self.warmup_steps * self.max + + def cosine_annealing(self, step): + if step > self.max_steps: + return self.min + + return self.min + 0.5 * (self.max_lr - self.min_lr) * ( + 1 + math.cos(math.pi * step / self.max_steps) + ) + + +class MattsSchedule: + def __init__( + self, + total_steps, + warmup_steps=1000, + min_lr=0.00001, + max_lr=0.001, + ): + self.max_steps = total_steps + self.warmup_steps = warmup_steps + self.annealing_steps = total_steps - warmup_steps + + self.max_lr = max_lr + self.min_lr = min_lr + + def warmup(self, step): + return step / self.warmup_steps * self.max_lr + + def cosine_annealing(self, step): + return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * ( + 1 + math.cos(math.pi * step / self.annealing_steps) + ) + + def calculate_lr(self, step): + if step < self.warmup_steps: + lr = self.warmup(step) + elif step < self.max_steps: + lr = self.cosine_annealing(step - self.warmup_steps) + else: + lr = self.min_lr + return lr + + def get_lr(self, step): + __import__("pdb").set_trace() # TODO delme kj:w + + return [self.calculate_lr(step) for _ in optimizer.param_groups] + + def __len__(self): + return self.max_steps diff --git a/neural_lam/models/ar_model.py b/neural_lam/models/ar_model.py index f3769f194..0bec5722e 100644 --- a/neural_lam/models/ar_model.py +++ b/neural_lam/models/ar_model.py @@ -193,6 +193,7 @@ def configure_optimizers(self): opt = torch.optim.AdamW( self.parameters(), lr=self.args.lr, betas=(0.9, 0.95) ) + return opt @property diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py new file mode 100644 index 000000000..b2f2e9012 --- /dev/null +++ b/tests/test_lr_scheduler.py @@ -0,0 +1,25 @@ +# Standard library +import warnings +from unittest.mock import MagicMock + +# Third-party +import pytest +import torch + +# First-party +from neural_lam import lr_scheduler + + +@pytest.fixture +def model(): + return torch.nn.Linear(1, 1) + + +@pytest.fixture +def optimizer(model): + return torch.optim.SGD(model.parameters(), lr=0.01) # Real optimizer + + +def test_warmup_cosine_annealing_can_instantiate(optimizer): + lrs = lr_scheduler.WarmupCosineAnnealingLR(optimizer, max_steps=1000) + __import__("pdb").set_trace() # TODO delme From 3b0f24e37d7bdb6a5c76d3bfb592cc1ccbd6fd37 Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Wed, 5 Mar 2025 13:46:50 +0100 Subject: [PATCH 03/17] add WarmupCosineAnnealingScheduler --- tests/test_lr_scheduler.py | 46 ++++++++++++++++++++++++++++++++------ 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index b2f2e9012..0df79d482 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -1,8 +1,6 @@ -# Standard library -import warnings -from unittest.mock import MagicMock - # Third-party +import matplotlib.pyplot as plt +import numpy as np import pytest import torch @@ -17,9 +15,43 @@ def model(): @pytest.fixture def optimizer(model): - return torch.optim.SGD(model.parameters(), lr=0.01) # Real optimizer + return torch.optim.Adam(model.parameters()) # Real optimizer def test_warmup_cosine_annealing_can_instantiate(optimizer): - lrs = lr_scheduler.WarmupCosineAnnealingLR(optimizer, max_steps=1000) - __import__("pdb").set_trace() # TODO delme + min_factor = 0.001 + max_factor = 1 + warmup_steps = 10 + annealing_steps = 10 + initial_lr = optimizer.param_groups[0]["lr"] + + linear = lr_scheduler.WarmupCosineAnnealingLR( + optimizer, + min_factor=min_factor, + max_factor=max_factor, + annealing_steps=annealing_steps, + warmup_steps=warmup_steps, + ) + + lrs = [] + for _ in range(25): + lrs.append(optimizer.param_groups[0]["lr"]) + linear.step() + + expected_warmup_lr = np.linspace( + min_factor * initial_lr, + max_factor * initial_lr, + warmup_steps, + endpoint=False, + ) + warmup_lr = lrs[:warmup_steps] + assert np.allclose(warmup_lr, expected_warmup_lr) + + annealing_lr = lrs[warmup_steps : warmup_steps + annealing_steps] + expected_annealing_lr = min_factor * initial_lr + 0.5 * ( + max_factor * initial_lr - min_factor * initial_lr + ) * (1 + np.cos(np.pi * np.arange(annealing_steps) / annealing_steps)) + assert np.allclose(annealing_lr, expected_annealing_lr) + + last_lr = lrs[-1] + assert all(lrs[-5:] == np.ones(5) * last_lr) From 2480d8bc3e8084ed45f2b8abe511792ff0efe05e Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Wed, 5 Mar 2025 13:46:54 +0100 Subject: [PATCH 04/17] add test --- neural_lam/lr_scheduler.py | 100 ++++++++++++------------------------- 1 file changed, 33 insertions(+), 67 deletions(-) diff --git a/neural_lam/lr_scheduler.py b/neural_lam/lr_scheduler.py index d0da08728..2c192ada3 100644 --- a/neural_lam/lr_scheduler.py +++ b/neural_lam/lr_scheduler.py @@ -1,8 +1,4 @@ -# Standard library -import math - # Third-party -import matplotlib.pyplot as plt import torch @@ -10,73 +6,43 @@ class WarmupCosineAnnealingLR(torch.optim.lr_scheduler.LRScheduler): def __init__( self, optimizer, - total_steps, - warmup_steps=1000, - max_lr=0.001, - min_lr=0.00001, + warmup_steps=10, + annealing_steps=90, + max_factor=1.0, + min_factor=0.001, ): - self.max_steps = max_steps self.warmup_steps = warmup_steps - self.max_lr = max_lr - self.min_lr = min_lr - schedule = MattsSchedule( - total_steps=100, warmup_steps=55, min_lr=0, max_lr=1 - ) - super().__init__(optimizer) - - def get_lr(self): - self.base_lrs - lrs = [1 for group in self.optimizer.param_groups] - return lrs - - def warmup(self, step): - return step / self.warmup_steps * self.max - - def cosine_annealing(self, step): - if step > self.max_steps: - return self.min - - return self.min + 0.5 * (self.max_lr - self.min_lr) * ( - 1 + math.cos(math.pi * step / self.max_steps) + self.annealing_steps = annealing_steps + initial_learning_rate = optimizer.param_groups[0]["lr"] + + self.warmup_scheduler = torch.optim.lr_scheduler.LinearLR( + optimizer, + start_factor=min_factor, + end_factor=max_factor, + total_iters=warmup_steps, ) - -class MattsSchedule: - def __init__( - self, - total_steps, - warmup_steps=1000, - min_lr=0.00001, - max_lr=0.001, - ): - self.max_steps = total_steps - self.warmup_steps = warmup_steps - self.annealing_steps = total_steps - warmup_steps - - self.max_lr = max_lr - self.min_lr = min_lr - - def warmup(self, step): - return step / self.warmup_steps * self.max_lr - - def cosine_annealing(self, step): - return self.min_lr + 0.5 * (self.max_lr - self.min_lr) * ( - 1 + math.cos(math.pi * step / self.annealing_steps) + self.annealing_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( + optimizer, + T_max=annealing_steps, + eta_min=min_factor * initial_learning_rate, ) - def calculate_lr(self, step): - if step < self.warmup_steps: - lr = self.warmup(step) - elif step < self.max_steps: - lr = self.cosine_annealing(step - self.warmup_steps) - else: - lr = self.min_lr - return lr - - def get_lr(self, step): - __import__("pdb").set_trace() # TODO delme kj:w - - return [self.calculate_lr(step) for _ in optimizer.param_groups] + super().__init__(optimizer) - def __len__(self): - return self.max_steps + def get_lr(self): + if self.last_epoch <= self.warmup_steps: + return self.warmup_scheduler.get_last_lr() + elif self.last_epoch <= self.warmup_steps + self.annealing_steps: + self.annealing_scheduler.step() + + return True + + def step(self): + if self._step_count == 0: + pass + elif self._step_count <= self.warmup_steps: + self.warmup_scheduler.step() + elif self._step_count <= self.warmup_steps + self.annealing_steps: + self.annealing_scheduler.step() + self._step_count += 1 From e0cb8fbb877e68ae9be4c7cdac8aee11f914e656 Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Wed, 5 Mar 2025 14:15:47 +0100 Subject: [PATCH 05/17] rename --- tests/test_lr_scheduler.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index 0df79d482..7af6defbf 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -19,13 +19,13 @@ def optimizer(model): def test_warmup_cosine_annealing_can_instantiate(optimizer): - min_factor = 0.001 + min_factor = 0.01 max_factor = 1 warmup_steps = 10 annealing_steps = 10 initial_lr = optimizer.param_groups[0]["lr"] - linear = lr_scheduler.WarmupCosineAnnealingLR( + scheduler = lr_scheduler.WarmupCosineAnnealingLR( optimizer, min_factor=min_factor, max_factor=max_factor, @@ -36,7 +36,7 @@ def test_warmup_cosine_annealing_can_instantiate(optimizer): lrs = [] for _ in range(25): lrs.append(optimizer.param_groups[0]["lr"]) - linear.step() + scheduler.step() expected_warmup_lr = np.linspace( min_factor * initial_lr, From 32dded5ff89b8d2f79b29f54726e08185a328e6b Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Wed, 5 Mar 2025 14:16:16 +0100 Subject: [PATCH 06/17] add plot of lrs with default values --- neural_lam/lr_scheduler.py | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/neural_lam/lr_scheduler.py b/neural_lam/lr_scheduler.py index 2c192ada3..7f582f1bb 100644 --- a/neural_lam/lr_scheduler.py +++ b/neural_lam/lr_scheduler.py @@ -46,3 +46,34 @@ def step(self): elif self._step_count <= self.warmup_steps + self.annealing_steps: self.annealing_scheduler.step() self._step_count += 1 + + +if __name__ == "__main__": + # Third-party + import matplotlib.pyplot as plt + + model = torch.nn.Linear(1, 1) + opt = torch.optim.Adam(model.parameters()) + scheduler = WarmupCosineAnnealingLR( + opt, warmup_steps=20, annealing_steps=100 + ) + + lrs = [] + for _ in range(150): + lrs.append(opt.param_groups[0]["lr"]) + scheduler.step() + + plt.plot(lrs) + plt.vlines(20, 0, max(lrs), colors="k", linestyles="dashed") + plt.vlines(120, 0, max(lrs), colors="k", linestyles="dashed") + plt.text(21, max(lrs) / 2, "warmup ended", fontsize=10, color="k") + plt.text(121, max(lrs) / 2, "annealing ended", fontsize=10, color="k") + + plt.hlines(max(lrs), 15, 25, colors="k", linestyles="dashed") + plt.text(26, max(lrs), f"{max(lrs):.2e}", fontsize=10, color="k") + plt.text(121, min(lrs), f"{min(lrs):.2e}", fontsize=10, color="k") + + plt.xlabel("Step") + plt.ylabel("Learning Rate") + + plt.show() From 0312f10e1d72b6bd6f669f1c81fa0ca6f56c56ec Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Wed, 5 Mar 2025 14:49:55 +0100 Subject: [PATCH 07/17] rename --- tests/test_lr_scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index 7af6defbf..fbbd0c71e 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -15,10 +15,10 @@ def model(): @pytest.fixture def optimizer(model): - return torch.optim.Adam(model.parameters()) # Real optimizer + return torch.optim.Adam(model.parameters()) -def test_warmup_cosine_annealing_can_instantiate(optimizer): +def test_warmup_cosine_annealing_produces_expected_schedule(optimizer): min_factor = 0.01 max_factor = 1 warmup_steps = 10 From 5d96ccee000eb8b38780af74f24be52a377fc73f Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Wed, 5 Mar 2025 14:57:46 +0100 Subject: [PATCH 08/17] remove unused import --- tests/test_lr_scheduler.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index fbbd0c71e..65b78a9f1 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -1,5 +1,4 @@ # Third-party -import matplotlib.pyplot as plt import numpy as np import pytest import torch From dba25321a666ddfecb632c88b985eef60b24add1 Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Wed, 5 Mar 2025 15:04:05 +0100 Subject: [PATCH 09/17] use _step --- neural_lam/lr_scheduler.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/neural_lam/lr_scheduler.py b/neural_lam/lr_scheduler.py index 7f582f1bb..e504bf205 100644 --- a/neural_lam/lr_scheduler.py +++ b/neural_lam/lr_scheduler.py @@ -6,8 +6,8 @@ class WarmupCosineAnnealingLR(torch.optim.lr_scheduler.LRScheduler): def __init__( self, optimizer, - warmup_steps=10, - annealing_steps=90, + warmup_steps=1000, + annealing_steps=100000, max_factor=1.0, min_factor=0.001, ): @@ -31,9 +31,9 @@ def __init__( super().__init__(optimizer) def get_lr(self): - if self.last_epoch <= self.warmup_steps: + if self._step_count <= self.warmup_steps: return self.warmup_scheduler.get_last_lr() - elif self.last_epoch <= self.warmup_steps + self.annealing_steps: + elif self._step_count <= self.warmup_steps + self.annealing_steps: self.annealing_scheduler.step() return True @@ -49,6 +49,7 @@ def step(self): if __name__ == "__main__": + # Run this code to visualize the learning rate schedule # Third-party import matplotlib.pyplot as plt From df4fba8ca18d2fd69f81d99339f506fd456a473d Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Wed, 5 Mar 2025 15:06:29 +0100 Subject: [PATCH 10/17] add comment --- tests/test_lr_scheduler.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index 65b78a9f1..1a648a63d 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -47,6 +47,8 @@ def test_warmup_cosine_annealing_produces_expected_schedule(optimizer): assert np.allclose(warmup_lr, expected_warmup_lr) annealing_lr = lrs[warmup_steps : warmup_steps + annealing_steps] + + # Formula for the cosine annealing expected_annealing_lr = min_factor * initial_lr + 0.5 * ( max_factor * initial_lr - min_factor * initial_lr ) * (1 + np.cos(np.pi * np.arange(annealing_steps) / annealing_steps)) From ddbf4e65cb4ed676e678296da90fc28e896af965 Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Wed, 5 Mar 2025 15:22:49 +0100 Subject: [PATCH 11/17] empty From 5550f8300514bfaed73ee988a88b481c483069e5 Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Thu, 6 Mar 2025 10:18:37 +0100 Subject: [PATCH 12/17] Trigger From 0f6c26a7cb6cdc42212f5d985b19617461c32cf5 Mon Sep 17 00:00:00 2001 From: matschreiner Date: Mon, 10 Mar 2025 10:20:12 +0100 Subject: [PATCH 13/17] Update tests/test_lr_scheduler.py Co-authored-by: SimonKamuk <43374850+SimonKamuk@users.noreply.github.com> --- tests/test_lr_scheduler.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_lr_scheduler.py b/tests/test_lr_scheduler.py index 1a648a63d..c179d5eaf 100644 --- a/tests/test_lr_scheduler.py +++ b/tests/test_lr_scheduler.py @@ -54,5 +54,5 @@ def test_warmup_cosine_annealing_produces_expected_schedule(optimizer): ) * (1 + np.cos(np.pi * np.arange(annealing_steps) / annealing_steps)) assert np.allclose(annealing_lr, expected_annealing_lr) - last_lr = lrs[-1] - assert all(lrs[-5:] == np.ones(5) * last_lr) + end_lr = np.array(lrs[warmup_steps + annealing_steps:]) + assert all(end_lr == min_factor * initial_lr) From 5263cffc1a12b46d87241dc21a7ad3c97c52aba8 Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Mon, 10 Mar 2025 10:41:33 +0100 Subject: [PATCH 14/17] remove visualization code --- neural_lam/lr_scheduler.py | 32 -------------------------------- 1 file changed, 32 deletions(-) diff --git a/neural_lam/lr_scheduler.py b/neural_lam/lr_scheduler.py index e504bf205..121ead6bf 100644 --- a/neural_lam/lr_scheduler.py +++ b/neural_lam/lr_scheduler.py @@ -46,35 +46,3 @@ def step(self): elif self._step_count <= self.warmup_steps + self.annealing_steps: self.annealing_scheduler.step() self._step_count += 1 - - -if __name__ == "__main__": - # Run this code to visualize the learning rate schedule - # Third-party - import matplotlib.pyplot as plt - - model = torch.nn.Linear(1, 1) - opt = torch.optim.Adam(model.parameters()) - scheduler = WarmupCosineAnnealingLR( - opt, warmup_steps=20, annealing_steps=100 - ) - - lrs = [] - for _ in range(150): - lrs.append(opt.param_groups[0]["lr"]) - scheduler.step() - - plt.plot(lrs) - plt.vlines(20, 0, max(lrs), colors="k", linestyles="dashed") - plt.vlines(120, 0, max(lrs), colors="k", linestyles="dashed") - plt.text(21, max(lrs) / 2, "warmup ended", fontsize=10, color="k") - plt.text(121, max(lrs) / 2, "annealing ended", fontsize=10, color="k") - - plt.hlines(max(lrs), 15, 25, colors="k", linestyles="dashed") - plt.text(26, max(lrs), f"{max(lrs):.2e}", fontsize=10, color="k") - plt.text(121, min(lrs), f"{min(lrs):.2e}", fontsize=10, color="k") - - plt.xlabel("Step") - plt.ylabel("Learning Rate") - - plt.show() From e6b1746967a0b82af8bbba6a492c6e37e3ce19d0 Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Mon, 10 Mar 2025 10:41:55 +0100 Subject: [PATCH 15/17] get_lr should return lr not take step --- neural_lam/lr_scheduler.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/neural_lam/lr_scheduler.py b/neural_lam/lr_scheduler.py index 121ead6bf..af33289f1 100644 --- a/neural_lam/lr_scheduler.py +++ b/neural_lam/lr_scheduler.py @@ -34,7 +34,7 @@ def get_lr(self): if self._step_count <= self.warmup_steps: return self.warmup_scheduler.get_last_lr() elif self._step_count <= self.warmup_steps + self.annealing_steps: - self.annealing_scheduler.step() + return self.annealing_scheduler.get_last_lr() return True From 27552cd95e59ce0b4abff4f440fd73de3b3fe666 Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Mon, 10 Mar 2025 10:44:29 +0100 Subject: [PATCH 16/17] only support single group --- neural_lam/lr_scheduler.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/neural_lam/lr_scheduler.py b/neural_lam/lr_scheduler.py index af33289f1..188d18f0b 100644 --- a/neural_lam/lr_scheduler.py +++ b/neural_lam/lr_scheduler.py @@ -13,6 +13,10 @@ def __init__( ): self.warmup_steps = warmup_steps self.annealing_steps = annealing_steps + assert ( + len(optimizer.param_groups) == 1 + ), "WarmupCosineAnnealingLR only supports training with one parameter group" + initial_learning_rate = optimizer.param_groups[0]["lr"] self.warmup_scheduler = torch.optim.lr_scheduler.LinearLR( From 56a33bd12e2aa3e57dd5a9e9ba775de02ed6f242 Mon Sep 17 00:00:00 2001 From: Jacob Mathias Schreiner Date: Mon, 10 Mar 2025 10:47:04 +0100 Subject: [PATCH 17/17] make sure only one parameter group is being used --- neural_lam/lr_scheduler.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/neural_lam/lr_scheduler.py b/neural_lam/lr_scheduler.py index 188d18f0b..5b1a6bf56 100644 --- a/neural_lam/lr_scheduler.py +++ b/neural_lam/lr_scheduler.py @@ -13,11 +13,13 @@ def __init__( ): self.warmup_steps = warmup_steps self.annealing_steps = annealing_steps + + # TODO generalize this to support multiple parameter groups assert ( len(optimizer.param_groups) == 1 ), "WarmupCosineAnnealingLR only supports training with one parameter group" - - initial_learning_rate = optimizer.param_groups[0]["lr"] + [param_group] = optimizer.param_groups + initial_learning_rate = param_group["lr"] self.warmup_scheduler = torch.optim.lr_scheduler.LinearLR( optimizer,