Skip to content

Commit c5bfbcd

Browse files
committed
Allow passing static shape to tensor creation helpers
* Also default dtype to "floatX" when using `tensor`
1 parent 16d1cbe commit c5bfbcd

File tree

2 files changed

+265
-42
lines changed

2 files changed

+265
-42
lines changed

pytensor/tensor/type.py

Lines changed: 157 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Union
44

55
import numpy as np
6+
from typing_extensions import Literal
67

78
import pytensor
89
from pytensor import scalar as aes
@@ -775,9 +776,17 @@ def values_eq_approx_always_true(a, b):
775776
)
776777

777778

778-
def tensor(*args, **kwargs):
779+
def tensor(
780+
dtype: Optional["DTypeLike"] = None,
781+
*args,
782+
**kwargs,
783+
) -> "TensorVariable":
784+
785+
if dtype is None:
786+
dtype = config.floatX
787+
779788
name = kwargs.pop("name", None)
780-
return TensorType(*args, **kwargs)(name=name)
789+
return TensorType(dtype, *args, **kwargs)(name=name)
781790

782791

783792
cscalar = TensorType("complex64", ())
@@ -794,7 +803,10 @@ def tensor(*args, **kwargs):
794803
ulscalar = TensorType("uint64", ())
795804

796805

797-
def scalar(name=None, dtype=None):
806+
def scalar(
807+
name: Optional[str] = None,
808+
dtype: Optional["DTypeLike"] = None,
809+
) -> "TensorVariable":
798810
"""Return a symbolic scalar variable.
799811
800812
Parameters
@@ -832,20 +844,47 @@ def scalar(name=None, dtype=None):
832844
lvector = TensorType("int64", shape=(None,))
833845

834846

835-
def vector(name=None, dtype=None):
847+
ST = Union[int, None]
848+
849+
850+
def _validate_static_shape(shape, ndim: int) -> Tuple[ST, ...]:
851+
852+
if not isinstance(shape, tuple):
853+
raise TypeError(f"Shape must be a tuple, got {type(shape)}")
854+
855+
if len(shape) != ndim:
856+
raise ValueError(f"Shape must be a tuple of length {ndim}, got {shape}")
857+
858+
if not all(sh is None or isinstance(sh, int) for sh in shape):
859+
raise TypeError(f"Shape entries must be None or integer, got {shape}")
860+
861+
return shape
862+
863+
864+
def vector(
865+
name: Optional[str] = None,
866+
dtype: Optional["DTypeLike"] = None,
867+
shape: Optional[Tuple[ST]] = (None,),
868+
) -> "TensorVariable":
836869
"""Return a symbolic vector variable.
837870
838871
Parameters
839872
----------
840-
dtype: numeric
841-
None means to use pytensor.config.floatX.
842873
name
843874
A name to attach to this variable
875+
shape
876+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
877+
allows that dimension to change size across evaluations.
878+
dtype
879+
Data type of tensor variable. By default, it's pytensor.config.floatX.
844880
845881
"""
846882
if dtype is None:
847883
dtype = config.floatX
848-
type = TensorType(dtype, shape=(None,))
884+
885+
shape = _validate_static_shape(shape, ndim=1)
886+
887+
type = TensorType(dtype, shape=shape)
849888
return type(name)
850889

851890

@@ -867,20 +906,28 @@ def vector(name=None, dtype=None):
867906
lmatrix = TensorType("int64", shape=(None, None))
868907

869908

870-
def matrix(name=None, dtype=None):
909+
def matrix(
910+
name: Optional[str] = None,
911+
dtype: Optional["DTypeLike"] = None,
912+
shape: Optional[Tuple[ST, ST]] = (None, None),
913+
) -> "TensorVariable":
871914
"""Return a symbolic matrix variable.
872915
873916
Parameters
874917
----------
875-
dtype: numeric
876-
None means to use pytensor.config.floatX.
877918
name
878-
A name to attach to this variable.
919+
A name to attach to this variable
920+
shape
921+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
922+
allows that dimension to change size across evaluations.
923+
dtype
924+
Data type of tensor variable. By default, it's pytensor.config.floatX.
879925
880926
"""
881927
if dtype is None:
882928
dtype = config.floatX
883-
type = TensorType(dtype, shape=(None, None))
929+
shape = _validate_static_shape(shape, ndim=2)
930+
type = TensorType(dtype, shape=shape)
884931
return type(name)
885932

886933

@@ -902,20 +949,34 @@ def matrix(name=None, dtype=None):
902949
lrow = TensorType("int64", shape=(1, None))
903950

904951

905-
def row(name=None, dtype=None):
952+
def row(
953+
name: Optional[str] = None,
954+
dtype: Optional["DTypeLike"] = None,
955+
shape: Optional[Tuple[Literal[1], ST]] = (1, None),
956+
) -> "TensorVariable":
906957
"""Return a symbolic row variable (i.e. shape ``(1, None)``).
907958
908959
Parameters
909960
----------
910-
dtype: numeric type
911-
None means to use pytensor.config.floatX.
912961
name
913-
A name to attach to this variable.
962+
A name to attach to this variable
963+
shape
964+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
965+
allows that dimension to change size across evaluations.
966+
dtype
967+
Data type of tensor variable. By default, it's pytensor.config.floatX.
914968
915969
"""
916970
if dtype is None:
917971
dtype = config.floatX
918-
type = TensorType(dtype, shape=(1, None))
972+
shape = _validate_static_shape(shape, ndim=2)
973+
974+
if shape[0] != 1:
975+
raise ValueError(
976+
f"The first dimension of a `row` must have shape 1, got {shape[0]}"
977+
)
978+
979+
type = TensorType(dtype, shape=shape)
919980
return type(name)
920981

921982

@@ -932,21 +993,31 @@ def row(name=None, dtype=None):
932993

933994

934995
def col(
935-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
996+
name: Optional[str] = None,
997+
dtype: Optional["DTypeLike"] = None,
998+
shape: Optional[Tuple[ST, Literal[1]]] = (None, 1),
936999
) -> "TensorVariable":
9371000
"""Return a symbolic column variable (i.e. shape ``(None, 1)``).
9381001
9391002
Parameters
9401003
----------
9411004
name
942-
A name to attach to this variable.
1005+
A name to attach to this variable
1006+
shape
1007+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1008+
allows that dimension to change size across evaluations.
9431009
dtype
944-
``None`` means to use `pytensor.config.floatX`.
1010+
Data type of tensor variable. By default, it's pytensor.config.floatX.
9451011
9461012
"""
9471013
if dtype is None:
9481014
dtype = config.floatX
949-
type = TensorType(dtype, shape=(None, 1))
1015+
shape = _validate_static_shape(shape, ndim=2)
1016+
if shape[1] != 1:
1017+
raise ValueError(
1018+
f"The second dimension of a `col` must have shape 1, got {shape[1]}"
1019+
)
1020+
type = TensorType(dtype, shape=shape)
9501021
return type(name)
9511022

9521023

@@ -963,21 +1034,27 @@ def col(
9631034

9641035

9651036
def tensor3(
966-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1037+
name: Optional[str] = None,
1038+
dtype: Optional["DTypeLike"] = None,
1039+
shape: Optional[Tuple[ST, ST, ST]] = (None, None, None),
9671040
) -> "TensorVariable":
9681041
"""Return a symbolic 3D variable.
9691042
9701043
Parameters
9711044
----------
9721045
name
973-
A name to attach to this variable.
1046+
A name to attach to this variable
1047+
shape
1048+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1049+
allows that dimension to change size across evaluations.
9741050
dtype
975-
``None`` means to use `pytensor.config.floatX`.
1051+
Data type of tensor variable. By default, it's pytensor.config.floatX.
9761052
9771053
"""
9781054
if dtype is None:
9791055
dtype = config.floatX
980-
type = TensorType(dtype, shape=(None, None, None))
1056+
shape = _validate_static_shape(shape, ndim=3)
1057+
type = TensorType(dtype, shape=shape)
9811058
return type(name)
9821059

9831060

@@ -996,21 +1073,27 @@ def tensor3(
9961073

9971074

9981075
def tensor4(
999-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1076+
name: Optional[str] = None,
1077+
dtype: Optional["DTypeLike"] = None,
1078+
shape: Optional[Tuple[ST, ST, ST, ST]] = (None, None, None, None),
10001079
) -> "TensorVariable":
10011080
"""Return a symbolic 4D variable.
10021081
10031082
Parameters
10041083
----------
10051084
name
1006-
A name to attach to this variable.
1085+
A name to attach to this variable
1086+
shape
1087+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1088+
allows that dimension to change size across evaluations.
10071089
dtype
1008-
``None`` means to use `pytensor.config.floatX`.
1090+
Data type of tensor variable. By default, it's pytensor.config.floatX.
10091091
10101092
"""
10111093
if dtype is None:
10121094
dtype = config.floatX
1013-
type = TensorType(dtype, shape=(None, None, None, None))
1095+
shape = _validate_static_shape(shape, ndim=4)
1096+
type = TensorType(dtype, shape=shape)
10141097
return type(name)
10151098

10161099

@@ -1029,21 +1112,27 @@ def tensor4(
10291112

10301113

10311114
def tensor5(
1032-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1115+
name: Optional[str] = None,
1116+
dtype: Optional["DTypeLike"] = None,
1117+
shape: Optional[Tuple[ST, ST, ST, ST, ST]] = (None, None, None, None, None),
10331118
) -> "TensorVariable":
10341119
"""Return a symbolic 5D variable.
10351120
10361121
Parameters
10371122
----------
10381123
name
1039-
A name to attach to this variable.
1124+
A name to attach to this variable
1125+
shape
1126+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1127+
allows that dimension to change size across evaluations.
10401128
dtype
1041-
``None`` means to use `pytensor.config.floatX`.
1129+
Data type of tensor variable. By default, it's pytensor.config.floatX.
10421130
10431131
"""
10441132
if dtype is None:
10451133
dtype = config.floatX
1046-
type = TensorType(dtype, shape=(None, None, None, None, None))
1134+
shape = _validate_static_shape(shape, ndim=5)
1135+
type = TensorType(dtype, shape=shape)
10471136
return type(name)
10481137

10491138

@@ -1062,21 +1151,34 @@ def tensor5(
10621151

10631152

10641153
def tensor6(
1065-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1154+
name: Optional[str] = None,
1155+
dtype: Optional["DTypeLike"] = None,
1156+
shape: Optional[Tuple[ST, ST, ST, ST, ST, ST]] = (
1157+
None,
1158+
None,
1159+
None,
1160+
None,
1161+
None,
1162+
None,
1163+
),
10661164
) -> "TensorVariable":
10671165
"""Return a symbolic 6D variable.
10681166
10691167
Parameters
10701168
----------
10711169
name
1072-
A name to attach to this variable.
1170+
A name to attach to this variable
1171+
shape
1172+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1173+
allows that dimension to change size across evaluations.
10731174
dtype
1074-
``None`` means to use `pytensor.config.floatX`.
1175+
Data type of tensor variable. By default, it's pytensor.config.floatX.
10751176
10761177
"""
10771178
if dtype is None:
10781179
dtype = config.floatX
1079-
type = TensorType(dtype, shape=(None,) * 6)
1180+
shape = _validate_static_shape(shape, ndim=6)
1181+
type = TensorType(dtype, shape=shape)
10801182
return type(name)
10811183

10821184

@@ -1095,21 +1197,35 @@ def tensor6(
10951197

10961198

10971199
def tensor7(
1098-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1200+
name: Optional[str] = None,
1201+
dtype: Optional["DTypeLike"] = None,
1202+
shape: Optional[Tuple[ST, ST, ST, ST, ST, ST, ST]] = (
1203+
None,
1204+
None,
1205+
None,
1206+
None,
1207+
None,
1208+
None,
1209+
None,
1210+
),
10991211
) -> "TensorVariable":
11001212
"""Return a symbolic 7-D variable.
11011213
11021214
Parameters
11031215
----------
11041216
name
1105-
A name to attach to this variable.
1217+
A name to attach to this variable
1218+
shape
1219+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1220+
allows that dimension to change size across evaluations.
11061221
dtype
1107-
``None`` means to use `pytensor.config.floatX`.
1222+
Data type of tensor variable. By default, it's pytensor.config.floatX.
11081223
11091224
"""
11101225
if dtype is None:
11111226
dtype = config.floatX
1112-
type = TensorType(dtype, shape=(None,) * 7)
1227+
shape = _validate_static_shape(shape, ndim=7)
1228+
type = TensorType(dtype, shape=shape)
11131229
return type(name)
11141230

11151231

0 commit comments

Comments
 (0)