Skip to content

Commit d94c685

Browse files
committed
Add node to connectivity between two layers.
1 parent 59ee7ef commit d94c685

File tree

11 files changed

+141
-19
lines changed

11 files changed

+141
-19
lines changed

src/TensorFlowNET.Core/APIs/keras.layers.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public static Embedding Embedding(int input_dim, int output_dim,
1818
embeddings_initializer,
1919
mask_zero);
2020

21-
public static InputLayer Input(int[] batch_shape = null,
21+
public static Tensor[] Input(int[] batch_shape = null,
2222
TF_DataType dtype = TF_DataType.DtInvalid,
2323
string name = null,
2424
bool sparse = false,
@@ -35,7 +35,9 @@ public static InputLayer Input(int[] batch_shape = null,
3535
sparse: sparse,
3636
input_tensor: tensor);
3737

38-
throw new NotImplementedException("");
38+
var outputs = input_layer.inbound_nodes[0].output_tensors;
39+
40+
return outputs;
3941
}
4042
}
4143
}

src/TensorFlowNET.Core/Keras/Engine/Sequential.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,12 @@ public void add(Layer layer)
3333
batch_shape: batch_shape,
3434
dtype: dtype,
3535
name: layer._name + "_input");
36+
37+
// This will build the current layer
38+
// and create the node connecting the current layer
39+
// to the input layer we just created.
40+
layer.__call__(x);
41+
set_inputs = true;
3642
}
3743
}
3844
}

src/TensorFlowNET.Core/Keras/Layers/Embedding.cs

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ public class Embedding : Layer
99
private int input_dim;
1010
private int output_dim;
1111
private bool mask_zero;
12+
public RefVariable embeddings;
13+
public IInitializer embeddings_initializer;
1214

1315
public Embedding(int input_dim, int output_dim,
1416
IInitializer embeddings_initializer = null,
@@ -18,10 +20,17 @@ public Embedding(int input_dim, int output_dim,
1820
{
1921
this.input_dim = input_dim;
2022
this.output_dim = output_dim;
21-
if (embeddings_initializer == null)
22-
embeddings_initializer = tf.uniform_initializer;
23+
this.embeddings_initializer = embeddings_initializer == null ? tf.uniform_initializer : embeddings_initializer;
2324
this.mask_zero = mask_zero;
2425
supports_masking = mask_zero;
2526
}
27+
28+
protected override void build(TensorShape input_shape)
29+
{
30+
embeddings = add_weight(shape: new int[] { input_dim, output_dim },
31+
initializer: embeddings_initializer,
32+
name: "embeddings");
33+
built = true;
34+
}
2635
}
2736
}

src/TensorFlowNET.Core/Keras/Layers/InputLayer.cs

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ public class InputLayer : Layer
1111
{
1212
public bool sparse;
1313
public int? batch_size;
14+
public bool is_placeholder;
1415

1516
public InputLayer(int[] input_shape = null,
1617
int? batch_size = null,
@@ -24,7 +25,7 @@ public InputLayer(int[] input_shape = null,
2425
this.batch_size = batch_size;
2526
this.supports_masking = true;
2627

27-
if(input_tensor == null)
28+
if (input_tensor == null)
2829
{
2930
var batch_input_shape = new int[] { batch_size.HasValue ? batch_size.Value : -1, -1 };
3031

@@ -39,7 +40,17 @@ public InputLayer(int[] input_shape = null,
3940
dtype: dtype,
4041
name: name);
4142
}
43+
44+
is_placeholder = true;
45+
_batch_input_shape = batch_input_shape;
4246
}
47+
48+
new Node(this,
49+
inbound_layers: new Layer[0],
50+
node_indices: new int[0],
51+
tensor_indices: new int[0],
52+
input_tensors: new Tensor[] { input_tensor },
53+
output_tensors: new Tensor[] { input_tensor });
4354
}
4455
}
4556
}

src/TensorFlowNET.Core/Keras/Layers/Layer.cs

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ public class Layer : CheckpointableBase
3939
protected List<Operation> _updates;
4040
public int[] _batch_input_shape;
4141

42+
private List<Node> _inbound_nodes;
43+
public List<Node> inbound_nodes => _inbound_nodes;
44+
45+
private List<Node> _outbound_nodes;
46+
public List<Node> outbound_nodes => _outbound_nodes;
47+
4248
public Layer(bool trainable = true,
4349
string name = null,
4450
TF_DataType dtype = TF_DataType.DtInvalid,
@@ -59,13 +65,15 @@ public Layer(bool trainable = true,
5965
_batch_input_shape = new int[] { -1, -1 };
6066

6167
_dtype = dtype;
68+
69+
_inbound_nodes = new List<Node>();
6270
}
6371

64-
public Tensor __call__(Tensor inputs,
72+
public Tensor __call__(Tensor[] inputs,
6573
Tensor training = null,
6674
VariableScope scope = null)
6775
{
68-
var input_list = new Tensor[] { inputs };
76+
var input_list = inputs;
6977
Tensor outputs = null;
7078

7179
// We will attempt to build a TF graph if & only if all inputs are symbolic.
@@ -88,9 +96,9 @@ public Tensor __call__(Tensor inputs,
8896
// Symbolic execution on symbolic tensors. We will attempt to build
8997
// the corresponding TF subgraph inside `backend.get_graph()`
9098
var graph = backend.get_graph();
91-
outputs = call(inputs, training: training);
92-
_handle_activity_regularization(inputs, outputs);
93-
_set_mask_metadata(inputs, outputs, null);
99+
outputs = call(inputs[0], training: training);
100+
_handle_activity_regularization(inputs[0], outputs);
101+
_set_mask_metadata(inputs[0], outputs, null);
94102
}
95103
});
96104

@@ -125,10 +133,10 @@ protected virtual string _name_scope()
125133
return null;
126134
}
127135

128-
protected void _maybe_build(Tensor inputs)
136+
protected void _maybe_build(Tensor[] inputs)
129137
{
130-
var input_list = new Tensor[] { inputs };
131-
build(inputs.getShape());
138+
var input_list = inputs;
139+
build(input_list[0].getShape());
132140
}
133141

134142
protected virtual void build(TensorShape input_shape)
@@ -143,10 +151,16 @@ protected virtual RefVariable add_weight(string name,
143151
bool? trainable = null,
144152
Func<string, int[], TF_DataType, IInitializer, bool, RefVariable> getter = null)
145153
{
154+
if (dtype == TF_DataType.DtInvalid)
155+
dtype = TF_DataType.TF_FLOAT;
156+
157+
if (trainable == null)
158+
trainable = true;
159+
146160
var variable = _add_variable_with_custom_getter(name,
147161
shape,
148162
dtype: dtype,
149-
getter: getter,
163+
getter: getter == null ? base_layer_utils.make_variable : getter,
150164
overwrite: true,
151165
initializer: initializer,
152166
trainable: trainable.Value);
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow.Keras.Layers
7+
{
8+
/// <summary>
9+
/// A `Node` describes the connectivity between two layers.
10+
/// </summary>
11+
public class Node
12+
{
13+
public InputLayer outbound_layer;
14+
public Layer[] inbound_layers;
15+
public int[] node_indices;
16+
public int[] tensor_indices;
17+
public Tensor[] input_tensors;
18+
public Tensor[] output_tensors;
19+
public int[][] input_shapes;
20+
public int[][] output_shapes;
21+
22+
/// <summary>
23+
///
24+
/// </summary>
25+
/// <param name="outbound_layer">
26+
/// the layer that takes
27+
/// `input_tensors` and turns them into `output_tensors`
28+
/// (the node gets created when the `call`
29+
/// method of the layer was called).
30+
/// </param>
31+
/// <param name="inbound_layers">
32+
/// a list of layers, the same length as `input_tensors`,
33+
/// the layers from where `input_tensors` originate.
34+
/// </param>
35+
/// <param name="node_indices">
36+
/// a list of integers, the same length as `inbound_layers`.
37+
/// `node_indices[i]` is the origin node of `input_tensors[i]`
38+
/// (necessary since each inbound layer might have several nodes,
39+
/// e.g. if the layer is being shared with a different data stream).
40+
/// </param>
41+
/// <param name="tensor_indices"></param>
42+
/// <param name="input_tensors">list of input tensors.</param>
43+
/// <param name="output_tensors">list of output tensors.</param>
44+
public Node(InputLayer outbound_layer,
45+
Layer[] inbound_layers,
46+
int[] node_indices,
47+
int[] tensor_indices,
48+
Tensor[] input_tensors,
49+
Tensor[] output_tensors)
50+
{
51+
this.outbound_layer = outbound_layer;
52+
this.inbound_layers = inbound_layers;
53+
this.node_indices = node_indices;
54+
this.tensor_indices = tensor_indices;
55+
this.input_tensors = input_tensors;
56+
this.output_tensors = output_tensors;
57+
58+
input_shapes = input_tensors.Select(x => x._shape_tuple()).ToArray();
59+
output_shapes = output_tensors.Select(x => x._shape_tuple()).ToArray();
60+
61+
// Add nodes to all layers involved.
62+
foreach (var layer in inbound_layers)
63+
{
64+
if (layer != null)
65+
layer.outbound_nodes.Add(this);
66+
}
67+
68+
outbound_layer.inbound_nodes.Add(this);
69+
}
70+
}
71+
}

src/TensorFlowNET.Core/Keras/Engine/base_layer_utils.cs renamed to src/TensorFlowNET.Core/Keras/Utils/base_layer_utils.cs

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,19 @@
22
using System.Collections.Generic;
33
using System.Text;
44

5-
namespace Tensorflow.Keras.Engine
5+
namespace Tensorflow.Keras.Utils
66
{
77
public class base_layer_utils
88
{
9+
public static RefVariable make_variable(string name,
10+
int[] shape,
11+
TF_DataType dtype = TF_DataType.TF_FLOAT,
12+
IInitializer initializer = null,
13+
bool trainable = false)
14+
{
15+
throw new NotImplementedException("");
16+
}
17+
918
/// <summary>
1019
/// Makes a layer name (or arbitrary string) unique within a TensorFlow graph.
1120
/// </summary>

src/TensorFlowNET.Core/Layers/Layer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ public Tensor __call__(Tensor inputs,
5252

5353
Python.with(scope_context_manager, scope2 => _current_scope = scope2);
5454
// Actually call layer
55-
var outputs = base.__call__(inputs, training: training);
55+
var outputs = base.__call__(new Tensor[] { inputs }, training: training);
5656

5757
// Update global default collections.
5858
_add_elements_to_collection(_updates.ToArray(), new string[] { ops.GraphKeys.UPDATE_OPS });

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ Docs: https://tensorflownet.readthedocs.io</Description>
4343

4444
<ItemGroup>
4545
<PackageReference Include="Google.Protobuf" Version="3.7.0" />
46-
<PackageReference Include="NumSharp" Version="0.7.4" />
46+
<PackageReference Include="NumSharp" Version="0.8.0" />
4747
</ItemGroup>
4848

4949
<ItemGroup>

test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
</PropertyGroup>
77

88
<ItemGroup>
9-
<PackageReference Include="NumSharp" Version="0.7.4" />
9+
<PackageReference Include="NumSharp" Version="0.8.0" />
1010
<PackageReference Include="SharpZipLib" Version="1.1.0" />
1111
<PackageReference Include="TensorFlow.NET" Version="0.4.2" />
1212
</ItemGroup>

test/TensorFlowNET.UnitTest/TensorFlowNET.UnitTest.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
<PackageReference Include="Microsoft.NET.Test.Sdk" Version="16.0.1" />
2020
<PackageReference Include="MSTest.TestAdapter" Version="1.4.0" />
2121
<PackageReference Include="MSTest.TestFramework" Version="1.4.0" />
22-
<PackageReference Include="NumSharp" Version="0.7.4" />
22+
<PackageReference Include="NumSharp" Version="0.8.0" />
2323
<PackageReference Include="TensorFlow.NET" Version="0.4.2" />
2424
</ItemGroup>
2525

0 commit comments

Comments
 (0)