2026, Jan 10 23:00
Why Your CLIP Similarities Don't Separate: Fixing Residual-LayerNorm Order and Forcing True Self-Attention
Debug a CLIP reimplementation where image-text similarity collapses: fix residual-LayerNorm wiring, enforce self-attention, recover cat-dog separation.
When you reimplement CLIP around a pretrained checkpoint and your image–text similarity barely differs between prompts like “a photo of a cat” and “a photo of a dog,” something subtle is off. The model compiles, shapes line up, no missing weights, but the logits fail to separate. That was exactly the situation here: a cat image produced almost the same score against both prompts.
Cat similarity: tensor([[-3.5724]], grad_fn=<MulBackward0>)
Dog similarity: tensor([[-3.4155]], grad_fn=<MulBackward0>)
Minimal reproduction of the issue
The implementation is based on the openai/clip-vit-base-patch32 checkpoint. The text path uses EOS to select the last token representation, the tokenizer context_length is 77, and the checkpoint loads without missing or unexpected layers. The problem surfaced in the transformer block inside the text stack. Below is a condensed version with the names rewritten, preserving the original logic that led to the poor separation.
class ClipBridge(nn.Module):
def __init__(self, proj_width: int = 768, text_width: int = 512):
super().__init__()
self.vision_backbone = PictureEncoder(project_dim=proj_width)
self.text_backbone = SentenceEncoder(embed_dim=text_width)
self.tok = TorchBpeTokenizer()
self.logit_scale = nn.Parameter(torch.ones([]) * 0.7)
self.proj_img = nn.Linear(proj_width, text_width, bias=False)
self.proj_txt = nn.Linear(text_width, text_width, bias=False)
self.vision_backbone.eval()
self.text_backbone.eval()
def forward(self, pixels: torch.Tensor, txt_emb: torch.Tensor) -> torch.Tensor:
pixels = Resize(size=(224, 224))(pixels)
img_feat = self.vision_backbone(pixels)
txt_feat = self.proj_txt(txt_emb)
img_feat = self.proj_img(img_feat)
txt_feat = F.normalize(txt_feat, dim=-1)
img_feat = F.normalize(img_feat, dim=-1)
logits = self.logit_scale.exp() * (img_feat @ txt_feat.t())
return logits
def vectorize_text(self, token_ids, attn_mask=None):
if attn_mask is None:
token_ids, attn_mask = self.tok.tokenize(token_ids)
if token_ids.dim() == 1:
token_ids = token_ids.unsqueeze(0)
with torch.no_grad():
out = self.text_backbone(token_ids.long(), attn_mask)
return out
class SentenceEncoder(nn.Module):
def __init__(self, embed_dim: int = 512):
super().__init__()
vocab_size = 49408
self.embed = nn.Module()
self.embed.token = nn.Embedding(vocab_size, embed_dim)
self.embed.pos = nn.Embedding(77, embed_dim)
self.blocks = TransformerStack(hidden_size=embed_dim)
self.out_norm = nn.LayerNorm(embed_dim)
def forward(self, tokens: torch.Tensor, attn_mask: torch.Tensor):
y = self.embed.token(tokens.long())
pos = torch.arange(y.size(1))
y += self.embed.pos(pos).to(y.dtype).to(y.device)
y = y.permute(1, 0, 2)
y = self.blocks(y, attn_mask)
y = y.permute(1, 0, 2)
if y.dim() == 2:
y = y.unsqueeze(0)
if attn_mask.dim() == 1:
attn_mask = attn_mask.unsqueeze(0)
y = y[torch.arange(y.size(0)), tokens.argmax(dim=-1)]
return self.out_norm(y)
class Block(nn.Module):
def __init__(self, model_dim: int = 768, ratio: int = 4, heads: int = 8):
super().__init__()
self.norm1 = nn.LayerNorm(model_dim)
self.norm2 = nn.LayerNorm(model_dim)
self.ff = MLP(embed_size=model_dim, ratio=ratio)
self.attn = AttentionPool2d(num_heads=heads, embed_dim=model_dim)
def forward(self, x: torch.Tensor, src_pad_key=None):
x = self.norm1(x)
if src_pad_key is not None:
attn_out = self.attn(x, src_pad_key=src_pad_key, use_self_attention=True)
else:
attn_out = self.attn(x)
x += attn_out
x = self.norm2(x)
x += self.ff(x)
return x
class TransformerStack(nn.Module):
def __init__(self, hidden_size=768):
super().__init__()
self.layers = nn.ModuleList([Block(model_dim=hidden_size) for _ in range(12)])
def forward(self, x: torch.Tensor, attention_mask=None):
if attention_mask is not None:
src_key = attention_mask == 0
if src_key.dim() == 1:
src_key = src_key.unsqueeze(0)
for layer in self.layers:
x = layer(x, src_key)
else:
for layer in self.layers:
x = layer(x)
return x
What actually went wrong
The transformer residual path inside the text stack was implemented incorrectly. Instead of saving the pre-activation as a skip connection and then adding it back after attention and MLP, the code normalized, computed attention, and immediately added the attention output to the already-normalized tensor. The same pattern repeated around the MLP. This breaks the expected residual–normalization interplay and degrades the representations used to compute the image–text logits.
There was also a mode switch in the attention module. It could run either self-attention or a pooled variant, depending on how it was called. That discrepancy matters because the text transformer needs true self-attention with query = x to align with how the image encoder works.
Fix and corrected code
The residuals must be taken from the input of each sub-layer and re-added after the sub-layer output. In parallel, the attention should always be invoked in self-attention mode. With those two changes, the similarity scores separate cleanly.
class Block(nn.Module):
def __init__(self, model_dim: int = 768, ratio: int = 4, heads: int = 8):
super().__init__()
self.norm1 = nn.LayerNorm(model_dim)
self.norm2 = nn.LayerNorm(model_dim)
self.ff = MLP(embed_size=model_dim, ratio=ratio)
self.attn = AttentionPool2d(num_heads=heads, embed_dim=model_dim)
def forward(self, x: torch.Tensor, src_pad_key=None):
skip = x
x = self.norm1(x)
if src_pad_key is not None:
x = self.attn(x, src_pad_key=src_pad_key, use_self_attention=True)
else:
x = self.attn(x, use_self_attention=True)
x += skip
skip = x
x = self.norm2(x)
x = self.ff(x)
x += skip
return x
After correcting the residual math and enforcing self-attention, the outputs look as expected for a cat image versus “cat” and “dog” text prompts.
Cat similarity: tensor([[25.4132]], grad_fn=<MulBackward0>)
Dog similarity: tensor([[21.8544]], grad_fn=<MulBackward0>)
cosine cat/dog: 0.8438754677772522
Why this matters for practitioners
With pretrained checkpoints, a single misplaced residual add or an unexpected attention path is enough to undermine the whole image–text alignment, even if all tensor shapes and parameter names match and the checkpoint loads cleanly. The tokenizer setup, EOS selection, and embedding sizes can be correct, yet the similarity still collapses if the sub-layer wiring is not faithful.
Takeaways
When porting CLIP-like stacks, be precise about the residual–LayerNorm order around attention and MLP. Keep the attention in self-attention mode so that query = x, matching the behavior assumed elsewhere in the model. If your scores barely separate between contradictory prompts, start by diffing the transformer block math against the intended pattern, then validate that the attention path is not silently switching to a pooled variant. Once the residual connections and attention mode are aligned, the similarity distribution recovers and the image–text scores become discriminative, as they should.