2025, Nov 09 15:03
Правильное усреднение MSE по батчам и выбор train или eval
Разбираем, как корректно усреднять MSE по батчам в PyTorch: делить на число батчей, а не объектов. Когда считать метрику в train и eval. Без искажений метрик.
Как правильно усреднять MSE по батчам и когда использовать режим train или eval
При обучении регрессионной модели обычно подводят итог по каждой эпохе с помощью среднеквадратичной ошибки (MSE). Есть тонкость, которая легко искажает метрику: делить накопленные потери по батчам на число объектов в датасете или на количество батчей. Ещё один частый вопрос — в каком режиме должна находиться модель при вычислении тренировочной метрики.
Пример проблемы
Ниже приведён цикл обучения, который накапливает MSE из каждого мини-батча, а затем делит на размер датасета. Параллельно считается корреляция для обеих выборок.
def fit_and_validate(net, loader_tr, loader_va, epochs=40, lr_rate=1e-3, wd=1e-5, patience_n=5):
net = net.to(device)
mse_fn = nn.MSELoss()
opt = optim.Adam(net.parameters(), lr=lr_rate, weight_decay=wd)
best_mse_val = float('inf')
stagnation = 0
for ep in range(epochs):
net.train()
agg_train_loss, preds_tr, gts_tr = 0, [], []
for feats, y in loader_tr:
feats, y = feats.to(device), y.to(device)
out = net(feats)
loss_val = mse_fn(out, y)
opt.zero_grad()
loss_val.backward()
opt.step()
agg_train_loss += loss_val.item()
preds_tr.extend(out.detach().cpu().numpy())
gts_tr.extend(y.cpu().numpy())
net.eval()
agg_val_loss, preds_va, gts_va = 0, [], []
with torch.no_grad():
for feats, y in loader_va:
feats, y = feats.to(device), y.to(device)
out = net(feats)
loss_val = mse_fn(out, y)
agg_val_loss += loss_val.item()
preds_va.extend(out.detach().cpu().numpy())
gts_va.extend(y.cpu().numpy())
mse_tr = agg_train_loss / len(train_data)
pc_tr = robust_pearsonr(preds_tr, gts_tr)
mse_va = agg_val_loss / len(val_data)
pc_va = robust_pearsonr(preds_va, gts_va)Что здесь не так и почему
Потеря на батч, которую возвращает nn.MSELoss(), уже является средним по объектам внутри этого батча. В течение эпохи вы суммируете по одному среднему значению на батч. К концу цикла у вас получается сумма средних по батчам, а не сумма покомпонентных квадратов ошибок по всем объектам. Чтобы превратить эту сумму батчевых средних в среднее за эпоху, нормировать нужно на число батчей в эпохе. Это как раз длина соответствующего загрузчика данных.
Исправление
Делите накопленные средние по батчам на количество батчей, то есть на длину loader’а для соответствующей выборки.
def fit_and_validate(net, loader_tr, loader_va, epochs=40, lr_rate=1e-3, wd=1e-5, patience_n=5):
net = net.to(device)
mse_fn = nn.MSELoss()
opt = optim.Adam(net.parameters(), lr=lr_rate, weight_decay=wd)
best_mse_val = float('inf')
stagnation = 0
for ep in range(epochs):
net.train()
agg_train_loss, preds_tr, gts_tr = 0, [], []
for feats, y in loader_tr:
feats, y = feats.to(device), y.to(device)
out = net(feats)
loss_val = mse_fn(out, y)
opt.zero_grad()
loss_val.backward()
opt.step()
agg_train_loss += loss_val.item()
preds_tr.extend(out.detach().cpu().numpy())
gts_tr.extend(y.cpu().numpy())
net.eval()
agg_val_loss, preds_va, gts_va = 0, [], []
with torch.no_grad():
for feats, y in loader_va:
feats, y = feats.to(device), y.to(device)
out = net(feats)
loss_val = mse_fn(out, y)
agg_val_loss += loss_val.item()
preds_va.extend(out.detach().cpu().numpy())
gts_va.extend(y.cpu().numpy())
mse_tr = agg_train_loss / len(loader_tr)
pc_tr = robust_pearsonr(preds_tr, gts_tr)
mse_va = agg_val_loss / len(loader_va)
pc_va = robust_pearsonr(preds_va, gts_va)Какой режим использовать при расчёте train MSE
Если ваша функция потерь — nn.MSELoss, то тренировочная потеря уже и есть train MSE, и вы естественным образом считаете её в режиме обучения. Если же вы хотите считать MSE как отдельную метрику, а не как loss, можно вычислять её в режиме eval.
Почему это важно
Нормирование по размеру датасета при накоплении средних по батчам искажает метрику эпохи и может как приглушать, так и усиливать сигнал. Использование числа батчей согласует среднее за эпоху с тем, что реально агрегировалось. Это делает сравнение train и validation надёжнее и помогает по последовательным цифрам понять, не начинается ли переобучение.
Итог
Когда вы суммируете одну «среднюю на батч» потерю по эпохе, делите на количество батчей, а не на размер датасета. Если loss — это MSE, у вас уже есть train MSE прямо из тренировочного цикла; если считаете её отдельно, делайте это в режиме eval. Соблюдая эти два правила, вы получите метрики, точно отражающие поведение модели.