Skip to content

Commit afb109b

Browse files
committed
Add get_scalar_constant method to raise for non-zero ndim
1 parent 277559b commit afb109b

File tree

1 file changed

+12
-0
lines changed

1 file changed

+12
-0
lines changed

pytensor/tensor/basic.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,17 @@ def _obj_is_wrappable_as_tensor(x):
255255
)
256256

257257

258+
def get_scalar_constant(v, elemwise=True, only_process_constants=False, max_recur=10):
259+
"""
260+
Checks whether 'v' is a scalar based on 'ndim'
261+
"""
262+
if isinstance(v, np.ndarray):
263+
data = v.data
264+
if data.ndim != 0:
265+
raise NotScalarConstantError()
266+
return get_scalar_constant_value(v, elemwise, only_process_constants, max_recur)
267+
268+
258269
def get_scalar_constant_value(
259270
orig_v, elemwise=True, only_process_constants=False, max_recur=10
260271
):
@@ -4094,6 +4105,7 @@ def take_along_axis(arr, indices, axis=0):
40944105
"cast",
40954106
"scalar_from_tensor",
40964107
"tensor_from_scalar",
4108+
"get_scalar_constant",
40974109
"get_scalar_constant_value",
40984110
"constant",
40994111
"as_tensor_variable",

0 commit comments

Comments
 (0)