Skip to content

Commit 07adbc8

Browse files
michaelosthegericardoV94
authored andcommitted
Fix mypy errors
1 parent 9e6818b commit 07adbc8

File tree

2 files changed

+6
-7
lines changed

2 files changed

+6
-7
lines changed

pymc/gp/hsgp_approx.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,9 @@ def __init__(
189189
self._drop_first = drop_first
190190
self._m = m
191191
self._m_star = int(np.prod(self._m))
192-
self._L = L
192+
self._L: Optional[pt.TensorVariable] = None
193+
if L is not None:
194+
self._L = pt.as_tensor(L)
193195
self._c = c
194196

195197
super().__init__(mean_func=mean_func, cov_func=cov_func)
@@ -198,13 +200,13 @@ def __add__(self, other):
198200
raise NotImplementedError("Additive HSGPs aren't supported.")
199201

200202
@property
201-
def L(self):
203+
def L(self) -> pt.TensorVariable:
202204
if self._L is None:
203205
raise RuntimeError("Boundaries `L` required but still unset.")
204206
return self._L
205207

206208
@L.setter
207-
def L(self, value):
209+
def L(self, value: TensorLike):
208210
self._L = pt.as_tensor_variable(value)
209211

210212
def prior_linearized(self, Xs: TensorLike):
@@ -290,9 +292,7 @@ def prior_linearized(self, Xs: TensorLike):
290292
# If not provided, use Xs and c to set L
291293
if self._L is None:
292294
assert isinstance(self._c, (numbers.Real, np.ndarray, pt.TensorVariable))
293-
self.L = set_boundary(Xs, self._c)
294-
else:
295-
self.L = self._L
295+
self._L = pt.as_tensor(set_boundary(Xs, self._c))
296296

297297
eigvals = calc_eigenvalues(self.L, self._m, tl=pt)
298298
phi = calc_eigenvectors(Xs, self.L, eigvals, self._m, tl=pt)

scripts/run_mypy.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,6 @@ def check_no_unexpected_results(mypy_lines: Iterator[str]):
105105
Exits the process with non-zero exit code upon unexpected results.
106106
"""
107107
df = mypy_to_pandas(mypy_lines)
108-
109108
all_files = {
110109
str(fp).replace(str(DP_ROOT), "").strip(os.sep).replace(os.sep, "/")
111110
for fp in DP_ROOT.glob("pymc/**/*.py")

0 commit comments

Comments
 (0)