2025, Oct 31 13:00

How DQN with RNNs in TorchRL Handles Episode Boundaries: Masking Done Flags and Safe Slice Sampling

Learn how DQN with RNNs in TorchRL avoids cross-episode leakage using done masks, replay buffers, and SliceSampler for safe bootstrapped targets in training.

How DQN with RNNs handles episodes and batches in TorchRL often raises a practical question: if a collector stitches multiple episodes into a single batch, do value targets leak across episode boundaries during training? The short answer is no. Proper handling of done/terminated/truncated markers prevents cross-contamination, and replay buffers can sample either single steps or trajectory slices without mixing unrelated data.

Minimal example of the apparent problem

Consider a batch that concatenates two episodes back-to-back. If one aligned next-state values across the whole batch without respecting episode ends, the last transition of episode 1 would accidentally use the first state of episode 2 as its next state. The snippet below demonstrates the pitfall and the correct masking that avoids it.

import torch
# Two back-to-back episodes of length 3 each: indices [0..2] and [3..5]
rew = torch.tensor([0.1, 0.0, 1.0, 0.2, -0.1, 0.3])
done = torch.tensor([0,   0,   1,   0,    0,    1], dtype=torch.float32)
gamma = 0.99
# Pretend these are max_a' Q(s_{t+1}, a') for each position t
q_next = torch.tensor([0.5, 0.4, 0.9, 0.7, 0.6, 0.3])
# Align next-state values to current transitions via a left shift
q_next_shift = torch.roll(q_next, shifts=-1)
q_next_shift[-1] = 0.0  # padding for the final element in the batch
# Naive target ignores episode boundaries (WRONG)
target_naive = rew + gamma * q_next_shift
# Correct target masks terminals: no bootstrap beyond done
# This is the key to avoiding cross-contamination
target_masked = rew + gamma * (1.0 - done) * q_next_shift
print("target_naive:", target_naive)
print("target_masked:", target_masked)

The naive computation uses the next entry even when the current step is terminal, which would implicitly bridge two unrelated episodes. The masked computation zeroes out bootstrapping at terminals, so targets stop exactly at done/terminated/truncated steps.

What actually happens with TorchRL collectors and losses

Collectors may return batches that combine splits of different trajectories. Feeding such a batch to objectives that understand temporal data is safe because they rely on done/terminated/truncated markers to separate trajectories and prevent any cross-episode influence. In particular, DQNLoss writes data to a replay buffer, and training then samples either individual transitions or entire trajectory slices. When a trajectory slice is needed, using a SliceSampler ensures slices stay within episode boundaries. In both cases there is no cross-contamination.

Practical fix if you implement targets manually

If you ever compute bootstrapped targets yourself instead of delegating to a provided loss, make sure to apply the episodic mask. This is the same idea shown above and is the reason why stacking multiple episodes in a single batch is fine as long as done markers are respected. Below is a compact pattern for masking that mirrors the safe behavior.

def safe_dqn_targets(r_t, done_t, q_next_t, gamma):
    # r_t: [T] rewards
    # done_t: [T] {0,1} flags with 1 at terminal steps
    # q_next_t: [T] aligned Q(s_{t+1}) as in the previous example
    # gamma: scalar discount
    return r_t + gamma * (1.0 - done_t) * q_next_t
# Example reuse with the batch above
targets = safe_dqn_targets(rew, done, q_next_shift, gamma)

Sampling transitions vs trajectory slices

Replay buffers in this setup can draw single transitions or contiguous windows that do not cross terminals. When full or partial trajectories are required, a SliceSampler fits the job; it respects episode boundaries so that temporal computations remain local to each trajectory. The following conceptual helper demonstrates how one would enumerate in-episode windows without crossing terminals.

def windows_within_episodes(done_flags, window):
    idx = 0
    spans = []
    n = len(done_flags)
    while idx < n:
        # find episode segment [ep_start, ep_end], inclusive of terminal
        ep_start = idx
        while idx < n and done_flags[idx].item() == 0:
            idx += 1
        ep_end = idx  # terminal at ep_end
        # generate fixed-size windows fully inside [ep_start, ep_end]
        for s in range(ep_start, ep_end + 1):
            e = s + window
            if e - 1 > ep_end:
                break
            spans.append((s, e))
        idx += 1  # move past terminal
    return spans
# Example: all length-2 windows that never cross episode boundaries
spans = windows_within_episodes(done, window=2)
print(spans)

This mirrors what a trajectory-aware sampler does conceptually. In production you would rely on the built-in SliceSampler so the replay buffer returns contiguous temporal chunks from a single trajectory only.

Related building blocks in TorchRL

SliceSampler is designed to sample slices of trajectories safely. Temporal objectives such as GAE illustrate how to operate on stacked trajectories while respecting done/terminated/truncated markers. There are also LLM collectors that can yield full trajectories only, and that capability could be generalized to other collectors.

Why this matters

When using RNNs or any temporally aware model in reinforcement learning, clarity about episode boundaries is essential. Misaligned targets near terminals and windows that cross episodes can silently degrade learning. Correct masking and boundary-aware sampling ensure that stacked batches remain efficient while preserving the integrity of each trajectory.

Takeaways

Stitched batches from collectors are safe to feed into TorchRL objectives because done/terminated/truncated markers prevent cross-episode leakage. DQNLoss workflows write data into a replay buffer and then sample individual steps or within-episode slices, and with a SliceSampler you stay inside a single trajectory. If you implement any bootstrapped target by hand, apply the episodic mask so no computation crosses terminal steps. That is enough to enjoy the efficiency of large batched training without mixing trajectories.

The article is based on a question from StackOverflow by Ícaro Lorran and an answer by vmoens.