diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs index 9fcd0d70f..6b2c38c32 100644 --- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs +++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs @@ -1,4 +1,5 @@ using System; +using Tensorflow.Framework.Models; using Tensorflow.NumPy; using static Google.Protobuf.Reflection.FieldDescriptorProto.Types; @@ -133,11 +134,16 @@ public ILayer EinsumDense(string equation, public ILayer GlobalMaxPooling1D(string data_format = "channels_last"); public ILayer GlobalMaxPooling2D(string data_format = "channels_last"); - public Tensors Input(Shape shape, + public Tensors Input(Shape shape = null, int batch_size = -1, string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, bool sparse = false, - bool ragged = false); + Tensor tensor = null, + bool ragged = false, + TypeSpec type_spec = null, + Shape batch_input_shape = null, + Shape batch_shape = null); public ILayer InputLayer(Shape input_shape, string name = null, bool sparse = false, diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs index e0d148cef..2bde713c0 100644 --- a/src/TensorFlowNET.Keras/KerasInterface.cs +++ b/src/TensorFlowNET.Keras/KerasInterface.cs @@ -12,6 +12,7 @@ using Tensorflow.Keras.Optimizers; using Tensorflow.Keras.Utils; using System.Threading; +using Tensorflow.Framework.Models; namespace Tensorflow.Keras { @@ -66,33 +67,16 @@ public Functional Model(Tensors inputs, Tensors outputs, string name = null) /// If set, the layer will not create a placeholder tensor. /// /// - public Tensor Input(Shape shape = null, - int batch_size = -1, - Shape batch_input_shape = null, - TF_DataType dtype = TF_DataType.DtInvalid, - string name = null, - bool sparse = false, - bool ragged = false, - Tensor tensor = null) - { - if (batch_input_shape != null) - shape = batch_input_shape.dims.Skip(1).ToArray(); - - var args = new InputLayerArgs - { - Name = name, - InputShape = shape, - BatchInputShape = batch_input_shape, - BatchSize = batch_size, - DType = dtype, - Sparse = sparse, - Ragged = ragged, - InputTensor = tensor - }; - - var layer = new InputLayer(args); - - return layer.InboundNodes[0].Outputs; - } + public Tensors Input(Shape shape = null, + int batch_size = -1, + string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, + bool sparse = false, + Tensor tensor = null, + bool ragged = false, + TypeSpec type_spec = null, + Shape batch_input_shape = null, + Shape batch_shape = null) => keras.layers.Input(shape, batch_size, name, + dtype, sparse, tensor, ragged, type_spec, batch_input_shape, batch_shape); } } diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs index 0d71b2713..cf689edf1 100644 --- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs +++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs @@ -1,4 +1,5 @@ using System; +using Tensorflow.Framework.Models; using Tensorflow.Keras.ArgsDefinition; using Tensorflow.Keras.ArgsDefinition.Core; using Tensorflow.Keras.ArgsDefinition.Rnn; @@ -471,20 +472,56 @@ public ILayer Flatten(string data_format = null) /// In this case, values of 'None' in the 'shape' argument represent ragged dimensions. For more information about RaggedTensors, see this guide. /// /// A tensor. - public Tensors Input(Shape shape, + public Tensors Input(Shape shape = null, int batch_size = -1, string name = null, + TF_DataType dtype = TF_DataType.DtInvalid, bool sparse = false, - bool ragged = false) + Tensor tensor = null, + bool ragged = false, + TypeSpec type_spec = null, + Shape batch_input_shape = null, + Shape batch_shape = null) { - var input_layer = new InputLayer(new InputLayerArgs + if(sparse && ragged) + { + throw new ValueError("Cannot set both `sparse` and `ragged` to `true` in a Keras `Input`."); + } + + InputLayerArgs input_layer_config = new() { - InputShape = shape, - BatchSize= batch_size, Name = name, + DType = dtype, Sparse = sparse, - Ragged = ragged - }); + Ragged = ragged, + InputTensor = tensor, + // skip the `type_spec` + }; + + if(shape is not null && batch_input_shape is not null) + { + throw new ValueError("Only provide the `shape` OR `batch_input_shape` argument " + + "to Input, not both at the same time."); + } + + if(batch_input_shape is null && shape is null && tensor is null && type_spec is null) + { + throw new ValueError("Please provide to Input a `shape` or a `tensor` or a `type_spec` argument. Note that " + + "`shape` does not include the batch dimension."); + } + + if(batch_input_shape is not null) + { + shape = batch_input_shape["1:"]; + input_layer_config.BatchInputShape = batch_input_shape; + } + else + { + input_layer_config.BatchSize = batch_size; + input_layer_config.InputShape = shape; + } + + var input_layer = new InputLayer(input_layer_config); return input_layer.InboundNodes[0].Outputs; } diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs index 02298ce81..e5987f298 100644 --- a/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs +++ b/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs @@ -158,7 +158,7 @@ public void test_masked_attention() var value = keras.Input(shape: (2, 8)); var mask_tensor = keras.Input(shape:(4, 2)); var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2); - attention_layer.Apply(new[] { query, value, mask_tensor }); + attention_layer.Apply(new Tensor[] { query, value, mask_tensor }); var from_data = 10 * np.random.randn(batch_size, 4, 8); var to_data = 10 * np.random.randn(batch_size, 2, 8);