Skip to content

Commit 5c26a6c

Browse files
committed
Refactor baseclass ScalarInnerGraphOp from Composite Op
1 parent 37fb461 commit 5c26a6c

File tree

1 file changed

+149
-141
lines changed

1 file changed

+149
-141
lines changed

pytensor/scalar/basic.py

Lines changed: 149 additions & 141 deletions
Original file line numberDiff line numberDiff line change
@@ -3986,7 +3986,150 @@ def c_code(self, *args, **kwargs):
39863986
complex_from_polar = ComplexFromPolar(name="complex_from_polar")
39873987

39883988

3989-
class Composite(ScalarOp, HasInnerGraph):
3989+
class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
3990+
"""Includes boilerplate code for Python and C-implementation of Scalar Ops with inner graph."""
3991+
3992+
def __init__(self, *args, **kwargs):
3993+
self.prepare_node_called = set()
3994+
3995+
@property
3996+
def fn(self):
3997+
return None
3998+
3999+
@property
4000+
def inner_inputs(self):
4001+
return self.fgraph.inputs
4002+
4003+
@property
4004+
def inner_outputs(self):
4005+
return self.fgraph.outputs
4006+
4007+
@property
4008+
def py_perform_fn(self):
4009+
if hasattr(self, "_py_perform_fn"):
4010+
return self._py_perform_fn
4011+
4012+
from pytensor.link.utils import fgraph_to_python
4013+
4014+
def python_convert(op, node=None, **kwargs):
4015+
assert node is not None
4016+
4017+
n_outs = len(node.outputs)
4018+
4019+
if n_outs > 1:
4020+
4021+
def _perform(*inputs, outputs=[[None]] * n_outs):
4022+
op.perform(node, inputs, outputs)
4023+
return tuple(o[0] for o in outputs)
4024+
4025+
else:
4026+
4027+
def _perform(*inputs, outputs=[[None]]):
4028+
op.perform(node, inputs, outputs)
4029+
return outputs[0][0]
4030+
4031+
return _perform
4032+
4033+
self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert)
4034+
return self._py_perform_fn
4035+
4036+
def impl(self, *inputs):
4037+
output_storage = [[None] for i in range(self.nout)]
4038+
self.perform(None, inputs, output_storage)
4039+
ret = to_return_values([storage[0] for storage in output_storage])
4040+
if self.nout > 1:
4041+
ret = tuple(ret)
4042+
return ret
4043+
4044+
def c_code_cache_version(self):
4045+
rval = list(self.c_code_cache_version_outer())
4046+
for x in self.fgraph.toposort():
4047+
xv = x.op.c_code_cache_version()
4048+
if xv:
4049+
rval.append(xv)
4050+
else:
4051+
return ()
4052+
return tuple(rval)
4053+
4054+
def c_header_dirs(self, **kwargs):
4055+
rval = sum(
4056+
(subnode.op.c_header_dirs(**kwargs) for subnode in self.fgraph.toposort()),
4057+
[],
4058+
)
4059+
return rval
4060+
4061+
def c_support_code(self, **kwargs):
4062+
# Remove duplicate code blocks by using a `set`
4063+
rval = {
4064+
subnode.op.c_support_code(**kwargs).strip()
4065+
for subnode in self.fgraph.toposort()
4066+
}
4067+
return "\n".join(sorted(rval))
4068+
4069+
def c_support_code_apply(self, node, name):
4070+
rval = []
4071+
for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames):
4072+
subnode_support_code = subnode.op.c_support_code_apply(
4073+
subnode, subnodename % dict(nodename=name)
4074+
)
4075+
if subnode_support_code:
4076+
rval.append(subnode_support_code)
4077+
# there should be no need to remove duplicate code blocks because
4078+
# each block should have been specialized for the given nodename.
4079+
# Any block that isn't specialized should be returned via
4080+
# c_support_code instead of c_support_code_apply.
4081+
return "\n".join(rval)
4082+
4083+
def prepare_node(self, node, storage_map, compute_map, impl):
4084+
if impl not in self.prepare_node_called:
4085+
for n in list_of_nodes(self.inputs, self.outputs):
4086+
n.op.prepare_node(n, None, None, impl)
4087+
self.prepare_node_called.add(impl)
4088+
4089+
def __eq__(self, other):
4090+
if self is other:
4091+
return True
4092+
if (
4093+
type(self) != type(other)
4094+
or self.nin != other.nin
4095+
or self.nout != other.nout
4096+
):
4097+
return False
4098+
4099+
# TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this
4100+
# object to generate the same `_c_code`?
4101+
return self.c_code_template == other.c_code_template
4102+
4103+
def __hash__(self):
4104+
# Note that in general, the configparser settings at the time
4105+
# of code generation (__init__) affect the semantics of this Op.
4106+
# This function assumes that all relevant info about the configparser
4107+
# is embodied in _c_code. So the _c_code, rather than self.fgraph,
4108+
# is the signature of the semantics of this Op.
4109+
# _c_code is preserved through unpickling, so the Op will not change
4110+
# semantics when it is reloaded with different configparser
4111+
# settings.
4112+
#
4113+
# TODO FIXME: Doesn't the above just mean that we should be including
4114+
# the relevant "configparser settings" here? Also, why should we even
4115+
# care about the exact form of the generated C code when comparing
4116+
# `Op`s? All this smells of leaky concerns and interfaces.
4117+
return hash((type(self), self.nin, self.nout, self.c_code_template))
4118+
4119+
def __getstate__(self):
4120+
rval = dict(self.__dict__)
4121+
rval.pop("_c_code", None)
4122+
rval.pop("_py_perform_fn", None)
4123+
rval.pop("_fgraph", None)
4124+
rval.pop("prepare_node_called", None)
4125+
return rval
4126+
4127+
def __setstate__(self, d):
4128+
self.__dict__.update(d)
4129+
self.prepare_node_called = set()
4130+
4131+
4132+
class Composite(ScalarInnerGraphOp):
39904133
"""
39914134
Composite is an Op that takes a graph of scalar operations and
39924135
produces c code for the whole graph. Its purpose is to implement loop
@@ -4001,7 +4144,7 @@ class Composite(ScalarOp, HasInnerGraph):
40014144
def __init__(self, inputs, outputs, name="Composite"):
40024145
self.name = name
40034146
# We need to clone the graph as sometimes its nodes already
4004-
# contain a reference to an fgraph. As we want the Composite
4147+
# contain a reference to a fgraph. As we want the Composite
40054148
# to be pickable, we can't have reference to fgraph.
40064149

40074150
# Also, if there is Composite in the inner graph, we want to
@@ -4043,19 +4186,7 @@ def __init__(self, inputs, outputs, name="Composite"):
40434186
self.outputs_type = tuple([output.type for output in outputs])
40444187
self.nin = len(inputs)
40454188
self.nout = len(outputs)
4046-
self.prepare_node_called = set()
4047-
4048-
@property
4049-
def fn(self):
4050-
return None
4051-
4052-
@property
4053-
def inner_inputs(self):
4054-
return self.fgraph.inputs
4055-
4056-
@property
4057-
def inner_outputs(self):
4058-
return self.fgraph.outputs
4189+
super().__init__()
40594190

40604191
def __str__(self):
40614192
return self.name
@@ -4076,35 +4207,6 @@ def make_new_inplace(self, output_types_preference=None, name=None):
40764207
super(Composite, out).__init__(output_types_preference, name)
40774208
return out
40784209

4079-
@property
4080-
def py_perform(self):
4081-
if hasattr(self, "_py_perform_fn"):
4082-
return self._py_perform_fn
4083-
4084-
from pytensor.link.utils import fgraph_to_python
4085-
4086-
def python_convert(op, node=None, **kwargs):
4087-
assert node is not None
4088-
4089-
n_outs = len(node.outputs)
4090-
4091-
if n_outs > 1:
4092-
4093-
def _perform(*inputs, outputs=[[None]] * n_outs):
4094-
op.perform(node, inputs, outputs)
4095-
return tuple(o[0] for o in outputs)
4096-
4097-
else:
4098-
4099-
def _perform(*inputs, outputs=[[None]]):
4100-
op.perform(node, inputs, outputs)
4101-
return outputs[0][0]
4102-
4103-
return _perform
4104-
4105-
self._py_perform_fn = fgraph_to_python(self.fgraph, python_convert)
4106-
return self._py_perform_fn
4107-
41084210
@property
41094211
def fgraph(self):
41104212
if hasattr(self, "_fgraph"):
@@ -4139,12 +4241,6 @@ def fgraph(self):
41394241
self._fgraph = fgraph
41404242
return self._fgraph
41414243

4142-
def prepare_node(self, node, storage_map, compute_map, impl):
4143-
if impl not in self.prepare_node_called:
4144-
for n in list_of_nodes(self.inputs, self.outputs):
4145-
n.op.prepare_node(n, None, None, impl)
4146-
self.prepare_node_called.add(impl)
4147-
41484244
def clone_float32(self):
41494245
# This will not modify the fgraph or the nodes
41504246
new_ins, new_outs = composite_f32.apply(self.fgraph)
@@ -4155,8 +4251,6 @@ def clone(self):
41554251
return Composite(new_ins, new_outs)
41564252

41574253
def output_types(self, input_types):
4158-
# TODO FIXME: What's the intended purpose/use of this method, and why
4159-
# does it even need to be a method?
41604254
if tuple(input_types) != self.inputs_type:
41614255
raise TypeError(
41624256
f"Wrong types for Composite. Expected {self.inputs_type}, got {tuple(input_types)}."
@@ -4183,63 +4277,13 @@ def make_node(self, *inputs):
41834277
return node
41844278

41854279
def perform(self, node, inputs, output_storage):
4186-
outputs = self.py_perform(*inputs)
4280+
outputs = self.py_perform_fn(*inputs)
41874281
for storage, out_val in zip(output_storage, outputs):
41884282
storage[0] = out_val
41894283

4190-
def impl(self, *inputs):
4191-
output_storage = [[None] for i in range(self.nout)]
4192-
self.perform(None, inputs, output_storage)
4193-
ret = to_return_values([storage[0] for storage in output_storage])
4194-
if self.nout > 1:
4195-
ret = tuple(ret)
4196-
return ret
4197-
41984284
def grad(self, inputs, output_grads):
41994285
raise NotImplementedError("grad is not implemented for Composite")
42004286

4201-
def __eq__(self, other):
4202-
if self is other:
4203-
return True
4204-
if (
4205-
type(self) != type(other)
4206-
or self.nin != other.nin
4207-
or self.nout != other.nout
4208-
):
4209-
return False
4210-
4211-
# TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this
4212-
# object to generate the same `_c_code`?
4213-
return self.c_code_template == other.c_code_template
4214-
4215-
def __hash__(self):
4216-
# Note that in general, the configparser settings at the time
4217-
# of code generation (__init__) affect the semantics of this Op.
4218-
# This function assumes that all relevant info about the configparser
4219-
# is embodied in _c_code. So the _c_code, rather than self.fgraph,
4220-
# is the signature of the semantics of this Op.
4221-
# _c_code is preserved through unpickling, so the Op will not change
4222-
# semantics when it is reloaded with different configparser
4223-
# settings.
4224-
#
4225-
# TODO FIXME: Doesn't the above just mean that we should be including
4226-
# the relevant "configparser settings" here? Also, why should we even
4227-
# care about the exact form of the generated C code when comparing
4228-
# `Op`s? All this smells of leaky concerns and interfaces.
4229-
return hash((type(self), self.nin, self.nout, self.c_code_template))
4230-
4231-
def __getstate__(self):
4232-
rval = dict(self.__dict__)
4233-
rval.pop("_c_code", None)
4234-
rval.pop("_py_perform_fn", None)
4235-
rval.pop("_fgraph", None)
4236-
rval.pop("prepare_node_called", None)
4237-
return rval
4238-
4239-
def __setstate__(self, d):
4240-
self.__dict__.update(d)
4241-
self.prepare_node_called = set()
4242-
42434287
@property
42444288
def c_code_template(self):
42454289
from pytensor.link.c.interface import CLinkerType
@@ -4317,44 +4361,8 @@ def c_code(self, node, nodename, inames, onames, sub):
43174361

43184362
return self.c_code_template % d
43194363

4320-
def c_code_cache_version(self):
4321-
rval = [3]
4322-
for x in self.fgraph.toposort():
4323-
xv = x.op.c_code_cache_version()
4324-
if xv:
4325-
rval.append(xv)
4326-
else:
4327-
return ()
4328-
return tuple(rval)
4329-
4330-
def c_header_dirs(self, **kwargs):
4331-
rval = sum(
4332-
(subnode.op.c_header_dirs(**kwargs) for subnode in self.fgraph.toposort()),
4333-
[],
4334-
)
4335-
return rval
4336-
4337-
def c_support_code(self, **kwargs):
4338-
# Remove duplicate code blocks by using a `set`
4339-
rval = {
4340-
subnode.op.c_support_code(**kwargs).strip()
4341-
for subnode in self.fgraph.toposort()
4342-
}
4343-
return "\n".join(sorted(rval))
4344-
4345-
def c_support_code_apply(self, node, name):
4346-
rval = []
4347-
for subnode, subnodename in zip(self.fgraph.toposort(), self.nodenames):
4348-
subnode_support_code = subnode.op.c_support_code_apply(
4349-
subnode, subnodename % dict(nodename=name)
4350-
)
4351-
if subnode_support_code:
4352-
rval.append(subnode_support_code)
4353-
# there should be no need to remove duplicate code blocks because
4354-
# each block should have been specialized for the given nodename.
4355-
# Any block that isn't specialized should be returned via
4356-
# c_support_code instead of c_support_code_apply.
4357-
return "\n".join(rval)
4364+
def c_code_cache_version_outer(self) -> Tuple[int, ...]:
4365+
return (3,)
43584366

43594367

43604368
class Compositef32:

0 commit comments

Comments
 (0)