Skip to content

Commit 809c228

Browse files
brandonwillardricardoV94
authored andcommitted
Change type_conversion_fn to const_conversion_fn in fgraph_to_python
1 parent 27459b6 commit 809c228

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

pytensor/link/jax/dispatch/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def jax_funcify_FunctionGraph(
4949
return fgraph_to_python(
5050
fgraph,
5151
jax_funcify,
52-
type_conversion_fn=jax_typify,
52+
const_conversion_fn=jax_typify,
5353
fgraph_name=fgraph_name,
5454
**kwargs,
5555
)

pytensor/link/numba/dispatch/basic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -442,7 +442,7 @@ def numba_funcify_FunctionGraph(
442442
return fgraph_to_python(
443443
fgraph,
444444
numba_funcify,
445-
type_conversion_fn=numba_typify,
445+
const_conversion_fn=numba_typify,
446446
fgraph_name=fgraph_name,
447447
**kwargs,
448448
)

pytensor/link/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -678,7 +678,7 @@ def fgraph_to_python(
678678
fgraph: FunctionGraph,
679679
op_conversion_fn: Callable,
680680
*,
681-
type_conversion_fn: Callable = lambda x, **kwargs: x,
681+
const_conversion_fn: Callable = lambda x, **kwargs: x,
682682
order: Optional[List[Apply]] = None,
683683
storage_map: Optional["StorageMapType"] = None,
684684
fgraph_name: str = "fgraph_to_python",
@@ -698,8 +698,8 @@ def fgraph_to_python(
698698
A callable used to convert nodes inside `fgraph` based on their `Op`
699699
types. It must have the signature
700700
``(op: Op, node: Apply=None, storage_map: Dict[Variable, List[Optional[Any]]]=None, **kwargs)``.
701-
type_conversion_fn
702-
A callable used to convert the values in `storage_map`. It must have
701+
const_conversion_fn
702+
A callable used to convert the `Constant` values in `storage_map`. It must have
703703
the signature
704704
``(value: Optional[Any], variable: Variable=None, storage: List[Optional[Any]]=None, **kwargs)``.
705705
order
@@ -753,7 +753,7 @@ def fgraph_to_python(
753753
)
754754
if input_storage[0] is not None or isinstance(i, Constant):
755755
# Constants need to be assigned locally and referenced
756-
global_env[local_input_name] = type_conversion_fn(
756+
global_env[local_input_name] = const_conversion_fn(
757757
input_storage[0], variable=i, storage=input_storage, **kwargs
758758
)
759759
# TODO: We could attempt to use the storage arrays directly
@@ -776,7 +776,7 @@ def fgraph_to_python(
776776
output_storage = storage_map.setdefault(
777777
out, [None if not isinstance(out, Constant) else out.data]
778778
)
779-
global_env[local_output_name] = type_conversion_fn(
779+
global_env[local_output_name] = const_conversion_fn(
780780
output_storage[0],
781781
variable=out,
782782
storage=output_storage,

0 commit comments

Comments
 (0)