Skip to content

Commit 38731ad

Browse files
ferrinericardoV94
authored andcommitted
move clone_replace to a separate file
1 parent 2445327 commit 38731ad

File tree

14 files changed

+213
-206
lines changed

14 files changed

+213
-206
lines changed

pytensor/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ def disable_log_handler(logger=pytensor_logger, handler=logging_default_handler)
7373
__api_version__ = 1
7474

7575
# isort: off
76-
from pytensor.graph.basic import Variable, clone_replace
76+
from pytensor.graph.basic import Variable
77+
from pytensor.graph.replace import clone_replace
7778

7879
# isort: on
7980

pytensor/compile/builders.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,13 @@
1616
Constant,
1717
NominalVariable,
1818
Variable,
19-
clone_replace,
2019
graph_inputs,
2120
io_connection_pattern,
2221
)
2322
from pytensor.graph.fg import FunctionGraph
2423
from pytensor.graph.null_type import NullType
2524
from pytensor.graph.op import HasInnerGraph, Op
25+
from pytensor.graph.replace import clone_replace
2626
from pytensor.graph.rewriting.basic import in2out, node_rewriter
2727
from pytensor.graph.utils import MissingInputError
2828
from pytensor.tensor.rewriting.shape import ShapeFeature

pytensor/graph/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@
77
Constant,
88
graph_inputs,
99
clone,
10-
clone_replace,
1110
ancestors,
1211
)
12+
from pytensor.graph.replace import clone_replace
1313
from pytensor.graph.op import Op
1414
from pytensor.graph.type import Type
1515
from pytensor.graph.fg import FunctionGraph

pytensor/graph/basic.py

Lines changed: 0 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,53 +1151,6 @@ def clone_get_equiv(
11511151
return memo
11521152

11531153

1154-
def clone_replace(
1155-
output: Collection[Variable],
1156-
replace: Optional[
1157-
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
1158-
] = None,
1159-
**rebuild_kwds,
1160-
) -> List[Variable]:
1161-
"""Clone a graph and replace subgraphs within it.
1162-
1163-
It returns a copy of the initial subgraph with the corresponding
1164-
substitutions.
1165-
1166-
Parameters
1167-
----------
1168-
output
1169-
PyTensor expression that represents the computational graph.
1170-
replace
1171-
Dictionary describing which subgraphs should be replaced by what.
1172-
rebuild_kwds
1173-
Keywords to `rebuild_collect_shared`.
1174-
1175-
"""
1176-
from pytensor.compile.function.pfunc import rebuild_collect_shared
1177-
1178-
items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]]
1179-
if isinstance(replace, dict):
1180-
items = list(replace.items())
1181-
elif isinstance(replace, (list, tuple)):
1182-
items = replace
1183-
elif replace is None:
1184-
items = []
1185-
else:
1186-
raise ValueError(
1187-
"replace is neither a dictionary, list, "
1188-
f"tuple or None ! The value provided is {replace},"
1189-
f"of type {type(replace)}"
1190-
)
1191-
tmp_replace = [(x, x.type()) for x, y in items]
1192-
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
1193-
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
1194-
1195-
# TODO Explain why we call it twice ?!
1196-
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
1197-
1198-
return cast(List[Variable], outs)
1199-
1200-
12011154
def general_toposort(
12021155
outputs: Iterable[T],
12031156
deps: Callable[[T], Union[OrderedSet, List[T]]],

pytensor/graph/replace.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from typing import (
2+
Collection,
3+
Dict,
4+
Iterable,
5+
List,
6+
Optional,
7+
Sequence,
8+
Tuple,
9+
Union,
10+
cast,
11+
)
12+
13+
from pytensor.graph.basic import Constant, Variable
14+
15+
16+
def clone_replace(
17+
output: Collection[Variable],
18+
replace: Optional[
19+
Union[Iterable[Tuple[Variable, Variable]], Dict[Variable, Variable]]
20+
] = None,
21+
**rebuild_kwds,
22+
) -> List[Variable]:
23+
"""Clone a graph and replace subgraphs within it.
24+
25+
It returns a copy of the initial subgraph with the corresponding
26+
substitutions.
27+
28+
Parameters
29+
----------
30+
output
31+
PyTensor expression that represents the computational graph.
32+
replace
33+
Dictionary describing which subgraphs should be replaced by what.
34+
rebuild_kwds
35+
Keywords to `rebuild_collect_shared`.
36+
37+
"""
38+
from pytensor.compile.function.pfunc import rebuild_collect_shared
39+
40+
items: Union[List[Tuple[Variable, Variable]], Tuple[Tuple[Variable, Variable], ...]]
41+
if isinstance(replace, dict):
42+
items = list(replace.items())
43+
elif isinstance(replace, (list, tuple)):
44+
items = replace
45+
elif replace is None:
46+
items = []
47+
else:
48+
raise ValueError(
49+
"replace is neither a dictionary, list, "
50+
f"tuple or None ! The value provided is {replace},"
51+
f"of type {type(replace)}"
52+
)
53+
tmp_replace = [(x, x.type()) for x, y in items]
54+
new_replace = [(x, y) for ((_, x), (_, y)) in zip(tmp_replace, items)]
55+
_, _outs, _ = rebuild_collect_shared(output, [], tmp_replace, [], **rebuild_kwds)
56+
57+
# TODO Explain why we call it twice ?!
58+
_, outs, _ = rebuild_collect_shared(_outs, [], new_replace, [], **rebuild_kwds)
59+
60+
return cast(List[Variable], outs)

pytensor/ifelse.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@
2020
from pytensor import as_symbolic
2121
from pytensor.compile import optdb
2222
from pytensor.configdefaults import config
23-
from pytensor.graph.basic import Apply, Variable, clone_replace, is_in_ancestors
23+
from pytensor.graph.basic import Apply, Variable, is_in_ancestors
2424
from pytensor.graph.op import _NoPythonOp
25+
from pytensor.graph.replace import clone_replace
2526
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
2627
from pytensor.graph.type import HasDataType, HasShape
2728
from pytensor.tensor.shape import Reshape, Shape, SpecifyShape, Unbroadcast

pytensor/scan/basic.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from pytensor.compile.function.pfunc import construct_pfunc_ins_and_outs
77
from pytensor.compile.sharedvalue import SharedVariable, collect_new_shareds
88
from pytensor.configdefaults import config
9-
from pytensor.graph.basic import Constant, Variable, clone_replace, graph_inputs
9+
from pytensor.graph.basic import Constant, Variable, graph_inputs
1010
from pytensor.graph.op import get_test_value
11+
from pytensor.graph.replace import clone_replace
1112
from pytensor.graph.utils import MissingInputError, TestValueError
1213
from pytensor.scan.op import Scan, ScanInfo
1314
from pytensor.scan.utils import expand_empty, safe_new, until

pytensor/scan/op.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,13 +65,13 @@
6565
from pytensor.graph.basic import (
6666
Apply,
6767
Variable,
68-
clone_replace,
6968
equal_computations,
7069
graph_inputs,
7170
io_connection_pattern,
7271
)
7372
from pytensor.graph.features import NoOutputFromInplace
7473
from pytensor.graph.op import HasInnerGraph, Op
74+
from pytensor.graph.replace import clone_replace
7575
from pytensor.graph.utils import InconsistencyError, MissingInputError
7676
from pytensor.link.c.basic import CLinker
7777
from pytensor.link.c.exceptions import MissingGXX

pytensor/scan/rewriting.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
Apply,
1919
Constant,
2020
Variable,
21-
clone_replace,
2221
equal_computations,
2322
graph_inputs,
2423
io_toposort,
@@ -28,6 +27,7 @@
2827
from pytensor.graph.features import ReplaceValidate
2928
from pytensor.graph.fg import FunctionGraph
3029
from pytensor.graph.op import compute_test_value
30+
from pytensor.graph.replace import clone_replace
3131
from pytensor.graph.rewriting.basic import GraphRewriter, in2out, node_rewriter
3232
from pytensor.graph.rewriting.db import EquilibriumDB, SequenceDB
3333
from pytensor.graph.type import HasShape

pytensor/scan/utils.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,9 @@
1414
from pytensor import tensor as at
1515
from pytensor.compile.profiling import ProfileStats
1616
from pytensor.configdefaults import config
17-
from pytensor.graph.basic import (
18-
Constant,
19-
Variable,
20-
clone_replace,
21-
equal_computations,
22-
graph_inputs,
23-
)
17+
from pytensor.graph.basic import Constant, Variable, equal_computations, graph_inputs
2418
from pytensor.graph.op import get_test_value
19+
from pytensor.graph.replace import clone_replace
2520
from pytensor.graph.type import HasDataType
2621
from pytensor.graph.utils import TestValueError
2722
from pytensor.tensor.basic import AllocEmpty, cast

tests/graph/test_basic.py

Lines changed: 2 additions & 137 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
import pytest
66

7-
from pytensor import config, function, shared
7+
from pytensor import shared
88
from pytensor import tensor as at
99
from pytensor.graph.basic import (
1010
Apply,
@@ -15,7 +15,6 @@
1515
as_string,
1616
clone,
1717
clone_get_equiv,
18-
clone_replace,
1918
equal_computations,
2019
general_toposort,
2120
get_var_by_name,
@@ -30,18 +29,9 @@
3029
from pytensor.graph.op import Op
3130
from pytensor.graph.type import Type
3231
from pytensor.tensor.math import max_and_argmax
33-
from pytensor.tensor.type import (
34-
TensorType,
35-
dvector,
36-
fvector,
37-
iscalars,
38-
matrix,
39-
scalars,
40-
vector,
41-
)
32+
from pytensor.tensor.type import TensorType, iscalars, matrix, scalars, vector
4233
from pytensor.tensor.type_other import NoneConst
4334
from pytensor.tensor.var import TensorVariable
44-
from tests import unittest_tools as utt
4535
from tests.graph.utils import MyInnerGraphOp
4636

4737

@@ -557,131 +547,6 @@ def test_get_var_by_name():
557547
assert res == exp_res
558548

559549

560-
class TestCloneReplace:
561-
def test_cloning_no_replace_strict_copy_inputs(self):
562-
# This has nothing to do with scan, but it refers to the clone
563-
# function that scan uses internally and that pfunc uses now and
564-
# that users might want to use
565-
x = vector("x")
566-
y = vector("y")
567-
z = shared(0.25)
568-
569-
f1 = z * (x + y) ** 2 + 5
570-
f2 = clone_replace(f1, replace=None, rebuild_strict=True, copy_inputs_over=True)
571-
f2_inp = graph_inputs([f2])
572-
573-
assert z in f2_inp
574-
assert x in f2_inp
575-
assert y in f2_inp
576-
577-
def test_cloning_no_replace_strict_not_copy_inputs(self):
578-
# This has nothing to do with scan, but it refers to the clone
579-
# function that scan uses internally and that pfunc uses now and
580-
# that users might want to use
581-
x = vector("x")
582-
y = vector("y")
583-
z = shared(0.25)
584-
585-
f1 = z * (x + y) ** 2 + 5
586-
f2 = clone_replace(
587-
f1, replace=None, rebuild_strict=True, copy_inputs_over=False
588-
)
589-
f2_inp = graph_inputs([f2])
590-
591-
assert z not in f2_inp
592-
assert x not in f2_inp
593-
assert y not in f2_inp
594-
595-
def test_cloning_replace_strict_copy_inputs(self):
596-
# This has nothing to do with scan, but it refers to the clone
597-
# function that scan uses internally and that pfunc uses now and
598-
# that users might want to use
599-
x = vector("x")
600-
y = vector("y")
601-
y2 = vector("y2")
602-
z = shared(0.25)
603-
604-
f1 = z * (x + y) ** 2 + 5
605-
f2 = clone_replace(
606-
f1, replace={y: y2}, rebuild_strict=True, copy_inputs_over=True
607-
)
608-
f2_inp = graph_inputs([f2])
609-
assert z in f2_inp
610-
assert x in f2_inp
611-
assert y2 in f2_inp
612-
613-
def test_cloning_replace_not_strict_copy_inputs(self):
614-
# This has nothing to do with scan, but it refers to the clone
615-
# function that scan uses internally and that pfunc uses now and
616-
# that users might want to use
617-
x = vector("x")
618-
y = fvector("y")
619-
y2 = dvector("y2")
620-
z = shared(0.25)
621-
622-
f1 = z * (x + y) ** 2 + 5
623-
f2 = clone_replace(
624-
f1, replace={y: y2}, rebuild_strict=False, copy_inputs_over=True
625-
)
626-
f2_inp = graph_inputs([f2])
627-
assert z in f2_inp
628-
assert x in f2_inp
629-
assert y2 in f2_inp
630-
631-
def test_cloning_replace_strict_not_copy_inputs(self):
632-
# This has nothing to do with scan, but it refers to the clone
633-
# function that scan uses internally and that pfunc uses now and
634-
# that users might want to use
635-
x = vector("x")
636-
y = vector("y")
637-
y2 = vector("y2")
638-
z = shared(0.25)
639-
640-
f1 = z * (x + y) ** 2 + 5
641-
f2 = clone_replace(
642-
f1, replace=[(y, y2)], rebuild_strict=True, copy_inputs_over=False
643-
)
644-
f2_inp = graph_inputs([f2])
645-
assert z not in f2_inp
646-
assert x not in f2_inp
647-
assert y2 not in f2_inp
648-
649-
def test_cloning_replace_not_strict_not_copy_inputs(self):
650-
# This has nothing to do with scan, but it refers to the clone
651-
# function that scan uses internally and that pfunc uses now and
652-
# that users might want to use
653-
x = vector("x")
654-
y = fvector("y")
655-
y2 = dvector("y2")
656-
z = shared(0.25)
657-
658-
f1 = z * (x + y) ** 2 + 5
659-
f2 = clone_replace(
660-
f1, replace=[(y, y2)], rebuild_strict=False, copy_inputs_over=False
661-
)
662-
f2_inp = graph_inputs([f2])
663-
assert z not in f2_inp
664-
assert x not in f2_inp
665-
assert y2 not in f2_inp
666-
667-
def test_clone(self):
668-
def test(x, y, mention_y):
669-
if mention_y:
670-
d = 0.1 + 0 * y
671-
else:
672-
d = 0.1
673-
out = clone_replace(y, replace={x: x + d})
674-
return function([], out)()
675-
676-
x = shared(np.asarray(0.0, dtype=config.floatX))
677-
utt.assert_allclose(
678-
test(x, at.sum((x + 1) ** 2), mention_y=False), 1.21000003815
679-
)
680-
utt.assert_allclose(
681-
test(x, at.sum((x + 1) ** 2), mention_y=True), 1.21000003815
682-
)
683-
684-
685550
def test_clone_new_inputs():
686551
"""Make sure that `Apply.clone_with_new_inputs` properly handles `Type` changes."""
687552

0 commit comments

Comments
 (0)