diff --git a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
index 9fcd0d70f..6b2c38c32 100644
--- a/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
+++ b/src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs
@@ -1,4 +1,5 @@
using System;
+using Tensorflow.Framework.Models;
using Tensorflow.NumPy;
using static Google.Protobuf.Reflection.FieldDescriptorProto.Types;
@@ -133,11 +134,16 @@ public ILayer EinsumDense(string equation,
public ILayer GlobalMaxPooling1D(string data_format = "channels_last");
public ILayer GlobalMaxPooling2D(string data_format = "channels_last");
- public Tensors Input(Shape shape,
+ public Tensors Input(Shape shape = null,
int batch_size = -1,
string name = null,
+ TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false,
- bool ragged = false);
+ Tensor tensor = null,
+ bool ragged = false,
+ TypeSpec type_spec = null,
+ Shape batch_input_shape = null,
+ Shape batch_shape = null);
public ILayer InputLayer(Shape input_shape,
string name = null,
bool sparse = false,
diff --git a/src/TensorFlowNET.Keras/KerasInterface.cs b/src/TensorFlowNET.Keras/KerasInterface.cs
index e0d148cef..2bde713c0 100644
--- a/src/TensorFlowNET.Keras/KerasInterface.cs
+++ b/src/TensorFlowNET.Keras/KerasInterface.cs
@@ -12,6 +12,7 @@
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.Utils;
using System.Threading;
+using Tensorflow.Framework.Models;
namespace Tensorflow.Keras
{
@@ -66,33 +67,16 @@ public Functional Model(Tensors inputs, Tensors outputs, string name = null)
/// If set, the layer will not create a placeholder tensor.
///
///
- public Tensor Input(Shape shape = null,
- int batch_size = -1,
- Shape batch_input_shape = null,
- TF_DataType dtype = TF_DataType.DtInvalid,
- string name = null,
- bool sparse = false,
- bool ragged = false,
- Tensor tensor = null)
- {
- if (batch_input_shape != null)
- shape = batch_input_shape.dims.Skip(1).ToArray();
-
- var args = new InputLayerArgs
- {
- Name = name,
- InputShape = shape,
- BatchInputShape = batch_input_shape,
- BatchSize = batch_size,
- DType = dtype,
- Sparse = sparse,
- Ragged = ragged,
- InputTensor = tensor
- };
-
- var layer = new InputLayer(args);
-
- return layer.InboundNodes[0].Outputs;
- }
+ public Tensors Input(Shape shape = null,
+ int batch_size = -1,
+ string name = null,
+ TF_DataType dtype = TF_DataType.DtInvalid,
+ bool sparse = false,
+ Tensor tensor = null,
+ bool ragged = false,
+ TypeSpec type_spec = null,
+ Shape batch_input_shape = null,
+ Shape batch_shape = null) => keras.layers.Input(shape, batch_size, name,
+ dtype, sparse, tensor, ragged, type_spec, batch_input_shape, batch_shape);
}
}
diff --git a/src/TensorFlowNET.Keras/Layers/LayersApi.cs b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
index 0d71b2713..cf689edf1 100644
--- a/src/TensorFlowNET.Keras/Layers/LayersApi.cs
+++ b/src/TensorFlowNET.Keras/Layers/LayersApi.cs
@@ -1,4 +1,5 @@
using System;
+using Tensorflow.Framework.Models;
using Tensorflow.Keras.ArgsDefinition;
using Tensorflow.Keras.ArgsDefinition.Core;
using Tensorflow.Keras.ArgsDefinition.Rnn;
@@ -471,20 +472,56 @@ public ILayer Flatten(string data_format = null)
/// In this case, values of 'None' in the 'shape' argument represent ragged dimensions. For more information about RaggedTensors, see this guide.
///
/// A tensor.
- public Tensors Input(Shape shape,
+ public Tensors Input(Shape shape = null,
int batch_size = -1,
string name = null,
+ TF_DataType dtype = TF_DataType.DtInvalid,
bool sparse = false,
- bool ragged = false)
+ Tensor tensor = null,
+ bool ragged = false,
+ TypeSpec type_spec = null,
+ Shape batch_input_shape = null,
+ Shape batch_shape = null)
{
- var input_layer = new InputLayer(new InputLayerArgs
+ if(sparse && ragged)
+ {
+ throw new ValueError("Cannot set both `sparse` and `ragged` to `true` in a Keras `Input`.");
+ }
+
+ InputLayerArgs input_layer_config = new()
{
- InputShape = shape,
- BatchSize= batch_size,
Name = name,
+ DType = dtype,
Sparse = sparse,
- Ragged = ragged
- });
+ Ragged = ragged,
+ InputTensor = tensor,
+ // skip the `type_spec`
+ };
+
+ if(shape is not null && batch_input_shape is not null)
+ {
+ throw new ValueError("Only provide the `shape` OR `batch_input_shape` argument "
+ + "to Input, not both at the same time.");
+ }
+
+ if(batch_input_shape is null && shape is null && tensor is null && type_spec is null)
+ {
+ throw new ValueError("Please provide to Input a `shape` or a `tensor` or a `type_spec` argument. Note that " +
+ "`shape` does not include the batch dimension.");
+ }
+
+ if(batch_input_shape is not null)
+ {
+ shape = batch_input_shape["1:"];
+ input_layer_config.BatchInputShape = batch_input_shape;
+ }
+ else
+ {
+ input_layer_config.BatchSize = batch_size;
+ input_layer_config.InputShape = shape;
+ }
+
+ var input_layer = new InputLayer(input_layer_config);
return input_layer.InboundNodes[0].Outputs;
}
diff --git a/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs b/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs
index 02298ce81..e5987f298 100644
--- a/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs
+++ b/test/TensorFlowNET.Keras.UnitTest/Layers/AttentionTest.cs
@@ -158,7 +158,7 @@ public void test_masked_attention()
var value = keras.Input(shape: (2, 8));
var mask_tensor = keras.Input(shape:(4, 2));
var attention_layer = keras.layers.MultiHeadAttention(num_heads: 2, key_dim: 2);
- attention_layer.Apply(new[] { query, value, mask_tensor });
+ attention_layer.Apply(new Tensor[] { query, value, mask_tensor });
var from_data = 10 * np.random.randn(batch_size, 4, 8);
var to_data = 10 * np.random.randn(batch_size, 2, 8);