2025, Dec 03 07:00

Efficient JAX Image Reconstruction from Patch Grids Using Reshape and Transpose (No Python Loops)

Learn how to reconstruct images from patch arrays in JAX using only reshape and transpose—no Python loops. Faster, XLA-friendly, jit tips and shape algebra.

Rebuilding images from patch arrays is a common task in JAX pipelines. The naive way—nested Python loops—works functionally but leaves performance on the table and can lead to memory issues on accelerators. The good news: when the patches are laid out on a regular grid, you can reassemble the original tensors using only transpose and reshape.

Problem setup

We start from batched images and convert them to patches. The reconstruction below is correct but inefficient due to pure-Python loops over the grid and spatial dimensions.

patch_kernel = jnp.ones((PATCH_VEC, N_CH, PH, PW), dtype=jnp.float32)

def extract_tiles(x):
    tile_grid = lax.conv_general_dilated_patches(
        x, (PH, PW), (PH, PW), padding='VALID'
    )
    # return channels last within the patch dimension
    return jnp.transpose(tile_grid, [0, 2, 3, 1])

# bfrc is a batch of images shaped (batch, channels, height, width)
tile_buf = extract_tiles(bfrc)

# V_SPLITS == IMG_HEIGHT // PH
# H_SPLITS == IMG_WIDTH // PW

# Note the two targets share the same value; the second name is used below
tiles_vh_c_ph_pw = tiles_alias = jnp.reshape(
    tile_buf, (V_SPLITS, H_SPLITS, N_CH, PH, PW)
)

recon_img = np.zeros(EXP_SHAPE)

for vi in range(0, tiles_vh_c_ph_pw.shape[0]):
    for hi in range(0, tiles_vh_c_ph_pw.shape[1]):
        for ch_ix in range(0, tiles_vh_c_ph_pw.shape[2]):
            for pr in range(0, tiles_vh_c_ph_pw.shape[3]):
                for pc in range(0, tiles_vh_c_ph_pw.shape[4]):
                    r_idx = vi * PH + pr
                    c_idx = hi * PW + pc
                    recon_img[0, ch_ix, r_idx, c_idx] = tiles_vh_c_ph_pw[vi, hi, ch_ix, pr, pc]

# This assert passes
assert jnp.max(jnp.abs(recon_img - bfrc[0])) == 0

What’s really going on

The patch tensor is a regular grid with shape (V_SPLITS, H_SPLITS, N_CH, PH, PW). Because patches tile the image without overlap, folding them back is a pure layout operation. No arithmetic is needed—only reordering of axes and collapsing/expanding dimensions. The line that assigns two different names to the same reshaped tensor looks like a naming slip, but it does not change behavior.

Assuming the input batch has shape (batch, channels, height, width) and V_SPLITS = IMG_HEIGHT // PH with H_SPLITS = IMG_WIDTH // PW, reconstruction becomes a sequence of reshape and transpose steps.

Efficient JAX solution

Below is a layout-only reconstruction that avoids Python loops and uses JAX array operations.

# tiles_vh_c_ph_pw has shape (V_SPLITS, H_SPLITS, N_CH, PH, PW)
v_bins, h_bins, n_chan, p_h, p_w = tiles_vh_c_ph_pw.shape

full_h = v_bins * p_h
full_w = h_bins * p_w

# Move channels to the last axis inside each patch block: (V, H, PH, PW, C)
ordered = jnp.transpose(tiles_vh_c_ph_pw, (0, 1, 3, 4, 2))

# Reorder grid dimensions into contiguous image layout
restored = ordered.reshape(v_bins, h_bins, p_h, p_w, n_chan)
restored = restored.transpose(0, 2, 1, 3, 4)
restored = restored.reshape(full_h, full_w, n_chan)

# Final shape: (1, C, H, W)
reconstructed_batch = jnp.transpose(restored, (2, 0, 1))[jnp.newaxis, ...]

This approach is cheaper than nested loops because it stays inside JAX’s compiled array ops. You can additionally wrap it with @jax.jit for extra speed.

Why this matters

Eliminating Python-side loops lets XLA fuse shape operations and execute them efficiently on device. It reduces overhead and improves memory behavior, and as noted in practice, wrapping the reconstruction in @jax.jit may be necessary to avoid OOM on GPU.

Takeaways

When you slice images into a regular patch grid and need to reassemble them, think in terms of shape algebra: transpose to gather axes that belong together, then reshape to collapse or expand them into the desired image layout. Verify the exact tensor shapes you have—especially the channel position and the grid dimensions—and prefer compiled JAX transformations to Python loops. Keeping shapes and constants explicit also makes the example reproducible and easier to reason about.