Skip to content

Commit 6d66092

Browse files
committed
Allow ComputeOutputShape to override in subclass #660.
1 parent d75366c commit 6d66092

File tree

2 files changed

+7
-3
lines changed

2 files changed

+7
-3
lines changed

src/TensorFlowNET.Keras/Engine/Layer.Layers.cs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System.Collections.Generic;
1+
using System;
2+
using System.Collections.Generic;
23

34
namespace Tensorflow.Keras.Engine
45
{
@@ -11,5 +12,8 @@ protected void StackLayers(params ILayer[] layers)
1112
{
1213
_layers.AddRange(layers);
1314
}
15+
16+
public virtual TensorShape ComputeOutputShape(TensorShape input_shape)
17+
=> throw new NotImplementedException("");
1418
}
1519
}

src/TensorFlowNET.Keras/Layers/Reshaping/Reshape.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,11 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool is_tra
2626

2727
var result = array_ops.reshape(inputs, shape.ToArray());
2828
if (!tf.Context.executing_eagerly())
29-
result.set_shape(compute_output_shape(inputs.shape));
29+
result.set_shape(ComputeOutputShape(inputs.shape));
3030
return result;
3131
}
3232

33-
TensorShape compute_output_shape(TensorShape input_shape)
33+
public override TensorShape ComputeOutputShape(TensorShape input_shape)
3434
{
3535
if (input_shape.dims[0] == -1)
3636
{

0 commit comments

Comments
 (0)