Skip to content

Commit 62982a2

Browse files
committed
bench: add gemm benchmark
1 parent dd069b1 commit 62982a2

File tree

1 file changed

+216
-0
lines changed

1 file changed

+216
-0
lines changed
Lines changed: 216 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,216 @@
1+
/**
2+
* @license Apache-2.0
3+
*
4+
* Copyright (c) 2024 The Stdlib Authors.
5+
*
6+
* Licensed under the Apache License, Version 2.0 (the "License");
7+
* you may not use this file except in compliance with the License.
8+
* You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing, software
13+
* distributed under the License is distributed on an "AS IS" BASIS,
14+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
* See the License for the specific language governing permissions and
16+
* limitations under the License.
17+
*/
18+
19+
/* eslint-disable max-len */
20+
21+
'use strict';
22+
23+
// MODULES //
24+
25+
var resolve = require( 'path' ).resolve;
26+
var bench = require( '@stdlib/bench' );
27+
var discreteUniform = require( '@stdlib/random/array/discrete-uniform' );
28+
var randi = require( '@stdlib/random/base/discrete-uniform' ).factory;
29+
var filled2dBy = require( '@stdlib/array/base/filled2d-by' );
30+
var pow = require( '@stdlib/math/base/special/pow' );
31+
var floor = require( '@stdlib/math/base/special/floor' );
32+
var numel = require( '@stdlib/ndarray/base/numel' );
33+
var shape2strides = require( '@stdlib/ndarray/base/shape2strides' );
34+
var isnanf = require( '@stdlib/math/base/assert/is-nanf' );
35+
var format = require( '@stdlib/string/format' );
36+
var tryRequire = require( '@stdlib/utils/try-require' );
37+
38+
// var sgemm = require( '@stdlib/blas/base/sgemm' ).ndarray;
39+
var sgemm = require( '@stdlib/utils/noop' ); // FIXME: remove once `sgemm` merged
40+
41+
var pkg = require( './../package.json' ).name;
42+
43+
44+
// VARIABLES //
45+
46+
var mathjs = tryRequire( resolve( __dirname, '..', 'node_modules', 'mathjs' ) );
47+
var opts = {
48+
'skip': ( mathjs instanceof Error )
49+
};
50+
var OPTS = {
51+
'dtype': 'float32'
52+
};
53+
54+
55+
// FUNCTIONS //
56+
57+
/**
58+
* Creates a benchmark function.
59+
*
60+
* @private
61+
* @param {PositiveIntegerArray} shapeA - shape of the first array
62+
* @param {string} orderA - memory layout of the first array
63+
* @param {PositiveIntegerArray} shapeB - shape of the second array
64+
* @param {string} orderB - memory layout of the second array
65+
* @param {PositiveIntegerArray} shapeC - shape of the third array
66+
* @param {string} orderC - memory layout of the third array
67+
* @returns {Function} benchmark function
68+
*/
69+
function createBenchmark1( shapeA, orderA, shapeB, orderB, shapeC, orderC ) {
70+
var sa;
71+
var sb;
72+
var sc;
73+
var A;
74+
var B;
75+
var C;
76+
77+
A = discreteUniform( numel( shapeA ), 0, 10, OPTS );
78+
B = discreteUniform( numel( shapeB ), 0, 10, OPTS );
79+
C = discreteUniform( numel( shapeC ), 0, 10, OPTS );
80+
81+
sa = shape2strides( shapeA, orderA );
82+
sb = shape2strides( shapeB, orderB );
83+
sc = shape2strides( shapeC, orderC );
84+
85+
return benchmark;
86+
87+
/**
88+
* Benchmark function.
89+
*
90+
* @private
91+
* @param {Benchmark} b - benchmark instance
92+
*/
93+
function benchmark( b ) {
94+
var i;
95+
96+
b.tic();
97+
for ( i = 0; i < b.iterations; i++ ) {
98+
sgemm( 'no-transpose', 'no-transpose', shapeA[0], shapeC[1], shapeB[0], 0.5, A, sa[0], sa[1], 0, B, sb[0], sb[1], 0, 2.0, C, sc[0], sc[1], 0 );
99+
if ( isnanf( C[ i%C.length ] ) ) {
100+
b.fail( 'should not return NaN' );
101+
}
102+
}
103+
b.toc();
104+
if ( isnanf( C[ i%C.length ] ) ) {
105+
b.fail( 'should not return NaN' );
106+
}
107+
b.pass( 'benchmark finished' );
108+
b.end();
109+
}
110+
}
111+
112+
/**
113+
* Creates a benchmark function.
114+
*
115+
* @private
116+
* @param {PositiveIntegerArray} shapeA - shape of the first array
117+
* @param {PositiveIntegerArray} shapeB - shape of the second array
118+
* @param {PositiveIntegerArray} shapeC - shape of the third array
119+
* @returns {Function} benchmark function
120+
*/
121+
function createBenchmark2( shapeA, shapeB, shapeC ) {
122+
var abuf;
123+
var bbuf;
124+
var cbuf;
125+
var A;
126+
var B;
127+
var C;
128+
129+
abuf = filled2dBy( shapeA, randi( 0, 10 ) );
130+
bbuf = filled2dBy( shapeB, randi( 0, 10 ) );
131+
cbuf = filled2dBy( shapeC, randi( 0, 10 ) );
132+
133+
A = mathjs.matrix( abuf );
134+
B = mathjs.matrix( bbuf );
135+
C = mathjs.matrix( cbuf );
136+
137+
return benchmark;
138+
139+
/**
140+
* Benchmark function.
141+
*
142+
* @private
143+
* @param {Benchmark} b - benchmark instance
144+
*/
145+
function benchmark( b ) {
146+
var out;
147+
var i;
148+
149+
b.tic();
150+
for ( i = 0; i < b.iterations; i++ ) {
151+
out = mathjs.add( mathjs.multiply( A, B ), C );
152+
if ( isnanf( out.get( [ i%shapeC[0], i%shapeC[1] ] ) ) ) {
153+
b.fail( 'should not return NaN' );
154+
}
155+
}
156+
b.toc();
157+
if ( isnanf( out.get( [ i%shapeC[0], i%shapeC[1] ] ) ) ) {
158+
b.fail( 'should not return NaN' );
159+
}
160+
b.pass( 'benchmark finished' );
161+
b.end();
162+
}
163+
}
164+
165+
166+
// MAIN //
167+
168+
/**
169+
* Main execution sequence.
170+
*
171+
* @private
172+
*/
173+
function main() {
174+
var shapes;
175+
var orders;
176+
var min;
177+
var max;
178+
var N;
179+
var f;
180+
var i;
181+
182+
min = 1; // 10^min
183+
max = 5; // 10^max
184+
185+
for ( i = min; i <= max; i++ ) {
186+
N = floor( pow( pow( 10, i ), 1.0/2.0 ) );
187+
shapes = [
188+
[ N, N ],
189+
[ N, N ],
190+
[ N, N ]
191+
];
192+
orders = [
193+
'row-major',
194+
'row-major',
195+
'row-major'
196+
];
197+
f = createBenchmark1( shapes[0], orders[0], shapes[1], orders[1], shapes[2], orders[2] );
198+
bench( format( '%s::stdlib:blas/base/sgemm:dtype=%s,orders=(%s),size=%d,shapes={(%s),(%s),(%s)}', OPTS.dtype, pkg, orders.join( ',' ), numel( shapes[2] ), shapes[0].join( ',' ), shapes[1].join( ',' ), shapes[2].join( ',' ) ), f );
199+
200+
f = createBenchmark2( shapes[0], shapes[1], shapes[2] );
201+
bench( format( '%s::mathjs:multiply:dtype=%s,size=%d,shapes={(%s),(%s),(%s)}', OPTS.dtype, pkg, numel( shapes[2] ), shapes[0].join( ',' ), shapes[1].join( ',' ), shapes[2].join( ',' ) ), opts, f );
202+
203+
orders = [
204+
'row-major',
205+
'column-major',
206+
'row-major'
207+
];
208+
f = createBenchmark1( shapes[0], orders[0], shapes[1], orders[1], shapes[2], orders[2] );
209+
bench( format( '%s::stdlib:blas/base/sgemm:dtype=%s,orders=(%s),size=%d,shapes={(%s),(%s),(%s)}', OPTS.dtype, pkg, orders.join( ',' ), numel( shapes[2] ), shapes[0].join( ',' ), shapes[1].join( ',' ), shapes[2].join( ',' ) ), f );
210+
211+
f = createBenchmark2( shapes[0], shapes[1], shapes[2] );
212+
bench( format( '%s::mathjs:multiply:dtype=%s,size=%d,shapes={(%s),(%s),(%s)}', OPTS.dtype, pkg, numel( shapes[2] ), shapes[0].join( ',' ), shapes[1].join( ',' ), shapes[2].join( ',' ) ), opts, f );
213+
}
214+
}
215+
216+
main();

0 commit comments

Comments
 (0)