2025, Oct 31 05:16
CTDE в RLlib: стабильная настройка SAC без PrioritizedEpisodeReplayBuffer
CTDE в RLlib: почему падает PrioritizedEpisodeReplayBuffer с GroupAgentsWrapper и кастомным RLModule, и как стабилизировать обучение SAC с EpisodeReplayBuffer.
Централизованное обучение, децентрализованное исполнение (CTDE) в RLlib — мощный подход, однако связующий код вокруг группировки агентов и реплеев может неожиданно зависеть от структуры наблюдений. В этом руководстве разбирается практический сценарий сбоя, возникающий при сочетании GroupAgentsWrapper из Ray RLlib, кастомного RLModule для CTDE и PrioritizedEpisodeReplayBuffer, а также показана конфигурация, которая стабильно обучается без изменений в логике модели и семантике окружения.
Краткий разбор проблемы
Пользовательский RLModule, реализующий ValueFunctionAPI, TargetNetworkAPI и QNetAPI, применяется с PPO, APPO и SAC. Мультиагентное окружение сводится к одноагентному интерфейсу через GroupAgentsWrapper. В тестовом окружении (вариация «Камень–Ножницы–Бумага») обучение успешно проходит с PPO/APPO/SAC под управлением ray.tune.Tuner. Но в реальном окружении с той же структурой пространства наблюдений обучение срывается при использовании PrioritizedEpisodeReplayBuffer. Первый симптом — утверждение в single_agent_episode.concat_episode:
assert np.all(other.observations[0] == self.observations[-1]) — ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Если обойти это утверждение ради отладки, обучение немного продолжается и затем падает глубже — во время выборки из сегментного дерева приоритезированного буфера эпизодов с ошибкой выхода за диапазон при префикс-суммировании.
Структура сгруппированных наблюдений совпадает в тестовой и реальной настройках: мультиагентный Dict({"agent1": Box, "agent2": Box}) превращается в одноагентный Dict({"grouped": Tuple(Box, Box)}). Несмотря на это, путь через приоритезированный буфер ломается только в реальном окружении.
Пример в стиле «repro»
Следующее окружение повторяет рабочий тест «Камень–Ножницы–Бумага». Логика программы сохраняется, меняются лишь обозначения и имена.
class RpsArena(MultiAgentEnv):
MOVE_R = 0
MOVE_P = 1
MOVE_S = 2
OUTCOME = {
(MOVE_R, MOVE_R): (0, 0),
(MOVE_R, MOVE_P): (-1, 1),
(MOVE_R, MOVE_S): (1, -1),
(MOVE_P, MOVE_R): (1, -1),
(MOVE_P, MOVE_P): (0, 0),
(MOVE_P, MOVE_S): (-1, 1),
(MOVE_S, MOVE_R): (-1, 1),
(MOVE_S, MOVE_P): (1, -1),
(MOVE_S, MOVE_S): (0, 0),
}
def __init__(self, env_config=None):
super().__init__()
self.agent_names = ["p1", "p2"]
self.agents = self.possible_agents = self.agent_names
self.observation_spaces = self.action_spaces = gym.spaces.Dict({
"p1": gym.spaces.Box(low=0, high=2, shape=(1,)),
"p2": gym.spaces.Box(low=0, high=2, shape=(1,)),
})
self.turns = 0
def reset(self, *, seed=None, options=None):
self.turns = 0
return {
"p1": np.array([0.0], dtype=np.float32),
"p2": np.array([0.0], dtype=np.float32),
}, {}
def step(self, action_dict):
self.turns += 1
m1 = int(action_dict["p1"].item())
m2 = int(action_dict["p2"].item())
obs = {
"p1": np.array([m2], dtype=np.float32),
"p2": np.array([m1], dtype=np.float32),
}
r1, r2 = self.OUTCOME[m1, m2]
rew = {"p1": r1, "p2": r2}
dones = {"__all__": bool(self.turns >= 10)}
truncs = {"__all__": bool(self.turns >= 10)}
return obs, rew, dones, truncs, {}
class RpsGrouped(MultiAgentEnv):
def __init__(self, env_config=None):
super().__init__()
base = RpsArena(env_config)
tuple_obs = self._to_tuple_space(base.observation_spaces)
tuple_act = self._to_tuple_space(base.action_spaces)
self.core = base.with_agent_groups(
groups={"pack": ["p1", "p2"]},
obs_space=tuple_obs,
act_space=tuple_act,
)
self.agent_names = ["pack"]
self.agents = self.possible_agents = self.agent_names
self.orig_ids = base.agent_names
self.observation_space = gym.spaces.Dict({"pack": tuple_obs})
self.action_space = gym.spaces.Dict({"pack": tuple_act})
def reset(self, *, seed=None, options=None):
obs, info = self.core.reset(seed=seed, options=options)
grouped = {k: tuple(v) for k, v in obs.items()}
return grouped, info
def step(self, action_dict):
obs, rew, dones, truncs, info = self.core.step(action_dict)
grouped = {k: tuple(v) for k, v in obs.items()}
total_rew = sum(rew.values())
return grouped, total_rew, dones["__all__"], truncs["__all__"], info
@staticmethod
def _to_tuple_space(dspace: gym.spaces.Dict) -> gym.spaces.Tuple:
keys = sorted(dspace.keys())
return gym.spaces.Tuple(tuple(dspace[k] for k in keys))
Конфигурация SAC, которая воспроизводит сбойный путь с кастомным модулем для CTDE, выглядит так:
train_cfg = (
SACConfig()
.environment(RpsGrouped, env_config={})
.framework("torch")
.rl_module(
rl_module_spec=RLModuleSpec(
module_class=CtdeModule,
observation_space=RpsGrouped.observation_space,
action_space=RpsGrouped.action_space,
)
)
.training(
twin_q=True,
replay_buffer_config={"type": "PrioritizedEpisodeReplayBuffer"},
)
.evaluation(evaluation_config=SACConfig.overrides(exploration=False))
)
Что на самом деле ломается
Сбои проявляются при склейке эпизодов, а позднее — при выборке из сегментного дерева приоритезированного буфера эпизодов. Наблюдаемая закономерность: PrioritizedEpisodeReplayBuffer не справляется со сложными наблюдениями в сочетании с кастомным RLModule при использовании для CTCE и CTDE. Та же модель и структура пространства наблюдений работают в сгруппированном тесте «Камень–Ножницы–Бумага», но не в реальном окружении, хотя оба предоставляют на одноагентном интерфейсе Dict({"grouped": Tuple(Box, Box)}).
Рабочее решение
Отказ от PrioritizedEpisodeReplayBuffer решает проблему. Для CTCE и CTDE используйте EpisodeReplayBuffer. Для DTDE — MultiAgentEpisodeReplayBuffer. Дополнительно задайте в replay_buffer_config параметры: replay_sequence_length = 1, replay_burn_in = 0, replay_zero_init_states = True, а сбор данных ведите целыми эпизодами. С этими изменениями обучение идет корректно без правок окружения или API RLModule.
Ниже — обновленная настройка SAC для CTDE:
stable_cfg = (
SACConfig()
.environment(RpsGrouped, env_config={})
.framework("torch")
.rl_module(
rl_module_spec=RLModuleSpec(
module_class=CtdeModule,
observation_space=RpsGrouped.observation_space,
action_space=RpsGrouped.action_space,
)
)
.training(
twin_q=True,
replay_buffer_config={
"type": "EpisodeReplayBuffer",
"replay_sequence_length": 1,
"replay_burn_in": 0,
"replay_zero_init_states": True,
},
)
.env_runners(batch_mode="complete_episodes")
.evaluation(evaluation_config=SACConfig.overrides(exploration=False))
)
Если вы работаете в DTDE, замените EpisodeReplayBuffer на MultiAgentEpisodeReplayBuffer и сохраните те же значения длины последовательности, burn-in и нуловой инициализации состояний.
Почему это важно
Мультиагентные пайплайны в RLlib зависят от корректной склейки эпизодов и выборки из реплея. Когда наблюдения становятся структурированными — Dict из Tuple массивов — некоторые реализации буферов могут оказаться чувствительны к такой форме в сочетании с пользовательской обвязкой модуля. Понимание того, какие буферы устойчивы к сложным наблюдениям при CTCE, CTDE или DTDE, экономит часы отладки и помогает избежать скрытых падений, проявляющихся уже в ходе обучения.
Практический вывод
Если при работе со сложными наблюдениями и кастомным RLModule вы видите утверждения в single_agent_episode.concat_episode или ошибки выборки в приоритезированном буфере, сперва попробуйте заменить PrioritizedEpisodeReplayBuffer. Для CTCE и CTDE подходит EpisodeReplayBuffer, для DTDE — MultiAgentEpisodeReplayBuffer. Задайте replay_sequence_length = 1, replay_burn_in = 0, replay_zero_init_states = True и собирайте батчи как полные эпизоды. Такое сочетание стабилизирует цикл обучения без изменений в логике окружения или модуле CTDE.
Статья основана на вопросе на StackOverflow от Nelson Salazar и ответе от Nelson Salazar.