-
Notifications
You must be signed in to change notification settings - Fork 132
WIP torch clone #937
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP torch clone #937
Conversation
Is it better to have this open, or in draft? |
If it's ready for review and merging according to you it should be open and marked as ready for review (default), otherwise marked as a draft |
|
||
def torch_safe_clone(x): | ||
# Detach to prevent the autograd overhead from following | ||
return torch.clone(x.detach()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are not handling this detach stuff (nor requires grad anywhere) because we're not planning to use torch autodiff machinery. Also not sure how this works in compile mode (which is what we're targetting for now)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, i think it makes sense to avoid if y'all aren't planning to use it. The only reason I do it here is because we don't have to have any of the overhead tracking following us around.
https://discuss.pytorch.org/t/when-to-use-detach/98147
Anytime we declare any kind of parameter (I think that includes most of the basic tensor ops) you get a bit of gradient info. In practice I'm not sure how much it uses extra, I would need to measure it, but was just following what I've learned from my time with torch, ha.
Also not sure how this works in compile mode
I don't think you need grad info for compile mode (unless of course we want the backwards pass!) - did I miss those requirements somewhere? I'm not an expert though; source:
- https://dev-discuss.pytorch.org/t/how-does-torch-compile-work-with-autograd/1621/2 - a discussion on it's use
- some code
>>> import torch
>>> def foo():
... return torch.rand(5).sin()
...
>>> foo()
tensor([0.3314, 0.8000, 0.0395, 0.5884, 0.7805])
>>> torch.compile(foo)
<function foo at 0x138ae50d0>
>>> def foo2():
... with torch.no_grad():
... return torch.rand(5).sin()
...
>>> torch.compile(foo2)
<function foo2 at 0x138ae5940>
>>> torch.compile(foo2)()
tensor([0.7675, 0.1882, 0.6245, 0.4960, 0.4802])
sounds good, i'll move this over to a draft |
Idk how I missed this, but this is already implemented https://github.com/pymc-devs/pytensor/blame/main/pytensor/link/pytorch/dispatch/basic.py#L57. I don't think we need any changes. |
Description
Adds the torch DeepCopyOp
Related Issue
Checklist
Type of change