|
1 |
| -import base64 |
2 |
| -import pickle |
3 | 1 | from collections.abc import Callable
|
4 | 2 | from functools import singledispatch
|
5 | 3 | from numbers import Number
|
|
22 | 20 | numba_njit,
|
23 | 21 | use_optimized_cheap_pass,
|
24 | 22 | )
|
25 |
| -from pytensor.link.numba.dispatch.vectorize_codegen import _vectorized |
| 23 | +from pytensor.link.numba.dispatch.vectorize_codegen import ( |
| 24 | + _vectorized, |
| 25 | + encode_literals, |
| 26 | +) |
26 | 27 | from pytensor.link.utils import compile_function_src, get_name_for_object
|
27 | 28 | from pytensor.scalar.basic import (
|
28 | 29 | AND,
|
@@ -482,19 +483,16 @@ def numba_funcify_Elemwise(op, node, **kwargs):
|
482 | 483 | op.scalar_op, node=scalar_node, parent_node=node, fastmath=flags, **kwargs
|
483 | 484 | )
|
484 | 485 |
|
485 |
| - ndim = node.outputs[0].ndim |
486 |
| - output_bc_patterns = tuple([(False,) * ndim for _ in node.outputs]) |
487 |
| - input_bc_patterns = tuple([input_var.broadcastable for input_var in node.inputs]) |
488 |
| - output_dtypes = tuple(variable.dtype for variable in node.outputs) |
| 486 | + input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs]) |
| 487 | + output_bc_patterns = tuple([out.type.broadcastable for out in node.inputs]) |
| 488 | + output_dtypes = tuple(out.type.dtype for out in node.outputs) |
489 | 489 | inplace_pattern = tuple(op.inplace_pattern.items())
|
490 | 490 |
|
491 | 491 | # numba doesn't support nested literals right now...
|
492 |
| - input_bc_patterns_enc = base64.encodebytes(pickle.dumps(input_bc_patterns)).decode() |
493 |
| - output_bc_patterns_enc = base64.encodebytes( |
494 |
| - pickle.dumps(output_bc_patterns) |
495 |
| - ).decode() |
496 |
| - output_dtypes_enc = base64.encodebytes(pickle.dumps(output_dtypes)).decode() |
497 |
| - inplace_pattern_enc = base64.encodebytes(pickle.dumps(inplace_pattern)).decode() |
| 492 | + input_bc_patterns_enc = encode_literals(input_bc_patterns) |
| 493 | + output_bc_patterns_enc = encode_literals(output_bc_patterns) |
| 494 | + output_dtypes_enc = encode_literals(output_dtypes) |
| 495 | + inplace_pattern_enc = encode_literals(inplace_pattern) |
498 | 496 |
|
499 | 497 | def elemwise_wrapper(*inputs):
|
500 | 498 | return _vectorized(
|
|
0 commit comments