Skip to content

Commit 41f0181

Browse files
michaelosthegericardoV94
authored andcommitted
Fix typing in opvi and model_graph
1 parent 07adbc8 commit 41f0181

File tree

3 files changed

+24
-14
lines changed

3 files changed

+24
-14
lines changed

pymc/model_graph.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,12 @@ def _expand(x):
7676
return reversed(_filter_non_parameter_inputs(x))
7777
return []
7878

79-
parents = {
80-
VarName(get_var_name(x))
81-
for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand)
79+
parents = set()
80+
for x in walk(nodes=_filter_non_parameter_inputs(var), expand=_expand):
8281
# Only consider nodes that are in the named model variables.
83-
if x.name and x.name in self._all_var_names
84-
}
82+
vname = getattr(x, "name", None)
83+
if isinstance(vname, str) and vname in self._all_var_names:
84+
parents.add(VarName(vname))
8585

8686
return parents
8787

@@ -113,7 +113,7 @@ def vars_to_plot(self, var_names: Optional[Iterable[VarName]] = None) -> List[Va
113113
selected_ancestors.add(self.model.rvs_to_values[var])
114114

115115
# ordering of self._all_var_names is important
116-
return [VarName(var.name) for var in selected_ancestors]
116+
return [VarName(get_var_name(var)) for var in selected_ancestors]
117117

118118
def make_compute_graph(
119119
self, var_names: Optional[Iterable[VarName]] = None

pymc/variational/opvi.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@
5151
import itertools
5252
import warnings
5353

54-
from typing import Any
54+
from typing import Any, overload
5555

5656
import numpy as np
5757
import pytensor
@@ -980,17 +980,29 @@ def symbolic_random(self):
980980
"""
981981
raise NotImplementedError
982982

983-
@pytensor.config.change_flags(compute_test_value="off")
983+
@overload
984+
def set_size_and_deterministic(
985+
self, node: Variable, s, d: bool, more_replacements: dict | None = None
986+
) -> Variable:
987+
...
988+
989+
@overload
984990
def set_size_and_deterministic(
985-
self, node: Variable, s, d: bool, more_replacements: dict = None
991+
self, node: list[Variable], s, d: bool, more_replacements: dict | None = None
986992
) -> list[Variable]:
993+
...
994+
995+
@pytensor.config.change_flags(compute_test_value="off")
996+
def set_size_and_deterministic(
997+
self, node: Variable | list[Variable], s, d: bool, more_replacements: dict | None = None
998+
) -> Variable | list[Variable]:
987999
"""*Dev* - after node is sampled via :func:`symbolic_sample_over_posterior` or
9881000
:func:`symbolic_single_sample` new random generator can be allocated and applied to node
9891001
9901002
Parameters
9911003
----------
992-
node: :class:`Variable`
993-
PyTensor node with symbolically applied VI replacements
1004+
node
1005+
PyTensor node(s) with symbolically applied VI replacements
9941006
s: scalar
9951007
desired number of samples
9961008
d: bool or int
@@ -1000,7 +1012,7 @@ def set_size_and_deterministic(
10001012
10011013
Returns
10021014
-------
1003-
:class:`Variable` with applied replacements, ready to use
1015+
:class:`Variable` or list with applied replacements, ready to use
10041016
"""
10051017

10061018
flat2rand = self.make_size_and_deterministic_replacements(s, d, more_replacements)

scripts/run_mypy.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,9 @@
4343
pymc/model/fgraph.py
4444
pymc/model/transform/basic.py
4545
pymc/model/transform/conditioning.py
46-
pymc/model_graph.py
4746
pymc/printing.py
4847
pymc/pytensorf.py
4948
pymc/sampling/jax.py
50-
pymc/variational/opvi.py
5149
"""
5250

5351

0 commit comments

Comments
 (0)