Skip to content

WIP torch clone #937

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,6 @@
import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.extra_ops
import pytensor.link.pytorch.dispatch.sort
import pytensor.link.pytorch.dispatch.basic
# isort: on

11 changes: 11 additions & 0 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,3 +116,14 @@ def eye(N, M, k):
return zeros

return eye

def torch_safe_clone(x):
# Detach to prevent the autograd overhead from following
return torch.clone(x.detach())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not handling this detach stuff (nor requires grad anywhere) because we're not planning to use torch autodiff machinery. Also not sure how this works in compile mode (which is what we're targetting for now)?

Copy link
Contributor Author

@Ch0ronomato Ch0ronomato Jul 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, i think it makes sense to avoid if y'all aren't planning to use it. The only reason I do it here is because we don't have to have any of the overhead tracking following us around.

https://discuss.pytorch.org/t/when-to-use-detach/98147

Anytime we declare any kind of parameter (I think that includes most of the basic tensor ops) you get a bit of gradient info. In practice I'm not sure how much it uses extra, I would need to measure it, but was just following what I've learned from my time with torch, ha.

Also not sure how this works in compile mode

I don't think you need grad info for compile mode (unless of course we want the backwards pass!) - did I miss those requirements somewhere? I'm not an expert though; source:

>>> import torch
>>> def foo():
...     return torch.rand(5).sin()
... 
>>> foo()
tensor([0.3314, 0.8000, 0.0395, 0.5884, 0.7805])
>>> torch.compile(foo)
<function foo at 0x138ae50d0>
>>> def foo2():
...     with torch.no_grad():
...             return torch.rand(5).sin()
... 
>>> torch.compile(foo2)
<function foo2 at 0x138ae5940>
>>> torch.compile(foo2)()
tensor([0.7675, 0.1882, 0.6245, 0.4960, 0.4802])


@pytorch_funcify.register(DeepCopyOp)
def pytorch_funcify_DeepCopyOp(op, *kwargs):
def deepcopyop(x):
return torch_safe_clone(x)

return deepcopyop
5 changes: 5 additions & 0 deletions tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytensor.tensor.basic as ptb
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.ops import DeepCopyOp
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply
Expand Down Expand Up @@ -294,3 +295,7 @@ def test_eye(dtype):
for _M in range(1, 6):
for _k in list(range(_M + 2)) + [-x for x in range(1, _N + 2)]:
np.testing.assert_array_equal(fn(_N, _M, _k), np.eye(_N, _M, _k))

@pytest.mark.xfail(raises=NotImplementedError)
def test_deepcopy():
assert False