Skip to content

Commit 9e714a6

Browse files
committed
Allow passing static shape to tensor creation helpers
* Also default dtype to "floatX" when using `tensor`
1 parent 16d1cbe commit 9e714a6

File tree

2 files changed

+259
-42
lines changed

2 files changed

+259
-42
lines changed

pytensor/tensor/type.py

Lines changed: 156 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Union
44

55
import numpy as np
6+
from typing_extensions import Literal
67

78
import pytensor
89
from pytensor import scalar as aes
@@ -775,9 +776,16 @@ def values_eq_approx_always_true(a, b):
775776
)
776777

777778

778-
def tensor(*args, **kwargs):
779+
def tensor(
780+
dtype: Optional["DTypeLike"] = None,
781+
**kwargs,
782+
) -> "TensorVariable":
783+
784+
if dtype is None:
785+
dtype = config.floatX
786+
779787
name = kwargs.pop("name", None)
780-
return TensorType(*args, **kwargs)(name=name)
788+
return TensorType(dtype=dtype, **kwargs)(name=name)
781789

782790

783791
cscalar = TensorType("complex64", ())
@@ -794,7 +802,10 @@ def tensor(*args, **kwargs):
794802
ulscalar = TensorType("uint64", ())
795803

796804

797-
def scalar(name=None, dtype=None):
805+
def scalar(
806+
name: Optional[str] = None,
807+
dtype: Optional["DTypeLike"] = None,
808+
) -> "TensorVariable":
798809
"""Return a symbolic scalar variable.
799810
800811
Parameters
@@ -832,20 +843,47 @@ def scalar(name=None, dtype=None):
832843
lvector = TensorType("int64", shape=(None,))
833844

834845

835-
def vector(name=None, dtype=None):
846+
ST = Union[int, None]
847+
848+
849+
def _validate_static_shape(shape, ndim: int) -> Tuple[ST, ...]:
850+
851+
if not isinstance(shape, tuple):
852+
raise TypeError(f"Shape must be a tuple, got {type(shape)}")
853+
854+
if len(shape) != ndim:
855+
raise ValueError(f"Shape must be a tuple of length {ndim}, got {shape}")
856+
857+
if not all(sh is None or isinstance(sh, int) for sh in shape):
858+
raise TypeError(f"Shape entries must be None or integer, got {shape}")
859+
860+
return shape
861+
862+
863+
def vector(
864+
name: Optional[str] = None,
865+
dtype: Optional["DTypeLike"] = None,
866+
shape: Optional[Tuple[ST]] = (None,),
867+
) -> "TensorVariable":
836868
"""Return a symbolic vector variable.
837869
838870
Parameters
839871
----------
840-
dtype: numeric
841-
None means to use pytensor.config.floatX.
842872
name
843873
A name to attach to this variable
874+
shape
875+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
876+
allows that dimension to change size across evaluations.
877+
dtype
878+
Data type of tensor variable. By default, it's pytensor.config.floatX.
844879
845880
"""
846881
if dtype is None:
847882
dtype = config.floatX
848-
type = TensorType(dtype, shape=(None,))
883+
884+
shape = _validate_static_shape(shape, ndim=1)
885+
886+
type = TensorType(dtype, shape=shape)
849887
return type(name)
850888

851889

@@ -867,20 +905,28 @@ def vector(name=None, dtype=None):
867905
lmatrix = TensorType("int64", shape=(None, None))
868906

869907

870-
def matrix(name=None, dtype=None):
908+
def matrix(
909+
name: Optional[str] = None,
910+
dtype: Optional["DTypeLike"] = None,
911+
shape: Optional[Tuple[ST, ST]] = (None, None),
912+
) -> "TensorVariable":
871913
"""Return a symbolic matrix variable.
872914
873915
Parameters
874916
----------
875-
dtype: numeric
876-
None means to use pytensor.config.floatX.
877917
name
878-
A name to attach to this variable.
918+
A name to attach to this variable
919+
shape
920+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
921+
allows that dimension to change size across evaluations.
922+
dtype
923+
Data type of tensor variable. By default, it's pytensor.config.floatX.
879924
880925
"""
881926
if dtype is None:
882927
dtype = config.floatX
883-
type = TensorType(dtype, shape=(None, None))
928+
shape = _validate_static_shape(shape, ndim=2)
929+
type = TensorType(dtype, shape=shape)
884930
return type(name)
885931

886932

@@ -902,20 +948,34 @@ def matrix(name=None, dtype=None):
902948
lrow = TensorType("int64", shape=(1, None))
903949

904950

905-
def row(name=None, dtype=None):
951+
def row(
952+
name: Optional[str] = None,
953+
dtype: Optional["DTypeLike"] = None,
954+
shape: Optional[Tuple[Literal[1], ST]] = (1, None),
955+
) -> "TensorVariable":
906956
"""Return a symbolic row variable (i.e. shape ``(1, None)``).
907957
908958
Parameters
909959
----------
910-
dtype: numeric type
911-
None means to use pytensor.config.floatX.
912960
name
913-
A name to attach to this variable.
961+
A name to attach to this variable
962+
shape
963+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
964+
allows that dimension to change size across evaluations.
965+
dtype
966+
Data type of tensor variable. By default, it's pytensor.config.floatX.
914967
915968
"""
916969
if dtype is None:
917970
dtype = config.floatX
918-
type = TensorType(dtype, shape=(1, None))
971+
shape = _validate_static_shape(shape, ndim=2)
972+
973+
if shape[0] != 1:
974+
raise ValueError(
975+
f"The first dimension of a `row` must have shape 1, got {shape[0]}"
976+
)
977+
978+
type = TensorType(dtype, shape=shape)
919979
return type(name)
920980

921981

@@ -932,21 +992,31 @@ def row(name=None, dtype=None):
932992

933993

934994
def col(
935-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
995+
name: Optional[str] = None,
996+
dtype: Optional["DTypeLike"] = None,
997+
shape: Optional[Tuple[ST, Literal[1]]] = (None, 1),
936998
) -> "TensorVariable":
937999
"""Return a symbolic column variable (i.e. shape ``(None, 1)``).
9381000
9391001
Parameters
9401002
----------
9411003
name
942-
A name to attach to this variable.
1004+
A name to attach to this variable
1005+
shape
1006+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1007+
allows that dimension to change size across evaluations.
9431008
dtype
944-
``None`` means to use `pytensor.config.floatX`.
1009+
Data type of tensor variable. By default, it's pytensor.config.floatX.
9451010
9461011
"""
9471012
if dtype is None:
9481013
dtype = config.floatX
949-
type = TensorType(dtype, shape=(None, 1))
1014+
shape = _validate_static_shape(shape, ndim=2)
1015+
if shape[1] != 1:
1016+
raise ValueError(
1017+
f"The second dimension of a `col` must have shape 1, got {shape[1]}"
1018+
)
1019+
type = TensorType(dtype, shape=shape)
9501020
return type(name)
9511021

9521022

@@ -963,21 +1033,27 @@ def col(
9631033

9641034

9651035
def tensor3(
966-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1036+
name: Optional[str] = None,
1037+
dtype: Optional["DTypeLike"] = None,
1038+
shape: Optional[Tuple[ST, ST, ST]] = (None, None, None),
9671039
) -> "TensorVariable":
9681040
"""Return a symbolic 3D variable.
9691041
9701042
Parameters
9711043
----------
9721044
name
973-
A name to attach to this variable.
1045+
A name to attach to this variable
1046+
shape
1047+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1048+
allows that dimension to change size across evaluations.
9741049
dtype
975-
``None`` means to use `pytensor.config.floatX`.
1050+
Data type of tensor variable. By default, it's pytensor.config.floatX.
9761051
9771052
"""
9781053
if dtype is None:
9791054
dtype = config.floatX
980-
type = TensorType(dtype, shape=(None, None, None))
1055+
shape = _validate_static_shape(shape, ndim=3)
1056+
type = TensorType(dtype, shape=shape)
9811057
return type(name)
9821058

9831059

@@ -996,21 +1072,27 @@ def tensor3(
9961072

9971073

9981074
def tensor4(
999-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1075+
name: Optional[str] = None,
1076+
dtype: Optional["DTypeLike"] = None,
1077+
shape: Optional[Tuple[ST, ST, ST, ST]] = (None, None, None, None),
10001078
) -> "TensorVariable":
10011079
"""Return a symbolic 4D variable.
10021080
10031081
Parameters
10041082
----------
10051083
name
1006-
A name to attach to this variable.
1084+
A name to attach to this variable
1085+
shape
1086+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1087+
allows that dimension to change size across evaluations.
10071088
dtype
1008-
``None`` means to use `pytensor.config.floatX`.
1089+
Data type of tensor variable. By default, it's pytensor.config.floatX.
10091090
10101091
"""
10111092
if dtype is None:
10121093
dtype = config.floatX
1013-
type = TensorType(dtype, shape=(None, None, None, None))
1094+
shape = _validate_static_shape(shape, ndim=4)
1095+
type = TensorType(dtype, shape=shape)
10141096
return type(name)
10151097

10161098

@@ -1029,21 +1111,27 @@ def tensor4(
10291111

10301112

10311113
def tensor5(
1032-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1114+
name: Optional[str] = None,
1115+
dtype: Optional["DTypeLike"] = None,
1116+
shape: Optional[Tuple[ST, ST, ST, ST, ST]] = (None, None, None, None, None),
10331117
) -> "TensorVariable":
10341118
"""Return a symbolic 5D variable.
10351119
10361120
Parameters
10371121
----------
10381122
name
1039-
A name to attach to this variable.
1123+
A name to attach to this variable
1124+
shape
1125+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1126+
allows that dimension to change size across evaluations.
10401127
dtype
1041-
``None`` means to use `pytensor.config.floatX`.
1128+
Data type of tensor variable. By default, it's pytensor.config.floatX.
10421129
10431130
"""
10441131
if dtype is None:
10451132
dtype = config.floatX
1046-
type = TensorType(dtype, shape=(None, None, None, None, None))
1133+
shape = _validate_static_shape(shape, ndim=5)
1134+
type = TensorType(dtype, shape=shape)
10471135
return type(name)
10481136

10491137

@@ -1062,21 +1150,34 @@ def tensor5(
10621150

10631151

10641152
def tensor6(
1065-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1153+
name: Optional[str] = None,
1154+
dtype: Optional["DTypeLike"] = None,
1155+
shape: Optional[Tuple[ST, ST, ST, ST, ST, ST]] = (
1156+
None,
1157+
None,
1158+
None,
1159+
None,
1160+
None,
1161+
None,
1162+
),
10661163
) -> "TensorVariable":
10671164
"""Return a symbolic 6D variable.
10681165
10691166
Parameters
10701167
----------
10711168
name
1072-
A name to attach to this variable.
1169+
A name to attach to this variable
1170+
shape
1171+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1172+
allows that dimension to change size across evaluations.
10731173
dtype
1074-
``None`` means to use `pytensor.config.floatX`.
1174+
Data type of tensor variable. By default, it's pytensor.config.floatX.
10751175
10761176
"""
10771177
if dtype is None:
10781178
dtype = config.floatX
1079-
type = TensorType(dtype, shape=(None,) * 6)
1179+
shape = _validate_static_shape(shape, ndim=6)
1180+
type = TensorType(dtype, shape=shape)
10801181
return type(name)
10811182

10821183

@@ -1095,21 +1196,35 @@ def tensor6(
10951196

10961197

10971198
def tensor7(
1098-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1199+
name: Optional[str] = None,
1200+
dtype: Optional["DTypeLike"] = None,
1201+
shape: Optional[Tuple[ST, ST, ST, ST, ST, ST, ST]] = (
1202+
None,
1203+
None,
1204+
None,
1205+
None,
1206+
None,
1207+
None,
1208+
None,
1209+
),
10991210
) -> "TensorVariable":
11001211
"""Return a symbolic 7-D variable.
11011212
11021213
Parameters
11031214
----------
11041215
name
1105-
A name to attach to this variable.
1216+
A name to attach to this variable
1217+
shape
1218+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1219+
allows that dimension to change size across evaluations.
11061220
dtype
1107-
``None`` means to use `pytensor.config.floatX`.
1221+
Data type of tensor variable. By default, it's pytensor.config.floatX.
11081222
11091223
"""
11101224
if dtype is None:
11111225
dtype = config.floatX
1112-
type = TensorType(dtype, shape=(None,) * 7)
1226+
shape = _validate_static_shape(shape, ndim=7)
1227+
type = TensorType(dtype, shape=shape)
11131228
return type(name)
11141229

11151230

0 commit comments

Comments
 (0)