2025, Oct 31 06:32
DQN डिबगिंग: PyTorch Conv2d के लिए HWC से CHW और in_channels=4
Atari Pong पर DQN न सीख रहा? कारण अक्सर PyTorch Conv2d में चैनल-फर्स्ट टेन्सर लेआउट की गलती है. HWC से CHW permute करें, in_channels=4 सेट करें—पूरी गाइड पढ़ें.
एक खिलौना-से कार्य पर धाक जमाने वाला DQN जब Atari Pong पर बिल्कुल ठहर जाता है, तो यह एक क्लासिक डिबगिंग कहानी बन जाती है: एल्गोरिद्म ठीक है, मेट्रिक्स समझदारी वाले लगते हैं, लॉस शून्य की ओर जाता है, फिर भी स्कोर बढ़ता ही नहीं। इस मामले में दोषी न तो ऑप्टिमाइज़र है और न ही एप्सिलॉन शेड्यूलिंग—मुद्दा टेन्सर लेआउट का है।
क्या गलत हो रहा है
Pong स्टैक से आने वाला observation 84 × 84 × 4 आकार का होता है—पहले ऊँचाई, फिर चौड़ाई, और अंत में चैनल्स। PyTorch का Conv2d चैनल-फर्स्ट फॉर्मेट चाहता है: N × C × H × W। जब इनपुट को ज्यों-का-त्यों भेज दिया जाता है और पहली कॉन्वोल्यूशन में in_channels को 84 सेट किया जाता है, तो नेटवर्क हर इमेज की पंक्ति को अलग-अलग चैनल मान लेता है और चार स्टैक किए गए फ़्रेम्स को चौड़ाई के पिक्सेल की तरह ट्रीट करता है। नतीजा यह कि ट्रेनिंग दुनिया की एक उलझी हुई तस्वीर पर होती है। टारगेट नेटवर्क जल्द ही मुख्य नेटवर्क के आउटपुट की नकल करने लगता है, Huber लॉस शून्य के पास आ टिकता है, और एजेंट फिर भी Pong के बारे में कोई उपयोगी बात नहीं सीख पाता।
समस्या को दोहराने वाला न्यूनतम उदाहरण
समस्या का मूल कारण इनपुट पर ही दिख जाता है। नीचे दिया गया कोड observation टेन्सर के पहले आयाम से पहली कॉन्वोल्यूशन बनाता है और कच्चा H × W × C टेन्सर सीधे Conv2d में भेज देता है।
class PixelAgent(nn.Module):
    def linear_fall(self, step):
        if step >= self.cfg.lr_n_steps:
            return self.cfg.lr_end / self.cfg.lr_begin
        return 1.0 - (step / self.cfg.lr_n_steps) * (1 - self.cfg.lr_end / self.cfg.lr_begin)
    def __init__(self, game_env, cfg, dev):
        super().__init__()
        self.env = game_env
        self.cfg = cfg
        self.dev = dev
        # गलत: चैनल की गिनती के बजाय H का उपयोग
        self.conv_a = nn.Conv2d(
            in_channels=self.env.observation_space.shape[0],  # 84
            out_channels=32,
            kernel_size=8,
            stride=4,
            dtype=torch.float32,
        )
        self.conv_b = nn.Conv2d(32, 64, 4, 2, dtype=torch.float32)
        self.conv_c = nn.Conv2d(64, 64, 3, 1, dtype=torch.float32)
        self.fc_a = nn.Linear(3136, 512, dtype=torch.float32)
        self.fc_b = nn.Linear(512, self.env.action_space.n, dtype=torch.float32)
        self.act = nn.ReLU()
        self.opt = optim.RMSprop(
            self.parameters(),
            lr=self.cfg.lr_begin,
            alpha=self.cfg.squared_gradient_momentum,
            eps=self.cfg.rms_eps,
        )
        self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lr_lambda=self.linear_fall)
        self.loss_fn = nn.SmoothL1Loss()
    def forward(self, tensor_in):
        x = self.conv_a(tensor_in)
        x = self.act(x)
        x = self.conv_b(x)
        x = self.act(x)
        x = self.conv_c(x)
        x = self.act(x)
        if len(x.shape) == 3:
            x = torch.flatten(x)
        elif len(x.shape) == 4:
            x = torch.flatten(x, start_dim=1)
        x = self.fc_a(x)
        x = self.act(x)
        x = self.fc_b(x)
        return x
इस सेटअप में मॉनिटरिंग दिखाती है कि अधिकतम/औसत रिवार्ड घटते जा रहे हैं और Q-वैल्यू छोटी हो रही हैं, जबकि लॉस शून्य के आसपास ठहरा रहता है। यह ऐसे टारगेट नेटवर्क के अनुरूप है जो स्थिर नीति सुधार के बजाय एक गलत प्रतिनिधित्व से मेल खा रहा है।
जड़ कारण की व्याख्या
PyTorch में Conv2d की परिभाषा साफ़ है: चैनल का आयाम सबसे पहले होना चाहिए। Atari प्रीप्रोसेसिंग से आने वाला स्टैक्ड observation आमतौर पर height-by-width-by-frames होता है। यदि आप अक्षों को permute नहीं करते, तो कॉन्वोल्यूशन के वज़न गलत आयाम के सापेक्ष सीखे जाते हैं। मॉडल तब भी टेम्पोरल-डिफरेंस त्रुटि घटा सकता है, क्योंकि ऑनलाइन और टारगेट दोनों नेटवर्क एक ही गलती कर रहे हैं, लेकिन सीखी गई फीचर्स Pong में अर्थपूर्ण स्थानिक या समय-गत संरचना से मेल नहीं खातीं।
समाधान
दो बदलाव ज़रूरी हैं। पहला, observation को चैनल-फर्स्ट में permute करें और नेटवर्क में भेजने से पहले उसे [0, 1] में स्केल करें। दूसरा, पहली कॉन्वोल्यूशन को 84 नहीं बल्कि चार चैनल अपेक्षित हों, ऐसा सेट करें।
def to_chw(obs_np: np.ndarray) -> torch.Tensor:
    t = torch.from_numpy(obs_np).float().div_(255.0)
    # (H, W, C) -> (C, H, W)
    if t.ndim == 3:
        t = t.permute(2, 0, 1)
    return t
class PixelAgentFixed(nn.Module):
    def linear_fall(self, step):
        if step >= self.cfg.lr_n_steps:
            return self.cfg.lr_end / self.cfg.lr_begin
        return 1.0 - (step / self.cfg.lr_n_steps) * (1 - self.cfg.lr_end / self.cfg.lr_begin)
    def __init__(self, game_env, cfg, dev):
        super().__init__()
        self.env = game_env
        self.cfg = cfg
        self.dev = dev
        # सही: 4 स्टैक्ड फ्रेम्स चैनल के रूप में
        self.conv_a = nn.Conv2d(in_channels=4, out_channels=32, kernel_size=8, stride=4)
        self.conv_b = nn.Conv2d(32, 64, 4, 2)
        self.conv_c = nn.Conv2d(64, 64, 3, 1)
        self.fc_a = nn.Linear(3136, 512)
        self.fc_b = nn.Linear(512, self.env.action_space.n)
        self.act = nn.ReLU()
        self.opt = optim.RMSprop(
            self.parameters(),
            lr=self.cfg.lr_begin,
            alpha=self.cfg.squared_gradient_momentum,
            eps=self.cfg.rms_eps,
        )
        self.lr_sched = torch.optim.lr_scheduler.LambdaLR(self.opt, lr_lambda=self.linear_fall)
        self.loss_fn = nn.SmoothL1Loss()
    def forward(self, tensor_in):
        x = self.conv_a(tensor_in)
        x = self.act(x)
        x = self.conv_b(x)
        x = self.act(x)
        x = self.conv_c(x)
        x = self.act(x)
        if len(x.shape) == 3:
            x = torch.flatten(x)
        elif len(x.shape) == 4:
            x = torch.flatten(x, start_dim=1)
        x = self.fc_a(x)
        x = self.act(x)
        x = self.fc_b(x)
        return x
# इंटरैक्शन लूप में उपयोग
# obs_np (84, 84, 4) है
obs_t = to_chw(obs_np).to(device)  # (4, 84, 84)
अगर आप इसे एनवायरनमेंट की सीमा पर ही ठीक करना चाहें, तो ग्रेस्केल में एक चैनल आयाम सक्षम करके सीधे चैनल-फर्स्ट मांग सकते हैं और फ्रेम-स्टैकर को सही लेआउट बनाने दें।
यह क्यों मायने रखता है
Deep Q-Networks स्थानिक फीचर्स निकालने वाली कॉन्वोल्यूशंस और अल्पकालिक डाइनैमिक्स को समेटने वाले स्टैक्ड फ्रेम्स पर निर्भर करते हैं। जब चैनल अक्ष गलत सेट हो, तो दोनों संकेत ध्वस्त हो जाते हैं। TD टारगेट अभी भी ऑनलाइन नेटवर्क का पीछा करता है, यही कारण है कि SmoothL1 लॉस बढ़िया दिखता है, पर एजेंट ऐसी नीति कभी नहीं सीखता जो Pong में स्कोर करा सके। अक्षों के क्रम को ठीक करने से अर्थपूर्ण ग्रेडिएंट लौट आते हैं और सीखने की वक्र फिर चढ़ने लगती है।
व्यावहारिक सुरक्षा-उपाय
कुछ व्यावहारिक विकल्प मूल सेटअप के साथ बेहतर मेल खाते हैं। यहाँ BatchNorm उल्टा असर डालता है, क्योंकि रिप्ले किए गए ट्रांज़िशन i.i.d. नहीं होते—इसलिए फायदा करने के बजाय नुकसान पहुँचाता है। रिवार्ड को ±1 तक क्लिप करना और Huber लॉस उपयुक्त हैं और बने रहने चाहिए। Pong में चैनल क्रम सही होने पर औसत रिटर्न आम तौर पर लगभग दस लाख स्टेप्स के बाद बढ़ना शुरू करता है और सादे DQN के साथ आठ से दस मिलियन स्टेप्स में कम से कम पंद्रह तक पहुँच सकता है।
निष्कर्ष
ऑप्टिमाइज़र सेटिंग्स या शेड्यूलिंग ट्यूनिंग के पीछे भागने से पहले, मॉडल की सीमा पर टेन्सर लेआउट जाँच लें। Atari इनपुट के लिए इसका मतलब है height-width-channels से channel-height-width में बदलना और पहली कॉन्वोल्यूशन के in_channels को स्टैक्ड फ्रेम्स की संख्या से मिलाना। यह होते ही वही एल्गोरिद्म जो खिलौना कार्य पर चला था, अब Pong को सही नज़र से देख पाएगा—और ट्रेनिंग कर्व सपाट रहने के बजाय ऊपर चढ़ेगा।
यह लेख StackOverflow पर प्रश्न (लेखक: Rohan Patel) और Dmitry543 के उत्तर पर आधारित है।