Skip to content

Commit 03902a7

Browse files
Raise NotImplementedError from pytorch_typify
1 parent 4f292f0 commit 03902a7

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,9 @@ def pytorch_typify(data, dtype=None, **kwargs):
1616
try:
1717
return torch.as_tensor(data, dtype=dtype)
1818
except RuntimeError:
19-
raise RuntimeError(f"Data is of type {type(data)}, it should be an array")
19+
raise NotImplementedError(
20+
f"pytorch_typify got type {type(data)}, it should be an array"
21+
)
2022

2123

2224
@pytorch_typify.register(NoneType)

tests/link/pytorch/test_subtensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def test_pytorch_Subtensor():
5151
out_fg = FunctionGraph([x_pt], [out_pt])
5252
compare_pytorch_and_py(out_fg, [x_np])
5353

54-
with pytest.raises(RuntimeError):
54+
with pytest.raises(NotImplementedError):
5555
out_pt = x_pt[[1, 2], :, [3, 4]]
5656
assert isinstance(out_pt.owner.op, pt_subtensor.AdvancedSubtensor)
5757
out_fg = FunctionGraph([x_pt], [out_pt])

0 commit comments

Comments
 (0)