@@ -3986,7 +3986,150 @@ def c_code(self, *args, **kwargs):
3986
3986
complex_from_polar = ComplexFromPolar (name = "complex_from_polar" )
3987
3987
3988
3988
3989
- class Composite (ScalarOp , HasInnerGraph ):
3989
+ class ScalarInnerGraphOp (ScalarOp , HasInnerGraph ):
3990
+ """Includes boilerplate code for Python and C-implementation of Scalar Ops with inner graph."""
3991
+
3992
+ def __init__ (self , * args , ** kwargs ):
3993
+ self .prepare_node_called = set ()
3994
+
3995
+ @property
3996
+ def fn (self ):
3997
+ return None
3998
+
3999
+ @property
4000
+ def inner_inputs (self ):
4001
+ return self .fgraph .inputs
4002
+
4003
+ @property
4004
+ def inner_outputs (self ):
4005
+ return self .fgraph .outputs
4006
+
4007
+ @property
4008
+ def py_perform_fn (self ):
4009
+ if hasattr (self , "_py_perform_fn" ):
4010
+ return self ._py_perform_fn
4011
+
4012
+ from pytensor .link .utils import fgraph_to_python
4013
+
4014
+ def python_convert (op , node = None , ** kwargs ):
4015
+ assert node is not None
4016
+
4017
+ n_outs = len (node .outputs )
4018
+
4019
+ if n_outs > 1 :
4020
+
4021
+ def _perform (* inputs , outputs = [[None ]] * n_outs ):
4022
+ op .perform (node , inputs , outputs )
4023
+ return tuple (o [0 ] for o in outputs )
4024
+
4025
+ else :
4026
+
4027
+ def _perform (* inputs , outputs = [[None ]]):
4028
+ op .perform (node , inputs , outputs )
4029
+ return outputs [0 ][0 ]
4030
+
4031
+ return _perform
4032
+
4033
+ self ._py_perform_fn = fgraph_to_python (self .fgraph , python_convert )
4034
+ return self ._py_perform_fn
4035
+
4036
+ def impl (self , * inputs ):
4037
+ output_storage = [[None ] for i in range (self .nout )]
4038
+ self .perform (None , inputs , output_storage )
4039
+ ret = to_return_values ([storage [0 ] for storage in output_storage ])
4040
+ if self .nout > 1 :
4041
+ ret = tuple (ret )
4042
+ return ret
4043
+
4044
+ def c_code_cache_version (self ):
4045
+ rval = list (self .c_code_cache_version_outer ())
4046
+ for x in self .fgraph .toposort ():
4047
+ xv = x .op .c_code_cache_version ()
4048
+ if xv :
4049
+ rval .append (xv )
4050
+ else :
4051
+ return ()
4052
+ return tuple (rval )
4053
+
4054
+ def c_header_dirs (self , ** kwargs ):
4055
+ rval = sum (
4056
+ (subnode .op .c_header_dirs (** kwargs ) for subnode in self .fgraph .toposort ()),
4057
+ [],
4058
+ )
4059
+ return rval
4060
+
4061
+ def c_support_code (self , ** kwargs ):
4062
+ # Remove duplicate code blocks by using a `set`
4063
+ rval = {
4064
+ subnode .op .c_support_code (** kwargs ).strip ()
4065
+ for subnode in self .fgraph .toposort ()
4066
+ }
4067
+ return "\n " .join (sorted (rval ))
4068
+
4069
+ def c_support_code_apply (self , node , name ):
4070
+ rval = []
4071
+ for subnode , subnodename in zip (self .fgraph .toposort (), self .nodenames ):
4072
+ subnode_support_code = subnode .op .c_support_code_apply (
4073
+ subnode , subnodename % dict (nodename = name )
4074
+ )
4075
+ if subnode_support_code :
4076
+ rval .append (subnode_support_code )
4077
+ # there should be no need to remove duplicate code blocks because
4078
+ # each block should have been specialized for the given nodename.
4079
+ # Any block that isn't specialized should be returned via
4080
+ # c_support_code instead of c_support_code_apply.
4081
+ return "\n " .join (rval )
4082
+
4083
+ def prepare_node (self , node , storage_map , compute_map , impl ):
4084
+ if impl not in self .prepare_node_called :
4085
+ for n in list_of_nodes (self .inputs , self .outputs ):
4086
+ n .op .prepare_node (n , None , None , impl )
4087
+ self .prepare_node_called .add (impl )
4088
+
4089
+ def __eq__ (self , other ):
4090
+ if self is other :
4091
+ return True
4092
+ if (
4093
+ type (self ) != type (other )
4094
+ or self .nin != other .nin
4095
+ or self .nout != other .nout
4096
+ ):
4097
+ return False
4098
+
4099
+ # TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this
4100
+ # object to generate the same `_c_code`?
4101
+ return self .c_code_template == other .c_code_template
4102
+
4103
+ def __hash__ (self ):
4104
+ # Note that in general, the configparser settings at the time
4105
+ # of code generation (__init__) affect the semantics of this Op.
4106
+ # This function assumes that all relevant info about the configparser
4107
+ # is embodied in _c_code. So the _c_code, rather than self.fgraph,
4108
+ # is the signature of the semantics of this Op.
4109
+ # _c_code is preserved through unpickling, so the Op will not change
4110
+ # semantics when it is reloaded with different configparser
4111
+ # settings.
4112
+ #
4113
+ # TODO FIXME: Doesn't the above just mean that we should be including
4114
+ # the relevant "configparser settings" here? Also, why should we even
4115
+ # care about the exact form of the generated C code when comparing
4116
+ # `Op`s? All this smells of leaky concerns and interfaces.
4117
+ return hash ((type (self ), self .nin , self .nout , self .c_code_template ))
4118
+
4119
+ def __getstate__ (self ):
4120
+ rval = dict (self .__dict__ )
4121
+ rval .pop ("_c_code" , None )
4122
+ rval .pop ("_py_perform_fn" , None )
4123
+ rval .pop ("_fgraph" , None )
4124
+ rval .pop ("prepare_node_called" , None )
4125
+ return rval
4126
+
4127
+ def __setstate__ (self , d ):
4128
+ self .__dict__ .update (d )
4129
+ self .prepare_node_called = set ()
4130
+
4131
+
4132
+ class Composite (ScalarInnerGraphOp ):
3990
4133
"""
3991
4134
Composite is an Op that takes a graph of scalar operations and
3992
4135
produces c code for the whole graph. Its purpose is to implement loop
@@ -4001,7 +4144,7 @@ class Composite(ScalarOp, HasInnerGraph):
4001
4144
def __init__ (self , inputs , outputs , name = "Composite" ):
4002
4145
self .name = name
4003
4146
# We need to clone the graph as sometimes its nodes already
4004
- # contain a reference to an fgraph. As we want the Composite
4147
+ # contain a reference to a fgraph. As we want the Composite
4005
4148
# to be pickable, we can't have reference to fgraph.
4006
4149
4007
4150
# Also, if there is Composite in the inner graph, we want to
@@ -4043,19 +4186,7 @@ def __init__(self, inputs, outputs, name="Composite"):
4043
4186
self .outputs_type = tuple ([output .type for output in outputs ])
4044
4187
self .nin = len (inputs )
4045
4188
self .nout = len (outputs )
4046
- self .prepare_node_called = set ()
4047
-
4048
- @property
4049
- def fn (self ):
4050
- return None
4051
-
4052
- @property
4053
- def inner_inputs (self ):
4054
- return self .fgraph .inputs
4055
-
4056
- @property
4057
- def inner_outputs (self ):
4058
- return self .fgraph .outputs
4189
+ super ().__init__ ()
4059
4190
4060
4191
def __str__ (self ):
4061
4192
return self .name
@@ -4076,35 +4207,6 @@ def make_new_inplace(self, output_types_preference=None, name=None):
4076
4207
super (Composite , out ).__init__ (output_types_preference , name )
4077
4208
return out
4078
4209
4079
- @property
4080
- def py_perform (self ):
4081
- if hasattr (self , "_py_perform_fn" ):
4082
- return self ._py_perform_fn
4083
-
4084
- from pytensor .link .utils import fgraph_to_python
4085
-
4086
- def python_convert (op , node = None , ** kwargs ):
4087
- assert node is not None
4088
-
4089
- n_outs = len (node .outputs )
4090
-
4091
- if n_outs > 1 :
4092
-
4093
- def _perform (* inputs , outputs = [[None ]] * n_outs ):
4094
- op .perform (node , inputs , outputs )
4095
- return tuple (o [0 ] for o in outputs )
4096
-
4097
- else :
4098
-
4099
- def _perform (* inputs , outputs = [[None ]]):
4100
- op .perform (node , inputs , outputs )
4101
- return outputs [0 ][0 ]
4102
-
4103
- return _perform
4104
-
4105
- self ._py_perform_fn = fgraph_to_python (self .fgraph , python_convert )
4106
- return self ._py_perform_fn
4107
-
4108
4210
@property
4109
4211
def fgraph (self ):
4110
4212
if hasattr (self , "_fgraph" ):
@@ -4139,12 +4241,6 @@ def fgraph(self):
4139
4241
self ._fgraph = fgraph
4140
4242
return self ._fgraph
4141
4243
4142
- def prepare_node (self , node , storage_map , compute_map , impl ):
4143
- if impl not in self .prepare_node_called :
4144
- for n in list_of_nodes (self .inputs , self .outputs ):
4145
- n .op .prepare_node (n , None , None , impl )
4146
- self .prepare_node_called .add (impl )
4147
-
4148
4244
def clone_float32 (self ):
4149
4245
# This will not modify the fgraph or the nodes
4150
4246
new_ins , new_outs = composite_f32 .apply (self .fgraph )
@@ -4155,8 +4251,6 @@ def clone(self):
4155
4251
return Composite (new_ins , new_outs )
4156
4252
4157
4253
def output_types (self , input_types ):
4158
- # TODO FIXME: What's the intended purpose/use of this method, and why
4159
- # does it even need to be a method?
4160
4254
if tuple (input_types ) != self .inputs_type :
4161
4255
raise TypeError (
4162
4256
f"Wrong types for Composite. Expected { self .inputs_type } , got { tuple (input_types )} ."
@@ -4183,63 +4277,13 @@ def make_node(self, *inputs):
4183
4277
return node
4184
4278
4185
4279
def perform (self , node , inputs , output_storage ):
4186
- outputs = self .py_perform (* inputs )
4280
+ outputs = self .py_perform_fn (* inputs )
4187
4281
for storage , out_val in zip (output_storage , outputs ):
4188
4282
storage [0 ] = out_val
4189
4283
4190
- def impl (self , * inputs ):
4191
- output_storage = [[None ] for i in range (self .nout )]
4192
- self .perform (None , inputs , output_storage )
4193
- ret = to_return_values ([storage [0 ] for storage in output_storage ])
4194
- if self .nout > 1 :
4195
- ret = tuple (ret )
4196
- return ret
4197
-
4198
4284
def grad (self , inputs , output_grads ):
4199
4285
raise NotImplementedError ("grad is not implemented for Composite" )
4200
4286
4201
- def __eq__ (self , other ):
4202
- if self is other :
4203
- return True
4204
- if (
4205
- type (self ) != type (other )
4206
- or self .nin != other .nin
4207
- or self .nout != other .nout
4208
- ):
4209
- return False
4210
-
4211
- # TODO FIXME: Why this? Shouldn't we expect equivalent inputs to this
4212
- # object to generate the same `_c_code`?
4213
- return self .c_code_template == other .c_code_template
4214
-
4215
- def __hash__ (self ):
4216
- # Note that in general, the configparser settings at the time
4217
- # of code generation (__init__) affect the semantics of this Op.
4218
- # This function assumes that all relevant info about the configparser
4219
- # is embodied in _c_code. So the _c_code, rather than self.fgraph,
4220
- # is the signature of the semantics of this Op.
4221
- # _c_code is preserved through unpickling, so the Op will not change
4222
- # semantics when it is reloaded with different configparser
4223
- # settings.
4224
- #
4225
- # TODO FIXME: Doesn't the above just mean that we should be including
4226
- # the relevant "configparser settings" here? Also, why should we even
4227
- # care about the exact form of the generated C code when comparing
4228
- # `Op`s? All this smells of leaky concerns and interfaces.
4229
- return hash ((type (self ), self .nin , self .nout , self .c_code_template ))
4230
-
4231
- def __getstate__ (self ):
4232
- rval = dict (self .__dict__ )
4233
- rval .pop ("_c_code" , None )
4234
- rval .pop ("_py_perform_fn" , None )
4235
- rval .pop ("_fgraph" , None )
4236
- rval .pop ("prepare_node_called" , None )
4237
- return rval
4238
-
4239
- def __setstate__ (self , d ):
4240
- self .__dict__ .update (d )
4241
- self .prepare_node_called = set ()
4242
-
4243
4287
@property
4244
4288
def c_code_template (self ):
4245
4289
from pytensor .link .c .interface import CLinkerType
@@ -4317,44 +4361,8 @@ def c_code(self, node, nodename, inames, onames, sub):
4317
4361
4318
4362
return self .c_code_template % d
4319
4363
4320
- def c_code_cache_version (self ):
4321
- rval = [3 ]
4322
- for x in self .fgraph .toposort ():
4323
- xv = x .op .c_code_cache_version ()
4324
- if xv :
4325
- rval .append (xv )
4326
- else :
4327
- return ()
4328
- return tuple (rval )
4329
-
4330
- def c_header_dirs (self , ** kwargs ):
4331
- rval = sum (
4332
- (subnode .op .c_header_dirs (** kwargs ) for subnode in self .fgraph .toposort ()),
4333
- [],
4334
- )
4335
- return rval
4336
-
4337
- def c_support_code (self , ** kwargs ):
4338
- # Remove duplicate code blocks by using a `set`
4339
- rval = {
4340
- subnode .op .c_support_code (** kwargs ).strip ()
4341
- for subnode in self .fgraph .toposort ()
4342
- }
4343
- return "\n " .join (sorted (rval ))
4344
-
4345
- def c_support_code_apply (self , node , name ):
4346
- rval = []
4347
- for subnode , subnodename in zip (self .fgraph .toposort (), self .nodenames ):
4348
- subnode_support_code = subnode .op .c_support_code_apply (
4349
- subnode , subnodename % dict (nodename = name )
4350
- )
4351
- if subnode_support_code :
4352
- rval .append (subnode_support_code )
4353
- # there should be no need to remove duplicate code blocks because
4354
- # each block should have been specialized for the given nodename.
4355
- # Any block that isn't specialized should be returned via
4356
- # c_support_code instead of c_support_code_apply.
4357
- return "\n " .join (rval )
4364
+ def c_code_cache_version_outer (self ) -> Tuple [int , ...]:
4365
+ return (3 ,)
4358
4366
4359
4367
4360
4368
class Compositef32 :
0 commit comments