Skip to content

Commit 1981bb5

Browse files
committed
Modify get_scalar_constant_value to work with PyTensor Variables
1 parent 7a01715 commit 1981bb5

File tree

2 files changed

+9
-6
lines changed

2 files changed

+9
-6
lines changed

pytensor/tensor/basic.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,8 @@ def get_scalar_constant_value(
267267
If 'v' is not a scalar, it raises a NotScalarConstantError.
268268
269269
"""
270-
if isinstance(v, np.ndarray):
271-
data = v.data
272-
if data.ndim != 0:
270+
if isinstance(v, (Variable, np.ndarray)):
271+
if v.ndim != 0:
273272
raise NotScalarConstantError()
274273
return get_underlying_scalar_constant_value(
275274
v, elemwise, only_process_constants, max_recur

tests/tensor/test_basic.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3430,10 +3430,14 @@ def test_None_and_NoneConst(self, only_process_constants):
34303430
)
34313431

34323432

3433-
def test_get_scalar_constant_value():
3433+
@pytest.mark.parametrize(
3434+
["valid_inp", "invalid_inp"],
3435+
((np.array(4), np.zeros(5)), (at.constant(4), at.constant(3, ndim=1))),
3436+
)
3437+
def test_get_scalar_constant_value(valid_inp, invalid_inp):
34343438
with pytest.raises(NotScalarConstantError):
3435-
get_scalar_constant_value(np.zeros(5))
3436-
assert get_scalar_constant_value(np.array(4)) == 4
3439+
get_scalar_constant_value(invalid_inp)
3440+
assert get_scalar_constant_value(valid_inp) == 4
34373441

34383442

34393443
def test_complex_mod_failure():

0 commit comments

Comments
 (0)