2025, Oct 25 17:00
Fix RLlib CTDE crashes with PrioritizedEpisodeReplayBuffer: use EpisodeReplayBuffer and stable settings
Fix RLlib CTDE multi-agent crashes: replace PrioritizedEpisodeReplayBuffer for complex Dict/Tuple observations with EpisodeReplayBuffer and tuned settings.
Centralized Training, Decentralized Execution (CTDE) with RLlib is a powerful pattern, but glue code around multi-agent grouping and replay buffers can be surprisingly sensitive to observation structure. This guide walks through a practical failure mode that appears when combining Ray RLlib’s GroupAgentsWrapper, a custom RLModule for CTDE, and PrioritizedEpisodeReplayBuffer, and shows a configuration that trains reliably without changing model logic or environment semantics.
Problem overview
A custom RLModule implementing ValueFunctionAPI, TargetNetworkAPI, and QNetAPI is used with PPO, APPO, and SAC. A multi-agent environment is grouped into a single-agent interface via GroupAgentsWrapper. In a test environment (a Rock–Paper–Scissors variant), training succeeds with PPO/APPO/SAC under ray.tune.Tuner. However, in a real environment with the same observation-space structure, training fails when using PrioritizedEpisodeReplayBuffer. The first visible symptom is an assertion in single_agent_episode.concat_episode:
assert np.all(other.observations[0] == self.observations[-1]) — ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
After bypassing that assertion for debugging, training proceeds for a bit and then fails deeper inside the prioritized episode buffer’s segment tree during sampling with an out-of-range prefix-sum assertion.
The grouped observation structure is consistent across test and real settings: a multi-agent Dict({"agent1": Box, "agent2": Box}) becomes a single-agent Dict({"grouped": Tuple(Box, Box)}). Despite that, the prioritized buffer path breaks only on the real environment.
Repro-style example
The following environment mirrors the working Rock–Paper–Scissors test. It keeps the same program logic while using different symbols and names.
class RpsArena(MultiAgentEnv):
    MOVE_R = 0
    MOVE_P = 1
    MOVE_S = 2
    OUTCOME = {
        (MOVE_R, MOVE_R): (0, 0),
        (MOVE_R, MOVE_P): (-1, 1),
        (MOVE_R, MOVE_S): (1, -1),
        (MOVE_P, MOVE_R): (1, -1),
        (MOVE_P, MOVE_P): (0, 0),
        (MOVE_P, MOVE_S): (-1, 1),
        (MOVE_S, MOVE_R): (-1, 1),
        (MOVE_S, MOVE_P): (1, -1),
        (MOVE_S, MOVE_S): (0, 0),
    }
    def __init__(self, env_config=None):
        super().__init__()
        self.agent_names = ["p1", "p2"]
        self.agents = self.possible_agents = self.agent_names
        self.observation_spaces = self.action_spaces = gym.spaces.Dict({
            "p1": gym.spaces.Box(low=0, high=2, shape=(1,)),
            "p2": gym.spaces.Box(low=0, high=2, shape=(1,)),
        })
        self.turns = 0
    def reset(self, *, seed=None, options=None):
        self.turns = 0
        return {
            "p1": np.array([0.0], dtype=np.float32),
            "p2": np.array([0.0], dtype=np.float32),
        }, {}
    def step(self, action_dict):
        self.turns += 1
        m1 = int(action_dict["p1"].item())
        m2 = int(action_dict["p2"].item())
        obs = {
            "p1": np.array([m2], dtype=np.float32),
            "p2": np.array([m1], dtype=np.float32),
        }
        r1, r2 = self.OUTCOME[m1, m2]
        rew = {"p1": r1, "p2": r2}
        dones = {"__all__": bool(self.turns >= 10)}
        truncs = {"__all__": bool(self.turns >= 10)}
        return obs, rew, dones, truncs, {}
class RpsGrouped(MultiAgentEnv):
    def __init__(self, env_config=None):
        super().__init__()
        base = RpsArena(env_config)
        tuple_obs = self._to_tuple_space(base.observation_spaces)
        tuple_act = self._to_tuple_space(base.action_spaces)
        self.core = base.with_agent_groups(
            groups={"pack": ["p1", "p2"]},
            obs_space=tuple_obs,
            act_space=tuple_act,
        )
        self.agent_names = ["pack"]
        self.agents = self.possible_agents = self.agent_names
        self.orig_ids = base.agent_names
        self.observation_space = gym.spaces.Dict({"pack": tuple_obs})
        self.action_space = gym.spaces.Dict({"pack": tuple_act})
    def reset(self, *, seed=None, options=None):
        obs, info = self.core.reset(seed=seed, options=options)
        grouped = {k: tuple(v) for k, v in obs.items()}
        return grouped, info
    def step(self, action_dict):
        obs, rew, dones, truncs, info = self.core.step(action_dict)
        grouped = {k: tuple(v) for k, v in obs.items()}
        total_rew = sum(rew.values())
        return grouped, total_rew, dones["__all__"], truncs["__all__"], info
    @staticmethod
    def _to_tuple_space(dspace: gym.spaces.Dict) -> gym.spaces.Tuple:
        keys = sorted(dspace.keys())
        return gym.spaces.Tuple(tuple(dspace[k] for k in keys))
A SAC configuration that triggers the failing path with a custom CTDE module looks like this:
train_cfg = (
    SACConfig()
    .environment(RpsGrouped, env_config={})
    .framework("torch")
    .rl_module(
        rl_module_spec=RLModuleSpec(
            module_class=CtdeModule,
            observation_space=RpsGrouped.observation_space,
            action_space=RpsGrouped.action_space,
        )
    )
    .training(
        twin_q=True,
        replay_buffer_config={"type": "PrioritizedEpisodeReplayBuffer"},
    )
    .evaluation(evaluation_config=SACConfig.overrides(exploration=False))
)
What’s actually going wrong
The failures appear while concatenating episodes and later when sampling from the prioritized episode buffer’s segment tree. The observed pattern is that PrioritizedEpisodeReplayBuffer does not cope with complex observations in combination with the custom RLModule when used for CTCE and CTDE. The same model and observation-space structure work in the grouped Rock–Paper–Scissors test, but not in the real environment, even though both expose Dict({"grouped": Tuple(Box, Box)}) at the single-agent interface.
Working fix
Switching away from PrioritizedEpisodeReplayBuffer resolves the issue. For CTCE and CTDE, use EpisodeReplayBuffer. For DTDE, use MultiAgentEpisodeReplayBuffer. In addition, set replay_sequence_length to 1, replay_burn_in to 0, replay_zero_init_states to True in replay_buffer_config, and collect with complete episodes. With these changes, training proceeds correctly without altering the environment or the RLModule APIs.
Here is the adjusted SAC setup for CTDE:
stable_cfg = (
    SACConfig()
    .environment(RpsGrouped, env_config={})
    .framework("torch")
    .rl_module(
        rl_module_spec=RLModuleSpec(
            module_class=CtdeModule,
            observation_space=RpsGrouped.observation_space,
            action_space=RpsGrouped.action_space,
        )
    )
    .training(
        twin_q=True,
        replay_buffer_config={
            "type": "EpisodeReplayBuffer",
            "replay_sequence_length": 1,
            "replay_burn_in": 0,
            "replay_zero_init_states": True,
        },
    )
    .env_runners(batch_mode="complete_episodes")
    .evaluation(evaluation_config=SACConfig.overrides(exploration=False))
)
If you operate in DTDE, replace EpisodeReplayBuffer with MultiAgentEpisodeReplayBuffer and keep the same sequence-length, burn-in, and zero-init settings.
Why this matters
Multi-agent pipelines in RLlib depend on consistent episode stitching and replay sampling. When observations become structured—Dict of Tuple of arrays—certain buffer implementations can be intolerant to that shape in combination with custom module plumbing. Knowing which replay buffers are robust to complex observations under CTCE, CTDE, or DTDE saves hours of debugging and protects against subtle crashes that surface late during training.
Practical takeaway
If you see assertions in single_agent_episode.concat_episode or sampling errors in the prioritized episode buffer when using complex observations and a custom RLModule, first try replacing PrioritizedEpisodeReplayBuffer. EpisodeReplayBuffer works for CTCE and CTDE, while MultiAgentEpisodeReplayBuffer is appropriate for DTDE. Set replay_sequence_length to 1, replay_burn_in to 0, replay_zero_init_states to True, and collect batches as complete episodes. This combination keeps the training loop stable without touching your environment logic or the CTDE module.
The article is based on a question from StackOverflow by Nelson Salazar and an answer by Nelson Salazar.