Skip to content

Fix 1365: Fix bindings in patterns #1377

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 15, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 18 additions & 8 deletions src/dotty/tools/dotc/core/TypeComparer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -813,14 +813,24 @@ class TypeComparer(initctx: Context) extends DotClass with ConstraintHandling {
private def narrowGADTBounds(tr: NamedType, bound: Type, isUpper: Boolean): Boolean =
ctx.mode.is(Mode.GADTflexible) && {
val tparam = tr.symbol
typr.println(s"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.isRef(tparam)}")
!bound.isRef(tparam) && {
val oldBounds = ctx.gadt.bounds(tparam)
val newBounds =
if (isUpper) TypeBounds(oldBounds.lo, oldBounds.hi & bound)
else TypeBounds(oldBounds.lo | bound, oldBounds.hi)
isSubType(newBounds.lo, newBounds.hi) &&
{ ctx.gadt.setBounds(tparam, newBounds); true }
typr.println(i"narrow gadt bound of $tparam: ${tparam.info} from ${if (isUpper) "above" else "below"} to $bound ${bound.isRef(tparam)}")
if (bound.isRef(tparam)) false
else bound match {
case bound: TypeRef
if bound.symbol.is(BindDefinedType) && ctx.gadt.bounds.contains(bound.symbol) &&
!tr.symbol.is(BindDefinedType) =>
// Avoid having pattern-bound types in gadt bounds,
// as these might be eliminated once the pattern is typechecked.
// Pattern-bound type symbols should be narrowed first, only if that fails
// should symbols in the environment be constrained.
narrowGADTBounds(bound, tr, !isUpper)
case _ =>
val oldBounds = ctx.gadt.bounds(tparam)
val newBounds =
if (isUpper) TypeBounds(oldBounds.lo, oldBounds.hi & bound)
else TypeBounds(oldBounds.lo | bound, oldBounds.hi)
isSubType(newBounds.lo, newBounds.hi) &&
Copy link
Contributor

@DarkDimius DarkDimius Jul 15, 2016

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (!isSubType(newBounds.lo, newBounds.hi))
 ctx.gadt.setBounds(tparam, newBounds)

true

Is equivalent but less hacky.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would not describe it as hacky. The full code you propose would be:

 if (isSubType(newBounds.lo, newBounds.hi)) {
   ctx.gadt.setBounds(tparam, newBounds)
   true
 }
 else false

It's certainly longer. There are quite a few other places in the code base by now that use the same idiom.

{ ctx.gadt.setBounds(tparam, newBounds); true }
}
}

Expand Down
63 changes: 49 additions & 14 deletions src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -448,11 +448,8 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
return typed(untpd.Apply(untpd.TypedSplice(arg), tree.expr), pt)
case _ =>
}
case tref: TypeRef if tref.symbol.isClass && !ctx.isAfterTyper =>
val setBefore = ctx.mode is Mode.GADTflexible
tpt1.tpe.<:<(pt)(ctx.addMode(Mode.GADTflexible))
if (!setBefore) ctx.retractMode(Mode.GADTflexible)
case _ =>
if (!ctx.isAfterTyper) tpt1.tpe.<:<(pt)(ctx.addMode(Mode.GADTflexible))
}
ascription(tpt1, isWildcard = true)
}
Expand Down Expand Up @@ -762,17 +759,37 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
def typedCase(tree: untpd.CaseDef, pt: Type, selType: Type, gadtSyms: Set[Symbol])(implicit ctx: Context): CaseDef = track("typedCase") {
val originalCtx = ctx

def caseRest(pat: Tree)(implicit ctx: Context) = {
pat foreachSubTree {
case b: Bind =>
if (ctx.scope.lookup(b.name) == NoSymbol) ctx.enter(b.symbol)
else ctx.error(d"duplicate pattern variable: ${b.name}", b.pos)
case _ =>
/** - replace all references to symbols associated with wildcards by their GADT bounds
* - enter all symbols introduced by a Bind in current scope
*/
val indexPattern = new TreeMap {
val elimWildcardSym = new TypeMap {
def apply(t: Type) = t match {
case ref @ TypeRef(_, tpnme.WILDCARD) if ctx.gadt.bounds.contains(ref.symbol) =>
ctx.gadt.bounds(ref.symbol)
case TypeAlias(ref @ TypeRef(_, tpnme.WILDCARD)) if ctx.gadt.bounds.contains(ref.symbol) =>
ctx.gadt.bounds(ref.symbol)
case _ =>
mapOver(t)
}
}
override def transform(tree: Tree)(implicit ctx: Context) =
super.transform(tree.withType(elimWildcardSym(tree.tpe))) match {
case b: Bind =>
if (ctx.scope.lookup(b.name) == NoSymbol) ctx.enter(b.symbol)
else ctx.error(d"duplicate pattern variable: ${b.name}", b.pos)
b.symbol.info = elimWildcardSym(b.symbol.info)
b
case t => t
}
}

def caseRest(pat: Tree)(implicit ctx: Context) = {
val pat1 = indexPattern.transform(pat)
val guard1 = typedExpr(tree.guard, defn.BooleanType)
val body1 = ensureNoLocalRefs(typedExpr(tree.body, pt), pt, ctx.scope.toList)
.ensureConforms(pt)(originalCtx) // insert a cast if body does not conform to expected type if we disregard gadt bounds
assignType(cpy.CaseDef(tree)(pat, guard1, body1), body1)
assignType(cpy.CaseDef(tree)(pat1, guard1, body1), body1)
}

val gadtCtx =
Expand Down Expand Up @@ -963,11 +980,30 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
assignType(cpy.ByNameTypeTree(tree)(result1), result1)
}

/** Define a new symbol associated with a Bind or pattern wildcard and
* make it gadt narrowable.
*/
private def newPatternBoundSym(name: Name, info: Type, pos: Position)(implicit ctx: Context) = {
val flags = if (name.isTypeName) BindDefinedType else EmptyFlags
val sym = ctx.newSymbol(ctx.owner, name, flags | Case, info, coord = pos)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Case is described as:

 /** A case class or its companion object */
final val Case = commonFlag(17, "case")
final val CaseClass = Case.toTypeFlags
final val CaseVal = Case.toTermFlags

so I don't understand what it's supposed to do here

if (name.isTypeName) ctx.gadt.setBounds(sym, info.bounds)
sym
}

def typedTypeBoundsTree(tree: untpd.TypeBoundsTree)(implicit ctx: Context): TypeBoundsTree = track("typedTypeBoundsTree") {
val TypeBoundsTree(lo, hi) = desugar.typeBoundsTree(tree)
val lo1 = typed(lo)
val hi1 = typed(hi)
assignType(cpy.TypeBoundsTree(tree)(lo1, hi1), lo1, hi1)
val tree1 = assignType(cpy.TypeBoundsTree(tree)(lo1, hi1), lo1, hi1)
if (ctx.mode.is(Mode.Pattern)) {
// Associate a pattern-bound type symbol with the wildcard.
// The bounds of the type symbol can be constrained when comparing a pattern type
// with an expected type in typedTyped. The type symbol is eliminated once
// the enclosing pattern has been typechecked; see `indexPattern` in `typedCase`.
val wildcardSym = newPatternBoundSym(tpnme.WILDCARD, tree1.tpe, tree.pos)
tree1.withType(wildcardSym.typeRef)
}
else tree1
}

def typedBind(tree: untpd.Bind, pt: Type)(implicit ctx: Context): Tree = track("typedBind") {
Expand All @@ -983,8 +1019,7 @@ class Typer extends Namer with TypeAssigner with Applications with Implicits wit
tpd.cpy.UnApply(body1)(fn, Nil,
typed(untpd.Bind(tree.name, arg).withPos(tree.pos), arg.tpe) :: Nil)
case _ =>
val flags = if (tree.isType) BindDefinedType else EmptyFlags
val sym = ctx.newSymbol(ctx.owner, tree.name, flags | Case, body1.tpe, coord = tree.pos)
val sym = newPatternBoundSym(tree.name, body1.tpe, tree.pos)
assignType(cpy.Bind(tree)(tree.name, body1), sym)
}
}
Expand Down
13 changes: 13 additions & 0 deletions tests/pos/i1365.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import scala.collection.mutable.ArrayBuffer

trait Message[M]
class Script[S] extends ArrayBuffer[Message[S]] with Message[S]

class Test[A] {
def f(cmd: Message[A]): Unit = cmd match {
case s: Script[_] => s.iterator.foreach(x => f(x))
}
def g(cmd: Message[A]): Unit = cmd match {
case s: Script[z] => s.iterator.foreach(x => g(x))
}
}