-
Notifications
You must be signed in to change notification settings - Fork 132
Pytorch support for Join and Careduce Ops #869
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 8 commits
48d16ff
a8c1fb1
14dde44
0bdeb2e
bbb3622
519fd47
9073a56
036082b
8dfa059
18bbc8c
35153f7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
|
||
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify | ||
from pytensor.tensor.elemwise import DimShuffle, Elemwise | ||
from pytensor.tensor.math import All, Any, Max, Min, Prod, Sum | ||
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad | ||
|
||
|
||
|
@@ -37,6 +38,64 @@ | |
return dimshuffle | ||
|
||
|
||
@pytorch_funcify.register(Sum) | ||
def pytorch_funcify_sum(op, **kwargs): | ||
def torch_sum(x): | ||
ricardoV94 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return torch.sum(x, dim=op.axis) | ||
|
||
return torch_sum | ||
|
||
|
||
@pytorch_funcify.register(All) | ||
def pytorch_funcify_all(op, **kwargs): | ||
dim = op.axis | ||
|
||
def torch_all(x): | ||
return torch.all(x, dim=dim) | ||
|
||
return torch_all | ||
|
||
|
||
@pytorch_funcify.register(Prod) | ||
def pytorch_funcify_prod(op, **kwargs): | ||
dim = op.axis[0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why axis[0]? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Then we need to change the logic, because it is possible for them to be tuples with more than one entry There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We could do something like this:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If you reduce in reversed order you don't have to worry about the keepdims. Sounds good, a bit surprising that they don't support multiple axes |
||
|
||
def torch_prod(x): | ||
return torch.prod(x, dim=dim) | ||
|
||
return torch_prod | ||
|
||
|
||
@pytorch_funcify.register(Any) | ||
def pytorch_funcify_any(op, **kwargs): | ||
dim = op.axis | ||
|
||
def torch_any(x): | ||
return torch.any(x, dim=dim) | ||
|
||
return torch_any | ||
|
||
|
||
@pytorch_funcify.register(Max) | ||
def pytorch_funcify_max(op, **kwargs): | ||
dim = op.axis[0] | ||
|
||
def torch_max(x): | ||
return torch.max(x, dim=dim).values | ||
|
||
return torch_max | ||
|
||
|
||
@pytorch_funcify.register(Min) | ||
def pytorch_funcify_min(op, **kwargs): | ||
dim = op.axis[0] | ||
|
||
def torch_min(x): | ||
return torch.min(x, dim=dim).values | ||
|
||
return torch_min | ||
|
||
|
||
@pytorch_funcify.register(Softmax) | ||
def pytorch_funcify_Softmax(op, **kwargs): | ||
axis = op.axis | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
import pytest | ||
|
||
import pytensor.tensor as pt | ||
import pytensor.tensor.math as ptm | ||
from pytensor.configdefaults import config | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.tensor import elemwise as pt_elemwise | ||
|
@@ -57,6 +58,72 @@ def test_pytorch_elemwise(): | |
compare_pytorch_and_py(fg, [[0.9, 0.9]]) | ||
|
||
|
||
@pytest.mark.parametrize("axis", [None, 0, 1]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we parametrize these tests with the reduce function? Since they all look the same, we can reduce a bunch of lines. Or at least separate only those that need numerical inputs from those that need boolean (all and any). There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also I would like to test axis = (1, 2), and have |
||
def test_pytorch_sum(axis): | ||
a_pt = matrix("a") | ||
test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX) | ||
|
||
x = pt.math.sum(a_pt, axis=axis) | ||
x_fg = FunctionGraph([a_pt], [x]) | ||
|
||
compare_pytorch_and_py(x_fg, [test_value]) | ||
|
||
|
||
@pytest.mark.parametrize("axis", [None, 0, 1]) | ||
def test_pytorch_all(axis): | ||
a_pt = matrix("a") | ||
test_value = np.array([[True, False, True], [False, True, True]]) | ||
|
||
x = ptm.all(a_pt, axis=axis) | ||
x_fg = FunctionGraph([a_pt], [x]) | ||
|
||
compare_pytorch_and_py(x_fg, [test_value]) | ||
|
||
|
||
@pytest.mark.parametrize("axis", [0, 1]) | ||
def test_pytorch_prod(axis): | ||
a_pt = matrix("a") | ||
test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX) | ||
|
||
x = ptm.prod(a_pt, axis=axis) | ||
x_fg = FunctionGraph([a_pt], [x]) | ||
|
||
compare_pytorch_and_py(x_fg, [test_value]) | ||
|
||
|
||
@pytest.mark.parametrize("axis", [None, 0, 1]) | ||
def test_pytorch_any(axis): | ||
a_pt = matrix("a") | ||
test_value = np.array([[True, False, True], [False, True, True]]) | ||
|
||
x = ptm.any(a_pt, axis=axis) | ||
x_fg = FunctionGraph([a_pt], [x]) | ||
|
||
compare_pytorch_and_py(x_fg, [test_value]) | ||
|
||
|
||
@pytest.mark.parametrize("axis", [0, 1]) | ||
def test_pytorch_max(axis): | ||
a_pt = matrix("a") | ||
test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX) | ||
|
||
x = ptm.max(a_pt, axis=axis) | ||
x_fg = FunctionGraph([a_pt], [x]) | ||
|
||
compare_pytorch_and_py(x_fg, [test_value]) | ||
|
||
|
||
@pytest.mark.parametrize("axis", [0, 1]) | ||
def test_pytorch_min(axis): | ||
a_pt = matrix("a") | ||
test_value = np.array([[1, 2], [3, 4]]).astype(config.floatX) | ||
|
||
x = ptm.min(a_pt, axis=axis) | ||
x_fg = FunctionGraph([a_pt], [x]) | ||
|
||
compare_pytorch_and_py(x_fg, [test_value]) | ||
|
||
|
||
@pytest.mark.parametrize("dtype", ["float64", "int64"]) | ||
@pytest.mark.parametrize("axis", [None, 0, 1]) | ||
def test_softmax(axis, dtype): | ||
|
Uh oh!
There was an error while loading. Please reload this page.