Skip to content

Commit 369d8ab

Browse files
committed
_ShapesFullySpecifiedAndEqual
1 parent b95601f commit 369d8ab

File tree

5 files changed

+88
-10
lines changed

5 files changed

+88
-10
lines changed

src/TensorFlowNET.Core/Gradients/math_grad.cs

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ public static Tensor[] _AddGrad(Operation op, Tensor[] grads)
2222
var sy = array_ops.shape(y);
2323
var (rx, ry) = gen_array_ops.broadcast_gradient_args(sx, sy);
2424

25-
var r1 = gen_array_ops.reshape(math_ops.reduce_sum(grad, rx), sx);
26-
var r2 = gen_array_ops.reshape(math_ops.reduce_sum(grad, ry), sy);
25+
var sum1 = math_ops.reduce_sum(grad, rx);
26+
var r1 = gen_array_ops.reshape(sum1, sx);
27+
var sum2 = math_ops.reduce_sum(grad, ry);
28+
var r2 = gen_array_ops.reshape(sum2, sy);
2729

2830
return new Tensor[] { r1, r2 };
2931
}
@@ -48,7 +50,8 @@ public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
4850
var x = op.inputs[0];
4951
var y = op.inputs[1];
5052
var grad = grads[0];
51-
if (grad is Tensor && _ShapesFullySpecifiedAndEqual(x, y, grad) &&
53+
if (grad is Tensor &&
54+
_ShapesFullySpecifiedAndEqual(x, y, grad) &&
5255
new TF_DataType[] { tf.int32, tf.float32 }.Contains(grad.dtype))
5356
return new Tensor[] { gen_math_ops.mul(grad, y), gen_math_ops.mul(grad, x) };
5457

@@ -60,10 +63,11 @@ public static Tensor[] _MulGrad(Operation op, Tensor[] grads)
6063
y = math_ops.conj(y);
6164

6265
var mul1 = gen_math_ops.mul(grad, y);
63-
var mul2 = gen_math_ops.mul(x, grad);
6466
var reduce_sum1 = math_ops.reduce_sum(mul1, rx);
65-
var reduce_sum2 = math_ops.reduce_sum(mul2, ry);
6667
var reshape1 = gen_array_ops.reshape(reduce_sum1, sx);
68+
69+
var mul2 = gen_math_ops.mul(x, grad);
70+
var reduce_sum2 = math_ops.reduce_sum(mul2, ry);
6771
var reshape2 = gen_array_ops.reshape(reduce_sum2, sy);
6872

6973
return new Tensor[] { reshape1, reshape2 };
@@ -146,7 +150,13 @@ public static Tensor[] _SubGrad(Operation op, Tensor[] grads)
146150

147151
public static bool _ShapesFullySpecifiedAndEqual(Tensor x, Tensor y, Tensor grad)
148152
{
149-
return x.NDims == y.NDims && y.NDims == grad.NDims && x.NDims > -1;
153+
var x_shape = x._shape_tuple();
154+
var y_shape = y._shape_tuple();
155+
var grad_shape = grad._shape_tuple();
156+
return Enumerable.SequenceEqual(x_shape, y_shape) &&
157+
Enumerable.SequenceEqual(y_shape, grad_shape) &&
158+
x.NDims != -1 &&
159+
!x_shape.Contains(-1);
150160
}
151161

152162
public static Tensor[] _SumGrad(Operation op, Tensor[] grads)

src/TensorFlowNET.Core/Python.cs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using System;
33
using System.Collections.Generic;
44
using System.ComponentModel;
5+
using System.Linq;
56
using System.Text;
67

78
namespace Tensorflow
@@ -16,6 +17,11 @@ protected void print(object obj)
1617
Console.WriteLine(obj.ToString());
1718
}
1819

20+
protected IEnumerable<int> range(int end)
21+
{
22+
return Enumerable.Range(0, end);
23+
}
24+
1925
public static T New<T>(object args) where T : IPyClass
2026
{
2127
var instance = Activator.CreateInstance<T>();

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ public partial class Tensor : Python, IDisposable, ITensorOrOperation
4343
public IntPtr buffer => _handle == IntPtr.Zero ? IntPtr.Zero : c_api.TF_TensorData(_handle);
4444
public int num_consumers(TF_Output oper_out) => _handle == IntPtr.Zero ? 0 : c_api.TF_OperationOutputNumConsumers(oper_out);
4545

46+
private TF_Output? _tf_output;
47+
4648
public long[] shape
4749
{
4850
get
@@ -123,7 +125,10 @@ public Operation[] consumers()
123125

124126
public TF_Output _as_tf_output()
125127
{
126-
return new TF_Output(op, value_index);
128+
if(!_tf_output.HasValue)
129+
_tf_output = new TF_Output(op, value_index);
130+
131+
return _tf_output.Value;
127132
}
128133

129134
public T[] Data<T>()

test/TensorFlowNET.Examples/LogisticRegression.cs

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
using NumSharp.Core;
1+
using Newtonsoft.Json;
2+
using NumSharp.Core;
23
using System;
34
using System.Collections.Generic;
5+
using System.Linq;
46
using System.Text;
57
using Tensorflow;
68
using TensorFlowNET.Examples.Utility;
@@ -26,8 +28,6 @@ public void Run()
2628

2729
private void PrepareData()
2830
{
29-
//var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
30-
3131
// tf Graph Input
3232
var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784
3333
var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes
@@ -49,13 +49,37 @@ private void PrepareData()
4949
// Gradient Descent
5050
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
5151

52+
//var new_saver = tf.train.import_meta_graph("logistic_regression.meta.bin");
53+
54+
/*var text = JsonConvert.SerializeObject(tf.get_default_graph(), new JsonSerializerSettings
55+
{
56+
Formatting = Formatting.Indented
57+
});*/
58+
5259
// Initialize the variables (i.e. assign their default value)
5360
var init = tf.global_variables_initializer();
5461

5562
with(tf.Session(), sess =>
5663
{
64+
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
5765
// Run the initializer
5866
sess.run(init);
67+
68+
// Training cycle
69+
foreach(var epoch in range(training_epochs))
70+
{
71+
var avg_cost = 0.0f;
72+
var total_batch = (int)(mnist.train.num_examples / batch_size);
73+
// Loop over all batches
74+
foreach (var i in range(total_batch))
75+
{
76+
var (batch_xs, batch_ys) = mnist.train.next_batch(batch_size);
77+
// Run optimization op (backprop) and cost op (to get loss value)
78+
/*sess.run(optimizer,
79+
new FeedItem(x, batch_xs),
80+
new FeedItem(y, batch_ys));*/
81+
}
82+
}
5983
});
6084
}
6185
}

test/TensorFlowNET.Examples/Utility/DataSet.cs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,15 @@ namespace TensorFlowNET.Examples.Utility
99
public class DataSet
1010
{
1111
private int _num_examples;
12+
public int num_examples => _num_examples;
1213
private int _epochs_completed;
14+
public int epochs_completed => _epochs_completed;
1315
private int _index_in_epoch;
16+
public int index_in_epoch => _index_in_epoch;
1417
private NDArray _images;
18+
public NDArray images => _images;
1519
private NDArray _labels;
20+
public NDArray labels => _labels;
1621

1722
public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
1823
{
@@ -26,5 +31,33 @@ public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
2631
_epochs_completed = 0;
2732
_index_in_epoch = 0;
2833
}
34+
35+
public (int, int) next_batch(int batch_size, bool fake_data = false, bool shuffle = true)
36+
{
37+
var start = _index_in_epoch;
38+
// Shuffle for the first epoch
39+
if(_epochs_completed == 0 && start == 0 && shuffle)
40+
{
41+
var perm0 = np.arange(_num_examples);
42+
np.random.shuffle(perm0);
43+
_images = images[perm0];
44+
_labels = labels[perm0];
45+
}
46+
47+
// Go to the next epoch
48+
if (start + batch_size > _num_examples)
49+
{
50+
// Finished epoch
51+
_epochs_completed += 1;
52+
53+
throw new NotImplementedException("next_batch");
54+
}
55+
else
56+
{
57+
_index_in_epoch += batch_size;
58+
var end = _index_in_epoch;
59+
return (_images[np.arange(start, end)], _labels[np.arange(start, end)]);
60+
}
61+
}
2962
}
3063
}

0 commit comments

Comments
 (0)