Skip to content

Commit 94b3ac7

Browse files
Rename required_inputs_graph to explicit_input_graphs
1 parent 1785fab commit 94b3ac7

File tree

2 files changed

+7
-7
lines changed

2 files changed

+7
-7
lines changed

pytensor/graph/basic.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -936,15 +936,15 @@ def graph_inputs(
936936
yield from (r for r in ancestors(graphs, blockers) if r.owner is None)
937937

938938

939-
def required_graph_inputs(
939+
def explicit_graph_inputs(
940940
graph: Variable[Any, Any] | Iterable[Variable[Any, Any]],
941941
) -> Generator[Variable, None, None]:
942942
"""
943943
Get the inputs into PyTensor variables
944944
945945
Parameters
946946
----------
947-
graph: PyTensor `Variable` instances
947+
graph: TensorVariable
948948
Output `Variable` instances from which to search backward through
949949
owners.
950950
@@ -969,7 +969,7 @@ def required_graph_inputs(
969969
# x [id A]
970970
# 2 [id B]
971971
972-
pytensor.dprint(required_graph_inputs([z]))
972+
pytensor.dprint(explicit_graph_inputs([z]))
973973
# x [id A]
974974
"""
975975
from pytensor.compile.sharedvalue import SharedVariable

tests/graph/test_basic.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,13 @@
1818
clone,
1919
clone_get_equiv,
2020
equal_computations,
21+
explicit_graph_inputs,
2122
general_toposort,
2223
get_var_by_name,
2324
graph_inputs,
2425
io_toposort,
2526
list_of_nodes,
2627
orphans_between,
27-
required_graph_inputs,
2828
truncated_graph_inputs,
2929
variable_depends_on,
3030
vars_between,
@@ -523,15 +523,15 @@ def test_graph_inputs():
523523
assert res_list == [r3, r1, r2]
524524

525525

526-
def test_required_graph_inputs():
526+
def test_explicit_graph_inputs():
527527
x = pt.fscalar()
528528
y = pt.constant(2)
529529
z = shared(1)
530530
a = pt.sum(x + y + z)
531531
b = pt.true_div(x, y)
532532

533-
res = list(required_graph_inputs([a]))
534-
res1 = list(required_graph_inputs(b))
533+
res = list(explicit_graph_inputs([a]))
534+
res1 = list(explicit_graph_inputs(b))
535535
assert res, res1 == [x]
536536

537537

0 commit comments

Comments
 (0)