@@ -1041,6 +1041,299 @@ def emit_to_pyo3_with_fields(self, cons, name):
1041
1041
0 ,
1042
1042
)
1043
1043
1044
+ class Pyo3StructVisitor (EmitVisitor ):
1045
+ """Visitor to generate type-defs for AST."""
1046
+
1047
+ def __init__ (self , namespace , * args , ** kw ):
1048
+ self .namespace = namespace
1049
+ self .borrow = True
1050
+ super ().__init__ (* args , ** kw )
1051
+
1052
+ @property
1053
+ def generics (self ):
1054
+ if self .namespace == "ranged" :
1055
+ return "<TextRange>"
1056
+ elif self .namespace == "located" :
1057
+ return "<SourceRange>"
1058
+ else :
1059
+ assert False , self .namespace
1060
+
1061
+ @property
1062
+ def module_name (self ):
1063
+ name = f"rustpython_ast.{ self .namespace } "
1064
+ return name
1065
+
1066
+ @property
1067
+ def ref_def (self ):
1068
+ return "&'static " if self .borrow else ""
1069
+
1070
+ @property
1071
+ def ref (self ):
1072
+ return "&" if self .borrow else ""
1073
+
1074
+ def emit_class (self , name , rust_name , simple , base = "super::AST" ):
1075
+ info = self .type_info [name ]
1076
+ if simple :
1077
+ generics = ""
1078
+ else :
1079
+ generics = self .generics
1080
+ if info .is_sum :
1081
+ subclass = ", subclass"
1082
+ body = ""
1083
+ into = f"{ rust_name } "
1084
+ else :
1085
+ subclass = ""
1086
+ body = f"(pub { self .ref_def } crate::{ rust_name } { generics } )"
1087
+ into = f"{ rust_name } (node)"
1088
+
1089
+ self .emit (
1090
+ textwrap .dedent (
1091
+ f"""
1092
+ #[pyclass(module="{ self .module_name } ", name="_{ name } ", extends={ base } , frozen{ subclass } )]
1093
+ #[derive(Clone, Debug)]
1094
+ pub struct { rust_name } { body } ;
1095
+
1096
+ impl From<{ self .ref_def } crate::{ rust_name } { generics } > for { rust_name } {{
1097
+ fn from({ "" if body else "_" } node: { self .ref_def } crate::{ rust_name } { generics } ) -> Self {{
1098
+ { into }
1099
+ }}
1100
+ }}
1101
+ """
1102
+ ),
1103
+ 0 ,
1104
+ )
1105
+ if subclass :
1106
+ self .emit (
1107
+ textwrap .dedent (
1108
+ f"""
1109
+ #[pymethods]
1110
+ impl { rust_name } {{
1111
+ #[new]
1112
+ fn new() -> PyClassInitializer<Self> {{
1113
+ PyClassInitializer::from(AST)
1114
+ .add_subclass(Self)
1115
+ }}
1116
+
1117
+ }}
1118
+ impl ToPyObject for { rust_name } {{
1119
+ fn to_object(&self, py: Python) -> PyObject {{
1120
+ let initializer = PyClassInitializer::from(AST)
1121
+ .add_subclass(self.clone());
1122
+ Py::new(py, initializer).unwrap().into_py(py)
1123
+ }}
1124
+ }}
1125
+ """
1126
+ ),
1127
+ 0 ,
1128
+ )
1129
+ else :
1130
+ if base != "super::AST" :
1131
+ add_subclass = f".add_subclass({ base } )"
1132
+ else :
1133
+ add_subclass = ""
1134
+ self .emit (
1135
+ textwrap .dedent (
1136
+ f"""
1137
+ impl ToPyObject for { rust_name } {{
1138
+ fn to_object(&self, py: Python) -> PyObject {{
1139
+ let initializer = PyClassInitializer::from(AST)
1140
+ { add_subclass }
1141
+ .add_subclass(self.clone());
1142
+ Py::new(py, initializer).unwrap().into_py(py)
1143
+ }}
1144
+ }}
1145
+ """
1146
+ ),
1147
+ 0 ,
1148
+ )
1149
+
1150
+ if not subclass :
1151
+ self .emit_wrapper (rust_name )
1152
+
1153
+ def emit_getter (self , owner , type_name ):
1154
+ self .emit (
1155
+ textwrap .dedent (
1156
+ f"""
1157
+ #[pymethods]
1158
+ impl { type_name } {{
1159
+ """
1160
+ ),
1161
+ 0 ,
1162
+ )
1163
+
1164
+ for field in owner .fields :
1165
+ self .emit (
1166
+ textwrap .dedent (
1167
+ f"""
1168
+ #[getter]
1169
+ #[inline]
1170
+ fn get_{ field .name } (&self, py: Python) -> PyResult<PyObject> {{
1171
+ self.0.{ rust_field (field .name )} .to_pyo3_wrapper(py)
1172
+ }}
1173
+ """
1174
+ ),
1175
+ 3 ,
1176
+ )
1177
+
1178
+ self .emit (
1179
+ textwrap .dedent (
1180
+ """
1181
+ }
1182
+ """
1183
+ ),
1184
+ 0 ,
1185
+ )
1186
+
1187
+ def emit_getattr (self , owner , type_name ):
1188
+ self .emit (
1189
+ textwrap .dedent (
1190
+ f"""
1191
+ #[pymethods]
1192
+ impl { type_name } {{
1193
+ fn __getattr__(&self, py: Python, key: &str) -> PyResult<PyObject> {{
1194
+ let object: Py<PyAny> = match key {{
1195
+ """
1196
+ ),
1197
+ 0 ,
1198
+ )
1199
+
1200
+ for field in owner .fields :
1201
+ self .emit (
1202
+ f'"{ field .name } " => self.0.{ rust_field (field .name )} .to_pyo3_wrapper(py)?,' ,
1203
+ 3 ,
1204
+ )
1205
+
1206
+ self .emit (
1207
+ textwrap .dedent (
1208
+ """
1209
+ _ => todo!(),
1210
+ };
1211
+ Ok(object)
1212
+ }
1213
+ }
1214
+ """
1215
+ ),
1216
+ 0 ,
1217
+ )
1218
+
1219
+ def emit_wrapper (self , rust_name ):
1220
+ self .emit (
1221
+ f"""
1222
+ impl ToPyo3Wrapper for crate::{ rust_name } { self .generics } {{
1223
+ #[inline]
1224
+ fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
1225
+ Ok({ rust_name } (self).to_object(py))
1226
+ }}
1227
+ }}
1228
+ """ ,
1229
+ 0 ,
1230
+ )
1231
+
1232
+ def visitModule (self , mod ):
1233
+ for dfn in mod .dfns :
1234
+ self .visit (dfn )
1235
+
1236
+ def visitType (self , type , depth = 0 ):
1237
+ self .visit (type .value , type , depth )
1238
+
1239
+ def visitSum (self , sum , type , depth = 0 ):
1240
+ rust_name = rust_type_name (type .name )
1241
+
1242
+ simple = is_simple (sum )
1243
+ self .emit_class (type .name , rust_name , simple )
1244
+
1245
+ if not simple :
1246
+ self .emit (
1247
+ f"""
1248
+ impl ToPyo3Wrapper for crate::{ rust_name } { self .generics } {{
1249
+ #[inline]
1250
+ fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
1251
+ match &self {{
1252
+ """ ,
1253
+ 0 ,
1254
+ )
1255
+
1256
+ for cons in sum .types :
1257
+ self .emit (f"Self::{ cons .name } (cons) => cons.to_pyo3_wrapper(py)," , 3 )
1258
+
1259
+ self .emit (
1260
+ """
1261
+ }
1262
+ }
1263
+ }
1264
+ """ ,
1265
+ 0 ,
1266
+ )
1267
+
1268
+ for cons in sum .types :
1269
+ self .visit (cons , rust_name , simple , depth + 1 )
1270
+
1271
+ def visitProduct (self , product , type , depth = 0 ):
1272
+ rust_name = rust_type_name (type .name )
1273
+ self .emit_class (type .name , rust_name , False )
1274
+ if self .borrow :
1275
+ self .emit_getter (product , rust_name )
1276
+
1277
+ def visitConstructor (self , cons , parent , simple , depth ):
1278
+ if simple :
1279
+ self .emit (
1280
+ f"""
1281
+ #[pyclass(module="{ self .module_name } ", name="_{ cons .name } ", extends={ parent } )]
1282
+ pub struct { parent } { cons .name } ;
1283
+
1284
+ impl ToPyObject for { parent } { cons .name } {{
1285
+ fn to_object(&self, py: Python) -> PyObject {{
1286
+ let initializer = PyClassInitializer::from(AST)
1287
+ .add_subclass({ parent } )
1288
+ .add_subclass(Self);
1289
+ Py::new(py, initializer).unwrap().into_py(py)
1290
+ }}
1291
+ }}
1292
+ """ ,
1293
+ depth ,
1294
+ )
1295
+ else :
1296
+ self .emit_class (
1297
+ cons .name ,
1298
+ f"{ parent } { cons .name } " ,
1299
+ simple = False ,
1300
+ base = parent ,
1301
+ )
1302
+ if self .borrow :
1303
+ self .emit_getter (cons , f"{ parent } { cons .name } " )
1304
+
1305
+
1306
+ class Pyo3PymoduleVisitor (EmitVisitor ):
1307
+ def __init__ (self , namespace , * args , ** kw ):
1308
+ self .namespace = namespace
1309
+ super ().__init__ (* args , ** kw )
1310
+
1311
+ def visitModule (self , mod ):
1312
+ for dfn in mod .dfns :
1313
+ self .visit (dfn )
1314
+
1315
+ def visitType (self , type , depth = 0 ):
1316
+ self .visit (type .value , type .name , depth )
1317
+
1318
+ def visitProduct (self , product , name , depth = 0 ):
1319
+ rust_name = rust_type_name (name )
1320
+ self .emit_fields (name , rust_name , False )
1321
+
1322
+ def visitSum (self , sum , name , depth ):
1323
+ rust_name = rust_type_name (name )
1324
+ simple = is_simple (sum )
1325
+ self .emit_fields (name , rust_name , True )
1326
+
1327
+ for cons in sum .types :
1328
+ self .visit (cons , name , simple , depth )
1329
+
1330
+ def visitConstructor (self , cons , parent , simple , depth ):
1331
+ rust_name = rust_type_name (parent ) + rust_type_name (cons .name )
1332
+ self .emit_fields (cons .name , rust_name , simple )
1333
+
1334
+ def emit_fields (self , name , rust_name , simple ):
1335
+ self .emit (f"super::init_type::<{ rust_name } , crate::generic::{ rust_name } >(py, ast_module, m)?;" , 1 )
1336
+
1044
1337
1045
1338
class StdlibClassDefVisitor (EmitVisitor ):
1046
1339
def visitModule (self , mod ):
@@ -1452,6 +1745,60 @@ def write_to_pyo3_simple(type_info, f):
1452
1745
)
1453
1746
1454
1747
1748
+ def write_pyo3_wrapper (mod , type_info , namespace , f ):
1749
+ Pyo3StructVisitor (namespace , f , type_info ).visit (mod )
1750
+
1751
+ if namespace == "located" :
1752
+ for type_info in type_info .values ():
1753
+ if not type_info .is_simple or not type_info .is_sum :
1754
+ continue
1755
+
1756
+ rust_name = type_info .rust_sum_name
1757
+ f .write (
1758
+ f"""
1759
+ impl ToPyo3Wrapper for crate::generic::{ rust_name } {{
1760
+ #[inline]
1761
+ fn to_pyo3_wrapper(&self, py: Python) -> PyResult<Py<PyAny>> {{
1762
+ match &self {{
1763
+ """ ,
1764
+ )
1765
+ for cons in type_info .type .value .types :
1766
+ f .write (
1767
+ f"Self::{ cons .name } => Ok({ rust_name } { cons .name } .to_object(py))," ,
1768
+ )
1769
+ f .write (
1770
+ """
1771
+ }
1772
+ }
1773
+ }
1774
+ """ ,
1775
+ )
1776
+
1777
+ for cons in type_info .type .value .types :
1778
+ f .write (
1779
+ f"""
1780
+ impl ToPyo3Wrapper for crate::generic::{ rust_name } { cons .name } {{
1781
+ #[inline]
1782
+ fn to_pyo3_wrapper(&self, py: Python) -> PyResult<Py<PyAny>> {{
1783
+ Ok({ rust_name } { cons .name } .to_object(py))
1784
+ }}
1785
+ }}
1786
+ """
1787
+ )
1788
+
1789
+ f .write (
1790
+ """
1791
+ pub fn add_to_module(py: Python, m: &PyModule) -> PyResult<()> {
1792
+ super::init_module(py, m)?;
1793
+
1794
+ let ast_module = PyModule::import(py, "_ast")?;
1795
+ """
1796
+ )
1797
+
1798
+ Pyo3PymoduleVisitor (namespace , f , type_info ).visit (mod )
1799
+ f .write ("Ok(())\n }" )
1800
+
1801
+
1455
1802
def write_ast_mod (mod , type_info , f ):
1456
1803
f .write (
1457
1804
textwrap .dedent (
@@ -1498,6 +1845,8 @@ def main(
1498
1845
("located" , p (write_located_def , mod , type_info )),
1499
1846
("visitor" , p (write_visitor_def , mod , type_info )),
1500
1847
("to_pyo3" , p (write_to_pyo3 , mod , type_info )),
1848
+ ("pyo3_wrapper_located" , p (write_pyo3_wrapper , mod , type_info , "located" )),
1849
+ ("pyo3_wrapper_ranged" , p (write_pyo3_wrapper , mod , type_info , "ranged" )),
1501
1850
]:
1502
1851
with (ast_dir / f"{ filename } .rs" ).open ("w" ) as f :
1503
1852
f .write (auto_gen_msg )
0 commit comments