2025, Dec 14 09:00

How to Feed a 28×28 Hand-Drawn Digit into a PyTorch Conv2D Model: Reshape 784 to (1,28,28) and Decode Output with argmax

Learn the exact tensor shape for PyTorch Conv2D on MNIST: reshape 784 values to (1,28,28), run eval, then pick the label via argmax. Concise inference guide.

Feeding a hand‑drawn 28×28 digit into a PyTorch Conv2D model sounds trivial until you hit tensor shape mismatches and opaque outputs. The good news: you only need to get two things right—how to format the input tensor for Conv2D and how to read the model’s prediction. Below is a compact, end‑to‑end guide that sticks to the essentials without hiding what happens under the hood.

Problem setup

You have a CNN trained on MNIST and you want to run inference from a web app where users draw in black and white. The drawing is currently stored as a flattened array of 784 values containing 0s and 1s. The model architecture is standard for MNIST and expects a single‑channel image.

import torch
import torch.nn as nn
import torch.nn.functional as F

class DigitNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_a = nn.Conv2d(1, 8, 3, stride=1, padding=1)
        self.pooler = nn.MaxPool2d(2, stride=2)
        self.conv_b = nn.Conv2d(8, 8, 3, stride=1, padding=1)
        self.fc_a = nn.Linear(7 * 7 * 8, 128)
        self.fc_b = nn.Linear(128, 128)
        self.fc_c = nn.Linear(128, 10)

    def forward(self, x):
        x = self.pooler(F.relu(self.conv_a(x)))
        x = self.pooler(F.relu(self.conv_b(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc_a(x))
        x = F.relu(self.fc_b(x))
        x = self.fc_c(x)
        return x

net = DigitNet()
net.load_state_dict(torch.load("model_weights.pth", weights_only=True))
net.eval()

What’s actually going wrong

The first layer is Conv2D, which accepts inputs shaped either as (N, C_in, H_in, W_in) or as (C_in, H_in, W_in). Here N is the batch size, C_in is the number of channels, H_in and W_in are height and width. For MNIST, C_in is 1 and the spatial size is 28×28. Since you process one image at a time, you can use the single‑sample form and pass a tensor shaped as (1, 28, 28).

At the output side, a classifier is trained with one‑hot targets. During inference, you pick the index of the largest logit as the predicted class.

Solution: shape the input and read the output

Start with your flattened array of 784 values and reshape it to match the Conv2D input. If height and width appear swapped in your data, you can swap the two spatial dimensions afterward.

def pack_canvas(flat_values):
    t = torch.tensor(flat_values)
    t = t.view(1, 28, 28)
    # If needed, swap height and width: t = t.transpose(1, 2)
    return t

def pick_label(logits):
    return logits.argmax()

# Example end-to-end run from a flattened 0/1 array:
# user_pixels = [0, 0, 1, 1, ..., 1]  # length 784
# The input can be a Python or NumPy multidimensional array passed to torch.tensor(...)

# user_pixels is a flattened array of length 784
sample = pack_canvas(user_pixels)
output = net(sample)
pred_digit = pick_label(output)
print(pred_digit)

Why this matters

Convolution expects a specific tensor layout. If you feed a wrongly shaped tensor, you don’t get useful gradients or logits; you get runtime errors or meaningless numbers. Likewise, classification heads produce a vector of scores, not the class itself. Picking the max is the bridge from model output to an actual label.

Practical notes

You can provide either a plain Python multidimensional array or a NumPy array to torch.tensor when constructing the input. If the reshaped 28×28 appears rotated or transposed relative to how the canvas stores pixels, swap the spatial axes with a transpose between the view and the forward pass.

Wrap‑up

Keep the pipeline simple: reshape the 784‑long vector into a (1, 28, 28) tensor for a single‑channel Conv2D input, run the model in eval mode, then take argmax over the output vector to get the digit. A small amount of shape hygiene on input and a single reduction on output is all it takes to get clean, deterministic inference for MNIST‑style digit recognition.