2025, Sep 21 17:00
Control Hyperparameters with Optax Schedules in JAX + Flax nnx using the Optimizer Step (Flax 0.11)
Learn to drive model hyperparameters with Optax schedules in JAX + Flax nnx. Use the optimizer step counter to fetch per-step values, with a Flax 0.11 example.
Controlling a model hyperparameter over training with an Optax schedule is straightforward once you know where to source the current step. The key is that schedules are plain callables and the optimizer exposes a step counter you can pass into them. Below is a minimal, end-to-end example in JAX + Flax nnx + Optax that shows how to retrieve the scheduled value on every optimization step and collect it during training. The example targets Flax 0.11.
Complete example
from jax import numpy as jnp
from jax import random
from flax import nnx
import optax
from matplotlib import pyplot as plt
if __name__ == '__main__':
    arr_shape = (2,55,1)
    num_epochs = 123
    keyset = nnx.Rngs(123)
    net = nnx.Linear(1, 1, rngs=keyset)
    prng_key = keyset.params()
    xgrid = random.uniform(prng_key, arr_shape, minval=-10, maxval=10)
    def gen_line(x, m=2.234, b=-1.123):
        return m * x + b
    y_obs1, y_obs2 = gen_line(xgrid)
    x_head, x_tail = xgrid
    decay_ratio = 0.9
    lr_schedule = optax.schedules.cosine_decay_schedule(
        init_value=2e-1,
        decay_steps=int(decay_ratio * num_epochs),
        alpha=0.01,
    )
    hp_schedule = optax.schedules.linear_schedule(
        init_value=12,
        end_value=234,
        transition_steps=int(decay_ratio * num_epochs),
    )
    trainable = nnx.Param
    opt = nnx.Optimizer(
        net,
        tx=optax.adam(lr_schedule),
        wrt=trainable,
    )
    @nnx.scan(
        in_axes=(nnx.Carry, None, None),
        out_axes=(nnx.Carry, 0),
        length=num_epochs,
    )
    def train_scan(state, xb, yb):
        def criterion(module, inputs, targets):
            preds = module(inputs)
            diff = targets - preds
            mse = jnp.mean(diff ** 2)
            mae = jnp.mean(jnp.abs(diff))
            return mse, mae
        module, opt_state = state
        cur_hp = hp_schedule(opt_state.step.value)
        (mse_val, mae_val), grads = nnx.value_and_grad(criterion, has_aux=True)(module, xb, yb)
        opt_state.update(module, grads)
        return (module, opt_state), (mse_val, mae_val, cur_hp)
    (net, opt), (loss_hist, mae_hist, hp_series) = train_scan((net, opt), x_head, y_obs1)
    print('AFTER TRAINING')
    print('training loss:', loss_hist[-1])
    y_pred1, y_pred2 = net(xgrid)
    err = y_obs2 - y_pred2
    test_loss = jnp.mean(err * err)
    print('test loss:', test_loss)
    print('m approximation:', net.kernel.value)
    print('b approximation:', net.bias.value)
What’s the actual problem?
The requirement is to drive a model hyperparameter with a schedule during training, ideally as a smooth function of the training progress. It can be tempting to assume the schedule needs to be embedded into the optimizer itself. That is unnecessary here: schedules in Optax are plain functions that map a step index to a value. The only missing piece is a reliable step counter inside the training loop.
Why the schedule alone wasn’t enough
A schedule needs a monotonically increasing step to produce the current hyperparameter value. Without threading the current optimization step into it, the schedule has nothing to evaluate. The confusion often comes from looking for a dedicated “injection” mechanism, but the schedule is already the correct abstraction. You just call it.
The working approach
The optimizer state holds the current iteration counter. Accessing it via optimizer.step.value inside the scanned training function provides exactly what the schedule expects. Call the schedule with that value to get the per-step hyperparameter, and use or record it as needed. In the example, the value is returned from the scan so that hp_series becomes the stacked time series of the scheduled hyperparameter across training.
Why this detail matters
Being able to schedule hyperparameters directly against the optimization step keeps the training loop explicit and predictable. You can evolve quantities gradually over epochs or steps, and you retain full control because the schedule is just a function call in the same place you compute the loss. This helps when you want to pace changes in model behavior alongside parameter updates without bolting logic into the optimizer.
Takeaways
Use the optimizer step counter as the single source of truth for training progress and feed it into the schedule each iteration. Treat schedules as regular callables: evaluate them when you need their current value and propagate that value through your forward pass or metrics. If you need to analyze or log the evolution of a hyperparameter, return it from the scan to collect the full trajectory over training.
The article is based on a question from StackOverflow by user137146 and an answer by user137146.