Skip to content

Commit bd42ed9

Browse files
committed
Mnist dataset
1 parent d94c685 commit bd42ed9

File tree

6 files changed

+250
-1
lines changed

6 files changed

+250
-1
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow;
5+
using TensorFlowNET.Examples.Utility;
6+
7+
namespace TensorFlowNET.Examples
8+
{
9+
/// <summary>
10+
/// A logistic regression learning algorithm example using TensorFlow library.
11+
/// This example is using the MNIST database of handwritten digits
12+
/// https://github.com/aymericdamien/TensorFlow-Examples/blob/master/examples/2_BasicModels/logistic_regression.py
13+
/// </summary>
14+
public class LogisticRegression : Python, IExample
15+
{
16+
public void Run()
17+
{
18+
PrepareData();
19+
}
20+
21+
private void PrepareData()
22+
{
23+
MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
24+
25+
}
26+
}
27+
}

test/TensorFlowNET.Examples/TensorFlowNET.Examples.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
</PropertyGroup>
77

88
<ItemGroup>
9+
<PackageReference Include="DevExpress.Xpo" Version="18.2.6" />
910
<PackageReference Include="NumSharp" Version="0.8.0" />
1011
<PackageReference Include="SharpZipLib" Version="1.1.0" />
1112
<PackageReference Include="TensorFlow.NET" Version="0.4.2" />

test/TensorFlowNET.Examples/Utility/Compress.cs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using ICSharpCode.SharpZipLib.GZip;
1+
using ICSharpCode.SharpZipLib.Core;
2+
using ICSharpCode.SharpZipLib.GZip;
23
using ICSharpCode.SharpZipLib.Tar;
34
using System;
45
using System.IO;
@@ -11,6 +12,26 @@ namespace TensorFlowNET.Examples.Utility
1112
{
1213
public class Compress
1314
{
15+
public static void ExtractGZip(string gzipFileName, string targetDir)
16+
{
17+
// Use a 4K buffer. Any larger is a waste.
18+
byte[] dataBuffer = new byte[4096];
19+
20+
using (System.IO.Stream fs = new FileStream(gzipFileName, FileMode.Open, FileAccess.Read))
21+
{
22+
using (GZipInputStream gzipStream = new GZipInputStream(fs))
23+
{
24+
// Change this to your needs
25+
string fnOut = Path.Combine(targetDir, Path.GetFileNameWithoutExtension(gzipFileName));
26+
27+
using (FileStream fsOut = File.Create(fnOut))
28+
{
29+
StreamUtils.Copy(gzipStream, fsOut, dataBuffer);
30+
}
31+
}
32+
}
33+
}
34+
1435
public static void UnZip(String gzArchiveName, String destFolder)
1536
{
1637
var flag = gzArchiveName.Split(Path.DirectorySeparatorChar).Last().Split('.').First() + ".bin";
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using NumSharp.Core;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
using Tensorflow;
6+
7+
namespace TensorFlowNET.Examples.Utility
8+
{
9+
public class DataSet
10+
{
11+
private int _num_examples;
12+
13+
public DataSet(NDArray images, NDArray labels, TF_DataType dtype, bool reshape)
14+
{
15+
_num_examples = images.shape[0];
16+
images = images.reshape(images.shape[0], images.shape[1] * images.shape[2]);
17+
images = np.multiply(images, 1.0f / 255.0f);
18+
}
19+
}
20+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
using ICSharpCode.SharpZipLib.Core;
2+
using ICSharpCode.SharpZipLib.GZip;
3+
using NumSharp.Core;
4+
using System;
5+
using System.Collections.Generic;
6+
using System.IO;
7+
using System.Linq;
8+
using System.Text;
9+
using Tensorflow;
10+
11+
namespace TensorFlowNET.Examples.Utility
12+
{
13+
public class MnistDataSet
14+
{
15+
private const string DEFAULT_SOURCE_URL = "https://storage.googleapis.com/cvdf-datasets/mnist/";
16+
private const string TRAIN_IMAGES = "train-images-idx3-ubyte.gz";
17+
private const string TRAIN_LABELS = "train-labels-idx1-ubyte.gz";
18+
private const string TEST_IMAGES = "t10k-images-idx3-ubyte.gz";
19+
private const string TEST_LABELS = "t10k-labels-idx1-ubyte.gz";
20+
21+
public static void read_data_sets(string train_dir,
22+
bool one_hot = false,
23+
TF_DataType dtype = TF_DataType.DtInvalid,
24+
bool reshape = true,
25+
int validation_size = 5000,
26+
string source_url = DEFAULT_SOURCE_URL)
27+
{
28+
Web.Download(source_url + TRAIN_IMAGES, train_dir, TRAIN_IMAGES);
29+
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_IMAGES), train_dir);
30+
var train_images = extract_images(Path.Join(train_dir, TRAIN_IMAGES.Split('.')[0]));
31+
32+
Web.Download(source_url + TRAIN_LABELS, train_dir, TRAIN_LABELS);
33+
Compress.ExtractGZip(Path.Join(train_dir, TRAIN_LABELS), train_dir);
34+
var train_labels = extract_labels(Path.Join(train_dir, TRAIN_LABELS.Split('.')[0]), one_hot: one_hot);
35+
36+
Web.Download(source_url + TEST_IMAGES, train_dir, TEST_IMAGES);
37+
Compress.ExtractGZip(Path.Join(train_dir, TEST_IMAGES), train_dir);
38+
var test_images = extract_images(Path.Join(train_dir, TEST_IMAGES.Split('.')[0]));
39+
40+
Web.Download(source_url + TEST_LABELS, train_dir, TEST_LABELS);
41+
Compress.ExtractGZip(Path.Join(train_dir, TEST_LABELS), train_dir);
42+
var test_labels = extract_labels(Path.Join(train_dir, TEST_LABELS.Split('.')[0]), one_hot: one_hot);
43+
44+
int end = train_images.shape[0];
45+
var validation_images = train_images[np.arange(validation_size)];
46+
var validation_labels = train_labels[np.arange(validation_size)];
47+
train_images = train_images[np.arange(validation_size, end)];
48+
train_labels = train_labels[np.arange(validation_size, end)];
49+
50+
var train = new DataSet(train_images, train_labels, dtype, reshape);
51+
}
52+
53+
public static NDArray extract_images(string file)
54+
{
55+
using (var bytestream = new FileStream(file, FileMode.Open))
56+
{
57+
var magic = _read32(bytestream);
58+
if (magic != 2051)
59+
throw new ValueError($"Invalid magic number {magic} in MNIST image file: {file}");
60+
var num_images = _read32(bytestream);
61+
var rows = _read32(bytestream);
62+
var cols = _read32(bytestream);
63+
var buf = new byte[rows * cols * num_images];
64+
bytestream.Read(buf, 0, buf.Length);
65+
var data = np.frombuffer(buf, np.uint8);
66+
data = data.reshape((int)num_images, (int)rows, (int)cols, 1);
67+
return data;
68+
}
69+
}
70+
71+
public static NDArray extract_labels(string file, bool one_hot = false, int num_classes = 10)
72+
{
73+
using (var bytestream = new FileStream(file, FileMode.Open))
74+
{
75+
var magic = _read32(bytestream);
76+
if (magic != 2049)
77+
throw new ValueError($"Invalid magic number {magic} in MNIST label file: {file}");
78+
var num_items = _read32(bytestream);
79+
var buf = new byte[num_items];
80+
bytestream.Read(buf, 0, buf.Length);
81+
var labels = np.frombuffer(buf, np.uint8);
82+
if (one_hot)
83+
return dense_to_one_hot(labels, num_classes);
84+
return labels;
85+
}
86+
}
87+
88+
private static NDArray dense_to_one_hot(NDArray labels_dense, int num_classes)
89+
{
90+
var num_labels = labels_dense.shape[0];
91+
var index_offset = np.arange(num_labels) * num_classes;
92+
var labels_one_hot = np.zeros(num_labels, num_classes);
93+
94+
for(int row = 0; row < num_labels; row++)
95+
{
96+
var col = labels_dense.Data<byte>(row);
97+
labels_one_hot[row, col] = 1;
98+
}
99+
100+
return labels_one_hot;
101+
}
102+
103+
private static uint _read32(FileStream bytestream)
104+
{
105+
var buffer = new byte[sizeof(uint)];
106+
var count = bytestream.Read(buffer, 0, 4);
107+
return np.frombuffer(buffer, ">u4").Data<uint>(0);
108+
}
109+
}
110+
}
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
'''
2+
A logistic regression learning algorithm example using TensorFlow library.
3+
This example is using the MNIST database of handwritten digits
4+
(http://yann.lecun.com/exdb/mnist/)
5+
Author: Aymeric Damien
6+
Project: https://github.com/aymericdamien/TensorFlow-Examples/
7+
'''
8+
9+
from __future__ import print_function
10+
11+
import tensorflow as tf
12+
13+
# Import MNIST data
14+
from tensorflow.examples.tutorials.mnist import input_data
15+
mnist = input_data.read_data_sets("/tmp/data/", one_hot=True)
16+
17+
# Parameters
18+
learning_rate = 0.01
19+
training_epochs = 25
20+
batch_size = 100
21+
display_step = 1
22+
23+
# tf Graph Input
24+
x = tf.placeholder(tf.float32, [None, 784]) # mnist data image of shape 28*28=784
25+
y = tf.placeholder(tf.float32, [None, 10]) # 0-9 digits recognition => 10 classes
26+
27+
# Set model weights
28+
W = tf.Variable(tf.zeros([784, 10]))
29+
b = tf.Variable(tf.zeros([10]))
30+
31+
# Construct model
32+
pred = tf.nn.softmax(tf.matmul(x, W) + b) # Softmax
33+
34+
# Minimize error using cross entropy
35+
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1))
36+
# Gradient Descent
37+
optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost)
38+
39+
# Initialize the variables (i.e. assign their default value)
40+
init = tf.global_variables_initializer()
41+
42+
# Start training
43+
with tf.Session() as sess:
44+
45+
# Run the initializer
46+
sess.run(init)
47+
48+
# Training cycle
49+
for epoch in range(training_epochs):
50+
avg_cost = 0.
51+
total_batch = int(mnist.train.num_examples/batch_size)
52+
# Loop over all batches
53+
for i in range(total_batch):
54+
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
55+
# Run optimization op (backprop) and cost op (to get loss value)
56+
_, c = sess.run([optimizer, cost], feed_dict={x: batch_xs,
57+
y: batch_ys})
58+
# Compute average loss
59+
avg_cost += c / total_batch
60+
# Display logs per epoch step
61+
if (epoch+1) % display_step == 0:
62+
print("Epoch:", '%04d' % (epoch+1), "cost=", "{:.9f}".format(avg_cost))
63+
64+
print("Optimization Finished!")
65+
66+
# Test model
67+
correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
68+
# Calculate accuracy
69+
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
70+
print("Accuracy:", accuracy.eval({x: mnist.test.images, y: mnist.test.labels}))

0 commit comments

Comments
 (0)