2025, Sep 21 20:01

Как управлять гиперпараметрами по расписанию в Optax (JAX/Flax nnx)

Пошаговый пример на JAX и Flax nnx показывает, как с помощью расписаний Optax управлять гиперпараметрами из счетчика шагов оптимизатора. Код и советы.

Управлять гиперпараметром модели в процессе обучения с помощью расписания Optax несложно, если понимать, откуда брать текущий номер шага. Главное — расписания представляют собой обычные вызываемые функции, а оптимизатор предоставляет счетчик шагов, который можно передавать в эти функции. Ниже — минимальный сквозной пример на JAX + Flax nnx + Optax, показывающий, как на каждом шаге оптимизации получать значение из расписания и собирать его в ходе обучения. Пример рассчитан на Flax 0.11.

Полный пример

from jax import numpy as jnp
from jax import random
from flax import nnx
import optax
from matplotlib import pyplot as plt

if __name__ == '__main__':
    arr_shape = (2,55,1)
    num_epochs = 123
    keyset = nnx.Rngs(123)
    net = nnx.Linear(1, 1, rngs=keyset)

    prng_key = keyset.params()
    xgrid = random.uniform(prng_key, arr_shape, minval=-10, maxval=10)

    def gen_line(x, m=2.234, b=-1.123):
        return m * x + b
    y_obs1, y_obs2 = gen_line(xgrid)
    x_head, x_tail = xgrid

    decay_ratio = 0.9
    lr_schedule = optax.schedules.cosine_decay_schedule(
        init_value=2e-1,
        decay_steps=int(decay_ratio * num_epochs),
        alpha=0.01,
    )

    hp_schedule = optax.schedules.linear_schedule(
        init_value=12,
        end_value=234,
        transition_steps=int(decay_ratio * num_epochs),
    )

    trainable = nnx.Param
    opt = nnx.Optimizer(
        net,
        tx=optax.adam(lr_schedule),
        wrt=trainable,
    )

    @nnx.scan(
        in_axes=(nnx.Carry, None, None),
        out_axes=(nnx.Carry, 0),
        length=num_epochs,
    )
    def train_scan(state, xb, yb):
        def criterion(module, inputs, targets):
            preds = module(inputs)
            diff = targets - preds
            mse = jnp.mean(diff ** 2)
            mae = jnp.mean(jnp.abs(diff))
            return mse, mae

        module, opt_state = state
        cur_hp = hp_schedule(opt_state.step.value)
        (mse_val, mae_val), grads = nnx.value_and_grad(criterion, has_aux=True)(module, xb, yb)
        opt_state.update(module, grads)
        return (module, opt_state), (mse_val, mae_val, cur_hp)

    (net, opt), (loss_hist, mae_hist, hp_series) = train_scan((net, opt), x_head, y_obs1)

    print('AFTER TRAINING')
    print('training loss:', loss_hist[-1])

    y_pred1, y_pred2 = net(xgrid)
    err = y_obs2 - y_pred2
    test_loss = jnp.mean(err * err)
    print('test loss:', test_loss)
    print('m approximation:', net.kernel.value)
    print('b approximation:', net.bias.value)

В чем собственно задача?

Задача — управлять гиперпараметром модели по расписанию во время обучения, желательно как плавной функцией от прогресса обучения. Легко предположить, что расписание нужно «встраивать» прямо в оптимизатор. Здесь это не требуется: в Optax расписания — это обычные функции, которые отображают индекс шага в значение. Единственное, что нужно, — надежный счетчик шагов внутри цикла обучения.

Почему одного расписания недостаточно

Чтобы выдать текущее значение гиперпараметра, расписанию нужен монотонно растущий номер шага. Если не передать ему текущий шаг оптимизации, вычислять будет нечего. Путаница часто возникает из‑за поиска специального «механизма внедрения», но расписание уже является нужной абстракцией. Его просто нужно вызвать.

Рабочий подход

Состояние оптимизатора хранит текущий номер итерации. Доступ к нему через optimizer.step.value внутри сканируемой функции обучения дает именно то, что ожидает расписание. Передайте это значение в расписание, чтобы получить гиперпараметр на данном шаге, и далее используйте или сохраняйте его. В примере значение возвращается из scan, поэтому hp_series оказывается собранным временным рядом запланированного гиперпараметра по всем шагам обучения.

Почему это важно

Привязка гиперпараметров к шагу оптимизации делает цикл обучения прозрачным и предсказуемым. Значения можно плавно изменять по эпохам или шагам, оставаясь полного контроля: расписание — это просто вызов функции там же, где считается loss. Это удобно, когда нужно согласованно менять поведение модели вместе с обновлением параметров, не встраивая дополнительную логику в оптимизатор.

Итоги

Используйте счетчик шагов оптимизатора как единственный источник истины о прогрессе обучения и передавайте его в расписание на каждой итерации. Относитесь к расписаниям как к обычным вызываемым функциям: вычисляйте их по мере необходимости и прокидывайте полученное значение в прямой проход или метрики. Если важно анализировать или логировать динамику гиперпараметра, возвращайте его из scan, чтобы получить полную траекторию за время обучения.

Статья основана на вопросе на StackOverflow от user137146 и ответе от user137146.