From 07f93032b35f41dab0e09d63f718f5feddf56c88 Mon Sep 17 00:00:00 2001 From: Liu Fengyun Date: Sun, 6 Dec 2020 16:34:55 +0100 Subject: [PATCH] Fix #5077: avoid pattern-bound type for selectors Given the following definition: trait Is[A] case object IsInt extends Is[Int] case object IsString extends Is[String] case class C[A](is: Is[A], value: A) and the pattern match: (x: C[_]) match case C(IsInt, i) => ... The typer (enhanced with GADT) will infer `C[A$1]` as the type of the pattern, where `A$1 =:= Int`. The patternMatcher generates the following code: case val x15: C[A$1] = C.unapply[A$1 @ A$1](x9.$asInstanceOf$[C[A$1]]) case val x16: Is[A$1] = x15._1 case val x17: A$1 = x15._2 // erase to `Int` if IsInt.==(x16) then { case val i: A$1 = x17 ... } Note that `x17` will have the erased type `Int`. This is incorrect: it may only assume the type `Int` if the test is true inside the block. If the test is false, we will get an type cast exception at runtime. To fix the problem, we replace pattern-bound types by `Any` for selector results: case val x15: C[A$1] = C.unapply[A$1 @ A$1](x9.$asInstanceOf$[C[A$1]]) case val x16: Is[A$1] = x15._1 case val x17: Any = x15._2 if IsInt.==(x16) then { case val i: A$1 = x17.$asInstanceOf$[A$1] ... } The patternMatcher will then use a type cast to pass the selector value for nested unapplys or assign it to bound variables. --- .../tools/dotc/transform/PatternMatcher.scala | 18 +++++++--- tests/run/i5077.check | 3 ++ tests/run/i5077.scala | 33 +++++++++++++++++++ 3 files changed, 49 insertions(+), 5 deletions(-) create mode 100644 tests/run/i5077.check create mode 100644 tests/run/i5077.scala diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index a20f11e3cf9c..1be6bbda98fb 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -94,14 +94,15 @@ object PatternMatcher { */ private val initializer = MutableSymbolMap[Tree]() - private def newVar(rhs: Tree, flags: FlagSet): TermSymbol = + private def newVar(rhs: Tree, flags: FlagSet, tpe: Type): TermSymbol = newSymbol(ctx.owner, PatMatStdBinderName.fresh(), Synthetic | Case | flags, - sanitize(rhs.tpe), coord = rhs.span) + sanitize(tpe), coord = rhs.span) // TODO: Drop Case once we use everywhere else `isPatmatGenerated`. /** The plan `let x = rhs in body(x)` where `x` is a fresh variable */ - private def letAbstract(rhs: Tree)(body: Symbol => Plan): Plan = { - val vble = newVar(rhs, EmptyFlags) + private def letAbstract(rhs: Tree, tpe: Type = NoType)(body: Symbol => Plan): Plan = { + val declTpe = if tpe.exists then tpe else rhs.tpe + val vble = newVar(rhs, EmptyFlags, declTpe) initializer(vble) = rhs LetPlan(vble, body(vble)) } @@ -223,6 +224,13 @@ object PatternMatcher { /** Plan for matching `scrutinee` symbol against `tree` pattern */ private def patternPlan(scrutinee: Symbol, tree: Tree, onSuccess: Plan): Plan = { + extension (tree: Tree) def avoidPatBoundType(): Type = + tree.tpe.widen match + case tref: TypeRef if tref.symbol.isPatternBound => + defn.AnyType + case _ => + tree.tpe + /** Plan for matching `selectors` against argument patterns `args` */ def matchArgsPlan(selectors: List[Tree], args: List[Tree], onSuccess: Plan): Plan = { /* For a case with arguments that have some test on them such as @@ -243,7 +251,7 @@ object PatternMatcher { */ def matchArgsSelectorsPlan(selectors: List[Tree], syms: List[Symbol]): Plan = selectors match { - case selector :: selectors1 => letAbstract(selector)(sym => matchArgsSelectorsPlan(selectors1, sym :: syms)) + case selector :: selectors1 => letAbstract(selector, selector.avoidPatBoundType())(sym => matchArgsSelectorsPlan(selectors1, sym :: syms)) case Nil => matchArgsPatternPlan(args, syms.reverse) } def matchArgsPatternPlan(args: List[Tree], syms: List[Symbol]): Plan = diff --git a/tests/run/i5077.check b/tests/run/i5077.check new file mode 100644 index 000000000000..71fb2bf10e1c --- /dev/null +++ b/tests/run/i5077.check @@ -0,0 +1,3 @@ +A String with length 4 +A String with length 4 +A String with length 4 diff --git a/tests/run/i5077.scala b/tests/run/i5077.scala new file mode 100644 index 000000000000..bf70cb0c0d19 --- /dev/null +++ b/tests/run/i5077.scala @@ -0,0 +1,33 @@ +trait Is[A] +case object IsInt extends Is[Int] +case object IsString extends Is[String] +case class C[A](is: Is[A], value: A) + +@main +def Test = { + val c_string: C[String] = C(IsString, "name") + val c_any: C[_] = c_string + val any: Any = c_string + + // Case 1: no error + // `IsInt.equals` might be overridden to match a value of `C[String]` + c_string match { + case C(IsInt, _) => println(s"An Int") // Can't possibly happen! + case C(IsString, s) => println(s"A String with length ${s.length}") + case _ => println("No match") + } + + // Case 2: Should match the second case and print the length of the string + c_any match { + case C(IsInt, i) if i < 10 => println(s"An Int less than 10") + case C(IsString, s) => println(s"A String with length ${s.length}") + case _ => println("No match") + } + + // Case 3: Same as above; should match the second case and print the length of the string + any match { + case C(IsInt, i) if i < 10 => println(s"An Int less than 10") + case C(IsString, s) => println(s"A String with length ${s.length}") + case _ => println("No match") + } +} \ No newline at end of file