Skip to content

[BUG Report]: duplicated variable names in layer components causes the loading weights failure. #1166

Open
@lingbai-kong

Description

@lingbai-kong

Description

When creating the layer with two different embeddings, the variable names of these embeddings are the same, which confuses the load_weights process and leads to the prediction error after loading:Unhandled exception. Tensorflow.RuntimeError: Attempting to capture an EagerTensor without building a function.

Reproduction Steps

run the following code:

using Newtonsoft.Json;
using Tensorflow;
using Tensorflow.Keras;
using Tensorflow.NumPy;
using Tensorflow.Common.Types;
using Tensorflow.Keras.Engine;
using Tensorflow.Keras.Saving;
using Tensorflow.Keras.ArgsDefinition;
using static Tensorflow.KerasApi;
using static Tensorflow.Binding;
using System.IO;

public class TokenAndPositionEmbeddingArgs : AutoSerializeLayerArgs
{
    [JsonProperty("max_len")]
    public int Maxlen { get; set; }
    [JsonProperty("vocab_sise")]
    public int VocabSize { get; set; }
    [JsonProperty("embed_dim")]
    public int EmbedDim { get; set; }
    [JsonProperty("activity_regularizer")]
    public override IRegularizer ActivityRegularizer { get => base.ActivityRegularizer; set => base.ActivityRegularizer = value; }
}

public class TokenAndPositionEmbedding : Layer
{
    TokenAndPositionEmbeddingArgs args;
    Tensor positions_base;
    ILayer token_emb;
    ILayer pos_emb;

    public TokenAndPositionEmbedding(TokenAndPositionEmbeddingArgs args) : base(args)
    {
        this.args = args;
    }

    public override void build(KerasShapesWrapper input_shape)
    {
        _buildInputShape = input_shape;
        positions_base = tf.constant(tf.range(start: 0, limit: args.Maxlen, delta: 1));
        token_emb = keras.layers.Embedding(input_dim: args.VocabSize, output_dim: args.EmbedDim);
        pos_emb = keras.layers.Embedding(input_dim: args.Maxlen, output_dim: args.EmbedDim);
        StackLayers(token_emb, pos_emb);
        built = true;
    }

    protected override Tensors Call(Tensors inputs, Tensors state = null, bool? training = false, IOptionalArgs? optional_args = null)
    {
        var embedding = token_emb.Apply(inputs, state, training ?? false, optional_args);
        var positions = pos_emb.Apply(positions_base, state, training ?? false, optional_args);
        return (Tensor)embedding + (Tensor)positions;
    }
}
class Program
{
    static void Main(string[] args)
    {
        Run();
    }

    static void Run()
    {
        var inputs = keras.Input(shape: new[] { 200 });
        var embedding = new TokenAndPositionEmbedding(
            new TokenAndPositionEmbeddingArgs
            {
                Maxlen = 200,
                VocabSize = 20000,
                EmbedDim = 32
            });
        var outputs = embedding.Apply(inputs);
        outputs = keras.layers.GlobalAveragePooling1D().Apply(outputs);
        outputs = keras.layers.Dense(2, activation: "softmax").Apply(outputs);
        var model = keras.Model(inputs: inputs, outputs: outputs);

        var x = new NDArray(tf.range(start: 0, limit: 8 * 200)).reshape(new[] { 8, 200 });
        var y = new NDArray(new[] { 0, 1, 0, 0, 1, 1, 1, 0 });

        model.summary();
        model.compile(optimizer: keras.optimizers.Adam(learning_rate: 0.01f), loss: keras.losses.SparseCategoricalCrossentropy(), metrics: new string[] { "accuracy" });
        model.fit(x, y, batch_size: 1, epochs: 10);

        var token_emb_var_name = model.Layers[1].Layers[0].TrainableVariables[0].Name;
        var pos_emb_var_name = model.Layers[1].Layers[1].TrainableVariables[0].Name;
        Console.WriteLine(token_emb_var_name);
        Console.WriteLine(pos_emb_var_name);

        model.save_weights("weights.h5");
        var load_model = keras.Model(inputs: inputs, outputs: outputs);
        load_model.load_weights("weights.h5");
        load_model.predict(x);
    }
}

The variable names of both token_emb and pos_emb are token_and_position_embedding/embedding/embeddings:0. Thus, their parameters have the same key name in the saved h5 file. Therefore, when loading weights, the hdf5_format.load_weights_from_hdf5_group misloads the parameters for pos_emb to the token_emb.
image

Known Workarounds

redefine the pos_emb as follows:

tf_with(ops.name_scope("position_embeddings"), scope =>
{
    position_embeddings = add_weight(name: "position_embedding", shape: (200, 32));
});

Configuration and Other Information

No response

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions