From 90880e248fe9294e44c27f33b883fc0608d6714d Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Mon, 20 Jan 2025 07:50:10 -0800 Subject: [PATCH 1/3] Allow function dispatch for constants --- pytensor/link/pytorch/dispatch/basic.py | 15 +++++++++++++-- pytensor/link/pytorch/linker.py | 1 + pytensor/link/utils.py | 18 +++++++++++++++++- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 11e1d6c63a..5ec5a366d6 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -36,10 +36,12 @@ def pytorch_typify_tensor(data, dtype=None, **kwargs): @pytorch_typify.register(slice) @pytorch_typify.register(NoneType) -@pytorch_typify.register(np.number) def pytorch_typify_no_conversion_needed(data, **kwargs): return data +@pytorch_typify.register(np.number) +def pytorch_typify_extract(data, **kwargs): + return data.item() @singledispatch def pytorch_funcify(op, node=None, storage_map=None, **kwargs): @@ -57,11 +59,20 @@ def pytorch_funcify_FunctionGraph( conversion_func=pytorch_funcify, **kwargs, ): + def constants_wrapper(x, **kwargs): + x = pytorch_typify(x) + + @torch.compiler.assume_constant_result + def torch_assume_constant(arg=x): + return arg + + return torch_assume_constant + built_kwargs = {"conversion_func": conversion_func, **kwargs} return fgraph_to_python( fgraph, conversion_func, - type_conversion_fn=pytorch_typify, + type_conversion_fn=constants_wrapper, fgraph_name=fgraph_name, **built_kwargs, ) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index d47aa43dda..4a5acd5b85 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -51,6 +51,7 @@ class wrapper: """ def __init__(self, fn, gen_functors): + self._fn = fn self.fn = torch.compile(fn) self.gen_functors = gen_functors.copy() diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index 9cbc3838dd..d02398f85b 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -749,9 +749,25 @@ def fgraph_to_python( ) if input_storage[0] is not None or isinstance(i, Constant): # Constants need to be assigned locally and referenced - global_env[local_input_name] = type_conversion_fn( + getter_or_value = type_conversion_fn( input_storage[0], variable=i, storage=input_storage, **kwargs ) + if callable(getter_or_value): + # we got passed a function, this could be used to indicate something + # to the backend. We'll embed it + new_output_name = unique_name(i) + getter_unique_name = unique_name(getter_or_value) + global_env[getter_unique_name] = getter_or_value + assign_str = ( + f"{new_output_name} = {getter_unique_name}()" + ) + body_assigns.append(assign_str) + node_input_names.append(new_output_name) + continue + else: + global_env[local_input_name] = type_conversion_fn( + input_storage[0], variable=i, storage=input_storage, **kwargs + ) # TODO: We could attempt to use the storage arrays directly # E.g. `local_input_name = f"{local_input_name}[0]"` node_input_names.append(local_input_name) From 96ec5314b43ead2c846d20010d69147c1cfed935 Mon Sep 17 00:00:00 2001 From: ischweer Date: Mon, 20 Jan 2025 19:03:16 -0800 Subject: [PATCH 2/3] Lint --- pytensor/link/pytorch/dispatch/basic.py | 2 ++ pytensor/link/pytorch/linker.py | 1 - pytensor/link/utils.py | 4 +--- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 5ec5a366d6..d0626b68ca 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -39,10 +39,12 @@ def pytorch_typify_tensor(data, dtype=None, **kwargs): def pytorch_typify_no_conversion_needed(data, **kwargs): return data + @pytorch_typify.register(np.number) def pytorch_typify_extract(data, **kwargs): return data.item() + @singledispatch def pytorch_funcify(op, node=None, storage_map=None, **kwargs): """Create a PyTorch compatible function from an PyTensor `Op`.""" diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index 4a5acd5b85..d47aa43dda 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -51,7 +51,6 @@ class wrapper: """ def __init__(self, fn, gen_functors): - self._fn = fn self.fn = torch.compile(fn) self.gen_functors = gen_functors.copy() diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index d02398f85b..6b4c2f20f2 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -758,9 +758,7 @@ def fgraph_to_python( new_output_name = unique_name(i) getter_unique_name = unique_name(getter_or_value) global_env[getter_unique_name] = getter_or_value - assign_str = ( - f"{new_output_name} = {getter_unique_name}()" - ) + assign_str = f"{new_output_name} = {getter_unique_name}()" body_assigns.append(assign_str) node_input_names.append(new_output_name) continue From dc908cb4486889ad5f17f3f66b9d6e90da6e2fa8 Mon Sep 17 00:00:00 2001 From: ischweer Date: Sun, 2 Feb 2025 18:04:07 -0800 Subject: [PATCH 3/3] Fix dangling reference --- pytensor/link/pytorch/dispatch/basic.py | 11 ++--------- pytensor/link/pytorch/linker.py | 16 +++++++++++++++- pytensor/link/utils.py | 4 ++-- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index d0626b68ca..99977a5915 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -61,20 +61,13 @@ def pytorch_funcify_FunctionGraph( conversion_func=pytorch_funcify, **kwargs, ): - def constants_wrapper(x, **kwargs): - x = pytorch_typify(x) - - @torch.compiler.assume_constant_result - def torch_assume_constant(arg=x): - return arg - - return torch_assume_constant + if "type_conversion_fn" not in kwargs: + kwargs["type_conversion_fn"] = pytorch_typify built_kwargs = {"conversion_func": conversion_func, **kwargs} return fgraph_to_python( fgraph, conversion_func, - type_conversion_fn=constants_wrapper, fgraph_name=fgraph_name, **built_kwargs, ) diff --git a/pytensor/link/pytorch/linker.py b/pytensor/link/pytorch/linker.py index d47aa43dda..8a6fc8b6f5 100644 --- a/pytensor/link/pytorch/linker.py +++ b/pytensor/link/pytorch/linker.py @@ -10,7 +10,9 @@ def __init__(self, *args, **kwargs): self.gen_functors = [] def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs): - from pytensor.link.pytorch.dispatch import pytorch_funcify + import torch + + from pytensor.link.pytorch.dispatch import pytorch_funcify, pytorch_typify # We want to have globally unique names # across the entire pytensor graph, not @@ -25,9 +27,21 @@ def conversion_func_register(*args, **kwargs): self.gen_functors.append((f"_{name}", functor)) return functor + def constants_wrapper(x, **kwargs): + x = pytorch_typify(x) + + @torch.compiler.assume_constant_result + def torch_assume_constant(arg=x): + return arg + + name = kwargs["unique_name"](torch_assume_constant) + self.gen_functors.append((f"_{name}", torch_assume_constant)) + return torch_assume_constant + built_kwargs = { "unique_name": generator, "conversion_func": conversion_func_register, + "type_conversion_fn": constants_wrapper, **kwargs, } return pytorch_funcify( diff --git a/pytensor/link/utils.py b/pytensor/link/utils.py index 6b4c2f20f2..142fefc04d 100644 --- a/pytensor/link/utils.py +++ b/pytensor/link/utils.py @@ -766,8 +766,8 @@ def fgraph_to_python( global_env[local_input_name] = type_conversion_fn( input_storage[0], variable=i, storage=input_storage, **kwargs ) - # TODO: We could attempt to use the storage arrays directly - # E.g. `local_input_name = f"{local_input_name}[0]"` + # TODO: We could attempt to use the storage arrays directly + # E.g. `local_input_name = f"{local_input_name}[0]"` node_input_names.append(local_input_name) node_output_names = [unique_name(v) for v in node.outputs]