Skip to content

Commit 0921b0e

Browse files
committed
upgrade to deeplearn 15 tensorflow/tfjs-core#461
1 parent 6bf0d7a commit 0921b0e

File tree

3 files changed

+75
-61
lines changed

3 files changed

+75
-61
lines changed

package-lock.json

Lines changed: 52 additions & 44 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"version": "0.1.0",
44
"private": true,
55
"dependencies": {
6-
"deeplearn": "0.3.12",
6+
"deeplearn": "0.3.15",
77
"mnist": "^1.1.0",
88
"react": "^16.1.1",
99
"react-dom": "^16.1.1",

src/neuralNetwork.js

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@ import {
44
Graph,
55
Session,
66
SGDOptimizer,
7-
NDArrayMathGPU,
7+
ENV,
8+
NDArrayMath,
89
CostReduction,
910
} from 'deeplearn';
1011

11-
const math = new NDArrayMathGPU();
12-
1312
class MnistModel {
13+
math = ENV.math;
14+
1415
session;
1516

1617
initialLearningRate = 0.06;
@@ -42,32 +43,37 @@ class MnistModel {
4243
this.predictionTensor = this.createFullyConnectedLayer(graph, fullyConnectedLayer, 3, 10);
4344
this.costTensor = graph.meanSquaredCost(this.targetTensor, this.predictionTensor);
4445

45-
this.session = new Session(graph, math);
46+
this.session = new Session(graph, this.math);
4647

4748
this.prepareTrainingSet(trainingSet);
4849
}
4950

5051
prepareTrainingSet(trainingSet) {
51-
math.scope(() => {
52-
const inputArray = trainingSet.map(v => Array1D.new(v.input));
53-
const targetArray = trainingSet.map(v => Array1D.new(v.output));
52+
const oldMath = ENV.math;
53+
const safeMode = false;
54+
const math = new NDArrayMath('cpu', safeMode);
55+
ENV.setMath(math);
5456

55-
const shuffledInputProviderBuilder = new InCPUMemoryShuffledInputProviderBuilder([ inputArray, targetArray ]);
56-
const [ inputProvider, targetProvider ] = shuffledInputProviderBuilder.getInputProviders();
57+
const inputArray = trainingSet.map(v => Array1D.new(v.input));
58+
const targetArray = trainingSet.map(v => Array1D.new(v.output));
5759

58-
this.feedEntries = [
59-
{ tensor: this.inputTensor, data: inputProvider },
60-
{ tensor: this.targetTensor, data: targetProvider },
61-
];
62-
});
60+
const shuffledInputProviderBuilder = new InCPUMemoryShuffledInputProviderBuilder([ inputArray, targetArray ]);
61+
const [ inputProvider, targetProvider ] = shuffledInputProviderBuilder.getInputProviders();
62+
63+
this.feedEntries = [
64+
{ tensor: this.inputTensor, data: inputProvider },
65+
{ tensor: this.targetTensor, data: targetProvider },
66+
];
67+
68+
ENV.setMath(oldMath);
6369
}
6470

6571
train(step, computeCost) {
6672
let learningRate = this.initialLearningRate * Math.pow(0.90, Math.floor(step / 50));
6773
this.optimizer.setLearningRate(learningRate);
6874

6975
let costValue;
70-
math.scope(() => {
76+
this.math.scope(() => {
7177
const cost = this.session.train(
7278
this.costTensor,
7379
this.feedEntries,
@@ -87,7 +93,7 @@ class MnistModel {
8793
predict(pixels) {
8894
let classifier = [];
8995

90-
math.scope(() => {
96+
this.math.scope(() => {
9197
const mapping = [{
9298
tensor: this.inputTensor,
9399
data: Array1D.new(pixels),

0 commit comments

Comments
 (0)