2025, Oct 18 01:31

Численный дрейф в JAX vmap на GPU: влияние размера батча и порядка операций

Почему JAX vmap на GPU в float32 даёт разные результаты при смене размера батча: влияние порядка операций и округления. Решения: float64 или запуск на CPU.

Небольшой численный дрейф при применении vmap к нейросети на разных размерах батча может озадачить, особенно если исходные входы идентичны. Если вы замечаете малые расхождения в первых строках результата в зависимости от того, сколько строк вы передаёте в функцию с vmap, такое поведение может быть корректным. Короткий ответ в том, что вычисления с плавающей точкой чувствительны к порядку операций: меняется порядок — меняется и профиль округления. На GPU с float32 это проявляется легко, тогда как на CPU или при float64 оно может исчезнуть.

Воспроизведение проблемы

Фрагмент ниже дважды применяет Equinox MLP к батчу: сначала к небольшому срезу, затем ко всему массиву, после чего сравнивает первые строки обоих результатов. Отличаются лишь имена относительно исходного паттерна; поведение программы такое же.

import jax
import jax.numpy as jnp
import equinox as eqx

def batch_apply(arr_in, net_fn):
    mapped = eqx.filter_vmap(net_fn.__call__)
    return mapped(arr_in)

rng = jax.random.PRNGKey(0)
rng, key_net = jax.random.split(rng, 2)
model = eqx.nn.MLP(2, 2, 10, 2, key=key_net)

rng, key_x = jax.random.split(rng, 2)
xb = jax.random.normal(key_x, (10000, 2))

delta = batch_apply(xb[:10], model) - batch_apply(xb, model)[:10]
print("eqx error:", delta)

Что именно вызывает расхождение

Это ожидаемо и не специфично для vmap как трансформации. Дело в арифметике с плавающей точкой. В float32 на каждом шаге накапливается ошибка округления. Когда вы выполняете «те же» вычисления по разным вычислительным путям, эти ошибки накапливаются в разных местах и в ином порядке, что слегка сдвигает результат. Изменение размера батча меняет то, как вычисления планируются и сливаются, поэтому фактический порядок операций отличается.

Устройство тоже имеет значение, потому что порядок операций зависит от архитектуры. На CPU выполнение, как правило, следует более последовательному порядку аккумуляции, что на практике может сохранять порядок операций стабильным при разных размерах батча. На GPU выполнение сильно параллельно, а разбиение работы и схемы накопления зависят от формы входа. Эта разница в последовательности ведёт к иному округлению и, как следствие, к немного отличающимся числам в float32.

Что с этим делать

Само по себе поведение корректно. Расхождение возникает из-за допустимых различий округления с плавающей точкой при ином порядке операций. В наблюдаемой настройке использование float64 устранило расхождение, как и запуск на CPU; оба варианта достаточно изменяют численные характеристики или последовательность, чтобы убрать видимый дрейф в этом примере.

Почему это важно

Не стоит считать, что батчевые вычисления на GPU в float32 будут побитово детерминированными при изменении формы входа. Если сравнивать результаты между размерами батча или устройствами и ожидать точного равенства, можно принять обычное поведение чисел с плавающей точкой за логическую ошибку. Понимание того, что планирование, батчинг и параллелизм устройства влияют на порядок накопления, помогает выстроить корректные ожидания и диагностику.

Вывод

Небольшие различия значений между размерностями vmapped-батчей на GPU в float32 — естественное следствие округления с плавающей точкой, накапливающегося при разном порядке операций. Вычисление при этом остаётся корректным. Если для вашей задачи требуется более тесное численное совпадение, в наблюдаемой конфигурации переключение на float64 или выполнение на CPU устраняли различие. В противном случае принимайте крошечные отклонения как нормальную часть параллельных вычислений с плавающей точкой и оценивайте равенство с учётом этого.

Статья основана на вопросе с StackOverflow от hvater и ответе от jakevdp.