Skip to content

Commit 880a57a

Browse files
Dhruvanshu-JoshiricardoV94
authored andcommitted
remove list_of_nodes in favor of similar applys_between
1 parent 8c157a2 commit 880a57a

File tree

3 files changed

+2
-46
lines changed

3 files changed

+2
-46
lines changed

pytensor/graph/basic.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,38 +1789,6 @@ def view_roots(node: Variable) -> list[Variable]:
17891789
return [node]
17901790

17911791

1792-
def list_of_nodes(
1793-
inputs: Collection[Variable], outputs: Iterable[Variable]
1794-
) -> list[Apply]:
1795-
r"""Return the `Apply` nodes of the graph between `inputs` and `outputs`.
1796-
1797-
Parameters
1798-
----------
1799-
inputs : list of Variable
1800-
Input `Variable`\s.
1801-
outputs : list of Variable
1802-
Output `Variable`\s.
1803-
1804-
"""
1805-
1806-
def expand(o: Apply) -> list[Apply]:
1807-
return [
1808-
inp.owner
1809-
for inp in o.inputs
1810-
if inp.owner and not any(i in inp.owner.outputs for i in inputs)
1811-
]
1812-
1813-
return list(
1814-
cast(
1815-
Iterable[Apply],
1816-
walk(
1817-
[o.owner for o in outputs if o.owner],
1818-
expand,
1819-
),
1820-
)
1821-
)
1822-
1823-
18241792
def apply_depends_on(apply: Apply, depends_on: Apply | Collection[Apply]) -> bool:
18251793
"""Determine if any `depends_on` is in the graph given by ``apply``.
18261794

pytensor/scalar/basic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from pytensor import printing
2525
from pytensor.configdefaults import config
2626
from pytensor.gradient import DisconnectedType, grad_undefined
27-
from pytensor.graph.basic import Apply, Constant, Variable, clone, list_of_nodes
27+
from pytensor.graph.basic import Apply, Constant, Variable, applys_between, clone
2828
from pytensor.graph.fg import FunctionGraph
2929
from pytensor.graph.op import HasInnerGraph
3030
from pytensor.graph.rewriting.basic import MergeOptimizer
@@ -4125,7 +4125,7 @@ def c_support_code_apply(self, node, name):
41254125

41264126
def prepare_node(self, node, storage_map, compute_map, impl):
41274127
if impl not in self.prepare_node_called:
4128-
for n in list_of_nodes(self.inputs, self.outputs):
4128+
for n in applys_between(self.inputs, self.outputs):
41294129
n.op.prepare_node(n, None, None, impl)
41304130
self.prepare_node_called.add(impl)
41314131

tests/graph/test_basic.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
get_var_by_name,
2424
graph_inputs,
2525
io_toposort,
26-
list_of_nodes,
2726
orphans_between,
2827
truncated_graph_inputs,
2928
variable_depends_on,
@@ -567,17 +566,6 @@ def test_ops():
567566
assert res_list == [o3.owner, o2.owner, o1.owner]
568567

569568

570-
def test_list_of_nodes():
571-
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
572-
o1 = MyOp(r1, r2)
573-
o1.name = "o1"
574-
o2 = MyOp(r3, o1)
575-
o2.name = "o2"
576-
577-
res = list_of_nodes([r1, r2], [o2])
578-
assert res == [o2.owner, o1.owner]
579-
580-
581569
def test_apply_depends_on():
582570
r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3)
583571
o1 = MyOp(r1, r2)

0 commit comments

Comments
 (0)