Skip to content

Commit a570e24

Browse files
authored
Merge pull request #2552 from dotty-staging/add-enum-eq
Add Enum Eq
2 parents 96854ef + bce0df9 commit a570e24

28 files changed

+659
-37
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 47 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -72,9 +72,12 @@ object desugar {
7272
val defctx = ctx.outersIterator.dropWhile(_.scope eq ctx.scope).next
7373
var local = defctx.denotNamed(tp.name).suchThat(_ is ParamOrAccessor).symbol
7474
if (local.exists) (defctx.owner.thisType select local).dealias
75-
else throw new java.lang.Error(
76-
s"no matching symbol for ${tp.symbol.showLocated} in ${defctx.owner} / ${defctx.effectiveScope}"
77-
)
75+
else {
76+
def msg =
77+
s"no matching symbol for ${tp.symbol.showLocated} in ${defctx.owner} / ${defctx.effectiveScope}"
78+
if (ctx.reporter.errorsReported) new ErrorType(msg)
79+
else throw new java.lang.Error(msg)
80+
}
7881
case _ =>
7982
mapOver(tp)
8083
}
@@ -124,7 +127,7 @@ object desugar {
124127
else vdef
125128
}
126129

127-
def makeImplicitParameters(tpts: List[Tree], forPrimaryConstructor: Boolean)(implicit ctx: Context) =
130+
def makeImplicitParameters(tpts: List[Tree], forPrimaryConstructor: Boolean = false)(implicit ctx: Context) =
128131
for (tpt <- tpts) yield {
129132
val paramFlags: FlagSet = if (forPrimaryConstructor) PrivateLocalParamAccessor else Param
130133
val epname = EvidenceParamName.fresh()
@@ -265,7 +268,7 @@ object desugar {
265268
val mods = cdef.mods
266269
val companionMods = mods
267270
.withFlags((mods.flags & AccessFlags).toCommonFlags)
268-
.withMods(mods.mods.filter(!_.isInstanceOf[Mod.EnumCase]))
271+
.withMods(Nil)
269272

270273
val (constr1, defaultGetters) = defDef(constr0, isPrimaryConstructor = true) match {
271274
case meth: DefDef => (meth, Nil)
@@ -291,7 +294,7 @@ object desugar {
291294

292295
val isCaseClass = mods.is(Case) && !mods.is(Module)
293296
val isCaseObject = mods.is(Case) && mods.is(Module)
294-
val isEnum = mods.hasMod[Mod.Enum]
297+
val isEnum = mods.hasMod[Mod.Enum] && !mods.is(Module)
295298
val isEnumCase = isLegalEnumCase(cdef)
296299
val isValueClass = parents.nonEmpty && isAnyVal(parents.head)
297300
// This is not watertight, but `extends AnyVal` will be replaced by `inline` later.
@@ -326,10 +329,12 @@ object desugar {
326329

327330
val classTycon: Tree = new TypeRefTree // watching is set at end of method
328331

329-
def appliedRef(tycon: Tree) =
330-
(if (constrTparams.isEmpty) tycon
331-
else AppliedTypeTree(tycon, constrTparams map refOfDef))
332-
.withPos(cdef.pos.startPos)
332+
def appliedTypeTree(tycon: Tree, args: List[Tree]) =
333+
(if (args.isEmpty) tycon else AppliedTypeTree(tycon, args))
334+
.withPos(cdef.pos.startPos)
335+
336+
def appliedRef(tycon: Tree, tparams: List[TypeDef] = constrTparams) =
337+
appliedTypeTree(tycon, tparams map refOfDef)
333338

334339
// a reference to the class type bound by `cdef`, with type parameters coming from the constructor
335340
val classTypeRef = appliedRef(classTycon)
@@ -344,8 +349,7 @@ object desugar {
344349
else {
345350
ctx.error(i"explicit extends clause needed because type parameters of case and enum class differ"
346351
, cdef.pos.startPos)
347-
AppliedTypeTree(enumClassRef, constrTparams map (_ => anyRef))
348-
.withPos(cdef.pos.startPos)
352+
appliedTypeTree(enumClassRef, constrTparams map (_ => anyRef))
349353
}
350354
case _ =>
351355
enumClassRef
@@ -411,6 +415,31 @@ object desugar {
411415
if (isEnum)
412416
parents1 = parents1 :+ ref(defn.EnumType)
413417

418+
// The Eq instance for an Enum class. For an enum class
419+
//
420+
// enum class C[T1, ..., Tn]
421+
//
422+
// we generate:
423+
//
424+
// implicit def eqInstance[T1$1, ..., Tn$1, T1$2, ..., Tn$2](implicit
425+
// ev1: Eq[T1$1, T1$2], ..., evn: Eq[Tn$1, Tn$2]])
426+
// : Eq[C[T1$1, ..., Tn$1], C[T1$2, ..., Tn$2]] = Eq
427+
def eqInstance = {
428+
def append(tdef: TypeDef, str: String) = cpy.TypeDef(tdef)(name = tdef.name ++ str)
429+
val leftParams = derivedTparams.map(append(_, "$1"))
430+
val rightParams = derivedTparams.map(append(_, "$2"))
431+
val subInstances = (leftParams, rightParams).zipped.map((param1, param2) =>
432+
appliedRef(ref(defn.EqType), List(param1, param2)))
433+
DefDef(
434+
name = nme.eqInstance,
435+
tparams = leftParams ++ rightParams,
436+
vparamss = List(makeImplicitParameters(subInstances)),
437+
tpt = appliedTypeTree(ref(defn.EqType),
438+
appliedRef(classTycon, leftParams) :: appliedRef(classTycon, rightParams) :: Nil),
439+
rhs = ref(defn.EqModule.termRef)).withFlags(Synthetic | Implicit)
440+
}
441+
def eqInstances = if (isEnum) eqInstance :: Nil else Nil
442+
414443
// The thicket which is the desugared version of the companion object
415444
// synthetic object C extends parentTpt { defs }
416445
def companionDefs(parentTpt: Tree, defs: List[Tree]) =
@@ -420,6 +449,8 @@ object desugar {
420449
.withMods(companionMods | Synthetic))
421450
.withPos(cdef.pos).toList
422451

452+
val companionMeths = defaultGetters ::: eqInstances
453+
423454
// The companion object definitions, if a companion is needed, Nil otherwise.
424455
// companion definitions include:
425456
// 1. If class is a case class case class C[Ts](p1: T1, ..., pN: TN)(moreParams):
@@ -465,10 +496,10 @@ object desugar {
465496
DefDef(nme.unapply, derivedTparams, (unapplyParam :: Nil) :: Nil, TypeTree(), unapplyRHS)
466497
.withMods(synthetic)
467498
}
468-
companionDefs(parent, applyMeths ::: unapplyMeth :: defaultGetters)
499+
companionDefs(parent, applyMeths ::: unapplyMeth :: companionMeths)
469500
}
470-
else if (defaultGetters.nonEmpty)
471-
companionDefs(anyRef, defaultGetters)
501+
else if (companionMeths.nonEmpty)
502+
companionDefs(anyRef, companionMeths)
472503
else if (isValueClass) {
473504
constr0.vparamss match {
474505
case List(_ :: Nil) => companionDefs(anyRef, Nil)
@@ -739,7 +770,7 @@ object desugar {
739770
}
740771

741772
def makeImplicitFunction(formals: List[Type], body: Tree)(implicit ctx: Context): Tree = {
742-
val params = makeImplicitParameters(formals.map(TypeTree), forPrimaryConstructor = false)
773+
val params = makeImplicitParameters(formals.map(TypeTree))
743774
new ImplicitFunction(params, body)
744775
}
745776

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ class Definitions {
547547

548548
lazy val EqType = ctx.requiredClassRef("scala.Eq")
549549
def EqClass(implicit ctx: Context) = EqType.symbol.asClass
550+
def EqModule(implicit ctx: Context) = EqClass.companionModule
550551

551552
lazy val XMLTopScopeModuleRef = ctx.requiredModuleRef("scala.xml.TopScope")
552553

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -400,6 +400,7 @@ object StdNames {
400400
val ensureAccessible : N = "ensureAccessible"
401401
val enumTag: N = "enumTag"
402402
val eq: N = "eq"
403+
val eqInstance: N = "eqInstance"
403404
val equalsNumChar : N = "equalsNumChar"
404405
val equalsNumNum : N = "equalsNumNum"
405406
val equalsNumObject : N = "equalsNumObject"

compiler/src/dotty/tools/dotc/typer/Implicits.scala

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -656,7 +656,7 @@ trait Implicits { self: Typer =>
656656
if (!ctx.isAfterTyper && !assumedCanEqual(ltp, rtp)) {
657657
val res = inferImplicitArg(
658658
defn.EqType.appliedTo(ltp, rtp), msgFun => ctx.error(msgFun(""), pos), pos)
659-
implicits.println(i"Eq witness found: $res: ${res.tpe}")
659+
implicits.println(i"Eq witness found for $ltp / $rtp: $res: ${res.tpe}")
660660
}
661661

662662
/** Find an implicit parameter or conversion.
@@ -676,7 +676,7 @@ trait Implicits { self: Typer =>
676676
val isearch =
677677
if (ctx.settings.explainImplicits.value) new ExplainedImplicitSearch(pt, argument, pos)
678678
else new ImplicitSearch(pt, argument, pos)
679-
val result = isearch.bestImplicit
679+
val result = isearch.bestImplicit(contextual = true)
680680
result match {
681681
case result: SearchSuccess =>
682682
result.tstate.commit()
@@ -743,7 +743,7 @@ trait Implicits { self: Typer =>
743743
def typedImplicit(cand: Candidate)(implicit ctx: Context): SearchResult = track("typedImplicit") { ctx.traceIndented(i"typed implicit ${cand.ref}, pt = $pt, implicitsEnabled == ${ctx.mode is ImplicitsEnabled}", implicits, show = true) {
744744
assert(constr eq ctx.typerState.constraint)
745745
val ref = cand.ref
746-
var generated: Tree = tpd.ref(ref).withPos(pos.startPos)
746+
var generated: Tree = tpd.ref(ref).withPos(pos)
747747
if (!argument.isEmpty)
748748
generated = typedUnadapted(
749749
untpd.Apply(untpd.TypedSplice(generated), untpd.TypedSplice(argument) :: Nil),
@@ -759,13 +759,20 @@ trait Implicits { self: Typer =>
759759
case _ => false
760760
}
761761
}
762-
// Does there exist an implicit value of type `Eq[tp, tp]`?
763-
def hasEq(tp: Type): Boolean =
764-
new ImplicitSearch(defn.EqType.appliedTo(tp, tp), EmptyTree, pos).bestImplicit match {
765-
case result: SearchSuccess => result.ref.symbol != defn.Predef_eqAny
766-
case result: AmbiguousImplicits => true
767-
case _ => false
768-
}
762+
// Does there exist an implicit value of type `Eq[tp, tp]`
763+
// which is different from `eqAny`?
764+
def hasEq(tp: Type): Boolean = {
765+
def search(contextual: Boolean): Boolean =
766+
new ImplicitSearch(defn.EqType.appliedTo(tp, tp), EmptyTree, pos)
767+
.bestImplicit(contextual) match {
768+
case result: SearchSuccess =>
769+
result.ref.symbol != defn.Predef_eqAny ||
770+
contextual && search(contextual = false)
771+
case result: AmbiguousImplicits => true
772+
case _ => false
773+
}
774+
search(contextual = true)
775+
}
769776

770777
def validEqAnyArgs(tp1: Type, tp2: Type) = {
771778
List(tp1, tp2).foreach(fullyDefinedType(_, "eqAny argument", pos))
@@ -872,12 +879,15 @@ trait Implicits { self: Typer =>
872879
}
873880

874881
/** Find a unique best implicit reference */
875-
def bestImplicit: SearchResult = {
876-
searchImplicits(ctx.implicits.eligible(wildProto), contextual = true) match {
882+
def bestImplicit(contextual: Boolean): SearchResult = {
883+
val eligible =
884+
if (contextual) ctx.implicits.eligible(wildProto)
885+
else implicitScope(wildProto).eligible
886+
searchImplicits(eligible, contextual) match {
877887
case result: SearchSuccess => result
878888
case result: AmbiguousImplicits => result
879889
case result: SearchFailure =>
880-
searchImplicits(implicitScope(wildProto).eligible, contextual = false)
890+
if (contextual) bestImplicit(contextual = false) else result
881891
}
882892
}
883893

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1996,7 +1996,8 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
19961996
case _: RefTree | _: Literal
19971997
if !isVarPattern(tree) &&
19981998
!(tree.tpe <:< pt)(ctx.addMode(Mode.GADTflexible)) =>
1999-
checkCanEqual(pt, wtp, tree.pos)(ctx.retractMode(Mode.Pattern))
1999+
val tp1 :: tp2 :: Nil = harmonizeTypes(pt :: wtp :: Nil)
2000+
checkCanEqual(tp1, tp2, tree.pos)(ctx.retractMode(Mode.Pattern))
20002001
case _ =>
20012002
}
20022003
tree
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
---
2+
layout: doc-page
3+
title: "Automatic Tupling of Function Parameters"
4+
---
5+
6+
Say you have a list of pairs
7+
8+
val xs: List[(Int, Int)]
9+
10+
and you want to map `xs` to a list of `Int`s so that eich pair of numbers is mapped to
11+
their sum. Previously, the best way to do this was with a pattern-matching decomposition:
12+
13+
xs map {
14+
case (x, y) => x + y
15+
}
16+
17+
While correct, this is also inconvenient. Dotty now also allows:
18+
19+
xs.map {
20+
(x, y) => x + y
21+
}
22+
23+
or, equivalently:
24+
25+
xs.map(_ + _)
26+
27+
Generally, a function value with `n > 1` parameters is converted to a
28+
pattern-matching closure using `case` if the expected type is a unary
29+
function type of the form `((T_1, ..., T_n)) => U`.
30+
31+
32+
33+

docs/docs/reference/desugarEnums.md

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ comma separated simple cases into a sequence of cases.
8989

9090
case C <params> ...
9191

92-
expands analogous to a case class:
92+
expands analogous to a final case class:
9393

9494
final case class C <params> ...
9595

@@ -138,6 +138,15 @@ comma separated simple cases into a sequence of cases.
138138
Any modifiers or annotations on the original case extend to all expanded
139139
cases.
140140

141+
## Equality
142+
143+
An `enum` type contains a `scala.Eq` instance that restricts values of the `enum` type to
144+
be compared only to other values of the same enum type. Furtermore, generic
145+
`enum` types are comparable only if their type arguments are. For instance the
146+
`Option` enum type will get the following definition in its companion object:
147+
148+
implicit def eqOption[T, U](implicit ev1: Eq[T, U]): Eq[Option[T], Option[U]] = Eq
149+
141150
## Translation of Enumerations
142151

143152
Non-generic enum classes `E` that define one or more singleton cases
Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
---
2+
layout: doc-page
3+
title: "Implicit Function Types"
4+
---
5+
6+
An implicit funciton type describes functions with implicit parameters. Example:
7+
8+
type Contextual[T] = implicit Context => T
9+
10+
A value of implicit function type is applied to implicit arguments, in
11+
the same way a method with implicit parameters is applied. For instance:
12+
13+
implicit ctx: Context = ...
14+
15+
def f(x: Int): Contextual[Int] = ...
16+
17+
f(2) // is expanded to f(2)(ctx)
18+
19+
Conversely, if the expected type of an expression `E` is an implicit
20+
function type `implicit (T_1, ..., T_n) => U` and `E` is not already an
21+
implicit function value, `E` is converted to an implicit function value
22+
by rewriting to
23+
24+
implicit (x_1: T1, ..., x_n: Tn) => E
25+
26+
where the names `x_1`, ..., `x_n` are arbitrary. For example, continuing
27+
with the previous definitions,
28+
29+
def g(arg: Contextual[Int]) = ...
30+
31+
g(22) // is expanded to g { implicit ctx => 22 }
32+
33+
g(f(2)) // is expanded to g { implicit ctx => f(2)(ctx) }
34+
35+
g(implicit ctx => f(22)(ctx)) // is left as it is
36+
37+
Implicit function types have considerable expressive power. For
38+
instance, here is how they can support the "builder pattern", where
39+
the aim is to construct tables like this:
40+
41+
table {
42+
row {
43+
cell("top left")
44+
cell("top right")
45+
}
46+
row {
47+
cell("botttom left")
48+
cell("bottom right")
49+
}
50+
}
51+
52+
The idea is to define classes for `Table` and `Row` that allow
53+
addition of elements via `add`:
54+
55+
class Table {
56+
val rows = new ArrayBuffer[Row]
57+
def add(r: Row): Unit = rows += r
58+
override def toString = rows.mkString("Table(", ", ", ")")
59+
}
60+
61+
class Row {
62+
val cells = new ArrayBuffer[Cell]
63+
def add(c: Cell): Unit = cells += c
64+
override def toString = cells.mkString("Row(", ", ", ")")
65+
}
66+
67+
case class Cell(elem: String)
68+
69+
Then, the `table`, `row` and `cell` constructor methods can be defined
70+
in terms of implicit function types to avoid the plumbing boilerplate
71+
that would otherwise be necessary.
72+
73+
def table(init: implicit Table => Unit) = {
74+
implicit val t = new Table
75+
init
76+
t
77+
}
78+
79+
def row(init: implicit Row => Unit)(implicit t: Table) = {
80+
implicit val r = new Row
81+
init
82+
t.add(r)
83+
}
84+
85+
def cell(str: String)(implicit r: Row) =
86+
r.add(new Cell(str))
87+
88+
With that setup, the table construction code above compiles and expands to:
89+
90+
table { implicit $t: Table =>
91+
row { implicit $r: Row =>
92+
cell("top left")($r)
93+
cell("top right")($r)
94+
}($t)
95+
row { implicit $r: Row =>
96+
cell("botttom left")($r)
97+
cell("bottom right")($r)
98+
}($t)
99+
}

0 commit comments

Comments
 (0)