2025, Dec 28 09:02

Минимальный сдвиг между массивами в Python: почему векторизация NumPy ошибается и как решить циклом с Numba/JIT

Как посчитать минимальный сдвиг между массивами в Python: почему векторизация NumPy даёт сбой, корректный алгоритм на циклах и ускорение через Numba/JIT.

Сопоставить два целочисленных массива, вычисляя для каждой позиции, на сколько шагов нужно сдвинуться, чтобы попасть на ближайшее равное значение в другом массиве, кажется простой задачей — пока не попытаешься её векторизовать. Реальный сценарий: миллион элементов, небольшой набор значений — целые числа в диапазоне [0, 10], без пропусков, и жёсткое требование находить минимальный сдвиг для каждой позиции. Ниже — краткое объяснение, почему наивный подход с NumPy даёт сбой, и как написать корректное решение на циклах, которое при необходимости можно JIT-компилировать.

Reproducing the problem

Рассмотрим два массива и ожидаемые сдвиги по индексам, чтобы добраться до ближайшего совпадения во втором массиве. На первой позиции слева стоит 1; в правом массиве ближайшая 1 находится на расстоянии трёх шагов, поэтому результат начинается с 3. Нули на совпадающих позициях дают сдвиг 0. Задача — посчитать такой сдвиг для каждого индекса.

import numpy as np

arr_u = np.array([1, 0, 2, 0, 0, 1, 0, 2, 0, 0, 1, 0, 2, 0, 0])
arr_v = np.array([2, 0, 0, 1, 0, 2, 0, 0, 1, 0, 2, 0, 0, 1, 0])

expected = np.array([3, 0, 2, 1, 0, 2, 0, 2, 1, 0, 2, 0, 2, 1, 0])

Заманчиво для каждого уникального значения найти все индексы и вычесть их. Следующая функция демонстрирует эту ловушку. Она выглядит векторизованной и лаконичной, но даёт неправильные расстояния, потому что сопоставляет вхождения по порядку, а не по близости.

def offsets_broadcast_fail(seq_x, seq_y):
    assert len(seq_x) == len(seq_y)
    out = np.zeros(len(seq_x))
    classes = np.unique(seq_x)
    for tag in classes:
        idx_x = np.where(seq_x == tag)
        idx_y = np.where(seq_y == tag)
        out[idx_x] = np.subtract(idx_x, idx_y)
    return np.abs(out).astype(int)

bad = offsets_broadcast_fail(arr_u, arr_v)

В итоге получается соответствие по порядку вхождений, а не к ближайшим элементам. В примере выше функция возвращает [3, 0, 2, 1, 0, 3, 0, 2, 1, 0, 3, 0, 2, 1, 0] вместо ожидаемого [3, 0, 2, 1, 0, 2, 0, 2, 1, 0, 2, 0, 2, 1, 0].

Why the vectorized attempt fails

Для каждого значения метод вычитает один набор индексов из другого. Это неявно связывает k‑е вхождение в одном массиве с k‑м вхождением в другом, независимо от того, где реально находится ближайший элемент с тем же значением. Сопоставление по порядку вхождений — не то же самое, что поиск ближайшего совпадения вокруг каждой позиции. В результате некоторые расстояния оказываются завышенными, хотя рядом есть более близкое одинаковое значение на другом порядковом месте.

A simple, correct approach

Поиск по своей природе локален и зависит от условий. Вместо глобального сопоставления идите от каждого индекса наружу — влево и вправо — пока не встретите то же значение во втором массиве. Если элементы уже совпадают на одной позиции, сдвиг равен нулю. Если совпадения нет вовсе, запишите -1. Реализация ниже следует этой логике и её легко проверять и понимать.

def nearest_shift_offsets(seq_a, seq_b):
    assert len(seq_a) == len(seq_b)
    n = len(seq_a)
    out = np.zeros(n, np.int64)
    for pos, val in enumerate(seq_a):
        if seq_b[pos] == val:
            out[pos] = 0
            continue
        for step in range(1, max(n - pos, pos)):
            if (pos - step >= 0) and (seq_b[pos - step] == val):
                out[pos] = step
                break
            elif (pos + step < n) and (seq_b[pos + step] == val):
                out[pos] = step
                break
        else:
            out[pos] = -1
    return out

good = nearest_shift_offsets(arr_u, arr_v)

Для приведённых входных данных это даёт ожидаемый массив сдвигов.

What to expect at scale

Ограничения входных данных важны. Размер массивов может достигать порядка миллиона элементов; значения — плотные целые числа в диапазоне [0, 10]; пропусков нет. При таких условиях базовым выбором остаётся поиск в цикле. Если время выполнения устраивает — на этом можно остановиться. Если нет — функцию можно JIT-компилировать. При указанных ограничениях есть конкретная ремарка: явное задание буфера вывода как np.int64 обеспечивает совместимость с Numba. По отчётам профилирования, JIT-версия очень быстра при, например, равномерном распределении значений, но может быть крайне медленной в патологических случаях, вроде отсортированных входов. И следите за вводящими в заблуждение замерами: отмечалось, что микросекундные результаты для входов масштабом в миллион элементов неправдоподобны.

Why this is worth knowing

Не всякая задача с массивами выигрывает от векторизации. Поиск и выравнивание часто требуют такого управления потоком, которое плохо ложится на broadcasting и редукции. Умение вовремя отказаться от погонь за векторизованной «однострочкой» в пользу простых циклов убережёт от тонких ошибок корректности и бесполезной оптимизации. В этом случае нет готовых примитивов NumPy или SciPy, рассчитанных именно на такой «сдвиг до ближайшего совпадения», поэтому самый простой корректный подход — ещё и наиболее удобная база для поддержки.

Conclusion

Если задача — найти ближайший сдвиг по каждому индексу между двумя массивами, не используйте глобальную арифметику над наборами индексов, которая связывает вхождения по порядку. Реализуйте прямой локальный поиск по индексу, расширяя область в обе стороны, пока не встретится совпадение. Проверьте на небольших примерах и масштабируйте. Если время выполнения станет критичным, JIT-компиляция той же логики способна дать практичный прирост, но помните, что производительность зависит от распределения данных. Сохраняйте тип выходного массива совместимым для компиляции и скептически относитесь к подозрительно кратким замерам.