@@ -14,6 +14,38 @@ namespace Tensorflow.Keras.Engine
14
14
{
15
15
public partial class Model
16
16
{
17
+ protected Dictionary < string , float > evaluate ( CallbackList callbacks , DataHandler data_handler , bool is_val )
18
+ {
19
+ callbacks . on_test_begin ( ) ;
20
+
21
+ //Dictionary<string, float>? logs = null;
22
+ var logs = new Dictionary < string , float > ( ) ;
23
+ int x_size = data_handler . DataAdapter . GetDataset ( ) . FirstInputTensorCount ;
24
+ foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
25
+ {
26
+ reset_metrics ( ) ;
27
+ callbacks . on_epoch_begin ( epoch ) ;
28
+ // data_handler.catch_stop_iteration();
29
+
30
+ foreach ( var step in data_handler . steps ( ) )
31
+ {
32
+ callbacks . on_test_batch_begin ( step ) ;
33
+
34
+ var data = iterator . next ( ) ;
35
+
36
+ logs = train_step ( data_handler , new Tensors ( data . Take ( x_size ) ) , new Tensors ( data . Skip ( x_size ) ) ) ;
37
+ tf_with ( ops . control_dependencies ( Array . Empty < object > ( ) ) , ctl => _test_counter . assign_add ( 1 ) ) ;
38
+
39
+ var end_step = step + data_handler . StepIncrement ;
40
+
41
+ if ( ! is_val )
42
+ callbacks . on_test_batch_end ( end_step , logs ) ;
43
+ }
44
+ }
45
+
46
+ return logs ;
47
+ }
48
+
17
49
/// <summary>
18
50
/// Returns the loss value & metrics values for the model in test mode.
19
51
/// </summary>
@@ -64,31 +96,8 @@ public Dictionary<string, float> evaluate(Tensor x, Tensor y,
64
96
Verbose = verbose ,
65
97
Steps = data_handler . Inferredsteps
66
98
} ) ;
67
- callbacks . on_test_begin ( ) ;
68
-
69
- //Dictionary<string, float>? logs = null;
70
- var logs = new Dictionary < string , float > ( ) ;
71
- foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
72
- {
73
- reset_metrics ( ) ;
74
- // data_handler.catch_stop_iteration();
75
99
76
- foreach ( var step in data_handler . steps ( ) )
77
- {
78
- callbacks . on_test_batch_begin ( step ) ;
79
- logs = test_function ( data_handler , iterator ) ;
80
- var end_step = step + data_handler . StepIncrement ;
81
- if ( is_val == false )
82
- callbacks . on_test_batch_end ( end_step , logs ) ;
83
- }
84
- }
85
-
86
- var results = new Dictionary < string , float > ( ) ;
87
- foreach ( var log in logs )
88
- {
89
- results [ log . Key ] = log . Value ;
90
- }
91
- return results ;
100
+ return evaluate ( callbacks , data_handler , is_val ) ;
92
101
}
93
102
94
103
public Dictionary < string , float > evaluate ( IEnumerable < Tensor > x , Tensor y , int verbose = 1 , bool is_val = false )
@@ -107,31 +116,8 @@ public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int v
107
116
Verbose = verbose ,
108
117
Steps = data_handler . Inferredsteps
109
118
} ) ;
110
- callbacks . on_test_begin ( ) ;
111
119
112
- Dictionary < string , float > logs = null ;
113
- foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
114
- {
115
- reset_metrics ( ) ;
116
- callbacks . on_epoch_begin ( epoch ) ;
117
- // data_handler.catch_stop_iteration();
118
-
119
- foreach ( var step in data_handler . steps ( ) )
120
- {
121
- callbacks . on_test_batch_begin ( step ) ;
122
- logs = test_function ( data_handler , iterator ) ;
123
- var end_step = step + data_handler . StepIncrement ;
124
- if ( is_val == false )
125
- callbacks . on_test_batch_end ( end_step , logs ) ;
126
- }
127
- }
128
-
129
- var results = new Dictionary < string , float > ( ) ;
130
- foreach ( var log in logs )
131
- {
132
- results [ log . Key ] = log . Value ;
133
- }
134
- return results ;
120
+ return evaluate ( callbacks , data_handler , is_val ) ;
135
121
}
136
122
137
123
@@ -150,51 +136,8 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
150
136
Verbose = verbose ,
151
137
Steps = data_handler . Inferredsteps
152
138
} ) ;
153
- callbacks . on_test_begin ( ) ;
154
-
155
- Dictionary < string , float > logs = null ;
156
- foreach ( var ( epoch , iterator ) in data_handler . enumerate_epochs ( ) )
157
- {
158
- reset_metrics ( ) ;
159
- callbacks . on_epoch_begin ( epoch ) ;
160
- // data_handler.catch_stop_iteration();
161
-
162
- foreach ( var step in data_handler . steps ( ) )
163
- {
164
- callbacks . on_test_batch_begin ( step ) ;
165
- logs = test_function ( data_handler , iterator ) ;
166
- var end_step = step + data_handler . StepIncrement ;
167
- if ( is_val == false )
168
- callbacks . on_test_batch_end ( end_step , logs ) ;
169
- }
170
- }
171
-
172
- var results = new Dictionary < string , float > ( ) ;
173
- foreach ( var log in logs )
174
- {
175
- results [ log . Key ] = log . Value ;
176
- }
177
- return results ;
178
- }
179
-
180
- Dictionary < string , float > test_function ( DataHandler data_handler , OwnedIterator iterator )
181
- {
182
- var data = iterator . next ( ) ;
183
- var x_size = data_handler . DataAdapter . GetDataset ( ) . FirstInputTensorCount ;
184
- var outputs = train_step ( data_handler , new Tensors ( data . Take ( x_size ) ) , new Tensors ( data . Skip ( x_size ) ) ) ;
185
- tf_with ( ops . control_dependencies ( new object [ 0 ] ) , ctl => _test_counter . assign_add ( 1 ) ) ;
186
- return outputs ;
187
- }
188
-
189
- Dictionary < string , float > test_step ( DataHandler data_handler , Tensor x , Tensor y )
190
- {
191
- ( x , y ) = data_handler . DataAdapter . Expand1d ( x , y ) ;
192
- var y_pred = Apply ( x , training : false ) ;
193
- var loss = compiled_loss . Call ( y , y_pred ) ;
194
-
195
- compiled_metrics . update_state ( y , y_pred ) ;
196
139
197
- return metrics . Select ( x => ( x . Name , x . result ( ) ) ) . ToDictionary ( x => x . Item1 , x => ( float ) x . Item2 ) ;
140
+ return evaluate ( callbacks , data_handler , is_val ) ;
198
141
}
199
142
}
200
- }
143
+ }
0 commit comments