Skip to content

Add docs on implementing Pytorch Ops (and CumOp) #837

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jul 4, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pytensor/link/pytorch/dispatch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@
# # Load dispatch specializations
import pytensor.link.pytorch.dispatch.scalar
import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.extra_ops
# isort: on
18 changes: 18 additions & 0 deletions pytensor/link/pytorch/dispatch/extra_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.extra_ops import CumOp


@pytorch_funcify.register(CumOp)
def pytorch_funcify_Cumop(op, **kwargs):
dim = op.axis
mode = op.mode

def cumop(x, dim=dim, mode=mode):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed, the returned functions are never called by the user

Suggested change
def cumop(x, dim=dim, mode=mode):
def cumop(x):

if mode == "add":
return torch.cumsum(x, dim=dim)
else:
return torch.cumprod(x, dim=dim)

return cumop
31 changes: 31 additions & 0 deletions tests/link/pytorch/test_extra_ops.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import numpy as np

import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph import FunctionGraph
from pytensor.graph.op import get_test_value
from tests.link.pytorch.test_basic import compare_pytorch_and_py


def test_pytorch_CumOp():
"""Test PyTorch conversion of the `CumOp` `Op`."""

# Create a symbolic input for the first input of `CumOp`
a = pt.matrix("a")

# Create test value tag for a
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Create test value tag for a
# Create test value

a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for test values and tags. We're planning to deprecate that functionality as well


# Create the output variable
out = pt.cumsum(a, axis=0)
Copy link
Member

@ricardoV94 ricardoV94 Jun 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test axis=None and axis=tuple(...) if supported by the original Op. If tuple is allowed make sure you have more dimensions (say 3) and only ask for a subset (say 2) of them in the axis. This is to make sure you test something that is different than axis=None or axis=int.

The axis can be parametrized (prod and add as well) instead of adding more conditions inside the test

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried this on the original Op. axis=tuple(...) does not work and gives a TypeError
axis=None gives the output as a 1-D array

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Op __init__ doesn't seem to check explicitly for axes but it does assume it is either None or an int. Can we add a check and raise an explicit ValueError if it's not either?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked again, there is no error if we use axis=(0), pytorch also returns the same output.
The error only comes when there are more than 1 elements in the tuple (Even np.cumsum gives TypeError in this case).

We could try adding a check and raise, but would that be needed in other Op implementations?
Since this would be used as an example, it might be complicated if a check and raise is not needed for other implementations.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(0) is 0, not a tuple with a 0 inside it, it would have to be (0,) to be a tuple with a single element inside. Does it work with (0,)?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it gives a TypeError

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which is fine but probably gives a typeerror in an obscure place. We should raise already in the init method of the Op to save people time


# Create a PyTensor `FunctionGraph`
fgraph = FunctionGraph([a], [out])

# Pass the graph and inputs to the testing function
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])

# For the second mode of CumOp
out = pt.cumprod(a, axis=1)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here just pass the test values (instead of adding them as tags and then retrieving them)

Loading