2025, Oct 21 02:31

jax.vmap बनाम ब्रॉडकास्टिंग: बैचिंग के लिए सही तरीका

JAX में jax.vmap कब उपयोग करें? ब्रॉडकास्टिंग से आगे वेक्टराइज़ेशन, jnp.histogram जैसे non-batched ऑपरेशन्स, PyTrees और कन्वोल्यूशन पर व्यावहारिक मार्गदर्शन.

पहली नज़र में JAX और वेक्टराइज़ेशन उलझाऊ लग सकते हैं। कई ऐरे ऑपरेशन्स ब्रॉडकास्टिंग के जरिए पहले से ही बैच इनपुट स्वीकार करते हैं, तो फिर jax.vmap की जरूरत क्यों है? छोटा जवाब: हर ऑपरेशन मूल रूप से वेक्टराइज़्ड नहीं होता, और कई बार स्पष्ट वेक्टराइज़ेशन डाइमेंशनों से छेड़छाड़ करने या Python लूप लिखने की तुलना में ज्यादा साफ, सुरक्षित और तेज़ होता है।

जब “पहले से वेक्टराइज़्ड” होना काफी नहीं होता

डीप लर्निंग कोडबेस में ब्रॉडकास्टिंग और बैच अक्ष आम हैं और अक्सर सामान्य जरूरतें पूरी कर देते हैं। लेकिन JAX में ऐसे ऑपरेशन्स भी हैं जो बैच डाइमेंशन पर स्वाभाविक रूप से वेक्टराइज़्ड नहीं हैं। ऐसे मामलों में या तो आप साफ-साफ लूप लिखते हैं, या jax.vmap से फ़ंक्शन का बैच्ड रूप बनाते हैं। इससे कोड पठनीय रहता है और आम तौर पर Python-स्तरीय लूप्स की तुलना में प्रदर्शन बेहतर होता है।

समस्या का उदाहरण

सोचिए jnp.histogram के बारे में। यह स्वयं बैच अक्ष स्वीकार नहीं करता। सीधी राह है Python में लूप चलाना और नतीजों को हाथ से stack करना।

import jax
import jax.numpy as jnp
# एकल-सैंपल हिस्टोग्राम के लिए सहायक फ़ंक्शन
def single_hist(vec, nbins, span):
    # jnp.histogram (hist, bin_edges) लौटाता है; हम केवल histogram रखते हैं
    counts, _ = jnp.histogram(vec, bins=nbins, range=span)
    return counts
# खिलौना बैच: 2 सैंपल, प्रत्येक में 3 मान
batch_samples = jnp.array([[0.0, 1.0, 2.0],
                           [2.0, 3.0, 1.0]])
num_bins = 4
value_range = (0.0, 4.0)
# Python पक्ष पर बैचिंग (loop + stack)
looped = [single_hist(row, num_bins, value_range) for row in batch_samples]
batched_counts = jnp.stack(looped, axis=0)

यह तरीका काम तो करता है, लेकिन इसमें Python लूप और stacking का अतिरिक्त झंझट आ जाता है। केवल बैच का अनुकरण करने के लिए हाथ से डाइमेंशन जोड़ते-घटाते समय गलती करना भी आसान है।

मूल समस्या क्या है?

उलझन की जड़ यह है कि “वेक्टराइज़्ड” दो अलग मायने रख सकता है: कुछ प्रिमिटिव्स स्वाभाविक रूप से बैच डाइमेंशन संभालते हैं; अन्य डिफॉल्ट रूप से स्केलर या एकल-सैंपल होते हैं और उन्हें बैच करना पड़ता है। jnp.histogram और jnp.bincount इसी दूसरी श्रेणी में आते हैं। ऐसे में jax.vmap साफ तरीके से यह व्यक्त करता है कि “इस फ़ंक्शन को बैच पर स्वतंत्र रूप से लागू करो” — बिना फ़ंक्शन की एकल-सैंपल सेमान्टिक्स बदले और बिना आकृतियों (shapes) को मोड़े-तोड़ें।

स्टाइल का पहलू भी है। कभी-कभी डेवलपर्स लूप से बचने के लिए स्थानीय रूप से एक अतिरिक्त डाइमेंशन जोड़ते हैं और बाद में उसे घटा देते हैं। यह चले तो जाता है, पर jax.vmap से वही इरादा ज़्यादा साफ-साफ बयान होता है।

jax.vmap के साथ समाधान

jax.vmap एकल-सैंपल फ़ंक्शन को चुनिंदा अक्ष पर बैच्ड फ़ंक्शन में बदल देता है। यह PyTrees पर काम करता है, जिससे लाइब्रेरी-स्तरीय कन्वेंशन्स बन पाते हैं जो बैच हैंडलिंग को पूरी तरह छिपा भी सकते हैं। सैंपल-टू-सैंपल स्वतंत्र ऑपरेशन्स के लिए यह स्वाभाविक फिट है।

import jax
import jax.numpy as jnp
# ऊपर जैसा ही एकल-सैंपल फ़ंक्शन
def single_hist(vec, nbins, span):
    counts, _ = jnp.histogram(vec, bins=nbins, range=span)
    return counts
batch_samples = jnp.array([[0.0, 1.0, 2.0],
                           [2.0, 3.0, 1.0]])
num_bins = 4
value_range = (0.0, 4.0)
# वेक्टराइज़्ड रूप: batch_samples के अग्रणी अक्ष पर लागू करें
vmapped_hist = jax.vmap(lambda v: single_hist(v, num_bins, value_range))
counts_batched = vmapped_hist(batch_samples)

vmapped संस्करण वही तर्क बताता है जो लूप वाला रूप करता था, बस अब न तो स्पष्ट लूप है और न मैन्युअल stacking। यह सुविधाजनक है और लूप्स से बचकर पठनीयता के साथ प्रदर्शन भी बेहतर करता है।

बुनियादी बातों से आगे vmap कहाँ चमकता है

vmap PyTrees पर चलता है, इसलिए पूरे पैरामीटर स्ट्रक्चर बिना अक्ष-संबंधी कोड लिखे संभाले जा सकते हैं। कुछ लाइब्रेरी, जैसे equinox, इस कन्वेंशन को अपनाती हैं और पूरे पैरामीटर ट्री पर vmapping को प्रोत्साहित करती हैं। इससे आपके मॉडल के कोड में बैच अक्षों को हाथ से थ्रेड करने की जरूरत नहीं रहती। यह तरीका सैंपल्स के बीच स्वतंत्रता मानता है और उन ऑपरेशन्स के लिए काम नहीं करेगा जो बुनियादी तौर पर सैंपल्स के बीच जानकारी मिलाते हैं, जैसे बैच नॉर्म।

अन्य स्थितियों में, आप किसी ऑपरेशन को चलाने के लिए जबरन ब्रॉडकास्ट कराने हेतु अस्थायी अक्ष जोड़ने, फिर उसे घटाने का सोच सकते हैं। कई बार vmap वही इरादा ज़्यादा सीधे तौर पर जताता है। सहज समझ के लिए, मान लीजिए प्रति सैंपल अलग-अलग kernels के साथ convolution2d लागू करनी है। एक तरीका है kernels को stack करना, चैनल्स को प्रतिलिपि/stack करना, और बढ़ी हुई अक्ष पर एक ही convolution चलाना। दूसरा तरीका है एकल-सैंपल convolution लिखना और फिर kernel या sample अक्ष पर उसे vmap करना। दोनों काम कर सकते हैं; vmap बस प्रति-सैंपल स्वतंत्रता को साफ-साफ व्यक्त करता है।

यह क्यों मायने रखता है

बैचिंग रणनीति सिर्फ शैली नहीं है; यह शुद्धता, पठनीयता और प्रदर्शन को प्रभावित करती है। कब कोई ऑपरेशन मूल रूप से वेक्टराइज़्ड नहीं है, यह जानना मूक shape-बग्स और अनजाने Python लूप्स से बचाता है जो थ्रूपुट सीमित कर देते हैं। जहाँ उपयुक्त हो, vmap इस्तेमाल करने से एकल-सैंपल तर्क जस का तस रहता है, PyTrees के साथ साफ़-सुथरा जुड़ता है, और उन लाइब्रेरीज़ के अनुरूप होता है जो कन्वेंशन से vmap पर निर्भर हैं।

व्यावहारिक मार्गदर्शन

रोज़मर्रा के कोड में ब्रॉडकास्टिंग और बैच अक्ष एक मजबूत डिफॉल्ट हैं। jax.vmap का सहारा तब लें जब किसी फ़ंक्शन में मूल वेक्टराइज़ेशन न हो, जब किसी लाइब्रेरी की डिज़ाइन PyTrees पर vmapping को बढ़ावा दे, या जब गैर-परंपरागत अक्षों पर वेक्टराइज़ेशन चाहिए। यदि आप लूप से बचने के लिए मनमाने डाइमेंशन जोड़-घटा रहे हैं, सोचें कि क्या vmap इरादे को और साफ करेगा और रखरखाव आसान बनाएगा।

कोई सार्वभौमिक नियम नहीं है, और व्यक्तिगत पसंद भी मायने रखती है। मुख्य बात यह पहचानना है: जहाँ मूल बैचिंग उपलब्ध और प्रचलित है, वहीं उसे अपनाएँ; और जहाँ नहीं है, वहाँ स्वतंत्र प्रति-सैंपल गणना व्यक्त करने के लिए vmap का उपयोग करें। इस भेद को याद रखना JAX कोड को सरल और अधिक मज़बूत बनाता है।

यह लेख StackOverflow पर एक प्रश्न (लेखक: Mingruifu Lin) और Axel Donath के उत्तर पर आधारित है।