2025, Nov 07 23:00

Remove per-group linear trend in Polars using Struct, map_batches, and over to add residuals without loops

Learn how to remove a per-group linear trend in Polars by computing residuals with Struct, map_batches and over, preserving the DataFrame shape without loops.

Fitting and removing a linear trend per group is a common preprocessing step, but it’s easy to fall into a loop-and-concat pattern that discards the strengths of a columnar engine. The goal here is simple: for each combination of GROUP1 and GROUP2, fit a line on X and Y, subtract the fit from the original Y, and keep the original shape with one extra residual column.

Problem setup

The straightforward implementation loops groups, runs a linear fit, computes residuals, and concatenates the partial results. It works, but it doesn’t leverage group-aware expressions or windowing, and it requires manual concatenation.

import polars as pl
import numpy as np

# Example frame
data = pl.DataFrame(
    {
        "GROUP1": [1, 1, 1, 2, 2, 2],
        "GROUP2": ["A", "A", "A", "B", "B", "B"],
        "X": [0.0, 1.0, 2.0, 0.0, 1.0, 2.0],
        "Y": [5.0, 7.0, 9.0, 3.0, 4.0, 6.0],
    }
)

# Looping implementation
def remove_linear_trend_per_group(frame: pl.DataFrame) -> pl.DataFrame:
    parts = []
    for _, chunk in frame.group_by(["GROUP1", "GROUP2"]):
        xx = chunk["X"].to_numpy()
        yy = chunk["Y"].to_numpy()

        slope, intercept = np.polyfit(xx, yy, 1)
        resid = yy - (slope * xx + intercept)

        parts.append(chunk.with_columns(pl.Series("residual", resid)))
    return pl.concat(parts)

out = remove_linear_trend_per_group(data)
print(out)

Why this is not ideal

When you reach for group_by().agg(), you inevitably summarize each group, which collapses rows and loses the original shape. Using with_columns without a group context also won’t help, because you need the linear fit to be computed per group, then applied back to each row of that same group. The loop-and-concat approach preserves the height, but it prevents you from expressing the logic as a single declarative pipeline and sidesteps the engine’s ability to operate over groups efficiently.

Solution: Struct + map_batches + over

The key is to pass multiple columns to a user function in one go and apply it per group while retaining the original rows. You can wrap X and Y into a Struct, feed that to .map_batches(), and scope the operation using .over(). This keeps the full width and height of the data and adds a residual column computed per group.

import polars as pl
import numpy as np

# UDF that receives a Struct of columns and returns residuals
def calc_residuals_batch(s: pl.Series) -> pl.Series:
    xs, ys = s.struct.unnest()
    m, c = np.polyfit(xs, ys, 1)
    return ys - (m * xs + c)

result = data.with_columns(
    pl.struct("X", "Y")
      .map_batches(calc_residuals_batch)
      .over("GROUP1", "GROUP2")
      .alias("residual")
)

print(result)

What’s happening under the hood

Struct combines multiple column values so they travel together into the function. map_batches applies the function to chunks instead of element-wise, which is exactly what you want when the function expects and returns arrays aligned to the current group’s rows. over establishes the group window, ensuring that the linear fit is computed within each GROUP1 and GROUP2 combination and that the result aligns to every row of that group.

Why this matters

This approach keeps your pipeline expressive and row-preserving. You avoid manual slicing and concatenation, keep all original columns intact, and attach a residual column produced by a per-group model. It’s also easier to reason about because the grouping context, column packaging, and transformation are expressed declaratively.

Takeaways

When you need a per-group transformation that depends on multiple columns but must return a full-length column, combine the inputs into a Struct, apply map_batches for batch-aware Python logic, and scope it with over to compute within each group. This pattern lets you fit a linear model per group and subtract it from the original data without leaving the dataframe expression context.

The article is based on a question from StackOverflow by Thomas and an answer by jqurious.