Skip to content

Commit 6b154dd

Browse files
committed
tf.log, tf.nn.softmax
1 parent bd42ed9 commit 6b154dd

File tree

9 files changed

+97
-10
lines changed

9 files changed

+97
-10
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ public static Tensor sqrt(Tensor a, string name = null)
1818
public static Tensor subtract<T>(Tensor x, T[] y, string name = null) where T : struct
1919
=> gen_math_ops.sub(x, ops.convert_to_tensor(y, dtype: x.dtype.as_base_dtype(), name: "y"), name);
2020

21+
public static Tensor log(Tensor x, string name = null)
22+
=> gen_math_ops.log(x, name);
23+
2124
public static Tensor multiply(Tensor x, Tensor y)
2225
=> gen_math_ops.mul(x, y);
2326

@@ -33,11 +36,11 @@ public static Tensor pow<T1, T2>(T1 x, T2 y)
3336
/// <param name="input"></param>
3437
/// <param name="axis"></param>
3538
/// <returns></returns>
36-
public static Tensor reduce_sum(Tensor input, int[] axis = null)
39+
public static Tensor reduce_sum(Tensor input, int[] axis = null, int? reduction_indices = null)
3740
=> math_ops.reduce_sum(input);
3841

39-
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
40-
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name);
42+
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
43+
=> math_ops.reduce_mean(input_tensor, axis: axis, keepdims: keepdims, name: name, reduction_indices: reduction_indices);
4144

4245
public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
4346
=> math_ops.cast(x, dtype, name);

src/TensorFlowNET.Core/APIs/tf.nn.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ public static Tensor bias_add(Tensor value, RefVariable bias, string data_format
5656
});
5757
}
5858

59+
public static Tensor softmax(Tensor logits, int axis = -1, string name = null)
60+
=> gen_nn_ops.softmax(logits, name);
61+
5962
public static Tensor softmax_cross_entropy_with_logits_v2(Tensor labels, Tensor logits, int axis = -1, string name = null)
6063
=> nn_ops.softmax_cross_entropy_with_logits_v2_helper(labels, logits, axis: axis, name: name);
6164
}

src/TensorFlowNET.Core/Operations/NnOps/gen_nn_ops.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,16 @@ public static Tensor relu_grad(Tensor gradients, Tensor features, string name =
146146
return _op.outputs[0];
147147
}
148148

149+
public static Tensor softmax(Tensor logits, string name = null)
150+
{
151+
var _op = _op_def_lib._apply_op_helper("Softmax", name: name, args: new
152+
{
153+
logits
154+
});
155+
156+
return _op.outputs[0];
157+
}
158+
149159
/// <summary>
150160
/// Computes softmax cross entropy cost and gradients to backpropagate.
151161
/// </summary>

src/TensorFlowNET.Core/Operations/array_ops.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ private static Tensor _constant_if_small<T>(T value, Shape shape, TF_DataType dt
4242
else
4343
{
4444
tShape = constant_op._tensor_shape_tensor_conversion_function(shape.as_shape());
45-
var c = constant_op.constant(0);
45+
var c = constant_op.constant(0, dtype: dtype);
4646
return gen_array_ops.fill(tShape, c, name: name);
4747
}
4848
}

src/TensorFlowNET.Core/Operations/math_ops.py.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ public static Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, s
3838
/// dimensions.Must be in the range `[-rank(input_tensor), rank(input_tensor))`.</param>
3939
/// <param name="keepdims"> If true, retains reduced dimensions with length 1.</param>
4040
/// <param name="name"> A name for the operation (optional).</param>
41-
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null)
41+
public static Tensor reduce_mean(Tensor input_tensor, int[] axis = null, bool keepdims = false, string name = null, int? reduction_indices = null)
4242
{
4343
var r = _ReductionDims(input_tensor, axis);
4444
if (axis == null)

test/TensorFlowNET.Examples/LogisticRegression.cs

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using System;
1+
using NumSharp.Core;
2+
using System;
23
using System.Collections.Generic;
34
using System.Text;
45
using Tensorflow;
@@ -13,15 +14,46 @@ namespace TensorFlowNET.Examples
1314
/// </summary>
1415
public class LogisticRegression : Python, IExample
1516
{
17+
private float learning_rate = 0.01f;
18+
private int training_epochs = 25;
19+
private int batch_size = 100;
20+
private int display_step = 1;
21+
1622
public void Run()
1723
{
1824
PrepareData();
1925
}
2026

2127
private void PrepareData()
2228
{
23-
MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
24-
29+
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
30+
31+
// tf Graph Input
32+
var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784
33+
var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes
34+
35+
// Set model weights
36+
var W = tf.Variable(tf.zeros(new Shape(784, 10)));
37+
var b = tf.Variable(tf.zeros(new Shape(10)));
38+
39+
// Construct model
40+
var pred = tf.nn.softmax(tf.matmul(x, W) + b); // Softmax
41+
42+
// Minimize error using cross entropy
43+
var sum = -tf.reduce_sum(y * tf.log(pred), reduction_indices: 1);
44+
var cost = tf.reduce_mean(sum);
45+
46+
// Gradient Descent
47+
var optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost);
48+
49+
// Initialize the variables (i.e. assign their default value)
50+
var init = tf.global_variables_initializer();
51+
52+
with(tf.Session(), sess =>
53+
{
54+
// Run the initializer
55+
sess.run(init);
56+
});
2557
}
2658
}
2759
}

test/TensorFlowNET.Examples/Utility/DataSet.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,22 @@ namespace TensorFlowNET.Examples.Utility
99
public class DataSet
1010
{
1111
private int _num_examples;
12+
private int _epochs_completed;
13+
private int _index_in_epoch;
14+
private NDArray _images;
15+
private NDArray _labels;
1216

1317
public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
1418
{
1519
_num_examples = images.shape[0];
1620
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
21+
images.astype(dtype.as_numpy_datatype());
1722
images = np.multiply(images, 1.0f / 255.0f);
23+
24+
_images = images;
25+
_labels = labels;
26+
_epochs_completed = 0;
27+
_index_in_epoch = 0;
1828
}
1929
}
2030
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace TensorFlowNET.Examples.Utility
6+
{
7+
public class Datasets
8+
{
9+
private DataSet _train;
10+
public DataSet train => _train;
11+
12+
private DataSet _validation;
13+
public DataSet validation => _validation;
14+
15+
private DataSet _test;
16+
public DataSet test => _test;
17+
18+
public Datasets(DataSet train, DataSet validation, DataSet test)
19+
{
20+
_train = train;
21+
_validation = validation;
22+
_test = test;
23+
}
24+
}
25+
}

test/TensorFlowNET.Examples/Utility/MnistDataSet.cs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,9 @@ public class MnistDataSet
1818
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
1919
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";
2020

21-
public static void read_data_sets(string train_dir,
21+
public static Datasets read_data_sets(string train_dir,
2222
bool one_hot = false,
23-
TF_DataType dtype = TF_DataType.DtInvalid,
23+
TF_DataType dtype = TF_DataType.TF_FLOAT,
2424
bool reshape = true,
2525
int validation_size = 5000,
2626
string source_url = DEFAULT_SOURCE_URL)
@@ -48,6 +48,10 @@ public static void read_data_sets(string train_dir,
4848
train_labels = train_labels[np.arange(validation_size, end)];
4949

5050
var train = new DataSet(train_images, train_labels, dtype, reshape);
51+
var validation = new DataSet(validation_images, validation_labels, dtype, reshape);
52+
var test = new DataSet(test_images, test_labels, dtype, reshape);
53+
54+
return new Datasets(train, validation, test);
5155
}
5256

5357
public static NDArray extract_images(string file)

0 commit comments

Comments
 (0)