Skip to content

Commit 6b53a8a

Browse files
committed
Add Pickler typeclass to test
1 parent d28b702 commit 6b53a8a

File tree

2 files changed

+178
-56
lines changed

2 files changed

+178
-56
lines changed

tests/run/typeclass-derivation2.check

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
ListBuffer(0, 11, 0, 22, 0, 33, 1)
2+
Cons(11,Cons(22,Cons(33,Nil)))
3+
ListBuffer(0, 0, 11, 0, 22, 0, 33, 1, 0, 0, 11, 0, 22, 1, 1)
4+
Cons(Cons(11,Cons(22,Cons(33,Nil))),Cons(Cons(11,Cons(22,Nil)),Nil))

tests/run/typeclass-derivation2.scala

Lines changed: 174 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
import scala.collection.mutable
2+
3+
// Generic deriving infrastructure
14
object Deriving {
2-
import scala.typelevel._
35

46
enum Shape {
57
case Sum[Alts <: Tuple]
6-
case Product[T, Elems <: Tuple]
8+
case Case[T, Elems <: Tuple]
79
}
810

911
case class GenericCase[+T](ordinal: Int, elems: Array[Object])
@@ -14,88 +16,187 @@ object Deriving {
1416
}
1517

1618
abstract class HasShape[T, S <: Shape] extends GenericMapper[T]
19+
}
20+
21+
// A datatype
22+
enum Lst[+T] {
23+
case Cons(hd: T, tl: Lst[T])
24+
case Nil
25+
}
26+
27+
object Lst {
28+
// common compiler-generated infrastructure
29+
import Deriving._
30+
31+
type LstShape[T] = Shape.Sum[(
32+
Shape.Case[Cons[T], (T, Lst[T])],
33+
Shape.Case[Nil.type, Unit]
34+
)]
1735

18-
enum Lst[+T] {
19-
case Cons(hd: T, tl: Lst[T])
20-
case Nil
36+
implicit def lstShape[T]: HasShape[Lst[T], LstShape[T]] = new {
37+
def toGenericCase(xs: Lst[T]): GenericCase[Lst[T]] = xs match {
38+
case Cons(x, xs1) => GenericCase[Cons[T]](0, Array(x.asInstanceOf, xs1))
39+
case Nil => GenericCase[Nil.type](1, Array())
40+
}
41+
def fromGenericCase(c: GenericCase[Lst[T]]): Lst[T] = c.ordinal match {
42+
case 0 => Cons[T](c.elems(0).asInstanceOf, c.elems(1).asInstanceOf)
43+
case 1 => Nil
44+
}
2145
}
2246

23-
object Lst {
24-
type LstShape[T] = Shape.Sum[(
25-
Shape.Product[Cons[T], (T, Lst[T])],
26-
Shape.Product[Nil.type, Unit]
27-
)]
47+
// two clauses that could be generated from a `derives` clause
48+
implicit def LstEq[T: Eq]: Eq[Lst[T]] = Eq.derived
49+
implicit def LstPickler[T: Pickler]: Pickler[Lst[T]] = Pickler.derived
50+
}
2851

29-
implicit def lstShape[T]: HasShape[Lst[T], LstShape[T]] = new {
30-
def toGenericCase(xs: Lst[T]): GenericCase[Lst[T]] = xs match {
31-
case Cons(x, xs1) => GenericCase[Cons[T]](0, Array(x.asInstanceOf, xs1))
32-
case Nil => GenericCase[Nil.type](1, Array())
33-
}
34-
def fromGenericCase(c: GenericCase[Lst[T]]): Lst[T] = c.ordinal match {
35-
case 0 => Cons[T](c.elems(0).asInstanceOf, c.elems(1).asInstanceOf)
36-
case 1 => Nil
52+
// A typeclass
53+
trait Eq[T] {
54+
def equals(x: T, y: T): Boolean
55+
}
56+
57+
object Eq {
58+
import scala.typelevel._
59+
import Deriving._
60+
61+
inline def tryEquals[T](x: T, y: T) = implicit match {
62+
case eq: Eq[T] => eq.equals(x, y)
63+
}
64+
65+
inline def equalsElems[Elems <: Tuple](xs: Array[Object], ys: Array[Object], n: Int): Boolean =
66+
inline erasedValue[Elems] match {
67+
case _: (elem *: elems1) =>
68+
tryEquals[elem](xs(n).asInstanceOf, ys(n).asInstanceOf) &&
69+
equalsElems[elems1](xs, ys, n + 1)
70+
case _: Unit =>
71+
true
72+
}
73+
74+
inline def equalsCase[T, Elems <: Tuple](mapper: GenericMapper[T], x: T, y: T) =
75+
equalsElems[Elems](mapper.toGenericCase(x).elems, mapper.toGenericCase(y).elems, 0)
76+
77+
inline def equalsSum[T, Alts <: Tuple](mapper: GenericMapper[T], x: T, y: T): Boolean =
78+
inline erasedValue[Alts] match {
79+
case _: (Shape.Case[alt, elems] *: alts1) =>
80+
x match {
81+
case x: `alt` =>
82+
y match {
83+
case y: `alt` => equalsCase[T, elems](mapper, x, y)
84+
case _ => false
85+
}
86+
case _ => equalsSum[T, alts1](mapper, x, y)
3787
}
88+
case _: Unit =>
89+
false
90+
}
91+
92+
inline def derived[T, S <: Shape](implicit ev: HasShape[T, S]): Eq[T] = new {
93+
def equals(x: T, y: T): Boolean = inline erasedValue[S] match {
94+
case _: Shape.Sum[alts] =>
95+
equalsSum[T, alts](ev, x, y)
96+
case _: Shape.Case[_, elems] =>
97+
equalsCase[T, elems](ev, x, y)
3898
}
99+
}
39100

40-
implicit def LstEq[T: Eq]: Eq[Lst[T]] = Eq.derived
101+
implicit object IntEq extends Eq[Int] {
102+
def equals(x: Int, y: Int) = x == y
41103
}
104+
}
105+
106+
// Another typeclass
107+
trait Pickler[T] {
108+
def pickle(buf: mutable.ListBuffer[Int], x: T): Unit
109+
def unpickle(buf: mutable.ListBuffer[Int]): T
110+
}
111+
112+
object Pickler {
113+
import scala.typelevel._
114+
import Deriving._
115+
116+
def nextInt(buf: mutable.ListBuffer[Int]): Int = try buf.head finally buf.trimStart(1)
42117

43-
trait Eq[T] {
44-
def equals(x: T, y: T): Boolean
118+
inline def tryPickle[T](buf: mutable.ListBuffer[Int], x: T): Unit = implicit match {
119+
case pkl: Pickler[T] => pkl.pickle(buf, x)
45120
}
46121

47-
object Eq {
48-
inline def tryEq[T](x: T, y: T) = implicit match {
49-
case eq: Eq[T] => eq.equals(x, y)
122+
inline def pickleElems[Elems <: Tuple](buf: mutable.ListBuffer[Int], elems: Array[AnyRef], n: Int): Unit =
123+
inline erasedValue[Elems] match {
124+
case _: (elem *: elems1) =>
125+
tryPickle[elem](buf, elems(n).asInstanceOf[elem])
126+
pickleElems[elems1](buf, elems, n + 1)
127+
case _: Unit =>
50128
}
51129

52-
inline def deriveForSum[T, Alts <: Tuple](mapper: GenericMapper[T], x: T, y: T): Boolean =
53-
inline erasedValue[Alts] match {
54-
case _: (Shape.Product[alt, elems] *: alts1) =>
55-
x match {
56-
case x: `alt` =>
57-
y match {
58-
case y: `alt` => deriveForProduct[T, elems](mapper, x, y)
59-
case _ => false
60-
}
61-
case _ => deriveForSum[T, alts1](mapper, x, y)
130+
inline def pickleCase[T, Elems <: Tuple](mapper: GenericMapper[T], buf: mutable.ListBuffer[Int], x: T): Unit = {
131+
val c = mapper.toGenericCase(x)
132+
buf += c.ordinal
133+
pickleElems[Elems](buf, c.elems, 0)
134+
}
135+
136+
inline def pickleSum[T, Alts <: Tuple](mapper: GenericMapper[T], buf: mutable.ListBuffer[Int], x: T): Unit =
137+
inline erasedValue[Alts] match {
138+
case _: (Shape.Case[alt, elems] *: alts1) =>
139+
x match {
140+
case x: `alt` => pickleCase[T, elems](mapper, buf, x)
141+
case _ => pickleSum[T, alts1](mapper, buf, x)
62142
}
63143
case _: Unit =>
64-
false
65144
}
66145

67-
inline def deriveForProduct[T, Elems <: Tuple](mapper: GenericMapper[T], x: T, y: T) =
68-
deriveForTuple[Elems](0, mapper.toGenericCase(x).elems, mapper.toGenericCase(y).elems)
146+
inline def tryUnpickle[T](buf: mutable.ListBuffer[Int]): T = implicit match {
147+
case pkl: Pickler[T] => pkl.unpickle(buf)
148+
}
69149

70-
inline def deriveForTuple[Elems <: Tuple](n: Int, xs: Array[Object], ys: Array[Object]): Boolean =
71-
inline erasedValue[Elems] match {
72-
case _: (elem *: elems1) =>
73-
tryEq[elem](xs(n).asInstanceOf, ys(n).asInstanceOf) &&
74-
deriveForTuple[elems1](n + 1, xs, ys)
75-
case _: Unit =>
76-
true
77-
}
150+
inline def unpickleElems[Elems <: Tuple](buf: mutable.ListBuffer[Int], elems: Array[AnyRef], n: Int): Unit =
151+
inline erasedValue[Elems] match {
152+
case _: (elem *: elems1) =>
153+
elems(n) = tryUnpickle[elem](buf).asInstanceOf[AnyRef]
154+
unpickleElems[elems1](buf, elems, n + 1)
155+
case _: Unit =>
156+
}
78157

79-
inline def derived[T, S <: Shape](implicit ev: HasShape[T, S]): Eq[T] = new {
80-
def equals(x: T, y: T): Boolean = inline erasedValue[S] match {
81-
case _: Shape.Sum[alts] =>
82-
deriveForSum[T, alts](ev, x, y)
83-
case _: Shape.Product[_, elems] =>
84-
deriveForProduct[T, elems](ev, x, y)
85-
}
158+
inline def unpickleCase[T, Elems <: Tuple](mapper: GenericMapper[T], buf: mutable.ListBuffer[Int], ordinal: Int): T = {
159+
val elems = new Array[Object](constValue[Tuple.Size[Elems]])
160+
unpickleElems[Elems](buf, elems, 0)
161+
mapper.fromGenericCase(GenericCase(ordinal, elems))
162+
}
163+
164+
inline def unpickleSum[T, Alts <: Tuple](mapper: GenericMapper[T], buf: mutable.ListBuffer[Int], ordinal: Int, n: Int): T =
165+
inline erasedValue[Alts] match {
166+
case _: (Shape.Case[alt, elems] *: alts1) =>
167+
if (n == ordinal) unpickleCase[T, elems](mapper, buf, ordinal)
168+
else unpickleSum[T, alts1](mapper, buf, ordinal, n + 1)
169+
case _ =>
170+
throw new IndexOutOfBoundsException(s"unexpected ordinal number: $ordinal")
86171
}
87172

88-
implicit object eqInt extends Eq[Int] {
89-
def equals(x: Int, y: Int) = x == y
173+
inline def derived[T, S <: Shape](implicit ev: HasShape[T, S]): Pickler[T] = new {
174+
def pickle(buf: mutable.ListBuffer[Int], x: T): Unit = inline erasedValue[S] match {
175+
case _: Shape.Sum[alts] =>
176+
pickleSum[T, alts](ev, buf, x)
177+
case _: Shape.Case[_, elems] =>
178+
pickleCase[T, elems](ev, buf, x)
179+
}
180+
def unpickle(buf: mutable.ListBuffer[Int]): T = inline erasedValue[S] match {
181+
case _: Shape.Sum[alts] =>
182+
unpickleSum[T, alts](ev, buf, nextInt(buf), 0)
183+
case _: Shape.Case[_, elems] =>
184+
unpickleCase[T, elems](ev, buf, 0)
90185
}
91186
}
187+
188+
implicit object IntPickler extends Pickler[Int] {
189+
def pickle(buf: mutable.ListBuffer[Int], x: Int): Unit = buf += x
190+
def unpickle(buf: mutable.ListBuffer[Int]): Int = nextInt(buf)
191+
}
92192
}
93193

194+
// Tests
94195
object Test extends App {
95196
import Deriving._
96197
val eq = implicitly[Eq[Lst[Int]]]
97-
val xs = Lst.Cons(1, Lst.Cons(2, Lst.Cons(3, Lst.Nil)))
98-
val ys = Lst.Cons(1, Lst.Cons(2, Lst.Nil))
198+
val xs = Lst.Cons(11, Lst.Cons(22, Lst.Cons(33, Lst.Nil)))
199+
val ys = Lst.Cons(11, Lst.Cons(22, Lst.Nil))
99200
assert(eq.equals(xs, xs))
100201
assert(!eq.equals(xs, ys))
101202
assert(!eq.equals(ys, xs))
@@ -108,4 +209,21 @@ object Test extends App {
108209
assert(!eq2.equals(xss, yss))
109210
assert(!eq2.equals(yss, xss))
110211
assert(eq2.equals(yss, yss))
212+
213+
val buf = new mutable.ListBuffer[Int]
214+
val pkl = implicitly[Pickler[Lst[Int]]]
215+
pkl.pickle(buf, xs)
216+
println(buf)
217+
val xs1 = pkl.unpickle(buf)
218+
println(xs1)
219+
assert(xs1 == xs)
220+
assert(eq.equals(xs1, xs))
221+
222+
val pkl2 = implicitly[Pickler[Lst[Lst[Int]]]]
223+
pkl2.pickle(buf, xss)
224+
println(buf)
225+
val xss1 = pkl2.unpickle(buf)
226+
println(xss1)
227+
assert(xss == xss1)
228+
assert(eq2.equals(xss, xss1))
111229
}

0 commit comments

Comments
 (0)