diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 8ee016117de3..22c6c7f342fe 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -49,6 +49,7 @@ class Compiler { List(new FirstTransform, // Some transformations to put trees into a canonical form new CheckReentrant), // Internal use only: Check that compiled program has no data races involving global vars List(new CheckStatic, // Check restrictions that apply to @static members + new CheckPhantomCast, // Checks that no Phantom types in are in casts new ElimRepeated, // Rewrite vararg parameters and arguments new RefChecks, // Various checks mostly related to abstract members and overriding new NormalizeFlags, // Rewrite some definition flags diff --git a/compiler/src/dotty/tools/dotc/transform/CheckPhantomCast.scala b/compiler/src/dotty/tools/dotc/transform/CheckPhantomCast.scala new file mode 100644 index 000000000000..f704ebb6b02e --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/CheckPhantomCast.scala @@ -0,0 +1,50 @@ +package dotty.tools.dotc +package transform + +import core._ +import dotty.tools.dotc.transform.TreeTransforms.{MiniPhaseTransform, TransformerInfo} +import Types._ +import Contexts.Context +import Symbols._ +import Decorators._ +import dotty.tools.dotc.ast.Trees._ +import dotty.tools.dotc.ast.tpd + + +/** A no-op transform to ensure that the compiled sources have no Phantom types in casts */ +class CheckPhantomCast extends MiniPhaseTransform { thisTransformer => + + override def phaseName = "checkPhantomCast" + + override def checkPostCondition(tree: tpd.Tree)(implicit ctx: Context): Unit = { + tree match { + case TypeApply(fun, targs) if fun.symbol eq defn.Any_asInstanceOf => assert(!containsPhantom(targs.head.tpe)) + case Bind(_, Typed(_, tpt)) => assert(!containsPhantom(tpt.tpe)) + case _ => + } + } + + override def transformTypeApply(tree: tpd.TypeApply)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { + if (tree.fun.symbol eq defn.Any_asInstanceOf) + checkNoPhantoms(tree.args.head) + tree + } + + override def transformBind(tree: tpd.Bind)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { + tree.body match { + case Typed(_, tpt) => checkNoPhantoms(tpt) + case _ => + } + tree + } + + private def checkNoPhantoms(tpTree: tpd.Tree)(implicit ctx: Context): Unit = { + if (containsPhantom(tpTree.tpe)) + ctx.error("Cannot cast type containing a phantom type", tpTree.pos) + } + + private def containsPhantom(tp: Type)(implicit ctx: Context): Boolean = new TypeAccumulator[Boolean] { + override def apply(x: Boolean, tp: Type): Boolean = x || tp.isPhantom || foldOver(false, tp) + }.apply(x = false, tp) + +} diff --git a/tests/neg/phantom-class-type-members.scala b/tests/neg/phantom-class-type-members.scala new file mode 100644 index 000000000000..480c2fb686cd --- /dev/null +++ b/tests/neg/phantom-class-type-members.scala @@ -0,0 +1,47 @@ +import Boo._ + +object Test { + def main(args: Array[String]): Unit = { + val a = new Bar() + foo(a.asInstanceOf[Foo{type T = BooNothing}].y) // error + + a match { + case a: Foo{type T = BooNothing} => a.y // error + } + + val b = new Baz + b.asInstanceOf[Foo{type T = BooAny}].z(any) // error + + b match { + case b: Foo{type T = BooAny} => a.z(any) // error + } + } + + def foo(x: BooNothing) = println("foo") + +} + +abstract class Foo { + type T <: BooAny + def y: T + def z(z: T): Unit +} + +class Bar extends Foo { + type T = BooAny + def y: T = any + def z(z: T) = () +} + +class Baz extends Foo { + type T = BooNothing + def y: T = nothing + def z(z: T) = () +} + +object Boo extends Phantom { + type BooAny = this.Any + type BooNothing = this.Nothing + def any: BooAny = assume + def nothing: BooNothing = assume +} diff --git a/tests/neg/phantom-class-type-parameters.scala b/tests/neg/phantom-class-type-parameters.scala new file mode 100644 index 000000000000..abed46b6e38a --- /dev/null +++ b/tests/neg/phantom-class-type-parameters.scala @@ -0,0 +1,34 @@ +import Boo._ + +object Test { + def main(args: Array[String]): Unit = { + val a = new Foo[BooAny](any) + foo(a.asInstanceOf[Foo[BooNothing]].x) // error + foo(a.asInstanceOf[Foo[BooNothing]].y) // error + + a match { + case a: Foo[BooNothing] => a.x // error + } + + val b = new Foo[BooNothing](a.asInstanceOf[Foo[BooNothing]].x) // error + b.asInstanceOf[Foo[BooAny]].z(any) // error + + b match { + case b: Foo[BooAny] => b.z(any) // error + } + } + + def foo(x: BooNothing) = println("foo") + +} + +class Foo[T <: BooAny](val x: T) { + def y: T = x + def z(z: T) = () +} + +object Boo extends Phantom { + type BooAny = this.Any + type BooNothing = this.Nothing + def any: BooAny = assume +}