2025, Dec 06 11:00

Comparing SymPy Expressions for Equivalence with Structural (AST) Matching, Not Polynomials

Learn to test SymPy expression equivalence beyond polynomials with structural (AST) matching. Handles logs, exps and symbolic parameters in mixed expressions.

Comparing two SymPy expressions for equivalence sounds straightforward until non-polynomial pieces show up. A quick attempt is to convert both sides into polynomials in selected symbols and compare their monomials and coefficients. That works for plain polynomials but breaks as soon as you introduce log, exp or other non-polynomial operators. Below is a practical walkthrough of the pitfall and a robust structural approach that handles complex expressions while keeping parameters symbolic.

Problem setup

The goal is to tell whether two expressions are equivalent up to a linear combination of symbolic parameters with respect to a given set of variables. The intent is captured by tests like a*x + b versus c*x + d, or a*log(b*x + c) + d versus e + f*log(g + h*x). The expectation is that coefficients and constant offsets can be different symbols, but the structural dependence on the chosen variables (such as x) should match.

Baseline approach that fails on non-polynomials

Here is a direct comparison via sympy.Poly. It expands both sides, builds polynomials in the provided variables, and compares monomials and numerical coefficients.

import sympy
def poly_equiv(lhs, rhs, vars_list = []):
    lhs = sympy.expand(lhs)
    rhs = sympy.expand(rhs)
    p1 = sympy.Poly(lhs, *vars_list)
    p2 = sympy.Poly(rhs, *vars_list)
    if p1 is None or p2 is None:
        return False
    if set(p1.monoms()) != set(p2.monoms()):
        return False
    for idx in range(0, len(p1.coeffs())):
        if p1.coeffs()[idx].is_Number and p2.coeffs()[idx].is_Number:
            if p1.coeffs()[idx] != p2.coeffs()[idx]:
                return False
    return True
# Minimal demonstration that raises on non-polynomials
a, b, c, d, e, f, g, h, x = sympy.symbols('a b c d e f g h x')
left = sympy.sympify("a*log(b*x+c)+d")
right = sympy.sympify("e+log(g+x*h)*f")
# This call triggers sympy.polys.polyerrors.PolynomialError
poly_equiv(left, right, [x])

Why this breaks

sympy.Poly expects a polynomial in the specified generators. As soon as log, exp or other non-polynomial constructs appear in the expression, building a polynomial is not defined and SymPy raises PolynomialError. This means the “compare by polynomial monomials and coefficients” strategy is unsuitable for mixed expressions.

In this context, “equivalent” means matching the same linear form with respect to the target variables while allowing free symbolic parameters to differ, as the tests indicate. For example, a*x + b and c*x + d should match with respect to x, and a*log(b*x+c) + d should match e + f*log(g + h*x) with respect to x as well.

A structural strategy that handles logs, exps and friends

The following approach avoids polynomials entirely. It normalizes expressions relative to the chosen variables, collapses subexpressions that do not depend on those variables, and then checks that both sides have the same abstract syntax tree shape. It also accounts for commutativity by sorting arguments in a consistent way during comparison. This method worked on the original test suite including the logarithmic and exponential cases. One case is noted as tricky and may behave inconsistently in practice.

import sympy
# Create a fresh placeholder symbol each time
def fresh_atom():
    return sympy.Dummy()
# Normalize an expression with respect to a set of pivot symbols.
# Subexpressions that do not contain any pivot are collapsed, keeping only one
# representative placeholder in each Add/Mul while neutralizing the rest.
def collapse_for_vars(expr, pivots):
    expr = sympy.collect(expr, extract_terms(expr, pivots))
    for arg in expr.args:
        expr = expr.subs(arg, collapse_for_vars(arg, pivots))
    expr = sympy.simplify(expr)
    if expr.is_Add:
        scalars = []
        with_syms = []
        no_syms = []
        for arg in expr.args:
            if arg.is_Number:
                scalars.append(arg)
            else:
                found = False
                for s in pivots:
                    if arg.has(s):
                        found = True
                        break
                if found:
                    with_syms.append(arg)
                else:
                    no_syms.append(arg)
        if len(no_syms):
            for n in scalars:
                expr = expr.subs(n, 0)
            for e in no_syms[1:]:
                expr = expr.subs(e, 0)
            expr = expr.subs(no_syms[0], fresh_atom())
            expr = sympy.simplify(expr)
    elif expr.is_Mul:
        scalars = []
        with_syms = []
        no_syms = []
        for arg in expr.args:
            if arg.is_Number:
                scalars.append(arg)
            else:
                found = False
                for s in pivots:
                    if arg.has(s):
                        found = True
                        break
                if found:
                    with_syms.append(arg)
                else:
                    no_syms.append(arg)
        if len(no_syms):
            for n in scalars:
                expr = expr.subs(n, 1)
            for e in no_syms[1:]:
                expr = expr.subs(e, 1)
            expr = expr.subs(no_syms[0], fresh_atom())
            expr = sympy.simplify(expr)
    return expr
# Extract a sortable key for subexpressions so commutative arguments can be
# compared irrespective of order.
def extract_terms(expr, pivots):
    parts = []
    if expr.is_Add:
        return [extract_terms(a, pivots)[0] for a in expr.args]
    elif expr.is_Mul:
        nums = []
        syms = []
        has_expr = False
        for arg in expr.args:
            if arg.is_Number:
                nums.append(arg)
            elif arg.is_Symbol:
                syms.append(arg)
            else:
                has_expr = True
        e = expr
        l = len(syms)
        syms = list(set(syms) - set(pivots))
        if has_expr:
            for n in nums:
                e = e.subs(n, 1)
            for s in syms:
                e = e.subs(s, 1)
            e = sympy.simplify(e)
        else:
            if l:
                for n in nums:
                    e = e.subs(n, 1)
                for s in syms:
                    e = e.subs(s, 1)
                e = sympy.simplify(e)
        parts = [e]
    else:
        parts.append(expr)
    return parts
# Structural equality with respect to pivots.
def same_shape(a, b, pivots):
    if type(a) != type(b):
        if not ((a.is_Number and b.is_Symbol) or (b.is_Number and a.is_Symbol)):
            return False
    else:
        if a.is_Symbol:
            if (a in pivots or b in pivots):
                if a != b:
                    return False
        elif a.is_number:
            if a != b:
                return False
    if len(a.args) != len(b.args):
        return False
    a_args = a.args
    if len(a_args):
        e = [extract_terms(arg, pivots)[0] for arg in a_args]
        e, a_args = list(zip(*sorted(zip(e, a_args), key=lambda x: str(x[0]))))
    b_args = b.args
    if len(b_args):
        e = [extract_terms(arg, pivots)[0] for arg in b_args]
        e, b_args = list(zip(*sorted(zip(e, b_args), key=lambda x: str(x[0]))))
    return all(same_shape(u, v, pivots) for u, v in zip(a_args, b_args))
# Public API: normalize both sides with respect to pivots and compare shapes.
def shape_equal(lhs, rhs, pivots = []):
    lhs = sympy.expand(sympy.sympify(lhs))
    lhs = collapse_for_vars(lhs, pivots)
    lhs = sympy.expand(lhs)
    rhs = sympy.expand(sympy.sympify(rhs))
    rhs = collapse_for_vars(rhs, pivots)
    rhs = sympy.expand(rhs)
    return same_shape(lhs, rhs, pivots)
# Example: works for logarithmic structure relative to x
a, b, c, d, e, f, g, h, x = sympy.symbols('a b c d e f g h x')
expr_left = sympy.sympify("a*log(b*x+c)+d")
expr_right = sympy.sympify("e+log(g+x*h)*f")
shape_equal(expr_left, expr_right, [x])

What changed and how it helps

The structural method first expands and simplifies each side. Inside each Add and Mul it separates number-like pieces, parts that depend on the chosen variables, and parts that do not. Variable-free pieces are collapsed so they cannot affect the structure, with only one placeholder retained. Arguments of commutative nodes are then compared after being sorted by a derived key, which avoids order-based mismatches. Finally, the algorithm recurses down the expression trees and enforces that the designated variables match between both sides. This makes it possible to declare a*log(b*x+c)+d and e + f*log(g + h*x) equivalent with respect to x, while correctly distinguishing expressions when variables differ.

On the extended test set, all checks passed. One case is explicitly called out as tricky and may sometimes fail and sometimes pass; that instability should be kept in mind when validating more exotic compositions.

Why you want this in your toolbox

Real-world symbolic workloads rarely stay within the comfort zone of polynomials. Once logs, exponentials or other composite functions appear, you need a comparison strategy that respects structure rather than trying to coerce everything into a polynomial. The normalization plus AST-shape comparison gives you a way to reason about families of expressions with symbolic parameters without committing to specific coefficients. It also keeps the control surface small: the meaning of “equivalence” is explicit through the set of variables you care about.

Practical takeaways

Use polynomial comparison only for actual polynomials; otherwise expect PolynomialError when non-polynomial constructs are present. If you need equivalence up to symbolic parameters, normalize expressions with respect to the variables of interest and compare their tree shape. When writing checks, remember that assert is a statement, not a function, and prefer x is None to x == None to test for sentinel values.

As with any symbolic normalization, keep an eye on edge cases. If you see occasional instability in a difficult case, capture it as a dedicated test and investigate how the transformation pipeline (expand, collect, simplify) interacts with that syntax.

The bottom line is simple. Be explicit about what “equivalent” means in your domain, avoid forcing non-polynomial structures into polynomials, and compare the right thing: the structure of dependence on the variables you actually care about.