Skip to content

Bypass eligible caches for implicit search under GADT constraints #14072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Dec 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 27 additions & 19 deletions compiler/src/dotty/tools/dotc/core/GadtConstraint.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ sealed abstract class GadtConstraint extends Showable {
*/
def contains(sym: Symbol)(using Context): Boolean

def isEmpty: Boolean
final def nonEmpty: Boolean = !isEmpty
/** GADT constraint narrows bounds of at least one variable */
def isNarrowing: Boolean

/** See [[ConstraintHandling.approximation]] */
def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type
Expand All @@ -61,13 +61,15 @@ final class ProperGadtConstraint private(
private var myConstraint: Constraint,
private var mapping: SimpleIdentityMap[Symbol, TypeVar],
private var reverseMapping: SimpleIdentityMap[TypeParamRef, Symbol],
private var wasConstrained: Boolean
) extends GadtConstraint with ConstraintHandling {
import dotty.tools.dotc.config.Printers.{gadts, gadtsConstr}

def this() = this(
myConstraint = new OrderingConstraint(SimpleIdentityMap.empty, SimpleIdentityMap.empty, SimpleIdentityMap.empty),
mapping = SimpleIdentityMap.empty,
reverseMapping = SimpleIdentityMap.empty
reverseMapping = SimpleIdentityMap.empty,
wasConstrained = false
)

/** Exposes ConstraintHandling.subsumes */
Expand Down Expand Up @@ -149,20 +151,24 @@ final class ProperGadtConstraint private(
if (ntTvar ne null) stripInternalTypeVar(ntTvar) else bound
case _ => bound
}
(
internalizedBound match {
case boundTvar: TypeVar =>
if (boundTvar eq symTvar) true
else if (isUpper) addLess(symTvar.origin, boundTvar.origin)
else addLess(boundTvar.origin, symTvar.origin)
case bound =>
addBoundTransitively(symTvar.origin, bound, isUpper)
}
).showing({

val saved = constraint
val result = internalizedBound match
case boundTvar: TypeVar =>
if (boundTvar eq symTvar) true
else if (isUpper) addLess(symTvar.origin, boundTvar.origin)
else addLess(boundTvar.origin, symTvar.origin)
case bound =>
addBoundTransitively(symTvar.origin, bound, isUpper)

gadts.println {
val descr = if (isUpper) "upper" else "lower"
val op = if (isUpper) "<:" else ">:"
i"adding $descr bound $sym $op $bound = $result"
}, gadts)
}

if constraint ne saved then wasConstrained = true
result
}

override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean =
Expand Down Expand Up @@ -193,6 +199,8 @@ final class ProperGadtConstraint private(

override def contains(sym: Symbol)(using Context): Boolean = mapping(sym) ne null

def isNarrowing: Boolean = wasConstrained

override def approximation(sym: Symbol, fromBelow: Boolean)(using Context): Type = {
val res = approximation(tvarOrError(sym).origin, fromBelow = fromBelow)
gadts.println(i"approximating $sym ~> $res")
Expand All @@ -202,19 +210,19 @@ final class ProperGadtConstraint private(
override def fresh: GadtConstraint = new ProperGadtConstraint(
myConstraint,
mapping,
reverseMapping
reverseMapping,
wasConstrained
)

def restore(other: GadtConstraint): Unit = other match {
case other: ProperGadtConstraint =>
this.myConstraint = other.myConstraint
this.mapping = other.mapping
this.reverseMapping = other.reverseMapping
this.wasConstrained = other.wasConstrained
case _ => ;
}

override def isEmpty: Boolean = mapping.size == 0

// ---- Protected/internal -----------------------------------------------

override protected def constraint = myConstraint
Expand Down Expand Up @@ -293,7 +301,7 @@ final class ProperGadtConstraint private(

override def isLess(sym1: Symbol, sym2: Symbol)(using Context): Boolean = unsupported("EmptyGadtConstraint.isLess")

override def isEmpty: Boolean = true
override def isNarrowing: Boolean = false

override def contains(sym: Symbol)(using Context) = false

Expand All @@ -304,7 +312,7 @@ final class ProperGadtConstraint private(

override def fresh = new ProperGadtConstraint
override def restore(other: GadtConstraint): Unit =
if (!other.isEmpty) sys.error("cannot restore a non-empty GADTMap")
assert(!other.isNarrowing, "cannot restore a non-empty GADTMap")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Was this changed intentionally? I don't think we should allow restoring a non-empty GadtConstraint into an empty one. GadtConstraint is also a scope, and EmptyGadtConstraint doesn't allow constraining anything, so such a call to restore would lead to errors down the road.

With that being said, this is more of a nitpick than anything else, I think we should merge this PR and then I'll tweak GadtConstraint to define EmptyGadtConstraint a bit more gracefully.


override def debugBoundsDescription(using Context): String = "EmptyGadtConstraint"

Expand Down
33 changes: 23 additions & 10 deletions compiler/src/dotty/tools/dotc/typer/Implicits.scala
Original file line number Diff line number Diff line change
Expand Up @@ -330,11 +330,25 @@ object Implicits:
(this eq finalImplicits) || (outerImplicits eq finalImplicits)
}

private def combineEligibles(ownEligible: List[Candidate], outerEligible: List[Candidate]): List[Candidate] =
if ownEligible.isEmpty then outerEligible
else if outerEligible.isEmpty then ownEligible
else
val shadowed = ownEligible.map(_.ref.implicitName).toSet
ownEligible ::: outerEligible.filterConserve(cand => !shadowed.contains(cand.ref.implicitName))

def uncachedEligible(tp: Type)(using Context): List[Candidate] =
Stats.record("uncached eligible")
if monitored then record(s"check uncached eligible refs in irefCtx", refs.length)
val ownEligible = filterMatching(tp)
if isOuterMost then ownEligible
else combineEligibles(ownEligible, outerImplicits.uncachedEligible(tp))

/** The implicit references that are eligible for type `tp`. */
def eligible(tp: Type): List[Candidate] =
if (tp.hash == NotCached)
Stats.record(i"compute eligible not cached ${tp.getClass}")
Stats.record(i"compute eligible not cached")
Stats.record("compute eligible not cached")
computeEligible(tp)
else {
val eligibles = eligibleCache.lookup(tp)
Expand All @@ -354,14 +368,8 @@ object Implicits:
private def computeEligible(tp: Type): List[Candidate] = /*>|>*/ trace(i"computeEligible $tp in $refs%, %", implicitsDetailed) /*<|<*/ {
if (monitored) record(s"check eligible refs in irefCtx", refs.length)
val ownEligible = filterMatching(tp)
if (isOuterMost) ownEligible
else if ownEligible.isEmpty then outerImplicits.eligible(tp)
else
val outerEligible = outerImplicits.eligible(tp)
if outerEligible.isEmpty then ownEligible
else
val shadowed = ownEligible.map(_.ref.implicitName).toSet
ownEligible ::: outerEligible.filterConserve(cand => !shadowed.contains(cand.ref.implicitName))
if isOuterMost then ownEligible
else combineEligibles(ownEligible, outerImplicits.eligible(tp))
}

override def isAccessible(ref: TermRef)(using Context): Boolean =
Expand Down Expand Up @@ -1444,7 +1452,12 @@ trait Implicits:
NoMatchingImplicitsFailure
else
val eligible =
if contextual then ctx.implicits.eligible(wildProto)
if contextual then
if ctx.gadt.isNarrowing then
withoutMode(Mode.ImplicitsEnabled) {
ctx.implicits.uncachedEligible(wildProto)
}
else ctx.implicits.eligible(wildProto)
else implicitScope(wildProto).eligible
searchImplicit(eligible, contextual) match
case result: SearchSuccess =>
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1662,7 +1662,7 @@ class Typer extends Namer
val pat1 = indexPattern(tree).transform(pat)
val guard1 = typedExpr(tree.guard, defn.BooleanType)
var body1 = ensureNoLocalRefs(typedExpr(tree.body, pt1), pt1, ctx.scope.toList)
if ctx.gadt.nonEmpty then
if ctx.gadt.isNarrowing then
// Store GADT constraint to later retrieve it (in PostTyper, for now).
// GADT constraints are necessary to correctly check bounds of type app,
// see tests/pos/i12226 and issue #12226. It might be possible that this
Expand Down Expand Up @@ -3824,7 +3824,7 @@ class Typer extends Namer

pt match
case pt: SelectionProto =>
if ctx.gadt.nonEmpty then
if ctx.gadt.isNarrowing then
// try GADT approximation if we're trying to select a member
// Member lookup cannot take GADTs into account b/c of cache, so we
// approximate types based on GADT constraints instead. For an example,
Expand Down
3 changes: 3 additions & 0 deletions compiler/test/dotc/pos-test-pickling.blacklist
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,7 @@ i13842.scala
# GADT cast applied to singleton type difference
i4176-gadt.scala

# GADT difference
i13974a.scala

java-inherited-type1
4 changes: 2 additions & 2 deletions tests/neg/gadt-approximation-interaction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ object ImplicitConversion {

def foo[T](t: T, ev: T SUB Int) =
ev match { case SUB.Refl() =>
t ** 2 // error // implementation limitation
t ** 2
}

def bar[T](t: T, ev: T SUB Int) =
Expand All @@ -67,7 +67,7 @@ object GivenConversion {

def foo[T](t: T, ev: T SUB Int) =
ev match { case SUB.Refl() =>
t ** 2 // error (implementation limitation)
t ** 2
}

def bar[T](t: T, ev: T SUB Int) =
Expand Down
13 changes: 13 additions & 0 deletions tests/pos/i13974.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
object Test {
class C
class Use[A]
case class UseC() extends Use[C]
class ConversionTarget
implicit def convert(c: C): ConversionTarget = ???
def go[X](u: Use[X], x: X) =
u match {
case UseC() =>
//val y: C = x
x: ConversionTarget
}
}
12 changes: 12 additions & 0 deletions tests/pos/i13974a.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@

object Test2:
class Foo[+X]
enum SUB[-S, +T]:
case Refl[U]() extends SUB[U, U]
def f[A, B, C](sub : A SUB (B,C)) =
given Foo[A] = ???
val x = summon[Foo[A]]
sub match
case SUB.Refl() =>
val c: Foo[(B, C)] = summon[Foo[A]]
summon[Foo[(B, C)]]