Skip to content

Commit 7d25118

Browse files
Fix pytorch_typify for non-array dtypes
1 parent 56358cc commit 7d25118

File tree

1 file changed

+5
-5
lines changed

1 file changed

+5
-5
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,12 @@
1313
@singledispatch
1414
def pytorch_typify(data, dtype=None, **kwargs):
1515
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
16-
try:
16+
if isinstance(data, NoneType):
17+
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")
18+
if isinstance(data, slice):
19+
raise NotImplementedError(f"pytorch_typify is not implemented for {type(data)}")
20+
else:
1721
return torch.as_tensor(data, dtype=dtype)
18-
except RuntimeError:
19-
raise NotImplementedError(
20-
f"pytorch_typify got type {type(data)}, it should be an array"
21-
)
2222

2323

2424
@pytorch_typify.register(NoneType)

0 commit comments

Comments
 (0)