2026, Jan 13 05:00

Freeze parameters in flax.nnx without errors: align wrt filter and value_and_grad with nnx.DiffState

Learn how to freeze parameters in flax.nnx without update errors: align the wrt filter and value_and_grad using nnx.DiffState to fix tree-structure mismatches.

Freezing part of a model while fine-tuning the rest is a common transfer learning pattern. In flax.nnx you can filter parameters to optimize with wrt, but there is a subtlety: the gradient transform must be aligned with that filter. If it isn’t, you’ll run into a tree-structure mismatch at update time.

Reproducing the issue

The example below aims to freeze the kernel of an nnx.Linear and train only the bias. The training loop computes gradients with nnx.value_and_grad and updates an nnx.Optimizer configured with a wrt filter that selects just the bias. This leads to a mismatch during the optimizer update.

from jax import numpy as jnp
from jax import random
from flax import nnx
import optax
from matplotlib import pyplot as plt
def glinear(x, slope=2.234, intercept=-1.123):
    return slope * x + intercept
def loss_eval(module, feats, targets):
    preds = module(feats)
    diffs = targets - preds
    mse = jnp.mean(diffs ** 2)
    mae = jnp.mean(jnp.abs(diffs))
    return mse, mae
if __name__ == '__main__':
    dims = (2, 55, 1)
    steps = 123
    seeds = nnx.Rngs(123)
    layer = nnx.Linear(1, 1, rngs=seeds)
    layer.kernel.value = jnp.array([[2.0]])  # pretend this is a loaded pretrained kernel
    pkey = seeds.params()
    xfull = random.uniform(pkey, dims, minval=-10, maxval=10)
    tgt1, tgt2 = glinear(xfull)
    x1, x2 = xfull
    vg = nnx.value_and_grad(loss_eval, has_aux=True)
    @nnx.scan(
        in_axes=(nnx.Carry, None, None,),
        out_axes=(nnx.Carry, 0),
        length=steps
    )
    def train_loop(opt_state, xb, yb):
        (mse, mae), grad_tree = vg(opt_state.model, xb, yb)
        opt_state.update(grad_tree)
        return opt_state, (mse, mae)
    mask = nnx.All(nnx.PathContains("bias"))
    trainer = nnx.Optimizer(layer, optax.adam(learning_rate=1e-3), wrt=mask)
    trainer, (mse_hist, mae_hist) = train_loop(trainer, x1, tgt1)
    print('AFTER TRAINING')
    print('training loss:', mse_hist[-1])
    yhat1, yhat2 = trainer.model(xfull)
    err2 = tgt2 - yhat2
    test_mse = jnp.mean(err2 * err2)
    print('test loss:', test_mse)
    print('m approximation:', trainer.model.kernel.value)
    print('b approximation:', trainer.model.bias.value)

The run fails with a structure mismatch error during the update because the optimizer is configured to update only bias while the gradient transform returns gradients for a larger set.

ValueError: Mismatch custom node data: ('bias', 'kernel') != ('bias',); value: State({ ... })

Why it breaks

The optimizer’s wrt filter restricts updates to parameters whose path contains “bias”. Meanwhile, the gradient transform is created without any restriction, so its output includes a broader parameter set. When the update tries to apply that gradient tree under a filtered optimizer, the shapes of the state trees differ, which triggers the mismatch error. The missing link is nnx.DiffState, which ties the gradient transform to the same parameter selection used by the optimizer.

The fix

Define the parameter filter before creating the gradient transform. Build a nnx.DiffState from that filter, and pass it to nnx.value_and_grad via argnums. This ensures the gradient tree layout matches what the optimizer expects. For more background on DiffState, see the transforms documentation for nnx.grad here: https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html. Another illustrative discussion of parameter filtering with nnx.DiffState can be found here: https://github.com/google/flax/issues/4167.

from jax import numpy as jnp
from jax import random
from flax import nnx
import optax
from matplotlib import pyplot as plt
def glinear(x, slope=2.234, intercept=-1.123):
    return slope * x + intercept
def loss_eval(module, feats, targets):
    preds = module(feats)
    diffs = targets - preds
    mse = jnp.mean(diffs ** 2)
    mae = jnp.mean(jnp.abs(diffs))
    return mse, mae
if __name__ == '__main__':
    dims = (2, 55, 1)
    steps = 123
    seeds = nnx.Rngs(123)
    layer = nnx.Linear(1, 1, rngs=seeds)
    layer.kernel.value = jnp.array([[2.0]])  # pretend this is a loaded pretrained kernel
    pkey = seeds.params()
    xfull = random.uniform(pkey, dims, minval=-10, maxval=10)
    tgt1, tgt2 = glinear(xfull)
    x1, x2 = xfull
    mask = nnx.All(nnx.PathContains("bias"))
    diff = nnx.DiffState(0, mask)
    vg = nnx.value_and_grad(loss_eval, argnums=diff, has_aux=True)
    @nnx.scan(
        in_axes=(nnx.Carry, None, None,),
        out_axes=(nnx.Carry, 0),
        length=steps
    )
    def train_loop(opt_state, xb, yb):
        (mse, mae), grad_tree = vg(opt_state.model, xb, yb)
        opt_state.update(grad_tree)
        return opt_state, (mse, mae)
    trainer = nnx.Optimizer(layer, optax.adamw(learning_rate=1e-3), wrt=mask)
    trainer, (mse_hist, mae_hist) = train_loop(trainer, x1, tgt1)
    print('AFTER TRAINING')
    print('training loss:', mse_hist[-1])
    yhat1, yhat2 = trainer.model(xfull)
    err2 = tgt2 - yhat2
    test_mse = jnp.mean(err2 * err2)
    print('test loss:', test_mse)
    print('m approximation:', trainer.model.kernel.value)
    print('b approximation:', trainer.model.bias.value)

Why this detail matters

Partial fine-tuning is a first-class use case in transfer learning. In nnx, filtering with wrt is only half of the story; the gradient transform must target the same subset. Aligning nnx.value_and_grad via nnx.DiffState prevents tree mismatches and keeps optimization focused on the selected parameters.

Takeaways

When freezing parameters in flax.nnx, build a shared understanding between your optimizer and your gradient transform. Define the parameter filter up front, wrap it in nnx.DiffState, and pass it as argnums to nnx.value_and_grad. This makes the gradient tree shape consistent with the optimizer’s wrt filter and avoids update-time surprises.