2026, Jan 04 05:00
Stabilizing OmniEmbed Multimodal Embeddings on Qwen2.5-Omni: Avoid Cache, Align Devices, Fix FlashAttention2 NaNs
Fix NaN multimodal embeddings in OmniEmbed on Qwen2.5-Omni: disable generation cache, align tensors to the model device, or use SDPA instead of FlashAttention2.
Embedding images and videos with OmniEmbed on top of Qwen2.5-Omni can unexpectedly yield NaNs when the run order changes. A common symptom is that a fresh session produces a valid video embedding but an image embedding turns into NaN, and starting with a video makes the image pass work. Below is a minimal walkthrough of why this happens and how to make the pipeline deterministic and stable.
Problem setup
The following snippet mirrors a typical multimodal embedding flow and reproduces the issue. It uses FlashAttention2, enables the generation cache, and constructs a representation from the last token’s hidden state.
import torch
from transformers import AutoProcessor, Qwen2_5OmniThinkerForConditionalGeneration
from qwen_omni_utils import process_mm_info as pack_mm
pipe = AutoProcessor.from_pretrained("Qwen/Qwen2.5-Omni-7B")
lm = Qwen2_5OmniThinkerForConditionalGeneration.from_pretrained(
"Tevatron/OmniEmbed-v0.1-multivent",
attn_implementation="flash_attention_2",
torch_dtype=torch.bfloat16,
device_map="auto"
).eval()
pipe.tokenizer.padding_side = "left"
lm.padding_side = "left"
# Embedding routine
def build_embed(dialogue):
prompt = pipe.apply_chat_template(dialogue, tokenize=False, add_generation_prompt=True) + "<|endoftext|>"
aud, img, vid = pack_mm(dialogue, use_audio_in_video=True)
payload = pipe(
text=prompt,
audio=aud,
images=img,
videos=vid,
return_tensors="pt",
padding="longest",
)
kpos = torch.arange(0, payload["input_ids"].shape[1])
prepared = lm.prepare_inputs_for_generation(**payload, use_cache=True, cache_position=kpos)
out = lm(**prepared, return_dict=True, output_hidden_states=True)
last = out.hidden_states[-1]
vec = last[:, -1]
vec = torch.nn.functional.normalize(vec, p=2, dim=-1)
return vec
# Calls
toy_video = [{
"role": "user",
"content": [{"type": "video", "video": "https://huggingface.co/Tevatron/OmniEmbed-v0.1/resolve/main/assets/mapo_tofu.mp4"}]
}]
embed_v = build_embed(toy_video)
toy_image = [{
"role": "user",
"content": [{"type": "image", "image": "https://huggingface.co/Tevatron/OmniEmbed-v0.1/resolve/main/assets/qwen2.5omni_hgf.png"}]
}]
embed_i = build_embed(toy_image)
What’s going on
The NaNs stem from combining the generation cache with FlashAttention2 in a multimodal setup and not ensuring that every tensor that participates in the forward pass, including cache_position, resides on the same device as the model. That combination leads to invalid values during attention, which then surface as NaN embeddings. The order-dependent behavior is a side effect of this misconfiguration, so a successful video pass does not guarantee stability for subsequent image passes.
Fix and stable pattern
Do not use the generation cache when extracting embeddings. Call the model with use_cache set to False. Additionally, make sure that all tensors produced by the processor are explicitly placed on the same device as the model. The following adjusted function keeps the logic identical but avoids the cache and enforces device consistency. It also casts the selected hidden state to float before normalization.
def build_embed_stable(dialogue):
prompt = pipe.apply_chat_template(dialogue, tokenize=False, add_generation_prompt=True) + "<|endoftext|>"
aud, img, vid = pack_mm(dialogue, use_audio_in_video=True)
payload = pipe(
text=prompt,
audio=aud,
images=img,
videos=vid,
return_tensors="pt",
padding="longest",
)
aligned = {k: (v.to(lm.device) if torch.is_tensor(v) else v) for k, v in payload.items()}
out = lm(
**aligned,
return_dict=True,
output_hidden_states=True,
use_cache=False,
)
vec = torch.nn.functional.normalize(out.hidden_states[-1][:, -1].float(), p=2, dim=-1)
return vec
If NaNs persist even with the cache disabled and tensors co-located, switch the attention implementation to scaled dot product attention by setting attn_implementation to "sdpa" when loading the model. You can also try torch_dtype set to torch.float16.
Why this matters
Embedding pipelines often run in batch jobs and services where a single NaN can invalidate whole retrieval or ranking stages. Seemingly harmless settings like enabling the generation cache can be inappropriate for representation extraction, especially in multimodal paths. Ensuring device consistency and choosing a stable attention backend are small changes that prevent subtle, order-dependent failures.
Takeaways
When you need multimodal embeddings from OmniEmbed on Qwen2.5-Omni, avoid the generation cache, push every tensor from the processor onto the model’s device, and prefer a stable attention implementation if FlashAttention2 still misbehaves. These adjustments keep the embedding flow deterministic for both images and videos and help you sidestep NaN surprises.