Skip to content

Commit c5cdf2c

Browse files
committed
Fixed model.fit return results. #927
1 parent f48ba40 commit c5cdf2c

File tree

11 files changed

+258
-63
lines changed

11 files changed

+258
-63
lines changed

src/TensorFlowNET.Core/NumPy/Numpy.Manipulation.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,10 @@ namespace Tensorflow.NumPy
88
{
99
public partial class np
1010
{
11+
[AutoNumPy]
12+
public static NDArray concatenate((NDArray, NDArray) tuple, int axis = 0)
13+
=> new NDArray(array_ops.concat(new[] { tuple.Item1, tuple.Item2 }, axis));
14+
1115
[AutoNumPy]
1216
public static NDArray concatenate(NDArray[] arrays, int axis = 0) => new NDArray(array_ops.concat(arrays, axis));
1317

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Callbacks
6+
{
7+
public class CallbackList
8+
{
9+
List<ICallback> callbacks = new List<ICallback>();
10+
public History History => callbacks[0] as History;
11+
12+
public CallbackList(CallbackParams parameters)
13+
{
14+
callbacks.Add(new History(parameters));
15+
callbacks.Add(new ProgbarLogger(parameters));
16+
}
17+
18+
public void on_train_begin()
19+
{
20+
callbacks.ForEach(x => x.on_train_begin());
21+
}
22+
23+
public void on_epoch_begin(int epoch)
24+
{
25+
callbacks.ForEach(x => x.on_epoch_begin(epoch));
26+
}
27+
28+
public void on_train_batch_begin(long step)
29+
{
30+
callbacks.ForEach(x => x.on_train_batch_begin(step));
31+
}
32+
33+
public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
34+
{
35+
callbacks.ForEach(x => x.on_train_batch_end(end_step, logs));
36+
}
37+
38+
public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
39+
{
40+
callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs));
41+
}
42+
}
43+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Keras.Engine;
5+
6+
namespace Tensorflow.Keras.Callbacks
7+
{
8+
public class CallbackParams
9+
{
10+
public IModel Model { get; set; }
11+
public int Verbose { get; set; }
12+
public int Epochs { get; set; }
13+
public long Steps { get; set; }
14+
}
15+
}
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Callbacks
6+
{
7+
public class History : ICallback
8+
{
9+
List<int> epochs;
10+
CallbackParams _parameters;
11+
public Dictionary<string, List<float>> history { get; set; }
12+
13+
public History(CallbackParams parameters)
14+
{
15+
_parameters = parameters;
16+
}
17+
18+
public void on_train_begin()
19+
{
20+
epochs = new List<int>();
21+
history = new Dictionary<string, List<float>>();
22+
}
23+
24+
public void on_epoch_begin(int epoch)
25+
{
26+
27+
}
28+
29+
public void on_train_batch_begin(long step)
30+
{
31+
32+
}
33+
34+
public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
35+
{
36+
}
37+
38+
public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
39+
{
40+
epochs.Add(epoch);
41+
42+
foreach (var log in epoch_logs)
43+
{
44+
if (!history.ContainsKey(log.Key))
45+
{
46+
history[log.Key] = new List<float>();
47+
}
48+
history[log.Key].Add((float)log.Value);
49+
}
50+
}
51+
}
52+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Callbacks
6+
{
7+
public interface ICallback
8+
{
9+
void on_train_begin();
10+
void on_epoch_begin(int epoch);
11+
void on_train_batch_begin(long step);
12+
void on_train_batch_end(long end_step, Dictionary<string, float> logs);
13+
void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs);
14+
}
15+
}
Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,81 @@
1+
using PureHDF;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Diagnostics;
5+
using System.Linq;
6+
using System.Text;
7+
8+
namespace Tensorflow.Keras.Callbacks
9+
{
10+
public class ProgbarLogger : ICallback
11+
{
12+
bool _called_in_fit = false;
13+
int seen = 0;
14+
CallbackParams _parameters;
15+
Stopwatch _sw;
16+
17+
public ProgbarLogger(CallbackParams parameters)
18+
{
19+
_parameters = parameters;
20+
}
21+
22+
public void on_train_begin()
23+
{
24+
_called_in_fit = true;
25+
_sw = new Stopwatch();
26+
}
27+
28+
public void on_epoch_begin(int epoch)
29+
{
30+
_reset_progbar();
31+
_maybe_init_progbar();
32+
Binding.tf_output_redirect.WriteLine($"Epoch: {epoch + 1:D3}/{_parameters.Epochs:D3}");
33+
}
34+
35+
public void on_train_batch_begin(long step)
36+
{
37+
_sw.Restart();
38+
}
39+
40+
public void on_train_batch_end(long end_step, Dictionary<string, float> logs)
41+
{
42+
_sw.Stop();
43+
var elapse = _sw.ElapsedMilliseconds;
44+
var results = string.Join(" - ", logs.Select(x => $"{x.Key}: {(float)x.Value:F6}"));
45+
46+
var progress = "";
47+
var length = 30.0 / _parameters.Steps;
48+
for (int i = 0; i < Math.Floor(end_step * length - 1); i++)
49+
progress += "=";
50+
if (progress.Length < 28)
51+
progress += ">";
52+
else
53+
progress += "=";
54+
55+
var remaining = "";
56+
for (int i = 1; i < 30 - progress.Length; i++)
57+
remaining += ".";
58+
59+
Binding.tf_output_redirect.Write($"{end_step + 1:D4}/{_parameters.Steps:D4} [{progress}{remaining}] - {elapse}ms/step - {results}");
60+
if (!Console.IsOutputRedirected)
61+
{
62+
Console.CursorLeft = 0;
63+
}
64+
}
65+
66+
public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
67+
{
68+
Console.WriteLine();
69+
}
70+
71+
void _reset_progbar()
72+
{
73+
seen = 0;
74+
}
75+
76+
void _maybe_init_progbar()
77+
{
78+
79+
}
80+
}
81+
}

src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ public void evaluate(NDArray x, NDArray y,
3131
bool use_multiprocessing = false,
3232
bool return_dict = false)
3333
{
34-
data_handler = new DataHandler(new DataHandlerArgs
34+
var data_handler = new DataHandler(new DataHandlerArgs
3535
{
3636
X = x,
3737
Y = y,
@@ -46,7 +46,6 @@ public void evaluate(NDArray x, NDArray y,
4646
StepsPerExecution = _steps_per_execution
4747
});
4848

49-
Binding.tf_output_redirect.WriteLine($"Testing...");
5049
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
5150
{
5251
reset_metrics();
@@ -56,22 +55,20 @@ public void evaluate(NDArray x, NDArray y,
5655
foreach (var step in data_handler.steps())
5756
{
5857
// callbacks.on_train_batch_begin(step)
59-
results = test_function(iterator);
58+
results = test_function(data_handler, iterator);
6059
}
61-
Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", results.Select(x => $"{x.Item1}: {(float)x.Item2}")));
6260
}
6361
}
6462

6563
public KeyValuePair<string, float>[] evaluate(IDatasetV2 x)
6664
{
67-
data_handler = new DataHandler(new DataHandlerArgs
65+
var data_handler = new DataHandler(new DataHandlerArgs
6866
{
6967
Dataset = x,
7068
Model = this,
7169
StepsPerExecution = _steps_per_execution
7270
});
7371

74-
Binding.tf_output_redirect.WriteLine($"Testing...");
7572
IEnumerable<(string, Tensor)> logs = null;
7673
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
7774
{
@@ -82,22 +79,21 @@ public KeyValuePair<string, float>[] evaluate(IDatasetV2 x)
8279
foreach (var step in data_handler.steps())
8380
{
8481
// callbacks.on_train_batch_begin(step)
85-
logs = test_function(iterator);
82+
logs = test_function(data_handler, iterator);
8683
}
87-
Binding.tf_output_redirect.WriteLine($"iterator: {epoch + 1}, " + string.Join(", ", logs.Select(x => $"{x.Item1}: {(float)x.Item2}")));
8884
}
8985
return logs.Select(x => new KeyValuePair<string, float>(x.Item1, (float)x.Item2)).ToArray();
9086
}
9187

92-
IEnumerable<(string, Tensor)> test_function(OwnedIterator iterator)
88+
IEnumerable<(string, Tensor)> test_function(DataHandler data_handler, OwnedIterator iterator)
9389
{
9490
var data = iterator.next();
95-
var outputs = test_step(data[0], data[1]);
91+
var outputs = test_step(data_handler, data[0], data[1]);
9692
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
9793
return outputs;
9894
}
9995

100-
List<(string, Tensor)> test_step(Tensor x, Tensor y)
96+
List<(string, Tensor)> test_step(DataHandler data_handler, Tensor x, Tensor y)
10197
{
10298
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
10399
var y_pred = Apply(x, training: false);

0 commit comments

Comments
 (0)