2025, Nov 09 15:03

Правильное усреднение MSE по батчам и выбор train или eval

Разбираем, как корректно усреднять MSE по батчам в PyTorch: делить на число батчей, а не объектов. Когда считать метрику в train и eval. Без искажений метрик.

Как правильно усреднять MSE по батчам и когда использовать режим train или eval

При обучении регрессионной модели обычно подводят итог по каждой эпохе с помощью среднеквадратичной ошибки (MSE). Есть тонкость, которая легко искажает метрику: делить накопленные потери по батчам на число объектов в датасете или на количество батчей. Ещё один частый вопрос — в каком режиме должна находиться модель при вычислении тренировочной метрики.

Пример проблемы

Ниже приведён цикл обучения, который накапливает MSE из каждого мини-батча, а затем делит на размер датасета. Параллельно считается корреляция для обеих выборок.

def fit_and_validate(net, loader_tr, loader_va, epochs=40, lr_rate=1e-3, wd=1e-5, patience_n=5):
    net = net.to(device)
    mse_fn = nn.MSELoss()
    opt = optim.Adam(net.parameters(), lr=lr_rate, weight_decay=wd)
    
    best_mse_val = float('inf')
    stagnation = 0

    for ep in range(epochs):
        net.train()
        agg_train_loss, preds_tr, gts_tr = 0, [], []
        for feats, y in loader_tr:
            feats, y = feats.to(device), y.to(device)
            out = net(feats)
            loss_val = mse_fn(out, y)

            opt.zero_grad()
            loss_val.backward()
            opt.step()

            agg_train_loss += loss_val.item()
            preds_tr.extend(out.detach().cpu().numpy())
            gts_tr.extend(y.cpu().numpy())

        net.eval()
        agg_val_loss, preds_va, gts_va = 0, [], []
        with torch.no_grad():
            for feats, y in loader_va:
                feats, y = feats.to(device), y.to(device)
                out = net(feats)
                loss_val = mse_fn(out, y)

                agg_val_loss += loss_val.item()
                preds_va.extend(out.detach().cpu().numpy())
                gts_va.extend(y.cpu().numpy())

        mse_tr = agg_train_loss / len(train_data)
        pc_tr = robust_pearsonr(preds_tr, gts_tr)
        mse_va = agg_val_loss / len(val_data)
        pc_va = robust_pearsonr(preds_va, gts_va)

Что здесь не так и почему

Потеря на батч, которую возвращает nn.MSELoss(), уже является средним по объектам внутри этого батча. В течение эпохи вы суммируете по одному среднему значению на батч. К концу цикла у вас получается сумма средних по батчам, а не сумма покомпонентных квадратов ошибок по всем объектам. Чтобы превратить эту сумму батчевых средних в среднее за эпоху, нормировать нужно на число батчей в эпохе. Это как раз длина соответствующего загрузчика данных.

Исправление

Делите накопленные средние по батчам на количество батчей, то есть на длину loader’а для соответствующей выборки.

def fit_and_validate(net, loader_tr, loader_va, epochs=40, lr_rate=1e-3, wd=1e-5, patience_n=5):
    net = net.to(device)
    mse_fn = nn.MSELoss()
    opt = optim.Adam(net.parameters(), lr=lr_rate, weight_decay=wd)
    
    best_mse_val = float('inf')
    stagnation = 0

    for ep in range(epochs):
        net.train()
        agg_train_loss, preds_tr, gts_tr = 0, [], []
        for feats, y in loader_tr:
            feats, y = feats.to(device), y.to(device)
            out = net(feats)
            loss_val = mse_fn(out, y)

            opt.zero_grad()
            loss_val.backward()
            opt.step()

            agg_train_loss += loss_val.item()
            preds_tr.extend(out.detach().cpu().numpy())
            gts_tr.extend(y.cpu().numpy())

        net.eval()
        agg_val_loss, preds_va, gts_va = 0, [], []
        with torch.no_grad():
            for feats, y in loader_va:
                feats, y = feats.to(device), y.to(device)
                out = net(feats)
                loss_val = mse_fn(out, y)

                agg_val_loss += loss_val.item()
                preds_va.extend(out.detach().cpu().numpy())
                gts_va.extend(y.cpu().numpy())

        mse_tr = agg_train_loss / len(loader_tr)
        pc_tr = robust_pearsonr(preds_tr, gts_tr)
        mse_va = agg_val_loss / len(loader_va)
        pc_va = robust_pearsonr(preds_va, gts_va)

Какой режим использовать при расчёте train MSE

Если ваша функция потерь — nn.MSELoss, то тренировочная потеря уже и есть train MSE, и вы естественным образом считаете её в режиме обучения. Если же вы хотите считать MSE как отдельную метрику, а не как loss, можно вычислять её в режиме eval.

Почему это важно

Нормирование по размеру датасета при накоплении средних по батчам искажает метрику эпохи и может как приглушать, так и усиливать сигнал. Использование числа батчей согласует среднее за эпоху с тем, что реально агрегировалось. Это делает сравнение train и validation надёжнее и помогает по последовательным цифрам понять, не начинается ли переобучение.

Итог

Когда вы суммируете одну «среднюю на батч» потерю по эпохе, делите на количество батчей, а не на размер датасета. Если loss — это MSE, у вас уже есть train MSE прямо из тренировочного цикла; если считаете её отдельно, делайте это в режиме eval. Соблюдая эти два правила, вы получите метрики, точно отражающие поведение модели.

Статья основана на вопросе на StackOverflow от mansi и ответе от nicod.