2025, Oct 03 07:16
Когда JAX перекомпилирует JIT: pytrees, scan и лямбда‑замыкания
Разбираем, что запускает перекомпиляцию JAX JIT: как pytrees влияют на ключ кеша, scan и лямбда‑замыканий, и как держать статическое и динамическое раздельно.
scan JAX, pytrees, лямбда‑замыкания и кеш JIT часто встречаются в одном и том же фрагменте кода — и тогда возникает вопрос: приведёт ли захват pytree через лямбда‑функцию к перекомпиляции каждый раз при передаче нового экземпляра? Разберёмся, что именно запускает перекомпиляцию JIT, как pytrees влияют на ключ кеша и почему всё будет корректно, если держать нужные части стабильными.
Минимальный пример, из которого возникает вопрос
Следующая программа выполняет scan по входному вектору. Логику обновления состояния мы упаковали в объект‑pytree, который содержит и массив, и скаляр, а сам объект замыкается лямбдой внутри функции под JIT.
import jax
import jax.numpy as jnp
from jax import tree_util
class ConfigBox:
    def __init__(self, vec, scale):
        self.vec = vec
        self.scale = scale
    def step_fn(self, st, inp):
        s = st
        y = inp
        nxt = (self.vec + s + jnp.ones(self.scale)) * y
        return nxt
    def _tree_flatten(self):
        leaves = (self.vec,)
        meta = {'scale': self.scale}
        return (leaves, meta)
    @classmethod
    def _tree_unflatten(cls, meta, leaves):
        return cls(*leaves, **meta)
tree_util.register_pytree_node(
    ConfigBox,
    ConfigBox._tree_flatten,
    ConfigBox._tree_unflatten,
)
def step_wrapper(cfg, st, inp):
    s = st
    y = inp
    s_next = cfg.step_fn(s, y)
    return s_next, [s_next]
@jax.jit
def run_scan_jit(cfg):
    body = lambda st, inp: step_wrapper(cfg, st, inp)
    st0 = jnp.array([0., 1.])
    seq = jnp.array([1., 2., 3.])
    last, out = jax.lax.scan(body, st0, seq)
    return last, out
if __name__ == "__main__":
    cfg1 = ConfigBox(jnp.array([1., 2.]), 2)
    last_state, outputs = run_scan_jit(cfg1)
    print(last_state)
    cfg2 = ConfigBox(jnp.array([3., 4.]), 2)
    last_state, outputs = run_scan_jit(cfg2)
    print(last_state)
Что на самом деле запускает перекомпиляцию
В JAX на попадания в кеш JIT для pytrees влияет то, как они разворачиваются. Значения, возвращаемые во вспомогательной части разворачивания, трактуются как статические. Значения, возвращаемые как потомки, считаются динамическими. Статические элементы попадают в ключ кеша по их Python‑значению. Динамические элементы влияют на ключ через структуру массива (shape, dtype, sharding), но не через содержимое.
Применительно к примеру выше meta хранит scale и является статической. Изменение значения scale вызовет перекомпиляцию. В leaves лежит vec — это динамика. Изменение значений внутри vec перекомпиляции не вызовет, но изменение его формы, dtype или схемы шардинга — вызовет. В двух показанных вызовах оба экземпляра vec имеют одинаковые shape, dtype и sharding, а scale не меняется; следовательно, перекомпиляции нет.
Ещё одно типичное опасение — влияет ли лямбда, в которую «запекается» cfg, на ключ кеша JIT. В этой конфигурации — нет, потому что сама лямбда не входит в результат разворачивания pytree и, значит, на ключ не воздействует.
Решение: разделяйте статическое и динамическое
В данном сценарии ничего менять не нужно, чтобы избежать перекомпиляции. Пока статические поля сохраняют то же значение, а динамические массивы — ту же форму, dtype и sharding, JIT переиспользует уже скомпилированный исполняемый код. При желании можно подтвердить повторное использование кеша подходом, указанным здесь: https://stackoverflow.com/a/70127930/2937831. Также можно заметить, что размер кеша jitted‑функции остаётся равным 1 после обоих вызовов в этой программе.
Тот же код, без изменений, соответствует этим условиям:
import jax
import jax.numpy as jnp
from jax import tree_util
class ConfigBox:
    def __init__(self, vec, scale):
        self.vec = vec
        self.scale = scale
    def step_fn(self, st, inp):
        return (self.vec + st + jnp.ones(self.scale)) * inp
    def _tree_flatten(self):
        leaves = (self.vec,)
        meta = {'scale': self.scale}
        return (leaves, meta)
    @classmethod
    def _tree_unflatten(cls, meta, leaves):
        return cls(*leaves, **meta)
tree_util.register_pytree_node(
    ConfigBox,
    ConfigBox._tree_flatten,
    ConfigBox._tree_unflatten,
)
def step_wrapper(cfg, st, inp):
    nxt = cfg.step_fn(st, inp)
    return nxt, [nxt]
@jax.jit
def run_scan_jit(cfg):
    body = lambda st, inp: step_wrapper(cfg, st, inp)
    st0 = jnp.array([0., 1.])
    seq = jnp.array([1., 2., 3.])
    return jax.lax.scan(body, st0, seq)
Почему это важно
Понимание, какие части статические, а какие динамические, — это разница между ровными, предсказуемыми запусками и неожиданными всплесками перекомпиляций. Значения массивов можно менять свободно — кеш от этого не инвалидируется; а вот изменение форм, типов dtype или схем шардинга — инвалидирует. Скаляры, помещённые в статическую часть pytree, приведут к перекомпиляции при изменении их значений. Захват значений лямбдой в таком паттерне на поведение кеша не влияет.
Итоги
Если pytree участвует в вашей программе на JAX, осознанно решайте, что должно попасть в динамические листья, а что — во вспомогательные статические данные. Держите статические поля неизменными, когда рассчитываете на повторное использование кеша, и сохраняйте согласованность формы и dtype динамических массивов, когда хотите избежать перекомпиляции. Захватывать динамическое значение лямбдой можно — при соблюдении этих условий это безопасно. Если сомневаетесь, проверьте повторное использование кеша указанным выше способом или посмотрите на размер кеша jitted‑функции: он должен оставаться равным 1 для вызовов, где не меняются статические данные и структура массивов.
Статья основана на вопросе на StackOverflow от user1168149 и ответе jakevdp.