Skip to content

Commit 95ee0e8

Browse files
committed
Merge branch 'master' of https://github.com/SciSharp/TensorFlow.NET into rnn-dev
2 parents 51b5f17 + 0454c7b commit 95ee0e8

File tree

8 files changed

+143
-93
lines changed

8 files changed

+143
-93
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Tensorflow.NumPy;
1718
using Tensorflow.Operations;
1819

1920
namespace Tensorflow
@@ -42,7 +43,6 @@ public Tensor erf(Tensor x, string name = null)
4243

4344
public Tensor multiply(Tensor x, Tensor y, string name = null)
4445
=> math_ops.multiply(x, y, name: name);
45-
4646
public Tensor divide_no_nan(Tensor a, Tensor b, string name = null)
4747
=> math_ops.div_no_nan(a, b);
4848

@@ -452,7 +452,18 @@ public Tensor multiply(Tensor x, Tensor y, string name = null)
452452
/// <returns></returns>
453453
public Tensor multiply<Tx, Ty>(Tx x, Ty y, string name = null)
454454
=> gen_math_ops.mul(ops.convert_to_tensor(x), ops.convert_to_tensor(y), name: name);
455-
455+
/// <summary>
456+
/// return scalar product
457+
/// </summary>
458+
/// <typeparam name="Tx"></typeparam>
459+
/// <typeparam name="Ty"></typeparam>
460+
/// <param name="x"></param>
461+
/// <param name="y"></param>
462+
/// <param name="axes"></param>
463+
/// <param name="name"></param>
464+
/// <returns></returns>
465+
public Tensor dot_prod<Tx, Ty>(Tx x, Ty y, NDArray axes, string name = null)
466+
=> math_ops.tensordot(convert_to_tensor(x), convert_to_tensor(y), axes, name: name);
456467
public Tensor negative(Tensor x, string name = null)
457468
=> gen_math_ops.neg(x, name);
458469

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,28 @@ public static Shape GetShape(this object data)
486486
throw new NotImplementedException("");
487487
}
488488
}
489-
489+
public static NDArray GetFlattenArray(NDArray x)
490+
{
491+
switch (x.GetDataType())
492+
{
493+
case TF_DataType.TF_FLOAT:
494+
x = x.ToArray<float>();
495+
break;
496+
case TF_DataType.TF_DOUBLE:
497+
x = x.ToArray<double>();
498+
break;
499+
case TF_DataType.TF_INT16:
500+
case TF_DataType.TF_INT32:
501+
x = x.ToArray<int>();
502+
break;
503+
case TF_DataType.TF_INT64:
504+
x = x.ToArray<long>();
505+
break;
506+
default:
507+
break;
508+
}
509+
return x;
510+
}
490511
public static TF_DataType GetDataType(this object data)
491512
{
492513
var type = data.GetType();

src/TensorFlowNET.Core/Keras/Engine/IModel.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ void load_weights(string filepath,
6060
bool skip_mismatch = false,
6161
object options = null);
6262

63-
Dictionary<string, float> evaluate(NDArray x, NDArray y,
63+
Dictionary<string, float> evaluate(Tensor x, Tensor y,
6464
int batch_size = -1,
6565
int verbose = 1,
6666
int steps = -1,

src/TensorFlowNET.Core/NumPy/Numpy.Math.cs

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,30 @@ public static NDArray prod(NDArray array, Axis? axis = null, Type? dtype = null,
4949
[AutoNumPy]
5050
public static NDArray prod<T>(params T[] array) where T : unmanaged
5151
=> new NDArray(tf.reduce_prod(new NDArray(array)));
52+
[AutoNumPy]
53+
public static NDArray dot(NDArray x1, NDArray x2, NDArray? axes = null, string? name = null)
54+
{
55+
//if axes mentioned
56+
if (axes != null)
57+
{
58+
return new NDArray(tf.dot_prod(x1, x2, axes, name));
59+
}
60+
if (x1.shape.ndim > 1)
61+
{
62+
x1 = GetFlattenArray(x1);
63+
}
64+
if (x2.shape.ndim > 1)
65+
{
66+
x2 = GetFlattenArray(x2);
67+
}
68+
//if axes not mentioned, default 0,0
69+
return new NDArray(tf.dot_prod(x1, x2, axes: new int[] { 0, 0 }, name));
5270

71+
}
5372
[AutoNumPy]
5473
public static NDArray power(NDArray x, NDArray y) => new NDArray(tf.pow(x, y));
74+
[AutoNumPy]
75+
public static NDArray square(NDArray x) => new NDArray(tf.square(x));
5576

5677
[AutoNumPy]
5778
public static NDArray sin(NDArray x) => new NDArray(math_ops.sin(x));

src/TensorFlowNET.Core/Tensors/Tensors.cs

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -226,62 +226,62 @@ public T[] ToArray<T>() where T: unmanaged
226226
}
227227

228228
#region Explicit Conversions
229-
public unsafe static explicit operator bool(Tensors tensor)
229+
public static explicit operator bool(Tensors tensor)
230230
{
231231
return (bool)tensor.Single;
232232
}
233233

234-
public unsafe static explicit operator sbyte(Tensors tensor)
234+
public static explicit operator sbyte(Tensors tensor)
235235
{
236236
return (sbyte)tensor.Single;
237237
}
238238

239-
public unsafe static explicit operator byte(Tensors tensor)
239+
public static explicit operator byte(Tensors tensor)
240240
{
241241
return (byte)tensor.Single;
242242
}
243243

244-
public unsafe static explicit operator ushort(Tensors tensor)
244+
public static explicit operator ushort(Tensors tensor)
245245
{
246246
return (ushort)tensor.Single;
247247
}
248248

249-
public unsafe static explicit operator short(Tensors tensor)
249+
public static explicit operator short(Tensors tensor)
250250
{
251251
return (short)tensor.Single;
252252
}
253253

254-
public unsafe static explicit operator int(Tensors tensor)
254+
public static explicit operator int(Tensors tensor)
255255
{
256256
return (int)tensor.Single;
257257
}
258258

259-
public unsafe static explicit operator uint(Tensors tensor)
259+
public static explicit operator uint(Tensors tensor)
260260
{
261261
return (uint)tensor.Single;
262262
}
263263

264-
public unsafe static explicit operator long(Tensors tensor)
264+
public static explicit operator long(Tensors tensor)
265265
{
266266
return (long)tensor.Single;
267267
}
268268

269-
public unsafe static explicit operator ulong(Tensors tensor)
269+
public static explicit operator ulong(Tensors tensor)
270270
{
271271
return (ulong)tensor.Single;
272272
}
273273

274-
public unsafe static explicit operator float(Tensors tensor)
274+
public static explicit operator float(Tensors tensor)
275275
{
276276
return (byte)tensor.Single;
277277
}
278278

279-
public unsafe static explicit operator double(Tensors tensor)
279+
public static explicit operator double(Tensors tensor)
280280
{
281281
return (double)tensor.Single;
282282
}
283283

284-
public unsafe static explicit operator string(Tensors tensor)
284+
public static explicit operator string(Tensors tensor)
285285
{
286286
return (string)tensor.Single;
287287
}
Lines changed: 45 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
1-
using Tensorflow.NumPy;
21
using System;
32
using System.Collections.Generic;
43
using System.Linq;
4+
using Tensorflow;
55
using Tensorflow.Keras.ArgsDefinition;
6+
using Tensorflow.Keras.Callbacks;
67
using Tensorflow.Keras.Engine.DataAdapters;
7-
using static Tensorflow.Binding;
88
using Tensorflow.Keras.Layers;
99
using Tensorflow.Keras.Utils;
10-
using Tensorflow;
11-
using Tensorflow.Keras.Callbacks;
10+
using Tensorflow.NumPy;
11+
using static Tensorflow.Binding;
1212

1313
namespace Tensorflow.Keras.Engine
1414
{
@@ -27,7 +27,7 @@ public partial class Model
2727
/// <param name="use_multiprocessing"></param>
2828
/// <param name="return_dict"></param>
2929
/// <param name="is_val"></param>
30-
public Dictionary<string, float> evaluate(NDArray x, NDArray y,
30+
public Dictionary<string, float> evaluate(Tensor x, Tensor y,
3131
int batch_size = -1,
3232
int verbose = 1,
3333
int steps = -1,
@@ -64,34 +64,11 @@ public Dictionary<string, float> evaluate(NDArray x, NDArray y,
6464
Verbose = verbose,
6565
Steps = data_handler.Inferredsteps
6666
});
67-
callbacks.on_test_begin();
68-
69-
//Dictionary<string, float>? logs = null;
70-
var logs = new Dictionary<string, float>();
71-
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
72-
{
73-
reset_metrics();
74-
// data_handler.catch_stop_iteration();
75-
76-
foreach (var step in data_handler.steps())
77-
{
78-
callbacks.on_test_batch_begin(step);
79-
logs = test_function(data_handler, iterator);
80-
var end_step = step + data_handler.StepIncrement;
81-
if (is_val == false)
82-
callbacks.on_test_batch_end(end_step, logs);
83-
}
84-
}
8567

86-
var results = new Dictionary<string, float>();
87-
foreach (var log in logs)
88-
{
89-
results[log.Key] = log.Value;
90-
}
91-
return results;
68+
return evaluate(data_handler, callbacks, is_val, test_function);
9269
}
9370

94-
public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int verbose = 1, bool is_val = false)
71+
public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, Tensor y, int verbose = 1, bool is_val = false)
9572
{
9673
var data_handler = new DataHandler(new DataHandlerArgs
9774
{
@@ -107,34 +84,10 @@ public Dictionary<string, float> evaluate(IEnumerable<Tensor> x, NDArray y, int
10784
Verbose = verbose,
10885
Steps = data_handler.Inferredsteps
10986
});
110-
callbacks.on_test_begin();
11187

112-
Dictionary<string, float> logs = null;
113-
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
114-
{
115-
reset_metrics();
116-
callbacks.on_epoch_begin(epoch);
117-
// data_handler.catch_stop_iteration();
118-
119-
foreach (var step in data_handler.steps())
120-
{
121-
callbacks.on_test_batch_begin(step);
122-
logs = test_step_multi_inputs_function(data_handler, iterator);
123-
var end_step = step + data_handler.StepIncrement;
124-
if (is_val == false)
125-
callbacks.on_test_batch_end(end_step, logs);
126-
}
127-
}
128-
129-
var results = new Dictionary<string, float>();
130-
foreach (var log in logs)
131-
{
132-
results[log.Key] = log.Value;
133-
}
134-
return results;
88+
return evaluate(data_handler, callbacks, is_val, test_step_multi_inputs_function);
13589
}
13690

137-
13891
public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is_val = false)
13992
{
14093
var data_handler = new DataHandler(new DataHandlerArgs
@@ -150,9 +103,24 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
150103
Verbose = verbose,
151104
Steps = data_handler.Inferredsteps
152105
});
106+
107+
return evaluate(data_handler, callbacks, is_val, test_function);
108+
}
109+
110+
/// <summary>
111+
/// Internal bare implementation of evaluate function.
112+
/// </summary>
113+
/// <param name="data_handler">Interations handling objects</param>
114+
/// <param name="callbacks"></param>
115+
/// <param name="test_func">The function to be called on each batch of data.</param>
116+
/// <param name="is_val">Whether it is validation or test.</param>
117+
/// <returns></returns>
118+
Dictionary<string, float> evaluate(DataHandler data_handler, CallbackList callbacks, bool is_val, Func<DataHandler, Tensor[], Dictionary<string, float>> test_func)
119+
{
153120
callbacks.on_test_begin();
154121

155-
Dictionary<string, float> logs = null;
122+
var results = new Dictionary<string, float>();
123+
var logs = results;
156124
foreach (var (epoch, iterator) in data_handler.enumerate_epochs())
157125
{
158126
reset_metrics();
@@ -162,45 +130,47 @@ public Dictionary<string, float> evaluate(IDatasetV2 x, int verbose = 1, bool is
162130
foreach (var step in data_handler.steps())
163131
{
164132
callbacks.on_test_batch_begin(step);
165-
logs = test_function(data_handler, iterator);
133+
134+
logs = test_func(data_handler, iterator.next());
135+
136+
tf_with(ops.control_dependencies(Array.Empty<object>()), ctl => _train_counter.assign_add(1));
137+
166138
var end_step = step + data_handler.StepIncrement;
167-
if (is_val == false)
139+
if (!is_val)
168140
callbacks.on_test_batch_end(end_step, logs);
169141
}
142+
143+
if (!is_val)
144+
callbacks.on_epoch_end(epoch, logs);
170145
}
171146

172-
var results = new Dictionary<string, float>();
173147
foreach (var log in logs)
174148
{
175149
results[log.Key] = log.Value;
176150
}
151+
177152
return results;
178153
}
179154

180-
Dictionary<string, float> test_function(DataHandler data_handler, OwnedIterator iterator)
155+
Dictionary<string, float> test_function(DataHandler data_handler, Tensor[] data)
181156
{
182-
var data = iterator.next();
183-
var outputs = test_step(data_handler, data[0], data[1]);
184-
tf_with(ops.control_dependencies(new object[0]), ctl => _test_counter.assign_add(1));
157+
var (x, y) = data_handler.DataAdapter.Expand1d(data[0], data[1]);
158+
159+
var y_pred = Apply(x, training: false);
160+
var loss = compiled_loss.Call(y, y_pred);
161+
162+
compiled_metrics.update_state(y, y_pred);
163+
164+
var outputs = metrics.Select(x => (x.Name, x.result())).ToDictionary(x => x.Name, x => (float)x.Item2);
185165
return outputs;
186166
}
187-
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, OwnedIterator iterator)
167+
168+
Dictionary<string, float> test_step_multi_inputs_function(DataHandler data_handler, Tensor[] data)
188169
{
189-
var data = iterator.next();
190170
var x_size = data_handler.DataAdapter.GetDataset().FirstInputTensorCount;
191171
var outputs = train_step(data_handler, new Tensors(data.Take(x_size).ToArray()), new Tensors(data.Skip(x_size).ToArray()));
192172
tf_with(ops.control_dependencies(new object[0]), ctl => _train_counter.assign_add(1));
193173
return outputs;
194174
}
195-
Dictionary<string, float> test_step(DataHandler data_handler, Tensor x, Tensor y)
196-
{
197-
(x, y) = data_handler.DataAdapter.Expand1d(x, y);
198-
var y_pred = Apply(x, training: false);
199-
var loss = compiled_loss.Call(y, y_pred);
200-
201-
compiled_metrics.update_state(y, y_pred);
202-
203-
return metrics.Select(x => (x.Name, x.result())).ToDictionary(x=>x.Item1, x=>(float)x.Item2);
204-
}
205175
}
206176
}

src/TensorFlowNET.Keras/Engine/Model.Fit.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -266,7 +266,7 @@ History FitInternal(DataHandler data_handler, int epochs, int verbose, List<ICal
266266
{
267267
// Because evaluate calls call_test_batch_end, this interferes with our output on the screen
268268
// so we need to pass a is_val parameter to stop on_test_batch_end
269-
var val_logs = evaluate(validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
269+
var val_logs = evaluate((Tensor)validation_data.Value.Item1, validation_data.Value.Item2, is_val:true);
270270
foreach (var log in val_logs)
271271
{
272272
logs["val_" + log.Key] = log.Value;

0 commit comments

Comments
 (0)