|
| 1 | +using NumSharp; |
| 2 | +using System; |
| 3 | +using System.Collections.Generic; |
| 4 | +using System.Text; |
| 5 | +using Tensorflow; |
| 6 | +using TensorFlowNET.Examples.Utility; |
| 7 | +using static Tensorflow.Python; |
| 8 | + |
| 9 | +namespace TensorFlowNET.Examples.ImageProcess |
| 10 | +{ |
| 11 | + /// <summary> |
| 12 | + /// Convolutional Neural Network classifier for Hand Written Digits |
| 13 | + /// CNN architecture with two convolutional layers, followed by two fully-connected layers at the end. |
| 14 | + /// Use Stochastic Gradient Descent (SGD) optimizer. |
| 15 | + /// http://www.easy-tensorflow.com/tf-tutorials/convolutional-neural-nets-cnns/cnn1 |
| 16 | + /// </summary> |
| 17 | + public class DigitRecognitionCNN : IExample |
| 18 | + { |
| 19 | + public bool Enabled { get; set; } = true; |
| 20 | + public bool IsImportingGraph { get; set; } = false; |
| 21 | + |
| 22 | + public string Name => "MNIST CNN"; |
| 23 | + |
| 24 | + const int img_h = 28; |
| 25 | + const int img_w = 28; |
| 26 | + int img_size_flat = img_h * img_w; // 784, the total number of pixels |
| 27 | + int n_classes = 10; // Number of classes, one class per digit |
| 28 | + // Hyper-parameters |
| 29 | + int epochs = 10; |
| 30 | + int batch_size = 100; |
| 31 | + float learning_rate = 0.001f; |
| 32 | + int h1 = 200; // number of nodes in the 1st hidden layer |
| 33 | + Datasets<DataSetMnist> mnist; |
| 34 | + |
| 35 | + Tensor x, y; |
| 36 | + Tensor loss, accuracy; |
| 37 | + Operation optimizer; |
| 38 | + |
| 39 | + int display_freq = 100; |
| 40 | + float accuracy_test = 0f; |
| 41 | + float loss_test = 1f; |
| 42 | + |
| 43 | + public bool Run() |
| 44 | + { |
| 45 | + PrepareData(); |
| 46 | + BuildGraph(); |
| 47 | + |
| 48 | + with(tf.Session(), sess => |
| 49 | + { |
| 50 | + Train(sess); |
| 51 | + Test(sess); |
| 52 | + }); |
| 53 | + |
| 54 | + return loss_test < 0.09 && accuracy_test > 0.95; |
| 55 | + } |
| 56 | + |
| 57 | + public Graph BuildGraph() |
| 58 | + { |
| 59 | + var graph = new Graph().as_default(); |
| 60 | + |
| 61 | + // Placeholders for inputs (x) and outputs(y) |
| 62 | + x = tf.placeholder(tf.float32, shape: (-1, img_size_flat), name: "X"); |
| 63 | + y = tf.placeholder(tf.float32, shape: (-1, n_classes), name: "Y"); |
| 64 | + |
| 65 | + // Create a fully-connected layer with h1 nodes as hidden layer |
| 66 | + var fc1 = fc_layer(x, h1, "FC1", use_relu: true); |
| 67 | + // Create a fully-connected layer with n_classes nodes as output layer |
| 68 | + var output_logits = fc_layer(fc1, n_classes, "OUT", use_relu: false); |
| 69 | + // Define the loss function, optimizer, and accuracy |
| 70 | + var logits = tf.nn.softmax_cross_entropy_with_logits(labels: y, logits: output_logits); |
| 71 | + loss = tf.reduce_mean(logits, name: "loss"); |
| 72 | + optimizer = tf.train.AdamOptimizer(learning_rate: learning_rate, name: "Adam-op").minimize(loss); |
| 73 | + var correct_prediction = tf.equal(tf.argmax(output_logits, 1), tf.argmax(y, 1), name: "correct_pred"); |
| 74 | + accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32), name: "accuracy"); |
| 75 | + |
| 76 | + // Network predictions |
| 77 | + var cls_prediction = tf.argmax(output_logits, axis: 1, name: "predictions"); |
| 78 | + |
| 79 | + return graph; |
| 80 | + } |
| 81 | + |
| 82 | + private Tensor fc_layer(Tensor x, int num_units, string name, bool use_relu = true) |
| 83 | + { |
| 84 | + var in_dim = x.shape[1]; |
| 85 | + |
| 86 | + var initer = tf.truncated_normal_initializer(stddev: 0.01f); |
| 87 | + var W = tf.get_variable("W_" + name, |
| 88 | + dtype: tf.float32, |
| 89 | + shape: (in_dim, num_units), |
| 90 | + initializer: initer); |
| 91 | + |
| 92 | + var initial = tf.constant(0f, num_units); |
| 93 | + var b = tf.get_variable("b_" + name, |
| 94 | + dtype: tf.float32, |
| 95 | + initializer: initial); |
| 96 | + |
| 97 | + var layer = tf.matmul(x, W) + b; |
| 98 | + if (use_relu) |
| 99 | + layer = tf.nn.relu(layer); |
| 100 | + |
| 101 | + return layer; |
| 102 | + } |
| 103 | + |
| 104 | + public Graph ImportGraph() => throw new NotImplementedException(); |
| 105 | + |
| 106 | + public void Predict(Session sess) => throw new NotImplementedException(); |
| 107 | + |
| 108 | + public void PrepareData() |
| 109 | + { |
| 110 | + mnist = MNIST.read_data_sets("mnist", one_hot: true); |
| 111 | + } |
| 112 | + |
| 113 | + public void Train(Session sess) |
| 114 | + { |
| 115 | + // Number of training iterations in each epoch |
| 116 | + var num_tr_iter = mnist.train.labels.len / batch_size; |
| 117 | + |
| 118 | + var init = tf.global_variables_initializer(); |
| 119 | + sess.run(init); |
| 120 | + |
| 121 | + float loss_val = 100.0f; |
| 122 | + float accuracy_val = 0f; |
| 123 | + |
| 124 | + foreach (var epoch in range(epochs)) |
| 125 | + { |
| 126 | + print($"Training epoch: {epoch + 1}"); |
| 127 | + // Randomly shuffle the training data at the beginning of each epoch |
| 128 | + var (x_train, y_train) = mnist.Randomize(mnist.train.data, mnist.train.labels); |
| 129 | + |
| 130 | + foreach (var iteration in range(num_tr_iter)) |
| 131 | + { |
| 132 | + var start = iteration * batch_size; |
| 133 | + var end = (iteration + 1) * batch_size; |
| 134 | + var (x_batch, y_batch) = mnist.GetNextBatch(x_train, y_train, start, end); |
| 135 | + |
| 136 | + // Run optimization op (backprop) |
| 137 | + sess.run(optimizer, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); |
| 138 | + |
| 139 | + if (iteration % display_freq == 0) |
| 140 | + { |
| 141 | + // Calculate and display the batch loss and accuracy |
| 142 | + var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, x_batch), new FeedItem(y, y_batch)); |
| 143 | + loss_val = result[0]; |
| 144 | + accuracy_val = result[1]; |
| 145 | + print($"iter {iteration.ToString("000")}: Loss={loss_val.ToString("0.0000")}, Training Accuracy={accuracy_val.ToString("P")}"); |
| 146 | + } |
| 147 | + } |
| 148 | + |
| 149 | + // Run validation after every epoch |
| 150 | + var results1 = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.validation.data), new FeedItem(y, mnist.validation.labels)); |
| 151 | + loss_val = results1[0]; |
| 152 | + accuracy_val = results1[1]; |
| 153 | + print("---------------------------------------------------------"); |
| 154 | + print($"Epoch: {epoch + 1}, validation loss: {loss_val.ToString("0.0000")}, validation accuracy: {accuracy_val.ToString("P")}"); |
| 155 | + print("---------------------------------------------------------"); |
| 156 | + } |
| 157 | + } |
| 158 | + |
| 159 | + public void Test(Session sess) |
| 160 | + { |
| 161 | + var result = sess.run(new[] { loss, accuracy }, new FeedItem(x, mnist.test.data), new FeedItem(y, mnist.test.labels)); |
| 162 | + loss_test = result[0]; |
| 163 | + accuracy_test = result[1]; |
| 164 | + print("---------------------------------------------------------"); |
| 165 | + print($"Test loss: {loss_test.ToString("0.0000")}, test accuracy: {accuracy_test.ToString("P")}"); |
| 166 | + print("---------------------------------------------------------"); |
| 167 | + } |
| 168 | + } |
| 169 | +} |
0 commit comments