Skip to content

Commit 271dcef

Browse files
committed
fix keras model predict return result.
1 parent ec340ee commit 271dcef

File tree

6 files changed

+128
-3
lines changed

6 files changed

+128
-3
lines changed

src/TensorFlowNET.Core/Sessions/BaseSession.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -293,7 +293,8 @@ protected override void DisposeUnmanagedResources(IntPtr handle)
293293
// c_api.TF_CloseSession(handle, tf.Status.Handle);
294294
if (tf.Status == null || tf.Status.Handle.IsInvalid)
295295
{
296-
c_api.TF_DeleteSession(handle, c_api.TF_NewStatus());
296+
using var status = new Status();
297+
c_api.TF_DeleteSession(handle, status.Handle);
297298
}
298299
else
299300
{

src/TensorFlowNET.Keras/Callbacks/CallbackList.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,5 +39,25 @@ public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
3939
{
4040
callbacks.ForEach(x => x.on_epoch_end(epoch, epoch_logs));
4141
}
42+
43+
public void on_predict_begin()
44+
{
45+
callbacks.ForEach(x => x.on_predict_begin());
46+
}
47+
48+
public void on_predict_batch_begin(long step)
49+
{
50+
callbacks.ForEach(x => x.on_predict_batch_begin(step));
51+
}
52+
53+
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
54+
{
55+
callbacks.ForEach(x => x.on_predict_batch_end(end_step, logs));
56+
}
57+
58+
public void on_predict_end()
59+
{
60+
callbacks.ForEach(x => x.on_predict_end());
61+
}
4262
}
4363
}

src/TensorFlowNET.Keras/Callbacks/History.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,26 @@ public void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs)
4848
history[log.Key].Add((float)log.Value);
4949
}
5050
}
51+
52+
public void on_predict_begin()
53+
{
54+
epochs = new List<int>();
55+
history = new Dictionary<string, List<float>>();
56+
}
57+
58+
public void on_predict_batch_begin(long step)
59+
{
60+
61+
}
62+
63+
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
64+
{
65+
66+
}
67+
68+
public void on_predict_end()
69+
{
70+
71+
}
5172
}
5273
}

src/TensorFlowNET.Keras/Callbacks/ICallback.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,9 @@ public interface ICallback
1111
void on_train_batch_begin(long step);
1212
void on_train_batch_end(long end_step, Dictionary<string, float> logs);
1313
void on_epoch_end(int epoch, Dictionary<string, float> epoch_logs);
14+
void on_predict_begin();
15+
void on_predict_batch_begin(long step);
16+
void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs);
17+
void on_predict_end();
1418
}
1519
}

src/TensorFlowNET.Keras/Callbacks/ProgbarLogger.cs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
using PureHDF;
2-
using System;
1+
using System;
32
using System.Collections.Generic;
43
using System.Diagnostics;
54
using System.Linq;
@@ -77,5 +76,26 @@ void _maybe_init_progbar()
7776
{
7877

7978
}
79+
80+
public void on_predict_begin()
81+
{
82+
_reset_progbar();
83+
_maybe_init_progbar();
84+
}
85+
86+
public void on_predict_batch_begin(long step)
87+
{
88+
89+
}
90+
91+
public void on_predict_batch_end(long end_step, Dictionary<string, Tensors> logs)
92+
{
93+
94+
}
95+
96+
public void on_predict_end()
97+
{
98+
99+
}
80100
}
81101
}

src/TensorFlowNET.Keras/Engine/Model.Predict.cs

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,70 @@
55
using Tensorflow.Keras.ArgsDefinition;
66
using Tensorflow.Keras.Engine.DataAdapters;
77
using static Tensorflow.Binding;
8+
using Tensorflow.Keras.Callbacks;
89

910
namespace Tensorflow.Keras.Engine
1011
{
1112
public partial class Model
1213
{
14+
public Tensors predict(IDatasetV2 dataset,
15+
int batch_size = -1,
16+
int verbose = 0,
17+
int steps = -1,
18+
int max_queue_size = 10,
19+
int workers = 1,
20+
bool use_multiprocessing = false)
21+
{
22+
var data_handler = new DataHandler(new DataHandlerArgs
23+
{
24+
Dataset = dataset,
25+
BatchSize = batch_size,
26+
StepsPerEpoch = steps,
27+
InitialEpoch = 0,
28+
Epochs = 1,
29+
MaxQueueSize = max_queue_size,
30+
Workers = workers,
31+
UseMultiprocessing = use_multiprocessing,
32+
Model = this,
33+
StepsPerExecution = _steps_per_execution
34+
});
35+
36+
var callbacks = new CallbackList(new CallbackParams
37+
{
38+
Model = this,
39+
Verbose = verbose,
40+
Epochs = 1,
41+
Steps = data_handler.Inferredsteps
42+
});
43+
44+
Tensor batch_outputs = null;
45+
_predict_counter.assign(0);
46+
callbacks.on_predict_begin();
47+
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
48+
{
49+
foreach (var step in data_handler.steps())
50+
{
51+
callbacks.on_predict_batch_begin(step);
52+
var tmp_batch_outputs = run_predict_step(iterator);
53+
if (batch_outputs == null)
54+
{
55+
batch_outputs = tmp_batch_outputs[0];
56+
}
57+
else
58+
{
59+
batch_outputs = tf.concat(new Tensor[] { batch_outputs, tmp_batch_outputs[0] }, axis: 0);
60+
}
61+
62+
var end_step = step + data_handler.StepIncrement;
63+
callbacks.on_predict_batch_end(end_step, new Dictionary<string, Tensors> { { "outputs", batch_outputs } });
64+
}
65+
GC.Collect();
66+
}
67+
68+
callbacks.on_predict_end();
69+
return batch_outputs;
70+
}
71+
1372
/// <summary>
1473
/// Generates output predictions for the input samples.
1574
/// </summary>

0 commit comments

Comments
 (0)