2025, Nov 10 21:01

Удаление линейного тренда по группам в Polars: struct, map_batches и over

Как в Polars убрать линейный тренд по группам без циклов: упаковываем X и Y в struct, считаем модель в map_batches и применяем over, получая столбец остатков.

Подгонка и удаление линейного тренда по группам — распространённый шаг предобработки, но легко увязнуть в схеме «цикл + конкатенация», которая перечёркивает сильные стороны колонночного движка. Цель проста: для каждой комбинации GROUP1 и GROUP2 подогнать прямую по X и Y, вычесть её из исходного Y и сохранить исходную форму данных, добавив один столбец с остатками.

Постановка задачи

Лобовой подход перебирает группы, выполняет линейную аппроксимацию, считает остатки и склеивает частичные результаты. Он работает, но не задействует выражения с учётом групп или оконные механизмы и требует ручной конкатенации.

import polars as pl
import numpy as np

# Пример фрейма
data = pl.DataFrame(
    {
        "GROUP1": [1, 1, 1, 2, 2, 2],
        "GROUP2": ["A", "A", "A", "B", "B", "B"],
        "X": [0.0, 1.0, 2.0, 0.0, 1.0, 2.0],
        "Y": [5.0, 7.0, 9.0, 3.0, 4.0, 6.0],
    }
)

# Реализация с циклом
def remove_linear_trend_per_group(frame: pl.DataFrame) -> pl.DataFrame:
    parts = []
    for _, chunk in frame.group_by(["GROUP1", "GROUP2"]):
        xx = chunk["X"].to_numpy()
        yy = chunk["Y"].to_numpy()

        slope, intercept = np.polyfit(xx, yy, 1)
        resid = yy - (slope * xx + intercept)

        parts.append(chunk.with_columns(pl.Series("residual", resid)))
    return pl.concat(parts)

out = remove_linear_trend_per_group(data)
print(out)

Почему это неидеально

Если воспользоваться group_by().agg(), вы неизбежно агрегируете каждую группу — строки схлопываются, и исходная форма теряется. Применение with_columns без контекста групп тоже не спасает: модель нужно посчитать на уровне группы и затем применить к каждой строке этой же группы. Подход «цикл + конкатенация» сохраняет высоту таблицы, но не даёт выразить логику единой декларативной цепочкой и обходит возможности движка эффективно работать по группам.

Решение: Struct + map_batches + over

Идея в том, чтобы передать пользовательской функции сразу несколько столбцов и применить её по группам, не теряя строки. Оберните X и Y в Struct, передайте его в .map_batches() и задайте область действия через .over(). Так вы сохраните и ширину, и высоту данных и добавите столбец с остатками, рассчитанный для каждой группы.

import polars as pl
import numpy as np

# UDF, который получает Struct из колонок и возвращает остатки
def calc_residuals_batch(s: pl.Series) -> pl.Series:
    xs, ys = s.struct.unnest()
    m, c = np.polyfit(xs, ys, 1)
    return ys - (m * xs + c)

result = data.with_columns(
    pl.struct("X", "Y")
      .map_batches(calc_residuals_batch)
      .over("GROUP1", "GROUP2")
      .alias("residual")
)

print(result)

Что происходит под капотом

Struct объединяет значения нескольких столбцов, чтобы они вместе попали в функцию. map_batches применяет функцию пакетами, а не поэлементно — именно это нужно, когда функция ожидает и возвращает массивы, согласованные со строками текущей группы. over задаёт групповое окно, гарантируя, что линейная подгонка выполняется в пределах каждой комбинации GROUP1 и GROUP2, а результат выравнивается по каждой строке этой группы.

Почему это важно

Такой подход сохраняет выразительность конвейера и не рушит строки. Не нужен ручной срез и склейка: все исходные столбцы остаются на месте, а к ним добавляется столбец остатков, рассчитанный моделью на уровне группы. Логику проще понимать: контекст группировки, упаковка столбцов и преобразование заданы декларативно.

Итоги

Когда нужно групповое преобразование, которое зависит от нескольких столбцов и при этом должно вернуть столбец полной длины, объединяйте входы в Struct, применяйте map_batches для пакетной Python-логики и ограничивайте вычисления через over, чтобы считать внутри каждой группы. Этот приём позволяет подогнать линейную модель по группам и вычесть её из исходных данных, не выходя из контекста выражений датафрейма.

Статья основана на вопросе с сайта StackOverflow от Thomas и ответе от jqurious.