diff --git a/README.rst b/README.rst index 0be106c05f..9ea767c528 100644 --- a/README.rst +++ b/README.rst @@ -22,69 +22,69 @@ Getting started .. code-block:: python - import pytensor - from pytensor import tensor as pt - - # Declare two symbolic floating-point scalars - a = pt.dscalar("a") - b = pt.dscalar("b") - - # Create a simple example expression - c = a + b - - # Convert the expression into a callable object that takes `(a, b)` - # values as input and computes the value of `c`. - f_c = pytensor.function([a, b], c) - - assert f_c(1.5, 2.5) == 4.0 - - # Compute the gradient of the example expression with respect to `a` - dc = pytensor.grad(c, a) - - f_dc = pytensor.function([a, b], dc) - - assert f_dc(1.5, 2.5) == 1.0 - - # Compiling functions with `pytensor.function` also optimizes - # expression graphs by removing unnecessary operations and - # replacing computations with more efficient ones. - - v = pt.vector("v") - M = pt.matrix("M") - - d = a/a + (M + a).dot(v) - - pytensor.dprint(d) - # Elemwise{add,no_inplace} [id A] '' - # |InplaceDimShuffle{x} [id B] '' - # | |Elemwise{true_div,no_inplace} [id C] '' - # | |a [id D] - # | |a [id D] - # |dot [id E] '' - # |Elemwise{add,no_inplace} [id F] '' - # | |M [id G] - # | |InplaceDimShuffle{x,x} [id H] '' - # | |a [id D] - # |v [id I] - - f_d = pytensor.function([a, v, M], d) - - # `a/a` -> `1` and the dot product is replaced with a BLAS function - # (i.e. CGemv) - pytensor.dprint(f_d) - # Elemwise{Add}[(0, 1)] [id A] '' 5 - # |TensorConstant{(1,) of 1.0} [id B] - # |CGemv{inplace} [id C] '' 4 - # |AllocEmpty{dtype='float64'} [id D] '' 3 - # | |Shape_i{0} [id E] '' 2 - # | |M [id F] - # |TensorConstant{1.0} [id G] - # |Elemwise{add,no_inplace} [id H] '' 1 - # | |M [id F] - # | |InplaceDimShuffle{x,x} [id I] '' 0 - # | |a [id J] - # |v [id K] - # |TensorConstant{0.0} [id L] + import pytensor + from pytensor import tensor as pt + + # Declare two symbolic floating-point scalars + a = pt.dscalar("a") + b = pt.dscalar("b") + + # Create a simple example expression + c = a + b + + # Convert the expression into a callable object that takes `(a, b)` + # values as input and computes the value of `c`. + f_c = pytensor.function([a, b], c) + + assert f_c(1.5, 2.5) == 4.0 + + # Compute the gradient of the example expression with respect to `a` + dc = pytensor.grad(c, a) + + f_dc = pytensor.function([a, b], dc) + + assert f_dc(1.5, 2.5) == 1.0 + + # Compiling functions with `pytensor.function` also optimizes + # expression graphs by removing unnecessary operations and + # replacing computations with more efficient ones. + + v = pt.vector("v") + M = pt.matrix("M") + + d = a/a + (M + a).dot(v) + + pytensor.dprint(d) + # Add [id A] + # ├─ ExpandDims{axis=0} [id B] + # │ └─ True_div [id C] + # │ ├─ a [id D] + # │ └─ a [id D] + # └─ dot [id E] + # ├─ Add [id F] + # │ ├─ M [id G] + # │ └─ ExpandDims{axes=[0, 1]} [id H] + # │ └─ a [id D] + # └─ v [id I] + + f_d = pytensor.function([a, v, M], d) + + # `a/a` -> `1` and the dot product is replaced with a BLAS function + # (i.e. CGemv) + pytensor.dprint(f_d) + # Add [id A] 5 + # ├─ [1.] [id B] + # └─ CGemv{inplace} [id C] 4 + # ├─ AllocEmpty{dtype='float64'} [id D] 3 + # │ └─ Shape_i{0} [id E] 2 + # │ └─ M [id F] + # ├─ 1.0 [id G] + # ├─ Add [id H] 1 + # │ ├─ M [id F] + # │ └─ ExpandDims{axes=[0, 1]} [id I] 0 + # │ └─ a [id J] + # ├─ v [id K] + # └─ 0.0 [id L] See `the PyTensor documentation `__ for in-depth tutorials. diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 7e59bf69e1..5111607a0a 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -763,13 +763,20 @@ def signature(self): return (self.type, self.data) def __str__(self): - if self.name is not None: - return self.name - else: - name = str(self.data) - if len(name) > 20: - name = name[:10] + "..." + name[-10:] - return f"{type(self).__name__}{{{name}}}" + data_str = str(self.data) + if len(data_str) > 20: + data_str = data_str[:10].strip() + " ... " + data_str[-10:].strip() + + if self.name is None: + return data_str + + return f"{self.name}{{{data_str}}}" + + def __repr__(self): + data_str = repr(self.data) + if len(data_str) > 20: + data_str = data_str[:10].strip() + " ... " + data_str[-10:].strip() + return f"{type(self).__name__}({repr(self.type)}, data={data_str})" def clone(self, **kwargs): return self diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py index b59394eb32..0cda6f1327 100644 --- a/pytensor/graph/op.py +++ b/pytensor/graph/op.py @@ -621,6 +621,11 @@ def make_thunk( def __str__(self): return getattr(type(self), "__name__", super().__str__()) + def __repr__(self): + props = getattr(self, "__props__", ()) + props = ",".join(f"{prop}={getattr(self, prop, '?')}" for prop in props) + return f"{self.__class__.__name__}({props})" + class _NoPythonOp(Op): """A class used to indicate that an `Op` does not provide a Python implementation. diff --git a/pytensor/graph/utils.py b/pytensor/graph/utils.py index 9b0abc5c73..6f00c687bf 100644 --- a/pytensor/graph/utils.py +++ b/pytensor/graph/utils.py @@ -234,6 +234,7 @@ def __eq__(self, other): dct["__eq__"] = __eq__ + # FIXME: This overrides __str__ inheritance when props are provided if "__str__" not in dct: if len(props) == 0: diff --git a/pytensor/printing.py b/pytensor/printing.py index 0b866f079e..11b266021b 100644 --- a/pytensor/printing.py +++ b/pytensor/printing.py @@ -291,7 +291,7 @@ def debugprint( for var in inputs_to_print: _debugprint( var, - prefix="-", + prefix="→ ", depth=depth, done=done, print_type=print_type, @@ -342,11 +342,17 @@ def debugprint( if len(inner_graph_vars) > 0: print("", file=_file) - new_prefix = " >" - new_prefix_child = " >" + prefix = "" + new_prefix = prefix + " ← " + new_prefix_child = prefix + " " print("Inner graphs:", file=_file) + printed_inner_graphs_nodes = set() for ig_var in inner_graph_vars: + if ig_var.owner in printed_inner_graphs_nodes: + continue + else: + printed_inner_graphs_nodes.add(ig_var.owner) # This is a work-around to maintain backward compatibility # (e.g. to only print inner graphs that have been compiled through # a call to `Op.prepare_node`) @@ -385,6 +391,7 @@ def debugprint( _debugprint( ig_var, + prefix=prefix, depth=depth, done=done, print_type=print_type, @@ -399,13 +406,14 @@ def debugprint( print_op_info=print_op_info, print_destroy_map=print_destroy_map, print_view_map=print_view_map, + is_inner_graph_header=True, ) if print_fgraph_inputs: for inp in inner_inputs: _debugprint( inp, - prefix="-", + prefix=" → ", depth=depth, done=done, print_type=print_type, @@ -485,6 +493,7 @@ def _debugprint( parent_node: Optional[Apply] = None, print_op_info: bool = False, inner_graph_node: Optional[Apply] = None, + is_inner_graph_header: bool = False, ) -> TextIO: r"""Print the graph represented by `var`. @@ -625,7 +634,10 @@ def get_id_str( else: data = "" - var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}" + if is_inner_graph_header: + var_output = f"{prefix}{node.op}{id_str}{destroy_map_str}{view_map_str}{o}" + else: + var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}" if print_op_info and node not in op_information: op_information.update(op_debug_information(node.op, node)) @@ -633,7 +645,7 @@ def get_id_str( node_info = ( parent_node and op_information.get(parent_node) ) or op_information.get(node) - if node_info and var in node_info: + if node_info and var in node_info and not is_inner_graph_header: var_output = f"{var_output} ({node_info[var]})" if profile and profile.apply_time and node in profile.apply_time: @@ -660,12 +672,13 @@ def get_id_str( if not already_done and ( not stop_on_name or not (hasattr(var, "name") and var.name is not None) ): - new_prefix = prefix_child + " |" - new_prefix_child = prefix_child + " |" + new_prefix = prefix_child + " ├─ " + new_prefix_child = prefix_child + " │ " for in_idx, in_var in enumerate(node.inputs): if in_idx == len(node.inputs) - 1: - new_prefix_child = prefix_child + " " + new_prefix = prefix_child + " └─ " + new_prefix_child = prefix_child + " " if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"): if ( @@ -698,6 +711,8 @@ def get_id_str( print_view_map=print_view_map, inner_graph_node=inner_graph_node, ) + elif not is_inner_graph_header: + print(prefix_child + " └─ ···", file=file) else: id_str = get_id_str(var) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index 96cf107d64..9c270e2a45 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -4143,6 +4143,7 @@ class Composite(ScalarInnerGraphOp): def __init__(self, inputs, outputs, name="Composite"): self.name = name + self._name = None # We need to clone the graph as sometimes its nodes already # contain a reference to an fgraph. As we want the Composite # to be pickable, we can't have reference to fgraph. @@ -4189,7 +4190,26 @@ def __init__(self, inputs, outputs, name="Composite"): super().__init__() def __str__(self): - return self.name + if self._name is not None: + return self._name + + # Rename internal variables + for i, r in enumerate(self.fgraph.inputs): + r.name = f"i{int(i)}" + for i, r in enumerate(self.fgraph.outputs): + r.name = f"o{int(i)}" + io = set(self.fgraph.inputs + self.fgraph.outputs) + for i, r in enumerate(self.fgraph.variables): + if r not in io and len(self.fgraph.clients[r]) > 1: + r.name = f"t{int(i)}" + + if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10: + self._name = "Composite{...}" + else: + outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs]) + self._name = f"Composite{{{outputs_str}}}" + + return self._name def make_new_inplace(self, output_types_preference=None, name=None): """ diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py index b35cfa9b7b..69eb0a5fc2 100644 --- a/pytensor/scan/op.py +++ b/pytensor/scan/op.py @@ -1282,27 +1282,18 @@ def __eq__(self, other): ) def __str__(self): - device_str = "cpu" - if self.info.as_while: - name = "do_while" - else: - name = "for" - aux_txt = "%s" + inplace = "none" if len(self.destroy_map.keys()) > 0: # Check if all outputs are inplace if sorted(self.destroy_map.keys()) == sorted( range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot) ): - aux_txt += "all_inplace,%s,%s}" + inplace = "all" else: - aux_txt += "{inplace{" - for k in self.destroy_map.keys(): - aux_txt += str(k) + "," - aux_txt += "},%s,%s}" - else: - aux_txt += "{%s,%s}" - aux_txt = aux_txt % (name, device_str, str(self.name)) - return aux_txt + inplace = str(list(self.destroy_map.keys())) + return ( + f"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}" + ) def __hash__(self): return hash( diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index f498ae5079..6d19579030 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -14,7 +14,7 @@ from pytensor.link.c.params_type import ParamsType from pytensor.misc.frozendict import frozendict from pytensor.misc.safe_asarray import _asarray -from pytensor.printing import FunctionPrinter, Printer, pprint +from pytensor.printing import Printer, pprint from pytensor.scalar import get_scalar_type from pytensor.scalar.basic import bool as scalar_bool from pytensor.scalar.basic import identity as scalar_identity @@ -215,10 +215,18 @@ def make_node(self, _input): return Apply(self, [input], [output]) def __str__(self): - if self.inplace: - return "InplaceDimShuffle{%s}" % ",".join(str(x) for x in self.new_order) - else: - return "DimShuffle{%s}" % ",".join(str(x) for x in self.new_order) + shuffle = sorted(self.shuffle) != self.shuffle + if self.augment and not (shuffle or self.drop): + if len(self.augment) == 1: + return f"ExpandDims{{axis={self.augment[0]}}}" + return f"ExpandDims{{axes={self.augment}}}" + if self.drop and not (self.augment or shuffle): + if len(self.drop) == 1: + return f"DropDims{{axis={self.drop[0]}}}" + return f"DropDims{{axes={self.drop}}}" + if shuffle and not (self.augment or self.drop): + return f"Transpose{{axes={self.shuffle}}}" + return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}" def perform(self, node, inp, out, params): (res,) = inp @@ -490,15 +498,9 @@ def make_node(self, *inputs): return Apply(self, inputs, outputs) def __str__(self): - if self.name is None: - if self.inplace_pattern: - items = list(self.inplace_pattern.items()) - items.sort() - return f"{type(self).__name__}{{{self.scalar_op}}}{items}" - else: - return f"{type(self).__name__}{{{self.scalar_op}}}" - else: + if self.name: return self.name + return str(self.scalar_op).capitalize() def R_op(self, inputs, eval_points): outs = self(*inputs, return_list=True) @@ -1469,23 +1471,17 @@ def clone( return res - def __str__(self): - prefix = f"{type(self).__name__}{{{self.scalar_op}}}" - extra_params = [] - - if self.axis is not None: - axis = ", ".join(str(x) for x in self.axis) - extra_params.append(f"axis=[{axis}]") - - if self.acc_dtype: - extra_params.append(f"acc_dtype={self.acc_dtype}") - - extra_params_str = ", ".join(extra_params) - - if extra_params_str: - return f"{prefix}{{{extra_params_str}}}" + def _axis_str(self): + axis = self.axis + if axis is None: + return "axes=None" + elif len(axis) == 1: + return f"axis={axis[0]}" else: - return f"{prefix}" + return f"axes={list(axis)}" + + def __str__(self): + return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}" def perform(self, node, inp, out): (input,) = inp @@ -1729,21 +1725,17 @@ def construct(symbol): symbolname = symbolname or symbol.__name__ if symbolname.endswith("_inplace"): - elemwise_name = f"Elemwise{{{symbolname},inplace}}" - scalar_op = getattr(scalar, symbolname[: -len("_inplace")]) + base_symbol_name = symbolname[: -len("_inplace")] + scalar_op = getattr(scalar, base_symbol_name) inplace_scalar_op = scalar_op.__class__(transfer_type(0)) rval = Elemwise( inplace_scalar_op, {0: 0}, - name=elemwise_name, nfunc_spec=(nfunc and (nfunc, nin, nout)), ) else: - elemwise_name = f"Elemwise{{{symbolname},no_inplace}}" scalar_op = getattr(scalar, symbolname) - rval = Elemwise( - scalar_op, name=elemwise_name, nfunc_spec=(nfunc and (nfunc, nin, nout)) - ) + rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout))) if getattr(symbol, "__doc__"): rval.__doc__ = symbol.__doc__ + "\n\n " + rval.__doc__ @@ -1753,8 +1745,6 @@ def construct(symbol): rval.__epydoc_asRoutine = symbol rval.__module__ = symbol.__module__ - pprint.assign(rval, FunctionPrinter([symbolname.replace("_inplace", "=")])) - return rval if symbol: diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py index 66373400d9..358c3f3724 100644 --- a/pytensor/tensor/math.py +++ b/pytensor/tensor/math.py @@ -583,7 +583,12 @@ def max_and_argmax(a, axis=None, keepdims=False): return [out, argout] -class NonZeroCAReduce(CAReduce): +class FixedOpCAReduce(CAReduce): + def __str__(self): + return f"{type(self).__name__}{{{self._axis_str()}}}" + + +class NonZeroDimsCAReduce(FixedOpCAReduce): def _c_all(self, node, name, inames, onames, sub): decl, checks, alloc, loop, end = super()._c_all(node, name, inames, onames, sub) @@ -614,7 +619,7 @@ def _c_all(self, node, name, inames, onames, sub): return decl, checks, alloc, loop, end -class Max(NonZeroCAReduce): +class Max(NonZeroDimsCAReduce): nfunc_spec = ("max", 1, 1) def __init__(self, axis): @@ -625,7 +630,7 @@ def clone(self, **kwargs): return type(self)(axis=axis) -class Min(NonZeroCAReduce): +class Min(NonZeroDimsCAReduce): nfunc_spec = ("min", 1, 1) def __init__(self, axis): @@ -1496,7 +1501,7 @@ def complex_from_polar(abs, angle): """Return complex-valued tensor from polar coordinate specification.""" -class Mean(CAReduce): +class Mean(FixedOpCAReduce): __props__ = ("axis",) nfunc_spec = ("mean", 1, 1) @@ -2356,7 +2361,7 @@ def outer(x, y): return dot(x.dimshuffle(0, "x"), y.dimshuffle("x", 0)) -class All(CAReduce): +class All(FixedOpCAReduce): """Applies `logical and` to all the values of a tensor along the specified axis(es). @@ -2370,12 +2375,6 @@ def __init__(self, axis=None): def _output_dtype(self, idtype): return "bool" - def __str__(self): - if self.axis is None: - return "All" - else: - return "All{%s}" % ", ".join(map(str, self.axis)) - def make_node(self, input): input = as_tensor_variable(input) if input.dtype != "bool": @@ -2392,7 +2391,7 @@ def clone(self, **kwargs): return type(self)(axis=axis) -class Any(CAReduce): +class Any(FixedOpCAReduce): """Applies `bitwise or` to all the values of a tensor along the specified axis(es). @@ -2406,12 +2405,6 @@ def __init__(self, axis=None): def _output_dtype(self, idtype): return "bool" - def __str__(self): - if self.axis is None: - return "Any" - else: - return "Any{%s}" % ", ".join(map(str, self.axis)) - def make_node(self, input): input = as_tensor_variable(input) if input.dtype != "bool": @@ -2428,7 +2421,7 @@ def clone(self, **kwargs): return type(self)(axis=axis) -class Sum(CAReduce): +class Sum(FixedOpCAReduce): """ Sums all the values of a tensor along the specified axis(es). @@ -2449,14 +2442,6 @@ def __init__(self, axis=None, dtype=None, acc_dtype=None): upcast_discrete_output=True, ) - def __str__(self): - name = self.__class__.__name__ - axis = "" - if self.axis is not None: - axis = ", ".join(str(x) for x in self.axis) - axis = f"axis=[{axis}], " - return f"{name}{{{axis}acc_dtype={self.acc_dtype}}}" - def L_op(self, inp, out, grads): (x,) = inp @@ -2526,7 +2511,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None): pprint.assign(Sum, printing.FunctionPrinter(["sum"], ["axis"])) -class Prod(CAReduce): +class Prod(FixedOpCAReduce): """ Multiplies all the values of a tensor along the specified axis(es). @@ -2537,7 +2522,6 @@ class Prod(CAReduce): """ __props__ = ("scalar_op", "axis", "dtype", "acc_dtype", "no_zeros_in_input") - nfunc_spec = ("prod", 1, 1) def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False): @@ -2683,6 +2667,14 @@ def clone(self, **kwargs): no_zeros_in_input=no_zeros_in_input, ) + def __str__(self): + if self.no_zeros_in_input: + return f"{super().__str__()[:-1]}, no_zeros_in_input}})" + return super().__str__() + + def __repr__(self): + return f"{super().__repr__()[:-1]}, no_zeros_in_input={self.no_zeros_in_input})" + def prod( input, @@ -2751,7 +2743,7 @@ def c_code_cache_version(self): mul_without_zeros = MulWithoutZeros(aes.upcast_out, name="mul_without_zeros") -class ProdWithoutZeros(CAReduce): +class ProdWithoutZeros(FixedOpCAReduce): def __init__(self, axis=None, dtype=None, acc_dtype=None): super().__init__( mul_without_zeros, diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py index cf604eec1b..b0d124f1c8 100644 --- a/pytensor/tensor/rewriting/math.py +++ b/pytensor/tensor/rewriting/math.py @@ -42,7 +42,8 @@ All, Any, Dot, - NonZeroCAReduce, + FixedOpCAReduce, + NonZeroDimsCAReduce, Prod, ProdWithoutZeros, Sum, @@ -1671,7 +1672,8 @@ def local_op_of_op(fgraph, node): ProdWithoutZeros, ] + CAReduce.__subclasses__() - + NonZeroCAReduce.__subclasses__() + + FixedOpCAReduce.__subclasses__() + + NonZeroDimsCAReduce.__subclasses__() ) diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py index 99c335348b..e0c0060058 100644 --- a/pytensor/tensor/subtensor.py +++ b/pytensor/tensor/subtensor.py @@ -840,22 +840,34 @@ def __hash__(self): @staticmethod def str_from_slice(entry): - msg = [] - for x in [entry.start, entry.stop, entry.step]: - if x is None: - msg.append("") - else: - msg.append(str(x)) - return ":".join(msg) + if entry.step: + return ":".join( + ( + "start" if entry.start else "", + "stop" if entry.stop else "", + "step", + ) + ) + if entry.stop: + return f"{'start' if entry.start else ''}:stop" + if entry.start: + return "start:" + return ":" - def __str__(self): + @staticmethod + def str_from_indices(idx_list): indices = [] - for entry in self.idx_list: + letter_indexes = 0 + for entry in idx_list: if isinstance(entry, slice): - indices.append(self.str_from_slice(entry)) + indices.append(Subtensor.str_from_slice(entry)) else: - indices.append(str(entry)) - return f"{self.__class__.__name__}{{{', '.join(indices)}}}" + indices.append("ijk"[letter_indexes % 3] * (letter_indexes // 3 + 1)) + letter_indexes += 1 + return ", ".join(indices) + + def __str__(self): + return f"{self.__class__.__name__}{{{self.str_from_indices(self.idx_list)}}}" @staticmethod def default_helper_c_code_args(): @@ -1498,21 +1510,8 @@ def __hash__(self): return hash((type(self), idx_list, self.inplace, self.set_instead_of_inc)) def __str__(self): - indices = [] - for entry in self.idx_list: - if isinstance(entry, slice): - indices.append(Subtensor.str_from_slice(entry)) - else: - indices.append(str(entry)) - if self.inplace: - msg = "Inplace" - else: - msg = "" - if not self.set_instead_of_inc: - msg += "Inc" - else: - msg += "Set" - return f"{self.__class__.__name__}{{{msg};{', '.join(indices)}}}" + name = "SetSubtensor" if self.set_instead_of_inc else "IncSubtensor" + return f"{name}{{{Subtensor.str_from_indices(self.idx_list)}}}" def make_node(self, x, y, *inputs): """ @@ -2661,10 +2660,10 @@ def __init__( self.ignore_duplicates = ignore_duplicates def __str__(self): - return "{}{{{}, {}}}".format( - self.__class__.__name__, - "inplace=" + str(self.inplace), - " set_instead_of_inc=" + str(self.set_instead_of_inc), + return ( + "AdvancedSetSubtensor" + if self.set_instead_of_inc + else "AdvancedIncSubtensor" ) def make_node(self, x, y, *inputs): diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py index 3a07afecd7..c4356d8117 100644 --- a/pytensor/tensor/type.py +++ b/pytensor/tensor/type.py @@ -386,6 +386,8 @@ def __str__(self): if self.name: return self.name else: + shape = self.shape + len_shape = len(shape) def shape_str(s): if s is None: @@ -393,14 +395,18 @@ def shape_str(s): else: return str(s) - formatted_shape = ", ".join([shape_str(s) for s in self.shape]) - if len(self.shape) == 1: + formatted_shape = ", ".join([shape_str(s) for s in shape]) + if len_shape == 1: formatted_shape += "," - return f"TensorType({self.dtype}, ({formatted_shape}))" + if len_shape > 2: + name = f"Tensor{len_shape}" + else: + name = ("Scalar", "Vector", "Matrix")[len_shape] + return f"{name}({self.dtype}, shape=({formatted_shape}))" def __repr__(self): - return str(self) + return f"TensorType({self.dtype}, shape={self.shape})" @staticmethod def may_share_memory(a, b): diff --git a/pytensor/tensor/var.py b/pytensor/tensor/var.py index 428cd8fffd..7cfd4cef87 100644 --- a/pytensor/tensor/var.py +++ b/pytensor/tensor/var.py @@ -616,9 +616,9 @@ def __iter__(self): except TypeError: # This prevents accidental iteration via sum(self) raise TypeError( - "TensorType does not support iteration. " - "Maybe you are using builtins.sum instead of " - "pytensor.tensor.math.sum? (Maybe .max?)" + "TensorType does not support iteration.\n" + "\tDid you pass a PyTensor variable to a function that expects a list?\n" + "\tMaybe you are using builtins.sum instead of pytensor.tensor.sum?" ) @property @@ -1023,21 +1023,6 @@ def __init__(self, type: _TensorTypeType, data, name=None): Constant.__init__(self, new_type, data, name) - def __str__(self): - unique_val = get_unique_value(self) - if unique_val is not None: - val = f"{self.data.shape} of {unique_val}" - else: - val = f"{self.data}" - if len(val) > 20: - val = val[:10] + ".." + val[-10:] - - if self.name is not None: - name = self.name - else: - name = "TensorConstant" - return f"{name}{{{val}}}" - def signature(self): return TensorConstantSignature((self.type, self.data)) diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py index ad45b8270c..8fdeb18470 100644 --- a/tests/compile/test_builders.py +++ b/tests/compile/test_builders.py @@ -572,18 +572,18 @@ def test_debugprint(): lines = output_str.split("\n") exp_res = """OpFromGraph{inline=False} [id A] - |x [id B] - |y [id C] - |z [id D] + ├─ x [id B] + ├─ y [id C] + └─ z [id D] Inner graphs: OpFromGraph{inline=False} [id A] - >Elemwise{add,no_inplace} [id E] - > |*0- [id F] - > |Elemwise{mul,no_inplace} [id G] - > |*1- [id H] - > |*2- [id I] + ← Add [id E] + ├─ *0- [id F] + └─ Mul [id G] + ├─ *1- [id H] + └─ *2- [id I] """ for truth, out in zip(exp_res.split("\n"), lines): diff --git a/tests/graph/rewriting/test_basic.py b/tests/graph/rewriting/test_basic.py index 7e249ddc3b..a690a1c88f 100644 --- a/tests/graph/rewriting/test_basic.py +++ b/tests/graph/rewriting/test_basic.py @@ -166,7 +166,7 @@ def test_constant(self): e = op1(op1(x, y), y) g = FunctionGraph([y], [e]) OpKeyPatternNodeRewriter((op1, z, "1"), (op2, "1", z)).rewrite(g) - assert str(g) == "FunctionGraph(Op1(Op2(y, z), y))" + assert str(g) == "FunctionGraph(Op1(Op2(y, z{2}), y))" def test_constraints(self): x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z") diff --git a/tests/link/test_utils.py b/tests/link/test_utils.py index 812bd6304c..80bab78012 100644 --- a/tests/link/test_utils.py +++ b/tests/link/test_utils.py @@ -156,7 +156,7 @@ def func(*args, op=op): assert ( """ - # Elemwise{add,no_inplace}(Test + # Add(Test # Op().0, Test # Op().1) """ diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py index 13f6a21110..acbecc7aba 100644 --- a/tests/scalar/test_basic.py +++ b/tests/scalar/test_basic.py @@ -183,7 +183,7 @@ def test_composite_printing(self): make_function(DualLinker().accept(g)) assert str(g) == ( - "FunctionGraph(*1 -> Composite(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)" + "FunctionGraph(*1 -> Composite{...}(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)" ) def test_non_scalar_error(self): diff --git a/tests/scan/test_printing.py b/tests/scan/test_printing.py index 882c708966..725a48627d 100644 --- a/tests/scan/test_printing.py +++ b/tests/scan/test_printing.py @@ -26,40 +26,43 @@ def test_debugprint_sitsot(): output_str = debugprint(final_result, file="str", print_op_info=True) lines = output_str.split("\n") - expected_output = """Subtensor{int64} [id A] - |Subtensor{int64::} [id B] - | |for{cpu,scan_fn} [id C] (outer_out_sit_sot-0) - | | |k [id D] (n_steps) - | | |IncSubtensor{Set;:int64:} [id E] (outer_in_sit_sot-0) - | | | |AllocEmpty{dtype='float64'} [id F] - | | | | |Elemwise{add,no_inplace} [id G] - | | | | | |k [id D] - | | | | | |Subtensor{int64} [id H] - | | | | | |Shape [id I] - | | | | | | |Unbroadcast{0} [id J] - | | | | | | |InplaceDimShuffle{x,0} [id K] - | | | | | | |Elemwise{second,no_inplace} [id L] - | | | | | | |A [id M] - | | | | | | |InplaceDimShuffle{x} [id N] - | | | | | | |TensorConstant{1.0} [id O] - | | | | | |ScalarConstant{0} [id P] - | | | | |Subtensor{int64} [id Q] - | | | | |Shape [id R] - | | | | | |Unbroadcast{0} [id J] - | | | | |ScalarConstant{1} [id S] - | | | |Unbroadcast{0} [id J] - | | | |ScalarFromTensor [id T] - | | | |Subtensor{int64} [id H] - | | |A [id M] (outer_in_non_seqs-0) - | |ScalarConstant{1} [id U] - |ScalarConstant{-1} [id V] + expected_output = """Subtensor{i} [id A] + ├─ Subtensor{start:} [id B] + │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id C] (outer_out_sit_sot-0) + │ │ ├─ k [id D] (n_steps) + │ │ ├─ SetSubtensor{:stop} [id E] (outer_in_sit_sot-0) + │ │ │ ├─ AllocEmpty{dtype='float64'} [id F] + │ │ │ │ ├─ Add [id G] + │ │ │ │ │ ├─ k [id D] + │ │ │ │ │ └─ Subtensor{i} [id H] + │ │ │ │ │ ├─ Shape [id I] + │ │ │ │ │ │ └─ Unbroadcast{0} [id J] + │ │ │ │ │ │ └─ ExpandDims{axis=0} [id K] + │ │ │ │ │ │ └─ Second [id L] + │ │ │ │ │ │ ├─ A [id M] + │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N] + │ │ │ │ │ │ └─ 1.0 [id O] + │ │ │ │ │ └─ 0 [id P] + │ │ │ │ └─ Subtensor{i} [id Q] + │ │ │ │ ├─ Shape [id R] + │ │ │ │ │ └─ Unbroadcast{0} [id J] + │ │ │ │ │ └─ ··· + │ │ │ │ └─ 1 [id S] + │ │ │ ├─ Unbroadcast{0} [id J] + │ │ │ │ └─ ··· + │ │ │ └─ ScalarFromTensor [id T] + │ │ │ └─ Subtensor{i} [id H] + │ │ │ └─ ··· + │ │ └─ A [id M] (outer_in_non_seqs-0) + │ └─ 1 [id U] + └─ -1 [id V] Inner graphs: - for{cpu,scan_fn} [id C] (outer_out_sit_sot-0) - >Elemwise{mul,no_inplace} [id W] (inner_out_sit_sot-0) - > |*0- [id X] -> [id E] (inner_in_sit_sot-0) - > |*1- [id Y] -> [id M] (inner_in_non_seqs-0)""" + Scan{scan_fn, while_loop=False, inplace=none} [id C] + ← Mul [id W] (inner_out_sit_sot-0) + ├─ *0- [id X] -> [id E] (inner_in_sit_sot-0) + └─ *1- [id Y] -> [id M] (inner_in_non_seqs-0)""" for truth, out in zip(expected_output.split("\n"), lines): assert truth.strip() == out.strip() @@ -81,40 +84,43 @@ def test_debugprint_sitsot_no_extra_info(): output_str = debugprint(final_result, file="str", print_op_info=False) lines = output_str.split("\n") - expected_output = """Subtensor{int64} [id A] - |Subtensor{int64::} [id B] - | |for{cpu,scan_fn} [id C] - | | |k [id D] - | | |IncSubtensor{Set;:int64:} [id E] - | | | |AllocEmpty{dtype='float64'} [id F] - | | | | |Elemwise{add,no_inplace} [id G] - | | | | | |k [id D] - | | | | | |Subtensor{int64} [id H] - | | | | | |Shape [id I] - | | | | | | |Unbroadcast{0} [id J] - | | | | | | |InplaceDimShuffle{x,0} [id K] - | | | | | | |Elemwise{second,no_inplace} [id L] - | | | | | | |A [id M] - | | | | | | |InplaceDimShuffle{x} [id N] - | | | | | | |TensorConstant{1.0} [id O] - | | | | | |ScalarConstant{0} [id P] - | | | | |Subtensor{int64} [id Q] - | | | | |Shape [id R] - | | | | | |Unbroadcast{0} [id J] - | | | | |ScalarConstant{1} [id S] - | | | |Unbroadcast{0} [id J] - | | | |ScalarFromTensor [id T] - | | | |Subtensor{int64} [id H] - | | |A [id M] - | |ScalarConstant{1} [id U] - |ScalarConstant{-1} [id V] + expected_output = """Subtensor{i} [id A] + ├─ Subtensor{start:} [id B] + │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id C] + │ │ ├─ k [id D] + │ │ ├─ SetSubtensor{:stop} [id E] + │ │ │ ├─ AllocEmpty{dtype='float64'} [id F] + │ │ │ │ ├─ Add [id G] + │ │ │ │ │ ├─ k [id D] + │ │ │ │ │ └─ Subtensor{i} [id H] + │ │ │ │ │ ├─ Shape [id I] + │ │ │ │ │ │ └─ Unbroadcast{0} [id J] + │ │ │ │ │ │ └─ ExpandDims{axis=0} [id K] + │ │ │ │ │ │ └─ Second [id L] + │ │ │ │ │ │ ├─ A [id M] + │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N] + │ │ │ │ │ │ └─ 1.0 [id O] + │ │ │ │ │ └─ 0 [id P] + │ │ │ │ └─ Subtensor{i} [id Q] + │ │ │ │ ├─ Shape [id R] + │ │ │ │ │ └─ Unbroadcast{0} [id J] + │ │ │ │ │ └─ ··· + │ │ │ │ └─ 1 [id S] + │ │ │ ├─ Unbroadcast{0} [id J] + │ │ │ │ └─ ··· + │ │ │ └─ ScalarFromTensor [id T] + │ │ │ └─ Subtensor{i} [id H] + │ │ │ └─ ··· + │ │ └─ A [id M] + │ └─ 1 [id U] + └─ -1 [id V] Inner graphs: - for{cpu,scan_fn} [id C] - >Elemwise{mul,no_inplace} [id W] - > |*0- [id X] -> [id E] - > |*1- [id Y] -> [id M]""" + Scan{scan_fn, while_loop=False, inplace=none} [id C] + ← Mul [id W] + ├─ *0- [id X] -> [id E] + └─ *1- [id Y] -> [id M]""" for truth, out in zip(expected_output.split("\n"), lines): assert truth.strip() == out.strip() @@ -141,43 +147,48 @@ def test_debugprint_nitsot(): output_str = debugprint(polynomial, file="str", print_op_info=True) lines = output_str.split("\n") - expected_output = """Sum{acc_dtype=float64} [id A] - |for{cpu,scan_fn} [id B] (outer_out_nit_sot-0) - |Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0) - | |Subtensor{int64} [id D] - | | |Shape [id E] - | | | |Subtensor{int64::} [id F] 'coefficients[0:]' - | | | |coefficients [id G] - | | | |ScalarConstant{0} [id H] - | | |ScalarConstant{0} [id I] - | |Subtensor{int64} [id J] - | |Shape [id K] - | | |Subtensor{int64::} [id L] - | | |ARange{dtype='int64'} [id M] - | | | |TensorConstant{0} [id N] - | | | |TensorConstant{10000} [id O] - | | | |TensorConstant{1} [id P] - | | |ScalarConstant{0} [id Q] - | |ScalarConstant{0} [id R] - |Subtensor{:int64:} [id S] (outer_in_seqs-0) - | |Subtensor{int64::} [id F] 'coefficients[0:]' - | |ScalarFromTensor [id T] - | |Elemwise{scalar_minimum,no_inplace} [id C] - |Subtensor{:int64:} [id U] (outer_in_seqs-1) - | |Subtensor{int64::} [id L] - | |ScalarFromTensor [id V] - | |Elemwise{scalar_minimum,no_inplace} [id C] - |Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0) - |x [id W] (outer_in_non_seqs-0) + expected_output = """Sum{axes=None} [id A] + └─ Scan{scan_fn, while_loop=False, inplace=none} [id B] (outer_out_nit_sot-0) + ├─ Minimum [id C] (outer_in_nit_sot-0) + │ ├─ Subtensor{i} [id D] + │ │ ├─ Shape [id E] + │ │ │ └─ Subtensor{start:} [id F] 'coefficients[0:]' + │ │ │ ├─ coefficients [id G] + │ │ │ └─ 0 [id H] + │ │ └─ 0 [id I] + │ └─ Subtensor{i} [id J] + │ ├─ Shape [id K] + │ │ └─ Subtensor{start:} [id L] + │ │ ├─ ARange{dtype='int64'} [id M] + │ │ │ ├─ 0 [id N] + │ │ │ ├─ 10000 [id O] + │ │ │ └─ 1 [id P] + │ │ └─ 0 [id Q] + │ └─ 0 [id R] + ├─ Subtensor{:stop} [id S] (outer_in_seqs-0) + │ ├─ Subtensor{start:} [id F] 'coefficients[0:]' + │ │ └─ ··· + │ └─ ScalarFromTensor [id T] + │ └─ Minimum [id C] + │ └─ ··· + ├─ Subtensor{:stop} [id U] (outer_in_seqs-1) + │ ├─ Subtensor{start:} [id L] + │ │ └─ ··· + │ └─ ScalarFromTensor [id V] + │ └─ Minimum [id C] + │ └─ ··· + ├─ Minimum [id C] (outer_in_nit_sot-0) + │ └─ ··· + └─ x [id W] (outer_in_non_seqs-0) Inner graphs: - for{cpu,scan_fn} [id B] (outer_out_nit_sot-0) - >Elemwise{mul,no_inplace} [id X] (inner_out_nit_sot-0) - > |*0- [id Y] -> [id S] (inner_in_seqs-0) - > |Elemwise{pow,no_inplace} [id Z] - > |*2- [id BA] -> [id W] (inner_in_non_seqs-0) - > |*1- [id BB] -> [id U] (inner_in_seqs-1)""" + Scan{scan_fn, while_loop=False, inplace=none} [id B] + ← Mul [id X] (inner_out_nit_sot-0) + ├─ *0- [id Y] -> [id S] (inner_in_seqs-0) + └─ Pow [id Z] + ├─ *2- [id BA] -> [id W] (inner_in_non_seqs-0) + └─ *1- [id BB] -> [id U] (inner_in_seqs-1)""" for truth, out in zip(expected_output.split("\n"), lines): assert truth.strip() == out.strip() @@ -214,77 +225,85 @@ def compute_A_k(A, k): output_str = debugprint(final_result, file="str", print_op_info=True) lines = output_str.split("\n") - expected_output = """Sum{acc_dtype=float64} [id A] - |for{cpu,scan_fn} [id B] (outer_out_nit_sot-0) - |Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0) - | |Subtensor{int64} [id D] - | | |Shape [id E] - | | | |Subtensor{int64::} [id F] 'c[0:]' - | | | |c [id G] - | | | |ScalarConstant{0} [id H] - | | |ScalarConstant{0} [id I] - | |Subtensor{int64} [id J] - | |Shape [id K] - | | |Subtensor{int64::} [id L] - | | |ARange{dtype='int64'} [id M] - | | | |TensorConstant{0} [id N] - | | | |TensorConstant{10} [id O] - | | | |TensorConstant{1} [id P] - | | |ScalarConstant{0} [id Q] - | |ScalarConstant{0} [id R] - |Subtensor{:int64:} [id S] (outer_in_seqs-0) - | |Subtensor{int64::} [id F] 'c[0:]' - | |ScalarFromTensor [id T] - | |Elemwise{scalar_minimum,no_inplace} [id C] - |Subtensor{:int64:} [id U] (outer_in_seqs-1) - | |Subtensor{int64::} [id L] - | |ScalarFromTensor [id V] - | |Elemwise{scalar_minimum,no_inplace} [id C] - |Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0) - |A [id W] (outer_in_non_seqs-0) - |k [id X] (outer_in_non_seqs-1) + expected_output = """Sum{axes=None} [id A] + └─ Scan{scan_fn, while_loop=False, inplace=none} [id B] (outer_out_nit_sot-0) + ├─ Minimum [id C] (outer_in_nit_sot-0) + │ ├─ Subtensor{i} [id D] + │ │ ├─ Shape [id E] + │ │ │ └─ Subtensor{start:} [id F] 'c[0:]' + │ │ │ ├─ c [id G] + │ │ │ └─ 0 [id H] + │ │ └─ 0 [id I] + │ └─ Subtensor{i} [id J] + │ ├─ Shape [id K] + │ │ └─ Subtensor{start:} [id L] + │ │ ├─ ARange{dtype='int64'} [id M] + │ │ │ ├─ 0 [id N] + │ │ │ ├─ 10 [id O] + │ │ │ └─ 1 [id P] + │ │ └─ 0 [id Q] + │ └─ 0 [id R] + ├─ Subtensor{:stop} [id S] (outer_in_seqs-0) + │ ├─ Subtensor{start:} [id F] 'c[0:]' + │ │ └─ ··· + │ └─ ScalarFromTensor [id T] + │ └─ Minimum [id C] + │ └─ ··· + ├─ Subtensor{:stop} [id U] (outer_in_seqs-1) + │ ├─ Subtensor{start:} [id L] + │ │ └─ ··· + │ └─ ScalarFromTensor [id V] + │ └─ Minimum [id C] + │ └─ ··· + ├─ Minimum [id C] (outer_in_nit_sot-0) + │ └─ ··· + ├─ A [id W] (outer_in_non_seqs-0) + └─ k [id X] (outer_in_non_seqs-1) Inner graphs: - for{cpu,scan_fn} [id B] (outer_out_nit_sot-0) - >Elemwise{mul,no_inplace} [id Y] (inner_out_nit_sot-0) - > |InplaceDimShuffle{x} [id Z] - > | |*0- [id BA] -> [id S] (inner_in_seqs-0) - > |Elemwise{pow,no_inplace} [id BB] - > |Subtensor{int64} [id BC] - > | |Subtensor{int64::} [id BD] - > | | |for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0) - > | | | |*3- [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps) - > | | | |IncSubtensor{Set;:int64:} [id BG] (outer_in_sit_sot-0) - > | | | | |AllocEmpty{dtype='float64'} [id BH] - > | | | | | |Elemwise{add,no_inplace} [id BI] - > | | | | | | |*3- [id BF] -> [id X] (inner_in_non_seqs-1) - > | | | | | | |Subtensor{int64} [id BJ] - > | | | | | | |Shape [id BK] - > | | | | | | | |Unbroadcast{0} [id BL] - > | | | | | | | |InplaceDimShuffle{x,0} [id BM] - > | | | | | | | |Elemwise{second,no_inplace} [id BN] - > | | | | | | | |*2- [id BO] -> [id W] (inner_in_non_seqs-0) - > | | | | | | | |InplaceDimShuffle{x} [id BP] - > | | | | | | | |TensorConstant{1.0} [id BQ] - > | | | | | | |ScalarConstant{0} [id BR] - > | | | | | |Subtensor{int64} [id BS] - > | | | | | |Shape [id BT] - > | | | | | | |Unbroadcast{0} [id BL] - > | | | | | |ScalarConstant{1} [id BU] - > | | | | |Unbroadcast{0} [id BL] - > | | | | |ScalarFromTensor [id BV] - > | | | | |Subtensor{int64} [id BJ] - > | | | |*2- [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0) - > | | |ScalarConstant{1} [id BW] - > | |ScalarConstant{-1} [id BX] - > |InplaceDimShuffle{x} [id BY] - > |*1- [id BZ] -> [id U] (inner_in_seqs-1) - - for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0) - >Elemwise{mul,no_inplace} [id CA] (inner_out_sit_sot-0) - > |*0- [id CB] -> [id BG] (inner_in_sit_sot-0) - > |*1- [id CC] -> [id BO] (inner_in_non_seqs-0)""" + Scan{scan_fn, while_loop=False, inplace=none} [id B] + ← Mul [id Y] (inner_out_nit_sot-0) + ├─ ExpandDims{axis=0} [id Z] + │ └─ *0- [id BA] -> [id S] (inner_in_seqs-0) + └─ Pow [id BB] + ├─ Subtensor{i} [id BC] + │ ├─ Subtensor{start:} [id BD] + │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BE] (outer_out_sit_sot-0) + │ │ │ ├─ *3- [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps) + │ │ │ ├─ SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0) + │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH] + │ │ │ │ │ ├─ Add [id BI] + │ │ │ │ │ │ ├─ *3- [id BF] -> [id X] (inner_in_non_seqs-1) + │ │ │ │ │ │ └─ Subtensor{i} [id BJ] + │ │ │ │ │ │ ├─ Shape [id BK] + │ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL] + │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BM] + │ │ │ │ │ │ │ └─ Second [id BN] + │ │ │ │ │ │ │ ├─ *2- [id BO] -> [id W] (inner_in_non_seqs-0) + │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP] + │ │ │ │ │ │ │ └─ 1.0 [id BQ] + │ │ │ │ │ │ └─ 0 [id BR] + │ │ │ │ │ └─ Subtensor{i} [id BS] + │ │ │ │ │ ├─ Shape [id BT] + │ │ │ │ │ │ └─ Unbroadcast{0} [id BL] + │ │ │ │ │ │ └─ ··· + │ │ │ │ │ └─ 1 [id BU] + │ │ │ │ ├─ Unbroadcast{0} [id BL] + │ │ │ │ │ └─ ··· + │ │ │ │ └─ ScalarFromTensor [id BV] + │ │ │ │ └─ Subtensor{i} [id BJ] + │ │ │ │ └─ ··· + │ │ │ └─ *2- [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0) + │ │ └─ 1 [id BW] + │ └─ -1 [id BX] + └─ ExpandDims{axis=0} [id BY] + └─ *1- [id BZ] -> [id U] (inner_in_seqs-1) + + Scan{scan_fn, while_loop=False, inplace=none} [id BE] + ← Mul [id CA] (inner_out_sit_sot-0) + ├─ *0- [id CB] -> [id BG] (inner_in_sit_sot-0) + └─ *1- [id CC] -> [id BO] (inner_in_non_seqs-0)""" for truth, out in zip(expected_output.split("\n"), lines): assert truth.strip() == out.strip() @@ -296,86 +315,94 @@ def compute_A_k(A, k): ) lines = output_str.split("\n") - expected_output = """-c [id A] - -k [id B] - -A [id C] - Sum{acc_dtype=float64} [id D] 13 - |for{cpu,scan_fn} [id E] 12 (outer_out_nit_sot-0) - |Elemwise{scalar_minimum,no_inplace} [id F] 7 (outer_in_nit_sot-0) - | |Subtensor{int64} [id G] 6 - | | |Shape [id H] 5 - | | | |Subtensor{int64::} [id I] 'c[0:]' 4 - | | | |c [id A] - | | | |ScalarConstant{0} [id J] - | | |ScalarConstant{0} [id K] - | |Subtensor{int64} [id L] 3 - | |Shape [id M] 2 - | | |Subtensor{int64::} [id N] 1 - | | |ARange{dtype='int64'} [id O] 0 - | | | |TensorConstant{0} [id P] - | | | |TensorConstant{10} [id Q] - | | | |TensorConstant{1} [id R] - | | |ScalarConstant{0} [id S] - | |ScalarConstant{0} [id T] - |Subtensor{:int64:} [id U] 11 (outer_in_seqs-0) - | |Subtensor{int64::} [id I] 'c[0:]' 4 - | |ScalarFromTensor [id V] 10 - | |Elemwise{scalar_minimum,no_inplace} [id F] 7 - |Subtensor{:int64:} [id W] 9 (outer_in_seqs-1) - | |Subtensor{int64::} [id N] 1 - | |ScalarFromTensor [id X] 8 - | |Elemwise{scalar_minimum,no_inplace} [id F] 7 - |Elemwise{scalar_minimum,no_inplace} [id F] 7 (outer_in_nit_sot-0) - |A [id C] (outer_in_non_seqs-0) - |k [id B] (outer_in_non_seqs-1) + expected_output = """→ c [id A] + → k [id B] + → A [id C] + Sum{axes=None} [id D] 13 + └─ Scan{scan_fn, while_loop=False, inplace=none} [id E] 12 (outer_out_nit_sot-0) + ├─ Minimum [id F] 7 (outer_in_nit_sot-0) + │ ├─ Subtensor{i} [id G] 6 + │ │ ├─ Shape [id H] 5 + │ │ │ └─ Subtensor{start:} [id I] 'c[0:]' 4 + │ │ │ ├─ c [id A] + │ │ │ └─ 0 [id J] + │ │ └─ 0 [id K] + │ └─ Subtensor{i} [id L] 3 + │ ├─ Shape [id M] 2 + │ │ └─ Subtensor{start:} [id N] 1 + │ │ ├─ ARange{dtype='int64'} [id O] 0 + │ │ │ ├─ 0 [id P] + │ │ │ ├─ 10 [id Q] + │ │ │ └─ 1 [id R] + │ │ └─ 0 [id S] + │ └─ 0 [id T] + ├─ Subtensor{:stop} [id U] 11 (outer_in_seqs-0) + │ ├─ Subtensor{start:} [id I] 'c[0:]' 4 + │ │ └─ ··· + │ └─ ScalarFromTensor [id V] 10 + │ └─ Minimum [id F] 7 + │ └─ ··· + ├─ Subtensor{:stop} [id W] 9 (outer_in_seqs-1) + │ ├─ Subtensor{start:} [id N] 1 + │ │ └─ ··· + │ └─ ScalarFromTensor [id X] 8 + │ └─ Minimum [id F] 7 + │ └─ ··· + ├─ Minimum [id F] 7 (outer_in_nit_sot-0) + │ └─ ··· + ├─ A [id C] (outer_in_non_seqs-0) + └─ k [id B] (outer_in_non_seqs-1) Inner graphs: - for{cpu,scan_fn} [id E] (outer_out_nit_sot-0) - -*0- [id Y] -> [id U] (inner_in_seqs-0) - -*1- [id Z] -> [id W] (inner_in_seqs-1) - -*2- [id BA] -> [id C] (inner_in_non_seqs-0) - -*3- [id BB] -> [id B] (inner_in_non_seqs-1) - >Elemwise{mul,no_inplace} [id BC] (inner_out_nit_sot-0) - > |InplaceDimShuffle{x} [id BD] - > | |*0- [id Y] (inner_in_seqs-0) - > |Elemwise{pow,no_inplace} [id BE] - > |Subtensor{int64} [id BF] - > | |Subtensor{int64::} [id BG] - > | | |for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0) - > | | | |*3- [id BB] (inner_in_non_seqs-1) (n_steps) - > | | | |IncSubtensor{Set;:int64:} [id BI] (outer_in_sit_sot-0) - > | | | | |AllocEmpty{dtype='float64'} [id BJ] - > | | | | | |Elemwise{add,no_inplace} [id BK] - > | | | | | | |*3- [id BB] (inner_in_non_seqs-1) - > | | | | | | |Subtensor{int64} [id BL] - > | | | | | | |Shape [id BM] - > | | | | | | | |Unbroadcast{0} [id BN] - > | | | | | | | |InplaceDimShuffle{x,0} [id BO] - > | | | | | | | |Elemwise{second,no_inplace} [id BP] - > | | | | | | | |*2- [id BA] (inner_in_non_seqs-0) - > | | | | | | | |InplaceDimShuffle{x} [id BQ] - > | | | | | | | |TensorConstant{1.0} [id BR] - > | | | | | | |ScalarConstant{0} [id BS] - > | | | | | |Subtensor{int64} [id BT] - > | | | | | |Shape [id BU] - > | | | | | | |Unbroadcast{0} [id BN] - > | | | | | |ScalarConstant{1} [id BV] - > | | | | |Unbroadcast{0} [id BN] - > | | | | |ScalarFromTensor [id BW] - > | | | | |Subtensor{int64} [id BL] - > | | | |*2- [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0) - > | | |ScalarConstant{1} [id BX] - > | |ScalarConstant{-1} [id BY] - > |InplaceDimShuffle{x} [id BZ] - > |*1- [id Z] (inner_in_seqs-1) - - for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0) - -*0- [id CA] -> [id BI] (inner_in_sit_sot-0) - -*1- [id CB] -> [id BA] (inner_in_non_seqs-0) - >Elemwise{mul,no_inplace} [id CC] (inner_out_sit_sot-0) - > |*0- [id CA] (inner_in_sit_sot-0) - > |*1- [id CB] (inner_in_non_seqs-0)""" + Scan{scan_fn, while_loop=False, inplace=none} [id E] + → *0- [id Y] -> [id U] (inner_in_seqs-0) + → *1- [id Z] -> [id W] (inner_in_seqs-1) + → *2- [id BA] -> [id C] (inner_in_non_seqs-0) + → *3- [id BB] -> [id B] (inner_in_non_seqs-1) + ← Mul [id BC] (inner_out_nit_sot-0) + ├─ ExpandDims{axis=0} [id BD] + │ └─ *0- [id Y] (inner_in_seqs-0) + └─ Pow [id BE] + ├─ Subtensor{i} [id BF] + │ ├─ Subtensor{start:} [id BG] + │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BH] (outer_out_sit_sot-0) + │ │ │ ├─ *3- [id BB] (inner_in_non_seqs-1) (n_steps) + │ │ │ ├─ SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0) + │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ] + │ │ │ │ │ ├─ Add [id BK] + │ │ │ │ │ │ ├─ *3- [id BB] (inner_in_non_seqs-1) + │ │ │ │ │ │ └─ Subtensor{i} [id BL] + │ │ │ │ │ │ ├─ Shape [id BM] + │ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN] + │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO] + │ │ │ │ │ │ │ └─ Second [id BP] + │ │ │ │ │ │ │ ├─ *2- [id BA] (inner_in_non_seqs-0) + │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BQ] + │ │ │ │ │ │ │ └─ 1.0 [id BR] + │ │ │ │ │ │ └─ 0 [id BS] + │ │ │ │ │ └─ Subtensor{i} [id BT] + │ │ │ │ │ ├─ Shape [id BU] + │ │ │ │ │ │ └─ Unbroadcast{0} [id BN] + │ │ │ │ │ │ └─ ··· + │ │ │ │ │ └─ 1 [id BV] + │ │ │ │ ├─ Unbroadcast{0} [id BN] + │ │ │ │ │ └─ ··· + │ │ │ │ └─ ScalarFromTensor [id BW] + │ │ │ │ └─ Subtensor{i} [id BL] + │ │ │ │ └─ ··· + │ │ │ └─ *2- [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0) + │ │ └─ 1 [id BX] + │ └─ -1 [id BY] + └─ ExpandDims{axis=0} [id BZ] + └─ *1- [id Z] (inner_in_seqs-1) + + Scan{scan_fn, while_loop=False, inplace=none} [id BH] + → *0- [id CA] -> [id BI] (inner_in_sit_sot-0) + → *1- [id CB] -> [id BA] (inner_in_non_seqs-0) + ← Mul [id CC] (inner_out_sit_sot-0) + ├─ *0- [id CA] (inner_in_sit_sot-0) + └─ *1- [id CB] (inner_in_non_seqs-0)""" for truth, out in zip(expected_output.split("\n"), lines): assert truth.strip() == out.strip() @@ -402,54 +429,55 @@ def fn(a_m2, a_m1, b_m2, b_m1): output_str = debugprint(final_result, file="str", print_op_info=True) lines = output_str.split("\n") - expected_output = """Elemwise{add,no_inplace} [id A] - |Subtensor{int64::} [id B] - | |for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0) - | | |TensorConstant{5} [id D] (n_steps) - | | |IncSubtensor{Set;:int64:} [id E] (outer_in_mit_sot-0) - | | | |AllocEmpty{dtype='int64'} [id F] - | | | | |Elemwise{add,no_inplace} [id G] - | | | | |TensorConstant{5} [id D] - | | | | |Subtensor{int64} [id H] - | | | | |Shape [id I] - | | | | | |Subtensor{:int64:} [id J] - | | | | | | [id K] - | | | | | |ScalarConstant{2} [id L] - | | | | |ScalarConstant{0} [id M] - | | | |Subtensor{:int64:} [id J] - | | | |ScalarFromTensor [id N] - | | | |Subtensor{int64} [id H] - | | |IncSubtensor{Set;:int64:} [id O] (outer_in_mit_sot-1) - | | |AllocEmpty{dtype='int64'} [id P] - | | | |Elemwise{add,no_inplace} [id Q] - | | | |TensorConstant{5} [id D] - | | | |Subtensor{int64} [id R] - | | | |Shape [id S] - | | | | |Subtensor{:int64:} [id T] - | | | | | [id U] - | | | | |ScalarConstant{2} [id V] - | | | |ScalarConstant{0} [id W] - | | |Subtensor{:int64:} [id T] - | | |ScalarFromTensor [id X] - | | |Subtensor{int64} [id R] - | |ScalarConstant{2} [id Y] - |Subtensor{int64::} [id Z] - |for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1) - |ScalarConstant{2} [id BA] + expected_output = """Add [id A] + ├─ Subtensor{start:} [id B] + │ ├─ Scan{scan_fn, while_loop=False, inplace=none}.0 [id C] (outer_out_mit_sot-0) + │ │ ├─ 5 [id D] (n_steps) + │ │ ├─ SetSubtensor{:stop} [id E] (outer_in_mit_sot-0) + │ │ │ ├─ AllocEmpty{dtype='int64'} [id F] + │ │ │ │ └─ Add [id G] + │ │ │ │ ├─ 5 [id D] + │ │ │ │ └─ Subtensor{i} [id H] + │ │ │ │ ├─ Shape [id I] + │ │ │ │ │ └─ Subtensor{:stop} [id J] + │ │ │ │ │ ├─ [id K] + │ │ │ │ │ └─ 2 [id L] + │ │ │ │ └─ 0 [id M] + │ │ │ ├─ Subtensor{:stop} [id J] + │ │ │ │ └─ ··· + │ │ │ └─ ScalarFromTensor [id N] + │ │ │ └─ Subtensor{i} [id H] + │ │ │ └─ ··· + │ │ └─ SetSubtensor{:stop} [id O] (outer_in_mit_sot-1) + │ │ ├─ AllocEmpty{dtype='int64'} [id P] + │ │ │ └─ Add [id Q] + │ │ │ ├─ 5 [id D] + │ │ │ └─ Subtensor{i} [id R] + │ │ │ ├─ Shape [id S] + │ │ │ │ └─ Subtensor{:stop} [id T] + │ │ │ │ ├─ [id U] + │ │ │ │ └─ 2 [id V] + │ │ │ └─ 0 [id W] + │ │ ├─ Subtensor{:stop} [id T] + │ │ │ └─ ··· + │ │ └─ ScalarFromTensor [id X] + │ │ └─ Subtensor{i} [id R] + │ │ └─ ··· + │ └─ 2 [id Y] + └─ Subtensor{start:} [id Z] + ├─ Scan{scan_fn, while_loop=False, inplace=none}.1 [id C] (outer_out_mit_sot-1) + │ └─ ··· + └─ 2 [id BA] Inner graphs: - for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0) - >Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0) - > |*1- [id BC] -> [id E] (inner_in_mit_sot-0-1) - > |*0- [id BD] -> [id E] (inner_in_mit_sot-0-0) - >Elemwise{add,no_inplace} [id BE] (inner_out_mit_sot-1) - > |*3- [id BF] -> [id O] (inner_in_mit_sot-1-1) - > |*2- [id BG] -> [id O] (inner_in_mit_sot-1-0) - - for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1) - >Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0) - >Elemwise{add,no_inplace} [id BE] (inner_out_mit_sot-1)""" + Scan{scan_fn, while_loop=False, inplace=none} [id C] + ← Add [id BB] (inner_out_mit_sot-0) + ├─ *1- [id BC] -> [id E] (inner_in_mit_sot-0-1) + └─ *0- [id BD] -> [id E] (inner_in_mit_sot-0-0) + ← Add [id BE] (inner_out_mit_sot-1) + ├─ *3- [id BF] -> [id O] (inner_in_mit_sot-1-1) + └─ *2- [id BG] -> [id O] (inner_in_mit_sot-1-0)""" for truth, out in zip(expected_output.split("\n"), lines): assert truth.strip() == out.strip() @@ -473,107 +501,119 @@ def test_debugprint_mitmot(): output_str = debugprint(final_result, file="str", print_op_info=True) lines = output_str.split("\n") - expected_output = """Subtensor{int64} [id A] - |for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0) - | |Elemwise{sub,no_inplace} [id C] (n_steps) - | | |Subtensor{int64} [id D] - | | | |Shape [id E] - | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) - | | | | |k [id G] (n_steps) - | | | | |IncSubtensor{Set;:int64:} [id H] (outer_in_sit_sot-0) - | | | | | |AllocEmpty{dtype='float64'} [id I] - | | | | | | |Elemwise{add,no_inplace} [id J] - | | | | | | | |k [id G] - | | | | | | | |Subtensor{int64} [id K] - | | | | | | | |Shape [id L] - | | | | | | | | |Unbroadcast{0} [id M] - | | | | | | | | |InplaceDimShuffle{x,0} [id N] - | | | | | | | | |Elemwise{second,no_inplace} [id O] - | | | | | | | | |A [id P] - | | | | | | | | |InplaceDimShuffle{x} [id Q] - | | | | | | | | |TensorConstant{1.0} [id R] - | | | | | | | |ScalarConstant{0} [id S] - | | | | | | |Subtensor{int64} [id T] - | | | | | | |Shape [id U] - | | | | | | | |Unbroadcast{0} [id M] - | | | | | | |ScalarConstant{1} [id V] - | | | | | |Unbroadcast{0} [id M] - | | | | | |ScalarFromTensor [id W] - | | | | | |Subtensor{int64} [id K] - | | | | |A [id P] (outer_in_non_seqs-0) - | | | |ScalarConstant{0} [id X] - | | |TensorConstant{1} [id Y] - | |Subtensor{:int64:} [id Z] (outer_in_seqs-0) - | | |Subtensor{::int64} [id BA] - | | | |Subtensor{:int64:} [id BB] - | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) - | | | | |ScalarConstant{-1} [id BC] - | | | |ScalarConstant{-1} [id BD] - | | |ScalarFromTensor [id BE] - | | |Elemwise{sub,no_inplace} [id C] - | |Subtensor{:int64:} [id BF] (outer_in_seqs-1) - | | |Subtensor{:int64:} [id BG] - | | | |Subtensor{::int64} [id BH] - | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) - | | | | |ScalarConstant{-1} [id BI] - | | | |ScalarConstant{-1} [id BJ] - | | |ScalarFromTensor [id BK] - | | |Elemwise{sub,no_inplace} [id C] - | |Subtensor{::int64} [id BL] (outer_in_mit_mot-0) - | | |IncSubtensor{Inc;int64::} [id BM] - | | | |Elemwise{second,no_inplace} [id BN] - | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) - | | | | |InplaceDimShuffle{x,x} [id BO] - | | | | |TensorConstant{0.0} [id BP] - | | | |IncSubtensor{Inc;int64} [id BQ] - | | | | |Elemwise{second,no_inplace} [id BR] - | | | | | |Subtensor{int64::} [id BS] - | | | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) - | | | | | | |ScalarConstant{1} [id BT] - | | | | | |InplaceDimShuffle{x,x} [id BU] - | | | | | |TensorConstant{0.0} [id BV] - | | | | |Elemwise{second} [id BW] - | | | | | |Subtensor{int64} [id BX] - | | | | | | |Subtensor{int64::} [id BS] - | | | | | | |ScalarConstant{-1} [id BY] - | | | | | |InplaceDimShuffle{x} [id BZ] - | | | | | |Elemwise{second,no_inplace} [id CA] - | | | | | |Sum{acc_dtype=float64} [id CB] - | | | | | | |Subtensor{int64} [id BX] - | | | | | |TensorConstant{1.0} [id CC] - | | | | |ScalarConstant{-1} [id BY] - | | | |ScalarConstant{1} [id BT] - | | |ScalarConstant{-1} [id CD] - | |Alloc [id CE] (outer_in_sit_sot-0) - | | |TensorConstant{0.0} [id CF] - | | |Elemwise{add,no_inplace} [id CG] - | | | |Elemwise{sub,no_inplace} [id C] - | | | |TensorConstant{1} [id CH] - | | |Subtensor{int64} [id CI] - | | |Shape [id CJ] - | | | |A [id P] - | | |ScalarConstant{0} [id CK] - | |A [id P] (outer_in_non_seqs-0) - |ScalarConstant{-1} [id CL] + expected_output = """Subtensor{i} [id A] + ├─ Scan{grad_of_scan_fn, while_loop=False, inplace=none}.1 [id B] (outer_out_sit_sot-0) + │ ├─ Sub [id C] (n_steps) + │ │ ├─ Subtensor{i} [id D] + │ │ │ ├─ Shape [id E] + │ │ │ │ └─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) + │ │ │ │ ├─ k [id G] (n_steps) + │ │ │ │ ├─ SetSubtensor{:stop} [id H] (outer_in_sit_sot-0) + │ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id I] + │ │ │ │ │ │ ├─ Add [id J] + │ │ │ │ │ │ │ ├─ k [id G] + │ │ │ │ │ │ │ └─ Subtensor{i} [id K] + │ │ │ │ │ │ │ ├─ Shape [id L] + │ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M] + │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N] + │ │ │ │ │ │ │ │ └─ Second [id O] + │ │ │ │ │ │ │ │ ├─ A [id P] + │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Q] + │ │ │ │ │ │ │ │ └─ 1.0 [id R] + │ │ │ │ │ │ │ └─ 0 [id S] + │ │ │ │ │ │ └─ Subtensor{i} [id T] + │ │ │ │ │ │ ├─ Shape [id U] + │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M] + │ │ │ │ │ │ │ └─ ··· + │ │ │ │ │ │ └─ 1 [id V] + │ │ │ │ │ ├─ Unbroadcast{0} [id M] + │ │ │ │ │ │ └─ ··· + │ │ │ │ │ └─ ScalarFromTensor [id W] + │ │ │ │ │ └─ Subtensor{i} [id K] + │ │ │ │ │ └─ ··· + │ │ │ │ └─ A [id P] (outer_in_non_seqs-0) + │ │ │ └─ 0 [id X] + │ │ └─ 1 [id Y] + │ ├─ Subtensor{:stop} [id Z] (outer_in_seqs-0) + │ │ ├─ Subtensor{::step} [id BA] + │ │ │ ├─ Subtensor{:stop} [id BB] + │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) + │ │ │ │ │ └─ ··· + │ │ │ │ └─ -1 [id BC] + │ │ │ └─ -1 [id BD] + │ │ └─ ScalarFromTensor [id BE] + │ │ └─ Sub [id C] + │ │ └─ ··· + │ ├─ Subtensor{:stop} [id BF] (outer_in_seqs-1) + │ │ ├─ Subtensor{:stop} [id BG] + │ │ │ ├─ Subtensor{::step} [id BH] + │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) + │ │ │ │ │ └─ ··· + │ │ │ │ └─ -1 [id BI] + │ │ │ └─ -1 [id BJ] + │ │ └─ ScalarFromTensor [id BK] + │ │ └─ Sub [id C] + │ │ └─ ··· + │ ├─ Subtensor{::step} [id BL] (outer_in_mit_mot-0) + │ │ ├─ IncSubtensor{start:} [id BM] + │ │ │ ├─ Second [id BN] + │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) + │ │ │ │ │ └─ ··· + │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO] + │ │ │ │ └─ 0.0 [id BP] + │ │ │ ├─ IncSubtensor{i} [id BQ] + │ │ │ │ ├─ Second [id BR] + │ │ │ │ │ ├─ Subtensor{start:} [id BS] + │ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0) + │ │ │ │ │ │ │ └─ ··· + │ │ │ │ │ │ └─ 1 [id BT] + │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU] + │ │ │ │ │ └─ 0.0 [id BV] + │ │ │ │ ├─ Second [id BW] + │ │ │ │ │ ├─ Subtensor{i} [id BX] + │ │ │ │ │ │ ├─ Subtensor{start:} [id BS] + │ │ │ │ │ │ │ └─ ··· + │ │ │ │ │ │ └─ -1 [id BY] + │ │ │ │ │ └─ ExpandDims{axis=0} [id BZ] + │ │ │ │ │ └─ Second [id CA] + │ │ │ │ │ ├─ Sum{axes=None} [id CB] + │ │ │ │ │ │ └─ Subtensor{i} [id BX] + │ │ │ │ │ │ └─ ··· + │ │ │ │ │ └─ 1.0 [id CC] + │ │ │ │ └─ -1 [id BY] + │ │ │ └─ 1 [id BT] + │ │ └─ -1 [id CD] + │ ├─ Alloc [id CE] (outer_in_sit_sot-0) + │ │ ├─ 0.0 [id CF] + │ │ ├─ Add [id CG] + │ │ │ ├─ Sub [id C] + │ │ │ │ └─ ··· + │ │ │ └─ 1 [id CH] + │ │ └─ Subtensor{i} [id CI] + │ │ ├─ Shape [id CJ] + │ │ │ └─ A [id P] + │ │ └─ 0 [id CK] + │ └─ A [id P] (outer_in_non_seqs-0) + └─ -1 [id CL] Inner graphs: - for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0) - >Elemwise{add,no_inplace} [id CM] (inner_out_mit_mot-0-0) - > |Elemwise{mul} [id CN] - > | |*2- [id CO] -> [id BL] (inner_in_mit_mot-0-0) - > | |*5- [id CP] -> [id P] (inner_in_non_seqs-0) - > |*3- [id CQ] -> [id BL] (inner_in_mit_mot-0-1) - >Elemwise{add,no_inplace} [id CR] (inner_out_sit_sot-0) - > |Elemwise{mul} [id CS] - > | |*2- [id CO] -> [id BL] (inner_in_mit_mot-0-0) - > | |*0- [id CT] -> [id Z] (inner_in_seqs-0) - > |*4- [id CU] -> [id CE] (inner_in_sit_sot-0) - - for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) - >Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0) - > |*0- [id CT] -> [id H] (inner_in_sit_sot-0) - > |*1- [id CW] -> [id P] (inner_in_non_seqs-0)""" + Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B] + ← Add [id CM] (inner_out_mit_mot-0-0) + ├─ Mul [id CN] + │ ├─ *2- [id CO] -> [id BL] (inner_in_mit_mot-0-0) + │ └─ *5- [id CP] -> [id P] (inner_in_non_seqs-0) + └─ *3- [id CQ] -> [id BL] (inner_in_mit_mot-0-1) + ← Add [id CR] (inner_out_sit_sot-0) + ├─ Mul [id CS] + │ ├─ *2- [id CO] -> [id BL] (inner_in_mit_mot-0-0) + │ └─ *0- [id CT] -> [id Z] (inner_in_seqs-0) + └─ *4- [id CU] -> [id CE] (inner_in_sit_sot-0) + + Scan{scan_fn, while_loop=False, inplace=none} [id F] + ← Mul [id CV] (inner_out_sit_sot-0) + ├─ *0- [id CT] -> [id H] (inner_in_sit_sot-0) + └─ *1- [id CW] -> [id P] (inner_in_non_seqs-0)""" for truth, out in zip(expected_output.split("\n"), lines): assert truth.strip() == out.strip() @@ -601,41 +641,40 @@ def no_shared_fn(n, x_tm1, M): # (i.e. from `Scan._fn`) out = pytensor.function([M], out, updates=updates, mode="FAST_RUN") - expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0) - |TensorConstant{20000} [id B] (n_steps) - |TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0) - |IncSubtensor{InplaceSet;:int64:} [id D] 1 (outer_in_sit_sot-0) - | |AllocEmpty{dtype='int64'} [id E] 0 - | | |TensorConstant{20000} [id B] - | |TensorConstant{(1,) of 0} [id F] - | |ScalarConstant{1} [id G] - | [id H] (outer_in_non_seqs-0) + expected_output = """Scan{scan_fn, while_loop=False, inplace=all} [id A] 2 (outer_out_sit_sot-0) + ├─ 20000 [id B] (n_steps) + ├─ [ 0 ... 998 19999] [id C] (outer_in_seqs-0) + ├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0) + │ ├─ AllocEmpty{dtype='int64'} [id E] 0 + │ │ └─ 20000 [id B] + │ ├─ [0] [id F] + │ └─ 1 [id G] + └─ [id H] (outer_in_non_seqs-0) Inner graphs: - forall_inplace,cpu,scan_fn} [id A] (outer_out_sit_sot-0) - >Elemwise{Composite} [id I] (inner_out_sit_sot-0) - > |TensorConstant{0} [id J] - > |Subtensor{int64, int64, uint8} [id K] - > | |*2- [id L] -> [id H] (inner_in_non_seqs-0) - > | |ScalarFromTensor [id M] - > | | |*0- [id N] -> [id C] (inner_in_seqs-0) - > | |ScalarFromTensor [id O] - > | | |*1- [id P] -> [id D] (inner_in_sit_sot-0) - > | |ScalarConstant{0} [id Q] - > |TensorConstant{1} [id R] - - Elemwise{Composite} [id I] - >Switch [id S] - > |LT [id T] - > | | [id U] - > | | [id V] - > | [id W] - > | [id U] + Scan{scan_fn, while_loop=False, inplace=all} [id A] + ← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0) + ├─ 0 [id J] + ├─ Subtensor{i, j, k} [id K] + │ ├─ *2- [id L] -> [id H] (inner_in_non_seqs-0) + │ ├─ ScalarFromTensor [id M] + │ │ └─ *0- [id N] -> [id C] (inner_in_seqs-0) + │ ├─ ScalarFromTensor [id O] + │ │ └─ *1- [id P] -> [id D] (inner_in_sit_sot-0) + │ └─ 0 [id Q] + └─ 1 [id R] + + Composite{switch(lt(i0, i1), i2, i0)} [id I] + ← Switch [id S] 'o0' + ├─ LT [id T] + │ ├─ i0 [id U] + │ └─ i1 [id V] + ├─ i2 [id W] + └─ i0 [id U] """ output_str = debugprint(out, file="str", print_op_info=True) - print(output_str) lines = output_str.split("\n") for truth, out in zip(expected_output.split("\n"), lines): diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py index dc0dfa58a9..06694f6a67 100644 --- a/tests/tensor/rewriting/test_basic.py +++ b/tests/tensor/rewriting/test_basic.py @@ -12,6 +12,7 @@ from pytensor.compile.mode import get_default_mode, get_mode from pytensor.compile.ops import DeepCopyOp, deep_copy_op from pytensor.configdefaults import config +from pytensor.graph.basic import equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.db import RewriteDatabaseQuery @@ -1410,31 +1411,28 @@ def simple_rewrite(self, g): def test_matrix_matrix(self): a, b = matrices("ab") - g = self.simple_rewrite(FunctionGraph([a, b], [dot(a, b).T])) - sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{1,0}(a)))" - assert str(g) == sg, (str(g), sg) + g = self.simple_rewrite(FunctionGraph([a, b], [dot(a, b).T], clone=False)) + assert equal_computations(g.outputs, [dot(b.T, a.T)]) assert check_stack_trace(g, ops_to_check="all") def test_row_matrix(self): a = vector("a") b = matrix("b") g = rewrite( - FunctionGraph([a, b], [dot(a.dimshuffle("x", 0), b).T]), + FunctionGraph([a, b], [dot(a.dimshuffle("x", 0), b).T], clone=False), level="stabilize", ) - sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{0,x}(a)))" - assert str(g) == sg, (str(g), sg) + assert equal_computations(g.outputs, [dot(b.T, a.dimshuffle(0, "x"))]) assert check_stack_trace(g, ops_to_check="all") def test_matrix_col(self): a = vector("a") b = matrix("b") g = rewrite( - FunctionGraph([a, b], [dot(b, a.dimshuffle(0, "x")).T]), + FunctionGraph([a, b], [dot(b, a.dimshuffle(0, "x")).T], clone=False), level="stabilize", ) - sg = "FunctionGraph(dot(InplaceDimShuffle{x,0}(a), InplaceDimShuffle{1,0}(b)))" - assert str(g) == sg, (str(g), sg) + assert equal_computations(g.outputs, [dot(a.dimshuffle("x", 0), b.T)]) assert check_stack_trace(g, ops_to_check="all") diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py index ddec8c5292..1f149ed965 100644 --- a/tests/tensor/rewriting/test_elemwise.py +++ b/tests/tensor/rewriting/test_elemwise.py @@ -12,7 +12,7 @@ from pytensor.compile.mode import Mode, get_default_mode from pytensor.configdefaults import config from pytensor.gradient import grad -from pytensor.graph.basic import Constant +from pytensor.graph.basic import Constant, equal_computations from pytensor.graph.fg import FunctionGraph from pytensor.graph.rewriting.basic import check_stack_trace, out2in from pytensor.graph.rewriting.db import RewriteDatabaseQuery @@ -86,113 +86,66 @@ def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)): class TestDimshuffleLift: def test_double_transpose(self): - x, y, z = inputs() + x, *_ = inputs() e = ds(ds(x, (1, 0)), (1, 0)) - g = FunctionGraph([x], [e]) - # TODO FIXME: Construct these graphs and compare them. - assert ( - str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))" - ) + g = FunctionGraph([x], [e], clone=False) + assert isinstance(g.outputs[0].owner.op, DimShuffle) dimshuffle_lift.rewrite(g) - assert str(g) == "FunctionGraph(x)" + assert g.outputs[0] is x # no need to check_stack_trace as graph is supposed to be empty def test_merge2(self): - x, y, z = inputs() + x, *_ = inputs() e = ds(ds(x, (1, "x", 0)), (2, 0, "x", 1)) - g = FunctionGraph([x], [e]) - # TODO FIXME: Construct these graphs and compare them. - assert ( - str(g) - == "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))" - ), str(g) + g = FunctionGraph([x], [e], clone=False) + assert len(g.apply_nodes) == 2 dimshuffle_lift.rewrite(g) - assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g) + assert equal_computations(g.outputs, [x.dimshuffle(0, 1, "x", "x")]) # Check stacktrace was copied over correctly after rewrite was applied assert check_stack_trace(g, ops_to_check="all") def test_elim3(self): x, y, z = inputs() e = ds(ds(ds(x, (0, "x", 1)), (2, 0, "x", 1)), (1, 0)) - g = FunctionGraph([x], [e]) - # TODO FIXME: Construct these graphs and compare them. - assert str(g) == ( - "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}" - "(InplaceDimShuffle{0,x,1}(x))))" - ), str(g) + g = FunctionGraph([x], [e], clone=False) + assert isinstance(g.outputs[0].owner.op, DimShuffle) dimshuffle_lift.rewrite(g) - assert str(g) == "FunctionGraph(x)", str(g) + assert g.outputs[0] is x # no need to check_stack_trace as graph is supposed to be empty def test_lift(self): x, y, z = inputs([False] * 1, [False] * 2, [False] * 3) e = x + y + z - g = FunctionGraph([x, y, z], [e]) - - # TODO FIXME: Construct these graphs and compare them. - # It does not really matter if the DimShuffles are inplace - # or not. - init_str_g_inplace = ( - "FunctionGraph(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}" - "(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z))" - ) - init_str_g_noinplace = ( - "FunctionGraph(Elemwise{add,no_inplace}(DimShuffle{x,0,1}" - "(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z))" - ) - assert str(g) in (init_str_g_inplace, init_str_g_noinplace), str(g) - - rewrite_str_g_inplace = ( - "FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}" - "(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z))" - ) - rewrite_str_g_noinplace = ( - "FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}" - "(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))" - ) + g = FunctionGraph([x, y, z], [e], clone=False) dimshuffle_lift.rewrite(g) - assert str(g) in (rewrite_str_g_inplace, rewrite_str_g_noinplace), str(g) + assert equal_computations( + g.outputs, + [(x.dimshuffle("x", "x", 0) + y.dimshuffle("x", 0, 1)) + z], + ) # Check stacktrace was copied over correctly after rewrite was applied assert check_stack_trace(g, ops_to_check="all") def test_recursive_lift(self): - v = vector(dtype="float64") - m = matrix(dtype="float64") + v = vector("v", dtype="float64") + m = matrix("m", dtype="float64") out = ((v + 42) * (m + 84)).T - g = FunctionGraph([v, m], [out]) - # TODO FIXME: Construct these graphs and compare them. - init_str_g = ( - "FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}" - "(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}" - "(, " - "InplaceDimShuffle{x}(TensorConstant{42}))), " - "Elemwise{add,no_inplace}" - "(, " - "InplaceDimShuffle{x,x}(TensorConstant{84})))))" - ) - assert str(g) == init_str_g - new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner)[0] - new_g = FunctionGraph(g.inputs, [new_out]) - rewrite_str_g = ( - "FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}" - "(InplaceDimShuffle{0,x}(), " - "InplaceDimShuffle{x,x}(TensorConstant{42})), " - "Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}" - "(), " - "InplaceDimShuffle{x,x}(TensorConstant{84}))))" + g = FunctionGraph([v, m], [out], clone=False) + new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner) + assert equal_computations( + new_out, + [(v.dimshuffle(0, "x") + 42) * (m.T + 84)], ) - assert str(new_g) == rewrite_str_g # Check stacktrace was copied over correctly after rewrite was applied + new_g = FunctionGraph(g.inputs, new_out, clone=False) assert check_stack_trace(new_g, ops_to_check="all") def test_useless_dimshuffle(self): - x, _, _ = inputs() + x, *_ = inputs() e = ds(x, (0, 1)) - g = FunctionGraph([x], [e]) - # TODO FIXME: Construct these graphs and compare them. - assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))" + g = FunctionGraph([x], [e], clone=False) + assert isinstance(g.outputs[0].owner.op, DimShuffle) dimshuffle_lift.rewrite(g) - assert str(g) == "FunctionGraph(x)" + assert g.outputs[0] is x # Check stacktrace was copied over correctly after rewrite was applied assert hasattr(g.outputs[0].tag, "trace") @@ -203,17 +156,10 @@ def test_dimshuffle_on_broadcastable(self): ds_y = ds(y, (2, 1, 0)) # useless ds_z = ds(z, (2, 1, 0)) # useful ds_u = ds(u, ("x")) # useful - g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u]) - # TODO FIXME: Construct these graphs and compare them. - assert ( - str(g) - == "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))" - ) + g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u], clone=False) + assert len(g.apply_nodes) == 4 dimshuffle_lift.rewrite(g) - assert ( - str(g) - == "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))" - ) + assert equal_computations(g.outputs, [x, y, z.T, u.dimshuffle("x")]) # Check stacktrace was copied over correctly after rewrite was applied assert hasattr(g.outputs[0].tag, "trace") @@ -237,34 +183,32 @@ def test_local_useless_dimshuffle_in_reshape(): reshape_dimshuffle_row, reshape_dimshuffle_col, ], + clone=False, ) - - # TODO FIXME: Construct these graphs and compare them. - assert str(g) == ( - "FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), " - "Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), " - "Reshape{2}(InplaceDimShuffle{1,x}(row), Shape(row)), " - "Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))" - ) + assert len(g.apply_nodes) == 4 * 3 useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape) useless_dimshuffle_in_reshape.rewrite(g) - assert str(g) == ( - "FunctionGraph(Reshape{1}(vector, Shape(vector)), " - "Reshape{2}(mat, Shape(mat)), " - "Reshape{2}(row, Shape(row)), " - "Reshape{2}(col, Shape(col)))" + assert equal_computations( + g.outputs, + [ + reshape(vec, vec.shape), + reshape(mat, mat.shape), + reshape(row, row.shape), + reshape(col, col.shape), + ], ) - # Check stacktrace was copied over correctly after rewrite was applied assert check_stack_trace(g, ops_to_check="all") # Check that the rewrite does not get applied when the order # of dimensions has changed. reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape) - h = FunctionGraph([mat], [reshape_dimshuffle_mat2]) - str_h = str(h) + h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False) + assert len(h.apply_nodes) == 3 useless_dimshuffle_in_reshape.rewrite(h) - assert str(h) == str_h + assert equal_computations( + h.outputs, [reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)] + ) class TestFusion: diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py index 8db1afe7e3..40e7db879c 100644 --- a/tests/tensor/test_elemwise.py +++ b/tests/tensor/test_elemwise.py @@ -676,14 +676,9 @@ def test_infer_shape(self, dtype=None, pre_scalar_op=None): def test_str(self): op = CAReduce(aes.add, axis=None) - assert str(op) == "CAReduce{add}" + assert str(op) == "CAReduce{add, axes=None}" op = CAReduce(aes.add, axis=(1,)) - assert str(op) == "CAReduce{add}{axis=[1]}" - - op = CAReduce(aes.add, axis=None, acc_dtype="float64") - assert str(op) == "CAReduce{add}{acc_dtype=float64}" - op = CAReduce(aes.add, axis=(1,), acc_dtype="float64") - assert str(op) == "CAReduce{add}{axis=[1], acc_dtype=float64}" + assert str(op) == "CAReduce{add, axis=1}" def test_repeated_axis(self): x = vector("x") @@ -802,10 +797,8 @@ def test_input_dimensions_match_c(self): self.check_input_dimensions_match(Mode(linker="c")) def test_str(self): - op = Elemwise(aes.add, inplace_pattern=None, name=None) - assert str(op) == "Elemwise{add}" op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None) - assert str(op) == "Elemwise{add}[(0, 0)]" + assert str(op) == "Add" op = Elemwise(aes.add, inplace_pattern=None, name="my_op") assert str(op) == "my_op" diff --git a/tests/tensor/test_type.py b/tests/tensor/test_type.py index 8ef4ca2864..4e9c456829 100644 --- a/tests/tensor/test_type.py +++ b/tests/tensor/test_type.py @@ -252,7 +252,7 @@ def test_fixed_shape_basic(): assert t1.shape == (2, 3) assert t1.broadcastable == (False, False) - assert str(t1) == "TensorType(float64, (2, 3))" + assert str(t1) == "Matrix(float64, shape=(2, 3))" t1 = TensorType("float64", shape=(1,)) assert t1.shape == (1,) diff --git a/tests/tensor/test_var.py b/tests/tensor/test_var.py index 483b0f7d28..c17f524797 100644 --- a/tests/tensor/test_var.py +++ b/tests/tensor/test_var.py @@ -213,7 +213,7 @@ def test_print_constant(): c = pytensor.tensor.constant(1, name="const") assert str(c) == "const{1}" d = pytensor.tensor.constant(1) - assert str(d) == "TensorConstant{1}" + assert str(d) == "1" @pytest.mark.parametrize( diff --git a/tests/test_printing.py b/tests/test_printing.py index e481db5ac2..72c47943b5 100644 --- a/tests/test_printing.py +++ b/tests/test_printing.py @@ -106,9 +106,9 @@ def test_min_informative_str(): mis = min_informative_str(G).replace("\t", " ") - reference = """A. Elemwise{add,no_inplace} + reference = """A. Add B. C - C. Elemwise{add,no_inplace} + C. Add D. D E. E""" @@ -144,13 +144,13 @@ def test_debugprint(): s = s.getvalue() reference = dedent( r""" - Elemwise{add,no_inplace} [id 0] - |Elemwise{add,no_inplace} [id 1] 'C' - | |A [id 2] - | |B [id 3] - |Elemwise{add,no_inplace} [id 4] - |D [id 5] - |E [id 6] + Add [id 0] + ├─ Add [id 1] 'C' + │ ├─ A [id 2] + │ └─ B [id 3] + └─ Add [id 4] + ├─ D [id 5] + └─ E [id 6] """ ).lstrip() @@ -162,13 +162,13 @@ def test_debugprint(): # The additional white space are needed! reference = dedent( r""" - Elemwise{add,no_inplace} [id A] - |Elemwise{add,no_inplace} [id B] 'C' - | |A [id C] - | |B [id D] - |Elemwise{add,no_inplace} [id E] - |D [id F] - |E [id G] + Add [id A] + ├─ Add [id B] 'C' + │ ├─ A [id C] + │ └─ B [id D] + └─ Add [id E] + ├─ D [id F] + └─ E [id G] """ ).lstrip() @@ -180,11 +180,12 @@ def test_debugprint(): # The additional white space are needed! reference = dedent( r""" - Elemwise{add,no_inplace} [id A] - |Elemwise{add,no_inplace} [id B] 'C' - |Elemwise{add,no_inplace} [id C] - |D [id D] - |E [id E] + Add [id A] + ├─ Add [id B] 'C' + │ └─ ··· + └─ Add [id C] + ├─ D [id D] + └─ E [id E] """ ).lstrip() @@ -195,13 +196,13 @@ def test_debugprint(): s = s.getvalue() reference = dedent( r""" - Elemwise{add,no_inplace} - |Elemwise{add,no_inplace} 'C' - | |A - | |B - |Elemwise{add,no_inplace} - |D - |E + Add + ├─ Add 'C' + │ ├─ A + │ └─ B + └─ Add + ├─ D + └─ E """ ).lstrip() @@ -212,11 +213,11 @@ def test_debugprint(): s = s.getvalue() reference = dedent( r""" - Elemwise{add,no_inplace} 0 [None] - |A [None] - |B [None] - |D [None] - |E [None] + Add 0 [None] + ├─ A [None] + ├─ B [None] + ├─ D [None] + └─ E [None] """ ).lstrip() @@ -230,11 +231,11 @@ def test_debugprint(): s = s.getvalue() reference = dedent( r""" - Elemwise{add,no_inplace} 0 [None] - |A [None] - |B [None] - |D [None] - |E [None] + Add 0 [None] + ├─ A [None] + ├─ B [None] + ├─ D [None] + └─ E [None] """ ).lstrip() @@ -248,11 +249,11 @@ def test_debugprint(): s = s.getvalue() reference = dedent( r""" - Elemwise{add,no_inplace} 0 [None] - |A [None] - |B [None] - |D [None] - |E [None] + Add 0 [None] + ├─ A [None] + ├─ B [None] + ├─ D [None] + └─ E [None] """ ).lstrip() @@ -273,27 +274,27 @@ def test_debugprint(): s = s.getvalue() exp_res = dedent( r""" - Elemwise{Composite} 4 - |InplaceDimShuffle{x,0} v={0: [0]} 3 - | |CGemv{inplace} d={0: [0]} 2 - | |AllocEmpty{dtype='float64'} 1 - | | |Shape_i{0} 0 - | | |B - | |TensorConstant{1.0} - | |B - | | - | |TensorConstant{0.0} - |D - |A + Composite{(i2 + (i0 - i1))} 4 + ├─ ExpandDims{axis=0} v={0: [0]} 3 + │ └─ CGemv{inplace} d={0: [0]} 2 + │ ├─ AllocEmpty{dtype='float64'} 1 + │ │ └─ Shape_i{0} 0 + │ │ └─ B + │ ├─ 1.0 + │ ├─ B + │ ├─ + │ └─ 0.0 + ├─ D + └─ A Inner graphs: - Elemwise{Composite} - >add - > | - > |sub - > | - > | + Composite{(i2 + (i0 - i1))} + ← add 'o0' + ├─ i2 + └─ sub + ├─ i0 + └─ i1 """ ).lstrip() @@ -313,11 +314,11 @@ def test_debugprint_id_type(): debugprint(e_at, id_type="auto", file=s) s = s.getvalue() - exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}] - |dot [id {d_at.auto_name}] - | | [id {b_at.auto_name}] - | | [id {a_at.auto_name}] - | [id {a_at.auto_name}] + exp_res = f"""Add [id {e_at.auto_name}] + ├─ dot [id {d_at.auto_name}] + │ ├─ [id {b_at.auto_name}] + │ └─ [id {a_at.auto_name}] + └─ [id {a_at.auto_name}] """ assert [l.strip() for l in s.split("\n")] == [ @@ -328,7 +329,7 @@ def test_debugprint_id_type(): def test_pprint(): x = dvector() y = x[1] - assert pp(y) == "[1]" + assert pp(y) == "[1]" def test_debugprint_inner_graph(): @@ -351,15 +352,15 @@ def test_debugprint_inner_graph(): lines = output_str.split("\n") exp_res = """MyInnerGraphOp [id A] - |3 [id B] - |4 [id C] + ├─ 3 [id B] + └─ 4 [id C] Inner graphs: MyInnerGraphOp [id A] - >op2 [id D] 'igo1' - > |*0- [id E] - > |*1- [id F] + ← op2 [id D] 'igo1' + ├─ *0- [id E] + └─ *1- [id F] """ for exp_line, res_line in zip(exp_res.split("\n"), lines): @@ -375,19 +376,19 @@ def test_debugprint_inner_graph(): lines = output_str.split("\n") exp_res = """MyInnerGraphOp [id A] - |5 [id B] + └─ 5 [id B] Inner graphs: MyInnerGraphOp [id A] - >MyInnerGraphOp [id C] - > |*0- [id D] - > |*1- [id E] + ← MyInnerGraphOp [id C] + ├─ *0- [id D] + └─ *1- [id E] MyInnerGraphOp [id C] - >op2 [id F] 'igo1' - > |*0- [id D] - > |*1- [id E] + ← op2 [id F] 'igo1' + ├─ *0- [id D] + └─ *1- [id E] """ for exp_line, res_line in zip(exp_res.split("\n"), lines):