2025, Dec 06 15:00

How to Fix 'Trying to backward through the graph a second time' in PyTorch GAN Training by Detaching Generator Outputs

Fix PyTorch GAN autograd error 'Trying to backward through the graph a second time' by detaching Generator outputs before Discriminator updates. Train safely.

When you push a GAN to train the Generator multiple times per Discriminator update, it’s easy to run into autograd pitfalls. A common failure shows up as “Trying to backward through the graph a second time,” even though you only call backward once on the Discriminator’s loss. The root cause is subtler: what you feed to the Discriminator was produced through a graph that has already been freed by the Generator’s backward pass.

Problem setup

The setup uses Binary Cross Entropy for the loss, Adam optimizers for both models, and 28×28 inputs flattened to 784 features. The goal is to update the Generator n times per m updates of the Discriminator, with n > m.

for epoch in range(total_epochs):
    for batch_id, (real_batch, _) in enumerate(data_stream):
        real_batch = real_batch.view(-1, 784).to(dev)
        bsz = real_batch.shape[0]

        # Generator updates (n times)
        for _ in range(g_update_factor):
            z = torch.randn(bsz, latent_dim).to(dev)
            fake_batch = netG(z)
            gen_scores = netD(fake_batch).view(-1)
            g_loss = bce_loss(gen_scores, torch.ones_like(gen_scores))
            g_loss.backward()
            optG.step()
            netG.zero_grad()

        # Discriminator updates (m times)
        for _ in range(d_update_factor):
            real_scores = netD(real_batch).view(-1)
            d_real = bce_loss(real_scores, torch.ones_like(real_scores))
            fake_scores = netD(fake_batch).view(-1)
            d_fake = bce_loss(fake_scores, torch.zeros_like(fake_scores))
            d_loss = (d_real + d_fake) * 0.5
            d_loss.backward()  # error triggers here
            optD.step()
            netD.zero_grad()

Why this fails

After the Generator’s backward pass completes, PyTorch frees the computation graph associated with fake_batch. Later, the Discriminator tries to compute its loss using that same tensor. Autograd sees a tensor that’s tied to a graph that no longer exists and attempts to trace gradients through it, which triggers the “Trying to backward through the graph a second time” error.

The offending line is the Discriminator call on the Generator’s output that still requires grad, even though the original graph was already consumed by the Generator update. In other words, the Discriminator is asked to backprop through a path that has been discarded.

The fix

The Discriminator should treat the Generator’s samples as constants. Detach the synthetic batch before passing it to the Discriminator. This prevents autograd from trying to track gradients through the Generator during the Discriminator step and avoids the freed-graph issue. An alternative exists via retaining the graph on the Generator’s backward pass, but that isn’t the intended training flow here and increases memory usage.

for epoch in range(total_epochs):
    for batch_id, (real_batch, _) in enumerate(data_stream):
        real_batch = real_batch.view(-1, 784).to(dev)
        bsz = real_batch.shape[0]

        # Generator updates (n times)
        for _ in range(g_update_factor):
            z = torch.randn(bsz, latent_dim).to(dev)
            fake_batch = netG(z)
            gen_scores = netD(fake_batch).view(-1)
            g_loss = bce_loss(gen_scores, torch.ones_like(gen_scores))
            g_loss.backward()
            optG.step()
            netG.zero_grad()

        # Discriminator updates (m times)
        for _ in range(d_update_factor):
            real_scores = netD(real_batch).view(-1)
            d_real = bce_loss(real_scores, torch.ones_like(real_scores))
            # detach so D does not backprop through G's graph
            fake_scores = netD(fake_batch.detach()).view(-1)
            d_fake = bce_loss(fake_scores, torch.zeros_like(fake_scores))
            d_loss = (d_real + d_fake) * 0.5
            d_loss.backward()
            optD.step()
            netD.zero_grad()

If you truly need to reuse the same computation graph across multiple backward passes for some reason, you can call backward with retain_graph=True in the Generator phase. However, in this training regime the goal is simply to prevent the Discriminator from traversing the Generator’s graph, and detaching is the straightforward, correct approach.

Why this detail matters

Alternating optimization in GANs relies on a clear boundary between the two update phases. Letting the Discriminator see gradients through the Generator during its turn is both conceptually wrong for the objective and error-prone in practice once the Generator step frees its graph. Detaching the Generator’s outputs makes the training schedule explicit and avoids cryptic autograd errors.

Takeaways

When performing multiple Generator updates per Discriminator step, always pass detached Generator outputs to the Discriminator during its update. Reserve retain_graph=True for cases where you intentionally need to reuse a graph; otherwise, keep the phases decoupled to maintain correctness and control memory usage.