Skip to content

Commit c58aef3

Browse files
committed
Allow passing static shape to tensor creation helpers
* Also default dtype to "floatX" when using `tensor`
1 parent 16d1cbe commit c58aef3

File tree

2 files changed

+260
-43
lines changed

2 files changed

+260
-43
lines changed

pytensor/tensor/type.py

Lines changed: 156 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import warnings
3-
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Union
3+
from typing import TYPE_CHECKING, Iterable, Literal, Optional, Tuple, Union
44

55
import numpy as np
66

@@ -775,9 +775,16 @@ def values_eq_approx_always_true(a, b):
775775
)
776776

777777

778-
def tensor(*args, **kwargs):
778+
def tensor(
779+
dtype: Optional["DTypeLike"] = None,
780+
**kwargs,
781+
) -> "TensorVariable":
782+
783+
if dtype is None:
784+
dtype = config.floatX
785+
779786
name = kwargs.pop("name", None)
780-
return TensorType(*args, **kwargs)(name=name)
787+
return TensorType(dtype=dtype, **kwargs)(name=name)
781788

782789

783790
cscalar = TensorType("complex64", ())
@@ -794,7 +801,10 @@ def tensor(*args, **kwargs):
794801
ulscalar = TensorType("uint64", ())
795802

796803

797-
def scalar(name=None, dtype=None):
804+
def scalar(
805+
name: Optional[str] = None,
806+
dtype: Optional["DTypeLike"] = None,
807+
) -> "TensorVariable":
798808
"""Return a symbolic scalar variable.
799809
800810
Parameters
@@ -832,20 +842,47 @@ def scalar(name=None, dtype=None):
832842
lvector = TensorType("int64", shape=(None,))
833843

834844

835-
def vector(name=None, dtype=None):
845+
ST = Union[int, None]
846+
847+
848+
def _validate_static_shape(shape, ndim: int) -> Tuple[ST, ...]:
849+
850+
if not isinstance(shape, tuple):
851+
raise TypeError(f"Shape must be a tuple, got {type(shape)}")
852+
853+
if len(shape) != ndim:
854+
raise ValueError(f"Shape must be a tuple of length {ndim}, got {shape}")
855+
856+
if not all(sh is None or isinstance(sh, int) for sh in shape):
857+
raise TypeError(f"Shape entries must be None or integer, got {shape}")
858+
859+
return shape
860+
861+
862+
def vector(
863+
name: Optional[str] = None,
864+
dtype: Optional["DTypeLike"] = None,
865+
shape: Optional[Tuple[ST]] = (None,),
866+
) -> "TensorVariable":
836867
"""Return a symbolic vector variable.
837868
838869
Parameters
839870
----------
840-
dtype: numeric
841-
None means to use pytensor.config.floatX.
842871
name
843872
A name to attach to this variable
873+
shape
874+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
875+
allows that dimension to change size across evaluations.
876+
dtype
877+
Data type of tensor variable. By default, it's pytensor.config.floatX.
844878
845879
"""
846880
if dtype is None:
847881
dtype = config.floatX
848-
type = TensorType(dtype, shape=(None,))
882+
883+
shape = _validate_static_shape(shape, ndim=1)
884+
885+
type = TensorType(dtype, shape=shape)
849886
return type(name)
850887

851888

@@ -867,20 +904,28 @@ def vector(name=None, dtype=None):
867904
lmatrix = TensorType("int64", shape=(None, None))
868905

869906

870-
def matrix(name=None, dtype=None):
907+
def matrix(
908+
name: Optional[str] = None,
909+
dtype: Optional["DTypeLike"] = None,
910+
shape: Optional[Tuple[ST, ST]] = (None, None),
911+
) -> "TensorVariable":
871912
"""Return a symbolic matrix variable.
872913
873914
Parameters
874915
----------
875-
dtype: numeric
876-
None means to use pytensor.config.floatX.
877916
name
878-
A name to attach to this variable.
917+
A name to attach to this variable
918+
shape
919+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
920+
allows that dimension to change size across evaluations.
921+
dtype
922+
Data type of tensor variable. By default, it's pytensor.config.floatX.
879923
880924
"""
881925
if dtype is None:
882926
dtype = config.floatX
883-
type = TensorType(dtype, shape=(None, None))
927+
shape = _validate_static_shape(shape, ndim=2)
928+
type = TensorType(dtype, shape=shape)
884929
return type(name)
885930

886931

@@ -902,20 +947,34 @@ def matrix(name=None, dtype=None):
902947
lrow = TensorType("int64", shape=(1, None))
903948

904949

905-
def row(name=None, dtype=None):
950+
def row(
951+
name: Optional[str] = None,
952+
dtype: Optional["DTypeLike"] = None,
953+
shape: Optional[Tuple[Literal[1], ST]] = (1, None),
954+
) -> "TensorVariable":
906955
"""Return a symbolic row variable (i.e. shape ``(1, None)``).
907956
908957
Parameters
909958
----------
910-
dtype: numeric type
911-
None means to use pytensor.config.floatX.
912959
name
913-
A name to attach to this variable.
960+
A name to attach to this variable
961+
shape
962+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
963+
allows that dimension to change size across evaluations.
964+
dtype
965+
Data type of tensor variable. By default, it's pytensor.config.floatX.
914966
915967
"""
916968
if dtype is None:
917969
dtype = config.floatX
918-
type = TensorType(dtype, shape=(1, None))
970+
shape = _validate_static_shape(shape, ndim=2)
971+
972+
if shape[0] != 1:
973+
raise ValueError(
974+
f"The first dimension of a `row` must have shape 1, got {shape[0]}"
975+
)
976+
977+
type = TensorType(dtype, shape=shape)
919978
return type(name)
920979

921980

@@ -932,21 +991,31 @@ def row(name=None, dtype=None):
932991

933992

934993
def col(
935-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
994+
name: Optional[str] = None,
995+
dtype: Optional["DTypeLike"] = None,
996+
shape: Optional[Tuple[ST, Literal[1]]] = (None, 1),
936997
) -> "TensorVariable":
937998
"""Return a symbolic column variable (i.e. shape ``(None, 1)``).
938999
9391000
Parameters
9401001
----------
9411002
name
942-
A name to attach to this variable.
1003+
A name to attach to this variable
1004+
shape
1005+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1006+
allows that dimension to change size across evaluations.
9431007
dtype
944-
``None`` means to use `pytensor.config.floatX`.
1008+
Data type of tensor variable. By default, it's pytensor.config.floatX.
9451009
9461010
"""
9471011
if dtype is None:
9481012
dtype = config.floatX
949-
type = TensorType(dtype, shape=(None, 1))
1013+
shape = _validate_static_shape(shape, ndim=2)
1014+
if shape[1] != 1:
1015+
raise ValueError(
1016+
f"The second dimension of a `col` must have shape 1, got {shape[1]}"
1017+
)
1018+
type = TensorType(dtype, shape=shape)
9501019
return type(name)
9511020

9521021

@@ -963,21 +1032,27 @@ def col(
9631032

9641033

9651034
def tensor3(
966-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1035+
name: Optional[str] = None,
1036+
dtype: Optional["DTypeLike"] = None,
1037+
shape: Optional[Tuple[ST, ST, ST]] = (None, None, None),
9671038
) -> "TensorVariable":
9681039
"""Return a symbolic 3D variable.
9691040
9701041
Parameters
9711042
----------
9721043
name
973-
A name to attach to this variable.
1044+
A name to attach to this variable
1045+
shape
1046+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1047+
allows that dimension to change size across evaluations.
9741048
dtype
975-
``None`` means to use `pytensor.config.floatX`.
1049+
Data type of tensor variable. By default, it's pytensor.config.floatX.
9761050
9771051
"""
9781052
if dtype is None:
9791053
dtype = config.floatX
980-
type = TensorType(dtype, shape=(None, None, None))
1054+
shape = _validate_static_shape(shape, ndim=3)
1055+
type = TensorType(dtype, shape=shape)
9811056
return type(name)
9821057

9831058

@@ -996,21 +1071,27 @@ def tensor3(
9961071

9971072

9981073
def tensor4(
999-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1074+
name: Optional[str] = None,
1075+
dtype: Optional["DTypeLike"] = None,
1076+
shape: Optional[Tuple[ST, ST, ST, ST]] = (None, None, None, None),
10001077
) -> "TensorVariable":
10011078
"""Return a symbolic 4D variable.
10021079
10031080
Parameters
10041081
----------
10051082
name
1006-
A name to attach to this variable.
1083+
A name to attach to this variable
1084+
shape
1085+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1086+
allows that dimension to change size across evaluations.
10071087
dtype
1008-
``None`` means to use `pytensor.config.floatX`.
1088+
Data type of tensor variable. By default, it's pytensor.config.floatX.
10091089
10101090
"""
10111091
if dtype is None:
10121092
dtype = config.floatX
1013-
type = TensorType(dtype, shape=(None, None, None, None))
1093+
shape = _validate_static_shape(shape, ndim=4)
1094+
type = TensorType(dtype, shape=shape)
10141095
return type(name)
10151096

10161097

@@ -1029,21 +1110,27 @@ def tensor4(
10291110

10301111

10311112
def tensor5(
1032-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1113+
name: Optional[str] = None,
1114+
dtype: Optional["DTypeLike"] = None,
1115+
shape: Optional[Tuple[ST, ST, ST, ST, ST]] = (None, None, None, None, None),
10331116
) -> "TensorVariable":
10341117
"""Return a symbolic 5D variable.
10351118
10361119
Parameters
10371120
----------
10381121
name
1039-
A name to attach to this variable.
1122+
A name to attach to this variable
1123+
shape
1124+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1125+
allows that dimension to change size across evaluations.
10401126
dtype
1041-
``None`` means to use `pytensor.config.floatX`.
1127+
Data type of tensor variable. By default, it's pytensor.config.floatX.
10421128
10431129
"""
10441130
if dtype is None:
10451131
dtype = config.floatX
1046-
type = TensorType(dtype, shape=(None, None, None, None, None))
1132+
shape = _validate_static_shape(shape, ndim=5)
1133+
type = TensorType(dtype, shape=shape)
10471134
return type(name)
10481135

10491136

@@ -1062,21 +1149,34 @@ def tensor5(
10621149

10631150

10641151
def tensor6(
1065-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1152+
name: Optional[str] = None,
1153+
dtype: Optional["DTypeLike"] = None,
1154+
shape: Optional[Tuple[ST, ST, ST, ST, ST, ST]] = (
1155+
None,
1156+
None,
1157+
None,
1158+
None,
1159+
None,
1160+
None,
1161+
),
10661162
) -> "TensorVariable":
10671163
"""Return a symbolic 6D variable.
10681164
10691165
Parameters
10701166
----------
10711167
name
1072-
A name to attach to this variable.
1168+
A name to attach to this variable
1169+
shape
1170+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1171+
allows that dimension to change size across evaluations.
10731172
dtype
1074-
``None`` means to use `pytensor.config.floatX`.
1173+
Data type of tensor variable. By default, it's pytensor.config.floatX.
10751174
10761175
"""
10771176
if dtype is None:
10781177
dtype = config.floatX
1079-
type = TensorType(dtype, shape=(None,) * 6)
1178+
shape = _validate_static_shape(shape, ndim=6)
1179+
type = TensorType(dtype, shape=shape)
10801180
return type(name)
10811181

10821182

@@ -1095,21 +1195,35 @@ def tensor6(
10951195

10961196

10971197
def tensor7(
1098-
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
1198+
name: Optional[str] = None,
1199+
dtype: Optional["DTypeLike"] = None,
1200+
shape: Optional[Tuple[ST, ST, ST, ST, ST, ST, ST]] = (
1201+
None,
1202+
None,
1203+
None,
1204+
None,
1205+
None,
1206+
None,
1207+
None,
1208+
),
10991209
) -> "TensorVariable":
11001210
"""Return a symbolic 7-D variable.
11011211
11021212
Parameters
11031213
----------
11041214
name
1105-
A name to attach to this variable.
1215+
A name to attach to this variable
1216+
shape
1217+
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1218+
allows that dimension to change size across evaluations.
11061219
dtype
1107-
``None`` means to use `pytensor.config.floatX`.
1220+
Data type of tensor variable. By default, it's pytensor.config.floatX.
11081221
11091222
"""
11101223
if dtype is None:
11111224
dtype = config.floatX
1112-
type = TensorType(dtype, shape=(None,) * 7)
1225+
shape = _validate_static_shape(shape, ndim=7)
1226+
type = TensorType(dtype, shape=shape)
11131227
return type(name)
11141228

11151229

0 commit comments

Comments
 (0)