Skip to content

Commit d47ce12

Browse files
ricardoV94brandonwillard
authored andcommitted
Do not use c_code for non-dense inputs in CheckAndRaise
1 parent 5e7b095 commit d47ce12

File tree

2 files changed

+23
-1
lines changed

2 files changed

+23
-1
lines changed

aesara/raise_op.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from aesara.link.c.op import COp
1111
from aesara.link.c.params_type import ParamsType
1212
from aesara.link.c.type import Generic
13+
from aesara.tensor.type import DenseTensorType
1314

1415

1516
class ExceptionType(Generic):
@@ -101,6 +102,10 @@ def connection_pattern(self, node):
101102
return [[1]] + [[0]] * (len(node.inputs) - 1)
102103

103104
def c_code(self, node, name, inames, onames, props):
105+
if not isinstance(node.inputs[0].type, DenseTensorType):
106+
raise NotImplementedError(
107+
f"CheckAndRaise c_code not implemented for input type {node.inputs[0].type}"
108+
)
104109
value_name, *cond_names = inames
105110
out_name = onames[0]
106111
check = []
@@ -129,7 +134,7 @@ def c_code(self, node, name, inames, onames, props):
129134
return res
130135

131136
def c_code_cache_version(self):
132-
return (1, 0)
137+
return (1, 1)
133138

134139
def infer_shape(self, fgraph, node, input_shapes):
135140
return [input_shapes[0]]

tests/test_raise_op.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import numpy as np
22
import pytest
3+
import scipy.sparse
34

45
import aesara
56
import aesara.tensor as at
67
from aesara.compile.mode import OPT_FAST_RUN, Mode
78
from aesara.graph.basic import Constant, equal_computations
89
from aesara.raise_op import Assert, CheckAndRaise, assert_op
10+
from aesara.sparse import as_sparse_variable
911
from tests import unittest_tools as utt
1012

1113

@@ -114,3 +116,18 @@ def test_infer_shape(self):
114116
self._compile_and_check(
115117
[admat, adscal, bdscal], [out], [admat_val, adscal_val, bdscal_val], Assert
116118
)
119+
120+
121+
def test_CheckAndRaise_sparse_variable():
122+
check_and_raise = CheckAndRaise(ValueError, "sparse_check")
123+
124+
spe1 = scipy.sparse.csc_matrix(scipy.sparse.eye(5, 3))
125+
aspe1 = as_sparse_variable(spe1)
126+
a1 = check_and_raise(aspe1, aspe1.sum() > 2)
127+
assert a1.sum().eval() == 3
128+
129+
spe2 = scipy.sparse.csc_matrix(scipy.sparse.eye(5, 1))
130+
aspe2 = as_sparse_variable(spe2)
131+
a2 = check_and_raise(aspe1, aspe2.sum() > 2)
132+
with pytest.raises(ValueError, match="sparse_check"):
133+
a2.sum().eval()

0 commit comments

Comments
 (0)