2025, Oct 03 07:31
JAX JIT cache: scan, pytrees और lambda closures का असर
JAX JIT में scan, pytree और lambda capture का असर समझें: क्या री-कंपाइलेशन ट्रिगर करता है, cache key में static बनाम dynamic की भूमिका, और कब executable का reuse होता है.
JAX का scan, pytrees, lambda closures और JIT caching अक्सर एक ही कोड में साथ आते हैं — और तब सवाल उठता है: क्या lambda के जरिए किसी pytree को कैप्चर करना हर नई instance पर दोबारा कंपाइल करवाएगा? आइए समझते हैं कि JIT री-कंपाइलेशन वास्तव में किससे ट्रिगर होता है, pytrees cache key को कैसे नियंत्रित करते हैं, और क्यों आपका सेटअप तब तक ठीक है जब तक सही हिस्से स्थिर बने रहें।
न्यूनतम उदाहरण जो सवाल उठाता है
यह प्रोग्राम एक इनपुट वेक्टर पर scan चलाता है। स्टेट अपडेट लॉजिक एक ऐसे pytree ऑब्जेक्ट में पैक है जो एक array और एक scalar दोनों रखता है, और इस ऑब्जेक्ट को एक jitted फ़ंक्शन के अंदर lambda के जरिए क्लोज़ किया गया है।
import jax
import jax.numpy as jnp
from jax import tree_util
class ConfigBox:
    def __init__(self, vec, scale):
        self.vec = vec
        self.scale = scale
    def step_fn(self, st, inp):
        s = st
        y = inp
        nxt = (self.vec + s + jnp.ones(self.scale)) * y
        return nxt
    def _tree_flatten(self):
        leaves = (self.vec,)
        meta = {'scale': self.scale}
        return (leaves, meta)
    @classmethod
    def _tree_unflatten(cls, meta, leaves):
        return cls(*leaves, **meta)
tree_util.register_pytree_node(
    ConfigBox,
    ConfigBox._tree_flatten,
    ConfigBox._tree_unflatten,
)
def step_wrapper(cfg, st, inp):
    s = st
    y = inp
    s_next = cfg.step_fn(s, y)
    return s_next, [s_next]
@jax.jit
def run_scan_jit(cfg):
    body = lambda st, inp: step_wrapper(cfg, st, inp)
    st0 = jnp.array([0., 1.])
    seq = jnp.array([1., 2., 3.])
    last, out = jax.lax.scan(body, st0, seq)
    return last, out
if __name__ == "__main__":
    cfg1 = ConfigBox(jnp.array([1., 2.]), 2)
    last_state, outputs = run_scan_jit(cfg1)
    print(last_state)
    cfg2 = ConfigBox(jnp.array([3., 4.]), 2)
    last_state, outputs = run_scan_jit(cfg2)
    print(last_state)
वास्तव में री-कंपाइलेशन किससे होता है
JAX में pytrees के लिए JIT cache hits इस बात पर निर्भर करते हैं कि वे कैसे flatten होते हैं। Flattening के सहायक (auxiliary) हिस्से में लौटाए गए मान static माने जाते हैं। जो मान बच्चों (children) के रूप में लौटते हैं, वे dynamic माने जाते हैं। Static प्रविष्टियाँ अपने Python मान के जरिए cache key को प्रभावित करती हैं। Dynamic प्रविष्टियाँ अपने array के ढांचे (shape, dtype, sharding) के जरिए cache key को प्रभावित करती हैं, लेकिन उनकी सामग्री नहीं करती।
उपरोक्त उदाहरण पर इसे लागू करें तो meta में scale है और वह static है। scale का मान बदलने से री-कंपाइलेशन ट्रिगर होगा। leaves में vec है और वह dynamic है। vec के मान बदलने से री-कंपाइल नहीं होगा, लेकिन उसके shape, dtype या sharding बदलने से होगा। दिखाए गए दोनों कॉल्स में vec की shape, dtype और sharding एक जैसी है, और scale बदला नहीं है; इसलिए री-कंपाइलेशन नहीं होता।
एक और सामान्य चिंता यह होती है कि cfg को bind करने के लिए इस्तेमाल किया गया lambda क्या JIT cache key को प्रभावित करता है। इस सेटअप में ऐसा नहीं है, क्योंकि lambda pytree flattening के आउटपुट का हिस्सा नहीं है और इसलिए key को प्रभावित नहीं करता।
समाधान: static और dynamic हिस्सों को स्पष्ट रखें
दिखाए गए परिदृश्य में री-कंपाइलेशन से बचने के लिए कोई बदलाव आवश्यक नहीं है। जब तक static फ़ील्ड्स के मान एक जैसे रहते हैं और dynamic arrays की shape, dtype और sharding समान रहती हैं, JIT कैश किए गए executable को ही पुन: उपयोग करेगा। आवश्यकता हो तो आप https://stackoverflow.com/a/70127930/2937831 पर संदर्भित तरीके से cache reuse को सत्यापित कर सकते हैं। आप यह भी देख सकते हैं कि इस प्रोग्राम की दोनों कॉल्स के बाद jitted फ़ंक्शन का cache आकार 1 ही रहता है।
उसी कोड में, बिना बदलाव किए, यह मानदंड पूरा होता है:
import jax
import jax.numpy as jnp
from jax import tree_util
class ConfigBox:
    def __init__(self, vec, scale):
        self.vec = vec
        self.scale = scale
    def step_fn(self, st, inp):
        return (self.vec + st + jnp.ones(self.scale)) * inp
    def _tree_flatten(self):
        leaves = (self.vec,)
        meta = {'scale': self.scale}
        return (leaves, meta)
    @classmethod
    def _tree_unflatten(cls, meta, leaves):
        return cls(*leaves, **meta)
tree_util.register_pytree_node(
    ConfigBox,
    ConfigBox._tree_flatten,
    ConfigBox._tree_unflatten,
)
def step_wrapper(cfg, st, inp):
    nxt = cfg.step_fn(st, inp)
    return nxt, [nxt]
@jax.jit
def run_scan_jit(cfg):
    body = lambda st, inp: step_wrapper(cfg, st, inp)
    st0 = jnp.array([0., 1.])
    seq = jnp.array([1., 2., 3.])
    return jax.lax.scan(body, st0, seq)
यह क्यों महत्वपूर्ण है
किन हिस्सों को static और किन्हें dynamic माना जाए, यह समझना स्थिर और अनुमानित रन व अनपेक्षित री-कंपाइल स्पाइक्स के बीच का अंतर तय करता है। Arrays के मान बदले जा सकते हैं, इससे cache अमान्य नहीं होता; लेकिन shape, dtype या sharding पैटर्न बदलेंगे तो होगा। जो scalar static हिस्से में रखे जाते हैं, उनके मान बदलने पर री-कंपाइलेशन होता है। इस पैटर्न में lambda के जरिए की गई capture caching व्यवहार को नहीं बदलती।
मुख्य निष्कर्ष
यदि आपके JAX प्रोग्राम में कोई pytree शामिल है, तो सोच-समझकर तय करें कि dynamic leaves में क्या जाए और static auxiliary डेटा में क्या। जब cache reuse चाहिए, तो static फ़ील्ड्स के मान स्थिर रखें; और री-कंपाइल से बचना हो, तो dynamic arrays की shape और dtype सुसंगत रखें। इन शर्तों के तहत किसी dynamic मान को lambda से कैप्चर करना ठीक है। संदेह हो तो ऊपर संदर्भित तरीके से cache reuse जाँच लें, या jitted फ़ंक्शन के cache आकार पर नजर डालें; जिन कॉल्स में static डेटा या array संरचना नहीं बदलती, उनमें यह 1 ही रहना चाहिए।
यह लेख StackOverflow पर प्रश्न (लेखक: user1168149) और jakevdp के उत्तर पर आधारित है।