2025, Sep 30 11:00

Correctly batching PySpark DataFrame rows by threshold with a resetting cumulative sum that keeps the boundary row

Learn how to batch PySpark DataFrame rows by a threshold using a resetting cumulative sum and pandas UDF, ensuring the boundary row stays in the same batch.

Batching records in a PySpark DataFrame until a running total crosses a threshold sounds straightforward, but a detail trips many implementations: the boundary row that pushes the sum over the limit must stay in the same batch. If you drop that record or roll it to the next group, your batches won’t match the intended semantics.

Problem

We have a DataFrame with two columns, ID and Count. The goal is to build batches (lists of IDs) such that, scanning the data, we keep adding IDs to a batch until the first moment the cumulative Count meets or exceeds a given threshold, and at that exact moment the batch closes including the boundary ID.

from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType
from pyspark.sql import Window
import pandas as pd
rows = [
    ("abc", 500),
    ("def", 300),
    ("ghi", 400),
    ("jkl", 200),
    ("mno", 1100),
    ("pqr", 900),
]
frame = spark.createDataFrame(rows, ["ID", "Count"])
threshold = 1000

For a threshold of 1000, the batches should look like [abc, def, ghi] with a sum of 1200, then [jkl, pqr] with a sum of 1100, and [mno] as a standalone batch because its Count is already at least the threshold.

Why the naive approach fails

It is tempting to reach for floor-based grouping or a simple running sum plus integer division. That path drops the row that actually pushes the total over the threshold, resulting in incorrect grouping. As the core idea states:

Floor and other methods don’t include the trailing ID that sends it over the threshold. You would need some sort of resetting counter.

This is where a resetting cumulative sum comes in: you accumulate until the threshold is reached or exceeded, close the batch while keeping that boundary row in it, reset the accumulator, and continue.

Solution

The approach below follows a two-step strategy. First, split out rows whose Count already meets or exceeds the threshold so they form standalone batches. Then, for the remaining rows, compute batch identifiers using a resetting cumulative sum implemented via a pandas_udf. Finally, number the standalone rows after the last computed batch and union the results.

from pyspark.sql import functions as F
from pyspark.sql.types import IntegerType
from pyspark.sql import Window
import pandas as pd
rows = [
    ("abc", 500),
    ("def", 300),
    ("ghi", 400),
    ("jkl", 200),
    ("mno", 1100),
    ("pqr", 900),
]
frame = spark.createDataFrame(rows, ["ID", "Count"])
# 1) Separate rows that already meet/exceed the threshold
over_limit = frame.where(frame.Count >= 1000)
under_limit = frame.where(frame.Count < 1000)
# 2) Attach row numbers to maintain a consistent processing order for later numbering
under_limit = under_limit.withColumn(
    "idx", F.row_number().over(Window.orderBy(F.monotonically_increasing_id()))
)
over_limit = over_limit.withColumn(
    "idx", F.row_number().over(Window.orderBy(F.monotonically_increasing_id()))
)
# 3) Assign batch ids using a resetting cumulative sum
@F.pandas_udf(IntegerType())
def tag_batches(cnt: pd.Series) -> pd.Series:
    tags = []
    acc = 0
    bucket_no = 0
    for val in cnt:
        acc += val
        tags.append(bucket_no)
        if acc >= 1000:
            bucket_no += 1
            acc = 0
    return pd.Series(tags)
under_limit = under_limit.withColumn("batch_id", tag_batches(F.col("Count")))
# 4) Continue numbering for standalone rows after the last batch of under-threshold rows
over_limit = over_limit.withColumn(
    "batch_id",
    F.col("idx") + under_limit.agg({"batch_id": "max"}).collect()[0][0] + 1,
)
# 5) Combine results and view
result_ds = under_limit.unionByName(over_limit, allowMissingColumns=True)
result_ds.select("ID", "Count", "batch_id").show()

What’s happening under the hood

The crucial piece is the pandas_udf that walks through Count values and maintains a running accumulator. Every time the accumulator reaches or exceeds the threshold, it emits the current batch id for that row, then resets the accumulator and increments the batch id. This guarantees that the row that pushed the total over the threshold is placed in the same batch, not deferred. Rows that already meet the threshold are handled separately and assigned batch ids that continue from where the previous series stopped. Row numbers generated over a monotonically increasing id help maintain a deterministic sequence for this numbering step.

Why this matters

Batch construction that preserves boundary rows is essential when batches drive downstream jobs, resource planning, or SLA windows. If the turning-point record jumps to the next batch, you distort both batch sizes and the number of batches, which can cascade into uneven processing and hard-to-debug behavior. The resetting cumulative sum approach keeps the grouping consistent with the intended definition.

Takeaways

If a grouping must include the first record that takes the cumulative total over a threshold, use a resetting cumulative sum rather than floor-like logic. Keep items already at or above the threshold as their own batches and number them after the smaller batches. Ensure a consistent processing order for reproducible grouping, then union the results and proceed with downstream processing.

The article is based on a question from StackOverflow by fstr and an answer by Chris.