Skip to content

Commit 4c41fc4

Browse files
committed
Reuse ConstraintHandling for GADTMap
1 parent 04503d9 commit 4c41fc4

31 files changed

+930
-38
lines changed

compiler/src/dotty/tools/dotc/config/Printers.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ object Printers {
2020
val dottydoc: Printer = noPrinter
2121
val exhaustivity: Printer = noPrinter
2222
val gadts: Printer = noPrinter
23+
val gadtsConstr: Printer = noPrinter
2324
val hk: Printer = noPrinter
2425
val implicits: Printer = noPrinter
2526
val implicitsDetailed: Printer = noPrinter

compiler/src/dotty/tools/dotc/core/ConstraintHandling.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ trait ConstraintHandling {
2323
def constr_println(msg: => String): Unit = constr.println(msg)
2424
def typr_println(msg: => String): Unit = typr.println(msg)
2525

26-
implicit val ctx: Context
26+
implicit def ctx: Context
2727

2828
protected def isSubType(tp1: Type, tp2: Type): Boolean
2929
protected def isSameType(tp1: Type, tp2: Type): Boolean
3030

31-
val state: TyperState
32-
import state.constraint
31+
protected def constraint: Constraint
32+
protected def constraint_=(c: Constraint): Unit
3333

3434
private[this] var addConstraintInvocations = 0
3535

compiler/src/dotty/tools/dotc/core/Contexts.scala

Lines changed: 195 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ object Contexts {
480480
def setTyper(typer: Typer): this.type = { this.scope = typer.scope; setTypeAssigner(typer) }
481481
def setImportInfo(importInfo: ImportInfo): this.type = { this.importInfo = importInfo; this }
482482
def setGadt(gadt: GADTMap): this.type = { this.gadt = gadt; this }
483-
def setFreshGADTBounds: this.type = setGadt(new GADTMap(gadt.bounds))
483+
def setFreshGADTBounds: this.type = setGadt(gadt.fresh)
484484
def setSearchHistory(searchHistory: SearchHistory): this.type = { this.searchHistory = searchHistory; this }
485485
def setTypeComparerFn(tcfn: Context => TypeComparer): this.type = { this.typeComparer = tcfn(this); this }
486486
private def setMoreProperties(moreProperties: Map[Key[Any], Any]): this.type = { this.moreProperties = moreProperties; this }
@@ -708,14 +708,201 @@ object Contexts {
708708
else assert(thread == Thread.currentThread(), "illegal multithreaded access to ContextBase")
709709
}
710710

711-
class GADTMap(initBounds: SimpleIdentityMap[Symbol, TypeBounds]) {
712-
private[this] var myBounds = initBounds
713-
def setBounds(sym: Symbol, b: TypeBounds): Unit =
714-
myBounds = myBounds.updated(sym, b)
715-
def bounds: SimpleIdentityMap[Symbol, TypeBounds] = myBounds
711+
sealed abstract class GADTMap {
712+
def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit
713+
def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean
714+
def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds
715+
def contains(sym: Symbol)(implicit ctx: Context): Boolean
716+
def debugBoundsDescription(implicit ctx: Context): String
717+
def fresh: GADTMap
716718
}
717719

718-
@sharable object EmptyGADTMap extends GADTMap(SimpleIdentityMap.Empty) {
719-
override def setBounds(sym: Symbol, b: TypeBounds): Unit = unsupported("EmptyGADTMap.setBounds")
720+
final class SmartGADTMap private (
721+
private[this] var myConstraint: Constraint,
722+
private[this] var mapping: SimpleIdentityMap[Symbol, TypeVar],
723+
private[this] var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol]
724+
) extends GADTMap with ConstraintHandling {
725+
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}
726+
727+
def this() = this(
728+
myConstraint = new OrderingConstraint(SimpleIdentityMap.Empty, SimpleIdentityMap.Empty, SimpleIdentityMap.Empty),
729+
mapping = SimpleIdentityMap.Empty,
730+
reverseMapping = SimpleIdentityMap.Empty
731+
)
732+
733+
// TODO: clean up this dirty kludge
734+
private[this] var myCtx: Context = null
735+
implicit override def ctx = myCtx
736+
@forceInline private[this] final def inCtx[T](_ctx: Context)(op: => T) = {
737+
val savedCtx = myCtx
738+
myCtx = _ctx
739+
try op finally myCtx = savedCtx
740+
}
741+
742+
override protected def constraint = myConstraint
743+
override protected def constraint_=(c: Constraint) = myConstraint = c
744+
745+
override def isSubType(tp1: Type, tp2: Type): Boolean = ctx.typeComparer.isSubType(tp1, tp2)
746+
override def isSameType(tp1: Type, tp2: Type): Boolean = ctx.typeComparer.isSameType(tp1, tp2)
747+
748+
749+
override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = tvar(sym)
750+
751+
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = inCtx(ctx) {
752+
@annotation.tailrec def stripInst(tp: Type): Type = tp match {
753+
case tv: TypeVar =>
754+
val inst = instType(tv)
755+
if (inst.exists) stripInst(inst) else tv
756+
case _ => tp
757+
}
758+
759+
def cautiousSubtype(tp1: Type, tp2: Type, isSubtype: Boolean): Boolean = {
760+
val externalizedTp1 = removeTypeVars(tp1)
761+
val externalizedTp2 = removeTypeVars(tp2)
762+
763+
def descr = {
764+
def op = s"frozen_${if (isSubtype) "<:<" else ">:>"}"
765+
i"$tp1 $op $tp2\n\t$externalizedTp1 $op $externalizedTp2"
766+
}
767+
// gadts.println(descr)
768+
769+
val res =
770+
// TypeComparer.explain[Boolean](gadts.println) { implicit ctx =>
771+
if (isSubtype) externalizedTp1 frozen_<:< externalizedTp2
772+
else externalizedTp2 frozen_<:< externalizedTp1
773+
// }
774+
775+
gadts.println(i"$descr = $res")
776+
res
777+
}
778+
779+
def unify(tv: TypeVar, tp: Type): Unit = {
780+
gadts.println(i"manually unifying $tv with $tp")
781+
constraint = constraint.updateEntry(tv.origin, tp)
782+
}
783+
784+
val symTvar: TypeVar = stripInst(tvar(sym)) match {
785+
case tv: TypeVar => tv
786+
case inst =>
787+
gadts.println(i"instantiated: $sym -> $inst")
788+
// this is wrong in general, but "correct" due to a subtype check in TypeComparer#narrowGadtBounds
789+
return true
790+
}
791+
792+
val internalizedBound = insertTypeVars(bound)
793+
val res = stripInst(internalizedBound) match {
794+
case boundTvar: TypeVar =>
795+
if (boundTvar eq symTvar) true
796+
else if (isUpper) addLess(symTvar.origin, boundTvar.origin)
797+
else addLess(boundTvar.origin, symTvar.origin)
798+
case bound =>
799+
if (cautiousSubtype(symTvar, bound, isSubtype = !isUpper)) { unify(symTvar, bound); true }
800+
else if (isUpper) addUpperBound(symTvar.origin, bound)
801+
else addLowerBound(symTvar.origin, bound)
802+
}
803+
804+
gadts.println {
805+
val descr = if (isUpper) "upper" else "lower"
806+
val op = if (isUpper) "<:" else ">:"
807+
i"adding $descr bound $sym $op $bound = $res\t( $symTvar $op $internalizedBound )"
808+
}
809+
res
810+
}
811+
812+
override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = inCtx(ctx) {
813+
mapping(sym) match {
814+
case null => null
815+
case tv =>
816+
val tb = constraint.fullBounds(tv.origin)
817+
val res = removeTypeVars(tb).asInstanceOf[TypeBounds]
818+
// gadts.println(i"gadt bounds $sym: $res")
819+
res
820+
}
821+
}
822+
823+
override def contains(sym: Symbol)(implicit ctx: Context): Boolean = mapping(sym) ne null
824+
825+
override def fresh: GADTMap = new SmartGADTMap(
826+
myConstraint,
827+
mapping,
828+
reverseMapping
829+
)
830+
831+
// ---- Private ----------------------------------------------------------
832+
833+
private[this] def tvar(sym: Symbol)(implicit ctx: Context): TypeVar = {
834+
mapping(sym) match {
835+
case tv: TypeVar =>
836+
tv
837+
case null =>
838+
val res = {
839+
import NameKinds.DepParamName
840+
// avoid registering the TypeVar with TyperState / TyperState#constraint
841+
// - we don't want TyperState instantiating these TypeVars
842+
// - we don't want TypeComparer constraining these TypeVars
843+
val poly = PolyType(DepParamName.fresh(sym.name.toTypeName) :: Nil)(
844+
pt => TypeBounds.empty :: Nil,
845+
pt => defn.AnyType)
846+
new TypeVar(poly.paramRefs.head, creatorState = null)
847+
}
848+
gadts.println(i"GADTMap: created tvar $sym -> $res")
849+
constraint = constraint.add(res.origin.binder, res :: Nil)
850+
mapping = mapping.updated(sym, res)
851+
reverseMapping = reverseMapping.updated(res.origin, sym)
852+
res
853+
}
854+
}
855+
856+
private def insertTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match {
857+
case tp: TypeRef =>
858+
val sym = tp.typeSymbol
859+
if (contains(sym)) tvar(sym) else tp
860+
case _ =>
861+
(if (map != null) map else new TypeVarInsertingMap()).mapOver(tp)
862+
}
863+
private final class TypeVarInsertingMap(implicit ctx: Context) extends TypeMap {
864+
override def apply(tp: Type): Type = insertTypeVars(tp, this)
865+
}
866+
867+
private def removeTypeVars(tp: Type, map: TypeMap = null)(implicit ctx: Context) = tp match {
868+
case tpr: TypeParamRef =>
869+
reverseMapping(tpr) match {
870+
case null => tpr
871+
case sym => sym.typeRef
872+
}
873+
case tv: TypeVar =>
874+
reverseMapping(tv.origin) match {
875+
case null => tv
876+
case sym => sym.typeRef
877+
}
878+
case _ =>
879+
(if (map != null) map else new TypeVarRemovingMap()).mapOver(tp)
880+
}
881+
private final class TypeVarRemovingMap(implicit ctx: Context) extends TypeMap {
882+
override def apply(tp: Type): Type = removeTypeVars(tp, this)
883+
}
884+
885+
// ---- Debug ------------------------------------------------------------
886+
887+
override def constr_println(msg: => String): Unit = gadtsConstr.println(msg)
888+
889+
override def debugBoundsDescription(implicit ctx: Context): String = {
890+
val sb = new mutable.StringBuilder
891+
sb ++= constraint.show
892+
sb += '\n'
893+
mapping.foreachBinding { case (sym, _) =>
894+
sb ++= i"$sym: ${bounds(sym)}\n"
895+
}
896+
sb.result
897+
}
898+
}
899+
900+
@sharable object EmptyGADTMap extends GADTMap {
901+
override def addEmptyBounds(sym: Symbol)(implicit ctx: Context): Unit = unsupported("EmptyGADTMap.addEmptyBounds")
902+
override def addBound(sym: Symbol, bound: Type, isUpper: Boolean)(implicit ctx: Context): Boolean = unsupported("EmptyGADTMap.addBound")
903+
override def bounds(sym: Symbol)(implicit ctx: Context): TypeBounds = null
904+
override def contains(sym: Symbol)(implicit ctx: Context) = false
905+
override def debugBoundsDescription(implicit ctx: Context): String = "EmptyGADTMap"
906+
override def fresh = new SmartGADTMap
720907
}
721908
}

compiler/src/dotty/tools/dotc/core/OrderingConstraint.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -325,8 +325,8 @@ class OrderingConstraint(private val boundsMap: ParamBounds,
325325
private def order(current: This, param1: TypeParamRef, param2: TypeParamRef)(implicit ctx: Context): This =
326326
if (param1 == param2 || current.isLess(param1, param2)) this
327327
else {
328-
assert(contains(param1))
329-
assert(contains(param2))
328+
assert(contains(param1), i"$param1")
329+
assert(contains(param2), i"$param2")
330330
val newUpper = param2 :: exclusiveUpper(param2, param1)
331331
val newLower = param1 :: exclusiveLower(param1, param2)
332332
val current1 = (current /: newLower)(upperLens.map(this, _, _, newUpper ::: _))

compiler/src/dotty/tools/dotc/core/Symbols.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -214,7 +214,11 @@ trait Symbols { this: Context =>
214214
*/
215215
def newPatternBoundSymbol(name: Name, info: Type, pos: Position): Symbol = {
216216
val sym = newSymbol(owner, name, Case, info, coord = pos)
217-
if (name.isTypeName) gadt.setBounds(sym, info.bounds)
217+
if (name.isTypeName) {
218+
val bounds = info.bounds
219+
gadt.addBound(sym, bounds.lo, isUpper = false)
220+
gadt.addBound(sym, bounds.hi, isUpper = true)
221+
}
218222
sym
219223
}
220224

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 51 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
2323
import TypeComparer._
2424
implicit val ctx: Context = initctx
2525

26-
val state: TyperState = ctx.typerState
27-
import state.constraint
26+
val state = ctx.typerState
27+
def constraint: Constraint = state.constraint
28+
def constraint_=(c: Constraint): Unit = state.constraint = c
2829

2930
private[this] var pendingSubTypes: mutable.Set[(Type, Type)] = null
3031
private[this] var recCount = 0
@@ -105,8 +106,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
105106
true
106107
}
107108

108-
protected def gadtBounds(sym: Symbol)(implicit ctx: Context): TypeBounds = ctx.gadt.bounds(sym)
109-
protected def gadtSetBounds(sym: Symbol, b: TypeBounds): Unit = ctx.gadt.setBounds(sym, b)
109+
protected def gadtBounds(sym: Symbol)(implicit ctx: Context) = ctx.gadt.bounds(sym)
110+
protected def gadtAddLowerBound(sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(sym, b, isUpper = false)
111+
protected def gadtAddUpperBound(sym: Symbol, b: Type): Boolean = ctx.gadt.addBound(sym, b, isUpper = true)
110112

111113
protected def typeVarInstance(tvar: TypeVar)(implicit ctx: Context): Type = tvar.underlying
112114

@@ -136,7 +138,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
136138
finally this.approx = saved
137139
}
138140

139-
protected def isSubType(tp1: Type, tp2: Type): Boolean = isSubType(tp1, tp2, NoApprox)
141+
def isSubType(tp1: Type, tp2: Type): Boolean = isSubType(tp1, tp2, NoApprox)
140142

141143
protected def recur(tp1: Type, tp2: Type): Boolean = trace(s"isSubType ${traceInfo(tp1, tp2)} $approx", subtyping) {
142144

@@ -738,9 +740,28 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
738740
isSubArgs(args1, args2, tp1, tparams)
739741
case tycon1: TypeRef =>
740742
tycon2.dealiasKeepRefiningAnnots match {
741-
case tycon2: TypeRef if tycon1.symbol == tycon2.symbol =>
743+
case tycon2: TypeRef =>
744+
val tycon1sym = tycon1.symbol
745+
val tycon2sym = tycon2.symbol
746+
747+
var touchedGADTs = false
748+
def gadtBoundsContain(sym: Symbol, tp: Type): Boolean = {
749+
touchedGADTs = true
750+
val b = gadtBounds(sym)
751+
b != null && inFrozenConstraint {
752+
(b.lo =:= tp) && (b.hi =:= tp)
753+
}
754+
}
755+
756+
val res = (
757+
tycon1sym == tycon2sym ||
758+
gadtBoundsContain(tycon1sym, tycon2) ||
759+
gadtBoundsContain(tycon2sym, tycon1)
760+
) &&
742761
isSubType(tycon1.prefix, tycon2.prefix) &&
743762
isSubArgs(args1, args2, tp1, tparams)
763+
if (res && touchedGADTs) GADTused = true
764+
res
744765
case _ =>
745766
false
746767
}
@@ -1217,7 +1238,7 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
12171238
if (isUpper) TypeBounds(oldBounds.lo, oldBounds.hi & bound)
12181239
else TypeBounds(oldBounds.lo | bound, oldBounds.hi)
12191240
isSubType(newBounds.lo, newBounds.hi) &&
1220-
{ gadtSetBounds(tparam, newBounds); true }
1241+
(if (isUpper) gadtAddUpperBound(tparam, bound) else gadtAddLowerBound(tparam, bound))
12211242
}
12221243
}
12231244
}
@@ -1766,6 +1787,9 @@ class TypeComparer(initctx: Context) extends ConstraintHandling {
17661787
totalCount = 0
17671788
}
17681789
}
1790+
1791+
/** Returns last check's debug mode, if explicitly enabled. */
1792+
def lastTrace(): String = ""
17691793
}
17701794

17711795
object TypeComparer {
@@ -1797,11 +1821,21 @@ object TypeComparer {
17971821

17981822
val NoApprox: ApproxState = new ApproxState(0)
17991823

1824+
def explain[T](say: String => Unit)(op: Context => T)(implicit ctx: Context): T = {
1825+
val (res, explanation) = underlyingExplained(op)
1826+
say(explanation)
1827+
res
1828+
}
1829+
18001830
/** Show trace of comparison operations when performing `op` as result string */
18011831
def explained[T](op: Context => T)(implicit ctx: Context): String = {
1832+
underlyingExplained(op)._2
1833+
}
1834+
1835+
private def underlyingExplained[T](op: Context => T)(implicit ctx: Context): (T, String) = {
18021836
val nestedCtx = ctx.fresh.setTypeComparerFn(new ExplainingTypeComparer(_))
1803-
op(nestedCtx)
1804-
nestedCtx.typeComparer.toString
1837+
val res = op(nestedCtx)
1838+
(res, nestedCtx.typeComparer.lastTrace())
18051839
}
18061840
}
18071841

@@ -1825,9 +1859,14 @@ class TrackingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
18251859
super.gadtBounds(sym)
18261860
}
18271861

1828-
override def gadtSetBounds(sym: Symbol, b: TypeBounds): Unit = {
1862+
override def gadtAddLowerBound(sym: Symbol, b: Type): Boolean = {
1863+
footprint += sym.typeRef
1864+
super.gadtAddLowerBound(sym, b)
1865+
}
1866+
1867+
override def gadtAddUpperBound(sym: Symbol, b: Type): Boolean = {
18291868
footprint += sym.typeRef
1830-
super.gadtSetBounds(sym, b)
1869+
super.gadtAddUpperBound(sym, b)
18311870
}
18321871

18331872
override def typeVarInstance(tvar: TypeVar)(implicit ctx: Context): Type = {
@@ -1928,5 +1967,5 @@ class ExplainingTypeComparer(initctx: Context) extends TypeComparer(initctx) {
19281967

19291968
override def copyIn(ctx: Context): ExplainingTypeComparer = new ExplainingTypeComparer(ctx)
19301969

1931-
override def toString: String = "Subtype trace:" + { try b.toString finally b.clear() }
1970+
override def lastTrace(): String = "Subtype trace:" + { try b.toString finally b.clear() }
19321971
}

0 commit comments

Comments
 (0)