2025, Oct 04 01:18

Как посчитать важность нейронов по предактивациям: градиентная атрибуция в PyTorch

Как получить важность нейронов через производные логита по предактивациям в PyTorch: рабочий код с backward hook, проверка на простой модели и выводы.

При оценке важности нейронов в моделях классификации распространённой отправной точкой служит частная производная логита правильного класса по предактивации нейрона. Если нужна оценка на уровне кластера, можно усреднить эти производные по всем примерам внутри кластера. На практике иногда возникает ситуация, когда полученные градиенты оказываются слишком малы, и удаление «топ‑нейронов» по таким баллам не превосходит случайное прореживание. Это легко принять за ошибку при извлечении градиентов, особенно если более простые эвристики, вроде использования самих предактиваций, работают заметно лучше.

Воспроизведение градиентной атрибуции для предактиваций нейронов

Ниже приведён фрагмент кода, который вычисляет важность каждого нейрона для выбранного слоя: собираются частные производные целевого логита по предактивациям этого слоя, после чего они усредняются по примерам. Предположим, что указанный слой — это PyTorch Linear, поэтому последний аргумент бэквард‑хука соответствует градиентам по предактивациям.

class InfluenceBase(ABC):
    def __init__(self, net: nn.Module, loaders: Dict[int, DataLoader]):
        self.net = net
        self.loaders = loaders
        if self.net is not None:
            self.net.eval()

    @abstractmethod
    def score(self, sublayer: Module, group_id: int) -> np.ndarray:
        pass


class GradInfluence(InfluenceBase):
    def score(self, sublayer: Module, group_id: int) -> np.ndarray:
        device = next(self.net.parameters()).device

        batch_x, batch_y = next(iter(self.loaders[group_id]))
        batch_x = batch_x.to(device)
        batch_y = batch_y.to(device)
        batch_x.requires_grad = True

        traces = []

        def tap(mod, gin, gout):
            traces.append(gout[0].detach().cpu())

        h = sublayer.register_full_backward_hook(tap)

        for j in range(len(batch_x)):
            xj = batch_x[j].unsqueeze(0)
            logits = self.net(xj).squeeze()
            cls = batch_y[j].item()

            logits[cls].backward()

        h.remove()

        g_stack = torch.cat(traces, dim=0)
        g_mean = g_stack.mean(dim=0).numpy()
        return g_mean

Что именно считает этот код и почему

Хук установлен на слой Linear. Во время обратного прохода он получает градиенты по выходам этого слоя до активации, то есть по предактивациям. Цикл выполняет обратное распространение для каждого примера, беря скалярный логит корректного класса и вычисляя его градиент по предактивациям «подключённого» слоя. Эти градиенты собираются для каждого объекта, конкатенируются и усредняются по нейронам, формируя оценку важности на уровне кластера. Обнулять градиенты параметров в цикле не требуется, поскольку атрибуция считывается из хука, а не из аккумуляторов градиентов параметров.

Иначе говоря, если ваша цель — «частная производная логита правильного класса по предактивации нейрона», то для слоя Linear этот код выполняет именно такую вычислительную процедуру.

Проверка на простой контрольной модели

Надёжный способ убедиться, что подход ведёт себя разумно, — сконструировать простую модель и датасет, где «правильный ответ» очевиден. Рассмотрим сеть с тремя раздельными ветвями: каждая читает ровно один признак, а далее идёт общий классификационный «хэд». Сильно предсказательным является только второй признак, поэтому нейроны во второй ветви должны получить наибольшие оценки важности по градиентному метрику.

class MiniNet(nn.Module):
    def __init__(self, width=4):
        super().__init__()
        self.branch0 = nn.Sequential(nn.Linear(1, width), nn.ReLU())
        self.branch1 = nn.Sequential(nn.Linear(1, width), nn.ReLU())
        self.branch2 = nn.Sequential(nn.Linear(1, width), nn.ReLU())
        self.head = nn.Linear(3 * width, 2)

    def forward(self, z):
        z0 = z[:, [0]]
        z1 = z[:, [1]]
        z2 = z[:, [2]]
        u0 = self.branch0(z0)
        u1 = self.branch1(z1)
        u2 = self.branch2(z2)
        u = torch.cat([u0, u1, u2], dim=1)
        return self.head(u)

Теперь сгенерируем синтетические данные так, чтобы метка определялась знаком второго признака.

M = 200
feats = np.random.randn(M, 3).astype(np.float32)
labels = (feats[:, 1] > 0).astype(np.int64)

ds = TensorDataset(torch.from_numpy(feats), torch.from_numpy(labels))
dl = DataLoader(ds, batch_size=16, shuffle=True)

Кратко обучим модель.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net = MiniNet().to(device)

optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
criterion = nn.CrossEntropyLoss()

for ep in range(10):
    for xb, yb in dl:
        xb, yb = xb.to(device), yb.to(device)
        optimizer.zero_grad()
        pred = net(xb)
        loss = criterion(pred, yb)
        loss.backward()
        optimizer.step()
    if (ep + 1) % 5 == 0:
        acc = (pred.argmax(1) == yb).float().mean().item()
        print(f"Epoch {ep+1}, Loss={loss.item():.4f}, Acc={acc:.3f}")

Посчитаем средние частные производные для каждого Linear‑слоя в ветвях. Это повторяет логику выше и рассматривает предактивации.

def grad_scores(net, sublayer, inputs, targets):
    device = next(net.parameters()).device
    xs = inputs.to(device)
    ys = targets.to(device)
    xs.requires_grad = True

    logs = []

    def tap(mod, gin, gout):
        logs.append(gout[0].detach().cpu())

    h = sublayer.register_full_backward_hook(tap)

    for k in range(len(xs)):
        xi = xs[k].unsqueeze(0)
        out = net(xi).squeeze()
        yi = ys[k].item()
        out[yi].backward()

    h.remove()

    g = torch.cat(logs, dim=0)
    return g.mean(dim=0).numpy()

xb, yb = next(iter(dl))
s0 = grad_scores(net, net.branch0[0], xb, yb)
s1 = grad_scores(net, net.branch1[0], xb, yb)
s2 = grad_scores(net, net.branch2[0], xb, yb)

print("Branch0 importance:", abs(s0).mean())
print("Branch1 importance:", abs(s1).mean())
print("Branch2 importance:", abs(s2).mean())

Типичный запуск даёт примерно такой вывод:

Branch0 importance: 0.083575
Branch1 importance: 0.3273803  <- highest, as expected
Branch2 importance: 0.05412347

Это показывает: градиентная атрибуция по предактивациям действительно выделяет информативную ветвь, когда сигнал достаточно явный.

Что это означает для реальных сценариев

Если тот же код даёт слабое разделение в вашем приложении, это не обязательно указывает на ошибку в сборе производных. Так бывает, когда сигналы слабы, «размазаны» по множеству единиц или просто плохо согласуются с простым прокси «градиент логита правильного класса по предактивации». Приведённая выше игрушечная постановка — полезная проверка: если она подсвечивает нужную ветвь, механизм работает как задумано. Далее можно использовать похожие контролируемые сценарии, чтобы сравнивать разные схемы ранжирования, включая ваш протокол прореживания после ранжирования.

Почему это важно

Градиентные атрибуции чувствительны к режиму данных. Проверка конвейера на контролируемом примере даёт уверенность, что хуки и агрегирование настроены корректно. После этого расхождения между методами на реальном датасете, скорее всего, обусловлены данными и динамикой обучения модели, а не багом в извлечении градиентов.

Вывод

Если вам нужна важность нейронов относительно предактиваций, обратный хук на выходах слоя Linear возвращает нужные частные производные, а их усреднение по примерам даёт оценку на уровне кластера. Используйте простой, жёстко структурированный тест, чтобы убедиться, что процедура поднимает заведомо важные единицы, а затем применяйте тот же инструментарий к вашим производственным данным, интерпретируя и сопоставляя меры важности так, как это соответствует реальному процессу обучения модели.

Статья основана на вопросе на StackOverflow от jonupp и ответе Sachin Hosmani.