From be8764fb2b6cb9ff2a4786638aa5038f9dcbd44c Mon Sep 17 00:00:00 2001 From: Angus Gibson Date: Tue, 30 Sep 2025 15:11:54 +0100 Subject: [PATCH 1/2] Add test for staggered timestep dependencies This introduces a dependency between non-adjacent timesteps, which can trigger some incorrect behaviour depending on the checkpointing scheduler in use. --- tests/firedrake/adjoint/test_assignment.py | 39 ++++++++++++++++++++++ 1 file changed, 39 insertions(+) diff --git a/tests/firedrake/adjoint/test_assignment.py b/tests/firedrake/adjoint/test_assignment.py index a6ba4e6c83..20af354939 100644 --- a/tests/firedrake/adjoint/test_assignment.py +++ b/tests/firedrake/adjoint/test_assignment.py @@ -281,3 +281,42 @@ def test_adjoint_cleanup(scheduler, rg): dtemp = rg.uniform(u_0.function_space()) assert taylor_test(reduced_functional, u_0, dtemp) > 1.99999999 + + +@pytest.mark.skipcomplex +@pytest.mark.parametrize("scheduler", (None, SingleMemoryStorageSchedule())) +def test_adjoint_stagger(scheduler, rg): + # This test checks that the adjoint does not discard too many checkpoint + # variables. This is achieved by computing the derivative before conducting + # the Taylor test. This extra derivative is the thing that would cause the + # spurious discards. + + # get tape + tape = get_working_tape() + tape.clear_tape() + continue_annotation() + + if scheduler is not None: + tape.enable_checkpointing(scheduler) + + mesh = SquareMesh(1, 1, 1, quadrilateral=True) + + V = FunctionSpace(mesh, "CG", 1) + R = FunctionSpace(mesh, "R", 0) + + u_0 = Function(V).assign(1.0) + u = Function(V).assign(u_0) + r = Function(R) + + for i in tape.timestepper(iter(range(10))): + if i % 3 == 0: + r.assign(r + 1.0) + u.project(r * u) + + J = assemble(u ** 2 * dx) + + pause_annotation() + reduced_functional = ReducedFunctional(J, Control(u_0)) + + dtemp = rg.uniform(u_0.function_space()) + assert taylor_test(reduced_functional, u_0, dtemp) > 1.99999999 From 8515512f41ba5fd6e33568a69e1b14dfb3516375 Mon Sep 17 00:00:00 2001 From: Angus Gibson Date: Wed, 1 Oct 2025 13:51:30 +0100 Subject: [PATCH 2/2] tmp! point to pyadjoint working branch --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 8b03f6bcc7..2affdc28b3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -34,7 +34,7 @@ dependencies = [ "pkgconfig", "progress", # TODO RELEASE: use a release - "pyadjoint-ad @ git+https://github.com/dolfin-adjoint/pyadjoint", + "pyadjoint-ad @ git+https://github.com/angus-g/pyadjoint@angus-g/smss-special-case", "pycparser", "pytools[siphash]", "requests",