2025, Sep 30 11:16

Как сформировать партии в PySpark по порогу без потери граничной строки

Разбираем правильное формирование партий в PySpark DataFrame по порогу: сбрасываемая накопительная сумма в pandas_udf, учет граничной строки и стабильный порядок

Формировать партии записей в PySpark DataFrame до тех пор, пока накопительная сумма не перейдёт порог, на первый взгляд просто. Но тут часто ошибаются в одной детали: граничная строка, которая переваливает сумму за предел, должна остаться в той же партии. Если удалить её или перенести в следующую группу, результат перестанет соответствовать задуманной логике.

Проблема

Есть DataFrame с двумя колонками: ID и Count. Наша цель — собирать партии (списки ID) так, чтобы при последовательном просмотре данных мы добавляли ID в партию до первого момента, когда накопленный Count достигает или превышает заданный порог. В этот же момент партия закрывается, включая тот самый граничный ID.

from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType
from pyspark.sql import Window
import pandas as pd

rows = [
    ("abc", 500),
    ("def", 300),
    ("ghi", 400),
    ("jkl", 200),
    ("mno", 1100),
    ("pqr", 900),
]
frame = spark.createDataFrame(rows, ["ID", "Count"])
threshold = 1000

При пороге 1000 партии должны получиться такими: [abc, def, ghi] с суммой 1200, затем [jkl, pqr] с суммой 1100, и [mno] как отдельная партия, потому что его Count уже не меньше порога.

Почему наивный подход даёт сбой

Хочется применить группировку через «пол» или сделать простую накопительную сумму с целочисленным делением. Но такие приёмы «теряют» строку, которая как раз и переваливает сумму через порог, — в итоге группы формируются неправильно. Суть проблемы в следующем:

Методы на базе floor и похожие не включают завершающий ID, который переносит сумму за порог. Нужен счётчик, который умеет сбрасываться.

Здесь выручает «сбрасываемая» накопительная сумма: накапливаем до достижения или превышения порога, закрываем партию, оставляя в ней граничную строку, затем обнуляем аккумулятор и идём дальше.

Решение

Подход ниже реализует двухшаговую стратегию. Сначала отделяем строки, у которых Count уже достигает или превышает порог, — они образуют самостоятельные партии. Затем для остальных строк вычисляем идентификаторы партий через «сбрасываемую» накопительную сумму, реализованную в pandas_udf. Наконец, продолжаем нумерацию для одиночных строк после последней полученной партии и объединяем результаты.

from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType
from pyspark.sql import Window
import pandas as pd

rows = [
    ("abc", 500),
    ("def", 300),
    ("ghi", 400),
    ("jkl", 200),
    ("mno", 1100),
    ("pqr", 900),
]
frame = spark.createDataFrame(rows, ["ID", "Count"])

# 1) Отделяем строки, которые уже достигают/превышают порог
over_limit = frame.where(frame.Count >= 1000)
under_limit = frame.where(frame.Count < 1000)

# 2) Добавляем порядковые номера, чтобы зафиксировать стабильный порядок обработки для последующей нумерации
under_limit = under_limit.withColumn(
    "idx", F.row_number().over(Window.orderBy(F.monotonically_increasing_id()))
)
over_limit = over_limit.withColumn(
    "idx", F.row_number().over(Window.orderBy(F.monotonically_increasing_id()))
)

# 3) Назначаем идентификаторы партий с помощью «сбрасываемой» накопительной суммы
@F.pandas_udf(IntegerType())
def tag_batches(cnt: pd.Series) -> pd.Series:
    tags = []
    acc = 0
    bucket_no = 0
    for val in cnt:
        acc += val
        tags.append(bucket_no)
        if acc >= 1000:
            bucket_no += 1
            acc = 0
    return pd.Series(tags)

under_limit = under_limit.withColumn("batch_id", tag_batches(F.col("Count")))

# 4) Продолжаем нумерацию для одиночных строк после последней партии строк ниже порога
over_limit = over_limit.withColumn(
    "batch_id",
    F.col("idx") + under_limit.agg({"batch_id": "max"}).collect()[0][0] + 1,
)

# 5) Объединяем результаты и смотрим вывод
result_ds = under_limit.unionByName(over_limit, allowMissingColumns=True)
result_ds.select("ID", "Count", "batch_id").show()

Что происходит под капотом

Ключевой элемент — pandas_udf, который проходит по значениям Count и ведёт текущий аккумулятор. Каждый раз, когда аккумулятор достигает или превышает порог, функция выдаёт для этой строки текущий идентификатор партии, затем обнуляет аккумулятор и увеличивает номер партии. Это гарантирует, что строка, перевалившая сумму через порог, останется в той же партии, а не уйдёт в следующую. Строки, которые уже удовлетворяют порогу, обрабатываются отдельно и получают номера партий, продолжающие последовательность после предыдущих. Номера строк, сгенерированные поверх монотонно возрастающего идентификатора, помогают сохранить детерминированный порядок для этапа нумерации.

Зачем это важно

Корректное формирование партий с сохранением граничных строк важно, когда от партий зависят последующие задачи, планирование ресурсов или окна SLA. Если ключевая запись «перескакивает» в следующую партию, искажаются и размеры партий, и их количество, что ведёт к неравномерной обработке и сложным для отладки эффектам. Подход со «сбрасываемой» накопительной суммой удерживает группировку в точном соответствии с определением.

Выводы

Если группировка должна включать первую запись, которая переносит накопленную сумму за порог, применяйте «сбрасываемую» накопительную сумму, а не логику в духе floor. Элементы, уже достигшие или превысившие порог, выделяйте в отдельные партии и нумеруйте их после малых партий. Обеспечьте стабильный порядок обработки для воспроизводимой группировки, затем объединяйте результаты и переходите к дальнейшим шагам.

Статья основана на вопросе на StackOverflow от fstr и ответе Chris.