Skip to content

Commit e1e77b0

Browse files
committed
check for deepcopy or view rather than string
1 parent 840d573 commit e1e77b0

File tree

1 file changed

+11
-2
lines changed

1 file changed

+11
-2
lines changed

tests/sandbox/linalg/test_linalg.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytensor
55
from pytensor import function
66
from pytensor import tensor as at
7+
from pytensor.compile import DeepCopyOp, ViewOp
78
from pytensor.configdefaults import config
89
from pytensor.sandbox.linalg.ops import inv_as_solve, spectral_radius_bound
910
from pytensor.tensor.elemwise import DimShuffle
@@ -163,7 +164,11 @@ def test_cholesky_dot_lower():
163164
C = cholesky_lower(L.dot(L.T))
164165
f = pytensor.function([L], C)
165166
if config.mode != "FAST_COMPILE":
166-
assert f.maker.fgraph.outputs[0].name == "L"
167+
assert (f.maker.fgraph.outputs[0] == f.maker.fgraph.inputs[0]) or (
168+
(o := f.maker.fgraph.outputs[0].owner)
169+
and isinstance(o.op, (DeepCopyOp, ViewOp))
170+
and o.inputs[0] == f.maker.fgraph.inputs[0]
171+
)
167172

168173

169174
def test_cholesky_dot_upper():
@@ -175,4 +180,8 @@ def test_cholesky_dot_upper():
175180
C = cholesky_upper(U.T.dot(U))
176181
f = pytensor.function([U], C)
177182
if config.mode != "FAST_COMPILE":
178-
assert f.maker.fgraph.outputs[0].name == "U"
183+
assert (f.maker.fgraph.outputs[0] == f.maker.fgraph.inputs[0]) or (
184+
(o := f.maker.fgraph.outputs[0].owner)
185+
and isinstance(o.op, (DeepCopyOp, ViewOp))
186+
and o.inputs[0] == f.maker.fgraph.inputs[0]
187+
)

0 commit comments

Comments
 (0)