diff --git a/README.rst b/README.rst
index 0be106c05f..9ea767c528 100644
--- a/README.rst
+++ b/README.rst
@@ -22,69 +22,69 @@ Getting started
.. code-block:: python
- import pytensor
- from pytensor import tensor as pt
-
- # Declare two symbolic floating-point scalars
- a = pt.dscalar("a")
- b = pt.dscalar("b")
-
- # Create a simple example expression
- c = a + b
-
- # Convert the expression into a callable object that takes `(a, b)`
- # values as input and computes the value of `c`.
- f_c = pytensor.function([a, b], c)
-
- assert f_c(1.5, 2.5) == 4.0
-
- # Compute the gradient of the example expression with respect to `a`
- dc = pytensor.grad(c, a)
-
- f_dc = pytensor.function([a, b], dc)
-
- assert f_dc(1.5, 2.5) == 1.0
-
- # Compiling functions with `pytensor.function` also optimizes
- # expression graphs by removing unnecessary operations and
- # replacing computations with more efficient ones.
-
- v = pt.vector("v")
- M = pt.matrix("M")
-
- d = a/a + (M + a).dot(v)
-
- pytensor.dprint(d)
- # Elemwise{add,no_inplace} [id A] ''
- # |InplaceDimShuffle{x} [id B] ''
- # | |Elemwise{true_div,no_inplace} [id C] ''
- # | |a [id D]
- # | |a [id D]
- # |dot [id E] ''
- # |Elemwise{add,no_inplace} [id F] ''
- # | |M [id G]
- # | |InplaceDimShuffle{x,x} [id H] ''
- # | |a [id D]
- # |v [id I]
-
- f_d = pytensor.function([a, v, M], d)
-
- # `a/a` -> `1` and the dot product is replaced with a BLAS function
- # (i.e. CGemv)
- pytensor.dprint(f_d)
- # Elemwise{Add}[(0, 1)] [id A] '' 5
- # |TensorConstant{(1,) of 1.0} [id B]
- # |CGemv{inplace} [id C] '' 4
- # |AllocEmpty{dtype='float64'} [id D] '' 3
- # | |Shape_i{0} [id E] '' 2
- # | |M [id F]
- # |TensorConstant{1.0} [id G]
- # |Elemwise{add,no_inplace} [id H] '' 1
- # | |M [id F]
- # | |InplaceDimShuffle{x,x} [id I] '' 0
- # | |a [id J]
- # |v [id K]
- # |TensorConstant{0.0} [id L]
+ import pytensor
+ from pytensor import tensor as pt
+
+ # Declare two symbolic floating-point scalars
+ a = pt.dscalar("a")
+ b = pt.dscalar("b")
+
+ # Create a simple example expression
+ c = a + b
+
+ # Convert the expression into a callable object that takes `(a, b)`
+ # values as input and computes the value of `c`.
+ f_c = pytensor.function([a, b], c)
+
+ assert f_c(1.5, 2.5) == 4.0
+
+ # Compute the gradient of the example expression with respect to `a`
+ dc = pytensor.grad(c, a)
+
+ f_dc = pytensor.function([a, b], dc)
+
+ assert f_dc(1.5, 2.5) == 1.0
+
+ # Compiling functions with `pytensor.function` also optimizes
+ # expression graphs by removing unnecessary operations and
+ # replacing computations with more efficient ones.
+
+ v = pt.vector("v")
+ M = pt.matrix("M")
+
+ d = a/a + (M + a).dot(v)
+
+ pytensor.dprint(d)
+ # Add [id A]
+ # ├─ ExpandDims{axis=0} [id B]
+ # │ └─ True_div [id C]
+ # │ ├─ a [id D]
+ # │ └─ a [id D]
+ # └─ dot [id E]
+ # ├─ Add [id F]
+ # │ ├─ M [id G]
+ # │ └─ ExpandDims{axes=[0, 1]} [id H]
+ # │ └─ a [id D]
+ # └─ v [id I]
+
+ f_d = pytensor.function([a, v, M], d)
+
+ # `a/a` -> `1` and the dot product is replaced with a BLAS function
+ # (i.e. CGemv)
+ pytensor.dprint(f_d)
+ # Add [id A] 5
+ # ├─ [1.] [id B]
+ # └─ CGemv{inplace} [id C] 4
+ # ├─ AllocEmpty{dtype='float64'} [id D] 3
+ # │ └─ Shape_i{0} [id E] 2
+ # │ └─ M [id F]
+ # ├─ 1.0 [id G]
+ # ├─ Add [id H] 1
+ # │ ├─ M [id F]
+ # │ └─ ExpandDims{axes=[0, 1]} [id I] 0
+ # │ └─ a [id J]
+ # ├─ v [id K]
+ # └─ 0.0 [id L]
See `the PyTensor documentation `__ for in-depth tutorials.
diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py
index 7e59bf69e1..5111607a0a 100644
--- a/pytensor/graph/basic.py
+++ b/pytensor/graph/basic.py
@@ -763,13 +763,20 @@ def signature(self):
return (self.type, self.data)
def __str__(self):
- if self.name is not None:
- return self.name
- else:
- name = str(self.data)
- if len(name) > 20:
- name = name[:10] + "..." + name[-10:]
- return f"{type(self).__name__}{{{name}}}"
+ data_str = str(self.data)
+ if len(data_str) > 20:
+ data_str = data_str[:10].strip() + " ... " + data_str[-10:].strip()
+
+ if self.name is None:
+ return data_str
+
+ return f"{self.name}{{{data_str}}}"
+
+ def __repr__(self):
+ data_str = repr(self.data)
+ if len(data_str) > 20:
+ data_str = data_str[:10].strip() + " ... " + data_str[-10:].strip()
+ return f"{type(self).__name__}({repr(self.type)}, data={data_str})"
def clone(self, **kwargs):
return self
diff --git a/pytensor/graph/op.py b/pytensor/graph/op.py
index b59394eb32..0cda6f1327 100644
--- a/pytensor/graph/op.py
+++ b/pytensor/graph/op.py
@@ -621,6 +621,11 @@ def make_thunk(
def __str__(self):
return getattr(type(self), "__name__", super().__str__())
+ def __repr__(self):
+ props = getattr(self, "__props__", ())
+ props = ",".join(f"{prop}={getattr(self, prop, '?')}" for prop in props)
+ return f"{self.__class__.__name__}({props})"
+
class _NoPythonOp(Op):
"""A class used to indicate that an `Op` does not provide a Python implementation.
diff --git a/pytensor/graph/utils.py b/pytensor/graph/utils.py
index 9b0abc5c73..6f00c687bf 100644
--- a/pytensor/graph/utils.py
+++ b/pytensor/graph/utils.py
@@ -234,6 +234,7 @@ def __eq__(self, other):
dct["__eq__"] = __eq__
+ # FIXME: This overrides __str__ inheritance when props are provided
if "__str__" not in dct:
if len(props) == 0:
diff --git a/pytensor/printing.py b/pytensor/printing.py
index 0b866f079e..11b266021b 100644
--- a/pytensor/printing.py
+++ b/pytensor/printing.py
@@ -291,7 +291,7 @@ def debugprint(
for var in inputs_to_print:
_debugprint(
var,
- prefix="-",
+ prefix="→ ",
depth=depth,
done=done,
print_type=print_type,
@@ -342,11 +342,17 @@ def debugprint(
if len(inner_graph_vars) > 0:
print("", file=_file)
- new_prefix = " >"
- new_prefix_child = " >"
+ prefix = ""
+ new_prefix = prefix + " ← "
+ new_prefix_child = prefix + " "
print("Inner graphs:", file=_file)
+ printed_inner_graphs_nodes = set()
for ig_var in inner_graph_vars:
+ if ig_var.owner in printed_inner_graphs_nodes:
+ continue
+ else:
+ printed_inner_graphs_nodes.add(ig_var.owner)
# This is a work-around to maintain backward compatibility
# (e.g. to only print inner graphs that have been compiled through
# a call to `Op.prepare_node`)
@@ -385,6 +391,7 @@ def debugprint(
_debugprint(
ig_var,
+ prefix=prefix,
depth=depth,
done=done,
print_type=print_type,
@@ -399,13 +406,14 @@ def debugprint(
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
+ is_inner_graph_header=True,
)
if print_fgraph_inputs:
for inp in inner_inputs:
_debugprint(
inp,
- prefix="-",
+ prefix=" → ",
depth=depth,
done=done,
print_type=print_type,
@@ -485,6 +493,7 @@ def _debugprint(
parent_node: Optional[Apply] = None,
print_op_info: bool = False,
inner_graph_node: Optional[Apply] = None,
+ is_inner_graph_header: bool = False,
) -> TextIO:
r"""Print the graph represented by `var`.
@@ -625,7 +634,10 @@ def get_id_str(
else:
data = ""
- var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"
+ if is_inner_graph_header:
+ var_output = f"{prefix}{node.op}{id_str}{destroy_map_str}{view_map_str}{o}"
+ else:
+ var_output = f"{prefix}{node.op}{output_idx}{id_str}{type_str}{var_name}{destroy_map_str}{view_map_str}{o}{data}"
if print_op_info and node not in op_information:
op_information.update(op_debug_information(node.op, node))
@@ -633,7 +645,7 @@ def get_id_str(
node_info = (
parent_node and op_information.get(parent_node)
) or op_information.get(node)
- if node_info and var in node_info:
+ if node_info and var in node_info and not is_inner_graph_header:
var_output = f"{var_output} ({node_info[var]})"
if profile and profile.apply_time and node in profile.apply_time:
@@ -660,12 +672,13 @@ def get_id_str(
if not already_done and (
not stop_on_name or not (hasattr(var, "name") and var.name is not None)
):
- new_prefix = prefix_child + " |"
- new_prefix_child = prefix_child + " |"
+ new_prefix = prefix_child + " ├─ "
+ new_prefix_child = prefix_child + " │ "
for in_idx, in_var in enumerate(node.inputs):
if in_idx == len(node.inputs) - 1:
- new_prefix_child = prefix_child + " "
+ new_prefix = prefix_child + " └─ "
+ new_prefix_child = prefix_child + " "
if hasattr(in_var, "owner") and hasattr(in_var.owner, "op"):
if (
@@ -698,6 +711,8 @@ def get_id_str(
print_view_map=print_view_map,
inner_graph_node=inner_graph_node,
)
+ elif not is_inner_graph_header:
+ print(prefix_child + " └─ ···", file=file)
else:
id_str = get_id_str(var)
diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py
index 96cf107d64..9c270e2a45 100644
--- a/pytensor/scalar/basic.py
+++ b/pytensor/scalar/basic.py
@@ -4143,6 +4143,7 @@ class Composite(ScalarInnerGraphOp):
def __init__(self, inputs, outputs, name="Composite"):
self.name = name
+ self._name = None
# We need to clone the graph as sometimes its nodes already
# contain a reference to an fgraph. As we want the Composite
# to be pickable, we can't have reference to fgraph.
@@ -4189,7 +4190,26 @@ def __init__(self, inputs, outputs, name="Composite"):
super().__init__()
def __str__(self):
- return self.name
+ if self._name is not None:
+ return self._name
+
+ # Rename internal variables
+ for i, r in enumerate(self.fgraph.inputs):
+ r.name = f"i{int(i)}"
+ for i, r in enumerate(self.fgraph.outputs):
+ r.name = f"o{int(i)}"
+ io = set(self.fgraph.inputs + self.fgraph.outputs)
+ for i, r in enumerate(self.fgraph.variables):
+ if r not in io and len(self.fgraph.clients[r]) > 1:
+ r.name = f"t{int(i)}"
+
+ if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10:
+ self._name = "Composite{...}"
+ else:
+ outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
+ self._name = f"Composite{{{outputs_str}}}"
+
+ return self._name
def make_new_inplace(self, output_types_preference=None, name=None):
"""
diff --git a/pytensor/scan/op.py b/pytensor/scan/op.py
index b35cfa9b7b..69eb0a5fc2 100644
--- a/pytensor/scan/op.py
+++ b/pytensor/scan/op.py
@@ -1282,27 +1282,18 @@ def __eq__(self, other):
)
def __str__(self):
- device_str = "cpu"
- if self.info.as_while:
- name = "do_while"
- else:
- name = "for"
- aux_txt = "%s"
+ inplace = "none"
if len(self.destroy_map.keys()) > 0:
# Check if all outputs are inplace
if sorted(self.destroy_map.keys()) == sorted(
range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot)
):
- aux_txt += "all_inplace,%s,%s}"
+ inplace = "all"
else:
- aux_txt += "{inplace{"
- for k in self.destroy_map.keys():
- aux_txt += str(k) + ","
- aux_txt += "},%s,%s}"
- else:
- aux_txt += "{%s,%s}"
- aux_txt = aux_txt % (name, device_str, str(self.name))
- return aux_txt
+ inplace = str(list(self.destroy_map.keys()))
+ return (
+ f"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}"
+ )
def __hash__(self):
return hash(
diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py
index f498ae5079..6d19579030 100644
--- a/pytensor/tensor/elemwise.py
+++ b/pytensor/tensor/elemwise.py
@@ -14,7 +14,7 @@
from pytensor.link.c.params_type import ParamsType
from pytensor.misc.frozendict import frozendict
from pytensor.misc.safe_asarray import _asarray
-from pytensor.printing import FunctionPrinter, Printer, pprint
+from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import bool as scalar_bool
from pytensor.scalar.basic import identity as scalar_identity
@@ -215,10 +215,18 @@ def make_node(self, _input):
return Apply(self, [input], [output])
def __str__(self):
- if self.inplace:
- return "InplaceDimShuffle{%s}" % ",".join(str(x) for x in self.new_order)
- else:
- return "DimShuffle{%s}" % ",".join(str(x) for x in self.new_order)
+ shuffle = sorted(self.shuffle) != self.shuffle
+ if self.augment and not (shuffle or self.drop):
+ if len(self.augment) == 1:
+ return f"ExpandDims{{axis={self.augment[0]}}}"
+ return f"ExpandDims{{axes={self.augment}}}"
+ if self.drop and not (self.augment or shuffle):
+ if len(self.drop) == 1:
+ return f"DropDims{{axis={self.drop[0]}}}"
+ return f"DropDims{{axes={self.drop}}}"
+ if shuffle and not (self.augment or self.drop):
+ return f"Transpose{{axes={self.shuffle}}}"
+ return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
def perform(self, node, inp, out, params):
(res,) = inp
@@ -490,15 +498,9 @@ def make_node(self, *inputs):
return Apply(self, inputs, outputs)
def __str__(self):
- if self.name is None:
- if self.inplace_pattern:
- items = list(self.inplace_pattern.items())
- items.sort()
- return f"{type(self).__name__}{{{self.scalar_op}}}{items}"
- else:
- return f"{type(self).__name__}{{{self.scalar_op}}}"
- else:
+ if self.name:
return self.name
+ return str(self.scalar_op).capitalize()
def R_op(self, inputs, eval_points):
outs = self(*inputs, return_list=True)
@@ -1469,23 +1471,17 @@ def clone(
return res
- def __str__(self):
- prefix = f"{type(self).__name__}{{{self.scalar_op}}}"
- extra_params = []
-
- if self.axis is not None:
- axis = ", ".join(str(x) for x in self.axis)
- extra_params.append(f"axis=[{axis}]")
-
- if self.acc_dtype:
- extra_params.append(f"acc_dtype={self.acc_dtype}")
-
- extra_params_str = ", ".join(extra_params)
-
- if extra_params_str:
- return f"{prefix}{{{extra_params_str}}}"
+ def _axis_str(self):
+ axis = self.axis
+ if axis is None:
+ return "axes=None"
+ elif len(axis) == 1:
+ return f"axis={axis[0]}"
else:
- return f"{prefix}"
+ return f"axes={list(axis)}"
+
+ def __str__(self):
+ return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}"
def perform(self, node, inp, out):
(input,) = inp
@@ -1729,21 +1725,17 @@ def construct(symbol):
symbolname = symbolname or symbol.__name__
if symbolname.endswith("_inplace"):
- elemwise_name = f"Elemwise{{{symbolname},inplace}}"
- scalar_op = getattr(scalar, symbolname[: -len("_inplace")])
+ base_symbol_name = symbolname[: -len("_inplace")]
+ scalar_op = getattr(scalar, base_symbol_name)
inplace_scalar_op = scalar_op.__class__(transfer_type(0))
rval = Elemwise(
inplace_scalar_op,
{0: 0},
- name=elemwise_name,
nfunc_spec=(nfunc and (nfunc, nin, nout)),
)
else:
- elemwise_name = f"Elemwise{{{symbolname},no_inplace}}"
scalar_op = getattr(scalar, symbolname)
- rval = Elemwise(
- scalar_op, name=elemwise_name, nfunc_spec=(nfunc and (nfunc, nin, nout))
- )
+ rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout)))
if getattr(symbol, "__doc__"):
rval.__doc__ = symbol.__doc__ + "\n\n " + rval.__doc__
@@ -1753,8 +1745,6 @@ def construct(symbol):
rval.__epydoc_asRoutine = symbol
rval.__module__ = symbol.__module__
- pprint.assign(rval, FunctionPrinter([symbolname.replace("_inplace", "=")]))
-
return rval
if symbol:
diff --git a/pytensor/tensor/math.py b/pytensor/tensor/math.py
index 66373400d9..358c3f3724 100644
--- a/pytensor/tensor/math.py
+++ b/pytensor/tensor/math.py
@@ -583,7 +583,12 @@ def max_and_argmax(a, axis=None, keepdims=False):
return [out, argout]
-class NonZeroCAReduce(CAReduce):
+class FixedOpCAReduce(CAReduce):
+ def __str__(self):
+ return f"{type(self).__name__}{{{self._axis_str()}}}"
+
+
+class NonZeroDimsCAReduce(FixedOpCAReduce):
def _c_all(self, node, name, inames, onames, sub):
decl, checks, alloc, loop, end = super()._c_all(node, name, inames, onames, sub)
@@ -614,7 +619,7 @@ def _c_all(self, node, name, inames, onames, sub):
return decl, checks, alloc, loop, end
-class Max(NonZeroCAReduce):
+class Max(NonZeroDimsCAReduce):
nfunc_spec = ("max", 1, 1)
def __init__(self, axis):
@@ -625,7 +630,7 @@ def clone(self, **kwargs):
return type(self)(axis=axis)
-class Min(NonZeroCAReduce):
+class Min(NonZeroDimsCAReduce):
nfunc_spec = ("min", 1, 1)
def __init__(self, axis):
@@ -1496,7 +1501,7 @@ def complex_from_polar(abs, angle):
"""Return complex-valued tensor from polar coordinate specification."""
-class Mean(CAReduce):
+class Mean(FixedOpCAReduce):
__props__ = ("axis",)
nfunc_spec = ("mean", 1, 1)
@@ -2356,7 +2361,7 @@ def outer(x, y):
return dot(x.dimshuffle(0, "x"), y.dimshuffle("x", 0))
-class All(CAReduce):
+class All(FixedOpCAReduce):
"""Applies `logical and` to all the values of a tensor along the
specified axis(es).
@@ -2370,12 +2375,6 @@ def __init__(self, axis=None):
def _output_dtype(self, idtype):
return "bool"
- def __str__(self):
- if self.axis is None:
- return "All"
- else:
- return "All{%s}" % ", ".join(map(str, self.axis))
-
def make_node(self, input):
input = as_tensor_variable(input)
if input.dtype != "bool":
@@ -2392,7 +2391,7 @@ def clone(self, **kwargs):
return type(self)(axis=axis)
-class Any(CAReduce):
+class Any(FixedOpCAReduce):
"""Applies `bitwise or` to all the values of a tensor along the
specified axis(es).
@@ -2406,12 +2405,6 @@ def __init__(self, axis=None):
def _output_dtype(self, idtype):
return "bool"
- def __str__(self):
- if self.axis is None:
- return "Any"
- else:
- return "Any{%s}" % ", ".join(map(str, self.axis))
-
def make_node(self, input):
input = as_tensor_variable(input)
if input.dtype != "bool":
@@ -2428,7 +2421,7 @@ def clone(self, **kwargs):
return type(self)(axis=axis)
-class Sum(CAReduce):
+class Sum(FixedOpCAReduce):
"""
Sums all the values of a tensor along the specified axis(es).
@@ -2449,14 +2442,6 @@ def __init__(self, axis=None, dtype=None, acc_dtype=None):
upcast_discrete_output=True,
)
- def __str__(self):
- name = self.__class__.__name__
- axis = ""
- if self.axis is not None:
- axis = ", ".join(str(x) for x in self.axis)
- axis = f"axis=[{axis}], "
- return f"{name}{{{axis}acc_dtype={self.acc_dtype}}}"
-
def L_op(self, inp, out, grads):
(x,) = inp
@@ -2526,7 +2511,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
pprint.assign(Sum, printing.FunctionPrinter(["sum"], ["axis"]))
-class Prod(CAReduce):
+class Prod(FixedOpCAReduce):
"""
Multiplies all the values of a tensor along the specified axis(es).
@@ -2537,7 +2522,6 @@ class Prod(CAReduce):
"""
__props__ = ("scalar_op", "axis", "dtype", "acc_dtype", "no_zeros_in_input")
-
nfunc_spec = ("prod", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False):
@@ -2683,6 +2667,14 @@ def clone(self, **kwargs):
no_zeros_in_input=no_zeros_in_input,
)
+ def __str__(self):
+ if self.no_zeros_in_input:
+ return f"{super().__str__()[:-1]}, no_zeros_in_input}})"
+ return super().__str__()
+
+ def __repr__(self):
+ return f"{super().__repr__()[:-1]}, no_zeros_in_input={self.no_zeros_in_input})"
+
def prod(
input,
@@ -2751,7 +2743,7 @@ def c_code_cache_version(self):
mul_without_zeros = MulWithoutZeros(aes.upcast_out, name="mul_without_zeros")
-class ProdWithoutZeros(CAReduce):
+class ProdWithoutZeros(FixedOpCAReduce):
def __init__(self, axis=None, dtype=None, acc_dtype=None):
super().__init__(
mul_without_zeros,
diff --git a/pytensor/tensor/rewriting/math.py b/pytensor/tensor/rewriting/math.py
index cf604eec1b..b0d124f1c8 100644
--- a/pytensor/tensor/rewriting/math.py
+++ b/pytensor/tensor/rewriting/math.py
@@ -42,7 +42,8 @@
All,
Any,
Dot,
- NonZeroCAReduce,
+ FixedOpCAReduce,
+ NonZeroDimsCAReduce,
Prod,
ProdWithoutZeros,
Sum,
@@ -1671,7 +1672,8 @@ def local_op_of_op(fgraph, node):
ProdWithoutZeros,
]
+ CAReduce.__subclasses__()
- + NonZeroCAReduce.__subclasses__()
+ + FixedOpCAReduce.__subclasses__()
+ + NonZeroDimsCAReduce.__subclasses__()
)
diff --git a/pytensor/tensor/subtensor.py b/pytensor/tensor/subtensor.py
index 99c335348b..e0c0060058 100644
--- a/pytensor/tensor/subtensor.py
+++ b/pytensor/tensor/subtensor.py
@@ -840,22 +840,34 @@ def __hash__(self):
@staticmethod
def str_from_slice(entry):
- msg = []
- for x in [entry.start, entry.stop, entry.step]:
- if x is None:
- msg.append("")
- else:
- msg.append(str(x))
- return ":".join(msg)
+ if entry.step:
+ return ":".join(
+ (
+ "start" if entry.start else "",
+ "stop" if entry.stop else "",
+ "step",
+ )
+ )
+ if entry.stop:
+ return f"{'start' if entry.start else ''}:stop"
+ if entry.start:
+ return "start:"
+ return ":"
- def __str__(self):
+ @staticmethod
+ def str_from_indices(idx_list):
indices = []
- for entry in self.idx_list:
+ letter_indexes = 0
+ for entry in idx_list:
if isinstance(entry, slice):
- indices.append(self.str_from_slice(entry))
+ indices.append(Subtensor.str_from_slice(entry))
else:
- indices.append(str(entry))
- return f"{self.__class__.__name__}{{{', '.join(indices)}}}"
+ indices.append("ijk"[letter_indexes % 3] * (letter_indexes // 3 + 1))
+ letter_indexes += 1
+ return ", ".join(indices)
+
+ def __str__(self):
+ return f"{self.__class__.__name__}{{{self.str_from_indices(self.idx_list)}}}"
@staticmethod
def default_helper_c_code_args():
@@ -1498,21 +1510,8 @@ def __hash__(self):
return hash((type(self), idx_list, self.inplace, self.set_instead_of_inc))
def __str__(self):
- indices = []
- for entry in self.idx_list:
- if isinstance(entry, slice):
- indices.append(Subtensor.str_from_slice(entry))
- else:
- indices.append(str(entry))
- if self.inplace:
- msg = "Inplace"
- else:
- msg = ""
- if not self.set_instead_of_inc:
- msg += "Inc"
- else:
- msg += "Set"
- return f"{self.__class__.__name__}{{{msg};{', '.join(indices)}}}"
+ name = "SetSubtensor" if self.set_instead_of_inc else "IncSubtensor"
+ return f"{name}{{{Subtensor.str_from_indices(self.idx_list)}}}"
def make_node(self, x, y, *inputs):
"""
@@ -2661,10 +2660,10 @@ def __init__(
self.ignore_duplicates = ignore_duplicates
def __str__(self):
- return "{}{{{}, {}}}".format(
- self.__class__.__name__,
- "inplace=" + str(self.inplace),
- " set_instead_of_inc=" + str(self.set_instead_of_inc),
+ return (
+ "AdvancedSetSubtensor"
+ if self.set_instead_of_inc
+ else "AdvancedIncSubtensor"
)
def make_node(self, x, y, *inputs):
diff --git a/pytensor/tensor/type.py b/pytensor/tensor/type.py
index 3a07afecd7..c4356d8117 100644
--- a/pytensor/tensor/type.py
+++ b/pytensor/tensor/type.py
@@ -386,6 +386,8 @@ def __str__(self):
if self.name:
return self.name
else:
+ shape = self.shape
+ len_shape = len(shape)
def shape_str(s):
if s is None:
@@ -393,14 +395,18 @@ def shape_str(s):
else:
return str(s)
- formatted_shape = ", ".join([shape_str(s) for s in self.shape])
- if len(self.shape) == 1:
+ formatted_shape = ", ".join([shape_str(s) for s in shape])
+ if len_shape == 1:
formatted_shape += ","
- return f"TensorType({self.dtype}, ({formatted_shape}))"
+ if len_shape > 2:
+ name = f"Tensor{len_shape}"
+ else:
+ name = ("Scalar", "Vector", "Matrix")[len_shape]
+ return f"{name}({self.dtype}, shape=({formatted_shape}))"
def __repr__(self):
- return str(self)
+ return f"TensorType({self.dtype}, shape={self.shape})"
@staticmethod
def may_share_memory(a, b):
diff --git a/pytensor/tensor/var.py b/pytensor/tensor/var.py
index 428cd8fffd..7cfd4cef87 100644
--- a/pytensor/tensor/var.py
+++ b/pytensor/tensor/var.py
@@ -616,9 +616,9 @@ def __iter__(self):
except TypeError:
# This prevents accidental iteration via sum(self)
raise TypeError(
- "TensorType does not support iteration. "
- "Maybe you are using builtins.sum instead of "
- "pytensor.tensor.math.sum? (Maybe .max?)"
+ "TensorType does not support iteration.\n"
+ "\tDid you pass a PyTensor variable to a function that expects a list?\n"
+ "\tMaybe you are using builtins.sum instead of pytensor.tensor.sum?"
)
@property
@@ -1023,21 +1023,6 @@ def __init__(self, type: _TensorTypeType, data, name=None):
Constant.__init__(self, new_type, data, name)
- def __str__(self):
- unique_val = get_unique_value(self)
- if unique_val is not None:
- val = f"{self.data.shape} of {unique_val}"
- else:
- val = f"{self.data}"
- if len(val) > 20:
- val = val[:10] + ".." + val[-10:]
-
- if self.name is not None:
- name = self.name
- else:
- name = "TensorConstant"
- return f"{name}{{{val}}}"
-
def signature(self):
return TensorConstantSignature((self.type, self.data))
diff --git a/tests/compile/test_builders.py b/tests/compile/test_builders.py
index ad45b8270c..8fdeb18470 100644
--- a/tests/compile/test_builders.py
+++ b/tests/compile/test_builders.py
@@ -572,18 +572,18 @@ def test_debugprint():
lines = output_str.split("\n")
exp_res = """OpFromGraph{inline=False} [id A]
- |x [id B]
- |y [id C]
- |z [id D]
+ ├─ x [id B]
+ ├─ y [id C]
+ └─ z [id D]
Inner graphs:
OpFromGraph{inline=False} [id A]
- >Elemwise{add,no_inplace} [id E]
- > |*0- [id F]
- > |Elemwise{mul,no_inplace} [id G]
- > |*1- [id H]
- > |*2- [id I]
+ ← Add [id E]
+ ├─ *0- [id F]
+ └─ Mul [id G]
+ ├─ *1- [id H]
+ └─ *2- [id I]
"""
for truth, out in zip(exp_res.split("\n"), lines):
diff --git a/tests/graph/rewriting/test_basic.py b/tests/graph/rewriting/test_basic.py
index 7e249ddc3b..a690a1c88f 100644
--- a/tests/graph/rewriting/test_basic.py
+++ b/tests/graph/rewriting/test_basic.py
@@ -166,7 +166,7 @@ def test_constant(self):
e = op1(op1(x, y), y)
g = FunctionGraph([y], [e])
OpKeyPatternNodeRewriter((op1, z, "1"), (op2, "1", z)).rewrite(g)
- assert str(g) == "FunctionGraph(Op1(Op2(y, z), y))"
+ assert str(g) == "FunctionGraph(Op1(Op2(y, z{2}), y))"
def test_constraints(self):
x, y, z = MyVariable("x"), MyVariable("y"), MyVariable("z")
diff --git a/tests/link/test_utils.py b/tests/link/test_utils.py
index 812bd6304c..80bab78012 100644
--- a/tests/link/test_utils.py
+++ b/tests/link/test_utils.py
@@ -156,7 +156,7 @@ def func(*args, op=op):
assert (
"""
- # Elemwise{add,no_inplace}(Test
+ # Add(Test
# Op().0, Test
# Op().1)
"""
diff --git a/tests/scalar/test_basic.py b/tests/scalar/test_basic.py
index 13f6a21110..acbecc7aba 100644
--- a/tests/scalar/test_basic.py
+++ b/tests/scalar/test_basic.py
@@ -183,7 +183,7 @@ def test_composite_printing(self):
make_function(DualLinker().accept(g))
assert str(g) == (
- "FunctionGraph(*1 -> Composite(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
+ "FunctionGraph(*1 -> Composite{...}(x, y, z), *1::1, *1::2, *1::3, *1::4, *1::5, *1::6, *1::7)"
)
def test_non_scalar_error(self):
diff --git a/tests/scan/test_printing.py b/tests/scan/test_printing.py
index 882c708966..725a48627d 100644
--- a/tests/scan/test_printing.py
+++ b/tests/scan/test_printing.py
@@ -26,40 +26,43 @@ def test_debugprint_sitsot():
output_str = debugprint(final_result, file="str", print_op_info=True)
lines = output_str.split("\n")
- expected_output = """Subtensor{int64} [id A]
- |Subtensor{int64::} [id B]
- | |for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
- | | |k [id D] (n_steps)
- | | |IncSubtensor{Set;:int64:} [id E] (outer_in_sit_sot-0)
- | | | |AllocEmpty{dtype='float64'} [id F]
- | | | | |Elemwise{add,no_inplace} [id G]
- | | | | | |k [id D]
- | | | | | |Subtensor{int64} [id H]
- | | | | | |Shape [id I]
- | | | | | | |Unbroadcast{0} [id J]
- | | | | | | |InplaceDimShuffle{x,0} [id K]
- | | | | | | |Elemwise{second,no_inplace} [id L]
- | | | | | | |A [id M]
- | | | | | | |InplaceDimShuffle{x} [id N]
- | | | | | | |TensorConstant{1.0} [id O]
- | | | | | |ScalarConstant{0} [id P]
- | | | | |Subtensor{int64} [id Q]
- | | | | |Shape [id R]
- | | | | | |Unbroadcast{0} [id J]
- | | | | |ScalarConstant{1} [id S]
- | | | |Unbroadcast{0} [id J]
- | | | |ScalarFromTensor [id T]
- | | | |Subtensor{int64} [id H]
- | | |A [id M] (outer_in_non_seqs-0)
- | |ScalarConstant{1} [id U]
- |ScalarConstant{-1} [id V]
+ expected_output = """Subtensor{i} [id A]
+ ├─ Subtensor{start:} [id B]
+ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id C] (outer_out_sit_sot-0)
+ │ │ ├─ k [id D] (n_steps)
+ │ │ ├─ SetSubtensor{:stop} [id E] (outer_in_sit_sot-0)
+ │ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
+ │ │ │ │ ├─ Add [id G]
+ │ │ │ │ │ ├─ k [id D]
+ │ │ │ │ │ └─ Subtensor{i} [id H]
+ │ │ │ │ │ ├─ Shape [id I]
+ │ │ │ │ │ │ └─ Unbroadcast{0} [id J]
+ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
+ │ │ │ │ │ │ └─ Second [id L]
+ │ │ │ │ │ │ ├─ A [id M]
+ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
+ │ │ │ │ │ │ └─ 1.0 [id O]
+ │ │ │ │ │ └─ 0 [id P]
+ │ │ │ │ └─ Subtensor{i} [id Q]
+ │ │ │ │ ├─ Shape [id R]
+ │ │ │ │ │ └─ Unbroadcast{0} [id J]
+ │ │ │ │ │ └─ ···
+ │ │ │ │ └─ 1 [id S]
+ │ │ │ ├─ Unbroadcast{0} [id J]
+ │ │ │ │ └─ ···
+ │ │ │ └─ ScalarFromTensor [id T]
+ │ │ │ └─ Subtensor{i} [id H]
+ │ │ │ └─ ···
+ │ │ └─ A [id M] (outer_in_non_seqs-0)
+ │ └─ 1 [id U]
+ └─ -1 [id V]
Inner graphs:
- for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
- >Elemwise{mul,no_inplace} [id W] (inner_out_sit_sot-0)
- > |*0- [id X] -> [id E] (inner_in_sit_sot-0)
- > |*1- [id Y] -> [id M] (inner_in_non_seqs-0)"""
+ Scan{scan_fn, while_loop=False, inplace=none} [id C]
+ ← Mul [id W] (inner_out_sit_sot-0)
+ ├─ *0- [id X] -> [id E] (inner_in_sit_sot-0)
+ └─ *1- [id Y] -> [id M] (inner_in_non_seqs-0)"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
@@ -81,40 +84,43 @@ def test_debugprint_sitsot_no_extra_info():
output_str = debugprint(final_result, file="str", print_op_info=False)
lines = output_str.split("\n")
- expected_output = """Subtensor{int64} [id A]
- |Subtensor{int64::} [id B]
- | |for{cpu,scan_fn} [id C]
- | | |k [id D]
- | | |IncSubtensor{Set;:int64:} [id E]
- | | | |AllocEmpty{dtype='float64'} [id F]
- | | | | |Elemwise{add,no_inplace} [id G]
- | | | | | |k [id D]
- | | | | | |Subtensor{int64} [id H]
- | | | | | |Shape [id I]
- | | | | | | |Unbroadcast{0} [id J]
- | | | | | | |InplaceDimShuffle{x,0} [id K]
- | | | | | | |Elemwise{second,no_inplace} [id L]
- | | | | | | |A [id M]
- | | | | | | |InplaceDimShuffle{x} [id N]
- | | | | | | |TensorConstant{1.0} [id O]
- | | | | | |ScalarConstant{0} [id P]
- | | | | |Subtensor{int64} [id Q]
- | | | | |Shape [id R]
- | | | | | |Unbroadcast{0} [id J]
- | | | | |ScalarConstant{1} [id S]
- | | | |Unbroadcast{0} [id J]
- | | | |ScalarFromTensor [id T]
- | | | |Subtensor{int64} [id H]
- | | |A [id M]
- | |ScalarConstant{1} [id U]
- |ScalarConstant{-1} [id V]
+ expected_output = """Subtensor{i} [id A]
+ ├─ Subtensor{start:} [id B]
+ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id C]
+ │ │ ├─ k [id D]
+ │ │ ├─ SetSubtensor{:stop} [id E]
+ │ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
+ │ │ │ │ ├─ Add [id G]
+ │ │ │ │ │ ├─ k [id D]
+ │ │ │ │ │ └─ Subtensor{i} [id H]
+ │ │ │ │ │ ├─ Shape [id I]
+ │ │ │ │ │ │ └─ Unbroadcast{0} [id J]
+ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
+ │ │ │ │ │ │ └─ Second [id L]
+ │ │ │ │ │ │ ├─ A [id M]
+ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
+ │ │ │ │ │ │ └─ 1.0 [id O]
+ │ │ │ │ │ └─ 0 [id P]
+ │ │ │ │ └─ Subtensor{i} [id Q]
+ │ │ │ │ ├─ Shape [id R]
+ │ │ │ │ │ └─ Unbroadcast{0} [id J]
+ │ │ │ │ │ └─ ···
+ │ │ │ │ └─ 1 [id S]
+ │ │ │ ├─ Unbroadcast{0} [id J]
+ │ │ │ │ └─ ···
+ │ │ │ └─ ScalarFromTensor [id T]
+ │ │ │ └─ Subtensor{i} [id H]
+ │ │ │ └─ ···
+ │ │ └─ A [id M]
+ │ └─ 1 [id U]
+ └─ -1 [id V]
Inner graphs:
- for{cpu,scan_fn} [id C]
- >Elemwise{mul,no_inplace} [id W]
- > |*0- [id X] -> [id E]
- > |*1- [id Y] -> [id M]"""
+ Scan{scan_fn, while_loop=False, inplace=none} [id C]
+ ← Mul [id W]
+ ├─ *0- [id X] -> [id E]
+ └─ *1- [id Y] -> [id M]"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
@@ -141,43 +147,48 @@ def test_debugprint_nitsot():
output_str = debugprint(polynomial, file="str", print_op_info=True)
lines = output_str.split("\n")
- expected_output = """Sum{acc_dtype=float64} [id A]
- |for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
- |Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0)
- | |Subtensor{int64} [id D]
- | | |Shape [id E]
- | | | |Subtensor{int64::} [id F] 'coefficients[0:]'
- | | | |coefficients [id G]
- | | | |ScalarConstant{0} [id H]
- | | |ScalarConstant{0} [id I]
- | |Subtensor{int64} [id J]
- | |Shape [id K]
- | | |Subtensor{int64::} [id L]
- | | |ARange{dtype='int64'} [id M]
- | | | |TensorConstant{0} [id N]
- | | | |TensorConstant{10000} [id O]
- | | | |TensorConstant{1} [id P]
- | | |ScalarConstant{0} [id Q]
- | |ScalarConstant{0} [id R]
- |Subtensor{:int64:} [id S] (outer_in_seqs-0)
- | |Subtensor{int64::} [id F] 'coefficients[0:]'
- | |ScalarFromTensor [id T]
- | |Elemwise{scalar_minimum,no_inplace} [id C]
- |Subtensor{:int64:} [id U] (outer_in_seqs-1)
- | |Subtensor{int64::} [id L]
- | |ScalarFromTensor [id V]
- | |Elemwise{scalar_minimum,no_inplace} [id C]
- |Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0)
- |x [id W] (outer_in_non_seqs-0)
+ expected_output = """Sum{axes=None} [id A]
+ └─ Scan{scan_fn, while_loop=False, inplace=none} [id B] (outer_out_nit_sot-0)
+ ├─ Minimum [id C] (outer_in_nit_sot-0)
+ │ ├─ Subtensor{i} [id D]
+ │ │ ├─ Shape [id E]
+ │ │ │ └─ Subtensor{start:} [id F] 'coefficients[0:]'
+ │ │ │ ├─ coefficients [id G]
+ │ │ │ └─ 0 [id H]
+ │ │ └─ 0 [id I]
+ │ └─ Subtensor{i} [id J]
+ │ ├─ Shape [id K]
+ │ │ └─ Subtensor{start:} [id L]
+ │ │ ├─ ARange{dtype='int64'} [id M]
+ │ │ │ ├─ 0 [id N]
+ │ │ │ ├─ 10000 [id O]
+ │ │ │ └─ 1 [id P]
+ │ │ └─ 0 [id Q]
+ │ └─ 0 [id R]
+ ├─ Subtensor{:stop} [id S] (outer_in_seqs-0)
+ │ ├─ Subtensor{start:} [id F] 'coefficients[0:]'
+ │ │ └─ ···
+ │ └─ ScalarFromTensor [id T]
+ │ └─ Minimum [id C]
+ │ └─ ···
+ ├─ Subtensor{:stop} [id U] (outer_in_seqs-1)
+ │ ├─ Subtensor{start:} [id L]
+ │ │ └─ ···
+ │ └─ ScalarFromTensor [id V]
+ │ └─ Minimum [id C]
+ │ └─ ···
+ ├─ Minimum [id C] (outer_in_nit_sot-0)
+ │ └─ ···
+ └─ x [id W] (outer_in_non_seqs-0)
Inner graphs:
- for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
- >Elemwise{mul,no_inplace} [id X] (inner_out_nit_sot-0)
- > |*0- [id Y] -> [id S] (inner_in_seqs-0)
- > |Elemwise{pow,no_inplace} [id Z]
- > |*2- [id BA] -> [id W] (inner_in_non_seqs-0)
- > |*1- [id BB] -> [id U] (inner_in_seqs-1)"""
+ Scan{scan_fn, while_loop=False, inplace=none} [id B]
+ ← Mul [id X] (inner_out_nit_sot-0)
+ ├─ *0- [id Y] -> [id S] (inner_in_seqs-0)
+ └─ Pow [id Z]
+ ├─ *2- [id BA] -> [id W] (inner_in_non_seqs-0)
+ └─ *1- [id BB] -> [id U] (inner_in_seqs-1)"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
@@ -214,77 +225,85 @@ def compute_A_k(A, k):
output_str = debugprint(final_result, file="str", print_op_info=True)
lines = output_str.split("\n")
- expected_output = """Sum{acc_dtype=float64} [id A]
- |for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
- |Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0)
- | |Subtensor{int64} [id D]
- | | |Shape [id E]
- | | | |Subtensor{int64::} [id F] 'c[0:]'
- | | | |c [id G]
- | | | |ScalarConstant{0} [id H]
- | | |ScalarConstant{0} [id I]
- | |Subtensor{int64} [id J]
- | |Shape [id K]
- | | |Subtensor{int64::} [id L]
- | | |ARange{dtype='int64'} [id M]
- | | | |TensorConstant{0} [id N]
- | | | |TensorConstant{10} [id O]
- | | | |TensorConstant{1} [id P]
- | | |ScalarConstant{0} [id Q]
- | |ScalarConstant{0} [id R]
- |Subtensor{:int64:} [id S] (outer_in_seqs-0)
- | |Subtensor{int64::} [id F] 'c[0:]'
- | |ScalarFromTensor [id T]
- | |Elemwise{scalar_minimum,no_inplace} [id C]
- |Subtensor{:int64:} [id U] (outer_in_seqs-1)
- | |Subtensor{int64::} [id L]
- | |ScalarFromTensor [id V]
- | |Elemwise{scalar_minimum,no_inplace} [id C]
- |Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0)
- |A [id W] (outer_in_non_seqs-0)
- |k [id X] (outer_in_non_seqs-1)
+ expected_output = """Sum{axes=None} [id A]
+ └─ Scan{scan_fn, while_loop=False, inplace=none} [id B] (outer_out_nit_sot-0)
+ ├─ Minimum [id C] (outer_in_nit_sot-0)
+ │ ├─ Subtensor{i} [id D]
+ │ │ ├─ Shape [id E]
+ │ │ │ └─ Subtensor{start:} [id F] 'c[0:]'
+ │ │ │ ├─ c [id G]
+ │ │ │ └─ 0 [id H]
+ │ │ └─ 0 [id I]
+ │ └─ Subtensor{i} [id J]
+ │ ├─ Shape [id K]
+ │ │ └─ Subtensor{start:} [id L]
+ │ │ ├─ ARange{dtype='int64'} [id M]
+ │ │ │ ├─ 0 [id N]
+ │ │ │ ├─ 10 [id O]
+ │ │ │ └─ 1 [id P]
+ │ │ └─ 0 [id Q]
+ │ └─ 0 [id R]
+ ├─ Subtensor{:stop} [id S] (outer_in_seqs-0)
+ │ ├─ Subtensor{start:} [id F] 'c[0:]'
+ │ │ └─ ···
+ │ └─ ScalarFromTensor [id T]
+ │ └─ Minimum [id C]
+ │ └─ ···
+ ├─ Subtensor{:stop} [id U] (outer_in_seqs-1)
+ │ ├─ Subtensor{start:} [id L]
+ │ │ └─ ···
+ │ └─ ScalarFromTensor [id V]
+ │ └─ Minimum [id C]
+ │ └─ ···
+ ├─ Minimum [id C] (outer_in_nit_sot-0)
+ │ └─ ···
+ ├─ A [id W] (outer_in_non_seqs-0)
+ └─ k [id X] (outer_in_non_seqs-1)
Inner graphs:
- for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
- >Elemwise{mul,no_inplace} [id Y] (inner_out_nit_sot-0)
- > |InplaceDimShuffle{x} [id Z]
- > | |*0- [id BA] -> [id S] (inner_in_seqs-0)
- > |Elemwise{pow,no_inplace} [id BB]
- > |Subtensor{int64} [id BC]
- > | |Subtensor{int64::} [id BD]
- > | | |for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
- > | | | |*3- [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
- > | | | |IncSubtensor{Set;:int64:} [id BG] (outer_in_sit_sot-0)
- > | | | | |AllocEmpty{dtype='float64'} [id BH]
- > | | | | | |Elemwise{add,no_inplace} [id BI]
- > | | | | | | |*3- [id BF] -> [id X] (inner_in_non_seqs-1)
- > | | | | | | |Subtensor{int64} [id BJ]
- > | | | | | | |Shape [id BK]
- > | | | | | | | |Unbroadcast{0} [id BL]
- > | | | | | | | |InplaceDimShuffle{x,0} [id BM]
- > | | | | | | | |Elemwise{second,no_inplace} [id BN]
- > | | | | | | | |*2- [id BO] -> [id W] (inner_in_non_seqs-0)
- > | | | | | | | |InplaceDimShuffle{x} [id BP]
- > | | | | | | | |TensorConstant{1.0} [id BQ]
- > | | | | | | |ScalarConstant{0} [id BR]
- > | | | | | |Subtensor{int64} [id BS]
- > | | | | | |Shape [id BT]
- > | | | | | | |Unbroadcast{0} [id BL]
- > | | | | | |ScalarConstant{1} [id BU]
- > | | | | |Unbroadcast{0} [id BL]
- > | | | | |ScalarFromTensor [id BV]
- > | | | | |Subtensor{int64} [id BJ]
- > | | | |*2- [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
- > | | |ScalarConstant{1} [id BW]
- > | |ScalarConstant{-1} [id BX]
- > |InplaceDimShuffle{x} [id BY]
- > |*1- [id BZ] -> [id U] (inner_in_seqs-1)
-
- for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
- >Elemwise{mul,no_inplace} [id CA] (inner_out_sit_sot-0)
- > |*0- [id CB] -> [id BG] (inner_in_sit_sot-0)
- > |*1- [id CC] -> [id BO] (inner_in_non_seqs-0)"""
+ Scan{scan_fn, while_loop=False, inplace=none} [id B]
+ ← Mul [id Y] (inner_out_nit_sot-0)
+ ├─ ExpandDims{axis=0} [id Z]
+ │ └─ *0- [id BA] -> [id S] (inner_in_seqs-0)
+ └─ Pow [id BB]
+ ├─ Subtensor{i} [id BC]
+ │ ├─ Subtensor{start:} [id BD]
+ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BE] (outer_out_sit_sot-0)
+ │ │ │ ├─ *3- [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
+ │ │ │ ├─ SetSubtensor{:stop} [id BG] (outer_in_sit_sot-0)
+ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH]
+ │ │ │ │ │ ├─ Add [id BI]
+ │ │ │ │ │ │ ├─ *3- [id BF] -> [id X] (inner_in_non_seqs-1)
+ │ │ │ │ │ │ └─ Subtensor{i} [id BJ]
+ │ │ │ │ │ │ ├─ Shape [id BK]
+ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
+ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BM]
+ │ │ │ │ │ │ │ └─ Second [id BN]
+ │ │ │ │ │ │ │ ├─ *2- [id BO] -> [id W] (inner_in_non_seqs-0)
+ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP]
+ │ │ │ │ │ │ │ └─ 1.0 [id BQ]
+ │ │ │ │ │ │ └─ 0 [id BR]
+ │ │ │ │ │ └─ Subtensor{i} [id BS]
+ │ │ │ │ │ ├─ Shape [id BT]
+ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
+ │ │ │ │ │ │ └─ ···
+ │ │ │ │ │ └─ 1 [id BU]
+ │ │ │ │ ├─ Unbroadcast{0} [id BL]
+ │ │ │ │ │ └─ ···
+ │ │ │ │ └─ ScalarFromTensor [id BV]
+ │ │ │ │ └─ Subtensor{i} [id BJ]
+ │ │ │ │ └─ ···
+ │ │ │ └─ *2- [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
+ │ │ └─ 1 [id BW]
+ │ └─ -1 [id BX]
+ └─ ExpandDims{axis=0} [id BY]
+ └─ *1- [id BZ] -> [id U] (inner_in_seqs-1)
+
+ Scan{scan_fn, while_loop=False, inplace=none} [id BE]
+ ← Mul [id CA] (inner_out_sit_sot-0)
+ ├─ *0- [id CB] -> [id BG] (inner_in_sit_sot-0)
+ └─ *1- [id CC] -> [id BO] (inner_in_non_seqs-0)"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
@@ -296,86 +315,94 @@ def compute_A_k(A, k):
)
lines = output_str.split("\n")
- expected_output = """-c [id A]
- -k [id B]
- -A [id C]
- Sum{acc_dtype=float64} [id D] 13
- |for{cpu,scan_fn} [id E] 12 (outer_out_nit_sot-0)
- |Elemwise{scalar_minimum,no_inplace} [id F] 7 (outer_in_nit_sot-0)
- | |Subtensor{int64} [id G] 6
- | | |Shape [id H] 5
- | | | |Subtensor{int64::} [id I] 'c[0:]' 4
- | | | |c [id A]
- | | | |ScalarConstant{0} [id J]
- | | |ScalarConstant{0} [id K]
- | |Subtensor{int64} [id L] 3
- | |Shape [id M] 2
- | | |Subtensor{int64::} [id N] 1
- | | |ARange{dtype='int64'} [id O] 0
- | | | |TensorConstant{0} [id P]
- | | | |TensorConstant{10} [id Q]
- | | | |TensorConstant{1} [id R]
- | | |ScalarConstant{0} [id S]
- | |ScalarConstant{0} [id T]
- |Subtensor{:int64:} [id U] 11 (outer_in_seqs-0)
- | |Subtensor{int64::} [id I] 'c[0:]' 4
- | |ScalarFromTensor [id V] 10
- | |Elemwise{scalar_minimum,no_inplace} [id F] 7
- |Subtensor{:int64:} [id W] 9 (outer_in_seqs-1)
- | |Subtensor{int64::} [id N] 1
- | |ScalarFromTensor [id X] 8
- | |Elemwise{scalar_minimum,no_inplace} [id F] 7
- |Elemwise{scalar_minimum,no_inplace} [id F] 7 (outer_in_nit_sot-0)
- |A [id C] (outer_in_non_seqs-0)
- |k [id B] (outer_in_non_seqs-1)
+ expected_output = """→ c [id A]
+ → k [id B]
+ → A [id C]
+ Sum{axes=None} [id D] 13
+ └─ Scan{scan_fn, while_loop=False, inplace=none} [id E] 12 (outer_out_nit_sot-0)
+ ├─ Minimum [id F] 7 (outer_in_nit_sot-0)
+ │ ├─ Subtensor{i} [id G] 6
+ │ │ ├─ Shape [id H] 5
+ │ │ │ └─ Subtensor{start:} [id I] 'c[0:]' 4
+ │ │ │ ├─ c [id A]
+ │ │ │ └─ 0 [id J]
+ │ │ └─ 0 [id K]
+ │ └─ Subtensor{i} [id L] 3
+ │ ├─ Shape [id M] 2
+ │ │ └─ Subtensor{start:} [id N] 1
+ │ │ ├─ ARange{dtype='int64'} [id O] 0
+ │ │ │ ├─ 0 [id P]
+ │ │ │ ├─ 10 [id Q]
+ │ │ │ └─ 1 [id R]
+ │ │ └─ 0 [id S]
+ │ └─ 0 [id T]
+ ├─ Subtensor{:stop} [id U] 11 (outer_in_seqs-0)
+ │ ├─ Subtensor{start:} [id I] 'c[0:]' 4
+ │ │ └─ ···
+ │ └─ ScalarFromTensor [id V] 10
+ │ └─ Minimum [id F] 7
+ │ └─ ···
+ ├─ Subtensor{:stop} [id W] 9 (outer_in_seqs-1)
+ │ ├─ Subtensor{start:} [id N] 1
+ │ │ └─ ···
+ │ └─ ScalarFromTensor [id X] 8
+ │ └─ Minimum [id F] 7
+ │ └─ ···
+ ├─ Minimum [id F] 7 (outer_in_nit_sot-0)
+ │ └─ ···
+ ├─ A [id C] (outer_in_non_seqs-0)
+ └─ k [id B] (outer_in_non_seqs-1)
Inner graphs:
- for{cpu,scan_fn} [id E] (outer_out_nit_sot-0)
- -*0- [id Y] -> [id U] (inner_in_seqs-0)
- -*1- [id Z] -> [id W] (inner_in_seqs-1)
- -*2- [id BA] -> [id C] (inner_in_non_seqs-0)
- -*3- [id BB] -> [id B] (inner_in_non_seqs-1)
- >Elemwise{mul,no_inplace} [id BC] (inner_out_nit_sot-0)
- > |InplaceDimShuffle{x} [id BD]
- > | |*0- [id Y] (inner_in_seqs-0)
- > |Elemwise{pow,no_inplace} [id BE]
- > |Subtensor{int64} [id BF]
- > | |Subtensor{int64::} [id BG]
- > | | |for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
- > | | | |*3- [id BB] (inner_in_non_seqs-1) (n_steps)
- > | | | |IncSubtensor{Set;:int64:} [id BI] (outer_in_sit_sot-0)
- > | | | | |AllocEmpty{dtype='float64'} [id BJ]
- > | | | | | |Elemwise{add,no_inplace} [id BK]
- > | | | | | | |*3- [id BB] (inner_in_non_seqs-1)
- > | | | | | | |Subtensor{int64} [id BL]
- > | | | | | | |Shape [id BM]
- > | | | | | | | |Unbroadcast{0} [id BN]
- > | | | | | | | |InplaceDimShuffle{x,0} [id BO]
- > | | | | | | | |Elemwise{second,no_inplace} [id BP]
- > | | | | | | | |*2- [id BA] (inner_in_non_seqs-0)
- > | | | | | | | |InplaceDimShuffle{x} [id BQ]
- > | | | | | | | |TensorConstant{1.0} [id BR]
- > | | | | | | |ScalarConstant{0} [id BS]
- > | | | | | |Subtensor{int64} [id BT]
- > | | | | | |Shape [id BU]
- > | | | | | | |Unbroadcast{0} [id BN]
- > | | | | | |ScalarConstant{1} [id BV]
- > | | | | |Unbroadcast{0} [id BN]
- > | | | | |ScalarFromTensor [id BW]
- > | | | | |Subtensor{int64} [id BL]
- > | | | |*2- [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
- > | | |ScalarConstant{1} [id BX]
- > | |ScalarConstant{-1} [id BY]
- > |InplaceDimShuffle{x} [id BZ]
- > |*1- [id Z] (inner_in_seqs-1)
-
- for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
- -*0- [id CA] -> [id BI] (inner_in_sit_sot-0)
- -*1- [id CB] -> [id BA] (inner_in_non_seqs-0)
- >Elemwise{mul,no_inplace} [id CC] (inner_out_sit_sot-0)
- > |*0- [id CA] (inner_in_sit_sot-0)
- > |*1- [id CB] (inner_in_non_seqs-0)"""
+ Scan{scan_fn, while_loop=False, inplace=none} [id E]
+ → *0- [id Y] -> [id U] (inner_in_seqs-0)
+ → *1- [id Z] -> [id W] (inner_in_seqs-1)
+ → *2- [id BA] -> [id C] (inner_in_non_seqs-0)
+ → *3- [id BB] -> [id B] (inner_in_non_seqs-1)
+ ← Mul [id BC] (inner_out_nit_sot-0)
+ ├─ ExpandDims{axis=0} [id BD]
+ │ └─ *0- [id Y] (inner_in_seqs-0)
+ └─ Pow [id BE]
+ ├─ Subtensor{i} [id BF]
+ │ ├─ Subtensor{start:} [id BG]
+ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id BH] (outer_out_sit_sot-0)
+ │ │ │ ├─ *3- [id BB] (inner_in_non_seqs-1) (n_steps)
+ │ │ │ ├─ SetSubtensor{:stop} [id BI] (outer_in_sit_sot-0)
+ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ]
+ │ │ │ │ │ ├─ Add [id BK]
+ │ │ │ │ │ │ ├─ *3- [id BB] (inner_in_non_seqs-1)
+ │ │ │ │ │ │ └─ Subtensor{i} [id BL]
+ │ │ │ │ │ │ ├─ Shape [id BM]
+ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
+ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO]
+ │ │ │ │ │ │ │ └─ Second [id BP]
+ │ │ │ │ │ │ │ ├─ *2- [id BA] (inner_in_non_seqs-0)
+ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BQ]
+ │ │ │ │ │ │ │ └─ 1.0 [id BR]
+ │ │ │ │ │ │ └─ 0 [id BS]
+ │ │ │ │ │ └─ Subtensor{i} [id BT]
+ │ │ │ │ │ ├─ Shape [id BU]
+ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
+ │ │ │ │ │ │ └─ ···
+ │ │ │ │ │ └─ 1 [id BV]
+ │ │ │ │ ├─ Unbroadcast{0} [id BN]
+ │ │ │ │ │ └─ ···
+ │ │ │ │ └─ ScalarFromTensor [id BW]
+ │ │ │ │ └─ Subtensor{i} [id BL]
+ │ │ │ │ └─ ···
+ │ │ │ └─ *2- [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
+ │ │ └─ 1 [id BX]
+ │ └─ -1 [id BY]
+ └─ ExpandDims{axis=0} [id BZ]
+ └─ *1- [id Z] (inner_in_seqs-1)
+
+ Scan{scan_fn, while_loop=False, inplace=none} [id BH]
+ → *0- [id CA] -> [id BI] (inner_in_sit_sot-0)
+ → *1- [id CB] -> [id BA] (inner_in_non_seqs-0)
+ ← Mul [id CC] (inner_out_sit_sot-0)
+ ├─ *0- [id CA] (inner_in_sit_sot-0)
+ └─ *1- [id CB] (inner_in_non_seqs-0)"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
@@ -402,54 +429,55 @@ def fn(a_m2, a_m1, b_m2, b_m1):
output_str = debugprint(final_result, file="str", print_op_info=True)
lines = output_str.split("\n")
- expected_output = """Elemwise{add,no_inplace} [id A]
- |Subtensor{int64::} [id B]
- | |for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
- | | |TensorConstant{5} [id D] (n_steps)
- | | |IncSubtensor{Set;:int64:} [id E] (outer_in_mit_sot-0)
- | | | |AllocEmpty{dtype='int64'} [id F]
- | | | | |Elemwise{add,no_inplace} [id G]
- | | | | |TensorConstant{5} [id D]
- | | | | |Subtensor{int64} [id H]
- | | | | |Shape [id I]
- | | | | | |Subtensor{:int64:} [id J]
- | | | | | | [id K]
- | | | | | |ScalarConstant{2} [id L]
- | | | | |ScalarConstant{0} [id M]
- | | | |Subtensor{:int64:} [id J]
- | | | |ScalarFromTensor [id N]
- | | | |Subtensor{int64} [id H]
- | | |IncSubtensor{Set;:int64:} [id O] (outer_in_mit_sot-1)
- | | |AllocEmpty{dtype='int64'} [id P]
- | | | |Elemwise{add,no_inplace} [id Q]
- | | | |TensorConstant{5} [id D]
- | | | |Subtensor{int64} [id R]
- | | | |Shape [id S]
- | | | | |Subtensor{:int64:} [id T]
- | | | | | [id U]
- | | | | |ScalarConstant{2} [id V]
- | | | |ScalarConstant{0} [id W]
- | | |Subtensor{:int64:} [id T]
- | | |ScalarFromTensor [id X]
- | | |Subtensor{int64} [id R]
- | |ScalarConstant{2} [id Y]
- |Subtensor{int64::} [id Z]
- |for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
- |ScalarConstant{2} [id BA]
+ expected_output = """Add [id A]
+ ├─ Subtensor{start:} [id B]
+ │ ├─ Scan{scan_fn, while_loop=False, inplace=none}.0 [id C] (outer_out_mit_sot-0)
+ │ │ ├─ 5 [id D] (n_steps)
+ │ │ ├─ SetSubtensor{:stop} [id E] (outer_in_mit_sot-0)
+ │ │ │ ├─ AllocEmpty{dtype='int64'} [id F]
+ │ │ │ │ └─ Add [id G]
+ │ │ │ │ ├─ 5 [id D]
+ │ │ │ │ └─ Subtensor{i} [id H]
+ │ │ │ │ ├─ Shape [id I]
+ │ │ │ │ │ └─ Subtensor{:stop} [id J]
+ │ │ │ │ │ ├─ [id K]
+ │ │ │ │ │ └─ 2 [id L]
+ │ │ │ │ └─ 0 [id M]
+ │ │ │ ├─ Subtensor{:stop} [id J]
+ │ │ │ │ └─ ···
+ │ │ │ └─ ScalarFromTensor [id N]
+ │ │ │ └─ Subtensor{i} [id H]
+ │ │ │ └─ ···
+ │ │ └─ SetSubtensor{:stop} [id O] (outer_in_mit_sot-1)
+ │ │ ├─ AllocEmpty{dtype='int64'} [id P]
+ │ │ │ └─ Add [id Q]
+ │ │ │ ├─ 5 [id D]
+ │ │ │ └─ Subtensor{i} [id R]
+ │ │ │ ├─ Shape [id S]
+ │ │ │ │ └─ Subtensor{:stop} [id T]
+ │ │ │ │ ├─ [id U]
+ │ │ │ │ └─ 2 [id V]
+ │ │ │ └─ 0 [id W]
+ │ │ ├─ Subtensor{:stop} [id T]
+ │ │ │ └─ ···
+ │ │ └─ ScalarFromTensor [id X]
+ │ │ └─ Subtensor{i} [id R]
+ │ │ └─ ···
+ │ └─ 2 [id Y]
+ └─ Subtensor{start:} [id Z]
+ ├─ Scan{scan_fn, while_loop=False, inplace=none}.1 [id C] (outer_out_mit_sot-1)
+ │ └─ ···
+ └─ 2 [id BA]
Inner graphs:
- for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
- >Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0)
- > |*1- [id BC] -> [id E] (inner_in_mit_sot-0-1)
- > |*0- [id BD] -> [id E] (inner_in_mit_sot-0-0)
- >Elemwise{add,no_inplace} [id BE] (inner_out_mit_sot-1)
- > |*3- [id BF] -> [id O] (inner_in_mit_sot-1-1)
- > |*2- [id BG] -> [id O] (inner_in_mit_sot-1-0)
-
- for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
- >Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0)
- >Elemwise{add,no_inplace} [id BE] (inner_out_mit_sot-1)"""
+ Scan{scan_fn, while_loop=False, inplace=none} [id C]
+ ← Add [id BB] (inner_out_mit_sot-0)
+ ├─ *1- [id BC] -> [id E] (inner_in_mit_sot-0-1)
+ └─ *0- [id BD] -> [id E] (inner_in_mit_sot-0-0)
+ ← Add [id BE] (inner_out_mit_sot-1)
+ ├─ *3- [id BF] -> [id O] (inner_in_mit_sot-1-1)
+ └─ *2- [id BG] -> [id O] (inner_in_mit_sot-1-0)"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
@@ -473,107 +501,119 @@ def test_debugprint_mitmot():
output_str = debugprint(final_result, file="str", print_op_info=True)
lines = output_str.split("\n")
- expected_output = """Subtensor{int64} [id A]
- |for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
- | |Elemwise{sub,no_inplace} [id C] (n_steps)
- | | |Subtensor{int64} [id D]
- | | | |Shape [id E]
- | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
- | | | | |k [id G] (n_steps)
- | | | | |IncSubtensor{Set;:int64:} [id H] (outer_in_sit_sot-0)
- | | | | | |AllocEmpty{dtype='float64'} [id I]
- | | | | | | |Elemwise{add,no_inplace} [id J]
- | | | | | | | |k [id G]
- | | | | | | | |Subtensor{int64} [id K]
- | | | | | | | |Shape [id L]
- | | | | | | | | |Unbroadcast{0} [id M]
- | | | | | | | | |InplaceDimShuffle{x,0} [id N]
- | | | | | | | | |Elemwise{second,no_inplace} [id O]
- | | | | | | | | |A [id P]
- | | | | | | | | |InplaceDimShuffle{x} [id Q]
- | | | | | | | | |TensorConstant{1.0} [id R]
- | | | | | | | |ScalarConstant{0} [id S]
- | | | | | | |Subtensor{int64} [id T]
- | | | | | | |Shape [id U]
- | | | | | | | |Unbroadcast{0} [id M]
- | | | | | | |ScalarConstant{1} [id V]
- | | | | | |Unbroadcast{0} [id M]
- | | | | | |ScalarFromTensor [id W]
- | | | | | |Subtensor{int64} [id K]
- | | | | |A [id P] (outer_in_non_seqs-0)
- | | | |ScalarConstant{0} [id X]
- | | |TensorConstant{1} [id Y]
- | |Subtensor{:int64:} [id Z] (outer_in_seqs-0)
- | | |Subtensor{::int64} [id BA]
- | | | |Subtensor{:int64:} [id BB]
- | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
- | | | | |ScalarConstant{-1} [id BC]
- | | | |ScalarConstant{-1} [id BD]
- | | |ScalarFromTensor [id BE]
- | | |Elemwise{sub,no_inplace} [id C]
- | |Subtensor{:int64:} [id BF] (outer_in_seqs-1)
- | | |Subtensor{:int64:} [id BG]
- | | | |Subtensor{::int64} [id BH]
- | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
- | | | | |ScalarConstant{-1} [id BI]
- | | | |ScalarConstant{-1} [id BJ]
- | | |ScalarFromTensor [id BK]
- | | |Elemwise{sub,no_inplace} [id C]
- | |Subtensor{::int64} [id BL] (outer_in_mit_mot-0)
- | | |IncSubtensor{Inc;int64::} [id BM]
- | | | |Elemwise{second,no_inplace} [id BN]
- | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
- | | | | |InplaceDimShuffle{x,x} [id BO]
- | | | | |TensorConstant{0.0} [id BP]
- | | | |IncSubtensor{Inc;int64} [id BQ]
- | | | | |Elemwise{second,no_inplace} [id BR]
- | | | | | |Subtensor{int64::} [id BS]
- | | | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
- | | | | | | |ScalarConstant{1} [id BT]
- | | | | | |InplaceDimShuffle{x,x} [id BU]
- | | | | | |TensorConstant{0.0} [id BV]
- | | | | |Elemwise{second} [id BW]
- | | | | | |Subtensor{int64} [id BX]
- | | | | | | |Subtensor{int64::} [id BS]
- | | | | | | |ScalarConstant{-1} [id BY]
- | | | | | |InplaceDimShuffle{x} [id BZ]
- | | | | | |Elemwise{second,no_inplace} [id CA]
- | | | | | |Sum{acc_dtype=float64} [id CB]
- | | | | | | |Subtensor{int64} [id BX]
- | | | | | |TensorConstant{1.0} [id CC]
- | | | | |ScalarConstant{-1} [id BY]
- | | | |ScalarConstant{1} [id BT]
- | | |ScalarConstant{-1} [id CD]
- | |Alloc [id CE] (outer_in_sit_sot-0)
- | | |TensorConstant{0.0} [id CF]
- | | |Elemwise{add,no_inplace} [id CG]
- | | | |Elemwise{sub,no_inplace} [id C]
- | | | |TensorConstant{1} [id CH]
- | | |Subtensor{int64} [id CI]
- | | |Shape [id CJ]
- | | | |A [id P]
- | | |ScalarConstant{0} [id CK]
- | |A [id P] (outer_in_non_seqs-0)
- |ScalarConstant{-1} [id CL]
+ expected_output = """Subtensor{i} [id A]
+ ├─ Scan{grad_of_scan_fn, while_loop=False, inplace=none}.1 [id B] (outer_out_sit_sot-0)
+ │ ├─ Sub [id C] (n_steps)
+ │ │ ├─ Subtensor{i} [id D]
+ │ │ │ ├─ Shape [id E]
+ │ │ │ │ └─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
+ │ │ │ │ ├─ k [id G] (n_steps)
+ │ │ │ │ ├─ SetSubtensor{:stop} [id H] (outer_in_sit_sot-0)
+ │ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id I]
+ │ │ │ │ │ │ ├─ Add [id J]
+ │ │ │ │ │ │ │ ├─ k [id G]
+ │ │ │ │ │ │ │ └─ Subtensor{i} [id K]
+ │ │ │ │ │ │ │ ├─ Shape [id L]
+ │ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
+ │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
+ │ │ │ │ │ │ │ │ └─ Second [id O]
+ │ │ │ │ │ │ │ │ ├─ A [id P]
+ │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Q]
+ │ │ │ │ │ │ │ │ └─ 1.0 [id R]
+ │ │ │ │ │ │ │ └─ 0 [id S]
+ │ │ │ │ │ │ └─ Subtensor{i} [id T]
+ │ │ │ │ │ │ ├─ Shape [id U]
+ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
+ │ │ │ │ │ │ │ └─ ···
+ │ │ │ │ │ │ └─ 1 [id V]
+ │ │ │ │ │ ├─ Unbroadcast{0} [id M]
+ │ │ │ │ │ │ └─ ···
+ │ │ │ │ │ └─ ScalarFromTensor [id W]
+ │ │ │ │ │ └─ Subtensor{i} [id K]
+ │ │ │ │ │ └─ ···
+ │ │ │ │ └─ A [id P] (outer_in_non_seqs-0)
+ │ │ │ └─ 0 [id X]
+ │ │ └─ 1 [id Y]
+ │ ├─ Subtensor{:stop} [id Z] (outer_in_seqs-0)
+ │ │ ├─ Subtensor{::step} [id BA]
+ │ │ │ ├─ Subtensor{:stop} [id BB]
+ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
+ │ │ │ │ │ └─ ···
+ │ │ │ │ └─ -1 [id BC]
+ │ │ │ └─ -1 [id BD]
+ │ │ └─ ScalarFromTensor [id BE]
+ │ │ └─ Sub [id C]
+ │ │ └─ ···
+ │ ├─ Subtensor{:stop} [id BF] (outer_in_seqs-1)
+ │ │ ├─ Subtensor{:stop} [id BG]
+ │ │ │ ├─ Subtensor{::step} [id BH]
+ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
+ │ │ │ │ │ └─ ···
+ │ │ │ │ └─ -1 [id BI]
+ │ │ │ └─ -1 [id BJ]
+ │ │ └─ ScalarFromTensor [id BK]
+ │ │ └─ Sub [id C]
+ │ │ └─ ···
+ │ ├─ Subtensor{::step} [id BL] (outer_in_mit_mot-0)
+ │ │ ├─ IncSubtensor{start:} [id BM]
+ │ │ │ ├─ Second [id BN]
+ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
+ │ │ │ │ │ └─ ···
+ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO]
+ │ │ │ │ └─ 0.0 [id BP]
+ │ │ │ ├─ IncSubtensor{i} [id BQ]
+ │ │ │ │ ├─ Second [id BR]
+ │ │ │ │ │ ├─ Subtensor{start:} [id BS]
+ │ │ │ │ │ │ ├─ Scan{scan_fn, while_loop=False, inplace=none} [id F] (outer_out_sit_sot-0)
+ │ │ │ │ │ │ │ └─ ···
+ │ │ │ │ │ │ └─ 1 [id BT]
+ │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU]
+ │ │ │ │ │ └─ 0.0 [id BV]
+ │ │ │ │ ├─ Second [id BW]
+ │ │ │ │ │ ├─ Subtensor{i} [id BX]
+ │ │ │ │ │ │ ├─ Subtensor{start:} [id BS]
+ │ │ │ │ │ │ │ └─ ···
+ │ │ │ │ │ │ └─ -1 [id BY]
+ │ │ │ │ │ └─ ExpandDims{axis=0} [id BZ]
+ │ │ │ │ │ └─ Second [id CA]
+ │ │ │ │ │ ├─ Sum{axes=None} [id CB]
+ │ │ │ │ │ │ └─ Subtensor{i} [id BX]
+ │ │ │ │ │ │ └─ ···
+ │ │ │ │ │ └─ 1.0 [id CC]
+ │ │ │ │ └─ -1 [id BY]
+ │ │ │ └─ 1 [id BT]
+ │ │ └─ -1 [id CD]
+ │ ├─ Alloc [id CE] (outer_in_sit_sot-0)
+ │ │ ├─ 0.0 [id CF]
+ │ │ ├─ Add [id CG]
+ │ │ │ ├─ Sub [id C]
+ │ │ │ │ └─ ···
+ │ │ │ └─ 1 [id CH]
+ │ │ └─ Subtensor{i} [id CI]
+ │ │ ├─ Shape [id CJ]
+ │ │ │ └─ A [id P]
+ │ │ └─ 0 [id CK]
+ │ └─ A [id P] (outer_in_non_seqs-0)
+ └─ -1 [id CL]
Inner graphs:
- for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
- >Elemwise{add,no_inplace} [id CM] (inner_out_mit_mot-0-0)
- > |Elemwise{mul} [id CN]
- > | |*2- [id CO] -> [id BL] (inner_in_mit_mot-0-0)
- > | |*5- [id CP] -> [id P] (inner_in_non_seqs-0)
- > |*3- [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
- >Elemwise{add,no_inplace} [id CR] (inner_out_sit_sot-0)
- > |Elemwise{mul} [id CS]
- > | |*2- [id CO] -> [id BL] (inner_in_mit_mot-0-0)
- > | |*0- [id CT] -> [id Z] (inner_in_seqs-0)
- > |*4- [id CU] -> [id CE] (inner_in_sit_sot-0)
-
- for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
- >Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)
- > |*0- [id CT] -> [id H] (inner_in_sit_sot-0)
- > |*1- [id CW] -> [id P] (inner_in_non_seqs-0)"""
+ Scan{grad_of_scan_fn, while_loop=False, inplace=none} [id B]
+ ← Add [id CM] (inner_out_mit_mot-0-0)
+ ├─ Mul [id CN]
+ │ ├─ *2- [id CO] -> [id BL] (inner_in_mit_mot-0-0)
+ │ └─ *5- [id CP] -> [id P] (inner_in_non_seqs-0)
+ └─ *3- [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
+ ← Add [id CR] (inner_out_sit_sot-0)
+ ├─ Mul [id CS]
+ │ ├─ *2- [id CO] -> [id BL] (inner_in_mit_mot-0-0)
+ │ └─ *0- [id CT] -> [id Z] (inner_in_seqs-0)
+ └─ *4- [id CU] -> [id CE] (inner_in_sit_sot-0)
+
+ Scan{scan_fn, while_loop=False, inplace=none} [id F]
+ ← Mul [id CV] (inner_out_sit_sot-0)
+ ├─ *0- [id CT] -> [id H] (inner_in_sit_sot-0)
+ └─ *1- [id CW] -> [id P] (inner_in_non_seqs-0)"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
@@ -601,41 +641,40 @@ def no_shared_fn(n, x_tm1, M):
# (i.e. from `Scan._fn`)
out = pytensor.function([M], out, updates=updates, mode="FAST_RUN")
- expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0)
- |TensorConstant{20000} [id B] (n_steps)
- |TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
- |IncSubtensor{InplaceSet;:int64:} [id D] 1 (outer_in_sit_sot-0)
- | |AllocEmpty{dtype='int64'} [id E] 0
- | | |TensorConstant{20000} [id B]
- | |TensorConstant{(1,) of 0} [id F]
- | |ScalarConstant{1} [id G]
- | [id H] (outer_in_non_seqs-0)
+ expected_output = """Scan{scan_fn, while_loop=False, inplace=all} [id A] 2 (outer_out_sit_sot-0)
+ ├─ 20000 [id B] (n_steps)
+ ├─ [ 0 ... 998 19999] [id C] (outer_in_seqs-0)
+ ├─ SetSubtensor{:stop} [id D] 1 (outer_in_sit_sot-0)
+ │ ├─ AllocEmpty{dtype='int64'} [id E] 0
+ │ │ └─ 20000 [id B]
+ │ ├─ [0] [id F]
+ │ └─ 1 [id G]
+ └─ [id H] (outer_in_non_seqs-0)
Inner graphs:
- forall_inplace,cpu,scan_fn} [id A] (outer_out_sit_sot-0)
- >Elemwise{Composite} [id I] (inner_out_sit_sot-0)
- > |TensorConstant{0} [id J]
- > |Subtensor{int64, int64, uint8} [id K]
- > | |*2- [id L] -> [id H] (inner_in_non_seqs-0)
- > | |ScalarFromTensor [id M]
- > | | |*0- [id N] -> [id C] (inner_in_seqs-0)
- > | |ScalarFromTensor [id O]
- > | | |*1- [id P] -> [id D] (inner_in_sit_sot-0)
- > | |ScalarConstant{0} [id Q]
- > |TensorConstant{1} [id R]
-
- Elemwise{Composite} [id I]
- >Switch [id S]
- > |LT [id T]
- > | | [id U]
- > | | [id V]
- > | [id W]
- > | [id U]
+ Scan{scan_fn, while_loop=False, inplace=all} [id A]
+ ← Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
+ ├─ 0 [id J]
+ ├─ Subtensor{i, j, k} [id K]
+ │ ├─ *2- [id L] -> [id H] (inner_in_non_seqs-0)
+ │ ├─ ScalarFromTensor [id M]
+ │ │ └─ *0- [id N] -> [id C] (inner_in_seqs-0)
+ │ ├─ ScalarFromTensor [id O]
+ │ │ └─ *1- [id P] -> [id D] (inner_in_sit_sot-0)
+ │ └─ 0 [id Q]
+ └─ 1 [id R]
+
+ Composite{switch(lt(i0, i1), i2, i0)} [id I]
+ ← Switch [id S] 'o0'
+ ├─ LT [id T]
+ │ ├─ i0 [id U]
+ │ └─ i1 [id V]
+ ├─ i2 [id W]
+ └─ i0 [id U]
"""
output_str = debugprint(out, file="str", print_op_info=True)
- print(output_str)
lines = output_str.split("\n")
for truth, out in zip(expected_output.split("\n"), lines):
diff --git a/tests/tensor/rewriting/test_basic.py b/tests/tensor/rewriting/test_basic.py
index dc0dfa58a9..06694f6a67 100644
--- a/tests/tensor/rewriting/test_basic.py
+++ b/tests/tensor/rewriting/test_basic.py
@@ -12,6 +12,7 @@
from pytensor.compile.mode import get_default_mode, get_mode
from pytensor.compile.ops import DeepCopyOp, deep_copy_op
from pytensor.configdefaults import config
+from pytensor.graph.basic import equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -1410,31 +1411,28 @@ def simple_rewrite(self, g):
def test_matrix_matrix(self):
a, b = matrices("ab")
- g = self.simple_rewrite(FunctionGraph([a, b], [dot(a, b).T]))
- sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{1,0}(a)))"
- assert str(g) == sg, (str(g), sg)
+ g = self.simple_rewrite(FunctionGraph([a, b], [dot(a, b).T], clone=False))
+ assert equal_computations(g.outputs, [dot(b.T, a.T)])
assert check_stack_trace(g, ops_to_check="all")
def test_row_matrix(self):
a = vector("a")
b = matrix("b")
g = rewrite(
- FunctionGraph([a, b], [dot(a.dimshuffle("x", 0), b).T]),
+ FunctionGraph([a, b], [dot(a.dimshuffle("x", 0), b).T], clone=False),
level="stabilize",
)
- sg = "FunctionGraph(dot(InplaceDimShuffle{1,0}(b), InplaceDimShuffle{0,x}(a)))"
- assert str(g) == sg, (str(g), sg)
+ assert equal_computations(g.outputs, [dot(b.T, a.dimshuffle(0, "x"))])
assert check_stack_trace(g, ops_to_check="all")
def test_matrix_col(self):
a = vector("a")
b = matrix("b")
g = rewrite(
- FunctionGraph([a, b], [dot(b, a.dimshuffle(0, "x")).T]),
+ FunctionGraph([a, b], [dot(b, a.dimshuffle(0, "x")).T], clone=False),
level="stabilize",
)
- sg = "FunctionGraph(dot(InplaceDimShuffle{x,0}(a), InplaceDimShuffle{1,0}(b)))"
- assert str(g) == sg, (str(g), sg)
+ assert equal_computations(g.outputs, [dot(a.dimshuffle("x", 0), b.T)])
assert check_stack_trace(g, ops_to_check="all")
diff --git a/tests/tensor/rewriting/test_elemwise.py b/tests/tensor/rewriting/test_elemwise.py
index ddec8c5292..1f149ed965 100644
--- a/tests/tensor/rewriting/test_elemwise.py
+++ b/tests/tensor/rewriting/test_elemwise.py
@@ -12,7 +12,7 @@
from pytensor.compile.mode import Mode, get_default_mode
from pytensor.configdefaults import config
from pytensor.gradient import grad
-from pytensor.graph.basic import Constant
+from pytensor.graph.basic import Constant, equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import check_stack_trace, out2in
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
@@ -86,113 +86,66 @@ def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)):
class TestDimshuffleLift:
def test_double_transpose(self):
- x, y, z = inputs()
+ x, *_ = inputs()
e = ds(ds(x, (1, 0)), (1, 0))
- g = FunctionGraph([x], [e])
- # TODO FIXME: Construct these graphs and compare them.
- assert (
- str(g) == "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{1,0}(x)))"
- )
+ g = FunctionGraph([x], [e], clone=False)
+ assert isinstance(g.outputs[0].owner.op, DimShuffle)
dimshuffle_lift.rewrite(g)
- assert str(g) == "FunctionGraph(x)"
+ assert g.outputs[0] is x
# no need to check_stack_trace as graph is supposed to be empty
def test_merge2(self):
- x, y, z = inputs()
+ x, *_ = inputs()
e = ds(ds(x, (1, "x", 0)), (2, 0, "x", 1))
- g = FunctionGraph([x], [e])
- # TODO FIXME: Construct these graphs and compare them.
- assert (
- str(g)
- == "FunctionGraph(InplaceDimShuffle{2,0,x,1}(InplaceDimShuffle{1,x,0}(x)))"
- ), str(g)
+ g = FunctionGraph([x], [e], clone=False)
+ assert len(g.apply_nodes) == 2
dimshuffle_lift.rewrite(g)
- assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1,x,x}(x))", str(g)
+ assert equal_computations(g.outputs, [x.dimshuffle(0, 1, "x", "x")])
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all")
def test_elim3(self):
x, y, z = inputs()
e = ds(ds(ds(x, (0, "x", 1)), (2, 0, "x", 1)), (1, 0))
- g = FunctionGraph([x], [e])
- # TODO FIXME: Construct these graphs and compare them.
- assert str(g) == (
- "FunctionGraph(InplaceDimShuffle{1,0}(InplaceDimShuffle{2,0,x,1}"
- "(InplaceDimShuffle{0,x,1}(x))))"
- ), str(g)
+ g = FunctionGraph([x], [e], clone=False)
+ assert isinstance(g.outputs[0].owner.op, DimShuffle)
dimshuffle_lift.rewrite(g)
- assert str(g) == "FunctionGraph(x)", str(g)
+ assert g.outputs[0] is x
# no need to check_stack_trace as graph is supposed to be empty
def test_lift(self):
x, y, z = inputs([False] * 1, [False] * 2, [False] * 3)
e = x + y + z
- g = FunctionGraph([x, y, z], [e])
-
- # TODO FIXME: Construct these graphs and compare them.
- # It does not really matter if the DimShuffles are inplace
- # or not.
- init_str_g_inplace = (
- "FunctionGraph(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0,1}"
- "(Elemwise{add,no_inplace}(InplaceDimShuffle{x,0}(x), y)), z))"
- )
- init_str_g_noinplace = (
- "FunctionGraph(Elemwise{add,no_inplace}(DimShuffle{x,0,1}"
- "(Elemwise{add,no_inplace}(DimShuffle{x,0}(x), y)), z))"
- )
- assert str(g) in (init_str_g_inplace, init_str_g_noinplace), str(g)
-
- rewrite_str_g_inplace = (
- "FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
- "(InplaceDimShuffle{x,x,0}(x), InplaceDimShuffle{x,0,1}(y)), z))"
- )
- rewrite_str_g_noinplace = (
- "FunctionGraph(Elemwise{add,no_inplace}(Elemwise{add,no_inplace}"
- "(DimShuffle{x,x,0}(x), DimShuffle{x,0,1}(y)), z))"
- )
+ g = FunctionGraph([x, y, z], [e], clone=False)
dimshuffle_lift.rewrite(g)
- assert str(g) in (rewrite_str_g_inplace, rewrite_str_g_noinplace), str(g)
+ assert equal_computations(
+ g.outputs,
+ [(x.dimshuffle("x", "x", 0) + y.dimshuffle("x", 0, 1)) + z],
+ )
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all")
def test_recursive_lift(self):
- v = vector(dtype="float64")
- m = matrix(dtype="float64")
+ v = vector("v", dtype="float64")
+ m = matrix("m", dtype="float64")
out = ((v + 42) * (m + 84)).T
- g = FunctionGraph([v, m], [out])
- # TODO FIXME: Construct these graphs and compare them.
- init_str_g = (
- "FunctionGraph(InplaceDimShuffle{1,0}(Elemwise{mul,no_inplace}"
- "(InplaceDimShuffle{x,0}(Elemwise{add,no_inplace}"
- "(, "
- "InplaceDimShuffle{x}(TensorConstant{42}))), "
- "Elemwise{add,no_inplace}"
- "(, "
- "InplaceDimShuffle{x,x}(TensorConstant{84})))))"
- )
- assert str(g) == init_str_g
- new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner)[0]
- new_g = FunctionGraph(g.inputs, [new_out])
- rewrite_str_g = (
- "FunctionGraph(Elemwise{mul,no_inplace}(Elemwise{add,no_inplace}"
- "(InplaceDimShuffle{0,x}(), "
- "InplaceDimShuffle{x,x}(TensorConstant{42})), "
- "Elemwise{add,no_inplace}(InplaceDimShuffle{1,0}"
- "(), "
- "InplaceDimShuffle{x,x}(TensorConstant{84}))))"
+ g = FunctionGraph([v, m], [out], clone=False)
+ new_out = local_dimshuffle_lift.transform(g, g.outputs[0].owner)
+ assert equal_computations(
+ new_out,
+ [(v.dimshuffle(0, "x") + 42) * (m.T + 84)],
)
- assert str(new_g) == rewrite_str_g
# Check stacktrace was copied over correctly after rewrite was applied
+ new_g = FunctionGraph(g.inputs, new_out, clone=False)
assert check_stack_trace(new_g, ops_to_check="all")
def test_useless_dimshuffle(self):
- x, _, _ = inputs()
+ x, *_ = inputs()
e = ds(x, (0, 1))
- g = FunctionGraph([x], [e])
- # TODO FIXME: Construct these graphs and compare them.
- assert str(g) == "FunctionGraph(InplaceDimShuffle{0,1}(x))"
+ g = FunctionGraph([x], [e], clone=False)
+ assert isinstance(g.outputs[0].owner.op, DimShuffle)
dimshuffle_lift.rewrite(g)
- assert str(g) == "FunctionGraph(x)"
+ assert g.outputs[0] is x
# Check stacktrace was copied over correctly after rewrite was applied
assert hasattr(g.outputs[0].tag, "trace")
@@ -203,17 +156,10 @@ def test_dimshuffle_on_broadcastable(self):
ds_y = ds(y, (2, 1, 0)) # useless
ds_z = ds(z, (2, 1, 0)) # useful
ds_u = ds(u, ("x")) # useful
- g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u])
- # TODO FIXME: Construct these graphs and compare them.
- assert (
- str(g)
- == "FunctionGraph(InplaceDimShuffle{0,x}(x), InplaceDimShuffle{2,1,0}(y), InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
- )
+ g = FunctionGraph([x, y, z, u], [ds_x, ds_y, ds_z, ds_u], clone=False)
+ assert len(g.apply_nodes) == 4
dimshuffle_lift.rewrite(g)
- assert (
- str(g)
- == "FunctionGraph(x, y, InplaceDimShuffle{2,1,0}(z), InplaceDimShuffle{x}(TensorConstant{1}))"
- )
+ assert equal_computations(g.outputs, [x, y, z.T, u.dimshuffle("x")])
# Check stacktrace was copied over correctly after rewrite was applied
assert hasattr(g.outputs[0].tag, "trace")
@@ -237,34 +183,32 @@ def test_local_useless_dimshuffle_in_reshape():
reshape_dimshuffle_row,
reshape_dimshuffle_col,
],
+ clone=False,
)
-
- # TODO FIXME: Construct these graphs and compare them.
- assert str(g) == (
- "FunctionGraph(Reshape{1}(InplaceDimShuffle{x,0}(vector), Shape(vector)), "
- "Reshape{2}(InplaceDimShuffle{x,0,x,1}(mat), Shape(mat)), "
- "Reshape{2}(InplaceDimShuffle{1,x}(row), Shape(row)), "
- "Reshape{2}(InplaceDimShuffle{0}(col), Shape(col)))"
- )
+ assert len(g.apply_nodes) == 4 * 3
useless_dimshuffle_in_reshape = out2in(local_useless_dimshuffle_in_reshape)
useless_dimshuffle_in_reshape.rewrite(g)
- assert str(g) == (
- "FunctionGraph(Reshape{1}(vector, Shape(vector)), "
- "Reshape{2}(mat, Shape(mat)), "
- "Reshape{2}(row, Shape(row)), "
- "Reshape{2}(col, Shape(col)))"
+ assert equal_computations(
+ g.outputs,
+ [
+ reshape(vec, vec.shape),
+ reshape(mat, mat.shape),
+ reshape(row, row.shape),
+ reshape(col, col.shape),
+ ],
)
-
# Check stacktrace was copied over correctly after rewrite was applied
assert check_stack_trace(g, ops_to_check="all")
# Check that the rewrite does not get applied when the order
# of dimensions has changed.
reshape_dimshuffle_mat2 = reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)
- h = FunctionGraph([mat], [reshape_dimshuffle_mat2])
- str_h = str(h)
+ h = FunctionGraph([mat], [reshape_dimshuffle_mat2], clone=False)
+ assert len(h.apply_nodes) == 3
useless_dimshuffle_in_reshape.rewrite(h)
- assert str(h) == str_h
+ assert equal_computations(
+ h.outputs, [reshape(mat.dimshuffle("x", 1, "x", 0), mat.shape)]
+ )
class TestFusion:
diff --git a/tests/tensor/test_elemwise.py b/tests/tensor/test_elemwise.py
index 8db1afe7e3..40e7db879c 100644
--- a/tests/tensor/test_elemwise.py
+++ b/tests/tensor/test_elemwise.py
@@ -676,14 +676,9 @@ def test_infer_shape(self, dtype=None, pre_scalar_op=None):
def test_str(self):
op = CAReduce(aes.add, axis=None)
- assert str(op) == "CAReduce{add}"
+ assert str(op) == "CAReduce{add, axes=None}"
op = CAReduce(aes.add, axis=(1,))
- assert str(op) == "CAReduce{add}{axis=[1]}"
-
- op = CAReduce(aes.add, axis=None, acc_dtype="float64")
- assert str(op) == "CAReduce{add}{acc_dtype=float64}"
- op = CAReduce(aes.add, axis=(1,), acc_dtype="float64")
- assert str(op) == "CAReduce{add}{axis=[1], acc_dtype=float64}"
+ assert str(op) == "CAReduce{add, axis=1}"
def test_repeated_axis(self):
x = vector("x")
@@ -802,10 +797,8 @@ def test_input_dimensions_match_c(self):
self.check_input_dimensions_match(Mode(linker="c"))
def test_str(self):
- op = Elemwise(aes.add, inplace_pattern=None, name=None)
- assert str(op) == "Elemwise{add}"
op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None)
- assert str(op) == "Elemwise{add}[(0, 0)]"
+ assert str(op) == "Add"
op = Elemwise(aes.add, inplace_pattern=None, name="my_op")
assert str(op) == "my_op"
diff --git a/tests/tensor/test_type.py b/tests/tensor/test_type.py
index 8ef4ca2864..4e9c456829 100644
--- a/tests/tensor/test_type.py
+++ b/tests/tensor/test_type.py
@@ -252,7 +252,7 @@ def test_fixed_shape_basic():
assert t1.shape == (2, 3)
assert t1.broadcastable == (False, False)
- assert str(t1) == "TensorType(float64, (2, 3))"
+ assert str(t1) == "Matrix(float64, shape=(2, 3))"
t1 = TensorType("float64", shape=(1,))
assert t1.shape == (1,)
diff --git a/tests/tensor/test_var.py b/tests/tensor/test_var.py
index 483b0f7d28..c17f524797 100644
--- a/tests/tensor/test_var.py
+++ b/tests/tensor/test_var.py
@@ -213,7 +213,7 @@ def test_print_constant():
c = pytensor.tensor.constant(1, name="const")
assert str(c) == "const{1}"
d = pytensor.tensor.constant(1)
- assert str(d) == "TensorConstant{1}"
+ assert str(d) == "1"
@pytest.mark.parametrize(
diff --git a/tests/test_printing.py b/tests/test_printing.py
index e481db5ac2..72c47943b5 100644
--- a/tests/test_printing.py
+++ b/tests/test_printing.py
@@ -106,9 +106,9 @@ def test_min_informative_str():
mis = min_informative_str(G).replace("\t", " ")
- reference = """A. Elemwise{add,no_inplace}
+ reference = """A. Add
B. C
- C. Elemwise{add,no_inplace}
+ C. Add
D. D
E. E"""
@@ -144,13 +144,13 @@ def test_debugprint():
s = s.getvalue()
reference = dedent(
r"""
- Elemwise{add,no_inplace} [id 0]
- |Elemwise{add,no_inplace} [id 1] 'C'
- | |A [id 2]
- | |B [id 3]
- |Elemwise{add,no_inplace} [id 4]
- |D [id 5]
- |E [id 6]
+ Add [id 0]
+ ├─ Add [id 1] 'C'
+ │ ├─ A [id 2]
+ │ └─ B [id 3]
+ └─ Add [id 4]
+ ├─ D [id 5]
+ └─ E [id 6]
"""
).lstrip()
@@ -162,13 +162,13 @@ def test_debugprint():
# The additional white space are needed!
reference = dedent(
r"""
- Elemwise{add,no_inplace} [id A]
- |Elemwise{add,no_inplace} [id B] 'C'
- | |A [id C]
- | |B [id D]
- |Elemwise{add,no_inplace} [id E]
- |D [id F]
- |E [id G]
+ Add [id A]
+ ├─ Add [id B] 'C'
+ │ ├─ A [id C]
+ │ └─ B [id D]
+ └─ Add [id E]
+ ├─ D [id F]
+ └─ E [id G]
"""
).lstrip()
@@ -180,11 +180,12 @@ def test_debugprint():
# The additional white space are needed!
reference = dedent(
r"""
- Elemwise{add,no_inplace} [id A]
- |Elemwise{add,no_inplace} [id B] 'C'
- |Elemwise{add,no_inplace} [id C]
- |D [id D]
- |E [id E]
+ Add [id A]
+ ├─ Add [id B] 'C'
+ │ └─ ···
+ └─ Add [id C]
+ ├─ D [id D]
+ └─ E [id E]
"""
).lstrip()
@@ -195,13 +196,13 @@ def test_debugprint():
s = s.getvalue()
reference = dedent(
r"""
- Elemwise{add,no_inplace}
- |Elemwise{add,no_inplace} 'C'
- | |A
- | |B
- |Elemwise{add,no_inplace}
- |D
- |E
+ Add
+ ├─ Add 'C'
+ │ ├─ A
+ │ └─ B
+ └─ Add
+ ├─ D
+ └─ E
"""
).lstrip()
@@ -212,11 +213,11 @@ def test_debugprint():
s = s.getvalue()
reference = dedent(
r"""
- Elemwise{add,no_inplace} 0 [None]
- |A [None]
- |B [None]
- |D [None]
- |E [None]
+ Add 0 [None]
+ ├─ A [None]
+ ├─ B [None]
+ ├─ D [None]
+ └─ E [None]
"""
).lstrip()
@@ -230,11 +231,11 @@ def test_debugprint():
s = s.getvalue()
reference = dedent(
r"""
- Elemwise{add,no_inplace} 0 [None]
- |A [None]
- |B [None]
- |D [None]
- |E [None]
+ Add 0 [None]
+ ├─ A [None]
+ ├─ B [None]
+ ├─ D [None]
+ └─ E [None]
"""
).lstrip()
@@ -248,11 +249,11 @@ def test_debugprint():
s = s.getvalue()
reference = dedent(
r"""
- Elemwise{add,no_inplace} 0 [None]
- |A [None]
- |B [None]
- |D [None]
- |E [None]
+ Add 0 [None]
+ ├─ A [None]
+ ├─ B [None]
+ ├─ D [None]
+ └─ E [None]
"""
).lstrip()
@@ -273,27 +274,27 @@ def test_debugprint():
s = s.getvalue()
exp_res = dedent(
r"""
- Elemwise{Composite} 4
- |InplaceDimShuffle{x,0} v={0: [0]} 3
- | |CGemv{inplace} d={0: [0]} 2
- | |AllocEmpty{dtype='float64'} 1
- | | |Shape_i{0} 0
- | | |B
- | |TensorConstant{1.0}
- | |B
- | |
- | |TensorConstant{0.0}
- |D
- |A
+ Composite{(i2 + (i0 - i1))} 4
+ ├─ ExpandDims{axis=0} v={0: [0]} 3
+ │ └─ CGemv{inplace} d={0: [0]} 2
+ │ ├─ AllocEmpty{dtype='float64'} 1
+ │ │ └─ Shape_i{0} 0
+ │ │ └─ B
+ │ ├─ 1.0
+ │ ├─ B
+ │ ├─
+ │ └─ 0.0
+ ├─ D
+ └─ A
Inner graphs:
- Elemwise{Composite}
- >add
- > |
- > |sub
- > |
- > |
+ Composite{(i2 + (i0 - i1))}
+ ← add 'o0'
+ ├─ i2
+ └─ sub
+ ├─ i0
+ └─ i1
"""
).lstrip()
@@ -313,11 +314,11 @@ def test_debugprint_id_type():
debugprint(e_at, id_type="auto", file=s)
s = s.getvalue()
- exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}]
- |dot [id {d_at.auto_name}]
- | | [id {b_at.auto_name}]
- | | [id {a_at.auto_name}]
- | [id {a_at.auto_name}]
+ exp_res = f"""Add [id {e_at.auto_name}]
+ ├─ dot [id {d_at.auto_name}]
+ │ ├─ [id {b_at.auto_name}]
+ │ └─ [id {a_at.auto_name}]
+ └─ [id {a_at.auto_name}]
"""
assert [l.strip() for l in s.split("\n")] == [
@@ -328,7 +329,7 @@ def test_debugprint_id_type():
def test_pprint():
x = dvector()
y = x[1]
- assert pp(y) == "[1]"
+ assert pp(y) == "[1]"
def test_debugprint_inner_graph():
@@ -351,15 +352,15 @@ def test_debugprint_inner_graph():
lines = output_str.split("\n")
exp_res = """MyInnerGraphOp [id A]
- |3 [id B]
- |4 [id C]
+ ├─ 3 [id B]
+ └─ 4 [id C]
Inner graphs:
MyInnerGraphOp [id A]
- >op2 [id D] 'igo1'
- > |*0- [id E]
- > |*1- [id F]
+ ← op2 [id D] 'igo1'
+ ├─ *0- [id E]
+ └─ *1- [id F]
"""
for exp_line, res_line in zip(exp_res.split("\n"), lines):
@@ -375,19 +376,19 @@ def test_debugprint_inner_graph():
lines = output_str.split("\n")
exp_res = """MyInnerGraphOp [id A]
- |5 [id B]
+ └─ 5 [id B]
Inner graphs:
MyInnerGraphOp [id A]
- >MyInnerGraphOp [id C]
- > |*0- [id D]
- > |*1- [id E]
+ ← MyInnerGraphOp [id C]
+ ├─ *0- [id D]
+ └─ *1- [id E]
MyInnerGraphOp [id C]
- >op2 [id F] 'igo1'
- > |*0- [id D]
- > |*1- [id E]
+ ← op2 [id F] 'igo1'
+ ├─ *0- [id D]
+ └─ *1- [id E]
"""
for exp_line, res_line in zip(exp_res.split("\n"), lines):