Skip to content

Commit 2bdf26d

Browse files
jan-wassenbergcopybara-github
authored andcommitted
Support bf16 output of Matmul
Adds Stride to ConstMat, to support decompression of C output for test matmul_test: add line numbers to output Also ignore "N is not a multiple of nc" when N==nc PiperOrigin-RevId: 731096662
1 parent b3b4b9f commit 2bdf26d

File tree

6 files changed

+355
-264
lines changed

6 files changed

+355
-264
lines changed

gemma/gemma-inl.h

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,8 @@ HWY_NOINLINE void FFWNoVit(Activations& activations, size_t num_interleaved,
734734

735735
// Hidden layer -> output layer.
736736
auto activations_mat = MakeConstMat(
737-
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim));
737+
hidden_activations.Row(0), Extents2D(num_interleaved, ffh_hidden_dim),
738+
hidden_activations.Stride());
738739

739740
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
740741
}
@@ -773,8 +774,9 @@ HWY_NOINLINE void FFWVit(Activations& activations, size_t num_interleaved,
773774
multiplier.Row(0), ff_hidden_dim * num_interleaved);
774775

775776
// Hidden layer -> output layer.
776-
auto activations_mat = MakeConstMat(
777-
hidden_activations.Row(0), Extents2D(num_interleaved, ff_hidden_dim));
777+
auto activations_mat = MakeConstMat(hidden_activations.Row(0),
778+
Extents2D(num_interleaved, ff_hidden_dim),
779+
hidden_activations.Stride());
778780

779781
MatMul(activations_mat, w_output, output_bias, *activations.env, ffw_out);
780782
}

ops/bench_matmul.cc

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -133,21 +133,22 @@ void PrintSpeed(const Extents2D& A_extents, const Extents2D& B_extents,
133133

134134
// Generates inputs and prints observed throughput of MatMul.
135135
// M = A rows, K = A cols, N = C cols.
136-
template <typename MatTA, typename MatTB = MatTA>
136+
template <typename TA, typename TB = TA, typename TC = float>
137137
void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
138138
hwy::ThreadPool& pool = env.parallel.Pools().Pool(0);
139139
if (env.print_config || env.print_measurement) {
140140
fprintf(stderr, "\n");
141141
}
142-
fprintf(stderr, "BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s\n", M, K, N,
143-
add, TypeName<MatTA>(), TypeName<MatTB>());
142+
fprintf(stderr,
143+
"BenchMatMul %zu, %zu, %zu, add=%d, TA=%s, TB=%s, TC=%s\n", //
144+
M, K, N, add, TypeName<TA>(), TypeName<TB>(), TypeName<TC>());
144145

145146
const Extents2D A_extents(M, K);
146147
const Extents2D B_extents(N, K); // already transposed
147148
const Extents2D C_extents(M, N);
148149

149-
RowVectorBatch<float> c_slow_batch = AllocateAlignedRows<float>(C_extents);
150-
RowVectorBatch<float> c_batch = AllocateAlignedRows<float>(C_extents);
150+
RowVectorBatch<TC> c_slow_batch = AllocateAlignedRows<TC>(C_extents);
151+
RowVectorBatch<TC> c_batch = AllocateAlignedRows<TC>(C_extents);
151152

152153
std::unique_ptr<MatStorageT<float>> add_storage;
153154
if (add) {
@@ -156,14 +157,14 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
156157
add_storage->set_scale(1.0f);
157158
}
158159

159-
MatStoragePtr<MatTA> a = GenerateMat<MatTA>(A_extents, pool);
160-
MatStoragePtr<MatTB> b_trans = GenerateTransposedMat<MatTB>(B_extents, pool);
160+
MatStoragePtr<TA> a = GenerateMat<TA>(A_extents, pool);
161+
MatStoragePtr<TB> b_trans = GenerateTransposedMat<TB>(B_extents, pool);
161162
HWY_ASSERT(a && b_trans);
162163
const auto A = ConstMatFromWeights(*a);
163164
const auto B = ConstMatFromWeights(*b_trans);
164165

165166
const float* add_row = add ? add_storage->data_scale1() : nullptr;
166-
const RowPtrF C = RowPtrFromBatch(c_batch);
167+
const RowPtr<TC> C = RowPtrFromBatch(c_batch);
167168

168169
// Fewer reps for large batch sizes, which take longer.
169170
const size_t num_samples = M < 32 ? 20 : 12;
@@ -173,7 +174,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
173174
// Ensure usage conditions are set before autotuning. Both binding and
174175
// spinning may materially affect the choice of config. No harm in calling
175176
// BindB/C if there is a single package: they will be a no-op.
176-
BindB(B_extents.rows, B, env.parallel);
177+
BindB(B_extents.rows, sizeof(TC), B, env.parallel);
177178
BindC(A_extents.rows, C, env.parallel);
178179

179180
Tristate use_spinning = Tristate::kDefault;
@@ -191,7 +192,7 @@ void BenchMatMul(size_t M, size_t K, size_t N, bool add, MatMulEnv& env) {
191192
per_key = MatMul(A, B, add_row, env, C);
192193
const double t1 = hwy::platform::Now();
193194
double elapsed = t1 - t0;
194-
keep += C.Row(0)[hwy::Unpredictable1()];
195+
keep += hwy::ConvertScalarTo<double>(C.Row(0)[hwy::Unpredictable1()]);
195196

196197
// Only record times after autotuning finished.
197198
if (per_key->autotune.Best()) times.push_back(elapsed);
@@ -229,8 +230,8 @@ void BenchAllMatMul() {
229230

230231
for (size_t batch_size : {1, 4, 128, 512}) {
231232
constexpr bool kAdd = false;
232-
BenchMatMul<BF16, SFP>(batch_size, 24576, 3072, kAdd, env);
233-
BenchMatMul<BF16, SFP>(batch_size, 3072, 24576, kAdd, env);
233+
BenchMatMul<BF16, SFP, BF16>(batch_size, 24576, 3072, kAdd, env);
234+
BenchMatMul<BF16, SFP, BF16>(batch_size, 3072, 24576, kAdd, env);
234235
}
235236

236237
PROFILER_PRINT_RESULTS();

0 commit comments

Comments
 (0)