2025, Dec 09 13:00

Why switching from Python lists to jnp.array changes JAX optimization: weakly typed scalars, dtype, and reproducible training

Learn how Python lists vs jnp.array trigger dtype and precision shifts in JAX. Fix weakly typed scalars for consistent Flax/Optax training and reproducible runs.

When Python lists meet JAX arrays: why a tiny type detail derails your optimization

Switching a preprocessing function from returning a list of lists to a list of jnp.array can look like a harmless refactor. Yet in JAX-based training loops, that subtle change can materially alter the outcome. In one RNN optimization workflow using Flax and Optax, the same pipeline produced a metric of about 0.9997 when parameters were prepared as a list of lists, but only around 0.998 when the same parameters were prepared as a list of jnp.array. Everything else was held constant: seeds, steps, iterations, and the initial dictionary of parameters.

This guide explains why that happens, how JAX’s weakly-typed scalars can nudge you into different precision regimes, and what to do so your training behaves consistently.

Minimal setup that triggers the discrepancy

The pipeline starts from a nested dict of scalar parameters. That dict is flattened at each time step and fed into an RNN-driven optimization loop. Here is an example input structure:

initial_params = {
    "param1": {
        "gamma": 0.1,
        "delta": -3 * jnp.pi / 2,
    }
}

The preprocessing function was implemented in two ways. In one version it returns a list of lists; in the other it returns a list of jnp.array. The core logic is identical, only the container type differs.

Version that returns Python lists of leaves:

import jax
import jax.numpy as jnp

def pack_params_from_mapping(cfg_tree):
    """
    Convert a nested mapping of parameters to a flat Python list and record segment lengths.

    Args:
        cfg_tree: Nested mapping of parameters.

    Returns:
        tuple: list of lists of leaves, and list of segment lengths.
    """
    packed = []
    seg_lengths = []
    for branch in cfg_tree.values():
        leaf_vals = jax.tree_util.tree_leaves(branch)
        packed.append(leaf_vals)
        seg_lengths.append(len(leaf_vals))
    return packed, seg_lengths

Version that returns JAX arrays of leaves:

import jax
import jax.numpy as jnp

def pack_params_from_mapping(cfg_tree):
    """
    Convert a nested mapping of parameters to a flat list of jnp.array and record segment lengths.

    Args:
        cfg_tree: Nested mapping of parameters.

    Returns:
        tuple: list of jnp.array leaves, and list of segment lengths.
    """
    packed = []
    seg_lengths = []
    for branch in cfg_tree.values():
        leaf_vals = jax.tree_util.tree_leaves(branch)
        arr = jnp.array(leaf_vals)
        packed.append(arr)
        seg_lengths.append(arr.shape[0])
    return packed, seg_lengths

What is really happening: weakly-typed scalars and precision

The operative difference between the two approaches is not the container itself, but the dtypes that flow into subsequent computations. Python floats inside lists are treated by JAX as weakly-typed values. In practice, weakly-typed scalars defer their dtype to whatever they are combined with. This can silently pull computations toward lower precision in mixed-type expressions.

Consider this short, self-contained illustration. With 64-bit enabled, a list of Python floats remains weakly typed, whereas an explicit array fixes the dtype and leads to higher-precision computation downstream.

import jax
import jax.numpy as jnp

jax.config.update('jax_enable_x64', True)

val_list = [0.1, -4.71238898]
val_array = jnp.array(val_list)

x32 = jnp.float32(1.0)

# Mixing a float32 array with a Python float (weak type)
# leads to a float32 result.
x32 + val_list[1]

# Mixing a float32 array with a jnp.array element (strict type)
# leads to a float64 result here.
x32 + val_array[1]

In other words, using lists of Python floats can push parts of your computation into float32, while arrays of the same values can anchor them in float64. In an optimization loop, these small differences accumulate and can lead to noticeably different objective values, even when everything else is deterministic.

How to make the behavior consistent

The practical fix is to make dtypes explicit at the boundaries of your data pipeline and avoid mixing weakly-typed Python floats with JAX arrays. There are two straightforward ways to do that based on the behavior above.

One option is to convert to jnp.array and choose an explicit dtype for the leaves as soon as you flatten the parameter structure. That makes all downstream math operate at a known precision:

import jax
import jax.numpy as jnp

jax.config.update('jax_enable_x64', True)

def pack_params_from_mapping(cfg_tree):
    packed = []
    seg_lengths = []
    for branch in cfg_tree.values():
        leaf_vals = jax.tree_util.tree_leaves(branch)
        arr = jnp.array(leaf_vals, dtype=jnp.float64)
        packed.append(arr)
        seg_lengths.append(arr.shape[0])
    return packed, seg_lengths

Another lever is to initialize your scalars as jnp.float64 (or your chosen dtype) right in the input dictionary. That keeps precision consistent even if you keep the list-of-lists shape later. You can also enable 64-bit mode so JAX defaults to float64 like NumPy.

A side note that can simplify your pipeline: dictionaries are native PyTrees in JAX. Depending on your use case, you might not need to convert to lists at all, because tree utilities already understand nested dicts.

Why this matters in training loops

Optimizer dynamics are sensitive to numerical precision. Seemingly tiny differences in dtype promotion can alter gradients and step magnitudes, leading to different convergence profiles and final metrics. If your input path sometimes delivers weakly-typed Python floats and other times delivers strict-typed JAX arrays, you are implicitly running two different numeric regimes.

Being explicit about types ensures repeatability when everything else is deterministic. It also prevents hard-to-debug discrepancies between two runs that look identical on the surface.

Takeaways

If switching from a list of lists to a list of jnp.array changes the outcome, the likely cause is JAX’s handling of weakly-typed scalars and resulting precision differences. Make the dtype unambiguous at the boundary: either build arrays with an explicit dtype for your leaves, initialize the inputs as jnp.float64 (or your chosen precision), or enable 64-bit mode when that is appropriate for your workload. Keeping dtypes consistent across the pipeline avoids accidental precision downgrades and stabilizes your optimization behavior.