2025, Nov 26 21:01

Почему зависает выход PyTorch Lightning на macOS MPS и как это исправить

На Mac Studio с M4 Max PyTorch Lightning зависает при выходе после обучения CIFAR‑10 на MPS из‑за DataLoader workers. Решение: обновиться до nightly PyTorch.

При обучении базовой модели CIFAR‑10 с PyTorch Lightning на Mac Studio с M4 Max цикл тренировки может успешно завершаться, но процесс не выходит. В консоли видно, что обучение остановлено из‑за достижения max_epochs, однако программа затем зависает на неопределённое время, пока вы не прервёте её вручную. Уменьшение num_workers до нуля избавляет от зависания, но замедляет загрузку данных. Удаление валидационного загрузчика ситуацию не меняет; проблема воспроизводится и при значении num_workers, равном одному.

Минимальный пример, который воспроизводит зависание

Следующий скрипт на Lightning повторяет типичную настройку обучения CIFAR‑10 и после успешного прогона демонстрирует проблему завершения. Базовая логика не менялась; имена сущностей слегка переименованы для наглядности.

import torch
import torch.nn as nn
import torch.nn.functional as F

import lightning as L

from torch.utils.data import DataLoader
from torchvision import datasets, transforms


class TinyCifarNet(L.LightningModule):
    def __init__(self):
        super().__init__()
        self.conv_a = nn.Conv2d(3, 32, 3, padding=1)
        self.conv_b = nn.Conv2d(32, 64, 3, padding=1)
        self.conv_c = nn.Conv2d(64, 64, 3, padding=1)
        self.down = nn.MaxPool2d(2, 2)
        self.dense_a = nn.Linear(64 * 4 * 4, 512)
        self.dense_b = nn.Linear(512, 10)

    def forward(self, data):
        data = self.down(F.relu(self.conv_a(data)))
        data = self.down(F.relu(self.conv_b(data)))
        data = self.down(F.relu(self.conv_c(data)))
        data = data.view(-1, 64 * 4 * 4)
        data = F.relu(self.dense_a(data))
        data = self.dense_b(data)
        return data

    def training_step(self, pack, step_idx):
        inputs, targets = pack
        scores = self(inputs)
        loss = F.cross_entropy(scores, targets)
        top1 = (scores.argmax(1) == targets).float().mean()
        self.log("train_loss", loss)
        self.log("train_acc", top1)
        return loss

    def validation_step(self, pack, step_idx):
        inputs, targets = pack
        scores = self(inputs)
        loss = F.cross_entropy(scores, targets)
        top1 = (scores.argmax(1) == targets).float().mean()
        self.log("val_loss", loss)
        self.log("val_acc", top1)

    def test_step(self, pack, step_idx):
        inputs, targets = pack
        scores = self(inputs)
        loss = F.cross_entropy(scores, targets)
        top1 = (scores.argmax(1) == targets).float().mean()
        self.log("test_loss", loss)
        self.log("test_acc", top1)

    def configure_optimizers(self):
        opt = torch.optim.Adam(self.parameters(), lr=1e-3)
        return opt


if __name__ == "__main__":
    aug_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    aug_eval = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    ds_train = datasets.CIFAR10(root="./data", train=True, download=True, transform=aug_train)
    ds_eval = datasets.CIFAR10(root="./data", train=False, download=True, transform=aug_eval)

    ldr_train = DataLoader(ds_train, batch_size=64, shuffle=True, num_workers=14, persistent_workers=True)
    ldr_eval = DataLoader(ds_eval, batch_size=64, shuffle=False, num_workers=14, persistent_workers=True)

    net = TinyCifarNet()

    runner = L.Trainer(max_epochs=5, accelerator="mps", devices="auto")
    runner.fit(net, ldr_train, ldr_eval)

Что на самом деле происходит

Симптомы выглядят как проблема Lightning, но это не специфично для Lightning. Обычный цикл обучения на чистом PyTorch ведёт себя так же в этой конфигурации: по завершении тренировки процесс не завершает работу корректно. Если опустить num_workers до нуля, скрипт выходит, что сильно намекает на связь с корректным завершением процессов‑воркеров. В окружении, где это наблюдалось, использовались PyTorch 2.7.1, PyTorch Lightning 2.5.1, macOS 15.5, ускорение MPS на Mac Studio с M4 Max; зависание воспроизводилось даже при num_workers, равном одному. Очень базовый учебный скрипт из Lightning может выглядеть незатронутым, из‑за чего поначалу легко запутаться.

Надёжное решение, которое сработало

На практике помогли два подхода. Быстрый выход можно форсировать через os._exit(0), но это просто убивает интерпретатор, обходя обычную очистку ресурсов. Устойчивое решение — обновить PyTorch до nightly‑сборки, которая сняла зависание при завершении в этом окружении.

pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu

На момент написания эта команда ставит дев‑сборку грядущей серии 2.8 и устраняет описанную проблему.

Почему это важно

Тихие зависания при завершении легко пропустить в CI‑пайплайнах, ноутбуках или долгих заданиях на совместном железе. Они зря тратят время GPU или MPS, удерживают файловые дескрипторы и усложняют автоматизацию вокруг обучающих задач. Понимание, что поведение не связано конкретно с Lightning и что обновлённая сборка PyTorch исправляет его, может сберечь часы отладки настроек загрузчиков и колбэков, которые ни при чём.

Практическое резюме

Если обучение в PyTorch Lightning завершается, но процесс на macOS с MPS не выходит, не спешите переписывать цикл тренировки или выкидывать функциональность. Сначала проверьте, воспроизводится ли то же поведение в минимальном скрипте на чистом PyTorch. Если да, попробуйте перейти на свежую nightly‑сборку PyTorch — в описанной ситуации это полностью сняло проблему. Установка num_workers в ноль работает как обходной путь, но замедляет загрузку данных; используйте его лишь как временную меру. Актуальный стек фреймворков — часто самый простой способ избежать тонких багов в завершении многопроцессной работы.