diff --git a/compiler/src/dotty/tools/dotc/staging/CrossStageSafety.scala b/compiler/src/dotty/tools/dotc/staging/CrossStageSafety.scala
index 219b428ca8d4..98e060488f43 100644
--- a/compiler/src/dotty/tools/dotc/staging/CrossStageSafety.scala
+++ b/compiler/src/dotty/tools/dotc/staging/CrossStageSafety.scala
@@ -107,30 +107,37 @@ class CrossStageSafety extends TreeMapWithStages {
val stripAnnotsDeep: TypeMap = new TypeMap:
def apply(tp: Type): Type = mapOver(tp.stripAnnots)
- val contextWithQuote =
- if level == 0 then contextWithQuoteTypeTags(taggedTypes)(using quoteContext)
- else quoteContext
- val body1 = transform(body)(using contextWithQuote)
- val body2 =
+ def transformBody() =
+ val contextWithQuote =
+ if level == 0 then contextWithQuoteTypeTags(taggedTypes)(using quoteContext)
+ else quoteContext
+ val transformedBody = transform(body)(using contextWithQuote)
taggedTypes.getTypeTags match
- case Nil => body1
- case tags => tpd.Block(tags, body1).withSpan(body.span)
+ case Nil => transformedBody
+ case tags => tpd.Block(tags, transformedBody).withSpan(body.span)
if body.isTerm then
+ val transformedBody = transformBody()
// `quoted.runtime.Expr.quote[T](
)` --> `quoted.runtime.Expr.quote[T2]()`
val TypeApply(fun, targs) = quote.fun: @unchecked
val targs2 = targs.map(targ => TypeTree(healType(quote.fun.srcPos)(stripAnnotsDeep(targ.tpe))))
- cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, targs2), body2 :: Nil)
+ cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, targs2), transformedBody :: Nil)
else
- val quotes = quote.args.mapConserve(transform)
body.tpe match
- case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice =>
+ case DirectTypeOf(termRef) =>
// Optimization: `quoted.Type.of[x.Underlying](quotes)` --> `x`
- ref(x)
+ ref(termRef).withSpan(quote.span)
case _ =>
- // `quoted.Type.of[](quotes)` --> `quoted.Type.of[](quotes)`
- val TypeApply(fun, _) = quote.fun: @unchecked
- cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, body2 :: Nil), quotes)
+ transformBody() match
+ case DirectTypeOf.Healed(termRef) =>
+ // Optimization: `quoted.Type.of[@SplicedType type T = x.Underlying; T](quotes)` --> `x`
+ ref(termRef).withSpan(quote.span)
+ case transformedBody =>
+ val quotes = quote.args.mapConserve(transform)
+ // `quoted.Type.of[](quotes)` --> `quoted.Type.of[](quotes)`
+ val TypeApply(fun, _) = quote.fun: @unchecked
+ cpy.Apply(quote)(cpy.TypeApply(quote.fun)(fun, transformedBody :: Nil), quotes)
+
}
/** Transform splice
diff --git a/compiler/src/dotty/tools/dotc/staging/DirectTypeOf.scala b/compiler/src/dotty/tools/dotc/staging/DirectTypeOf.scala
new file mode 100644
index 000000000000..488d8ff2a88e
--- /dev/null
+++ b/compiler/src/dotty/tools/dotc/staging/DirectTypeOf.scala
@@ -0,0 +1,25 @@
+package dotty.tools.dotc.staging
+
+import dotty.tools.dotc.ast.{tpd, untpd}
+import dotty.tools.dotc.core.Contexts._
+import dotty.tools.dotc.core.Symbols._
+import dotty.tools.dotc.core.Types._
+
+object DirectTypeOf:
+ import tpd.*
+
+ /** Matches `x.Underlying` and extracts the TermRef to `x` */
+ def unapply(tpe: Type)(using Context): Option[TermRef] = tpe match
+ case tp @ TypeRef(x: TermRef, _) if tp.symbol == defn.QuotedType_splice => Some(x)
+ case _ => None
+
+ object Healed:
+ /** Matches `{ @SplicedType type T = x.Underlying; T }` and extracts the TermRef to `x` */
+ def unapply(body: Tree)(using Context): Option[TermRef] =
+ body match
+ case Block(List(tdef: TypeDef), tpt: TypeTree) =>
+ tpt.tpe match
+ case tpe: TypeRef if tpe.typeSymbol == tdef.symbol =>
+ DirectTypeOf.unapply(tdef.rhs.tpe.hiBound)
+ case _ => None
+ case _ => None
diff --git a/compiler/src/dotty/tools/dotc/staging/HealType.scala b/compiler/src/dotty/tools/dotc/staging/HealType.scala
index 22008f381c32..7907c2e47542 100644
--- a/compiler/src/dotty/tools/dotc/staging/HealType.scala
+++ b/compiler/src/dotty/tools/dotc/staging/HealType.scala
@@ -32,7 +32,12 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {
def apply(tp: Type): Type =
tp match
case tp: TypeRef =>
- healTypeRef(tp)
+ tp.underlying match
+ case TypeAlias(alias)
+ if !tp.symbol.isTypeSplice && !tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) =>
+ this.apply(alias)
+ case _ =>
+ healTypeRef(tp)
case tp @ TermRef(NoPrefix, _) if !tp.symbol.isStatic && level > levelOf(tp.symbol) =>
levelError(tp.symbol, tp, pos)
case tp: AnnotatedType =>
@@ -46,11 +51,11 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {
checkNotWildcardSplice(tp)
if level == 0 then tp else getQuoteTypeTags.getTagRef(prefix)
case prefix: TermRef if !prefix.symbol.isStatic && level > levelOf(prefix.symbol) =>
- dealiasAndTryHeal(prefix.symbol, tp, pos)
+ tryHeal(prefix.symbol, tp, pos)
case NoPrefix if level > levelOf(tp.symbol) && !tp.typeSymbol.hasAnnotation(defn.QuotedRuntime_SplicedTypeAnnot) =>
- dealiasAndTryHeal(tp.symbol, tp, pos)
+ tryHeal(tp.symbol, tp, pos)
case prefix: ThisType if level > levelOf(prefix.cls) && !tp.symbol.isStatic =>
- dealiasAndTryHeal(tp.symbol, tp, pos)
+ tryHeal(tp.symbol, tp, pos)
case _ =>
mapOver(tp)
@@ -59,11 +64,6 @@ class HealType(pos: SrcPos)(using Context) extends TypeMap {
case (tb: TypeBounds) :: _ => report.error(em"Cannot splice $splice because it is a wildcard type", pos)
case _ =>
- private def dealiasAndTryHeal(sym: Symbol, tp: TypeRef, pos: SrcPos): Type =
- val tp1 = tp.dealias
- if tp1 != tp then apply(tp1)
- else tryHeal(tp.symbol, tp, pos)
-
/** Try to heal reference to type `T` used in a higher level than its definition.
* Returns a reference to a type tag generated by `QuoteTypeTags` that contains a
* reference to a type alias containing the equivalent of `${summon[quoted.Type[T]]}`.
diff --git a/tests/pos-macros/i8100b.scala b/tests/pos-macros/i8100b.scala
new file mode 100644
index 000000000000..ecba10e439d2
--- /dev/null
+++ b/tests/pos-macros/i8100b.scala
@@ -0,0 +1,37 @@
+import scala.quoted.*
+
+def f[T](using t: Type[T])(using Quotes) =
+ '{
+ // @SplicedType type t$1 = t.Underlying
+ type T2 = T // type T2 = t$1
+ ${
+
+ val t0: T = ???
+ val t1: T2 = ??? // val t1: T = ???
+ val tp1 = Type.of[T] // val tp1 = t
+ val tp2 = Type.of[T2] // val tp2 = t
+ '{
+ // @SplicedType type t$2 = t.Underlying
+ val t3: T = ??? // val t3: t$2 = ???
+ val t4: T2 = ??? // val t4: t$2 = ???
+ }
+ }
+ }
+
+def g(using Quotes) =
+ '{
+ type U
+ type U2 = U
+ ${
+
+ val u1: U = ???
+ val u2: U2 = ??? // val u2: U = ???
+
+ val tp1 = Type.of[U] // val tp1 = Type.of[U]
+ val tp2 = Type.of[U2] // val tp2 = Type.of[U]
+ '{
+ val u3: U = ???
+ val u4: U2 = ??? // val u4: U = ???
+ }
+ }
+ }
diff --git a/tests/run-macros/i12392.check b/tests/run-macros/i12392.check
index 54c7f5d06c3f..92bbfa65fb49 100644
--- a/tests/run-macros/i12392.check
+++ b/tests/run-macros/i12392.check
@@ -1 +1 @@
-scala.Option[scala.Predef.String] to scala.Option[scala.Int]
+scala.Option[java.lang.String] to scala.Option[scala.Int]
diff --git a/tests/run-staging/quote-nested-3.check b/tests/run-staging/quote-nested-3.check
index 63bdda5c6c4c..c3dfba2d8abe 100644
--- a/tests/run-staging/quote-nested-3.check
+++ b/tests/run-staging/quote-nested-3.check
@@ -1,7 +1,7 @@
{
- type T = scala.Predef.String
+ type T = java.lang.String
val x: java.lang.String = "foo"
- val z: T = x
+ val z: java.lang.String = x
(x: java.lang.String)
}
diff --git a/tests/run-staging/quote-nested-4.check b/tests/run-staging/quote-nested-4.check
index 895bd0ddc914..d31b6394dccd 100644
--- a/tests/run-staging/quote-nested-4.check
+++ b/tests/run-staging/quote-nested-4.check
@@ -1,5 +1,5 @@
((q: scala.quoted.Quotes) ?=> {
- val t: scala.quoted.Type[scala.Predef.String] = scala.quoted.Type.of[scala.Predef.String](q)
+ val t: scala.quoted.Type[java.lang.String] = scala.quoted.Type.of[java.lang.String](q)
- (t: scala.quoted.Type[scala.Predef.String])
+ (t: scala.quoted.Type[java.lang.String])
})
diff --git a/tests/run-staging/quote-nested-6.check b/tests/run-staging/quote-nested-6.check
index 05c2bd4eb00c..2ae8b0d26e47 100644
--- a/tests/run-staging/quote-nested-6.check
+++ b/tests/run-staging/quote-nested-6.check
@@ -1,7 +1,7 @@
{
- type T[X] = scala.List[X]
+ type T[X] = [A >: scala.Nothing <: scala.Any] => scala.collection.immutable.List[A][X]
val x: java.lang.String = "foo"
- val z: T[scala.Predef.String] = scala.List.apply[java.lang.String](x)
+ val z: [X >: scala.Nothing <: scala.Any] => scala.collection.immutable.List[X][java.lang.String] = scala.List.apply[java.lang.String](x)
(x: java.lang.String)
}
diff --git a/tests/run-staging/quote-owners-2.check b/tests/run-staging/quote-owners-2.check
index 323ce64b7bc7..49c09271779c 100644
--- a/tests/run-staging/quote-owners-2.check
+++ b/tests/run-staging/quote-owners-2.check
@@ -2,7 +2,7 @@
def ff: scala.Int = {
val a: scala.collection.immutable.List[scala.Int] = {
type T = scala.collection.immutable.List[scala.Int]
- val b: T = scala.Nil.::[scala.Int](3)
+ val b: scala.collection.immutable.List[scala.Int] = scala.Nil.::[scala.Int](3)
(b: scala.collection.immutable.List[scala.Int])
}
diff --git a/tests/run-staging/quote-unrolled-foreach.check b/tests/run-staging/quote-unrolled-foreach.check
index 8e58ab8eed51..3a72cd1b1311 100644
--- a/tests/run-staging/quote-unrolled-foreach.check
+++ b/tests/run-staging/quote-unrolled-foreach.check
@@ -8,7 +8,7 @@
}
})
-((arr: scala.Array[scala.Predef.String], f: scala.Function1[scala.Predef.String, scala.Unit]) => {
+((arr: scala.Array[java.lang.String], f: scala.Function1[java.lang.String, scala.Unit]) => {
val size: scala.Int = arr.length
var i: scala.Int = 0
while (i.<(size)) {
@@ -18,7 +18,7 @@
}
})
-((arr: scala.Array[scala.Predef.String], f: scala.Function1[scala.Predef.String, scala.Unit]) => {
+((arr: scala.Array[java.lang.String], f: scala.Function1[java.lang.String, scala.Unit]) => {
val size: scala.Int = arr.length
var i: scala.Int = 0
while (i.<(size)) {
@@ -41,7 +41,7 @@
((arr: scala.Array[scala.Int], f: scala.Function1[scala.Int, scala.Unit]) => {
val size: scala.Int = arr.length
var i: scala.Int = 0
- if (size.%(3).!=(0)) throw new scala.Exception("...") else ()
+ if (size.%(3).!=(0)) throw new java.lang.Exception("...") else ()
while (i.<(size)) {
f.apply(arr.apply(i))
f.apply(arr.apply(i.+(1)))
@@ -53,7 +53,7 @@
((arr: scala.Array[scala.Int], f: scala.Function1[scala.Int, scala.Unit]) => {
val size: scala.Int = arr.length
var i: scala.Int = 0
- if (size.%(4).!=(0)) throw new scala.Exception("...") else ()
+ if (size.%(4).!=(0)) throw new java.lang.Exception("...") else ()
while (i.<(size)) {
f.apply(arr.apply(i.+(0)))
f.apply(arr.apply(i.+(1)))
diff --git a/tests/run-staging/shonan-hmm-simple.check b/tests/run-staging/shonan-hmm-simple.check
index da437646482d..cbef88812dcd 100644
--- a/tests/run-staging/shonan-hmm-simple.check
+++ b/tests/run-staging/shonan-hmm-simple.check
@@ -6,7 +6,7 @@ Complex(4,3)
10
((arr1: scala.Array[scala.Int], arr2: scala.Array[scala.Int]) => {
- if (arr1.length.!=(arr2.length)) throw new scala.Exception("...") else ()
+ if (arr1.length.!=(arr2.length)) throw new java.lang.Exception("...") else ()
var sum: scala.Int = 0
var i: scala.Int = 0
while (i.<(scala.Predef.intArrayOps(arr1).size)) {
@@ -22,13 +22,13 @@ Complex(4,3)
10
((arr: scala.Array[scala.Int]) => {
- if (arr.length.!=(5)) throw new scala.Exception("...") else ()
+ if (arr.length.!=(5)) throw new java.lang.Exception("...") else ()
arr.apply(0).+(arr.apply(2)).+(arr.apply(4))
})
10
((arr: scala.Array[Complex[scala.Int]]) => {
- if (arr.length.!=(4)) throw new scala.Exception("...") else ()
+ if (arr.length.!=(4)) throw new java.lang.Exception("...") else ()
Complex.apply[scala.Int](0.-(arr.apply(0).im).+(0.-(arr.apply(2).im)).+(arr.apply(3).re.*(2)), arr.apply(0).re.+(arr.apply(2).re).+(arr.apply(3).im.*(2)))
})
Complex(4,3)
diff --git a/tests/run-staging/shonan-hmm.check b/tests/run-staging/shonan-hmm.check
index fa5206904962..9cb77f850155 100644
--- a/tests/run-staging/shonan-hmm.check
+++ b/tests/run-staging/shonan-hmm.check
@@ -35,8 +35,8 @@ List(25, 30, 20, 43, 44)
((vout: scala.Array[scala.Int], a: scala.Array[scala.Array[scala.Int]], v: scala.Array[scala.Int]) => {
- if (3.!=(vout.length)) throw new scala.IndexOutOfBoundsException("3") else ()
- if (2.!=(v.length)) throw new scala.IndexOutOfBoundsException("2") else ()
+ if (3.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("3") else ()
+ if (2.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("2") else ()
vout.update(0, 0.+(v.apply(0).*(a.apply(0).apply(0))).+(v.apply(1).*(a.apply(0).apply(1))))
vout.update(1, 0.+(v.apply(0).*(a.apply(1).apply(0))).+(v.apply(1).*(a.apply(1).apply(1))))
vout.update(2, 0.+(v.apply(0).*(a.apply(2).apply(0))).+(v.apply(1).*(a.apply(2).apply(1))))
@@ -95,8 +95,8 @@ List(25, 30, 20, 43, 44)
array
}
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
- if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
- if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
+ if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
+ if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
vout.update(0, 0.+(v.apply(0).*(5)).+(v.apply(1).*(0)).+(v.apply(2).*(0)).+(v.apply(3).*(5)).+(v.apply(4).*(0)))
vout.update(1, 0.+(v.apply(0).*(0)).+(v.apply(1).*(0)).+(v.apply(2).*(10)).+(v.apply(3).*(0)).+(v.apply(4).*(0)))
vout.update(2, 0.+(v.apply(0).*(0)).+(v.apply(1).*(10)).+(v.apply(2).*(0)).+(v.apply(3).*(0)).+(v.apply(4).*(0)))
@@ -158,8 +158,8 @@ List(25, 30, 20, 43, 44)
array
}
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
- if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
- if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
+ if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
+ if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
vout.update(1, v.apply(2).*(10))
vout.update(2, v.apply(1).*(10))
@@ -221,8 +221,8 @@ List(25, 30, 20, 43, 44)
array
}
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
- if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
- if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
+ if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
+ if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
vout.update(1, v.apply(2).*(10))
vout.update(2, v.apply(1).*(10))
@@ -243,8 +243,8 @@ List(25, 30, 20, 43, 44)
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
- if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
- if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
+ if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
+ if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
vout.update(1, v.apply(2).*(10))
vout.update(2, v.apply(1).*(10))
@@ -282,8 +282,8 @@ List(25, 30, 20, 43, 44)
array
}
((vout: scala.Array[scala.Int], v: scala.Array[scala.Int]) => {
- if (5.!=(vout.length)) throw new scala.IndexOutOfBoundsException("5") else ()
- if (5.!=(v.length)) throw new scala.IndexOutOfBoundsException("5") else ()
+ if (5.!=(vout.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
+ if (5.!=(v.length)) throw new java.lang.IndexOutOfBoundsException("5") else ()
vout.update(0, v.apply(0).*(5).+(v.apply(3).*(5)))
vout.update(1, v.apply(2).*(10))
vout.update(2, v.apply(1).*(10))