Skip to content

Commit 4ec5446

Browse files
committed
fix #3935: widen inferred enum types, precise factory method
1 parent b3f908d commit 4ec5446

File tree

7 files changed

+81
-25
lines changed

7 files changed

+81
-25
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -527,6 +527,12 @@ object desugar {
527527
// a reference to the class type bound by `cdef`, with type parameters coming from the constructor
528528
val classTypeRef = appliedRef(classTycon)
529529

530+
def applyResultTpt =
531+
if isEnumCase then
532+
classTypeRef
533+
else
534+
TypeTree()
535+
530536
// a reference to `enumClass`, with type parameters coming from the case constructor
531537
lazy val enumClassTypeRef =
532538
if (enumClass.typeParams.isEmpty)
@@ -605,7 +611,7 @@ object desugar {
605611
cpy.ValDef(vparam)(rhs = copyDefault(vparam)))
606612
val copyRestParamss = derivedVparamss.tail.nestedMap(vparam =>
607613
cpy.ValDef(vparam)(rhs = EmptyTree))
608-
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, TypeTree(), creatorExpr)
614+
DefDef(nme.copy, derivedTparams, copyFirstParams :: copyRestParamss, applyResultTpt, creatorExpr)
609615
.withMods(Modifiers(Synthetic | constr1.mods.flags & copiedAccessFlags, constr1.mods.privateWithin)) :: Nil
610616
}
611617
}
@@ -656,15 +662,6 @@ object desugar {
656662
// For all other classes, the parent is AnyRef.
657663
val companions =
658664
if (isCaseClass) {
659-
// The return type of the `apply` method, and an (empty or singleton) list
660-
// of widening coercions
661-
val (applyResultTpt, widenDefs) =
662-
if (!isEnumCase)
663-
(TypeTree(), Nil)
664-
else if (parents.isEmpty || enumClass.typeParams.isEmpty)
665-
(enumClassTypeRef, Nil)
666-
else
667-
enumApplyResult(cdef, parents, derivedEnumParams, appliedRef(enumClassRef, derivedEnumParams))
668665

669666
// true if access to the apply method has to be restricted
670667
// i.e. if the case class constructor is either private or qualified private
@@ -695,8 +692,6 @@ object desugar {
695692
then anyRef
696693
else
697694
constrVparamss.foldRight(classTypeRef)((vparams, restpe) => Function(vparams map (_.tpt), restpe))
698-
def widenedCreatorExpr =
699-
widenDefs.foldLeft(creatorExpr)((rhs, meth) => Apply(Ident(meth.name), rhs :: Nil))
700695
val applyMeths =
701696
if (mods.is(Abstract)) Nil
702697
else {
@@ -709,9 +704,8 @@ object desugar {
709704
val appParamss =
710705
derivedVparamss.nestedZipWithConserve(constrVparamss)((ap, cp) =>
711706
ap.withMods(ap.mods | (cp.mods.flags & HasDefault)))
712-
val app = DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, widenedCreatorExpr)
713-
.withMods(appMods)
714-
app :: widenDefs
707+
DefDef(nme.apply, derivedTparams, appParamss, applyResultTpt, creatorExpr)
708+
.withMods(appMods) :: Nil
715709
}
716710
val unapplyMeth = {
717711
val hasRepeatedParam = constrVparamss.head.exists {
@@ -720,7 +714,7 @@ object desugar {
720714
val methName = if (hasRepeatedParam) nme.unapplySeq else nme.unapply
721715
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
722716
val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name)
723-
val unapplyResTp = if (arity == 0) Literal(Constant(true)) else TypeTree()
717+
val unapplyResTp = if (arity == 0) Literal(Constant(true)) else applyResultTpt
724718
DefDef(methName, derivedTparams, (unapplyParam :: Nil) :: Nil, unapplyResTp, unapplyRHS)
725719
.withMods(synthetic)
726720
}

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1180,6 +1180,8 @@ object Types {
11801180
def widenSingletons(using Context): Type = dealias match {
11811181
case tp: SingletonType =>
11821182
tp.widen
1183+
case tp: (TypeRef | AppliedType) if tp.typeSymbol.isAllOf(EnumCase) =>
1184+
tp.parents.head
11831185
case tp: OrType =>
11841186
val tp1w = tp.widenSingletons
11851187
if (tp1w eq tp) this else tp1w

compiler/src/dotty/tools/dotc/parsing/Scanners.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1399,8 +1399,8 @@ object Scanners {
13991399

14001400
object IndentWidth {
14011401
private inline val MaxCached = 40
1402-
private val spaces = Array.tabulate(MaxCached + 1)(new Run(' ', _))
1403-
private val tabs = Array.tabulate(MaxCached + 1)(new Run('\t', _))
1402+
private val spaces = Array.tabulate[Run](MaxCached + 1)(new Run(' ', _)) // TODO: remove new after bootstrap
1403+
private val tabs = Array.tabulate[Run](MaxCached + 1)(new Run('\t', _)) // TODO: remove new after bootstrap
14041404

14051405
def Run(ch: Char, n: Int): Run =
14061406
if (n <= MaxCached && ch == ' ') spaces(n)

tests/pos/i3935.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
enum Foo3[T](x: T) {
2+
case Bar[S, T](y: T) extends Foo3[y.type](y)
3+
}
4+
5+
val foo: Foo3.Bar[Nothing, 3] = Foo3.Bar(3)
6+
val bar = foo
7+
8+
def baz[T](f: Foo3[T]): f.type = f
9+
10+
val qux = baz(bar) // existentials are back in Dotty?

tests/run-macros/i8007/Macro_3.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,6 @@ object Eq {
6464
$ordx == $ordy && $elements($ordx).asInstanceOf[Eq[Any]].eqv($x, $y)
6565
}
6666
}
67-
6867
'{
6968
eqSum((x: T, y: T) => ${eqSumBody('x, 'y)})
7069
}
@@ -76,4 +75,4 @@ object Macro3 {
7675
extension [T](x: =>T) inline def === (y: =>T)(using eq: Eq[T]): Boolean = eq.eqv(x, y)
7776

7877
implicit inline def eqGen[T]: Eq[T] = ${ Eq.derived[T] }
79-
}
78+
}

tests/run-macros/i8007/Test_4.scala

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,22 @@ import Macro3.eqGen
66
case class Person(name: String, age: Int)
77

88
enum Opt[+T] {
9-
case Sm(t: T)
9+
case Sm[U](t: U) extends Opt[U]
1010
case Nn
1111
}
1212

13+
enum OptInfer[+T] {
14+
case Sm[+U](t: U) extends OptInfer[U]
15+
case Nn
16+
}
17+
18+
// simulation of Opt using case class hierarchy
19+
sealed abstract class OptCase[+T]
20+
object OptCase {
21+
final case class Sm[T](t: T) extends OptCase[T]
22+
case object Nn extends OptCase[Nothing]
23+
}
24+
1325
@main def Test() = {
1426
import Opt._
1527
import Eq.{given _, _}
@@ -30,15 +42,23 @@ enum Opt[+T] {
3042
println(t4) // false
3143
println
3244

33-
val t5 = Sm(23) === Sm(23)
45+
val t5 = Opt.Sm[Int](23) === Opt.Sm(23) // same behaviour as case class when using apply
3446
println(t5) // true
3547
println
3648

37-
val t6 = Sm(Person("Test", 23)) === Sm(Person("Test", 23))
49+
val t5_2 = OptCase.Sm[Int](23) === OptCase.Sm(23)
50+
println(t5_2) // true
51+
println
52+
53+
val t5_3 = OptInfer.Sm(23) === OptInfer.Sm(23) // covariant `Sm` case means we can avoid explicit type parameter
54+
println(t5_3) // true
55+
println
56+
57+
val t6 = Sm[Person](Person("Test", 23)) === Sm(Person("Test", 23))
3858
println(t6) // true
3959
println
4060

41-
val t7 = Sm(Person("Test", 23)) === Sm(Person("Test", 24))
61+
val t7 = Sm[Person](Person("Test", 23)) === Sm(Person("Test", 24))
4262
println(t7) // false
4363
println
44-
}
64+
}

tests/run/enum-precise.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
enum NonEmptyList[+T]:
2+
case Many[+U](head: U, tail: NonEmptyList[U]) extends NonEmptyList[U]
3+
case One [+U](value: U) extends NonEmptyList[U]
4+
5+
enum Ast:
6+
case Binding(name: String, tpe: String)
7+
case Lambda(args: NonEmptyList[Binding], rhs: Ast) // reference to another case of the enum
8+
case Ident(name: String)
9+
case Apply(fn: Ast, args: NonEmptyList[Ast])
10+
11+
import NonEmptyList._
12+
import Ast._
13+
14+
// This example showcases the widening when inferring enum case types.
15+
// With scala 2 case class hierarchies, if One.apply(1) returns One[Int] and Many.apply(2, One(3)) returns Many[Int]
16+
// then the `foldRight` expression below would complain that Many[Binding] is not One[Binding]. With Scala 3 enums,
17+
// .apply on the companion returns the precise class, but type inference will widen to NonEmptyList[Binding] unless
18+
// the precise class is expected.
19+
def Bindings(arg: (String, String), args: (String, String)*): NonEmptyList[Binding] =
20+
def Bind(arg: (String, String)): Binding =
21+
val (name, tpe) = arg
22+
Binding(name, tpe)
23+
24+
args.foldRight(One[Binding](Bind(arg)))((arg, acc) => Many(Bind(arg), acc))
25+
26+
@main def Test: Unit =
27+
val OneOfOne: One[1] = One[1](1)
28+
val True = Lambda(Bindings("x" -> "T", "y" -> "T"), Ident("x"))
29+
val Const = Lambda(One(Binding("x", "T")), Lambda(One(Binding("y", "U")), Ident("x"))) // precise type is forwarded
30+
31+
assert(OneOfOne.value == 1)

0 commit comments

Comments
 (0)