-
Notifications
You must be signed in to change notification settings - Fork 134
Closed
Labels
Description
Please describe the purpose of filing this issue
Would be nice if we could pass the static shape to the tensor constructors now that it is supported
import pytensor.tensor as pt
pt.vector(shape=(3, 4))
For some types that constrain the shape, we could assert they are compatible
pt.col(shape=(5, 1))
pt.col(shape=(1, 5)) # Raise ValueError
In addition, would be nice if the more general at.tensor
, had a default dtype of floatX
, which is the most common case
pt.tensor("float64", shape=(None, None, 5)) # Fine
pt.tensor(shape=(None, None, 5))
# TypeError: TensorType.__init__() missing 1 required positional argument: 'dtype'
These are implemented here:
pytensor/pytensor/tensor/type.py
Lines 778 to 1118 in b7a2d30
def tensor(*args, **kwargs): | |
name = kwargs.pop("name", None) | |
return TensorType(*args, **kwargs)(name=name) | |
cscalar = TensorType("complex64", ()) | |
zscalar = TensorType("complex128", ()) | |
fscalar = TensorType("float32", ()) | |
dscalar = TensorType("float64", ()) | |
bscalar = TensorType("int8", ()) | |
wscalar = TensorType("int16", ()) | |
iscalar = TensorType("int32", ()) | |
lscalar = TensorType("int64", ()) | |
ubscalar = TensorType("uint8", ()) | |
uwscalar = TensorType("uint16", ()) | |
uiscalar = TensorType("uint32", ()) | |
ulscalar = TensorType("uint64", ()) | |
def scalar(name=None, dtype=None): | |
"""Return a symbolic scalar variable. | |
Parameters | |
---------- | |
dtype: numeric | |
None means to use pytensor.config.floatX. | |
name | |
A name to attach to this variable. | |
""" | |
if dtype is None: | |
dtype = config.floatX | |
type = TensorType(dtype, ()) | |
return type(name) | |
scalars, fscalars, dscalars, iscalars, lscalars = apply_across_args( | |
scalar, fscalar, dscalar, iscalar, lscalar | |
) | |
int_types = bscalar, wscalar, iscalar, lscalar | |
float_types = fscalar, dscalar | |
complex_types = cscalar, zscalar | |
int_scalar_types = int_types | |
float_scalar_types = float_types | |
complex_scalar_types = complex_types | |
cvector = TensorType("complex64", shape=(None,)) | |
zvector = TensorType("complex128", shape=(None,)) | |
fvector = TensorType("float32", shape=(None,)) | |
dvector = TensorType("float64", shape=(None,)) | |
bvector = TensorType("int8", shape=(None,)) | |
wvector = TensorType("int16", shape=(None,)) | |
ivector = TensorType("int32", shape=(None,)) | |
lvector = TensorType("int64", shape=(None,)) | |
def vector(name=None, dtype=None): | |
"""Return a symbolic vector variable. | |
Parameters | |
---------- | |
dtype: numeric | |
None means to use pytensor.config.floatX. | |
name | |
A name to attach to this variable | |
""" | |
if dtype is None: | |
dtype = config.floatX | |
type = TensorType(dtype, shape=(None,)) | |
return type(name) | |
vectors, fvectors, dvectors, ivectors, lvectors = apply_across_args( | |
vector, fvector, dvector, ivector, lvector | |
) | |
int_vector_types = bvector, wvector, ivector, lvector | |
float_vector_types = fvector, dvector | |
complex_vector_types = cvector, zvector | |
cmatrix = TensorType("complex64", shape=(None, None)) | |
zmatrix = TensorType("complex128", shape=(None, None)) | |
fmatrix = TensorType("float32", shape=(None, None)) | |
dmatrix = TensorType("float64", shape=(None, None)) | |
bmatrix = TensorType("int8", shape=(None, None)) | |
wmatrix = TensorType("int16", shape=(None, None)) | |
imatrix = TensorType("int32", shape=(None, None)) | |
lmatrix = TensorType("int64", shape=(None, None)) | |
def matrix(name=None, dtype=None): | |
"""Return a symbolic matrix variable. | |
Parameters | |
---------- | |
dtype: numeric | |
None means to use pytensor.config.floatX. | |
name | |
A name to attach to this variable. | |
""" | |
if dtype is None: | |
dtype = config.floatX | |
type = TensorType(dtype, shape=(None, None)) | |
return type(name) | |
matrices, fmatrices, dmatrices, imatrices, lmatrices = apply_across_args( | |
matrix, fmatrix, dmatrix, imatrix, lmatrix | |
) | |
int_matrix_types = bmatrix, wmatrix, imatrix, lmatrix | |
float_matrix_types = fmatrix, dmatrix | |
complex_matrix_types = cmatrix, zmatrix | |
crow = TensorType("complex64", shape=(1, None)) | |
zrow = TensorType("complex128", shape=(1, None)) | |
frow = TensorType("float32", shape=(1, None)) | |
drow = TensorType("float64", shape=(1, None)) | |
brow = TensorType("int8", shape=(1, None)) | |
wrow = TensorType("int16", shape=(1, None)) | |
irow = TensorType("int32", shape=(1, None)) | |
lrow = TensorType("int64", shape=(1, None)) | |
def row(name=None, dtype=None): | |
"""Return a symbolic row variable (i.e. shape ``(1, None)``). | |
Parameters | |
---------- | |
dtype: numeric type | |
None means to use pytensor.config.floatX. | |
name | |
A name to attach to this variable. | |
""" | |
if dtype is None: | |
dtype = config.floatX | |
type = TensorType(dtype, shape=(1, None)) | |
return type(name) | |
rows, frows, drows, irows, lrows = apply_across_args(row, frow, drow, irow, lrow) | |
ccol = TensorType("complex64", shape=(None, 1)) | |
zcol = TensorType("complex128", shape=(None, 1)) | |
fcol = TensorType("float32", shape=(None, 1)) | |
dcol = TensorType("float64", shape=(None, 1)) | |
bcol = TensorType("int8", shape=(None, 1)) | |
wcol = TensorType("int16", shape=(None, 1)) | |
icol = TensorType("int32", shape=(None, 1)) | |
lcol = TensorType("int64", shape=(None, 1)) | |
def col( | |
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None | |
) -> "TensorVariable": | |
"""Return a symbolic column variable (i.e. shape ``(None, 1)``). | |
Parameters | |
---------- | |
name | |
A name to attach to this variable. | |
dtype | |
``None`` means to use `pytensor.config.floatX`. | |
""" | |
if dtype is None: | |
dtype = config.floatX | |
type = TensorType(dtype, shape=(None, 1)) | |
return type(name) | |
cols, fcols, dcols, icols, lcols = apply_across_args(col, fcol, dcol, icol, lcol) | |
ctensor3 = TensorType("complex64", shape=((None,) * 3)) | |
ztensor3 = TensorType("complex128", shape=((None,) * 3)) | |
ftensor3 = TensorType("float32", shape=((None,) * 3)) | |
dtensor3 = TensorType("float64", shape=((None,) * 3)) | |
btensor3 = TensorType("int8", shape=((None,) * 3)) | |
wtensor3 = TensorType("int16", shape=((None,) * 3)) | |
itensor3 = TensorType("int32", shape=((None,) * 3)) | |
ltensor3 = TensorType("int64", shape=((None,) * 3)) | |
def tensor3( | |
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None | |
) -> "TensorVariable": | |
"""Return a symbolic 3D variable. | |
Parameters | |
---------- | |
name | |
A name to attach to this variable. | |
dtype | |
``None`` means to use `pytensor.config.floatX`. | |
""" | |
if dtype is None: | |
dtype = config.floatX | |
type = TensorType(dtype, shape=(None, None, None)) | |
return type(name) | |
tensor3s, ftensor3s, dtensor3s, itensor3s, ltensor3s = apply_across_args( | |
tensor3, ftensor3, dtensor3, itensor3, ltensor3 | |
) | |
ctensor4 = TensorType("complex64", shape=((None,) * 4)) | |
ztensor4 = TensorType("complex128", shape=((None,) * 4)) | |
ftensor4 = TensorType("float32", shape=((None,) * 4)) | |
dtensor4 = TensorType("float64", shape=((None,) * 4)) | |
btensor4 = TensorType("int8", shape=((None,) * 4)) | |
wtensor4 = TensorType("int16", shape=((None,) * 4)) | |
itensor4 = TensorType("int32", shape=((None,) * 4)) | |
ltensor4 = TensorType("int64", shape=((None,) * 4)) | |
def tensor4( | |
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None | |
) -> "TensorVariable": | |
"""Return a symbolic 4D variable. | |
Parameters | |
---------- | |
name | |
A name to attach to this variable. | |
dtype | |
``None`` means to use `pytensor.config.floatX`. | |
""" | |
if dtype is None: | |
dtype = config.floatX | |
type = TensorType(dtype, shape=(None, None, None, None)) | |
return type(name) | |
tensor4s, ftensor4s, dtensor4s, itensor4s, ltensor4s = apply_across_args( | |
tensor4, ftensor4, dtensor4, itensor4, ltensor4 | |
) | |
ctensor5 = TensorType("complex64", shape=((None,) * 5)) | |
ztensor5 = TensorType("complex128", shape=((None,) * 5)) | |
ftensor5 = TensorType("float32", shape=((None,) * 5)) | |
dtensor5 = TensorType("float64", shape=((None,) * 5)) | |
btensor5 = TensorType("int8", shape=((None,) * 5)) | |
wtensor5 = TensorType("int16", shape=((None,) * 5)) | |
itensor5 = TensorType("int32", shape=((None,) * 5)) | |
ltensor5 = TensorType("int64", shape=((None,) * 5)) | |
def tensor5( | |
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None | |
) -> "TensorVariable": | |
"""Return a symbolic 5D variable. | |
Parameters | |
---------- | |
name | |
A name to attach to this variable. | |
dtype | |
``None`` means to use `pytensor.config.floatX`. | |
""" | |
if dtype is None: | |
dtype = config.floatX | |
type = TensorType(dtype, shape=(None, None, None, None, None)) | |
return type(name) | |
tensor5s, ftensor5s, dtensor5s, itensor5s, ltensor5s = apply_across_args( | |
tensor5, ftensor5, dtensor5, itensor5, ltensor5 | |
) | |
ctensor6 = TensorType("complex64", shape=((None,) * 6)) | |
ztensor6 = TensorType("complex128", shape=((None,) * 6)) | |
ftensor6 = TensorType("float32", shape=((None,) * 6)) | |
dtensor6 = TensorType("float64", shape=((None,) * 6)) | |
btensor6 = TensorType("int8", shape=((None,) * 6)) | |
wtensor6 = TensorType("int16", shape=((None,) * 6)) | |
itensor6 = TensorType("int32", shape=((None,) * 6)) | |
ltensor6 = TensorType("int64", shape=((None,) * 6)) | |
def tensor6( | |
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None | |
) -> "TensorVariable": | |
"""Return a symbolic 6D variable. | |
Parameters | |
---------- | |
name | |
A name to attach to this variable. | |
dtype | |
``None`` means to use `pytensor.config.floatX`. | |
""" | |
if dtype is None: | |
dtype = config.floatX | |
type = TensorType(dtype, shape=(None,) * 6) | |
return type(name) | |
tensor6s, ftensor6s, dtensor6s, itensor6s, ltensor6s = apply_across_args( | |
tensor6, ftensor6, dtensor6, itensor6, ltensor6 | |
) | |
ctensor7 = TensorType("complex64", shape=((None,) * 7)) | |
ztensor7 = TensorType("complex128", shape=((None,) * 7)) | |
ftensor7 = TensorType("float32", shape=((None,) * 7)) | |
dtensor7 = TensorType("float64", shape=((None,) * 7)) | |
btensor7 = TensorType("int8", shape=((None,) * 7)) | |
wtensor7 = TensorType("int16", shape=((None,) * 7)) | |
itensor7 = TensorType("int32", shape=((None,) * 7)) | |
ltensor7 = TensorType("int64", shape=((None,) * 7)) | |
def tensor7( | |
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None | |
) -> "TensorVariable": | |
"""Return a symbolic 7-D variable. | |
Parameters | |
---------- | |
name | |
A name to attach to this variable. | |
dtype | |
``None`` means to use `pytensor.config.floatX`. | |
""" | |
if dtype is None: | |
dtype = config.floatX | |
type = TensorType(dtype, shape=(None,) * 7) | |
return type(name) | |
tensor7s, ftensor7s, dtensor7s, itensor7s, ltensor7s = apply_across_args( | |
tensor7, ftensor7, dtensor7, itensor7, ltensor7 | |
) |