@@ -4,14 +4,15 @@ import {
4
4
Graph ,
5
5
Session ,
6
6
SGDOptimizer ,
7
- NDArrayMathGPU ,
7
+ ENV ,
8
+ NDArrayMath ,
8
9
CostReduction ,
9
10
} from 'deeplearn' ;
10
11
11
- // Encapsulates math operations on the CPU and GPU.
12
- const math = new NDArrayMathGPU ( ) ;
13
-
14
12
class ColorAccessibilityModel {
13
+ // Encapsulates math operations on the CPU and GPU.
14
+ math = ENV . math ;
15
+
15
16
// Runs training.
16
17
session ;
17
18
@@ -47,27 +48,32 @@ class ColorAccessibilityModel {
47
48
this . predictionTensor = this . createFullyConnectedLayer ( graph , fullyConnectedLayer , 3 , 2 ) ;
48
49
this . costTensor = graph . meanSquaredCost ( this . targetTensor , this . predictionTensor ) ;
49
50
50
- this . session = new Session ( graph , math ) ;
51
+ this . session = new Session ( graph , this . math ) ;
51
52
52
53
this . prepareTrainingSet ( trainingSet ) ;
53
54
}
54
55
55
56
prepareTrainingSet ( trainingSet ) {
56
- math . scope ( ( ) => {
57
- const { rawInputs, rawTargets } = trainingSet ;
57
+ const oldMath = ENV . math ;
58
+ const safeMode = false ;
59
+ const math = new NDArrayMath ( 'cpu' , safeMode ) ;
60
+ ENV . setMath ( math ) ;
58
61
59
- const inputArray = rawInputs . map ( v => Array1D . new ( this . normalizeColor ( v ) ) ) ;
60
- const targetArray = rawTargets . map ( v => Array1D . new ( v ) ) ;
62
+ const { rawInputs, rawTargets } = trainingSet ;
61
63
62
- const shuffledInputProviderBuilder = new InCPUMemoryShuffledInputProviderBuilder ( [ inputArray , targetArray ] ) ;
63
- const [ inputProvider , targetProvider ] = shuffledInputProviderBuilder . getInputProviders ( ) ;
64
+ const inputArray = rawInputs . map ( v => Array1D . new ( this . normalizeColor ( v ) ) ) ;
65
+ const targetArray = rawTargets . map ( v => Array1D . new ( v ) ) ;
64
66
65
- // Maps tensors to InputProviders.
66
- this . feedEntries = [
67
- { tensor : this . inputTensor , data : inputProvider } ,
68
- { tensor : this . targetTensor , data : targetProvider } ,
69
- ] ;
70
- } ) ;
67
+ const shuffledInputProviderBuilder = new InCPUMemoryShuffledInputProviderBuilder ( [ inputArray , targetArray ] ) ;
68
+ const [ inputProvider , targetProvider ] = shuffledInputProviderBuilder . getInputProviders ( ) ;
69
+
70
+ // Maps tensors to InputProviders.
71
+ this . feedEntries = [
72
+ { tensor : this . inputTensor , data : inputProvider } ,
73
+ { tensor : this . targetTensor , data : targetProvider } ,
74
+ ] ;
75
+
76
+ ENV . setMath ( oldMath ) ;
71
77
}
72
78
73
79
train ( step , computeCost ) {
@@ -77,7 +83,7 @@ class ColorAccessibilityModel {
77
83
78
84
// Train one batch.
79
85
let costValue ;
80
- math . scope ( ( ) => {
86
+ this . math . scope ( ( ) => {
81
87
const cost = this . session . train (
82
88
this . costTensor ,
83
89
this . feedEntries ,
@@ -98,7 +104,7 @@ class ColorAccessibilityModel {
98
104
predict ( rgb ) {
99
105
let classifier = [ ] ;
100
106
101
- math . scope ( ( ) => {
107
+ this . math . scope ( ( ) => {
102
108
const mapping = [ {
103
109
tensor : this . inputTensor ,
104
110
data : Array1D . new ( this . normalizeColor ( rgb ) ) ,
0 commit comments