2025, Oct 31 05:16

CTDE в RLlib: стабильная настройка SAC без PrioritizedEpisodeReplayBuffer

CTDE в RLlib: почему падает PrioritizedEpisodeReplayBuffer с GroupAgentsWrapper и кастомным RLModule, и как стабилизировать обучение SAC с EpisodeReplayBuffer.

Централизованное обучение, децентрализованное исполнение (CTDE) в RLlib — мощный подход, однако связующий код вокруг группировки агентов и реплеев может неожиданно зависеть от структуры наблюдений. В этом руководстве разбирается практический сценарий сбоя, возникающий при сочетании GroupAgentsWrapper из Ray RLlib, кастомного RLModule для CTDE и PrioritizedEpisodeReplayBuffer, а также показана конфигурация, которая стабильно обучается без изменений в логике модели и семантике окружения.

Краткий разбор проблемы

Пользовательский RLModule, реализующий ValueFunctionAPI, TargetNetworkAPI и QNetAPI, применяется с PPO, APPO и SAC. Мультиагентное окружение сводится к одноагентному интерфейсу через GroupAgentsWrapper. В тестовом окружении (вариация «Камень–Ножницы–Бумага») обучение успешно проходит с PPO/APPO/SAC под управлением ray.tune.Tuner. Но в реальном окружении с той же структурой пространства наблюдений обучение срывается при использовании PrioritizedEpisodeReplayBuffer. Первый симптом — утверждение в single_agent_episode.concat_episode:

assert np.all(other.observations[0] == self.observations[-1]) — ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

Если обойти это утверждение ради отладки, обучение немного продолжается и затем падает глубже — во время выборки из сегментного дерева приоритезированного буфера эпизодов с ошибкой выхода за диапазон при префикс-суммировании.

Структура сгруппированных наблюдений совпадает в тестовой и реальной настройках: мультиагентный Dict({"agent1": Box, "agent2": Box}) превращается в одноагентный Dict({"grouped": Tuple(Box, Box)}). Несмотря на это, путь через приоритезированный буфер ломается только в реальном окружении.

Пример в стиле «repro»

Следующее окружение повторяет рабочий тест «Камень–Ножницы–Бумага». Логика программы сохраняется, меняются лишь обозначения и имена.

class RpsArena(MultiAgentEnv):
    MOVE_R = 0
    MOVE_P = 1
    MOVE_S = 2
    OUTCOME = {
        (MOVE_R, MOVE_R): (0, 0),
        (MOVE_R, MOVE_P): (-1, 1),
        (MOVE_R, MOVE_S): (1, -1),
        (MOVE_P, MOVE_R): (1, -1),
        (MOVE_P, MOVE_P): (0, 0),
        (MOVE_P, MOVE_S): (-1, 1),
        (MOVE_S, MOVE_R): (-1, 1),
        (MOVE_S, MOVE_P): (1, -1),
        (MOVE_S, MOVE_S): (0, 0),
    }
    def __init__(self, env_config=None):
        super().__init__()
        self.agent_names = ["p1", "p2"]
        self.agents = self.possible_agents = self.agent_names
        self.observation_spaces = self.action_spaces = gym.spaces.Dict({
            "p1": gym.spaces.Box(low=0, high=2, shape=(1,)),
            "p2": gym.spaces.Box(low=0, high=2, shape=(1,)),
        })
        self.turns = 0
    def reset(self, *, seed=None, options=None):
        self.turns = 0
        return {
            "p1": np.array([0.0], dtype=np.float32),
            "p2": np.array([0.0], dtype=np.float32),
        }, {}
    def step(self, action_dict):
        self.turns += 1
        m1 = int(action_dict["p1"].item())
        m2 = int(action_dict["p2"].item())
        obs = {
            "p1": np.array([m2], dtype=np.float32),
            "p2": np.array([m1], dtype=np.float32),
        }
        r1, r2 = self.OUTCOME[m1, m2]
        rew = {"p1": r1, "p2": r2}
        dones = {"__all__": bool(self.turns >= 10)}
        truncs = {"__all__": bool(self.turns >= 10)}
        return obs, rew, dones, truncs, {}
class RpsGrouped(MultiAgentEnv):
    def __init__(self, env_config=None):
        super().__init__()
        base = RpsArena(env_config)
        tuple_obs = self._to_tuple_space(base.observation_spaces)
        tuple_act = self._to_tuple_space(base.action_spaces)
        self.core = base.with_agent_groups(
            groups={"pack": ["p1", "p2"]},
            obs_space=tuple_obs,
            act_space=tuple_act,
        )
        self.agent_names = ["pack"]
        self.agents = self.possible_agents = self.agent_names
        self.orig_ids = base.agent_names
        self.observation_space = gym.spaces.Dict({"pack": tuple_obs})
        self.action_space = gym.spaces.Dict({"pack": tuple_act})
    def reset(self, *, seed=None, options=None):
        obs, info = self.core.reset(seed=seed, options=options)
        grouped = {k: tuple(v) for k, v in obs.items()}
        return grouped, info
    def step(self, action_dict):
        obs, rew, dones, truncs, info = self.core.step(action_dict)
        grouped = {k: tuple(v) for k, v in obs.items()}
        total_rew = sum(rew.values())
        return grouped, total_rew, dones["__all__"], truncs["__all__"], info
    @staticmethod
    def _to_tuple_space(dspace: gym.spaces.Dict) -> gym.spaces.Tuple:
        keys = sorted(dspace.keys())
        return gym.spaces.Tuple(tuple(dspace[k] for k in keys))

Конфигурация SAC, которая воспроизводит сбойный путь с кастомным модулем для CTDE, выглядит так:

train_cfg = (
    SACConfig()
    .environment(RpsGrouped, env_config={})
    .framework("torch")
    .rl_module(
        rl_module_spec=RLModuleSpec(
            module_class=CtdeModule,
            observation_space=RpsGrouped.observation_space,
            action_space=RpsGrouped.action_space,
        )
    )
    .training(
        twin_q=True,
        replay_buffer_config={"type": "PrioritizedEpisodeReplayBuffer"},
    )
    .evaluation(evaluation_config=SACConfig.overrides(exploration=False))
)

Что на самом деле ломается

Сбои проявляются при склейке эпизодов, а позднее — при выборке из сегментного дерева приоритезированного буфера эпизодов. Наблюдаемая закономерность: PrioritizedEpisodeReplayBuffer не справляется со сложными наблюдениями в сочетании с кастомным RLModule при использовании для CTCE и CTDE. Та же модель и структура пространства наблюдений работают в сгруппированном тесте «Камень–Ножницы–Бумага», но не в реальном окружении, хотя оба предоставляют на одноагентном интерфейсе Dict({"grouped": Tuple(Box, Box)}).

Рабочее решение

Отказ от PrioritizedEpisodeReplayBuffer решает проблему. Для CTCE и CTDE используйте EpisodeReplayBuffer. Для DTDE — MultiAgentEpisodeReplayBuffer. Дополнительно задайте в replay_buffer_config параметры: replay_sequence_length = 1, replay_burn_in = 0, replay_zero_init_states = True, а сбор данных ведите целыми эпизодами. С этими изменениями обучение идет корректно без правок окружения или API RLModule.

Ниже — обновленная настройка SAC для CTDE:

stable_cfg = (
    SACConfig()
    .environment(RpsGrouped, env_config={})
    .framework("torch")
    .rl_module(
        rl_module_spec=RLModuleSpec(
            module_class=CtdeModule,
            observation_space=RpsGrouped.observation_space,
            action_space=RpsGrouped.action_space,
        )
    )
    .training(
        twin_q=True,
        replay_buffer_config={
            "type": "EpisodeReplayBuffer",
            "replay_sequence_length": 1,
            "replay_burn_in": 0,
            "replay_zero_init_states": True,
        },
    )
    .env_runners(batch_mode="complete_episodes")
    .evaluation(evaluation_config=SACConfig.overrides(exploration=False))
)

Если вы работаете в DTDE, замените EpisodeReplayBuffer на MultiAgentEpisodeReplayBuffer и сохраните те же значения длины последовательности, burn-in и нуловой инициализации состояний.

Почему это важно

Мультиагентные пайплайны в RLlib зависят от корректной склейки эпизодов и выборки из реплея. Когда наблюдения становятся структурированными — Dict из Tuple массивов — некоторые реализации буферов могут оказаться чувствительны к такой форме в сочетании с пользовательской обвязкой модуля. Понимание того, какие буферы устойчивы к сложным наблюдениям при CTCE, CTDE или DTDE, экономит часы отладки и помогает избежать скрытых падений, проявляющихся уже в ходе обучения.

Практический вывод

Если при работе со сложными наблюдениями и кастомным RLModule вы видите утверждения в single_agent_episode.concat_episode или ошибки выборки в приоритезированном буфере, сперва попробуйте заменить PrioritizedEpisodeReplayBuffer. Для CTCE и CTDE подходит EpisodeReplayBuffer, для DTDE — MultiAgentEpisodeReplayBuffer. Задайте replay_sequence_length = 1, replay_burn_in = 0, replay_zero_init_states = True и собирайте батчи как полные эпизоды. Такое сочетание стабилизирует цикл обучения без изменений в логике окружения или модуле CTDE.

Статья основана на вопросе на StackOverflow от Nelson Salazar и ответе от Nelson Salazar.