Skip to content

Commit 18d2512

Browse files
committed
fix keras sequential.
1 parent 824308a commit 18d2512

File tree

6 files changed

+87
-40
lines changed

6 files changed

+87
-40
lines changed

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 10 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -98,35 +98,23 @@ private static string _tostring(object obj)
9898
default:
9999
return obj?.ToString() ?? "null";
100100
}
101-
102-
object[] toObjectArray(Array arr)
103-
{
104-
var len = arr.LongLength;
105-
var ret = new object[len];
106-
for (long i = 0; i < len; i++)
107-
{
108-
ret[i] = arr.GetValue(i);
109-
}
110-
111-
return ret;
112-
}
113101
}
114102

115-
private static TextWriter writer = null;
103+
private static TextWriter _writer = Console.Out;
116104

117105
public static TextWriter tf_output_redirect {
118106
set
119107
{
120-
var originWriter = writer ?? Console.Out;
121-
originWriter.Flush();
122-
if (originWriter is StringWriter)
123-
(originWriter as StringWriter).GetStringBuilder().Clear();
124-
writer = value;
125-
}
126-
get
127-
{
128-
return writer ?? Console.Out;
108+
if(_writer != null)
109+
{
110+
_writer.Flush();
111+
if (_writer is StringWriter sw)
112+
sw.GetStringBuilder().Clear();
113+
}
114+
115+
_writer = value;
129116
}
117+
get => _writer ?? Console.Out;
130118
}
131119

132120
public static void print(object obj)

src/TensorFlowNET.Core/DisposableObject.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ private void Dispose(bool disposing)
4848
}
4949

5050
// free unmanaged memory
51-
// if (_handle != IntPtr.Zero)
51+
if (_handle != IntPtr.Zero)
5252
{
5353
// Call the appropriate methods to clean up
5454
// unmanaged resources here.

src/TensorFlowNET.Keras/Engine/Layer.Apply.cs

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,32 +14,23 @@ public partial class Layer
1414
/// <returns></returns>
1515
public Tensors Apply(Tensors inputs, Tensor state = null, bool training = false)
1616
{
17-
callContext = callContext?.Value != null ? callContext : new ThreadLocal<CallContext>()
18-
{
19-
Value = new CallContext()
20-
};
17+
if (callContext.Value == null)
18+
callContext.Value = new CallContext();
2119

2220
if (_in_functional_construction_mode(inputs))
2321
return FunctionalConstructionCall(inputs);
2422

25-
Tensors outputs = null;
26-
2723
var eager = tf.executing_eagerly();
2824
using var ctxManager = CallContext.enter(build_graph: false);
2925

30-
string nameScope = "";
31-
if (eager)
32-
nameScope = Name;
33-
else
34-
nameScope = _name_scope();
35-
26+
string nameScope = eager ? name : _name_scope();
3627
var scope = ops.name_scope(nameScope);
3728
scope.__enter__();
3829

3930
if (!built)
4031
MaybeBuild(inputs);
4132

42-
outputs = Call(inputs, state: state, training: training);
33+
var outputs = Call(inputs, state: state, training: training);
4334

4435
// memory leak
4536
// _set_connectivity_metadata_(inputs, outputs);

src/TensorFlowNET.Keras/Engine/Layer.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,13 @@ public abstract partial class Layer : AutoTrackable, ILayer
8484
List<INode> outboundNodes;
8585
public List<INode> OutboundNodes => outboundNodes;
8686

87-
ThreadLocal<CallContext> callContext;
87+
ThreadLocal<CallContext> callContext = new ThreadLocal<CallContext>();
8888
public CallContext CallContext => callContext.Value;
8989
public Tensor[] input => inboundNodes[0].input_tensors;
9090
public Dictionary<int, List<INode>> NodesByDepth { get; set; }
9191
public Shape output_shape => inboundNodes[0].Outputs.shape;
92+
protected List<ILayer> _self_tracked_trackables;
93+
9294
public Layer(LayerArgs args)
9395
{
9496
this.args = args;
@@ -106,6 +108,7 @@ public Layer(LayerArgs args)
106108
non_trainable_weights = new List<IVariableV1>();
107109
computePreviousMask = false;
108110
updates = new List<Operation>();
111+
_self_tracked_trackables = new List<ILayer>();
109112

110113
inboundNodes = new List<INode>();
111114
outboundNodes = new List<INode>();

src/TensorFlowNET.Keras/Engine/Sequential.cs

Lines changed: 67 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using System;
1718
using System.Linq;
1819
using System.Collections.Generic;
1920
using Tensorflow.Keras.ArgsDefinition;
@@ -35,8 +36,9 @@ public class Sequential : Functional
3536
bool _auto_track_sub_layers;
3637
Shape _inferred_input_shape;
3738
bool _has_explicit_input_shape;
38-
39+
bool _graph_initialized;
3940
public Shape output_shape => outputs[0].shape;
41+
List<INode> _created_nodes;
4042

4143
public Sequential(SequentialArgs args)
4244
: base(args.Inputs, args.Outputs, name: args.Name)
@@ -49,12 +51,13 @@ public Sequential(SequentialArgs args)
4951
_auto_track_sub_layers = false;
5052
_has_explicit_input_shape = false;
5153
_is_graph_network = false;
54+
_created_nodes = new List<INode>();
5255

5356
// Add to the model any layers passed to the constructor.
5457
if (args.Layers != null)
5558
{
5659
foreach (var layer in args.Layers)
57-
add(layer as Layer);
60+
add(layer);
5861
}
5962
}
6063

@@ -118,7 +121,69 @@ public void add(ILayer layer)
118121
}
119122
else
120123
{
124+
_self_tracked_trackables.add(layer);
125+
_handle_deferred_layer_dependencies(layer);
126+
}
127+
}
121128

129+
void _handle_deferred_layer_dependencies(params ILayer[] layers)
130+
{
131+
_layers.AddRange(layers);
132+
}
133+
134+
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
135+
{
136+
if (!_has_explicit_input_shape)
137+
{
138+
_build_graph_network_for_inferred_shape(inputs.shape, inputs.dtype);
139+
}
140+
141+
if(_graph_initialized)
142+
{
143+
if (!built)
144+
_init_graph_network(this.inputs, outputs);
145+
return base.Call(inputs, state, training);
146+
}
147+
148+
return base.Call(inputs, state, training);
149+
}
150+
151+
void _build_graph_network_for_inferred_shape(Shape input_shape, TF_DataType input_dtype)
152+
{
153+
ops.init_scope();
154+
var inputs = keras.Input(batch_input_shape: input_shape,
155+
dtype: input_dtype,
156+
name: $"{_layers[0].Name}_input");
157+
Tensors layer_input = inputs;
158+
Tensors layer_output = null;
159+
Tensors outputs = null;
160+
161+
foreach (var layer in _layers)
162+
{
163+
clear_previously_created_nodes(layer, _created_nodes);
164+
layer_output = layer.Apply(layer_input);
165+
// Keep track of nodes just created above
166+
track_nodes_created_by_last_call(layer, _created_nodes);
167+
layer_input = layer_output;
168+
outputs = layer_output;
169+
}
170+
_init_graph_network(inputs, outputs);
171+
_graph_initialized = true;
172+
_inferred_input_shape = input_shape;
173+
}
174+
175+
void clear_previously_created_nodes(ILayer layer, List<INode> created_nodes)
176+
{
177+
178+
}
179+
180+
void track_nodes_created_by_last_call(ILayer layer, List<INode> created_nodes)
181+
{
182+
var node = layer.InboundNodes.Last();
183+
created_nodes.Add(node);
184+
foreach(var prev_layer in node.iterate_inbound())
185+
{
186+
created_nodes.add(prev_layer.Item1.OutboundNodes.Last());
122187
}
123188
}
124189
}

src/TensorFlowNET.Keras/Layers/Core/Dense.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ protected override Tensors Call(Tensors inputs, Tensor state = null, bool? train
7171
var rank = inputs.rank;
7272
if (rank > 2)
7373
{
74-
throw new NotImplementedException("call rank > 2");
74+
outputs = tf.linalg.tensordot(inputs, kernel.AsTensor(), new[,] { { rank - 1 }, { 0 } });
7575
}
7676
else
7777
{

0 commit comments

Comments
 (0)