2025, Oct 31 06:17

DQN на Atari Pong: как ошибка порядка каналов в PyTorch убивает обучение

Разбираем, почему DQN на Atari Pong «встает» при верном лоссе: Conv2d в PyTorch требует channel-first. Показываем баг с HWC и исправление CHW для стабильного обучения

Почему DQN, которая блестяще справляется с игрушечной задачей, на Atari Pong «выходит на плато», — классическая история отладки: с алгоритмом все в порядке, числа выглядят разумно, лосс уходит к нулю, а счет так и не растет. Виноват здесь не оптимизатор и не расписание эпсилона. Причина — в раскладке тензора.

Что идет не так

Наблюдение из стека Pong имеет форму 84 × 84 × 4. То есть высота, ширина, затем каналы. Conv2d в PyTorch ожидает порядок channel-first: N × C × H × W. Когда вход подается как есть, а первый сверточный слой создается с in_channels, равным 84, сеть интерпретирует каждую строку изображения как отдельный канал, а четыре сложенных кадра — как пиксели ширины. Обучение идет на перепутанном представлении мира. Таргет-сеть быстро зеркалит выходы основной сети, Huber loss стремится к нулю, и при этом агент ничего полезного о Pong не осваивает.

Минимальный пример, воспроизводящий проблему

Суть проблемы видна уже на входе. Ниже код, который строит первый сверточный слой по первой размерности тензора наблюдения и без преобразований передает сырой тензор H × W × C прямо в Conv2d.

class PixelAgent(nn.Module):
    def linear_fall(self, step):
        if step >= self.cfg.lr_n_steps:
            return self.cfg.lr_end / self.cfg.lr_begin
        return 1.0 - (step / self.cfg.lr_n_steps) * (1 - self.cfg.lr_end / self.cfg.lr_begin)
    def __init__(self, game_env, cfg, dev):
        super().__init__()
        self.env = game_env
        self.cfg = cfg
        self.dev = dev
        # Неверно: используется H вместо количества каналов
        self.conv_a = nn.Conv2d(
            in_channels=self.env.observation_space.shape[0],  # 84
            out_channels=32,
            kernel_size=8,
            stride=4,
            dtype=torch.float32,
        )
        self.conv_b = nn.Conv2d(32, 64, 4, 2, dtype=torch.float32)
        self.conv_c = nn.Conv2d(64, 64, 3, 1, dtype=torch.float32)
        self.fc_a = nn.Linear(3136, 512, dtype=torch.float32)
        self.fc_b = nn.Linear(512, self.env.action_space.n, dtype=torch.float32)
        self.act = nn.ReLU()
        self.opt = optim.RMSprop(
            self.parameters(),
            lr=self.cfg.lr_begin,
            alpha=self.cfg.squared_gradient_momentum,
            eps=self.cfg.rms_eps,
        )
        self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lr_lambda=self.linear_fall)
        self.loss_fn = nn.SmoothL1Loss()
    def forward(self, tensor_in):
        x = self.conv_a(tensor_in)
        x = self.act(x)
        x = self.conv_b(x)
        x = self.act(x)
        x = self.conv_c(x)
        x = self.act(x)
        if len(x.shape) == 3:
            x = torch.flatten(x)
        elif len(x.shape) == 4:
            x = torch.flatten(x, start_dim=1)
        x = self.fc_a(x)
        x = self.act(x)
        x = self.fc_b(x)
        return x

В таком варианте мониторинг показывает падающие максимум/среднее вознаграждение и уменьшающиеся Q-значения, в то время как лосс держится около нуля. Это согласуется с тем, что таргет-сеть подгоняется под некорректное представление, а не под устойчивое улучшение политики.

Разбор причины

Семантика Conv2d в PyTorch однозначна: размерность каналов должна идти первой. Стековое наблюдение после препроцессинга Atari обычно имеет вид «высота × ширина × кадры». Если не переставить оси, веса сверток обучаются по неверной размерности. Модель все еще способна уменьшать ошибку временной разницы, потому что онлайн- и таргет-сети повторяют одну и ту же ошибку, но выученные признаки не соответствуют осмысленной пространственной или временной структуре Pong.

Исправление

Нужны две правки. Во-первых, перед подачей в сеть переставьте наблюдение в формат channel-first и масштабируйте значения в [0, 1]. Во-вторых, задайте первому сверточному слою четыре канала, а не 84.

def to_chw(obs_np: np.ndarray) -> torch.Tensor:
    t = torch.from_numpy(obs_np).float().div_(255.0)
    # (H, W, C) -> (C, H, W)
    if t.ndim == 3:
        t = t.permute(2, 0, 1)
    return t
class PixelAgentFixed(nn.Module):
    def linear_fall(self, step):
        if step >= self.cfg.lr_n_steps:
            return self.cfg.lr_end / self.cfg.lr_begin
        return 1.0 - (step / self.cfg.lr_n_steps) * (1 - self.cfg.lr_end / self.cfg.lr_begin)
    def __init__(self, game_env, cfg, dev):
        super().__init__()
        self.env = game_env
        self.cfg = cfg
        self.dev = dev
        # Верно: 4 сложенных кадра — это каналы
        self.conv_a = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
        self.conv_b = nn.Conv2d(32, 64, 4, 2)
        self.conv_c = nn.Conv2d(64, 64, 3, 1)
        self.fc_a = nn.Linear(3136, 512)
        self.fc_b = nn.Linear(512, self.env.action_space.n)
        self.act = nn.ReLU()
        self.opt = optim.RMSprop(
            self.parameters(),
            lr=self.cfg.lr_begin,
            alpha=self.cfg.squared_gradient_momentum,
            eps=self.cfg.rms_eps,
        )
        self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lr_lambda=self.linear_fall)
        self.loss_fn = nn.SmoothL1Loss()
    def forward(self, tensor_in):
        x = self.conv_a(tensor_in)
        x = self.act(x)
        x = self.conv_b(x)
        x = self.act(x)
        x = self.conv_c(x)
        x = self.act(x)
        if len(x.shape) == 3:
            x = torch.flatten(x)
        elif len(x.shape) == 4:
            x = torch.flatten(x, start_dim=1)
        x = self.fc_a(x)
        x = self.act(x)
        x = self.fc_b(x)
        return x
# Использование в цикле взаимодействия
# obs_np имеет форму (84, 84, 4)
obs_t = to_chw(obs_np).to(device)  # (4, 84, 84)

Если удобнее решить это на границе окружения, можно сразу запросить channel-first: включите размерность канала для градаций серого, и стекер кадров сформирует правильную раскладку.

Почему это важно

Deep Q-Networks опираются на свертки, извлекающие пространственные признаки, и на стек кадров, передающий короткую динамику. При неверно заданной оси каналов оба сигнала рушатся. TD-таргет все равно следует за онлайн-сетью — поэтому SmoothL1 выглядит отлично — но политика, которая набирала бы очки в Pong, не появляется. Исправление порядка осей возвращает осмысленные градиенты и снова делает возможным рост кривых обучения.

Практические рекомендации

Есть несколько рабочих настроек, согласованных с оригинальным подходом. BatchNorm здесь скорее мешает, потому что переходы из буфера воспроизведения не i.i.d., и нормализация ухудшает обучение. Обрезка вознаграждений до ±1 и Huber loss уместны и их стоит оставить. С корректным порядком каналов на Pong средняя отдача обычно начинает расти примерно после миллиона шагов и может достигать не менее 15 после восьми–десяти миллионов на обычном DQN.

Вывод

Прежде чем крутить параметры оптимизатора или график эпсилона, проверьте раскладку тензоров на границе модели. Для входов Atari это означает перевод из height–width–channels в channel–height–width и соответствие in_channels первого свертечного слоя числу сложенных кадров. После этого тот же алгоритм, который работал на игрушечной задаче, наконец «увидит» Pong как нужно, и кривая обучения пойдет вверх, а не будет лежать на нуле.

Статья основана на вопросе с StackOverflow от Rohan Patel и ответе от Dmitry543.