From 15e4336e20ecc040bd828ad3c5f47dc7b91ce5b0 Mon Sep 17 00:00:00 2001 From: Luc BOLOGNA Date: Mon, 29 May 2023 19:45:34 +0200 Subject: [PATCH] Update PredictInternational on Model.Predict.cs Fix issue if data_handler.steps() > 1 --- src/TensorFlowNET.Keras/Engine/Model.Predict.cs | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs index fc8d784ca..cbe4a7295 100644 --- a/src/TensorFlowNET.Keras/Engine/Model.Predict.cs +++ b/src/TensorFlowNET.Keras/Engine/Model.Predict.cs @@ -99,7 +99,8 @@ Tensors PredictInternal(DataHandler data_handler, int verbose) } else { - batch_outputs = tf.concat(new Tensor[] { batch_outputs, tmp_batch_outputs[0] }, axis: 0); + for (int i = 0; i < batch_outputs.Length; i++) + batch_outputs[i] = tf.concat(new Tensor[] { batch_outputs[i], tmp_batch_outputs[i] }, axis: 0); } var end_step = step + data_handler.StepIncrement; @@ -116,7 +117,7 @@ Tensors run_predict_step(OwnedIterator iterator) { var data = iterator.next(); var outputs = predict_step(data); - tf_with(ops.control_dependencies(new object[0]), ctl => _predict_counter.assign_add(1)); + tf_with(ops.control_dependencies(Array.Empty()), ctl => _predict_counter.assign_add(1)); return outputs; }