@@ -546,17 +546,32 @@ def visitType(self, type, depth):
546
546
self .emit (f"fold_{ name } (self, node)" , depth + 1 )
547
547
self .emit ("}" , depth )
548
548
549
+ if isinstance (type .value , asdl .Sum ) and not is_simple (type .value ):
550
+ for cons in type .value .types :
551
+ self .visit (cons , type , depth )
552
+
553
+ def visitConstructor (self , cons , type , depth ):
554
+ apply_u , apply_target_u = self .apply_generics (type .name , "U" , "Self::TargetU" )
555
+ enum_name = rust_type_name (type .name )
556
+ func_name = f"fold_{ type .name } _{ rust_field_name (cons .name )} "
557
+ self .emit (
558
+ f"fn { func_name } (&mut self, node: { enum_name } { cons .name } { apply_u } ) -> Result<{ enum_name } { cons .name } { apply_target_u } , Self::Error> {{" ,
559
+ depth ,
560
+ )
561
+ self .emit (f"{ func_name } (self, node)" , depth + 1 )
562
+ self .emit ("}" , depth )
563
+
549
564
550
565
class FoldImplVisitor (EmitVisitor ):
551
566
def visitModule (self , mod , depth ):
552
567
for dfn in mod .dfns :
553
568
self .visit (dfn , depth )
554
569
555
570
def visitType (self , type , depth = 0 ):
556
- self .visit (type .value , type . name , depth )
571
+ self .visit (type .value , type , depth )
557
572
558
- def visitSum (self , sum , name , depth ):
559
- type_info = self . type_info [ name ]
573
+ def visitSum (self , sum , type , depth ):
574
+ name = type . name
560
575
apply_t , apply_u , apply_target_u = self .apply_generics (
561
576
name , "T" , "U" , "F::TargetU"
562
577
)
@@ -582,34 +597,69 @@ def visitSum(self, sum, name, depth):
582
597
self .emit ("Ok(node) }" , depth + 1 )
583
598
return
584
599
585
- self .emit ("match node {" , depth + 1 )
600
+ self .emit ("let folded = match node {" , depth + 1 )
586
601
for cons in sum .types :
587
- fields_pattern = self .make_pattern (enum_name , cons .name , cons .fields )
588
602
self .emit (
589
- f"{ fields_pattern [ 0 ] } {{ { fields_pattern [ 1 ] } }} { fields_pattern [ 2 ] } => {{ " ,
590
- depth + 2 ,
603
+ f"{ enum_name } :: { cons . name } (cons) => { enum_name } :: { cons . name } (Foldable::fold(cons, folder)?), " ,
604
+ depth + 1 ,
591
605
)
592
606
593
- map_user_suffix = "" if type_info .has_attributes else "_cfg"
594
- self .emit (
595
- f"let context = folder.will_map_user{ map_user_suffix } (&range);" ,
596
- depth + 3 ,
597
- )
598
- self .fold_fields (
599
- fields_pattern [0 ], cons .fields , fields_pattern [2 ], depth + 3
600
- )
601
- self .emit (
602
- f"let range = folder.map_user{ map_user_suffix } (range, context)?;" ,
603
- depth + 3 ,
604
- )
605
- self .composite_fields (
606
- fields_pattern [0 ], cons .fields , fields_pattern [2 ], depth + 3
607
- )
608
- self .emit ("}" , depth + 2 )
607
+ self .emit ("};" , depth + 1 )
608
+ self .emit ("Ok(folded)" , depth + 1 )
609
+ self .emit ("}" , depth )
610
+
611
+ for cons in sum .types :
612
+ self .visit (cons , type , depth )
613
+
614
+ def visitConstructor (self , cons , type , depth ):
615
+ apply_t , apply_u , apply_target_u = self .apply_generics (
616
+ type .name , "T" , "U" , "F::TargetU"
617
+ )
618
+ enum_name = rust_type_name (type .name )
619
+
620
+ cons_type_name = f"{ enum_name } { cons .name } "
621
+
622
+ self .emit (
623
+ f"impl<T, U> Foldable<T, U> for { cons_type_name } { apply_t } {{" , depth
624
+ )
625
+ self .emit (f"type Mapped = { cons_type_name } { apply_u } ;" , depth + 1 )
626
+ self .emit (
627
+ "fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {" ,
628
+ depth + 1 ,
629
+ )
630
+ self .emit (
631
+ f"folder.fold_{ type .name } _{ rust_field_name (cons .name )} (self)" , depth + 2
632
+ )
609
633
self .emit ("}" , depth + 1 )
610
634
self .emit ("}" , depth )
611
635
612
- def visitProduct (self , product , name , depth ):
636
+ self .emit (
637
+ f"pub fn fold_{ type .name } _{ rust_field_name (cons .name )} <U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: { cons_type_name } { apply_u } ) -> Result<{ enum_name } { cons .name } { apply_target_u } , F::Error> {{" ,
638
+ depth ,
639
+ )
640
+
641
+ type_info = self .type_info [type .name ]
642
+
643
+ fields_pattern = self .make_pattern (cons .fields )
644
+
645
+ map_user_suffix = "" if type_info .has_attributes else "_cfg"
646
+ self .emit (
647
+ f"""
648
+ let { cons_type_name } {{ { fields_pattern } }} = node;
649
+ let context = folder.will_map_user{ map_user_suffix } (&range);
650
+ """ ,
651
+ depth + 3 ,
652
+ )
653
+ self .fold_fields (cons .fields , depth + 3 )
654
+ self .emit (
655
+ f"let range = folder.map_user{ map_user_suffix } (range, context)?;" ,
656
+ depth + 3 ,
657
+ )
658
+ self .composite_fields (f"{ cons_type_name } " , cons .fields , depth + 3 )
659
+ self .emit ("}" , depth + 2 )
660
+
661
+ def visitProduct (self , product , type , depth ):
662
+ name = type .name
613
663
apply_t , apply_u , apply_target_u = self .apply_generics (
614
664
name , "T" , "U" , "F::TargetU"
615
665
)
@@ -631,45 +681,42 @@ def visitProduct(self, product, name, depth):
631
681
depth ,
632
682
)
633
683
634
- fields_pattern = self .make_pattern (struct_name , struct_name , product .fields )
635
- self .emit (f"let { struct_name } {{ { fields_pattern [ 1 ] } }} = node;" , depth + 1 )
684
+ fields_pattern = self .make_pattern (product .fields )
685
+ self .emit (f"let { struct_name } {{ { fields_pattern } }} = node;" , depth + 1 )
636
686
637
687
map_user_suffix = "" if has_attributes else "_cfg"
638
688
639
689
self .emit (
640
690
f"let context = folder.will_map_user{ map_user_suffix } (&range);" , depth + 3
641
691
)
642
- self .fold_fields (struct_name , product .fields , "" , depth + 1 )
692
+ self .fold_fields (product .fields , depth + 1 )
643
693
self .emit (
644
694
f"let range = folder.map_user{ map_user_suffix } (range, context)?;" , depth + 3
645
695
)
646
- self .composite_fields (struct_name , product .fields , "" , depth + 1 )
696
+ self .composite_fields (struct_name , product .fields , depth + 1 )
647
697
648
698
self .emit ("}" , depth )
649
699
650
- def make_pattern (self , rust_name , fieldname : str , fields ):
651
- header = f"{ rust_name } ::{ fieldname } ({ rust_name } { fieldname } "
652
- footer = ")"
653
-
700
+ def make_pattern (self , fields ):
654
701
body = "," .join (rust_field (f .name ) for f in fields )
655
702
if body :
656
703
body += ","
657
704
body += "range"
658
705
659
- return header , body , footer
706
+ return body
660
707
661
- def fold_fields (self , header , fields , footer , depth ):
708
+ def fold_fields (self , fields , depth ):
662
709
for field in fields :
663
710
name = rust_field (field .name )
664
711
self .emit (f"let { name } = Foldable::fold({ name } , folder)?;" , depth + 1 )
665
712
666
- def composite_fields (self , header , fields , footer , depth ):
713
+ def composite_fields (self , header , fields , depth ):
667
714
self .emit (f"Ok({ header } {{" , depth )
668
715
for field in fields :
669
716
name = rust_field (field .name )
670
717
self .emit (f"{ name } ," , depth + 1 )
671
718
self .emit ("range," , depth + 1 )
672
- self .emit (f"}}{ footer } )" , depth )
719
+ self .emit (f"}})" , depth )
673
720
674
721
675
722
class FoldModuleVisitor (EmitVisitor ):
0 commit comments