Skip to content

Add EarlyStopping callback #1007

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 6 commits into from
Mar 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/TensorFlowNET.Core/Keras/Engine/ICallback.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ public interface ICallback
{
Dictionary<string, List<float>> history { get; set; }
void on_train_begin();
void on_train_end();
void on_epoch_begin(int epoch);
void on_train_batch_begin(long step);
void on_train_batch_end(long end_step, Dictionary<string, float> logs);
Expand Down
4 changes: 4 additions & 0 deletions src/TensorFlowNET.Core/Keras/Engine/IModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ ICallback fit(NDArray x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
Expand All @@ -28,6 +29,7 @@ ICallback fit(IEnumerable<NDArray> x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
Expand Down Expand Up @@ -73,4 +75,6 @@ Tensors predict(Tensors x,
void summary(int line_length = -1, float[] positions = null);

IKerasConfig get_config();

void set_stopTraining_true();
}
5 changes: 3 additions & 2 deletions src/TensorFlowNET.Keras/Callbacks/CallbackList.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ namespace Tensorflow.Keras.Callbacks;

public class CallbackList
{
List<ICallback> callbacks = new List<ICallback>();
// 改成public使得新定义的callback可以加入到callbacks里
public List<ICallback> callbacks = new List<ICallback>();
public History History => callbacks[0] as History;

public CallbackList(CallbackParams parameters)
Expand Down Expand Up @@ -66,7 +67,7 @@ public void on_predict_end()

public void on_test_batch_begin(long step)
{
callbacks.ForEach(x => x.on_train_batch_begin(step));
callbacks.ForEach(x => x.on_test_batch_begin(step));
}
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs)
{
Expand Down
155 changes: 155 additions & 0 deletions src/TensorFlowNET.Keras/Callbacks/Earlystopping.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
using Tensorflow.Keras.Engine;
namespace Tensorflow.Keras.Callbacks;


/// <summary>
/// Stop training when a monitored metric has stopped improving.
/// </summary>
/// <param name="parameters"></param>
/// <param name="monitor"></param>

public class EarlyStopping: ICallback
{
int _paitence;
int _min_delta;
int _verbose;
int _stopped_epoch;
int _wait;
int _best_epoch;
int _start_from_epoch;
float _best;
float _baseline;
string _monitor;
string _mode;
bool _restore_best_weights;
List<IVariableV1>? _best_weights;
CallbackParams _parameters;
public Dictionary<string, List<float>>? history { get; set; }
// user need to pass a CallbackParams to EarlyStopping, CallbackParams at least need the model
public EarlyStopping(CallbackParams parameters,string monitor = "val_loss", int min_delta = 0, int patience = 0,
int verbose = 1, string mode = "auto", float baseline = 0f, bool restore_best_weights = false,
int start_from_epoch = 0)
{
_parameters = parameters;
_stopped_epoch = 0;
_wait = 0;
_monitor = monitor;
_paitence = patience;
_verbose = verbose;
_baseline = baseline;
_start_from_epoch = start_from_epoch;
_min_delta = Math.Abs(min_delta);
_restore_best_weights = restore_best_weights;
_mode = mode;
if (mode != "auto" && mode != "min" && mode != "max")
{
Console.WriteLine("EarlyStopping mode %s is unknown, fallback to auto mode.", mode);
}
}
public void on_train_begin()
{
_wait = 0;
_stopped_epoch = 0;
_best_epoch = 0;
_best = (float)np.Inf;
}

public void on_epoch_begin(int epoch)
{

}

public void on_train_batch_begin(long step)
{

}

public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
{
}

public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
{
var current = get_monitor_value(epoch_logs);
// If no monitor value exists or still in initial warm-up stage.
if (current == 0f || epoch < _start_from_epoch)
return;
// Restore the weights after first epoch if no progress is ever made.
if (_restore_best_weights && _best_weights == null)
{
_best_weights = _parameters.Model.TrainableWeights;
}
_wait += 1;

if (_is_improvement(current, _best))
{
_best = current;
_best_epoch = epoch;
if (_restore_best_weights)
_best_weights = _parameters.Model.TrainableWeights;
// Only restart wait if we beat both the baseline and our previous best.
if (_baseline == 0f || _is_improvement(current, _baseline))
_wait = 0;
}
// Only check after the first epoch.
if (_wait >= _paitence && epoch > 0)
{
_stopped_epoch = epoch;
_parameters.Model.set_stopTraining_true();
if (_restore_best_weights && _best_weights != null)
{
if (_verbose > 0)
{
Console.WriteLine($"Restoring model weights from the end of the best epoch: {_best_epoch + 1}");
}
}
// Because loading the weight variable into the model has not yet been implemented, so Earlystopping can't load best_weight yet.
// TODO(Wanglongzhi2001): implement it.
// _parameters.Model.load_weights(best_weights);
}
}
public void on_train_end()
{
if (_stopped_epoch > 0 && _verbose > 0)
{
Console.WriteLine($"Epoch {_stopped_epoch + 1}: early stopping");
}
}
public void on_predict_begin() { }
public void on_predict_batch_begin(long step) { }
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs) { }
public void on_predict_end() { }
public void on_test_begin() { }
public void on_test_batch_begin(long step) { }
public void on_test_batch_end(long end_step, IEnumerable<(string, Tensor)> logs) { }

float get_monitor_value(Dictionary<string, float> logs)
{
logs = logs ?? new Dictionary<string, float>();
float monitor_value = logs[_monitor];
if (monitor_value == 0f)
{
Console.WriteLine($"Early stopping conditioned on metric {_monitor} " +
$"which is not available. Available metrics are: {string.Join(", ", logs.Keys)}");
}
return monitor_value;
}
public bool _is_improvement(float monitor_value, float reference_value)
{
bool less_op = (monitor_value - _min_delta) < reference_value;
bool greater_op = (monitor_value - _min_delta) >= reference_value;
if (_mode == "min")
return less_op;
else if (_mode == "max")
return greater_op;
else
{
if (_monitor.EndsWith("acc") || _monitor.EndsWith("accuracy") || _monitor.EndsWith("auc"))
{
return greater_op;
}
else
return less_op;
}
}
}
1 change: 1 addition & 0 deletions src/TensorFlowNET.Keras/Callbacks/History.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ public void on_test_begin()
epochs = new List<int>();
history = new Dictionary<string, List<float>>();
}
public void on_train_end() { }
public void on_epoch_begin(int epoch)
{

Expand Down
1 change: 1 addition & 0 deletions src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public void on_train_begin()
_called_in_fit = true;
_sw = new Stopwatch();
}
public void on_train_end() { }
public void on_test_begin()
{
_sw = new Stopwatch();
Expand Down
21 changes: 16 additions & 5 deletions src/TensorFlowNET.Keras/Engine/Model.Fit.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ public partial class Model
/// <param name="y"></param>
/// <param name="batch_size"></param>
/// <param name="epochs"></param>
/// <param name="callbacks"></param>
/// <param name="verbose"></param>
/// <param name="validation_split"></param>
/// <param name="shuffle"></param>
public ICallback fit(NDArray x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
Expand Down Expand Up @@ -59,14 +61,15 @@ public ICallback fit(NDArray x, NDArray y,
StepsPerExecution = _steps_per_execution
});

return FitInternal(data_handler, epochs, verbose, validation_data: null,
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
train_step_func: train_step_function);
}

public ICallback fit(IEnumerable<NDArray> x, NDArray y,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
Expand Down Expand Up @@ -107,12 +110,12 @@ public ICallback fit(IEnumerable<NDArray> x, NDArray y,
if (data_handler.DataAdapter.GetDataset().structure.Length > 2 ||
data_handler.DataAdapter.GetDataset().FirstInputTensorCount > 1)
{
return FitInternal(data_handler, epochs, verbose, validation_data: null,
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
train_step_func: train_step_multi_inputs_function);
}
else
{
return FitInternal(data_handler, epochs, verbose, validation_data: null,
return FitInternal(data_handler, epochs, verbose, callbackList: callbacks, validation_data: null,
train_step_func: train_step_function);
}
}
Expand All @@ -122,6 +125,7 @@ public History fit(IDatasetV2 dataset,
int batch_size = -1,
int epochs = 1,
int verbose = 1,
List<ICallback> callbacks = null,
float validation_split = 0f,
bool shuffle = true,
int initial_epoch = 0,
Expand All @@ -143,11 +147,11 @@ public History fit(IDatasetV2 dataset,
StepsPerExecution = _steps_per_execution
});

return FitInternal(data_handler, epochs, verbose, validation_data: validation_data,
return FitInternal(data_handler, epochs, verbose, callbacks, validation_data: validation_data,
train_step_func: train_step_function);
}

History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV2 validation_data,
History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICallback> callbackList, IDatasetV2 validation_data,
Func<DataHandler, OwnedIterator, Dictionary<string, float>> train_step_func)
{
stop_training = false;
Expand All @@ -159,6 +163,13 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, IDatasetV
Epochs = epochs,
Steps = data_handler.Inferredsteps
});

if (callbackList != null)
{
foreach(var callback in callbackList)
callbacks.callbacks.add(callback);
}

callbacks.on_train_begin();

foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
Expand Down
6 changes: 6 additions & 0 deletions src/TensorFlowNET.Keras/Engine/Model.cs
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,11 @@ public override IDictionary<string, Trackable> _trackable_children(SaveType save
var children = base._trackable_children(save_type, cache);
return children;
}


void IModel.set_stopTraining_true()
{
stop_training = true;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
using Microsoft.VisualStudio.TestTools.UnitTesting;
using Tensorflow.Keras.UnitTest.Helpers;
using static Tensorflow.Binding;
using Tensorflow;
using Tensorflow.Keras.Optimizers;
using Tensorflow.Keras.Callbacks;
using Tensorflow.Keras.Engine;
using System.Collections.Generic;
using static Tensorflow.KerasApi;
using Tensorflow.Keras;


namespace TensorFlowNET.Keras.UnitTest
{
[TestClass]
public class EarltstoppingTest
{
[TestMethod]
// Because loading the weight variable into the model has not yet been implemented,
// so you'd better not set patience too large, because the weights will equal to the last epoch's weights.
public void Earltstopping()
{
var layers = keras.layers;
var model = keras.Sequential(new List<ILayer>
{
layers.Rescaling(1.0f / 255, input_shape: (32, 32, 3)),
layers.Conv2D(32, 3, padding: "same", activation: keras.activations.Relu),
layers.MaxPooling2D(),
layers.Flatten(),
layers.Dense(128, activation: keras.activations.Relu),
layers.Dense(10)
});


model.summary();

model.compile(optimizer: keras.optimizers.RMSprop(1e-3f),
loss: keras.losses.SparseCategoricalCrossentropy(from_logits: true),
metrics: new[] { "acc" });

var num_epochs = 3;
var batch_size = 8;

var ((x_train, y_train), (x_test, y_test)) = keras.datasets.cifar10.load_data();
x_train = x_train / 255.0f;
// define a CallbackParams first, the parameters you pass al least contain Model and Epochs.
CallbackParams callback_parameters = new CallbackParams
{
Model = model,
Epochs = num_epochs,
};
// define your earlystop
ICallback earlystop = new EarlyStopping(callback_parameters, "accuracy");
// define a callbcaklist, then add the earlystopping to it.
var callbacks = new List<ICallback>();
callbacks.add(earlystop);

model.fit(x_train[new Slice(0, 2000)], y_train[new Slice(0, 2000)], batch_size, num_epochs,callbacks:callbacks);
}

}


}