Skip to content

Commit 35f0df9

Browse files
ricardoV94twiecki
authored andcommitted
Make params exclusive to COp's
Also removes them from the signature of perform
1 parent 5f84027 commit 35f0df9

File tree

20 files changed

+105
-132
lines changed

20 files changed

+105
-132
lines changed

doc/extending/using_params.rst

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
.. _extending_op_params:
22

3-
===============
4-
Using Op params
5-
===============
3+
================
4+
Using COp params
5+
================
66

7-
The Op params is a facility to pass some runtime parameters to the
7+
The COp params is a facility to pass some runtime parameters to the
88
code of an op without modifying it. It can enable a single instance
99
of C code to serve different needs and therefore reduce compilation.
1010

@@ -53,7 +53,7 @@ following methods will be used for the type:
5353
- :meth:`__hash__ <Type.__hash__>`
5454
- :meth:`values_eq <Type.values_eq>`
5555

56-
Additionally if you want to use your params with C code, you need to extend `COp`
56+
Additionally, to use your params with C code, you need to extend `COp`
5757
and implement the following methods:
5858

5959
- :meth:`c_declare <CLinkerType.c_declare>`
@@ -65,24 +65,24 @@ You can also define other convenience methods such as
6565
:meth:`c_headers <CLinkerType.c_headers>` if you need any special things.
6666

6767

68-
Registering the params with your Op
69-
-----------------------------------
68+
Registering the params with your COp
69+
------------------------------------
7070

71-
To declare that your Op uses params you have to set the class
71+
To declare that your `COp` uses params you have to set the class
7272
attribute :attr:`params_type` to an instance of your params Type.
7373

7474
.. note::
7575

7676
If you want to have multiple parameters, PyTensor provides the convenient class
7777
:class:`pytensor.link.c.params_type.ParamsType` that allows to bundle many parameters into
78-
one object that will be available in both Python (as a Python object) and C code (as a struct).
78+
one object that will be available to the C code (as a struct).
7979

8080
For example if we decide to use an int as the params the following
8181
would be appropriate:
8282

8383
.. code-block:: python
8484
85-
class MyOp(Op):
85+
class MyOp(COp):
8686
params_type = Generic()
8787
8888
After that you need to define a :meth:`get_params` method on your
@@ -115,12 +115,7 @@ Having declared a params for your Op will affect the expected
115115
signature of :meth:`perform`. The new expected signature will have an
116116
extra parameter at the end which corresponds to the params object.
117117

118-
.. warning::
119-
120-
If you do not account for this extra parameter, the code will fail
121-
at runtime if it tries to run the python version.
122-
123-
Also, for the C code, the `sub` dictionary will contain an extra entry
118+
The `sub` dictionary for `COp`s with params will contain an extra entry
124119
`'params'` which will map to the variable name of the params object.
125120
This is true for all methods that receive a `sub` parameter, so this
126121
means that you can use your params in the :meth:`c_code <COp.c_code>`
@@ -131,7 +126,7 @@ A simple example
131126
----------------
132127

133128
This is a simple example which uses a params object to pass a value.
134-
This `Op` will multiply a scalar input by a fixed floating point value.
129+
This `COp` will multiply a scalar input by a fixed floating point value.
135130

136131
Since the value in this case is a python float, we chose Generic as
137132
the params type.
@@ -156,9 +151,10 @@ the params type.
156151
inp = as_scalar(inp)
157152
return Apply(self, [inp], [inp.type()])
158153

159-
def perform(self, node, inputs, output_storage, params):
160-
# Here params is a python float so this is ok
161-
output_storage[0][0] = inputs[0] * params
154+
def perform(self, node, inputs, output_storage):
155+
# Because params is a python float we can use `self.mul` directly.
156+
# If it's something fancier, call `self.params_type.filter(self.get_params(node))`
157+
output_storage[0][0] = inputs[0] * self.mul
162158

163159
def c_code(self, node, name, inputs, outputs, sub):
164160
return ("%(z)s = %(x)s * PyFloat_AsDouble(%(p)s);" %
@@ -174,7 +170,7 @@ weights.
174170

175171
.. testcode::
176172

177-
from pytensor.graph.op import Op
173+
from pytensor.link.c.op import COp
178174
from pytensor.link.c.type import Generic
179175
from pytensor.scalar import as_scalar
180176

pytensor/graph/basic.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
from pytensor.configdefaults import config
3131
from pytensor.graph.utils import (
3232
MetaObject,
33-
MethodNotDefined,
3433
Scratchpad,
3534
TestValueError,
3635
ValidatingScratchpad,
@@ -151,16 +150,6 @@ def __init__(
151150
f"The 'outputs' argument to Apply must contain Variable instances with no owner, not {output}"
152151
)
153152

154-
def run_params(self):
155-
"""
156-
Returns the params for the node, or NoParams if no params is set.
157-
158-
"""
159-
try:
160-
return self.op.get_params(self)
161-
except MethodNotDefined:
162-
return NoParams
163-
164153
def __getstate__(self):
165154
d = self.__dict__
166155
# ufunc don't pickle/unpickle well

pytensor/graph/op.py

Lines changed: 9 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,13 @@
1616

1717
import pytensor
1818
from pytensor.configdefaults import config
19-
from pytensor.graph.basic import Apply, NoParams, Variable
19+
from pytensor.graph.basic import Apply, Variable
2020
from pytensor.graph.utils import (
2121
MetaObject,
22-
MethodNotDefined,
2322
TestValueError,
2423
add_tag_trace,
2524
get_variable_trace_string,
2625
)
27-
from pytensor.link.c.params_type import Params, ParamsType
2826

2927

3028
if TYPE_CHECKING:
@@ -37,10 +35,7 @@
3735
ComputeMapType = dict[Variable, list[bool]]
3836
InputStorageType = list[StorageCellType]
3937
OutputStorageType = list[StorageCellType]
40-
ParamsInputType = Optional[tuple[Any, ...]]
41-
PerformMethodType = Callable[
42-
[Apply, list[Any], OutputStorageType, ParamsInputType], None
43-
]
38+
PerformMethodType = Callable[[Apply, list[Any], OutputStorageType], None]
4439
BasicThunkType = Callable[[], None]
4540
ThunkCallableType = Callable[
4641
[PerformMethodType, StorageMapType, ComputeMapType, Apply], None
@@ -202,7 +197,6 @@ class Op(MetaObject):
202197

203198
itypes: Optional[Sequence["Type"]] = None
204199
otypes: Optional[Sequence["Type"]] = None
205-
params_type: Optional[ParamsType] = None
206200

207201
_output_type_depends_on_input_value = False
208202
"""
@@ -426,7 +420,6 @@ def perform(
426420
node: Apply,
427421
inputs: Sequence[Any],
428422
output_storage: OutputStorageType,
429-
params: ParamsInputType = None,
430423
) -> None:
431424
"""Calculate the function on the inputs and put the variables in the output storage.
432425
@@ -442,8 +435,6 @@ def perform(
442435
these lists). Each sub-list corresponds to value of each
443436
`Variable` in :attr:`node.outputs`. The primary purpose of this method
444437
is to set the values of these sub-lists.
445-
params
446-
A tuple containing the values of each entry in :attr:`Op.__props__`.
447438
448439
Notes
449440
-----
@@ -481,22 +472,6 @@ def do_constant_folding(self, fgraph: "FunctionGraph", node: Apply) -> bool:
481472
"""
482473
return True
483474

484-
def get_params(self, node: Apply) -> Params:
485-
"""Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
486-
if isinstance(self.params_type, ParamsType):
487-
wrapper = self.params_type
488-
if not all(hasattr(self, field) for field in wrapper.fields):
489-
# Let's print missing attributes for debugging.
490-
not_found = tuple(
491-
field for field in wrapper.fields if not hasattr(self, field)
492-
)
493-
raise AttributeError(
494-
f"{type(self).__name__}: missing attributes {not_found} for ParamsType."
495-
)
496-
# ParamsType.get_params() will apply filtering to attributes.
497-
return self.params_type.get_params(self)
498-
raise MethodNotDefined("get_params")
499-
500475
def prepare_node(
501476
self,
502477
node: Apply,
@@ -538,34 +513,12 @@ def make_py_thunk(
538513
else:
539514
p = node.op.perform
540515

541-
params = node.run_params()
542-
543-
if params is NoParams:
544-
# default arguments are stored in the closure of `rval`
545-
@is_thunk_type
546-
def rval(
547-
p=p, i=node_input_storage, o=node_output_storage, n=node, params=None
548-
):
549-
r = p(n, [x[0] for x in i], o)
550-
for o in node.outputs:
551-
compute_map[o][0] = True
552-
return r
553-
554-
else:
555-
params_val = node.params_type.filter(params)
556-
557-
@is_thunk_type
558-
def rval(
559-
p=p,
560-
i=node_input_storage,
561-
o=node_output_storage,
562-
n=node,
563-
params=params_val,
564-
):
565-
r = p(n, [x[0] for x in i], o, params)
566-
for o in node.outputs:
567-
compute_map[o][0] = True
568-
return r
516+
@is_thunk_type
517+
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
518+
r = p(n, [x[0] for x in i], o)
519+
for o in node.outputs:
520+
compute_map[o][0] = True
521+
return r
569522

570523
rval.inputs = node_input_storage
571524
rval.outputs = node_output_storage
@@ -640,7 +593,7 @@ class _NoPythonOp(Op):
640593
641594
"""
642595

643-
def perform(self, node, inputs, output_storage, params=None):
596+
def perform(self, node, inputs, output_storage):
644597
raise NotImplementedError("No Python implementation is provided by this Op.")
645598

646599

pytensor/link/c/basic.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
io_toposort,
2121
vars_between,
2222
)
23+
from pytensor.graph.utils import MethodNotDefined
2324
from pytensor.link.basic import Container, Linker, LocalLinker, PerformLinker
2425
from pytensor.link.c.cmodule import (
2526
METH_VARARGS,
@@ -617,7 +618,12 @@ def fetch_variables(self):
617618
# that needs it
618619
self.node_params = dict()
619620
for node in self.node_order:
620-
params = node.run_params()
621+
if not isinstance(node.op, CLinkerOp):
622+
continue
623+
try:
624+
params = node.op.get_params(node)
625+
except MethodNotDefined:
626+
params = NoParams
621627
if params is not NoParams:
622628
# try to avoid creating more than one variable for the
623629
# same params.
@@ -803,7 +809,10 @@ def code_gen(self):
803809

804810
sub = dict(failure_var=failure_var)
805811

806-
params = node.run_params()
812+
try:
813+
params = op.get_params(node)
814+
except MethodNotDefined:
815+
params = NoParams
807816
if params is not NoParams:
808817
params_var = symbol[self.node_params[params]]
809818

pytensor/link/c/interface.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,16 @@
1+
import typing
12
import warnings
23
from abc import abstractmethod
3-
from typing import Callable
4+
from typing import Callable, Optional
45

56
from pytensor.graph.basic import Apply, Constant
67
from pytensor.graph.utils import MethodNotDefined
78

89

10+
if typing.TYPE_CHECKING:
11+
from pytensor.link.c.params_type import Params, ParamsType
12+
13+
914
class CLinkerObject:
1015
"""Standard methods for an `Op` or `Type` used with the `CLinker`."""
1116

@@ -172,6 +177,8 @@ def c_code_cache_version(self) -> tuple[int, ...]:
172177
class CLinkerOp(CLinkerObject):
173178
"""Interface definition for `Op` subclasses compiled by `CLinker`."""
174179

180+
params_type: Optional["ParamsType"] = None
181+
175182
@abstractmethod
176183
def c_code(
177184
self,
@@ -362,6 +369,22 @@ def c_cleanup_code_struct(self, node: Apply, name: str) -> str:
362369
"""
363370
return ""
364371

372+
def get_params(self, node: Apply) -> "Params":
373+
"""Try to get parameters for the `Op` when :attr:`Op.params_type` is set to a `ParamsType`."""
374+
if self.params_type is not None:
375+
wrapper = self.params_type
376+
if not all(hasattr(self, field) for field in wrapper.fields):
377+
# Let's print missing attributes for debugging.
378+
not_found = tuple(
379+
field for field in wrapper.fields if not hasattr(self, field)
380+
)
381+
raise AttributeError(
382+
f"{type(self).__name__}: missing attributes {not_found} for ParamsType."
383+
)
384+
# ParamsType.get_params() will apply filtering to attributes.
385+
return self.params_type.get_params(self)
386+
raise MethodNotDefined("get_params")
387+
365388

366389
class CLinkerType(CLinkerObject):
367390
r"""Interface specification for `Type`\s that can be arguments to a `CLinkerOp`.

pytensor/link/c/op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -664,7 +664,7 @@ class _NoPythonCOp(COp):
664664
665665
"""
666666

667-
def perform(self, node, inputs, output_storage, params=None):
667+
def perform(self, node, inputs, output_storage):
668668
raise NotImplementedError("No Python implementation is provided by this COp.")
669669

670670

@@ -675,7 +675,7 @@ class _NoPythonExternalCOp(ExternalCOp):
675675
676676
"""
677677

678-
def perform(self, node, inputs, output_storage, params=None):
678+
def perform(self, node, inputs, output_storage):
679679
raise NotImplementedError(
680680
"No Python implementation is provided by this ExternalCOp."
681681
)

pytensor/link/numba/dispatch/basic.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from pytensor import config
2222
from pytensor.compile.builders import OpFromGraph
2323
from pytensor.compile.ops import DeepCopyOp
24-
from pytensor.graph.basic import Apply, NoParams
24+
from pytensor.graph.basic import Apply
2525
from pytensor.graph.fg import FunctionGraph
2626
from pytensor.graph.type import Type
2727
from pytensor.ifelse import IfElse
@@ -383,22 +383,11 @@ def generate_fallback_impl(op, node=None, storage_map=None, **kwargs):
383383
ret_sig = get_numba_type(node.outputs[0].type)
384384

385385
output_types = tuple(out.type for out in node.outputs)
386-
params = node.run_params()
387386

388-
if params is not NoParams:
389-
params_val = dict(node.params_type.filter(params))
390-
391-
def py_perform(inputs):
392-
outputs = [[None] for i in range(n_outputs)]
393-
op.perform(node, inputs, outputs, params_val)
394-
return outputs
395-
396-
else:
397-
398-
def py_perform(inputs):
399-
outputs = [[None] for i in range(n_outputs)]
400-
op.perform(node, inputs, outputs)
401-
return outputs
387+
def py_perform(inputs):
388+
outputs = [[None] for i in range(n_outputs)]
389+
op.perform(node, inputs, outputs)
390+
return outputs
402391

403392
if n_outputs == 1:
404393

pytensor/raise_op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def make_node(self, value: Variable, *conds: Variable):
9090
[value.type()],
9191
)
9292

93-
def perform(self, node, inputs, outputs, params):
93+
def perform(self, node, inputs, outputs):
9494
(out,) = outputs
9595
val, *conds = inputs
9696
out[0] = val

0 commit comments

Comments
 (0)