2025, Dec 22 03:00
JAX vs NumPy for array reshaping: benchmarks, views vs copies, and when JIT makes it faster
Learn why JAX seems slower than NumPy for array reshaping (broadcast_to, moveaxis), how views vs copies and dispatch affect performance, and when JIT wins.
JAX vs. NumPy for array reshaping: why simple ops can look slower, and what to do about it
Problem overview
Porting a transformation-heavy numerical pipeline from NumPy to JAX to leverage JIT often starts with a surprise: basic array manipulations such as broadcast_to and moveaxis feel slower in JAX, even before JIT, and sometimes even with JIT. Microbenchmarks that apply only those two operations, including on large batches, may show NumPy completing them far faster than both eager JAX and jitted JAX.
This behavior is expected once you account for the very different execution models of the two libraries, especially around views vs. copies and per-op dispatch cost. The short version: NumPy is optimized for extremely fast per-op execution on CPU and many manipulations are view-only; JAX is optimized to fuse sequences of operations via JIT. You will see the payoff in pipelines, not in isolated micro-ops.
Minimal benchmark demonstrating the mismatch
The snippet below compares broadcast_to and moveaxis in NumPy, JAX eager, and JAX with JIT. Names are arbitrary; the logic is identical in all three cases.
import timeit
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit
arr_base_np = np.array([[1, 0, 0, 0.5],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1]])
arr_base_jx = jnp.array(arr_base_np)
batch_total = 1_000_000
print("### Benchmark: moveaxis + broadcast_to ###")
# NumPy
elapsed_np = timeit.timeit(
lambda: np.moveaxis(np.broadcast_to(arr_base_np[:, :, None], (4, 4, batch_total)), 2, 0),
number=10
)
print(f"NumPy: moveaxis + broadcast_to → {elapsed_np:.6f} s")
# JAX eager
elapsed_jx = timeit.timeit(
lambda: jnp.moveaxis(jnp.broadcast_to(arr_base_jx[:, :, None], (4, 4, batch_total)), 2, 0).block_until_ready(),
number=10
)
print(f"JAX: moveaxis + broadcast_to → {elapsed_jx:.6f} s")
# JAX JIT
@jit
def jit_move_bcast(a):
return jnp.moveaxis(jnp.broadcast_to(a[:, :, None], (4, 4, batch_total)), 2, 0)
jit_move_bcast(arr_base_jx).block_until_ready()
elapsed_jit = timeit.timeit(
lambda: jit_move_bcast(arr_base_jx).block_until_ready(),
number=10
)
print(f"JAX JIT: moveaxis + broadcast_to → {elapsed_jit:.6f} s")
print("\n### Benchmark: broadcast_to only ###")
# NumPy
elapsed_np_b = timeit.timeit(
lambda: np.broadcast_to(arr_base_np[:, :, None], (4, 4, batch_total)),
number=10
)
print(f"NumPy: broadcast_to → {elapsed_np_b:.6f} s")
# JAX eager
elapsed_jx_b = timeit.timeit(
lambda: jnp.broadcast_to(arr_base_jx[:, :, None], (4, 4, batch_total)).block_until_ready(),
number=10
)
print(f"JAX: broadcast_to → {elapsed_jx_b:.6f} s")
# JAX JIT
@jit
def jit_bcast_only(a):
return jnp.broadcast_to(a[:, :, None], (4, 4, batch_total))
jit_bcast_only(arr_base_jx).block_until_ready()
elapsed_jit_b = timeit.timeit(
lambda: jit_bcast_only(arr_base_jx).block_until_ready(),
number=10
)
print(f"JAX JIT: broadcast_to → {elapsed_jit_b:.6f} s")
What’s actually happening and why it looks slower
Two differences dominate here. First, in NumPy, many shape-manipulation operations such as broadcasting, transposing, reshaping, slicing, and moveaxis often produce views of the same underlying buffer. They return a lightweight wrapper with different strides and shape but avoid copying the data. In JAX, two array objects cannot share memory. The same category of operations therefore returns copies rather than views. When the operation is conceptually cheap in NumPy because it is view-only, the JAX equivalent incurs real work and memory movement.
Second, the per-operation dispatch overheads are very different. NumPy has extremely low dispatch cost for individual array ops on CPU. JAX incurs a higher cost per op when executed eagerly. On microbenchmarks where the operation is cheap, that overhead dominates and makes JAX look slow.
Given those two points, it’s natural to ask how JAX ever wins. The answer is JIT. With JIT compilation, JAX fuses sequences of operations into a single compiled program. Intermediate results are not allocated, may not exist as explicit buffers, and the dispatch overhead is paid once for the whole fused computation. NumPy cannot fuse at all; it pays dispatch and materializes intermediates at each step.
How JIT changes the picture
Within JIT, intermediate Python-level arrays do not necessarily correspond to any concrete buffer. Instead, the compiler fuses and rewrites the computation graph. A simple example makes this clear.
import jax
import jax.numpy as jnp
@jax.jit
def take_and_sum(v):
picked = v[::2]
return picked.sum()
source = jnp.arange(10)
print(take_and_sum.lower(source).compile().as_text())
Here, it doesn’t make sense to ask whether picked is a view or a copy. Under JIT, the slice and the sum are fused, and the intermediate array may never be created. The compiled HLO shows a single fused kernel that reads every second element from the input and reduces it, without materializing the sliced array as a separate buffer.
What to change in practice
If you benchmark individual cheap array manipulations, expect NumPy to be faster. The right way to assess JAX is to wrap real sequences of transformations in JIT so that fusion can do its job. This shifts the cost model: the compiler emits one program, the dispatch overhead is amortized, and buffers for intermediates often disappear. For single-step microbenchmarks, even with JIT, JAX can still lag behind NumPy because there is little to fuse and the op itself is trivial. On realistic pipelines, you should often find JAX faster, including on CPU.
If you want to inspect what happens to intermediates, you can look at the compiled representation, as shown above. Within a jitted function, asking whether a given Python variable is a view or a copy is not well-posed; the compiler may eliminate it entirely, reorder operations, or fuse them into larger kernels.
Why this matters
Confusing microbenchmark results can lead to the wrong architectural choices. Understanding that NumPy frequently returns views while JAX does not, and that JAX’s strengths show up only when it can fuse sequences under JIT, helps you benchmark fairly and design your pipeline to match the execution model. It also explains why isolated array layout tweaks look “slow” in JAX: they’re cheap views in NumPy, but real work in eager JAX, and not necessarily materialized inside JIT.
Takeaways
Use JAX for pipelines, not for one-off array shuffles. Expect isolated broadcast_to or moveaxis to be faster in NumPy due to views and very low dispatch overhead. Wrap end-to-end computations in jax.jit so the compiler can fuse operations, eliminate intermediates, and pay the dispatch cost once. If you need to reason about what truly happens to intermediates, inspect the compiled HLO of a jitted function rather than drawing conclusions from Python-level variables. For more background, the JAX FAQ covers this exact topic: https://docs.jax.dev/en/latest/faq.html#is-jax-faster-than-numpy.