2025, Oct 18 01:00
Understanding numeric drift in JAX vmap across batch sizes on GPU float32 and practical ways to reduce it
Learn why JAX vmap outputs differ across batch sizes on GPU float32: floating point order, scheduling, fusion. Repro and fixes included (float64 or CPU).
Minor numeric drift when vmapping a neural network over different batch sizes can be puzzling, especially when the underlying inputs are identical. If you see small discrepancies in the first few rows of the output depending on how many rows you pass into a vmapped function, that behavior can still be correct. The short explanation is that floating point arithmetic is sensitive to the sequence of operations; change the order, change the rounding profile. On GPU with float32 this shows up easily, while on CPU or in float64 it may disappear.
Reproducing the issue
The snippet below applies an Equinox MLP to a batch twice: first to a small slice, then to the full array, comparing the first rows of both results. Only names differ from the original pattern; the program behavior is the same.
import jax
import jax.numpy as jnp
import equinox as eqx
def batch_apply(arr_in, net_fn):
    mapped = eqx.filter_vmap(net_fn.__call__)
    return mapped(arr_in)
rng = jax.random.PRNGKey(0)
rng, key_net = jax.random.split(rng, 2)
model = eqx.nn.MLP(2, 2, 10, 2, key=key_net)
rng, key_x = jax.random.split(rng, 2)
xb = jax.random.normal(key_x, (10000, 2))
delta = batch_apply(xb[:10], model) - batch_apply(xb, model)[:10]
print("eqx error:", delta)
What actually causes the discrepancy
This is expected and it is not specific to vmap as a transformation. It is about floating point arithmetic. With float32 you accumulate rounding error at every operation. When you execute “the same” math through different computational paths, you accumulate those rounding errors in different places and in different orders, which slightly shifts the results. Changing the batch size changes how the computation is scheduled and fused, so the effective operation order differs.
The device matters because operation sequencing differs by architecture. CPU execution tends to follow a more serial accumulation order, which in practice can keep the order of operations consistent across batch sizes. GPU execution is highly parallel, and parallel work partitioning and accumulation patterns depend on input shapes. That difference in sequencing leads to different rounding, hence slightly different numbers in float32.
What to do about it
The behavior itself is correct. The discrepancy stems from valid floating point rounding differences due to different operation orders. In the observed setup, using float64 removed the discrepancy, as did running on CPU; both change the numerical characteristics or sequencing enough to eliminate the visible drift in this example.
Why this matters
It’s important not to treat batched GPU computations as bitwise deterministic across shape changes in float32. If you compare results across batch sizes or devices, expecting exact equality can lead you to misinterpret routine floating point behavior as a logic bug. Understanding that scheduling, batching, and device parallelism influence accumulation order helps you set appropriate expectations and diagnostics.
Conclusion
Small value differences across vmapped batch sizes on GPU in float32 are a natural consequence of floating point rounding accumulated along different operation orders. The computation is still correct. If you require tighter numeric agreement for your use case, the observed configuration showed that switching to float64 or evaluating on CPU removed the discrepancy. Otherwise, accept tiny deviations as a normal part of parallel floating point computation and evaluate equality with that in mind.