2026, Jan 13 00:02

Как добавить предвычисленные узловые признаки к графам ZINC в PyTorch Geometric

Разбираем, как добавить предвычисленные узловые тензоры к каждому графу ZINC в PyTorch Geometric: сохранение в Data и оборачивание в InMemoryDataset.

Прикрепление заранее вычисленных тензорных признаков к каждому графу в наборе данных torch_geometric — частая задача, когда вы дополняете представления узлов вне привычного пайплайна. Идея проста: взять уже посчитанный тензор на граф и сохранить его рядом с существующими полями в каждом объекте Data, чтобы последующий код мог прозрачно его использовать.

Постановка задачи

Допустим, вы загружаете тренировочную выборку ZINC и смотрите на первый граф. Перед вами объект Data с признаками узлов, рёбрами и целевой переменной:

from torch_geometric.datasets import ZINC
zinc_train = ZINC(root='my_path', split='train')
print(zinc_train[0])  # Пример структуры
# Data(x=[33, 1], edge_index=[2, 72], edge_attr=[72], y=[1])

Вы уже вычислили дополнительный тензорный признак для каждого графа и собрали их в список: i‑й тензор соответствует i‑му графу в датасете. Каждый тензор — это признаковое описание на уровне узлов. В идеале у каждого Data должен появиться этот «груз» следующего вида:

Data(x=[33, 1], edge_index=[2, 72], edge_attr=[72], y=[1], new_feature=[33, 12])

Что происходит на самом деле

Каждый элемент ZINC — это экземпляр torch_geometric.data.Data. Эти объекты гибкие и умеют хранить произвольные атрибуты. Если внешние признаки подготовлены по графам и их первая размерность совпадает с числом узлов конкретного графа, их можно напрямую прикрепить к каждому экземпляру Data и держать внутри датасета.

Решение

Добавьте новые тензоры, проходя по базовому датасету и устанавливая атрибут у каждого объекта Data. Удобный приём — обернуть исходный датасет в легковесный InMemoryDataset, который один раз выполнит обогащение и будет отдавать уже изменённые элементы.

import torch
from torch_geometric.datasets import ZINC
from torch_geometric.data import InMemoryDataset
# 1) Загрузите исходный набор данных
base_ds = ZINC(root='my_path', split='train')
# 2) Постройте список новых тензоров на уровне узлов, выровненных с base_ds
# Замените ниже на ваши реальные тензоры; размеры должны соответствовать каждому графу
aug_tensor_list = []
for g in base_ds:
    node_count = g.x.size(0)
    # Пример-заглушка: [num_nodes, 12]
    feat_tensor = torch.randn(node_count, 12)
    aug_tensor_list.append(feat_tensor)
# 3) Оберните и прикрепите признаки к каждому объекту Data
class ZINCEnriched(InMemoryDataset):
    def __init__(self, src_ds, feature_list):
        self._store = []
        for idx in range(len(src_ds)):
            item = src_ds[idx]
            item.new_feature = feature_list[idx]
            self._store.append(item)
        super().__init__('.', transform=None, pre_transform=None)
        self.data, self.slices = self.collate(self._store)
    def __len__(self):
        return len(self._store)
    def get(self, index):
        return self._store[index]
# 4) Создайте расширенный набор данных
zinc_with_extra = ZINCEnriched(base_ds, aug_tensor_list)
# 5) Посмотрите на пример
example = zinc_with_extra[0]
print(example)
print("Shape of new feature:", example.new_feature.shape)
# Data(x=[33, 1], edge_index=[2, 72], edge_attr=[72], y=[1], new_feature=[33, 12])
# Форма нового признака: torch.Size([33, 12])

Почему это важно

Как только дополнительный тензор закреплён за каждым объектом Data, он перемещается вместе с графом куда бы тот ни пошёл. Так предобработка и моделирование остаются в одном русле: модели не нужно жонглировать параллельными структурами — вся узловая информация доступна в одном месте.

Что следует запомнить

Если соблюдается взаимно‑однозначное соответствие между элементами датасета и вашими тензорами, а размерность по узлам совпадает для каждого графа, можно просто задать новый атрибут у каждого объекта Data. Подход выше проверен на датасете, использованном именно так. Если при инициализации датасета вы задаёте transform, поведение может отличаться; чтобы воспроизвести рабочую конфигурацию, применяйте метод к датасету, созданному так, как показано здесь.