@@ -937,6 +937,111 @@ def emit_located_impl(self, info):
937
937
)
938
938
939
939
940
+ class ToPyo3AstVisitor (EmitVisitor ):
941
+ """Visitor to generate type-defs for AST."""
942
+
943
+ def __init__ (self , namespace , * args , ** kw ):
944
+ super ().__init__ (* args , ** kw )
945
+ self .namespace = namespace
946
+
947
+ @property
948
+ def generics (self ):
949
+ if self .namespace == "ranged" :
950
+ return "<TextRange>"
951
+ elif self .namespace == "located" :
952
+ return "<SourceRange>"
953
+ else :
954
+ assert False , self .namespace
955
+
956
+ def visitModule (self , mod ):
957
+ for dfn in mod .dfns :
958
+ self .visit (dfn )
959
+
960
+ def visitType (self , type , depth = 0 ):
961
+ self .visit (type .value , type .name , depth )
962
+
963
+ def visitProduct (self , product , name , depth = 0 ):
964
+ rust_name = rust_type_name (name )
965
+ self .emit_to_pyo3_with_fields (product , rust_name )
966
+
967
+ def visitSum (self , sum , name , depth = 0 ):
968
+ rust_name = rust_type_name (name )
969
+ simple = is_simple (sum )
970
+ if is_simple (sum ):
971
+ return
972
+
973
+ self .emit (
974
+ f"""
975
+ impl ToPyo3Ast for crate::generic::{ rust_name } { self .generics } {{
976
+ #[inline]
977
+ fn to_pyo3_ast(&self, { "_" if simple else "" } py: Python) -> PyResult<Py<PyAny>> {{
978
+ let instance = match &self {{
979
+ """ ,
980
+ 0 ,
981
+ )
982
+ for cons in sum .types :
983
+ self .emit (
984
+ f"""crate::{ rust_name } ::{ cons .name } (cons) => cons.to_pyo3_ast(py)?,""" ,
985
+ depth ,
986
+ )
987
+ self .emit (
988
+ """
989
+ };
990
+ Ok(instance)
991
+ }
992
+ }
993
+ """ ,
994
+ 0 ,
995
+ )
996
+
997
+ for cons in sum .types :
998
+ self .visit (cons , rust_name , depth )
999
+
1000
+ def visitConstructor (self , cons , parent , depth ):
1001
+ self .emit_to_pyo3_with_fields (cons , f"{ parent } { cons .name } " )
1002
+
1003
+ def emit_to_pyo3_with_fields (self , cons , name ):
1004
+ if cons .fields :
1005
+ self .emit (
1006
+ f"""
1007
+ impl ToPyo3Ast for crate::{ name } { self .generics } {{
1008
+ #[inline]
1009
+ fn to_pyo3_ast(&self, py: Python) -> PyResult<Py<PyAny>> {{
1010
+ let cache = Self::py_type_cache().get().unwrap();
1011
+ let instance = cache.0.call1(py, (
1012
+ """ ,
1013
+ 0 ,
1014
+ )
1015
+ for field in cons .fields :
1016
+ self .emit (
1017
+ f"self.{ rust_field (field .name )} .to_pyo3_ast(py)?," ,
1018
+ 3 ,
1019
+ )
1020
+ self .emit (
1021
+ """
1022
+ ))?;
1023
+ Ok(instance)
1024
+ }
1025
+ }
1026
+ """ ,
1027
+ 0 ,
1028
+ )
1029
+ else :
1030
+ self .emit (
1031
+ f"""
1032
+ impl ToPyo3Ast for crate::{ name } { self .generics } {{
1033
+ #[inline]
1034
+ fn to_pyo3_ast(&self, py: Python) -> PyResult<Py<PyAny>> {{
1035
+ let cache = Self::py_type_cache().get().unwrap();
1036
+ let instance = cache.0.call0(py)?;
1037
+ Ok(instance)
1038
+ }}
1039
+ }}
1040
+ """ ,
1041
+ 0 ,
1042
+ )
1043
+
1044
+
940
1045
class StdlibClassDefVisitor (EmitVisitor ):
941
1046
def visitModule (self , mod ):
942
1047
for dfn in mod .dfns :
@@ -1271,6 +1376,82 @@ def write_located_def(mod, type_info, f):
1271
1376
LocatedDefVisitor (f , type_info ).visit (mod )
1272
1377
1273
1378
1379
+ def write_pyo3_node (type_info , f ):
1380
+ def write (info : TypeInfo ):
1381
+ rust_name = info .rust_sum_name
1382
+ if info .is_simple :
1383
+ generics = ""
1384
+ else :
1385
+ generics = "<R>"
1386
+
1387
+ f .write (
1388
+ textwrap .dedent (
1389
+ f"""
1390
+ impl{ generics } Pyo3Node for crate::generic::{ rust_name } { generics } {{
1391
+ #[inline]
1392
+ fn py_type_cache() -> &'static OnceCell<(Py<PyAny>, Py<PyAny>)> {{
1393
+ static PY_TYPE: OnceCell<(Py<PyAny>, Py<PyAny>)> = OnceCell::new();
1394
+ &PY_TYPE
1395
+ }}
1396
+ }}
1397
+ """
1398
+ ),
1399
+ )
1400
+
1401
+ for info in type_info .values ():
1402
+ write (info )
1403
+
1404
+
1405
+ def write_to_pyo3 (mod , type_info , f ):
1406
+ write_pyo3_node (type_info , f )
1407
+ write_to_pyo3_simple (type_info , f )
1408
+
1409
+ for namespace in ("ranged" , "located" ):
1410
+ ToPyo3AstVisitor (namespace , f , type_info ).visit (mod )
1411
+
1412
+ f .write (
1413
+ """
1414
+ pub fn init(py: Python) -> PyResult<()> {
1415
+ let ast_module = PyModule::import(py, "_ast")?;
1416
+ """
1417
+ )
1418
+
1419
+ for info in type_info .values ():
1420
+ rust_name = info .rust_sum_name
1421
+ f .write (f"cache_py_type::<crate::generic::{ rust_name } >(ast_module)?;\n " )
1422
+ f .write ("Ok(())\n }" )
1423
+
1424
+
1425
+ def write_to_pyo3_simple (type_info , f ):
1426
+ for type_info in type_info .values ():
1427
+ if not type_info .is_sum :
1428
+ continue
1429
+ if not type_info .is_simple :
1430
+ continue
1431
+
1432
+ rust_name = type_info .rust_sum_name
1433
+ f .write (
1434
+ f"""
1435
+ impl ToPyo3Ast for crate::generic::{ rust_name } {{
1436
+ #[inline]
1437
+ fn to_pyo3_ast(&self, _py: Python) -> PyResult<Py<PyAny>> {{
1438
+ let cell = match &self {{
1439
+ """ ,
1440
+ )
1441
+ for cons in type_info .type .value .types :
1442
+ f .write (
1443
+ f"""crate::{ rust_name } ::{ cons .name } => crate::{ rust_name } { cons .name } ::py_type_cache(),""" ,
1444
+ )
1445
+ f .write (
1446
+ """
1447
+ };
1448
+ Ok(cell.get().unwrap().1.clone())
1449
+ }
1450
+ }
1451
+ """ ,
1452
+ )
1453
+
1454
+
1274
1455
def write_ast_mod (mod , type_info , f ):
1275
1456
f .write (
1276
1457
textwrap .dedent (
@@ -1316,6 +1497,7 @@ def main(
1316
1497
("ranged" , p (write_ranged_def , mod , type_info )),
1317
1498
("located" , p (write_located_def , mod , type_info )),
1318
1499
("visitor" , p (write_visitor_def , mod , type_info )),
1500
+ ("to_pyo3" , p (write_to_pyo3 , mod , type_info )),
1319
1501
]:
1320
1502
with (ast_dir / f"{ filename } .rs" ).open ("w" ) as f :
1321
1503
f .write (auto_gen_msg )
0 commit comments