2025, Dec 30 03:02

Когда в JAX использовать scan, а когда Python for под jit

Сравниваем jax.lax.scan и Python for под jit: как выбор влияет на компиляцию и скорость, когда что использовать, настройка unroll, примеры и советы по бенчмарку.

Выбор между циклом for на Python и jax.lax.scan в JAX — это не просто вопрос стиля. Он влияет на поведение при компиляции, производительность на исполнении и на то, как компилятор способен оптимизировать вашу программу. Если вы когда‑либо переключали реализацию цикла туда‑сюда и недоумевали, почему скорость меняется неинтуитивно, этот материал для вас.

Проблема в контексте

Рассмотрим вложенную схему обучения, где внешний тренировочный цикл вызывает функцию, которая сама несколько раз итерируется и выполняет rollout через scan. В такой задаче естественно спросить, стоит ли оставлять внутренний цикл в виде Python for или переписать его на scan, и как этот выбор будет масштабироваться, если окружающий код повторяет его много раз.

for a in range(num_train_steps):
  for b in range(num_env_steps):
    execute()

@jax.jit
def execute():
  for c in range(num_algo_iters):
    jax.lax.scan(rollout_step, init_carry, xs=None, length=inner_scan_len)

В тесте, где num_env_steps менялся (1, 100, 1000, 10000), а функция execute компилировалась через jit, внутренний цикл c поочерёдно реализовывался как for и как scan, при этом самый глубокий rollout оставался на scan. При 5 итерациях цикла c и 2 итерациях во внутреннем scan наблюдались примерные тайминги на уровне act(): 1.5, 11.3, 99.0, 956.2 секунды для варианта с scan в c и 5.1, 14.5, 103.6, 972.7 секунды для варианта с for в c. В этом эксперименте версия с for не стала быстрее даже при росте количества внешних повторений.

Почему так происходит

JAX разворачивает управляющие конструкции Python внутри функций под jit. Цикл for со 100 итерациями превращается в линейную программу с сотней копий тела. Плюс в том, что компилятор может оптимизировать «поперёк» итераций — например, склеивать операции между соседними шагами или выкидывать целые подграфы, если их вывод нигде не используется. Минус в том, что стоимость компиляции растёт сверхлинейно с размером программы, поэтому крупные тела и множество итераций приводят к долгой компиляции.

С jax.lax.scan или jax.lax.fori_loop цикл остаётся внутри HLO. Тело разбирается и компилируется один раз, что делает компиляцию намного эффективнее. Компромисс в том, что у компилятора меньше возможностей оптимизировать между итерациями, поэтому относительно полностью развёрнутого for часть скорости на исполнении может потеряться.

Единого победителя нет. Небольшие тела с малым числом итераций часто лучше работают с for, потому что компилятор может агрессивно оптимизировать поперёк итераций без взрывного времени компиляции. Большие тела или множество итераций чаще выигрывают от scan или fori_loop, поскольку они удерживают стоимость компиляции в разумных пределах.

Полезно также скорректировать ожидания по сложности. Для Python‑циклов for стоит ожидать сверхлинейного роста времени компиляции по мере увеличения развёрнутой программы, тогда как время выполнения может быть сублинейным, а может и нет — это зависит от конкретных операций и эвристик компилятора. В случае scan компиляция обычно гораздо эффективнее, потому что тело компилируется один раз, а сам цикл представлен в HLO; на исполнении же оптимизации между итерациями обычно менее агрессивны, чем у полностью развёрнутого кода.

Минимальный эквивалентный пример

Ниже показаны две версии с одинаковой логикой, но по‑разному выражённым циклом c. Самый внутренний rollout в обоих случаях остаётся scan.

import jax
import jax.numpy as jnp
from jax import lax

# Фиктивное тело rollout для ясности семантики
# состояние на вход, состояние на выход

def rollout_core(carry, _):
  return carry + 1.0, None

# Вариант A: внутренний цикл c в виде Python for
@jax.jit
def run_step_for(init_state, algo_iters, inner_len):
  carry = init_state
  for c in range(algo_iters):
    carry, _ = lax.scan(rollout_core, carry, xs=None, length=inner_len)
  return carry

# Вариант B: внутренний цикл c в виде scan
@jax.jit
def run_step_scan(init_state, algo_iters, inner_len):
  def c_body(carry, _):
    return lax.scan(rollout_core, carry, xs=None, length=inner_len)
  return lax.scan(c_body, init_state, xs=None, length=algo_iters)[0]

Обе версии выдают одно и то же конечное состояние для заданных init_state, algo_iters и inner_len; только в одной цикл c — это Python for, а в другой — scan.

Тонкая настройка компромисса с помощью unroll

У scan есть параметр unroll, который позволяет варьировать поведение между крайними вариантами. Значение unroll=True делает scan почти эквивалентным развёрнутому for. Можно частично разворачивать, передав целое число, например unroll=n в диапазоне 1 < n < total_iterations: фактически это создаёт небольшой развёрнутый цикл внутри каждого шага scan. Так открывается больше возможностей для оптимизаций, а стоимость компиляции остаётся контролируемой.

@jax.jit
def run_step_scan_tuned(init_state, algo_iters, inner_len, outer_unroll, inner_unroll):
  def c_body(carry, _):
    return lax.scan(rollout_core, carry, xs=None, length=inner_len, unroll=inner_unroll)
  return lax.scan(c_body, init_state, xs=None, length=algo_iters, unroll=outer_unroll)[0]

Когда на каком‑то уровне стоит unroll=True, ожидайте поведения, схожего с Python for на этом уровне, включая рост времени компиляции. Если оставить unroll по умолчанию, компиляция обычно значительно эффективнее.

Какой вариант выбрать?

Лучший выбор зависит от вашей программы и приоритетов. Если тело цикла небольшое, а число итераций умеренное, for внутри jit часто показывает хорошие результаты за счёт оптимизаций между итерациями. Если тело крупное или итераций много, scan или fori_loop обычно компилируются гораздо быстрее и в сумме оказываются предпочтительнее. Нет гарантии, что увеличение числа повторов до 100k или миллиона перевернёт результат в пользу for; исход зависит и от компиляции, и от поведения на исполнении именно в вашей нагрузке.

Установка unroll=True в scan делает его по сути эквивалентным for с точки зрения компилятора. Значит, ожидайте того же типа сверхлинейного роста усилий на компиляцию и потенциальной выгоды на исполнении от более широких оптимизаций. Это не универсальный способ ускориться; это лишь движение в сторону полностью развёрнутого варианта.

Корректное измерение производительности имеет значение

Когда замеряете код JAX, важно бенчмаркать правильно. Нужно синхронизироваться с устройством, чтобы измерять реальное время вычислений, а не только постановку задач в очередь. Обычно для этого вызывают .block_until_ready() на результате. См. также официальные рекомендации: JAX FAQ по бенчмаркингу.

Почему это важно знать

То, как представлен цикл, напрямую определяет пространство поиска компилятора. Оставляя цикл в Python, вы даёте компилятору шанс сливать и отбрасывать операции между итерациями, платя за это компиляцией, растущей быстрее линейно. Перемещение цикла в HLO делает компиляцию эффективной и предсказуемой, но сокращает число трансформаций между итерациями. Понимание этого спектра помогает осознанно выбирать подход, а не механически заменять for на scan или наоборот.

Практические рекомендации

Сначала определите, что важнее для вашей задачи: задержка на компиляцию или устоявшаяся скорость на исполнении. Для небольших и коротких циклов Python for внутри jit обычно подходит. Для больших тел или множества итераций предпочтительнее scan или fori_loop. При необходимости тонкой настройки используйте unroll в scan, чтобы балансировать стоимость компиляции и потенциальные оптимизации на рантайме. И всегда бенчмаркйте с правильной синхронизацией, чтобы получать надёжные цифры.

Заключение

Не существует универсального правила, по которому for категорически быстрее scan или наоборот. Рассматривайте выбор как компромисс между стоимостью компиляции и возможностями оптимизации во время выполнения. Используйте scan для масштабируемости, for — для агрессивных межитерационных оптимизаций на небольших задачах, а к unroll обращайтесь, когда нужен середнячок. Измеряйте с .block_until_ready() и по лучшим практикам бенчмаркинга JAX — и пусть именно эти измерения подскажут, что выбрать в вашем коде.