Skip to content

Commit 8af33de

Browse files
committed
Add experimental pyo3-wrapper feature
1 parent 09cbef2 commit 8af33de

File tree

6 files changed

+9011
-0
lines changed

6 files changed

+9011
-0
lines changed

ast/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ unparse = ["rustpython-literal"]
1616
visitor = []
1717
all-nodes-with-ranges = []
1818

19+
# This feature is experimental
20+
# It reimplements AST types, but currently both slower than python AST types and limited to use in other API
21+
pyo3-wrapper = ["pyo3"]
22+
1923
[dependencies]
2024
rustpython-parser-core = { workspace = true }
2125
rustpython-literal = { workspace = true, optional = true }

ast/asdl_rs.py

Lines changed: 350 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1042,6 +1042,302 @@ def emit_to_pyo3_with_fields(self, cons, name):
10421042
)
10431043

10441044

1045+
class Pyo3StructVisitor(EmitVisitor):
1046+
"""Visitor to generate type-defs for AST."""
1047+
1048+
def __init__(self, namespace, *args, **kw):
1049+
self.namespace = namespace
1050+
self.borrow = True
1051+
super().__init__(*args, **kw)
1052+
1053+
@property
1054+
def generics(self):
1055+
if self.namespace == "ranged":
1056+
return "<TextRange>"
1057+
elif self.namespace == "located":
1058+
return "<SourceRange>"
1059+
else:
1060+
assert False, self.namespace
1061+
1062+
@property
1063+
def module_name(self):
1064+
name = f"rustpython_ast.{self.namespace}"
1065+
return name
1066+
1067+
@property
1068+
def ref_def(self):
1069+
return "&'static " if self.borrow else ""
1070+
1071+
@property
1072+
def ref(self):
1073+
return "&" if self.borrow else ""
1074+
1075+
def emit_class(self, name, rust_name, simple, base="super::AST"):
1076+
info = self.type_info[name]
1077+
if simple:
1078+
generics = ""
1079+
else:
1080+
generics = self.generics
1081+
if info.is_sum:
1082+
subclass = ", subclass"
1083+
body = ""
1084+
into = f"{rust_name}"
1085+
else:
1086+
subclass = ""
1087+
body = f"(pub {self.ref_def} crate::{rust_name}{generics})"
1088+
into = f"{rust_name}(node)"
1089+
1090+
self.emit(
1091+
textwrap.dedent(
1092+
f"""
1093+
#[pyclass(module="{self.module_name}", name="_{name}", extends={base}, frozen{subclass})]
1094+
#[derive(Clone, Debug)]
1095+
pub struct {rust_name} {body};
1096+
1097+
impl From<{self.ref_def} crate::{rust_name}{generics}> for {rust_name} {{
1098+
fn from({"" if body else "_"}node: {self.ref_def} crate::{rust_name}{generics}) -> Self {{
1099+
{into}
1100+
}}
1101+
}}
1102+
"""
1103+
),
1104+
0,
1105+
)
1106+
if subclass:
1107+
self.emit(
1108+
textwrap.dedent(
1109+
f"""
1110+
#[pymethods]
1111+
impl {rust_name} {{
1112+
#[new]
1113+
fn new() -> PyClassInitializer<Self> {{
1114+
PyClassInitializer::from(AST)
1115+
.add_subclass(Self)
1116+
}}
1117+
1118+
}}
1119+
impl ToPyObject for {rust_name} {{
1120+
fn to_object(&self, py: Python) -> PyObject {{
1121+
let initializer = PyClassInitializer::from(AST)
1122+
.add_subclass(self.clone());
1123+
Py::new(py, initializer).unwrap().into_py(py)
1124+
}}
1125+
}}
1126+
"""
1127+
),
1128+
0,
1129+
)
1130+
else:
1131+
if base != "super::AST":
1132+
add_subclass = f".add_subclass({base})"
1133+
else:
1134+
add_subclass = ""
1135+
self.emit(
1136+
textwrap.dedent(
1137+
f"""
1138+
impl ToPyObject for {rust_name} {{
1139+
fn to_object(&self, py: Python) -> PyObject {{
1140+
let initializer = PyClassInitializer::from(AST)
1141+
{add_subclass}
1142+
.add_subclass(self.clone());
1143+
Py::new(py, initializer).unwrap().into_py(py)
1144+
}}
1145+
}}
1146+
"""
1147+
),
1148+
0,
1149+
)
1150+
1151+
if not subclass:
1152+
self.emit_wrapper(rust_name)
1153+
1154+
def emit_getter(self, owner, type_name):
1155+
self.emit(
1156+
textwrap.dedent(
1157+
f"""
1158+
#[pymethods]
1159+
impl {type_name} {{
1160+
"""
1161+
),
1162+
0,
1163+
)
1164+
1165+
for field in owner.fields:
1166+
self.emit(
1167+
textwrap.dedent(
1168+
f"""
1169+
#[getter]
1170+
#[inline]
1171+
fn get_{field.name}(&self, py: Python) -> PyResult<PyObject> {{
1172+
self.0.{rust_field(field.name)}.to_pyo3_wrapper(py)
1173+
}}
1174+
"""
1175+
),
1176+
3,
1177+
)
1178+
1179+
self.emit(
1180+
textwrap.dedent(
1181+
"""
1182+
}
1183+
"""
1184+
),
1185+
0,
1186+
)
1187+
1188+
def emit_getattr(self, owner, type_name):
1189+
self.emit(
1190+
textwrap.dedent(
1191+
f"""
1192+
#[pymethods]
1193+
impl {type_name} {{
1194+
fn __getattr__(&self, py: Python, key: &str) -> PyResult<PyObject> {{
1195+
let object: Py<PyAny> = match key {{
1196+
"""
1197+
),
1198+
0,
1199+
)
1200+
1201+
for field in owner.fields:
1202+
self.emit(
1203+
f'"{field.name}" => self.0.{rust_field(field.name)}.to_pyo3_wrapper(py)?,',
1204+
3,
1205+
)
1206+
1207+
self.emit(
1208+
textwrap.dedent(
1209+
"""
1210+
_ => todo!(),
1211+
};
1212+
Ok(object)
1213+
}
1214+
}
1215+
"""
1216+
),
1217+
0,
1218+
)
1219+
1220+
def emit_wrapper(self, rust_name):
1221+
self.emit(
1222+
f"""
1223+
impl ToPyo3Wrapper for crate::{rust_name}{self.generics} {{
1224+
#[inline]
1225+
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
1226+
Ok({rust_name}(self).to_object(py))
1227+
}}
1228+
}}
1229+
""",
1230+
0,
1231+
)
1232+
1233+
def visitModule(self, mod):
1234+
for dfn in mod.dfns:
1235+
self.visit(dfn)
1236+
1237+
def visitType(self, type, depth=0):
1238+
self.visit(type.value, type, depth)
1239+
1240+
def visitSum(self, sum, type, depth=0):
1241+
rust_name = rust_type_name(type.name)
1242+
1243+
simple = is_simple(sum)
1244+
self.emit_class(type.name, rust_name, simple)
1245+
1246+
if not simple:
1247+
self.emit(
1248+
f"""
1249+
impl ToPyo3Wrapper for crate::{rust_name}{self.generics} {{
1250+
#[inline]
1251+
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
1252+
match &self {{
1253+
""",
1254+
0,
1255+
)
1256+
1257+
for cons in sum.types:
1258+
self.emit(f"Self::{cons.name}(cons) => cons.to_pyo3_wrapper(py),", 3)
1259+
1260+
self.emit(
1261+
"""
1262+
}
1263+
}
1264+
}
1265+
""",
1266+
0,
1267+
)
1268+
1269+
for cons in sum.types:
1270+
self.visit(cons, rust_name, simple, depth + 1)
1271+
1272+
def visitProduct(self, product, type, depth=0):
1273+
rust_name = rust_type_name(type.name)
1274+
self.emit_class(type.name, rust_name, False)
1275+
if self.borrow:
1276+
self.emit_getter(product, rust_name)
1277+
1278+
def visitConstructor(self, cons, parent, simple, depth):
1279+
if simple:
1280+
self.emit(
1281+
f"""
1282+
#[pyclass(module="{self.module_name}", name="_{cons.name}", extends={parent})]
1283+
pub struct {parent}{cons.name};
1284+
1285+
impl ToPyObject for {parent}{cons.name} {{
1286+
fn to_object(&self, py: Python) -> PyObject {{
1287+
let initializer = PyClassInitializer::from(AST)
1288+
.add_subclass({parent})
1289+
.add_subclass(Self);
1290+
Py::new(py, initializer).unwrap().into_py(py)
1291+
}}
1292+
}}
1293+
""",
1294+
depth,
1295+
)
1296+
else:
1297+
self.emit_class(
1298+
cons.name,
1299+
f"{parent}{cons.name}",
1300+
simple=False,
1301+
base=parent,
1302+
)
1303+
if self.borrow:
1304+
self.emit_getter(cons, f"{parent}{cons.name}")
1305+
1306+
1307+
class Pyo3PymoduleVisitor(EmitVisitor):
1308+
def __init__(self, namespace, *args, **kw):
1309+
self.namespace = namespace
1310+
super().__init__(*args, **kw)
1311+
1312+
def visitModule(self, mod):
1313+
for dfn in mod.dfns:
1314+
self.visit(dfn)
1315+
1316+
def visitType(self, type, depth=0):
1317+
self.visit(type.value, type.name, depth)
1318+
1319+
def visitProduct(self, product, name, depth=0):
1320+
rust_name = rust_type_name(name)
1321+
self.emit_fields(name, rust_name, False)
1322+
1323+
def visitSum(self, sum, name, depth):
1324+
rust_name = rust_type_name(name)
1325+
simple = is_simple(sum)
1326+
self.emit_fields(name, rust_name, True)
1327+
1328+
for cons in sum.types:
1329+
self.visit(cons, name, simple, depth)
1330+
1331+
def visitConstructor(self, cons, parent, simple, depth):
1332+
rust_name = rust_type_name(parent) + rust_type_name(cons.name)
1333+
self.emit_fields(cons.name, rust_name, simple)
1334+
1335+
def emit_fields(self, name, rust_name, simple):
1336+
self.emit(
1337+
f"super::init_type::<{rust_name}, crate::generic::{rust_name}>(py, m)?;", 1
1338+
)
1339+
1340+
10451341
class StdlibClassDefVisitor(EmitVisitor):
10461342
def visitModule(self, mod):
10471343
for dfn in mod.dfns:
@@ -1452,6 +1748,58 @@ def write_to_pyo3_simple(type_info, f):
14521748
)
14531749

14541750

1751+
def write_pyo3_wrapper(mod, type_info, namespace, f):
1752+
Pyo3StructVisitor(namespace, f, type_info).visit(mod)
1753+
1754+
if namespace == "located":
1755+
for type_info in type_info.values():
1756+
if not type_info.is_simple or not type_info.is_sum:
1757+
continue
1758+
1759+
rust_name = type_info.rust_sum_name
1760+
f.write(
1761+
f"""
1762+
impl ToPyo3Wrapper for crate::generic::{rust_name} {{
1763+
#[inline]
1764+
fn to_pyo3_wrapper(&self, py: Python) -> PyResult<Py<PyAny>> {{
1765+
match &self {{
1766+
""",
1767+
)
1768+
for cons in type_info.type.value.types:
1769+
f.write(
1770+
f"Self::{cons.name} => Ok({rust_name}{cons.name}.to_object(py)),",
1771+
)
1772+
f.write(
1773+
"""
1774+
}
1775+
}
1776+
}
1777+
""",
1778+
)
1779+
1780+
for cons in type_info.type.value.types:
1781+
f.write(
1782+
f"""
1783+
impl ToPyo3Wrapper for crate::generic::{rust_name}{cons.name} {{
1784+
#[inline]
1785+
fn to_pyo3_wrapper(&self, py: Python) -> PyResult<Py<PyAny>> {{
1786+
Ok({rust_name}{cons.name}.to_object(py))
1787+
}}
1788+
}}
1789+
"""
1790+
)
1791+
1792+
f.write(
1793+
"""
1794+
pub fn add_to_module(py: Python, m: &PyModule) -> PyResult<()> {
1795+
super::init_module(py, m)?;
1796+
"""
1797+
)
1798+
1799+
Pyo3PymoduleVisitor(namespace, f, type_info).visit(mod)
1800+
f.write("Ok(())\n}")
1801+
1802+
14551803
def write_ast_mod(mod, type_info, f):
14561804
f.write(
14571805
textwrap.dedent(
@@ -1498,6 +1846,8 @@ def main(
14981846
("located", p(write_located_def, mod, type_info)),
14991847
("visitor", p(write_visitor_def, mod, type_info)),
15001848
("to_pyo3", p(write_to_pyo3, mod, type_info)),
1849+
("pyo3_wrapper_located", p(write_pyo3_wrapper, mod, type_info, "located")),
1850+
("pyo3_wrapper_ranged", p(write_pyo3_wrapper, mod, type_info, "ranged")),
15011851
]:
15021852
with (ast_dir / f"{filename}.rs").open("w") as f:
15031853
f.write(auto_gen_msg)

0 commit comments

Comments
 (0)