2025, Dec 25 11:00

Build a Numerically Stable Non-Central Chi-Squared Density in PyTorch with a Log-Domain Series

Learn to compute a numerically stable non-central chi-squared pdf in PyTorch using a log-space series. Code and simulation show accuracy while avoiding overflow

Building a numerically stable non-central chi-squared density in torch can be tricky if you evaluate the closed-form pdf directly with multiplications and divisions. The straightforward route quickly runs into underflow/overflow. The task is to compute the density in a way that plays nicely with torch tensors and remains stable across reasonable parameter ranges.

Problem setup

The target is the non-central chi-squared distribution with degrees of freedom k and non-centrality parameter lambda. A natural way to validate the density is to simulate a sum of squared Normal random variables with non-zero means and then compare a histogram with the computed pdf.

import torch
import matplotlib.pyplot as plt
plt.style.use('dark_background')
shift = torch.rand((5))
gauss = torch.distributions.Normal(shift, 1)
n_samples = 5000
X = gauss.sample((n_samples,))
Q = (X**2).sum(-1)
grid = torch.linspace(0.1, Q.max() + 10, 100)
plt.hist(Q, bins=int(n_samples**0.5), density=True)

What makes it fail in practice

The direct pdf formula involves products of large and small numbers and a special-function series that ideally runs to infinity. That cocktail is numerically fragile. The key to making it work in torch is to perform as much of the computation as possible in log-space and only exponentiate at the very end. The special-function part is approximated by a truncated series; the series should be infinite in principle, but we will stop at a finite number of terms.

Solution: stable log-domain computation

The approach below follows the description from the standard non-central chi-squared reference and replaces products with sums in the log domain. The series is truncated at j=100. Broadcasting is handled via unsqueeze to keep tensor shapes aligned.

import torch
import matplotlib.pyplot as plt
plt.style.use('dark_background')
# log of bezel-like function approximation
# the sum should go to infinity; we stop at j=100
def log_bezel_terms(nu, z, n_terms=100):
    if not isinstance(z, torch.Tensor):
        z = torch.tensor(z)
    if not isinstance(nu, torch.Tensor):
        nu = torch.tensor(nu)
    j = torch.arange(0, n_terms)
    denom = torch.lgamma(j + nu + 1) + torch.lgamma(j + 1)
    numer = 2 * j * (0.5 * z.unsqueeze(-1)).log()
    combo = numer - denom
    return nu * (z / 2).log().unsqueeze_(-1) + combo
def ncx2_pdf(u, lam, df):
    if not isinstance(lam, torch.Tensor):
        lam = torch.tensor(lam)
    if not isinstance(df, torch.Tensor):
        df = torch.tensor(df)
    if not isinstance(u, torch.Tensor):
        u = torch.tensor(u)
    bez_out = log_bezel_terms(0.5 * df - 1, (lam * u).sqrt())
    u = u.unsqueeze_(-1)
    return (
        torch.tensor(0.5).log()
        + (-0.5 * (u + lam))
        + (u.log() - lam.log()) * (0.25 * df - 0.5)
        + bez_out
    ).exp().sum(-1)
# simulate a non-central chi-squared via sum of squared Normals
shift = torch.rand((5))
gauss = torch.distributions.Normal(shift, 1)
lam = (shift**2).sum()
n_samples = 5000
X = gauss.sample((n_samples,))
Q = (X**2).sum(-1)
# evaluate pdf on a grid and compare to the histogram
grid = torch.linspace(0.1, Q.max() + 10, 100)
pdf = ncx2_pdf(grid, lam, len(shift))
plt.title(f"df={len(shift)}, lambda={lam:0.2f}")
plt.hist(Q, bins=int(n_samples**0.5), density=True)
plt.plot(grid, pdf)

Why this works

The stability comes from transforming the multiplicative structure into additive log terms. The special-function series is accumulated in log-space using torch.lgamma for factorial-like factors and log on the remaining components. Only after constructing the full log-density per series term do we exponentiate and sum along the series dimension. The series itself should run to infinity; here, it is truncated at 100 terms to obtain a practical approximation.

Why you want this in your toolbox

If you are fitting models or running simulations that rely on non-central chi-squared densities, you need a routine that behaves well for a range of parameters and integrates smoothly with torch tensors. A log-domain implementation reduces numerical issues and lets you overlay densities on simulated data without fighting instability.

Takeaways

Use a series-based expression for the non-central chi-squared density, keep intermediate computations in log-space, and only exponentiate right before summation over series terms. When building a grid for visualization, simulate from a Normal with non-zero mean, form the sum of squares, and validate the shape of the result against the histogram. The series is infinite in theory; in code it is terminated at a finite number of terms to get a usable approximation.