2025, Oct 18 01:46
क्यों vmap में बैच आकार बदलते ही GPU float32 परिणाम बदलते हैं
JAX vmap में बैच आकार बदलने पर GPU पर float32 परिणाम सूक्ष्म रूप से क्यों बदलते हैं, समझें: फ्लोटिंग-पॉइंट राउंडिंग, क्रिया-क्रम, Equinox MLP उदाहरण और व्यावहारिक उपाय.
अलग-अलग बैच आकारों पर किसी न्यूरल नेटवर्क को vmap करते समय हल्का-सा संख्यात्मक विचलन भ्रमित कर सकता है, खासकर तब जब इनपुट बिल्कुल समान हों। यदि vmapped फ़ंक्शन में आप जितनी पंक्तियाँ भेजते हैं, उसके आधार पर आउटपुट की शुरुआती कुछ पंक्तियों में छोटे-छोटे अंतर दिखते हैं, तो भी यह व्यवहार सही हो सकता है। संक्षेप में, फ्लोटिंग‑पॉइंट गणित क्रियाओं के क्रम के प्रति संवेदनशील होता है; क्रम बदलते ही राउंडिंग की प्रोफ़ाइल बदल जाती है। GPU पर float32 में यह आसानी से दिखता है, जबकि CPU पर या float64 में यह अंतर गायब हो सकता है।
समस्या को पुनरुत्पादित करना
नीचे दिया गया स्निपेट एक बैच पर Equinox MLP को दो बार लागू करता है: पहले एक छोटे स्लाइस पर, फिर पूरे एरे पर—और दोनों के परिणामों की शुरुआती पंक्तियों की तुलना करता है। सिर्फ नाम मूल पैटर्न से अलग हैं; प्रोग्राम का व्यवहार समान है।
import jax
import jax.numpy as jnp
import equinox as eqx
def batch_apply(arr_in, net_fn):
    mapped = eqx.filter_vmap(net_fn.__call__)
    return mapped(arr_in)
rng = jax.random.PRNGKey(0)
rng, key_net = jax.random.split(rng, 2)
model = eqx.nn.MLP(2, 2, 10, 2, key=key_net)
rng, key_x = jax.random.split(rng, 2)
xb = jax.random.normal(key_x, (10000, 2))
delta = batch_apply(xb[:10], model) - batch_apply(xb, model)[:10]
print("eqx error:", delta)
अंतर वास्तव में क्यों आता है
यह अपेक्षित है और vmap रूपांतरण के लिए कोई विशेष बात नहीं है। बात फ्लोटिंग‑पॉइंट अंकगणित की है। float32 में हर ऑपरेशन पर थोड़ी‑थोड़ी राउंडिंग त्रुटि जमा होती है। जब आप “एक‑सा” गणित अलग‑अलग गणनात्मक मार्गों से चलाते हैं, तो ये राउंडिंग त्रुटियाँ अलग स्थानों और अलग क्रम में जमा होती हैं, जिससे परिणाम थोड़े खिसक जाते हैं। बैच आकार बदलने से यह तय होता है कि कंप्यूटेशन कैसे शेड्यूल और फ्यूज़ होगा, इसलिए प्रभावी क्रिया‑क्रम भी बदल जाता है।
डिवाइस भी मायने रखता है, क्योंकि ऑपरेशन का अनुक्रम आर्किटेक्चर पर निर्भर करता है। CPU पर निष्पादन आमतौर पर अधिक क्रमिक संचय क्रम का पालन करता है, जो व्यवहार में अलग‑अलग बैच आकारों के बीच भी ऑपरेशनों के क्रम को एक‑सा रख सकता है। GPU पर निष्पादन अत्यधिक समांतर होता है, और समांतर कार्य‑विभाजन व संचय के पैटर्न इनपुट के आकारों पर निर्भर करते हैं। इसी अनुक्रम के फर्क से राउंडिंग अलग होती है, इसलिए float32 में संख्याएँ थोड़ी भिन्न दिखती हैं।
क्या किया जाए
यह व्यवहार स्वयं में सही है। यह अंतर अलग‑अलग क्रिया‑क्रमों के कारण होने वाले वैध फ्लोटिंग‑पॉइंट राउंडिंग भेद से आता है। देखे गए सेटअप में, float64 का उपयोग करने से यह अंतर गायब हो गया, और CPU पर चलाने से भी; दोनों ही मामलों में संख्यात्मक विशेषताएँ या अनुक्रम इतना बदल गया कि इस उदाहरण में दिखने वाला विचलन समाप्त हो गया।
यह क्यों मायने रखता है
float32 में आकार बदलने पर बैच्ड GPU गणनाओं को बिट‑स्तरीय रूप से निर्धारक मानना उचित नहीं है। यदि आप बैच आकारों या डिवाइसों के बीच परिणामों की तुलना करते हैं और पूर्ण समानता की उम्मीद रखते हैं, तो सामान्य फ्लोटिंग‑पॉइंट व्यवहार को आप तर्कगत बग समझ बैठ सकते हैं। यह समझना कि शेड्यूलिंग, बैचिंग और डिवाइस की समानांतरता संचय क्रम को प्रभावित करती है, अपेक्षाएँ और जाँच सही रखने में मदद करता है।
निष्कर्ष
GPU पर float32 में vmapped बैच आकार बदलने पर मानों में छोटे‑छोटे अंतर अलग क्रिया‑क्रमों के साथ जमा हुई फ्लोटिंग‑पॉइंट राउंडिंग का स्वाभाविक परिणाम हैं। गणना फिर भी सही है। यदि आपके उपयोग‑मामले में अधिक कड़े संख्यात्मक मिलान की जरूरत है, तो देखे गए कॉन्फ़िगरेशन में float64 पर स्विच करना या CPU पर मूल्यांकन करना यह अंतर हटाता दिखा। अन्यथा, सूक्ष्म विचलनों को समानांतर फ्लोटिंग‑पॉइंट गणना का सामान्य हिस्सा मानें और समानता का आकलन उसी दृष्टि से करें।