Skip to content

Commit 1785fab

Browse files
Support derivation of only variable graph_inputs
1 parent 8bf4332 commit 1785fab

File tree

2 files changed

+24
-30
lines changed

2 files changed

+24
-30
lines changed

pytensor/graph/basic.py

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -938,58 +938,50 @@ def graph_inputs(
938938

939939
def required_graph_inputs(
940940
graph: Variable[Any, Any] | Iterable[Variable[Any, Any]],
941-
) -> list[Variable[Any, Any]]:
941+
) -> Generator[Variable, None, None]:
942942
"""
943943
Get the inputs into PyTensor variables
944944
945945
Parameters
946946
----------
947-
graph: PyTensor `Variable` instances
947+
graph: PyTensor `Variable` instances
948948
Output `Variable` instances from which to search backward through
949949
owners.
950950
951951
Returns
952952
-------
953-
List of tensor variables that are input nodes with no owner, in the order
953+
Tensor variables that are input nodes with no owner, in the order
954954
found by a left-recursive depth-first search started at the nodes in `graphs`.
955955
956956
Examples
957957
--------
958-
>>> import pytensor as pt
959-
>>> x=pt.vector('x')
960-
>>> y=pt.constant('y')
961-
>>> z = pt.mul(x*y)
962-
>>> required_graph_inputs([a])
963-
[[[ 0 1 2 3]
964-
[ 4 5 6 7]
965-
[ 8 9 10 11]]
966-
967-
[[12 13 14 15]
968-
[16 17 18 19]
969-
[20 21 22 23]]]
970-
971-
972-
>>> pt.matrix_transpose(x).eval()
973-
[[[ 0 4 8]
974-
[ 1 5 9]
975-
[ 2 6 10]
976-
[ 3 7 11]]
977-
978-
[[12 16 20]
979-
[13 17 21]
980-
[14 18 22]
981-
[15 19 23]]]
958+
959+
.. code-block:: python
960+
961+
import pytensor
962+
import pytensor.tensor as pt
963+
964+
x = pt.vector('x')
965+
y = pt.constant(2)
966+
z = pt.mul(x*y)
967+
968+
pytensor.dprint(graph_inputs([z]))
969+
# x [id A]
970+
# 2 [id B]
971+
972+
pytensor.dprint(required_graph_inputs([z]))
973+
# x [id A]
982974
"""
983975
from pytensor.compile.sharedvalue import SharedVariable
984976

985977
if isinstance(graph, Variable):
986978
graph = [graph]
987979

988-
return [
980+
return (
989981
v
990982
for v in graph_inputs(graph)
991983
if isinstance(v, Variable) and not isinstance(v, Constant | SharedVariable)
992-
]
984+
)
993985

994986

995987
def vars_between(

tests/graph/test_basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -528,9 +528,11 @@ def test_required_graph_inputs():
528528
y = pt.constant(2)
529529
z = shared(1)
530530
a = pt.sum(x + y + z)
531+
b = pt.true_div(x, y)
531532

532533
res = list(required_graph_inputs([a]))
533-
assert res == [x]
534+
res1 = list(required_graph_inputs(b))
535+
assert res, res1 == [x]
534536

535537

536538
def test_variables_and_orphans():

0 commit comments

Comments
 (0)