From b0ebe6ad30ce2584aa221b3ed8d10042bd9e97ac Mon Sep 17 00:00:00 2001 From: Nicolas Stucki Date: Fri, 10 Jun 2016 11:14:17 +0200 Subject: [PATCH] Fix #856: Handle try/catch cases as catch cases if possible. Previously they were all lifted into a match with the came cases. Now the first cases are handled directly by by the catch. If one of the cases can not be handled the old scheme is applied to to it and all subsequent cases. --- src/dotty/tools/dotc/Compiler.scala | 3 +- .../tools/dotc/transform/PatternMatcher.scala | 76 +--------- .../dotc/transform/TryCatchPatterns.scala | 99 +++++++++++++ tests/neg/tryPatternMatchError.scala | 35 +++++ tests/run/tryPatternMatch.check | 20 +++ tests/run/tryPatternMatch.scala | 139 ++++++++++++++++++ 6 files changed, 297 insertions(+), 75 deletions(-) create mode 100644 src/dotty/tools/dotc/transform/TryCatchPatterns.scala create mode 100644 tests/neg/tryPatternMatchError.scala create mode 100644 tests/run/tryPatternMatch.check create mode 100644 tests/run/tryPatternMatch.scala diff --git a/src/dotty/tools/dotc/Compiler.scala b/src/dotty/tools/dotc/Compiler.scala index 3844f42a715c..ce9280d827b8 100644 --- a/src/dotty/tools/dotc/Compiler.scala +++ b/src/dotty/tools/dotc/Compiler.scala @@ -57,7 +57,8 @@ class Compiler { new TailRec, // Rewrite tail recursion to loops new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods new ClassOf), // Expand `Predef.classOf` calls. - List(new PatternMatcher, // Compile pattern matches + List(new TryCatchPatterns, // Compile cases in try/catch + new PatternMatcher, // Compile pattern matches new ExplicitOuter, // Add accessors to outer classes from nested ones. new ExplicitSelf, // Make references to non-trivial self types explicit as casts new CrossCastAnd, // Normalize selections involving intersection types. diff --git a/src/dotty/tools/dotc/transform/PatternMatcher.scala b/src/dotty/tools/dotc/transform/PatternMatcher.scala index fd89696a83a3..974053769233 100644 --- a/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -1,6 +1,8 @@ package dotty.tools.dotc package transform +import scala.language.postfixOps + import TreeTransforms._ import core.Denotations._ import core.SymDenotations._ @@ -53,19 +55,6 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {thisTrans translated.ensureConforms(tree.tpe) } - - override def transformTry(tree: tpd.Try)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { - val selector = - ctx.newSymbol(ctx.owner, ctx.freshName("ex").toTermName, Flags.Synthetic | Flags.Case, defn.ThrowableType, coord = tree.pos) - val sel = Ident(selector.termRef).withPos(tree.pos) - val rethrow = tpd.CaseDef(EmptyTree, EmptyTree, Throw(ref(selector))) - val newCases = tpd.CaseDef( - Bind(selector, Underscore(selector.info).withPos(tree.pos)), - EmptyTree, - transformMatch(tpd.Match(sel, tree.cases ::: rethrow :: Nil))) - cpy.Try(tree)(tree.expr, newCases :: Nil, tree.finalizer) - } - class Translator(implicit ctx: Context) { def translator = { @@ -1264,27 +1253,6 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {thisTrans t } - /** Is this pattern node a catch-all or type-test pattern? */ - def isCatchCase(cdef: CaseDef) = cdef match { - case CaseDef(Typed(Ident(nme.WILDCARD), tpt), EmptyTree, _) => - isSimpleThrowable(tpt.tpe) - case CaseDef(Bind(_, Typed(Ident(nme.WILDCARD), tpt)), EmptyTree, _) => - isSimpleThrowable(tpt.tpe) - case _ => - isDefaultCase(cdef) - } - - private def isSimpleThrowable(tp: Type)(implicit ctx: Context): Boolean = tp match { - case tp @ TypeRef(pre, _) => - val sym = tp.symbol - (pre == NoPrefix || pre.widen.typeSymbol.isStatic) && - (sym.derivesFrom(defn.ThrowableClass)) && /* bq */ !(sym is Flags.Trait) - case _ => - false - } - - - /** Implement a pattern match by turning its cases (including the implicit failure case) * into the corresponding (monadic) extractors, and combining them with the `orElse` combinator. * @@ -1335,46 +1303,6 @@ class PatternMatcher extends MiniPhaseTransform with DenotTransformer {thisTrans Block(List(ValDef(selectorSym, sel)), combined) } - // return list of typed CaseDefs that are supported by the backend (typed/bind/wildcard) - // we don't have a global scrutinee -- the caught exception must be bound in each of the casedefs - // there's no need to check the scrutinee for null -- "throw null" becomes "throw new NullPointerException" - // try to simplify to a type-based switch, or fall back to a catch-all case that runs a normal pattern match - // unlike translateMatch, we type our result before returning it - /*def translateTry(caseDefs: List[CaseDef], pt: Type, pos: Position): List[CaseDef] = - // if they're already simple enough to be handled by the back-end, we're done - if (caseDefs forall isCatchCase) caseDefs - else { - val swatches = { // switch-catches - val bindersAndCases = caseDefs map { caseDef => - // generate a fresh symbol for each case, hoping we'll end up emitting a type-switch (we don't have a global scrut there) - // if we fail to emit a fine-grained switch, have to do translateCase again with a single scrutSym (TODO: uniformize substitution on treemakers so we can avoid this) - val caseScrutSym = freshSym(pos, pureType(defn.ThrowableType)) - (caseScrutSym, propagateSubstitution(translateCase(caseScrutSym, pt)(caseDef), EmptySubstitution)) - } - - for(cases <- emitTypeSwitch(bindersAndCases, pt).toList - if cases forall isCatchCase; // must check again, since it's not guaranteed -- TODO: can we eliminate this? e.g., a type test could test for a trait or a non-trivial prefix, which are not handled by the back-end - cse <- cases) yield /*fixerUpper(matchOwner, pos)*/(cse).asInstanceOf[CaseDef] - } - - val catches = if (swatches.nonEmpty) swatches else { - val scrutSym = freshSym(pos, pureType(defn.ThrowableType)) - val casesNoSubstOnly = caseDefs map { caseDef => (propagateSubstitution(translateCase(scrutSym, pt)(caseDef), EmptySubstitution))} - - val exSym = freshSym(pos, pureType(defn.ThrowableType), "ex") - - List( - CaseDef( - Bind(exSym, Ident(??? /*nme.WILDCARD*/)), // TODO: does this need fixing upping? - EmptyTree, - combineCasesNoSubstOnly(ref(exSym), scrutSym, casesNoSubstOnly, pt, matchOwner, Some((scrut: Symbol) => Throw(ref(exSym)))) - ) - ) - } - - /*typer.typedCases(*/catches/*, defn.ThrowableType, WildcardType)*/ - }*/ - /** The translation of `pat if guard => body` has two aspects: * 1) the substitution due to the variables bound by patterns * 2) the combination of the extractor calls using `flatMap`. diff --git a/src/dotty/tools/dotc/transform/TryCatchPatterns.scala b/src/dotty/tools/dotc/transform/TryCatchPatterns.scala new file mode 100644 index 000000000000..9a6ecef51e6b --- /dev/null +++ b/src/dotty/tools/dotc/transform/TryCatchPatterns.scala @@ -0,0 +1,99 @@ +package dotty.tools.dotc +package transform + +import core.Symbols._ +import core.StdNames._ +import ast.Trees._ +import core.Types._ +import dotty.tools.dotc.core.Decorators._ +import dotty.tools.dotc.core.Flags +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.transform.TreeTransforms.{MiniPhaseTransform, TransformerInfo} +import dotty.tools.dotc.util.Positions.Position + +/** Compiles the cases that can not be handled by primitive catch cases as a common pattern match. + * + * The following code: + * ``` + * try { } + * catch { + * // Cases that can be handled by catch + * // Cases starting with first one that can't be handled by catch + * } + * ``` + * will become: + * ``` + * try { } + * catch { + * + * case e => e match { + * + * } + * } + * ``` + * + * Cases that are not supported include: + * - Applies and unapplies + * - Idents + * - Alternatives + * - `case _: T =>` where `T` is not `Throwable` + * + */ +class TryCatchPatterns extends MiniPhaseTransform { + import dotty.tools.dotc.ast.tpd._ + + def phaseName: String = "tryCatchPatterns" + + override def runsAfter = Set(classOf[ElimRepeated]) + + override def checkPostCondition(tree: Tree)(implicit ctx: Context): Unit = tree match { + case Try(_, cases, _) => + cases.foreach { + case CaseDef(Typed(_, _), guard, _) => assert(guard.isEmpty, "Try case should not contain a guard.") + case CaseDef(Bind(_, _), guard, _) => assert(guard.isEmpty, "Try case should not contain a guard.") + case c => + assert(isDefaultCase(c), "Pattern in Try should be Bind, Typed or default case.") + } + case _ => + } + + override def transformTry(tree: Try)(implicit ctx: Context, info: TransformerInfo): Tree = { + val (tryCases, patternMatchCases) = tree.cases.span(isCatchCase) + val fallbackCase = mkFallbackPatterMatchCase(patternMatchCases, tree.pos) + cpy.Try(tree)(cases = tryCases ++ fallbackCase) + } + + /** Is this pattern node a catch-all or type-test pattern? */ + private def isCatchCase(cdef: CaseDef)(implicit ctx: Context): Boolean = cdef match { + case CaseDef(Typed(Ident(nme.WILDCARD), tpt), EmptyTree, _) => isSimpleThrowable(tpt.tpe) + case CaseDef(Bind(_, Typed(Ident(nme.WILDCARD), tpt)), EmptyTree, _) => isSimpleThrowable(tpt.tpe) + case _ => isDefaultCase(cdef) + } + + private def isSimpleThrowable(tp: Type)(implicit ctx: Context): Boolean = tp match { + case tp @ TypeRef(pre, _) => + (pre == NoPrefix || pre.widen.typeSymbol.isStatic) && // Does not require outer class check + !tp.symbol.is(Flags.Trait) && // Traits not supported by JVM + tp.derivesFrom(defn.ThrowableClass) + case _ => + false + } + + private def mkFallbackPatterMatchCase(patternMatchCases: List[CaseDef], pos: Position)( + implicit ctx: Context, info: TransformerInfo): Option[CaseDef] = { + if (patternMatchCases.isEmpty) None + else { + val exName = ctx.freshName("ex").toTermName + val fallbackSelector = + ctx.newSymbol(ctx.owner, exName, Flags.Synthetic | Flags.Case, defn.ThrowableType, coord = pos) + val sel = Ident(fallbackSelector.termRef).withPos(pos) + val rethrow = CaseDef(EmptyTree, EmptyTree, Throw(ref(fallbackSelector))) + Some(CaseDef( + Bind(fallbackSelector, Underscore(fallbackSelector.info).withPos(pos)), + EmptyTree, + transformFollowing(Match(sel, patternMatchCases ::: rethrow :: Nil))) + ) + } + } + +} diff --git a/tests/neg/tryPatternMatchError.scala b/tests/neg/tryPatternMatchError.scala new file mode 100644 index 000000000000..fe12a62329be --- /dev/null +++ b/tests/neg/tryPatternMatchError.scala @@ -0,0 +1,35 @@ +import java.io.IOException +import java.lang.NullPointerException +import java.lang.IllegalArgumentException + +object IAE { + def unapply(e: Exception): Option[String] = + if (e.isInstanceOf[IllegalArgumentException]) Some(e.getMessage) + else None +} + +object EX extends Exception + +trait ExceptionTrait extends Exception + +object Test { + def main(args: Array[String]): Unit = { + var a: Int = 1 + try { + throw new IllegalArgumentException() + } catch { + case e: IOException if e.getMessage == null => + case e: NullPointerException => + case e: IndexOutOfBoundsException => + case _: NoSuchElementException => + case _: ExceptionTrait => + case _: NoSuchElementException if a <= 1 => + case _: NullPointerException | _:IOException => + case `a` => // This case should probably emmit an error + case e: Int => // error + case EX => + case IAE(msg) => + case e: IllegalArgumentException => + } + } +} diff --git a/tests/run/tryPatternMatch.check b/tests/run/tryPatternMatch.check new file mode 100644 index 000000000000..44f7b7d5ac10 --- /dev/null +++ b/tests/run/tryPatternMatch.check @@ -0,0 +1,20 @@ +success 1 +success 2 +success 3 +success 4 +success 5 +success 6 +success 7 +success 8 +success 9.1 +success 9.2 +IllegalArgumentException: abc +IllegalArgumentException +NullPointerException | IOException +NoSuchElementException +EX +InnerException +NullPointerException +ExceptionTrait +ClassCastException +TimeoutException escaped diff --git a/tests/run/tryPatternMatch.scala b/tests/run/tryPatternMatch.scala new file mode 100644 index 000000000000..06b469d4d33d --- /dev/null +++ b/tests/run/tryPatternMatch.scala @@ -0,0 +1,139 @@ +import java.io.IOException +import java.util.concurrent.TimeoutException + +object IAE { + def unapply(e: Exception): Option[String] = + if (e.isInstanceOf[IllegalArgumentException] && e.getMessage != null) Some(e.getMessage) + else None +} + +object EX extends Exception { + val msg = "a" + class InnerException extends Exception(msg) +} + +trait ExceptionTrait extends Exception + +trait TestTrait { + type ExceptionType <: Exception + + def traitTest(): Unit = { + try { + throw new IOException + } catch { + case _: ExceptionType => println("success 9.2") + case _ => println("failed 9.2") + } + } +} + +object Test extends TestTrait { + type ExceptionType = IOException + + def main(args: Array[String]): Unit = { + var a: Int = 1 + + try { + throw new Exception("abc") + } catch { + case _: Exception => println("success 1") + case _ => println("failed 1") + } + + try { + throw new Exception("abc") + } catch { + case e: Exception => println("success 2") + case _ => println("failed 2") + } + + try { + throw new Exception("abc") + } catch { + case e: Exception if e.getMessage == "abc" => println("success 3") + case _ => println("failed 3") + } + + try { + throw new Exception("abc") + } catch { + case e: Exception if e.getMessage == "" => println("failed 4") + case _ => println("success 4") + } + + try { + throw EX + } catch { + case EX => println("success 5") + case _ => println("failed 5") + } + + try { + throw new EX.InnerException + } catch { + case _: EX.InnerException => println("success 6") + case _ => println("failed 6") + } + + try { + throw new NullPointerException + } catch { + case _: NullPointerException | _:IOException => println("success 7") + case _ => println("failed 7") + } + + try { + throw new ExceptionTrait {} + } catch { + case _: ExceptionTrait => println("success 8") + case _ => println("failed 8") + } + + try { + throw new IOException + } catch { + case _: ExceptionType => println("success 9.1") + case _ => println("failed 9.1") + } + + traitTest() // test 9.2 + + def testThrow(throwIt: => Unit): Unit = { + try { + throwIt + } catch { + // These cases will be compiled as catch cases + case e: NullPointerException => println("NullPointerException") + case e: IndexOutOfBoundsException => println("IndexOutOfBoundsException") + case _: NoSuchElementException => println("NoSuchElementException") + case _: EX.InnerException => println("InnerException") + // All the following will be compiled as a match + case IAE(msg) => println("IllegalArgumentException: " + msg) + case _: ExceptionTrait => println("ExceptionTrait") + case e: IOException if e.getMessage == null => println("IOException") + case _: NullPointerException | _:IOException => println("NullPointerException | IOException") + case `a` => println("`a`") + case EX => println("EX") + case e: IllegalArgumentException => println("IllegalArgumentException") + case _: ClassCastException => println("ClassCastException") + } + } + + testThrow(throw new IllegalArgumentException("abc")) + testThrow(throw new IllegalArgumentException()) + testThrow(throw new IOException("abc")) + testThrow(throw new NoSuchElementException()) + testThrow(throw EX) + testThrow(throw new EX.InnerException) + testThrow(throw new NullPointerException()) + testThrow(throw new ExceptionTrait {}) + testThrow(throw a.asInstanceOf[Throwable]) + try { + testThrow(throw new TimeoutException) + println("TimeoutException did not escape") + } catch { + case _: TimeoutException => println("TimeoutException escaped") + } + } + +}