From fa8d48ed6509ec1ccbf3133ae280a08f7b59b89f Mon Sep 17 00:00:00 2001 From: mmcky Date: Fri, 28 Nov 2025 14:39:26 +1100 Subject: [PATCH] FIX: jax_intro timeout - use lax.fori_loop instead of Python for loop The compute_call_price_jax function was timing out during builds due to JAX unrolling the Python for loop during JIT compilation. With large arrays (M=10,000,000), this causes excessive compilation time. Solution: Replace the Python for loop with jax.lax.fori_loop, which compiles the loop efficiently without unrolling. Same fix as QuantEcon/lecture-python-programming.myst#442 --- lectures/jax_intro.md | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/lectures/jax_intro.md b/lectures/jax_intro.md index 740821a5..f43f0964 100644 --- a/lectures/jax_intro.md +++ b/lectures/jax_intro.md @@ -727,16 +727,31 @@ def compute_call_price_jax(β=β, s = jnp.full(M, np.log(S0)) h = jnp.full(M, h0) - for t in range(n): + + def update(i, loop_state): + s, h, key = loop_state key, subkey = jax.random.split(key) Z = jax.random.normal(subkey, (2, M)) s = s + μ + jnp.exp(h) * Z[0, :] h = ρ * h + ν * Z[1, :] + new_loop_state = s, h, key + return new_loop_state + + initial_loop_state = s, h, key + final_loop_state = jax.lax.fori_loop(0, n, update, initial_loop_state) + s, h, key = final_loop_state + expectation = jnp.mean(jnp.maximum(jnp.exp(s) - K, 0)) return β**n * expectation ``` +```{note} +We use `jax.lax.fori_loop` instead of a Python `for` loop. +This allows JAX to compile the loop efficiently without unrolling it, +which significantly reduces compilation time for large arrays. +``` + Let's run it once to compile it: ```{code-cell} ipython3