2025, Nov 05 00:02

Сохранение и загрузка Keras‑моделей с пользовательскими слоями: как исправить ValueError

Разбираем ValueError при загрузке Keras‑моделей с пользовательскими слоями: регистрация сериализуемых классов, корректный get_config и сохранение в .keras.

Сохранение и повторная загрузка моделей Keras, использующих пользовательские слои, может оказаться неожиданно сложной. Если вы следовали руководству по Transformer ASR и получили модель, которая успешно обучается и сохраняется, но при загрузке падает с ValueError о недостающих переменных, скорее всего, вы столкнулись с проблемой сериализации пользовательских объектов. Ниже — наглядный разбор сценария сбоя и того, как корректно его устранить.

Воспроизведение проблемы

После обучения модели распознавания речи одну эпоху модель сохраняется без ошибок:

net.save("asr_epoch1.keras")

При попытке загрузить модель обратно вместе с её пользовательскими компонентами используется такой подход:

catalog = {
    'TokenEmbedding': LexemeEmbedding,
    'SpeechFeatureEmbedding': AudioFeatureBlock,
    'TransformerEncoder': StackEncoder,
    'TransformerDecoder': StackDecoder,
    'Transformer': Seq2SeqTransformer,
    'CustomSchedule': WarmupSchedule
}
restored = keras.models.load_model("asr_epoch1.keras", custom_objects=catalog, compile=False)

Но при загрузке появляется ошибка примерно такого вида:

ValueError: A total of 51 objects could not be loaded. Example error message for object <Dense name=dense_65, built=True>:
Layer 'dense_65' expected 2 variables, but received 0 variables during loading. Expected: ['kernel', 'bias']
List of objects that could not be loaded:
[<Dense name=dense_65, built=True>, <Embedding name=embedding_10, built=True>, <Embedding name=embedding_11, built=True>, <EinsumDense name=key, built=True>, <EinsumDense name=attention_output, built=True>, <EinsumDense name=query, built=True>, <EinsumDense name=value, built=True>, <Dense name=dense_63, built=True>, <Dense name=dense_64, built=True>, <LayerNormalization name=layer_normalization_63, built=True>, <LayerNormalization name=layer_normalization_64, built=True>, <LayerNormalization name=layer_normalization_65 .......

Сама модель собирается и обучается так:

net = Seq2SeqTransformer(
    num_hid=200,
    num_head=2,
    num_feed_forward=400,
    target_maxlen=max_target_len,
    num_layers_enc=4,
    num_layers_dec=1,
    num_classes=34,
)
opt = keras.optimizers.Adam(lr)
net.compile(optimizer=opt, loss=loss_obj)
hist = net.fit(train_ds, validation_data=eval_ds, callbacks=[progress_cb], epochs=1)

Также пробовали сохранить в формат H5:

net.save("asr_epoch1.h5")

но при загрузке это тоже не сработало.

Что происходит

Ошибка возникает во время десериализации, когда Keras заново создаёт слои и подгружает их веса. Сообщение «ожидалось N переменных, а получено 0» указывает на то, что пользовательские компоненты не были корректно подготовлены к сериализации. Необходимо зарегистрировать пользовательские слои для сериализации в Keras и обеспечить, чтобы их метод get_config возвращал параметры инициализации — так при загрузке Keras сможет воссоздать их правильно. Важно: эти определения должны быть добавлены до того, как модель была собрана и сохранена; иначе сохранённый артефакт не будет содержать нужных метаданных для восстановления переменных.

Решение

Зарегистрируйте каждый пользовательский компонент с помощью @keras.saving.register_keras_serializable и сделайте так, чтобы get_config возвращал ровно те аргументы, которые используются в __init__. Затем запустите среду «с чистого листа», заново соберите модель, обучите, сохраните и только после этого загружайте её с custom_objects.

Ниже пример сериализуемого блока для извлечения аудиопризнаков, используемого внутри модели ASR:

@keras.saving.register_keras_serializable(package="components")
class AudioFeatureBlock(layers.Layer):
    def __init__(self, hid_units=64, max_len=100):
        super().__init__()
        self.hid_units = hid_units
        self.max_len = max_len
        self.f1 = keras.layers.Conv1D(
            hid_units, 11, strides=2, padding="same", activation="relu"
        )
        self.f2 = keras.layers.Conv1D(
            hid_units, 11, strides=2, padding="same", activation="relu"
        )
        self.f3 = keras.layers.Conv1D(
            hid_units, 11, strides=2, padding="same", activation="relu"
        )
    def get_config(self):
        return {
            "hid_units": self.hid_units,
            "max_len": self.max_len,
        }
    def call(self, inp):
        z = self.f1(inp)
        z = self.f2(z)
        return self.f3(z)

После регистрации пользовательских слоёв и возврата параметров инициализации в get_config перезапустите ноутбук и снова обучите модель, чтобы сохранённый файл отразил обновлённую схему сериализации. Затем модель можно загрузить так:

catalog = {
    'TokenEmbedding': LexemeEmbedding,
    'SpeechFeatureEmbedding': AudioFeatureBlock,
    'TransformerEncoder': StackEncoder,
    'TransformerDecoder': StackDecoder,
    'Transformer': Seq2SeqTransformer,
    'CustomSchedule': WarmupSchedule
}
net2 = keras.models.load_model("asr_epoch1.keras", custom_objects=catalog, compile=False)

Почему это важно

При построении ASR‑систем или любых глубоких моделей с пользовательскими компонентами вы часто работаете в ноутбуках. Порядок выполнения и сохранение состояния критичны: если добавить хуки сериализации после создания модели, сохранённый артефакт окажется неконсистентным и не загрузится. Регистрация компонентов и полный get_config до начала обучения гарантируют, что файл модели будет переносим и корректно восстановится в разных сессиях и окружениях.

Практические итоги

Определяйте и регистрируйте каждый пользовательский слой или утилиту, которые участвуют в графе модели. В get_config возвращайте все аргументы конструктора. Перезапустите окружение, заново соберите модель, обучите и сохраните в формате .keras. При загрузке при необходимости укажите custom_objects и compile=False. Если застряли, сократите код до минимального примера, который строит, сохраняет и загружает модель — так быстрее удастся локализовать точку сбоя.

Статья основана на вопросе на StackOverflow от FaisalShakeel и ответе от FaisalShakeel.