Skip to content

Commit ad36e37

Browse files
authored
Merge pull request #1021 from Wanglongzhi2001/master
Finish EarlyStopping
2 parents 14da379 + 059ad48 commit ad36e37

File tree

3 files changed

+4
-2
lines changed

3 files changed

+4
-2
lines changed

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable
1717
List<IVariableV1> TrainableVariables { get; }
1818
List<IVariableV1> TrainableWeights { get; }
1919
List<IVariableV1> NonTrainableWeights { get; }
20-
List<IVariableV1> Weights { get; set}
20+
List<IVariableV1> Weights { get; set; }
2121
Shape OutputShape { get; }
2222
Shape BatchInputShape { get; }
2323
TensorShapeConfig BuildInputShape { get; }

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
8484
protected bool built = false;
8585
public bool Built => built;
8686

87+
List<IVariableV1> ILayer.Weights { get => throw new NotImplementedException(); set => throw new NotImplementedException(); }
88+
8789
public RnnCell(bool trainable = true,
8890
string name = null,
8991
TF_DataType dtype = TF_DataType.DtInvalid,

src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,8 +102,8 @@ public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
102102
{
103103
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}");
104104
}
105+
_parameters.Model.Weights = _best_weights;
105106
}
106-
_parameters.Model.Weights = _best_weights;
107107
}
108108
}
109109
public void on_train_end()

0 commit comments

Comments
 (0)