Skip to content

Commit 9721a34

Browse files
committed
refactor MNIST dataset.
1 parent 03336c8 commit 9721a34

File tree

10 files changed

+321
-119
lines changed

10 files changed

+321
-119
lines changed

test/TensorFlowNET.Examples/BasicModels/KMeansClustering.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ public class KMeansClustering : IExample
2727
public int? test_size = null;
2828
public int batch_size = 1024; // The number of samples per batch
2929

30-
Datasets mnist;
30+
Datasets<DataSetMnist> mnist;
3131
NDArray full_data_x;
3232
int num_steps = 20; // Total steps to train
3333
int k = 25; // The number of clusters
@@ -50,8 +50,8 @@ public bool Run()
5050

5151
public void PrepareData()
5252
{
53-
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size:validation_size, test_size:test_size);
54-
full_data_x = mnist.train.images;
53+
mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size:validation_size, test_size:test_size);
54+
full_data_x = mnist.train.data;
5555

5656
// download graph meta data
5757
string url = "https://raw.githubusercontent.com/SciSharp/TensorFlow.NET/master/graph/kmeans.meta";
@@ -141,7 +141,7 @@ public void Train(Session sess)
141141
var accuracy_op = tf.reduce_mean(cast);
142142

143143
// Test Model
144-
var (test_x, test_y) = (mnist.test.images, mnist.test.labels);
144+
var (test_x, test_y) = (mnist.test.data, mnist.test.labels);
145145
result = sess.run(accuracy_op, new FeedItem(X, test_x), new FeedItem(Y, test_y));
146146
accuray_test = result;
147147
print($"Test Accuracy: {accuray_test}");

test/TensorFlowNET.Examples/BasicModels/LogisticRegression.cs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ public class LogisticRegression : IExample
3232
private float learning_rate = 0.01f;
3333
private int display_step = 1;
3434

35-
Datasets mnist;
35+
Datasets<DataSetMnist> mnist;
3636

3737
public bool Run()
3838
{
@@ -102,7 +102,7 @@ public bool Run()
102102
var correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1));
103103
// Calculate accuracy
104104
var accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32));
105-
float acc = accuracy.eval(new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels));
105+
float acc = accuracy.eval(new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels));
106106
print($"Accuracy: {acc.ToString("F4")}");
107107

108108
return acc > 0.9;
@@ -111,7 +111,7 @@ public bool Run()
111111

112112
public void PrepareData()
113113
{
114-
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size: validation_size, test_size: test_size);
114+
mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: train_size, validation_size: validation_size, test_size: test_size);
115115
}
116116

117117
public void SaveModel(Session sess)

test/TensorFlowNET.Examples/BasicModels/NearestNeighbor.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ public class NearestNeighbor : IExample
1717
{
1818
public bool Enabled { get; set; } = true;
1919
public string Name => "Nearest Neighbor";
20-
Datasets mnist;
20+
Datasets<DataSetMnist> mnist;
2121
NDArray Xtr, Ytr, Xte, Yte;
2222
public int? TrainSize = null;
2323
public int ValidationSize = 5000;
@@ -70,7 +70,7 @@ public bool Run()
7070

7171
public void PrepareData()
7272
{
73-
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize);
73+
mnist = MNIST.read_data_sets("mnist", one_hot: true, train_size: TrainSize, validation_size:ValidationSize, test_size:TestSize);
7474
// In this example, we limit mnist data
7575
(Xtr, Ytr) = mnist.train.next_batch(TrainSize==null ? 5000 : TrainSize.Value / 100); // 5000 for training (nn candidates)
7676
(Xte, Yte) = mnist.test.next_batch(TestSize==null ? 200 : TestSize.Value / 100); // 200 for testing
Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
using NumSharp;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using Tensorflow;
6+
using TensorFlowNET.Examples.Utility;
7+
using static Tensorflow.Python;
8+
9+
namespace TensorFlowNET.Examples.ImageProcess
10+
{
11+
/// <summary>
12+
/// Convolutional Neural Network classifier for Hand Written Digits
13+
/// CNN architecture with two convolutional layers, followed by two fully-connected layers at the end.
14+
/// Use Stochastic Gradient Descent (SGD) optimizer.
15+
/// http://www.easy-tensorflow.com/tf-tutorials/convolutional-neural-nets-cnns/cnn1
16+
/// </summary>
17+
public class DigitRecognitionCNN : IExample
18+
{
19+
public bool Enabled { get; set; } = true;
20+
public bool IsImportingGraph { get; set; } = false;
21+
22+
public string Name => "MNIST CNN";
23+
24+
const int img_h = 28;
25+
const int img_w = 28;
26+
int img_size_flat = img_h * img_w; // 784, the total number of pixels
27+
int n_classes = 10; // Number of classes, one class per digit
28+
// Hyper-parameters
29+
int epochs = 10;
30+
int batch_size = 100;
31+
float learning_rate = 0.001f;
32+
int h1 = 200; // number of nodes in the 1st hidden layer
33+
Datasets<DataSetMnist> mnist;
34+
35+
Tensor x, y;
36+
Tensor loss, accuracy;
37+
Operation optimizer;
38+
39+
int display_freq = 100;
40+
float accuracy_test = 0f;
41+
float loss_test = 1f;
42+
43+
public bool Run()
44+
{
45+
PrepareData();
46+
BuildGraph();
47+
48+
with(tf.Session(), sess =>
49+
{
50+
Train(sess);
51+
Test(sess);
52+
});
53+
54+
return loss_test < 0.09 && accuracy_test > 0.95;
55+
}
56+
57+
public Graph BuildGraph()
58+
{
59+
var graph = new Graph().as_default();
60+
61+
// Placeholders for inputs (x) and outputs(y)
62+
x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X");
63+
y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y");
64+
65+
// Create a fully-connected layer with h1 nodes as hidden layer
66+
var fc1 = fc_layer(x, h1, "FC1", use_relu: true);
67+
// Create a fully-connected layer with n_classes nodes as output layer
68+
var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false);
69+
// Define the loss function, optimizer, and accuracy
70+
var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits);
71+
loss = tf.reduce_mean(logits, name: "loss");
72+
optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss);
73+
var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred");
74+
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy");
75+
76+
// Network predictions
77+
var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions");
78+
79+
return graph;
80+
}
81+
82+
private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true)
83+
{
84+
var in_dim = x.shape[1];
85+
86+
var initer = tf.truncated_normal_initializer(stddev: 0.01f);
87+
var W = tf.get_variable("W_" + name,
88+
dtype: tf.float32,
89+
shape: (in_dim, num_units),
90+
initializer: initer);
91+
92+
var initial = tf.constant(0f, num_units);
93+
var b = tf.get_variable("b_" + name,
94+
dtype: tf.float32,
95+
initializer: initial);
96+
97+
var layer = tf.matmul(x, W) + b;
98+
if (use_relu)
99+
layer = tf.nn.relu(layer);
100+
101+
return layer;
102+
}
103+
104+
public Graph ImportGraph() => throw new NotImplementedException();
105+
106+
public void Predict(Session sess) => throw new NotImplementedException();
107+
108+
public void PrepareData()
109+
{
110+
mnist = MNIST.read_data_sets("mnist", one_hot: true);
111+
}
112+
113+
public void Train(Session sess)
114+
{
115+
// Number of training iterations in each epoch
116+
var num_tr_iter = mnist.train.labels.len / batch_size;
117+
118+
var init = tf.global_variables_initializer();
119+
sess.run(init);
120+
121+
float loss_val = 100.0f;
122+
float accuracy_val = 0f;
123+
124+
foreach (var epoch in range(epochs))
125+
{
126+
print($"Training epoch: {epoch + 1}");
127+
// Randomly shuffle the training data at the beginning of each epoch
128+
var (x_train, y_train) = mnist.Randomize(mnist.train.data, mnist.train.labels);
129+
130+
foreach (var iteration in range(num_tr_iter))
131+
{
132+
var start = iteration * batch_size;
133+
var end = (iteration + 1) * batch_size;
134+
var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end);
135+
136+
// Run optimization op (backprop)
137+
sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
138+
139+
if (iteration % display_freq == 0)
140+
{
141+
// Calculate and display the batch loss and accuracy
142+
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch));
143+
loss_val = result[0];
144+
accuracy_val = result[1];
145+
print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}");
146+
}
147+
}
148+
149+
// Run validation after every epoch
150+
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.data), new FeedItem(y, mnist.validation.labels));
151+
loss_val = results1[0];
152+
accuracy_val = results1[1];
153+
print("---------------------------------------------------------");
154+
print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}");
155+
print("---------------------------------------------------------");
156+
}
157+
}
158+
159+
public void Test(Session sess)
160+
{
161+
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels));
162+
loss_test = result[0];
163+
accuracy_test = result[1];
164+
print("---------------------------------------------------------");
165+
print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}");
166+
print("---------------------------------------------------------");
167+
}
168+
}
169+
}

test/TensorFlowNET.Examples/ImageProcess/DigitRecognitionNN.cs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public class DigitRecognitionNN : IExample
3030
int batch_size = 100;
3131
float learning_rate = 0.001f;
3232
int h1 = 200; // number of nodes in the 1st hidden layer
33-
Datasets mnist;
33+
Datasets<DataSetMnist> mnist;
3434

3535
Tensor x, y;
3636
Tensor loss, accuracy;
@@ -107,7 +107,7 @@ private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = tr
107107

108108
public void PrepareData()
109109
{
110-
mnist = MnistDataSet.read_data_sets("mnist", one_hot: true);
110+
mnist = MNIST.read_data_sets("mnist", one_hot: true);
111111
}
112112

113113
public void Train(Session sess)
@@ -125,7 +125,7 @@ public void Train(Session sess)
125125
{
126126
print($"Training epoch: {epoch + 1}");
127127
// Randomly shuffle the training data at the beginning of each epoch
128-
var (x_train, y_train) = randomize(mnist.train.images, mnist.train.labels);
128+
var (x_train, y_train) = randomize(mnist.train.data, mnist.train.labels);
129129

130130
foreach (var iteration in range(num_tr_iter))
131131
{
@@ -147,7 +147,7 @@ public void Train(Session sess)
147147
}
148148

149149
// Run validation after every epoch
150-
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.images), new FeedItem(y, mnist.validation.labels));
150+
var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.data), new FeedItem(y, mnist.validation.labels));
151151
loss_val = results1[0];
152152
accuracy_val = results1[1];
153153
print("---------------------------------------------------------");
@@ -158,7 +158,7 @@ public void Train(Session sess)
158158

159159
public void Test(Session sess)
160160
{
161-
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.images), new FeedItem(y, mnist.test.labels));
161+
var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels));
162162
loss_test = result[0];
163163
accuracy_test = result[1];
164164
print("---------------------------------------------------------");
@@ -171,7 +171,7 @@ public void Test(Session sess)
171171
var perm = np.random.permutation(y.shape[0]);
172172

173173
np.random.shuffle(perm);
174-
return (mnist.train.images[perm], mnist.train.labels[perm]);
174+
return (mnist.train.data[perm], mnist.train.labels[perm]);
175175
}
176176

177177
/// <summary>

test/TensorFlowNET.Examples/Utility/DataSet.cs

Lines changed: 0 additions & 86 deletions
This file was deleted.

0 commit comments

Comments
 (0)