2025, Oct 21 02:00

When Broadcasting Isn't Enough in JAX: Use vmap for Explicit, Faster, Safer per-sample Vectorization

Learn when to use JAX vmap over broadcasting: handle non-vectorized ops like jnp.histogram, keep single-sample logic, improve readability and performance.

JAX and vectorization can feel confusing at first glance. Many array operations already accept batched inputs through broadcasting, so why does jax.vmap exist at all? The short answer: not every operation is natively vectorized, and sometimes an explicit vectorization is clearer, safer, and faster than fiddling with dimensions or writing Python loops.

When “already vectorized” is not enough

Broadcasting and batch axes are pervasive in deep learning codebases, and they often cover the common cases. But there are operations in JAX that are not natively vectorized across a batch dimension. For those, you either reach for an explicit loop or you use jax.vmap to create a batched version of the function. This keeps the code readable and typically improves performance compared to Python-side loops.

Problem demonstration

Consider jnp.histogram. It does not accept a batch axis natively. A straightforward approach is to loop in Python and stack the results manually.

import jax
import jax.numpy as jnp

# Single-sample histogram helper
def single_hist(vec, nbins, span):
    # jnp.histogram returns (hist, bin_edges); we keep the histogram only
    counts, _ = jnp.histogram(vec, bins=nbins, range=span)
    return counts

# Toy batch of samples (2 samples, each with 3 values)
batch_samples = jnp.array([[0.0, 1.0, 2.0],
                           [2.0, 3.0, 1.0]])

num_bins = 4
value_range = (0.0, 4.0)

# Python-side batching (loop + stack)
looped = [single_hist(row, num_bins, value_range) for row in batch_samples]
batched_counts = jnp.stack(looped, axis=0)

This works, but it introduces a Python loop and extra ceremony around stacking. It is also easy to make mistakes when manually adding or juggling dimensions just to simulate a batch.

What is the core issue?

The root of the confusion is that “vectorized” can mean two different things: some primitives are implemented to naturally handle a batch dimension; others are scalar or single-sample by default and need to be batched. jnp.histogram and jnp.bincount belong to the latter group. In those cases, jax.vmap provides a clean path to express “apply this function independently over a batch” without changing the function’s single-sample semantics or contorting shapes.

There is also a style dimension. Sometimes developers introduce a local extra dimension to avoid writing a loop, then reduce back later. While that works, the intent is often clearer when stated explicitly with jax.vmap.

Solution with jax.vmap

jax.vmap transforms a single-sample function into a batched one over a chosen axis. It works over PyTrees, which enables library-level conventions that hide batch handling entirely. For independent operations across samples, it is a natural fit.

import jax
import jax.numpy as jnp

# Same single-sample function as above
def single_hist(vec, nbins, span):
    counts, _ = jnp.histogram(vec, bins=nbins, range=span)
    return counts

batch_samples = jnp.array([[0.0, 1.0, 2.0],
                           [2.0, 3.0, 1.0]])

num_bins = 4
value_range = (0.0, 4.0)

# Vectorized version: apply over the leading axis of batch_samples
vmapped_hist = jax.vmap(lambda v: single_hist(v, num_bins, value_range))
counts_batched = vmapped_hist(batch_samples)

The vmapped version expresses the same logic as the looped implementation, but without the explicit loop and manual stacking. This is convenient and avoids loops to improve readability as well as performance.

Where vmap shines beyond the basics

vmap operates over PyTrees, so entire parameter structures can be handled without writing axis-handling code. Some libraries, such as equinox, embrace this convention and encourage vmapping across the whole parameter tree. This removes the need to manually thread batch axes through your model’s code. The approach assumes independence across samples and will not work for operations that fundamentally mix information between samples, such as a batch norm.

In other scenarios, you might be tempted to add a temporary axis to force a broadcast, run an operation, and then reduce away the extra dimension. vmap often communicates that intent more directly. As a concrete intuition, think of cases like applying convolution2d with different kernels per sample. One way is to stack kernels, replicate and stack channels, and run a single convolution over the enlarged axis. Another way is to write a single-sample convolution and then vmap it across the kernel or sample axis. Both approaches can work; vmap simply states the per-sample independence explicitly.

Why this matters

Batching strategy is not just a stylistic choice; it affects correctness, readability, and performance. Knowing when an operation is not natively vectorized prevents silent shape bugs and accidental Python loops that limit throughput. Using vmap where appropriate keeps single-sample logic intact, composes cleanly with PyTrees, and aligns with libraries that rely on vmap by convention.

Practical guidance

In day-to-day code, broadcasting and batch axes are a solid default. Reach for jax.vmap when you hit functions without native vectorization, when a library’s design encourages vmapping over PyTrees, or when you need to vectorize along non-conventional axes. If you find yourself introducing ad hoc dimensions and reductions to dodge a loop, consider whether a vmap would make the intent clearer and the code easier to maintain.

There is no universal rule, and personal preference plays a role. The key is to recognize the dividing line: use native batching when it exists and is idiomatic, and use vmap to express independent per-sample computation when it does not. Keeping that distinction in mind leads to simpler, more robust JAX code.

The article is based on a question from StackOverflow by Mingruifu Lin and an answer by Axel Donath.