Skip to content

Commit b48834f

Browse files
authored
More flexible map_user and fold for new constructor nodes (#53)
* make fold.rs file * Split user_map steps * Fold for new constructor nodes
1 parent 205ee80 commit b48834f

File tree

6 files changed

+2534
-851
lines changed

6 files changed

+2534
-851
lines changed

ast/asdl_rs.py

Lines changed: 111 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -496,15 +496,29 @@ def visitModule(self, mod, depth):
496496
self.emit("pub trait Fold<U> {", depth)
497497
self.emit("type TargetU;", depth + 1)
498498
self.emit("type Error;", depth + 1)
499+
self.emit("type UserContext;", depth + 1)
499500
self.emit(
500501
"""
501-
fn map_user(&mut self, user: U) -> Result<Self::TargetU, Self::Error>;
502+
fn will_map_user(&mut self, user: &U) -> Self::UserContext;
502503
#[cfg(feature = "all-nodes-with-ranges")]
503-
fn map_user_cfg(&mut self, user: U) -> Result<Self::TargetU, Self::Error> {
504-
self.map_user(user)
504+
fn will_map_user_cfg(&mut self, user: &U) -> Self::UserContext {
505+
self.will_map_user(user)
505506
}
506507
#[cfg(not(feature = "all-nodes-with-ranges"))]
507-
fn map_user_cfg(&mut self, _user: crate::EmptyRange<U>) -> Result<crate::EmptyRange<Self::TargetU>, Self::Error> {
508+
fn will_map_user_cfg(&mut self, user: &crate::EmptyRange<U>) -> crate::EmptyRange<Self::TargetU> {
509+
crate::EmptyRange::default()
510+
}
511+
fn map_user(&mut self, user: U, context: Self::UserContext) -> Result<Self::TargetU, Self::Error>;
512+
#[cfg(feature = "all-nodes-with-ranges")]
513+
fn map_user_cfg(&mut self, user: U, context: Self::UserContext) -> Result<Self::TargetU, Self::Error> {
514+
self.map_user(user, context)
515+
}
516+
#[cfg(not(feature = "all-nodes-with-ranges"))]
517+
fn map_user_cfg(
518+
&mut self,
519+
_user: crate::EmptyRange<U>,
520+
_context: crate::EmptyRange<Self::TargetU>,
521+
) -> Result<crate::EmptyRange<Self::TargetU>, Self::Error> {
508522
Ok(crate::EmptyRange::default())
509523
}
510524
""",
@@ -532,17 +546,32 @@ def visitType(self, type, depth):
532546
self.emit(f"fold_{name}(self, node)", depth + 1)
533547
self.emit("}", depth)
534548

549+
if isinstance(type.value, asdl.Sum) and not is_simple(type.value):
550+
for cons in type.value.types:
551+
self.visit(cons, type, depth)
552+
553+
def visitConstructor(self, cons, type, depth):
554+
apply_u, apply_target_u = self.apply_generics(type.name, "U", "Self::TargetU")
555+
enum_name = rust_type_name(type.name)
556+
func_name = f"fold_{type.name}_{rust_field_name(cons.name)}"
557+
self.emit(
558+
f"fn {func_name}(&mut self, node: {enum_name}{cons.name}{apply_u}) -> Result<{enum_name}{cons.name}{apply_target_u}, Self::Error> {{",
559+
depth,
560+
)
561+
self.emit(f"{func_name}(self, node)", depth + 1)
562+
self.emit("}", depth)
563+
535564

536565
class FoldImplVisitor(EmitVisitor):
537566
def visitModule(self, mod, depth):
538567
for dfn in mod.dfns:
539568
self.visit(dfn, depth)
540569

541570
def visitType(self, type, depth=0):
542-
self.visit(type.value, type.name, depth)
571+
self.visit(type.value, type, depth)
543572

544-
def visitSum(self, sum, name, depth):
545-
type_info = self.type_info[name]
573+
def visitSum(self, sum, type, depth):
574+
name = type.name
546575
apply_t, apply_u, apply_target_u = self.apply_generics(
547576
name, "T", "U", "F::TargetU"
548577
)
@@ -568,27 +597,69 @@ def visitSum(self, sum, name, depth):
568597
self.emit("Ok(node) }", depth + 1)
569598
return
570599

571-
self.emit("match node {", depth + 1)
600+
self.emit("let folded = match node {", depth + 1)
572601
for cons in sum.types:
573-
fields_pattern = self.make_pattern(enum_name, cons.name, cons.fields)
574602
self.emit(
575-
f"{fields_pattern[0]} {{ {fields_pattern[1]}}} {fields_pattern[2]} => {{",
576-
depth + 2,
603+
f"{enum_name}::{cons.name}(cons) => {enum_name}::{cons.name}(Foldable::fold(cons, folder)?),",
604+
depth + 1,
577605
)
578606

579-
map_user_suffix = "" if type_info.has_attributes else "_cfg"
580-
self.emit(
581-
f"let range = folder.map_user{map_user_suffix}(range)?;", depth + 3
582-
)
607+
self.emit("};", depth + 1)
608+
self.emit("Ok(folded)", depth + 1)
609+
self.emit("}", depth)
583610

584-
self.gen_construction(
585-
fields_pattern[0], cons.fields, fields_pattern[2], depth + 3
586-
)
587-
self.emit("}", depth + 2)
611+
for cons in sum.types:
612+
self.visit(cons, type, depth)
613+
614+
def visitConstructor(self, cons, type, depth):
615+
apply_t, apply_u, apply_target_u = self.apply_generics(
616+
type.name, "T", "U", "F::TargetU"
617+
)
618+
enum_name = rust_type_name(type.name)
619+
620+
cons_type_name = f"{enum_name}{cons.name}"
621+
622+
self.emit(
623+
f"impl<T, U> Foldable<T, U> for {cons_type_name}{apply_t} {{", depth
624+
)
625+
self.emit(f"type Mapped = {cons_type_name}{apply_u};", depth + 1)
626+
self.emit(
627+
"fn fold<F: Fold<T, TargetU = U> + ?Sized>(self, folder: &mut F) -> Result<Self::Mapped, F::Error> {",
628+
depth + 1,
629+
)
630+
self.emit(
631+
f"folder.fold_{type.name}_{rust_field_name(cons.name)}(self)", depth + 2
632+
)
588633
self.emit("}", depth + 1)
589634
self.emit("}", depth)
590635

591-
def visitProduct(self, product, name, depth):
636+
self.emit(
637+
f"pub fn fold_{type.name}_{rust_field_name(cons.name)}<U, F: Fold<U> + ?Sized>(#[allow(unused)] folder: &mut F, node: {cons_type_name}{apply_u}) -> Result<{enum_name}{cons.name}{apply_target_u}, F::Error> {{",
638+
depth,
639+
)
640+
641+
type_info = self.type_info[type.name]
642+
643+
fields_pattern = self.make_pattern(cons.fields)
644+
645+
map_user_suffix = "" if type_info.has_attributes else "_cfg"
646+
self.emit(
647+
f"""
648+
let {cons_type_name} {{ {fields_pattern} }} = node;
649+
let context = folder.will_map_user{map_user_suffix}(&range);
650+
""",
651+
depth + 3,
652+
)
653+
self.fold_fields(cons.fields, depth + 3)
654+
self.emit(
655+
f"let range = folder.map_user{map_user_suffix}(range, context)?;",
656+
depth + 3,
657+
)
658+
self.composite_fields(f"{cons_type_name}", cons.fields, depth + 3)
659+
self.emit("}", depth + 2)
660+
661+
def visitProduct(self, product, type, depth):
662+
name = type.name
592663
apply_t, apply_u, apply_target_u = self.apply_generics(
593664
name, "T", "U", "F::TargetU"
594665
)
@@ -610,41 +681,47 @@ def visitProduct(self, product, name, depth):
610681
depth,
611682
)
612683

613-
fields_pattern = self.make_pattern(struct_name, struct_name, product.fields)
614-
self.emit(f"let {struct_name} {{ {fields_pattern[1]} }} = node;", depth + 1)
684+
fields_pattern = self.make_pattern(product.fields)
685+
self.emit(f"let {struct_name} {{ {fields_pattern} }} = node;", depth + 1)
615686

616687
map_user_suffix = "" if has_attributes else "_cfg"
617-
self.emit(f"let range = folder.map_user{map_user_suffix}(range)?;", depth + 3)
618688

619-
self.gen_construction(struct_name, product.fields, "", depth + 1)
689+
self.emit(
690+
f"let context = folder.will_map_user{map_user_suffix}(&range);", depth + 3
691+
)
692+
self.fold_fields(product.fields, depth + 1)
693+
self.emit(
694+
f"let range = folder.map_user{map_user_suffix}(range, context)?;", depth + 3
695+
)
696+
self.composite_fields(struct_name, product.fields, depth + 1)
620697

621698
self.emit("}", depth)
622699

623-
def make_pattern(self, rust_name, fieldname: str, fields):
624-
header = f"{rust_name}::{fieldname}({rust_name}{fieldname}"
625-
footer = ")"
626-
700+
def make_pattern(self, fields):
627701
body = ",".join(rust_field(f.name) for f in fields)
628702
if body:
629703
body += ","
630704
body += "range"
631705

632-
return header, body, footer
706+
return body
633707

634-
def gen_construction(self, header, fields, footer, depth):
708+
def fold_fields(self, fields, depth):
709+
for field in fields:
710+
name = rust_field(field.name)
711+
self.emit(f"let {name} = Foldable::fold({name}, folder)?;", depth + 1)
712+
713+
def composite_fields(self, header, fields, depth):
635714
self.emit(f"Ok({header} {{", depth)
636715
for field in fields:
637716
name = rust_field(field.name)
638-
self.emit(f"{name}: Foldable::fold({name}, folder)?,", depth + 1)
717+
self.emit(f"{name},", depth + 1)
639718
self.emit("range,", depth + 1)
640-
641-
self.emit(f"}}{footer})", depth)
719+
self.emit(f"}})", depth)
642720

643721

644722
class FoldModuleVisitor(EmitVisitor):
645723
def visitModule(self, mod):
646724
depth = 0
647-
self.emit("use crate::fold_helpers::Foldable;", depth)
648725
FoldTraitDefVisitor(self.file, self.type_info).visit(mod, depth)
649726
FoldImplVisitor(self.file, self.type_info).visit(mod, depth)
650727

ast/src/fold_helpers.rs renamed to ast/src/fold.rs

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use crate::{builtin, fold::Fold, ConversionFlag};
1+
use super::generic::*;
2+
3+
use crate::{builtin, ConversionFlag};
24

35
pub trait Foldable<T, U> {
46
type Mapped;
@@ -49,7 +51,7 @@ where
4951

5052
macro_rules! simple_fold {
5153
($($t:ty),+$(,)?) => {
52-
$(impl<T, U> $crate::fold_helpers::Foldable<T, U> for $t {
54+
$(impl<T, U> $crate::fold::Foldable<T, U> for $t {
5355
type Mapped = Self;
5456
#[inline]
5557
fn fold<F: Fold<T, TargetU = U> + ?Sized>(
@@ -70,3 +72,5 @@ simple_fold!(
7072
ConversionFlag,
7173
builtin::Constant
7274
);
75+
76+
include!("gen/fold.rs");

0 commit comments

Comments
 (0)