Skip to content

Commit 0e455fd

Browse files
committed
Reworked tests in implementation of Shape in PyTorch
1 parent 4d4abd7 commit 0e455fd

File tree

3 files changed

+5
-17
lines changed

3 files changed

+5
-17
lines changed

pytensor/link/pytorch/dispatch/shape.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,7 @@
66

77
@pytorch_funcify.register(Reshape)
88
def pytorch_funcify_Reshape(op, node, **kwargs):
9-
shape = node.inputs[1]
10-
11-
def reshape(x, shape=shape):
9+
def reshape(x, shape):
1210
return torch.reshape(x, tuple(shape))
1311

1412
return reshape

tests/link/pytorch/test_extra_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def test_pytorch_CumOp(axis, dtype):
5151
pytest.param(
5252
None,
5353
3,
54-
marks=pytest.mark.xfail(reason="Reshape not implemented"),
54+
marks=pytest.mark.xfail(reason="Issue in Elemwise"),
55+
# TODO: add reference to issue
5556
),
5657
],
5758
)

tests/link/pytorch/test_shape.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
import numpy as np
22

33
import pytensor.tensor as pt
4-
from pytensor.compile.ops import DeepCopyOp, ViewOp
54
from pytensor.configdefaults import config
65
from pytensor.graph.fg import FunctionGraph
76
from pytensor.tensor.shape import Shape, Shape_i, Unbroadcast, reshape
@@ -46,27 +45,17 @@ def test_pytorch_Reshape_constant():
4645
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
4746

4847

49-
def test_pytorch_Reshape_shape_graph_input():
48+
def test_pytorch_Reshape_dynamic():
5049
a = vector("a")
5150
shape_pt = iscalar("b")
5251
x = reshape(a, (shape_pt, shape_pt))
5352
x_fg = FunctionGraph([a, shape_pt], [x])
5453
compare_pytorch_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])
5554

5655

57-
def test_pytorch_compile_ops():
58-
x = DeepCopyOp()(pt.as_tensor_variable(1.1))
59-
x_fg = FunctionGraph([], [x])
60-
61-
compare_pytorch_and_py(x_fg, [])
62-
56+
def test_pytorch_unbroadcast():
6357
x_np = np.zeros((20, 1, 1))
6458
x = Unbroadcast(0, 2)(pt.as_tensor_variable(x_np))
6559
x_fg = FunctionGraph([], [x])
6660

6761
compare_pytorch_and_py(x_fg, [])
68-
69-
x = ViewOp()(pt.as_tensor_variable(x_np))
70-
x_fg = FunctionGraph([], [x])
71-
72-
compare_pytorch_and_py(x_fg, [])

0 commit comments

Comments
 (0)