Skip to content

Commit f6469b9

Browse files
committed
Implement some egglog rewrites
1 parent af7ed24 commit f6469b9

File tree

11 files changed

+1422
-0
lines changed

11 files changed

+1422
-0
lines changed
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
try:
2+
import egglog
3+
except ImportError:
4+
raise RuntimeError("egglog must be manually installed")
5+
6+
try:
7+
import frozendict
8+
except ImportError:
9+
raise RuntimeError("frozendict must be manually installed")
10+
11+
# Register rewrites
12+
import pytensor.sandbox.scrambled.rewrites.basic
13+
import pytensor.sandbox.scrambled.rewrites.op
14+
import pytensor.sandbox.scrambled.rewrites.tensorify
15+
from pytensor.sandbox.scrambled.basic import egraph
16+
17+
18+
__all__ = ("egraph",)

pytensor/sandbox/scrambled/basic.py

Lines changed: 306 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,306 @@
1+
from __future__ import annotations
2+
3+
from egglog import (
4+
EGraph,
5+
Expr,
6+
PyObject,
7+
String,
8+
StringLike,
9+
convert,
10+
converter,
11+
i64,
12+
i64Like,
13+
)
14+
15+
from pytensor import Variable
16+
from pytensor.graph import FunctionGraph
17+
18+
19+
egraph = EGraph()
20+
21+
22+
tensorify_ruleset = egraph.ruleset("tensorify")
23+
24+
25+
@egraph.class_
26+
class Int(Expr):
27+
def __init__(self, value: i64Like) -> None:
28+
...
29+
30+
@classmethod
31+
def var(cls, name: StringLike) -> Int:
32+
...
33+
34+
def __add__(self, other: Int) -> Int:
35+
...
36+
37+
def __sub__(self, other: Int) -> Int:
38+
...
39+
40+
def __eq__(self, i: Int) -> Int:
41+
...
42+
43+
# Egglog doesn't allow to override __ne__ for now
44+
# # def __ne__(self, i: Int) -> Int: ...
45+
46+
def __gt__(self, i: Int) -> Int:
47+
...
48+
49+
def __ge__(self, i: Int) -> Int:
50+
...
51+
52+
def __lt__(self, i: Int) -> Int:
53+
...
54+
55+
def __le__(self, i: Int) -> Int:
56+
...
57+
58+
@property
59+
def tensorify(self) -> PyObject:
60+
...
61+
62+
63+
converter(i64, Int, Int)
64+
65+
66+
@egraph.class_
67+
class IntTuple(Expr):
68+
def __init__(self, head: Int) -> None:
69+
...
70+
71+
@classmethod
72+
def empty(cls) -> IntTuple:
73+
...
74+
75+
@egraph.method(cost=1000)
76+
@classmethod
77+
def from_range(cls, i: Int, n: Int) -> IntTuple:
78+
...
79+
80+
def __add__(self, other: IntTuple) -> IntTuple:
81+
...
82+
83+
def __getitem__(self, i: Int) -> Int:
84+
...
85+
86+
@egraph.method(cost=1000)
87+
def length(self) -> Int:
88+
...
89+
90+
def insert(self, idx: Int, value: Int) -> IntTuple:
91+
...
92+
93+
def pop(self, idx: Int) -> IntTuple:
94+
...
95+
96+
@property
97+
def tensorify(self) -> PyObject:
98+
...
99+
100+
101+
converter(int, IntTuple, lambda i: IntTuple(Int(i64(i))))
102+
converter(i64, IntTuple, lambda i: IntTuple(Int(i)))
103+
converter(Int, IntTuple, lambda i: IntTuple(i))
104+
converter(
105+
tuple,
106+
IntTuple,
107+
lambda x: (
108+
IntTuple(convert(x[0], Int)) + convert(x[1:], IntTuple)
109+
if len(x) > 1
110+
else (IntTuple(convert(x[0], Int)) if x else IntTuple.empty())
111+
),
112+
)
113+
# converter(list, IntTuple, lambda x: convert(tuple(x), IntTuple)) # Not working!
114+
115+
116+
@egraph.class_
117+
class Tensor(Expr):
118+
def __init__(self, name: StringLike, shape: IntTuple = IntTuple.empty()) -> None:
119+
...
120+
121+
@classmethod
122+
def constant(cls, value: Int, shape: IntTuple = IntTuple.empty()) -> Tensor:
123+
...
124+
125+
@property
126+
def tensorify(self) -> PyObject:
127+
...
128+
129+
def __add__(self, other: Tensor) -> Tensor:
130+
...
131+
132+
def __sub__(self, other: Tensor) -> Tensor:
133+
...
134+
135+
def __mul__(self, other: Tensor) -> Tensor:
136+
...
137+
138+
def __pow__(self, other: Tensor) -> Tensor:
139+
...
140+
141+
def __neg__(self) -> Tensor:
142+
...
143+
144+
145+
@egraph.class_
146+
class TensorTuple(Expr):
147+
def __init__(self, value: Tensor) -> None:
148+
...
149+
150+
def __add__(self, other: TensorTuple) -> TensorTuple:
151+
...
152+
153+
@classmethod
154+
def empty(cls) -> TensorTuple:
155+
...
156+
157+
def __add__(self, other: TensorTuple) -> TensorTuple:
158+
...
159+
160+
def __getitem__(self, i: Int) -> Tensor:
161+
...
162+
163+
# __xor__ is used as a shorcut for broadcasting shape tuples
164+
def __xor__(self, other: TensorTuple) -> TensorTuple:
165+
...
166+
167+
@egraph.method(cost=1000)
168+
def length(self) -> Int:
169+
...
170+
171+
def insert(self, idx: Int, value: Tensor) -> TensorTuple:
172+
...
173+
174+
def pop(self, idx: Int) -> TensorTuple:
175+
...
176+
177+
@property
178+
def tensorify(self) -> PyObject:
179+
...
180+
181+
@egraph.method(cost=1000)
182+
@classmethod
183+
def from_int_tuple(cls, int_tuple: IntTuple) -> TensorTuple:
184+
...
185+
186+
@egraph.method(cost=1000)
187+
@classmethod
188+
def from_tensor_shape(
189+
cls, sh: TensorTuple, static_sh: IntTuple, idx: Int
190+
) -> TensorTuple:
191+
...
192+
193+
@property
194+
def tensorify(self) -> PyObject:
195+
...
196+
197+
198+
converter(i64, Tensor, lambda i: Tensor.constant(Int(i)))
199+
converter(int, Tensor, lambda i: Tensor.constant(Int(i64(i))))
200+
converter(i64, TensorTuple, lambda i: TensorTuple(Tensor.constant(Int(i))))
201+
converter(int, TensorTuple, lambda i: TensorTuple(Tensor.constant(Int(i64(i)))))
202+
converter(
203+
tuple,
204+
TensorTuple,
205+
lambda x: (
206+
TensorTuple(convert(x[0], Tensor)) + convert(x[1:], TensorTuple)
207+
if len(x) > 1
208+
else (TensorTuple(convert(x[0], Tensor)) if x else TensorTuple.empty())
209+
),
210+
)
211+
212+
213+
@egraph.class_
214+
class UnaryInOp(Expr):
215+
def __call__(self, x: Tensor) -> Tensor:
216+
...
217+
218+
@property
219+
def tensorify(self) -> PyObject:
220+
...
221+
222+
223+
@egraph.class_
224+
class BinaryInOp(Expr):
225+
def __call__(self, x: Tensor, y: Tensor) -> Tensor:
226+
...
227+
228+
@property
229+
def tensorify(self) -> PyObject:
230+
...
231+
232+
233+
@egraph.class_
234+
class VariadicInOp(Expr):
235+
def __call__(self, vars: TensorTuple) -> Tensor:
236+
...
237+
238+
@property
239+
def tensorify(self) -> PyObject:
240+
...
241+
242+
243+
@egraph.class_
244+
class VariadicInOutOp(Expr):
245+
def __call__(self, vars: TensorTuple) -> TensorTuple:
246+
...
247+
248+
@property
249+
def tensorify(self) -> PyObject:
250+
...
251+
252+
253+
@egraph.class_
254+
class ScalarOp(Expr):
255+
...
256+
257+
@property
258+
def tensorify(self) -> PyObject:
259+
...
260+
261+
262+
def eggify(*vars: Variable | FunctionGraph) -> tuple[Expr]:
263+
from pytensor.sandbox.scrambled.eggify.basic import eggify_fg
264+
265+
if len(vars) > 1 or isinstance(vars[0], Variable):
266+
fg = FunctionGraph(outputs=vars, clone=False)
267+
else:
268+
[fg] = vars
269+
return eggify_fg(fg)
270+
271+
272+
def rewrite_exprs(*exprs: Expr, epochs=100, verbose=False) -> tuple[Expr]:
273+
with egraph:
274+
initial_costs = []
275+
for expr in exprs:
276+
egraph.register(expr)
277+
initial_costs.append(egraph.extract(expr, include_cost=True)[1])
278+
279+
egraph.run(epochs)
280+
281+
new_exprs = []
282+
for expr, initial_cost in zip(exprs, initial_costs):
283+
new_expr, final_cost = egraph.extract(expr, include_cost=True)
284+
new_exprs.append(new_expr)
285+
if verbose:
286+
print(f"Cost: {initial_cost} -> {final_cost}")
287+
print(new_expr)
288+
print("")
289+
return tuple(new_exprs)
290+
291+
292+
def tensorify(*exprs: Expr) -> tuple[Variable]:
293+
with egraph:
294+
for expr in exprs:
295+
egraph.register(expr)
296+
egraph.run(100, ruleset=tensorify_ruleset)
297+
return tuple(egraph.eval(expr.tensorify) for expr in exprs)
298+
299+
300+
def egg_rewrite(
301+
*variables: Variable, epochs: int = 100, verbose: bool = False
302+
) -> tuple[Variable]:
303+
var_exprs = eggify(*variables)
304+
new_var_exprs = rewrite_exprs(*var_exprs, epochs=epochs, verbose=verbose)
305+
# TODO: Assert all root variables where present in fg
306+
return tensorify(*new_var_exprs)

pytensor/sandbox/scrambled/eggify/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)