diff --git a/compiler/src/dotty/tools/dotc/staging/CrossStageSafety.scala b/compiler/src/dotty/tools/dotc/staging/CrossStageSafety.scala index 219b428ca8d4..98e060488f43 100644 --- a/compiler/src/dotty/tools/dotc/staging/CrossStageSafety.scala +++ b/compiler/src/dotty/tools/dotc/staging/CrossStageSafety.scala @@ -107,30 +107,37 @@ class CrossStageSafety extends TreeMapWithStages { val stripAnnotsDeep: TypeMap = new TypeMap: def apply(tp: Type): Type = mapOver(tp.stripAnnots) - val contextWithQuote = - if level == 0 then contextWithQuoteTypeTags(taggedTypes)(using quoteContext) - else quoteContext - val body1 = transform(body)(using contextWithQuote) - val body2 = + def transformBody() = + val contextWithQuote = + if level == 0 then contextWithQuoteTypeTags(taggedTypes)(using quoteContext) + else quoteContext + val transformedBody = transform(body)(using contextWithQuote) taggedTypes.getTypeTags match - case Nil => body1 - case tags => tpd.Block(tags, body1).withSpan(body.span) + case Nil => transformedBody + case tags => tpd.Block(tags, transformedBody).withSpan(body.span) if body.isTerm then + val transformedBody = transformBody() // `quoted.runtime.Expr.quote[T]()` --> `quoted.runtime.Expr.quote[T2]()` val TypeApply(fun, targs) = quote.fun: @unchecked val targs2 = targs.map(targ => TypeTree(healType(quote.fun.srcPos)(stripAnnotsDeep(targ.tpe)))) - cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, targs2), body2 :: Nil) + cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, targs2), transformedBody :: Nil) else - val quotes = quote.args.mapConserve(transform) body.tpe match - case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice => + case DirectTypeOf(termRef) => // Optimization: `quoted.Type.of[x.Underlying](quotes)` --> `x` - ref(x) + ref(termRef).withSpan(quote.span) case _ => - // `quoted.Type.of[](quotes)` --> `quoted.Type.of[](quotes)` - val TypeApply(fun, _) = quote.fun: @unchecked - cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, body2 :: Nil), quotes) + transformBody() match + case DirectTypeOf.Healed(termRef) => + // Optimization: `quoted.Type.of[@SplicedType type T = x.Underlying; T](quotes)` --> `x` + ref(termRef).withSpan(quote.span) + case transformedBody => + val quotes = quote.args.mapConserve(transform) + // `quoted.Type.of[](quotes)` --> `quoted.Type.of[](quotes)` + val TypeApply(fun, _) = quote.fun: @unchecked + cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, transformedBody :: Nil), quotes) + } /** Transform splice diff --git a/compiler/src/dotty/tools/dotc/staging/DirectTypeOf.scala b/compiler/src/dotty/tools/dotc/staging/DirectTypeOf.scala new file mode 100644 index 000000000000..488d8ff2a88e --- /dev/null +++ b/compiler/src/dotty/tools/dotc/staging/DirectTypeOf.scala @@ -0,0 +1,25 @@ +package dotty.tools.dotc.staging + +import dotty.tools.dotc.ast.{tpd, untpd} +import dotty.tools.dotc.core.Contexts._ +import dotty.tools.dotc.core.Symbols._ +import dotty.tools.dotc.core.Types._ + +object DirectTypeOf: + import tpd.* + + /** Matches `x.Underlying` and extracts the TermRef to `x` */ + def unapply(tpe: Type)(using Context): Option[TermRef] = tpe match + case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice => Some(x) + case _ => None + + object Healed: + /** Matches `{ @SplicedType type T = x.Underlying; T }` and extracts the TermRef to `x` */ + def unapply(body: Tree)(using Context): Option[TermRef] = + body match + case Block(List(tdef: TypeDef), tpt: TypeTree) => + tpt.tpe match + case tpe: TypeRef if tpe.typeSymbol == tdef.symbol => + DirectTypeOf.unapply(tdef.rhs.tpe.hiBound) + case _ => None + case _ => None diff --git a/compiler/src/dotty/tools/dotc/staging/HealType.scala b/compiler/src/dotty/tools/dotc/staging/HealType.scala index 22008f381c32..7907c2e47542 100644 --- a/compiler/src/dotty/tools/dotc/staging/HealType.scala +++ b/compiler/src/dotty/tools/dotc/staging/HealType.scala @@ -32,7 +32,12 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap { def apply(tp: Type): Type = tp match case tp: TypeRef => - healTypeRef(tp) + tp.underlying match + case TypeAlias(alias) + if !tp.symbol.isTypeSplice && !tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) => + this.apply(alias) + case _ => + healTypeRef(tp) case tp @ TermRef(NoPrefix, _) if !tp.symbol.isStatic && level > levelOf(tp.symbol) => levelError(tp.symbol, tp, pos) case tp: AnnotatedType => @@ -46,11 +51,11 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap { checkNotWildcardSplice(tp) if level == 0 then tp else getQuoteTypeTags.getTagRef(prefix) case prefix: TermRef if !prefix.symbol.isStatic && level > levelOf(prefix.symbol) => - dealiasAndTryHeal(prefix.symbol, tp, pos) + tryHeal(prefix.symbol, tp, pos) case NoPrefix if level > levelOf(tp.symbol) && !tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) => - dealiasAndTryHeal(tp.symbol, tp, pos) + tryHeal(tp.symbol, tp, pos) case prefix: ThisType if level > levelOf(prefix.cls) && !tp.symbol.isStatic => - dealiasAndTryHeal(tp.symbol, tp, pos) + tryHeal(tp.symbol, tp, pos) case _ => mapOver(tp) @@ -59,11 +64,6 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap { case (tb: TypeBounds) :: _ => report.error(em"Cannot splice $splice because it is a wildcard type", pos) case _ => - private def dealiasAndTryHeal(sym: Symbol, tp: TypeRef, pos: SrcPos): Type = - val tp1 = tp.dealias - if tp1 != tp then apply(tp1) - else tryHeal(tp.symbol, tp, pos) - /** Try to heal reference to type `T` used in a higher level than its definition. * Returns a reference to a type tag generated by `QuoteTypeTags` that contains a * reference to a type alias containing the equivalent of `${summon[quoted.Type[T]]}`. diff --git a/tests/pos-macros/i8100b.scala b/tests/pos-macros/i8100b.scala new file mode 100644 index 000000000000..ecba10e439d2 --- /dev/null +++ b/tests/pos-macros/i8100b.scala @@ -0,0 +1,37 @@ +import scala.quoted.* + +def f[T](using t: Type[T])(using Quotes) = + '{ + // @SplicedType type t$1 = t.Underlying + type T2 = T // type T2 = t$1 + ${ + + val t0: T = ??? + val t1: T2 = ??? // val t1: T = ??? + val tp1 = Type.of[T] // val tp1 = t + val tp2 = Type.of[T2] // val tp2 = t + '{ + // @SplicedType type t$2 = t.Underlying + val t3: T = ??? // val t3: t$2 = ??? + val t4: T2 = ??? // val t4: t$2 = ??? + } + } + } + +def g(using Quotes) = + '{ + type U + type U2 = U + ${ + + val u1: U = ??? + val u2: U2 = ??? // val u2: U = ??? + + val tp1 = Type.of[U] // val tp1 = Type.of[U] + val tp2 = Type.of[U2] // val tp2 = Type.of[U] + '{ + val u3: U = ??? + val u4: U2 = ??? // val u4: U = ??? + } + } + } diff --git a/tests/run-macros/i12392.check b/tests/run-macros/i12392.check index 54c7f5d06c3f..92bbfa65fb49 100644 --- a/tests/run-macros/i12392.check +++ b/tests/run-macros/i12392.check @@ -1 +1 @@ -scala.Option[scala.Predef.String] to scala.Option[scala.Int] +scala.Option[java.lang.String] to scala.Option[scala.Int] diff --git a/tests/run-staging/quote-nested-3.check b/tests/run-staging/quote-nested-3.check index 63bdda5c6c4c..c3dfba2d8abe 100644 --- a/tests/run-staging/quote-nested-3.check +++ b/tests/run-staging/quote-nested-3.check @@ -1,7 +1,7 @@ { - type T = scala.Predef.String + type T = java.lang.String val x: java.lang.String = "foo" - val z: T = x + val z: java.lang.String = x (x: java.lang.String) } diff --git a/tests/run-staging/quote-nested-4.check b/tests/run-staging/quote-nested-4.check index 895bd0ddc914..d31b6394dccd 100644 --- a/tests/run-staging/quote-nested-4.check +++ b/tests/run-staging/quote-nested-4.check @@ -1,5 +1,5 @@ ((q: scala.quoted.Quotes) ?=> { - val t: scala.quoted.Type[scala.Predef.String] = scala.quoted.Type.of[scala.Predef.String](q) + val t: scala.quoted.Type[java.lang.String] = scala.quoted.Type.of[java.lang.String](q) - (t: scala.quoted.Type[scala.Predef.String]) + (t: scala.quoted.Type[java.lang.String]) }) diff --git a/tests/run-staging/quote-nested-6.check b/tests/run-staging/quote-nested-6.check index 05c2bd4eb00c..2ae8b0d26e47 100644 --- a/tests/run-staging/quote-nested-6.check +++ b/tests/run-staging/quote-nested-6.check @@ -1,7 +1,7 @@ { - type T[X] = scala.List[X] + type T[X] = [A >: scala.Nothing <: scala.Any] => scala.collection.immutable.List[A][X] val x: java.lang.String = "foo" - val z: T[scala.Predef.String] = scala.List.apply[java.lang.String](x) + val z: [X >: scala.Nothing <: scala.Any] => scala.collection.immutable.List[X][java.lang.String] = scala.List.apply[java.lang.String](x) (x: java.lang.String) } diff --git a/tests/run-staging/quote-owners-2.check b/tests/run-staging/quote-owners-2.check index 323ce64b7bc7..49c09271779c 100644 --- a/tests/run-staging/quote-owners-2.check +++ b/tests/run-staging/quote-owners-2.check @@ -2,7 +2,7 @@ def ff: scala.Int = { val a: scala.collection.immutable.List[scala.Int] = { type T = scala.collection.immutable.List[scala.Int] - val b: T = scala.Nil.::[scala.Int](3) + val b: scala.collection.immutable.List[scala.Int] = scala.Nil.::[scala.Int](3) (b: scala.collection.immutable.List[scala.Int]) } diff --git a/tests/run-staging/quote-unrolled-foreach.check b/tests/run-staging/quote-unrolled-foreach.check index 8e58ab8eed51..3a72cd1b1311 100644 --- a/tests/run-staging/quote-unrolled-foreach.check +++ b/tests/run-staging/quote-unrolled-foreach.check @@ -8,7 +8,7 @@ } }) -((arr: scala.Array[scala.Predef.String], f: scala.Function1[scala.Predef.String, scala.Unit]) => { +((arr: scala.Array[java.lang.String], f: scala.Function1[java.lang.String, scala.Unit]) => { val size: scala.Int = arr.length var i: scala.Int = 0 while (i.<(size)) { @@ -18,7 +18,7 @@ } }) -((arr: scala.Array[scala.Predef.String], f: scala.Function1[scala.Predef.String, scala.Unit]) => { +((arr: scala.Array[java.lang.String], f: scala.Function1[java.lang.String, scala.Unit]) => { val size: scala.Int = arr.length var i: scala.Int = 0 while (i.<(size)) { @@ -41,7 +41,7 @@ ((arr: scala.Array[scala.Int], f: scala.Function1[scala.Int, scala.Unit]) => { val size: scala.Int = arr.length var i: scala.Int = 0 - if (size.%(3).!=(0)) throw new scala.Exception("...") else () + if (size.%(3).!=(0)) throw new java.lang.Exception("...") else () while (i.<(size)) { f.apply(arr.apply(i)) f.apply(arr.apply(i.+(1))) @@ -53,7 +53,7 @@ ((arr: scala.Array[scala.Int], f: scala.Function1[scala.Int, scala.Unit]) => { val size: scala.Int = arr.length var i: scala.Int = 0 - if (size.%(4).!=(0)) throw new scala.Exception("...") else () + if (size.%(4).!=(0)) throw new java.lang.Exception("...") else () while (i.<(size)) { f.apply(arr.apply(i.+(0))) f.apply(arr.apply(i.+(1))) diff --git a/tests/run-staging/shonan-hmm-simple.check b/tests/run-staging/shonan-hmm-simple.check index da437646482d..cbef88812dcd 100644 --- a/tests/run-staging/shonan-hmm-simple.check +++ b/tests/run-staging/shonan-hmm-simple.check @@ -6,7 +6,7 @@ Complex(4,3) 10 ((arr1: scala.Array[scala.Int], arr2: scala.Array[scala.Int]) => { - if (arr1.length.!=(arr2.length)) throw new scala.Exception("...") else () + if (arr1.length.!=(arr2.length)) throw new java.lang.Exception("...") else () var sum: scala.Int = 0 var i: scala.Int = 0 while (i.<(scala.Predef.intArrayOps(arr1).size)) { @@ -22,13 +22,13 @@ Complex(4,3) 10 ((arr: scala.Array[scala.Int]) => { - if (arr.length.!=(5)) throw new scala.Exception("...") else () + if (arr.length.!=(5)) throw new java.lang.Exception("...") else () arr.apply(0).+(arr.apply(2)).+(arr.apply(4)) }) 10 ((arr: scala.Array[Complex[scala.Int]]) => { - if (arr.length.!=(4)) throw new scala.Exception("...") else () + if (arr.length.!=(4)) throw new java.lang.Exception("...") else () Complex.apply[scala.Int](0.-(arr.apply(0).im).+(0.-(arr.apply(2).im)).+(arr.apply(3).re.*(2)), arr.apply(0).re.+(arr.apply(2).re).+(arr.apply(3).im.*(2))) }) Complex(4,3) diff --git a/tests/run-staging/shonan-hmm.check b/tests/run-staging/shonan-hmm.check index fa5206904962..9cb77f850155 100644 --- a/tests/run-staging/shonan-hmm.check +++ b/tests/run-staging/shonan-hmm.check @@ -35,8 +35,8 @@ List(25, 30, 20, 43, 44) ((vout: scala.Array[scala.Int], a: scala.Array[scala.Array[scala.Int]], v: scala.Array[scala.Int]) => { - if (3.!=(vout.length)) throw new scala.IndexOutOfBoundsException("3") else () - if (2.!=(v.length)) throw new scala.IndexOutOfBoundsException("2") else () + if (3.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("3") else () + if (2.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("2") else () vout.update(0, 0.+(v.apply(0).*(a.apply(0).apply(0))).+(v.apply(1).*(a.apply(0).apply(1)))) vout.update(1, 0.+(v.apply(0).*(a.apply(1).apply(0))).+(v.apply(1).*(a.apply(1).apply(1)))) vout.update(2, 0.+(v.apply(0).*(a.apply(2).apply(0))).+(v.apply(1).*(a.apply(2).apply(1)))) @@ -95,8 +95,8 @@ List(25, 30, 20, 43, 44) array } ((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => { - if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else () - if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else () + if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else () + if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else () vout.update(0, 0.+(v.apply(0).*(5)).+(v.apply(1).*(0)).+(v.apply(2).*(0)).+(v.apply(3).*(5)).+(v.apply(4).*(0))) vout.update(1, 0.+(v.apply(0).*(0)).+(v.apply(1).*(0)).+(v.apply(2).*(10)).+(v.apply(3).*(0)).+(v.apply(4).*(0))) vout.update(2, 0.+(v.apply(0).*(0)).+(v.apply(1).*(10)).+(v.apply(2).*(0)).+(v.apply(3).*(0)).+(v.apply(4).*(0))) @@ -158,8 +158,8 @@ List(25, 30, 20, 43, 44) array } ((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => { - if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else () - if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else () + if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else () + if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else () vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5))) vout.update(1, v.apply(2).*(10)) vout.update(2, v.apply(1).*(10)) @@ -221,8 +221,8 @@ List(25, 30, 20, 43, 44) array } ((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => { - if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else () - if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else () + if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else () + if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else () vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5))) vout.update(1, v.apply(2).*(10)) vout.update(2, v.apply(1).*(10)) @@ -243,8 +243,8 @@ List(25, 30, 20, 43, 44) ((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => { - if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else () - if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else () + if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else () + if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else () vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5))) vout.update(1, v.apply(2).*(10)) vout.update(2, v.apply(1).*(10)) @@ -282,8 +282,8 @@ List(25, 30, 20, 43, 44) array } ((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => { - if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else () - if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else () + if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else () + if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else () vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5))) vout.update(1, v.apply(2).*(10)) vout.update(2, v.apply(1).*(10))