2025, Dec 02 06:02
Эффективное извлечение пар строк по булевой маске в NumPy без бродкаста
Как извлекать парные строки из массивов NumPy по булевой маске без бродкаста: используем np.where для координат и прямую индексацию, экономя память и время.
Извлечение парных строк с помощью булевой маски по двум массивам разной формы легко превращается в «пылесос» памяти, если материализовать всю растянутую (broadcast) сетку. Ниже — компактный способ получить тот же результат без выделения огромного промежуточного тензора.
Задача
Есть два массива, у которых последняя ось имеет размер 3, и булева маска, заданная на декартовом произведении их пространственных осей. Требуется получить два выхода формы Mx3 (с одинаковым M), содержащие те строки из каждого источника, для которых маска истинна.
import numpy as np
np.random.seed(0)
srcA = np.random.rand(4, 5, 3)
srcB = np.random.rand(6, 3)
mask = np.random.rand(4, 5, 6) >= 0.7
Прямолинейный подход — повторить оба входа до полной объединённой формы и проиндексировать их маской. Это срабатывает, но заставляет создать огромный промежуточный массив размера M1x...xMPxN1x...xNQ — ровно того, чего хочется избежать.
shapeA, shapeB = srcA.shape[:-1], srcB.shape[:-1]
tgt_shape = shapeA + shapeB + (3,)
expA = srcA[..., None, :].repeat(np.prod(shapeB), axis=-2).reshape(tgt_shape)
expB = srcB[None, ..., :].repeat(np.prod(shapeA), axis=0).reshape(tgt_shape)
outA, outB = expA[mask, :], expB[mask, :]
Почему это невыгодно
Повторение до декартова произведения резервирует память, пропорциональную произведению пространственных размеров обоих входов. Даже если конечный отбор разреженный, промежуточный массив становится узким местом. Маска уже кодирует подходящие позиции; важно напрямую использовать эти индексы положений на исходных массивах.
Решение: один раз получить индексы — дважды собрать
Вызовите np.where(mask), чтобы получить координаты всех истинных позиций на объединённой сетке, затем разделите полученный кортеж координат на части, адресующие соответственно srcA и srcB. Поскольку оба входа как минимум двумерные, никаких особых случаев не требуется.
idx_tup = np.where(mask)
p_dims = len(srcA.shape) - 1 # пространственные оси srcA (без последней оси=3)
q_dims = len(srcB.shape) - 1 # пространственные оси srcB (без последней оси=3)
selA = idx_tup[:p_dims]
selB = idx_tup[p_dims:p_dims + q_dims]
resA = srcA[selA]
resB = srcB[selB]
Кортеж idx_tup содержит по одному массиву индексов на каждую пространственную ось объединённой сетки. Первые p_dims массивов указывают на пространственные оси srcA, а следующие q_dims — на srcB. Умное индексирование этими срезами даёт два массива формы Kx3, где K — число истинных элементов в маске.
Зачем это нужно
Такой подход исключает необходимость выделять весь бродкаст‑тензор. Маска используется один раз для получения координат, а затем эти координаты напрямую выбирают строки из исходных массивов. Память расходуется примерно как на входные данные плюс массивы индексов — обычно это на порядки меньше, чем декартово произведение.
Итоги
Если булева маска покрывает произведение двух входных сеток, не материализуйте бродкаст‑массивы. Один раз вычислите координаты индексов через np.where, разделите их по числу пространственных осей каждого входа и соберите данные из исходников. Вы сохраните форму вывода Mx3 для обоих результатов и полностью обойдёте тяжёлый промежуточный шаг.