1
- using Tensorflow . NumPy ;
2
1
using System ;
3
2
using System . Collections . Generic ;
4
3
using System . Linq ;
4
+ using Tensorflow ;
5
5
using Tensorflow . Keras . ArgsDefinition ;
6
+ using Tensorflow . Keras . Callbacks ;
6
7
using Tensorflow . Keras . Engine . DataAdapters ;
7
- using static Tensorflow . Binding ;
8
8
using Tensorflow . Keras . Layers ;
9
9
using Tensorflow . Keras . Utils ;
10
- using Tensorflow ;
11
- using Tensorflow . Keras . Callbacks ;
10
+ using Tensorflow . NumPy ;
11
+ using static Tensorflow . Binding ;
12
12
13
13
namespace Tensorflow . Keras . Engine
14
14
{
15
15
public partial class Model
16
16
{
17
- protected Dictionary < string , float > evaluate ( CallbackList callbacks , DataHandler data_handler , bool is_val )
18
- {
19
- callbacks . on_test_begin ( ) ;
20
-
21
- //Dictionary<string, float>? logs = null;
22
- var logs = new Dictionary < string , float > ( ) ;
23
- int x_size = data_handler . DataAdapter . GetDataset ( ) . FirstInputTensorCount ;
24
- foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
25
- {
26
- reset_metrics ( ) ;
27
- callbacks . on_epoch_begin ( epoch ) ;
28
- // data_handler.catch_stop_iteration();
29
-
30
- foreach ( var step in data_handler . steps ( ) )
31
- {
32
- callbacks . on_test_batch_begin ( step ) ;
33
-
34
- var data = iterator . next ( ) ;
35
-
36
- logs = train_step ( data_handler , new Tensors ( data . Take ( x_size ) ) , new Tensors ( data . Skip ( x_size ) ) ) ;
37
- tf_with ( ops . control_dependencies ( Array . Empty < object > ( ) ) , ctl => _test_counter . assign_add ( 1 ) ) ;
38
-
39
- var end_step = step + data_handler . StepIncrement ;
40
-
41
- if ( ! is_val )
42
- callbacks . on_test_batch_end ( end_step , logs ) ;
43
- }
44
- }
45
-
46
- return logs ;
47
- }
48
-
49
17
/// <summary>
50
18
/// Returns the loss value & metrics values for the model in test mode.
51
19
/// </summary>
@@ -97,7 +65,7 @@ public Dictionary<string, float> evaluate(Tensor x, Tensor y,
97
65
Steps = data_handler . Inferredsteps
98
66
} ) ;
99
67
100
- return evaluate ( callbacks , data_handler , is_val ) ;
68
+ return evaluate ( data_handler , callbacks , is_val , test_function ) ;
101
69
}
102
70
103
71
public Dictionary < string , float > evaluate ( IEnumerable < Tensor > x , Tensor y , int verbose = 1 , bool is_val = false )
@@ -117,10 +85,9 @@ public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int v
117
85
Steps = data_handler . Inferredsteps
118
86
} ) ;
119
87
120
- return evaluate ( callbacks , data_handler , is_val ) ;
88
+ return evaluate ( data_handler , callbacks , is_val , test_step_multi_inputs_function ) ;
121
89
}
122
90
123
-
124
91
public Dictionary < string , float > evaluate ( IDatasetV2 x , int verbose = 1 , bool is_val = false )
125
92
{
126
93
var data_handler = new DataHandler ( new DataHandlerArgs
@@ -137,7 +104,74 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
137
104
Steps = data_handler . Inferredsteps
138
105
} ) ;
139
106
140
- return evaluate ( callbacks , data_handler , is_val ) ;
107
+ return evaluate ( data_handler , callbacks , is_val , test_function ) ;
108
+ }
109
+
110
+ /// <summary>
111
+ /// Internal bare implementation of evaluate function.
112
+ /// </summary>
113
+ /// <param name="data_handler">Interations handling objects</param>
114
+ /// <param name="callbacks"></param>
115
+ /// <param name="test_func">The function to be called on each batch of data.</param>
116
+ /// <param name="is_val">Whether it is validation or test.</param>
117
+ /// <returns></returns>
118
+ Dictionary < string , float > evaluate ( DataHandler data_handler , CallbackList callbacks , bool is_val , Func < DataHandler , Tensor [ ] , Dictionary < string , float > > test_func )
119
+ {
120
+ callbacks . on_test_begin ( ) ;
121
+
122
+ var results = new Dictionary < string , float > ( ) ;
123
+ var logs = results ;
124
+ foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
125
+ {
126
+ reset_metrics ( ) ;
127
+ callbacks . on_epoch_begin ( epoch ) ;
128
+ // data_handler.catch_stop_iteration();
129
+
130
+ foreach ( var step in data_handler . steps ( ) )
131
+ {
132
+ callbacks . on_test_batch_begin ( step ) ;
133
+
134
+ var data = iterator . next ( ) ;
135
+
136
+ logs = test_func ( data_handler , iterator . next ( ) ) ;
137
+
138
+ tf_with ( ops . control_dependencies ( Array . Empty < object > ( ) ) , ctl => _train_counter . assign_add ( 1 ) ) ;
139
+
140
+ var end_step = step + data_handler . StepIncrement ;
141
+ if ( ! is_val )
142
+ callbacks . on_test_batch_end ( end_step , logs ) ;
143
+ }
144
+
145
+ if ( ! is_val )
146
+ callbacks . on_epoch_end ( epoch , logs ) ;
147
+ }
148
+
149
+ foreach ( var log in logs )
150
+ {
151
+ results [ log . Key ] = log . Value ;
152
+ }
153
+
154
+ return results ;
155
+ }
156
+
157
+ Dictionary < string , float > test_function ( DataHandler data_handler , Tensor [ ] data )
158
+ {
159
+ var ( x , y ) = data_handler . DataAdapter . Expand1d ( data [ 0 ] , data [ 1 ] ) ;
160
+
161
+ var y_pred = Apply ( x , training : false ) ;
162
+ var loss = compiled_loss . Call ( y , y_pred ) ;
163
+
164
+ compiled_metrics . update_state ( y , y_pred ) ;
165
+
166
+ var outputs = metrics . Select ( x => ( x . Name , x . result ( ) ) ) . ToDictionary ( x => x . Name , x => ( float ) x . Item2 ) ;
167
+ return outputs ;
168
+ }
169
+
170
+ Dictionary < string , float > test_step_multi_inputs_function ( DataHandler data_handler , Tensor [ ] data )
171
+ {
172
+ var x_size = data_handler . DataAdapter . GetDataset ( ) . FirstInputTensorCount ;
173
+ var outputs = train_step ( data_handler , new Tensors ( data . Take ( x_size ) ) , new Tensors ( data . Skip ( x_size ) ) ) ;
174
+ return outputs ;
141
175
}
142
176
}
143
- }
177
+ }
0 commit comments