From 15905f4ba8fc57101132856d458a05c40f3464e8 Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sat, 13 Apr 2024 19:03:07 +0530 Subject: [PATCH 1/2] Support deriivation of input variable nodes only from output --- pytensor/graph/basic.py | 49 +++++++++++++++++++++++++++++++++++++++ tests/graph/test_basic.py | 15 ++++++++++++ 2 files changed, 64 insertions(+) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 82592c0c86..48b7d1a2a6 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -936,6 +936,55 @@ def graph_inputs( yield from (r for r in ancestors(graphs, blockers) if r.owner is None) +def explicit_graph_inputs( + graph: Variable[Any, Any] | Iterable[Variable[Any, Any]], +) -> Generator[Variable, None, None]: + """ + Get the root variables needed as inputs to a function that computes `graph` + + Parameters + ---------- + graph : TensorVariable + Output `Variable` instances for which to search backward through + owners. + + Returns + ------- + iterable + Generator of root Variables (without owner) needed to compile a function that evaluates `graphs`. + + Examples + -------- + + .. code-block:: python + + import pytensor + import pytensor.tensor as pt + from pytensor.graph.basic import explicit_graph_inputs + + x = pt.vector('x') + y = pt.constant(2) + z = pt.mul(x*y) + + inputs = list(explicit_graph_inputs(z)) + f = pytensor.function(inputs, z) + eval = f([1, 2, 3]) + + print(eval) + # [2. 4. 6.] + """ + from pytensor.compile.sharedvalue import SharedVariable + + if isinstance(graph, Variable): + graph = [graph] + + return ( + v + for v in graph_inputs(graph) + if isinstance(v, Variable) and not isinstance(v, Constant | SharedVariable) + ) + + def vars_between( ins: Collection[Variable], outs: Iterable[Variable] ) -> Generator[Variable, None, None]: diff --git a/tests/graph/test_basic.py b/tests/graph/test_basic.py index c3a9598e52..5dc9789727 100644 --- a/tests/graph/test_basic.py +++ b/tests/graph/test_basic.py @@ -18,6 +18,7 @@ clone, clone_get_equiv, equal_computations, + explicit_graph_inputs, general_toposort, get_var_by_name, graph_inputs, @@ -522,6 +523,20 @@ def test_graph_inputs(): assert res_list == [r3, r1, r2] +def test_explicit_graph_inputs(): + x = pt.fscalar() + y = pt.constant(2) + z = shared(1) + a = pt.sum(x + y + z) + b = pt.true_div(x, y) + + res = list(explicit_graph_inputs([a])) + res1 = list(explicit_graph_inputs(b)) + + assert res == [x] + assert res1 == [x] + + def test_variables_and_orphans(): r1, r2, r3 = MyVariable(1), MyVariable(2), MyVariable(3) o1 = MyOp(r1, r2) From 3637e3916332f2156f85c1e80a3b197e19817fbe Mon Sep 17 00:00:00 2001 From: Dhruvanshu-Joshi Date: Sun, 28 Apr 2024 19:02:50 +0530 Subject: [PATCH 2/2] Add helper for explicit_graph_inputs --- pytensor/graph/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytensor/graph/basic.py b/pytensor/graph/basic.py index 48b7d1a2a6..ac1716ee06 100644 --- a/pytensor/graph/basic.py +++ b/pytensor/graph/basic.py @@ -937,7 +937,7 @@ def graph_inputs( def explicit_graph_inputs( - graph: Variable[Any, Any] | Iterable[Variable[Any, Any]], + graph: Variable | Iterable[Variable], ) -> Generator[Variable, None, None]: """ Get the root variables needed as inputs to a function that computes `graph`