Skip to content

Commit 459730a

Browse files
committed
upgrade to deeplearn 15 tensorflow/tfjs-core#461
1 parent 7a2bfe5 commit 459730a

File tree

3 files changed

+78
-64
lines changed

3 files changed

+78
-64
lines changed

package-lock.json

Lines changed: 52 additions & 44 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"version": "0.1.0",
44
"private": true,
55
"dependencies": {
6-
"deeplearn": "0.3.12",
6+
"deeplearn": "0.3.15",
77
"react": "^16.1.1",
88
"react-dom": "^16.1.1",
99
"react-scripts": "1.0.17"

src/neuralNetwork.js

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,15 @@ import {
44
Graph,
55
Session,
66
SGDOptimizer,
7-
NDArrayMathGPU,
7+
ENV,
8+
NDArrayMath,
89
CostReduction,
910
} from 'deeplearn';
1011

11-
// Encapsulates math operations on the CPU and GPU.
12-
const math = new NDArrayMathGPU();
13-
1412
class ColorAccessibilityModel {
13+
// Encapsulates math operations on the CPU and GPU.
14+
math = ENV.math;
15+
1516
// Runs training.
1617
session;
1718

@@ -47,27 +48,32 @@ class ColorAccessibilityModel {
4748
this.predictionTensor = this.createFullyConnectedLayer(graph, fullyConnectedLayer, 3, 2);
4849
this.costTensor = graph.meanSquaredCost(this.targetTensor, this.predictionTensor);
4950

50-
this.session = new Session(graph, math);
51+
this.session = new Session(graph, this.math);
5152

5253
this.prepareTrainingSet(trainingSet);
5354
}
5455

5556
prepareTrainingSet(trainingSet) {
56-
math.scope(() => {
57-
const { rawInputs, rawTargets } = trainingSet;
57+
const oldMath = ENV.math;
58+
const safeMode = false;
59+
const math = new NDArrayMath('cpu', safeMode);
60+
ENV.setMath(math);
5861

59-
const inputArray = rawInputs.map(v => Array1D.new(this.normalizeColor(v)));
60-
const targetArray = rawTargets.map(v => Array1D.new(v));
62+
const { rawInputs, rawTargets } = trainingSet;
6163

62-
const shuffledInputProviderBuilder = new InCPUMemoryShuffledInputProviderBuilder([ inputArray, targetArray ]);
63-
const [ inputProvider, targetProvider ] = shuffledInputProviderBuilder.getInputProviders();
64+
const inputArray = rawInputs.map(v => Array1D.new(this.normalizeColor(v)));
65+
const targetArray = rawTargets.map(v => Array1D.new(v));
6466

65-
// Maps tensors to InputProviders.
66-
this.feedEntries = [
67-
{ tensor: this.inputTensor, data: inputProvider },
68-
{ tensor: this.targetTensor, data: targetProvider },
69-
];
70-
});
67+
const shuffledInputProviderBuilder = new InCPUMemoryShuffledInputProviderBuilder([ inputArray, targetArray ]);
68+
const [ inputProvider, targetProvider ] = shuffledInputProviderBuilder.getInputProviders();
69+
70+
// Maps tensors to InputProviders.
71+
this.feedEntries = [
72+
{ tensor: this.inputTensor, data: inputProvider },
73+
{ tensor: this.targetTensor, data: targetProvider },
74+
];
75+
76+
ENV.setMath(oldMath);
7177
}
7278

7379
train(step, computeCost) {
@@ -77,7 +83,7 @@ class ColorAccessibilityModel {
7783

7884
// Train one batch.
7985
let costValue;
80-
math.scope(() => {
86+
this.math.scope(() => {
8187
const cost = this.session.train(
8288
this.costTensor,
8389
this.feedEntries,
@@ -98,7 +104,7 @@ class ColorAccessibilityModel {
98104
predict(rgb) {
99105
let classifier = [];
100106

101-
math.scope(() => {
107+
this.math.scope(() => {
102108
const mapping = [{
103109
tensor: this.inputTensor,
104110
data: Array1D.new(this.normalizeColor(rgb)),

0 commit comments

Comments
 (0)