Skip to content

Commit 6a150d1

Browse files
committed
bench: explicitly set backends to override tf-node default
1 parent 6865db1 commit 6a150d1

File tree

1 file changed

+35
-17
lines changed

1 file changed

+35
-17
lines changed

docs/migration-guides/tfjs/benchmark/benchmark.gemm.square_matrices.js

Lines changed: 35 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -121,18 +121,11 @@ function createBenchmark2( shapeA, shapeB, shapeC ) {
121121
var abuf;
122122
var bbuf;
123123
var cbuf;
124-
var A;
125-
var B;
126-
var C;
127124

128125
abuf = discreteUniform( numel( shapeA ), 0, 10, OPTS );
129126
bbuf = discreteUniform( numel( shapeB ), 0, 10, OPTS );
130127
cbuf = discreteUniform( numel( shapeC ), 0, 10, OPTS );
131128

132-
A = tf.tensor( abuf, shapeA, OPTS.dtype );
133-
B = tf.tensor( bbuf, shapeB, OPTS.dtype );
134-
C = tf.tensor( cbuf, shapeC, OPTS.dtype );
135-
136129
return benchmark;
137130

138131
/**
@@ -143,19 +136,35 @@ function createBenchmark2( shapeA, shapeB, shapeC ) {
143136
*/
144137
function benchmark( b ) {
145138
var out;
139+
var A;
140+
var B;
141+
var C;
142+
var D;
146143
var i;
147144

145+
tf.setBackend( 'cpu' );
146+
147+
A = tf.tensor( abuf, shapeA, OPTS.dtype );
148+
B = tf.tensor( bbuf, shapeB, OPTS.dtype );
149+
C = tf.tensor( cbuf, shapeC, OPTS.dtype );
150+
148151
b.tic();
149152
for ( i = 0; i < b.iterations; i++ ) {
150-
out = tf.add( tf.matMul( A, B ), C );
153+
D = tf.matMul( A, B );
154+
out = tf.add( D, C );
151155
if ( typeof out !== 'object' ) {
152156
b.fail( 'should return an object' );
153157
}
158+
tf.dispose( D );
159+
tf.dispose( out );
154160
}
155161
b.toc();
156162
if ( typeof out !== 'object' ) {
157163
b.fail( 'should return an object' );
158164
}
165+
tf.dispose( A );
166+
tf.dispose( B );
167+
tf.dispose( C );
159168
b.pass( 'benchmark finished' );
160169
b.end();
161170
}
@@ -174,18 +183,11 @@ function createBenchmark3( shapeA, shapeB, shapeC ) {
174183
var abuf;
175184
var bbuf;
176185
var cbuf;
177-
var A;
178-
var B;
179-
var C;
180186

181187
abuf = discreteUniform( numel( shapeA ), 0, 10, OPTS );
182188
bbuf = discreteUniform( numel( shapeB ), 0, 10, OPTS );
183189
cbuf = discreteUniform( numel( shapeC ), 0, 10, OPTS );
184190

185-
A = tfnode.tensor( abuf, shapeA, OPTS.dtype );
186-
B = tfnode.tensor( bbuf, shapeB, OPTS.dtype );
187-
C = tfnode.tensor( cbuf, shapeC, OPTS.dtype );
188-
189191
return benchmark;
190192

191193
/**
@@ -196,19 +198,35 @@ function createBenchmark3( shapeA, shapeB, shapeC ) {
196198
*/
197199
function benchmark( b ) {
198200
var out;
201+
var A;
202+
var B;
203+
var C;
204+
var D;
199205
var i;
200206

207+
tfnode.setBackend( 'tensorflow' );
208+
209+
A = tfnode.tensor( abuf, shapeA, OPTS.dtype );
210+
B = tfnode.tensor( bbuf, shapeB, OPTS.dtype );
211+
C = tfnode.tensor( cbuf, shapeC, OPTS.dtype );
212+
201213
b.tic();
202214
for ( i = 0; i < b.iterations; i++ ) {
203-
out = tfnode.add( tfnode.matMul( A, B ), C );
215+
D = tfnode.matMul( A, B );
216+
out = tfnode.add( D, C );
204217
if ( typeof out !== 'object' ) {
205218
b.fail( 'should return an object' );
206219
}
220+
tfnode.dispose( D );
221+
tfnode.dispose( out );
207222
}
208223
b.toc();
209224
if ( typeof out !== 'object' ) {
210225
b.fail( 'should return an object' );
211226
}
227+
tfnode.dispose( A );
228+
tfnode.dispose( B );
229+
tfnode.dispose( C );
212230
b.pass( 'benchmark finished' );
213231
b.end();
214232
}
@@ -232,7 +250,7 @@ function main() {
232250
var i;
233251

234252
min = 1; // 10^min
235-
max = 3; // 10^max
253+
max = 6; // 10^max
236254

237255
for ( i = min; i <= max; i++ ) {
238256
N = floor( pow( pow( 10, i ), 1.0/2.0 ) );

0 commit comments

Comments
 (0)