Skip to content

Commit f462c55

Browse files
committed
print Accuracy of LogisticRegression
1 parent 9c161b1 commit f462c55

File tree

5 files changed

+41
-8
lines changed

5 files changed

+41
-8
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ public static Tensor subtract<T>(Tensor x, T[] y, string name = null) where T :
2121
public static Tensor log(Tensor x, string name = null)
2222
=> gen_math_ops.log(x, name);
2323

24+
public static Tensor equal(Tensor x, Tensor y, string name = null)
25+
=> gen_math_ops.equal(x, y, name);
26+
2427
public static Tensor multiply(Tensor x, Tensor y)
2528
=> gen_math_ops.mul(x, y);
2629

src/TensorFlowNET.Core/Operations/gen_math_ops.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,20 @@ public static Tensor sub<Tx, Ty>(Tx x, Ty y, string name = null)
128128
return _op.outputs[0];
129129
}
130130

131+
/// <summary>
132+
/// Returns the truth value of (x == y) element-wise.
133+
/// </summary>
134+
/// <param name="x"></param>
135+
/// <param name="y"></param>
136+
/// <param name="name"></param>
137+
/// <returns></returns>
138+
public static Tensor equal(Tensor x, Tensor y, string name = null)
139+
{
140+
var _op = _op_def_lib._apply_op_helper("Equal", name, args: new { x, y });
141+
142+
return _op.outputs[0];
143+
}
144+
131145
public static Tensor mul(Tensor x, Tensor y, string name = null)
132146
{
133147
var _op = _op_def_lib._apply_op_helper("Mul", name, args: new { x, y });

src/TensorFlowNET.Core/Sessions/_ElementFetchMapper.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ public override NDArray build_results(List<object> values)
4343
case NDArray value:
4444
result = value;
4545
break;
46+
case float fVal:
47+
result = fVal;
48+
break;
4649
default:
4750
break;
4851
}

src/TensorFlowNET.Core/Tensors/Tensor.cs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,12 @@ public Tensor MaybeMove()
168168
/// <param name="feed_dict">A dictionary that maps `Tensor` objects to feed values.</param>
169169
/// <param name="session">The `Session` to be used to evaluate this tensor.</param>
170170
/// <returns></returns>
171-
public NDArray eval(FeedItem[] feed_dict = null, Session session = null)
171+
public NDArray eval(params FeedItem[] feed_dict)
172+
{
173+
return ops._eval_using_default_session(this, feed_dict, graph);
174+
}
175+
176+
public NDArray eval(Session session, FeedItem[] feed_dict = null)
172177
{
173178
return ops._eval_using_default_session(this, feed_dict, graph, session);
174179
}

test/TensorFlowNET.Examples/LogisticRegression.cs

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,14 @@ namespace TensorFlowNET.Examples
1717
public class LogisticRegression : Python, IExample
1818
{
1919
private float learning_rate = 0.01f;
20-
private int training_epochs = 25;
20+
private int training_epochs = 5;
2121
private int batch_size = 100;
2222
private int display_step = 1;
2323

2424
public void Run()
2525
{
26-
PrepareData();
27-
}
26+
var mnist = PrepareData();
2827

29-
private void PrepareData()
30-
{
3128
// tf Graph Input
3229
var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784
3330
var y = tf.placeholder(tf.float32, new TensorShape(-1, 10)); // 0-9 digits recognition => 10 classes
@@ -50,12 +47,12 @@ private void PrepareData()
5047

5148
with(tf.Session(), sess =>
5249
{
53-
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
50+
5451
// Run the initializer
5552
sess.run(init);
5653

5754
// Training cycle
58-
foreach(var epoch in range(training_epochs))
55+
foreach (var epoch in range(training_epochs))
5956
{
6057
var avg_cost = 0.0f;
6158
var total_batch = mnist.train.num_examples / batch_size;
@@ -81,7 +78,18 @@ private void PrepareData()
8178
print("Optimization Finished!");
8279

8380
// Test model
81+
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1));
82+
// Calculate accuracy
83+
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
84+
float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels));
85+
print($"Accuracy: {acc.ToString("F4")}");
8486
});
8587
}
88+
89+
private Datasets PrepareData()
90+
{
91+
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
92+
return mnist;
93+
}
8694
}
8795
}

0 commit comments

Comments
 (0)