-
Notifications
You must be signed in to change notification settings - Fork 132
Add docs on implementing Pytorch Ops (and CumOp) #837
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
f007c0d
6ae355f
33463fc
bf905cb
debc3e0
2bc7ddc
ec87e4e
a6e6bd8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
|
||
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify | ||
from pytensor.tensor.extra_ops import CumOp | ||
|
||
|
||
@pytorch_funcify.register(CumOp) | ||
def pytorch_funcify_Cumop(op, **kwargs): | ||
dim = op.axis | ||
mode = op.mode | ||
|
||
def cumop(x, dim=dim, mode=mode): | ||
if mode == "add": | ||
return torch.cumsum(x, dim=dim) | ||
else: | ||
return torch.cumprod(x, dim=dim) | ||
|
||
return cumop |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -0,0 +1,31 @@ | ||||||
import numpy as np | ||||||
|
||||||
import pytensor.tensor as pt | ||||||
from pytensor.configdefaults import config | ||||||
from pytensor.graph import FunctionGraph | ||||||
from pytensor.graph.op import get_test_value | ||||||
from tests.link.pytorch.test_basic import compare_pytorch_and_py | ||||||
|
||||||
|
||||||
def test_pytorch_CumOp(): | ||||||
"""Test PyTorch conversion of the `CumOp` `Op`.""" | ||||||
|
||||||
# Create a symbolic input for the first input of `CumOp` | ||||||
a = pt.matrix("a") | ||||||
|
||||||
# Create test value tag for a | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need for test values and tags. We're planning to deprecate that functionality as well |
||||||
|
||||||
# Create the output variable | ||||||
out = pt.cumsum(a, axis=0) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test The axis can be parametrized (prod and add as well) instead of adding more conditions inside the test There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Tried this on the original Op. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The Op There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Checked again, there is no error if we use We could try adding a check and raise, but would that be needed in other Op implementations? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, it gives a TypeError There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which is fine but probably gives a typeerror in an obscure place. We should raise already in the init method of the Op to save people time |
||||||
|
||||||
# Create a PyTensor `FunctionGraph` | ||||||
fgraph = FunctionGraph([a], [out]) | ||||||
|
||||||
# Pass the graph and inputs to the testing function | ||||||
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) | ||||||
|
||||||
# For the second mode of CumOp | ||||||
out = pt.cumprod(a, axis=1) | ||||||
fgraph = FunctionGraph([a], [out]) | ||||||
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here just pass the test values (instead of adding them as tags and then retrieving them) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed, the returned functions are never called by the user