Skip to content

Commit 906e142

Browse files
michaelosthegericardoV94
authored andcommitted
Fix type hints here and there
1 parent 7a00b88 commit 906e142

File tree

6 files changed

+12
-9
lines changed

6 files changed

+12
-9
lines changed

pytensor/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def as_symbolic(x: Any, name: str | None = None, **kwargs) -> Variable:
108108

109109

110110
@singledispatch
111-
def _as_symbolic(x, **kwargs) -> Variable:
111+
def _as_symbolic(x: Any, **kwargs) -> Variable:
112112
from pytensor.tensor import as_tensor_variable
113113

114114
return as_tensor_variable(x, **kwargs)

pytensor/graph/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1302,8 +1302,8 @@ def clone_node_and_cache(
13021302

13031303

13041304
def clone_get_equiv(
1305-
inputs: Sequence[Variable],
1306-
outputs: Sequence[Variable],
1305+
inputs: Iterable[Variable],
1306+
outputs: Reversible[Variable],
13071307
copy_inputs: bool = True,
13081308
copy_orphans: bool = True,
13091309
memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]

pytensor/tensor/random/op.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import warnings
22
from collections.abc import Sequence
33
from copy import copy
4-
from typing import cast
4+
from typing import Any, cast
55

66
import numpy as np
77

@@ -218,6 +218,7 @@ def _infer_shape(
218218

219219
from pytensor.tensor.extra_ops import broadcast_shape_iter
220220

221+
supp_shape: tuple[Any]
221222
if self.ndim_supp == 0:
222223
supp_shape = ()
223224
else:

pytensor/tensor/random/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ def explicit_expand_dims(
147147
return new_params
148148

149149

150-
def compute_batch_shape(params, ndims_params: Sequence[int]) -> TensorVariable:
150+
def compute_batch_shape(
151+
params: Sequence[TensorVariable], ndims_params: Sequence[int]
152+
) -> TensorVariable:
151153
params = explicit_expand_dims(params, ndims_params)
152154
batch_params = [
153155
param[(..., *(0,) * core_ndim)]

pytensor/tensor/shape.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,14 @@ def c_code_cache_version(self):
144144
_shape = Shape()
145145

146146

147-
def shape(x: np.ndarray | Number | Variable) -> Variable:
147+
def shape(x: np.ndarray | Number | Variable) -> TensorVariable:
148148
"""Return the shape of `x`."""
149149
if not isinstance(x, Variable):
150150
# The following is a type error in Python 3.9 but not 3.12.
151151
# Thus we need to ignore unused-ignore on 3.12.
152152
x = ptb.as_tensor_variable(x) # type: ignore[arg-type,unused-ignore]
153153

154-
return cast(Variable, _shape(x))
154+
return cast(TensorVariable, _shape(x))
155155

156156

157157
@_get_vector_length.register(Shape) # type: ignore
@@ -195,7 +195,7 @@ def shape_tuple(x: TensorVariable) -> tuple[Variable, ...]:
195195
# TODO: Why not use uint64?
196196
res += (pytensor.scalar.ScalarConstant(pytensor.scalar.int64, shape_val),)
197197
else:
198-
res += (symbolic_shape[i],) # type: ignore
198+
res += (symbolic_shape[i],)
199199

200200
return res
201201

pytensor/tensor/type.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def clone(
138138
shape = self.shape
139139
return type(self)(dtype, shape, name=self.name)
140140

141-
def filter(self, data, strict=False, allow_downcast=None):
141+
def filter(self, data, strict=False, allow_downcast=None) -> np.ndarray:
142142
"""Convert `data` to something which can be associated to a `TensorVariable`.
143143
144144
This function is not meant to be called in user code. It is for

0 commit comments

Comments
 (0)