diff --git a/compiler/src/dotty/tools/dotc/printing/DecompilerPrinter.scala b/compiler/src/dotty/tools/dotc/printing/DecompilerPrinter.scala index de0da060a01f..9630bdc98c87 100644 --- a/compiler/src/dotty/tools/dotc/printing/DecompilerPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/DecompilerPrinter.scala @@ -9,12 +9,12 @@ import dotty.tools.dotc.core.StdNames.nme import dotty.tools.dotc.core.Flags._ import dotty.tools.dotc.core.Symbols._ import dotty.tools.dotc.core.StdNames._ - +import dotty.tools.dotc.core.Annotations.Annotation class DecompilerPrinter(_ctx: Context) extends RefinedPrinter(_ctx) { - override protected def filterModTextAnnots(annots: List[untpd.Tree]): List[untpd.Tree] = - super.filterModTextAnnots(annots).filter(_.tpe != defn.SourceFileAnnotType) + override protected def dropAnnotForModText(sym: Symbol): Boolean = + super.dropAnnotForModText(sym) || sym == defn.SourceFileAnnot override protected def blockToText[T >: Untyped](block: Block[T]): Text = block match { diff --git a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala index d19834955908..82c2fb6c9a0b 100644 --- a/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala +++ b/compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala @@ -10,6 +10,7 @@ import Symbols._ import NameOps._ import TypeErasure.ErasedValueType import Contexts.Context +import Annotations.Annotation import Denotations._ import SymDenotations._ import StdNames.{nme, tpnme} @@ -633,7 +634,9 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { def Modifiers(sym: Symbol)(implicit ctx: Context): Modifiers = untpd.Modifiers( sym.flags & (if (sym.isType) ModifierFlags | VarianceFlags else ModifierFlags), if (sym.privateWithin.exists) sym.privateWithin.asType.name else tpnme.EMPTY, - sym.annotations map (_.tree)) + sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree)) + + protected def dropAnnotForModText(sym: Symbol): Boolean = sym == defn.BodyAnnot protected def optAscription[T >: Untyped](tpt: Tree[T]): Text = optText(tpt)(": " ~ _) @@ -757,14 +760,12 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) { if (homogenizedView && mods.flags.isTypeFlags) flagMask &~= Implicit // drop implicit from classes val flags = (if (sym.exists) sym.flags else (mods.flags)) & flagMask val flagsText = if (flags.isEmpty) "" else keywordStr(flags.toString) - val annotations = filterModTextAnnots( - if (sym.exists) sym.annotations.filterNot(_.isInstanceOf[Annotations.BodyAnnotation]).map(_.tree) - else mods.annotations) + val annotations = + if (sym.exists) sym.annotations.filterNot(ann => dropAnnotForModText(ann.symbol)).map(_.tree) + else mods.annotations.filterNot(tree => dropAnnotForModText(tree.symbol)) Text(annotations.map(annotText), " ") ~~ flagsText ~~ (Str(kw) provided !suppressKw) } - protected def filterModTextAnnots(annots: List[untpd.Tree]): List[untpd.Tree] = annots - def optText(name: Name)(encl: Text => Text): Text = if (name.isEmpty) "" else encl(toText(name)) diff --git a/compiler/src/dotty/tools/dotc/typer/ConstFold.scala b/compiler/src/dotty/tools/dotc/typer/ConstFold.scala index 68a5d05f5f26..882a2274a5c5 100644 --- a/compiler/src/dotty/tools/dotc/typer/ConstFold.scala +++ b/compiler/src/dotty/tools/dotc/typer/ConstFold.scala @@ -20,7 +20,7 @@ object ConstFold { def apply(tree: Tree)(implicit ctx: Context): Tree = finish(tree) { tree match { case Apply(Select(xt, op), yt :: Nil) => - xt.tpe.widenTermRefExpr match { + xt.tpe.widenTermRefExpr.normalized match { case ConstantType(x) => yt.tpe.widenTermRefExpr match { case ConstantType(y) => foldBinop(op, x, y) @@ -42,7 +42,7 @@ object ConstFold { */ def apply(tree: Tree, pt: Type)(implicit ctx: Context): Tree = finish(apply(tree)) { - tree.tpe.widenTermRefExpr match { + tree.tpe.widenTermRefExpr.normalized match { case ConstantType(x) => x convertTo pt case _ => null } diff --git a/compiler/src/dotty/tools/dotc/typer/Inliner.scala b/compiler/src/dotty/tools/dotc/typer/Inliner.scala index 5b1c9fe47ec4..1aef548cd2ae 100644 --- a/compiler/src/dotty/tools/dotc/typer/Inliner.scala +++ b/compiler/src/dotty/tools/dotc/typer/Inliner.scala @@ -461,24 +461,19 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { } // Drop unused bindings - val matchBindings = reducer.matchBindingsBuf.toList - val (finalBindings, finalExpansion) = dropUnusedDefs(bindingsBuf.toList ++ matchBindings, expansion1) - val (finalMatchBindings, finalArgBindings) = finalBindings.partition(matchBindings.contains(_)) + val (finalBindings, finalExpansion) = dropUnusedDefs(bindingsBuf.toList, expansion1) if (inlinedMethod == defn.Typelevel_error) issueError() // Take care that only argument bindings go into `bindings`, since positions are // different for bindings from arguments and bindings from body. - tpd.Inlined(call, finalArgBindings, seq(finalMatchBindings, finalExpansion)) + tpd.Inlined(call, finalBindings, finalExpansion) } } /** A utility object offering methods for rewriting inlined code */ object reducer { - /** Additional bindings established by reducing match expressions */ - val matchBindingsBuf = new mutable.ListBuffer[MemberDef] - /** An extractor for terms equivalent to `new C(args)`, returning the class `C`, * a list of bindings, and the arguments `args`. Can see inside blocks and Inlined nodes and can * follow a reference to an inline value binding to its right hand side. @@ -599,7 +594,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { def unapply(tree: Trees.Ident[_])(implicit ctx: Context): Option[Tree] = { def search(buf: mutable.ListBuffer[MemberDef]) = buf.find(_.name == tree.name) if (paramProxies.contains(tree.typeOpt)) - search(bindingsBuf).orElse(search(matchBindingsBuf)) match { + search(bindingsBuf) match { case Some(vdef: ValDef) if vdef.symbol.is(Inline) => Some(integrate(vdef.rhs, vdef.symbol)) case Some(ddef: DefDef) => @@ -611,7 +606,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { } object ConstantValue { - def unapply(tree: Tree)(implicit ctx: Context): Option[Any] = tree.tpe.widenTermRefExpr match { + def unapply(tree: Tree)(implicit ctx: Context): Option[Any] = tree.tpe.widenTermRefExpr.normalized match { case ConstantType(Constant(x)) => Some(x) case _ => None } @@ -662,7 +657,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { * for the pattern-bound variables and the RHS of the selected case. * Returns `None` if no case was selected. */ - type MatchRedux = Option[(List[MemberDef], untpd.Tree)] + type MatchRedux = Option[(List[MemberDef], tpd.Tree)] /** Reduce an inline match * @param mtch the match tree @@ -674,7 +669,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { * @return optionally, if match can be reduced to a matching case: A pair of * bindings for all pattern-bound variables and the RHS of the case. */ - def reduceInlineMatch(scrutinee: Tree, scrutType: Type, cases: List[untpd.CaseDef], typer: Typer)(implicit ctx: Context): MatchRedux = { + def reduceInlineMatch(scrutinee: Tree, scrutType: Type, cases: List[CaseDef], typer: Typer)(implicit ctx: Context): MatchRedux = { val isImplicit = scrutinee.isEmpty val gadtSyms = typer.gadtSyms(scrutType) @@ -712,7 +707,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { val getBoundVars = new TreeAccumulator[List[TypeSymbol]] { def apply(syms: List[TypeSymbol], t: Tree)(implicit ctx: Context) = { val syms1 = t match { - case t: Bind if t.symbol.isType && t.name != tpnme.WILDCARD => + case t: Bind if t.symbol.isType => t.symbol.asType :: syms case _ => syms @@ -739,7 +734,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { // ConstraintHandler#approximation does. However, this only works for constrained paramrefs // not GADT-bound variables. Hopefully we will get some way to improve this when we // re-implement GADTs in terms of constraints. - bindingsBuf += TypeDef(bv) + if (bv.name != nme.WILDCARD) bindingsBuf += TypeDef(bv) } reducePattern(bindingsBuf, scrut, pat1) } @@ -805,7 +800,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { val scrutineeSym = newSym(InlineScrutineeName.fresh(), Synthetic, scrutType).asTerm val scrutineeBinding = normalizeBinding(ValDef(scrutineeSym, scrutinee)) - def reduceCase(cdef: untpd.CaseDef): MatchRedux = { + def reduceCase(cdef: CaseDef): MatchRedux = { val caseBindingsBuf = new mutable.ListBuffer[MemberDef]() def guardOK(implicit ctx: Context) = cdef.guard.isEmpty || { val guardCtx = ctx.fresh.setNewScope @@ -824,7 +819,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { None } - def recur(cases: List[untpd.CaseDef]): MatchRedux = cases match { + def recur(cases: List[CaseDef]): MatchRedux = cases match { case Nil => None case cdef :: cases1 => reduceCase(cdef) `orElse` recur(cases1) } @@ -895,14 +890,15 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { super.typedMatchFinish(tree, sel, wideSelType, cases, pt) else { val selType = if (sel.isEmpty) wideSelType else sel.tpe - reduceInlineMatch(sel, selType, cases, this) match { - case Some((caseBindings, rhs)) => - var rhsCtx = ctx.fresh.setNewScope - for (binding <- caseBindings) { - matchBindingsBuf += binding - rhsCtx.enter(binding.symbol) - } - typedExpr(rhs, pt)(rhsCtx) + reduceInlineMatch(sel, selType, cases.asInstanceOf[List[CaseDef]], this) match { + case Some((caseBindings, rhs0)) => + val (usedBindings, rhs1) = dropUnusedDefs(caseBindings, rhs0) + val rhs = seq(usedBindings, rhs1) + inlining.println(i"""--- reduce: + |$tree + |--- to: + |$rhs""") + typedExpr(rhs, pt) case None => def guardStr(guard: untpd.Tree) = if (guard.isEmpty) "" else i" if $guard" def patStr(cdef: untpd.CaseDef) = i"case ${cdef.pat}${guardStr(cdef.guard)}" @@ -993,7 +989,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(implicit ctx: Context) { val dealiasedType = dealias(t.tpe) val t1 = t match { case t: RefTree => - if (boundTypes.contains(t.symbol)) TypeTree(dealiasedType).withPos(t.pos) + if (t.name != nme.WILDCARD && boundTypes.contains(t.symbol)) TypeTree(dealiasedType).withPos(t.pos) else t.withType(dealiasedType) case t: DefTree => t.symbol.info = dealias(t.symbol.info) diff --git a/compiler/test/dotc/pos-test-pickling.blacklist b/compiler/test/dotc/pos-test-pickling.blacklist index 22a41321bf77..28accd4f5cfa 100644 --- a/compiler/test/dotc/pos-test-pickling.blacklist +++ b/compiler/test/dotc/pos-test-pickling.blacklist @@ -12,6 +12,7 @@ i4125.scala implicit-dep.scala inline-access-levels inline-rewrite.scala +inline-caseclass.scala macro-with-array macro-with-type matchtype.scala diff --git a/compiler/test/dotc/run-from-tasty.blacklist b/compiler/test/dotc/run-from-tasty.blacklist index ad83b4eb8b8b..a2fbf589a9c8 100644 --- a/compiler/test/dotc/run-from-tasty.blacklist +++ b/compiler/test/dotc/run-from-tasty.blacklist @@ -6,3 +6,6 @@ puzzle.scala # Need to print empty tree for implicit match implicitMatch.scala +typeclass-derivation1.scala +typeclass-derivation2.scala + diff --git a/compiler/test/dotc/run-test-pickling.blacklist b/compiler/test/dotc/run-test-pickling.blacklist index 4e30708a3def..91ab8c1a70a8 100644 --- a/compiler/test/dotc/run-test-pickling.blacklist +++ b/compiler/test/dotc/run-test-pickling.blacklist @@ -9,3 +9,5 @@ t8133b tuples1.scala tuples1a.scala implicitMatch.scala +typeclass-derivation1.scala +typeclass-derivation2.scala diff --git a/tests/run/typeclass-derivation1.scala b/tests/run/typeclass-derivation1.scala new file mode 100644 index 000000000000..47e457e8ed10 --- /dev/null +++ b/tests/run/typeclass-derivation1.scala @@ -0,0 +1,100 @@ +object Deriving { + import scala.typelevel._ + + sealed trait Shape + + class HasSumShape[T, S <: Tuple] + + abstract class HasProductShape[T, Xs <: Tuple] { + def toProduct(x: T): Xs + def fromProduct(x: Xs): T + } + + enum Lst[+T] { + case Cons(hd: T, tl: Lst[T]) + case Nil + } + + object Lst { + implicit def lstShape[T]: HasSumShape[Lst[T], (Cons[T], Nil.type)] = new HasSumShape + + implicit def consShape[T]: HasProductShape[Lst.Cons[T], (T, Lst[T])] = new { + def toProduct(xs: Lst.Cons[T]) = (xs.hd, xs.tl) + def fromProduct(xs: (T, Lst[T])): Lst.Cons[T] = Lst.Cons(xs(0), xs(1)).asInstanceOf + } + + implicit def nilShape[T]: HasProductShape[Lst.Nil.type, Unit] = new { + def toProduct(xs: Lst.Nil.type) = () + def fromProduct(xs: Unit) = Lst.Nil + } + + implicit def LstEq[T: Eq]: Eq[Lst[T]] = Eq.derivedForSum + implicit def ConsEq[T: Eq]: Eq[Cons[T]] = Eq.derivedForProduct + implicit def NilEq[T]: Eq[Nil.type] = Eq.derivedForProduct + } + + trait Eq[T] { + def equals(x: T, y: T): Boolean + } + + object Eq { + inline def tryEq[T](x: T, y: T) = implicit match { + case eq: Eq[T] => eq.equals(x, y) + } + + inline def deriveForSum[Alts <: Tuple](x: Any, y: Any): Boolean = inline erasedValue[Alts] match { + case _: (alt *: alts1) => + x match { + case x: `alt` => + y match { + case y: `alt` => tryEq[alt](x, y) + case _ => false + } + case _ => deriveForSum[alts1](x, y) + } + case _: Unit => + false + } + + inline def deriveForProduct[Elems <: Tuple](xs: Elems, ys: Elems): Boolean = inline erasedValue[Elems] match { + case _: (elem *: elems1) => + val xs1 = xs.asInstanceOf[elem *: elems1] + val ys1 = ys.asInstanceOf[elem *: elems1] + tryEq[elem](xs1.head, ys1.head) && + deriveForProduct[elems1](xs1.tail, ys1.tail) + case _: Unit => + true + } + + inline def derivedForSum[T, Alts <: Tuple](implicit ev: HasSumShape[T, Alts]): Eq[T] = new { + def equals(x: T, y: T): Boolean = deriveForSum[Alts](x, y) + } + + inline def derivedForProduct[T, Elems <: Tuple](implicit ev: HasProductShape[T, Elems]): Eq[T] = new { + def equals(x: T, y: T): Boolean = deriveForProduct[Elems](ev.toProduct(x), ev.toProduct(y)) + } + + implicit object eqInt extends Eq[Int] { + def equals(x: Int, y: Int) = x == y + } + } +} + +object Test extends App { + import Deriving._ + val eq = implicitly[Eq[Lst[Int]]] + val xs = Lst.Cons(1, Lst.Cons(2, Lst.Cons(3, Lst.Nil))) + val ys = Lst.Cons(1, Lst.Cons(2, Lst.Nil)) + assert(eq.equals(xs, xs)) + assert(!eq.equals(xs, ys)) + assert(!eq.equals(ys, xs)) + assert(eq.equals(ys, ys)) + + val eq2 = implicitly[Eq[Lst[Lst[Int]]]] + val xss = Lst.Cons(xs, Lst.Cons(ys, Lst.Nil)) + val yss = Lst.Cons(xs, Lst.Nil) + assert(eq2.equals(xss, xss)) + assert(!eq2.equals(xss, yss)) + assert(!eq2.equals(yss, xss)) + assert(eq2.equals(yss, yss)) +} \ No newline at end of file diff --git a/tests/run/typeclass-derivation2.check b/tests/run/typeclass-derivation2.check new file mode 100644 index 000000000000..4ec92ef1a1e0 --- /dev/null +++ b/tests/run/typeclass-derivation2.check @@ -0,0 +1,8 @@ +ListBuffer(0, 11, 0, 22, 0, 33, 1) +Cons(11,Cons(22,Cons(33,Nil))) +ListBuffer(0, 0, 11, 0, 22, 0, 33, 1, 0, 0, 11, 0, 22, 1, 1) +Cons(Cons(11,Cons(22,Cons(33,Nil))),Cons(Cons(11,Cons(22,Nil)),Nil)) +ListBuffer(1, 2) +Pair(1,2) +Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil()))) +Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Cons(hd = 33, tl = Nil()))), tl = Cons(hd = Cons(hd = 11, tl = Cons(hd = 22, tl = Nil())), tl = Nil())) diff --git a/tests/run/typeclass-derivation2.scala b/tests/run/typeclass-derivation2.scala new file mode 100644 index 000000000000..69863715976f --- /dev/null +++ b/tests/run/typeclass-derivation2.scala @@ -0,0 +1,427 @@ +import scala.collection.mutable +import scala.annotation.tailrec + +trait Deriving { + import Deriving._ + + /** A mirror of case with ordinal number `ordinal` and elements as given by `Product` */ + def mirror(ordinal: Int, product: Product): Mirror = + new Mirror(this, ordinal, product) + + /** A mirror with elements given as an array */ + def mirror(ordinal: Int, elems: Array[AnyRef]): Mirror = + mirror(ordinal, new ArrayProduct(elems)) + + /** A mirror with an initial empty array of `numElems` elements, to be filled in. */ + def mirror(ordinal: Int, numElems: Int): Mirror = + mirror(ordinal, new Array[AnyRef](numElems)) + + /** A mirror of a case with no elements */ + def mirror(ordinal: Int): Mirror = + mirror(ordinal, EmptyProduct) + + /** The case and element labels of the described ADT as encoded strings. */ + protected def caseLabels: Array[String] + + private final val separator = '\000' + + private def label(ordinal: Int, idx: Int): String = { + val labels = caseLabels(ordinal) + @tailrec def separatorPos(from: Int): Int = + if (from == labels.length || labels(from) == separator) from + else separatorPos(from + 1) + @tailrec def findLabel(count: Int, idx: Int): String = + if (idx == labels.length) "" + else if (count == 0) labels.substring(idx, separatorPos(idx)) + else findLabel(if (labels(idx) == separator) count - 1 else count, idx + 1) + findLabel(idx, 0) + } +} + +// Generic deriving infrastructure +object Deriving { + + /** A generic representation of a case in an ADT + * @param deriving The companion object of the ADT + * @param ordinal The ordinal value of the case in the list of the ADT's cases + * @param elems The elements of the case + */ + class Mirror(val deriving: Deriving, val ordinal: Int, val elems: Product) { + + /** The `n`'th element of this generic case */ + def apply(n: Int): Any = elems.productElement(n) + + /** The name of the constructor of the case reflected by this mirror */ + def caseLabel: String = deriving.label(ordinal, 0) + + /** The label of the `n`'th element of the case reflected by this mirror */ + def elementLabel(n: Int) = deriving.label(ordinal, n + 1) + } + + /** A class for mapping between an ADT value and + * the case mirror that represents the value. + */ + abstract class Reflected[T] { + + /** The case mirror corresponding to ADT instance `x` */ + def reflect(x: T): Mirror + + /** The ADT instance corresponding to given `mirror` */ + def reify(mirror: Mirror): T + + /** The companion object of the ADT */ + def deriving: Deriving + } + + /** The shape of an ADT. + * This is eithe a product (`Case`) or a sum (`Cases`) of products. + */ + enum Shape { + + /** A sum with alternative types `Alts` */ + case Cases[Alts <: Tuple] + + /** A product type `T` with element types `Elems` */ + case Case[T, Elems <: Tuple] + } + + /** Every generic derivation starts with a typeclass instance of this type. + * It informs that type `T` has shape `S` and also implements runtime reflection on `T`. + */ + abstract class Shaped[T, S <: Shape] extends Reflected[T] + + /** Helper class to turn arrays into products */ + private class ArrayProduct(val elems: Array[AnyRef]) extends Product { + def canEqual(that: Any): Boolean = true + def productElement(n: Int) = elems(n) + def productArity = elems.length + override def productIterator: Iterator[Any] = elems.iterator + def update(n: Int, x: Any) = elems(n) = x.asInstanceOf[AnyRef] + } + + /** Helper object */ + private object EmptyProduct extends Product { + def canEqual(that: Any): Boolean = true + def productElement(n: Int) = throw new IndexOutOfBoundsException + def productArity = 0 + } +} + +// An algebraic datatype +enum Lst[+T] // derives Eq, Pickler +{ + case Cons(hd: T, tl: Lst[T]) + case Nil +} + +object Lst extends Deriving { + // common compiler-generated infrastructure + import Deriving._ + + type Shape[T] = Shape.Cases[( + Shape.Case[Cons[T], (T, Lst[T])], + Shape.Case[Nil.type, Unit] + )] + + val NilMirror = mirror(1) + + implicit def lstShape[T]: Shaped[Lst[T], Shape[T]] = new { + def reflect(xs: Lst[T]): Mirror = xs match { + case xs: Cons[T] => mirror(0, xs) + case Nil => NilMirror + } + def reify(c: Mirror): Lst[T] = c.ordinal match { + case 0 => Cons[T](c(0).asInstanceOf, c(1).asInstanceOf) + case 1 => Nil + } + def deriving = Lst + } + + protected val caseLabels = Array("Cons\000hd\000tl", "Nil") + + // three clauses that could be generated from a `derives` clause + implicit def LstEq[T: Eq]: Eq[Lst[T]] = Eq.derived + implicit def LstPickler[T: Pickler]: Pickler[Lst[T]] = Pickler.derived + implicit def LstShow[T: Show]: Show[Lst[T]] = Show.derived +} + +// A simple product type +case class Pair[T](x: T, y: T) // derives Eq, Pickler + +object Pair extends Deriving { + // common compiler-generated infrastructure + import Deriving._ + + type Shape[T] = Shape.Case[Pair[T], (T, T)] + + implicit def pairShape[T]: Shaped[Pair[T], Shape[T]] = new { + def reflect(xy: Pair[T]) = + mirror(0, xy) + def reify(c: Mirror): Pair[T] = + Pair(c(0).asInstanceOf, c(1).asInstanceOf) + def deriving = Pair + } + + protected val caseLabels = Array("Pair\000x\000y") + + // two clauses that could be generated from a `derives` clause + implicit def PairEq[T: Eq]: Eq[Pair[T]] = Eq.derived + implicit def PairPickler[T: Pickler]: Pickler[Pair[T]] = Pickler.derived +} + +// A typeclass +trait Eq[T] { + def eql(x: T, y: T): Boolean +} + +object Eq { + import scala.typelevel._ + import Deriving._ + + inline def tryEql[T](x: T, y: T) = implicit match { + case eq: Eq[T] => eq.eql(x, y) + } + + inline def eqlElems[Elems <: Tuple](xs: Mirror, ys: Mirror, n: Int): Boolean = + inline erasedValue[Elems] match { + case _: (elem *: elems1) => + tryEql[elem](xs(n).asInstanceOf, ys(n).asInstanceOf) && + eqlElems[elems1](xs, ys, n + 1) + case _: Unit => + true + } + + inline def eqlCase[T, Elems <: Tuple](r: Reflected[T], x: T, y: T) = + eqlElems[Elems](r.reflect(x), r.reflect(y), 0) + + inline def eqlCases[T, Alts <: Tuple](r: Reflected[T], x: T, y: T): Boolean = + inline erasedValue[Alts] match { + case _: (Shape.Case[alt, elems] *: alts1) => + x match { + case x: `alt` => + y match { + case y: `alt` => eqlCase[T, elems](r, x, y) + case _ => false + } + case _ => eqlCases[T, alts1](r, x, y) + } + case _: Unit => + false + } + + inline def derived[T, S <: Shape](implicit ev: Shaped[T, S]): Eq[T] = new { + def eql(x: T, y: T): Boolean = inline erasedValue[S] match { + case _: Shape.Cases[alts] => + eqlCases[T, alts](ev, x, y) + case _: Shape.Case[_, elems] => + eqlCase[T, elems](ev, x, y) + } + } + + implicit object IntEq extends Eq[Int] { + def eql(x: Int, y: Int) = x == y + } +} + +// Another typeclass +trait Pickler[T] { + def pickle(buf: mutable.ListBuffer[Int], x: T): Unit + def unpickle(buf: mutable.ListBuffer[Int]): T +} + +object Pickler { + import scala.typelevel._ + import Deriving._ + + def nextInt(buf: mutable.ListBuffer[Int]): Int = try buf.head finally buf.trimStart(1) + + inline def tryPickle[T](buf: mutable.ListBuffer[Int], x: T): Unit = implicit match { + case pkl: Pickler[T] => pkl.pickle(buf, x) + } + + inline def pickleElems[Elems <: Tuple](buf: mutable.ListBuffer[Int], elems: Mirror, n: Int): Unit = + inline erasedValue[Elems] match { + case _: (elem *: elems1) => + tryPickle[elem](buf, elems(n).asInstanceOf[elem]) + pickleElems[elems1](buf, elems, n + 1) + case _: Unit => + } + + inline def pickleCase[T, Elems <: Tuple](r: Reflected[T], buf: mutable.ListBuffer[Int], x: T): Unit = + pickleElems[Elems](buf, r.reflect(x), 0) + + inline def pickleCases[T, Alts <: Tuple](r: Reflected[T], buf: mutable.ListBuffer[Int], x: T, n: Int): Unit = + inline erasedValue[Alts] match { + case _: (Shape.Case[alt, elems] *: alts1) => + x match { + case x: `alt` => + buf += n + pickleCase[T, elems](r, buf, x) + case _ => + pickleCases[T, alts1](r, buf, x, n + 1) + } + case _: Unit => + } + + inline def tryUnpickle[T](buf: mutable.ListBuffer[Int]): T = implicit match { + case pkl: Pickler[T] => pkl.unpickle(buf) + } + + inline def unpickleElems[Elems <: Tuple](buf: mutable.ListBuffer[Int], elems: Array[AnyRef], n: Int): Unit = + inline erasedValue[Elems] match { + case _: (elem *: elems1) => + elems(n) = tryUnpickle[elem](buf).asInstanceOf[AnyRef] + unpickleElems[elems1](buf, elems, n + 1) + case _: Unit => + } + + inline def unpickleCase[T, Elems <: Tuple](r: Reflected[T], buf: mutable.ListBuffer[Int], ordinal: Int): T = { + inline val size = constValue[Tuple.Size[Elems]] + inline if (size == 0) + r.reify(r.deriving.mirror(ordinal)) + else { + val elems = new Array[Object](size) + unpickleElems[Elems](buf, elems, 0) + r.reify(r.deriving.mirror(ordinal, elems)) + } + } + + inline def unpickleCases[T, Alts <: Tuple](r: Reflected[T], buf: mutable.ListBuffer[Int], ordinal: Int, n: Int): T = + inline erasedValue[Alts] match { + case _: (Shape.Case[_, elems] *: alts1) => + if (n == ordinal) unpickleCase[T, elems](r, buf, ordinal) + else unpickleCases[T, alts1](r, buf, ordinal, n + 1) + case _ => + throw new IndexOutOfBoundsException(s"unexpected ordinal number: $ordinal") + } + + inline def derived[T, S <: Shape](implicit ev: Shaped[T, S]): Pickler[T] = new { + def pickle(buf: mutable.ListBuffer[Int], x: T): Unit = inline erasedValue[S] match { + case _: Shape.Cases[alts] => + pickleCases[T, alts](ev, buf, x, 0) + case _: Shape.Case[_, elems] => + pickleCase[T, elems](ev, buf, x) + } + def unpickle(buf: mutable.ListBuffer[Int]): T = inline erasedValue[S] match { + case _: Shape.Cases[alts] => + unpickleCases[T, alts](ev, buf, nextInt(buf), 0) + case _: Shape.Case[_, elems] => + unpickleCase[T, elems](ev, buf, 0) + } + } + + implicit object IntPickler extends Pickler[Int] { + def pickle(buf: mutable.ListBuffer[Int], x: Int): Unit = buf += x + def unpickle(buf: mutable.ListBuffer[Int]): Int = nextInt(buf) + } +} + +// A third typeclass, making use of labels +trait Show[T] { + def show(x: T): String +} +object Show { + import scala.typelevel._ + import Deriving._ + + inline def tryShow[T](x: T): String = implicit match { + case s: Show[T] => s.show(x) + } + + inline def showElems[Elems <: Tuple](elems: Mirror, n: Int): List[String] = + inline erasedValue[Elems] match { + case _: (elem *: elems1) => + val formal = elems.elementLabel(n) + val actual = tryShow[elem](elems(n).asInstanceOf) + s"$formal = $actual" :: showElems[elems1](elems, n + 1) + case _: Unit => + Nil + } + + inline def showCase[T, Elems <: Tuple](r: Reflected[T], x: T): String = { + val mirror = r.reflect(x) + val args = showElems[Elems](mirror, 0).mkString(", ") + s"${mirror.caseLabel}($args)" + } + + inline def showCases[T, Alts <: Tuple](r: Reflected[T], x: T): String = + inline erasedValue[Alts] match { + case _: (Shape.Case[alt, elems] *: alts1) => + x match { + case x: `alt` => showCase[T, elems](r, x) + case _ => showCases[T, alts1](r, x) + } + case _: Unit => + throw new MatchError(x) + } + + inline def derived[T, S <: Shape](implicit ev: Shaped[T, S]): Show[T] = new { + def show(x: T): String = inline erasedValue[S] match { + case _: Shape.Cases[alts] => + showCases[T, alts](ev, x) + case _: Shape.Case[_, elems] => + showCase[T, elems](ev, x) + } + } + + implicit object IntShow extends Show[Int] { + def show(x: Int): String = x.toString + } +} + +// Tests +object Test extends App { + import Deriving._ + val eq = implicitly[Eq[Lst[Int]]] + val xs = Lst.Cons(11, Lst.Cons(22, Lst.Cons(33, Lst.Nil))) + val ys = Lst.Cons(11, Lst.Cons(22, Lst.Nil)) + assert(eq.eql(xs, xs)) + assert(!eq.eql(xs, ys)) + assert(!eq.eql(ys, xs)) + assert(eq.eql(ys, ys)) + + val eq2 = implicitly[Eq[Lst[Lst[Int]]]] + val xss = Lst.Cons(xs, Lst.Cons(ys, Lst.Nil)) + val yss = Lst.Cons(xs, Lst.Nil) + assert(eq2.eql(xss, xss)) + assert(!eq2.eql(xss, yss)) + assert(!eq2.eql(yss, xss)) + assert(eq2.eql(yss, yss)) + + val buf = new mutable.ListBuffer[Int] + val pkl = implicitly[Pickler[Lst[Int]]] + pkl.pickle(buf, xs) + println(buf) + val xs1 = pkl.unpickle(buf) + println(xs1) + assert(xs1 == xs) + assert(eq.eql(xs1, xs)) + + val pkl2 = implicitly[Pickler[Lst[Lst[Int]]]] + pkl2.pickle(buf, xss) + println(buf) + val xss1 = pkl2.unpickle(buf) + println(xss1) + assert(xss == xss1) + assert(eq2.eql(xss, xss1)) + + val p1 = Pair(1, 2) + val p2 = Pair(1, 2) + val p3 = Pair(2, 1) + val eqp = implicitly[Eq[Pair[Int]]] + assert(eqp.eql(p1, p2)) + assert(!eqp.eql(p2, p3)) + + val pklp = implicitly[Pickler[Pair[Int]]] + pklp.pickle(buf, p1) + println(buf) + val p1a = pklp.unpickle(buf) + println(p1a) + assert(p1 == p1a) + assert(eqp.eql(p1, p1a)) + + def showPrintln[T: Show](x: T): Unit = + println(implicitly[Show[T]].show(x)) + showPrintln(xs) + showPrintln(xss) +} \ No newline at end of file