2025, Oct 17 08:00
How to Avoid TracerArrayConversionError in JAX: Don't Mutate Object Attributes Inside jit
Learn why JAX jit and object state clash, causing TracerArrayConversionError, and how to fix it with pure functions, returning values, and pytrees over mutation.
JAX, jit, and object state don’t always mix. A common pitfall is attempting to mutate attributes inside a jitted method and then trying to use those attributes as if they were ordinary NumPy arrays. The result is a TracerArrayConversionError and a bit of head-scratching about when values become concrete and how to structure code to avoid surprises.
Problem setup
The issue appears when a jitted function assigns an intermediate value to a member attribute and that attribute is later accessed outside the jitted context. The code below reproduces the behavior.
def host_bootstrap():
    print("This code runs once at the beginning of the program.")
    return jnp.array([1.0, 2.0, 3.0])
class Box:
    @partial(jax.jit, static_argnums=0)
    def apply_jit(self, arr):
        print("This code runs once during the first trace.")
        out = arr * 2
        self.cache = out
        return out
# Program execution
payload = host_bootstrap()
box = Box()
out1 = box.apply_jit(payload)  # Tracing happens here
np.array(out1)
np.array(box.cache)
Calling np.array on the method result works, but doing the same on the attribute fails with:
jax.errors.TracerArrayConversionError: The numpy.ndarray conversion method __array__() was called on traced array with shape float32[3]
What actually goes wrong
JAX transformations like jit are built for pure functions. A pure function doesn’t mutate inputs or close over hidden state, and its outputs depend only on its inputs. Assigning to self.cache inside the jitted method mutates a function input and breaks that purity contract. During the jit trace/compile stage, values like out are tracers rather than concrete arrays. Returning a value from the jitted function gives JAX a well-defined boundary to materialize it; assigning the same value to an attribute does not.
When will the value in the attribute become concrete?
Only when it is returned from the JIT-compiled function. If it’s not returned, it remains a traced value at that point in execution, and using it as a NumPy array raises the conversion error. Code that mutates state inside jit has behavior that is not well-defined.
There’s another red flag here: marking self as static while modifying it. That is explicitly discouraged. See JAX documentation on pure functions and the FAQ on using jit with methods for the pitfalls and recommended approaches: Pure Functions and How to use jit with methods.
Fixing the design
If the value needs to be used outside the jitted function, return it and handle any object state updates on the Python side. That preserves purity and gives JAX a clean way to materialize the result. Also, don’t mark self as static when you intend to update or depend on it.
def host_bootstrap():
    print("This code runs once at the beginning of the program.")
    return jnp.array([1.0, 2.0, 3.0])
class Box:
    @jax.jit
    def apply_jit(self, arr):
        print("This code runs once during the first trace.")
        out = arr * 2
        return out
# Program execution
payload = host_bootstrap()
box = Box()
out1 = box.apply_jit(payload)
# Update object state outside jit
box.cache = out1
# Both conversions are fine now
np.array(out1)
np.array(box.cache)
Another viable approach is to make the object a pytree, avoid marking it static, and return the updated instance from the function. This aligns with the idea of explicit state passing rather than hidden mutation.
Why this matters
JAX’s compilation and transformation model assumes pure functions. Leaning on side effects during tracing breaks those assumptions and leads to fragile, undefined behavior. Returning values provides a clear materialization boundary and keeps transformations predictable and composable.
Takeaways
If you need a value outside a jitted method, return it. Avoid mutating attributes inside jit, and don’t mark self as static when it is being modified. If state must flow across method calls, pass it explicitly, or structure your class as a pytree and return the updated state. These patterns keep your code compatible with JAX’s execution model and eliminate TracerArrayConversionError surprises.
The article is based on a question from StackOverflow by Warm_Duscher and an answer by jakevdp.