Skip to content

Commit 0effee4

Browse files
DevNullx64Oceania2018
authored andcommitted
Update Model.Evaluate.cs
Fix my bad: Bad handling between test_function and test_step_multi_inputs_function.
1 parent f7208c9 commit 0effee4

File tree

1 file changed

+75
-41
lines changed

1 file changed

+75
-41
lines changed

src/TensorFlowNET.Keras/Engine/Model.Evaluate.cs

Lines changed: 75 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,19 @@
1-
using Tensorflow.NumPy;
21
using System;
32
using System.Collections.Generic;
43
using System.Linq;
4+
using Tensorflow;
55
using Tensorflow.Keras.ArgsDefinition;
6+
using Tensorflow.Keras.Callbacks;
67
using Tensorflow.Keras.Engine.DataAdapters;
7-
using static Tensorflow.Binding;
88
using Tensorflow.Keras.Layers;
99
using Tensorflow.Keras.Utils;
10-
using Tensorflow;
11-
using Tensorflow.Keras.Callbacks;
10+
using Tensorflow.NumPy;
11+
using static Tensorflow.Binding;
1212

1313
namespace Tensorflow.Keras.Engine
1414
{
1515
public partial class Model
1616
{
17-
protected Dictionary<string, float> evaluate(CallbackList callbacks, DataHandler data_handler, bool is_val)
18-
{
19-
callbacks.on_test_begin();
20-
21-
//Dictionary<string, float>? logs = null;
22-
var logs = new Dictionary<string, float>();
23-
int x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
24-
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
25-
{
26-
reset_metrics();
27-
callbacks.on_epoch_begin(epoch);
28-
// data_handler.catch_stop_iteration();
29-
30-
foreach (var step in data_handler.steps())
31-
{
32-
callbacks.on_test_batch_begin(step);
33-
34-
var data = iterator.next();
35-
36-
logs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
37-
tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _test_counter.assign_add(1));
38-
39-
var end_step = step + data_handler.StepIncrement;
40-
41-
if (!is_val)
42-
callbacks.on_test_batch_end(end_step, logs);
43-
}
44-
}
45-
46-
return logs;
47-
}
48-
4917
/// <summary>
5018
/// Returns the loss value & metrics values for the model in test mode.
5119
/// </summary>
@@ -97,7 +65,7 @@ public Dictionary<string, float> evaluate(Tensor x, Tensor y,
9765
Steps = data_handler.Inferredsteps
9866
});
9967

100-
return evaluate(callbacks, data_handler, is_val);
68+
return evaluate(data_handler, callbacks, is_val, test_function);
10169
}
10270

10371
public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false)
@@ -117,10 +85,9 @@ public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int v
11785
Steps = data_handler.Inferredsteps
11886
});
11987

120-
return evaluate(callbacks, data_handler, is_val);
88+
return evaluate(data_handler, callbacks, is_val, test_step_multi_inputs_function);
12189
}
12290

123-
12491
public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false)
12592
{
12693
var data_handler = new DataHandler(new DataHandlerArgs
@@ -137,7 +104,74 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
137104
Steps = data_handler.Inferredsteps
138105
});
139106

140-
return evaluate(callbacks, data_handler, is_val);
107+
return evaluate(data_handler, callbacks, is_val, test_function);
108+
}
109+
110+
/// <summary>
111+
/// Internal bare implementation of evaluate function.
112+
/// </summary>
113+
/// <param name="data_handler">Interations handling objects</param>
114+
/// <param name="callbacks"></param>
115+
/// <param name="test_func">The function to be called on each batch of data.</param>
116+
/// <param name="is_val">Whether it is validation or test.</param>
117+
/// <returns></returns>
118+
Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func<DataHandler, Tensor[], Dictionary<string, float>> test_func)
119+
{
120+
callbacks.on_test_begin();
121+
122+
var results = new Dictionary<string, float>();
123+
var logs = results;
124+
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
125+
{
126+
reset_metrics();
127+
callbacks.on_epoch_begin(epoch);
128+
// data_handler.catch_stop_iteration();
129+
130+
foreach (var step in data_handler.steps())
131+
{
132+
callbacks.on_test_batch_begin(step);
133+
134+
var data = iterator.next();
135+
136+
logs = test_func(data_handler, iterator.next());
137+
138+
tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _train_counter.assign_add(1));
139+
140+
var end_step = step + data_handler.StepIncrement;
141+
if (!is_val)
142+
callbacks.on_test_batch_end(end_step, logs);
143+
}
144+
145+
if (!is_val)
146+
callbacks.on_epoch_end(epoch, logs);
147+
}
148+
149+
foreach (var log in logs)
150+
{
151+
results[log.Key] = log.Value;
152+
}
153+
154+
return results;
155+
}
156+
157+
Dictionary<string, float> test_function(DataHandler data_handler, Tensor[] data)
158+
{
159+
var (x, y) = data_handler.DataAdapter.Expand1d(data[0], data[1]);
160+
161+
var y_pred = Apply(x, training: false);
162+
var loss = compiled_loss.Call(y, y_pred);
163+
164+
compiled_metrics.update_state(y, y_pred);
165+
166+
var outputs = metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Name, x => (float)x.Item2);
167+
return outputs;
168+
}
169+
170+
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, Tensor[] data)
171+
{
172+
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
173+
var outputs = train_step(data_handler, new Tensors(data.Take(x_size)), new Tensors(data.Skip(x_size)));
174+
return outputs;
141175
}
142176
}
143-
}
177+
}

0 commit comments

Comments
 (0)