2025, Oct 04 01:00

Estimating Neuron Importance with Pre-Activation Gradients in PyTorch: Backward Hooks, Validation, and Pruning Insights

Learn how to compute neuron importance via pre-activation gradients in PyTorch using backward hooks, verify with a toy model, and understand limits for pruning.

When estimating neuron importance in classification models, a common baseline is the partial derivative of the correct class logit with respect to a neuron’s pre-activation. If you want importance at the cluster level, you can average these derivatives across samples in a cluster. A practical issue that sometimes arises is that the resulting gradients look tiny and pruning “top neurons” ranked by these scores does not outperform random pruning. This can look like a bug in the gradient extraction, especially if simpler heuristics like using pre-activations seem to work better.

Repro of the gradient-based attribution used for neuron pre-activations

The following snippet computes per-neuron importance for a given layer by collecting the partial derivatives of the target logit with respect to that layer’s pre-activations, then averaging across samples. Assume the provided layer is a PyTorch Linear layer, so the backward hook’s last argument corresponds to gradients w.r.t. pre-activations.

class InfluenceBase(ABC):
    def __init__(self, net: nn.Module, loaders: Dict[int, DataLoader]):
        self.net = net
        self.loaders = loaders
        if self.net is not None:
            self.net.eval()
    @abstractmethod
    def score(self, sublayer: Module, group_id: int) -> np.ndarray:
        pass
class GradInfluence(InfluenceBase):
    def score(self, sublayer: Module, group_id: int) -> np.ndarray:
        device = next(self.net.parameters()).device
        batch_x, batch_y = next(iter(self.loaders[group_id]))
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        batch_x.requires_grad = True
        traces = []
        def tap(mod, gin, gout):
            traces.append(gout[0].detach().cpu())
        h = sublayer.register_full_backward_hook(tap)
        for j in range(len(batch_x)):
            xj = batch_x[j].unsqueeze(0)
            logits = self.net(xj).squeeze()
            cls = batch_y[j].item()
            logits[cls].backward()
        h.remove()
        g_stack = torch.cat(traces, dim=0)
        g_mean = g_stack.mean(dim=0).numpy()
        return g_mean

What this code actually computes and why

The hook is registered on a Linear layer. During backpropagation, the hook receives gradients with respect to that layer’s outputs before activation, i.e., pre-activations. The loop drives a backward pass per sample, taking the scalar logit for the ground-truth class and computing its gradient with respect to the hooked layer’s pre-activations. Those gradients are collected for each item, concatenated, and averaged per neuron to yield the cluster-level importance. There’s no need to zero parameter gradients inside the loop here because the attribution is read from the hook, not from parameter accumulators.

In other words, if your goal is “partial derivative of the correct class logit with respect to neuron pre-activations,” this code performs the intended computation for a Linear layer.

Sanity-check with a controlled toy model

A reliable way to verify that the approach behaves sensibly is to construct a simple model and dataset where the “right answer” is obvious. Consider a network with three separate branches, each reading exactly one feature, and a shared classification head. Only the second feature is highly predictive of the label, so neurons in the middle branch should rank as most important by the gradient metric.

class MiniNet(nn.Module):
    def __init__(self, width=4):
        super().__init__()
        self.branch0 = nn.Sequential(nn.Linear(1, width), nn.ReLU())
        self.branch1 = nn.Sequential(nn.Linear(1, width), nn.ReLU())
        self.branch2 = nn.Sequential(nn.Linear(1, width), nn.ReLU())
        self.head = nn.Linear(3 * width, 2)
    def forward(self, z):
        z0 = z[:, [0]]
        z1 = z[:, [1]]
        z2 = z[:, [2]]
        u0 = self.branch0(z0)
        u1 = self.branch1(z1)
        u2 = self.branch2(z2)
        u = torch.cat([u0, u1, u2], dim=1)
        return self.head(u)

Now build synthetic data so that the label is determined by the sign of the second feature.

M = 200
feats = np.random.randn(M, 3).astype(np.float32)
labels = (feats[:, 1] > 0).astype(np.int64)
ds = TensorDataset(torch.from_numpy(feats), torch.from_numpy(labels))
dl = DataLoader(ds, batch_size=16, shuffle=True)

Train the model briefly.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = MiniNet().to(device)
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()
for ep in range(10):
    for xb, yb in dl:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        pred = net(xb)
        loss = criterion(pred, yb)
        loss.backward()
        optimizer.step()
    if (ep + 1) % 5 == 0:
        acc = (pred.argmax(1) == yb).float().mean().item()
        print(f"Epoch {ep+1}, Loss={loss.item():.4f}, Acc={acc:.3f}")

Compute the average partial derivatives for each branch’s Linear layer. This mirrors the earlier logic and examines pre-activations.

def grad_scores(net, sublayer, inputs, targets):
    device = next(net.parameters()).device
    xs = inputs.to(device)
    ys = targets.to(device)
    xs.requires_grad = True
    logs = []
    def tap(mod, gin, gout):
        logs.append(gout[0].detach().cpu())
    h = sublayer.register_full_backward_hook(tap)
    for k in range(len(xs)):
        xi = xs[k].unsqueeze(0)
        out = net(xi).squeeze()
        yi = ys[k].item()
        out[yi].backward()
    h.remove()
    g = torch.cat(logs, dim=0)
    return g.mean(dim=0).numpy()
xb, yb = next(iter(dl))
s0 = grad_scores(net, net.branch0[0], xb, yb)
s1 = grad_scores(net, net.branch1[0], xb, yb)
s2 = grad_scores(net, net.branch2[0], xb, yb)
print("Branch0 importance:", abs(s0).mean())
print("Branch1 importance:", abs(s1).mean())
print("Branch2 importance:", abs(s2).mean())

A representative run yields outputs of the form:

Branch0 importance: 0.083575
Branch1 importance: 0.3273803  <- highest, as expected
Branch2 importance: 0.05412347

This shows the gradient-based attribution over pre-activations surfaces the genuinely informative branch when the signal is clear.

What this means for real-world setups

If the same code produces weak separation in your application, that does not necessarily indicate a bug in the derivative collection. It can happen when signals are subtle, distributed across many units, or otherwise not aligned with the simple “correct-class logit gradient at pre-activation” proxy. The toy setup above is a useful sanity check: if it highlights the intended branch, the mechanism is working as designed. You can then use similar controlled scenarios to compare different ranking schemes, including your pruning-after-ranking protocol.

Why it’s worth knowing

Gradient-based attributions are sensitive to the data regime. Verifying the pipeline with a controlled example gives confidence that your hooks and aggregation are correct. From there, any discrepancy between methods in a real dataset likely comes from the data and model dynamics rather than a coding error in the gradient extraction.

Takeaway

If you need neuron importance tied to pre-activations, the backward hook on a Linear layer’s outputs retrieves the correct partial derivatives, and averaging them across samples yields a cluster-level score. Use a simple, strongly structured task to validate that your procedure elevates the known-important units, then bring the same tooling back to your production data to interpret and compare importance measures in a way that is faithful to how your model actually learns.

The article is based on a question from StackOverflow by jonupp and an answer by Sachin Hosmani.