Skip to content

Commit 8e2aae2

Browse files
committed
Refine widening of enumCases
Widen whenever it is possible to do so while still conforming to expected type.
1 parent c42db70 commit 8e2aae2

File tree

4 files changed

+57
-46
lines changed

4 files changed

+57
-46
lines changed

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 42 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -286,18 +286,53 @@ trait ConstraintHandling {
286286
}
287287
}
288288

289+
/** If `tp` is an intersection such that some operands are super trait instances
290+
* and others are not, replace as many super trait instances as possible with Any
291+
* as long as the result is still a subtype of `bound`. But fall back to the
292+
* original type if the resulting widened type is a supertype of all dropped
293+
* types (since in this case the type was not a true intersection of super traits
294+
* and other types to start with).
295+
*/
296+
def dropSuperTraits(tp: Type, bound: Type)(using Context): Type =
297+
var kept: Set[Type] = Set() // types to keep since otherwise bound would not fit
298+
var dropped: List[Type] = List() // the types dropped so far, last one on top
299+
300+
def dropOneSuperTrait(tp: Type): Type =
301+
val tpd = tp.dealias
302+
if tpd.typeSymbol.isSuperTrait && !tpd.isLambdaSub && !kept.contains(tpd) then
303+
dropped = tpd :: dropped
304+
defn.AnyType
305+
else tpd match
306+
case AndType(tp1, tp2) =>
307+
val tp1w = dropOneSuperTrait(tp1)
308+
if tp1w ne tp1 then tp1w & tp2
309+
else
310+
val tp2w = dropOneSuperTrait(tp2)
311+
if tp2w ne tp2 then tp1 & tp2w
312+
else tpd
313+
case _ =>
314+
tp
315+
316+
def recur(tp: Type): Type =
317+
val tpw = dropOneSuperTrait(tp)
318+
if tpw eq tp then tp
319+
else if tpw <:< bound then recur(tpw)
320+
else
321+
kept += dropped.head
322+
dropped = dropped.tail
323+
recur(tp)
324+
325+
val tpw = recur(tp)
326+
if (tpw eq tp) || dropped.forall(_ frozen_<:< tpw) then tp else tpw
327+
end dropSuperTraits
328+
289329
/** Widen inferred type `inst` with upper `bound`, according to the following rules:
290330
* 1. If `inst` is a singleton type, or a union containing some singleton types,
291331
* widen (all) the singleton type(s), provided the result is a subtype of `bound`
292332
* (i.e. `inst.widenSingletons <:< bound` succeeds with satisfiable constraint)
293333
* 2. If `inst` is a union type, approximate the union type from above by an intersection
294334
* of all common base types, provided the result is a subtype of `bound`.
295-
* 3. If `inst` is an intersection such that some operands are super trait instances
296-
* and others are not, replace as many super trait instances as possible with Any
297-
* as long as the result is still a subtype of `bound`. But fall back to the
298-
* original type if the resulting widened type is a supertype of all dropped
299-
* types (since in this case the type was not a true intersection of super traits
300-
* and other types to start with).
335+
* 3. drop super traits from intersections (see @dropSuperTraits)
301336
*
302337
* Don't do these widenings if `bound` is a subtype of `scala.Singleton`.
303338
* Also, if the result of these widenings is a TypeRef to a module class,
@@ -308,40 +343,6 @@ trait ConstraintHandling {
308343
* as those could leak the annotation to users (see run/inferred-repeated-result).
309344
*/
310345
def widenInferred(inst: Type, bound: Type)(using Context): Type =
311-
312-
def dropSuperTraits(tp: Type): Type =
313-
var kept: Set[Type] = Set() // types to keep since otherwise bound would not fit
314-
var dropped: List[Type] = List() // the types dropped so far, last one on top
315-
316-
def dropOneSuperTrait(tp: Type): Type =
317-
val tpd = tp.dealias
318-
if tpd.typeSymbol.isSuperTrait && !tpd.isLambdaSub && !kept.contains(tpd) then
319-
dropped = tpd :: dropped
320-
defn.AnyType
321-
else tpd match
322-
case AndType(tp1, tp2) =>
323-
val tp1w = dropOneSuperTrait(tp1)
324-
if tp1w ne tp1 then tp1w & tp2
325-
else
326-
val tp2w = dropOneSuperTrait(tp2)
327-
if tp2w ne tp2 then tp1 & tp2w
328-
else tpd
329-
case _ =>
330-
tp
331-
332-
def recur(tp: Type): Type =
333-
val tpw = dropOneSuperTrait(tp)
334-
if tpw eq tp then tp
335-
else if tpw <:< bound then recur(tpw)
336-
else
337-
kept += dropped.head
338-
dropped = dropped.tail
339-
recur(tp)
340-
341-
val tpw = recur(tp)
342-
if (tpw eq tp) || dropped.forall(_ frozen_<:< tpw) then tp else tpw
343-
end dropSuperTraits
344-
345346
def widenOr(tp: Type) =
346347
val tpw = tp.widenUnion
347348
if (tpw ne tp) && (tpw <:< bound) then tpw else tp
@@ -356,7 +357,7 @@ trait ConstraintHandling {
356357

357358
val wideInst =
358359
if isSingleton(bound) then inst
359-
else dropSuperTraits(widenOr(widenSingle(inst)))
360+
else dropSuperTraits(widenOr(widenSingle(inst)), bound)
360361
wideInst match
361362
case wideInst: TypeRef if wideInst.symbol.is(Module) =>
362363
TermRef(wideInst.prefix, wideInst.symbol.sourceModule)

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2625,6 +2625,9 @@ object TypeComparer {
26252625
def widenInferred(inst: Type, bound: Type)(using Context): Type =
26262626
comparing(_.widenInferred(inst, bound))
26272627

2628+
def dropSuperTraits(tp: Type, bound: Type)(using Context): Type =
2629+
comparing(_.dropSuperTraits(tp, bound))
2630+
26282631
def constrainPatternType(pat: Type, scrut: Type)(using Context): Boolean =
26292632
comparing(_.constrainPatternType(pat, scrut))
26302633

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1123,12 +1123,14 @@ trait Applications extends Compatibility {
11231123
def isEnumApply = sym.name == nme.apply && sym.owner.linkedClass.isEnumCase
11241124
if sym.is(Synthetic) && (isEnumApply || isEnumCopy)
11251125
&& tree.tpe.classSymbol.isEnumCase
1126-
&& !pt.isInstanceOf[FunProto]
1127-
&& !pt.classSymbol.isEnumCase
1126+
&& tree.tpe.widen.isValueType
11281127
then
1129-
Typed(tree, TypeTree(tree.tpe.parents.reduceLeft(TypeComparer.andType(_, _))))
1130-
else
1131-
tree
1128+
val widened = TypeComparer.dropSuperTraits(
1129+
tree.tpe.parents.reduceLeft(TypeComparer.andType(_, _)),
1130+
pt)
1131+
if widened <:< pt then Typed(tree, TypeTree(widened))
1132+
else tree
1133+
else tree
11321134

11331135
/** Does `state` contain a "NotAMember" or "MissingIdent" message as
11341136
* first pending error message? That message would be

tests/pos/enum-widen.scala

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,9 @@ object test:
1313
x = None
1414
xc = None
1515

16+
enum Nat:
17+
case Z
18+
case S[N <: Z.type | S[_]](pred: N)
19+
import Nat._
1620

21+
val two = S(S(Z))

0 commit comments

Comments
 (0)