1
- using Tensorflow . NumPy ;
2
1
using System ;
3
2
using System . Collections . Generic ;
4
3
using System . Linq ;
4
+ using Tensorflow ;
5
5
using Tensorflow . Keras . ArgsDefinition ;
6
+ using Tensorflow . Keras . Callbacks ;
6
7
using Tensorflow . Keras . Engine . DataAdapters ;
7
- using static Tensorflow . Binding ;
8
8
using Tensorflow . Keras . Layers ;
9
9
using Tensorflow . Keras . Utils ;
10
- using Tensorflow ;
11
- using Tensorflow . Keras . Callbacks ;
10
+ using Tensorflow . NumPy ;
11
+ using static Tensorflow . Binding ;
12
12
13
13
namespace Tensorflow . Keras . Engine
14
14
{
@@ -27,7 +27,7 @@ public partial class Model
27
27
/// <param name="use_multiprocessing"></param>
28
28
/// <param name="return_dict"></param>
29
29
/// <param name="is_val"></param>
30
- public Dictionary < string , float > evaluate ( NDArray x , NDArray y ,
30
+ public Dictionary < string , float > evaluate ( Tensor x , Tensor y ,
31
31
int batch_size = - 1 ,
32
32
int verbose = 1 ,
33
33
int steps = - 1 ,
@@ -64,34 +64,11 @@ public Dictionary<string, float> evaluate(NDArray x, NDArray y,
64
64
Verbose = verbose ,
65
65
Steps = data_handler . Inferredsteps
66
66
} ) ;
67
- callbacks . on_test_begin ( ) ;
68
-
69
- //Dictionary<string, float>? logs = null;
70
- var logs = new Dictionary < string , float > ( ) ;
71
- foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
72
- {
73
- reset_metrics ( ) ;
74
- // data_handler.catch_stop_iteration();
75
-
76
- foreach ( var step in data_handler . steps ( ) )
77
- {
78
- callbacks . on_test_batch_begin ( step ) ;
79
- logs = test_function ( data_handler , iterator ) ;
80
- var end_step = step + data_handler . StepIncrement ;
81
- if ( is_val == false )
82
- callbacks . on_test_batch_end ( end_step , logs ) ;
83
- }
84
- }
85
67
86
- var results = new Dictionary < string , float > ( ) ;
87
- foreach ( var log in logs )
88
- {
89
- results [ log . Key ] = log . Value ;
90
- }
91
- return results ;
68
+ return evaluate ( data_handler , callbacks , is_val , test_function ) ;
92
69
}
93
70
94
- public Dictionary < string , float > evaluate ( IEnumerable < Tensor > x , NDArray y , int verbose = 1 , bool is_val = false )
71
+ public Dictionary < string , float > evaluate ( IEnumerable < Tensor > x , Tensor y , int verbose = 1 , bool is_val = false )
95
72
{
96
73
var data_handler = new DataHandler ( new DataHandlerArgs
97
74
{
@@ -107,34 +84,10 @@ public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int
107
84
Verbose = verbose ,
108
85
Steps = data_handler . Inferredsteps
109
86
} ) ;
110
- callbacks . on_test_begin ( ) ;
111
87
112
- Dictionary < string , float > logs = null ;
113
- foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
114
- {
115
- reset_metrics ( ) ;
116
- callbacks . on_epoch_begin ( epoch ) ;
117
- // data_handler.catch_stop_iteration();
118
-
119
- foreach ( var step in data_handler . steps ( ) )
120
- {
121
- callbacks . on_test_batch_begin ( step ) ;
122
- logs = test_step_multi_inputs_function ( data_handler , iterator ) ;
123
- var end_step = step + data_handler . StepIncrement ;
124
- if ( is_val == false )
125
- callbacks . on_test_batch_end ( end_step , logs ) ;
126
- }
127
- }
128
-
129
- var results = new Dictionary < string , float > ( ) ;
130
- foreach ( var log in logs )
131
- {
132
- results [ log . Key ] = log . Value ;
133
- }
134
- return results ;
88
+ return evaluate ( data_handler , callbacks , is_val , test_step_multi_inputs_function ) ;
135
89
}
136
90
137
-
138
91
public Dictionary < string , float > evaluate ( IDatasetV2 x , int verbose = 1 , bool is_val = false )
139
92
{
140
93
var data_handler = new DataHandler ( new DataHandlerArgs
@@ -150,9 +103,24 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
150
103
Verbose = verbose ,
151
104
Steps = data_handler . Inferredsteps
152
105
} ) ;
106
+
107
+ return evaluate ( data_handler , callbacks , is_val , test_function ) ;
108
+ }
109
+
110
+ /// <summary>
111
+ /// Internal bare implementation of evaluate function.
112
+ /// </summary>
113
+ /// <param name="data_handler">Interations handling objects</param>
114
+ /// <param name="callbacks"></param>
115
+ /// <param name="test_func">The function to be called on each batch of data.</param>
116
+ /// <param name="is_val">Whether it is validation or test.</param>
117
+ /// <returns></returns>
118
+ Dictionary < string , float > evaluate ( DataHandler data_handler , CallbackList callbacks , bool is_val , Func < DataHandler , Tensor [ ] , Dictionary < string , float > > test_func )
119
+ {
153
120
callbacks . on_test_begin ( ) ;
154
121
155
- Dictionary < string , float > logs = null ;
122
+ var results = new Dictionary < string , float > ( ) ;
123
+ var logs = results ;
156
124
foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
157
125
{
158
126
reset_metrics ( ) ;
@@ -162,45 +130,47 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
162
130
foreach ( var step in data_handler . steps ( ) )
163
131
{
164
132
callbacks . on_test_batch_begin ( step ) ;
165
- logs = test_function ( data_handler , iterator ) ;
133
+
134
+ logs = test_func ( data_handler , iterator . next ( ) ) ;
135
+
136
+ tf_with ( ops . control_dependencies ( Array . Empty < object > ( ) ) , ctl => _train_counter . assign_add ( 1 ) ) ;
137
+
166
138
var end_step = step + data_handler . StepIncrement ;
167
- if ( is_val == false )
139
+ if ( ! is_val )
168
140
callbacks . on_test_batch_end ( end_step , logs ) ;
169
141
}
142
+
143
+ if ( ! is_val )
144
+ callbacks . on_epoch_end ( epoch , logs ) ;
170
145
}
171
146
172
- var results = new Dictionary < string , float > ( ) ;
173
147
foreach ( var log in logs )
174
148
{
175
149
results [ log . Key ] = log . Value ;
176
150
}
151
+
177
152
return results ;
178
153
}
179
154
180
- Dictionary < string , float > test_function ( DataHandler data_handler , OwnedIterator iterator )
155
+ Dictionary < string , float > test_function ( DataHandler data_handler , Tensor [ ] data )
181
156
{
182
- var data = iterator . next ( ) ;
183
- var outputs = test_step ( data_handler , data [ 0 ] , data [ 1 ] ) ;
184
- tf_with ( ops . control_dependencies ( new object [ 0 ] ) , ctl => _test_counter . assign_add ( 1 ) ) ;
157
+ var ( x , y ) = data_handler . DataAdapter . Expand1d ( data [ 0 ] , data [ 1 ] ) ;
158
+
159
+ var y_pred = Apply ( x , training : false ) ;
160
+ var loss = compiled_loss . Call ( y , y_pred ) ;
161
+
162
+ compiled_metrics . update_state ( y , y_pred ) ;
163
+
164
+ var outputs = metrics . Select ( x => ( x . Name , x . result ( ) ) ) . ToDictionary ( x => x . Name , x => ( float ) x . Item2 ) ;
185
165
return outputs ;
186
166
}
187
- Dictionary < string , float > test_step_multi_inputs_function ( DataHandler data_handler , OwnedIterator iterator )
167
+
168
+ Dictionary < string , float > test_step_multi_inputs_function ( DataHandler data_handler , Tensor [ ] data )
188
169
{
189
- var data = iterator . next ( ) ;
190
170
var x_size = data_handler . DataAdapter . GetDataset ( ) . FirstInputTensorCount ;
191
171
var outputs = train_step ( data_handler , new Tensors ( data . Take ( x_size ) . ToArray ( ) ) , new Tensors ( data . Skip ( x_size ) . ToArray ( ) ) ) ;
192
172
tf_with ( ops . control_dependencies ( new object [ 0 ] ) , ctl => _train_counter . assign_add ( 1 ) ) ;
193
173
return outputs ;
194
174
}
195
- Dictionary < string , float > test_step ( DataHandler data_handler , Tensor x , Tensor y )
196
- {
197
- ( x , y ) = data_handler . DataAdapter . Expand1d ( x , y ) ;
198
- var y_pred = Apply ( x , training : false ) ;
199
- var loss = compiled_loss . Call ( y , y_pred ) ;
200
-
201
- compiled_metrics . update_state ( y , y_pred ) ;
202
-
203
- return metrics . Select ( x => ( x . Name , x . result ( ) ) ) . ToDictionary ( x=> x . Item1 , x=> ( float ) x . Item2 ) ;
204
- }
205
175
}
206
176
}
0 commit comments