Skip to content

Commit 53baa05

Browse files
committed
Fix the duplicated weights in Keras.Model.
1 parent 4d0a64f commit 53baa05

File tree

10 files changed

+116
-53
lines changed

10 files changed

+116
-53
lines changed

src/TensorFlowNET.Core/Keras/Common/CustomizedDTypeJsonConverter.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public override void WriteJson(JsonWriter writer, object? value, JsonSerializer
3232
}
3333
else
3434
{
35-
return (TF_DataType)serializer.Deserialize(reader, typeof(TF_DataType));
35+
return (TF_DataType)serializer.Deserialize(reader, typeof(int));
3636
}
3737
}
3838
}

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable
1919
List<IVariableV1> TrainableVariables { get; }
2020
List<IVariableV1> TrainableWeights { get; }
2121
List<IVariableV1> NonTrainableWeights { get; }
22+
List<IVariableV1> Weights { get; }
2223
Shape OutputShape { get; }
2324
Shape BatchInputShape { get; }
2425
TensorShapeConfig BuildInputShape { get; }

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
7171

7272
public List<IVariableV1> TrainableVariables => throw new NotImplementedException();
7373
public List<IVariableV1> TrainableWeights => throw new NotImplementedException();
74+
public List<IVariableV1> Weights => throw new NotImplementedException();
7475
public List<IVariableV1> NonTrainableWeights => throw new NotImplementedException();
7576

7677
public Shape OutputShape => throw new NotImplementedException();

src/TensorFlowNET.Keras/Engine/Layer.Layers.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Linq;
34

45
namespace Tensorflow.Keras.Engine
56
{
@@ -14,5 +15,30 @@ protected void StackLayers(params ILayer[] layers)
1415

1516
public virtual Shape ComputeOutputShape(Shape input_shape)
1617
=> throw new NotImplementedException("");
18+
19+
protected List<IVariableV1> _gather_children_variables(bool include_trainable = false, bool include_non_trainable = false)
20+
{
21+
List<IVariableV1> res = new();
22+
var nested_layers = _flatten_layers(false, false);
23+
foreach (var layer in nested_layers)
24+
{
25+
if (layer is Layer l)
26+
{
27+
if (include_trainable == true && include_non_trainable == true)
28+
{
29+
res.AddRange(l.Variables);
30+
}
31+
else if (include_trainable == true && include_non_trainable == false)
32+
{
33+
res.AddRange(l.TrainableVariables);
34+
}
35+
else if(include_trainable == false && include_non_trainable == true)
36+
{
37+
res.AddRange(l.NonTrainableVariables);
38+
}
39+
}
40+
}
41+
return res;
42+
}
1743
}
1844
}

src/TensorFlowNET.Keras/Engine/Layer.cs

Lines changed: 51 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -67,10 +67,58 @@ public abstract partial class Layer : AutoTrackable, ILayer
6767
public bool SupportsMasking { get; set; }
6868
protected List<IVariableV1> _trainable_weights;
6969

70-
public virtual List<IVariableV1> TrainableVariables => _trainable_weights;
70+
public virtual List<IVariableV1> TrainableVariables => TrainableWeights;
7171

7272
protected List<IVariableV1> _non_trainable_weights;
73-
public List<IVariableV1> non_trainable_variables => _non_trainable_weights;
73+
public List<IVariableV1> NonTrainableVariables => NonTrainableWeights;
74+
public List<IVariableV1> Variables => Weights;
75+
76+
public virtual List<IVariableV1> TrainableWeights
77+
{
78+
get
79+
{
80+
if (!this.Trainable)
81+
{
82+
return new List<IVariableV1>();
83+
}
84+
var children_weights = _gather_children_variables(true);
85+
return children_weights.Concat(_trainable_weights).Distinct().ToList();
86+
}
87+
}
88+
89+
public virtual List<IVariableV1> NonTrainableWeights
90+
{
91+
get
92+
{
93+
if (!this.Trainable)
94+
{
95+
var children_weights = _gather_children_variables(true, true);
96+
return children_weights.Concat(_trainable_weights).Concat(_non_trainable_weights).Distinct().ToList();
97+
}
98+
else
99+
{
100+
var children_weights = _gather_children_variables(include_non_trainable: true);
101+
return children_weights.Concat(_non_trainable_weights).Distinct().ToList();
102+
}
103+
}
104+
}
105+
106+
public virtual List<IVariableV1> Weights
107+
{
108+
get
109+
{
110+
return TrainableWeights.Concat(NonTrainableWeights).ToList();
111+
}
112+
set
113+
{
114+
if (Weights.Count() != value.Count()) throw new ValueError(
115+
$"You called `set_weights` on layer \"{this.name}\"" +
116+
$"with a weight list of length {len(value)}, but the layer was " +
117+
$"expecting {len(Weights)} weights.");
118+
foreach (var (this_w, v_w) in zip(Weights, value))
119+
this_w.assign(v_w, read_value: true);
120+
}
121+
}
74122

75123
protected int id;
76124
public int Id => id;
@@ -290,46 +338,9 @@ protected virtual void _init_set_name(string name, bool zero_based = true)
290338
public int count_params()
291339
{
292340
if (Trainable)
293-
return layer_utils.count_params(this, weights);
341+
return layer_utils.count_params(this, Weights);
294342
return 0;
295343
}
296-
List<IVariableV1> ILayer.TrainableWeights
297-
{
298-
get
299-
{
300-
return _trainable_weights;
301-
}
302-
}
303-
304-
List<IVariableV1> ILayer.NonTrainableWeights
305-
{
306-
get
307-
{
308-
return _non_trainable_weights;
309-
}
310-
}
311-
312-
public List<IVariableV1> weights
313-
{
314-
get
315-
{
316-
var weights = new List<IVariableV1>();
317-
weights.AddRange(_trainable_weights);
318-
weights.AddRange(_non_trainable_weights);
319-
return weights;
320-
}
321-
set
322-
{
323-
if (weights.Count() != value.Count()) throw new ValueError(
324-
$"You called `set_weights` on layer \"{this.name}\"" +
325-
$"with a weight list of length {len(value)}, but the layer was " +
326-
$"expecting {len(weights)} weights.");
327-
foreach (var (this_w, v_w) in zip(weights, value))
328-
this_w.assign(v_w, read_value: true);
329-
}
330-
}
331-
332-
public List<IVariableV1> Variables => weights;
333344

334345
public virtual IKerasConfig get_config()
335346
=> args;

src/TensorFlowNET.Keras/Engine/Model.cs

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,11 @@ void _init_batch_counters()
8989
public override List<ILayer> Layers
9090
=> _flatten_layers(recursive: false, include_self: false).ToList();
9191

92-
public override List<IVariableV1> TrainableVariables
92+
public override List<IVariableV1> TrainableWeights
9393
{
9494
get
9595
{
96+
// skip the assertion of weights created.
9697
var variables = new List<IVariableV1>();
9798

9899
if (!Trainable)
@@ -103,18 +104,40 @@ public override List<IVariableV1> TrainableVariables
103104
foreach (var trackable_obj in _self_tracked_trackables)
104105
{
105106
if (trackable_obj.Trainable)
106-
variables.AddRange(trackable_obj.TrainableVariables);
107+
variables.AddRange(trackable_obj.TrainableWeights);
107108
}
108109

109-
foreach (var layer in _self_tracked_trackables)
110+
variables.AddRange(_trainable_weights);
111+
112+
return variables.Distinct().ToList();
113+
}
114+
}
115+
116+
public override List<IVariableV1> NonTrainableWeights
117+
{
118+
get
119+
{
120+
// skip the assertion of weights created.
121+
var variables = new List<IVariableV1>();
122+
123+
foreach (var trackable_obj in _self_tracked_trackables)
110124
{
111-
if (layer.Trainable)
112-
variables.AddRange(layer.TrainableVariables);
125+
variables.AddRange(trackable_obj.NonTrainableWeights);
113126
}
114127

115-
// variables.AddRange(_trainable_weights);
128+
if (!Trainable)
129+
{
130+
var trainable_variables = new List<IVariableV1>();
131+
foreach (var trackable_obj in _self_tracked_trackables)
132+
{
133+
variables.AddRange(trackable_obj.TrainableWeights);
134+
}
135+
variables.AddRange(trainable_variables);
136+
variables.AddRange(_trainable_weights);
137+
variables.AddRange(_non_trainable_weights);
138+
}
116139

117-
return variables;
140+
return variables.Distinct().ToList();
118141
}
119142
}
120143

src/TensorFlowNET.Keras/Metrics/Metric.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public virtual Tensor update_state(Tensor y_true, Tensor y_pred, Tensor sample_w
5656

5757
public virtual void reset_states()
5858
{
59-
foreach (var v in weights)
59+
foreach (var v in Weights)
6060
v.assign(0);
6161
}
6262

src/TensorFlowNET.Keras/Saving/SavedModel/Save.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ public static IDictionary<string, Trackable> wrap_layer_objects(Layer layer, IDi
130130
if (x is ResourceVariable or RefVariable) return (Trackable)x;
131131
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");
132132
}));
133-
var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.non_trainable_variables.Select(x =>
133+
var non_trainable_variables = TrackableDataStructure.wrap_or_unwrap(layer.NonTrainableVariables.Select(x =>
134134
{
135135
if (x is ResourceVariable or RefVariable) return (Trackable)x;
136136
else throw new TypeError($"The type{x.GetType()} is not supported for the wrapping of layer.");

src/TensorFlowNET.Keras/Utils/layer_utils.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ public static void print_summary(Model model, int line_length = -1, float[] posi
104104
}
105105

106106
var trainable_count = count_params(model, model.TrainableVariables);
107-
var non_trainable_count = count_params(model, model.non_trainable_variables);
107+
var non_trainable_count = count_params(model, model.NonTrainableVariables);
108108

109109
print($"Total params: {trainable_count + non_trainable_count}");
110110
print($"Trainable params: {trainable_count}");

test/TensorFlowNET.Keras.UnitTest/SaveModel/SequentialModelLoad.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ public class SequentialModelLoad
2121
[TestMethod]
2222
public void SimpleModelFromSequential()
2323
{
24-
new SequentialModelSave().SimpleModelFromSequential();
25-
var model = keras.models.load_model(@"./pb_simple_sequential");
24+
//new SequentialModelSave().SimpleModelFromSequential();
25+
var model = keras.models.load_model(@"D:\development\tf.net\tf_test\tf.net.simple.sequential");
2626

2727
model.summary();
2828

@@ -40,5 +40,6 @@ public void SimpleModelFromSequential()
4040
}).Result;
4141

4242
model.fit(dataset.Train.Data, dataset.Train.Labels, batch_size, num_epochs);
43+
model.summary();
4344
}
4445
}

0 commit comments

Comments
 (0)