Skip to content

Commit f1bc4a3

Browse files
committed
Refactor encoding helper
1 parent 164f24f commit f1bc4a3

File tree

2 files changed

+16
-13
lines changed

2 files changed

+16
-13
lines changed

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
import base64
2-
import pickle
31
from collections.abc import Callable
42
from functools import singledispatch
53
from numbers import Number
@@ -22,7 +20,10 @@
2220
numba_njit,
2321
use_optimized_cheap_pass,
2422
)
25-
from pytensor.link.numba.dispatch.vectorize_codegen import _vectorized
23+
from pytensor.link.numba.dispatch.vectorize_codegen import (
24+
_vectorized,
25+
encode_literals,
26+
)
2627
from pytensor.link.utils import compile_function_src, get_name_for_object
2728
from pytensor.scalar.basic import (
2829
AND,
@@ -482,19 +483,16 @@ def numba_funcify_Elemwise(op, node, **kwargs):
482483
op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs
483484
)
484485

485-
ndim = node.outputs[0].ndim
486-
output_bc_patterns = tuple([(False,) * ndim for _ in node.outputs])
487-
input_bc_patterns = tuple([input_var.broadcastable for input_var in node.inputs])
488-
output_dtypes = tuple(variable.dtype for variable in node.outputs)
486+
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
487+
output_bc_patterns = tuple([out.type.broadcastable for out in node.inputs])
488+
output_dtypes = tuple(out.type.dtype for out in node.outputs)
489489
inplace_pattern = tuple(op.inplace_pattern.items())
490490

491491
# numba doesn't support nested literals right now...
492-
input_bc_patterns_enc = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode()
493-
output_bc_patterns_enc = base64.encodebytes(
494-
pickle.dumps(output_bc_patterns)
495-
).decode()
496-
output_dtypes_enc = base64.encodebytes(pickle.dumps(output_dtypes)).decode()
497-
inplace_pattern_enc = base64.encodebytes(pickle.dumps(inplace_pattern)).decode()
492+
input_bc_patterns_enc = encode_literals(input_bc_patterns)
493+
output_bc_patterns_enc = encode_literals(output_bc_patterns)
494+
output_dtypes_enc = encode_literals(output_dtypes)
495+
inplace_pattern_enc = encode_literals(inplace_pattern)
498496

499497
def elemwise_wrapper(*inputs):
500498
return _vectorized(

pytensor/link/numba/dispatch/vectorize_codegen.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import base64
44
import pickle
5+
from collections.abc import Sequence
56
from typing import Any
67

78
import numba
@@ -13,6 +14,10 @@
1314
from numba.np import arrayobj
1415

1516

17+
def encode_literals(literals: Sequence) -> str:
18+
return base64.encodebytes(pickle.dumps(literals)).decode()
19+
20+
1621
_jit_options = {
1722
"fastmath": {
1823
"arcp", # Allow Reciprocal

0 commit comments

Comments
 (0)