2025, Dec 10 11:00

Python for Loop vs jax.lax.scan in JAX: compile-time vs runtime, JIT trade-offs and unroll tips

Learn when to use a Python for loop or jax.lax.scan in JAX. Compare JIT compile time vs runtime, tune with unroll, and benchmark to maximize performance.

Choosing between a Python for loop and jax.lax.scan in JAX is not just a stylistic decision. It affects compilation behavior, runtime performance, and the way the compiler can optimize your program. If you’ve ever flipped a loop back and forth and wondered why performance moves in non‑intuitive ways, this guide is for you.

Problem in context

Consider a nested training setup where an outer training loop calls a function that itself iterates several times and performs a rollout via scan. In this kind of workload, it’s natural to ask whether the inner loop should remain a Python for or be refactored into a scan, and how that choice scales as the surrounding code repeats it many times.

for a in range(num_train_steps):
  for b in range(num_env_steps):
    execute()
@jax.jit
def execute():
  for c in range(num_algo_iters):
    jax.lax.scan(rollout_step, init_carry, xs=None, length=inner_scan_len)

In a test where num_env_steps was varied (1, 100, 1000, 10000) and the execute function was jitted, the inner loop c was alternated between a for and a scan while keeping the innermost rollout as a scan. With 5 iterations for the c loop and 2 iterations for the innermost scan, the observed act()-level timings were approximately 1.5, 11.3, 99.0, 956.2 seconds for the scan variant of c, and 5.1, 14.5, 103.6, 972.7 seconds for the for variant of c. In that experiment, the for version never became faster as the outer repetition increased.

Why this happens

JAX unrolls Python control flow inside jitted functions. A for with 100 iterations becomes a single linear program carrying 100 copies of the body. The upside is that the compiler can optimize across iterations, for example by fusing ops between neighboring steps or dropping entire subgraphs if an output is unused. The downside is that compilation cost grows super‑linearly with program size, so large bodies and many iterations can lead to long compile times.

With jax.lax.scan or jax.lax.fori_loop the loop lives inside HLO. The body is parsed and compiled once, which makes compilation much more efficient. The tradeoff is that the compiler has fewer degrees of freedom to optimize across iterations, so you may leave some runtime performance on the table relative to a fully unrolled for.

There is no single winner. Smaller bodies with fewer iterations often favor for because the compiler can be aggressive across iterations without exploding compile time. Larger bodies or many iterations often favor scan or fori_loop because they keep compilation manageable.

It’s also useful to calibrate expectations about complexity. For Python for loops you should definitely expect super‑linear compile time growth as the unrolled program grows, while runtime may or may not be sublinear depending on the specific operations and compiler heuristics. With scan you generally get far more efficient compilation because the body is compiled once and the looping is represented in HLO; at runtime, the behavior may be less aggressively optimized across iterations than a fully unrolled program.

A minimal, equivalent setup

The following two versions keep the same core logic but differ in how the c loop is expressed. The innermost rollout remains a scan in both cases.

import jax
import jax.numpy as jnp
from jax import lax
# Dummy rollout body to keep semantics clear
# carry in, carry out
def rollout_core(carry, _):
  return carry + 1.0, None
# Version A: inner loop c as a Python for
@jax.jit
def run_step_for(init_state, algo_iters, inner_len):
  carry = init_state
  for c in range(algo_iters):
    carry, _ = lax.scan(rollout_core, carry, xs=None, length=inner_len)
  return carry
# Version B: inner loop c as a scan
@jax.jit
def run_step_scan(init_state, algo_iters, inner_len):
  def c_body(carry, _):
    return lax.scan(rollout_core, carry, xs=None, length=inner_len)
  return lax.scan(c_body, init_state, xs=None, length=algo_iters)[0]

Both versions produce the same final state for a given init_state, algo_iters and inner_len, but one uses a Python for and the other uses scan for the c loop.

Tuning the tradeoff with unroll

scan has an unroll parameter that lets you dial between extremes. Setting unroll=True makes scan behave more‑or‑less like an unrolled for. You can also partially unroll by passing an integer, for example unroll=n in the range 1 < n < total_iterations, which effectively places a small unrolled loop inside each scan step. This exposes more optimization opportunities while keeping compilation under control.

@jax.jit
def run_step_scan_tuned(init_state, algo_iters, inner_len, outer_unroll, inner_unroll):
  def c_body(carry, _):
    return lax.scan(rollout_core, carry, xs=None, length=inner_len, unroll=inner_unroll)
  return lax.scan(c_body, init_state, xs=None, length=algo_iters, unroll=outer_unroll)[0]

When unroll=True at a given level, expect behavior to be more like a Python for at that level, including increased compile times. When unroll is left at its default, compilation is typically much more efficient.

So which should you use?

The best option depends on your program and priorities. If the loop body is small and the iteration count is modest, a for inside jit often performs well due to cross‑iteration optimizations. If the loop body is larger or the iteration count is high, scan or fori_loop usually compiles much faster and may be the better choice overall. There is no guarantee that pushing repeats to 100k or a million will flip the result in favor of for; the outcome depends on both compilation and runtime characteristics of your specific workload.

Setting unroll=True in scan makes it essentially equivalent to a for in terms of how the compiler sees it. That means you should expect the same kind of super‑linear growth in compilation effort, along with the possibility of runtime benefits from broader optimization. It will not universally yield speed‑ups; it simply moves you toward the unrolled end of the spectrum.

Benchmarking correctly matters

When timing JAX code, ensure you benchmark properly. It is important to block on device execution so you measure actual computation time rather than just enqueueing work. Calling .block_until_ready() on the result is the usual way to do this. Also see the official guidance: JAX FAQ on benchmarking.

Why this is worth knowing

Loop representation directly shapes the compiler’s search space. Leaving a loop in Python gives the compiler the chance to fuse and prune across iterations at the cost of compile time that grows faster than linearly. Pushing the loop into HLO makes compilation efficient and predictable, at the cost of fewer cross‑iteration transformations. Understanding this spectrum helps you make principled choices rather than blanket‑replacing for with scan or vice versa.

Practical guidance

Start by clarifying whether compilation latency or steady‑state runtime is more important for your application. For small, short loops, a Python for within jit is often fine. For large bodies or many iterations, prefer scan or fori_loop. If you need to tune, use scan’s unroll to balance compilation cost against potential runtime optimizations. And always benchmark with proper synchronization to get trustworthy numbers.

Closing thoughts

There is no one‑size‑fits‑all rule that makes for categorically faster than scan or the other way around. Treat the choice as a tradeoff between compilation cost and runtime opportunities. Use scan for scalability, use for for aggressive cross‑iteration optimization on smaller problems, and reach for unroll when you need a middle ground. Measure with .block_until_ready() and the JAX benchmarking best practices, and let those measurements guide the decision in your codebase.