2025, Oct 22 15:01

ConditionalCategorical в pomegranate: правильная форма probs и n_categories

Разбираем TypeError при инициализации ConditionalCategorical в pomegranate: почему одна матрица NumPy не работает и как передать probs и n_categories списками.

При настройке простого ConditionalCategorical в pomegranate хочется просто передать одну 2×2 матрицу NumPy как таблицу вероятностей и идти дальше. Но так легко наткнуться на неожиданный TypeError где‑то внутри библиотеки. Исправление минимальное, однако форма входных данных имеет значение.

Минимальный пример воспроизведения

Ниже фрагмент кода, который задаёт бинарного родителя и бинарного потомка с прямым one‑hot‑отображением, после чего инициализирует ConditionalCategorical. Именно он вызывает ошибку.

from pomegranate.distributions import ConditionalCategorical
import numpy as np
prob_grid = [
    [1.0, 0.0],  # родитель = 0 -> потомок = 0
    [0.0, 1.0],  # родитель = 1 -> потомок = 1
]
probs_np = np.array(prob_grid, dtype=np.float32)
cat_sizes = [2, 2]
cond_model = ConditionalCategorical(probs_np, n_categories=cat_sizes)
print("ConditionalCategorical constructed:", cond_model)

В чём на самом деле проблема

ConditionalCategorical ожидает, что аргумент probs будет списком массивов NumPy. Если передать один массив, меняется логика вычисления внутренней формы, и позже при инициализации происходит сбой. Помимо обёртки таблицы вероятностей в список, n_categories также должен соответствовать этой вложенности.

Исправление и корректный пример

Обверните таблицу вероятностей в список и обновите n_categories, чтобы он отражал такую вложенность.

from pomegranate.distributions import ConditionalCategorical
import numpy as np
prob_grid = [
    [1.0, 0.0],  # родитель = 0 -> потомок = 0
    [0.0, 1.0],  # родитель = 1 -> потомок = 1
]
probs_np = np.array(prob_grid, dtype=np.float32)
cat_sizes = [2, 2]
cond_model = ConditionalCategorical(probs=[probs_np], n_categories=[cat_sizes])
print("ConditionalCategorical constructed:", cond_model)

Почему это важно

Интерфейсы, принимающие структурированные таблицы вероятностей, часто кодируют семантику формы через типы контейнеров, а не только через размерности массивов. Здесь список сообщает инициализатору ожидаемую организацию условных распределений. Соблюдение этого контракта предотвращает непонятные ошибки во время выполнения и обеспечивает корректное поведение последующих операций.

Итог

Если ConditionalCategorical падает при передаче одного массива NumPy, передайте probs как список массивов и отразите ту же структуру в n_categories. Эта небольшая правка приводит входы в соответствие ожиданиям конструктора и делает настройку модели чистой и предсказуемой.

Эта статья основана на вопросе с StackOverflow от Isaac A и ответе от p011yr011n.