diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 82592c0c86..ac1716ee06 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -936,6 +936,55 @@ def graph_inputs( yield from (r for r in ancestors(graphs, blockers) if r.owner is None) +def explicit_graph_inputs( + graph: Variable | Iterable[Variable], +) -> Generator[Variable, None, None]: + """ + Get the root variables needed as inputs to a function that computes `graph` + + Parameters + ---------- + graph : TensorVariable + Output `Variable` instances for which to search backward through + owners. + + Returns + ------- + iterable + Generator of root Variables (without owner) needed to compile a function that evaluates `graphs`. + + Examples + -------- + + .. code-block:: python + + import pytensor + import pytensor.tensor as pt + from pytensor.graph.basic import explicit_graph_inputs + + x = pt.vector('x') + y = pt.constant(2) + z = pt.mul(x*y) + + inputs = list(explicit_graph_inputs(z)) + f = pytensor.function(inputs, z) + eval = f([1, 2, 3]) + + print(eval) + # [2. 4. 6.] + """ + from pytensor.compile.sharedvalue import SharedVariable + + if isinstance(graph, Variable): + graph = [graph] + + return ( + v + for v in graph_inputs(graph) + if isinstance(v, Variable) and not isinstance(v, Constant | SharedVariable) + ) + + def vars_between( ins: Collection[Variable], outs: Iterable[Variable] ) -> Generator[Variable, None, None]: diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index c3a9598e52..5dc9789727 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -18,6 +18,7 @@ clone, clone_get_equiv, equal_computations, + explicit_graph_inputs, general_toposort, get_var_by_name, graph_inputs, @@ -522,6 +523,20 @@ def test_graph_inputs(): assert res_list == [r3, r1, r2] +def test_explicit_graph_inputs(): + x = pt.fscalar() + y = pt.constant(2) + z = shared(1) + a = pt.sum(x + y + z) + b = pt.true_div(x, y) + + res = list(explicit_graph_inputs([a])) + res1 = list(explicit_graph_inputs(b)) + + assert res == [x] + assert res1 == [x] + + def test_variables_and_orphans(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) o1 = MyOp(r1, r2)