From 37fb461c20d97e32b8c41e787f821e466775d60d Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 5 Jan 2023 12:02:04 +0100 Subject: [PATCH 1/7] Fix bug in gradient of Elemwise containing multi-output scalars --- pytensor/tensor/elemwise.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytensor/tensor/elemwise.py b/pytensor/tensor/elemwise.py index bbbd3831f2..aa77326638 100644 --- a/pytensor/tensor/elemwise.py +++ b/pytensor/tensor/elemwise.py @@ -636,6 +636,9 @@ def transform(r): return DimShuffle((), ["x"] * nd)(res) new_r = Elemwise(node.op, {})(*[transform(ipt) for ipt in node.inputs]) + if isinstance(new_r, (list, tuple)): + # Scalar Op with multiple outputs + new_r = new_r[r.owner.outputs.index(r)] return new_r ret = [] From f6f3ef6161cfa5480e44b3f999bc194ca5742ffa Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Fri, 28 Apr 2023 11:25:52 +0200 Subject: [PATCH 2/7] Refactor baseclass ScalarInnerGraphOp from Composite Op --- pytensor/scalar/basic.py | 288 ++++++++++++++++++++------------------- 1 file changed, 148 insertions(+), 140 deletions(-) diff --git a/pytensor/scalar/basic.py b/pytensor/scalar/basic.py index b0639ff588..96cf107d64 100644 --- a/pytensor/scalar/basic.py +++ b/pytensor/scalar/basic.py @@ -3986,7 +3986,150 @@ def c_code(self, *args, **kwargs): complex_from_polar = ComplexFromPolar(name="complex_from_polar") -class Composite(ScalarOp, HasInnerGraph): +class ScalarInnerGraphOp(ScalarOp, HasInnerGraph): + """Includes boilerplate code for Python and C-implementation of Scalar Ops with inner graph.""" + + def __init__(self, *args, **kwargs): + self.prepare_node_called = set() + + @property + def fn(self): + return None + + @property + def inner_inputs(self): + return self.fgraph.inputs + + @property + def inner_outputs(self): + return self.fgraph.outputs + + @property + def py_perform_fn(self): + if hasattr(self, "_py_perform_fn"): + return self._py_perform_fn + + from pytensor.link.utils import fgraph_to_python + + def python_convert(op, node=None, **kwargs): + assert node is not None + + n_outs = len(node.outputs) + + if n_outs > 1: + + def _perform(*inputs, outputs=[[None]] * n_outs): + op.perform(node, inputs, outputs) + return tuple(o[0] for o in outputs) + + else: + + def _perform(*inputs, outputs=[[None]]): + op.perform(node, inputs, outputs) + return outputs[0][0] + + return _perform + + self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert) + return self._py_perform_fn + + def impl(self, *inputs): + output_storage = [[None] for i in range(self.nout)] + self.perform(None, inputs, output_storage) + ret = to_return_values([storage[0] for storage in output_storage]) + if self.nout > 1: + ret = tuple(ret) + return ret + + def c_code_cache_version(self): + rval = list(self.c_code_cache_version_outer()) + for x in self.fgraph.toposort(): + xv = x.op.c_code_cache_version() + if xv: + rval.append(xv) + else: + return () + return tuple(rval) + + def c_header_dirs(self, **kwargs): + rval = sum( + (subnode.op.c_header_dirs(**kwargs) for subnode in self.fgraph.toposort()), + [], + ) + return rval + + def c_support_code(self, **kwargs): + # Remove duplicate code blocks by using a `set` + rval = { + subnode.op.c_support_code(**kwargs).strip() + for subnode in self.fgraph.toposort() + } + return "\n".join(sorted(rval)) + + def c_support_code_apply(self, node, name): + rval = [] + for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames): + subnode_support_code = subnode.op.c_support_code_apply( + subnode, subnodename % dict(nodename=name) + ) + if subnode_support_code: + rval.append(subnode_support_code) + # there should be no need to remove duplicate code blocks because + # each block should have been specialized for the given nodename. + # Any block that isn't specialized should be returned via + # c_support_code instead of c_support_code_apply. + return "\n".join(rval) + + def prepare_node(self, node, storage_map, compute_map, impl): + if impl not in self.prepare_node_called: + for n in list_of_nodes(self.inputs, self.outputs): + n.op.prepare_node(n, None, None, impl) + self.prepare_node_called.add(impl) + + def __eq__(self, other): + if self is other: + return True + if ( + type(self) != type(other) + or self.nin != other.nin + or self.nout != other.nout + ): + return False + + # TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this + # object to generate the same `_c_code`? + return self.c_code_template == other.c_code_template + + def __hash__(self): + # Note that in general, the configparser settings at the time + # of code generation (__init__) affect the semantics of this Op. + # This function assumes that all relevant info about the configparser + # is embodied in _c_code. So the _c_code, rather than self.fgraph, + # is the signature of the semantics of this Op. + # _c_code is preserved through unpickling, so the Op will not change + # semantics when it is reloaded with different configparser + # settings. + # + # TODO FIXME: Doesn't the above just mean that we should be including + # the relevant "configparser settings" here? Also, why should we even + # care about the exact form of the generated C code when comparing + # `Op`s? All this smells of leaky concerns and interfaces. + return hash((type(self), self.nin, self.nout, self.c_code_template)) + + def __getstate__(self): + rval = dict(self.__dict__) + rval.pop("_c_code", None) + rval.pop("_py_perform_fn", None) + rval.pop("_fgraph", None) + rval.pop("prepare_node_called", None) + return rval + + def __setstate__(self, d): + self.__dict__.update(d) + self.prepare_node_called = set() + + +class Composite(ScalarInnerGraphOp): """ Composite is an Op that takes a graph of scalar operations and produces c code for the whole graph. Its purpose is to implement loop @@ -4043,19 +4186,7 @@ def __init__(self, inputs, outputs, name="Composite"): self.outputs_type = tuple([output.type for output in outputs]) self.nin = len(inputs) self.nout = len(outputs) - self.prepare_node_called = set() - - @property - def fn(self): - return None - - @property - def inner_inputs(self): - return self.fgraph.inputs - - @property - def inner_outputs(self): - return self.fgraph.outputs + super().__init__() def __str__(self): return self.name @@ -4076,35 +4207,6 @@ def make_new_inplace(self, output_types_preference=None, name=None): super(Composite, out).__init__(output_types_preference, name) return out - @property - def py_perform(self): - if hasattr(self, "_py_perform_fn"): - return self._py_perform_fn - - from pytensor.link.utils import fgraph_to_python - - def python_convert(op, node=None, **kwargs): - assert node is not None - - n_outs = len(node.outputs) - - if n_outs > 1: - - def _perform(*inputs, outputs=[[None]] * n_outs): - op.perform(node, inputs, outputs) - return tuple(o[0] for o in outputs) - - else: - - def _perform(*inputs, outputs=[[None]]): - op.perform(node, inputs, outputs) - return outputs[0][0] - - return _perform - - self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert) - return self._py_perform_fn - @property def fgraph(self): if hasattr(self, "_fgraph"): @@ -4139,12 +4241,6 @@ def fgraph(self): self._fgraph = fgraph return self._fgraph - def prepare_node(self, node, storage_map, compute_map, impl): - if impl not in self.prepare_node_called: - for n in list_of_nodes(self.inputs, self.outputs): - n.op.prepare_node(n, None, None, impl) - self.prepare_node_called.add(impl) - def clone_float32(self): # This will not modify the fgraph or the nodes new_ins, new_outs = composite_f32.apply(self.fgraph) @@ -4155,8 +4251,6 @@ def clone(self): return Composite(new_ins, new_outs) def output_types(self, input_types): - # TODO FIXME: What's the intended purpose/use of this method, and why - # does it even need to be a method? if tuple(input_types) != self.inputs_type: raise TypeError( f"Wrong types for Composite. Expected {self.inputs_type}, got {tuple(input_types)}." @@ -4183,63 +4277,13 @@ def make_node(self, *inputs): return node def perform(self, node, inputs, output_storage): - outputs = self.py_perform(*inputs) + outputs = self.py_perform_fn(*inputs) for storage, out_val in zip(output_storage, outputs): storage[0] = out_val - def impl(self, *inputs): - output_storage = [[None] for i in range(self.nout)] - self.perform(None, inputs, output_storage) - ret = to_return_values([storage[0] for storage in output_storage]) - if self.nout > 1: - ret = tuple(ret) - return ret - def grad(self, inputs, output_grads): raise NotImplementedError("grad is not implemented for Composite") - def __eq__(self, other): - if self is other: - return True - if ( - type(self) != type(other) - or self.nin != other.nin - or self.nout != other.nout - ): - return False - - # TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this - # object to generate the same `_c_code`? - return self.c_code_template == other.c_code_template - - def __hash__(self): - # Note that in general, the configparser settings at the time - # of code generation (__init__) affect the semantics of this Op. - # This function assumes that all relevant info about the configparser - # is embodied in _c_code. So the _c_code, rather than self.fgraph, - # is the signature of the semantics of this Op. - # _c_code is preserved through unpickling, so the Op will not change - # semantics when it is reloaded with different configparser - # settings. - # - # TODO FIXME: Doesn't the above just mean that we should be including - # the relevant "configparser settings" here? Also, why should we even - # care about the exact form of the generated C code when comparing - # `Op`s? All this smells of leaky concerns and interfaces. - return hash((type(self), self.nin, self.nout, self.c_code_template)) - - def __getstate__(self): - rval = dict(self.__dict__) - rval.pop("_c_code", None) - rval.pop("_py_perform_fn", None) - rval.pop("_fgraph", None) - rval.pop("prepare_node_called", None) - return rval - - def __setstate__(self, d): - self.__dict__.update(d) - self.prepare_node_called = set() - @property def c_code_template(self): from pytensor.link.c.interface import CLinkerType @@ -4317,44 +4361,8 @@ def c_code(self, node, nodename, inames, onames, sub): return self.c_code_template % d - def c_code_cache_version(self): - rval = [3] - for x in self.fgraph.toposort(): - xv = x.op.c_code_cache_version() - if xv: - rval.append(xv) - else: - return () - return tuple(rval) - - def c_header_dirs(self, **kwargs): - rval = sum( - (subnode.op.c_header_dirs(**kwargs) for subnode in self.fgraph.toposort()), - [], - ) - return rval - - def c_support_code(self, **kwargs): - # Remove duplicate code blocks by using a `set` - rval = { - subnode.op.c_support_code(**kwargs).strip() - for subnode in self.fgraph.toposort() - } - return "\n".join(sorted(rval)) - - def c_support_code_apply(self, node, name): - rval = [] - for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames): - subnode_support_code = subnode.op.c_support_code_apply( - subnode, subnodename % dict(nodename=name) - ) - if subnode_support_code: - rval.append(subnode_support_code) - # there should be no need to remove duplicate code blocks because - # each block should have been specialized for the given nodename. - # Any block that isn't specialized should be returned via - # c_support_code instead of c_support_code_apply. - return "\n".join(rval) + def c_code_cache_version_outer(self) -> Tuple[int, ...]: + return (3,) class Compositef32: From 65d118332ad08b599b0fb280ac1c2c47bb1eba57 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 3 Apr 2023 13:05:25 +0200 Subject: [PATCH 3/7] Implement Scalar Loop Op --- pytensor/graph/basic.py | 6 +- pytensor/scalar/loop.py | 388 ++++++++++++++++++++++++++ pytensor/tensor/rewriting/elemwise.py | 10 +- tests/scalar/test_loop.py | 279 ++++++++++++++++++ 4 files changed, 677 insertions(+), 6 deletions(-) create mode 100644 pytensor/scalar/loop.py create mode 100644 tests/scalar/test_loop.py diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 321357a13f..31bd8d311a 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -1115,12 +1115,12 @@ def truncated_graph_inputs( def clone( - inputs: List[Variable], - outputs: List[Variable], + inputs: Sequence[Variable], + outputs: Sequence[Variable], copy_inputs: bool = True, copy_orphans: Optional[bool] = None, clone_inner_graphs: bool = False, -) -> Tuple[Collection[Variable], Collection[Variable]]: +) -> Tuple[List[Variable], List[Variable]]: r"""Copies the sub-graph contained between inputs and outputs. Parameters diff --git a/pytensor/scalar/loop.py b/pytensor/scalar/loop.py new file mode 100644 index 0000000000..74ee15bdbe --- /dev/null +++ b/pytensor/scalar/loop.py @@ -0,0 +1,388 @@ +import warnings +from copy import copy +from itertools import chain +from textwrap import dedent +from typing import Literal, Optional, Sequence, Tuple + +from pytensor.compile import rebuild_collect_shared +from pytensor.graph import Constant, FunctionGraph, Variable, clone +from pytensor.graph.rewriting.basic import MergeOptimizer +from pytensor.scalar.basic import ScalarInnerGraphOp, ScalarOp, as_scalar + + +class ScalarLoop(ScalarInnerGraphOp): + """Scalar Op that encapsulates a scalar loop operation. + + This Op can be used for the gradient of other Scalar Ops. + It is much more restricted that `Scan` in that the entire inner graph must be composed of Scalar operations. + + """ + + init_param: Tuple[str, ...] = ( + "init", + "update", + "constant", + "until", + "until_condition_failed", + ) + + def __init__( + self, + init: Sequence[Variable], + update: Sequence[Variable], + constant: Optional[Sequence[Variable]] = None, + until: Optional[Variable] = None, + until_condition_failed: Literal["ignore", "warn", "raise"] = "warn", + name="ScalarLoop", + ): + if until_condition_failed not in ["ignore", "warn", "raise"]: + raise ValueError( + f"Invalid until_condition_failed: {until_condition_failed}" + ) + + if constant is None: + constant = [] + if not len(init) == len(update): + raise ValueError("An update must be given for each init variable") + if until: + inputs, (*outputs, until) = clone([*init, *constant], [*update, until]) + self.outputs = copy([*outputs, until]) + else: + inputs, outputs = clone([*init, *constant], update) + self.outputs = copy(outputs) + self.inputs = copy(inputs) + + self.inputs_type = tuple(input.type for input in inputs) + self.outputs_type = tuple(output.type for output in outputs) + self.nin = len(inputs) + 1 # n_steps is not part of the inner graph + self.nout = len(outputs) # until is not output + self.is_while = bool(until) + self.until_condition_failed = until_condition_failed + self.name = name + self._validate_fgraph(FunctionGraph(self.inputs, self.outputs, clone=False)) + super().__init__() + + def output_types(self, input_types): + return self.outputs_type + + def _validate_fgraph(self, fgraph: FunctionGraph) -> None: + for node in fgraph.apply_nodes: + if not isinstance(node.op, ScalarOp): + raise TypeError( + "The fgraph of ScalarLoop must be composed exclusively of ScalarOp nodes" + ) + + init = fgraph.inputs + update = fgraph.outputs + + if self.is_while: + *update, until = update + if not until.type.dtype == "bool": + raise TypeError( + f"Until condition must be boolean, got {until}({until.type.dtype})" + ) + + for i, u in zip(init, update): + if i.type != u.type: + raise TypeError( + "Init and update types must be the same: " + f"{i}({i.type}) != {u}({u.type})" + ) + if set(init) & set(update): + raise ValueError( + "Some inputs and outputs are the same variable. " + "If you want to return an output as a lagged input, wrap it in an identity Op." + ) + + @property + def fgraph(self): + if hasattr(self, "_fgraph"): + return self._fgraph + + fgraph = FunctionGraph(self.inputs, self.outputs) + # TODO: We could convert to TensorVariable, optimize graph, + # and then convert back to ScalarVariable. + # This would introduce rewrites like `log(1 + x) -> log1p`. + MergeOptimizer().rewrite(fgraph) + self._validate_fgraph(fgraph) + + # Clone identical outputs that have been merged + if len(set(fgraph.outputs)) != len(self.outputs): + old_outputs = fgraph.outputs + new_outputs = [] + for output in old_outputs: + if output not in new_outputs: + new_outputs.append(output) + else: + node = output.owner + output_idx = node.outputs.index(output) + new_output = node.clone().outputs[output_idx] + new_outputs.append(new_output) + fgraph = FunctionGraph(fgraph.inputs, new_outputs, clone=False) + + self._fgraph = fgraph + return self._fgraph + + def clone(self): + if self.is_while: + *update, until = self.outputs + else: + update, until = self.outputs, None + init = self.inputs[: len(update)] + constant = self.inputs[len(update) :] + return ScalarLoop( + init=init, + update=update, + constant=constant, + until=until, + until_condition_failed=self.until_condition_failed, + name=self.name, + ) + + @property + def fn(self): + raise NotImplementedError + + def make_new_inplace(self, output_types_preference=None, name=None): + """ + This op.__init__ fct don't have the same parameter as other scalar op. + This break the insert_inplace_optimizer optimization. + This fct allow fix patch this. + + """ + d = {k: getattr(self, k) for k in self.init_param} + out = self.__class__(**d) + if name: + out.name = name + else: + name = out.name + super(ScalarLoop, out).__init__(output_types_preference, name) + return out + + def make_node(self, n_steps, *inputs): + assert len(inputs) == self.nin - 1 + + n_steps = as_scalar(n_steps) + if not n_steps.type.dtype.startswith("int"): + raise TypeError( + "The first variable of ScalarLoop (n_steps) must be of integer type. " + f"Got {n_steps.type.dtype}", + ) + + if self.inputs_type == tuple([i.type for i in inputs]): + return super().make_node(n_steps, *inputs) + else: + # Make a new op with the right input types. + res = rebuild_collect_shared( + self.outputs, + replace=dict(zip(self.inputs, inputs)), + rebuild_strict=False, + ) + if self.is_while: + *cloned_update, cloned_until = res[1] + else: + cloned_update, cloned_until = res[1], None + cloned_inputs = [res[2][0][i] for i in inputs] + cloned_init = cloned_inputs[: len(cloned_update)] + cloned_constant = cloned_inputs[len(cloned_update) :] + # This will fail if the cloned init have a different dtype than the cloned_update + op = ScalarLoop( + init=cloned_init, + update=cloned_update, + constant=cloned_constant, + until=cloned_until, + until_condition_failed=self.until_condition_failed, + name=self.name, + ) + node = op.make_node(n_steps, *inputs) + return node + + def perform(self, node, inputs, output_storage): + n_steps, *inputs = inputs + n_update = len(self.outputs) - (1 if self.is_while else 0) + carry, constant = inputs[:n_update], inputs[n_update:] + inner_fn = self.py_perform_fn + + if self.is_while: + until = True + for i in range(n_steps): + *carry, until = inner_fn(*carry, *constant) + if until: + break + + if not until: # no-break + if self.until_condition_failed == "raise": + raise RuntimeError( + f"Until condition in ScalarLoop {self.name} not reached!" + ) + elif self.until_condition_failed == "warn": + warnings.warn( + f"Until condition in ScalarLoop {self.name} not reached!", + RuntimeWarning, + ) + else: + if n_steps < 0: + raise ValueError("ScalarLoop does not have a termination condition.") + for i in range(n_steps): + carry = inner_fn(*carry, *constant) + + for storage, out_val in zip(output_storage, carry): + storage[0] = out_val + + @property + def c_code_template(self): + from pytensor.link.c.interface import CLinkerType + + if hasattr(self, "_c_code"): + return self._c_code + + fgraph = self.fgraph + + # The first input is `n_steps` so we skip it in the mapping dictionary + n_update = len(self.outputs) - (1 if self.is_while else 0) + carry_subd = { + c: f"%(i{int(i)})s" for i, c in enumerate(fgraph.inputs[:n_update], start=1) + } + constant_subd = { + c: f"%(i{int(i)})s" + for i, c in enumerate(fgraph.inputs[n_update:], start=n_update + 1) + } + update_subd = { + u: f"%(o{int(i)})s" for i, u in enumerate(fgraph.outputs[:n_update]) + } + until_subd = {u: "until" for u in fgraph.outputs[n_update:]} + subd = {**carry_subd, **constant_subd, **update_subd, **until_subd} + + for var in fgraph.variables: + if var.owner is None: + if var not in self.fgraph.inputs: + # This is an orphan + if isinstance(var, Constant) and isinstance(var.type, CLinkerType): + subd[var] = var.type.c_literal(var.data) + else: + raise ValueError( + "All orphans in the fgraph to ScalarLoop must" + " be Constant, CLinkerType instances." + ) + elif any(i.dtype == "float16" for i in var.owner.inputs) or any( + o.dtype == "float16" for o in var.owner.outputs + ): + # flag for elemwise ops to check. + self.inner_float16 = True + + _c_code = "{\n" + if self.is_while: + _c_code += "bool until = 1;\n\n" + + # Copy carried inputs + for i, (var, name) in enumerate(carry_subd.items()): + copy_var_name = f"{name}_copy{i}" + _c_code += f"{var.type.dtype_specs()[1]} {copy_var_name} = {name};\n" + carry_subd[var] = copy_var_name + subd[var] = copy_var_name + + # _c_code += 'printf("inputs=[");' + # for i in range(1, len(fgraph.inputs)): + # _c_code += f'printf("%%.16g, ", %(i{i})s);' + # _c_code += 'printf("]\\n");\n' + + _c_code += "\nfor(%(n_steps_dtype)s i = 0; i < %(n_steps)s; i++){\n" + + self.nodenames = [ + f"%(nodename)s_subnode{int(j)}" for j, n in enumerate(fgraph.toposort()) + ] + + i = 0 + for j, node in enumerate(fgraph.toposort()): + for output in node.outputs: + if output not in subd: + i += 1 + name = f"V%(id)s_tmp{int(i)}" + subd[output] = name + _c_code += f"{output.type.dtype_specs()[1]} {name};\n" + s = node.op.c_code( + node, + self.nodenames[j], + # Any node that depended on `init` will depend on `update` instead + # The initial value of `update` was set to `init` before the loop + [subd[input] for input in node.inputs], + [subd[output] for output in node.outputs], + dict(fail="%(fail)s", id=f"%(id)s_{int(j)}"), + ) + _c_code += s + _c_code += "\n" + + # Set the carry variables to the output variables + _c_code += "\n" + for init, update in zip(carry_subd.values(), update_subd.values()): + _c_code += f"{init} = {update};\n" + + # _c_code += 'printf("%%ld\\n", i);\n' + # for carry in range(1, 10): + # _c_code += f'printf("\\t %%.g\\n", i, %(i{carry})s_copy{carry-1});\n' + + if self.is_while: + _c_code += "\nif(until){break;}\n" + + _c_code += "}\n" + + # End of the loop + if self.is_while: + if self.until_condition_failed == "raise": + _c_code += dedent( + f""" + if (!until) {{ + PyErr_SetString(PyExc_RuntimeError, "Until condition in ScalarLoop {self.name} not reached!"); + %(fail)s + }} + """ + ) + elif self.until_condition_failed == "warn": + _c_code += dedent( + f""" + if (!until) {{ + PyErr_WarnEx(PyExc_RuntimeWarning, "Until condition in ScalarLoop {self.name} not reached!", 1); + }} + """ + ) + + _c_code += "}\n" + + self._c_code = _c_code + + return self._c_code + + def c_code(self, node, nodename, inames, onames, sub): + d = dict( + chain( + zip((f"i{int(i)}" for i in range(len(inames))), inames), + zip((f"o{int(i)}" for i in range(len(onames))), onames), + ), + **sub, + ) + d["nodename"] = nodename + if "id" not in sub: + # The use of a dummy id is safe as the code is in a separate block. + # It won't generate conflicting variable name. + d["id"] = "_DUMMY_ID_" + + # When called inside Elemwise we don't have access to the dtype + # via the usual `f"dtype_{inames[i]}"` variable + d["n_steps"] = inames[0] + d["n_steps_dtype"] = "npy_" + node.inputs[0].dtype + + res = self.c_code_template % d + # print(res) + return res + + def c_code_cache_version_outer(self): + return (1,) + + def __eq__(self, other): + return ( + super().__eq__(other) + and self.until_condition_failed == other.until_condition_failed + ) + + def __hash__(self): + return hash((super().__hash__(), self.until_condition_failed)) diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index d3be0e2079..03308b9983 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -22,6 +22,7 @@ ) from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError, MethodNotDefined +from pytensor.scalar.loop import ScalarLoop from pytensor.tensor.basic import ( MakeVector, alloc, @@ -66,9 +67,12 @@ def print_profile(cls, stream, prof, level=0): print(blanc, n, ndim[n], file=stream) def candidate_input_idxs(self, node): - if isinstance(node.op.scalar_op, aes.Composite) and len(node.outputs) > 1: - # TODO: Implement specialized InplaceCompositeOptimizer with logic - # needed to correctly assign inplace for multi-output Composites + # TODO: Implement specialized InplaceCompositeOptimizer with logic + # needed to correctly assign inplace for multi-output Composites + # and ScalarLoops + if isinstance(node.op.scalar_op, ScalarLoop): + return [] + if isinstance(node.op.scalar_op, aes.Composite) and (len(node.outputs) > 1): return [] else: return range(len(node.outputs)) diff --git a/tests/scalar/test_loop.py b/tests/scalar/test_loop.py new file mode 100644 index 0000000000..041960ca46 --- /dev/null +++ b/tests/scalar/test_loop.py @@ -0,0 +1,279 @@ +import re + +import numpy as np +import pytest + +from pytensor import Mode, function +from pytensor.scalar import ( + Composite, + as_scalar, + cos, + exp, + float16, + float32, + float64, + identity, + int64, + sin, +) +from pytensor.scalar.loop import ScalarLoop +from pytensor.tensor import exp as tensor_exp + + +mode = pytest.mark.parametrize( + "mode", + [ + Mode(optimizer="fast_compile", linker="py"), + Mode(optimizer="fast_compile", linker="cvm"), + ], +) + + +@mode +def test_single_output(mode): + n_steps = int64("n_steps") + x0 = float64("x0") + const = float64("const") + x = x0 + const + + op = ScalarLoop(init=[x0], constant=[const], update=[x]) + x = op(n_steps, x0, const) + + fn = function([n_steps, x0, const], x, mode=mode) + np.testing.assert_allclose(fn(5, 0, 1), 5) + np.testing.assert_allclose(fn(5, 0, 2), 10) + np.testing.assert_allclose(fn(4, 3, -1), -1) + + +@mode +def test_multiple_output(mode): + n_steps = int64("n_steps") + x0 = float64("x0") + y0 = int64("y0") + const = float64("const") + x = x0 + const + y = y0 + 1 + + op = ScalarLoop(init=[x0, y0], constant=[const], update=[x, y]) + x, y = op(n_steps, x0, y0, const) + + fn = function([n_steps, x0, y0, const], [x, y], mode=mode) + + res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=1) + np.testing.assert_allclose(res_x, 5) + np.testing.assert_allclose(res_y, 5) + + res_x, res_y = fn(n_steps=5, x0=0, y0=0, const=2) + np.testing.assert_allclose(res_x, 10) + np.testing.assert_allclose(res_y, 5) + + res_x, res_y = fn(n_steps=4, x0=3, y0=2, const=-1) + np.testing.assert_allclose(res_x, -1) + np.testing.assert_allclose(res_y, 6) + + +@mode +def test_input_not_aliased_to_update(mode): + n_steps = int64("n_steps") + x0 = float64("x0") + y0 = float64("y0") + const = float64("const") + + def update(x_prev, y_prev): + x = x_prev + const + # y depends on x_prev, so x_prev should not be overriden by x! + y = y_prev + x_prev + return [x, y] + + op = ScalarLoop(init=[x0, y0], constant=[const], update=update(x0, y0)) + x, y = op(n_steps, x0, y0, const) + + fn = function([n_steps, x0, y0, const], y, mode=mode) + np.testing.assert_allclose(fn(n_steps=1, x0=0, y0=0, const=1), 0.0) + np.testing.assert_allclose(fn(n_steps=2, x0=0, y0=0, const=1), 1.0) + np.testing.assert_allclose(fn(n_steps=3, x0=0, y0=0, const=1), 3.0) + np.testing.assert_allclose(fn(n_steps=4, x0=0, y0=0, const=1), 6.0) + np.testing.assert_allclose(fn(n_steps=5, x0=0, y0=0, const=1), 10.0) + + +@mode +def test_until(mode): + n_steps = int64("n_steps") + x0 = float64("x0") + x = x0 + 1 + until = x >= 10 + + op = ScalarLoop(init=[x0], update=[x], until=until, until_condition_failed="ignore") + fn = function([n_steps, x0], op(n_steps, x0), mode=mode) + np.testing.assert_allclose(fn(n_steps=20, x0=0), 10) + np.testing.assert_allclose(fn(n_steps=20, x0=1), 10) + np.testing.assert_allclose(fn(n_steps=5, x0=1), 6) + + op = ScalarLoop( + init=[x0], + update=[x], + until=until, + until_condition_failed="warn", + name="TestLoop", + ) + fn = function([n_steps, x0], op(n_steps, x0), mode=mode) + np.testing.assert_allclose(fn(n_steps=20, x0=0), 10) + np.testing.assert_allclose(fn(n_steps=20, x0=1), 10) + with pytest.warns( + RuntimeWarning, match="Until condition in ScalarLoop TestLoop not reached!" + ): + np.testing.assert_allclose(fn(n_steps=5, x0=1), 6) + + op = ScalarLoop( + init=[x0], + update=[x], + until=until, + until_condition_failed="raise", + name="TestLoop", + ) + fn = function([n_steps, x0], op(n_steps, x0), mode=mode) + np.testing.assert_allclose(fn(n_steps=20, x0=0), 10) + np.testing.assert_allclose(fn(n_steps=20, x0=1), 10) + with pytest.raises( + RuntimeError, match="Until condition in ScalarLoop TestLoop not reached!" + ): + fn(n_steps=5, x0=1) + + +def test_update_missing_error(): + x0 = float64("x0") + const = float64("const") + with pytest.raises( + ValueError, match="An update must be given for each init variable" + ): + ScalarLoop(init=[x0], constant=[const], update=[]) + + +def test_init_update_type_error(): + x0 = float32("x0") + const = float64("const") + x = x0 + const + assert x.type.dtype == "float64" + with pytest.raises(TypeError, match="Init and update types must be the same"): + ScalarLoop(init=[x0], constant=[const], update=[x]) + + +def test_rebuild_dtype(): + n_steps = int64("n_steps") + x0 = float64("x0") + const = float64("const") + x = x0 + const + op = ScalarLoop(init=[x0], constant=[const], update=[x]) + + # If x0 is float32 but const is still float64, the output type will not be able to match + x0_float32 = float32("x0_float32") + with pytest.raises(TypeError, match="Init and update types must be the same"): + op(n_steps, x0_float32, const) + + # Now it should be fine + const_float32 = float32("const_float32") + y = op(n_steps, x0_float32, const_float32) + assert y.dtype == "float32" + + +def test_non_scalar_error(): + x0 = float64("x0") + x = as_scalar(tensor_exp(x0)) + + with pytest.raises( + TypeError, match="must be composed exclusively of ScalarOp nodes" + ): + ScalarLoop(init=[x0], constant=[], update=[x]) + + +def test_n_steps_type_error(): + x0 = float64("x0") + const = float64("const") + x = x0 + const + + op = ScalarLoop(init=[x0], constant=[const], update=[x]) + with pytest.raises( + TypeError, match=re.escape("(n_steps) must be of integer type. Got float64") + ): + op(float64("n_steps"), x0, const) + + +def test_same_out_as_inp_error(): + xtm2 = float64("xtm2") + xtm1 = float64("xtm1") + x = xtm2 + xtm1 + + with pytest.raises( + ValueError, match="Some inputs and outputs are the same variable" + ): + ScalarLoop(init=[xtm2, xtm1], update=[xtm1, x]) + + +@mode +def test_lags(mode): + n_steps = int64("n_steps") + xtm2 = float64("xtm2") + xtm1 = float64("xtm1") + x = xtm2 + xtm1 + + op = ScalarLoop(init=[xtm2, xtm1], update=[identity(xtm1), x]) + _, x = op(n_steps, xtm2, xtm1) + + fn = function([n_steps, xtm2, xtm1], x, mode=mode) + np.testing.assert_allclose(fn(n_steps=5, xtm2=0, xtm1=1), 8) + + +@mode +def test_inner_composite(mode): + n_steps = int64("n_steps") + x = float64("x") + + one = Composite([x], [cos(exp(x)) ** 2 + sin(exp(x)) ** 2])(x) + + op = ScalarLoop(init=[x], update=[one + x]) + y = op(n_steps, x) + + fn = function([n_steps, x], y, mode=mode) + np.testing.assert_allclose(fn(n_steps=5, x=2.53), 2.53 + 5) + + # Now with a dtype that must be rebuilt + x16 = float16("x16") + y16 = op(n_steps, x16) + assert y16.type.dtype == "float16" + + fn32 = function([n_steps, x16], y16, mode=mode) + np.testing.assert_allclose( + fn32(n_steps=9, x16=np.array(4.73, dtype="float16")), + 4.73 + 9, + rtol=1e-3, + ) + + +@mode +def test_inner_loop(mode): + n_steps = int64("n_steps") + x = float64("x") + + x_in = float64("x_in") + inner_loop_op = ScalarLoop(init=[x_in], update=[x_in + 1]) + + outer_loop_op = ScalarLoop( + init=[x], update=[inner_loop_op(n_steps, x)], constant=[n_steps] + ) + y = outer_loop_op(n_steps, x, n_steps) + + fn = function([n_steps, x], y, mode=mode) + np.testing.assert_allclose(fn(n_steps=5, x=0), 5**2) + np.testing.assert_allclose(fn(n_steps=7, x=0), 7**2) + np.testing.assert_allclose(fn(n_steps=7, x=1), 7**2 + 1) + + # Now with a dtype that must be rebuilt + x16 = float16("x16") + y16 = outer_loop_op(n_steps, x16, n_steps) + assert y16.type.dtype == "float16" + + fn32 = function([n_steps, x16], y16, mode=mode) + np.testing.assert_allclose( + fn32(n_steps=3, x16=np.array(2.5, dtype="float16")), + 3**2 + 2.5, + ) From 0f012ccb981a8a18a09b6366ea0e7103d7934817 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Mon, 3 Apr 2023 13:25:00 +0200 Subject: [PATCH 4/7] Use ScalarLoop for gammainc(c) gradients --- pytensor/scalar/math.py | 312 +++++++++++++++++++------------- tests/tensor/test_math_scipy.py | 31 +++- 2 files changed, 209 insertions(+), 134 deletions(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 42e70ddb9b..1a2ca7de9d 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -18,8 +18,11 @@ BinaryScalarOp, ScalarOp, UnaryScalarOp, + as_scalar, complex_types, + constant, discrete_types, + eq, exp, expm1, float64, @@ -27,6 +30,7 @@ isinf, log, log1p, + sqrt, switch, true_div, upcast, @@ -34,6 +38,7 @@ upgrade_to_float64, upgrade_to_float_no_complex, ) +from pytensor.scalar.loop import ScalarLoop class Erf(UnaryScalarOp): @@ -595,7 +600,7 @@ def grad(self, inputs, grads): (k, x) = inputs (gz,) = grads return [ - gz * gammainc_der(k, x), + gz * gammainc_grad(k, x), gz * exp(-x + (k - 1) * log(x) - gammaln(k)), ] @@ -644,7 +649,7 @@ def grad(self, inputs, grads): (k, x) = inputs (gz,) = grads return [ - gz * gammaincc_der(k, x), + gz * gammaincc_grad(k, x), gz * -exp(-x + (k - 1) * log(x) - gammaln(k)), ] @@ -675,162 +680,209 @@ def __hash__(self): gammaincc = GammaIncC(upgrade_to_float, name="gammaincc") -class GammaIncDer(BinaryScalarOp): - """ - Gradient of the the regularized lower gamma function (P) wrt to the first - argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_lower_inc_gamma.hpp` +def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name): + init = [as_scalar(x) for x in init] + constant = [as_scalar(x) for x in constant] + # Create dummy types, in case some variables have the same initial form + init_ = [x.type() for x in init] + constant_ = [x.type() for x in constant] + update_, until_ = inner_loop_fn(*init_, *constant_) + op = ScalarLoop( + init=init_, + constant=constant_, + update=update_, + until=until_, + until_condition_failed="warn", + name=name, + ) + S, *_ = op(n_steps, *init, *constant) + return S + + +def gammainc_grad(k, x): + """Gradient of the regularized lower gamma function (P) wrt to the first + argument (k, a.k.a. alpha). + + Adapted from STAN `grad_reg_lower_inc_gamma.hpp` Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions. ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481. """ + dtype = upcast(k.type.dtype, x.type.dtype, "float32") - def impl(self, k, x): - if x == 0: - return 0 - - sqrt_exp = -756 - x**2 + 60 * x - if ( - (k < 0.8 and x > 15) - or (k < 12 and x > 30) - or (sqrt_exp > 0 and k < np.sqrt(sqrt_exp)) - ): - return -GammaIncCDer.st_impl(k, x) - - precision = 1e-10 - max_iters = int(1e5) + def grad_approx(skip_loop): + precision = np.array(1e-10, dtype=config.floatX) + max_iters = switch( + skip_loop, np.array(0, dtype="int32"), np.array(1e5, dtype="int32") + ) - log_x = np.log(x) - log_gamma_k_plus_1 = scipy.special.gammaln(k + 1) + log_x = log(x) + log_gamma_k_plus_1 = gammaln(k + 1) - k_plus_n = k + # First loop + k_plus_n = k # Should not overflow unless k > 2,147,383,647 log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1 - sum_a = 0.0 - for n in range(0, max_iters + 1): - term = np.exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) - sum_a += term + sum_a0 = np.array(0.0, dtype=dtype) - if term <= precision: - break + def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x): + term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) + sum_a += term - log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n) + log_gamma_k_plus_n_plus_1 += log1p(k_plus_n) k_plus_n += 1 - - if n >= max_iters: - warnings.warn( - f"gammainc_der did not converge after {n} iterations", - RuntimeWarning, + return ( + (sum_a, log_gamma_k_plus_n_plus_1, k_plus_n), + (term <= precision), ) - return np.nan - k_plus_n = k + init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n] + constant = [log_x] + sum_a = _make_scalar_loop( + max_iters, init, constant, inner_loop_a, name="gammainc_grad_a" + ) + + # Second loop + n = np.array(0, dtype="int32") log_gamma_k_plus_n_plus_1 = log_gamma_k_plus_1 - sum_b = 0.0 - for n in range(0, max_iters + 1): - term = np.exp( - k_plus_n * log_x - log_gamma_k_plus_n_plus_1 - ) * scipy.special.digamma(k_plus_n + 1) - sum_b += term + k_plus_n = k + sum_b0 = np.array(0.0, dtype=dtype) - if term <= precision and n >= 1: # Require at least two iterations - return np.exp(-x) * (log_x * sum_a - sum_b) + def inner_loop_b(sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n, log_x): + term = exp(k_plus_n * log_x - log_gamma_k_plus_n_plus_1) * psi(k_plus_n + 1) + sum_b += term - log_gamma_k_plus_n_plus_1 += np.log1p(k_plus_n) + log_gamma_k_plus_n_plus_1 += log1p(k_plus_n) + n += 1 k_plus_n += 1 + return ( + (sum_b, log_gamma_k_plus_n_plus_1, n, k_plus_n), + # Require at least two iterations + ((term <= precision) & (n > 1)), + ) - warnings.warn( - f"gammainc_der did not converge after {n} iterations", - RuntimeWarning, + init = [sum_b0, log_gamma_k_plus_n_plus_1, n, k_plus_n] + constant = [log_x] + sum_b, *_ = _make_scalar_loop( + max_iters, init, constant, inner_loop_b, name="gammainc_grad_b" ) - return np.nan - - def c_code(self, *args, **kwargs): - raise NotImplementedError() - - -gammainc_der = GammaIncDer(upgrade_to_float, name="gammainc_der") - -class GammaIncCDer(BinaryScalarOp): - """ - Gradient of the the regularized upper gamma function (Q) wrt to the first - argument (k, a.k.a. alpha). Adapted from STAN `grad_reg_inc_gamma.hpp` + grad_approx = exp(-x) * (log_x * sum_a - sum_b) + return grad_approx + + zero_branch = eq(x, 0) + sqrt_exp = -756 - x**2 + 60 * x + gammaincc_branch = ( + ((k < 0.8) & (x > 15)) + | ((k < 12) & (x > 30)) + | ((sqrt_exp > 0) & (k < sqrt(sqrt_exp))) + ) + grad = switch( + zero_branch, + 0, + switch( + gammaincc_branch, + -gammaincc_grad(k, x, skip_loops=zero_branch | (~gammaincc_branch)), + grad_approx(skip_loop=zero_branch | gammaincc_branch), + ), + ) + return grad + + +def gammaincc_grad(k, x, skip_loops=constant(False, dtype="bool")): + """Gradient of the regularized upper gamma function (Q) wrt to the first + argument (k, a.k.a. alpha). + + Adapted from STAN `grad_reg_inc_gamma.hpp` + + skip_loops is used for faster branching when this function is called by `gammainc_der` """ + dtype = upcast(k.type.dtype, x.type.dtype, "float32") - @staticmethod - def st_impl(k, x): - gamma_k = scipy.special.gamma(k) - digamma_k = scipy.special.digamma(k) - log_x = np.log(x) - - # asymptotic expansion http://dlmf.nist.gov/8.11#E2 - if (x >= k) and (x >= 8): - S = 0 - k_minus_one_minus_n = k - 1 - fac = k_minus_one_minus_n - dfac = 1 - xpow = x + gamma_k = gamma(k) + digamma_k = psi(k) + log_x = log(x) + + def approx_a(skip_loop): + n_steps = switch( + skip_loop, np.array(0, dtype="int32"), np.array(9, dtype="int32") + ) + sum_a0 = np.array(0.0, dtype=dtype) + dfac = np.array(1.0, dtype=dtype) + xpow = x + k_minus_one_minus_n = k - 1 + fac = k_minus_one_minus_n + delta = true_div(dfac, xpow) + + def inner_loop_a(sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac, x): + sum_a += delta + xpow *= x + k_minus_one_minus_n -= 1 + dfac = k_minus_one_minus_n * dfac + fac + fac *= k_minus_one_minus_n delta = dfac / xpow + return (sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac), () - for n in range(1, 10): - k_minus_one_minus_n -= 1 - S += delta - xpow *= x - dfac = k_minus_one_minus_n * dfac + fac - fac *= k_minus_one_minus_n - delta = dfac / xpow - if np.isinf(delta): - warnings.warn( - "gammaincc_der did not converge", - RuntimeWarning, - ) - return np.nan + init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac] + constant = [x] + sum_a = _make_scalar_loop( + n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a" + ) + grad_approx_a = ( + gammaincc(k, x) * (log_x - digamma_k) + + exp(-x + (k - 1) * log_x) * sum_a / gamma_k + ) + return grad_approx_a + def approx_b(skip_loop): + max_iters = switch( + skip_loop, np.array(0, dtype="int32"), np.array(1e5, dtype="int32") + ) + log_precision = np.array(np.log(1e-6), dtype=config.floatX) + + sum_b0 = np.array(0.0, dtype=dtype) + log_s = np.array(0.0, dtype=dtype) + s_sign = np.array(1, dtype="int8") + n = np.array(1, dtype="int32") + log_delta = log_s - 2 * log(k) + + def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x): + delta = exp(log_delta) + sum_b += switch(s_sign > 0, delta, -delta) + s_sign = -s_sign + + # log will cast >int16 to float64 + log_s_inc = log_x - log(n) + if log_s_inc.type.dtype != log_s.type.dtype: + log_s_inc = log_s_inc.astype(log_s.type.dtype) + log_s += log_s_inc + + new_log_delta = log_s - 2 * log(n + k) + if new_log_delta.type.dtype != log_delta.type.dtype: + new_log_delta = new_log_delta.astype(log_delta.type.dtype) + log_delta = new_log_delta + + n += 1 return ( - scipy.special.gammaincc(k, x) * (log_x - digamma_k) - + np.exp(-x + (k - 1) * log_x) * S / gamma_k - ) - - # gradient of series expansion http://dlmf.nist.gov/8.7#E3 - else: - log_precision = np.log(1e-6) - max_iters = int(1e5) - S = 0 - log_s = 0.0 - s_sign = 1 - log_delta = log_s - 2 * np.log(k) - for n in range(1, max_iters + 1): - S += np.exp(log_delta) if s_sign > 0 else -np.exp(log_delta) - s_sign = -s_sign - log_s += log_x - np.log(n) - log_delta = log_s - 2 * np.log(n + k) - - if np.isinf(log_delta): - warnings.warn( - "gammaincc_der did not converge", - RuntimeWarning, - ) - return np.nan - - if log_delta <= log_precision: - return ( - scipy.special.gammainc(k, x) * (digamma_k - log_x) - + np.exp(k * log_x) * S / gamma_k - ) - - warnings.warn( - f"gammaincc_der did not converge after {n} iterations", - RuntimeWarning, + (sum_b, log_s, s_sign, log_delta, n), + log_delta <= log_precision, ) - return np.nan - - def impl(self, k, x): - return self.st_impl(k, x) - - def c_code(self, *args, **kwargs): - raise NotImplementedError() - -gammaincc_der = GammaIncCDer(upgrade_to_float, name="gammaincc_der") + init = [sum_b0, log_s, s_sign, log_delta, n] + constant = [k, log_x] + sum_b = _make_scalar_loop( + max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b" + ) + grad_approx_b = ( + gammainc(k, x) * (digamma_k - log_x) + exp(k * log_x) * sum_b / gamma_k + ) + return grad_approx_b + + branch_a = (x >= k) & (x >= 8) + return switch( + branch_a, + approx_a(skip_loop=~branch_a | skip_loops), + approx_b(skip_loop=branch_a | skip_loops), + ) class GammaU(BinaryScalarOp): diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 5d0af664c8..4e90b8a081 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -3,6 +3,8 @@ import numpy as np import pytest +from pytensor.gradient import verify_grad + scipy = pytest.importorskip("scipy") @@ -11,11 +13,11 @@ import scipy.special import scipy.stats -from pytensor import function +from pytensor import function, grad from pytensor import tensor as at from pytensor.compile.mode import get_default_mode from pytensor.configdefaults import config -from pytensor.tensor import inplace +from pytensor.tensor import gammaincc, inplace, vector from tests import unittest_tools as utt from tests.tensor.utils import ( _good_broadcast_unary_chi2sf, @@ -387,6 +389,9 @@ def test_gammainc_ddk_tabulated_values(): gammaincc_ddk = at.grad(gammainc_out, k) f_grad = function([k, x], gammaincc_ddk) + rtol = 1e-5 if config.floatX == "float64" else 1e-2 + atol = 1e-10 if config.floatX == "float64" else 1e-6 + for test_k, test_x, expected_ddk in ( (0.0001, 0, 0), # Limit condition (0.0001, 0.0001, -8.62594024578651), @@ -421,10 +426,27 @@ def test_gammainc_ddk_tabulated_values(): (19.0001, 29.7501, -0.007828749832965796), ): np.testing.assert_allclose( - f_grad(test_k, test_x), expected_ddk, rtol=1e-5, atol=1e-14 + f_grad(test_k, test_x), expected_ddk, rtol=rtol, atol=atol ) +def test_gammaincc_ddk_performance(benchmark): + rng = np.random.default_rng(1) + k = vector("k") + x = vector("x") + + out = gammaincc(k, x) + grad_fn = function([k, x], grad(out.sum(), wrt=[k]), mode="FAST_RUN") + vals = [ + # Values that hit the second branch of the gradient + np.full((1000,), 3.2), + np.full((1000,), 0.01), + ] + + verify_grad(gammaincc, vals, rng=rng) + benchmark(grad_fn, *vals) + + TestGammaUBroadcast = makeBroadcastTester( op=at.gammau, expected=expected_gammau, @@ -796,7 +818,7 @@ def test_boik_robison_cox(self): betainc_out = at.betainc(a, b, z) betainc_grad = at.grad(betainc_out, [a, b]) f_grad = function([a, b, z], betainc_grad) - + decimal = 7 if config.floatX == "float64" else 5 for test_a, test_b, test_z, expected_dda, expected_ddb in ( (1.5, 11.0, 0.001, -4.5720356e-03, 1.1845673e-04), (1.5, 11.0, 0.5, -2.5501997e-03, 9.0824388e-04), @@ -806,6 +828,7 @@ def test_boik_robison_cox(self): np.testing.assert_almost_equal( f_grad(test_a, test_b, test_z), [expected_dda, expected_ddb], + decimal=decimal, ) def test_beta_inc_stan_grad_combined(self): From 5954874796fb35f32eac0369ebe6e5847edfd2d3 Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 25 Apr 2023 16:02:24 +0200 Subject: [PATCH 5/7] Use ScalarLoop for betainc gradient --- pytensor/scalar/math.py | 185 ++++++++++++++++++++++---------------- tests/scalar/test_math.py | 4 +- 2 files changed, 108 insertions(+), 81 deletions(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 1a2ca7de9d..f9bfa1826d 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -14,10 +14,9 @@ from pytensor.configdefaults import config from pytensor.gradient import grad_not_implemented +from pytensor.scalar.basic import BinaryScalarOp, ScalarOp, UnaryScalarOp +from pytensor.scalar.basic import abs as scalar_abs from pytensor.scalar.basic import ( - BinaryScalarOp, - ScalarOp, - UnaryScalarOp, as_scalar, complex_types, constant, @@ -27,9 +26,12 @@ expm1, float64, float_types, + identity, isinf, log, log1p, + reciprocal, + scalar_maximum, sqrt, switch, true_div, @@ -1329,8 +1331,8 @@ def grad(self, inp, grads): (gz,) = grads return [ - gz * betainc_der(a, b, x, True), - gz * betainc_der(a, b, x, False), + gz * betainc_grad(a, b, x, True), + gz * betainc_grad(a, b, x, False), gz * exp( log1p(-x) * (b - 1) @@ -1346,28 +1348,28 @@ def c_code(self, *args, **kwargs): betainc = BetaInc(upgrade_to_float_no_complex, name="betainc") -class BetaIncDer(ScalarOp): - """ - Gradient of the regularized incomplete beta function wrt to the first - argument (alpha) or the second argument (beta), depending on whether the - fourth argument to betainc_der is `True` or `False`, respectively. +def betainc_grad(p, q, x, wrtp: bool): + """Gradient of the regularized lower gamma function (P) wrt to the first + argument (k, a.k.a. alpha). - Reference: Boik, R. J., & Robison-Cox, J. F. (1998). Derivatives of the incomplete beta function. - Journal of Statistical Software, 3(1), 1-20. + Adapted from STAN `grad_reg_lower_inc_gamma.hpp` + + Reference: Gautschi, W. (1979). A computational procedure for incomplete gamma functions. + ACM Transactions on Mathematical Software (TOMS), 5(4), 466-481. """ - nin = 4 + def _betainc_der(p, q, x, wrtp, skip_loop): + dtype = upcast(p.type.dtype, q.type.dtype, x.type.dtype, "float32") + + def betaln(a, b): + return gammaln(a) + (gammaln(b) - gammaln(a + b)) - def impl(self, p, q, x, wrtp): def _betainc_a_n(f, p, q, n): """ Numerator (a_n) of the nth approximant of the continued fraction representation of the regularized incomplete beta function """ - if n == 1: - return p * f * (q - 1) / (q * (p + 1)) - p2n = p + 2 * n F1 = p**2 * f**2 * (n - 1) / (q**2) F2 = ( @@ -1377,7 +1379,11 @@ def _betainc_a_n(f, p, q, n): / ((p2n - 3) * (p2n - 2) ** 2 * (p2n - 1)) ) - return F1 * F2 + return switch( + eq(n, 1), + p * f * (q - 1) / (q * (p + 1)), + F1 * F2, + ) def _betainc_b_n(f, p, q, n): """ @@ -1397,9 +1403,6 @@ def _betainc_da_n_dp(f, p, q, n): Derivative of a_n wrt p """ - if n == 1: - return -p * f * (q - 1) / (q * (p + 1) ** 2) - pp = p**2 ppp = pp * p p2n = p + 2 * n @@ -1414,20 +1417,25 @@ def _betainc_da_n_dp(f, p, q, n): D1 = q**2 * (p2n - 3) ** 2 D2 = (p2n - 2) ** 3 * (p2n - 1) ** 2 - return (N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2 + return switch( + eq(n, 1), + -p * f * (q - 1) / (q * (p + 1) ** 2), + (N1 / D1) * (N2a + N2b + N2c + N2d + N2e) / D2, + ) def _betainc_da_n_dq(f, p, q, n): """ Derivative of a_n wrt q """ - if n == 1: - return p * f / (q * (p + 1)) - p2n = p + 2 * n F1 = (p**2 * f**2 / (q**2)) * (n - 1) * (p + n - 1) * (2 * q + p - 2) D1 = (p2n - 3) * (p2n - 2) ** 2 * (p2n - 1) - return F1 / D1 + return switch( + eq(n, 1), + p * f / (q * (p + 1)), + F1 / D1, + ) def _betainc_db_n_dp(f, p, q, n): """ @@ -1452,42 +1460,44 @@ def _betainc_db_n_dq(f, p, q, n): p2n = p + 2 * n return -(p**2 * f) / (q * (p2n - 2) * p2n) - # Input validation - if not (0 <= x <= 1) or p < 0 or q < 0: - return np.nan - - if x > (p / (p + q)): - return -self.impl(q, p, 1 - x, not wrtp) - - min_iters = 3 - max_iters = 200 - err_threshold = 1e-12 - - derivative_old = 0 + min_iters = np.array(3, dtype="int32") + max_iters = switch( + skip_loop, np.array(0, dtype="int32"), np.array(200, dtype="int32") + ) + err_threshold = np.array(1e-12, dtype=config.floatX) - Am2, Am1 = 1, 1 - Bm2, Bm1 = 0, 1 - dAm2, dAm1 = 0, 0 - dBm2, dBm1 = 0, 0 + Am2, Am1 = np.array(1, dtype=dtype), np.array(1, dtype=dtype) + Bm2, Bm1 = np.array(0, dtype=dtype), np.array(1, dtype=dtype) + dAm2, dAm1 = np.array(0, dtype=dtype), np.array(0, dtype=dtype) + dBm2, dBm1 = np.array(0, dtype=dtype), np.array(0, dtype=dtype) f = (q * x) / (p * (1 - x)) - K = np.exp( - p * np.log(x) - + (q - 1) * np.log1p(-x) - - np.log(p) - - scipy.special.betaln(p, q) - ) + K = exp(p * log(x) + (q - 1) * log1p(-x) - log(p) - betaln(p, q)) if wrtp: - dK = ( - np.log(x) - - 1 / p - + scipy.special.digamma(p + q) - - scipy.special.digamma(p) - ) + dK = log(x) - reciprocal(p) + psi(p + q) - psi(p) else: - dK = np.log1p(-x) + scipy.special.digamma(p + q) - scipy.special.digamma(q) - - for n in range(1, max_iters + 1): + dK = log1p(-x) + psi(p + q) - psi(q) + + derivative = np.array(0, dtype=dtype) + n = np.array(1, dtype="int16") # Enough for 200 max iters + + def inner_loop( + derivative, + Am2, + Am1, + Bm2, + Bm1, + dAm2, + dAm1, + dBm2, + dBm1, + n, + f, + p, + q, + K, + dK, + ): a_n_ = _betainc_a_n(f, p, q, n) b_n_ = _betainc_b_n(f, p, q, n) if wrtp: @@ -1502,36 +1512,53 @@ def _betainc_db_n_dq(f, p, q, n): dA = da_n * Am2 + a_n_ * dAm2 + db_n * Am1 + b_n_ * dAm1 dB = da_n * Bm2 + a_n_ * dBm2 + db_n * Bm1 + b_n_ * dBm1 - Am2, Am1 = Am1, A - Bm2, Bm1 = Bm1, B - dAm2, dAm1 = dAm1, dA - dBm2, dBm1 = dBm1, dB - - if n < min_iters - 1: - continue + Am2, Am1 = identity(Am1), identity(A) + Bm2, Bm1 = identity(Bm1), identity(B) + dAm2, dAm1 = identity(dAm1), identity(dA) + dBm2, dBm1 = identity(dBm1), identity(dB) F1 = A / B F2 = (dA - F1 * dB) / B - derivative = K * (F1 * dK + F2) + derivative_new = K * (F1 * dK + F2) - errapx = abs(derivative_old - derivative) - d_errapx = errapx / max(err_threshold, abs(derivative)) - derivative_old = derivative - - if d_errapx <= err_threshold: - return derivative + errapx = scalar_abs(derivative - derivative_new) + d_errapx = errapx / scalar_maximum( + err_threshold, scalar_abs(derivative_new) + ) - warnings.warn( - f"betainc_der did not converge after {n} iterations", - RuntimeWarning, - ) - return np.nan + min_iters_cond = n > (min_iters - 1) + derivative = switch( + min_iters_cond, + derivative_new, + derivative, + ) + n += 1 - def c_code(self, *args, **kwargs): - raise NotImplementedError() + return ( + (derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n), + (d_errapx <= err_threshold) & min_iters_cond, + ) + init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n] + constant = [f, p, q, K, dK] + grad = _make_scalar_loop( + max_iters, init, constant, inner_loop, name="betainc_grad" + ) + return grad -betainc_der = BetaIncDer(upgrade_to_float_no_complex, name="betainc_der") + # Input validation + nan_branch = (x < 0) | (x > 1) | (p < 0) | (q < 0) + flip_branch = x > (p / (p + q)) + grad = switch( + nan_branch, + np.nan, + switch( + flip_branch, + -_betainc_der(q, p, 1 - x, not wrtp, skip_loop=nan_branch | (~flip_branch)), + _betainc_der(p, q, x, wrtp, skip_loop=nan_branch | flip_branch), + ), + ) + return grad class Hyp2F1(ScalarOp): diff --git a/tests/scalar/test_math.py b/tests/scalar/test_math.py index 456ff8a5f6..ed09aa8426 100644 --- a/tests/scalar/test_math.py +++ b/tests/scalar/test_math.py @@ -8,7 +8,7 @@ from pytensor.link.c.basic import CLinker from pytensor.scalar.math import ( betainc, - betainc_der, + betainc_grad, gammainc, gammaincc, gammal, @@ -82,7 +82,7 @@ def test_betainc(): def test_betainc_derivative_nan(): a, b, x = at.scalars("a", "b", "x") - res = betainc_der(a, b, x, True) + res = betainc_grad(a, b, x, True) test_func = function([a, b, x], res, mode=Mode("py")) assert not np.isnan(test_func(1, 1, 1)) assert np.isnan(test_func(1, 1, -1)) From 9d31e8f8484773979c7c798fa713b35a727ec5ab Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Tue, 25 Apr 2023 18:58:03 +0200 Subject: [PATCH 6/7] Use ScalarLoop for hyp2f1 gradient --- pytensor/scalar/math.py | 284 +++++++++++++++------------ tests/tensor/test_math_scipy.py | 329 +++++++++++++++++--------------- 2 files changed, 332 insertions(+), 281 deletions(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index f9bfa1826d..2e01584512 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -5,7 +5,6 @@ """ import os -import warnings from textwrap import dedent import numpy as np @@ -26,7 +25,9 @@ expm1, float64, float_types, + floor, identity, + integer_types, isinf, log, log1p, @@ -853,15 +854,13 @@ def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x): s_sign = -s_sign # log will cast >int16 to float64 - log_s_inc = log_x - log(n) - if log_s_inc.type.dtype != log_s.type.dtype: - log_s_inc = log_s_inc.astype(log_s.type.dtype) - log_s += log_s_inc + log_s += log_x - log(n) + if log_s.type.dtype != dtype: + log_s = log_s.astype(dtype) - new_log_delta = log_s - 2 * log(n + k) - if new_log_delta.type.dtype != log_delta.type.dtype: - new_log_delta = new_log_delta.astype(log_delta.type.dtype) - log_delta = new_log_delta + log_delta = log_s - 2 * log(n + k) + if log_delta.type.dtype != dtype: + log_delta = log_delta.astype(dtype) n += 1 return ( @@ -1581,9 +1580,9 @@ def grad(self, inputs, grads): a, b, c, z = inputs (gz,) = grads return [ - gz * hyp2f1_der(a, b, c, z, wrt=0), - gz * hyp2f1_der(a, b, c, z, wrt=1), - gz * hyp2f1_der(a, b, c, z, wrt=2), + gz * hyp2f1_grad(a, b, c, z, wrt=0), + gz * hyp2f1_grad(a, b, c, z, wrt=1), + gz * hyp2f1_grad(a, b, c, z, wrt=2), gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z), ] @@ -1594,134 +1593,165 @@ def c_code(self, *args, **kwargs): hyp2f1 = Hyp2F1(upgrade_to_float, name="hyp2f1") -class Hyp2F1Der(ScalarOp): - """ - Derivatives of the Gaussian Hypergeometric function ``2F1(a, b; c; z)`` with respect to one of the first 3 inputs. +def _unsafe_sign(x): + # Unlike scalar.sign we don't worry about x being 0 or nan + return switch(x > 0, 1, -1) - Adapted from https://github.com/stan-dev/math/blob/develop/stan/math/prim/fun/grad_2F1.hpp - """ - nin = 5 +def hyp2f1_grad(a, b, c, z, wrt: int): + dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32") - def impl(self, a, b, c, z, wrt): - def check_2f1_converges(a, b, c, z) -> bool: - num_terms = 0 - is_polynomial = False + def check_2f1_converges(a, b, c, z): + def is_nonpositive_integer(x): + if x.type.dtype not in integer_types: + return eq(floor(x), x) & (x <= 0) + else: + return x <= 0 - def is_nonpositive_integer(x): - return x <= 0 and x.is_integer() + a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0) + num_terms = switch( + a_is_polynomial, + floor(scalar_abs(a)).astype("int64"), + 0, + ) - if is_nonpositive_integer(a) and abs(a) >= num_terms: - is_polynomial = True - num_terms = int(np.floor(abs(a))) - if is_nonpositive_integer(b) and abs(b) >= num_terms: - is_polynomial = True - num_terms = int(np.floor(abs(b))) + b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms) + num_terms = switch( + b_is_polynomial, + floor(scalar_abs(b)).astype("int64"), + num_terms, + ) - is_undefined = is_nonpositive_integer(c) and abs(c) <= num_terms + is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms) + is_polynomial = a_is_polynomial | b_is_polynomial - return not is_undefined and ( - is_polynomial or np.abs(z) < 1 or (np.abs(z) == 1 and c > (a + b)) - ) + return (~is_undefined) & ( + is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b))) + ) - def compute_grad_2f1(a, b, c, z, wrt): - """ - Notes - ----- - The algorithm can be derived by looking at the ratio of two successive terms in the series - β_{k+1}/β_{k} = A(k)/B(k) - β_{k+1} = A(k)/B(k) * β_{k} - d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule - - In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z - - The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k), - by dropping the respective term - d/da[A(k)/B(k)] = A(k)/B(k) / (a + k) - d/db[A(k)/B(k)] = A(k)/B(k) / (b + k) - d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k) - - The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and - tracking their signs. - """ + def compute_grad_2f1(a, b, c, z, wrt, skip_loop): + """ + Notes + ----- + The algorithm can be derived by looking at the ratio of two successive terms in the series + β_{k+1}/β_{k} = A(k)/B(k) + β_{k+1} = A(k)/B(k) * β_{k} + d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule + + In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z + + The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k), + by dropping the respective term + d/da[A(k)/B(k)] = A(k)/B(k) / (a + k) + d/db[A(k)/B(k)] = A(k)/B(k) / (b + k) + d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k) + + The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and + tracking their signs. + """ + + wrt_a = wrt_b = False + if wrt == 0: + wrt_a = True + elif wrt == 1: + wrt_b = True + elif wrt != 2: + raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}") + + min_steps = np.array( + 10, dtype="int32" + ) # https://github.com/stan-dev/math/issues/2857 + max_steps = switch( + skip_loop, np.array(0, dtype="int32"), np.array(int(1e6), dtype="int32") + ) + precision = np.array(1e-14, dtype=config.floatX) - wrt_a = wrt_b = False - if wrt == 0: - wrt_a = True - elif wrt == 1: - wrt_b = True - elif wrt != 2: - raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}") - - min_steps = 10 # https://github.com/stan-dev/math/issues/2857 - max_steps = int(1e6) - precision = 1e-14 - - res = 0 - - if z == 0: - return res - - log_g_old = -np.inf - log_t_old = 0.0 - log_t_new = 0.0 - sign_z = np.sign(z) - log_z = np.log(np.abs(z)) - - log_g_old_sign = 1 - log_t_old_sign = 1 - log_t_new_sign = 1 - sign_zk = sign_z - - for k in range(max_steps): - p = (a + k) * (b + k) / ((c + k) * (k + 1)) - if p == 0: - return res - log_t_new += np.log(np.abs(p)) + log_z - log_t_new_sign = np.sign(p) * log_t_new_sign - - term = log_g_old_sign * log_t_old_sign * np.exp(log_g_old - log_t_old) - if wrt_a: - term += np.reciprocal(a + k) - elif wrt_b: - term += np.reciprocal(b + k) - else: - term -= np.reciprocal(c + k) - - log_g_old = log_t_new + np.log(np.abs(term)) - log_g_old_sign = np.sign(term) * log_t_new_sign - g_current = log_g_old_sign * np.exp(log_g_old) * sign_zk - res += g_current - - log_t_old = log_t_new - log_t_old_sign = log_t_new_sign - sign_zk *= sign_z - - if k >= min_steps and np.abs(g_current) <= precision: - return res - - warnings.warn( - f"hyp2f1_der did not converge after {k} iterations", - RuntimeWarning, - ) - return np.nan + grad = np.array(0, dtype=dtype) + + log_g = np.array(-np.inf, dtype=dtype) + log_g_sign = np.array(1, dtype="int8") + + log_t = np.array(0.0, dtype=dtype) + log_t_sign = np.array(1, dtype="int8") + + log_z = log(scalar_abs(z)) + sign_z = _unsafe_sign(z) + + sign_zk = sign_z + k = np.array(0, dtype="int32") + + def inner_loop( + grad, + log_g, + log_g_sign, + log_t, + log_t_sign, + sign_zk, + k, + a, + b, + c, + log_z, + sign_z, + ): + p = (a + k) * (b + k) / ((c + k) * (k + 1)) + if p.type.dtype != dtype: + p = p.astype(dtype) + + term = log_g_sign * log_t_sign * exp(log_g - log_t) + if wrt_a: + term += reciprocal(a + k) + elif wrt_b: + term += reciprocal(b + k) + else: + term -= reciprocal(c + k) + + if term.type.dtype != dtype: + term = term.astype(dtype) + + log_t = log_t + log(scalar_abs(p)) + log_z + log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8") + log_g = log_t + log(scalar_abs(term)) + log_g_sign = (_unsafe_sign(term) * log_t_sign).astype("int8") + + g_current = log_g_sign * exp(log_g) * sign_zk - # TODO: We could implement the Euler transform to expand supported domain, as Stan does - if not check_2f1_converges(a, b, c, z): - warnings.warn( - f"Hyp2F1 does not meet convergence conditions with given arguments a={a}, b={b}, c={c}, z={z}", - RuntimeWarning, + # If p==0, don't update grad and get out of while loop next + grad = switch( + eq(p, 0), + grad, + grad + g_current, ) - return np.nan - return compute_grad_2f1(a, b, c, z, wrt=wrt) + sign_zk *= sign_z + k += 1 - def __call__(self, a, b, c, z, wrt, **kwargs): - # This allows wrt to be a keyword argument - return super().__call__(a, b, c, z, wrt, **kwargs) + return ( + (grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k), + (eq(p, 0) | ((k > min_steps) & (scalar_abs(g_current) <= precision))), + ) - def c_code(self, *args, **kwargs): - raise NotImplementedError() + init = [grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k] + constant = [a, b, c, log_z, sign_z] + grad = _make_scalar_loop( + max_steps, init, constant, inner_loop, name="hyp2f1_grad" + ) + return switch( + eq(z, 0), + 0, + grad, + ) -hyp2f1_der = Hyp2F1Der(upgrade_to_float, name="hyp2f1_der") + # We have to pass the converges flag to interrupt the loop, as the switch is not lazy + z_is_zero = eq(z, 0) + converges = check_2f1_converges(a, b, c, z) + return switch( + z_is_zero, + 0, + switch( + converges, + compute_grad_2f1(a, b, c, z, wrt, skip_loop=z_is_zero | (~converges)), + np.nan, + ), + ) diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 4e90b8a081..25e35ee0e6 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -1,4 +1,4 @@ -from contextlib import ExitStack as does_not_warn +import warnings import numpy as np import pytest @@ -872,162 +872,183 @@ def test_beta_inc_stan_grad_combined(self): ) -def test_hyp2f1_grad_stan_cases(): - """This test reuses the same test cases as in: - https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_2F1_test.cpp - https://github.com/andrjohns/math/blob/develop/test/unit/math/prim/fun/hypergeometric_2F1_test.cpp - - Note: The expected_ddz was computed from the perform method, as it is not part of all Stan tests - """ - a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z") - betainc_out = at.hyp2f1(a1, a2, b1, z) - betainc_grad = at.grad(betainc_out, [a1, a2, b1, z]) - f_grad = function([a1, a2, b1, z], betainc_grad) - - rtol = 1e-9 if config.floatX == "float64" else 1e-3 - - for ( - test_a1, - test_a2, - test_b1, - test_z, - expected_dda1, - expected_dda2, - expected_ddb1, - expected_ddz, - ) in ( - ( - 3.70975, - 1.0, - 2.70975, - -0.2, - -0.0488658806159776, - -0.193844936204681, - 0.0677809985598383, - 0.8652952472723672, - ), - (3.70975, 1.0, 2.70975, 0, 0, 0, 0, 1.369037734108313), - ( - 1.0, - 1.0, - 1.0, - 0.6, - 2.290726829685388, - 2.290726829685388, - -2.290726829685388, - 6.25, - ), - ( - 1.0, - 31.0, - 41.0, - 1.0, - 6.825270649241036, - 0.4938271604938271, - -0.382716049382716, - 17.22222222222223, - ), - ( - 1.0, - -2.1, - 41.0, - 1.0, - -0.04921317604093563, - 0.02256814168279349, - 0.00118482743834665, - -0.04854621426218426, - ), - ( - 1.0, - -0.5, - 10.6, - 0.3, - -0.01443822031245647, - 0.02829710651967078, - 0.00136986255602642, - -0.04846036062115473, - ), - ( - 1.0, - -0.5, - 10.0, - 0.3, - -0.0153218866216130, - 0.02999436412836072, - 0.0015413242328729, - -0.05144686244336445, - ), - ( - -0.5, - -4.5, - 11.0, - 0.3, - -0.1227022810085707, - -0.01298849638043795, - -0.0053540982315572, - 0.1959735211840362, - ), - ( - -0.5, - -4.5, - -3.2, - 0.9, - 0.85880025358111, - 0.4677704416159314, - -4.19010422485256, - -2.959196647856408, - ), - ( - 3.70975, - 1.0, - 2.70975, - -0.2, - -0.0488658806159776, - -0.193844936204681, - 0.0677809985598383, - 0.865295247272367, - ), - ( - 2.0, - 1.0, - 2.0, - 0.4, - 0.4617734323582945, - 0.851376039609984, - -0.4617734323582945, - 2.777777777777778, - ), - ( - 3.70975, - 1.0, - 2.70975, - 0.999696, - 29369830.002773938200417693317785, - 36347869.41885337, - -30843032.10697079073015067426929807, - 26278034019.28811, - ), - # Cases where series does not converge - (1.0, 12.0, 10.0, 1.0, np.nan, np.nan, np.nan, np.inf), - (1.0, 12.0, 20.0, 1.2, np.nan, np.nan, np.nan, np.inf), - # Case where series converges under Euler transform (not implemented!) - # (1.0, 1.0, 2.0, -5.0, -0.321040199556840, -0.321040199556840, 0.129536268190289, 0.0383370454357889), - (1.0, 1.0, 2.0, -5.0, np.nan, np.nan, np.nan, 0.0383370454357889), - ): - expectation = ( - pytest.warns( - RuntimeWarning, match="Hyp2F1 does not meet convergence conditions" - ) - if np.any( - np.isnan([expected_dda1, expected_dda2, expected_ddb1, expected_ddz]) +class TestHyp2F1Grad: + few_iters_case = ( + 2.0, + 1.0, + 2.0, + 0.4, + 0.4617734323582945, + 0.851376039609984, + -0.4617734323582945, + 2.777777777777778, + ) + + many_iters_case = ( + 3.70975, + 1.0, + 2.70975, + 0.999696, + 29369830.002773938200417693317785, + 36347869.41885337, + -30843032.10697079073015067426929807, + 26278034019.28811, + ) + + def test_hyp2f1_grad_stan_cases(self): + """This test reuses the same test cases as in: + https://github.com/stan-dev/math/blob/master/test/unit/math/prim/fun/grad_2F1_test.cpp + https://github.com/andrjohns/math/blob/develop/test/unit/math/prim/fun/hypergeometric_2F1_test.cpp + + Note: The expected_ddz was computed from the perform method, as it is not part of all Stan tests + """ + a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z") + hyp2f1_out = at.hyp2f1(a1, a2, b1, z) + hyp2f1_grad = at.grad(hyp2f1_out, [a1, a2, b1, z]) + f_grad = function([a1, a2, b1, z], hyp2f1_grad) + + rtol = 1e-9 if config.floatX == "float64" else 2e-3 + for ( + test_a1, + test_a2, + test_b1, + test_z, + expected_dda1, + expected_dda2, + expected_ddb1, + expected_ddz, + ) in ( + ( + 3.70975, + 1.0, + 2.70975, + -0.2, + -0.0488658806159776, + -0.193844936204681, + 0.0677809985598383, + 0.8652952472723672, + ), + (3.70975, 1.0, 2.70975, 0, 0, 0, 0, 1.369037734108313), + ( + 1.0, + 1.0, + 1.0, + 0.6, + 2.290726829685388, + 2.290726829685388, + -2.290726829685388, + 6.25, + ), + ( + 1.0, + 31.0, + 41.0, + 1.0, + 6.825270649241036, + 0.4938271604938271, + -0.382716049382716, + 17.22222222222223, + ), + ( + 1.0, + -2.1, + 41.0, + 1.0, + -0.04921317604093563, + 0.02256814168279349, + 0.00118482743834665, + -0.04854621426218426, + ), + ( + 1.0, + -0.5, + 10.6, + 0.3, + -0.01443822031245647, + 0.02829710651967078, + 0.00136986255602642, + -0.04846036062115473, + ), + ( + 1.0, + -0.5, + 10.0, + 0.3, + -0.0153218866216130, + 0.02999436412836072, + 0.0015413242328729, + -0.05144686244336445, + ), + ( + -0.5, + -4.5, + 11.0, + 0.3, + -0.1227022810085707, + -0.01298849638043795, + -0.0053540982315572, + 0.1959735211840362, + ), + ( + -0.5, + -4.5, + -3.2, + 0.9, + 0.85880025358111, + 0.4677704416159314, + -4.19010422485256, + -2.959196647856408, + ), + ( + 3.70975, + 1.0, + 2.70975, + -0.2, + -0.0488658806159776, + -0.193844936204681, + 0.0677809985598383, + 0.865295247272367, + ), + self.few_iters_case, + self.many_iters_case, + # Cases where series does not converge + (1.0, 12.0, 10.0, 1.0, np.nan, np.nan, np.nan, np.inf), + (1.0, 12.0, 20.0, 1.2, np.nan, np.nan, np.nan, np.inf), + # Case where series converges under Euler transform (not implemented!) + # (1.0, 1.0, 2.0, -5.0, -0.321040199556840, -0.321040199556840, 0.129536268190289, 0.0383370454357889), + (1.0, 1.0, 2.0, -5.0, np.nan, np.nan, np.nan, 0.0383370454357889), + ): + with warnings.catch_warnings(): + warnings.simplefilter("error") + warnings.filterwarnings( + "ignore", + category=RuntimeWarning, + message="divide by zero encountered in log", + ) + result = np.array(f_grad(test_a1, test_a2, test_b1, test_z)) + + np.testing.assert_allclose( + result, + np.array([expected_dda1, expected_dda2, expected_ddb1, expected_ddz]), + rtol=rtol, ) - else does_not_warn() - ) - with expectation: - result = np.array(f_grad(test_a1, test_a2, test_b1, test_z)) + @pytest.mark.parametrize("case", (few_iters_case, many_iters_case)) + @pytest.mark.parametrize("wrt", ("a", "all")) + def test_benchmark(self, case, wrt, benchmark): + a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z") + hyp2f1_out = at.hyp2f1(a1, a2, b1, z) + hyp2f1_grad = at.grad(hyp2f1_out, wrt=a1 if wrt == "a" else [a1, a2, b1, z]) + f_grad = function([a1, a2, b1, z], hyp2f1_grad) + + (test_a1, test_a2, test_b1, test_z, *expected_dds) = case + + result = benchmark(f_grad, test_a1, test_a2, test_b1, test_z) + + rtol = 1e-9 if config.floatX == "float64" else 2e-3 + expected_result = expected_dds[0] if wrt == "a" else np.array(expected_dds) np.testing.assert_allclose( result, - np.array([expected_dda1, expected_dda2, expected_ddb1, expected_ddz]), + expected_result, rtol=rtol, ) From 34230a2cb0e9fa5caa1a4bb1e8c993d6df22fd8a Mon Sep 17 00:00:00 2001 From: Ricardo Vieira Date: Thu, 27 Apr 2023 09:52:29 +0200 Subject: [PATCH 7/7] Fuse hyp2f1 grads --- pytensor/scalar/math.py | 323 +++++++++++++++----------- pytensor/tensor/rewriting/elemwise.py | 60 +++++ tests/tensor/test_math_scipy.py | 37 +++ 3 files changed, 285 insertions(+), 135 deletions(-) diff --git a/pytensor/scalar/math.py b/pytensor/scalar/math.py index 2e01584512..a72d19d7aa 100644 --- a/pytensor/scalar/math.py +++ b/pytensor/scalar/math.py @@ -5,7 +5,9 @@ """ import os +from functools import reduce from textwrap import dedent +from typing import Tuple import numpy as np import scipy.special @@ -683,14 +685,20 @@ def __hash__(self): gammaincc = GammaIncC(upgrade_to_float, name="gammaincc") -def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name): - init = [as_scalar(x) for x in init] +def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name, loop_op=ScalarLoop): + init = [as_scalar(x) if x is not None else None for x in init] constant = [as_scalar(x) for x in constant] + # Create dummy types, in case some variables have the same initial form - init_ = [x.type() for x in init] + init_ = [x.type() if x is not None else None for x in init] constant_ = [x.type() for x in constant] update_, until_ = inner_loop_fn(*init_, *constant_) - op = ScalarLoop( + + # Filter Nones + init = [i for i in init if i is not None] + init_ = [i for i in init_ if i is not None] + update_ = [u for u in update_ if u is not None] + op = loop_op( init=init_, constant=constant_, update=update_, @@ -698,8 +706,7 @@ def _make_scalar_loop(n_steps, init, constant, inner_loop_fn, name): until_condition_failed="warn", name=name, ) - S, *_ = op(n_steps, *init, *constant) - return S + return op(n_steps, *init, *constant) def gammainc_grad(k, x): @@ -740,7 +747,7 @@ def inner_loop_a(sum_a, log_gamma_k_plus_n_plus_1, k_plus_n, log_x): init = [sum_a0, log_gamma_k_plus_n_plus_1, k_plus_n] constant = [log_x] - sum_a = _make_scalar_loop( + sum_a, *_ = _make_scalar_loop( max_iters, init, constant, inner_loop_a, name="gammainc_grad_a" ) @@ -827,7 +834,7 @@ def inner_loop_a(sum_a, delta, xpow, k_minus_one_minus_n, fac, dfac, x): init = [sum_a0, delta, xpow, k_minus_one_minus_n, fac, dfac] constant = [x] - sum_a = _make_scalar_loop( + sum_a, *_ = _make_scalar_loop( n_steps, init, constant, inner_loop_a, name="gammaincc_grad_a" ) grad_approx_a = ( @@ -870,7 +877,7 @@ def inner_loop_b(sum_b, log_s, s_sign, log_delta, n, k, log_x): init = [sum_b0, log_s, s_sign, log_delta, n] constant = [k, log_x] - sum_b = _make_scalar_loop( + sum_b, *_ = _make_scalar_loop( max_iters, init, constant, inner_loop_b, name="gammaincc_grad_b" ) grad_approx_b = ( @@ -1540,7 +1547,7 @@ def inner_loop( init = [derivative, Am2, Am1, Bm2, Bm1, dAm2, dAm1, dBm2, dBm1, n] constant = [f, p, q, K, dK] - grad = _make_scalar_loop( + grad, *_ = _make_scalar_loop( max_iters, init, constant, inner_loop, name="betainc_grad" ) return grad @@ -1579,10 +1586,11 @@ def impl(self, a, b, c, z): def grad(self, inputs, grads): a, b, c, z = inputs (gz,) = grads + grad_a, grad_b, grad_c = hyp2f1_grad(a, b, c, z, wrt=[0, 1, 2]) return [ - gz * hyp2f1_grad(a, b, c, z, wrt=0), - gz * hyp2f1_grad(a, b, c, z, wrt=1), - gz * hyp2f1_grad(a, b, c, z, wrt=2), + gz * grad_a, + gz * grad_b, + gz * grad_c, gz * ((a * b) / c) * hyp2f1(a + 1, b + 1, c + 1, z), ] @@ -1598,92 +1606,55 @@ def _unsafe_sign(x): return switch(x > 0, 1, -1) -def hyp2f1_grad(a, b, c, z, wrt: int): - dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32") - - def check_2f1_converges(a, b, c, z): - def is_nonpositive_integer(x): - if x.type.dtype not in integer_types: - return eq(floor(x), x) & (x <= 0) - else: - return x <= 0 +class Grad2F1Loop(ScalarLoop): + """Subclass of ScalarLoop for easier targetting in rewrites""" - a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0) - num_terms = switch( - a_is_polynomial, - floor(scalar_abs(a)).astype("int64"), - 0, - ) - b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms) - num_terms = switch( - b_is_polynomial, - floor(scalar_abs(b)).astype("int64"), - num_terms, - ) +def _grad_2f1_loop(a, b, c, z, *, skip_loop, wrt, dtype): + """ + Notes + ----- + The algorithm can be derived by looking at the ratio of two successive terms in the series + β_{k+1}/β_{k} = A(k)/B(k) + β_{k+1} = A(k)/B(k) * β_{k} + d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule - is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms) - is_polynomial = a_is_polynomial | b_is_polynomial + In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z - return (~is_undefined) & ( - is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b))) - ) + The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k), + by dropping the respective term + d/da[A(k)/B(k)] = A(k)/B(k) / (a + k) + d/db[A(k)/B(k)] = A(k)/B(k) / (b + k) + d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k) - def compute_grad_2f1(a, b, c, z, wrt, skip_loop): - """ - Notes - ----- - The algorithm can be derived by looking at the ratio of two successive terms in the series - β_{k+1}/β_{k} = A(k)/B(k) - β_{k+1} = A(k)/B(k) * β_{k} - d[β_{k+1}] = d[A(k)/B(k)] * β_{k} + A(k)/B(k) * d[β_{k}] via the product rule - - In the 2F1, A(k)/B(k) corresponds to (((a + k) * (b + k) / ((c + k) (1 + k))) * z - - The partial d[A(k)/B(k)] with respect to the 3 first inputs can be obtained from the ratio A(k)/B(k), - by dropping the respective term - d/da[A(k)/B(k)] = A(k)/B(k) / (a + k) - d/db[A(k)/B(k)] = A(k)/B(k) / (b + k) - d/dc[A(k)/B(k)] = A(k)/B(k) * (c + k) - - The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and - tracking their signs. - """ - - wrt_a = wrt_b = False - if wrt == 0: - wrt_a = True - elif wrt == 1: - wrt_b = True - elif wrt != 2: - raise ValueError(f"wrt must be 0, 1, or 2, got {wrt}") - - min_steps = np.array( - 10, dtype="int32" - ) # https://github.com/stan-dev/math/issues/2857 - max_steps = switch( - skip_loop, np.array(0, dtype="int32"), np.array(int(1e6), dtype="int32") - ) - precision = np.array(1e-14, dtype=config.floatX) + The algorithm is implemented in the log scale, which adds the complexity of working with absolute terms and + tracking their signs. + """ - grad = np.array(0, dtype=dtype) + min_steps = np.array( + 10, dtype="int32" + ) # https://github.com/stan-dev/math/issues/2857 + max_steps = switch( + skip_loop, np.array(0, dtype="int32"), np.array(int(1e6), dtype="int32") + ) + precision = np.array(1e-14, dtype=config.floatX) - log_g = np.array(-np.inf, dtype=dtype) - log_g_sign = np.array(1, dtype="int8") + grads = [np.array(0, dtype=dtype) if i in wrt else None for i in range(3)] + log_gs = [np.array(-np.inf, dtype=dtype) if i in wrt else None for i in range(3)] + log_gs_signs = [np.array(1, dtype="int8") if i in wrt else None for i in range(3)] - log_t = np.array(0.0, dtype=dtype) - log_t_sign = np.array(1, dtype="int8") + log_t = np.array(0.0, dtype=dtype) + log_t_sign = np.array(1, dtype="int8") - log_z = log(scalar_abs(z)) - sign_z = _unsafe_sign(z) + log_z = log(scalar_abs(z)) + sign_z = _unsafe_sign(z) - sign_zk = sign_z - k = np.array(0, dtype="int32") + sign_zk = sign_z + k = np.array(0, dtype="int32") - def inner_loop( - grad, - log_g, - log_g_sign, + def inner_loop(*args): + ( + *grads_vars, log_t, log_t_sign, sign_zk, @@ -1693,65 +1664,147 @@ def inner_loop( c, log_z, sign_z, - ): - p = (a + k) * (b + k) / ((c + k) * (k + 1)) - if p.type.dtype != dtype: - p = p.astype(dtype) - - term = log_g_sign * log_t_sign * exp(log_g - log_t) - if wrt_a: - term += reciprocal(a + k) - elif wrt_b: - term += reciprocal(b + k) - else: - term -= reciprocal(c + k) + ) = args + + ( + grad_a, + grad_b, + grad_c, + log_g_a, + log_g_b, + log_g_c, + log_g_sign_a, + log_g_sign_b, + log_g_sign_c, + ) = grads_vars + + p = (a + k) * (b + k) / ((c + k) * (k + 1)) + if p.type.dtype != dtype: + p = p.astype(dtype) + + # If p==0, don't update grad and get out of while loop next + p_zero = eq(p, 0) + + if 0 in wrt: + term_a = log_g_sign_a * log_t_sign * exp(log_g_a - log_t) + term_a += reciprocal(a + k) + if term_a.type.dtype != dtype: + term_a = term_a.astype(dtype) + if 1 in wrt: + term_b = log_g_sign_b * log_t_sign * exp(log_g_b - log_t) + term_b += reciprocal(b + k) + if term_b.type.dtype != dtype: + term_b = term_b.astype(dtype) + if 2 in wrt: + term_c = log_g_sign_c * log_t_sign * exp(log_g_c - log_t) + term_c -= reciprocal(c + k) + if term_c.type.dtype != dtype: + term_c = term_c.astype(dtype) + + log_t = log_t + log(scalar_abs(p)) + log_z + log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8") + + grads = [None] * 3 + log_gs = [None] * 3 + log_gs_signs = [None] * 3 + grad_incs = [None] * 3 + + if 0 in wrt: + log_g_a = log_t + log(scalar_abs(term_a)) + log_g_sign_a = (_unsafe_sign(term_a) * log_t_sign).astype("int8") + grad_inc_a = log_g_sign_a * exp(log_g_a) * sign_zk + grads[0] = switch(p_zero, grad_a, grad_a + grad_inc_a) + log_gs[0] = log_g_a + log_gs_signs[0] = log_g_sign_a + grad_incs[0] = grad_inc_a + if 1 in wrt: + log_g_b = log_t + log(scalar_abs(term_b)) + log_g_sign_b = (_unsafe_sign(term_b) * log_t_sign).astype("int8") + grad_inc_b = log_g_sign_b * exp(log_g_b) * sign_zk + grads[1] = switch(p_zero, grad_b, grad_b + grad_inc_b) + log_gs[1] = log_g_b + log_gs_signs[1] = log_g_sign_b + grad_incs[1] = grad_inc_b + if 2 in wrt: + log_g_c = log_t + log(scalar_abs(term_c)) + log_g_sign_c = (_unsafe_sign(term_c) * log_t_sign).astype("int8") + grad_inc_c = log_g_sign_c * exp(log_g_c) * sign_zk + grads[2] = switch(p_zero, grad_c, grad_c + grad_inc_c) + log_gs[2] = log_g_c + log_gs_signs[2] = log_g_sign_c + grad_incs[2] = grad_inc_c + + sign_zk *= sign_z + k += 1 + + abs_grad_incs = [ + scalar_abs(grad_inc) for grad_inc in grad_incs if grad_inc is not None + ] + if len(grad_incs) == 1: + [max_abs_grad_inc] = grad_incs + else: + max_abs_grad_inc = reduce(scalar_maximum, abs_grad_incs) - if term.type.dtype != dtype: - term = term.astype(dtype) + return ( + (*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k), + (eq(p, 0) | ((k > min_steps) & (max_abs_grad_inc <= precision))), + ) - log_t = log_t + log(scalar_abs(p)) + log_z - log_t_sign = (_unsafe_sign(p) * log_t_sign).astype("int8") - log_g = log_t + log(scalar_abs(term)) - log_g_sign = (_unsafe_sign(term) * log_t_sign).astype("int8") + init = [*grads, *log_gs, *log_gs_signs, log_t, log_t_sign, sign_zk, k] + constant = [a, b, c, log_z, sign_z] + loop_outs = _make_scalar_loop( + max_steps, init, constant, inner_loop, name="hyp2f1_grad", loop_op=Grad2F1Loop + ) + return loop_outs[: len(wrt)] - g_current = log_g_sign * exp(log_g) * sign_zk - # If p==0, don't update grad and get out of while loop next - grad = switch( - eq(p, 0), - grad, - grad + g_current, - ) +def hyp2f1_grad(a, b, c, z, wrt: Tuple[int, ...]): + dtype = upcast(a.type.dtype, b.type.dtype, c.type.dtype, z.type.dtype, "float32") - sign_zk *= sign_z - k += 1 + def check_2f1_converges(a, b, c, z): + def is_nonpositive_integer(x): + if x.type.dtype not in integer_types: + return eq(floor(x), x) & (x <= 0) + else: + return x <= 0 - return ( - (grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k), - (eq(p, 0) | ((k > min_steps) & (scalar_abs(g_current) <= precision))), - ) + a_is_polynomial = is_nonpositive_integer(a) & (scalar_abs(a) >= 0) + num_terms = switch( + a_is_polynomial, + floor(scalar_abs(a)).astype("int64"), + 0, + ) - init = [grad, log_g, log_g_sign, log_t, log_t_sign, sign_zk, k] - constant = [a, b, c, log_z, sign_z] - grad = _make_scalar_loop( - max_steps, init, constant, inner_loop, name="hyp2f1_grad" + b_is_polynomial = is_nonpositive_integer(b) & (scalar_abs(b) >= num_terms) + num_terms = switch( + b_is_polynomial, + floor(scalar_abs(b)).astype("int64"), + num_terms, ) - return switch( - eq(z, 0), - 0, - grad, + is_undefined = is_nonpositive_integer(c) & (scalar_abs(c) <= num_terms) + is_polynomial = a_is_polynomial | b_is_polynomial + + return (~is_undefined) & ( + is_polynomial | (scalar_abs(z) < 1) | (eq(scalar_abs(z), 1) & (c > (a + b))) ) # We have to pass the converges flag to interrupt the loop, as the switch is not lazy z_is_zero = eq(z, 0) converges = check_2f1_converges(a, b, c, z) - return switch( - z_is_zero, - 0, - switch( - converges, - compute_grad_2f1(a, b, c, z, wrt, skip_loop=z_is_zero | (~converges)), - np.nan, - ), + grads = _grad_2f1_loop( + a, b, c, z, skip_loop=z_is_zero | (~converges), wrt=wrt, dtype=dtype ) + + return [ + switch( + z_is_zero, + 0, + switch( + converges, + grad, + np.nan, + ), + ) + for grad in grads + ] diff --git a/pytensor/tensor/rewriting/elemwise.py b/pytensor/tensor/rewriting/elemwise.py index 03308b9983..d5a5a8eb9f 100644 --- a/pytensor/tensor/rewriting/elemwise.py +++ b/pytensor/tensor/rewriting/elemwise.py @@ -23,6 +23,7 @@ from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.scalar.loop import ScalarLoop +from pytensor.scalar.math import Grad2F1Loop, _grad_2f1_loop from pytensor.tensor.basic import ( MakeVector, alloc, @@ -31,6 +32,7 @@ ) from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.exceptions import NotScalarConstantError +from pytensor.tensor.math import exp from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize from pytensor.tensor.shape import shape_padleft from pytensor.tensor.var import TensorConstant @@ -1215,3 +1217,61 @@ def local_careduce_fusion(fgraph, node): "fusion", position=49, ) + + +@register_specialize +@node_rewriter([Elemwise]) +def local_useless_2f1grad_loop(fgraph, node): + # Remove unused terms from the hyp2f1 grad loop + + loop_op = node.op.scalar_op + if not isinstance(loop_op, Grad2F1Loop): + return + + grad_related_vars = node.outputs[:-4] + # Rewrite was already applied + if len(grad_related_vars) // 3 != 3: + return None + + grad_vars = grad_related_vars[:3] + grad_var_is_used = [bool(fgraph.clients.get(v)) for v in grad_vars] + + # Nothing to do here + if sum(grad_var_is_used) == 3: + return None + + # Check that None of the remaining vars is used anywhere + if any(bool(fgraph.clients.get(v)) for v in node.outputs[3:]): + return None + + a, b, c, log_z, sign_z = node.inputs[-5:] + z = exp(log_z) * sign_z + + # Reconstruct scalar loop with relevant outputs + a_, b_, c_, z_ = (x.type.to_scalar_type()() for x in (a, b, c, z)) + wrt = [i for i, used in enumerate(grad_var_is_used) if used] + new_loop_op = _grad_2f1_loop( + a_, b_, c_, z_, skip_loop=False, wrt=wrt, dtype=a_.type.dtype + )[0].owner.op + + # Reconstruct elemwise loop + new_elemwise_op = Elemwise(scalar_op=new_loop_op) + n_steps = node.inputs[0] + init_grad_vars = node.inputs[1:10] + other_inputs = node.inputs[10:] + + init_grads = init_grad_vars[: len(wrt)] + init_gs = init_grad_vars[3 : 3 + len(wrt)] + init_gs_signs = init_grad_vars[6 : 6 + len(wrt)] + subset_init_grad_vars = init_grads + init_gs + init_gs_signs + + new_outs = new_elemwise_op(n_steps, *subset_init_grad_vars, *other_inputs) + + replacements = {} + i = 0 + for grad_var, is_used in zip(grad_vars, grad_var_is_used): + if not is_used: + continue + replacements[grad_var] = new_outs[i] + i += 1 + return replacements diff --git a/tests/tensor/test_math_scipy.py b/tests/tensor/test_math_scipy.py index 25e35ee0e6..d39f30b652 100644 --- a/tests/tensor/test_math_scipy.py +++ b/tests/tensor/test_math_scipy.py @@ -4,6 +4,8 @@ import pytest from pytensor.gradient import verify_grad +from pytensor.scalar import ScalarLoop +from pytensor.tensor.elemwise import Elemwise scipy = pytest.importorskip("scipy") @@ -1052,3 +1054,38 @@ def test_benchmark(self, case, wrt, benchmark): expected_result, rtol=rtol, ) + + @pytest.mark.parametrize("wrt", ([0], [1], [2], [0, 1], [1, 2], [0, 2], [0, 1, 2])) + def test_unused_grad_loop_opt(self, wrt): + """Test that we don't compute unnecessary outputs in the grad scalar loop""" + ( + test_a1, + test_a2, + test_b1, + test_z, + *expected_dds, + expected_ddz, + ) = self.few_iters_case + + a1, a2, b1, z = at.scalars("a1", "a2", "b1", "z") + hyp2f1_out = at.hyp2f1(a1, a2, b1, z) + wrt_vars = [v for i, v in enumerate((a1, a2, b1, z)) if i in wrt] + hyp2f1_grad = at.grad(hyp2f1_out, wrt=wrt_vars) + + mode = get_default_mode().including("local_useless_2f1grad_loop") + f_grad = function([a1, a2, b1, z], hyp2f1_grad, mode=mode) + + [scalar_loop_op] = [ + node.op.scalar_op + for node in f_grad.maker.fgraph.apply_nodes + if isinstance(node.op, Elemwise) + and isinstance(node.op.scalar_op, ScalarLoop) + ] + assert scalar_loop_op.nin == 10 + 3 * len(wrt) + + rtol = 1e-9 if config.floatX == "float64" else 2e-3 + np.testing.assert_allclose( + f_grad(test_a1, test_a2, test_b1, test_z), + [dd for i, dd in enumerate(expected_dds) if i in wrt], + rtol=rtol, + )