Skip to content

Commit ab38b24

Browse files
committed
Add benchmark test for CAReduce
1 parent 6e6a17a commit ab38b24

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

tests/tensor/test_elemwise.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -985,3 +985,23 @@ def test_CAReduce(self):
985985
assert isinstance(vect_node.op, Any)
986986
assert vect_node.op.axis == (1,)
987987
assert vect_node.inputs[0] is bool_tns
988+
989+
990+
@pytest.mark.parametrize("axis", (0, 1, 2, None), ids=lambda x: f"axis={x}")
991+
@pytest.mark.parametrize(
992+
"c_contiguous", (True, False), ids=lambda x: f"c_contiguous={x}"
993+
)
994+
def test_careduce_benchmark(axis, c_contiguous, benchmark):
995+
N = 256
996+
x_test = np.random.uniform(size=(N, N, N))
997+
transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1)
998+
999+
x = pytensor.shared(x_test, name="x", shape=x_test.shape)
1000+
out = x.transpose(transpose_axis).sum(axis=axis)
1001+
fn = pytensor.function([], out)
1002+
1003+
np.testing.assert_allclose(
1004+
fn(),
1005+
x_test.transpose(transpose_axis).sum(axis=axis),
1006+
)
1007+
benchmark(fn)

0 commit comments

Comments
 (0)