2025, Nov 01 15:46

Как избежать утечки таргетов между эпизодами в DQN с RNN в TorchRL

Разбираем, почему в DQN с RNN в TorchRL таргеты не «перетекают» через done: корректное маскирование, буфер воспроизведения и SliceSampler для безопасных батчей

То, как DQN с RNN работает с эпизодами и батчами в TorchRL, часто вызывает практический вопрос: если коллектор сшивает несколько эпизодов в один батч, «утекают» ли целевые значения через границы эпизодов во время обучения? Короткий ответ: нет. Корректная обработка маркеров done/terminated/truncated не допускает перекрёстного влияния, а буферы воспроизведения могут выбирать либо отдельные шаги, либо срезы траекторий, не смешивая несвязанные данные.

Минимальный пример кажущейся проблемы

Представьте батч, в котором два эпизода просто стоят подряд. Если механически выровнять значения следующего состояния по всему батчу, не учитывая конец эпизода, то последний переход первого эпизода по ошибке возьмёт в качестве следующего состояния первое состояние второго эпизода. Фрагмент ниже показывает эту ловушку и корректное маскирование, которое её устраняет.

import torch
# Два идущих подряд эпизода по 3 шага: индексы [0..2] и [3..5]
rew = torch.tensor([0.1, 0.0, 1.0, 0.2, -0.1, 0.3])
done = torch.tensor([0,   0,   1,   0,    0,    1], dtype=torch.float32)
gamma = 0.99
# Будем считать, что это значения max_a' Q(s_{t+1}, a') для каждой позиции t
q_next = torch.tensor([0.5, 0.4, 0.9, 0.7, 0.6, 0.3])
# Выравниваем значения следующего состояния с текущими переходами сдвигом влево
q_next_shift = torch.roll(q_next, shifts=-1)
q_next_shift[-1] = 0.0  # заполнитель для последнего элемента батча
# Наивная цель игнорирует границы эпизодов (НЕПРАВИЛЬНО)
target_naive = rew + gamma * q_next_shift
# Правильная цель маскирует терминальные шаги: без бутстрепа за пределами done
# Это ключ к избеганию перекрёстного смешения
target_masked = rew + gamma * (1.0 - done) * q_next_shift
print("target_naive:", target_naive)
print("target_masked:", target_masked)

Наивный расчёт берёт следующий элемент даже тогда, когда текущий шаг терминальный, тем самым неявно «сшивая» два не связанные эпизода. Маскирование обнуляет бутстреп на терминалах, поэтому цели обрываются точно на шагах done/terminated/truncated.

Как это действительно устроено в коллекторах и функциях потерь TorchRL

Коллекторы могут возвращать батчи, в которых смешаны фрагменты различных траекторий. Передавать такой батч в цели/лоссы, учитывающие временную структуру, безопасно: они опираются на маркеры done/terminated/truncated, чтобы разделять траектории и исключать влияние между эпизодами. В частности, DQNLoss записывает данные в буфер воспроизведения, а обучение затем выбирает либо отдельные переходы, либо целые срезы траекторий. Когда нужны срезы, использование SliceSampler гарантирует, что окна не выходят за границы эпизодов. В обоих случаях перекрёстного смешения нет.

Практический рецепт, если вы вычисляете таргеты вручную

Если вы когда-либо вычисляете бутстрепные таргеты самостоятельно, а не делегируете это готовой функции потерь, обязательно применяйте эпизодную маску. Это та же идея, что показана выше, — именно поэтому можно складывать несколько эпизодов в один батч, пока соблюдаются маркеры done. Ниже — компактный шаблон маскирования с тем же безопасным поведением.

def safe_dqn_targets(r_t, done_t, q_next_t, gamma):
    # r_t: [T] вознаграждения
    # done_t: [T] флаги {0,1}, где 1 на терминальных шагах
    # q_next_t: [T] выровненные Q(s_{t+1}), как в предыдущем примере
    # gamma: скалярный дисконт
    return r_t + gamma * (1.0 - done_t) * q_next_t
# Пример повторного использования с батчем выше
targets = safe_dqn_targets(rew, done, q_next_shift, gamma)

Выбор отдельных переходов и срезов траекторий

В такой схеме буферы воспроизведения могут выбирать единичные переходы или непрерывные окна, которые не пересекают терминалы. Когда требуются полные или частичные траектории, подходит SliceSampler: он уважает границы эпизодов, так что временные вычисления остаются локальными для каждой траектории. Ниже — концептуальный помощник, показывающий, как перечислять окна внутри эпизода, не переходя через терминалы.

def windows_within_episodes(done_flags, window):
    idx = 0
    spans = []
    n = len(done_flags)
    while idx < n:
        # находим отрезок эпизода [ep_start, ep_end], включая терминальный шаг
        ep_start = idx
        while idx < n and done_flags[idx].item() == 0:
            idx += 1
        ep_end = idx  # терминал на ep_end
        # генерируем окна фиксированного размера полностью внутри [ep_start, ep_end]
        for s in range(ep_start, ep_end + 1):
            e = s + window
            if e - 1 > ep_end:
                break
            spans.append((s, e))
        idx += 1  # переходим за терминальный шаг
    return spans
# Пример: все окна длины 2, которые не пересекают границы эпизодов
spans = windows_within_episodes(done, window=2)
print(spans)

Это концептуально отражает работу семплера, понимающего траектории. В продакшене вы будете полагаться на встроенный SliceSampler, чтобы буфер воспроизведения возвращал только непрерывные временные фрагменты из одной траектории.

Связанные компоненты в TorchRL

SliceSampler предназначен для безопасной выборки срезов траекторий. Временные цели, такие как GAE, показывают, как работать со «стеком» траекторий, уважая маркеры done/terminated/truncated. Есть и LLM‑коллекторы, которые умеют выдавать только полные траектории; этот подход можно обобщить и на другие коллекторы.

Почему это важно

При использовании RNN или любых моделей, чувствительных ко времени, в обучении с подкреплением важно чётко знать границы эпизодов. Неверно выровненные таргеты рядом с терминалами и окна, пересекающие эпизоды, могут незаметно ухудшать обучение. Корректное маскирование и выборка с учётом границ гарантируют, что склеенные батчи остаются эффективными и при этом сохраняют целостность каждой траектории.

Выводы

Сшитые батчи из коллекторов можно безопасно передавать в цели TorchRL, поскольку маркеры done/terminated/truncated предотвращают утечку между эпизодами. Рабочий процесс с DQNLoss записывает данные в буфер воспроизведения, а затем выбирает отдельные шаги или срезы внутри эпизода; с SliceSampler вы всегда остаетесь в пределах одной траектории. Если вы рассчитываете любой бутстрепный таргет вручную, применяйте эпизодную маску, чтобы вычисления не переходили через терминальные шаги. Этого достаточно, чтобы пользоваться эффективностью крупного батч‑обучения, не смешивая траектории.

Статья основана на вопросе с StackOverflow от Ícaro Lorran и ответе от vmoens.