Skip to content

Commit 321ddfc

Browse files
committed
Fix Model.build.
1 parent 0f7bf4d commit 321ddfc

File tree

15 files changed

+104
-48
lines changed

15 files changed

+104
-48
lines changed

src/TensorFlowNET.Console/SimpleRnnTest.cs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,16 @@ public class SimpleRnnTest
1212
{
1313
public void Run()
1414
{
15-
tf.keras = new KerasInterface();
16-
var inputs = np.random.random((32, 10, 8)).astype(np.float32);
17-
var simple_rnn = tf.keras.layers.SimpleRNN(4);
18-
var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`.
19-
if (output.shape == (32, 4))
20-
{
15+
tf.UseKeras<KerasInterface>();
16+
var inputs = np.random.random((6, 10, 8)).astype(np.float32);
17+
//var simple_rnn = tf.keras.layers.SimpleRNN(4);
18+
//var output = simple_rnn.Apply(inputs); // The output has shape `[32, 4]`.
2119

22-
}
23-
/*simple_rnn = tf.keras.layers.SimpleRNN(
24-
4, return_sequences = True, return_state = True)
20+
var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true);
2521

26-
# whole_sequence_output has shape `[32, 10, 4]`.
27-
# final_state has shape `[32, 4]`.
28-
whole_sequence_output, final_state = simple_rnn(inputs)*/
22+
// whole_sequence_output has shape `[32, 10, 4]`.
23+
// final_state has shape `[32, 4]`.
24+
var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
2925
}
3026
}
3127
}

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ public interface ILayer
99
string Name { get; }
1010
bool Trainable { get; }
1111
bool Built { get; }
12+
void build(Shape input_shape);
1213
List<ILayer> Layers { get; }
1314
List<INode> InboundNodes { get; }
1415
List<INode> OutboundNodes { get; }

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -163,7 +163,9 @@ public ILayer SimpleRNN(int units,
163163
string activation = "tanh",
164164
string kernel_initializer = "glorot_uniform",
165165
string recurrent_initializer = "orthogonal",
166-
string bias_initializer = "zeros");
166+
string bias_initializer = "zeros",
167+
bool return_sequences = false,
168+
bool return_state = false);
167169

168170
public ILayer Subtract();
169171
}
Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,32 @@
11
using System;
2+
using System.Linq;
3+
using static Tensorflow.TensorShapeProto.Types;
24

35
namespace Tensorflow.Operations.Initializers
46
{
57
public class Orthogonal : IInitializer
68
{
9+
float _gain = 0f;
10+
11+
public Orthogonal(float gain = 1.0f, int? seed = null)
12+
{
13+
14+
}
15+
716
public Tensor Apply(InitializerArgs args)
817
{
9-
throw new NotImplementedException();
18+
return _generate_init_val(args.Shape, args.DType);
19+
}
20+
21+
private Tensor _generate_init_val(Shape shape, TF_DataType dtype)
22+
{
23+
var num_rows = 1L;
24+
foreach (var dim in shape.dims.Take(shape.ndim - 1))
25+
num_rows *= dim;
26+
var num_cols = shape.dims.Last();
27+
var flat_shape = (Math.Max(num_cols, num_rows), Math.Min(num_cols, num_rows));
28+
29+
throw new NotImplementedException("");
1030
}
1131
}
1232
}

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,5 +147,10 @@ public LayerArgs get_config()
147147
{
148148
throw new NotImplementedException();
149149
}
150+
151+
public void build(Shape input_shape)
152+
{
153+
throw new NotImplementedException();
154+
}
150155
}
151156
}

src/TensorFlowNET.Core/tensorflow.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ public tensorflow()
6565
InitGradientEnvironment();
6666
}
6767

68+
public void UseKeras<T>() where T : IKerasApi, new()
69+
{
70+
keras = new T();
71+
}
72+
6873
public string VERSION => c_api.StringPiece(c_api.TF_Version());
6974

7075
private void InitGradientEnvironment()

src/TensorFlowNET.Keras/Engine/Functional.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,12 @@ protected void _init_graph_network(Tensors inputs, Tensors outputs)
6565
}
6666

6767
// Keep track of the network's nodes and layers.
68-
(NetworkNodes, NodesByDepth, _self_tracked_trackables, _) = MapGraphNetwork(inputs, outputs);
68+
(NetworkNodes, NodesByDepth, var layers, _) = MapGraphNetwork(inputs, outputs);
69+
70+
if (!_self_tracked_trackables.Any())
71+
{
72+
_self_tracked_trackables = layers;
73+
}
6974

7075
// Build self.input_names and self.output_names.
7176
_set_output_names();

src/TensorFlowNET.Keras/Engine/Model.Build.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
using System;
22
using System.Linq;
33
using Tensorflow.Graphs;
4-
using Tensorflow.Keras.ArgsDefinition;
5-
using Tensorflow.Keras.Losses;
6-
using Tensorflow.Keras.Optimizers;
74
using static Tensorflow.Binding;
85
using static Tensorflow.KerasApi;
96

@@ -13,6 +10,12 @@ public partial class Model
1310
{
1411
public override void build(Shape input_shape)
1512
{
13+
if (this is Functional || this is Sequential)
14+
{
15+
base.build(input_shape);
16+
return;
17+
}
18+
1619
var graph = tf.executing_eagerly() ? new FuncGraph("build_graph") : keras.backend.get_graph();
1720

1821
graph.as_default();

src/TensorFlowNET.Keras/Engine/Sequential.cs

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -122,15 +122,9 @@ public void add(ILayer layer)
122122
else
123123
{
124124
_self_tracked_trackables.add(layer);
125-
_handle_deferred_layer_dependencies(layer);
126125
}
127126
}
128127

129-
void _handle_deferred_layer_dependencies(params ILayer[] layers)
130-
{
131-
_self_tracked_trackables.AddRange(layers);
132-
}
133-
134128
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
135129
{
136130
if (!_has_explicit_input_shape)
@@ -156,7 +150,7 @@ void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType inpu
156150
ops.init_scope();
157151
var inputs = keras.Input(batch_input_shape: input_shape,
158152
dtype: input_dtype,
159-
name: $"{_self_tracked_trackables[0].Name}_input");
153+
name: _self_tracked_trackables[0].Name.EndsWith("_input") ? _self_tracked_trackables[0].Name : $"{_self_tracked_trackables[0].Name}_input");
160154
Tensors layer_input = inputs;
161155
Tensors layer_output = null;
162156
Tensors outputs = null;

src/TensorFlowNET.Keras/Layers/LayersApi.cs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -658,14 +658,18 @@ public ILayer SimpleRNN(int units,
658658
string activation = "tanh",
659659
string kernel_initializer = "glorot_uniform",
660660
string recurrent_initializer = "orthogonal",
661-
string bias_initializer = "zeros")
661+
string bias_initializer = "zeros",
662+
bool return_sequences = false,
663+
bool return_state = false)
662664
=> new SimpleRNN(new SimpleRNNArgs
663665
{
664666
Units = units,
665667
Activation = GetActivationByName(activation),
666668
KernelInitializer = GetInitializerByName(kernel_initializer),
667-
RecurrentInitializer= GetInitializerByName(recurrent_initializer),
668-
BiasInitializer= GetInitializerByName(bias_initializer)
669+
RecurrentInitializer = GetInitializerByName(recurrent_initializer),
670+
BiasInitializer = GetInitializerByName(bias_initializer),
671+
ReturnSequences = return_sequences,
672+
ReturnState = return_state
669673
});
670674

671675
/// <summary>

src/TensorFlowNET.Keras/Layers/Rnn/RNN.cs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public class RNN : Layer
1818
private int _num_constants = 0;
1919
protected IVariableV1 kernel;
2020
protected IVariableV1 bias;
21-
21+
protected ILayer cell;
2222
public RNN(RNNArgs args) : base(PreConstruct(args))
2323
{
2424
this.args = args;
@@ -37,6 +37,14 @@ public RNN(RNNArgs args) : base(PreConstruct(args))
3737
//}
3838
}
3939

40+
public override void build(Shape input_shape)
41+
{
42+
if (!cell.Built)
43+
{
44+
cell.build(input_shape);
45+
}
46+
}
47+
4048
private static RNNArgs PreConstruct(RNNArgs args)
4149
{
4250
if (args.Kwargs == null)

src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNN.cs

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,22 +9,10 @@ namespace Tensorflow.Keras.Layers.Rnn
99
public class SimpleRNN : RNN
1010
{
1111
SimpleRNNArgs args;
12-
SimpleRNNCell cell;
1312
public SimpleRNN(SimpleRNNArgs args) : base(args)
1413
{
1514
this.args = args;
16-
}
17-
18-
public override void build(Shape input_shape)
19-
{
20-
var input_dim = input_shape[-1];
21-
22-
kernel = add_weight("kernel", (input_shape[-1], args.Units),
23-
initializer: args.KernelInitializer
24-
//regularizer = self.kernel_regularizer,
25-
//constraint = self.kernel_constraint,
26-
//caching_device = default_caching_device,
27-
);
15+
cell = new SimpleRNNCell(args);
2816
}
2917
}
3018
}

src/TensorFlowNET.Keras/Layers/Rnn/SimpleRNNCell.cs

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,36 @@ namespace Tensorflow.Keras.Layers.Rnn
88
{
99
public class SimpleRNNCell : Layer
1010
{
11+
SimpleRNNArgs args;
12+
IVariableV1 kernel;
13+
IVariableV1 recurrent_kernel;
14+
IVariableV1 bias;
15+
1116
public SimpleRNNCell(SimpleRNNArgs args) : base(args)
1217
{
13-
18+
this.args = args;
1419
}
1520

1621
public override void build(Shape input_shape)
1722
{
18-
23+
var input_dim = input_shape[-1];
24+
25+
kernel = add_weight("kernel", (input_shape[-1], args.Units),
26+
initializer: args.KernelInitializer
27+
);
28+
29+
recurrent_kernel = add_weight("recurrent_kernel", (args.Units, args.Units),
30+
initializer: args.RecurrentInitializer
31+
);
32+
33+
if (args.UseBias)
34+
{
35+
bias = add_weight("bias", (args.Units),
36+
initializer: args.RecurrentInitializer
37+
);
38+
}
39+
40+
built = true;
1941
}
2042
}
2143
}

test/TensorFlowNET.Keras.UnitTest/Layers/LayersTest.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,10 +150,13 @@ public void EinsumDense()
150150
[TestMethod]
151151
public void SimpleRNN()
152152
{
153-
var inputs = np.random.random((32, 10, 8)).astype(np.float32);
154-
var simple_rnn = keras.layers.SimpleRNN(4);
153+
tf.UseKeras<KerasInterface>();
154+
var inputs = np.arange(6 * 10 * 8).reshape((6, 10, 8)).astype(np.float32);
155+
/*var simple_rnn = keras.layers.SimpleRNN(4);
155156
var output = simple_rnn.Apply(inputs);
156-
Assert.AreEqual((32, 4), output.shape);
157+
Assert.AreEqual((32, 4), output.shape);*/
158+
var simple_rnn = tf.keras.layers.SimpleRNN(4, return_sequences: true, return_state: true);
159+
var (whole_sequence_output, final_state) = simple_rnn.Apply(inputs);
157160
}
158161

159162
[TestMethod]

test/TensorFlowNET.UnitTest/Tensorflow.Binding.UnitTest.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747

4848
<ItemGroup>
4949
<PackageReference Include="FluentAssertions" Version="5.10.3" />
50-
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.144" />
50+
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
5151
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="17.0.0" />
5252
<PackageReference Include="MSTest.TestAdapter" Version="2.2.8" />
5353
<PackageReference Include="MSTest.TestFramework" Version="2.2.8" />

0 commit comments

Comments
 (0)