Skip to content

Commit f7208c9

Browse files
DevNullx64Oceania2018
authored andcommitted
Refactor: Model.Evaluate.cs
1 parent 02cb239 commit f7208c9

File tree

1 file changed

+36
-93
lines changed

1 file changed

+36
-93
lines changed

src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

Lines changed: 36 additions & 93 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,38 @@ namespace Tensorflow.Keras.Engine
1414
{
1515
public partial class Model
1616
{
17+
protected Dictionary<string, float> evaluate(CallbackList callbacks, DataHandler data_handler, bool is_val)
18+
{
19+
callbacks.on_test_begin();
20+
21+
//Dictionary<string, float>? logs = null;
22+
var logs = new Dictionary<string, float>();
23+
int x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
24+
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
25+
{
26+
reset_metrics();
27+
callbacks.on_epoch_begin(epoch);
28+
// data_handler.catch_stop_iteration();
29+
30+
foreach (var step in data_handler.steps())
31+
{
32+
callbacks.on_test_batch_begin(step);
33+
34+
var data = iterator.next();
35+
36+
logs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
37+
tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _test_counter.assign_add(1));
38+
39+
var end_step = step + data_handler.StepIncrement;
40+
41+
if (!is_val)
42+
callbacks.on_test_batch_end(end_step, logs);
43+
}
44+
}
45+
46+
return logs;
47+
}
48+
1749
/// <summary>
1850
/// Returns the loss value & metrics values for the model in test mode.
1951
/// </summary>
@@ -64,31 +96,8 @@ public Dictionary<string, float> evaluate(Tensor x, Tensor y,
6496
Verbose = verbose,
6597
Steps = data_handler.Inferredsteps
6698
});
67-
callbacks.on_test_begin();
68-
69-
//Dictionary<string, float>? logs = null;
70-
var logs = new Dictionary<string, float>();
71-
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
72-
{
73-
reset_metrics();
74-
// data_handler.catch_stop_iteration();
7599

76-
foreach (var step in data_handler.steps())
77-
{
78-
callbacks.on_test_batch_begin(step);
79-
logs = test_function(data_handler, iterator);
80-
var end_step = step + data_handler.StepIncrement;
81-
if (is_val == false)
82-
callbacks.on_test_batch_end(end_step, logs);
83-
}
84-
}
85-
86-
var results = new Dictionary<string, float>();
87-
foreach (var log in logs)
88-
{
89-
results[log.Key] = log.Value;
90-
}
91-
return results;
100+
return evaluate(callbacks, data_handler, is_val);
92101
}
93102

94103
public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false)
@@ -107,31 +116,8 @@ public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int v
107116
Verbose = verbose,
108117
Steps = data_handler.Inferredsteps
109118
});
110-
callbacks.on_test_begin();
111119

112-
Dictionary<string, float> logs = null;
113-
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
114-
{
115-
reset_metrics();
116-
callbacks.on_epoch_begin(epoch);
117-
// data_handler.catch_stop_iteration();
118-
119-
foreach (var step in data_handler.steps())
120-
{
121-
callbacks.on_test_batch_begin(step);
122-
logs = test_function(data_handler, iterator);
123-
var end_step = step + data_handler.StepIncrement;
124-
if (is_val == false)
125-
callbacks.on_test_batch_end(end_step, logs);
126-
}
127-
}
128-
129-
var results = new Dictionary<string, float>();
130-
foreach (var log in logs)
131-
{
132-
results[log.Key] = log.Value;
133-
}
134-
return results;
120+
return evaluate(callbacks, data_handler, is_val);
135121
}
136122

137123

@@ -150,51 +136,8 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
150136
Verbose = verbose,
151137
Steps = data_handler.Inferredsteps
152138
});
153-
callbacks.on_test_begin();
154-
155-
Dictionary<string, float> logs = null;
156-
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
157-
{
158-
reset_metrics();
159-
callbacks.on_epoch_begin(epoch);
160-
// data_handler.catch_stop_iteration();
161-
162-
foreach (var step in data_handler.steps())
163-
{
164-
callbacks.on_test_batch_begin(step);
165-
logs = test_function(data_handler, iterator);
166-
var end_step = step + data_handler.StepIncrement;
167-
if (is_val == false)
168-
callbacks.on_test_batch_end(end_step, logs);
169-
}
170-
}
171-
172-
var results = new Dictionary<string, float>();
173-
foreach (var log in logs)
174-
{
175-
results[log.Key] = log.Value;
176-
}
177-
return results;
178-
}
179-
180-
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
181-
{
182-
var data = iterator.next();
183-
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
184-
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
185-
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
186-
return outputs;
187-
}
188-
189-
Dictionary<string, float> test_step(DataHandler data_handler, Tensor x, Tensor y)
190-
{
191-
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
192-
var y_pred = Apply(x, training: false);
193-
var loss = compiled_loss.Call(y, y_pred);
194-
195-
compiled_metrics.update_state(y, y_pred);
196139

197-
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2);
140+
return evaluate(callbacks, data_handler, is_val);
198141
}
199142
}
200-
}
143+
}

0 commit comments

Comments
 (0)