Skip to content

Commit 7fd83dc

Browse files
Fix TensorFromScalar
1 parent 5c73991 commit 7fd83dc

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pytensor/link/pytorch/dispatch/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def pytorch_typify_tensor(data, dtype=None, **kwargs):
3333
@pytorch_typify.register(slice)
3434
@pytorch_typify.register(NoneType)
3535
@pytorch_typify.register(np.number)
36-
def pytorch_typify_scalar(data, **kwargs):
36+
def pytorch_typify_no_conversion_needed(data, **kwargs):
3737
return data
3838

3939

@@ -153,6 +153,6 @@ def makevector(*x):
153153
@pytorch_funcify.register(TensorFromScalar)
154154
def pytorch_funcify_TensorFromScalar(op, **kwargs):
155155
def tensorfromscalar(x):
156-
return x
156+
return torch.as_tensor(x)
157157

158158
return tensorfromscalar

0 commit comments

Comments
 (0)