@@ -1042,6 +1042,302 @@ def emit_to_pyo3_with_fields(self, cons, name):
1042
1042
)
1043
1043
1044
1044
1045
+ class Pyo3StructVisitor (EmitVisitor ):
1046
+ """Visitor to generate type-defs for AST."""
1047
+
1048
+ def __init__ (self , namespace , * args , ** kw ):
1049
+ self .namespace = namespace
1050
+ self .borrow = True
1051
+ super ().__init__ (* args , ** kw )
1052
+
1053
+ @property
1054
+ def generics (self ):
1055
+ if self .namespace == "ranged" :
1056
+ return "<TextRange>"
1057
+ elif self .namespace == "located" :
1058
+ return "<SourceRange>"
1059
+ else :
1060
+ assert False , self .namespace
1061
+
1062
+ @property
1063
+ def module_name (self ):
1064
+ name = f"rustpython_ast.{ self .namespace } "
1065
+ return name
1066
+
1067
+ @property
1068
+ def ref_def (self ):
1069
+ return "&'static " if self .borrow else ""
1070
+
1071
+ @property
1072
+ def ref (self ):
1073
+ return "&" if self .borrow else ""
1074
+
1075
+ def emit_class (self , name , rust_name , simple , base = "super::AST" ):
1076
+ info = self .type_info [name ]
1077
+ if simple :
1078
+ generics = ""
1079
+ else :
1080
+ generics = self .generics
1081
+ if info .is_sum :
1082
+ subclass = ", subclass"
1083
+ body = ""
1084
+ into = f"{ rust_name } "
1085
+ else :
1086
+ subclass = ""
1087
+ body = f"(pub { self .ref_def } crate::{ rust_name } { generics } )"
1088
+ into = f"{ rust_name } (node)"
1089
+
1090
+ self .emit (
1091
+ textwrap .dedent (
1092
+ f"""
1093
+ #[pyclass(module="{ self .module_name } ", name="_{ name } ", extends={ base } , frozen{ subclass } )]
1094
+ #[derive(Clone, Debug)]
1095
+ pub struct { rust_name } { body } ;
1096
+
1097
+ impl From<{ self .ref_def } crate::{ rust_name } { generics } > for { rust_name } {{
1098
+ fn from({ "" if body else "_" } node: { self .ref_def } crate::{ rust_name } { generics } ) -> Self {{
1099
+ { into }
1100
+ }}
1101
+ }}
1102
+ """
1103
+ ),
1104
+ 0 ,
1105
+ )
1106
+ if subclass :
1107
+ self .emit (
1108
+ textwrap .dedent (
1109
+ f"""
1110
+ #[pymethods]
1111
+ impl { rust_name } {{
1112
+ #[new]
1113
+ fn new() -> PyClassInitializer<Self> {{
1114
+ PyClassInitializer::from(AST)
1115
+ .add_subclass(Self)
1116
+ }}
1117
+
1118
+ }}
1119
+ impl ToPyObject for { rust_name } {{
1120
+ fn to_object(&self, py: Python) -> PyObject {{
1121
+ let initializer = PyClassInitializer::from(AST)
1122
+ .add_subclass(self.clone());
1123
+ Py::new(py, initializer).unwrap().into_py(py)
1124
+ }}
1125
+ }}
1126
+ """
1127
+ ),
1128
+ 0 ,
1129
+ )
1130
+ else :
1131
+ if base != "super::AST" :
1132
+ add_subclass = f".add_subclass({ base } )"
1133
+ else :
1134
+ add_subclass = ""
1135
+ self .emit (
1136
+ textwrap .dedent (
1137
+ f"""
1138
+ impl ToPyObject for { rust_name } {{
1139
+ fn to_object(&self, py: Python) -> PyObject {{
1140
+ let initializer = PyClassInitializer::from(AST)
1141
+ { add_subclass }
1142
+ .add_subclass(self.clone());
1143
+ Py::new(py, initializer).unwrap().into_py(py)
1144
+ }}
1145
+ }}
1146
+ """
1147
+ ),
1148
+ 0 ,
1149
+ )
1150
+
1151
+ if not subclass :
1152
+ self .emit_wrapper (rust_name )
1153
+
1154
+ def emit_getter (self , owner , type_name ):
1155
+ self .emit (
1156
+ textwrap .dedent (
1157
+ f"""
1158
+ #[pymethods]
1159
+ impl { type_name } {{
1160
+ """
1161
+ ),
1162
+ 0 ,
1163
+ )
1164
+
1165
+ for field in owner .fields :
1166
+ self .emit (
1167
+ textwrap .dedent (
1168
+ f"""
1169
+ #[getter]
1170
+ #[inline]
1171
+ fn get_{ field .name } (&self, py: Python) -> PyResult<PyObject> {{
1172
+ self.0.{ rust_field (field .name )} .to_pyo3_wrapper(py)
1173
+ }}
1174
+ """
1175
+ ),
1176
+ 3 ,
1177
+ )
1178
+
1179
+ self .emit (
1180
+ textwrap .dedent (
1181
+ """
1182
+ }
1183
+ """
1184
+ ),
1185
+ 0 ,
1186
+ )
1187
+
1188
+ def emit_getattr (self , owner , type_name ):
1189
+ self .emit (
1190
+ textwrap .dedent (
1191
+ f"""
1192
+ #[pymethods]
1193
+ impl { type_name } {{
1194
+ fn __getattr__(&self, py: Python, key: &str) -> PyResult<PyObject> {{
1195
+ let object: Py<PyAny> = match key {{
1196
+ """
1197
+ ),
1198
+ 0 ,
1199
+ )
1200
+
1201
+ for field in owner .fields :
1202
+ self .emit (
1203
+ f'"{ field .name } " => self.0.{ rust_field (field .name )} .to_pyo3_wrapper(py)?,' ,
1204
+ 3 ,
1205
+ )
1206
+
1207
+ self .emit (
1208
+ textwrap .dedent (
1209
+ """
1210
+ _ => todo!(),
1211
+ };
1212
+ Ok(object)
1213
+ }
1214
+ }
1215
+ """
1216
+ ),
1217
+ 0 ,
1218
+ )
1219
+
1220
+ def emit_wrapper (self , rust_name ):
1221
+ self .emit (
1222
+ f"""
1223
+ impl ToPyo3Wrapper for crate::{ rust_name } { self .generics } {{
1224
+ #[inline]
1225
+ fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
1226
+ Ok({ rust_name } (self).to_object(py))
1227
+ }}
1228
+ }}
1229
+ """ ,
1230
+ 0 ,
1231
+ )
1232
+
1233
+ def visitModule (self , mod ):
1234
+ for dfn in mod .dfns :
1235
+ self .visit (dfn )
1236
+
1237
+ def visitType (self , type , depth = 0 ):
1238
+ self .visit (type .value , type , depth )
1239
+
1240
+ def visitSum (self , sum , type , depth = 0 ):
1241
+ rust_name = rust_type_name (type .name )
1242
+
1243
+ simple = is_simple (sum )
1244
+ self .emit_class (type .name , rust_name , simple )
1245
+
1246
+ if not simple :
1247
+ self .emit (
1248
+ f"""
1249
+ impl ToPyo3Wrapper for crate::{ rust_name } { self .generics } {{
1250
+ #[inline]
1251
+ fn to_pyo3_wrapper(&'static self, py: Python) -> PyResult<Py<PyAny>> {{
1252
+ match &self {{
1253
+ """ ,
1254
+ 0 ,
1255
+ )
1256
+
1257
+ for cons in sum .types :
1258
+ self .emit (f"Self::{ cons .name } (cons) => cons.to_pyo3_wrapper(py)," , 3 )
1259
+
1260
+ self .emit (
1261
+ """
1262
+ }
1263
+ }
1264
+ }
1265
+ """ ,
1266
+ 0 ,
1267
+ )
1268
+
1269
+ for cons in sum .types :
1270
+ self .visit (cons , rust_name , simple , depth + 1 )
1271
+
1272
+ def visitProduct (self , product , type , depth = 0 ):
1273
+ rust_name = rust_type_name (type .name )
1274
+ self .emit_class (type .name , rust_name , False )
1275
+ if self .borrow :
1276
+ self .emit_getter (product , rust_name )
1277
+
1278
+ def visitConstructor (self , cons , parent , simple , depth ):
1279
+ if simple :
1280
+ self .emit (
1281
+ f"""
1282
+ #[pyclass(module="{ self .module_name } ", name="_{ cons .name } ", extends={ parent } )]
1283
+ pub struct { parent } { cons .name } ;
1284
+
1285
+ impl ToPyObject for { parent } { cons .name } {{
1286
+ fn to_object(&self, py: Python) -> PyObject {{
1287
+ let initializer = PyClassInitializer::from(AST)
1288
+ .add_subclass({ parent } )
1289
+ .add_subclass(Self);
1290
+ Py::new(py, initializer).unwrap().into_py(py)
1291
+ }}
1292
+ }}
1293
+ """ ,
1294
+ depth ,
1295
+ )
1296
+ else :
1297
+ self .emit_class (
1298
+ cons .name ,
1299
+ f"{ parent } { cons .name } " ,
1300
+ simple = False ,
1301
+ base = parent ,
1302
+ )
1303
+ if self .borrow :
1304
+ self .emit_getter (cons , f"{ parent } { cons .name } " )
1305
+
1306
+
1307
+ class Pyo3PymoduleVisitor (EmitVisitor ):
1308
+ def __init__ (self , namespace , * args , ** kw ):
1309
+ self .namespace = namespace
1310
+ super ().__init__ (* args , ** kw )
1311
+
1312
+ def visitModule (self , mod ):
1313
+ for dfn in mod .dfns :
1314
+ self .visit (dfn )
1315
+
1316
+ def visitType (self , type , depth = 0 ):
1317
+ self .visit (type .value , type .name , depth )
1318
+
1319
+ def visitProduct (self , product , name , depth = 0 ):
1320
+ rust_name = rust_type_name (name )
1321
+ self .emit_fields (name , rust_name , False )
1322
+
1323
+ def visitSum (self , sum , name , depth ):
1324
+ rust_name = rust_type_name (name )
1325
+ simple = is_simple (sum )
1326
+ self .emit_fields (name , rust_name , True )
1327
+
1328
+ for cons in sum .types :
1329
+ self .visit (cons , name , simple , depth )
1330
+
1331
+ def visitConstructor (self , cons , parent , simple , depth ):
1332
+ rust_name = rust_type_name (parent ) + rust_type_name (cons .name )
1333
+ self .emit_fields (cons .name , rust_name , simple )
1334
+
1335
+ def emit_fields (self , name , rust_name , simple ):
1336
+ self .emit (
1337
+ f"super::init_type::<{ rust_name } , crate::generic::{ rust_name } >(py, m)?;" , 1
1338
+ )
1339
+
1340
+
1045
1341
class StdlibClassDefVisitor (EmitVisitor ):
1046
1342
def visitModule (self , mod ):
1047
1343
for dfn in mod .dfns :
@@ -1452,6 +1748,58 @@ def write_to_pyo3_simple(type_info, f):
1452
1748
)
1453
1749
1454
1750
1751
+ def write_pyo3_wrapper (mod , type_info , namespace , f ):
1752
+ Pyo3StructVisitor (namespace , f , type_info ).visit (mod )
1753
+
1754
+ if namespace == "located" :
1755
+ for type_info in type_info .values ():
1756
+ if not type_info .is_simple or not type_info .is_sum :
1757
+ continue
1758
+
1759
+ rust_name = type_info .rust_sum_name
1760
+ f .write (
1761
+ f"""
1762
+ impl ToPyo3Wrapper for crate::generic::{ rust_name } {{
1763
+ #[inline]
1764
+ fn to_pyo3_wrapper(&self, py: Python) -> PyResult<Py<PyAny>> {{
1765
+ match &self {{
1766
+ """ ,
1767
+ )
1768
+ for cons in type_info .type .value .types :
1769
+ f .write (
1770
+ f"Self::{ cons .name } => Ok({ rust_name } { cons .name } .to_object(py))," ,
1771
+ )
1772
+ f .write (
1773
+ """
1774
+ }
1775
+ }
1776
+ }
1777
+ """ ,
1778
+ )
1779
+
1780
+ for cons in type_info .type .value .types :
1781
+ f .write (
1782
+ f"""
1783
+ impl ToPyo3Wrapper for crate::generic::{ rust_name } { cons .name } {{
1784
+ #[inline]
1785
+ fn to_pyo3_wrapper(&self, py: Python) -> PyResult<Py<PyAny>> {{
1786
+ Ok({ rust_name } { cons .name } .to_object(py))
1787
+ }}
1788
+ }}
1789
+ """
1790
+ )
1791
+
1792
+ f .write (
1793
+ """
1794
+ pub fn add_to_module(py: Python, m: &PyModule) -> PyResult<()> {
1795
+ super::init_module(py, m)?;
1796
+ """
1797
+ )
1798
+
1799
+ Pyo3PymoduleVisitor (namespace , f , type_info ).visit (mod )
1800
+ f .write ("Ok(())\n }" )
1801
+
1802
+
1455
1803
def write_ast_mod (mod , type_info , f ):
1456
1804
f .write (
1457
1805
textwrap .dedent (
@@ -1498,6 +1846,8 @@ def main(
1498
1846
("located" , p (write_located_def , mod , type_info )),
1499
1847
("visitor" , p (write_visitor_def , mod , type_info )),
1500
1848
("to_pyo3" , p (write_to_pyo3 , mod , type_info )),
1849
+ ("pyo3_wrapper_located" , p (write_pyo3_wrapper , mod , type_info , "located" )),
1850
+ ("pyo3_wrapper_ranged" , p (write_pyo3_wrapper , mod , type_info , "ranged" )),
1501
1851
]:
1502
1852
with (ast_dir / f"{ filename } .rs" ).open ("w" ) as f :
1503
1853
f .write (auto_gen_msg )
0 commit comments