Skip to content

Commit 00a8a88

Browse files
committed
Add benchmark test for CAReduce
1 parent a8303a0 commit 00a8a88

File tree

1 file changed

+26
-0
lines changed

1 file changed

+26
-0
lines changed

tests/tensor/test_elemwise.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,3 +985,29 @@ def test_CAReduce(self):
985985
assert isinstance(vect_node.op, Any)
986986
assert vect_node.op.axis == (1,)
987987
assert vect_node.inputs[0] is bool_tns
988+
989+
990+
@pytest.mark.parametrize(
991+
"axis",
992+
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
993+
ids=lambda x: f"axis={x}",
994+
)
995+
@pytest.mark.parametrize(
996+
"c_contiguous",
997+
(True, False),
998+
ids=lambda x: f"c_contiguous={x}",
999+
)
1000+
def test_careduce_benchmark(axis, c_contiguous, benchmark):
1001+
N = 256
1002+
x_test = np.random.uniform(size=(N, N, N))
1003+
transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1)
1004+
1005+
x = pytensor.shared(x_test, name="x", shape=x_test.shape)
1006+
out = x.transpose(transpose_axis).sum(axis=axis)
1007+
fn = pytensor.function([], out)
1008+
1009+
np.testing.assert_allclose(
1010+
fn(),
1011+
x_test.transpose(transpose_axis).sum(axis=axis),
1012+
)
1013+
benchmark(fn)

0 commit comments

Comments
 (0)