Skip to content

Commit f578b97

Browse files
committed
Fixes to support case id: T <- ...
This was not classified as a pattern binding before, so no filtering was applied.
1 parent 1228d0c commit f578b97

File tree

5 files changed

+30
-21
lines changed

5 files changed

+30
-21
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1279,16 +1279,19 @@ object desugar {
12791279
*/
12801280
def makeFor(mapName: TermName, flatMapName: TermName, enums: List[Tree], body: Tree): Tree = trace(i"make for ${ForYield(enums, body)}", show = true) {
12811281

1282-
/** Make a function value pat => body.
1283-
* If pat is a var pattern id: T then this gives (id: T) => body
1284-
* Otherwise this gives { case pat => body }, where `pat` is allowed to be
1285-
* refutable only if `checkMode` is MatchCheck.None.
1282+
/** Let `pat` be `gen`'s pattern. Make a function value `pat => body`.
1283+
* If `pat` is a var pattern `id: T` then this gives `(id: T) => body`.
1284+
* Otherwise this gives `{ case pat => body }`, where `pat` is checked to be
1285+
* irrefutable if `gen`'s checkMode is GenCheckMode.Check.
12861286
*/
1287-
def makeLambda(pat: Tree, body: Tree, checkMode: MatchCheck): Tree = pat match {
1288-
case IdPattern(named, tpt) =>
1289-
Function(derivedValDef(pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body)
1287+
def makeLambda(gen: GenFrom, body: Tree): Tree = gen.pat match {
1288+
case IdPattern(named, tpt) if gen.checkMode != GenCheckMode.FilterAlways =>
1289+
Function(derivedValDef(gen.pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body)
12901290
case _ =>
1291-
makeCaseLambda(CaseDef(pat, EmptyTree, body) :: Nil, checkMode)
1291+
val matchCheckMode =
1292+
if (gen.checkMode == GenCheckMode.Check) MatchCheck.IrrefutableGenFrom
1293+
else MatchCheck.None
1294+
makeCaseLambda(CaseDef(gen.pat, EmptyTree, body) :: Nil, matchCheckMode)
12921295
}
12931296

12941297
/** If `pat` is not an Identifier, a Typed(Ident, _), or a Bind, wrap
@@ -1360,16 +1363,20 @@ object desugar {
13601363
}
13611364
}
13621365

1363-
def needsFilter(gen: GenFrom): Boolean =
1364-
gen.checkMode != GenCheckMode.Filter ||
1365-
IdPattern.unapply(gen.pat).isDefined ||
1366-
isIrrefutable(gen.pat, gen.expr)
1366+
def needsNoFilter(gen: GenFrom): Boolean =
1367+
if (gen.checkMode == GenCheckMode.FilterAlways) // pattern was prefixed by `case`
1368+
isIrrefutable(gen.pat, gen.expr)
1369+
else (
1370+
gen.checkMode != GenCheckMode.FilterNow ||
1371+
IdPattern.unapply(gen.pat).isDefined ||
1372+
isIrrefutable(gen.pat, gen.expr)
1373+
)
13671374

13681375
/** rhs.name with a pattern filter on rhs unless `pat` is irrefutable when
13691376
* matched against `rhs`.
13701377
*/
13711378
def rhsSelect(gen: GenFrom, name: TermName) = {
1372-
val rhs = if (needsFilter(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
1379+
val rhs = if (needsNoFilter(gen)) gen.expr else makePatFilter(gen.expr, gen.pat)
13731380
Select(rhs, name)
13741381
}
13751382

@@ -1379,10 +1386,10 @@ object desugar {
13791386

13801387
enums match {
13811388
case (gen: GenFrom) :: Nil =>
1382-
Apply(rhsSelect(gen, mapName), makeLambda(gen.pat, body, checkMode(gen)))
1389+
Apply(rhsSelect(gen, mapName), makeLambda(gen, body))
13831390
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
13841391
val cont = makeFor(mapName, flatMapName, rest, body)
1385-
Apply(rhsSelect(gen, flatMapName), makeLambda(gen.pat, cont, checkMode(gen)))
1392+
Apply(rhsSelect(gen, flatMapName), makeLambda(gen, cont))
13861393
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
13871394
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
13881395
val pats = valeqs map { case GenAlias(pat, _) => pat }
@@ -1395,7 +1402,7 @@ object desugar {
13951402
val vfrom1 = new GenFrom(makeTuple(allpats), rhs1, GenCheckMode.Ignore)
13961403
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
13971404
case (gen: GenFrom) :: test :: rest =>
1398-
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen.pat, test, MatchCheck.None))
1405+
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen, test))
13991406
val genFrom = GenFrom(gen.pat, filtered, GenCheckMode.Ignore)
14001407
makeFor(mapName, flatMapName, genFrom :: rest, body)
14011408
case _ =>

compiler/src/dotty/tools/dotc/ast/untpd.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo {
121121
object GenCheckMode {
122122
val Ignore = new GenCheckMode(0) // neither filter nor check since filtering was done before
123123
val Check = new GenCheckMode(1) // check that pattern is irrefutable
124-
val Filter = new GenCheckMode(2) // filter out non-matching elements
124+
val FilterNow = new GenCheckMode(2) // filter out non-matching elements since we are not in -strict
125+
val FilterAlways = new GenCheckMode(3) // filter out non-matching elements since pattern is prefixed by `case`
125126
}
126127

127128
// ----- Modifiers -----------------------------------------------------

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1737,8 +1737,9 @@ object Parsers {
17371737
def generatorRest(pat: Tree, casePat: Boolean): GenFrom =
17381738
atSpan(startOffset(pat), accept(LARROW)) {
17391739
val checkMode =
1740-
if (casePat || !ctx.settings.strict.value) GenCheckMode.Filter // don't filter under -strict
1741-
else GenCheckMode.Check
1740+
if (casePat) GenCheckMode.FilterAlways
1741+
else if (ctx.settings.strict.value) GenCheckMode.Check
1742+
else GenCheckMode.FilterNow // filter for now, to keep backwards compat
17421743
GenFrom(pat, expr(), checkMode)
17431744
}
17441745

compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -568,7 +568,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
568568
case ForDo(enums, expr) =>
569569
forText(enums, expr, keywordStr(" do "))
570570
case GenFrom(pat, expr, checkMode) =>
571-
(Str("case ") provided checkMode == untpd.GenCheckMode.Filter) ~
571+
(Str("case ") provided checkMode == untpd.GenCheckMode.FilterAlways) ~
572572
toText(pat) ~ " <- " ~ toText(expr)
573573
case GenAlias(pat, expr) =>
574574
toText(pat) ~ " = " ~ toText(expr)

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -607,7 +607,7 @@ trait Checking {
607607
def fail(pat: Tree, pt: Type): Boolean = {
608608
var reportedPt = pt.dropAnnot(defn.UncheckedAnnot)
609609
if (!pat.tpe.isSingleton) reportedPt = reportedPt.widen
610-
val problem = if (pat.tpe <:< pt) "is more specialized than" else "does not match"
610+
val problem = if (pat.tpe <:< reportedPt) "is more specialized than" else "does not match"
611611
val fix = if (isPatDef) "`: @unchecked` after" else "`case ` before"
612612
ctx.errorOrMigrationWarning(
613613
ex"""pattern's type ${pat.tpe} $problem the right hand side expression's type $reportedPt

0 commit comments

Comments
 (0)