2025, Dec 17 05:00

Understanding def _() in Python: how decorators invoke single-underscore functions in JAX

Learn why def _() runs without being referenced: Python decorators capture and call the function object, enabling conditional execution with JAX cond.

Single underscore names can be baffling at a glance, especially inside decorated call sites. You see a def _(): and immediately wonder how it is ever invoked. The short answer: it is not meant to be referenced by name at all, yet it still executes because the decorator holds and calls the function object directly.

Example that raises the question

The pattern below defines a function inside another function and assigns it the name _. The implementation relies on a decorator that decides whether to run the wrapped callable.

import functools
import jax
from jax._src.pallas import helpers as pl_helpers

def execute_on_primary_core(axis_label: str):
  """Runs a function on the first core of the given axis."""
  core_count = jax.lax.axis_size(axis_label)
  if core_count == 1:
    return lambda g: g()

  def invoker(g):
    idx = jax.lax.axis_index(axis_label)

    @pl_helpers.when(idx == 0)
    @functools.wraps(g)
    def _():  # How is this called?
      return g()

  return invoker

What is actually going on

A name that is only _ is not the same as a name prefixed with _. Using _ as the full name is a convention that reads as “a name is syntactically required, but it will not be used.” In this situation the name is irrelevant because the function object is passed straight into a decorator, and the decorator is responsible for calling it when a condition is met.

The relevant decorator implementation shows the mechanism clearly: it receives the function object, not its name, and calls it conditionally.

import jax

def when_conditionally(predicate):
  def wrap(fn):
    if isinstance(predicate, bool):
      if predicate:
        fn()
    else:
      jax.lax.cond(predicate, fn, lambda: None)
  return wrap

The decorator captures the callable as fn and decides whether to run it immediately or to route through jax.lax.cond. At no point does it need to look up the function by its identifier. That is why tools that search for direct references to _ will not find call sites: the call happens through the decorator’s closure.

How to reason about the fix

There is nothing to fix in the control flow. The key is understanding that the decorator drives execution. If you prefer to see the same idea with names that emphasize intent, here is an equivalent rewrite with unchanged behavior.

import functools
import jax

# Same semantics as the referenced decorator

def run_if(predicate):
  def attach(callable_obj):
    if isinstance(predicate, bool):
      if predicate:
        callable_obj()
    else:
      jax.lax.cond(predicate, callable_obj, lambda: None)
  return attach

# Same structure as the earlier example, different identifiers only

def apply_on_first(axis_token: str):
  total = jax.lax.axis_size(axis_token)
  if total == 1:
    return lambda op: op()

  def binder(op):
    position = jax.lax.axis_index(axis_token)

    @run_if(position == 0)
    @functools.wraps(op)
    def _():
      return op()

  return binder

Why this matters

Understanding the single-underscore convention keeps you from chasing ghosts in code search or audit workflows. If you are scanning for invocations of _, you will not find them because the decorator already owns the call path. This knowledge helps when reasoning about execution order, side effects, and performance characteristics in codebases that lean on decorators and conditional execution.

Takeaways

If you encounter def _(): in modern Python code, read it as “a callable whose identifier is intentionally irrelevant.” Focus on the decorator that receives it. That is where invocation and control-flow decisions are made, often by directly calling the function object or by routing through a conditional primitive. When auditing, trace decorators first, not the textual references to the function’s name.