Skip to content

Commit a0df810

Browse files
committed
fix: training LSTM does not align with tensorflow.
1 parent 675b93a commit a0df810

File tree

14 files changed

+68
-37
lines changed

14 files changed

+68
-37
lines changed

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -503,7 +503,7 @@ public static TF_DataType GetDataType(this object data)
503503
case Tensors tensors:
504504
return tensors.dtype;
505505
case IEnumerable<Tensor> tensors:
506-
return tensors.First().dtype;
506+
return tensors.Where(x => x is not null).First().dtype;
507507
case RefVariable variable:
508508
return variable.dtype;
509509
case ResourceVariable variable:

src/TensorFlowNET.Core/Eager/EagerRunner.TFE_TapeGradient.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ public Tensor[] TFE_TapeGradient(ITape tape,
6565
{
6666
outgrad_vec = output_gradients.ToList();
6767
}
68-
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, false);
68+
var result = tape.ComputeGradient(target_vec, sources_vec, source_tensors_that_are_targets, outgrad_vec, true);
6969

7070

7171
bool unconnected_gradients_zero = unconnected_gradients == "zero";

src/TensorFlowNET.Core/Eager/EagerTensor.ToString.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,11 @@ public override string ToString()
1010
var str = NDArrayRender.ToString(nd);
1111
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
1212
}
13-
13+
public string ToString(int maxLength)
14+
{
15+
var nd = new NDArray(this);
16+
var str = NDArrayRender.ToString(nd, maxLength);
17+
return $"tf.Tensor: shape={shape}, dtype={dtype.as_numpy_name()}, numpy={str}";
18+
}
1419
}
1520
}

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/LSTMCellArgs.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ public class LSTMCellArgs : AutoSerializeLayerArgs
2929
[JsonProperty("unit_forget_bias")]
3030
public bool UnitForgetBias { get; set; } = true;
3131
[JsonProperty("implementation")]
32-
public int Implementation { get; set; } = 1;
32+
public int Implementation { get; set; } = 2;
3333

3434
}
3535
}

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ public ILayer LSTM(int units,
182182
bool unit_forget_bias = true,
183183
float dropout = 0f,
184184
float recurrent_dropout = 0f,
185-
int implementation = 1,
185+
int implementation = 2,
186186
bool return_sequences = false,
187187
bool return_state = false,
188188
bool go_backwards = false,

src/TensorFlowNET.Core/NumPy/NDArrayRender.cs

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@ namespace Tensorflow.NumPy
77
{
88
public class NDArrayRender
99
{
10-
public static string ToString(NDArray array)
10+
public static string ToString(NDArray array, int maxLength = 10)
1111
{
1212
Shape shape = array.shape;
1313
if (shape.IsScalar)
1414
return Render(array);
1515

1616
var s = new StringBuilder();
1717
s.Append("array(");
18-
Build(s, array);
18+
Build(s, array, maxLength);
1919
s.Append(")");
2020
return s.ToString();
2121
}
2222

23-
static void Build(StringBuilder s, NDArray array)
23+
static void Build(StringBuilder s, NDArray array, int maxLength)
2424
{
2525
var shape = array.shape;
2626

@@ -35,11 +35,11 @@ static void Build(StringBuilder s, NDArray array)
3535
var len = shape[0];
3636
s.Append("[");
3737

38-
if (len <= 10)
38+
if (len <= maxLength)
3939
{
4040
for (int i = 0; i < len; i++)
4141
{
42-
Build(s, array[i]);
42+
Build(s, array[i], maxLength);
4343
if (i < len - 1)
4444
{
4545
s.Append(", ");
@@ -49,9 +49,9 @@ static void Build(StringBuilder s, NDArray array)
4949
}
5050
else
5151
{
52-
for (int i = 0; i < 5; i++)
52+
for (int i = 0; i < maxLength / 2; i++)
5353
{
54-
Build(s, array[i]);
54+
Build(s, array[i], maxLength);
5555
if (i < len - 1)
5656
{
5757
s.Append(", ");
@@ -62,9 +62,9 @@ static void Build(StringBuilder s, NDArray array)
6262
s.Append(" ... ");
6363
s.AppendLine();
6464

65-
for (int i = (int)len - 5; i < len; i++)
65+
for (int i = (int)len - maxLength / 2; i < len; i++)
6666
{
67-
Build(s, array[i]);
67+
Build(s, array[i], maxLength);
6868
if (i < len - 1)
6969
{
7070
s.Append(", ");
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.NumPy;
5+
6+
namespace Tensorflow.Operations.Initializers
7+
{
8+
/// <summary>
9+
/// An initializer specially used for debugging (to load weights from disk).
10+
/// </summary>
11+
class NpyLoadInitializer : IInitializer
12+
{
13+
string _path;
14+
public NpyLoadInitializer(string path) { _path = path; }
15+
public string ClassName => "";
16+
public IDictionary<string, object> Config => new Dictionary<string, object>();
17+
public Tensor Apply(InitializerArgs args)
18+
{
19+
return np.load(_path);
20+
}
21+
}
22+
}

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ https://tensorflownet.readthedocs.io</Description>
111111
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.148" />
112112
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
113113
<PackageReference Include="OneOf" Version="3.0.223" />
114-
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
114+
<PackageReference Include="Protobuf.Text" Version="0.7.1" />
115115
<PackageReference Include="Serilog.Sinks.Console" Version="4.1.0" />
116116
</ItemGroup>
117117

src/TensorFlowNET.Core/Training/Trackable.cs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,8 +179,7 @@ protected virtual IVariableV1 _add_variable_with_custom_getter(VariableArgs args
179179
// handles slot variables.
180180
if (!args.Overwrite || new_variable is RefVariable || new_variable is Trackable)
181181
{
182-
var temp = new_variable as Trackable;
183-
var res = _track_trackable(temp, args.Name, args.Overwrite);
182+
var res = _track_trackable(new_variable as Trackable, args.Name, args.Overwrite);
184183
Debug.Assert(res is IVariableV1);
185184
return res as IVariableV1;
186185
}

src/TensorFlowNET.Keras/Layers/LayersApi.cs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,7 @@ public IRnnCell LSTMCell(int uints,
793793
bool unit_forget_bias = true,
794794
float dropout = 0f,
795795
float recurrent_dropout = 0f,
796-
int implementation = 1)
796+
int implementation = 2)
797797
=> new LSTMCell(new LSTMCellArgs
798798
{
799799
Units = uints,
@@ -846,7 +846,7 @@ public ILayer LSTM(int units,
846846
bool unit_forget_bias = true,
847847
float dropout = 0f,
848848
float recurrent_dropout = 0f,
849-
int implementation = 1,
849+
int implementation = 2,
850850
bool return_sequences = false,
851851
bool return_state = false,
852852
bool go_backwards = false,
@@ -869,7 +869,8 @@ public ILayer LSTM(int units,
869869
GoBackwards = go_backwards,
870870
Stateful = stateful,
871871
TimeMajor = time_major,
872-
Unroll = unroll
872+
Unroll = unroll,
873+
UnitForgetBias = unit_forget_bias
873874
});
874875

875876
/// <summary>

src/TensorFlowNET.Keras/Layers/Rnn/LSTMCell.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using Serilog.Core;
1+
using Newtonsoft.Json;
2+
using Serilog.Core;
23
using System.Diagnostics;
34
using Tensorflow.Common.Extensions;
45
using Tensorflow.Common.Types;
@@ -54,6 +55,7 @@ public LSTMCell(LSTMCellArgs args)
5455

5556
public override void build(KerasShapesWrapper input_shape)
5657
{
58+
base.build(input_shape);
5759
var single_shape = input_shape.ToSingleShape();
5860
var input_dim = single_shape[-1];
5961
_kernel = add_weight("kernel", (input_dim, _args.Units * 4),
@@ -82,7 +84,8 @@ Tensor bias_initializer()
8284
_bias_initializer = _args.BiasInitializer;
8385
}
8486
_bias = add_weight("bias", (_args.Units * 4),
85-
initializer: _bias_initializer);
87+
initializer: _bias_initializer
88+
);
8689
}
8790
built = true;
8891
}
@@ -203,7 +206,7 @@ public Tensors _compute_carry_and_output(Tensor[] x, Tensor[] h_tm1, Tensor c_tm
203206
x_c + math_ops.matmul(h_tm1_c, _recurrent_kernel_slice));
204207
_recurrent_kernel_slice = tf.slice(_recurrent_kernel_tensor,
205208
new[] { 0, _args.Units * 3 }, new[] { startIndex, _args.Units });
206-
var o = _args.RecurrentActivation.Apply(
209+
var o = _args.Activation.Apply(
207210
x_o + math_ops.matmul(h_tm1_o, _recurrent_kernel_slice));
208211

209212
return new Tensors(c, o);
@@ -220,7 +223,7 @@ public Tensors _compute_carry_and_output_fused(Tensor[] z, Tensor c_tm1)
220223
Tensor z0 = z[0], z1 = z[1], z2 = z[2], z3 = z[3];
221224
var i = _args.RecurrentActivation.Apply(z0);
222225
var f = _args.RecurrentActivation.Apply(z1);
223-
var c = f * c_tm1 + i * _args.RecurrentActivation.Apply(z2);
226+
var c = f * c_tm1 + i * _args.Activation.Apply(z2);
224227
var o = _args.RecurrentActivation.Apply(z3);
225228
return new Tensors(c, o);
226229
}

test/TensorFlowNET.Keras.UnitTest/Layers/Rnn.Test.cs

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,26 +60,23 @@ public void TrainLSTMWithMnist()
6060
{
6161
var input = keras.Input((784));
6262
var x = keras.layers.Reshape((28, 28)).Apply(input);
63-
//x = keras.layers.LSTM(50, return_sequences: true).Apply(x);
64-
//x = keras.layers.LSTM(100, return_sequences: true).Apply(x);
65-
//x = keras.layers.LSTM(150, return_sequences: true).Apply(x);
66-
x = keras.layers.LSTM(4, implementation: 2).Apply(x);
67-
//x = keras.layers.Dense(100).Apply(x);
63+
x = keras.layers.LSTM(50, return_sequences: true).Apply(x);
64+
x = keras.layers.LSTM(100).Apply(x);
6865
var output = keras.layers.Dense(10, activation: "softmax").Apply(x);
6966

7067
var model = keras.Model(input, output);
7168
model.summary();
72-
model.compile(keras.optimizers.Adam(), keras.losses.SparseCategoricalCrossentropy(), new string[] { "accuracy" });
69+
model.compile(keras.optimizers.Adam(), keras.losses.CategoricalCrossentropy(), new string[] { "accuracy" });
7370

7471
var data_loader = new MnistModelLoader();
7572
var dataset = data_loader.LoadAsync(new ModelLoadSetting
7673
{
7774
TrainDir = "mnist",
78-
OneHot = false,
79-
ValidationSize = 58000,
75+
OneHot = true,
76+
ValidationSize = 55000,
8077
}).Result;
8178

82-
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 30);
79+
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 1);
8380
}
8481

8582
[TestMethod]
@@ -102,7 +99,7 @@ public void SimpleRNN()
10299
ValidationSize = 58000,
103100
}).Result;
104101

105-
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 10);
102+
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size: 16, epochs: 2);
106103
}
107104

108105
[TestMethod]

tools/Tensorflow.CodeGen/FunctionGenerator.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,12 @@ public void AppendFunction(OpDef op, StringBuilder sb)
8383

8484
sb.AppendLine("}"); // try
8585

86-
sb.Append("catch(NotOkStatusException ex)\n{\n");
87-
sb.AppendLine("throw ex;");
86+
sb.Append("catch(NotOkStatusException ex1)\n{\n");
87+
sb.AppendLine("throw ex1;");
88+
sb.AppendLine("}"); // catch
89+
90+
sb.Append("catch(InvalidArgumentError ex2)\n{\n");
91+
sb.AppendLine("throw ex2;");
8892
sb.AppendLine("}"); // catch
8993

9094
sb.Append("catch(Exception)\n{\n");

tools/Tensorflow.CodeGen/Tensorflow.CodeGen.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
<ItemGroup>
1111
<PackageReference Include="Microsoft.CodeAnalysis.CSharp.Scripting" Version="4.6.0-1.final" />
12-
<PackageReference Include="Protobuf.Text" Version="0.7.0" />
12+
<PackageReference Include="Protobuf.Text" Version="0.7.1" />
1313
</ItemGroup>
1414

1515
<ItemGroup>

0 commit comments

Comments
 (0)