Skip to content

Commit ce3ddb2

Browse files
AhmedZeroOceania2018
authored andcommitted
Initial regularizers.
1 parent b5b4c51 commit ce3ddb2

File tree

6 files changed

+61
-10
lines changed

6 files changed

+61
-10
lines changed

src/TensorFlowNET.Core/Keras/ArgsDefinition/Core/DenseArgs.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,12 @@ public class DenseArgs : LayerArgs
3333
/// <summary>
3434
/// Regularizer function applied to the `kernel` weights matrix.
3535
/// </summary>
36-
public IInitializer KernelRegularizer { get; set; }
36+
public IRegularizer KernelRegularizer { get; set; }
3737

3838
/// <summary>
3939
/// Regularizer function applied to the bias vector.
4040
/// </summary>
41-
public IInitializer BiasRegularizer { get; set; }
41+
public IRegularizer BiasRegularizer { get; set; }
4242

4343
/// <summary>
4444
/// Constraint function applied to the `kernel` weights matrix.

src/TensorFlowNET.Core/Keras/Regularizers/RegularizerArgs.cs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,5 +2,12 @@
22
{
33
public class RegularizerArgs
44
{
5+
public Tensor X { get; set; }
6+
7+
8+
public RegularizerArgs(Tensor x)
9+
{
10+
X = x;
11+
}
512
}
613
}

src/TensorFlowNET.Keras/Engine/Layer.cs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ protected virtual void build(Tensors inputs)
206206

207207
protected virtual void add_loss(Func<Tensor> losses)
208208
{
209-
209+
210210
}
211211

212212
/// <summary>
@@ -217,10 +217,13 @@ protected virtual void add_loss(Func<Tensor> losses)
217217
/// <param name="regularizer"></param>
218218
void _handle_weight_regularization(string name, IVariableV1 variable, IRegularizer regularizer)
219219
{
220-
add_loss(() => regularizer.Apply(new RegularizerArgs
221-
{
222220

223-
}));
221+
add_loss(() => tf_with(ops.name_scope(name + "/Regularizer"), scope =>
222+
regularizer.Apply(new RegularizerArgs(variable.AsTensor())
223+
{
224+
225+
})
226+
));
224227
}
225228

226229
/*protected virtual void add_update(Tensor[] updates, bool inputs = false)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using System;
2+
3+
namespace Tensorflow.Keras
4+
{
5+
public class L1 : IRegularizer
6+
{
7+
float l1;
8+
9+
public L1(float l1 = 0.01f)
10+
{
11+
this.l1 = l1;
12+
}
13+
14+
public Tensor Apply(RegularizerArgs args)
15+
{
16+
return l1 * math_ops.reduce_sum(math_ops.abs(args.X));
17+
}
18+
}
19+
}
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
using System;
2+
using static Tensorflow.Binding;
3+
namespace Tensorflow.Keras
4+
{
5+
public class L1L2 : IRegularizer
6+
{
7+
float l1;
8+
float l2;
9+
10+
public L1L2(float l1 = 0.0f, float l2 = 0.0f)
11+
{
12+
this.l1 = l1;
13+
this.l2 = l2;
14+
15+
}
16+
public Tensor Apply(RegularizerArgs args)
17+
{
18+
Tensor regularization = tf.constant(0.0, args.X.dtype);
19+
regularization += l1 * math_ops.reduce_sum(math_ops.abs(args.X));
20+
regularization += l2 * math_ops.reduce_sum(math_ops.square(args.X));
21+
return regularization;
22+
}
23+
}
24+
}

src/TensorFlowNET.Keras/Regularizers/L2.cs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
using System;
2-
3-
namespace Tensorflow.Keras
1+
namespace Tensorflow.Keras
42
{
53
public class L2 : IRegularizer
64
{
@@ -13,7 +11,7 @@ public L2(float l2 = 0.01f)
1311

1412
public Tensor Apply(RegularizerArgs args)
1513
{
16-
throw new NotImplementedException();
14+
return l2 * math_ops.reduce_sum(math_ops.square(args.X));
1715
}
1816
}
1917
}

0 commit comments

Comments
 (0)