@@ -133,21 +133,22 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
133
133
134
134
// Generates inputs and prints observed throughput of MatMul.
135
135
// M = A rows, K = A cols, N = C cols.
136
- template <typename MatTA , typename MatTB = MatTA >
136
+ template <typename TA , typename TB = TA, typename TC = float >
137
137
void BenchMatMul (size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
138
138
hwy::ThreadPool& pool = env.parallel .Pools ().Pool (0 );
139
139
if (env.print_config || env.print_measurement ) {
140
140
fprintf (stderr, " \n " );
141
141
}
142
- fprintf (stderr, " BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n " , M, K, N,
143
- add, TypeName<MatTA>(), TypeName<MatTB>());
142
+ fprintf (stderr,
143
+ " BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n " , //
144
+ M, K, N, add, TypeName<TA>(), TypeName<TB>(), TypeName<TC>());
144
145
145
146
const Extents2D A_extents (M, K);
146
147
const Extents2D B_extents (N, K); // already transposed
147
148
const Extents2D C_extents (M, N);
148
149
149
- RowVectorBatch<float > c_slow_batch = AllocateAlignedRows<float >(C_extents);
150
- RowVectorBatch<float > c_batch = AllocateAlignedRows<float >(C_extents);
150
+ RowVectorBatch<TC > c_slow_batch = AllocateAlignedRows<TC >(C_extents);
151
+ RowVectorBatch<TC > c_batch = AllocateAlignedRows<TC >(C_extents);
151
152
152
153
std::unique_ptr<MatStorageT<float >> add_storage;
153
154
if (add) {
@@ -156,14 +157,14 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
156
157
add_storage->set_scale (1 .0f );
157
158
}
158
159
159
- MatStoragePtr<MatTA > a = GenerateMat<MatTA >(A_extents, pool);
160
- MatStoragePtr<MatTB > b_trans = GenerateTransposedMat<MatTB >(B_extents, pool);
160
+ MatStoragePtr<TA > a = GenerateMat<TA >(A_extents, pool);
161
+ MatStoragePtr<TB > b_trans = GenerateTransposedMat<TB >(B_extents, pool);
161
162
HWY_ASSERT (a && b_trans);
162
163
const auto A = ConstMatFromWeights (*a);
163
164
const auto B = ConstMatFromWeights (*b_trans);
164
165
165
166
const float * add_row = add ? add_storage->data_scale1 () : nullptr ;
166
- const RowPtrF C = RowPtrFromBatch (c_batch);
167
+ const RowPtr<TC> C = RowPtrFromBatch (c_batch);
167
168
168
169
// Fewer reps for large batch sizes, which take longer.
169
170
const size_t num_samples = M < 32 ? 20 : 12 ;
@@ -173,7 +174,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
173
174
// Ensure usage conditions are set before autotuning. Both binding and
174
175
// spinning may materially affect the choice of config. No harm in calling
175
176
// BindB/C if there is a single package: they will be a no-op.
176
- BindB (B_extents.rows , B, env.parallel );
177
+ BindB (B_extents.rows , sizeof (TC), B, env.parallel );
177
178
BindC (A_extents.rows , C, env.parallel );
178
179
179
180
Tristate use_spinning = Tristate::kDefault ;
@@ -191,7 +192,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
191
192
per_key = MatMul (A, B, add_row, env, C);
192
193
const double t1 = hwy::platform::Now ();
193
194
double elapsed = t1 - t0;
194
- keep += C.Row (0 )[hwy::Unpredictable1 ()];
195
+ keep += hwy::ConvertScalarTo< double >( C.Row (0 )[hwy::Unpredictable1 ()]) ;
195
196
196
197
// Only record times after autotuning finished.
197
198
if (per_key->autotune .Best ()) times.push_back (elapsed);
@@ -229,8 +230,8 @@ void BenchAllMatMul() {
229
230
230
231
for (size_t batch_size : {1 , 4 , 128 , 512 }) {
231
232
constexpr bool kAdd = false ;
232
- BenchMatMul<BF16, SFP>(batch_size, 24576 , 3072 , kAdd , env);
233
- BenchMatMul<BF16, SFP>(batch_size, 3072 , 24576 , kAdd , env);
233
+ BenchMatMul<BF16, SFP, BF16 >(batch_size, 24576 , 3072 , kAdd , env);
234
+ BenchMatMul<BF16, SFP, BF16 >(batch_size, 3072 , 24576 , kAdd , env);
234
235
}
235
236
236
237
PROFILER_PRINT_RESULTS ();
0 commit comments