Skip to content

Commit 3aa2738

Browse files
committed
Add layers.Normalization.
1 parent 544ff91 commit 3aa2738

File tree

13 files changed

+367
-5
lines changed

13 files changed

+367
-5
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using Newtonsoft.Json;
2+
3+
namespace Tensorflow.Keras.ArgsDefinition;
4+
5+
public class NormalizationArgs : PreprocessingLayerArgs
6+
{
7+
[JsonProperty("axis")]
8+
public Axis? Axis { get; set; }
9+
[JsonProperty("mean")]
10+
public float? Mean { get; set; }
11+
[JsonProperty("variance")]
12+
public float? Variance { get; set; }
13+
14+
public bool Invert { get; set; } = false;
15+
}

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,5 +23,6 @@ public interface ILayer: IWithTrackable, IKerasConfigable
2323
TensorShapeConfig BuildInputShape { get; }
2424
TF_DataType DType { get; }
2525
int count_params();
26+
void adapt(Tensor data, int? batch_size = null, int? steps = null);
2627
}
2728
}

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ public ILayer LayerNormalization(Axis? axis,
156156
IInitializer beta_initializer = null,
157157
IInitializer gamma_initializer = null);
158158

159+
public ILayer Normalization(int? axis = -1, float? mean = null, float? variance = null, bool invert = false);
159160
public ILayer LeakyReLU(float alpha = 0.3f);
160161

161162
public ILayer LSTM(int units,

src/TensorFlowNET.Core/NumPy/ShapeHelper.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,9 @@ public class ShapeHelper
99
{
1010
public static long GetSize(Shape shape)
1111
{
12+
if (shape.IsNull)
13+
return 0;
14+
1215
// scalar
1316
if (shape.ndim == 0)
1417
return 1;

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,5 +159,10 @@ public void build(Shape input_shape)
159159
}
160160

161161
public Trackable GetTrackable() { throw new NotImplementedException(); }
162+
163+
public void adapt(Tensor data, int? batch_size = null, int? steps = null)
164+
{
165+
throw new NotImplementedException();
166+
}
162167
}
163168
}

src/TensorFlowNET.Core/tensorflow.cs

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License.
1616

1717
using Serilog;
1818
using Serilog.Core;
19+
using System.Reflection;
1920
using System.Threading;
2021
using Tensorflow.Contexts;
2122
using Tensorflow.Eager;
@@ -52,7 +53,29 @@ public partial class tensorflow
5253
ThreadLocal<IEagerRunner> _runner = new ThreadLocal<IEagerRunner>(() => new EagerRunner());
5354
public IEagerRunner Runner => _runner.Value;
5455

55-
public IKerasApi keras { get; set; }
56+
private IKerasApi _keras;
57+
public IKerasApi keras
58+
{
59+
get
60+
{
61+
if (_keras != null)
62+
{
63+
return _keras;
64+
}
65+
66+
var k = Assembly.Load("Tensorflow.Keras");
67+
var cls = k.GetTypes().FirstOrDefault(x => x.GetInterfaces().Contains(typeof(IKerasApi)));
68+
if (cls != null)
69+
{
70+
_keras = Activator.CreateInstance(cls) as IKerasApi;
71+
return _keras;
72+
}
73+
else
74+
{
75+
throw new Exception("Can't find keras library.");
76+
}
77+
}
78+
}
5679

5780
public tensorflow()
5881
{

src/TensorFlowNET.Keras/Engine/Layer.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,5 +344,10 @@ public int count_params()
344344

345345
public virtual IKerasConfig get_config()
346346
=> args;
347+
348+
public virtual void adapt(Tensor data, int? batch_size = null, int? steps = null)
349+
{
350+
351+
}
347352
}
348353
}

src/TensorFlowNET.Keras/KerasInterface.cs

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,6 @@ public class KerasInterface : IKerasApi
2020
{
2121
private static KerasInterface _instance = null;
2222
private static readonly object _lock = new object();
23-
private KerasInterface()
24-
{
25-
Tensorflow.Binding.tf.keras = this;
26-
}
2723

2824
public static KerasInterface Instance
2925
{

src/TensorFlowNET.Keras/Layers/LayersApi.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -872,5 +872,14 @@ public ILayer CategoryEncoding(int num_tokens, string output_mode = "one_hot", b
872872
Sparse = sparse,
873873
CountWeights = count_weights
874874
});
875+
876+
public ILayer Normalization(int? axis = -1, float? mean = null, float? variance = null, bool invert = false)
877+
=> new Normalization(new NormalizationArgs
878+
{
879+
Axis = axis,
880+
Mean = mean,
881+
Variance = variance,
882+
Invert = invert
883+
});
875884
}
876885
}
Lines changed: 173 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,173 @@
1+
/*****************************************************************************
2+
Copyright 2023 Haiping Chen. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using Tensorflow.Keras.ArgsDefinition;
18+
19+
namespace Tensorflow.Keras.Layers
20+
{
21+
public class Normalization : PreprocessingLayer
22+
{
23+
NormalizationArgs _args;
24+
25+
int[] axis;
26+
int[] _reduce_axis;
27+
IVariableV1 adapt_mean, adapt_variance, count;
28+
Tensor mean, variance;
29+
Shape _broadcast_shape;
30+
float? input_mean, input_variance;
31+
TF_DataType compute_dtype = tf.float32;
32+
33+
public Normalization(NormalizationArgs args) : base(args)
34+
{
35+
_args = args;
36+
if (args.Axis == null)
37+
{
38+
axis = new int[0];
39+
}
40+
else
41+
{
42+
axis = args.Axis.axis;
43+
}
44+
input_mean = args.Mean;
45+
input_variance = args.Variance;
46+
}
47+
48+
public override void build(Shape input_shape)
49+
{
50+
base.build(input_shape);
51+
var ndim = input_shape.ndim;
52+
foreach (var (idx, x) in enumerate(axis))
53+
if (x < 0)
54+
axis[idx] = ndim + x;
55+
56+
var _keep_axis = axis.Select(d => d >= 0 ? d : d + ndim).ToArray();
57+
_reduce_axis = range(ndim).Where(d => !_keep_axis.Contains(d)).ToArray();
58+
var _reduce_axis_mask = range(ndim).Select(d => _keep_axis.Contains(d) ? 0 : 1).ToArray();
59+
// Broadcast any reduced axes.
60+
_broadcast_shape = new Shape(range(ndim).Select(d => _keep_axis.Contains(d) ? input_shape.dims[d] : 1).ToArray());
61+
var mean_and_var_shape = _keep_axis.Select(d => input_shape.dims[d]).ToArray();
62+
63+
var param_dtype = DType == TF_DataType.DtInvalid ? TF_DataType.TF_FLOAT : DType;
64+
var param_shape = input_shape;
65+
66+
if(input_mean == null)
67+
{
68+
adapt_mean = add_weight("mean",
69+
mean_and_var_shape,
70+
dtype: tf.float32,
71+
initializer: tf.zeros_initializer,
72+
trainable: false);
73+
74+
adapt_variance = add_weight("variance",
75+
mean_and_var_shape,
76+
dtype: tf.float32,
77+
initializer: tf.ones_initializer,
78+
trainable: false);
79+
80+
count = add_weight("count",
81+
Shape.Scalar,
82+
dtype: tf.int64,
83+
initializer: tf.zeros_initializer,
84+
trainable: false);
85+
86+
finalize_state();
87+
}
88+
else
89+
{
90+
mean = input_mean * np.ones(mean_and_var_shape);
91+
variance = input_variance * np.ones(mean_and_var_shape);
92+
mean = tf.reshape(mean, _broadcast_shape);
93+
variance = tf.reshape(variance, _broadcast_shape);
94+
mean = tf.cast(mean, compute_dtype);
95+
variance = tf.cast(variance, compute_dtype);
96+
}
97+
}
98+
99+
public override void reset_state()
100+
{
101+
if (input_mean != null && !built)
102+
{
103+
return;
104+
}
105+
adapt_mean.assign(tf.zeros_like(adapt_mean.AsTensor()));
106+
adapt_variance.assign(tf.ones_like(adapt_variance.AsTensor()));
107+
count.assign(tf.zeros_like(count.AsTensor()));
108+
}
109+
110+
public override void finalize_state()
111+
{
112+
if (input_mean != null && !built)
113+
{
114+
return;
115+
}
116+
mean = tf.reshape(adapt_mean.AsTensor(), _broadcast_shape);
117+
variance = tf.reshape(adapt_variance.AsTensor(), _broadcast_shape);
118+
}
119+
120+
public override void update_state(Tensor data)
121+
{
122+
data = tf.cast(data, adapt_mean.dtype);
123+
var (batch_mean, batch_variance) = tf.nn.moments(data, axes: _reduce_axis);
124+
var batch_shape = tf.shape(data, out_type: count.dtype);
125+
126+
var batch_count = constant_op.constant(1L);
127+
if (_reduce_axis != null)
128+
{
129+
var batch_reduce_shape = tf.gather(batch_shape, constant_op.constant(_reduce_axis));
130+
batch_count = tf.reduce_prod(batch_reduce_shape);
131+
}
132+
var total_count = batch_count + count.AsTensor();
133+
var batch_weight = tf.cast(batch_count, dtype: compute_dtype) / tf.cast(
134+
total_count, dtype: compute_dtype);
135+
var existing_weight = 1.0 - batch_weight;
136+
var total_mean = adapt_mean.AsTensor() * existing_weight + batch_mean * batch_weight;
137+
138+
var total_variance = (
139+
adapt_variance.AsTensor() + tf.square(adapt_mean.AsTensor() - total_mean)
140+
) * existing_weight + (
141+
batch_variance + tf.square(batch_mean - total_mean)
142+
) * batch_weight;
143+
adapt_mean.assign(total_mean);
144+
adapt_variance.assign(total_variance);
145+
count.assign(total_count);
146+
}
147+
148+
public override Shape ComputeOutputShape(Shape input_shape)
149+
{
150+
return input_shape;
151+
}
152+
153+
public override void adapt(Tensor data, int? batch_size = null, int? steps = null)
154+
{
155+
base.adapt(data, batch_size: batch_size, steps: steps);
156+
}
157+
158+
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
159+
{
160+
if (_args.Invert)
161+
{
162+
return mean + (
163+
inputs * tf.maximum(tf.sqrt(variance), keras.backend.epsilon())
164+
);
165+
}
166+
else
167+
{
168+
return (inputs - mean) / tf.maximum(
169+
tf.sqrt(variance), keras.backend.epsilon());
170+
}
171+
}
172+
}
173+
}

src/TensorFlowNET.Keras/Layers/Preprocessing/PreprocessingLayer.cs

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,95 @@
33
using System.Text;
44
using Tensorflow.Keras.ArgsDefinition;
55
using Tensorflow.Keras.Engine;
6+
using Tensorflow.Keras.Engine.DataAdapters;
67

78
namespace Tensorflow.Keras.Layers
89
{
910
public class PreprocessingLayer : Layer
1011
{
12+
bool _is_compiled;
13+
bool _is_adapted;
14+
IVariableV1 _steps_per_execution;
15+
PreprocessingLayerArgs _args;
1116
public PreprocessingLayer(PreprocessingLayerArgs args) : base(args)
1217
{
18+
_args = args;
19+
}
20+
21+
public override void adapt(Tensor data, int? batch_size = null, int? steps = null)
22+
{
23+
if (!_is_compiled)
24+
{
25+
compile();
26+
}
27+
28+
if (built)
29+
{
30+
reset_state();
31+
}
32+
33+
var data_handler = new DataHandler(new DataHandlerArgs
34+
{
35+
X = new Tensors(data),
36+
BatchSize = _args.BatchSize,
37+
Epochs = 1,
38+
StepsPerExecution = _steps_per_execution
39+
});
40+
41+
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
42+
{
43+
foreach (var _ in data_handler.steps())
44+
{
45+
run_step(iterator);
46+
}
47+
}
48+
finalize_state();
49+
_is_adapted = true;
50+
}
51+
52+
private void run_step(OwnedIterator iterator)
53+
{
54+
var data = iterator.next();
55+
_adapt_maybe_build(data[0]);
56+
update_state(data[0]);
57+
}
58+
59+
public virtual void reset_state()
60+
{
61+
62+
}
63+
64+
public virtual void finalize_state()
65+
{
66+
67+
}
68+
69+
public virtual void update_state(Tensor data)
70+
{
71+
72+
}
73+
74+
private void _adapt_maybe_build(Tensor data)
75+
{
76+
if (!built)
77+
{
78+
var data_shape = data.shape;
79+
var data_shape_nones = Enumerable.Range(0, data.ndim).Select(x => -1).ToArray();
80+
_args.BatchInputShape = BatchInputShape ?? new Shape(data_shape_nones);
81+
build(data_shape);
82+
built = true;
83+
}
84+
}
85+
86+
public void compile(bool run_eagerly = false, int steps_per_execution = 1)
87+
{
88+
_steps_per_execution = tf.Variable(
89+
steps_per_execution,
90+
dtype: tf.int64,
91+
aggregation: VariableAggregation.OnlyFirstReplica
92+
);
1393

94+
_is_compiled = true;
1495
}
1596
}
1697
}

0 commit comments

Comments
 (0)