2025, Nov 01 16:01

TorchRL में DQN/RNN: done मास्क, SliceSampler और एपिसोड-सुरक्षित टार्गेट्स

TorchRL में DQN/RNN के लिए done/terminated मास्क, Replay Buffer और SliceSampler कैसे एपिसोड सीमाओं पर बूटस्ट्रैप रोकते हैं और टार्गेट लीकेज से बचाते हैं।

TorchRL में RNNs के साथ DQN एपिसोड और बैचों को कैसे संभालता है, इस पर अक्सर एक व्यावहारिक सवाल उठता है: अगर एक कलेक्टर कई एपिसोड को जोड़कर एक ही बैच बना देता है, तो क्या ट्रेनिंग के दौरान वैल्यू टार्गेट्स एपिसोड की सीमाओं को पार करके ‘लीक’ होते हैं? संक्षिप्त जवाब: नहीं। done/terminated/truncated मार्करों का सही उपयोग किसी भी तरह की क्रॉस-कंटैमिनेशन को रोकता है, और रिप्ले बफर या तो एकल स्टेप्स या ट्राजेक्टरी के स्लाइस सैंपल कर सकता है—बिना असंबंधित डेटा मिलाए।

स्पष्ट दिखने वाली समस्या का न्यूनतम उदाहरण

ऐसा बैच लें जिसमें दो एपिसोड लगातार जोड़ दिए गए हों। अगर पूरे बैच में एपिसोड के अंत का सम्मान किए बिना next-state की वैल्यूज़ को संरेखित कर दिया जाए, तो एपिसोड 1 का आखिरी ट्रांज़िशन गलती से एपिसोड 2 की पहली स्टेट को अपनी next state मान लेगा। नीचे दिया स्निपेट इस गलती और उसे रोकने के लिए सही मास्किंग दोनों को दिखाता है।

import torch
# लंबाई 3 के दो एपिसोड लगातार जुड़े हुए: सूचकांक [0..2] और [3..5]
rew = torch.tensor([0.1, 0.0, 1.0, 0.2, -0.1, 0.3])
done = torch.tensor([0,   0,   1,   0,    0,    1], dtype=torch.float32)
gamma = 0.99
# मान लें कि ये हर स्थिति t के लिए max_a' Q(s_{t+1}, a') हैं
q_next = torch.tensor([0.5, 0.4, 0.9, 0.7, 0.6, 0.3])
# next-state वैल्यूज़ को वर्तमान ट्रांज़िशन के साथ left shift द्वारा संरेखित करें
q_next_shift = torch.roll(q_next, shifts=-1)
q_next_shift[-1] = 0.0  # बैच के अंतिम एलिमेंट के लिए पैडिंग
# भोला-सा target एपिसोड की सीमाओं को नज़रअंदाज़ करता है (गलत)
target_naive = rew + gamma * q_next_shift
# सही target टर्मिनल्स पर मास्क लगाता है: done के बाद कोई बूटस्ट्रैप नहीं
# क्रॉस-कंटैमिनेशन से बचने की यही कुंजी है
target_masked = rew + gamma * (1.0 - done) * q_next_shift
print("target_naive:", target_naive)
print("target_masked:", target_masked)

भोली-भाली गणना टर्मिनल स्टेप पर भी अगली एंट्री उठा लेती है, जिससे दो असंबंधित एपिसोड अनजाने में आपस में जुड़ जाते हैं। मास्क की गई गणना टर्मिनल पर बूटस्ट्रैपिंग को शून्य कर देती है, इसलिए टार्गेट्स बिल्कुल done/terminated/truncated स्टेप्स पर ही रुकते हैं।

वास्तव में TorchRL के कलेक्टर्स और लॉसेज़ क्या करते हैं

कलेक्टर्स ऐसे बैच लौटा सकते हैं जिनमें अलग-अलग ट्राजेक्टरी के हिस्से जुड़े हों। ऐसे बैच को समय-आधारित डेटा समझने वाले ऑब्जेक्टिव्स को देना सुरक्षित है, क्योंकि वे ट्राजेक्टरी अलग रखने और क्रॉस-एपिसोड प्रभाव रोकने के लिए done/terminated/truncated मार्करों पर निर्भर रहते हैं। खास तौर पर, DQNLoss डेटा को रिप्ले बफर में लिखता है, और ट्रेनिंग के दौरान या तो अलग-अलग ट्रांज़िशन या पूरी ट्राजेक्टरी के स्लाइस सैंपल होते हैं। जब ट्राजेक्टरी स्लाइस चाहिए, तो SliceSampler यह सुनिश्चित करता है कि स्लाइस एपिसोड की सीमा के भीतर ही रहें। दोनों ही स्थितियों में कोई क्रॉस-कंटैमिनेशन नहीं होता।

अगर आप टार्गेट्स खुद गणना कर रहे हों तो व्यावहारिक उपाय

अगर आप दिए गए लॉस पर छोड़ने के बजाय बूटस्ट्रैप्ड टार्गेट्स खुद निकालते हैं, तो एपिसोडिक मास्क ज़रूर लगाएँ। यही विचार ऊपर दिखाया गया है और यही वजह है कि done मार्करों का सम्मान हो तो कई एपिसोड को एक ही बैच में जोड़ना ठीक रहता है। नीचे मास्किंग का संक्षिप्त पैटर्न दिया है जो इसी सुरक्षित व्यवहार को दोहराता है।

def safe_dqn_targets(r_t, done_t, q_next_t, gamma):
    # r_t: [T] रिवॉर्ड्स
    # done_t: [T] {0,1} फ्लैग; टर्मिनल स्टेप पर 1
    # q_next_t: [T] संरेखित Q(s_{t+1}), जैसे पिछले उदाहरण में
    # gamma: स्केलर डिस्काउंट
    return r_t + gamma * (1.0 - done_t) * q_next_t
# ऊपर वाले बैच के साथ उदाहरण के तौर पर पुन: उपयोग
targets = safe_dqn_targets(rew, done, q_next_shift, gamma)

ट्रांज़िशन बनाम ट्राजेक्टरी स्लाइस का सैंपलिंग

इस सेटअप में रिप्ले बफर या तो एकल ट्रांज़िशन निकाल सकता है या ऐसे निरंतर विंडो जो टर्मिनल को पार नहीं करते। जब पूरी या आंशिक ट्राजेक्टरी चाहिए, तो SliceSampler उपयुक्त है; वह एपिसोड सीमाओं का सम्मान करता है ताकि समय-संबंधी गणनाएँ हर ट्राजेक्टरी के भीतर ही रहें। नीचे दिया गया अवधारणात्मक हेल्पर दिखाता है कि टर्मिनल को पार किए बिना एपिसोड के भीतर विंडो कैसे गिनें।

def windows_within_episodes(done_flags, window):
    idx = 0
    spans = []
    n = len(done_flags)
    while idx < n:
        # एपिसोड सेगमेंट [ep_start, ep_end] खोजें, टर्मिनल सहित
        ep_start = idx
        while idx < n and done_flags[idx].item() == 0:
            idx += 1
        ep_end = idx  # ep_end पर टर्मिनल
        # [ep_start, ep_end] के भीतर पूरी तरह स्थिर आकार की विंडो बनाएँ
        for s in range(ep_start, ep_end + 1):
            e = s + window
            if e - 1 > ep_end:
                break
            spans.append((s, e))
        idx += 1  # टर्मिनल से आगे बढ़ें
    return spans
# उदाहरण: सभी लंबाई-2 विंडो जो कभी एपिसोड सीमाएँ नहीं पार करतीं
spans = windows_within_episodes(done, window=2)
print(spans)

यह अवधारणा उस व्यवहार का प्रतिबिंब है जो एक ट्राजेक्टरी-सचेत सैंपलर करता है। प्रोडक्शन में, आप बिल्ट-इन SliceSampler पर भरोसा करेंगे ताकि रिप्ले बफर सिर्फ एक ही ट्राजेक्टरी से निरंतर समय खंड लौटाए।

TorchRL में संबंधित बिल्डिंग ब्लॉक्स

SliceSampler ट्राजेक्टरी के स्लाइस को सुरक्षित रूप से सैंपल करने के लिए बनाया गया है। GAE जैसे टेम्पोरल ऑब्जेक्टिव दिखाते हैं कि done/terminated/truncated मार्करों का सम्मान करते हुए स्टैक्ड ट्राजेक्टरी पर कैसे काम किया जाए। कुछ LLM कलेक्टर्स भी हैं जो केवल पूर्ण ट्राजेक्टरी लौटाते हैं, और इस क्षमता को अन्य कलेक्टर्स तक सामान्यीकृत किया जा सकता है।

यह क्यों महत्वपूर्ण है

रिइनफोर्समेंट लर्निंग में RNNs या किसी भी समय-सचेत मॉडल का उपयोग करते समय एपिसोड सीमाओं की स्पष्टता बेहद जरूरी है। टर्मिनल के पास गलत-संरेखित टार्गेट्स और एपिसोड को पार करती विंडो सीखने की गुणवत्ता को चुपचाप गिरा सकती हैं। सही मास्किंग और सीमा-सचेत सैंपलिंग स्टैक्ड बैचों को कुशल बनाए रखते हुए हर ट्राजेक्टरी की अखंडता सुरक्षित रखती है।

मुख्य बातें

कलेक्टर्स से आए जोड़कर बनाए गए बैच TorchRL के ऑब्जेक्टिव्स में सुरक्षित रूप से दिए जा सकते हैं, क्योंकि done/terminated/truncated मार्कर क्रॉस-एपिसोड लीकेज रोकते हैं। DQNLoss वर्कफ़्लो डेटा को रिप्ले बफर में लिखता है और फिर या तो अलग-अलग स्टेप्स या एपिसोड के भीतर के स्लाइस सैंपल करता है; SliceSampler के साथ आप एक ही ट्राजेक्टरी के भीतर रहते हैं। अगर आप कोई भी बूटस्ट्रैप्ड टार्गेट हाथ से बनाते हैं, तो एपिसोडिक मास्क लगाएँ ताकि कोई भी गणना टर्मिनल स्टेप्स को पार न करे। इतना पर्याप्त है कि आप बड़े बैच वाली ट्रेनिंग की दक्षता का लाभ उठा सकें—बिना ट्राजेक्टरी मिलाए।

यह लेख StackOverflow के प्रश्न (लेखक: Ícaro Lorran) और vmoens के उत्तर पर आधारित है।