From bd6801bb7239d572bca0a3970f34269fe261f69d Mon Sep 17 00:00:00 2001 From: bishabosha Date: Thu, 19 Nov 2020 13:36:30 +0100 Subject: [PATCH] Enable one encoding of recursive gadt to work with inline match If the recursive part is fixed to a subtype of the union of the cases of the enum, enable inline match to reduce cases. Notes: this encoding could be supported by a compiletime.Refract[S] type to split cases of a sum type. --- .../src/dotty/tools/dotc/typer/Inliner.scala | 19 +++++++++- tests/run/enum-nat.scala | 37 +++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) create mode 100644 tests/run/enum-nat.scala diff --git a/compiler/src/dotty/tools/dotc/typer/Inliner.scala b/compiler/src/dotty/tools/dotc/typer/Inliner.scala index 5c9a1d7ca05a..d8270ce76f1b 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inliner.scala @@ -337,6 +337,21 @@ object Inliner { def codeOf(arg: Tree, pos: SrcPos)(using Context): Tree = Literal(Constant(arg.show)).withSpan(pos.span) } + + extension (tp: Type) { + + /** same as widenTermRefExpr, but preserves modules and singleton enum values */ + private final def widenInlineScrutinee(using Context): Type = tp.stripTypeVar match { + case tp: TermRef => + val sym = tp.termSymbol + if sym.isAllOf(EnumCase, butNot=JavaDefined) || sym.is(Module) then tp + else if !tp.isOverloaded then tp.underlying.widenExpr.widenInlineScrutinee + else tp + case _ => tp + } + + } + } /** Produces an inlined version of `call` via its `inlined` method. @@ -1003,7 +1018,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) { * scrutinee as RHS and type that corresponds to RHS. */ def newTermBinding(sym: TermSymbol, rhs: Tree): Unit = { - val copied = sym.copy(info = rhs.tpe.widenTermRefExpr, coord = sym.coord, flags = sym.flags &~ Case).asTerm + val copied = sym.copy(info = rhs.tpe.widenInlineScrutinee, coord = sym.coord, flags = sym.flags &~ Case).asTerm caseBindingMap += ((sym, ValDef(copied, constToLiteral(rhs)).withSpan(sym.span))) } @@ -1121,7 +1136,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) { def reduceSubPatterns(pats: List[Tree], selectors: List[Tree]): Boolean = (pats, selectors) match { case (Nil, Nil) => true case (pat :: pats1, selector :: selectors1) => - val elem = newSym(InlineBinderName.fresh(), Synthetic, selector.tpe.widenTermRefExpr).asTerm + val elem = newSym(InlineBinderName.fresh(), Synthetic, selector.tpe.widenInlineScrutinee).asTerm val rhs = constToLiteral(selector) elem.defTree = rhs caseBindingMap += ((NoSymbol, ValDef(elem, rhs).withSpan(elem.span))) diff --git a/tests/run/enum-nat.scala b/tests/run/enum-nat.scala new file mode 100644 index 000000000000..47bddcc665cd --- /dev/null +++ b/tests/run/enum-nat.scala @@ -0,0 +1,37 @@ +import Nat._ +import compiletime._ + +enum Nat: + case Zero + case Succ[N <: Nat.Refract](n: N) + +object Nat: + type Refract = Zero.type | Succ[_] + +inline def toIntTypeLevel[N <: Nat]: Int = inline erasedValue[N] match + case _: Zero.type => 0 + case _: Succ[n] => toIntTypeLevel[n] + 1 + +inline def toInt[N <: Nat.Refract](inline nat: N): Int = inline nat match + case nat: Zero.type => 0 + case nat: Succ[n] => toInt(nat.n) + 1 + +inline def toIntUnapply[N <: Nat.Refract](inline nat: N): Int = inline nat match + case Zero => 0 + case Succ(n) => toIntUnapply(n) + 1 + +inline def toIntTypeTailRec[N <: Nat, Acc <: Int]: Int = inline erasedValue[N] match + case _: Zero.type => constValue[Acc] + case _: Succ[n] => toIntTypeTailRec[n, S[Acc]] + +inline def toIntErased[N <: Nat.Refract](inline nat: N): Int = toIntTypeTailRec[N, 0] + +@main def Test: Unit = + println("erased value:") + assert(toIntTypeLevel[Succ[Succ[Succ[Zero.type]]]] == 3) + println("type test:") + assert(toInt(Succ(Succ(Succ(Zero)))) == 3) + println("unapply:") + assert(toIntUnapply(Succ(Succ(Succ(Zero)))) == 3) + println("infer erased:") + assert(toIntErased(Succ(Succ(Succ(Zero)))) == 3)