Skip to content

Commit d639ce3

Browse files
authored
Merge pull request #1007 from Wanglongzhi2001/master
Add EarlyStopping callback
2 parents 3aa2738 + a23b80c commit d639ce3

File tree

9 files changed

+252
-7
lines changed

9 files changed

+252
-7
lines changed

src/TensorFlowNET.Core/Keras/Engine/ICallback.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ public interface ICallback
44
{
55
Dictionary<string, List<float>> history { get; set; }
66
void on_train_begin();
7+
void on_train_end();
78
void on_epoch_begin(int epoch);
89
void on_train_batch_begin(long step);
910
void on_train_batch_end(long end_step, Dictionary<string, float> logs);

src/TensorFlowNET.Core/Keras/Engine/IModel.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ ICallback fit(NDArray x, NDArray y,
1717
int batch_size = -1,
1818
int epochs = 1,
1919
int verbose = 1,
20+
List<ICallback> callbacks = null,
2021
float validation_split = 0f,
2122
bool shuffle = true,
2223
int initial_epoch = 0,
@@ -28,6 +29,7 @@ ICallback fit(IEnumerable<NDArray> x, NDArray y,
2829
int batch_size = -1,
2930
int epochs = 1,
3031
int verbose = 1,
32+
List<ICallback> callbacks = null,
3133
float validation_split = 0f,
3234
bool shuffle = true,
3335
int initial_epoch = 0,
@@ -73,4 +75,6 @@ Tensors predict(Tensors x,
7375
void summary(int line_length = -1, float[] positions = null);
7476

7577
IKerasConfig get_config();
78+
79+
void set_stopTraining_true();
7680
}

src/TensorFlowNET.Keras/Callbacks/CallbackList.cs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ namespace Tensorflow.Keras.Callbacks;
77

88
public class CallbackList
99
{
10-
List<ICallback> callbacks = new List<ICallback>();
10+
// 改成public使得新定义的callback可以加入到callbacks里
11+
public List<ICallback> callbacks = new List<ICallback>();
1112
public History History => callbacks[0] as History;
1213

1314
public CallbackList(CallbackParams parameters)
@@ -66,7 +67,7 @@ public void on_predict_end()
6667

6768
public void on_test_batch_begin(long step)
6869
{
69-
callbacks.ForEach(x => x.on_train_batch_begin(step));
70+
callbacks.ForEach(x => x.on_test_batch_begin(step));
7071
}
7172
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
7273
{
Lines changed: 155 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,155 @@
1+
using Tensorflow.Keras.Engine;
2+
namespace Tensorflow.Keras.Callbacks;
3+
4+
5+
/// <summary>
6+
/// Stop training when a monitored metric has stopped improving.
7+
/// </summary>
8+
/// <param name="parameters"></param>
9+
/// <param name="monitor"></param>
10+
11+
public class EarlyStopping: ICallback
12+
{
13+
int _paitence;
14+
int _min_delta;
15+
int _verbose;
16+
int _stopped_epoch;
17+
int _wait;
18+
int _best_epoch;
19+
int _start_from_epoch;
20+
float _best;
21+
float _baseline;
22+
string _monitor;
23+
string _mode;
24+
bool _restore_best_weights;
25+
List<IVariableV1>? _best_weights;
26+
CallbackParams _parameters;
27+
public Dictionary<string, List<float>>? history { get; set; }
28+
// user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model
29+
public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", int min_delta = 0, int patience = 0,
30+
int verbose = 1, string mode = "auto", float baseline = 0f, bool restore_best_weights = false,
31+
int start_from_epoch = 0)
32+
{
33+
_parameters = parameters;
34+
_stopped_epoch = 0;
35+
_wait = 0;
36+
_monitor = monitor;
37+
_paitence = patience;
38+
_verbose = verbose;
39+
_baseline = baseline;
40+
_start_from_epoch = start_from_epoch;
41+
_min_delta = Math.Abs(min_delta);
42+
_restore_best_weights = restore_best_weights;
43+
_mode = mode;
44+
if (mode != "auto" && mode != "min" && mode != "max")
45+
{
46+
Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode);
47+
}
48+
}
49+
public void on_train_begin()
50+
{
51+
_wait = 0;
52+
_stopped_epoch = 0;
53+
_best_epoch = 0;
54+
_best = (float)np.Inf;
55+
}
56+
57+
public void on_epoch_begin(int epoch)
58+
{
59+
60+
}
61+
62+
public void on_train_batch_begin(long step)
63+
{
64+
65+
}
66+
67+
public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
68+
{
69+
}
70+
71+
public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
72+
{
73+
var current = get_monitor_value(epoch_logs);
74+
// If no monitor value exists or still in initial warm-up stage.
75+
if (current == 0f || epoch < _start_from_epoch)
76+
return;
77+
// Restore the weights after first epoch if no progress is ever made.
78+
if (_restore_best_weights && _best_weights == null)
79+
{
80+
_best_weights = _parameters.Model.TrainableWeights;
81+
}
82+
_wait += 1;
83+
84+
if (_is_improvement(current, _best))
85+
{
86+
_best = current;
87+
_best_epoch = epoch;
88+
if (_restore_best_weights)
89+
_best_weights = _parameters.Model.TrainableWeights;
90+
// Only restart wait if we beat both the baseline and our previous best.
91+
if (_baseline == 0f || _is_improvement(current, _baseline))
92+
_wait = 0;
93+
}
94+
// Only check after the first epoch.
95+
if (_wait >= _paitence && epoch > 0)
96+
{
97+
_stopped_epoch = epoch;
98+
_parameters.Model.set_stopTraining_true();
99+
if (_restore_best_weights && _best_weights != null)
100+
{
101+
if (_verbose > 0)
102+
{
103+
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}");
104+
}
105+
}
106+
// Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet.
107+
// TODO(Wanglongzhi2001): implement it.
108+
// _parameters.Model.load_weights(best_weights);
109+
}
110+
}
111+
public void on_train_end()
112+
{
113+
if (_stopped_epoch > 0 && _verbose > 0)
114+
{
115+
Console.WriteLine($"Epoch {_stopped_epoch + 1}: early stopping");
116+
}
117+
}
118+
public void on_predict_begin() { }
119+
public void on_predict_batch_begin(long step) { }
120+
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) { }
121+
public void on_predict_end() { }
122+
public void on_test_begin() { }
123+
public void on_test_batch_begin(long step) { }
124+
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { }
125+
126+
float get_monitor_value(Dictionary<string, float> logs)
127+
{
128+
logs = logs ?? new Dictionary<string, float>();
129+
float monitor_value = logs[_monitor];
130+
if (monitor_value == 0f)
131+
{
132+
Console.WriteLine($"Early stopping conditioned on metric {_monitor} " +
133+
$"which is not available. Available metrics are: {string.Join(", ", logs.Keys)}");
134+
}
135+
return monitor_value;
136+
}
137+
public bool _is_improvement(float monitor_value, float reference_value)
138+
{
139+
bool less_op = (monitor_value - _min_delta) < reference_value;
140+
bool greater_op = (monitor_value - _min_delta) >= reference_value;
141+
if (_mode == "min")
142+
return less_op;
143+
else if (_mode == "max")
144+
return greater_op;
145+
else
146+
{
147+
if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc"))
148+
{
149+
return greater_op;
150+
}
151+
else
152+
return less_op;
153+
}
154+
}
155+
}

src/TensorFlowNET.Keras/Callbacks/History.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ public void on_test_begin()
2323
epochs = new List<int>();
2424
history = new Dictionary<string, List<float>>();
2525
}
26+
public void on_train_end() { }
2627
public void on_epoch_begin(int epoch)
2728
{
2829

src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ public void on_train_begin()
2222
_called_in_fit = true;
2323
_sw = new Stopwatch();
2424
}
25+
public void on_train_end() { }
2526
public void on_test_begin()
2627
{
2728
_sw = new Stopwatch();

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,15 @@ public partial class Model
1919
/// <param name="y"></param>
2020
/// <param name="batch_size"></param>
2121
/// <param name="epochs"></param>
22+
/// <param name="callbacks"></param>
2223
/// <param name="verbose"></param>
2324
/// <param name="validation_split"></param>
2425
/// <param name="shuffle"></param>
2526
public ICallback fit(NDArray x, NDArray y,
2627
int batch_size = -1,
2728
int epochs = 1,
2829
int verbose = 1,
30+
List<ICallback> callbacks = null,
2931
float validation_split = 0f,
3032
bool shuffle = true,
3133
int initial_epoch = 0,
@@ -59,14 +61,15 @@ public ICallback fit(NDArray x, NDArray y,
5961
StepsPerExecution = _steps_per_execution
6062
});
6163

62-
return FitInternal(data_handler, epochs, verbose, validation_data: null,
64+
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
6365
train_step_func: train_step_function);
6466
}
6567

6668
public ICallback fit(IEnumerable<NDArray> x, NDArray y,
6769
int batch_size = -1,
6870
int epochs = 1,
6971
int verbose = 1,
72+
List<ICallback> callbacks = null,
7073
float validation_split = 0f,
7174
bool shuffle = true,
7275
int initial_epoch = 0,
@@ -107,12 +110,12 @@ public ICallback fit(IEnumerable<NDArray> x, NDArray y,
107110
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
108111
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
109112
{
110-
return FitInternal(data_handler, epochs, verbose, validation_data: null,
113+
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
111114
train_step_func: train_step_multi_inputs_function);
112115
}
113116
else
114117
{
115-
return FitInternal(data_handler, epochs, verbose, validation_data: null,
118+
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
116119
train_step_func: train_step_function);
117120
}
118121
}
@@ -122,6 +125,7 @@ public History fit(IDatasetV2 dataset,
122125
int batch_size = -1,
123126
int epochs = 1,
124127
int verbose = 1,
128+
List<ICallback> callbacks = null,
125129
float validation_split = 0f,
126130
bool shuffle = true,
127131
int initial_epoch = 0,
@@ -143,11 +147,11 @@ public History fit(IDatasetV2 dataset,
143147
StepsPerExecution = _steps_per_execution
144148
});
145149

146-
return FitInternal(data_handler, epochs, verbose, validation_data: validation_data,
150+
return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data,
147151
train_step_func: train_step_function);
148152
}
149153

150-
History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data,
154+
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,
151155
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
152156
{
153157
stop_training = false;
@@ -159,6 +163,13 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV
159163
Epochs = epochs,
160164
Steps = data_handler.Inferredsteps
161165
});
166+
167+
if (callbackList != null)
168+
{
169+
foreach(var callback in callbackList)
170+
callbacks.callbacks.add(callback);
171+
}
172+
162173
callbacks.on_train_begin();
163174

164175
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())

src/TensorFlowNET.Keras/Engine/Model.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,5 +144,11 @@ public override IDictionary<string, Trackable> _trackable_children(SaveType save
144144
var children = base._trackable_children(save_type, cache);
145145
return children;
146146
}
147+
148+
149+
void IModel.set_stopTraining_true()
150+
{
151+
stop_training = true;
152+
}
147153
}
148154
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using Tensorflow.Keras.UnitTest.Helpers;
3+
using static Tensorflow.Binding;
4+
using Tensorflow;
5+
using Tensorflow.Keras.Optimizers;
6+
using Tensorflow.Keras.Callbacks;
7+
using Tensorflow.Keras.Engine;
8+
using System.Collections.Generic;
9+
using static Tensorflow.KerasApi;
10+
using Tensorflow.Keras;
11+
12+
13+
namespace TensorFlowNET.Keras.UnitTest
14+
{
15+
[TestClass]
16+
public class EarltstoppingTest
17+
{
18+
[TestMethod]
19+
// Because loading the weight variable into the model has not yet been implemented,
20+
// so you'd better not set patience too large, because the weights will equal to the last epoch's weights.
21+
public void Earltstopping()
22+
{
23+
var layers = keras.layers;
24+
var model = keras.Sequential(new List<ILayer>
25+
{
26+
layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)),
27+
layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu),
28+
layers.MaxPooling2D(),
29+
layers.Flatten(),
30+
layers.Dense(128, activation: keras.activations.Relu),
31+
layers.Dense(10)
32+
});
33+
34+
35+
model.summary();
36+
37+
model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
38+
loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true),
39+
metrics: new[] { "acc" });
40+
41+
var num_epochs = 3;
42+
var batch_size = 8;
43+
44+
var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
45+
x_train = x_train / 255.0f;
46+
// define a CallbackParams first, the parameters you pass al least contain Model and Epochs.
47+
CallbackParams callback_parameters = new CallbackParams
48+
{
49+
Model = model,
50+
Epochs = num_epochs,
51+
};
52+
// define your earlystop
53+
ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy");
54+
// define a callbcaklist, then add the earlystopping to it.
55+
var callbacks = new List<ICallback>();
56+
callbacks.add(earlystop);
57+
58+
model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs,callbacks:callbacks);
59+
}
60+
61+
}
62+
63+
64+
}
65+

0 commit comments

Comments
 (0)