Skip to content

Commit 8a81a53

Browse files
committed
Simplify cholesky infer_shape test and remove slow mark
1 parent 9452257 commit 8a81a53

File tree

1 file changed

+12
-10
lines changed

1 file changed

+12
-10
lines changed

tests/tensor/test_slinalg.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -122,18 +122,20 @@ def test_cholesky_grad_indef():
122122
assert np.all(np.isnan(chol_f(mat)))
123123

124124

125-
@pytest.mark.slow
126-
def test_cholesky_shape():
127-
rng = np.random.default_rng(utt.fetch_seed())
125+
def test_cholesky_infer_shape():
128126
x = matrix()
129-
for l in (cholesky(x), Cholesky(lower=True)(x), Cholesky(lower=False)(x)):
130-
f_chol = pytensor.function([x], l.shape)
127+
f_chol = pytensor.function([x], [cholesky(x).shape, cholesky(x, lower=False).shape])
128+
if config.mode != "FAST_COMPILE":
131129
topo_chol = f_chol.maker.fgraph.toposort()
132-
if config.mode != "FAST_COMPILE":
133-
assert sum(node.op.__class__ == Cholesky for node in topo_chol) == 0
134-
for shp in [2, 3, 5]:
135-
m = np.cov(rng.standard_normal((shp, shp + 10))).astype(config.floatX)
136-
np.testing.assert_equal(f_chol(m), (shp, shp))
130+
f_chol.dprint()
131+
assert not any(
132+
isinstance(getattr(node.op, "core_op", node.op), Cholesky)
133+
for node in topo_chol
134+
)
135+
for shp in [2, 3, 5]:
136+
res1, res2 = f_chol(np.eye(shp).astype(x.dtype))
137+
assert tuple(res1) == (shp, shp)
138+
assert tuple(res2) == (shp, shp)
137139

138140

139141
def test_eigvalsh():

0 commit comments

Comments
 (0)