Skip to content

Commit c1a9f26

Browse files
Add tests to canonicalise subtensor slices
1 parent 34b084f commit c1a9f26

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

tests/tensor/rewriting/test_subtensor.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2404,22 +2404,22 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
24042404
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
24052405

24062406

2407-
@pytest.mark.parametrize("fstop, lstop, lstep", [(None, 9, 1), (-1, -1, -1)])
2408-
def test_slice_canonicalize(fstop, lstop, lstep):
2407+
def test_slice_canonicalize():
24092408
x = tensor(shape=(3, 5, None, 9))
2410-
y = x[0:fstop, 0:5, 0:7, 0:lstop:lstep]
2409+
# Test case 1
2410+
y = x[0:None, 0:5, 0:7, 0:9:1]
24112411
f = pytensor.function([x], y)
2412-
test_y = f.maker.fgraph.toposort()
2412+
test_y = f.maker.fgraph.outputs[0].owner.inputs[0]
24132413

2414-
y1 = x[None:None:None, None:None:None, None:7:None, None:None:None]
2414+
expected_y = x[None:None:None, None:None:None, None:7:None]
24152415

2416-
if fstop == -1 and lstop == -1 and lstep == -1:
2417-
y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]
2416+
assert equal_computations([test_y], [expected_y])
24182417

2418+
# Test case 2
2419+
y1 = x[0:-1, 0:5, 0:7, 0:-1:-1]
24192420
f1 = pytensor.function([x], y1)
2420-
expected_y = f1.maker.fgraph.toposort()
2421+
test_y1 = f1.maker.fgraph.outputs[0].owner.inputs[0]
24212422

2422-
assert all(
2423-
equal_computations([x1], [y1])
2424-
for x1, y1 in zip(test_y[0].inputs, expected_y[0].inputs)
2425-
)
2423+
expected_y1 = x[None:-1:None, None:None:None, None:7:None, None:-1:-1]
2424+
2425+
assert equal_computations([test_y1], [expected_y1])

0 commit comments

Comments
 (0)