Skip to content

Implement scalar loop for iterative gradients #283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
May 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1115,12 +1115,12 @@ def truncated_graph_inputs(


def clone(
inputs: List[Variable],
outputs: List[Variable],
inputs: Sequence[Variable],
outputs: Sequence[Variable],
copy_inputs: bool = True,
copy_orphans: Optional[bool] = None,
clone_inner_graphs: bool = False,
) -> Tuple[Collection[Variable], Collection[Variable]]:
) -> Tuple[List[Variable], List[Variable]]:
r"""Copies the sub-graph contained between inputs and outputs.

Parameters
Expand Down
288 changes: 148 additions & 140 deletions pytensor/scalar/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3986,7 +3986,150 @@ def c_code(self, *args, **kwargs):
complex_from_polar = ComplexFromPolar(name="complex_from_polar")


class Composite(ScalarOp, HasInnerGraph):
class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
"""Includes boilerplate code for Python and C-implementation of Scalar Ops with inner graph."""

def __init__(self, *args, **kwargs):
self.prepare_node_called = set()

@property
def fn(self):
return None

@property
def inner_inputs(self):
return self.fgraph.inputs

@property
def inner_outputs(self):
return self.fgraph.outputs

@property
def py_perform_fn(self):
if hasattr(self, "_py_perform_fn"):
return self._py_perform_fn

from pytensor.link.utils import fgraph_to_python

def python_convert(op, node=None, **kwargs):
assert node is not None

n_outs = len(node.outputs)

if n_outs > 1:

def _perform(*inputs, outputs=[[None]] * n_outs):
op.perform(node, inputs, outputs)
return tuple(o[0] for o in outputs)

else:

def _perform(*inputs, outputs=[[None]]):
op.perform(node, inputs, outputs)
return outputs[0][0]

return _perform

self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert)
return self._py_perform_fn

def impl(self, *inputs):
output_storage = [[None] for i in range(self.nout)]
self.perform(None, inputs, output_storage)
ret = to_return_values([storage[0] for storage in output_storage])
if self.nout > 1:
ret = tuple(ret)
return ret

def c_code_cache_version(self):
rval = list(self.c_code_cache_version_outer())
for x in self.fgraph.toposort():
xv = x.op.c_code_cache_version()
if xv:
rval.append(xv)
else:
return ()
return tuple(rval)

def c_header_dirs(self, **kwargs):
rval = sum(
(subnode.op.c_header_dirs(**kwargs) for subnode in self.fgraph.toposort()),
[],
)
return rval

def c_support_code(self, **kwargs):
# Remove duplicate code blocks by using a `set`
rval = {
subnode.op.c_support_code(**kwargs).strip()
for subnode in self.fgraph.toposort()
}
return "\n".join(sorted(rval))

def c_support_code_apply(self, node, name):
rval = []
for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames):
subnode_support_code = subnode.op.c_support_code_apply(
subnode, subnodename % dict(nodename=name)
)
if subnode_support_code:
rval.append(subnode_support_code)
# there should be no need to remove duplicate code blocks because
# each block should have been specialized for the given nodename.
# Any block that isn't specialized should be returned via
# c_support_code instead of c_support_code_apply.
return "\n".join(rval)

def prepare_node(self, node, storage_map, compute_map, impl):
if impl not in self.prepare_node_called:
for n in list_of_nodes(self.inputs, self.outputs):
n.op.prepare_node(n, None, None, impl)
self.prepare_node_called.add(impl)

def __eq__(self, other):
if self is other:
return True
if (
type(self) != type(other)
or self.nin != other.nin
or self.nout != other.nout
):
return False

# TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this
# object to generate the same `_c_code`?
return self.c_code_template == other.c_code_template

def __hash__(self):
# Note that in general, the configparser settings at the time
# of code generation (__init__) affect the semantics of this Op.
# This function assumes that all relevant info about the configparser
# is embodied in _c_code. So the _c_code, rather than self.fgraph,
# is the signature of the semantics of this Op.
# _c_code is preserved through unpickling, so the Op will not change
# semantics when it is reloaded with different configparser
# settings.
#
# TODO FIXME: Doesn't the above just mean that we should be including
# the relevant "configparser settings" here? Also, why should we even
# care about the exact form of the generated C code when comparing
# `Op`s? All this smells of leaky concerns and interfaces.
return hash((type(self), self.nin, self.nout, self.c_code_template))

def __getstate__(self):
rval = dict(self.__dict__)
rval.pop("_c_code", None)
rval.pop("_py_perform_fn", None)
rval.pop("_fgraph", None)
rval.pop("prepare_node_called", None)
return rval

def __setstate__(self, d):
self.__dict__.update(d)
self.prepare_node_called = set()


class Composite(ScalarInnerGraphOp):
"""
Composite is an Op that takes a graph of scalar operations and
produces c code for the whole graph. Its purpose is to implement loop
Expand Down Expand Up @@ -4043,19 +4186,7 @@ def __init__(self, inputs, outputs, name="Composite"):
self.outputs_type = tuple([output.type for output in outputs])
self.nin = len(inputs)
self.nout = len(outputs)
self.prepare_node_called = set()

@property
def fn(self):
return None

@property
def inner_inputs(self):
return self.fgraph.inputs

@property
def inner_outputs(self):
return self.fgraph.outputs
super().__init__()

def __str__(self):
return self.name
Expand All @@ -4076,35 +4207,6 @@ def make_new_inplace(self, output_types_preference=None, name=None):
super(Composite, out).__init__(output_types_preference, name)
return out

@property
def py_perform(self):
if hasattr(self, "_py_perform_fn"):
return self._py_perform_fn

from pytensor.link.utils import fgraph_to_python

def python_convert(op, node=None, **kwargs):
assert node is not None

n_outs = len(node.outputs)

if n_outs > 1:

def _perform(*inputs, outputs=[[None]] * n_outs):
op.perform(node, inputs, outputs)
return tuple(o[0] for o in outputs)

else:

def _perform(*inputs, outputs=[[None]]):
op.perform(node, inputs, outputs)
return outputs[0][0]

return _perform

self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert)
return self._py_perform_fn

@property
def fgraph(self):
if hasattr(self, "_fgraph"):
Expand Down Expand Up @@ -4139,12 +4241,6 @@ def fgraph(self):
self._fgraph = fgraph
return self._fgraph

def prepare_node(self, node, storage_map, compute_map, impl):
if impl not in self.prepare_node_called:
for n in list_of_nodes(self.inputs, self.outputs):
n.op.prepare_node(n, None, None, impl)
self.prepare_node_called.add(impl)

def clone_float32(self):
# This will not modify the fgraph or the nodes
new_ins, new_outs = composite_f32.apply(self.fgraph)
Expand All @@ -4155,8 +4251,6 @@ def clone(self):
return Composite(new_ins, new_outs)

def output_types(self, input_types):
# TODO FIXME: What's the intended purpose/use of this method, and why
# does it even need to be a method?
if tuple(input_types) != self.inputs_type:
raise TypeError(
f"Wrong types for Composite. Expected {self.inputs_type}, got {tuple(input_types)}."
Expand All @@ -4183,63 +4277,13 @@ def make_node(self, *inputs):
return node

def perform(self, node, inputs, output_storage):
outputs = self.py_perform(*inputs)
outputs = self.py_perform_fn(*inputs)
for storage, out_val in zip(output_storage, outputs):
storage[0] = out_val

def impl(self, *inputs):
output_storage = [[None] for i in range(self.nout)]
self.perform(None, inputs, output_storage)
ret = to_return_values([storage[0] for storage in output_storage])
if self.nout > 1:
ret = tuple(ret)
return ret

def grad(self, inputs, output_grads):
raise NotImplementedError("grad is not implemented for Composite")

def __eq__(self, other):
if self is other:
return True
if (
type(self) != type(other)
or self.nin != other.nin
or self.nout != other.nout
):
return False

# TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this
# object to generate the same `_c_code`?
return self.c_code_template == other.c_code_template

def __hash__(self):
# Note that in general, the configparser settings at the time
# of code generation (__init__) affect the semantics of this Op.
# This function assumes that all relevant info about the configparser
# is embodied in _c_code. So the _c_code, rather than self.fgraph,
# is the signature of the semantics of this Op.
# _c_code is preserved through unpickling, so the Op will not change
# semantics when it is reloaded with different configparser
# settings.
#
# TODO FIXME: Doesn't the above just mean that we should be including
# the relevant "configparser settings" here? Also, why should we even
# care about the exact form of the generated C code when comparing
# `Op`s? All this smells of leaky concerns and interfaces.
return hash((type(self), self.nin, self.nout, self.c_code_template))

def __getstate__(self):
rval = dict(self.__dict__)
rval.pop("_c_code", None)
rval.pop("_py_perform_fn", None)
rval.pop("_fgraph", None)
rval.pop("prepare_node_called", None)
return rval

def __setstate__(self, d):
self.__dict__.update(d)
self.prepare_node_called = set()

@property
def c_code_template(self):
from pytensor.link.c.interface import CLinkerType
Expand Down Expand Up @@ -4317,44 +4361,8 @@ def c_code(self, node, nodename, inames, onames, sub):

return self.c_code_template % d

def c_code_cache_version(self):
rval = [3]
for x in self.fgraph.toposort():
xv = x.op.c_code_cache_version()
if xv:
rval.append(xv)
else:
return ()
return tuple(rval)

def c_header_dirs(self, **kwargs):
rval = sum(
(subnode.op.c_header_dirs(**kwargs) for subnode in self.fgraph.toposort()),
[],
)
return rval

def c_support_code(self, **kwargs):
# Remove duplicate code blocks by using a `set`
rval = {
subnode.op.c_support_code(**kwargs).strip()
for subnode in self.fgraph.toposort()
}
return "\n".join(sorted(rval))

def c_support_code_apply(self, node, name):
rval = []
for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames):
subnode_support_code = subnode.op.c_support_code_apply(
subnode, subnodename % dict(nodename=name)
)
if subnode_support_code:
rval.append(subnode_support_code)
# there should be no need to remove duplicate code blocks because
# each block should have been specialized for the given nodename.
# Any block that isn't specialized should be returned via
# c_support_code instead of c_support_code_apply.
return "\n".join(rval)
def c_code_cache_version_outer(self) -> Tuple[int, ...]:
return (3,)


class Compositef32:
Expand Down
Loading