Skip to content

Commit 9df4db9

Browse files
committed
Add experimental pyo3-wrapper feature
1 parent 611dcc2 commit 9df4db9

File tree

6 files changed

+9038
-0
lines changed

6 files changed

+9038
-0
lines changed

ast/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ unparse = ["rustpython-literal"]
1616
visitor = []
1717
all-nodes-with-ranges = []
1818

19+
# This feature is experimental
20+
# It reimplements AST types, but currently both slower than python AST types and limited to use in other API
21+
pyo3-wrapper = ["pyo3"]
22+
1923
[dependencies]
2024
rustpython-parser-core = { workspace = true }
2125
rustpython-literal = { workspace = true, optional = true }

ast/asdl_rs.py

Lines changed: 349 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,299 @@ def emit_to_pyo3_with_fields(self, cons, name):
10411041
0,
10421042
)
10431043

1044+
class Pyo3StructVisitor(EmitVisitor):
1045+
"""Visitor to generate type-defs for AST."""
1046+
1047+
def __init__(self, namespace, *args, **kw):
1048+
self.namespace = namespace
1049+
self.borrow = True
1050+
super().__init__(*args, **kw)
1051+
1052+
@property
1053+
def generics(self):
1054+
if self.namespace == "ranged":
1055+
return "<TextRange>"
1056+
elif self.namespace == "located":
1057+
return "<SourceRange>"
1058+
else:
1059+
assert False, self.namespace
1060+
1061+
@property
1062+
def module_name(self):
1063+
name = f"rustpython_ast.{self.namespace}"
1064+
return name
1065+
1066+
@property
1067+
def ref_def(self):
1068+
return "&'static " if self.borrow else ""
1069+
1070+
@property
1071+
def ref(self):
1072+
return "&" if self.borrow else ""
1073+
1074+
def emit_class(self, name, rust_name, simple, base="super::AST"):
1075+
info = self.type_info[name]
1076+
if simple:
1077+
generics = ""
1078+
else:
1079+
generics = self.generics
1080+
if info.is_sum:
1081+
subclass = ", subclass"
1082+
body = ""
1083+
into = f"{rust_name}"
1084+
else:
1085+
subclass = ""
1086+
body = f"(pub {self.ref_def} crate::{rust_name}{generics})"
1087+
into = f"{rust_name}(node)"
1088+
1089+
self.emit(
1090+
textwrap.dedent(
1091+
f"""
1092+
#[pyclass(module="{self.module_name}", name="_{name}", extends={base}, frozen{subclass})]
1093+
#[derive(Clone, Debug)]
1094+
pub struct {rust_name} {body};
1095+
1096+
impl From<{self.ref_def} crate::{rust_name}{generics}> for {rust_name} {{
1097+
fn from({"" if body else "_"}node: {self.ref_def} crate::{rust_name}{generics}) -> Self {{
1098+
{into}
1099+
}}
1100+
}}
1101+
"""
1102+
),
1103+
0,
1104+
)
1105+
if subclass:
1106+
self.emit(
1107+
textwrap.dedent(
1108+
f"""
1109+
#[pymethods]
1110+
impl {rust_name} {{
1111+
#[new]
1112+
fn new() -> PyClassInitializer<Self> {{
1113+
PyClassInitializer::from(AST)
1114+
.add_subclass(Self)
1115+
}}
1116+
1117+
}}
1118+
impl ToPyObject for {rust_name} {{
1119+
fn to_object(&self, py: Python) -> PyObject {{
1120+
let initializer = PyClassInitializer::from(AST)
1121+
.add_subclass(self.clone());
1122+
Py::new(py, initializer).unwrap().into_py(py)
1123+
}}
1124+
}}
1125+
"""
1126+
),
1127+
0,
1128+
)
1129+
else:
1130+
if base != "super::AST":
1131+
add_subclass = f".add_subclass({base})"
1132+
else:
1133+
add_subclass = ""
1134+
self.emit(
1135+
textwrap.dedent(
1136+
f"""
1137+
impl ToPyObject for {rust_name} {{
1138+
fn to_object(&self, py: Python) -> PyObject {{
1139+
let initializer = PyClassInitializer::from(AST)
1140+
{add_subclass}
1141+
.add_subclass(self.clone());
1142+
Py::new(py, initializer).unwrap().into_py(py)
1143+
}}
1144+
}}
1145+
"""
1146+
),
1147+
0,
1148+
)
1149+
1150+
if not subclass:
1151+
self.emit_wrapper(rust_name)
1152+
1153+
def emit_getter(self, owner, type_name):
1154+
self.emit(
1155+
textwrap.dedent(
1156+
f"""
1157+
#[pymethods]
1158+
impl {type_name} {{
1159+
"""
1160+
),
1161+
0,
1162+
)
1163+
1164+
for field in owner.fields:
1165+
self.emit(
1166+
textwrap.dedent(
1167+
f"""
1168+
#[getter]
1169+
#[inline]
1170+
fn get_{field.name}(&self, py: Python) -> PyResult<PyObject> {{
1171+
self.0.{rust_field(field.name)}.to_pyo3_wrapper(py)
1172+
}}
1173+
"""
1174+
),
1175+
3,
1176+
)
1177+
1178+
self.emit(
1179+
textwrap.dedent(
1180+
"""
1181+
}
1182+
"""
1183+
),
1184+
0,
1185+
)
1186+
1187+
def emit_getattr(self, owner, type_name):
1188+
self.emit(
1189+
textwrap.dedent(
1190+
f"""
1191+
#[pymethods]
1192+
impl {type_name} {{
1193+
fn __getattr__(&self, py: Python, key: &str) -> PyResult<PyObject> {{
1194+
let object: Py<PyAny> = match key {{
1195+
"""
1196+
),
1197+
0,
1198+
)
1199+
1200+
for field in owner.fields:
1201+
self.emit(
1202+
f'"{field.name}" => self.0.{rust_field(field.name)}.to_pyo3_wrapper(py)?,',
1203+
3,
1204+
)
1205+
1206+
self.emit(
1207+
textwrap.dedent(
1208+
"""
1209+
_ => todo!(),
1210+
};
1211+
Ok(object)
1212+
}
1213+
}
1214+
"""
1215+
),
1216+
0,
1217+
)
1218+
1219+
def emit_wrapper(self, rust_name):
1220+
self.emit(
1221+
f"""
1222+
impl ToPyo3Wrapper for crate::{rust_name}{self.generics} {{
1223+
#[inline]
1224+
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
1225+
Ok({rust_name}(self).to_object(py))
1226+
}}
1227+
}}
1228+
""",
1229+
0,
1230+
)
1231+
1232+
def visitModule(self, mod):
1233+
for dfn in mod.dfns:
1234+
self.visit(dfn)
1235+
1236+
def visitType(self, type, depth=0):
1237+
self.visit(type.value, type, depth)
1238+
1239+
def visitSum(self, sum, type, depth=0):
1240+
rust_name = rust_type_name(type.name)
1241+
1242+
simple = is_simple(sum)
1243+
self.emit_class(type.name, rust_name, simple)
1244+
1245+
if not simple:
1246+
self.emit(
1247+
f"""
1248+
impl ToPyo3Wrapper for crate::{rust_name}{self.generics} {{
1249+
#[inline]
1250+
fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
1251+
match &self {{
1252+
""",
1253+
0,
1254+
)
1255+
1256+
for cons in sum.types:
1257+
self.emit(f"Self::{cons.name}(cons) => cons.to_pyo3_wrapper(py),", 3)
1258+
1259+
self.emit(
1260+
"""
1261+
}
1262+
}
1263+
}
1264+
""",
1265+
0,
1266+
)
1267+
1268+
for cons in sum.types:
1269+
self.visit(cons, rust_name, simple, depth + 1)
1270+
1271+
def visitProduct(self, product, type, depth=0):
1272+
rust_name = rust_type_name(type.name)
1273+
self.emit_class(type.name, rust_name, False)
1274+
if self.borrow:
1275+
self.emit_getter(product, rust_name)
1276+
1277+
def visitConstructor(self, cons, parent, simple, depth):
1278+
if simple:
1279+
self.emit(
1280+
f"""
1281+
#[pyclass(module="{self.module_name}", name="_{cons.name}", extends={parent})]
1282+
pub struct {parent}{cons.name};
1283+
1284+
impl ToPyObject for {parent}{cons.name} {{
1285+
fn to_object(&self, py: Python) -> PyObject {{
1286+
let initializer = PyClassInitializer::from(AST)
1287+
.add_subclass({parent})
1288+
.add_subclass(Self);
1289+
Py::new(py, initializer).unwrap().into_py(py)
1290+
}}
1291+
}}
1292+
""",
1293+
depth,
1294+
)
1295+
else:
1296+
self.emit_class(
1297+
cons.name,
1298+
f"{parent}{cons.name}",
1299+
simple=False,
1300+
base=parent,
1301+
)
1302+
if self.borrow:
1303+
self.emit_getter(cons, f"{parent}{cons.name}")
1304+
1305+
1306+
class Pyo3PymoduleVisitor(EmitVisitor):
1307+
def __init__(self, namespace, *args, **kw):
1308+
self.namespace = namespace
1309+
super().__init__(*args, **kw)
1310+
1311+
def visitModule(self, mod):
1312+
for dfn in mod.dfns:
1313+
self.visit(dfn)
1314+
1315+
def visitType(self, type, depth=0):
1316+
self.visit(type.value, type.name, depth)
1317+
1318+
def visitProduct(self, product, name, depth=0):
1319+
rust_name = rust_type_name(name)
1320+
self.emit_fields(name, rust_name, False)
1321+
1322+
def visitSum(self, sum, name, depth):
1323+
rust_name = rust_type_name(name)
1324+
simple = is_simple(sum)
1325+
self.emit_fields(name, rust_name, True)
1326+
1327+
for cons in sum.types:
1328+
self.visit(cons, name, simple, depth)
1329+
1330+
def visitConstructor(self, cons, parent, simple, depth):
1331+
rust_name = rust_type_name(parent) + rust_type_name(cons.name)
1332+
self.emit_fields(cons.name, rust_name, simple)
1333+
1334+
def emit_fields(self, name, rust_name, simple):
1335+
self.emit(f"super::init_type::<{rust_name}, crate::generic::{rust_name}>(py, ast_module, m)?;", 1)
1336+
10441337

10451338
class StdlibClassDefVisitor(EmitVisitor):
10461339
def visitModule(self, mod):
@@ -1452,6 +1745,60 @@ def write_to_pyo3_simple(type_info, f):
14521745
)
14531746

14541747

1748+
def write_pyo3_wrapper(mod, type_info, namespace, f):
1749+
Pyo3StructVisitor(namespace, f, type_info).visit(mod)
1750+
1751+
if namespace == "located":
1752+
for type_info in type_info.values():
1753+
if not type_info.is_simple or not type_info.is_sum:
1754+
continue
1755+
1756+
rust_name = type_info.rust_sum_name
1757+
f.write(
1758+
f"""
1759+
impl ToPyo3Wrapper for crate::generic::{rust_name} {{
1760+
#[inline]
1761+
fn to_pyo3_wrapper(&self, py: Python) -> PyResult<Py<PyAny>> {{
1762+
match &self {{
1763+
""",
1764+
)
1765+
for cons in type_info.type.value.types:
1766+
f.write(
1767+
f"Self::{cons.name} => Ok({rust_name}{cons.name}.to_object(py)),",
1768+
)
1769+
f.write(
1770+
"""
1771+
}
1772+
}
1773+
}
1774+
""",
1775+
)
1776+
1777+
for cons in type_info.type.value.types:
1778+
f.write(
1779+
f"""
1780+
impl ToPyo3Wrapper for crate::generic::{rust_name}{cons.name} {{
1781+
#[inline]
1782+
fn to_pyo3_wrapper(&self, py: Python) -> PyResult<Py<PyAny>> {{
1783+
Ok({rust_name}{cons.name}.to_object(py))
1784+
}}
1785+
}}
1786+
"""
1787+
)
1788+
1789+
f.write(
1790+
"""
1791+
pub fn add_to_module(py: Python, m: &PyModule) -> PyResult<()> {
1792+
super::init_module(py, m)?;
1793+
1794+
let ast_module = PyModule::import(py, "_ast")?;
1795+
"""
1796+
)
1797+
1798+
Pyo3PymoduleVisitor(namespace, f, type_info).visit(mod)
1799+
f.write("Ok(())\n}")
1800+
1801+
14551802
def write_ast_mod(mod, type_info, f):
14561803
f.write(
14571804
textwrap.dedent(
@@ -1498,6 +1845,8 @@ def main(
14981845
("located", p(write_located_def, mod, type_info)),
14991846
("visitor", p(write_visitor_def, mod, type_info)),
15001847
("to_pyo3", p(write_to_pyo3, mod, type_info)),
1848+
("pyo3_wrapper_located", p(write_pyo3_wrapper, mod, type_info, "located")),
1849+
("pyo3_wrapper_ranged", p(write_pyo3_wrapper, mod, type_info, "ranged")),
15011850
]:
15021851
with (ast_dir / f"{filename}.rs").open("w") as f:
15031852
f.write(auto_gen_msg)

0 commit comments

Comments
 (0)