Skip to content

Commit 86cc3b8

Browse files
ricardoV94jessegrabowski
authored andcommitted
Benchmark special vector case in GEMV
1 parent a7b4652 commit 86cc3b8

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

tests/tensor/test_blas_c.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -413,6 +413,32 @@ class TestBlasStridesC(TestBlasStrides):
413413
mode = mode_blas_opt
414414

415415

416+
def test_gemv_vector_dot_perf(benchmark):
417+
n = 400_000
418+
a = pt.vector("A", shape=(n,))
419+
b = pt.vector("x", shape=(n,))
420+
421+
out = CGemv(inplace=True)(
422+
pt.empty((1,)),
423+
1.0,
424+
a[None],
425+
b,
426+
0.0,
427+
)
428+
fn = pytensor.function([a, b], out, accept_inplace=True, trust_input=True)
429+
430+
rng = np.random.default_rng(430)
431+
test_a = rng.normal(size=n)
432+
test_b = rng.normal(size=n)
433+
434+
np.testing.assert_allclose(
435+
fn(test_a, test_b),
436+
np.dot(test_a, test_b),
437+
)
438+
439+
benchmark(fn, test_a, test_b)
440+
441+
416442
@pytest.mark.parametrize(
417443
"neg_stride1", (True, False), ids=["neg_stride1", "pos_stride1"]
418444
)

0 commit comments

Comments
 (0)