@@ -496,15 +496,29 @@ def visitModule(self, mod, depth):
496
496
self .emit ("pub trait Fold<U> {" , depth )
497
497
self .emit ("type TargetU;" , depth + 1 )
498
498
self .emit ("type Error;" , depth + 1 )
499
+ self .emit ("type UserContext;" , depth + 1 )
499
500
self .emit (
500
501
"""
501
- fn map_user (&mut self, user: U) -> Result< Self::TargetU, Self::Error> ;
502
+ fn will_map_user (&mut self, user: & U) -> Self::UserContext ;
502
503
#[cfg(feature = "all-nodes-with-ranges")]
503
- fn map_user_cfg (&mut self, user: U) -> Result< Self::TargetU, Self::Error> {
504
- self.map_user (user)
504
+ fn will_map_user_cfg (&mut self, user: & U) -> Self::UserContext {
505
+ self.will_map_user (user)
505
506
}
506
507
#[cfg(not(feature = "all-nodes-with-ranges"))]
507
- fn map_user_cfg(&mut self, _user: crate::EmptyRange<U>) -> Result<crate::EmptyRange<Self::TargetU>, Self::Error> {
508
+ fn will_map_user_cfg(&mut self, user: &crate::EmptyRange<U>) -> crate::EmptyRange<Self::TargetU> {
509
+ crate::EmptyRange::default()
510
+ }
511
+ fn map_user(&mut self, user: U, context: Self::UserContext) -> Result<Self::TargetU, Self::Error>;
512
+ #[cfg(feature = "all-nodes-with-ranges")]
513
+ fn map_user_cfg(&mut self, user: U, context: Self::UserContext) -> Result<Self::TargetU, Self::Error> {
514
+ self.map_user(user, context)
515
+ }
516
+ #[cfg(not(feature = "all-nodes-with-ranges"))]
517
+ fn map_user_cfg(
518
+ &mut self,
519
+ _user: crate::EmptyRange<U>,
520
+ _context: crate::EmptyRange<Self::TargetU>,
521
+ ) -> Result<crate::EmptyRange<Self::TargetU>, Self::Error> {
508
522
Ok(crate::EmptyRange::default())
509
523
}
510
524
""" ,
@@ -532,17 +546,32 @@ def visitType(self, type, depth):
532
546
self .emit (f"fold_{ name } (self, node)" , depth + 1 )
533
547
self .emit ("}" , depth )
534
548
549
+ if isinstance (type .value , asdl .Sum ) and not is_simple (type .value ):
550
+ for cons in type .value .types :
551
+ self .visit (cons , type , depth )
552
+
553
+ def visitConstructor (self , cons , type , depth ):
554
+ apply_u , apply_target_u = self .apply_generics (type .name , "U" , "Self::TargetU" )
555
+ enum_name = rust_type_name (type .name )
556
+ func_name = f"fold_{ type .name } _{ rust_field_name (cons .name )} "
557
+ self .emit (
558
+ f"fn { func_name } (&mut self, node: { enum_name } { cons .name } { apply_u } ) -> Result<{ enum_name } { cons .name } { apply_target_u } , Self::Error> {{" ,
559
+ depth ,
560
+ )
561
+ self .emit (f"{ func_name } (self, node)" , depth + 1 )
562
+ self .emit ("}" , depth )
563
+
535
564
536
565
class FoldImplVisitor (EmitVisitor ):
537
566
def visitModule (self , mod , depth ):
538
567
for dfn in mod .dfns :
539
568
self .visit (dfn , depth )
540
569
541
570
def visitType (self , type , depth = 0 ):
542
- self .visit (type .value , type . name , depth )
571
+ self .visit (type .value , type , depth )
543
572
544
- def visitSum (self , sum , name , depth ):
545
- type_info = self . type_info [ name ]
573
+ def visitSum (self , sum , type , depth ):
574
+ name = type . name
546
575
apply_t , apply_u , apply_target_u = self .apply_generics (
547
576
name , "T" , "U" , "F::TargetU"
548
577
)
@@ -568,27 +597,69 @@ def visitSum(self, sum, name, depth):
568
597
self .emit ("Ok(node) }" , depth + 1 )
569
598
return
570
599
571
- self .emit ("match node {" , depth + 1 )
600
+ self .emit ("let folded = match node {" , depth + 1 )
572
601
for cons in sum .types :
573
- fields_pattern = self .make_pattern (enum_name , cons .name , cons .fields )
574
602
self .emit (
575
- f"{ fields_pattern [ 0 ] } {{ { fields_pattern [ 1 ] } }} { fields_pattern [ 2 ] } => {{ " ,
576
- depth + 2 ,
603
+ f"{ enum_name } :: { cons . name } (cons) => { enum_name } :: { cons . name } (Foldable::fold(cons, folder)?), " ,
604
+ depth + 1 ,
577
605
)
578
606
579
- map_user_suffix = "" if type_info .has_attributes else "_cfg"
580
- self .emit (
581
- f"let range = folder.map_user{ map_user_suffix } (range)?;" , depth + 3
582
- )
607
+ self .emit ("};" , depth + 1 )
608
+ self .emit ("Ok(folded)" , depth + 1 )
609
+ self .emit ("}" , depth )
583
610
584
- self .gen_construction (
585
- fields_pattern [0 ], cons .fields , fields_pattern [2 ], depth + 3
586
- )
587
- self .emit ("}" , depth + 2 )
611
+ for cons in sum .types :
612
+ self .visit (cons , type , depth )
613
+
614
+ def visitConstructor (self , cons , type , depth ):
615
+ apply_t , apply_u , apply_target_u = self .apply_generics (
616
+ type .name , "T" , "U" , "F::TargetU"
617
+ )
618
+ enum_name = rust_type_name (type .name )
619
+
620
+ cons_type_name = f"{ enum_name } { cons .name } "
621
+
622
+ self .emit (
623
+ f"impl<T, U> Foldable<T, U> for { cons_type_name } { apply_t } {{" , depth
624
+ )
625
+ self .emit (f"type Mapped = { cons_type_name } { apply_u } ;" , depth + 1 )
626
+ self .emit (
627
+ "fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {" ,
628
+ depth + 1 ,
629
+ )
630
+ self .emit (
631
+ f"folder.fold_{ type .name } _{ rust_field_name (cons .name )} (self)" , depth + 2
632
+ )
588
633
self .emit ("}" , depth + 1 )
589
634
self .emit ("}" , depth )
590
635
591
- def visitProduct (self , product , name , depth ):
636
+ self .emit (
637
+ f"pub fn fold_{ type .name } _{ rust_field_name (cons .name )} <U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: { cons_type_name } { apply_u } ) -> Result<{ enum_name } { cons .name } { apply_target_u } , F::Error> {{" ,
638
+ depth ,
639
+ )
640
+
641
+ type_info = self .type_info [type .name ]
642
+
643
+ fields_pattern = self .make_pattern (cons .fields )
644
+
645
+ map_user_suffix = "" if type_info .has_attributes else "_cfg"
646
+ self .emit (
647
+ f"""
648
+ let { cons_type_name } {{ { fields_pattern } }} = node;
649
+ let context = folder.will_map_user{ map_user_suffix } (&range);
650
+ """ ,
651
+ depth + 3 ,
652
+ )
653
+ self .fold_fields (cons .fields , depth + 3 )
654
+ self .emit (
655
+ f"let range = folder.map_user{ map_user_suffix } (range, context)?;" ,
656
+ depth + 3 ,
657
+ )
658
+ self .composite_fields (f"{ cons_type_name } " , cons .fields , depth + 3 )
659
+ self .emit ("}" , depth + 2 )
660
+
661
+ def visitProduct (self , product , type , depth ):
662
+ name = type .name
592
663
apply_t , apply_u , apply_target_u = self .apply_generics (
593
664
name , "T" , "U" , "F::TargetU"
594
665
)
@@ -610,41 +681,47 @@ def visitProduct(self, product, name, depth):
610
681
depth ,
611
682
)
612
683
613
- fields_pattern = self .make_pattern (struct_name , struct_name , product .fields )
614
- self .emit (f"let { struct_name } {{ { fields_pattern [ 1 ] } }} = node;" , depth + 1 )
684
+ fields_pattern = self .make_pattern (product .fields )
685
+ self .emit (f"let { struct_name } {{ { fields_pattern } }} = node;" , depth + 1 )
615
686
616
687
map_user_suffix = "" if has_attributes else "_cfg"
617
- self .emit (f"let range = folder.map_user{ map_user_suffix } (range)?;" , depth + 3 )
618
688
619
- self .gen_construction (struct_name , product .fields , "" , depth + 1 )
689
+ self .emit (
690
+ f"let context = folder.will_map_user{ map_user_suffix } (&range);" , depth + 3
691
+ )
692
+ self .fold_fields (product .fields , depth + 1 )
693
+ self .emit (
694
+ f"let range = folder.map_user{ map_user_suffix } (range, context)?;" , depth + 3
695
+ )
696
+ self .composite_fields (struct_name , product .fields , depth + 1 )
620
697
621
698
self .emit ("}" , depth )
622
699
623
- def make_pattern (self , rust_name , fieldname : str , fields ):
624
- header = f"{ rust_name } ::{ fieldname } ({ rust_name } { fieldname } "
625
- footer = ")"
626
-
700
+ def make_pattern (self , fields ):
627
701
body = "," .join (rust_field (f .name ) for f in fields )
628
702
if body :
629
703
body += ","
630
704
body += "range"
631
705
632
- return header , body , footer
706
+ return body
633
707
634
- def gen_construction (self , header , fields , footer , depth ):
708
+ def fold_fields (self , fields , depth ):
709
+ for field in fields :
710
+ name = rust_field (field .name )
711
+ self .emit (f"let { name } = Foldable::fold({ name } , folder)?;" , depth + 1 )
712
+
713
+ def composite_fields (self , header , fields , depth ):
635
714
self .emit (f"Ok({ header } {{" , depth )
636
715
for field in fields :
637
716
name = rust_field (field .name )
638
- self .emit (f"{ name } : Foldable::fold( { name } , folder)? ," , depth + 1 )
717
+ self .emit (f"{ name } ," , depth + 1 )
639
718
self .emit ("range," , depth + 1 )
640
-
641
- self .emit (f"}}{ footer } )" , depth )
719
+ self .emit (f"}})" , depth )
642
720
643
721
644
722
class FoldModuleVisitor (EmitVisitor ):
645
723
def visitModule (self , mod ):
646
724
depth = 0
647
- self .emit ("use crate::fold_helpers::Foldable;" , depth )
648
725
FoldTraitDefVisitor (self .file , self .type_info ).visit (mod , depth )
649
726
FoldImplVisitor (self .file , self .type_info ).visit (mod , depth )
650
727
0 commit comments