2025, Dec 28 15:02

Списки Python vs jnp.array в JAX: как dtype влияет на обучение RNN

Почему списки Python и jnp.array в JAX дают разную точность (float32/float64) и траектории RNN на Flax/Optax; как закрепить dtype и стабилизировать обучение.

Когда списки Python встречаются с массивами JAX: почему крошечная деталь в типах срывает оптимизацию

Замена функции препроцессинга, которая возвращает список списков, на вариант, возвращающий список jnp.array, кажется безобидным рефакторингом. Однако в обучающих циклах на JAX эта тонкость способна заметно изменить результат. В одном конвейере оптимизации RNN на Flax и Optax тот же пайплайн давал метрику около 0.9997, когда параметры подготавливались как список списков, и лишь примерно 0.998, когда те же параметры формировались как список jnp.array. Все прочее оставалось неизменным: сиды, число шагов и итераций, а также исходный словарь параметров.

Этот разбор показывает, откуда берётся расхождение, как слабо типизированные скаляры в JAX ненавязчиво уводят вас в разные режимы точности и что сделать, чтобы обучение вело себя стабильно.

Минимальный пример, который воспроизводит расхождение

Пайплайн стартует с вложенного словаря скалярных параметров. На каждом шаге времени словарь «сплющивается» и подаётся в оптимизационную петлю на RNN. Пример структуры входа:

initial_params = {
    "param1": {
        "gamma": 0.1,
        "delta": -3 * jnp.pi / 2,
    }
}

Функцию препроцессинга реализовали в двух вариантах. В одном она возвращает список списков; в другом — список jnp.array. Логика идентична, отличается лишь тип контейнера.

Вариант, который возвращает списки Python из листьев:

import jax
import jax.numpy as jnp

def pack_params_from_mapping(cfg_tree):
    """
    Convert a nested mapping of parameters to a flat Python list and record segment lengths.

    Args:
        cfg_tree: Nested mapping of parameters.

    Returns:
        tuple: list of lists of leaves, and list of segment lengths.
    """
    packed = []
    seg_lengths = []
    for branch in cfg_tree.values():
        leaf_vals = jax.tree_util.tree_leaves(branch)
        packed.append(leaf_vals)
        seg_lengths.append(len(leaf_vals))
    return packed, seg_lengths

Вариант, который возвращает массивы JAX из листьев:

import jax
import jax.numpy as jnp

def pack_params_from_mapping(cfg_tree):
    """
    Convert a nested mapping of parameters to a flat list of jnp.array and record segment lengths.

    Args:
        cfg_tree: Nested mapping of parameters.

    Returns:
        tuple: list of jnp.array leaves, and list of segment lengths.
    """
    packed = []
    seg_lengths = []
    for branch in cfg_tree.values():
        leaf_vals = jax.tree_util.tree_leaves(branch)
        arr = jnp.array(leaf_vals)
        packed.append(arr)
        seg_lengths.append(arr.shape[0])
    return packed, seg_lengths

Что происходит на самом деле: слабо типизированные скаляры и точность

Существенное различие между подходами не в контейнерах как таковых, а в типах данных (dtype), попадающих в последующие вычисления. Числа с плавающей точкой Python внутри списков JAX трактует как слабо типизированные значения. На практике слабо типизированные скаляры подстраивают свой dtype под то, с чем они комбинируются. В смешанных выражениях это незаметно тянет вычисления к более низкой точности.

Посмотрите на короткую самодостаточную иллюстрацию. При включённой 64-битной точности список чисел Python остаётся слабо типизированным, тогда как явный массив фиксирует dtype и далее ведёт к более точным вычислениям.

import jax
import jax.numpy as jnp

jax.config.update('jax_enable_x64', True)

val_list = [0.1, -4.71238898]
val_array = jnp.array(val_list)

x32 = jnp.float32(1.0)

# Смешивание массива float32 с числом Python (слабый тип)
# даёт результат в float32.
x32 + val_list[1]

# Смешивание массива float32 с элементом jnp.array (строгий тип)
# здесь приводит к результату float64.
x32 + val_array[1]

Иными словами, списки чисел Python могут уводить часть вычислений в float32, тогда как массивы с теми же значениями закрепляют их в float64. В оптимизационной петле такие мелочи накапливаются и могут заметно менять итоговую целевую метрику, даже если всё остальное детерминировано.

Как добиться одинакового поведения

Практическое решение — явно задавать dtype на границах вашего конвейера данных и не смешивать слабо типизированные числа Python с массивами JAX. С учётом описанного выше есть два простых пути.

Первый вариант — сразу после «сплющивания» структуры параметров конвертировать листья в jnp.array с явным dtype. Так вы фиксируете точность для всех последующих вычислений:

import jax
import jax.numpy as jnp

jax.config.update('jax_enable_x64', True)

def pack_params_from_mapping(cfg_tree):
    packed = []
    seg_lengths = []
    for branch in cfg_tree.values():
        leaf_vals = jax.tree_util.tree_leaves(branch)
        arr = jnp.array(leaf_vals, dtype=jnp.float64)
        packed.append(arr)
        seg_lengths.append(arr.shape[0])
    return packed, seg_lengths

Другой рычаг — инициализировать скаляры в исходном словаре сразу как jnp.float64 (или с выбранным вами типом). Тогда точность останется согласованной, даже если сохранить форму «список списков». Можно также включить 64-битный режим, чтобы JAX, как и NumPy, по умолчанию использовал float64.

Небольшое примечание, упрощающее конвейер: словари — родные PyTree в JAX. В зависимости от сценария, преобразовывать их в списки вовсе не обязательно — tree-утилиты уже понимают вложенные dict.

Почему это важно в обучающих циклах

Динамика оптимизаторов чувствительна к числовой точности. Незаметные различия в правилах приведения типов меняют градиенты и величину шагов, в итоге приводя к иным траекториям сходимости и финальным метрикам. Если на вход иногда попадают слабо типизированные числа Python, а иногда — строго типизированные массивы JAX, вы фактически запускаете два разных численных режима.

Явная фиксация типов обеспечивает повторяемость при прочих равных. К тому же она избавляет от трудных для диагностики расхождений между двумя, казалось бы, идентичными запусками.

Выводы

Если переход от «списка списков» к «списку jnp.array» меняет результат, вероятная причина — обработка JAX слабо типизированных скаляров и возникающие различия в точности. Сделайте dtype однозначным на границе: создавайте массивы с явным dtype для листьев, инициализируйте входы как jnp.float64 (или в выбранной вами точности) либо включайте 64-битный режим, когда это соответствует вашей нагрузке. Согласованность типов по всему конвейеру предотвращает случайные понижения точности и стабилизирует поведение оптимизации.