2025, Dec 18 21:02

Почему в LlamaForCausalLM метки выровнены с input_ids, а prompt маскируют -100

Разбираем обучение LlamaForCausalLM на паре «prompt → ответ»: как считается loss, почему метки не сдвигают, а prompt маскируют -100. Пример разметки и код

Обучение LlamaForCausalLM на задаче в духе seq2seq — например, «диалог → краткое резюме» — часто поднимает вопрос о разметке: почему в датасете для меток используют те же id токенов, что и для входа, помечая часть с подсказкой значением -100, а не сдвигают метки на одну позицию? Разберёмся в настройке, заглянем в код и уточним, как именно считается loss в моделях CausalLM, чтобы поведение стало понятным.

Минимальный пример, из‑за которого возникает путаница

При подготовке данных диалог (prompt) и итоговое резюме склеиваются в одну последовательность. Токены из prompt игнорируются при расчёте функции потерь, а токены из резюме идут в метки. Ключевая логика выглядит так:

tok = tokenizer
row = sample

lead_seq = tok.encode(tok.bos_token + row["prompt"], add_special_tokens=False)
abstract_seq = tok.encode(row["summary"] + tok.eos_token, add_special_tokens=False)

row = {
    "input_ids": lead_seq + abstract_seq,
    "attention_mask": [1] * (len(lead_seq) + len(abstract_seq)),
    "labels": [-100] * len(lead_seq) + abstract_seq,
}

Если посмотреть на батч, видно: метки повторяют целевую часть input_ids, а позиции prompt заполнены -100.

train_loader = train_dataloader
mini = next(iter(train_loader))
print(mini["input_ids"][0][35:40])
print(mini["labels"][0][35:40])

tensor([19791, 512, 32, 36645, 41778])
tensor([ -100, -100, 32, 36645, 41778])

На первый взгляд это кажется странным, если вы ожидаете, что labels[i] должны соответствовать input_ids[i+1] для каждого предсказываемого токена. Почему же это работает?

Что на самом деле предсказывает CausalLM

Каузальные языковые модели предсказывают каждый токен, опираясь на все предыдущие токены в той же последовательности. По сути, i‑й прогноз использует префикс до позиции i−1. Если смотреть на всю последовательность, «золотые» токены, с которыми нужно совпадение на каждом шаге, — это те же токены, что уже стоят в самой последовательности. Поэтому в целевом фрагменте labels выровнены с input_ids.

Сдвиг выполняется под капотом при расчёте лосса. Популярные реализации принимают labels, выровненные с input_ids, и внутри смещают целевые значения на один шаг при вычислении кросс-энтропии. Вручную сдвигать метки вне модели не нужно.

Почему prompt помечают -100

В функцию потерь должна вносить вклад только часть с резюме. Поэтому участок prompt в labels заполняется -100 — этим мы говорим лоссу игнорировать эти позиции. Модель по‑прежнему «видит» токены prompt как контекст, но они не влияют на loss.

Для эффективности логиты генерируются за один проход для всей последовательности, а затем loss считается по выровненным меткам. Если сгенерированная часть короче меток, её дополняют паддингом; если длиннее — обрезают. Важна сопоставимость именно на целевом участке последовательности.

Как правильно формировать метки

Ручной сдвиг не требуется. Оставляйте метки равными целевым токенам и маскируйте prompt значением -100. Рабочий шаблон полностью соответствует обработке датасета выше. Хотите более компактную запись — функционально это то же самое:

tok = tokenizer
ex = sample

ctx = tok.encode(tok.bos_token + ex["prompt"], add_special_tokens=False)
resp = tok.encode(ex["summary"] + tok.eos_token, add_special_tokens=False)

ex = {
    "input_ids": ctx + resp,
    "attention_mask": [1] * (len(ctx) + len(resp)),
    "labels": [-100] * len(ctx) + resp,
}

Именно такого формата CausalLM ожидает при супервизионном дообучении задач вроде суммаризации, когда они формулируются как предсказание следующего токена над склеенной последовательностью prompt+target.

Почему эта деталь важна

Неверное понимание выравнивания меток ведёт к ошибочной предобработке и, как следствие, к неверным сигналам лосса. Осознание, что сдвиг в CausalLM делается внутри, уберегает от двойного сдвига и от случайного обучения предсказывать не те позиции. Маска -100 на prompt гарантирует, что модель учится только по токенам резюме, но при этом использует prompt как контекст.

Выводы

При дообучении CausalLM на парах «подсказка → ответ» склеивайте prompt и выход, помечайте часть prompt значением -100, чтобы исключить её из лосса, а в сегменте меток используйте целевые токены как есть. Полагайтесь на код модели для сдвига на один токен при вычислении лосса. Эта простая конвенция упрощает предобработку и соответствует тому, как на самом деле считается loss.