Skip to content

Commit bd26bbd

Browse files
committed
Orthogonal initializer.
1 parent 321ddfc commit bd26bbd

File tree

16 files changed

+202
-32
lines changed

16 files changed

+202
-32
lines changed

src/TensorFlowNET.Console/Program.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using Tensorflow.Keras;
23
using static Tensorflow.Binding;
34

45
namespace Tensorflow
@@ -7,6 +8,8 @@ class Program
78
{
89
static void Main(string[] args)
910
{
11+
tf.UseKeras<KerasInterface>();
12+
1013
var diag = new Diagnostician();
1114
// diag.Diagnose(@"D:\memory.txt");
1215

src/TensorFlowNET.Core/APIs/tf.linalg.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ public Tensor lstsq(Tensor matrix, Tensor rhs,
5858
NDArray l2_regularizer = null, bool fast = true, string name = null)
5959
=> ops.matrix_solve_ls(matrix, rhs, l2_regularizer: l2_regularizer, fast: fast, name: name);
6060

61+
public Tensors qr(Tensor input, bool full_matrices = true, string name = null)
62+
=> ops.qr(input, full_matrices: full_matrices, name: name);
63+
64+
public Tensor tensor_diag_part(Tensor input, string name = null)
65+
=> gen_array_ops.diag_part(input, name: name);
66+
6167
public Tensor tensordot(Tensor x, Tensor y, NDArray axes, string name = null)
6268
=> math_ops.tensordot(x, y, axes, name: name);
6369
}

src/TensorFlowNET.Core/APIs/tf.random.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ public Tensor normal(Shape shape,
3939
int? seed = null,
4040
string name = null) => random_ops.random_normal(shape, mean, stddev, dtype, seed, name);
4141

42+
public Tensor stateless_normal(Shape shape,
43+
float mean = 0.0f,
44+
float stddev = 1.0f,
45+
TF_DataType dtype = TF_DataType.TF_FLOAT,
46+
string name = null) => stateless_random_ops.stateless_random_normal(shape, mean, stddev, dtype, name: name);
47+
4248
/// <summary>
4349
/// Outputs random values from a truncated normal distribution.
4450
/// </summary>
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras
6+
{
7+
public interface IInitializersApi
8+
{
9+
IInitializer Orthogonal(float gain = 1.0f, int? seed = null);
10+
}
11+
}

src/TensorFlowNET.Core/Keras/IKerasApi.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@ namespace Tensorflow.Keras
88
public interface IKerasApi
99
{
1010
public ILayersApi layers { get; }
11+
public IInitializersApi initializers { get; }
1112
}
1213
}

src/TensorFlowNET.Core/NumPy/NDArrayRender.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ static string Render(NDArray array)
109109
TF_DataType.TF_INT8 => Render(array.ToArray<sbyte>(), array.shape),
110110
TF_DataType.TF_INT32 => Render(array.ToArray<int>(), array.shape),
111111
TF_DataType.TF_INT64 => Render(array.ToArray<long>(), array.shape),
112+
TF_DataType.TF_UINT64 => Render(array.ToArray<ulong>(), array.shape),
112113
TF_DataType.TF_FLOAT => Render(array.ToArray<float>(), array.shape),
113114
TF_DataType.TF_DOUBLE => Render(array.ToArray<double>(), array.shape),
114115
_ => Render(array.ToArray<byte>(), array.shape)
Lines changed: 49 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,62 @@
1-
using System;
1+
/*****************************************************************************
2+
Copyright 2023 Haiping Chen. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System;
218
using System.Linq;
3-
using static Tensorflow.TensorShapeProto.Types;
19+
using static Tensorflow.Binding;
420

5-
namespace Tensorflow.Operations.Initializers
21+
namespace Tensorflow.Operations.Initializers;
22+
23+
public class Orthogonal : IInitializer
624
{
7-
public class Orthogonal : IInitializer
25+
float _gain = 0f;
26+
int? _seed;
27+
28+
public Orthogonal(float gain = 1.0f, int? seed = null)
829
{
9-
float _gain = 0f;
30+
_gain = gain;
31+
_seed = seed;
32+
}
1033

11-
public Orthogonal(float gain = 1.0f, int? seed = null)
12-
{
34+
public Tensor Apply(InitializerArgs args)
35+
{
36+
return _generate_init_val(args.Shape, args.DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : args.DType);
37+
}
1338

14-
}
39+
private Tensor _generate_init_val(Shape shape, TF_DataType dtype)
40+
{
41+
var num_rows = 1L;
42+
foreach (var dim in shape.dims.Take(shape.ndim - 1))
43+
num_rows *= dim;
44+
var num_cols = shape.dims.Last();
45+
var flat_shape = (Math.Max(num_cols, num_rows), Math.Min(num_cols, num_rows));
1546

16-
public Tensor Apply(InitializerArgs args)
17-
{
18-
return _generate_init_val(args.Shape, args.DType);
19-
}
47+
var a = tf.random.stateless_normal(flat_shape, dtype: dtype);
48+
// Compute the qr factorization
49+
var (q, r) = tf.linalg.qr(a, full_matrices: false);
50+
// Make Q uniform
51+
var d = tf.linalg.tensor_diag_part(r);
52+
q *= tf.sign(d);
2053

21-
private Tensor _generate_init_val(Shape shape, TF_DataType dtype)
54+
if (num_rows < num_cols)
2255
{
23-
var num_rows = 1L;
24-
foreach (var dim in shape.dims.Take(shape.ndim - 1))
25-
num_rows *= dim;
26-
var num_cols = shape.dims.Last();
27-
var flat_shape = (Math.Max(num_cols, num_rows), Math.Min(num_cols, num_rows));
28-
56+
// q = tf.linalg.matrix_transpose(q);
2957
throw new NotImplementedException("");
3058
}
59+
60+
return _gain * tf.reshape(q, shape);
3161
}
3262
}

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ public static Tensor[] concat_offset(Tensor concat_dim, Tensor[] shape, string n
113113
public static Tensor diag(Tensor diagonal, string name = null)
114114
=> tf.Context.ExecuteOp("Diag", name, new ExecuteOpArgs(diagonal));
115115

116+
public static Tensor diag_part(Tensor diagonal, string name = null)
117+
=> tf.Context.ExecuteOp("DiagPart", name, new ExecuteOpArgs(diagonal));
118+
116119
public static Tensor expand_dims(Tensor input, int axis, string name = null)
117120
=> tf.Context.ExecuteOp("ExpandDims", name, new ExecuteOpArgs(input, axis)
118121
.SetAttributes(new { dim = axis }));

src/TensorFlowNET.Core/Operations/gen_random_ops.cs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,10 @@ You may obtain a copy of the License at
1313
See the License for the specific language governing permissions and
1414
limitations under the License.
1515
******************************************************************************/
16+
using static Tensorflow.ApiDef.Types;
17+
using System.Reflection;
1618
using static Tensorflow.Binding;
19+
using System.Xml.Linq;
1720

1821
namespace Tensorflow
1922
{
@@ -85,6 +88,15 @@ public static Tensor truncated_normal(Tensor shape, TF_DataType dtype, int? seed
8588
int? seed2 = 0, string name = null)
8689
=> tf.Context.ExecuteOp("TruncatedNormal", name, new ExecuteOpArgs(shape)
8790
.SetAttributes(new { dtype, seed = seed ?? 0, seed2 = seed2 ?? 0 }));
91+
public static Tensor stateless_random_normal_v2(Tensor shape, Tensor key, Tensor counter,
92+
int alg, TF_DataType dtype, string name = null)
93+
=> tf.Context.ExecuteOp("StatelessRandomNormalV2", name,
94+
new ExecuteOpArgs(shape, key, counter, alg)
95+
.SetAttributes(new { dtype }));
96+
97+
public static Tensors stateless_random_get_key_counter(int[] seed, string name = null)
98+
=> tf.Context.ExecuteOp("StatelessRandomGetKeyCounter", name,
99+
new ExecuteOpArgs(seed));
88100

89101
public static Tensor multinomial(Tensor logits, int num_samples, int? seed = 0,
90102
int? seed2 = 0, TF_DataType output_dtype = TF_DataType.TF_INT64, string name = null)

src/TensorFlowNET.Core/Operations/linalg_ops.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,5 +129,12 @@ public Tensor matrix_triangular_solve(Tensor matrix, Tensor rhs, bool lower = tr
129129
lower,
130130
adjoint
131131
}));
132+
133+
public Tensors qr(Tensor input, bool full_matrices = false, string name = null)
134+
=> tf.Context.ExecuteOp("Qr", name,
135+
new ExecuteOpArgs(input).SetAttributes(new
136+
{
137+
full_matrices
138+
}));
132139
}
133140
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*****************************************************************************
2+
Copyright 2023 Haiping Chen. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using static Tensorflow.ApiDef.Types;
18+
using System.Reflection;
19+
using static Tensorflow.Binding;
20+
using System;
21+
22+
namespace Tensorflow;
23+
24+
public class stateless_random_ops
25+
{
26+
public static Tensor stateless_random_normal(Shape shape,
27+
float mean = 0.0f,
28+
float stddev = 1.0f,
29+
TF_DataType dtype = TF_DataType.TF_FLOAT,
30+
int[]? seed = null,
31+
string name = null)
32+
{
33+
return tf_with(ops.name_scope(name, "stateless_random_normal", new { shape, seed, mean, stddev }), scope =>
34+
{
35+
name = scope;
36+
var shape_tensor = _ShapeTensor(shape);
37+
var mean_tensor = ops.convert_to_tensor(mean, dtype: dtype, name: "mean");
38+
var stddev_tensor = ops.convert_to_tensor(stddev, dtype: dtype, name: "stddev");
39+
40+
if (seed == null)
41+
{
42+
seed = new[] { new Random().Next(), 0 };
43+
}
44+
var (key, counter) = _get_key_counter(seed, 3);
45+
var rnd = gen_random_ops.stateless_random_normal_v2(shape: shape_tensor, key: key, counter: counter, dtype: dtype, alg: 3);
46+
var value = math_ops.add(rnd * stddev, mean_tensor, name: name);
47+
// tensor_util.maybe_set_static_shape(value, shape)
48+
return value;
49+
});
50+
}
51+
52+
private static Tensor _ShapeTensor(int[] shape)
53+
{
54+
return ops.convert_to_tensor(shape, name: "shape");
55+
}
56+
57+
private static (Tensor, Tensor) _get_key_counter(int[] seed, int alg)
58+
{
59+
var results = gen_random_ops.stateless_random_get_key_counter(seed);
60+
return (results[0], results[1]);
61+
}
62+
}

src/TensorFlowNET.Core/tensorflow.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,10 @@ public tensorflow()
6767

6868
public void UseKeras<T>() where T : IKerasApi, new()
6969
{
70-
keras = new T();
70+
if (keras == null)
71+
{
72+
keras = new T();
73+
}
7174
}
7275

7376
public string VERSION => c_api.StringPiece(c_api.TF_Version());

src/TensorFlowNET.Keras/Initializers.cs renamed to src/TensorFlowNET.Keras/InitializersApi.cs

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,18 +16,20 @@ limitations under the License.
1616

1717
using Tensorflow.Operations.Initializers;
1818

19-
namespace Tensorflow.Keras
19+
namespace Tensorflow.Keras;
20+
21+
public partial class InitializersApi : IInitializersApi
2022
{
21-
public class Initializers
23+
/// <summary>
24+
/// He normal initializer.
25+
/// </summary>
26+
/// <param name="seed"></param>
27+
/// <returns></returns>
28+
public IInitializer he_normal(int? seed = null)
2229
{
23-
/// <summary>
24-
/// He normal initializer.
25-
/// </summary>
26-
/// <param name="seed"></param>
27-
/// <returns></returns>
28-
public IInitializer he_normal(int? seed = null)
29-
{
30-
return new VarianceScaling(factor: 2.0f, mode: "fan_in", seed: seed);
31-
}
30+
return new VarianceScaling(factor: 2.0f, mode: "fan_in", seed: seed);
3231
}
32+
33+
public IInitializer Orthogonal(float gain = 1.0f, int? seed = null)
34+
=> new Orthogonal(gain: gain, seed: seed);
3335
}

src/TensorFlowNET.Keras/KerasInterface.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ namespace Tensorflow.Keras
1818
public class KerasInterface : IKerasApi
1919
{
2020
public KerasDataset datasets { get; } = new KerasDataset();
21-
public Initializers initializers { get; } = new Initializers();
21+
public IInitializersApi initializers { get; } = new InitializersApi();
2222
public Regularizers regularizers { get; } = new Regularizers();
2323
public ILayersApi layers { get; } = new LayersApi();
2424
public LossesApi losses { get; } = new LossesApi();

test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3+
using Tensorflow.Keras;
34
using static Tensorflow.Binding;
45

56
namespace TensorFlowNET.Keras.UnitTest
@@ -9,6 +10,8 @@ public class EagerModeTestBase
910
[TestInitialize]
1011
public void TestInit()
1112
{
13+
tf.UseKeras<KerasInterface>();
14+
1215
if (!tf.executing_eagerly())
1316
tf.enable_eager_execution();
1417
tf.Context.ensure_initialized();
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Linq;
5+
using System.Text;
6+
using TensorFlowNET.Keras.UnitTest;
7+
using static Tensorflow.Binding;
8+
9+
namespace Tensorflow.Keras.UnitTest;
10+
11+
[TestClass]
12+
public class InitializerTest : EagerModeTestBase
13+
{
14+
[TestMethod]
15+
public void Orthogonal()
16+
{
17+
var initializer = tf.keras.initializers.Orthogonal();
18+
var values = initializer.Apply(new InitializerArgs((2, 2)));
19+
}
20+
}

0 commit comments

Comments
 (0)