2026, Jan 04 18:02
Инференс MNIST в PyTorch: подача 28×28 в Conv2D и чтение предсказания
Как из массива 784 значений получить тензор (1, 28, 28) для Conv2D в PyTorch и корректно прочитать logits модели MNIST через argmax. С примерами и пояснениями.
Передать нарисованную от руки цифру 28×28 в модель PyTorch с Conv2D кажется пустяком — пока не столкнёшься с несовпадениями форм тензоров и непрозрачными выходами. Хорошая новость: важно правильно сделать всего две вещи — как оформить входной тензор для Conv2D и как прочитать предсказание модели. Ниже — компактное, сквозное руководство, которое держится за основы и не скрывает, что происходит под капотом.
Постановка задачи
У вас есть CNN, обученная на MNIST, и вы хотите запускать инференс из веб‑приложения, где пользователи рисуют чёрно‑белые изображения. Рисунок сейчас хранится как «сплющенный» массив из 784 значений, содержащих 0 и 1. Архитектура модели типична для MNIST и ожидает одноканальное изображение.
import torch
import torch.nn as nn
import torch.nn.functional as F
class DigitNet(nn.Module):
def __init__(self):
super().__init__()
self.conv_a = nn.Conv2d(1, 8, 3, stride=1, padding=1)
self.pooler = nn.MaxPool2d(2, stride=2)
self.conv_b = nn.Conv2d(8, 8, 3, stride=1, padding=1)
self.fc_a = nn.Linear(7 * 7 * 8, 128)
self.fc_b = nn.Linear(128, 128)
self.fc_c = nn.Linear(128, 10)
def forward(self, x):
x = self.pooler(F.relu(self.conv_a(x)))
x = self.pooler(F.relu(self.conv_b(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc_a(x))
x = F.relu(self.fc_b(x))
x = self.fc_c(x)
return x
net = DigitNet()
net.load_state_dict(torch.load("model_weights.pth", weights_only=True))
net.eval()
Что именно идёт не так
Первый слой — Conv2D, который принимает входы формы либо (N, C_in, H_in, W_in), либо (C_in, H_in, W_in). Здесь N — размер батча, C_in — число каналов, H_in и W_in — высота и ширина. Для MNIST C_in равен 1, а пространственный размер — 28×28. Поскольку вы обрабатываете по одному изображению, можно использовать форму для одного образца и передать тензор формы (1, 28, 28).
На выходе классификатор обучался с one‑hot целями. В режиме инференса берите индекс наибольшего логита как предсказанный класс.
Решение: привести вход и прочитать выход
Начните со своего «сплющенного» массива из 784 значений и преобразуйте его к форме, ожидаемой Conv2D. Если в ваших данных высота и ширина оказались перепутаны, можно поменять местами две пространственные оси после преобразования.
def pack_canvas(flat_values):
t = torch.tensor(flat_values)
t = t.view(1, 28, 28)
# При необходимости поменяйте местами высоту и ширину: t = t.transpose(1, 2)
return t
def pick_label(logits):
return logits.argmax()
# Пример сквозного запуска из «сплющенного» массива 0/1:
# user_pixels = [0, 0, 1, 1, ..., 1] # длина 784
# Входом может быть многомерный массив Python или NumPy, переданный в torch.tensor(...)
# user_pixels — это «сплющенный» массив длины 784
sample = pack_canvas(user_pixels)
output = net(sample)
pred_digit = pick_label(output)
print(pred_digit)
Почему это важно
Свёртка ожидает конкретную раскладку тензора. Если подать тензор неправильной формы, полезных градиентов или логитов вы не получите: либо возникнут ошибки выполнения, либо бессмысленные числа. Точно так же классификационные головы возвращают вектор оценок, а не сам класс. Выбор максимума — мостик от выходов модели к реальной метке.
Практические заметки
В качестве входа в torch.tensor можно передать как обычный многомерный массив Python, так и массив NumPy. Если преобразованное 28×28 выглядит повёрнутым или транспонированным относительно того, как холст хранит пиксели, поменяйте местами пространственные оси с помощью transpose между view и прямым проходом.
Итоги
Держите конвейер простым: преобразуйте вектор длины 784 в тензор формы (1, 28, 28) для одноканального входа Conv2D, запустите модель в режиме eval, затем возьмите argmax по выходному вектору, чтобы получить цифру. Небольшая гигиена форм на входе и одно сокращение на выходе — и вы получаете чистый, детерминированный инференс для распознавания цифр в стиле MNIST.