diff --git a/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala b/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala index 98376c497868..5e350ac1eec1 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeMapWithImplicits.scala @@ -26,7 +26,7 @@ class TreeMapWithImplicits extends tpd.TreeMap { * - be tail-recursive where possible * - don't re-allocate trees where nothing has changed */ - def transformStats(stats: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = { + override def transformStats(stats: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = { @tailrec def traverse(curStats: List[Tree])(using Context): List[Tree] = { @@ -88,8 +88,14 @@ class TreeMapWithImplicits extends tpd.TreeMap { def localCtx = if (tree.hasType && tree.symbol.exists) ctx.withOwner(tree.symbol) else ctx try tree match { - case tree: Block => - super.transform(tree)(using nestedScopeCtx(tree.stats)) + case Block(stats, expr) => + inContext(nestedScopeCtx(stats)) { + if stats.exists(_.isInstanceOf[Import]) then + // need to transform stats and expr together to account for import visibility + val stats1 = transformStats(stats :+ expr, ctx.owner) + cpy.Block(tree)(stats1.init, stats1.last) + else super.transform(tree) + } case tree: DefDef => inContext(localCtx) { cpy.DefDef(tree)( @@ -100,8 +106,10 @@ class TreeMapWithImplicits extends tpd.TreeMap { } case EmptyValDef => tree - case _: PackageDef | _: MemberDef => + case _: MemberDef => super.transform(tree)(using localCtx) + case _: PackageDef => + super.transform(tree)(using ctx.withOwner(tree.symbol.moduleClass)) case impl @ Template(constr, parents, self, _) => cpy.Template(tree)( transformSub(constr), diff --git a/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala b/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala index fddde05d9a31..50373f044719 100644 --- a/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala +++ b/compiler/src/dotty/tools/dotc/ast/TreeTypeMap.scala @@ -130,7 +130,7 @@ class TreeTypeMap( } } - override def transformStats(trees: List[tpd.Tree])(using Context): List[Tree] = + override def transformStats(trees: List[tpd.Tree], exprOwner: Symbol)(using Context): List[Tree] = transformDefs(trees)._2 def transformDefs[TT <: tpd.Tree](trees: List[TT])(using Context): (TreeTypeMap, List[TT]) = { diff --git a/compiler/src/dotty/tools/dotc/ast/Trees.scala b/compiler/src/dotty/tools/dotc/ast/Trees.scala index 23996ce14dfb..ba2dcc2e16d2 100644 --- a/compiler/src/dotty/tools/dotc/ast/Trees.scala +++ b/compiler/src/dotty/tools/dotc/ast/Trees.scala @@ -1332,7 +1332,7 @@ object Trees { case Assign(lhs, rhs) => cpy.Assign(tree)(transform(lhs), transform(rhs)) case Block(stats, expr) => - cpy.Block(tree)(transformStats(stats), transform(expr)) + cpy.Block(tree)(transformStats(stats, ctx.owner), transform(expr)) case If(cond, thenp, elsep) => cpy.If(tree)(transform(cond), transform(thenp), transform(elsep)) case Closure(env, meth, tpt) => @@ -1398,13 +1398,13 @@ object Trees { cpy.TypeDef(tree)(name, transform(rhs)) } case tree @ Template(constr, parents, self, _) if tree.derived.isEmpty => - cpy.Template(tree)(transformSub(constr), transform(tree.parents), Nil, transformSub(self), transformStats(tree.body)) + cpy.Template(tree)(transformSub(constr), transform(tree.parents), Nil, transformSub(self), transformStats(tree.body, tree.symbol)) case Import(expr, selectors) => cpy.Import(tree)(transform(expr), selectors) case Export(expr, selectors) => cpy.Export(tree)(transform(expr), selectors) case PackageDef(pid, stats) => - cpy.PackageDef(tree)(transformSub(pid), transformStats(stats)(using localCtx)) + cpy.PackageDef(tree)(transformSub(pid), transformStats(stats, pid.symbol.moduleClass)(using localCtx)) case Annotated(arg, annot) => cpy.Annotated(tree)(transform(arg), transform(annot)) case Thicket(trees) => @@ -1416,7 +1416,7 @@ object Trees { } } - def transformStats(trees: List[Tree])(using Context): List[Tree] = + def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = transform(trees) def transform(trees: List[Tree])(using Context): List[Tree] = flatten(trees mapConserve (transform(_))) diff --git a/compiler/src/dotty/tools/dotc/ast/untpd.scala b/compiler/src/dotty/tools/dotc/ast/untpd.scala index 6db0905c8596..9a5b222617bd 100644 --- a/compiler/src/dotty/tools/dotc/ast/untpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/untpd.scala @@ -659,7 +659,8 @@ object untpd extends Trees.Instance[Untyped] with UntypedTreeInfo { case ModuleDef(name, impl) => cpy.ModuleDef(tree)(name, transformSub(impl)) case tree: DerivingTemplate => - cpy.Template(tree)(transformSub(tree.constr), transform(tree.parents), transform(tree.derived), transformSub(tree.self), transformStats(tree.body)) + cpy.Template(tree)(transformSub(tree.constr), transform(tree.parents), + transform(tree.derived), transformSub(tree.self), transformStats(tree.body, tree.symbol)) case ParsedTry(expr, handler, finalizer) => cpy.ParsedTry(tree)(transform(expr), transform(handler), transform(finalizer)) case SymbolLit(str) => diff --git a/compiler/src/dotty/tools/dotc/core/Contexts.scala b/compiler/src/dotty/tools/dotc/core/Contexts.scala index 03593726da37..a1e4eb6a9ad6 100644 --- a/compiler/src/dotty/tools/dotc/core/Contexts.scala +++ b/compiler/src/dotty/tools/dotc/core/Contexts.scala @@ -538,11 +538,17 @@ object Contexts { case _ => new Typer } - override def toString: String = { - def iinfo(using Context) = if (ctx.importInfo == null) "" else i"${ctx.importInfo.selectors}%, %" - "Context(\n" + - (outersIterator.map(ctx => s" owner = ${ctx.owner}, scope = ${ctx.scope}, import = ${iinfo(using ctx)}").mkString("\n")) - } + override def toString: String = + def iinfo(using Context) = + if (ctx.importInfo == null) "" else i"${ctx.importInfo.selectors}%, %" + def cinfo(using Context) = + val core = s" owner = ${ctx.owner}, scope = ${ctx.scope}, import = ${iinfo(using ctx)}" + if (ctx ne NoContext) && (ctx.implicits ne ctx.outer.implicits) then + s"$core, implicits = ${ctx.implicits}" + else + core + s"""Context( + |${outersIterator.map(ctx => cinfo(using ctx)).mkString("\n\n")})""".stripMargin def settings: ScalaSettings = base.settings def definitions: Definitions = base.definitions diff --git a/compiler/src/dotty/tools/dotc/transform/MacroTransform.scala b/compiler/src/dotty/tools/dotc/transform/MacroTransform.scala index 87a5ef67bf6b..34ba86ce4c9e 100644 --- a/compiler/src/dotty/tools/dotc/transform/MacroTransform.scala +++ b/compiler/src/dotty/tools/dotc/transform/MacroTransform.scala @@ -37,7 +37,7 @@ abstract class MacroTransform extends Phase { ctx.fresh.setTree(tree).setOwner(owner) } - def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = { + override def transformStats(trees: List[Tree], exprOwner: Symbol)(using Context): List[Tree] = { def transformStat(stat: Tree): Tree = stat match { case _: Import | _: DefTree => transform(stat) case _ => transform(stat)(using ctx.exprContext(stat, exprOwner)) diff --git a/compiler/src/dotty/tools/dotc/transform/TreeMapWithStages.scala b/compiler/src/dotty/tools/dotc/transform/TreeMapWithStages.scala index f678a1930758..c23e90352d10 100644 --- a/compiler/src/dotty/tools/dotc/transform/TreeMapWithStages.scala +++ b/compiler/src/dotty/tools/dotc/transform/TreeMapWithStages.scala @@ -46,7 +46,7 @@ abstract class TreeMapWithStages(@constructorOnly ictx: Context) extends TreeMap /** The quotation level of the definition of the locally defined symbol */ protected def levelOf(sym: Symbol): Int = levelOfMap.getOrElse(sym, 0) - /** Localy defined symbols seen so far by `StagingTransformer.transform` */ + /** Locally defined symbols seen so far by `StagingTransformer.transform` */ protected def localSymbols: List[Symbol] = enteredSyms /** If we are inside a quote or a splice */ @@ -74,7 +74,7 @@ abstract class TreeMapWithStages(@constructorOnly ictx: Context) extends TreeMap /** Transform the expression splice `splice` which contains the spliced `body`. */ protected def transformSplice(body: Tree, splice: Apply)(using Context): Tree - /** Transform the typee splice `splice` which contains the spliced `body`. */ + /** Transform the type splice `splice` which contains the spliced `body`. */ protected def transformSpliceType(body: Tree, splice: Select)(using Context): Tree override def transform(tree: Tree)(using Context): Tree = @@ -109,7 +109,7 @@ abstract class TreeMapWithStages(@constructorOnly ictx: Context) extends TreeMap try dropEmptyBlocks(quotedTree) match { case Spliced(t) => // '{ $x } --> x - // and adapt the refinment of `Quotes { type tasty: ... } ?=> Expr[T]` + // and adapt the refinement of `Quotes { type reflect: ... } ?=> Expr[T]` transform(t).asInstance(tree.tpe) case _ => transformQuotation(quotedTree, tree) } diff --git a/tests/pos-macros/i11479/Macro_1.scala b/tests/pos-macros/i11479/Macro_1.scala new file mode 100644 index 000000000000..f4a8c0d13767 --- /dev/null +++ b/tests/pos-macros/i11479/Macro_1.scala @@ -0,0 +1,9 @@ +trait Foo +given Foo: Foo with {} +inline def summonFoo(): Foo = scala.compiletime.summonInline[Foo] + +package p: + trait Bar + given Bar: Bar with {} + inline def summonBar(): Bar = scala.compiletime.summonInline[Bar] + diff --git a/tests/pos-macros/i11479/OtherTest_2.scala b/tests/pos-macros/i11479/OtherTest_2.scala new file mode 100644 index 000000000000..8963f4eaa068 --- /dev/null +++ b/tests/pos-macros/i11479/OtherTest_2.scala @@ -0,0 +1,2 @@ +package p +def test3: Unit = summonBar() diff --git a/tests/pos-macros/i11479/Test_2.scala b/tests/pos-macros/i11479/Test_2.scala new file mode 100644 index 000000000000..cec9d7824b01 --- /dev/null +++ b/tests/pos-macros/i11479/Test_2.scala @@ -0,0 +1,8 @@ +import p.{*, given} +def test: Unit = + summonFoo() + summonBar() + + + + diff --git a/tests/pos/i11538a.scala b/tests/pos/i11538a.scala new file mode 100644 index 000000000000..243900c43b44 --- /dev/null +++ b/tests/pos/i11538a.scala @@ -0,0 +1,23 @@ +package a: + + trait Printer[A]: + def print(a: A): Unit + + given Printer[String] with + def print(s: String) = println(s) + +package b: + + import a.{given, *} + + object test: + import scala.compiletime.{error, summonFrom} + + inline def summonStringPrinter = + summonFrom { + case given Printer[String] => () + case _ => error("Couldn't find a printer") + } + + val summoned = summon[Printer[String]] + val summonedFrom = summonStringPrinter diff --git a/tests/pos/i11538b.scala b/tests/pos/i11538b.scala new file mode 100644 index 000000000000..c061910b2e73 --- /dev/null +++ b/tests/pos/i11538b.scala @@ -0,0 +1,9 @@ +package a: + type Foo + given foo: Foo = ??? + +import a.{Foo, given} +object test: + inline def summonInlineFoo = scala.compiletime.summonInline[Foo] + val summoned = summon[Foo] + val summonedInline = summonInlineFoo diff --git a/tests/pos/i11557.scala b/tests/pos/i11557.scala new file mode 100644 index 000000000000..f886a98d1727 --- /dev/null +++ b/tests/pos/i11557.scala @@ -0,0 +1,12 @@ +type MyEncoder + +class MyContext: + given intEncoder: MyEncoder = ??? + +def doEncoding(ctx: MyContext): Unit = + import ctx.{*, given} + summon[MyEncoder] + summonInlineMyEncoder() + +inline def summonInlineMyEncoder(): Unit = + compiletime.summonInline[MyEncoder]