2025, Dec 28 15:02
Списки Python vs jnp.array в JAX: как dtype влияет на обучение RNN
Почему списки Python и jnp.array в JAX дают разную точность (float32/float64) и траектории RNN на Flax/Optax; как закрепить dtype и стабилизировать обучение.
Когда списки Python встречаются с массивами JAX: почему крошечная деталь в типах срывает оптимизацию
Замена функции препроцессинга, которая возвращает список списков, на вариант, возвращающий список jnp.array, кажется безобидным рефакторингом. Однако в обучающих циклах на JAX эта тонкость способна заметно изменить результат. В одном конвейере оптимизации RNN на Flax и Optax тот же пайплайн давал метрику около 0.9997, когда параметры подготавливались как список списков, и лишь примерно 0.998, когда те же параметры формировались как список jnp.array. Все прочее оставалось неизменным: сиды, число шагов и итераций, а также исходный словарь параметров.
Этот разбор показывает, откуда берётся расхождение, как слабо типизированные скаляры в JAX ненавязчиво уводят вас в разные режимы точности и что сделать, чтобы обучение вело себя стабильно.
Минимальный пример, который воспроизводит расхождение
Пайплайн стартует с вложенного словаря скалярных параметров. На каждом шаге времени словарь «сплющивается» и подаётся в оптимизационную петлю на RNN. Пример структуры входа:
initial_params = {
"param1": {
"gamma": 0.1,
"delta": -3 * jnp.pi / 2,
}
}
Функцию препроцессинга реализовали в двух вариантах. В одном она возвращает список списков; в другом — список jnp.array. Логика идентична, отличается лишь тип контейнера.
Вариант, который возвращает списки Python из листьев:
import jax
import jax.numpy as jnp
def pack_params_from_mapping(cfg_tree):
"""
Convert a nested mapping of parameters to a flat Python list and record segment lengths.
Args:
cfg_tree: Nested mapping of parameters.
Returns:
tuple: list of lists of leaves, and list of segment lengths.
"""
packed = []
seg_lengths = []
for branch in cfg_tree.values():
leaf_vals = jax.tree_util.tree_leaves(branch)
packed.append(leaf_vals)
seg_lengths.append(len(leaf_vals))
return packed, seg_lengths
Вариант, который возвращает массивы JAX из листьев:
import jax
import jax.numpy as jnp
def pack_params_from_mapping(cfg_tree):
"""
Convert a nested mapping of parameters to a flat list of jnp.array and record segment lengths.
Args:
cfg_tree: Nested mapping of parameters.
Returns:
tuple: list of jnp.array leaves, and list of segment lengths.
"""
packed = []
seg_lengths = []
for branch in cfg_tree.values():
leaf_vals = jax.tree_util.tree_leaves(branch)
arr = jnp.array(leaf_vals)
packed.append(arr)
seg_lengths.append(arr.shape[0])
return packed, seg_lengths
Что происходит на самом деле: слабо типизированные скаляры и точность
Существенное различие между подходами не в контейнерах как таковых, а в типах данных (dtype), попадающих в последующие вычисления. Числа с плавающей точкой Python внутри списков JAX трактует как слабо типизированные значения. На практике слабо типизированные скаляры подстраивают свой dtype под то, с чем они комбинируются. В смешанных выражениях это незаметно тянет вычисления к более низкой точности.
Посмотрите на короткую самодостаточную иллюстрацию. При включённой 64-битной точности список чисел Python остаётся слабо типизированным, тогда как явный массив фиксирует dtype и далее ведёт к более точным вычислениям.
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)
val_list = [0.1, -4.71238898]
val_array = jnp.array(val_list)
x32 = jnp.float32(1.0)
# Смешивание массива float32 с числом Python (слабый тип)
# даёт результат в float32.
x32 + val_list[1]
# Смешивание массива float32 с элементом jnp.array (строгий тип)
# здесь приводит к результату float64.
x32 + val_array[1]
Иными словами, списки чисел Python могут уводить часть вычислений в float32, тогда как массивы с теми же значениями закрепляют их в float64. В оптимизационной петле такие мелочи накапливаются и могут заметно менять итоговую целевую метрику, даже если всё остальное детерминировано.
Как добиться одинакового поведения
Практическое решение — явно задавать dtype на границах вашего конвейера данных и не смешивать слабо типизированные числа Python с массивами JAX. С учётом описанного выше есть два простых пути.
Первый вариант — сразу после «сплющивания» структуры параметров конвертировать листья в jnp.array с явным dtype. Так вы фиксируете точность для всех последующих вычислений:
import jax
import jax.numpy as jnp
jax.config.update('jax_enable_x64', True)
def pack_params_from_mapping(cfg_tree):
packed = []
seg_lengths = []
for branch in cfg_tree.values():
leaf_vals = jax.tree_util.tree_leaves(branch)
arr = jnp.array(leaf_vals, dtype=jnp.float64)
packed.append(arr)
seg_lengths.append(arr.shape[0])
return packed, seg_lengths
Другой рычаг — инициализировать скаляры в исходном словаре сразу как jnp.float64 (или с выбранным вами типом). Тогда точность останется согласованной, даже если сохранить форму «список списков». Можно также включить 64-битный режим, чтобы JAX, как и NumPy, по умолчанию использовал float64.
Небольшое примечание, упрощающее конвейер: словари — родные PyTree в JAX. В зависимости от сценария, преобразовывать их в списки вовсе не обязательно — tree-утилиты уже понимают вложенные dict.
Почему это важно в обучающих циклах
Динамика оптимизаторов чувствительна к числовой точности. Незаметные различия в правилах приведения типов меняют градиенты и величину шагов, в итоге приводя к иным траекториям сходимости и финальным метрикам. Если на вход иногда попадают слабо типизированные числа Python, а иногда — строго типизированные массивы JAX, вы фактически запускаете два разных численных режима.
Явная фиксация типов обеспечивает повторяемость при прочих равных. К тому же она избавляет от трудных для диагностики расхождений между двумя, казалось бы, идентичными запусками.
Выводы
Если переход от «списка списков» к «списку jnp.array» меняет результат, вероятная причина — обработка JAX слабо типизированных скаляров и возникающие различия в точности. Сделайте dtype однозначным на границе: создавайте массивы с явным dtype для листьев, инициализируйте входы как jnp.float64 (или в выбранной вами точности) либо включайте 64-битный режим, когда это соответствует вашей нагрузке. Согласованность типов по всему конвейеру предотвращает случайные понижения точности и стабилизирует поведение оптимизации.