2025, Oct 25 21:00

Fixing DQN Failure on Atari Pong: Correcting HWC vs CHW Tensor Layout for PyTorch Conv2d

DQN flatlining on Atari Pong? The issue may be tensor layout. Fix PyTorch Conv2d channel order (HWC to CHW), set in_channels=4, and get learning back on track.

Why a DQN that aces a toy task flatlines on Atari Pong is a classic debugging story: the algorithm is fine, the numbers look sane, the loss goes to zero, yet the score never climbs. In this case the culprit isn’t the optimizer or epsilon scheduling. It’s the tensor layout.

What goes wrong

The observation coming from the Pong stack is shaped as 84 × 84 × 4. That is height, width, then channels. PyTorch’s Conv2d expects channel-first: N × C × H × W. When the input is fed in as-is and the first convolution is built with in_channels equal to 84, the network interprets each image row as a separate channel and treats the four stacked frames as width pixels. Training proceeds on a scrambled view of the world. The target network quickly mirrors the main network’s outputs, the Huber loss approaches zero, and the agent still learns nothing useful about Pong.

Minimal example that reproduces the issue

The core of the problem is visible right at the input. The code below builds the first convolution using the first dimension of the observation tensor and forwards the raw H × W × C tensor straight into Conv2d.

class PixelAgent(nn.Module):
    def linear_fall(self, step):
        if step >= self.cfg.lr_n_steps:
            return self.cfg.lr_end / self.cfg.lr_begin
        return 1.0 - (step / self.cfg.lr_n_steps) * (1 - self.cfg.lr_end / self.cfg.lr_begin)
    def __init__(self, game_env, cfg, dev):
        super().__init__()
        self.env = game_env
        self.cfg = cfg
        self.dev = dev
        # Wrong: uses H instead of channel count
        self.conv_a = nn.Conv2d(
            in_channels=self.env.observation_space.shape[0],  # 84
            out_channels=32,
            kernel_size=8,
            stride=4,
            dtype=torch.float32,
        )
        self.conv_b = nn.Conv2d(32, 64, 4, 2, dtype=torch.float32)
        self.conv_c = nn.Conv2d(64, 64, 3, 1, dtype=torch.float32)
        self.fc_a = nn.Linear(3136, 512, dtype=torch.float32)
        self.fc_b = nn.Linear(512, self.env.action_space.n, dtype=torch.float32)
        self.act = nn.ReLU()
        self.opt = optim.RMSprop(
            self.parameters(),
            lr=self.cfg.lr_begin,
            alpha=self.cfg.squared_gradient_momentum,
            eps=self.cfg.rms_eps,
        )
        self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lr_lambda=self.linear_fall)
        self.loss_fn = nn.SmoothL1Loss()
    def forward(self, tensor_in):
        x = self.conv_a(tensor_in)
        x = self.act(x)
        x = self.conv_b(x)
        x = self.act(x)
        x = self.conv_c(x)
        x = self.act(x)
        if len(x.shape) == 3:
            x = torch.flatten(x)
        elif len(x.shape) == 4:
            x = torch.flatten(x, start_dim=1)
        x = self.fc_a(x)
        x = self.act(x)
        x = self.fc_b(x)
        return x

With this setup, monitoring shows decreasing max/average reward and shrinking Q-values while the loss hovers near zero. That’s consistent with a target network that matches an incorrect representation rather than a stable policy improvement.

Root cause explained

Conv2d semantics in PyTorch are unambiguous: the channel dimension must be first. A stacked observation from Atari preprocessing is typically height-by-width-by-frames. If you do not permute the axes, the convolution weights are learned against the wrong dimension. The model can still reduce the temporal-difference error because both online and target networks share the same mistake, but the learned features don’t correspond to meaningful spatial or temporal structure in Pong.

The fix

Two changes are required. First, permute the observation to channel-first and scale to [0, 1] before passing it to the network. Second, set the first convolution to expect four channels, not 84.

def to_chw(obs_np: np.ndarray) -> torch.Tensor:
    t = torch.from_numpy(obs_np).float().div_(255.0)
    # (H, W, C) -> (C, H, W)
    if t.ndim == 3:
        t = t.permute(2, 0, 1)
    return t
class PixelAgentFixed(nn.Module):
    def linear_fall(self, step):
        if step >= self.cfg.lr_n_steps:
            return self.cfg.lr_end / self.cfg.lr_begin
        return 1.0 - (step / self.cfg.lr_n_steps) * (1 - self.cfg.lr_end / self.cfg.lr_begin)
    def __init__(self, game_env, cfg, dev):
        super().__init__()
        self.env = game_env
        self.cfg = cfg
        self.dev = dev
        # Correct: 4 stacked frames as channels
        self.conv_a = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
        self.conv_b = nn.Conv2d(32, 64, 4, 2)
        self.conv_c = nn.Conv2d(64, 64, 3, 1)
        self.fc_a = nn.Linear(3136, 512)
        self.fc_b = nn.Linear(512, self.env.action_space.n)
        self.act = nn.ReLU()
        self.opt = optim.RMSprop(
            self.parameters(),
            lr=self.cfg.lr_begin,
            alpha=self.cfg.squared_gradient_momentum,
            eps=self.cfg.rms_eps,
        )
        self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lr_lambda=self.linear_fall)
        self.loss_fn = nn.SmoothL1Loss()
    def forward(self, tensor_in):
        x = self.conv_a(tensor_in)
        x = self.act(x)
        x = self.conv_b(x)
        x = self.act(x)
        x = self.conv_c(x)
        x = self.act(x)
        if len(x.shape) == 3:
            x = torch.flatten(x)
        elif len(x.shape) == 4:
            x = torch.flatten(x, start_dim=1)
        x = self.fc_a(x)
        x = self.act(x)
        x = self.fc_b(x)
        return x
# Usage in the interaction loop
# obs_np is (84, 84, 4)
obs_t = to_chw(obs_np).to(device)  # (4, 84, 84)

If you prefer to address it at the environment boundary, you can request channel-first directly by enabling a channel dimension in grayscale and letting the frame stacker produce the correct layout.

Why this matters

Deep Q-Networks rely on convolutions that extract spatial features and stacked frames that encode short-term dynamics. When the channel axis is mis-specified, both of these signals collapse. The TD target still tracks the online network, which is why the SmoothL1 loss looks excellent, but the agent never acquires a policy that scores in Pong. Fixing the axis order restores meaningful gradients and makes the rising learning curves possible again.

Practical guardrails

There are a few operational choices that align with the original setup. BatchNorm is counterproductive here because replayed transitions are not i.i.d., so it tends to hurt rather than help. Reward clipping to plus-minus one and Huber loss are appropriate and should stay. With the corrected channel ordering on Pong, the average return typically starts moving after roughly one million steps and can reach at least fifteen after eight to ten million with a vanilla DQN.

Takeaway

Before chasing optimizer settings or scheduling tweaks, verify tensor layouts at the model boundary. For Atari inputs that means converting from height-width-channels to channel-height-width and matching the first convolution’s in_channels to the number of stacked frames. Once that is in place, the same algorithm that worked on a toy task can finally see Pong the way it was meant to be seen, and the training curve should rise instead of flatlining.

The article is based on a question from StackOverflow by Rohan Patel and an answer by Dmitry543.