3
3
from typing import TYPE_CHECKING , Iterable , Optional , Tuple , Union
4
4
5
5
import numpy as np
6
+ from typing_extensions import Literal
6
7
7
8
import pytensor
8
9
from pytensor import scalar as aes
@@ -775,9 +776,16 @@ def values_eq_approx_always_true(a, b):
775
776
)
776
777
777
778
778
- def tensor (* args , ** kwargs ):
779
+ def tensor (
780
+ dtype : Optional ["DTypeLike" ] = None ,
781
+ ** kwargs ,
782
+ ) -> "TensorVariable" :
783
+
784
+ if dtype is None :
785
+ dtype = config .floatX
786
+
779
787
name = kwargs .pop ("name" , None )
780
- return TensorType (* args , ** kwargs )(name = name )
788
+ return TensorType (dtype = dtype , ** kwargs )(name = name )
781
789
782
790
783
791
cscalar = TensorType ("complex64" , ())
@@ -794,7 +802,10 @@ def tensor(*args, **kwargs):
794
802
ulscalar = TensorType ("uint64" , ())
795
803
796
804
797
- def scalar (name = None , dtype = None ):
805
+ def scalar (
806
+ name : Optional [str ] = None ,
807
+ dtype : Optional ["DTypeLike" ] = None ,
808
+ ) -> "TensorVariable" :
798
809
"""Return a symbolic scalar variable.
799
810
800
811
Parameters
@@ -832,20 +843,47 @@ def scalar(name=None, dtype=None):
832
843
lvector = TensorType ("int64" , shape = (None ,))
833
844
834
845
835
- def vector (name = None , dtype = None ):
846
+ ST = Union [int , None ]
847
+
848
+
849
+ def _validate_static_shape (shape , ndim : int ) -> Tuple [ST , ...]:
850
+
851
+ if not isinstance (shape , tuple ):
852
+ raise TypeError (f"Shape must be a tuple, got { type (shape )} " )
853
+
854
+ if len (shape ) != ndim :
855
+ raise ValueError (f"Shape must be a tuple of length { ndim } , got { shape } " )
856
+
857
+ if not all (sh is None or isinstance (sh , int ) for sh in shape ):
858
+ raise TypeError (f"Shape entries must be None or integer, got { shape } " )
859
+
860
+ return shape
861
+
862
+
863
+ def vector (
864
+ name : Optional [str ] = None ,
865
+ dtype : Optional ["DTypeLike" ] = None ,
866
+ shape : Optional [Tuple [ST ]] = (None ,),
867
+ ) -> "TensorVariable" :
836
868
"""Return a symbolic vector variable.
837
869
838
870
Parameters
839
871
----------
840
- dtype: numeric
841
- None means to use pytensor.config.floatX.
842
872
name
843
873
A name to attach to this variable
874
+ shape
875
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
876
+ allows that dimension to change size across evaluations.
877
+ dtype
878
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
844
879
845
880
"""
846
881
if dtype is None :
847
882
dtype = config .floatX
848
- type = TensorType (dtype , shape = (None ,))
883
+
884
+ shape = _validate_static_shape (shape , ndim = 1 )
885
+
886
+ type = TensorType (dtype , shape = shape )
849
887
return type (name )
850
888
851
889
@@ -867,20 +905,28 @@ def vector(name=None, dtype=None):
867
905
lmatrix = TensorType ("int64" , shape = (None , None ))
868
906
869
907
870
- def matrix (name = None , dtype = None ):
908
+ def matrix (
909
+ name : Optional [str ] = None ,
910
+ dtype : Optional ["DTypeLike" ] = None ,
911
+ shape : Optional [Tuple [ST , ST ]] = (None , None ),
912
+ ) -> "TensorVariable" :
871
913
"""Return a symbolic matrix variable.
872
914
873
915
Parameters
874
916
----------
875
- dtype: numeric
876
- None means to use pytensor.config.floatX.
877
917
name
878
- A name to attach to this variable.
918
+ A name to attach to this variable
919
+ shape
920
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
921
+ allows that dimension to change size across evaluations.
922
+ dtype
923
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
879
924
880
925
"""
881
926
if dtype is None :
882
927
dtype = config .floatX
883
- type = TensorType (dtype , shape = (None , None ))
928
+ shape = _validate_static_shape (shape , ndim = 2 )
929
+ type = TensorType (dtype , shape = shape )
884
930
return type (name )
885
931
886
932
@@ -902,20 +948,34 @@ def matrix(name=None, dtype=None):
902
948
lrow = TensorType ("int64" , shape = (1 , None ))
903
949
904
950
905
- def row (name = None , dtype = None ):
951
+ def row (
952
+ name : Optional [str ] = None ,
953
+ dtype : Optional ["DTypeLike" ] = None ,
954
+ shape : Optional [Tuple [Literal [1 ], ST ]] = (1 , None ),
955
+ ) -> "TensorVariable" :
906
956
"""Return a symbolic row variable (i.e. shape ``(1, None)``).
907
957
908
958
Parameters
909
959
----------
910
- dtype: numeric type
911
- None means to use pytensor.config.floatX.
912
960
name
913
- A name to attach to this variable.
961
+ A name to attach to this variable
962
+ shape
963
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
964
+ allows that dimension to change size across evaluations.
965
+ dtype
966
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
914
967
915
968
"""
916
969
if dtype is None :
917
970
dtype = config .floatX
918
- type = TensorType (dtype , shape = (1 , None ))
971
+ shape = _validate_static_shape (shape , ndim = 2 )
972
+
973
+ if shape [0 ] != 1 :
974
+ raise ValueError (
975
+ f"The first dimension of a `row` must have shape 1, got { shape [0 ]} "
976
+ )
977
+
978
+ type = TensorType (dtype , shape = shape )
919
979
return type (name )
920
980
921
981
@@ -932,21 +992,31 @@ def row(name=None, dtype=None):
932
992
933
993
934
994
def col (
935
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
995
+ name : Optional [str ] = None ,
996
+ dtype : Optional ["DTypeLike" ] = None ,
997
+ shape : Optional [Tuple [ST , Literal [1 ]]] = (None , 1 ),
936
998
) -> "TensorVariable" :
937
999
"""Return a symbolic column variable (i.e. shape ``(None, 1)``).
938
1000
939
1001
Parameters
940
1002
----------
941
1003
name
942
- A name to attach to this variable.
1004
+ A name to attach to this variable
1005
+ shape
1006
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1007
+ allows that dimension to change size across evaluations.
943
1008
dtype
944
- ``None`` means to use ` pytensor.config.floatX` .
1009
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
945
1010
946
1011
"""
947
1012
if dtype is None :
948
1013
dtype = config .floatX
949
- type = TensorType (dtype , shape = (None , 1 ))
1014
+ shape = _validate_static_shape (shape , ndim = 2 )
1015
+ if shape [1 ] != 1 :
1016
+ raise ValueError (
1017
+ f"The second dimension of a `col` must have shape 1, got { shape [1 ]} "
1018
+ )
1019
+ type = TensorType (dtype , shape = shape )
950
1020
return type (name )
951
1021
952
1022
@@ -963,21 +1033,27 @@ def col(
963
1033
964
1034
965
1035
def tensor3 (
966
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1036
+ name : Optional [str ] = None ,
1037
+ dtype : Optional ["DTypeLike" ] = None ,
1038
+ shape : Optional [Tuple [ST , ST , ST ]] = (None , None , None ),
967
1039
) -> "TensorVariable" :
968
1040
"""Return a symbolic 3D variable.
969
1041
970
1042
Parameters
971
1043
----------
972
1044
name
973
- A name to attach to this variable.
1045
+ A name to attach to this variable
1046
+ shape
1047
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1048
+ allows that dimension to change size across evaluations.
974
1049
dtype
975
- ``None`` means to use ` pytensor.config.floatX` .
1050
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
976
1051
977
1052
"""
978
1053
if dtype is None :
979
1054
dtype = config .floatX
980
- type = TensorType (dtype , shape = (None , None , None ))
1055
+ shape = _validate_static_shape (shape , ndim = 3 )
1056
+ type = TensorType (dtype , shape = shape )
981
1057
return type (name )
982
1058
983
1059
@@ -996,21 +1072,27 @@ def tensor3(
996
1072
997
1073
998
1074
def tensor4 (
999
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1075
+ name : Optional [str ] = None ,
1076
+ dtype : Optional ["DTypeLike" ] = None ,
1077
+ shape : Optional [Tuple [ST , ST , ST , ST ]] = (None , None , None , None ),
1000
1078
) -> "TensorVariable" :
1001
1079
"""Return a symbolic 4D variable.
1002
1080
1003
1081
Parameters
1004
1082
----------
1005
1083
name
1006
- A name to attach to this variable.
1084
+ A name to attach to this variable
1085
+ shape
1086
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1087
+ allows that dimension to change size across evaluations.
1007
1088
dtype
1008
- ``None`` means to use ` pytensor.config.floatX` .
1089
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
1009
1090
1010
1091
"""
1011
1092
if dtype is None :
1012
1093
dtype = config .floatX
1013
- type = TensorType (dtype , shape = (None , None , None , None ))
1094
+ shape = _validate_static_shape (shape , ndim = 4 )
1095
+ type = TensorType (dtype , shape = shape )
1014
1096
return type (name )
1015
1097
1016
1098
@@ -1029,21 +1111,27 @@ def tensor4(
1029
1111
1030
1112
1031
1113
def tensor5 (
1032
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1114
+ name : Optional [str ] = None ,
1115
+ dtype : Optional ["DTypeLike" ] = None ,
1116
+ shape : Optional [Tuple [ST , ST , ST , ST , ST ]] = (None , None , None , None , None ),
1033
1117
) -> "TensorVariable" :
1034
1118
"""Return a symbolic 5D variable.
1035
1119
1036
1120
Parameters
1037
1121
----------
1038
1122
name
1039
- A name to attach to this variable.
1123
+ A name to attach to this variable
1124
+ shape
1125
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1126
+ allows that dimension to change size across evaluations.
1040
1127
dtype
1041
- ``None`` means to use ` pytensor.config.floatX` .
1128
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
1042
1129
1043
1130
"""
1044
1131
if dtype is None :
1045
1132
dtype = config .floatX
1046
- type = TensorType (dtype , shape = (None , None , None , None , None ))
1133
+ shape = _validate_static_shape (shape , ndim = 5 )
1134
+ type = TensorType (dtype , shape = shape )
1047
1135
return type (name )
1048
1136
1049
1137
@@ -1062,21 +1150,34 @@ def tensor5(
1062
1150
1063
1151
1064
1152
def tensor6 (
1065
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1153
+ name : Optional [str ] = None ,
1154
+ dtype : Optional ["DTypeLike" ] = None ,
1155
+ shape : Optional [Tuple [ST , ST , ST , ST , ST , ST ]] = (
1156
+ None ,
1157
+ None ,
1158
+ None ,
1159
+ None ,
1160
+ None ,
1161
+ None ,
1162
+ ),
1066
1163
) -> "TensorVariable" :
1067
1164
"""Return a symbolic 6D variable.
1068
1165
1069
1166
Parameters
1070
1167
----------
1071
1168
name
1072
- A name to attach to this variable.
1169
+ A name to attach to this variable
1170
+ shape
1171
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1172
+ allows that dimension to change size across evaluations.
1073
1173
dtype
1074
- ``None`` means to use ` pytensor.config.floatX` .
1174
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
1075
1175
1076
1176
"""
1077
1177
if dtype is None :
1078
1178
dtype = config .floatX
1079
- type = TensorType (dtype , shape = (None ,) * 6 )
1179
+ shape = _validate_static_shape (shape , ndim = 6 )
1180
+ type = TensorType (dtype , shape = shape )
1080
1181
return type (name )
1081
1182
1082
1183
@@ -1095,21 +1196,35 @@ def tensor6(
1095
1196
1096
1197
1097
1198
def tensor7 (
1098
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1199
+ name : Optional [str ] = None ,
1200
+ dtype : Optional ["DTypeLike" ] = None ,
1201
+ shape : Optional [Tuple [ST , ST , ST , ST , ST , ST , ST ]] = (
1202
+ None ,
1203
+ None ,
1204
+ None ,
1205
+ None ,
1206
+ None ,
1207
+ None ,
1208
+ None ,
1209
+ ),
1099
1210
) -> "TensorVariable" :
1100
1211
"""Return a symbolic 7-D variable.
1101
1212
1102
1213
Parameters
1103
1214
----------
1104
1215
name
1105
- A name to attach to this variable.
1216
+ A name to attach to this variable
1217
+ shape
1218
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1219
+ allows that dimension to change size across evaluations.
1106
1220
dtype
1107
- ``None`` means to use ` pytensor.config.floatX` .
1221
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
1108
1222
1109
1223
"""
1110
1224
if dtype is None :
1111
1225
dtype = config .floatX
1112
- type = TensorType (dtype , shape = (None ,) * 7 )
1226
+ shape = _validate_static_shape (shape , ndim = 7 )
1227
+ type = TensorType (dtype , shape = shape )
1113
1228
return type (name )
1114
1229
1115
1230
0 commit comments