From 078708e3c14d8a5b35209a9dc6d37db1110660bd Mon Sep 17 00:00:00 2001 From: Shreyas Singh Date: Tue, 14 Mar 2023 16:56:44 +0530 Subject: [PATCH 1/7] Add get_scalar_constant method to raise for non-zero ndim --- pytensor/tensor/basic.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index eaa468c4c7..effd6495f4 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -255,6 +255,17 @@ def _obj_is_wrappable_as_tensor(x): ) +def get_scalar_constant(v, elemwise=True, only_process_constants=False, max_recur=10): + """ + Checks whether 'v' is a scalar based on 'ndim' + """ + if isinstance(v, np.ndarray): + data = v.data + if data.ndim != 0: + raise NotScalarConstantError() + return get_scalar_constant_value(v, elemwise, only_process_constants, max_recur) + + def get_scalar_constant_value( orig_v, elemwise=True, only_process_constants=False, max_recur=10 ): @@ -4094,6 +4105,7 @@ def take_along_axis(arr, indices, axis=0): "cast", "scalar_from_tensor", "tensor_from_scalar", + "get_scalar_constant", "get_scalar_constant_value", "constant", "as_tensor_variable", From 4d51f7737e529427c450b468d2252bc9862a630e Mon Sep 17 00:00:00 2001 From: Shreyas Singh Date: Tue, 14 Mar 2023 17:31:23 +0530 Subject: [PATCH 2/7] Rename get_scalar_constant_value to get_underlying_scalar_constant --- doc/library/tensor/basic.rst | 2 +- pytensor/__init__.py | 6 +- pytensor/gradient.py | 4 +- pytensor/link/jax/dispatch/tensor_basic.py | 6 +- pytensor/scan/basic.py | 6 +- pytensor/scan/rewriting.py | 6 +- pytensor/tensor/basic.py | 38 ++++----- pytensor/tensor/blas.py | 2 +- pytensor/tensor/conv/abstract_conv.py | 12 +-- pytensor/tensor/elemwise.py | 2 +- pytensor/tensor/exceptions.py | 2 +- pytensor/tensor/extra_ops.py | 2 +- pytensor/tensor/random/op.py | 4 +- pytensor/tensor/rewriting/basic.py | 6 +- pytensor/tensor/rewriting/elemwise.py | 4 +- pytensor/tensor/rewriting/math.py | 36 ++++----- pytensor/tensor/rewriting/shape.py | 6 +- pytensor/tensor/rewriting/subtensor.py | 20 ++--- pytensor/tensor/shape.py | 12 +-- pytensor/tensor/subtensor.py | 12 +-- pytensor/tensor/var.py | 4 +- tests/sparse/test_basic.py | 2 +- tests/tensor/test_basic.py | 94 +++++++++++----------- tests/tensor/test_elemwise.py | 4 +- tests/tensor/test_math.py | 4 +- 25 files changed, 148 insertions(+), 148 deletions(-) diff --git a/doc/library/tensor/basic.rst b/doc/library/tensor/basic.rst index fab197dd75..fda72a9c36 100644 --- a/doc/library/tensor/basic.rst +++ b/doc/library/tensor/basic.rst @@ -577,7 +577,7 @@ them perfectly, but a `dscalar` otherwise. .. method:: round(mode="half_away_from_zero") :noindex: .. method:: trace() - .. method:: get_scalar_constant_value() + .. method:: get_underlying_scalar_constant() .. method:: zeros_like(model, dtype=None) All the above methods are equivalent to NumPy for PyTensor on the current tensor. diff --git a/pytensor/__init__.py b/pytensor/__init__.py index 8cd4bcd972..789c465ce6 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -137,7 +137,7 @@ def _as_symbolic(x, **kwargs) -> Variable: # isort: on -def get_scalar_constant_value(v): +def get_underlying_scalar_constant(v): """Return the constant scalar (i.e. 0-D) value underlying variable `v`. If `v` is the output of dim-shuffles, fills, allocs, cast, etc. @@ -153,8 +153,8 @@ def get_scalar_constant_value(v): if sparse and isinstance(v.type, sparse.SparseTensorType): if v.owner is not None and isinstance(v.owner.op, sparse.CSM): data = v.owner.inputs[0] - return tensor.get_scalar_constant_value(data) - return tensor.get_scalar_constant_value(v) + return tensor.get_underlying_scalar_constant(data) + return tensor.get_underlying_scalar_constant(v) # isort: off diff --git a/pytensor/gradient.py b/pytensor/gradient.py index 246a3d3b0b..930acfbdb3 100644 --- a/pytensor/gradient.py +++ b/pytensor/gradient.py @@ -1325,7 +1325,7 @@ def try_to_copy_if_needed(var): f" {i}. Since this input is only connected " "to integer-valued outputs, it should " "evaluate to zeros, but it evaluates to" - f"{pytensor.get_scalar_constant_value(term)}." + f"{pytensor.get_underlying_scalar_constant(term)}." ) raise ValueError(msg) @@ -2086,7 +2086,7 @@ def _is_zero(x): no_constant_value = True try: - constant_value = pytensor.get_scalar_constant_value(x) + constant_value = pytensor.get_underlying_scalar_constant(x) no_constant_value = False except pytensor.tensor.exceptions.NotScalarConstantError: pass diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index 8db4b255a8..4545705e28 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -18,7 +18,7 @@ ScalarFromTensor, Split, TensorFromScalar, - get_scalar_constant_value, + get_underlying_scalar_constant, ) from pytensor.tensor.exceptions import NotScalarConstantError @@ -106,7 +106,7 @@ def join(axis, *tensors): def jax_funcify_Split(op: Split, node, **kwargs): _, axis, splits = node.inputs try: - constant_axis = get_scalar_constant_value(axis) + constant_axis = get_underlying_scalar_constant(axis) except NotScalarConstantError: constant_axis = None warnings.warn( @@ -116,7 +116,7 @@ def jax_funcify_Split(op: Split, node, **kwargs): try: constant_splits = np.array( [ - get_scalar_constant_value(splits[i]) + get_underlying_scalar_constant(splits[i]) for i in range(get_vector_length(splits)) ] ) diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index 2875ca18a2..973d4be7bf 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -12,7 +12,7 @@ from pytensor.graph.utils import MissingInputError, TestValueError from pytensor.scan.op import Scan, ScanInfo from pytensor.scan.utils import expand_empty, safe_new, until -from pytensor.tensor.basic import get_scalar_constant_value +from pytensor.tensor.basic import get_underlying_scalar_constant from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import minimum from pytensor.tensor.shape import shape_padleft, unbroadcast @@ -147,7 +147,7 @@ def isNaN_or_Inf_or_None(x): isStr = False if not isNaN and not isInf: try: - val = get_scalar_constant_value(x) + val = get_underlying_scalar_constant(x) isInf = np.isinf(val) isNaN = np.isnan(val) except Exception: @@ -476,7 +476,7 @@ def wrap_into_list(x): n_fixed_steps = int(n_steps) else: try: - n_fixed_steps = at.get_scalar_constant_value(n_steps) + n_fixed_steps = at.get_underlying_scalar_constant(n_steps) except NotScalarConstantError: n_fixed_steps = None diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index e71fb271de..abab3644c9 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -49,7 +49,7 @@ safe_new, scan_can_remove_outs, ) -from pytensor.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value +from pytensor.tensor.basic import Alloc, AllocEmpty, get_underlying_scalar_constant from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import Dot, dot, maximum, minimum @@ -1956,13 +1956,13 @@ def belongs_to_set(self, node, set_nodes): nsteps = node.inputs[0] try: - nsteps = int(get_scalar_constant_value(nsteps)) + nsteps = int(get_underlying_scalar_constant(nsteps)) except NotScalarConstantError: pass rep_nsteps = rep.inputs[0] try: - rep_nsteps = int(get_scalar_constant_value(rep_nsteps)) + rep_nsteps = int(get_underlying_scalar_constant(rep_nsteps)) except NotScalarConstantError: pass diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index effd6495f4..85b9a684c1 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -263,10 +263,10 @@ def get_scalar_constant(v, elemwise=True, only_process_constants=False, max_recu data = v.data if data.ndim != 0: raise NotScalarConstantError() - return get_scalar_constant_value(v, elemwise, only_process_constants, max_recur) + return get_underlying_scalar_constant(v, elemwise, only_process_constants, max_recur) -def get_scalar_constant_value( +def get_underlying_scalar_constant( orig_v, elemwise=True, only_process_constants=False, max_recur=10 ): """Return the constant scalar(0-D) value underlying variable `v`. @@ -369,7 +369,7 @@ def get_scalar_constant_value( elif isinstance(v.owner.op, CheckAndRaise): # check if all conditions are constant and true conds = [ - get_scalar_constant_value(c, max_recur=max_recur) + get_underlying_scalar_constant(c, max_recur=max_recur) for c in v.owner.inputs[1:] ] if builtins.all(0 == c.ndim and c != 0 for c in conds): @@ -383,7 +383,7 @@ def get_scalar_constant_value( continue if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops): const = [ - get_scalar_constant_value(i, max_recur=max_recur) + get_underlying_scalar_constant(i, max_recur=max_recur) for i in v.owner.inputs ] ret = [[None]] @@ -402,7 +402,7 @@ def get_scalar_constant_value( v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops ): const = [ - get_scalar_constant_value(i, max_recur=max_recur) + get_underlying_scalar_constant(i, max_recur=max_recur) for i in v.owner.inputs ] ret = [[None]] @@ -448,7 +448,7 @@ def get_scalar_constant_value( ): idx = v.owner.op.idx_list[0] if isinstance(idx, Type): - idx = get_scalar_constant_value( + idx = get_underlying_scalar_constant( v.owner.inputs[1], max_recur=max_recur ) try: @@ -482,14 +482,14 @@ def get_scalar_constant_value( ): idx = v.owner.op.idx_list[0] if isinstance(idx, Type): - idx = get_scalar_constant_value( + idx = get_underlying_scalar_constant( v.owner.inputs[1], max_recur=max_recur ) # Python 2.4 does not support indexing with numpy.integer # So we cast it. idx = int(idx) ret = v.owner.inputs[0].owner.inputs[idx] - ret = get_scalar_constant_value(ret, max_recur=max_recur) + ret = get_underlying_scalar_constant(ret, max_recur=max_recur) # MakeVector can cast implicitly its input in some case. return _asarray(ret, dtype=v.type.dtype) @@ -504,7 +504,7 @@ def get_scalar_constant_value( idx_list = op.idx_list idx = idx_list[0] if isinstance(idx, Type): - idx = get_scalar_constant_value( + idx = get_underlying_scalar_constant( owner.inputs[1], max_recur=max_recur ) grandparent = leftmost_parent.owner.inputs[0] @@ -519,7 +519,7 @@ def get_scalar_constant_value( if not (idx < ndim): msg = ( - "get_scalar_constant_value detected " + "get_underlying_scalar_constant detected " f"deterministic IndexError: x.shape[{int(idx)}] " f"when x.ndim={int(ndim)}." ) @@ -1581,7 +1581,7 @@ def do_constant_folding(self, fgraph, node): @_get_vector_length.register(Alloc) def _get_vector_length_Alloc(var_inst, var): try: - return get_scalar_constant_value(var.owner.inputs[1]) + return get_underlying_scalar_constant(var.owner.inputs[1]) except NotScalarConstantError: raise ValueError(f"Length of {var} cannot be determined") @@ -1832,17 +1832,17 @@ def perform(self, node, inp, out_): def extract_constant(x, elemwise=True, only_process_constants=False): """ - This function is basically a call to tensor.get_scalar_constant_value. + This function is basically a call to tensor.get_underlying_scalar_constant. The main difference is the behaviour in case of failure. While - get_scalar_constant_value raises an TypeError, this function returns x, + get_underlying_scalar_constant raises an TypeError, this function returns x, as a tensor if possible. If x is a ScalarVariable from a scalar_from_tensor, we remove the conversion. If x is just a ScalarVariable, we convert it to a tensor with tensor_from_scalar. """ try: - x = get_scalar_constant_value(x, elemwise, only_process_constants) + x = get_underlying_scalar_constant(x, elemwise, only_process_constants) except NotScalarConstantError: pass if isinstance(x, aes.ScalarVariable) or isinstance( @@ -2212,7 +2212,7 @@ def make_node(self, axis, *tensors): if not isinstance(axis, int): try: - axis = int(get_scalar_constant_value(axis)) + axis = int(get_underlying_scalar_constant(axis)) except NotScalarConstantError: pass @@ -2461,7 +2461,7 @@ def infer_shape(self, fgraph, node, ishapes): def _get_vector_length_Join(op, var): axis, *arrays = var.owner.inputs try: - axis = get_scalar_constant_value(axis) + axis = get_underlying_scalar_constant(axis) assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays) return builtins.sum(get_vector_length(a) for a in arrays) except NotScalarConstantError: @@ -2873,7 +2873,7 @@ def infer_shape(self, fgraph, node, i_shapes): def is_constant_value(var, value): try: - v = get_scalar_constant_value(var) + v = get_underlying_scalar_constant(var) return np.all(v == value) except NotScalarConstantError: pass @@ -3785,7 +3785,7 @@ def make_node(self, a, choices): static_out_shape = () for s in out_shape: try: - s_val = pytensor.get_scalar_constant_value(s) + s_val = pytensor.get_underlying_scalar_constant(s) except (NotScalarConstantError, AttributeError): s_val = None @@ -4106,7 +4106,7 @@ def take_along_axis(arr, indices, axis=0): "scalar_from_tensor", "tensor_from_scalar", "get_scalar_constant", - "get_scalar_constant_value", + "get_underlying_scalar_constant", "constant", "as_tensor_variable", "as_tensor", diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index cd10807921..ae3324fc1d 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -1834,7 +1834,7 @@ def local_gemm_to_ger(fgraph, node): xv = x.dimshuffle(0) yv = y.dimshuffle(1) try: - bval = at.get_scalar_constant_value(b) + bval = at.get_underlying_scalar_constant(b) except NotScalarConstantError: # b isn't a constant, GEMM is doing useful pre-scaling return diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py index 29408e7d1d..53c591bdd0 100644 --- a/pytensor/tensor/conv/abstract_conv.py +++ b/pytensor/tensor/conv/abstract_conv.py @@ -24,7 +24,7 @@ from pytensor.graph.basic import Apply, Variable from pytensor.graph.op import Op from pytensor.raise_op import Assert -from pytensor.tensor.basic import as_tensor_variable, get_scalar_constant_value +from pytensor.tensor.basic import as_tensor_variable, get_underlying_scalar_constant from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.var import TensorConstant, TensorVariable @@ -495,8 +495,8 @@ def check_dim(given, computed): if given is None or computed is None: return True try: - given = get_scalar_constant_value(given) - computed = get_scalar_constant_value(computed) + given = get_underlying_scalar_constant(given) + computed = get_underlying_scalar_constant(computed) return int(given) == int(computed) except NotScalarConstantError: # no answer possible, accept for now @@ -532,7 +532,7 @@ def assert_conv_shape(shape): out_shape = [] for i, n in enumerate(shape): try: - const_n = get_scalar_constant_value(n) + const_n = get_underlying_scalar_constant(n) if i < 2: if const_n < 0: raise ValueError( @@ -2200,7 +2200,7 @@ def __init__( if imshp_i is not None: # Components of imshp should be constant or ints try: - get_scalar_constant_value(imshp_i, only_process_constants=True) + get_underlying_scalar_constant(imshp_i, only_process_constants=True) except NotScalarConstantError: raise ValueError( "imshp should be None or a tuple of constant int values" @@ -2213,7 +2213,7 @@ def __init__( if kshp_i is not None: # Components of kshp should be constant or ints try: - get_scalar_constant_value(kshp_i, only_process_constants=True) + get_underlying_scalar_constant(kshp_i, only_process_constants=True) except NotScalarConstantError: raise ValueError( "kshp should be None or a tuple of constant int values" diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 1e316e5afa..87992c7235 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -759,7 +759,7 @@ def perform(self, node, inputs, output_storage): ufunc = self.ufunc elif not hasattr(node.tag, "ufunc"): # It happen that make_thunk isn't called, like in - # get_scalar_constant_value + # get_underlying_scalar_constant self.prepare_node(node, None, None, "py") # prepare_node will add ufunc to self or the tag # depending if we can reuse it or not. So we need to diff --git a/pytensor/tensor/exceptions.py b/pytensor/tensor/exceptions.py index d6996dfe39..59ad52fada 100644 --- a/pytensor/tensor/exceptions.py +++ b/pytensor/tensor/exceptions.py @@ -4,7 +4,7 @@ class ShapeError(Exception): class NotScalarConstantError(Exception): """ - Raised by get_scalar_constant_value if called on something that is + Raised by get_underlying_scalar_constant if called on something that is not a scalar constant. """ diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 5844b295df..5025d6e382 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -671,7 +671,7 @@ def make_node(self, x, repeats): out_shape = [None] else: try: - const_reps = at.get_scalar_constant_value(repeats) + const_reps = at.get_underlying_scalar_constant(repeats) except NotScalarConstantError: const_reps = None if const_reps == 1: diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index 10b5a06b84..e04b56dd64 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -12,7 +12,7 @@ from pytensor.tensor.basic import ( as_tensor_variable, constant, - get_scalar_constant_value, + get_underlying_scalar_constant, get_vector_length, infer_static_shape, ) @@ -277,7 +277,7 @@ def infer_shape(self, fgraph, node, input_shapes): try: size_len = get_vector_length(size) except ValueError: - size_len = get_scalar_constant_value(size_shape[0]) + size_len = get_underlying_scalar_constant(size_shape[0]) size = tuple(size[n] for n in range(size_len)) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index dbf2b472b1..14dce47b1d 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -32,7 +32,7 @@ cast, extract_constant, fill, - get_scalar_constant_value, + get_underlying_scalar_constant, join, ones_like, switch, @@ -802,7 +802,7 @@ def local_remove_useless_assert(fgraph, node): n_conds = len(node.inputs[1:]) for c in node.inputs[1:]: try: - const = get_scalar_constant_value(c) + const = get_underlying_scalar_constant(c) if 0 != const.ndim or const == 0: # Should we raise an error here? How to be sure it @@ -895,7 +895,7 @@ def local_join_empty(fgraph, node): return new_inputs = [] try: - join_idx = get_scalar_constant_value( + join_idx = get_underlying_scalar_constant( node.inputs[0], only_process_constants=True ) except NotScalarConstantError: diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 494b08025d..5d231421d4 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -22,7 +22,7 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError, MethodNotDefined -from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value +from pytensor.tensor.basic import MakeVector, alloc, cast, get_underlying_scalar_constant from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize @@ -495,7 +495,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): else: try: # works only for scalars - cval_i = get_scalar_constant_value( + cval_i = get_underlying_scalar_constant( i, only_process_constants=True ) if all(i.broadcastable): diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index ac42197199..061835325f 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -31,7 +31,7 @@ constant, extract_constant, fill, - get_scalar_constant_value, + get_underlying_scalar_constant, ones_like, switch, zeros_like, @@ -112,7 +112,7 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): nonconsts = [] for i in inputs: try: - v = get_scalar_constant_value( + v = get_underlying_scalar_constant( i, elemwise=elemwise, only_process_constants=only_process_constants ) consts.append(v) @@ -165,13 +165,13 @@ def local_0_dot_x(fgraph, node): y = node.inputs[1] replace = False try: - if get_scalar_constant_value(x, only_process_constants=True) == 0: + if get_underlying_scalar_constant(x, only_process_constants=True) == 0: replace = True except NotScalarConstantError: pass try: - if get_scalar_constant_value(y, only_process_constants=True) == 0: + if get_underlying_scalar_constant(y, only_process_constants=True) == 0: replace = True except NotScalarConstantError: pass @@ -585,7 +585,7 @@ def local_mul_switch_sink(fgraph, node): switch_node = i.owner try: if ( - get_scalar_constant_value( + get_underlying_scalar_constant( switch_node.inputs[1], only_process_constants=True ) == 0.0 @@ -613,7 +613,7 @@ def local_mul_switch_sink(fgraph, node): pass try: if ( - get_scalar_constant_value( + get_underlying_scalar_constant( switch_node.inputs[2], only_process_constants=True ) == 0.0 @@ -665,7 +665,7 @@ def local_div_switch_sink(fgraph, node): switch_node = node.inputs[0].owner try: if ( - get_scalar_constant_value( + get_underlying_scalar_constant( switch_node.inputs[1], only_process_constants=True ) == 0.0 @@ -691,7 +691,7 @@ def local_div_switch_sink(fgraph, node): pass try: if ( - get_scalar_constant_value( + get_underlying_scalar_constant( switch_node.inputs[2], only_process_constants=True ) == 0.0 @@ -1493,7 +1493,7 @@ def investigate(node): and investigate(node.inputs[0].owner) ): try: - cst = get_scalar_constant_value(node.inputs[1], only_process_constants=True) + cst = get_underlying_scalar_constant(node.inputs[1], only_process_constants=True) res = zeros_like(node.inputs[0], dtype=dtype, opt=True) @@ -1733,7 +1733,7 @@ def local_reduce_join(fgraph, node): # We add the new check late to don't add extra warning. try: - join_axis = get_scalar_constant_value( + join_axis = get_underlying_scalar_constant( join_node.inputs[0], only_process_constants=True ) @@ -1816,7 +1816,7 @@ def local_opt_alloc(fgraph, node): inp = node_inps.owner.inputs[0] shapes = node_inps.owner.inputs[1:] try: - val = get_scalar_constant_value(inp, only_process_constants=True) + val = get_underlying_scalar_constant(inp, only_process_constants=True) assert val.size == 1 val = val.reshape(1)[0] # check which type of op @@ -1948,7 +1948,7 @@ def local_mul_zero(fgraph, node): for i in node.inputs: try: - value = get_scalar_constant_value(i) + value = get_underlying_scalar_constant(i) except NotScalarConstantError: continue # print 'MUL by value', value, node.inputs @@ -2230,7 +2230,7 @@ def local_add_specialize(fgraph, node): new_inputs = [] for inp in node.inputs: try: - y = get_scalar_constant_value(inp) + y = get_underlying_scalar_constant(inp) except NotScalarConstantError: y = inp if np.all(y == 0.0): @@ -2329,7 +2329,7 @@ def local_abs_merge(fgraph, node): inputs.append(i.owner.inputs[0]) elif isinstance(i, Constant): try: - const = get_scalar_constant_value(i, only_process_constants=True) + const = get_underlying_scalar_constant(i, only_process_constants=True) except NotScalarConstantError: return False if not (const >= 0).all(): @@ -2878,7 +2878,7 @@ def check_input(inputs): mul_neg = mul(*mul_inputs) try: - cst2 = get_scalar_constant_value( + cst2 = get_underlying_scalar_constant( mul_neg.owner.inputs[0], only_process_constants=True ) except NotScalarConstantError: @@ -2912,7 +2912,7 @@ def check_input(inputs): x = erfc_x try: - cst = get_scalar_constant_value( + cst = get_underlying_scalar_constant( erfc_x.owner.inputs[0], only_process_constants=True ) except NotScalarConstantError: @@ -2979,7 +2979,7 @@ def _is_1(expr): """ try: - v = get_scalar_constant_value(expr) + v = get_underlying_scalar_constant(expr) return np.allclose(v, 1) except NotScalarConstantError: return False @@ -3147,7 +3147,7 @@ def is_neg(var): if var_node.op == mul and len(var_node.inputs) >= 2: for idx, mul_input in enumerate(var_node.inputs): try: - constant = get_scalar_constant_value(mul_input) + constant = get_underlying_scalar_constant(mul_input) is_minus_1 = np.allclose(constant, -1) except NotScalarConstantError: is_minus_1 = False diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 5bc0ab9505..0b6b2707f9 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -24,7 +24,7 @@ cast, constant, extract_constant, - get_scalar_constant_value, + get_underlying_scalar_constant, stack, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise @@ -226,7 +226,7 @@ def shape_ir(self, i, r): # Do not call make_node for test_value s = Shape_i(i)(r) try: - s = get_scalar_constant_value(s) + s = get_underlying_scalar_constant(s) except NotScalarConstantError: pass return s @@ -310,7 +310,7 @@ def unpack(self, s_i, var): assert len(idx) == 1 idx = idx[0] try: - i = get_scalar_constant_value(idx) + i = get_underlying_scalar_constant(idx) except NotScalarConstantError: pass else: diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index d121586860..c6458129a4 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -25,7 +25,7 @@ cast, concatenate, extract_constant, - get_scalar_constant_value, + get_underlying_scalar_constant, switch, ) from pytensor.tensor.elemwise import Elemwise @@ -756,7 +756,7 @@ def local_subtensor_make_vector(fgraph, node): elif isinstance(idx, Variable): if idx.ndim == 0: try: - v = get_scalar_constant_value(idx, only_process_constants=True) + v = get_underlying_scalar_constant(idx, only_process_constants=True) try: ret = [x.owner.inputs[v]] except IndexError: @@ -808,7 +808,7 @@ def local_useless_inc_subtensor(fgraph, node): # This is an increment operation, so the array being incremented must # consist of all zeros in order for the entire operation to be useless try: - c = get_scalar_constant_value(x) + c = get_underlying_scalar_constant(x) if c != 0: return except NotScalarConstantError: @@ -927,7 +927,7 @@ def local_useless_subtensor(fgraph, node): if isinstance(idx.stop, (int, np.integer)): length_pos_data = sys.maxsize try: - length_pos_data = get_scalar_constant_value( + length_pos_data = get_underlying_scalar_constant( length_pos, only_process_constants=True ) except NotScalarConstantError: @@ -992,7 +992,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node): # get length of the indexed tensor along the first axis try: - length = get_scalar_constant_value( + length = get_underlying_scalar_constant( shape_of[node.inputs[0]][0], only_process_constants=True ) except NotScalarConstantError: @@ -1329,7 +1329,7 @@ def local_incsubtensor_of_zeros(fgraph, node): try: # Don't use only_process_constants=True. We need to # investigate Alloc of 0s but with non constant shape. - if get_scalar_constant_value(y, elemwise=False) == 0: + if get_underlying_scalar_constant(y, elemwise=False) == 0: # No need to copy over the stacktrace, # because x should already have a stacktrace return [x] @@ -1375,12 +1375,12 @@ def local_setsubtensor_of_constants(fgraph, node): # Don't use only_process_constants=True. We need to # investigate Alloc of 0s but with non constant shape. try: - replace_x = get_scalar_constant_value(x, elemwise=False) + replace_x = get_underlying_scalar_constant(x, elemwise=False) except NotScalarConstantError: return try: - replace_y = get_scalar_constant_value(y, elemwise=False) + replace_y = get_underlying_scalar_constant(y, elemwise=False) except NotScalarConstantError: return @@ -1668,7 +1668,7 @@ def local_join_subtensors(fgraph, node): axis, tensors = node.inputs[0], node.inputs[1:] try: - axis = get_scalar_constant_value(axis) + axis = get_underlying_scalar_constant(axis) except NotScalarConstantError: return @@ -1729,7 +1729,7 @@ def local_join_subtensors(fgraph, node): if step is None: continue try: - if get_scalar_constant_value(step, only_process_constants=True) != 1: + if get_underlying_scalar_constant(step, only_process_constants=True) != 1: return None except NotScalarConstantError: return None diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index 18f51ba3ca..ecffc7ba3b 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -397,7 +397,7 @@ class SpecifyShape(COp): _f16_ok = True def make_node(self, x, *shape): - from pytensor.tensor.basic import get_scalar_constant_value + from pytensor.tensor.basic import get_underlying_scalar_constant x = at.as_tensor_variable(x) @@ -426,7 +426,7 @@ def make_node(self, x, *shape): type_shape[i] = xts else: try: - type_s = get_scalar_constant_value(s) + type_s = get_underlying_scalar_constant(s) if type_s is not None: type_shape[i] = int(type_s) except NotScalarConstantError: @@ -457,9 +457,9 @@ def infer_shape(self, fgraph, node, shapes): for dim in range(node.inputs[0].type.ndim): s = shape[dim] try: - s = at.get_scalar_constant_value(s) + s = at.get_underlying_scalar_constant(s) # We assume that `None` shapes are always retrieved by - # `get_scalar_constant_value`, and only in that case do we default to + # `get_underlying_scalar_constant`, and only in that case do we default to # the shape of the input variable if s is None: s = xshape[dim] @@ -581,7 +581,7 @@ def specify_shape( @_get_vector_length.register(SpecifyShape) def _get_vector_length_SpecifyShape(op, var): try: - return at.get_scalar_constant_value(var.owner.inputs[1]).item() + return at.get_underlying_scalar_constant(var.owner.inputs[1]).item() except NotScalarConstantError: raise ValueError(f"Length of {var} cannot be determined") @@ -635,7 +635,7 @@ def make_node(self, x, shp): y = shp_list[index] y = at.as_tensor_variable(y) try: - s_val = at.get_scalar_constant_value(y).item() + s_val = at.get_underlying_scalar_constant(y).item() if s_val >= 0: out_shape[index] = s_val except NotScalarConstantError: diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 12d7d0a3f8..9afc3d2197 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -20,7 +20,7 @@ from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length -from pytensor.tensor.basic import alloc, get_scalar_constant_value +from pytensor.tensor.basic import alloc, get_underlying_scalar_constant from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.exceptions import ( AdvancedIndexingError, @@ -656,7 +656,7 @@ def conv(val): return slice(conv(val.start), conv(val.stop), conv(val.step)) else: try: - return get_scalar_constant_value( + return get_underlying_scalar_constant( val, only_process_constants=only_process_constants, elemwise=elemwise, @@ -733,7 +733,7 @@ def make_node(self, x, *inputs): if s == 1: start = p.start try: - start = get_scalar_constant_value(start) + start = get_underlying_scalar_constant(start) except NotScalarConstantError: pass if start is None or start == 0: @@ -2808,17 +2808,17 @@ def _get_vector_length_Subtensor(op, var): start = ( None if indices[0].start is None - else get_scalar_constant_value(indices[0].start) + else get_underlying_scalar_constant(indices[0].start) ) stop = ( None if indices[0].stop is None - else get_scalar_constant_value(indices[0].stop) + else get_underlying_scalar_constant(indices[0].stop) ) step = ( None if indices[0].step is None - else get_scalar_constant_value(indices[0].step) + else get_underlying_scalar_constant(indices[0].step) ) if start == stop: diff --git a/pytensor/tensor/var.py b/pytensor/tensor/var.py index ea9290787f..7da43ffbc4 100644 --- a/pytensor/tensor/var.py +++ b/pytensor/tensor/var.py @@ -756,8 +756,8 @@ def trace(self): # This value is set so that PyTensor arrays will trump NumPy operators. __array_priority__ = 1000 - def get_scalar_constant_value(self): - return at.basic.get_scalar_constant_value(self) + def get_underlying_scalar_constant(self): + return at.basic.get_underlying_scalar_constant(self) def zeros_like(model, dtype=None): return at.basic.zeros_like(model, dtype=dtype) diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index c83e0eea19..ec8509d202 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -1043,7 +1043,7 @@ def test_basic(self): from pytensor.tensor.exceptions import NotScalarConstantError with pytest.raises(NotScalarConstantError): - at.get_scalar_constant_value(s, only_process_constants=True) + at.get_underlying_scalar_constant(s, only_process_constants=True) # TODO: # def test_sparse_as_tensor_variable(self): diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 694d38b9c1..126b42d6ba 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -51,7 +51,7 @@ flatnonzero, flatten, full_like, - get_scalar_constant_value, + get_underlying_scalar_constant, get_vector_length, horizontal_stack, identity_like, @@ -3263,52 +3263,52 @@ def test_dimshuffle_duplicate(): DimShuffle((False,), (0, 0))(x) -class TestGetScalarConstantValue: +class TestGetUnderlyingScalarConstant: def test_basic(self): with pytest.raises(NotScalarConstantError): - get_scalar_constant_value(aes.int64()) + get_underlying_scalar_constant(aes.int64()) - res = get_scalar_constant_value(at.as_tensor(10)) + res = get_underlying_scalar_constant(at.as_tensor(10)) assert res == 10 assert isinstance(res, np.ndarray) - res = get_scalar_constant_value(np.array(10)) + res = get_underlying_scalar_constant(np.array(10)) assert res == 10 assert isinstance(res, np.ndarray) a = at.stack([1, 2, 3]) - assert get_scalar_constant_value(a[0]) == 1 - assert get_scalar_constant_value(a[1]) == 2 - assert get_scalar_constant_value(a[2]) == 3 + assert get_underlying_scalar_constant(a[0]) == 1 + assert get_underlying_scalar_constant(a[1]) == 2 + assert get_underlying_scalar_constant(a[2]) == 3 b = iscalar() a = at.stack([b, 2, 3]) with pytest.raises(NotScalarConstantError): - get_scalar_constant_value(a[0]) - assert get_scalar_constant_value(a[1]) == 2 - assert get_scalar_constant_value(a[2]) == 3 + get_underlying_scalar_constant(a[0]) + assert get_underlying_scalar_constant(a[1]) == 2 + assert get_underlying_scalar_constant(a[2]) == 3 - # For now get_scalar_constant_value goes through only MakeVector and Join of + # For now get_underlying_scalar_constant goes through only MakeVector and Join of # scalars. v = ivector() a = at.stack([v, [2], [3]]) with pytest.raises(NotScalarConstantError): - get_scalar_constant_value(a[0]) + get_underlying_scalar_constant(a[0]) with pytest.raises(NotScalarConstantError): - get_scalar_constant_value(a[1]) + get_underlying_scalar_constant(a[1]) with pytest.raises(NotScalarConstantError): - get_scalar_constant_value(a[2]) + get_underlying_scalar_constant(a[2]) # Test the case SubTensor(Shape(v)) when the dimensions # is broadcastable. v = row() - assert get_scalar_constant_value(v.shape[0]) == 1 + assert get_underlying_scalar_constant(v.shape[0]) == 1 - res = at.get_scalar_constant_value(at.as_tensor([10, 20]).shape[0]) + res = at.get_underlying_scalar_constant(at.as_tensor([10, 20]).shape[0]) assert isinstance(res, np.ndarray) assert 2 == res - res = at.get_scalar_constant_value( + res = at.get_underlying_scalar_constant( 9 + at.as_tensor([1.0]).shape[0], elemwise=True, only_process_constants=False, @@ -3320,63 +3320,63 @@ def test_basic(self): @pytest.mark.xfail(reason="Incomplete implementation") def test_DimShufle(self): a = as_tensor_variable(1.0)[None][0] - assert get_scalar_constant_value(a) == 1 + assert get_underlying_scalar_constant(a) == 1 def test_subtensor_of_constant(self): c = constant(random(5)) for i in range(c.value.shape[0]): - assert get_scalar_constant_value(c[i]) == c.value[i] + assert get_underlying_scalar_constant(c[i]) == c.value[i] c = constant(random(5, 5)) for i in range(c.value.shape[0]): for j in range(c.value.shape[1]): - assert get_scalar_constant_value(c[i, j]) == c.value[i, j] + assert get_underlying_scalar_constant(c[i, j]) == c.value[i, j] def test_numpy_array(self): # Regression test for crash when called on a numpy array. - assert get_scalar_constant_value(np.array(3)) == 3 + assert get_underlying_scalar_constant(np.array(3)) == 3 with pytest.raises(NotScalarConstantError): - get_scalar_constant_value(np.array([0, 1])) + get_underlying_scalar_constant(np.array([0, 1])) with pytest.raises(NotScalarConstantError): - get_scalar_constant_value(np.array([])) + get_underlying_scalar_constant(np.array([])) def test_make_vector(self): mv = make_vector(1, 2, 3) with pytest.raises(NotScalarConstantError): - get_scalar_constant_value(mv) - assert get_scalar_constant_value(mv[0]) == 1 - assert get_scalar_constant_value(mv[1]) == 2 - assert get_scalar_constant_value(mv[2]) == 3 - assert get_scalar_constant_value(mv[np.int32(0)]) == 1 - assert get_scalar_constant_value(mv[np.int64(1)]) == 2 - assert get_scalar_constant_value(mv[np.uint(2)]) == 3 + get_underlying_scalar_constant(mv) + assert get_underlying_scalar_constant(mv[0]) == 1 + assert get_underlying_scalar_constant(mv[1]) == 2 + assert get_underlying_scalar_constant(mv[2]) == 3 + assert get_underlying_scalar_constant(mv[np.int32(0)]) == 1 + assert get_underlying_scalar_constant(mv[np.int64(1)]) == 2 + assert get_underlying_scalar_constant(mv[np.uint(2)]) == 3 t = aes.ScalarType("int64") with pytest.raises(NotScalarConstantError): - get_scalar_constant_value(mv[t()]) + get_underlying_scalar_constant(mv[t()]) def test_shape_i(self): c = constant(np.random.random((3, 4))) s = Shape_i(0)(c) - assert get_scalar_constant_value(s) == 3 + assert get_underlying_scalar_constant(s) == 3 s = Shape_i(1)(c) - assert get_scalar_constant_value(s) == 4 + assert get_underlying_scalar_constant(s) == 4 d = pytensor.shared(np.random.standard_normal((1, 1)), shape=(1, 1)) f = ScalarFromTensor()(Shape_i(0)(d)) - assert get_scalar_constant_value(f) == 1 + assert get_underlying_scalar_constant(f) == 1 def test_elemwise(self): # We test only for a few elemwise, the list of all supported # elemwise are in the fct. c = constant(np.random.random()) s = c + 1 - assert np.allclose(get_scalar_constant_value(s), c.data + 1) + assert np.allclose(get_underlying_scalar_constant(s), c.data + 1) s = c - 1 - assert np.allclose(get_scalar_constant_value(s), c.data - 1) + assert np.allclose(get_underlying_scalar_constant(s), c.data - 1) s = c * 1.2 - assert np.allclose(get_scalar_constant_value(s), c.data * 1.2) + assert np.allclose(get_underlying_scalar_constant(s), c.data * 1.2) s = c < 0.5 - assert np.allclose(get_scalar_constant_value(s), int(c.data < 0.5)) + assert np.allclose(get_underlying_scalar_constant(s), int(c.data < 0.5)) s = at.second(c, 0.4) - assert np.allclose(get_scalar_constant_value(s), 0.4) + assert np.allclose(get_underlying_scalar_constant(s), 0.4) def test_assert(self): # Make sure we still get the constant value if it is wrapped in @@ -3386,25 +3386,25 @@ def test_assert(self): # condition is always True a = Assert()(c, c > 1) - assert get_scalar_constant_value(a) == 2 + assert get_underlying_scalar_constant(a) == 2 with config.change_flags(compute_test_value="off"): # condition is always False a = Assert()(c, c > 2) with pytest.raises(NotScalarConstantError): - get_scalar_constant_value(a) + get_underlying_scalar_constant(a) # condition is not constant a = Assert()(c, c > x) with pytest.raises(NotScalarConstantError): - get_scalar_constant_value(a) + get_underlying_scalar_constant(a) def test_second(self): # Second should apply when the value is constant but not the shape c = constant(np.random.random()) shp = vector() s = at.second(shp, c) - assert get_scalar_constant_value(s) == c.data + assert get_underlying_scalar_constant(s) == c.data def test_copy(self): # Make sure we do not return the internal storage of a constant, @@ -3418,14 +3418,14 @@ def test_copy(self): @pytest.mark.parametrize("only_process_constants", (True, False)) def test_None_and_NoneConst(self, only_process_constants): with pytest.raises(NotScalarConstantError): - get_scalar_constant_value( + get_underlying_scalar_constant( None, only_process_constants=only_process_constants ) assert ( - get_scalar_constant_value( + get_underlying_scalar_constant( NoneConst, only_process_constants=only_process_constants ) - is None + is None ) diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 65bd61c656..9bb1fd3154 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -823,8 +823,8 @@ def test_partial_static_shape_info(self): assert len(res_shape) == 1 assert len(res_shape[0]) == 2 - assert pytensor.get_scalar_constant_value(res_shape[0][0]) == 1 - assert pytensor.get_scalar_constant_value(res_shape[0][1]) == 1 + assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1 + assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1 def test_multi_output(self): class CustomElemwise(Elemwise): diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 60ff3178b5..7b8407c08b 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -27,7 +27,7 @@ as_tensor_variable, constant, eye, - get_scalar_constant_value, + get_underlying_scalar_constant, switch, ) from pytensor.tensor.elemwise import CAReduce, Elemwise @@ -894,7 +894,7 @@ def test_arg_grad(self): x = matrix() cost = argmax(x, axis=0).sum() gx = grad(cost, x) - val = get_scalar_constant_value(gx) + val = get_underlying_scalar_constant(gx) assert val == 0.0 def test_grad(self): From cc6340c3aa7b6960cb639d213cf3134f06b56357 Mon Sep 17 00:00:00 2001 From: Shreyas Singh Date: Tue, 14 Mar 2023 17:49:36 +0530 Subject: [PATCH 3/7] Add test for get_scalar_constant --- tests/tensor/test_basic.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 126b42d6ba..1dcca44539 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -51,6 +51,7 @@ flatnonzero, flatten, full_like, + get_scalar_constant, get_underlying_scalar_constant, get_vector_length, horizontal_stack, @@ -3429,6 +3430,12 @@ def test_None_and_NoneConst(self, only_process_constants): ) +def test_get_scalar_constant(): + with pytest.raises(NotScalarConstantError): + get_scalar_constant(np.zeros(5)) + assert get_scalar_constant(np.array(4)) == 4 + + def test_complex_mod_failure(): # Make sure % fails on complex numbers. x = vector(dtype="complex64") From cce055e2b1061dc54cb79836181496841b4efa03 Mon Sep 17 00:00:00 2001 From: Shreyas Singh Date: Sat, 18 Mar 2023 04:19:42 +0530 Subject: [PATCH 4/7] Modify docstrings for get_scalar_constant method --- pytensor/tensor/basic.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 85b9a684c1..23bf88fa4d 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -257,7 +257,13 @@ def _obj_is_wrappable_as_tensor(x): def get_scalar_constant(v, elemwise=True, only_process_constants=False, max_recur=10): """ - Checks whether 'v' is a scalar based on 'ndim' + Checks whether 'v' is a scalar (ndim = 0). + + If 'v' is a scalar then this function fetches the underlying constant by calling + 'get_underlying_scalar_constant()'. + + If 'v' is not a scalar, it raises a NotScalarConstantError. + """ if isinstance(v, np.ndarray): data = v.data From bfbcc8ff43f3e9545044ac3bb2d6fdde673266c1 Mon Sep 17 00:00:00 2001 From: Shreyas Singh Date: Sat, 18 Mar 2023 04:21:29 +0530 Subject: [PATCH 5/7] Reformat with black --- pytensor/tensor/basic.py | 4 +++- pytensor/tensor/rewriting/elemwise.py | 7 ++++++- pytensor/tensor/rewriting/math.py | 8 ++++++-- pytensor/tensor/rewriting/subtensor.py | 5 ++++- tests/tensor/test_basic.py | 4 ++-- 5 files changed, 21 insertions(+), 7 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 23bf88fa4d..38d67a7bfc 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -269,7 +269,9 @@ def get_scalar_constant(v, elemwise=True, only_process_constants=False, max_recu data = v.data if data.ndim != 0: raise NotScalarConstantError() - return get_underlying_scalar_constant(v, elemwise, only_process_constants, max_recur) + return get_underlying_scalar_constant( + v, elemwise, only_process_constants, max_recur + ) def get_underlying_scalar_constant( diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 5d231421d4..8eb2d891a3 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -22,7 +22,12 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError, MethodNotDefined -from pytensor.tensor.basic import MakeVector, alloc, cast, get_underlying_scalar_constant +from pytensor.tensor.basic import ( + MakeVector, + alloc, + cast, + get_underlying_scalar_constant, +) from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 061835325f..5a66bf8a0f 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -1493,7 +1493,9 @@ def investigate(node): and investigate(node.inputs[0].owner) ): try: - cst = get_underlying_scalar_constant(node.inputs[1], only_process_constants=True) + cst = get_underlying_scalar_constant( + node.inputs[1], only_process_constants=True + ) res = zeros_like(node.inputs[0], dtype=dtype, opt=True) @@ -2329,7 +2331,9 @@ def local_abs_merge(fgraph, node): inputs.append(i.owner.inputs[0]) elif isinstance(i, Constant): try: - const = get_underlying_scalar_constant(i, only_process_constants=True) + const = get_underlying_scalar_constant( + i, only_process_constants=True + ) except NotScalarConstantError: return False if not (const >= 0).all(): diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index c6458129a4..51035106ab 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -1729,7 +1729,10 @@ def local_join_subtensors(fgraph, node): if step is None: continue try: - if get_underlying_scalar_constant(step, only_process_constants=True) != 1: + if ( + get_underlying_scalar_constant(step, only_process_constants=True) + != 1 + ): return None except NotScalarConstantError: return None diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 1dcca44539..23f1418260 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3423,10 +3423,10 @@ def test_None_and_NoneConst(self, only_process_constants): None, only_process_constants=only_process_constants ) assert ( - get_underlying_scalar_constant( + get_underlying_scalar_constant( NoneConst, only_process_constants=only_process_constants ) - is None + is None ) From 7a01715028b4c836926d715122d9478471754511 Mon Sep 17 00:00:00 2001 From: Shreyas Singh Date: Wed, 22 Mar 2023 13:47:17 +0530 Subject: [PATCH 6/7] Refactor to get_scalar_constant_value and get_underlying_scalar_constant_value --- doc/library/tensor/basic.rst | 2 +- pytensor/__init__.py | 4 +- pytensor/link/jax/dispatch/tensor_basic.py | 6 +- pytensor/scan/basic.py | 6 +- pytensor/scan/rewriting.py | 10 ++- pytensor/tensor/basic.py | 44 ++++----- pytensor/tensor/blas.py | 2 +- pytensor/tensor/conv/abstract_conv.py | 19 ++-- pytensor/tensor/elemwise.py | 2 +- pytensor/tensor/exceptions.py | 2 +- pytensor/tensor/extra_ops.py | 2 +- pytensor/tensor/random/op.py | 4 +- pytensor/tensor/rewriting/basic.py | 6 +- pytensor/tensor/rewriting/elemwise.py | 4 +- pytensor/tensor/rewriting/math.py | 38 ++++---- pytensor/tensor/rewriting/shape.py | 6 +- pytensor/tensor/rewriting/subtensor.py | 24 ++--- pytensor/tensor/shape.py | 12 +-- pytensor/tensor/subtensor.py | 12 +-- pytensor/tensor/var.py | 2 +- tests/sparse/test_basic.py | 2 +- tests/tensor/test_basic.py | 100 ++++++++++----------- tests/tensor/test_math.py | 4 +- 23 files changed, 166 insertions(+), 147 deletions(-) diff --git a/doc/library/tensor/basic.rst b/doc/library/tensor/basic.rst index fda72a9c36..911583da92 100644 --- a/doc/library/tensor/basic.rst +++ b/doc/library/tensor/basic.rst @@ -577,7 +577,7 @@ them perfectly, but a `dscalar` otherwise. .. method:: round(mode="half_away_from_zero") :noindex: .. method:: trace() - .. method:: get_underlying_scalar_constant() + .. method:: get_underlying_scalar_constant_value() .. method:: zeros_like(model, dtype=None) All the above methods are equivalent to NumPy for PyTensor on the current tensor. diff --git a/pytensor/__init__.py b/pytensor/__init__.py index 789c465ce6..8e01416b93 100644 --- a/pytensor/__init__.py +++ b/pytensor/__init__.py @@ -153,8 +153,8 @@ def get_underlying_scalar_constant(v): if sparse and isinstance(v.type, sparse.SparseTensorType): if v.owner is not None and isinstance(v.owner.op, sparse.CSM): data = v.owner.inputs[0] - return tensor.get_underlying_scalar_constant(data) - return tensor.get_underlying_scalar_constant(v) + return tensor.get_underlying_scalar_constant_value(data) + return tensor.get_underlying_scalar_constant_value(v) # isort: off diff --git a/pytensor/link/jax/dispatch/tensor_basic.py b/pytensor/link/jax/dispatch/tensor_basic.py index 4545705e28..7981eb21a8 100644 --- a/pytensor/link/jax/dispatch/tensor_basic.py +++ b/pytensor/link/jax/dispatch/tensor_basic.py @@ -18,7 +18,7 @@ ScalarFromTensor, Split, TensorFromScalar, - get_underlying_scalar_constant, + get_underlying_scalar_constant_value, ) from pytensor.tensor.exceptions import NotScalarConstantError @@ -106,7 +106,7 @@ def join(axis, *tensors): def jax_funcify_Split(op: Split, node, **kwargs): _, axis, splits = node.inputs try: - constant_axis = get_underlying_scalar_constant(axis) + constant_axis = get_underlying_scalar_constant_value(axis) except NotScalarConstantError: constant_axis = None warnings.warn( @@ -116,7 +116,7 @@ def jax_funcify_Split(op: Split, node, **kwargs): try: constant_splits = np.array( [ - get_underlying_scalar_constant(splits[i]) + get_underlying_scalar_constant_value(splits[i]) for i in range(get_vector_length(splits)) ] ) diff --git a/pytensor/scan/basic.py b/pytensor/scan/basic.py index 973d4be7bf..d5109a0a9c 100644 --- a/pytensor/scan/basic.py +++ b/pytensor/scan/basic.py @@ -12,7 +12,7 @@ from pytensor.graph.utils import MissingInputError, TestValueError from pytensor.scan.op import Scan, ScanInfo from pytensor.scan.utils import expand_empty, safe_new, until -from pytensor.tensor.basic import get_underlying_scalar_constant +from pytensor.tensor.basic import get_underlying_scalar_constant_value from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import minimum from pytensor.tensor.shape import shape_padleft, unbroadcast @@ -147,7 +147,7 @@ def isNaN_or_Inf_or_None(x): isStr = False if not isNaN and not isInf: try: - val = get_underlying_scalar_constant(x) + val = get_underlying_scalar_constant_value(x) isInf = np.isinf(val) isNaN = np.isnan(val) except Exception: @@ -476,7 +476,7 @@ def wrap_into_list(x): n_fixed_steps = int(n_steps) else: try: - n_fixed_steps = at.get_underlying_scalar_constant(n_steps) + n_fixed_steps = at.get_underlying_scalar_constant_value(n_steps) except NotScalarConstantError: n_fixed_steps = None diff --git a/pytensor/scan/rewriting.py b/pytensor/scan/rewriting.py index abab3644c9..797d4c4062 100644 --- a/pytensor/scan/rewriting.py +++ b/pytensor/scan/rewriting.py @@ -49,7 +49,11 @@ safe_new, scan_can_remove_outs, ) -from pytensor.tensor.basic import Alloc, AllocEmpty, get_underlying_scalar_constant +from pytensor.tensor.basic import ( + Alloc, + AllocEmpty, + get_underlying_scalar_constant_value, +) from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.math import Dot, dot, maximum, minimum @@ -1956,13 +1960,13 @@ def belongs_to_set(self, node, set_nodes): nsteps = node.inputs[0] try: - nsteps = int(get_underlying_scalar_constant(nsteps)) + nsteps = int(get_underlying_scalar_constant_value(nsteps)) except NotScalarConstantError: pass rep_nsteps = rep.inputs[0] try: - rep_nsteps = int(get_underlying_scalar_constant(rep_nsteps)) + rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps)) except NotScalarConstantError: pass diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 38d67a7bfc..1cc6949e87 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -255,12 +255,14 @@ def _obj_is_wrappable_as_tensor(x): ) -def get_scalar_constant(v, elemwise=True, only_process_constants=False, max_recur=10): +def get_scalar_constant_value( + v, elemwise=True, only_process_constants=False, max_recur=10 +): """ Checks whether 'v' is a scalar (ndim = 0). If 'v' is a scalar then this function fetches the underlying constant by calling - 'get_underlying_scalar_constant()'. + 'get_underlying_scalar_constant_value()'. If 'v' is not a scalar, it raises a NotScalarConstantError. @@ -269,12 +271,12 @@ def get_scalar_constant(v, elemwise=True, only_process_constants=False, max_recu data = v.data if data.ndim != 0: raise NotScalarConstantError() - return get_underlying_scalar_constant( + return get_underlying_scalar_constant_value( v, elemwise, only_process_constants, max_recur ) -def get_underlying_scalar_constant( +def get_underlying_scalar_constant_value( orig_v, elemwise=True, only_process_constants=False, max_recur=10 ): """Return the constant scalar(0-D) value underlying variable `v`. @@ -377,7 +379,7 @@ def get_underlying_scalar_constant( elif isinstance(v.owner.op, CheckAndRaise): # check if all conditions are constant and true conds = [ - get_underlying_scalar_constant(c, max_recur=max_recur) + get_underlying_scalar_constant_value(c, max_recur=max_recur) for c in v.owner.inputs[1:] ] if builtins.all(0 == c.ndim and c != 0 for c in conds): @@ -391,7 +393,7 @@ def get_underlying_scalar_constant( continue if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops): const = [ - get_underlying_scalar_constant(i, max_recur=max_recur) + get_underlying_scalar_constant_value(i, max_recur=max_recur) for i in v.owner.inputs ] ret = [[None]] @@ -410,7 +412,7 @@ def get_underlying_scalar_constant( v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops ): const = [ - get_underlying_scalar_constant(i, max_recur=max_recur) + get_underlying_scalar_constant_value(i, max_recur=max_recur) for i in v.owner.inputs ] ret = [[None]] @@ -456,7 +458,7 @@ def get_underlying_scalar_constant( ): idx = v.owner.op.idx_list[0] if isinstance(idx, Type): - idx = get_underlying_scalar_constant( + idx = get_underlying_scalar_constant_value( v.owner.inputs[1], max_recur=max_recur ) try: @@ -490,14 +492,14 @@ def get_underlying_scalar_constant( ): idx = v.owner.op.idx_list[0] if isinstance(idx, Type): - idx = get_underlying_scalar_constant( + idx = get_underlying_scalar_constant_value( v.owner.inputs[1], max_recur=max_recur ) # Python 2.4 does not support indexing with numpy.integer # So we cast it. idx = int(idx) ret = v.owner.inputs[0].owner.inputs[idx] - ret = get_underlying_scalar_constant(ret, max_recur=max_recur) + ret = get_underlying_scalar_constant_value(ret, max_recur=max_recur) # MakeVector can cast implicitly its input in some case. return _asarray(ret, dtype=v.type.dtype) @@ -512,7 +514,7 @@ def get_underlying_scalar_constant( idx_list = op.idx_list idx = idx_list[0] if isinstance(idx, Type): - idx = get_underlying_scalar_constant( + idx = get_underlying_scalar_constant_value( owner.inputs[1], max_recur=max_recur ) grandparent = leftmost_parent.owner.inputs[0] @@ -527,7 +529,7 @@ def get_underlying_scalar_constant( if not (idx < ndim): msg = ( - "get_underlying_scalar_constant detected " + "get_underlying_scalar_constant_value detected " f"deterministic IndexError: x.shape[{int(idx)}] " f"when x.ndim={int(ndim)}." ) @@ -1589,7 +1591,7 @@ def do_constant_folding(self, fgraph, node): @_get_vector_length.register(Alloc) def _get_vector_length_Alloc(var_inst, var): try: - return get_underlying_scalar_constant(var.owner.inputs[1]) + return get_underlying_scalar_constant_value(var.owner.inputs[1]) except NotScalarConstantError: raise ValueError(f"Length of {var} cannot be determined") @@ -1840,17 +1842,17 @@ def perform(self, node, inp, out_): def extract_constant(x, elemwise=True, only_process_constants=False): """ - This function is basically a call to tensor.get_underlying_scalar_constant. + This function is basically a call to tensor.get_underlying_scalar_constant_value. The main difference is the behaviour in case of failure. While - get_underlying_scalar_constant raises an TypeError, this function returns x, + get_underlying_scalar_constant_value raises an TypeError, this function returns x, as a tensor if possible. If x is a ScalarVariable from a scalar_from_tensor, we remove the conversion. If x is just a ScalarVariable, we convert it to a tensor with tensor_from_scalar. """ try: - x = get_underlying_scalar_constant(x, elemwise, only_process_constants) + x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants) except NotScalarConstantError: pass if isinstance(x, aes.ScalarVariable) or isinstance( @@ -2220,7 +2222,7 @@ def make_node(self, axis, *tensors): if not isinstance(axis, int): try: - axis = int(get_underlying_scalar_constant(axis)) + axis = int(get_underlying_scalar_constant_value(axis)) except NotScalarConstantError: pass @@ -2469,7 +2471,7 @@ def infer_shape(self, fgraph, node, ishapes): def _get_vector_length_Join(op, var): axis, *arrays = var.owner.inputs try: - axis = get_underlying_scalar_constant(axis) + axis = get_underlying_scalar_constant_value(axis) assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays) return builtins.sum(get_vector_length(a) for a in arrays) except NotScalarConstantError: @@ -2881,7 +2883,7 @@ def infer_shape(self, fgraph, node, i_shapes): def is_constant_value(var, value): try: - v = get_underlying_scalar_constant(var) + v = get_underlying_scalar_constant_value(var) return np.all(v == value) except NotScalarConstantError: pass @@ -4113,8 +4115,8 @@ def take_along_axis(arr, indices, axis=0): "cast", "scalar_from_tensor", "tensor_from_scalar", - "get_scalar_constant", - "get_underlying_scalar_constant", + "get_scalar_constant_value", + "get_underlying_scalar_constant_value", "constant", "as_tensor_variable", "as_tensor", diff --git a/pytensor/tensor/blas.py b/pytensor/tensor/blas.py index ae3324fc1d..1282cabae5 100644 --- a/pytensor/tensor/blas.py +++ b/pytensor/tensor/blas.py @@ -1834,7 +1834,7 @@ def local_gemm_to_ger(fgraph, node): xv = x.dimshuffle(0) yv = y.dimshuffle(1) try: - bval = at.get_underlying_scalar_constant(b) + bval = at.get_underlying_scalar_constant_value(b) except NotScalarConstantError: # b isn't a constant, GEMM is doing useful pre-scaling return diff --git a/pytensor/tensor/conv/abstract_conv.py b/pytensor/tensor/conv/abstract_conv.py index 53c591bdd0..b747cb3a28 100644 --- a/pytensor/tensor/conv/abstract_conv.py +++ b/pytensor/tensor/conv/abstract_conv.py @@ -24,7 +24,10 @@ from pytensor.graph.basic import Apply, Variable from pytensor.graph.op import Op from pytensor.raise_op import Assert -from pytensor.tensor.basic import as_tensor_variable, get_underlying_scalar_constant +from pytensor.tensor.basic import ( + as_tensor_variable, + get_underlying_scalar_constant_value, +) from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.var import TensorConstant, TensorVariable @@ -495,8 +498,8 @@ def check_dim(given, computed): if given is None or computed is None: return True try: - given = get_underlying_scalar_constant(given) - computed = get_underlying_scalar_constant(computed) + given = get_underlying_scalar_constant_value(given) + computed = get_underlying_scalar_constant_value(computed) return int(given) == int(computed) except NotScalarConstantError: # no answer possible, accept for now @@ -532,7 +535,7 @@ def assert_conv_shape(shape): out_shape = [] for i, n in enumerate(shape): try: - const_n = get_underlying_scalar_constant(n) + const_n = get_underlying_scalar_constant_value(n) if i < 2: if const_n < 0: raise ValueError( @@ -2200,7 +2203,9 @@ def __init__( if imshp_i is not None: # Components of imshp should be constant or ints try: - get_underlying_scalar_constant(imshp_i, only_process_constants=True) + get_underlying_scalar_constant_value( + imshp_i, only_process_constants=True + ) except NotScalarConstantError: raise ValueError( "imshp should be None or a tuple of constant int values" @@ -2213,7 +2218,9 @@ def __init__( if kshp_i is not None: # Components of kshp should be constant or ints try: - get_underlying_scalar_constant(kshp_i, only_process_constants=True) + get_underlying_scalar_constant_value( + kshp_i, only_process_constants=True + ) except NotScalarConstantError: raise ValueError( "kshp should be None or a tuple of constant int values" diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index 87992c7235..bbbd3831f2 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -759,7 +759,7 @@ def perform(self, node, inputs, output_storage): ufunc = self.ufunc elif not hasattr(node.tag, "ufunc"): # It happen that make_thunk isn't called, like in - # get_underlying_scalar_constant + # get_underlying_scalar_constant_value self.prepare_node(node, None, None, "py") # prepare_node will add ufunc to self or the tag # depending if we can reuse it or not. So we need to diff --git a/pytensor/tensor/exceptions.py b/pytensor/tensor/exceptions.py index 59ad52fada..c4f107b3bd 100644 --- a/pytensor/tensor/exceptions.py +++ b/pytensor/tensor/exceptions.py @@ -4,7 +4,7 @@ class ShapeError(Exception): class NotScalarConstantError(Exception): """ - Raised by get_underlying_scalar_constant if called on something that is + Raised by get_underlying_scalar_constant_value if called on something that is not a scalar constant. """ diff --git a/pytensor/tensor/extra_ops.py b/pytensor/tensor/extra_ops.py index 5025d6e382..bd3cf71fb4 100644 --- a/pytensor/tensor/extra_ops.py +++ b/pytensor/tensor/extra_ops.py @@ -671,7 +671,7 @@ def make_node(self, x, repeats): out_shape = [None] else: try: - const_reps = at.get_underlying_scalar_constant(repeats) + const_reps = at.get_underlying_scalar_constant_value(repeats) except NotScalarConstantError: const_reps = None if const_reps == 1: diff --git a/pytensor/tensor/random/op.py b/pytensor/tensor/random/op.py index e04b56dd64..5a3b4aea19 100644 --- a/pytensor/tensor/random/op.py +++ b/pytensor/tensor/random/op.py @@ -12,7 +12,7 @@ from pytensor.tensor.basic import ( as_tensor_variable, constant, - get_underlying_scalar_constant, + get_underlying_scalar_constant_value, get_vector_length, infer_static_shape, ) @@ -277,7 +277,7 @@ def infer_shape(self, fgraph, node, input_shapes): try: size_len = get_vector_length(size) except ValueError: - size_len = get_underlying_scalar_constant(size_shape[0]) + size_len = get_underlying_scalar_constant_value(size_shape[0]) size = tuple(size[n] for n in range(size_len)) diff --git a/pytensor/tensor/rewriting/basic.py b/pytensor/tensor/rewriting/basic.py index 14dce47b1d..c3bb653870 100644 --- a/pytensor/tensor/rewriting/basic.py +++ b/pytensor/tensor/rewriting/basic.py @@ -32,7 +32,7 @@ cast, extract_constant, fill, - get_underlying_scalar_constant, + get_underlying_scalar_constant_value, join, ones_like, switch, @@ -802,7 +802,7 @@ def local_remove_useless_assert(fgraph, node): n_conds = len(node.inputs[1:]) for c in node.inputs[1:]: try: - const = get_underlying_scalar_constant(c) + const = get_underlying_scalar_constant_value(c) if 0 != const.ndim or const == 0: # Should we raise an error here? How to be sure it @@ -895,7 +895,7 @@ def local_join_empty(fgraph, node): return new_inputs = [] try: - join_idx = get_underlying_scalar_constant( + join_idx = get_underlying_scalar_constant_value( node.inputs[0], only_process_constants=True ) except NotScalarConstantError: diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 8eb2d891a3..d3be0e2079 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -26,7 +26,7 @@ MakeVector, alloc, cast, - get_underlying_scalar_constant, + get_underlying_scalar_constant_value, ) from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError @@ -500,7 +500,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): else: try: # works only for scalars - cval_i = get_underlying_scalar_constant( + cval_i = get_underlying_scalar_constant_value( i, only_process_constants=True ) if all(i.broadcastable): diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index 5a66bf8a0f..9eab75bfe9 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -31,7 +31,7 @@ constant, extract_constant, fill, - get_underlying_scalar_constant, + get_underlying_scalar_constant_value, ones_like, switch, zeros_like, @@ -112,7 +112,7 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): nonconsts = [] for i in inputs: try: - v = get_underlying_scalar_constant( + v = get_underlying_scalar_constant_value( i, elemwise=elemwise, only_process_constants=only_process_constants ) consts.append(v) @@ -165,13 +165,13 @@ def local_0_dot_x(fgraph, node): y = node.inputs[1] replace = False try: - if get_underlying_scalar_constant(x, only_process_constants=True) == 0: + if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0: replace = True except NotScalarConstantError: pass try: - if get_underlying_scalar_constant(y, only_process_constants=True) == 0: + if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0: replace = True except NotScalarConstantError: pass @@ -585,7 +585,7 @@ def local_mul_switch_sink(fgraph, node): switch_node = i.owner try: if ( - get_underlying_scalar_constant( + get_underlying_scalar_constant_value( switch_node.inputs[1], only_process_constants=True ) == 0.0 @@ -613,7 +613,7 @@ def local_mul_switch_sink(fgraph, node): pass try: if ( - get_underlying_scalar_constant( + get_underlying_scalar_constant_value( switch_node.inputs[2], only_process_constants=True ) == 0.0 @@ -665,7 +665,7 @@ def local_div_switch_sink(fgraph, node): switch_node = node.inputs[0].owner try: if ( - get_underlying_scalar_constant( + get_underlying_scalar_constant_value( switch_node.inputs[1], only_process_constants=True ) == 0.0 @@ -691,7 +691,7 @@ def local_div_switch_sink(fgraph, node): pass try: if ( - get_underlying_scalar_constant( + get_underlying_scalar_constant_value( switch_node.inputs[2], only_process_constants=True ) == 0.0 @@ -1493,7 +1493,7 @@ def investigate(node): and investigate(node.inputs[0].owner) ): try: - cst = get_underlying_scalar_constant( + cst = get_underlying_scalar_constant_value( node.inputs[1], only_process_constants=True ) @@ -1735,7 +1735,7 @@ def local_reduce_join(fgraph, node): # We add the new check late to don't add extra warning. try: - join_axis = get_underlying_scalar_constant( + join_axis = get_underlying_scalar_constant_value( join_node.inputs[0], only_process_constants=True ) @@ -1818,7 +1818,9 @@ def local_opt_alloc(fgraph, node): inp = node_inps.owner.inputs[0] shapes = node_inps.owner.inputs[1:] try: - val = get_underlying_scalar_constant(inp, only_process_constants=True) + val = get_underlying_scalar_constant_value( + inp, only_process_constants=True + ) assert val.size == 1 val = val.reshape(1)[0] # check which type of op @@ -1950,7 +1952,7 @@ def local_mul_zero(fgraph, node): for i in node.inputs: try: - value = get_underlying_scalar_constant(i) + value = get_underlying_scalar_constant_value(i) except NotScalarConstantError: continue # print 'MUL by value', value, node.inputs @@ -2232,7 +2234,7 @@ def local_add_specialize(fgraph, node): new_inputs = [] for inp in node.inputs: try: - y = get_underlying_scalar_constant(inp) + y = get_underlying_scalar_constant_value(inp) except NotScalarConstantError: y = inp if np.all(y == 0.0): @@ -2331,7 +2333,7 @@ def local_abs_merge(fgraph, node): inputs.append(i.owner.inputs[0]) elif isinstance(i, Constant): try: - const = get_underlying_scalar_constant( + const = get_underlying_scalar_constant_value( i, only_process_constants=True ) except NotScalarConstantError: @@ -2882,7 +2884,7 @@ def check_input(inputs): mul_neg = mul(*mul_inputs) try: - cst2 = get_underlying_scalar_constant( + cst2 = get_underlying_scalar_constant_value( mul_neg.owner.inputs[0], only_process_constants=True ) except NotScalarConstantError: @@ -2916,7 +2918,7 @@ def check_input(inputs): x = erfc_x try: - cst = get_underlying_scalar_constant( + cst = get_underlying_scalar_constant_value( erfc_x.owner.inputs[0], only_process_constants=True ) except NotScalarConstantError: @@ -2983,7 +2985,7 @@ def _is_1(expr): """ try: - v = get_underlying_scalar_constant(expr) + v = get_underlying_scalar_constant_value(expr) return np.allclose(v, 1) except NotScalarConstantError: return False @@ -3151,7 +3153,7 @@ def is_neg(var): if var_node.op == mul and len(var_node.inputs) >= 2: for idx, mul_input in enumerate(var_node.inputs): try: - constant = get_underlying_scalar_constant(mul_input) + constant = get_underlying_scalar_constant_value(mul_input) is_minus_1 = np.allclose(constant, -1) except NotScalarConstantError: is_minus_1 = False diff --git a/pytensor/tensor/rewriting/shape.py b/pytensor/tensor/rewriting/shape.py index 0b6b2707f9..7b9c80c96c 100644 --- a/pytensor/tensor/rewriting/shape.py +++ b/pytensor/tensor/rewriting/shape.py @@ -24,7 +24,7 @@ cast, constant, extract_constant, - get_underlying_scalar_constant, + get_underlying_scalar_constant_value, stack, ) from pytensor.tensor.elemwise import DimShuffle, Elemwise @@ -226,7 +226,7 @@ def shape_ir(self, i, r): # Do not call make_node for test_value s = Shape_i(i)(r) try: - s = get_underlying_scalar_constant(s) + s = get_underlying_scalar_constant_value(s) except NotScalarConstantError: pass return s @@ -310,7 +310,7 @@ def unpack(self, s_i, var): assert len(idx) == 1 idx = idx[0] try: - i = get_underlying_scalar_constant(idx) + i = get_underlying_scalar_constant_value(idx) except NotScalarConstantError: pass else: diff --git a/pytensor/tensor/rewriting/subtensor.py b/pytensor/tensor/rewriting/subtensor.py index 51035106ab..4eff892903 100644 --- a/pytensor/tensor/rewriting/subtensor.py +++ b/pytensor/tensor/rewriting/subtensor.py @@ -25,7 +25,7 @@ cast, concatenate, extract_constant, - get_underlying_scalar_constant, + get_underlying_scalar_constant_value, switch, ) from pytensor.tensor.elemwise import Elemwise @@ -756,7 +756,9 @@ def local_subtensor_make_vector(fgraph, node): elif isinstance(idx, Variable): if idx.ndim == 0: try: - v = get_underlying_scalar_constant(idx, only_process_constants=True) + v = get_underlying_scalar_constant_value( + idx, only_process_constants=True + ) try: ret = [x.owner.inputs[v]] except IndexError: @@ -808,7 +810,7 @@ def local_useless_inc_subtensor(fgraph, node): # This is an increment operation, so the array being incremented must # consist of all zeros in order for the entire operation to be useless try: - c = get_underlying_scalar_constant(x) + c = get_underlying_scalar_constant_value(x) if c != 0: return except NotScalarConstantError: @@ -927,7 +929,7 @@ def local_useless_subtensor(fgraph, node): if isinstance(idx.stop, (int, np.integer)): length_pos_data = sys.maxsize try: - length_pos_data = get_underlying_scalar_constant( + length_pos_data = get_underlying_scalar_constant_value( length_pos, only_process_constants=True ) except NotScalarConstantError: @@ -992,7 +994,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node): # get length of the indexed tensor along the first axis try: - length = get_underlying_scalar_constant( + length = get_underlying_scalar_constant_value( shape_of[node.inputs[0]][0], only_process_constants=True ) except NotScalarConstantError: @@ -1329,7 +1331,7 @@ def local_incsubtensor_of_zeros(fgraph, node): try: # Don't use only_process_constants=True. We need to # investigate Alloc of 0s but with non constant shape. - if get_underlying_scalar_constant(y, elemwise=False) == 0: + if get_underlying_scalar_constant_value(y, elemwise=False) == 0: # No need to copy over the stacktrace, # because x should already have a stacktrace return [x] @@ -1375,12 +1377,12 @@ def local_setsubtensor_of_constants(fgraph, node): # Don't use only_process_constants=True. We need to # investigate Alloc of 0s but with non constant shape. try: - replace_x = get_underlying_scalar_constant(x, elemwise=False) + replace_x = get_underlying_scalar_constant_value(x, elemwise=False) except NotScalarConstantError: return try: - replace_y = get_underlying_scalar_constant(y, elemwise=False) + replace_y = get_underlying_scalar_constant_value(y, elemwise=False) except NotScalarConstantError: return @@ -1668,7 +1670,7 @@ def local_join_subtensors(fgraph, node): axis, tensors = node.inputs[0], node.inputs[1:] try: - axis = get_underlying_scalar_constant(axis) + axis = get_underlying_scalar_constant_value(axis) except NotScalarConstantError: return @@ -1730,7 +1732,9 @@ def local_join_subtensors(fgraph, node): continue try: if ( - get_underlying_scalar_constant(step, only_process_constants=True) + get_underlying_scalar_constant_value( + step, only_process_constants=True + ) != 1 ): return None diff --git a/pytensor/tensor/shape.py b/pytensor/tensor/shape.py index ecffc7ba3b..ffb407a121 100644 --- a/pytensor/tensor/shape.py +++ b/pytensor/tensor/shape.py @@ -397,7 +397,7 @@ class SpecifyShape(COp): _f16_ok = True def make_node(self, x, *shape): - from pytensor.tensor.basic import get_underlying_scalar_constant + from pytensor.tensor.basic import get_underlying_scalar_constant_value x = at.as_tensor_variable(x) @@ -426,7 +426,7 @@ def make_node(self, x, *shape): type_shape[i] = xts else: try: - type_s = get_underlying_scalar_constant(s) + type_s = get_underlying_scalar_constant_value(s) if type_s is not None: type_shape[i] = int(type_s) except NotScalarConstantError: @@ -457,9 +457,9 @@ def infer_shape(self, fgraph, node, shapes): for dim in range(node.inputs[0].type.ndim): s = shape[dim] try: - s = at.get_underlying_scalar_constant(s) + s = at.get_underlying_scalar_constant_value(s) # We assume that `None` shapes are always retrieved by - # `get_underlying_scalar_constant`, and only in that case do we default to + # `get_underlying_scalar_constant_value`, and only in that case do we default to # the shape of the input variable if s is None: s = xshape[dim] @@ -581,7 +581,7 @@ def specify_shape( @_get_vector_length.register(SpecifyShape) def _get_vector_length_SpecifyShape(op, var): try: - return at.get_underlying_scalar_constant(var.owner.inputs[1]).item() + return at.get_underlying_scalar_constant_value(var.owner.inputs[1]).item() except NotScalarConstantError: raise ValueError(f"Length of {var} cannot be determined") @@ -635,7 +635,7 @@ def make_node(self, x, shp): y = shp_list[index] y = at.as_tensor_variable(y) try: - s_val = at.get_underlying_scalar_constant(y).item() + s_val = at.get_underlying_scalar_constant_value(y).item() if s_val >= 0: out_shape[index] = s_val except NotScalarConstantError: diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 9afc3d2197..95534770ab 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -20,7 +20,7 @@ from pytensor.printing import Printer, pprint, set_precedence from pytensor.scalar.basic import ScalarConstant from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length -from pytensor.tensor.basic import alloc, get_underlying_scalar_constant +from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.exceptions import ( AdvancedIndexingError, @@ -656,7 +656,7 @@ def conv(val): return slice(conv(val.start), conv(val.stop), conv(val.step)) else: try: - return get_underlying_scalar_constant( + return get_underlying_scalar_constant_value( val, only_process_constants=only_process_constants, elemwise=elemwise, @@ -733,7 +733,7 @@ def make_node(self, x, *inputs): if s == 1: start = p.start try: - start = get_underlying_scalar_constant(start) + start = get_underlying_scalar_constant_value(start) except NotScalarConstantError: pass if start is None or start == 0: @@ -2808,17 +2808,17 @@ def _get_vector_length_Subtensor(op, var): start = ( None if indices[0].start is None - else get_underlying_scalar_constant(indices[0].start) + else get_underlying_scalar_constant_value(indices[0].start) ) stop = ( None if indices[0].stop is None - else get_underlying_scalar_constant(indices[0].stop) + else get_underlying_scalar_constant_value(indices[0].stop) ) step = ( None if indices[0].step is None - else get_underlying_scalar_constant(indices[0].step) + else get_underlying_scalar_constant_value(indices[0].step) ) if start == stop: diff --git a/pytensor/tensor/var.py b/pytensor/tensor/var.py index 7da43ffbc4..428cd8fffd 100644 --- a/pytensor/tensor/var.py +++ b/pytensor/tensor/var.py @@ -757,7 +757,7 @@ def trace(self): __array_priority__ = 1000 def get_underlying_scalar_constant(self): - return at.basic.get_underlying_scalar_constant(self) + return at.basic.get_underlying_scalar_constant_value(self) def zeros_like(model, dtype=None): return at.basic.zeros_like(model, dtype=dtype) diff --git a/tests/sparse/test_basic.py b/tests/sparse/test_basic.py index ec8509d202..0c6d59b064 100644 --- a/tests/sparse/test_basic.py +++ b/tests/sparse/test_basic.py @@ -1043,7 +1043,7 @@ def test_basic(self): from pytensor.tensor.exceptions import NotScalarConstantError with pytest.raises(NotScalarConstantError): - at.get_underlying_scalar_constant(s, only_process_constants=True) + at.get_underlying_scalar_constant_value(s, only_process_constants=True) # TODO: # def test_sparse_as_tensor_variable(self): diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 23f1418260..9d243a7adc 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -51,8 +51,8 @@ flatnonzero, flatten, full_like, - get_scalar_constant, - get_underlying_scalar_constant, + get_scalar_constant_value, + get_underlying_scalar_constant_value, get_vector_length, horizontal_stack, identity_like, @@ -3264,52 +3264,52 @@ def test_dimshuffle_duplicate(): DimShuffle((False,), (0, 0))(x) -class TestGetUnderlyingScalarConstant: +class TestGetUnderlyingScalarConstantValue: def test_basic(self): with pytest.raises(NotScalarConstantError): - get_underlying_scalar_constant(aes.int64()) + get_underlying_scalar_constant_value(aes.int64()) - res = get_underlying_scalar_constant(at.as_tensor(10)) + res = get_underlying_scalar_constant_value(at.as_tensor(10)) assert res == 10 assert isinstance(res, np.ndarray) - res = get_underlying_scalar_constant(np.array(10)) + res = get_underlying_scalar_constant_value(np.array(10)) assert res == 10 assert isinstance(res, np.ndarray) a = at.stack([1, 2, 3]) - assert get_underlying_scalar_constant(a[0]) == 1 - assert get_underlying_scalar_constant(a[1]) == 2 - assert get_underlying_scalar_constant(a[2]) == 3 + assert get_underlying_scalar_constant_value(a[0]) == 1 + assert get_underlying_scalar_constant_value(a[1]) == 2 + assert get_underlying_scalar_constant_value(a[2]) == 3 b = iscalar() a = at.stack([b, 2, 3]) with pytest.raises(NotScalarConstantError): - get_underlying_scalar_constant(a[0]) - assert get_underlying_scalar_constant(a[1]) == 2 - assert get_underlying_scalar_constant(a[2]) == 3 + get_underlying_scalar_constant_value(a[0]) + assert get_underlying_scalar_constant_value(a[1]) == 2 + assert get_underlying_scalar_constant_value(a[2]) == 3 - # For now get_underlying_scalar_constant goes through only MakeVector and Join of + # For now get_underlying_scalar_constant_value goes through only MakeVector and Join of # scalars. v = ivector() a = at.stack([v, [2], [3]]) with pytest.raises(NotScalarConstantError): - get_underlying_scalar_constant(a[0]) + get_underlying_scalar_constant_value(a[0]) with pytest.raises(NotScalarConstantError): - get_underlying_scalar_constant(a[1]) + get_underlying_scalar_constant_value(a[1]) with pytest.raises(NotScalarConstantError): - get_underlying_scalar_constant(a[2]) + get_underlying_scalar_constant_value(a[2]) # Test the case SubTensor(Shape(v)) when the dimensions # is broadcastable. v = row() - assert get_underlying_scalar_constant(v.shape[0]) == 1 + assert get_underlying_scalar_constant_value(v.shape[0]) == 1 - res = at.get_underlying_scalar_constant(at.as_tensor([10, 20]).shape[0]) + res = at.get_underlying_scalar_constant_value(at.as_tensor([10, 20]).shape[0]) assert isinstance(res, np.ndarray) assert 2 == res - res = at.get_underlying_scalar_constant( + res = at.get_underlying_scalar_constant_value( 9 + at.as_tensor([1.0]).shape[0], elemwise=True, only_process_constants=False, @@ -3321,63 +3321,63 @@ def test_basic(self): @pytest.mark.xfail(reason="Incomplete implementation") def test_DimShufle(self): a = as_tensor_variable(1.0)[None][0] - assert get_underlying_scalar_constant(a) == 1 + assert get_underlying_scalar_constant_value(a) == 1 def test_subtensor_of_constant(self): c = constant(random(5)) for i in range(c.value.shape[0]): - assert get_underlying_scalar_constant(c[i]) == c.value[i] + assert get_underlying_scalar_constant_value(c[i]) == c.value[i] c = constant(random(5, 5)) for i in range(c.value.shape[0]): for j in range(c.value.shape[1]): - assert get_underlying_scalar_constant(c[i, j]) == c.value[i, j] + assert get_underlying_scalar_constant_value(c[i, j]) == c.value[i, j] def test_numpy_array(self): # Regression test for crash when called on a numpy array. - assert get_underlying_scalar_constant(np.array(3)) == 3 + assert get_underlying_scalar_constant_value(np.array(3)) == 3 with pytest.raises(NotScalarConstantError): - get_underlying_scalar_constant(np.array([0, 1])) + get_underlying_scalar_constant_value(np.array([0, 1])) with pytest.raises(NotScalarConstantError): - get_underlying_scalar_constant(np.array([])) + get_underlying_scalar_constant_value(np.array([])) def test_make_vector(self): mv = make_vector(1, 2, 3) with pytest.raises(NotScalarConstantError): - get_underlying_scalar_constant(mv) - assert get_underlying_scalar_constant(mv[0]) == 1 - assert get_underlying_scalar_constant(mv[1]) == 2 - assert get_underlying_scalar_constant(mv[2]) == 3 - assert get_underlying_scalar_constant(mv[np.int32(0)]) == 1 - assert get_underlying_scalar_constant(mv[np.int64(1)]) == 2 - assert get_underlying_scalar_constant(mv[np.uint(2)]) == 3 + get_underlying_scalar_constant_value(mv) + assert get_underlying_scalar_constant_value(mv[0]) == 1 + assert get_underlying_scalar_constant_value(mv[1]) == 2 + assert get_underlying_scalar_constant_value(mv[2]) == 3 + assert get_underlying_scalar_constant_value(mv[np.int32(0)]) == 1 + assert get_underlying_scalar_constant_value(mv[np.int64(1)]) == 2 + assert get_underlying_scalar_constant_value(mv[np.uint(2)]) == 3 t = aes.ScalarType("int64") with pytest.raises(NotScalarConstantError): - get_underlying_scalar_constant(mv[t()]) + get_underlying_scalar_constant_value(mv[t()]) def test_shape_i(self): c = constant(np.random.random((3, 4))) s = Shape_i(0)(c) - assert get_underlying_scalar_constant(s) == 3 + assert get_underlying_scalar_constant_value(s) == 3 s = Shape_i(1)(c) - assert get_underlying_scalar_constant(s) == 4 + assert get_underlying_scalar_constant_value(s) == 4 d = pytensor.shared(np.random.standard_normal((1, 1)), shape=(1, 1)) f = ScalarFromTensor()(Shape_i(0)(d)) - assert get_underlying_scalar_constant(f) == 1 + assert get_underlying_scalar_constant_value(f) == 1 def test_elemwise(self): # We test only for a few elemwise, the list of all supported # elemwise are in the fct. c = constant(np.random.random()) s = c + 1 - assert np.allclose(get_underlying_scalar_constant(s), c.data + 1) + assert np.allclose(get_underlying_scalar_constant_value(s), c.data + 1) s = c - 1 - assert np.allclose(get_underlying_scalar_constant(s), c.data - 1) + assert np.allclose(get_underlying_scalar_constant_value(s), c.data - 1) s = c * 1.2 - assert np.allclose(get_underlying_scalar_constant(s), c.data * 1.2) + assert np.allclose(get_underlying_scalar_constant_value(s), c.data * 1.2) s = c < 0.5 - assert np.allclose(get_underlying_scalar_constant(s), int(c.data < 0.5)) + assert np.allclose(get_underlying_scalar_constant_value(s), int(c.data < 0.5)) s = at.second(c, 0.4) - assert np.allclose(get_underlying_scalar_constant(s), 0.4) + assert np.allclose(get_underlying_scalar_constant_value(s), 0.4) def test_assert(self): # Make sure we still get the constant value if it is wrapped in @@ -3387,25 +3387,25 @@ def test_assert(self): # condition is always True a = Assert()(c, c > 1) - assert get_underlying_scalar_constant(a) == 2 + assert get_underlying_scalar_constant_value(a) == 2 with config.change_flags(compute_test_value="off"): # condition is always False a = Assert()(c, c > 2) with pytest.raises(NotScalarConstantError): - get_underlying_scalar_constant(a) + get_underlying_scalar_constant_value(a) # condition is not constant a = Assert()(c, c > x) with pytest.raises(NotScalarConstantError): - get_underlying_scalar_constant(a) + get_underlying_scalar_constant_value(a) def test_second(self): # Second should apply when the value is constant but not the shape c = constant(np.random.random()) shp = vector() s = at.second(shp, c) - assert get_underlying_scalar_constant(s) == c.data + assert get_underlying_scalar_constant_value(s) == c.data def test_copy(self): # Make sure we do not return the internal storage of a constant, @@ -3419,21 +3419,21 @@ def test_copy(self): @pytest.mark.parametrize("only_process_constants", (True, False)) def test_None_and_NoneConst(self, only_process_constants): with pytest.raises(NotScalarConstantError): - get_underlying_scalar_constant( + get_underlying_scalar_constant_value( None, only_process_constants=only_process_constants ) assert ( - get_underlying_scalar_constant( + get_underlying_scalar_constant_value( NoneConst, only_process_constants=only_process_constants ) is None ) -def test_get_scalar_constant(): +def test_get_scalar_constant_value(): with pytest.raises(NotScalarConstantError): - get_scalar_constant(np.zeros(5)) - assert get_scalar_constant(np.array(4)) == 4 + get_scalar_constant_value(np.zeros(5)) + assert get_scalar_constant_value(np.array(4)) == 4 def test_complex_mod_failure(): diff --git a/tests/tensor/test_math.py b/tests/tensor/test_math.py index 7b8407c08b..8ee5ac4544 100644 --- a/tests/tensor/test_math.py +++ b/tests/tensor/test_math.py @@ -27,7 +27,7 @@ as_tensor_variable, constant, eye, - get_underlying_scalar_constant, + get_underlying_scalar_constant_value, switch, ) from pytensor.tensor.elemwise import CAReduce, Elemwise @@ -894,7 +894,7 @@ def test_arg_grad(self): x = matrix() cost = argmax(x, axis=0).sum() gx = grad(cost, x) - val = get_underlying_scalar_constant(gx) + val = get_underlying_scalar_constant_value(gx) assert val == 0.0 def test_grad(self): From 1981bb546e67676d965b0aec53b43223c6ae3e9b Mon Sep 17 00:00:00 2001 From: Shreyas Singh Date: Wed, 22 Mar 2023 13:54:36 +0530 Subject: [PATCH 7/7] Modify get_scalar_constant_value to work with PyTensor Variables --- pytensor/tensor/basic.py | 5 ++--- tests/tensor/test_basic.py | 10 +++++++--- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/pytensor/tensor/basic.py b/pytensor/tensor/basic.py index 1cc6949e87..b35ddec4bc 100644 --- a/pytensor/tensor/basic.py +++ b/pytensor/tensor/basic.py @@ -267,9 +267,8 @@ def get_scalar_constant_value( If 'v' is not a scalar, it raises a NotScalarConstantError. """ - if isinstance(v, np.ndarray): - data = v.data - if data.ndim != 0: + if isinstance(v, (Variable, np.ndarray)): + if v.ndim != 0: raise NotScalarConstantError() return get_underlying_scalar_constant_value( v, elemwise, only_process_constants, max_recur diff --git a/tests/tensor/test_basic.py b/tests/tensor/test_basic.py index 9d243a7adc..a3f3177229 100644 --- a/tests/tensor/test_basic.py +++ b/tests/tensor/test_basic.py @@ -3430,10 +3430,14 @@ def test_None_and_NoneConst(self, only_process_constants): ) -def test_get_scalar_constant_value(): +@pytest.mark.parametrize( + ["valid_inp", "invalid_inp"], + ((np.array(4), np.zeros(5)), (at.constant(4), at.constant(3, ndim=1))), +) +def test_get_scalar_constant_value(valid_inp, invalid_inp): with pytest.raises(NotScalarConstantError): - get_scalar_constant_value(np.zeros(5)) - assert get_scalar_constant_value(np.array(4)) == 4 + get_scalar_constant_value(invalid_inp) + assert get_scalar_constant_value(valid_inp) == 4 def test_complex_mod_failure():