Skip to content

Commit 5a2433c

Browse files
committed
abstract PrepareData interface for example
1 parent f462c55 commit 5a2433c

15 files changed

+70
-27
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ Read the docs & book [The Definitive Guide to Tensorflow.NET](https://tensorflow
7272
* [Basic Operations](test/TensorFlowNET.Examples/BasicOperations.cs)
7373
* [Image Recognition](test/TensorFlowNET.Examples/ImageRecognition.cs)
7474
* [Linear Regression](test/TensorFlowNET.Examples/LinearRegression.cs)
75+
* [Logistic Regression](test/TensorFlowNET.Examples/LogisticRegression.cs)
7576
* [Text Classification](test/TensorFlowNET.Examples/TextClassificationWithMovieReviews.cs)
7677
* [CNN Text Classification](test/TensorFlowNET.Examples/CnnTextClassification.cs)
7778
* [Naive Bayes Classification](test/TensorFlowNET.Examples/NaiveBayesClassifier.cs)

src/TensorFlowNET.Core/TensorFlowNET.Core.csproj

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
<Description>Google's TensorFlow binding in .NET Standard.
1818
Docs: https://tensorflownet.readthedocs.io</Description>
1919
<AssemblyVersion>0.5.0.0</AssemblyVersion>
20-
<PackageReleaseNotes>Add a lot of APIs to build neural networks model</PackageReleaseNotes>
20+
<PackageReleaseNotes>Add Logistic Regression to do MNIST.
21+
Add a lot of APIs to build neural networks model</PackageReleaseNotes>
2122
<LangVersion>7.2</LangVersion>
2223
<FileVersion>0.5.0.0</FileVersion>
2324
</PropertyGroup>

test/TensorFlowNET.Examples/BasicEagerApi.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ namespace TensorFlowNET.Examples
1212
public class BasicEagerApi : IExample
1313
{
1414
private Tensor a, b, c, d;
15+
1516
public void Run()
1617
{
1718
// Set Eager API
@@ -34,5 +35,9 @@ public void Run()
3435

3536
// Full compatibility with Numpy
3637
}
38+
39+
public void PrepareData()
40+
{
41+
}
3742
}
3843
}

test/TensorFlowNET.Examples/BasicOperations.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,5 +96,9 @@ public void Run()
9696
}
9797
}
9898
}
99+
100+
public void PrepareData()
101+
{
102+
}
99103
}
100104
}

test/TensorFlowNET.Examples/HelloWorld.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,9 @@ of the Constant op. */
3333
}
3434
}
3535
}
36+
37+
public void PrepareData()
38+
{
39+
}
3640
}
3741
}

test/TensorFlowNET.Examples/IExample.cs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,5 +11,6 @@ namespace TensorFlowNET.Examples
1111
public interface IExample
1212
{
1313
void Run();
14+
void PrepareData();
1415
}
1516
}

test/TensorFlowNET.Examples/ImageRecognition.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ private NDArray ReadTensorFromImageFile(string file_name,
7878
});
7979
}
8080

81-
private void PrepareData()
81+
public void PrepareData()
8282
{
8383
Directory.CreateDirectory(dir);
8484

test/TensorFlowNET.Examples/InceptionArchGoogLeNet.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ private NDArray ReadTensorFromImageFile(string file_name,
8383
});
8484
}
8585

86-
private void PrepareData()
86+
public void PrepareData()
8787
{
8888
Directory.CreateDirectory(dir);
8989

test/TensorFlowNET.Examples/LinearRegression.cs

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@ public class LinearRegression : Python, IExample
1919
int training_epochs = 1000;
2020
int display_step = 50;
2121

22+
NDArray train_X, train_Y;
23+
int n_samples;
24+
2225
public void Run()
2326
{
2427
// Training Data
25-
var train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f,
26-
7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f);
27-
var train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f,
28-
2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f);
29-
var n_samples = train_X.shape[0];
28+
PrepareData();
3029

3130
// tf Graph Input
3231
var X = tf.placeholder(tf.float32);
@@ -95,5 +94,14 @@ public void Run()
9594
Console.WriteLine($"Absolute mean square loss difference: {Math.Abs((float)training_cost - (float)testing_cost)}");
9695
});
9796
}
97+
98+
public void PrepareData()
99+
{
100+
train_X = np.array(3.3f, 4.4f, 5.5f, 6.71f, 6.93f, 4.168f, 9.779f, 6.182f, 7.59f, 2.167f,
101+
7.042f, 10.791f, 5.313f, 7.997f, 5.654f, 9.27f, 3.1f);
102+
train_Y = np.array(1.7f, 2.76f, 2.09f, 3.19f, 1.694f, 1.573f, 3.366f, 2.596f, 2.53f, 1.221f,
103+
2.827f, 3.465f, 1.65f, 2.904f, 2.42f, 2.94f, 1.3f);
104+
n_samples = train_X.shape[0];
105+
}
98106
}
99107
}

test/TensorFlowNET.Examples/LogisticRegression.cs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@ public class LogisticRegression : Python, IExample
2121
private int batch_size = 100;
2222
private int display_step = 1;
2323

24+
Datasets mnist;
25+
2426
public void Run()
2527
{
26-
var mnist = PrepareData();
28+
PrepareData();
2729

2830
// tf Graph Input
2931
var x = tf.placeholder(tf.float32, new TensorShape(-1, 784)); // mnist data image of shape 28*28=784
@@ -86,10 +88,9 @@ public void Run()
8688
});
8789
}
8890

89-
private Datasets PrepareData()
91+
public void PrepareData()
9092
{
91-
var mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
92-
return mnist;
93+
mnist = MnistDataSet.read_data_sets("logistic_regression", one_hot: true);
9394
}
9495
}
9596
}

test/TensorFlowNET.Examples/MetaGraph.cs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,5 +27,9 @@ private void ImportMetaGraph(string dir)
2727
logits: logits);
2828
});
2929
}
30+
31+
public void PrepareData()
32+
{
33+
}
3034
}
3135
}

test/TensorFlowNET.Examples/NaiveBayesClassifier.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,9 @@ public class NaiveBayesClassifier : Python, IExample
1515
public Normal dist { get; set; }
1616
public void Run()
1717
{
18-
np.array<float>(1.0f, 1.0f);
19-
var X = np.array<float>(new float[][] { new float[] { 1.0f, 1.0f }, new float[] { 2.0f, 2.0f }, new float[] { -1.0f, -1.0f }, new float[] { -2.0f, -2.0f }, new float[] { 1.0f, -1.0f }, new float[] { 2.0f, -2.0f }, });
20-
var y = np.array<int>(0,0,1,1,2,2);
18+
np.array(1.0f, 1.0f);
19+
var X = np.array(new float[][] { new float[] { 1.0f, 1.0f }, new float[] { 2.0f, 2.0f }, new float[] { -1.0f, -1.0f }, new float[] { -2.0f, -2.0f }, new float[] { 1.0f, -1.0f }, new float[] { 2.0f, -2.0f }, });
20+
var y = np.array(0,0,1,1,2,2);
2121
fit(X, y);
2222
// Create a regular grid and classify each point
2323
}
@@ -102,5 +102,10 @@ public Tensor predict (NDArray X)
102102
// exp to get the actual probabilities
103103
return tf.exp(log_prob);
104104
}
105+
106+
public void PrepareData()
107+
{
108+
109+
}
105110
}
106111
}

test/TensorFlowNET.Examples/NamedEntityRecognition.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,5 +14,10 @@ public void Run()
1414
{
1515
throw new NotImplementedException();
1616
}
17+
18+
public void PrepareData()
19+
{
20+
throw new NotImplementedException();
21+
}
1722
}
1823
}

test/TensorFlowNET.Examples/TextClassification/TextClassificationTrain.cs

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public class TextClassificationTrain : Python, IExample
2323

2424
public void Run()
2525
{
26-
download_dbpedia();
26+
PrepareData();
2727
Console.WriteLine("Building dataset...");
2828
var (x, y, alphabet_size) = DataHelpers.build_char_dataset("train", "vdcnn", CHAR_MAX_LEN);
2929

@@ -32,17 +32,9 @@ public void Run()
3232
with(tf.Session(), sess =>
3333
{
3434
new VdCnn(alphabet_size, CHAR_MAX_LEN, NUM_CLASS);
35-
3635
});
3736
}
3837

39-
public void download_dbpedia()
40-
{
41-
string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
42-
Web.Download(url, dataDir, dataFileName);
43-
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
44-
}
45-
4638
private (int[][], int[][], int[], int[]) train_test_split(int[][] x, int[] y, float test_size = 0.3f)
4739
{
4840
int len = x.Length;
@@ -75,5 +67,12 @@ public void download_dbpedia()
7567

7668
return (train_x.ToArray(), valid_x.ToArray(), train_y.ToArray(), valid_y.ToArray());
7769
}
70+
71+
public void PrepareData()
72+
{
73+
string url = "https://github.com/le-scientifique/torchDatasets/raw/master/dbpedia_csv.tar.gz";
74+
Web.Download(url, dataDir, dataFileName);
75+
Compress.ExtractTGZ(Path.Join(dataDir, dataFileName), dataDir);
76+
}
7877
}
7978
}

test/TensorFlowNET.Examples/TextClassification/TextClassificationWithMovieReviews.cs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,11 @@ public class TextClassificationWithMovieReviews : Python, IExample
1313
{
1414
string dir = "text_classification_with_movie_reviews";
1515
string dataFile = "imdb.zip";
16+
NDArray train_data, train_labels, test_data, test_labels;
1617

1718
public void Run()
1819
{
19-
var((train_data, train_labels), (test_data, test_labels)) = PrepareData();
20+
PrepareData();
2021

2122
Console.WriteLine($"Training entries: {train_data.size}, labels: {train_labels.size}");
2223

@@ -40,7 +41,7 @@ public void Run()
4041
model.add(keras.layers.Embedding(vocab_size, 16));
4142
}
4243

43-
private ((NDArray, NDArray), (NDArray, NDArray)) PrepareData()
44+
public void PrepareData()
4445
{
4546
Directory.CreateDirectory(dir);
4647

@@ -71,7 +72,11 @@ public void Run()
7172
var y_train = labels_train;
7273
var y_test = labels_test;
7374

74-
return ((x_train, y_train), (x_test, y_test));
75+
x_train = train_data;
76+
train_labels = y_train;
77+
78+
test_data = x_test;
79+
test_labels = y_test;
7580
}
7681

7782
private NDArray ReadData(string file)

0 commit comments

Comments
 (0)