Skip to content

Allow passing static shape to tensor creation helpers #118

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 14, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 157 additions & 41 deletions pytensor/tensor/type.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Union

import numpy as np
from typing_extensions import Literal

import pytensor
from pytensor import scalar as aes
Expand Down Expand Up @@ -775,9 +776,17 @@ def values_eq_approx_always_true(a, b):
)


def tensor(*args, **kwargs):
def tensor(
dtype: Optional["DTypeLike"] = None,
*args,
**kwargs,
) -> "TensorVariable":

if dtype is None:
dtype = config.floatX

name = kwargs.pop("name", None)
return TensorType(*args, **kwargs)(name=name)
return TensorType(dtype, *args, **kwargs)(name=name)


cscalar = TensorType("complex64", ())
Expand All @@ -794,7 +803,10 @@ def tensor(*args, **kwargs):
ulscalar = TensorType("uint64", ())


def scalar(name=None, dtype=None):
def scalar(
name: Optional[str] = None,
dtype: Optional["DTypeLike"] = None,
) -> "TensorVariable":
"""Return a symbolic scalar variable.

Parameters
Expand Down Expand Up @@ -832,20 +844,47 @@ def scalar(name=None, dtype=None):
lvector = TensorType("int64", shape=(None,))


def vector(name=None, dtype=None):
ST = Union[int, None]


def _validate_static_shape(shape, ndim: int) -> Tuple[ST, ...]:

if not isinstance(shape, tuple):
raise TypeError(f"Shape must be a tuple, got {type(shape)}")

if len(shape) != ndim:
raise ValueError(f"Shape must be a tuple of length {ndim}, got {shape}")

if not all(sh is None or isinstance(sh, int) for sh in shape):
raise TypeError(f"Shape entries must be None or integer, got {shape}")

return shape


def vector(
name: Optional[str] = None,
dtype: Optional["DTypeLike"] = None,
shape: Optional[Tuple[ST]] = (None,),
) -> "TensorVariable":
"""Return a symbolic vector variable.

Parameters
----------
dtype: numeric
None means to use pytensor.config.floatX.
name
A name to attach to this variable
shape
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
allows that dimension to change size across evaluations.
dtype
Data type of tensor variable. By default, it's pytensor.config.floatX.

"""
if dtype is None:
dtype = config.floatX
type = TensorType(dtype, shape=(None,))

shape = _validate_static_shape(shape, ndim=1)

type = TensorType(dtype, shape=shape)
return type(name)


Expand All @@ -867,20 +906,28 @@ def vector(name=None, dtype=None):
lmatrix = TensorType("int64", shape=(None, None))


def matrix(name=None, dtype=None):
def matrix(
name: Optional[str] = None,
dtype: Optional["DTypeLike"] = None,
shape: Optional[Tuple[ST, ST]] = (None, None),
) -> "TensorVariable":
"""Return a symbolic matrix variable.

Parameters
----------
dtype: numeric
None means to use pytensor.config.floatX.
name
A name to attach to this variable.
A name to attach to this variable
shape
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
allows that dimension to change size across evaluations.
dtype
Data type of tensor variable. By default, it's pytensor.config.floatX.

"""
if dtype is None:
dtype = config.floatX
type = TensorType(dtype, shape=(None, None))
shape = _validate_static_shape(shape, ndim=2)
type = TensorType(dtype, shape=shape)
return type(name)


Expand All @@ -902,20 +949,34 @@ def matrix(name=None, dtype=None):
lrow = TensorType("int64", shape=(1, None))


def row(name=None, dtype=None):
def row(
name: Optional[str] = None,
dtype: Optional["DTypeLike"] = None,
shape: Optional[Tuple[Literal[1], ST]] = (1, None),
) -> "TensorVariable":
"""Return a symbolic row variable (i.e. shape ``(1, None)``).

Parameters
----------
dtype: numeric type
None means to use pytensor.config.floatX.
name
A name to attach to this variable.
A name to attach to this variable
shape
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
allows that dimension to change size across evaluations.
dtype
Data type of tensor variable. By default, it's pytensor.config.floatX.

"""
if dtype is None:
dtype = config.floatX
type = TensorType(dtype, shape=(1, None))
shape = _validate_static_shape(shape, ndim=2)

if shape[0] != 1:
raise ValueError(
f"The first dimension of a `row` must have shape 1, got {shape[0]}"
)

type = TensorType(dtype, shape=shape)
return type(name)


Expand All @@ -932,21 +993,31 @@ def row(name=None, dtype=None):


def col(
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
name: Optional[str] = None,
dtype: Optional["DTypeLike"] = None,
shape: Optional[Tuple[ST, Literal[1]]] = (None, 1),
) -> "TensorVariable":
"""Return a symbolic column variable (i.e. shape ``(None, 1)``).

Parameters
----------
name
A name to attach to this variable.
A name to attach to this variable
shape
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
allows that dimension to change size across evaluations.
dtype
``None`` means to use `pytensor.config.floatX`.
Data type of tensor variable. By default, it's pytensor.config.floatX.

"""
if dtype is None:
dtype = config.floatX
type = TensorType(dtype, shape=(None, 1))
shape = _validate_static_shape(shape, ndim=2)
if shape[1] != 1:
raise ValueError(
f"The second dimension of a `col` must have shape 1, got {shape[1]}"
)
type = TensorType(dtype, shape=shape)
return type(name)


Expand All @@ -963,21 +1034,27 @@ def col(


def tensor3(
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
name: Optional[str] = None,
dtype: Optional["DTypeLike"] = None,
shape: Optional[Tuple[ST, ST, ST]] = (None, None, None),
) -> "TensorVariable":
"""Return a symbolic 3D variable.

Parameters
----------
name
A name to attach to this variable.
A name to attach to this variable
shape
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
allows that dimension to change size across evaluations.
dtype
``None`` means to use `pytensor.config.floatX`.
Data type of tensor variable. By default, it's pytensor.config.floatX.

"""
if dtype is None:
dtype = config.floatX
type = TensorType(dtype, shape=(None, None, None))
shape = _validate_static_shape(shape, ndim=3)
type = TensorType(dtype, shape=shape)
return type(name)


Expand All @@ -996,21 +1073,27 @@ def tensor3(


def tensor4(
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
name: Optional[str] = None,
dtype: Optional["DTypeLike"] = None,
shape: Optional[Tuple[ST, ST, ST, ST]] = (None, None, None, None),
) -> "TensorVariable":
"""Return a symbolic 4D variable.

Parameters
----------
name
A name to attach to this variable.
A name to attach to this variable
shape
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
allows that dimension to change size across evaluations.
dtype
``None`` means to use `pytensor.config.floatX`.
Data type of tensor variable. By default, it's pytensor.config.floatX.

"""
if dtype is None:
dtype = config.floatX
type = TensorType(dtype, shape=(None, None, None, None))
shape = _validate_static_shape(shape, ndim=4)
type = TensorType(dtype, shape=shape)
return type(name)


Expand All @@ -1029,21 +1112,27 @@ def tensor4(


def tensor5(
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
name: Optional[str] = None,
dtype: Optional["DTypeLike"] = None,
shape: Optional[Tuple[ST, ST, ST, ST, ST]] = (None, None, None, None, None),
) -> "TensorVariable":
"""Return a symbolic 5D variable.

Parameters
----------
name
A name to attach to this variable.
A name to attach to this variable
shape
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
allows that dimension to change size across evaluations.
dtype
``None`` means to use `pytensor.config.floatX`.
Data type of tensor variable. By default, it's pytensor.config.floatX.

"""
if dtype is None:
dtype = config.floatX
type = TensorType(dtype, shape=(None, None, None, None, None))
shape = _validate_static_shape(shape, ndim=5)
type = TensorType(dtype, shape=shape)
return type(name)


Expand All @@ -1062,21 +1151,34 @@ def tensor5(


def tensor6(
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
name: Optional[str] = None,
dtype: Optional["DTypeLike"] = None,
shape: Optional[Tuple[ST, ST, ST, ST, ST, ST]] = (
None,
None,
None,
None,
None,
None,
),
) -> "TensorVariable":
"""Return a symbolic 6D variable.

Parameters
----------
name
A name to attach to this variable.
A name to attach to this variable
shape
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
allows that dimension to change size across evaluations.
dtype
``None`` means to use `pytensor.config.floatX`.
Data type of tensor variable. By default, it's pytensor.config.floatX.

"""
if dtype is None:
dtype = config.floatX
type = TensorType(dtype, shape=(None,) * 6)
shape = _validate_static_shape(shape, ndim=6)
type = TensorType(dtype, shape=shape)
return type(name)


Expand All @@ -1095,21 +1197,35 @@ def tensor6(


def tensor7(
name: Optional[str] = None, dtype: Optional["DTypeLike"] = None
name: Optional[str] = None,
dtype: Optional["DTypeLike"] = None,
shape: Optional[Tuple[ST, ST, ST, ST, ST, ST, ST]] = (
None,
None,
None,
None,
None,
None,
None,
),
) -> "TensorVariable":
"""Return a symbolic 7-D variable.

Parameters
----------
name
A name to attach to this variable.
A name to attach to this variable
shape
A tuple of static sizes for each dimension of the variable. By default, each dimension length is `None` which
allows that dimension to change size across evaluations.
dtype
``None`` means to use `pytensor.config.floatX`.
Data type of tensor variable. By default, it's pytensor.config.floatX.

"""
if dtype is None:
dtype = config.floatX
type = TensorType(dtype, shape=(None,) * 7)
shape = _validate_static_shape(shape, ndim=7)
type = TensorType(dtype, shape=shape)
return type(name)


Expand Down
Loading