2026, Jan 04 18:02

Инференс MNIST в PyTorch: подача 28×28 в Conv2D и чтение предсказания

Как из массива 784 значений получить тензор (1, 28, 28) для Conv2D в PyTorch и корректно прочитать logits модели MNIST через argmax. С примерами и пояснениями.

Передать нарисованную от руки цифру 28×28 в модель PyTorch с Conv2D кажется пустяком — пока не столкнёшься с несовпадениями форм тензоров и непрозрачными выходами. Хорошая новость: важно правильно сделать всего две вещи — как оформить входной тензор для Conv2D и как прочитать предсказание модели. Ниже — компактное, сквозное руководство, которое держится за основы и не скрывает, что происходит под капотом.

Постановка задачи

У вас есть CNN, обученная на MNIST, и вы хотите запускать инференс из веб‑приложения, где пользователи рисуют чёрно‑белые изображения. Рисунок сейчас хранится как «сплющенный» массив из 784 значений, содержащих 0 и 1. Архитектура модели типична для MNIST и ожидает одноканальное изображение.

import torch
import torch.nn as nn
import torch.nn.functional as F
class DigitNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv_a = nn.Conv2d(1, 8, 3, stride=1, padding=1)
        self.pooler = nn.MaxPool2d(2, stride=2)
        self.conv_b = nn.Conv2d(8, 8, 3, stride=1, padding=1)
        self.fc_a = nn.Linear(7 * 7 * 8, 128)
        self.fc_b = nn.Linear(128, 128)
        self.fc_c = nn.Linear(128, 10)
    def forward(self, x):
        x = self.pooler(F.relu(self.conv_a(x)))
        x = self.pooler(F.relu(self.conv_b(x)))
        x = torch.flatten(x, 1)
        x = F.relu(self.fc_a(x))
        x = F.relu(self.fc_b(x))
        x = self.fc_c(x)
        return x
net = DigitNet()
net.load_state_dict(torch.load("model_weights.pth", weights_only=True))
net.eval()

Что именно идёт не так

Первый слой — Conv2D, который принимает входы формы либо (N, C_in, H_in, W_in), либо (C_in, H_in, W_in). Здесь N — размер батча, C_in — число каналов, H_in и W_in — высота и ширина. Для MNIST C_in равен 1, а пространственный размер — 28×28. Поскольку вы обрабатываете по одному изображению, можно использовать форму для одного образца и передать тензор формы (1, 28, 28).

На выходе классификатор обучался с one‑hot целями. В режиме инференса берите индекс наибольшего логита как предсказанный класс.

Решение: привести вход и прочитать выход

Начните со своего «сплющенного» массива из 784 значений и преобразуйте его к форме, ожидаемой Conv2D. Если в ваших данных высота и ширина оказались перепутаны, можно поменять местами две пространственные оси после преобразования.

def pack_canvas(flat_values):
    t = torch.tensor(flat_values)
    t = t.view(1, 28, 28)
    # При необходимости поменяйте местами высоту и ширину: t = t.transpose(1, 2)
    return t
def pick_label(logits):
    return logits.argmax()
# Пример сквозного запуска из «сплющенного» массива 0/1:
# user_pixels = [0, 0, 1, 1, ..., 1]  # длина 784
# Входом может быть многомерный массив Python или NumPy, переданный в torch.tensor(...)
# user_pixels — это «сплющенный» массив длины 784
sample = pack_canvas(user_pixels)
output = net(sample)
pred_digit = pick_label(output)
print(pred_digit)

Почему это важно

Свёртка ожидает конкретную раскладку тензора. Если подать тензор неправильной формы, полезных градиентов или логитов вы не получите: либо возникнут ошибки выполнения, либо бессмысленные числа. Точно так же классификационные головы возвращают вектор оценок, а не сам класс. Выбор максимума — мостик от выходов модели к реальной метке.

Практические заметки

В качестве входа в torch.tensor можно передать как обычный многомерный массив Python, так и массив NumPy. Если преобразованное 28×28 выглядит повёрнутым или транспонированным относительно того, как холст хранит пиксели, поменяйте местами пространственные оси с помощью transpose между view и прямым проходом.

Итоги

Держите конвейер простым: преобразуйте вектор длины 784 в тензор формы (1, 28, 28) для одноканального входа Conv2D, запустите модель в режиме eval, затем возьмите argmax по выходному вектору, чтобы получить цифру. Небольшая гигиена форм на входе и одно сокращение на выходе — и вы получаете чистый, детерминированный инференс для распознавания цифр в стиле MNIST.