From d93e6c25a743578634970896249a529a1f8b9255 Mon Sep 17 00:00:00 2001 From: Ian Schweer Date: Sun, 14 Jul 2024 21:29:18 -0700 Subject: [PATCH] WIP torch clone --- pytensor/link/pytorch/dispatch/__init__.py | 2 ++ pytensor/link/pytorch/dispatch/basic.py | 11 +++++++++++ tests/link/pytorch/test_basic.py | 5 +++++ 3 files changed, 18 insertions(+) diff --git a/pytensor/link/pytorch/dispatch/__init__.py b/pytensor/link/pytorch/dispatch/__init__.py index 143d6b1bcb..d9807fc706 100644 --- a/pytensor/link/pytorch/dispatch/__init__.py +++ b/pytensor/link/pytorch/dispatch/__init__.py @@ -6,4 +6,6 @@ import pytensor.link.pytorch.dispatch.elemwise import pytensor.link.pytorch.dispatch.extra_ops import pytensor.link.pytorch.dispatch.sort +import pytensor.link.pytorch.dispatch.basic # isort: on + diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 37622a8294..5da13fe517 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -116,3 +116,14 @@ def eye(N, M, k): return zeros return eye + +def torch_safe_clone(x): + # Detach to prevent the autograd overhead from following + return torch.clone(x.detach()) + +@pytorch_funcify.register(DeepCopyOp) +def pytorch_funcify_DeepCopyOp(op, *kwargs): + def deepcopyop(x): + return torch_safe_clone(x) + + return deepcopyop \ No newline at end of file diff --git a/tests/link/pytorch/test_basic.py b/tests/link/pytorch/test_basic.py index 0ccb1c454f..f5b59476be 100644 --- a/tests/link/pytorch/test_basic.py +++ b/tests/link/pytorch/test_basic.py @@ -7,6 +7,7 @@ import pytensor.tensor.basic as ptb from pytensor.compile.function import function from pytensor.compile.mode import get_mode +from pytensor.compile.ops import DeepCopyOp from pytensor.compile.sharedvalue import SharedVariable, shared from pytensor.configdefaults import config from pytensor.graph.basic import Apply @@ -294,3 +295,7 @@ def test_eye(dtype): for _M in range(1, 6): for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]: np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k)) + +@pytest.mark.xfail(raises=NotImplementedError) +def test_deepcopy(): + assert False