Skip to content

Commit 649a9e4

Browse files
committed
Move Constraint#fullBounds to ConstraintHandler
1 parent 7008144 commit 649a9e4

13 files changed

+149
-50
lines changed

compiler/src/dotty/tools/dotc/core/Constraint.scala

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ abstract class Constraint extends Showable {
4545
/** The parameters that are known to be greater wrt <: than `param` */
4646
def upper(param: TypeParamRef): List[TypeParamRef]
4747

48+
/** `lower`, except that `minLower.forall(tpr => !minLower.exists(_ <:< tpr))` */
49+
def minLower(param: TypeParamRef): List[TypeParamRef]
50+
51+
/** `upper`, except that `minUpper.forall(tpr => !minUpper.exists(tpr <:< _))` */
52+
def minUpper(param: TypeParamRef): List[TypeParamRef]
53+
4854
/** lower(param) \ lower(butNot) */
4955
def exclusiveLower(param: TypeParamRef, butNot: TypeParamRef): List[TypeParamRef]
5056

@@ -58,15 +64,6 @@ abstract class Constraint extends Showable {
5864
*/
5965
def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds
6066

61-
/** The lower bound of `param` including all known-to-be-smaller parameters */
62-
def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type
63-
64-
/** The upper bound of `param` including all known-to-be-greater parameters */
65-
def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type
66-
67-
/** The bounds of `param` including all known-to-be-smaller and -greater parameters */
68-
def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds
69-
7067
/** A new constraint which is derived from this constraint by adding
7168
* entries for all type parameters of `poly`.
7269
* @param tvars A list of type variables associated with the params,

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@ package dotty.tools
22
package dotc
33
package core
44

5-
import Types._, Contexts._, Symbols._
5+
import Types._
6+
import Contexts._
7+
import Symbols._
68
import Decorators._
79
import config.Config
810
import config.Printers.{constr, typr}
11+
import dotty.tools.dotc.reporting.trace
912

1013
/** Methods for adding constraints and solving them.
1114
*
@@ -31,6 +34,8 @@ trait ConstraintHandling[AbstractContext] {
3134
protected def constraint: Constraint
3235
protected def constraint_=(c: Constraint): Unit
3336

37+
protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type
38+
3439
private[this] var addConstraintInvocations = 0
3540

3641
/** If the constraint is frozen we cannot add new bounds to the constraint. */
@@ -66,6 +71,30 @@ trait ConstraintHandling[AbstractContext] {
6671
case tp => tp
6772
}
6873

74+
def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds =
75+
constraint.nonParamBounds(param) match {
76+
case TypeAlias(tpr: TypeParamRef) => TypeAlias(externalize(tpr))
77+
case tb => tb
78+
}
79+
80+
def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type =
81+
(nonParamBounds(param).lo /: constraint.minLower(param)) {
82+
(t, u) => t | externalize(u)
83+
}
84+
85+
def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type =
86+
(nonParamBounds(param).hi /: constraint.minUpper(param)) {
87+
(t, u) => t & externalize(u)
88+
}
89+
90+
/** Full bounds of `param`, including other lower/upper params.
91+
*
92+
* Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds`
93+
* of some param when comparing types might lead to infinite recursion. Consider `bounds` instead.
94+
*/
95+
def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds =
96+
nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param))
97+
6998
protected def addOneBound(param: TypeParamRef, bound: Type, isUpper: Boolean)(implicit actx: AbstractContext): Boolean =
7099
!constraint.contains(param) || {
71100
def occursIn(bound: Type): Boolean = {
@@ -261,7 +290,7 @@ trait ConstraintHandling[AbstractContext] {
261290
}
262291
constraint.entry(param) match {
263292
case _: TypeBounds =>
264-
val bound = if (fromBelow) constraint.fullLowerBound(param) else constraint.fullUpperBound(param)
293+
val bound = if (fromBelow) fullLowerBound(param) else fullUpperBound(param)
265294
val inst = avoidParam(bound)
266295
typr_println(s"approx ${param.show}, from below = $fromBelow, bound = ${bound.show}, inst = ${inst.show}")
267296
inst
@@ -312,7 +341,7 @@ trait ConstraintHandling[AbstractContext] {
312341
*/
313342
def instanceType(param: TypeParamRef, fromBelow: Boolean)(implicit actx: AbstractContext): Type = {
314343
val inst = approximation(param, fromBelow).simplified
315-
if (fromBelow) widenInferred(inst, constraint.fullUpperBound(param)) else inst
344+
if (fromBelow) widenInferred(inst, fullUpperBound(param)) else inst
316345
}
317346

318347
/** Constraint `c1` subsumes constraint `c2`, if under `c2` as constraint we have

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -765,7 +765,15 @@ object Contexts {
765765
sealed abstract class GADTMap {
766766
def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit
767767
def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean
768+
def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean
768769
def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds
770+
771+
/** Full bounds of `sym`, including TypeRefs to other lower/upper symbols.
772+
*
773+
* Note that underlying operations perform subtype checks - for this reason, recursing on `fullBounds`
774+
* of some symbol when comparing types might lead to infinite recursion. Consider `bounds` instead.
775+
*/
776+
def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds
769777
def contains(sym: Symbol)(implicit ctx: Context): Boolean
770778
def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type
771779
def debugBoundsDescription(implicit ctx: Context): String
@@ -794,6 +802,12 @@ object Contexts {
794802
override protected def constraint = myConstraint
795803
override protected def constraint_=(c: Constraint) = myConstraint = c
796804

805+
override protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type =
806+
reverseMapping(param) match {
807+
case sym: Symbol => sym.typeRef
808+
case null => param
809+
}
810+
797811
override def isSubType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSubType(tp1, tp2)
798812
override def isSameType(tp1: Type, tp2: Type)(implicit ctx: Context): Boolean = ctx.typeComparer.isSameType(tp1, tp2)
799813

@@ -853,12 +867,21 @@ object Contexts {
853867
}, gadts)
854868
} finally boundAdditionInProgress = false
855869

870+
override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean =
871+
constraint.isLess(tvar(sym1).origin, tvar(sym2).origin)
872+
873+
override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds =
874+
mapping(sym) match {
875+
case null => null
876+
case tv => removeTypeVars(fullBounds(tv.origin)).asInstanceOf[TypeBounds]
877+
}
878+
856879
override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = {
857880
mapping(sym) match {
858881
case null => null
859882
case tv =>
860883
def retrieveBounds: TypeBounds = {
861-
val tb = constraint.fullBounds(tv.origin)
884+
val tb = bounds(tv.origin)
862885
removeTypeVars(tb).asInstanceOf[TypeBounds]
863886
}
864887
(
@@ -870,10 +893,7 @@ object Contexts {
870893
boundCache = boundCache.updated(sym, bounds)
871894
bounds
872895
}
873-
).reporting({ res =>
874-
// i"gadt bounds $sym: $res"
875-
""
876-
}, gadts)
896+
)// .reporting({ res => i"gadt bounds $sym: $res" }, gadts)
877897
}
878898
}
879899

@@ -968,7 +988,7 @@ object Contexts {
968988
sb ++= constraint.show
969989
sb += '\n'
970990
mapping.foreachBinding { case (sym, _) =>
971-
sb ++= i"$sym: ${bounds(sym)}\n"
991+
sb ++= i"$sym: ${fullBounds(sym)}\n"
972992
}
973993
sb.result
974994
}
@@ -977,7 +997,9 @@ object Contexts {
977997
@sharable object EmptyGADTMap extends GADTMap {
978998
override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = unsupported("EmptyGADTMap.addEmptyBounds")
979999
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.addBound")
1000+
override def isLess(sym1: Symbol, sym2: Symbol)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.isLess")
9801001
override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null
1002+
override def fullBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null
9811003
override def contains(sym: Symbol)(implicit ctx: Context) = false
9821004
override def approximation(sym: Symbol, fromBelow: Boolean)(implicit ctx: Context): Type = unsupported("EmptyGADTMap.approximation")
9831005
override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGADTMap"

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -197,15 +197,6 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
197197
def nonParamBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds =
198198
entry(param).bounds
199199

200-
def fullLowerBound(param: TypeParamRef)(implicit ctx: Context): Type =
201-
(nonParamBounds(param).lo /: minLower(param))(_ | _)
202-
203-
def fullUpperBound(param: TypeParamRef)(implicit ctx: Context): Type =
204-
(nonParamBounds(param).hi /: minUpper(param))(_ & _)
205-
206-
def fullBounds(param: TypeParamRef)(implicit ctx: Context): TypeBounds =
207-
nonParamBounds(param).derivedTypeBounds(fullLowerBound(param), fullUpperBound(param))
208-
209200
def typeVarOfParam(param: TypeParamRef): Type = {
210201
val entries = boundsMap(param.binder)
211202
if (entries == null) NoType

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,8 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
3232
def constraint: Constraint = state.constraint
3333
def constraint_=(c: Constraint): Unit = state.constraint = c
3434

35+
override protected def externalize(param: TypeParamRef)(implicit ctx: Context): Type = param
36+
3537
private[this] var pendingSubTypes: mutable.Set[(Type, Type)] = null
3638
private[this] var recCount = 0
3739
private[this] var monitored = false
@@ -403,6 +405,16 @@ class TypeComparer(initctx: Context) extends ConstraintHandling[AbsentContext] {
403405
val gbounds2 = gadtBounds(tp2.symbol)
404406
(gbounds2 != null) &&
405407
(isSubTypeWhenFrozen(tp1, gbounds2.lo) ||
408+
(tp1 match {
409+
case tp1: NamedType if ctx.gadt.contains(tp1.symbol) =>
410+
// Note: since we approximate constrained types only with their non-param bounds,
411+
// we need to manually handle the case when we're comparing two constrained types,
412+
// one of which is constrained to be a subtype of another.
413+
// We do not need similar code in fourthTry, since we only need to care about
414+
// comparing two constrained types, and that case will be handled here first.
415+
ctx.gadt.isLess(tp1.symbol, tp2.symbol) && GADTusage(tp1.symbol) && GADTusage(tp2.symbol)
416+
case _ => false
417+
}) ||
406418
narrowGADTBounds(tp2, tp1, approx, isUpper = false)) &&
407419
GADTusage(tp2.symbol)
408420
}

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3790,10 +3790,10 @@ object Types {
37903790
def contextInfo(tp: Type): Type = tp match {
37913791
case tp: TypeParamRef =>
37923792
val constraint = ctx.typerState.constraint
3793-
if (constraint.entry(tp).exists) constraint.fullBounds(tp)
3793+
if (constraint.entry(tp).exists) ctx.typeComparer.fullBounds(tp)
37943794
else NoType
37953795
case tp: TypeRef =>
3796-
val bounds = ctx.gadt.bounds(tp.symbol)
3796+
val bounds = ctx.gadt.fullBounds(tp.symbol)
37973797
if (bounds == null) NoType else bounds
37983798
case tp: TypeVar =>
37993799
tp.underlying

compiler/src/dotty/tools/dotc/printing/Formatting.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ object Formatting {
169169
case sym: Symbol =>
170170
val info =
171171
if (ctx.gadt.contains(sym))
172-
sym.info & ctx.gadt.bounds(sym)
172+
sym.info & ctx.gadt.fullBounds(sym)
173173
else
174174
sym.info
175175
s"is a ${ctx.printer.kindString(sym)}${sym.showExtendedLocation}${addendum("bounds", info)}"
@@ -189,7 +189,7 @@ object Formatting {
189189
case param: TermParamRef => false
190190
case skolem: SkolemType => true
191191
case sym: Symbol =>
192-
ctx.gadt.contains(sym) && ctx.gadt.bounds(sym) != TypeBounds.empty
192+
ctx.gadt.contains(sym) && ctx.gadt.fullBounds(sym) != TypeBounds.empty
193193
case _ =>
194194
assert(false, "unreachable")
195195
false

compiler/src/dotty/tools/dotc/printing/PlainPrinter.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,10 @@ class PlainPrinter(_ctx: Context) extends Printer {
206206
else {
207207
val constr = ctx.typerState.constraint
208208
val bounds =
209-
if (constr.contains(tp)) constr.fullBounds(tp.origin)(ctx.addMode(Mode.Printing))
209+
if (constr.contains(tp)) {
210+
val ctx0 = ctx.addMode(Mode.Printing)
211+
ctx0.typeComparer.fullBounds(tp.origin)(ctx0)
212+
}
210213
else TypeBounds.empty
211214
if (bounds.isTypeAlias) toText(bounds.lo) ~ (Str("^") provided ctx.settings.YprintDebug.value)
212215
else if (ctx.settings.YshowVarBounds.value) "(" ~ toText(tp.origin) ~ "?" ~ toText(bounds) ~ ")"

compiler/src/dotty/tools/dotc/typer/ErrorReporting.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,8 @@ object ErrorReporting {
128128
case tp: TypeParamRef =>
129129
constraint.entry(tp) match {
130130
case bounds: TypeBounds =>
131-
if (variance < 0) apply(constraint.fullUpperBound(tp))
132-
else if (variance > 0) apply(constraint.fullLowerBound(tp))
131+
if (variance < 0) apply(ctx.typeComparer.fullUpperBound(tp))
132+
else if (variance > 0) apply(ctx.typeComparer.fullLowerBound(tp))
133133
else tp
134134
case NoType => tp
135135
case instType => apply(instType)

compiler/src/dotty/tools/dotc/typer/Implicits.scala

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -394,21 +394,29 @@ object Implicits {
394394
* what was expected
395395
*/
396396
override def clarify(tp: Type)(implicit ctx: Context): Type = {
397-
val map = new TypeMap {
398-
def apply(t: Type): Type = t match {
399-
case t: TypeParamRef =>
400-
constraint.entry(t) match {
401-
case NoType => t
402-
case bounds: TypeBounds => constraint.fullBounds(t)
403-
case t1 => t1
404-
}
405-
case t: TypeVar =>
406-
t.instanceOpt.orElse(apply(t.origin))
407-
case _ =>
408-
mapOver(t)
397+
val ctx0 = ctx
398+
locally {
399+
implicit val ctx = ctx0.fresh.setTyperState {
400+
val ts = ctx0.typerState.fresh()
401+
ts.constraint_=(constraint)(ctx0)
402+
ts
403+
}
404+
val map = new TypeMap {
405+
def apply(t: Type): Type = t match {
406+
case t: TypeParamRef =>
407+
constraint.entry(t) match {
408+
case NoType => t
409+
case bounds: TypeBounds => ctx.typeComparer.fullBounds(t)
410+
case t1 => t1
411+
}
412+
case t: TypeVar =>
413+
t.instanceOpt.orElse(apply(t.origin))
414+
case _ =>
415+
mapOver(t)
416+
}
409417
}
418+
map(tp)
410419
}
411-
map(tp)
412420
}
413421

414422
def explanation(implicit ctx: Context): String =

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ object Inferencing {
260260
* 0 if unconstrained, or constraint is from below and above.
261261
*/
262262
private def instDirection(param: TypeParamRef)(implicit ctx: Context): Int = {
263-
val constrained = ctx.typerState.constraint.fullBounds(param)
263+
val constrained = ctx.typeComparer.fullBounds(param)
264264
val original = param.binder.paramInfos(param.paramNum)
265265
val cmp = ctx.typeComparer
266266
val approxBelow =
@@ -295,7 +295,7 @@ object Inferencing {
295295
if (v == 1) tvar.instantiate(fromBelow = false)
296296
else if (v == -1) tvar.instantiate(fromBelow = true)
297297
else {
298-
val bounds = ctx.typerState.constraint.fullBounds(tvar.origin)
298+
val bounds = ctx.typeComparer.fullBounds(tvar.origin)
299299
if (bounds.hi <:< bounds.lo || bounds.hi.classSymbol.is(Final) || fromScala2x)
300300
tvar.instantiate(fromBelow = false)
301301
else {

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1087,7 +1087,7 @@ class Typer extends Namer
10871087
if (ctx.scope.lookup(b.name) == NoSymbol) ctx.enter(sym)
10881088
else ctx.error(new DuplicateBind(b, cdef), b.sourcePos)
10891089
if (!ctx.isAfterTyper) {
1090-
val bounds = ctx.gadt.bounds(sym)
1090+
val bounds = ctx.gadt.fullBounds(sym)
10911091
if (bounds != null) sym.info = bounds
10921092
}
10931093
b

tests/pos/gadt-accumulatable.scala

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
object `gadt-accumulatable` {
2+
sealed abstract class Or[+G,+B] extends Product with Serializable
3+
final case class Good[+G](g: G) extends Or[G,Nothing]
4+
final case class Bad[+B](b: B) extends Or[Nothing,B]
5+
6+
sealed trait Validation[+E] extends Product with Serializable
7+
case object Pass extends Validation[Nothing]
8+
case class Fail[E](error: E) extends Validation[E]
9+
10+
sealed abstract class Every[+T] protected (underlying: Vector[T]) extends /*PartialFunction[Int, T] with*/ Product with Serializable
11+
final case class One[+T](loneElement: T) extends Every[T](Vector(loneElement))
12+
final case class Many[+T](firstElement: T, secondElement: T, otherElements: T*) extends Every[T](firstElement +: secondElement +: Vector(otherElements: _*))
13+
14+
class Accumulatable[G, ERR, EVERY[_]] { }
15+
16+
def convertOrToAccumulatable[G, ERR, EVERY[b] <: Every[b]](accumulatable: G Or EVERY[ERR]): Accumulatable[G, ERR, EVERY] = {
17+
new Accumulatable[G, ERR, EVERY] {
18+
def when[OTHERERR >: ERR](validations: (G => Validation[OTHERERR])*): G Or Every[OTHERERR] = {
19+
accumulatable match {
20+
case Good(g) =>
21+
val results = validations flatMap (_(g) match { case Fail(x) => val z: OTHERERR = x; Seq(x); case Pass => Seq.empty})
22+
results.length match {
23+
case 0 => Good(g)
24+
case 1 => Bad(One(results.head))
25+
case _ =>
26+
val first = results.head
27+
val tail = results.tail
28+
val second = tail.head
29+
val rest = tail.tail
30+
Bad(Many(first, second, rest: _*))
31+
}
32+
case Bad(myBad) => Bad(myBad)
33+
}
34+
}
35+
}
36+
}
37+
}

0 commit comments

Comments
 (0)