Skip to content

Commit b12151f

Browse files
committed
Filter only for generators starting with case.
But wait with this for now, since we can't cross-compile easily otherwise. So currently this is enabled only under -strict.
1 parent f697d08 commit b12151f

File tree

5 files changed

+130
-55
lines changed

5 files changed

+130
-55
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 49 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,19 @@ object desugar {
3232
*/
3333
val DerivingCompanion: Property.Key[SourcePosition] = new Property.Key
3434

35-
/** An attachment for match expressions generated from a PatDef */
36-
val PatDefMatch: Property.Key[Unit] = new Property.Key
35+
/** An attachment for match expressions generated from a PatDef or GenFrom.
36+
* Value of key == one of IrrefutablePatDef, IrrefutableGenFrom
37+
*/
38+
val CheckIrrefutable: Property.Key[MatchCheck] = new Property.StickyKey
39+
40+
/** What static check should be applied to a Match (none, irrefutable, exhaustive) */
41+
class MatchCheck(val n: Int) extends AnyVal
42+
object MatchCheck {
43+
val None = new MatchCheck(0)
44+
val Exhaustive = new MatchCheck(1)
45+
val IrrefutablePatDef = new MatchCheck(2)
46+
val IrrefutableGenFrom = new MatchCheck(3)
47+
}
3748

3849
/** Info of a variable in a pattern: The named tree and its type */
3950
private type VarInfo = (NameTree, Tree)
@@ -925,6 +936,22 @@ object desugar {
925936
}
926937
}
927938

939+
/** The selector of a match, which depends of the given `checkMode`.
940+
* @param sel the original selector
941+
* @return if `checkMode` is
942+
* - None : sel @unchecked
943+
* - Exhaustive : sel
944+
* - IrrefutablePatDef,
945+
* IrrefutableGenFrom: sel @unchecked with attachment `CheckIrrefutable -> checkMode`
946+
*/
947+
def makeSelector(sel: Tree, checkMode: MatchCheck)(implicit ctx: Context): Tree =
948+
if (checkMode == MatchCheck.Exhaustive) sel
949+
else {
950+
val sel1 = Annotated(sel, New(ref(defn.UncheckedAnnotType)))
951+
if (checkMode != MatchCheck.None) sel1.pushAttachment(CheckIrrefutable, checkMode)
952+
sel1
953+
}
954+
928955
/** If `pat` is a variable pattern,
929956
*
930957
* val/var/lazy val p = e
@@ -959,11 +986,6 @@ object desugar {
959986
// - `pat` is a tuple of N variables or wildcard patterns like `(x1, x2, ..., xN)`
960987
val tupleOptimizable = forallResults(rhs, isMatchingTuple)
961988

962-
def rhsUnchecked = {
963-
val rhs1 = makeAnnotated("scala.unchecked", rhs)
964-
rhs1.pushAttachment(PatDefMatch, ())
965-
rhs1
966-
}
967989
val vars =
968990
if (tupleOptimizable) // include `_`
969991
pat match {
@@ -976,7 +998,7 @@ object desugar {
976998
val caseDef = CaseDef(pat, EmptyTree, makeTuple(ids))
977999
val matchExpr =
9781000
if (tupleOptimizable) rhs
979-
else Match(rhsUnchecked, caseDef :: Nil)
1001+
else Match(makeSelector(rhs, MatchCheck.IrrefutablePatDef), caseDef :: Nil)
9801002
vars match {
9811003
case Nil =>
9821004
matchExpr
@@ -1125,14 +1147,10 @@ object desugar {
11251147
*
11261148
* (x$1, ..., x$n) => (x$0, ..., x${n-1} @unchecked?) match { cases }
11271149
*/
1128-
def makeCaseLambda(cases: List[CaseDef], nparams: Int = 1, unchecked: Boolean = true)(implicit ctx: Context): Function = {
1150+
def makeCaseLambda(cases: List[CaseDef], checkMode: MatchCheck, nparams: Int = 1)(implicit ctx: Context): Function = {
11291151
val params = (1 to nparams).toList.map(makeSyntheticParameter(_))
11301152
val selector = makeTuple(params.map(p => Ident(p.name)))
1131-
1132-
if (unchecked)
1133-
Function(params, Match(Annotated(selector, New(ref(defn.UncheckedAnnotType))), cases))
1134-
else
1135-
Function(params, Match(selector, cases))
1153+
Function(params, Match(makeSelector(selector, checkMode), cases))
11361154
}
11371155

11381156
/** Map n-ary function `(p1, ..., pn) => body` where n != 1 to unary function as follows:
@@ -1262,13 +1280,14 @@ object desugar {
12621280

12631281
/** Make a function value pat => body.
12641282
* If pat is a var pattern id: T then this gives (id: T) => body
1265-
* Otherwise this gives { case pat => body }
1283+
* Otherwise this gives { case pat => body }, where `pat` is allowed to be
1284+
* refutable only if `checkMode` is MatchCheck.None.
12661285
*/
1267-
def makeLambda(pat: Tree, body: Tree): Tree = pat match {
1286+
def makeLambda(pat: Tree, body: Tree, checkMode: MatchCheck): Tree = pat match {
12681287
case IdPattern(named, tpt) =>
12691288
Function(derivedValDef(pat, named, tpt, EmptyTree, Modifiers(Param)) :: Nil, body)
12701289
case _ =>
1271-
makeCaseLambda(CaseDef(pat, EmptyTree, body) :: Nil)
1290+
makeCaseLambda(CaseDef(pat, EmptyTree, body) :: Nil, checkMode)
12721291
}
12731292

12741293
/** If `pat` is not an Identifier, a Typed(Ident, _), or a Bind, wrap
@@ -1314,7 +1333,7 @@ object desugar {
13141333
val cases = List(
13151334
CaseDef(pat, EmptyTree, Literal(Constant(true))),
13161335
CaseDef(Ident(nme.WILDCARD), EmptyTree, Literal(Constant(false))))
1317-
Apply(Select(rhs, nme.withFilter), makeCaseLambda(cases))
1336+
Apply(Select(rhs, nme.withFilter), makeCaseLambda(cases, MatchCheck.None))
13181337
}
13191338

13201339
/** Is pattern `pat` irrefutable when matched against `rhs`?
@@ -1353,26 +1372,30 @@ object desugar {
13531372
Select(rhs, name)
13541373
}
13551374

1375+
def checkMode(gen: GenFrom) =
1376+
if (gen.filtering) MatchCheck.None // refutable paterns were already eliminated in filter step
1377+
else MatchCheck.IrrefutableGenFrom
1378+
13561379
enums match {
13571380
case (gen: GenFrom) :: Nil =>
1358-
Apply(rhsSelect(gen, mapName), makeLambda(gen.pat, body))
1381+
Apply(rhsSelect(gen, mapName), makeLambda(gen.pat, body, checkMode(gen)))
13591382
case (gen: GenFrom) :: (rest @ (GenFrom(_, _, _) :: _)) =>
13601383
val cont = makeFor(mapName, flatMapName, rest, body)
1361-
Apply(rhsSelect(gen, flatMapName), makeLambda(gen.pat, cont))
1362-
case (gen @ GenFrom(pat, rhs, _)) :: (rest @ GenAlias(_, _) :: _) =>
1384+
Apply(rhsSelect(gen, flatMapName), makeLambda(gen.pat, cont, checkMode(gen)))
1385+
case (gen: GenFrom) :: (rest @ GenAlias(_, _) :: _) =>
13631386
val (valeqs, rest1) = rest.span(_.isInstanceOf[GenAlias])
13641387
val pats = valeqs map { case GenAlias(pat, _) => pat }
13651388
val rhss = valeqs map { case GenAlias(_, rhs) => rhs }
1366-
val (defpat0, id0) = makeIdPat(pat)
1389+
val (defpat0, id0) = makeIdPat(gen.pat)
13671390
val (defpats, ids) = (pats map makeIdPat).unzip
13681391
val pdefs = (valeqs, defpats, rhss).zipped.map(makePatDef(_, Modifiers(), _, _))
1369-
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, rhs, gen.filtering) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
1370-
val allpats = pat :: pats
1392+
val rhs1 = makeFor(nme.map, nme.flatMap, GenFrom(defpat0, gen.expr, gen.filtering) :: Nil, Block(pdefs, makeTuple(id0 :: ids)))
1393+
val allpats = gen.pat :: pats
13711394
val vfrom1 = new GenFrom(makeTuple(allpats), rhs1, filtering = false)
13721395
makeFor(mapName, flatMapName, vfrom1 :: rest1, body)
13731396
case (gen: GenFrom) :: test :: rest =>
1374-
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen.pat, test))
1375-
val genFrom = new GenFrom(gen.pat, filtered, filtering = false)
1397+
val filtered = Apply(rhsSelect(gen, nme.withFilter), makeLambda(gen.pat, test, MatchCheck.None))
1398+
val genFrom = GenFrom(gen.pat, filtered, filtering = false)
13761399
makeFor(mapName, flatMapName, genFrom :: rest, body)
13771400
case _ =>
13781401
EmptyTree //may happen for erroneous input

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -601,42 +601,48 @@ trait Checking {
601601
* This means `pat` is either marked @unchecked or `pt` conforms to the
602602
* pattern's type. If pattern is an UnApply, do the check recursively.
603603
*/
604-
def checkIrrefutable(pat: Tree, pt: Type)(implicit ctx: Context): Boolean = {
605-
patmatch.println(i"check irrefutable $pat: ${pat.tpe} against $pt")
604+
def checkIrrefutable(pat: Tree, pt: Type, isPatDef: Boolean)(implicit ctx: Context): Boolean = {
606605

607606
def check(pat: Tree, pt: Type): Boolean = {
608607
if (pt <:< pat.tpe)
609608
true
610609
else {
610+
var reportedPt = pt.dropAnnot(defn.UncheckedAnnot)
611+
if (!pat.tpe.isSingleton) reportedPt = reportedPt.widen
612+
val fix = if (isPatDef) "`: @unchecked` after" else "`case ` before"
611613
ctx.errorOrMigrationWarning(
612-
ex"""pattern's type ${pat.tpe} is more specialized than the right hand side expression's type ${pt.dropAnnot(defn.UncheckedAnnot)}
614+
ex"""pattern's type ${pat.tpe} is more specialized than the right hand side expression's type $reportedPt
613615
|
614-
|If the narrowing is intentional, this can be communicated by writing `: @unchecked` after the full pattern.${err.rewriteNotice}""",
616+
|If the narrowing is intentional, this can be communicated by writing $fix the full pattern.${err.rewriteNotice}""",
615617
pat.sourcePos)
616618
false
617619
}
618620
}
619621

620-
!ctx.settings.strict.value || // only in -strict mode for now since mitigations work only after this PR
621-
pat.tpe.widen.hasAnnotation(defn.UncheckedAnnot) || {
622-
pat match {
623-
case Bind(_, pat1) =>
624-
checkIrrefutable(pat1, pt)
625-
case UnApply(fn, _, pats) =>
626-
check(pat, pt) && {
627-
val argPts = unapplyArgs(fn.tpe.finalResultType, fn, pats, pat.sourcePos)
628-
pats.corresponds(argPts)(checkIrrefutable)
629-
}
630-
case Alternative(pats) =>
631-
pats.forall(checkIrrefutable(_, pt))
632-
case Typed(arg, tpt) =>
633-
check(pat, pt) && checkIrrefutable(arg, pt)
634-
case Ident(nme.WILDCARD) =>
635-
true
636-
case _ =>
637-
check(pat, pt)
622+
def recur(pat: Tree, pt: Type): Boolean =
623+
!ctx.settings.strict.value || // only in -strict mode for now since mitigations work only after this PR
624+
pat.tpe.widen.hasAnnotation(defn.UncheckedAnnot) || {
625+
patmatch.println(i"check irrefutable $pat: ${pat.tpe} against $pt")
626+
pat match {
627+
case Bind(_, pat1) =>
628+
recur(pat1, pt)
629+
case UnApply(fn, _, pats) =>
630+
check(pat, pt) && {
631+
val argPts = unapplyArgs(fn.tpe.finalResultType, fn, pats, pat.sourcePos)
632+
pats.corresponds(argPts)(recur)
633+
}
634+
case Alternative(pats) =>
635+
pats.forall(recur(_, pt))
636+
case Typed(arg, tpt) =>
637+
check(pat, pt) && recur(arg, pt)
638+
case Ident(nme.WILDCARD) =>
639+
true
640+
case _ =>
641+
check(pat, pt)
642+
}
638643
}
639-
}
644+
645+
recur(pat, pt)
640646
}
641647

642648
/** Check that `path` is a legal prefix for an import or export clause */

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,19 +1029,26 @@ class Typer extends Namer
10291029
}
10301030
else {
10311031
val (protoFormals, _) = decomposeProtoFunction(pt, 1)
1032-
val unchecked = pt.isRef(defn.PartialFunctionClass)
1033-
typed(desugar.makeCaseLambda(tree.cases, protoFormals.length, unchecked).withSpan(tree.span), pt)
1032+
val checkMode =
1033+
if (pt.isRef(defn.PartialFunctionClass)) desugar.MatchCheck.None
1034+
else desugar.MatchCheck.Exhaustive
1035+
typed(desugar.makeCaseLambda(tree.cases, checkMode, protoFormals.length).withSpan(tree.span), pt)
10341036
}
10351037
case _ =>
10361038
if (tree.isInline) checkInInlineContext("inline match", tree.posd)
10371039
val sel1 = typedExpr(tree.selector)
10381040
val selType = fullyDefinedType(sel1.tpe, "pattern selector", tree.span).widen
10391041
val result = typedMatchFinish(tree, sel1, selType, tree.cases, pt)
10401042
result match {
1041-
case Match(sel, CaseDef(pat, _, _) :: _)
1042-
if (tree.selector.removeAttachment(desugar.PatDefMatch).isDefined) =>
1043-
if (!checkIrrefutable(pat, sel.tpe) && ctx.scala2Mode)
1044-
patch(Span(pat.span.end), ": @unchecked")
1043+
case Match(sel, CaseDef(pat, _, _) :: _) =>
1044+
tree.selector.removeAttachment(desugar.CheckIrrefutable) match {
1045+
case Some(checkMode) =>
1046+
val isPatDef = checkMode == desugar.MatchCheck.IrrefutablePatDef
1047+
if (!checkIrrefutable(pat, sel.tpe, isPatDef) && ctx.settings.migration.value)
1048+
if (isPatDef) patch(Span(pat.span.end), ": @unchecked")
1049+
else patch(Span(pat.span.start), "case ")
1050+
case _ =>
1051+
}
10451052
case _ =>
10461053
}
10471054
result

compiler/test/dotty/tools/dotc/CompilationTests.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,7 @@ class CompilationTests extends ParallelTesting {
194194
compileFilesInDir("tests/run-custom-args/Yretain-trees", defaultOptions and "-Yretain-trees"),
195195
compileFile("tests/run-custom-args/tuple-cons.scala", allowDeepSubtypes),
196196
compileFile("tests/run-custom-args/i5256.scala", allowDeepSubtypes),
197+
compileFile("tests/run-custom-args/fors.scala", defaultOptions and "-strict"),
197198
compileFile("tests/run-custom-args/no-useless-forwarders.scala", defaultOptions and "-Xmixin-force-forwarders:false"),
198199
compileFilesInDir("tests/run", defaultOptions)
199200
).checkRuns()

tests/neg/zipped.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
// This test shows some un-intuitive behavior of the `zipped` method.
2+
object Test {
3+
val xs: List[Int] = ???
4+
5+
// 1. This works, since withFilter is not defined on Tuple3zipped. Instead,
6+
// an implicit conversion from Tuple3zipped to Traversable[(Int, Int, Int)] is inserted.
7+
// The subsequent map operation has the right type for this Traversable.
8+
(xs, xs, xs).zipped
9+
.withFilter( (x: (Int, Int, Int)) => x match { case (x, y, z) => true } ) // OK
10+
.map( (x: (Int, Int, Int)) => x match { case (x, y, z) => x + y + z }) // OK
11+
12+
13+
// 2. This works as well, because of auto untupling i.e. `case` is inserted.
14+
// But it does not work in Scala2.
15+
(xs, xs, xs).zipped
16+
.withFilter( (x: (Int, Int, Int)) => x match { case (x, y, z) => true } ) // OK
17+
.map( (x: Int, y: Int, z: Int) => x + y + z ) // OK
18+
// works, because of auto untupling i.e. `case` is inserted
19+
// does not work in Scala2
20+
21+
// 3. Now, without withFilter, it's the opposite, we need the 3 parameter map.
22+
(xs, xs, xs).zipped
23+
.map( (x: Int, y: Int, z: Int) => x + y + z ) // OK
24+
25+
// 4. The single parameter map does not work.
26+
(xs, xs, xs).zipped
27+
.map( (x: (Int, Int, Int)) => x match { case (x, y, z) => x + y + z }) // error
28+
29+
// 5. If we leave out the parameter type, we get a "Wrong number of parameters" error instead
30+
(xs, xs, xs).zipped
31+
.map( x => x match { case (x, y, z) => x + y + z }) // error
32+
33+
// This means that the following works in Dotty in normal mode, since a `withFilter`
34+
// is inserted. But it does no work under -strict. And it will not work in Scala 3.1.
35+
// The reason is that without -strict, the code below is mapped to (1), but with -strict
36+
// it is mapped to (5).
37+
for ((x, y, z) <- (xs, xs, xs).zipped) yield x + y + z
38+
}

0 commit comments

Comments
 (0)