1
1
import logging
2
2
import warnings
3
- from typing import TYPE_CHECKING , Iterable , Optional , Tuple , Union
3
+ from typing import TYPE_CHECKING , Iterable , Literal , Optional , Tuple , Union
4
4
5
5
import numpy as np
6
6
@@ -775,9 +775,16 @@ def values_eq_approx_always_true(a, b):
775
775
)
776
776
777
777
778
- def tensor (* args , ** kwargs ):
778
+ def tensor (
779
+ dtype : Optional ["DTypeLike" ] = None ,
780
+ ** kwargs ,
781
+ ) -> "TensorVariable" :
782
+
783
+ if dtype is None :
784
+ dtype = config .floatX
785
+
779
786
name = kwargs .pop ("name" , None )
780
- return TensorType (* args , ** kwargs )(name = name )
787
+ return TensorType (dtype = dtype , ** kwargs )(name = name )
781
788
782
789
783
790
cscalar = TensorType ("complex64" , ())
@@ -794,7 +801,10 @@ def tensor(*args, **kwargs):
794
801
ulscalar = TensorType ("uint64" , ())
795
802
796
803
797
- def scalar (name = None , dtype = None ):
804
+ def scalar (
805
+ name : Optional [str ] = None ,
806
+ dtype : Optional ["DTypeLike" ] = None ,
807
+ ) -> "TensorVariable" :
798
808
"""Return a symbolic scalar variable.
799
809
800
810
Parameters
@@ -832,20 +842,47 @@ def scalar(name=None, dtype=None):
832
842
lvector = TensorType ("int64" , shape = (None ,))
833
843
834
844
835
- def vector (name = None , dtype = None ):
845
+ ST = Union [int , None ]
846
+
847
+
848
+ def _validate_static_shape (shape , ndim : int ) -> Tuple [ST , ...]:
849
+
850
+ if not isinstance (shape , tuple ):
851
+ raise TypeError (f"Shape must be a tuple, got { type (shape )} " )
852
+
853
+ if len (shape ) != ndim :
854
+ raise ValueError (f"Shape must be a tuple of length { ndim } , got { shape } " )
855
+
856
+ if not all (sh is None or isinstance (sh , int ) for sh in shape ):
857
+ raise TypeError (f"Shape entries must be None or integer, got { shape } " )
858
+
859
+ return shape
860
+
861
+
862
+ def vector (
863
+ name : Optional [str ] = None ,
864
+ dtype : Optional ["DTypeLike" ] = None ,
865
+ shape : Optional [Tuple [ST ]] = (None ,),
866
+ ) -> "TensorVariable" :
836
867
"""Return a symbolic vector variable.
837
868
838
869
Parameters
839
870
----------
840
- dtype: numeric
841
- None means to use pytensor.config.floatX.
842
871
name
843
872
A name to attach to this variable
873
+ shape
874
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
875
+ allows that dimension to change size across evaluations.
876
+ dtype
877
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
844
878
845
879
"""
846
880
if dtype is None :
847
881
dtype = config .floatX
848
- type = TensorType (dtype , shape = (None ,))
882
+
883
+ shape = _validate_static_shape (shape , ndim = 1 )
884
+
885
+ type = TensorType (dtype , shape = shape )
849
886
return type (name )
850
887
851
888
@@ -867,20 +904,28 @@ def vector(name=None, dtype=None):
867
904
lmatrix = TensorType ("int64" , shape = (None , None ))
868
905
869
906
870
- def matrix (name = None , dtype = None ):
907
+ def matrix (
908
+ name : Optional [str ] = None ,
909
+ dtype : Optional ["DTypeLike" ] = None ,
910
+ shape : Optional [Tuple [ST , ST ]] = (None , None ),
911
+ ) -> "TensorVariable" :
871
912
"""Return a symbolic matrix variable.
872
913
873
914
Parameters
874
915
----------
875
- dtype: numeric
876
- None means to use pytensor.config.floatX.
877
916
name
878
- A name to attach to this variable.
917
+ A name to attach to this variable
918
+ shape
919
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
920
+ allows that dimension to change size across evaluations.
921
+ dtype
922
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
879
923
880
924
"""
881
925
if dtype is None :
882
926
dtype = config .floatX
883
- type = TensorType (dtype , shape = (None , None ))
927
+ shape = _validate_static_shape (shape , ndim = 2 )
928
+ type = TensorType (dtype , shape = shape )
884
929
return type (name )
885
930
886
931
@@ -902,20 +947,34 @@ def matrix(name=None, dtype=None):
902
947
lrow = TensorType ("int64" , shape = (1 , None ))
903
948
904
949
905
- def row (name = None , dtype = None ):
950
+ def row (
951
+ name : Optional [str ] = None ,
952
+ dtype : Optional ["DTypeLike" ] = None ,
953
+ shape : Optional [Tuple [Literal [1 ], ST ]] = (1 , None ),
954
+ ) -> "TensorVariable" :
906
955
"""Return a symbolic row variable (i.e. shape ``(1, None)``).
907
956
908
957
Parameters
909
958
----------
910
- dtype: numeric type
911
- None means to use pytensor.config.floatX.
912
959
name
913
- A name to attach to this variable.
960
+ A name to attach to this variable
961
+ shape
962
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
963
+ allows that dimension to change size across evaluations.
964
+ dtype
965
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
914
966
915
967
"""
916
968
if dtype is None :
917
969
dtype = config .floatX
918
- type = TensorType (dtype , shape = (1 , None ))
970
+ shape = _validate_static_shape (shape , ndim = 2 )
971
+
972
+ if shape [0 ] != 1 :
973
+ raise ValueError (
974
+ f"The first dimension of a `row` must have shape 1, got { shape [0 ]} "
975
+ )
976
+
977
+ type = TensorType (dtype , shape = shape )
919
978
return type (name )
920
979
921
980
@@ -932,21 +991,31 @@ def row(name=None, dtype=None):
932
991
933
992
934
993
def col (
935
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
994
+ name : Optional [str ] = None ,
995
+ dtype : Optional ["DTypeLike" ] = None ,
996
+ shape : Optional [Tuple [ST , Literal [1 ]]] = (None , 1 ),
936
997
) -> "TensorVariable" :
937
998
"""Return a symbolic column variable (i.e. shape ``(None, 1)``).
938
999
939
1000
Parameters
940
1001
----------
941
1002
name
942
- A name to attach to this variable.
1003
+ A name to attach to this variable
1004
+ shape
1005
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1006
+ allows that dimension to change size across evaluations.
943
1007
dtype
944
- ``None`` means to use ` pytensor.config.floatX` .
1008
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
945
1009
946
1010
"""
947
1011
if dtype is None :
948
1012
dtype = config .floatX
949
- type = TensorType (dtype , shape = (None , 1 ))
1013
+ shape = _validate_static_shape (shape , ndim = 2 )
1014
+ if shape [1 ] != 1 :
1015
+ raise ValueError (
1016
+ f"The second dimension of a `col` must have shape 1, got { shape [1 ]} "
1017
+ )
1018
+ type = TensorType (dtype , shape = shape )
950
1019
return type (name )
951
1020
952
1021
@@ -963,21 +1032,27 @@ def col(
963
1032
964
1033
965
1034
def tensor3 (
966
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1035
+ name : Optional [str ] = None ,
1036
+ dtype : Optional ["DTypeLike" ] = None ,
1037
+ shape : Optional [Tuple [ST , ST , ST ]] = (None , None , None ),
967
1038
) -> "TensorVariable" :
968
1039
"""Return a symbolic 3D variable.
969
1040
970
1041
Parameters
971
1042
----------
972
1043
name
973
- A name to attach to this variable.
1044
+ A name to attach to this variable
1045
+ shape
1046
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1047
+ allows that dimension to change size across evaluations.
974
1048
dtype
975
- ``None`` means to use ` pytensor.config.floatX` .
1049
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
976
1050
977
1051
"""
978
1052
if dtype is None :
979
1053
dtype = config .floatX
980
- type = TensorType (dtype , shape = (None , None , None ))
1054
+ shape = _validate_static_shape (shape , ndim = 3 )
1055
+ type = TensorType (dtype , shape = shape )
981
1056
return type (name )
982
1057
983
1058
@@ -996,21 +1071,27 @@ def tensor3(
996
1071
997
1072
998
1073
def tensor4 (
999
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1074
+ name : Optional [str ] = None ,
1075
+ dtype : Optional ["DTypeLike" ] = None ,
1076
+ shape : Optional [Tuple [ST , ST , ST , ST ]] = (None , None , None , None ),
1000
1077
) -> "TensorVariable" :
1001
1078
"""Return a symbolic 4D variable.
1002
1079
1003
1080
Parameters
1004
1081
----------
1005
1082
name
1006
- A name to attach to this variable.
1083
+ A name to attach to this variable
1084
+ shape
1085
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1086
+ allows that dimension to change size across evaluations.
1007
1087
dtype
1008
- ``None`` means to use ` pytensor.config.floatX` .
1088
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
1009
1089
1010
1090
"""
1011
1091
if dtype is None :
1012
1092
dtype = config .floatX
1013
- type = TensorType (dtype , shape = (None , None , None , None ))
1093
+ shape = _validate_static_shape (shape , ndim = 4 )
1094
+ type = TensorType (dtype , shape = shape )
1014
1095
return type (name )
1015
1096
1016
1097
@@ -1029,21 +1110,27 @@ def tensor4(
1029
1110
1030
1111
1031
1112
def tensor5 (
1032
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1113
+ name : Optional [str ] = None ,
1114
+ dtype : Optional ["DTypeLike" ] = None ,
1115
+ shape : Optional [Tuple [ST , ST , ST , ST , ST ]] = (None , None , None , None , None ),
1033
1116
) -> "TensorVariable" :
1034
1117
"""Return a symbolic 5D variable.
1035
1118
1036
1119
Parameters
1037
1120
----------
1038
1121
name
1039
- A name to attach to this variable.
1122
+ A name to attach to this variable
1123
+ shape
1124
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1125
+ allows that dimension to change size across evaluations.
1040
1126
dtype
1041
- ``None`` means to use ` pytensor.config.floatX` .
1127
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
1042
1128
1043
1129
"""
1044
1130
if dtype is None :
1045
1131
dtype = config .floatX
1046
- type = TensorType (dtype , shape = (None , None , None , None , None ))
1132
+ shape = _validate_static_shape (shape , ndim = 5 )
1133
+ type = TensorType (dtype , shape = shape )
1047
1134
return type (name )
1048
1135
1049
1136
@@ -1062,21 +1149,34 @@ def tensor5(
1062
1149
1063
1150
1064
1151
def tensor6 (
1065
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1152
+ name : Optional [str ] = None ,
1153
+ dtype : Optional ["DTypeLike" ] = None ,
1154
+ shape : Optional [Tuple [ST , ST , ST , ST , ST , ST ]] = (
1155
+ None ,
1156
+ None ,
1157
+ None ,
1158
+ None ,
1159
+ None ,
1160
+ None ,
1161
+ ),
1066
1162
) -> "TensorVariable" :
1067
1163
"""Return a symbolic 6D variable.
1068
1164
1069
1165
Parameters
1070
1166
----------
1071
1167
name
1072
- A name to attach to this variable.
1168
+ A name to attach to this variable
1169
+ shape
1170
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1171
+ allows that dimension to change size across evaluations.
1073
1172
dtype
1074
- ``None`` means to use ` pytensor.config.floatX` .
1173
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
1075
1174
1076
1175
"""
1077
1176
if dtype is None :
1078
1177
dtype = config .floatX
1079
- type = TensorType (dtype , shape = (None ,) * 6 )
1178
+ shape = _validate_static_shape (shape , ndim = 6 )
1179
+ type = TensorType (dtype , shape = shape )
1080
1180
return type (name )
1081
1181
1082
1182
@@ -1095,21 +1195,35 @@ def tensor6(
1095
1195
1096
1196
1097
1197
def tensor7 (
1098
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1198
+ name : Optional [str ] = None ,
1199
+ dtype : Optional ["DTypeLike" ] = None ,
1200
+ shape : Optional [Tuple [ST , ST , ST , ST , ST , ST , ST ]] = (
1201
+ None ,
1202
+ None ,
1203
+ None ,
1204
+ None ,
1205
+ None ,
1206
+ None ,
1207
+ None ,
1208
+ ),
1099
1209
) -> "TensorVariable" :
1100
1210
"""Return a symbolic 7-D variable.
1101
1211
1102
1212
Parameters
1103
1213
----------
1104
1214
name
1105
- A name to attach to this variable.
1215
+ A name to attach to this variable
1216
+ shape
1217
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1218
+ allows that dimension to change size across evaluations.
1106
1219
dtype
1107
- ``None`` means to use ` pytensor.config.floatX` .
1220
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
1108
1221
1109
1222
"""
1110
1223
if dtype is None :
1111
1224
dtype = config .floatX
1112
- type = TensorType (dtype , shape = (None ,) * 7 )
1225
+ shape = _validate_static_shape (shape , ndim = 7 )
1226
+ type = TensorType (dtype , shape = shape )
1113
1227
return type (name )
1114
1228
1115
1229
0 commit comments