3
3
from typing import TYPE_CHECKING , Iterable , Optional , Tuple , Union
4
4
5
5
import numpy as np
6
+ from typing_extensions import Literal
6
7
7
8
import pytensor
8
9
from pytensor import scalar as aes
@@ -775,9 +776,17 @@ def values_eq_approx_always_true(a, b):
775
776
)
776
777
777
778
778
- def tensor (* args , ** kwargs ):
779
+ def tensor (
780
+ dtype : Optional ["DTypeLike" ] = None ,
781
+ * args ,
782
+ ** kwargs ,
783
+ ) -> "TensorVariable" :
784
+
785
+ if dtype is None :
786
+ dtype = config .floatX
787
+
779
788
name = kwargs .pop ("name" , None )
780
- return TensorType (* args , ** kwargs )(name = name )
789
+ return TensorType (dtype , * args , ** kwargs )(name = name )
781
790
782
791
783
792
cscalar = TensorType ("complex64" , ())
@@ -794,7 +803,10 @@ def tensor(*args, **kwargs):
794
803
ulscalar = TensorType ("uint64" , ())
795
804
796
805
797
- def scalar (name = None , dtype = None ):
806
+ def scalar (
807
+ name : Optional [str ] = None ,
808
+ dtype : Optional ["DTypeLike" ] = None ,
809
+ ) -> "TensorVariable" :
798
810
"""Return a symbolic scalar variable.
799
811
800
812
Parameters
@@ -832,20 +844,47 @@ def scalar(name=None, dtype=None):
832
844
lvector = TensorType ("int64" , shape = (None ,))
833
845
834
846
835
- def vector (name = None , dtype = None ):
847
+ ST = Union [int , None ]
848
+
849
+
850
+ def _validate_static_shape (shape , ndim : int ) -> Tuple [ST , ...]:
851
+
852
+ if not isinstance (shape , tuple ):
853
+ raise TypeError (f"Shape must be a tuple, got { type (shape )} " )
854
+
855
+ if len (shape ) != ndim :
856
+ raise ValueError (f"Shape must be a tuple of length { ndim } , got { shape } " )
857
+
858
+ if not all (sh is None or isinstance (sh , int ) for sh in shape ):
859
+ raise TypeError (f"Shape entries must be None or integer, got { shape } " )
860
+
861
+ return shape
862
+
863
+
864
+ def vector (
865
+ name : Optional [str ] = None ,
866
+ dtype : Optional ["DTypeLike" ] = None ,
867
+ shape : Optional [Tuple [ST ]] = (None ,),
868
+ ) -> "TensorVariable" :
836
869
"""Return a symbolic vector variable.
837
870
838
871
Parameters
839
872
----------
840
- dtype: numeric
841
- None means to use pytensor.config.floatX.
842
873
name
843
874
A name to attach to this variable
875
+ shape
876
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
877
+ allows that dimension to change size across evaluations.
878
+ dtype
879
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
844
880
845
881
"""
846
882
if dtype is None :
847
883
dtype = config .floatX
848
- type = TensorType (dtype , shape = (None ,))
884
+
885
+ shape = _validate_static_shape (shape , ndim = 1 )
886
+
887
+ type = TensorType (dtype , shape = shape )
849
888
return type (name )
850
889
851
890
@@ -867,20 +906,28 @@ def vector(name=None, dtype=None):
867
906
lmatrix = TensorType ("int64" , shape = (None , None ))
868
907
869
908
870
- def matrix (name = None , dtype = None ):
909
+ def matrix (
910
+ name : Optional [str ] = None ,
911
+ dtype : Optional ["DTypeLike" ] = None ,
912
+ shape : Optional [Tuple [ST , ST ]] = (None , None ),
913
+ ) -> "TensorVariable" :
871
914
"""Return a symbolic matrix variable.
872
915
873
916
Parameters
874
917
----------
875
- dtype: numeric
876
- None means to use pytensor.config.floatX.
877
918
name
878
- A name to attach to this variable.
919
+ A name to attach to this variable
920
+ shape
921
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
922
+ allows that dimension to change size across evaluations.
923
+ dtype
924
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
879
925
880
926
"""
881
927
if dtype is None :
882
928
dtype = config .floatX
883
- type = TensorType (dtype , shape = (None , None ))
929
+ shape = _validate_static_shape (shape , ndim = 2 )
930
+ type = TensorType (dtype , shape = shape )
884
931
return type (name )
885
932
886
933
@@ -902,20 +949,34 @@ def matrix(name=None, dtype=None):
902
949
lrow = TensorType ("int64" , shape = (1 , None ))
903
950
904
951
905
- def row (name = None , dtype = None ):
952
+ def row (
953
+ name : Optional [str ] = None ,
954
+ dtype : Optional ["DTypeLike" ] = None ,
955
+ shape : Optional [Tuple [Literal [1 ], ST ]] = (1 , None ),
956
+ ) -> "TensorVariable" :
906
957
"""Return a symbolic row variable (i.e. shape ``(1, None)``).
907
958
908
959
Parameters
909
960
----------
910
- dtype: numeric type
911
- None means to use pytensor.config.floatX.
912
961
name
913
- A name to attach to this variable.
962
+ A name to attach to this variable
963
+ shape
964
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
965
+ allows that dimension to change size across evaluations.
966
+ dtype
967
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
914
968
915
969
"""
916
970
if dtype is None :
917
971
dtype = config .floatX
918
- type = TensorType (dtype , shape = (1 , None ))
972
+ shape = _validate_static_shape (shape , ndim = 2 )
973
+
974
+ if shape [0 ] != 1 :
975
+ raise ValueError (
976
+ f"The first dimension of a `row` must have shape 1, got { shape [0 ]} "
977
+ )
978
+
979
+ type = TensorType (dtype , shape = shape )
919
980
return type (name )
920
981
921
982
@@ -932,21 +993,31 @@ def row(name=None, dtype=None):
932
993
933
994
934
995
def col (
935
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
996
+ name : Optional [str ] = None ,
997
+ dtype : Optional ["DTypeLike" ] = None ,
998
+ shape : Optional [Tuple [ST , Literal [1 ]]] = (None , 1 ),
936
999
) -> "TensorVariable" :
937
1000
"""Return a symbolic column variable (i.e. shape ``(None, 1)``).
938
1001
939
1002
Parameters
940
1003
----------
941
1004
name
942
- A name to attach to this variable.
1005
+ A name to attach to this variable
1006
+ shape
1007
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1008
+ allows that dimension to change size across evaluations.
943
1009
dtype
944
- ``None`` means to use ` pytensor.config.floatX` .
1010
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
945
1011
946
1012
"""
947
1013
if dtype is None :
948
1014
dtype = config .floatX
949
- type = TensorType (dtype , shape = (None , 1 ))
1015
+ shape = _validate_static_shape (shape , ndim = 2 )
1016
+ if shape [1 ] != 1 :
1017
+ raise ValueError (
1018
+ f"The second dimension of a `col` must have shape 1, got { shape [1 ]} "
1019
+ )
1020
+ type = TensorType (dtype , shape = shape )
950
1021
return type (name )
951
1022
952
1023
@@ -963,21 +1034,27 @@ def col(
963
1034
964
1035
965
1036
def tensor3 (
966
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1037
+ name : Optional [str ] = None ,
1038
+ dtype : Optional ["DTypeLike" ] = None ,
1039
+ shape : Optional [Tuple [ST , ST , ST ]] = (None , None , None ),
967
1040
) -> "TensorVariable" :
968
1041
"""Return a symbolic 3D variable.
969
1042
970
1043
Parameters
971
1044
----------
972
1045
name
973
- A name to attach to this variable.
1046
+ A name to attach to this variable
1047
+ shape
1048
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1049
+ allows that dimension to change size across evaluations.
974
1050
dtype
975
- ``None`` means to use ` pytensor.config.floatX` .
1051
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
976
1052
977
1053
"""
978
1054
if dtype is None :
979
1055
dtype = config .floatX
980
- type = TensorType (dtype , shape = (None , None , None ))
1056
+ shape = _validate_static_shape (shape , ndim = 3 )
1057
+ type = TensorType (dtype , shape = shape )
981
1058
return type (name )
982
1059
983
1060
@@ -996,21 +1073,27 @@ def tensor3(
996
1073
997
1074
998
1075
def tensor4 (
999
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1076
+ name : Optional [str ] = None ,
1077
+ dtype : Optional ["DTypeLike" ] = None ,
1078
+ shape : Optional [Tuple [ST , ST , ST , ST ]] = (None , None , None , None ),
1000
1079
) -> "TensorVariable" :
1001
1080
"""Return a symbolic 4D variable.
1002
1081
1003
1082
Parameters
1004
1083
----------
1005
1084
name
1006
- A name to attach to this variable.
1085
+ A name to attach to this variable
1086
+ shape
1087
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1088
+ allows that dimension to change size across evaluations.
1007
1089
dtype
1008
- ``None`` means to use ` pytensor.config.floatX` .
1090
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
1009
1091
1010
1092
"""
1011
1093
if dtype is None :
1012
1094
dtype = config .floatX
1013
- type = TensorType (dtype , shape = (None , None , None , None ))
1095
+ shape = _validate_static_shape (shape , ndim = 4 )
1096
+ type = TensorType (dtype , shape = shape )
1014
1097
return type (name )
1015
1098
1016
1099
@@ -1029,21 +1112,27 @@ def tensor4(
1029
1112
1030
1113
1031
1114
def tensor5 (
1032
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1115
+ name : Optional [str ] = None ,
1116
+ dtype : Optional ["DTypeLike" ] = None ,
1117
+ shape : Optional [Tuple [ST , ST , ST , ST , ST ]] = (None , None , None , None , None ),
1033
1118
) -> "TensorVariable" :
1034
1119
"""Return a symbolic 5D variable.
1035
1120
1036
1121
Parameters
1037
1122
----------
1038
1123
name
1039
- A name to attach to this variable.
1124
+ A name to attach to this variable
1125
+ shape
1126
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1127
+ allows that dimension to change size across evaluations.
1040
1128
dtype
1041
- ``None`` means to use ` pytensor.config.floatX` .
1129
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
1042
1130
1043
1131
"""
1044
1132
if dtype is None :
1045
1133
dtype = config .floatX
1046
- type = TensorType (dtype , shape = (None , None , None , None , None ))
1134
+ shape = _validate_static_shape (shape , ndim = 5 )
1135
+ type = TensorType (dtype , shape = shape )
1047
1136
return type (name )
1048
1137
1049
1138
@@ -1062,21 +1151,34 @@ def tensor5(
1062
1151
1063
1152
1064
1153
def tensor6 (
1065
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1154
+ name : Optional [str ] = None ,
1155
+ dtype : Optional ["DTypeLike" ] = None ,
1156
+ shape : Optional [Tuple [ST , ST , ST , ST , ST , ST ]] = (
1157
+ None ,
1158
+ None ,
1159
+ None ,
1160
+ None ,
1161
+ None ,
1162
+ None ,
1163
+ ),
1066
1164
) -> "TensorVariable" :
1067
1165
"""Return a symbolic 6D variable.
1068
1166
1069
1167
Parameters
1070
1168
----------
1071
1169
name
1072
- A name to attach to this variable.
1170
+ A name to attach to this variable
1171
+ shape
1172
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1173
+ allows that dimension to change size across evaluations.
1073
1174
dtype
1074
- ``None`` means to use ` pytensor.config.floatX` .
1175
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
1075
1176
1076
1177
"""
1077
1178
if dtype is None :
1078
1179
dtype = config .floatX
1079
- type = TensorType (dtype , shape = (None ,) * 6 )
1180
+ shape = _validate_static_shape (shape , ndim = 6 )
1181
+ type = TensorType (dtype , shape = shape )
1080
1182
return type (name )
1081
1183
1082
1184
@@ -1095,21 +1197,35 @@ def tensor6(
1095
1197
1096
1198
1097
1199
def tensor7 (
1098
- name : Optional [str ] = None , dtype : Optional ["DTypeLike" ] = None
1200
+ name : Optional [str ] = None ,
1201
+ dtype : Optional ["DTypeLike" ] = None ,
1202
+ shape : Optional [Tuple [ST , ST , ST , ST , ST , ST , ST ]] = (
1203
+ None ,
1204
+ None ,
1205
+ None ,
1206
+ None ,
1207
+ None ,
1208
+ None ,
1209
+ None ,
1210
+ ),
1099
1211
) -> "TensorVariable" :
1100
1212
"""Return a symbolic 7-D variable.
1101
1213
1102
1214
Parameters
1103
1215
----------
1104
1216
name
1105
- A name to attach to this variable.
1217
+ A name to attach to this variable
1218
+ shape
1219
+ A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
1220
+ allows that dimension to change size across evaluations.
1106
1221
dtype
1107
- ``None`` means to use ` pytensor.config.floatX` .
1222
+ Data type of tensor variable. By default, it's pytensor.config.floatX.
1108
1223
1109
1224
"""
1110
1225
if dtype is None :
1111
1226
dtype = config .floatX
1112
- type = TensorType (dtype , shape = (None ,) * 7 )
1227
+ shape = _validate_static_shape (shape , ndim = 7 )
1228
+ type = TensorType (dtype , shape = shape )
1113
1229
return type (name )
1114
1230
1115
1231
0 commit comments