Skip to content

Commit 78b5120

Browse files
committed
Refactor test and change expected counts of Alloc that were due to BlasOpt
1 parent e468381 commit 78b5120

File tree

1 file changed

+29
-26
lines changed

1 file changed

+29
-26
lines changed

tests/tensor/test_basic.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -758,40 +758,43 @@ def check_allocs_in_fgraph(fgraph, n):
758758
def setup_method(self):
759759
self.rng = np.random.default_rng(seed=utt.fetch_seed())
760760

761-
def test_alloc_constant_folding(self):
761+
@pytest.mark.parametrize(
762+
"subtensor_fn, expected_grad_n_alloc",
763+
[
764+
# IncSubtensor1
765+
(lambda x: x[:60], 1),
766+
# AdvancedIncSubtensor1
767+
(lambda x: x[np.arange(60)], 1),
768+
# AdvancedIncSubtensor
769+
(lambda x: x[np.arange(50), np.arange(50)], 1),
770+
],
771+
)
772+
def test_alloc_constant_folding(self, subtensor_fn, expected_grad_n_alloc):
762773
test_params = np.asarray(self.rng.standard_normal(50 * 60), self.dtype)
763774

764775
some_vector = vector("some_vector", dtype=self.dtype)
765776
some_matrix = some_vector.reshape((60, 50))
766777
variables = self.shared(np.ones((50,), dtype=self.dtype))
767-
idx = constant(np.arange(50))
768778

769-
for alloc_, (subtensor, n_alloc) in zip(
770-
self.allocs,
771-
[
772-
# IncSubtensor1
773-
(some_matrix[:60], 2),
774-
# AdvancedIncSubtensor1
775-
(some_matrix[arange(60)], 2),
776-
# AdvancedIncSubtensor
777-
(some_matrix[idx, idx], 1),
778-
],
779-
):
780-
derp = pt_sum(dense_dot(subtensor, variables))
779+
subtensor = subtensor_fn(some_matrix)
781780

782-
fobj = pytensor.function([some_vector], derp, mode=self.mode)
783-
grad_derp = pytensor.grad(derp, some_vector)
784-
fgrad = pytensor.function([some_vector], grad_derp, mode=self.mode)
785-
786-
topo_obj = fobj.maker.fgraph.toposort()
787-
assert sum(isinstance(node.op, type(alloc_)) for node in topo_obj) == 0
781+
derp = pt_sum(dense_dot(subtensor, variables))
782+
fobj = pytensor.function([some_vector], derp, mode=self.mode)
783+
assert (
784+
sum(isinstance(node.op, Alloc) for node in fobj.maker.fgraph.apply_nodes)
785+
== 0
786+
)
787+
# TODO: Assert something about the value if we bothered to call it?
788+
fobj(test_params)
788789

789-
topo_grad = fgrad.maker.fgraph.toposort()
790-
assert (
791-
sum(isinstance(node.op, type(alloc_)) for node in topo_grad) == n_alloc
792-
), (alloc_, subtensor, n_alloc, topo_grad)
793-
fobj(test_params)
794-
fgrad(test_params)
790+
grad_derp = pytensor.grad(derp, some_vector)
791+
fgrad = pytensor.function([some_vector], grad_derp, mode=self.mode)
792+
assert (
793+
sum(isinstance(node.op, Alloc) for node in fgrad.maker.fgraph.apply_nodes)
794+
== expected_grad_n_alloc
795+
)
796+
# TODO: Assert something about the value if we bothered to call it?
797+
fgrad(test_params)
795798

796799
def test_alloc_output(self):
797800
val = constant(self.rng.standard_normal((1, 1)), dtype=self.dtype)

0 commit comments

Comments
 (0)