2025, Oct 02 21:34
JAX में lax.scan को पैरामीटराइज़ करना: lambda/def क्लोज़र बनाम carry
JAX में lax.scan के तीन पैटर्न—lambda/def क्लोज़र और carry—कम्पाइल व प्रदर्शन पर लगभग समान हैं। अंतर समझें, पठनीयता के आधार पर सही तरीका चुनें. उदाहरण सहित
jax.lax.scan लूप को पैरामीटराइज़ करते समय अक्सर वही सवाल उठता है: क्या स्थिर मानों को क्लोज़र में कैप्चर करें, या उन्हें स्कैन की state के साथ carry करें? व्यवहार में तीन प्रचलित तरीके—lambda का इस्तेमाल, def से लोकल फ़ंक्शन बनाना, या पैरामीटर्स को carry के रूप में पास करना—JAX में लगभग एक जैसे कोड पथ में कम्पाइल होते हैं। यह समझना कि रनटाइम को ये सब एक-समान क्यों दिखते हैं, आपको काल्पनिक परफ़ॉर्मेंस ट्रेड‑ऑफ़ के बजाय पठनीयता के आधार पर चुनाव करने में मदद करता है।
तीनों पैटर्न दिखाने वाला पुनरुत्पादनीय सेटअप
import jax
import jax.numpy as jnp
import time
STEPS = 100
def evolve(s, a, dtau, damp):
    v = jnp.array([jnp.cos(s[2]) * a[0], jnp.sin(s[2]) * a[0], a[1]])
    s_next = s + v * dtau - damp * v
    return s_next
def advance(buf, a, dtau, damp):
    s, t = buf
    s_next = evolve(s, a[:, t], dtau, damp)
    return (s_next, t + 1), s_next
# Closure via lambda
def run_with_lambda(s0, a, dtau, damp):
    step_fn = lambda carry, _inp: advance(carry, a, dtau, damp)
    init = (s0, 0)
    return jax.lax.scan(step_fn, init, None, length=STEPS)
# Closure via def
def run_with_def(s0, a, dtau, damp):
    def step_fn(carry, _inp):
        return advance(carry, a, dtau, damp)
    return jax.lax.scan(step_fn, (s0, 0), None, length=STEPS)
# Pass everything as carry
def advance_with_carry(buf, _inp):
    s, t, a, dtau, damp = buf
    s_next = evolve(s, a[:, t], dtau, damp)
    return (s_next, t + 1, a, dtau, damp), s_next
def run_with_carry(s0, a, dtau, damp):
    return jax.lax.scan(advance_with_carry, (s0, 0, a, dtau, damp), None, length=STEPS)
# Demo inputs
s0 = jnp.array([2., 3., 4.])
rng = jax.random.PRNGKey(0)
a = jax.random.uniform(rng, shape=(2, STEPS))
dtau = 0.01
damp = 1e-6
# Optional benchmarking helper
def run_bench(g, label):
    g_jit = jax.jit(g)
    carry, traj = g_jit(s0, a, dtau, damp)
    carry[0].block_until_ready()
    t0 = time.time()
    for _ in range(1000):
        carry, ys = g_jit(s0, a, dtau, damp)
        carry[0].block_until_ready()
    t1 = time.time()
    print(f"{label:10s} | cached run: {t1 - t0:.6f}s")
# run_bench(run_with_lambda, "lambda")
# run_bench(run_with_def,    "def")
# run_bench(run_with_carry,  "carry_all")
JAX में वास्तव में क्या होता है
ये सभी तीनों रूप JAX में मानक माने जाते हैं, और रनटाइम के दृष्टिकोण से किसी एक को दूसरों पर प्राथमिकता देने का कारण नहीं है। lambda से बने अनाम फ़ंक्शन और def से बने नामित फ़ंक्शन को समान रूप से扱ा जाता है; JAX tracing और compilation के दौरान इनके बीच भेद नहीं करता। इसलिए a, dtau और damp को कैप्चर करने वाला क्लोज़र, चाहे आप उसे lambda से लिखें या def से, बराबर ही होता है।
carry‑आधारित संस्करण में अंतर केवल इतना है कि वही मान step फ़ंक्शन तक किस तरह पहुँचते हैं: उन्हें closed‑over constants के रूप में कैप्चर करने के बजाय scan के carry के ज़रिए थ्रेड किया जाता है। इससे scan के lowering में पैरामीटर्स थोड़े अलग दिखते हैं, पर निष्पादन के लिहाज़ से कम्पाइल हुई फ़ंक्शन के व्यवहार में कोई सार्थक फर्क नहीं पड़ता।
यहाँ अक्सर उठने वाला एक शब्दार्थ बिंदु है lambda‑assignment। ऊपर के lambda वाले रूप में वास्तव में एक lambda को किसी नाम को सौंपा गया है। कुछ Python शैली दिशानिर्देश इसे टालने की सलाह देते हैं, लेकिन JAX निष्पादन के संदर्भ में lambda‑assignment कोई समस्या नहीं है; उस शैली का पालन करना या नहीं करना अलग निर्णय है।
व्यवहार में समाधान
चूँकि JAX स्तर पर ये विकल्प समकक्ष हैं, वह रूप चुनें जिससे आपका कोड पढ़ना और संभालना आसान हो। अगर constants को क्लोज़र में लेकर step को लोकल फ़ंक्शन के रूप में लिखना ज़्यादा स्पष्ट लगता है, तो वैसा करें। और यदि आप चाहते हैं कि पैरामीटर्स scan की state में स्पष्ट रूप से दिखें, तो उन्हें carry के रूप में पास करना भी उतना ही ठीक है। नीचे local def का एक संक्षिप्त रूप दिया है:
import jax
import jax.numpy as jnp
STEPS = 100
def evolve(s, a, dtau, damp):
    v = jnp.array([jnp.cos(s[2]) * a[0], jnp.sin(s[2]) * a[0], a[1]])
    return s + v * dtau - damp * v
def advance(buf, a, dtau, damp):
    s, t = buf
    s_next = evolve(s, a[:, t], dtau, damp)
    return (s_next, t + 1), s_next
def rollout_scan(s0, a, dtau, damp):
    def step(carry, _):
        return advance(carry, a, dtau, damp)
    return jax.lax.scan(step, (s0, 0), None, length=STEPS)
यह क्यों मायने रखता है
यह जानना कि lambda, def और carry लगभग एक‑से कम्पाइल्ड कोड तक ले जाते हैं, आपको समय से पहले अनुकूलन करने से बचाता है और ध्यान प्रणाली के डायनेमिक्स पर केंद्रित रखता है। यह भी स्पष्ट हो जाता है कि जो भी अंतर महसूस हों, वे क्लोज़र बनाम carry की पसंद से आने की संभावना कम है, और JAX का tracing अनाम व नामित callables को एक समान मानता है।
मुख्य बातें
अपनी टीम के लिए जो तरीका कोड को सबसे स्पष्ट बनाता है, वही अपनाएँ। JAX की दृष्टि में lambda और def परस्पर विनिमेय हैं, और पैरामीटर्स को क्लोज़र में रखने की जगह carry के रूप में पास करना केवल scan के पैरामीटराइज़ होने का ढंग बदलता है, कम्पाइल्ड फ़ंक्शन का व्यवहार नहीं। संदेह हो तो अपने सेटअप में मापें, पर किसी एक शैली से अंतर्निहित रनटाइम लाभ की उम्मीद न करें।
यह लेख StackOverflow पर question (प्रश्न) user1168149 द्वारा और jakevdp के उत्तर पर आधारित है।