Skip to content

Commit 611dcc2

Browse files
authored
rustpython_ast + pyo3 (#25)
1 parent 53de75e commit 611dcc2

File tree

6 files changed

+3658
-0
lines changed

6 files changed

+3658
-0
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ log = "0.4.16"
3030
num-complex = "0.4.0"
3131
num-bigint = "0.4.3"
3232
num-traits = "0.2"
33+
pyo3 = { version = "0.18.3" }
3334
rand = "0.8.5"
3435
serde = "1.0"
3536
static_assertions = "1.1"

ast/Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,6 @@ rustpython-literal = { workspace = true, optional = true }
2323
is-macro = { workspace = true }
2424
num-bigint = { workspace = true }
2525
static_assertions = "1.1.0"
26+
num-complex = { workspace = true }
27+
once_cell = { workspace = true }
28+
pyo3 = { workspace = true, optional = true, features = ["num-bigint", "num-complex"] }

ast/asdl_rs.py

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -937,6 +937,111 @@ def emit_located_impl(self, info):
937937
)
938938

939939

940+
class ToPyo3AstVisitor(EmitVisitor):
941+
"""Visitor to generate type-defs for AST."""
942+
943+
def __init__(self, namespace, *args, **kw):
944+
super().__init__(*args, **kw)
945+
self.namespace = namespace
946+
947+
@property
948+
def generics(self):
949+
if self.namespace == "ranged":
950+
return "<TextRange>"
951+
elif self.namespace == "located":
952+
return "<SourceRange>"
953+
else:
954+
assert False, self.namespace
955+
956+
def visitModule(self, mod):
957+
for dfn in mod.dfns:
958+
self.visit(dfn)
959+
960+
def visitType(self, type, depth=0):
961+
self.visit(type.value, type.name, depth)
962+
963+
def visitProduct(self, product, name, depth=0):
964+
rust_name = rust_type_name(name)
965+
self.emit_to_pyo3_with_fields(product, rust_name)
966+
967+
def visitSum(self, sum, name, depth=0):
968+
rust_name = rust_type_name(name)
969+
simple = is_simple(sum)
970+
if is_simple(sum):
971+
return
972+
973+
self.emit(
974+
f"""
975+
impl ToPyo3Ast for crate::generic::{rust_name}{self.generics} {{
976+
#[inline]
977+
fn to_pyo3_ast(&self, {"_" if simple else ""}py: Python) -> PyResult<Py<PyAny>> {{
978+
let instance = match &self {{
979+
""",
980+
0,
981+
)
982+
for cons in sum.types:
983+
self.emit(
984+
f"""crate::{rust_name}::{cons.name}(cons) => cons.to_pyo3_ast(py)?,""",
985+
depth,
986+
)
987+
self.emit(
988+
"""
989+
};
990+
Ok(instance)
991+
}
992+
}
993+
""",
994+
0,
995+
)
996+
997+
for cons in sum.types:
998+
self.visit(cons, rust_name, depth)
999+
1000+
def visitConstructor(self, cons, parent, depth):
1001+
self.emit_to_pyo3_with_fields(cons, f"{parent}{cons.name}")
1002+
1003+
def emit_to_pyo3_with_fields(self, cons, name):
1004+
if cons.fields:
1005+
self.emit(
1006+
f"""
1007+
impl ToPyo3Ast for crate::{name}{self.generics} {{
1008+
#[inline]
1009+
fn to_pyo3_ast(&self, py: Python) -> PyResult<Py<PyAny>> {{
1010+
let cache = Self::py_type_cache().get().unwrap();
1011+
let instance = cache.0.call1(py, (
1012+
""",
1013+
0,
1014+
)
1015+
for field in cons.fields:
1016+
self.emit(
1017+
f"self.{rust_field(field.name)}.to_pyo3_ast(py)?,",
1018+
3,
1019+
)
1020+
self.emit(
1021+
"""
1022+
))?;
1023+
Ok(instance)
1024+
}
1025+
}
1026+
""",
1027+
0,
1028+
)
1029+
else:
1030+
self.emit(
1031+
f"""
1032+
impl ToPyo3Ast for crate::{name}{self.generics} {{
1033+
#[inline]
1034+
fn to_pyo3_ast(&self, py: Python) -> PyResult<Py<PyAny>> {{
1035+
let cache = Self::py_type_cache().get().unwrap();
1036+
let instance = cache.0.call0(py)?;
1037+
Ok(instance)
1038+
}}
1039+
}}
1040+
""",
1041+
0,
1042+
)
1043+
1044+
9401045
class StdlibClassDefVisitor(EmitVisitor):
9411046
def visitModule(self, mod):
9421047
for dfn in mod.dfns:
@@ -1271,6 +1376,82 @@ def write_located_def(mod, type_info, f):
12711376
LocatedDefVisitor(f, type_info).visit(mod)
12721377

12731378

1379+
def write_pyo3_node(type_info, f):
1380+
def write(info: TypeInfo):
1381+
rust_name = info.rust_sum_name
1382+
if info.is_simple:
1383+
generics = ""
1384+
else:
1385+
generics = "<R>"
1386+
1387+
f.write(
1388+
textwrap.dedent(
1389+
f"""
1390+
impl{generics} Pyo3Node for crate::generic::{rust_name}{generics} {{
1391+
#[inline]
1392+
fn py_type_cache() -> &'static OnceCell<(Py<PyAny>, Py<PyAny>)> {{
1393+
static PY_TYPE: OnceCell<(Py<PyAny>, Py<PyAny>)> = OnceCell::new();
1394+
&PY_TYPE
1395+
}}
1396+
}}
1397+
"""
1398+
),
1399+
)
1400+
1401+
for info in type_info.values():
1402+
write(info)
1403+
1404+
1405+
def write_to_pyo3(mod, type_info, f):
1406+
write_pyo3_node(type_info, f)
1407+
write_to_pyo3_simple(type_info, f)
1408+
1409+
for namespace in ("ranged", "located"):
1410+
ToPyo3AstVisitor(namespace, f, type_info).visit(mod)
1411+
1412+
f.write(
1413+
"""
1414+
pub fn init(py: Python) -> PyResult<()> {
1415+
let ast_module = PyModule::import(py, "_ast")?;
1416+
"""
1417+
)
1418+
1419+
for info in type_info.values():
1420+
rust_name = info.rust_sum_name
1421+
f.write(f"cache_py_type::<crate::generic::{rust_name}>(ast_module)?;\n")
1422+
f.write("Ok(())\n}")
1423+
1424+
1425+
def write_to_pyo3_simple(type_info, f):
1426+
for type_info in type_info.values():
1427+
if not type_info.is_sum:
1428+
continue
1429+
if not type_info.is_simple:
1430+
continue
1431+
1432+
rust_name = type_info.rust_sum_name
1433+
f.write(
1434+
f"""
1435+
impl ToPyo3Ast for crate::generic::{rust_name} {{
1436+
#[inline]
1437+
fn to_pyo3_ast(&self, _py: Python) -> PyResult<Py<PyAny>> {{
1438+
let cell = match &self {{
1439+
""",
1440+
)
1441+
for cons in type_info.type.value.types:
1442+
f.write(
1443+
f"""crate::{rust_name}::{cons.name} => crate::{rust_name}{cons.name}::py_type_cache(),""",
1444+
)
1445+
f.write(
1446+
"""
1447+
};
1448+
Ok(cell.get().unwrap().1.clone())
1449+
}
1450+
}
1451+
""",
1452+
)
1453+
1454+
12741455
def write_ast_mod(mod, type_info, f):
12751456
f.write(
12761457
textwrap.dedent(
@@ -1316,6 +1497,7 @@ def main(
13161497
("ranged", p(write_ranged_def, mod, type_info)),
13171498
("located", p(write_located_def, mod, type_info)),
13181499
("visitor", p(write_visitor_def, mod, type_info)),
1500+
("to_pyo3", p(write_to_pyo3, mod, type_info)),
13191501
]:
13201502
with (ast_dir / f"{filename}.rs").open("w") as f:
13211503
f.write(auto_gen_msg)

0 commit comments

Comments
 (0)