2025, Oct 02 21:17

Параметризация jax.lax.scan: lambda, def или carry — что выбрать

Сравниваем три подхода к jax.lax.scan — lambda, def и передача параметров через carry. Почему в JAX они компилируются одинаково и что выбрать на практике.

Параметризация цикла jax.lax.scan часто приводит к одному и тому же вопросу: захватывать константы через замыкание или прокидывать их через состояние scan? На практике три распространённых приёма — использование lambda, объявление локальной функции через def или передача параметров в carry — компилируются в JAX практически в идентичные пути выполнения. Понимание, почему рантайм видит их одинаково, позволяет выбирать исходя из читаемости, а не из предполагаемых различий в производительности.

Воспроизводимая настройка, демонстрирующая три подхода

import jax
import jax.numpy as jnp
import time

STEPS = 100

def evolve(s, a, dtau, damp):
    v = jnp.array([jnp.cos(s[2]) * a[0], jnp.sin(s[2]) * a[0], a[1]])
    s_next = s + v * dtau - damp * v
    return s_next

def advance(buf, a, dtau, damp):
    s, t = buf
    s_next = evolve(s, a[:, t], dtau, damp)
    return (s_next, t + 1), s_next

# Замыкание через lambda
def run_with_lambda(s0, a, dtau, damp):
    step_fn = lambda carry, _inp: advance(carry, a, dtau, damp)
    init = (s0, 0)
    return jax.lax.scan(step_fn, init, None, length=STEPS)

# Замыкание через def
def run_with_def(s0, a, dtau, damp):
    def step_fn(carry, _inp):
        return advance(carry, a, dtau, damp)
    return jax.lax.scan(step_fn, (s0, 0), None, length=STEPS)

# Передавать всё через carry
def advance_with_carry(buf, _inp):
    s, t, a, dtau, damp = buf
    s_next = evolve(s, a[:, t], dtau, damp)
    return (s_next, t + 1, a, dtau, damp), s_next

def run_with_carry(s0, a, dtau, damp):
    return jax.lax.scan(advance_with_carry, (s0, 0, a, dtau, damp), None, length=STEPS)

# Демонстрационные входные данные
s0 = jnp.array([2., 3., 4.])
rng = jax.random.PRNGKey(0)
a = jax.random.uniform(rng, shape=(2, STEPS))
dtau = 0.01
damp = 1e-6

# Необязательная вспомогательная функция для бенчмарка
def run_bench(g, label):
    g_jit = jax.jit(g)
    carry, traj = g_jit(s0, a, dtau, damp)
    carry[0].block_until_ready()
    t0 = time.time()
    for _ in range(1000):
        carry, ys = g_jit(s0, a, dtau, damp)
        carry[0].block_until_ready()
    t1 = time.time()
    print(f"{label:10s} | cached run: {t1 - t0:.6f}s")

# run_bench(run_with_lambda, "lambda")
# run_bench(run_with_def,    "def")
# run_bench(run_with_carry,  "carry_all")

Что на самом деле происходит в JAX

Все три варианта идиоматичны для JAX, и с точки зрения рантайма нет причин отдавать одному предпочтение. Анонимные функции, созданные через lambda, и именованные функции, объявленные через def, обрабатываются одинаково; JAX не различает их в процессе трассировки и компиляции. Поэтому замыкание, которое захватывает a, dtau и damp, эквивалентно независимо от того, написано ли оно через lambda или def.

Вариант с carry отличается лишь способом передачи тех же значений в шаговую функцию: они протягиваются через carry в scan, а не захватываются как замкнутые константы. Это даёт немного иные параметры при понижении scan, но не меняет скомпилированную функцию так, чтобы это влияло на исполнение.

Ещё один терминологический нюанс — «lambda-assignment» (присваивание lambda). В варианте выше действительно происходит присваивание lambda-переменной. Некоторые гайдлайны по стилю Python советуют его избегать, но для исполнения в JAX это не проблема; следовать этому совету или нет — отдельный вопрос стиля.

Практическое решение

Поскольку на уровне JAX эти варианты равнозначны, выбирайте форму, при которой код проще читать и сопровождать. Если удобнее оформить шаг как локальную функцию, замыкающую константы, — отлично. Если предпочтительнее явно держать параметры в состоянии scan, передавая их через carry, — это так же допустимо. Ниже — компактная версия с локальным def:

import jax
import jax.numpy as jnp

STEPS = 100

def evolve(s, a, dtau, damp):
    v = jnp.array([jnp.cos(s[2]) * a[0], jnp.sin(s[2]) * a[0], a[1]])
    return s + v * dtau - damp * v

def advance(buf, a, dtau, damp):
    s, t = buf
    s_next = evolve(s, a[:, t], dtau, damp)
    return (s_next, t + 1), s_next

def rollout_scan(s0, a, dtau, damp):
    def step(carry, _):
        return advance(carry, a, dtau, damp)
    return jax.lax.scan(step, (s0, 0), None, length=STEPS)

Почему это важно

Понимание того, что lambda, def и carry приводят к практически одинаковому скомпилированному коду, помогает избежать преждевременных оптимизаций и сосредоточиться на самой динамике системы. Это также проясняет, что кажущиеся различия вряд ли связаны с выбором между замыканием и передачей параметров, а трассировка в JAX одинаково относится к анонимным и именованным вызываемым объектам.

Выводы

Используйте тот подход, который делает код понятнее вашей команде. Для JAX lambda и def взаимозаменяемы, а передача параметров через carry вместо их замыкания меняет лишь способ параметризации scan, но не поведение скомпилированной функции. Если сомневаетесь — измерьте на своей конфигурации, но изначального преимущества по времени выполнения у какого‑то одного стиля не ожидается.

Статья основана на вопросе на StackOverflow от user1168149 и ответе jakevdp.