2025, Oct 03 07:00

How JAX JIT Caching Works with pytrees, scan, and lambda closures: what triggers recompilation

Learn how JAX JIT caching works with pytrees, scan, and lambda closures: what affects cache keys, what changes trigger recompilation, and keeping runs stable.

JAX scan, pytrees, lambda closures, and JIT caching often meet in the same piece of code — and then the question appears: will capturing a pytree via lambda force recompilation every time a new instance is passed? Let’s break down what actually triggers JIT recompilation, how pytrees control the cache key, and why your setup is fine as long as you keep the right parts stable.

Minimal example that raises the question

The following program scans over an input vector. The state update logic is packaged into a pytree object that carries both an array and a scalar, and the object is closed over via a lambda inside a jitted function.

import jax
import jax.numpy as jnp
from jax import tree_util

class ConfigBox:
    def __init__(self, vec, scale):
        self.vec = vec
        self.scale = scale

    def step_fn(self, st, inp):
        s = st
        y = inp
        nxt = (self.vec + s + jnp.ones(self.scale)) * y
        return nxt

    def _tree_flatten(self):
        leaves = (self.vec,)
        meta = {'scale': self.scale}
        return (leaves, meta)

    @classmethod
    def _tree_unflatten(cls, meta, leaves):
        return cls(*leaves, **meta)


tree_util.register_pytree_node(
    ConfigBox,
    ConfigBox._tree_flatten,
    ConfigBox._tree_unflatten,
)

def step_wrapper(cfg, st, inp):
    s = st
    y = inp
    s_next = cfg.step_fn(s, y)
    return s_next, [s_next]

@jax.jit
def run_scan_jit(cfg):
    body = lambda st, inp: step_wrapper(cfg, st, inp)
    st0 = jnp.array([0., 1.])
    seq = jnp.array([1., 2., 3.])
    last, out = jax.lax.scan(body, st0, seq)
    return last, out

if __name__ == "__main__":
    cfg1 = ConfigBox(jnp.array([1., 2.]), 2)
    last_state, outputs = run_scan_jit(cfg1)
    print(last_state)

    cfg2 = ConfigBox(jnp.array([3., 4.]), 2)
    last_state, outputs = run_scan_jit(cfg2)
    print(last_state)

What actually triggers recompilation

In JAX, what controls JIT cache hits for pytrees is how they flatten. Values returned in the auxiliary part of flattening are treated as static. Values returned as children are treated as dynamic. Static entries affect the cache key via their Python value. Dynamic entries affect the cache key via array structure (shape, dtype, sharding), but not their contents.

Applied to the example above, meta holds scale and is static. Changing the value of scale will trigger recompilation. The leaves hold vec and are dynamic. Changing the values in vec won’t recompile, but changing its shape, dtype, or sharding will. In the two calls shown, both vec instances have the same shape, dtype, and sharding, and scale is unchanged; therefore, there is no recompilation.

Another common worry is whether the lambda used to bake in cfg affects the JIT cache key. It does not in this setup, because the lambda itself is not part of the pytree flattening output and therefore does not influence the key.

Solution: keep static and dynamic parts straight

No changes are required to avoid recompilation in the shown scenario. As long as the static fields keep the same value and the dynamic arrays keep the same shape, dtype, and sharding, JIT will reuse the cached executable. If needed, you can confirm cache reuse using the approach referenced at https://stackoverflow.com/a/70127930/2937831. It is also possible to observe that the jitted function’s cache remains at size 1 after both calls in this program.

The same code, kept as is, meets that criterion:

import jax
import jax.numpy as jnp
from jax import tree_util

class ConfigBox:
    def __init__(self, vec, scale):
        self.vec = vec
        self.scale = scale

    def step_fn(self, st, inp):
        return (self.vec + st + jnp.ones(self.scale)) * inp

    def _tree_flatten(self):
        leaves = (self.vec,)
        meta = {'scale': self.scale}
        return (leaves, meta)

    @classmethod
    def _tree_unflatten(cls, meta, leaves):
        return cls(*leaves, **meta)


tree_util.register_pytree_node(
    ConfigBox,
    ConfigBox._tree_flatten,
    ConfigBox._tree_unflatten,
)

def step_wrapper(cfg, st, inp):
    nxt = cfg.step_fn(st, inp)
    return nxt, [nxt]

@jax.jit
def run_scan_jit(cfg):
    body = lambda st, inp: step_wrapper(cfg, st, inp)
    st0 = jnp.array([0., 1.])
    seq = jnp.array([1., 2., 3.])
    return jax.lax.scan(body, st0, seq)

Why it matters

Understanding which pieces are static and which are dynamic is the difference between steady, predictable runs and unexpected recompile spikes. Arrays can change value freely without invalidating the cache, while changing shapes, dtypes, or sharding patterns will. Scalars placed in the static part of the pytree will cause recompilation when their values change. Lambda captures in this pattern do not alter caching behavior.

Takeaways

If a pytree is part of your JAX program, decide intentionally what goes into the dynamic leaves and what goes into the static aux data. Keep static fields stable when you expect cache reuse, and keep dynamic arrays consistent in shape and dtype when you want to avoid recompilation. Closing over a dynamic value with a lambda is fine under those conditions. If in doubt, verify cache reuse using the method referenced above or by inspecting the jitted function’s cache size; you should see it remain at 1 across calls that don’t change the static data or the array structure.

The article is based on a question from StackOverflow by user1168149 and an answer by jakevdp.