2025, Oct 02 21:00
JAX lax.scan parameterization: closure (lambda/def) vs carry, identical performance explained
Learn whether to use closures (lambda/def) or pass params as carry in JAX lax.scan. See they compile similarly and how to choose for readability. In practice.
Parameterizing a jax.lax.scan loop often triggers the same question: should you capture constants with a closure, or carry them through the scan state? In practice the three common patterns—using a lambda, defining a local function with def, or passing parameters as carry—compile to virtually identical code paths in JAX. Understanding why they look the same to the runtime helps you choose based on readability rather than imagined performance trade-offs.
Reproducible setup that shows the three patterns
import jax
import jax.numpy as jnp
import time
STEPS = 100
def evolve(s, a, dtau, damp):
    v = jnp.array([jnp.cos(s[2]) * a[0], jnp.sin(s[2]) * a[0], a[1]])
    s_next = s + v * dtau - damp * v
    return s_next
def advance(buf, a, dtau, damp):
    s, t = buf
    s_next = evolve(s, a[:, t], dtau, damp)
    return (s_next, t + 1), s_next
# Closure via lambda
def run_with_lambda(s0, a, dtau, damp):
    step_fn = lambda carry, _inp: advance(carry, a, dtau, damp)
    init = (s0, 0)
    return jax.lax.scan(step_fn, init, None, length=STEPS)
# Closure via def
def run_with_def(s0, a, dtau, damp):
    def step_fn(carry, _inp):
        return advance(carry, a, dtau, damp)
    return jax.lax.scan(step_fn, (s0, 0), None, length=STEPS)
# Pass everything as carry
def advance_with_carry(buf, _inp):
    s, t, a, dtau, damp = buf
    s_next = evolve(s, a[:, t], dtau, damp)
    return (s_next, t + 1, a, dtau, damp), s_next
def run_with_carry(s0, a, dtau, damp):
    return jax.lax.scan(advance_with_carry, (s0, 0, a, dtau, damp), None, length=STEPS)
# Demo inputs
s0 = jnp.array([2., 3., 4.])
rng = jax.random.PRNGKey(0)
a = jax.random.uniform(rng, shape=(2, STEPS))
dtau = 0.01
damp = 1e-6
# Optional benchmarking helper
def run_bench(g, label):
    g_jit = jax.jit(g)
    carry, traj = g_jit(s0, a, dtau, damp)
    carry[0].block_until_ready()
    t0 = time.time()
    for _ in range(1000):
        carry, ys = g_jit(s0, a, dtau, damp)
        carry[0].block_until_ready()
    t1 = time.time()
    print(f"{label:10s} | cached run: {t1 - t0:.6f}s")
# run_bench(run_with_lambda, "lambda")
# run_bench(run_with_def,    "def")
# run_bench(run_with_carry,  "carry_all")
What actually happens in JAX
All three variants are idiomatic in JAX and there is no reason to prefer one over the others from the runtime’s perspective. Anonymous functions created with lambda and named functions created with def are treated identically; JAX does not distinguish them during tracing and compilation. The closure that captures a, dtau, and damp is therefore equivalent whether you write it via lambda or def.
The carry-based version differs only in how those same values are made available to the step function: they are threaded through scan’s carry instead of being captured as closed-over constants. This yields slightly different parameters in the lowering of scan, but it does not change the resulting compiled function in a way that would matter for execution.
One related terminology point often raised here is lambda-assignment. Assigning a lambda to a name is indeed what’s being done in the lambda variant above. Some Python style guides recommend avoiding it, but as far as JAX execution is concerned, lambda-assignment is not problematic; whether you follow that style guidance is a separate choice.
Solution in practice
Because these options are equivalent at the JAX level, pick the shape that makes your code easiest to read and maintain. If it’s clearer to express the step as a locally defined function that closes over constants, do so. If you prefer to surface parameters explicitly in the scan state, passing them as carry is equally acceptable. Here is a compact version using a local def:
import jax
import jax.numpy as jnp
STEPS = 100
def evolve(s, a, dtau, damp):
    v = jnp.array([jnp.cos(s[2]) * a[0], jnp.sin(s[2]) * a[0], a[1]])
    return s + v * dtau - damp * v
def advance(buf, a, dtau, damp):
    s, t = buf
    s_next = evolve(s, a[:, t], dtau, damp)
    return (s_next, t + 1), s_next
def rollout_scan(s0, a, dtau, damp):
    def step(carry, _):
        return advance(carry, a, dtau, damp)
    return jax.lax.scan(step, (s0, 0), None, length=STEPS)
Why this matters
Knowing that lambda, def, and carry lead to virtually identical compiled code helps you avoid premature optimization and focus on the system dynamics themselves. It also clarifies that any perceived differences are unlikely to stem from whether you used a closure or carried parameters, and that JAX’s tracing treats anonymous and named callables the same way.
Takeaways
Use whichever approach results in clearer code for your team. Lambda and def are interchangeable to JAX, and passing parameters as carry instead of closing over them only changes how the scan is parameterized, not the behavior of the compiled function. If in doubt, measure in your specific setup, but expect no inherent runtime advantage of one style over the others.
The article is based on a question from StackOverflow by user1168149 and an answer by jakevdp.