2025, Dec 24 06:02
Как исправить ошибку Trying to backward through the graph a second time при обучении GAN в PyTorch
Почему в GAN на PyTorch возникает ошибка Trying to backward through the graph a second time при нескольких шагах генератора и как исправить её с detach.
Когда вы настраиваете GAN так, чтобы Генератор обучался несколько раз на один шаг Дискриминатора, легко попасть в ловушки autograd. Часто это выливается в ошибку «Trying to backward through the graph a second time», хотя backward вызывается для функции потерь Дискриминатора лишь один раз. Причина тоньше: на вход Дискриминатору попадает тензор, полученный через вычислительный граф, который уже был освобождён во время backward Генератора.
Постановка задачи
В примере используется бинарная кросс‑энтропия как функция потерь, оптимизатор Adam для обеих моделей и вход 28×28, развёрнутый в 784 признака. Задача — делать n обновлений Генератора на m шагов Дискриминатора, где n > m.
for epoch in range(total_epochs):
for batch_id, (real_batch, _) in enumerate(data_stream):
real_batch = real_batch.view(-1, 784).to(dev)
bsz = real_batch.shape[0]
# Обновления Генератора (n раз)
for _ in range(g_update_factor):
z = torch.randn(bsz, latent_dim).to(dev)
fake_batch = netG(z)
gen_scores = netD(fake_batch).view(-1)
g_loss = bce_loss(gen_scores, torch.ones_like(gen_scores))
g_loss.backward()
optG.step()
netG.zero_grad()
# Обновления Дискриминатора (m раз)
for _ in range(d_update_factor):
real_scores = netD(real_batch).view(-1)
d_real = bce_loss(real_scores, torch.ones_like(real_scores))
fake_scores = netD(fake_batch).view(-1)
d_fake = bce_loss(fake_scores, torch.zeros_like(fake_scores))
d_loss = (d_real + d_fake) * 0.5
d_loss.backward() # здесь возникает ошибка
optD.step()
netD.zero_grad()
Почему это не работает
После завершения обратного прохода Генератора PyTorch освобождает вычислительный граф, связанный с fake_batch. Позже Дискриминатор пытается посчитать свою потерю, используя тот же тензор. Autograd видит тензор, привязанный к уже несуществующему графу, и пытается провести по нему градиенты, что и вызывает ошибку «Trying to backward through the graph a second time».
Проблема — в вызове Дискриминатора на выходе Генератора, который всё ещё требует градиентов, хотя исходный граф уже был «съеден» обновлением Генератора. Иными словами, от Дискриминатора требуют делать backprop по пути, который уже удалён.
Решение
Дискриминатор должен воспринимать выборки Генератора как константы. Отсоедините (detach) синтетический батч перед подачей в Дискриминатор. Тогда autograd не будет пытаться отслеживать градиенты через Генератор на шаге Дискриминатора, и проблема «освобождённого графа» исчезнет. Есть альтернативный путь — вызывать backward у Генератора с retain_graph=True, но это не соответствует задуманной схеме обучения и увеличивает потребление памяти.
for epoch in range(total_epochs):
for batch_id, (real_batch, _) in enumerate(data_stream):
real_batch = real_batch.view(-1, 784).to(dev)
bsz = real_batch.shape[0]
# Обновления Генератора (n раз)
for _ in range(g_update_factor):
z = torch.randn(bsz, latent_dim).to(dev)
fake_batch = netG(z)
gen_scores = netD(fake_batch).view(-1)
g_loss = bce_loss(gen_scores, torch.ones_like(gen_scores))
g_loss.backward()
optG.step()
netG.zero_grad()
# Обновления Дискриминатора (m раз)
for _ in range(d_update_factor):
real_scores = netD(real_batch).view(-1)
d_real = bce_loss(real_scores, torch.ones_like(real_scores))
# detach, чтобы D не пропускал градиенты через граф G
fake_scores = netD(fake_batch.detach()).view(-1)
d_fake = bce_loss(fake_scores, torch.zeros_like(fake_scores))
d_loss = (d_real + d_fake) * 0.5
d_loss.backward()
optD.step()
netD.zero_grad()
Если по какой-то причине вам действительно нужно переиспользовать один и тот же вычислительный граф в нескольких обратных проходах, можно вызвать backward с параметром retain_graph=True на этапе Генератора. Однако в данной схеме обучения цель — не позволять Дискриминатору проходить по графу Генератора, и detach остаётся простым и корректным решением.
Почему это важно
Чередующаяся оптимизация в GAN опирается на чёткую границу между двумя фазами обновлений. Если во время хода Дискриминатора ему «видны» градиенты через Генератор, это и концептуально противоречит целевой постановке, и на практике приводит к ошибкам, как только шаг Генератора освобождает свой граф. Отсоединяя выходы Генератора, вы явно фиксируете расписание обучения и избегаете загадочных ошибок autograd.
Выводы
Если на один шаг Дискриминатора приходится несколько обновлений Генератора, во время обновления Дискриминатора всегда подавайте ему отделённые (detached) выходы Генератора. Используйте retain_graph=True только тогда, когда осознанно переиспользуете граф; в остальных случаях держите фазы разнесёнными — так вы сохраните корректность и контроль над памятью.