-
Notifications
You must be signed in to change notification settings - Fork 132
Pytorch support for Join and Careduce Ops #869
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 10 commits
48d16ff
a8c1fb1
14dde44
0bdeb2e
bbb3622
519fd47
9073a56
036082b
8dfa059
18bbc8c
35153f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,6 +2,7 @@ | |||||
|
||||||
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify | ||||||
from pytensor.tensor.elemwise import DimShuffle, Elemwise | ||||||
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum | ||||||
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad | ||||||
|
||||||
|
||||||
|
@@ -37,6 +38,69 @@ | |||||
return dimshuffle | ||||||
|
||||||
|
||||||
@pytorch_funcify.register(Sum) | ||||||
def pytorch_funcify_sum(op, **kwargs): | ||||||
def torch_sum(x): | ||||||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
return torch.sum(x, dim=op.axis) | ||||||
|
||||||
return torch_sum | ||||||
|
||||||
|
||||||
@pytorch_funcify.register(All) | ||||||
def pytorch_funcify_all(op, **kwargs): | ||||||
def torch_all(x): | ||||||
return torch.all(x, dim=op.axis) | ||||||
|
||||||
return torch_all | ||||||
|
||||||
|
||||||
@pytorch_funcify.register(Prod) | ||||||
def pytorch_funcify_prod(op, **kwargs): | ||||||
def torch_prod(x): | ||||||
if isinstance(op.axis, tuple): | ||||||
for d in op.axis[::-1]: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. More readable?
Suggested change
Same for the others |
||||||
x = torch.prod(x, dim=d) | ||||||
return x | ||||||
else: | ||||||
return torch.prod(x.flatten(), dim=0) | ||||||
|
||||||
return torch_prod | ||||||
|
||||||
|
||||||
@pytorch_funcify.register(Any) | ||||||
def pytorch_funcify_any(op, **kwargs): | ||||||
def torch_any(x): | ||||||
return torch.any(x, dim=op.axis) | ||||||
|
||||||
return torch_any | ||||||
|
||||||
|
||||||
@pytorch_funcify.register(Max) | ||||||
def pytorch_funcify_max(op, **kwargs): | ||||||
def torch_max(x): | ||||||
if isinstance(op.axis, tuple): | ||||||
for d in op.axis[::-1]: | ||||||
x = torch.max(x, dim=d).values | ||||||
return x | ||||||
else: | ||||||
return torch.max(x.flatten(), dim=0).values | ||||||
|
||||||
return torch_max | ||||||
|
||||||
|
||||||
@pytorch_funcify.register(Min) | ||||||
def pytorch_funcify_min(op, **kwargs): | ||||||
def torch_min(x): | ||||||
if isinstance(op.axis, tuple): | ||||||
for d in op.axis[::-1]: | ||||||
x = torch.min(x, dim=d).values | ||||||
return x | ||||||
else: | ||||||
return torch.min(x.flatten(), dim=0).values | ||||||
|
||||||
return torch_min | ||||||
|
||||||
|
||||||
@pytorch_funcify.register(Softmax) | ||||||
def pytorch_funcify_Softmax(op, **kwargs): | ||||||
axis = op.axis | ||||||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,11 +2,12 @@ | |||||
import pytest | ||||||
|
||||||
import pytensor.tensor as pt | ||||||
import pytensor.tensor.math as ptm | ||||||
from pytensor.configdefaults import config | ||||||
from pytensor.graph.fg import FunctionGraph | ||||||
from pytensor.tensor import elemwise as pt_elemwise | ||||||
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax | ||||||
from pytensor.tensor.type import matrix, tensor, vector | ||||||
from pytensor.tensor.type import matrix, tensor, tensor3, vector | ||||||
from tests.link.pytorch.test_basic import compare_pytorch_and_py | ||||||
|
||||||
|
||||||
|
@@ -57,6 +58,46 @@ def test_pytorch_elemwise(): | |||||
compare_pytorch_and_py(fg, [[0.9, 0.9]]) | ||||||
|
||||||
|
||||||
@pytest.mark.parametrize("fn", [ptm.sum, ptm.prod, ptm.max, ptm.min]) | ||||||
@pytest.mark.parametrize("axis", [0, 1, (0, 1), (1, 2), (1, -1)]) | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is sufficient
Suggested change
|
||||||
def test_pytorch_careduce(fn, axis): | ||||||
a_pt = tensor3("a") | ||||||
test_value = np.array( | ||||||
[ | ||||||
[ | ||||||
[1, 1, 1, 1], | ||||||
[2, 2, 2, 2], | ||||||
], | ||||||
[ | ||||||
[3, 3, 3, 3], | ||||||
[ | ||||||
4, | ||||||
4, | ||||||
4, | ||||||
4, | ||||||
], | ||||||
], | ||||||
] | ||||||
).astype(config.floatX) | ||||||
|
||||||
x = fn(a_pt, axis=axis) | ||||||
x_fg = FunctionGraph([a_pt], [x]) | ||||||
|
||||||
compare_pytorch_and_py(x_fg, [test_value]) | ||||||
|
||||||
|
||||||
@pytest.mark.parametrize("fn", [ptm.any, ptm.all]) | ||||||
@pytest.mark.parametrize("axis", [None, 0, 1, (0, 1)]) | ||||||
def test_pytorch_any_all(fn, axis): | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
a_pt = matrix("a") | ||||||
test_value = np.array([[True, False, True], [False, True, True]]) | ||||||
|
||||||
x = fn(a_pt, axis=axis) | ||||||
x_fg = FunctionGraph([a_pt], [x]) | ||||||
|
||||||
compare_pytorch_and_py(x_fg, [test_value]) | ||||||
|
||||||
|
||||||
@pytest.mark.parametrize("dtype", ["float64", "int64"]) | ||||||
@pytest.mark.parametrize("axis", [None, 0, 1]) | ||||||
def test_softmax(axis, dtype): | ||||||
|
Uh oh!
There was an error while loading. Please reload this page.