4
4
import pytensor
5
5
from pytensor import function
6
6
from pytensor import tensor as at
7
+ from pytensor .compile import DeepCopyOp , ViewOp
7
8
from pytensor .configdefaults import config
8
9
from pytensor .sandbox .linalg .ops import inv_as_solve , spectral_radius_bound
9
10
from pytensor .tensor .elemwise import DimShuffle
@@ -163,7 +164,11 @@ def test_cholesky_dot_lower():
163
164
C = cholesky_lower (L .dot (L .T ))
164
165
f = pytensor .function ([L ], C )
165
166
if config .mode != "FAST_COMPILE" :
166
- assert f .maker .fgraph .outputs [0 ].name == "L"
167
+ assert (f .maker .fgraph .outputs [0 ] == f .maker .fgraph .inputs [0 ]) or (
168
+ (o := f .maker .fgraph .outputs [0 ].owner )
169
+ and isinstance (o .op , (DeepCopyOp , ViewOp ))
170
+ and o .inputs [0 ] == f .maker .fgraph .inputs [0 ]
171
+ )
167
172
168
173
169
174
def test_cholesky_dot_upper ():
@@ -175,4 +180,8 @@ def test_cholesky_dot_upper():
175
180
C = cholesky_upper (U .T .dot (U ))
176
181
f = pytensor .function ([U ], C )
177
182
if config .mode != "FAST_COMPILE" :
178
- assert f .maker .fgraph .outputs [0 ].name == "U"
183
+ assert (f .maker .fgraph .outputs [0 ] == f .maker .fgraph .inputs [0 ]) or (
184
+ (o := f .maker .fgraph .outputs [0 ].owner )
185
+ and isinstance (o .op , (DeepCopyOp , ViewOp ))
186
+ and o .inputs [0 ] == f .maker .fgraph .inputs [0 ]
187
+ )
0 commit comments