diff --git a/compiler/src/dotty/tools/dotc/ast/tpd.scala b/compiler/src/dotty/tools/dotc/ast/tpd.scala index 2f46106e1230..c4f6284c0e13 100644 --- a/compiler/src/dotty/tools/dotc/ast/tpd.scala +++ b/compiler/src/dotty/tools/dotc/ast/tpd.scala @@ -799,13 +799,29 @@ object tpd extends Trees.Instance[Type] with TypedTreeInfo { * owner to `to`, and continue until a non-weak owner is reached. */ def changeOwner(from: Symbol, to: Symbol)(using Context): ThisTree = { - @tailrec def loop(from: Symbol, froms: List[Symbol], tos: List[Symbol]): ThisTree = - if (from.isWeakOwner && !from.owner.isClass) - loop(from.owner, from :: froms, to :: tos) - else - //println(i"change owner ${from :: froms}%, % ==> $tos of $tree") - TreeTypeMap(oldOwners = from :: froms, newOwners = tos).apply(tree) - if (from == to) tree else loop(from, Nil, to :: Nil) + changeOwners(List(from),to) + } + + /** Change owner from all `froms` to `to`. If `from` is a weak owner, also change its + * owner to `to`, and continue until a non-weak owner is reached. + */ + def changeOwners(froms: List[Symbol], to: Symbol)(using Context): ThisTree = { + @tailrec def loop(froms: List[Symbol], processedFroms: List[Symbol], tos: List[Symbol]): ThisTree = + froms match + case from::rest => + if (from == to) + loop(rest, processedFroms, tos) + else + if (from.isWeakOwner && !from.owner.isClass) + loop(from.owner::rest, from :: processedFroms, to :: tos) + else + loop(rest, from::processedFroms, to :: tos) + case Nil => + if (processedFroms.isEmpty) + tree + else + TreeTypeMap(oldOwners = processedFroms, newOwners = tos).apply(tree) + loop(froms, Nil, Nil) } /** diff --git a/compiler/src/dotty/tools/dotc/quoted/QuoteUtils.scala b/compiler/src/dotty/tools/dotc/quoted/QuoteUtils.scala index 56c8d3347205..e0635792bb6e 100644 --- a/compiler/src/dotty/tools/dotc/quoted/QuoteUtils.scala +++ b/compiler/src/dotty/tools/dotc/quoted/QuoteUtils.scala @@ -9,24 +9,23 @@ import dotty.tools.dotc.core.Symbols._ object QuoteUtils: import tpd._ - /** Get the owner of a tree if it has one */ - def treeOwner(tree: Tree)(using Context): Option[Symbol] = { - val getCurrentOwner = new TreeAccumulator[Option[Symbol]] { - def apply(x: Option[Symbol], tree: tpd.Tree)(using Context): Option[Symbol] = - if (x.isDefined) x - else tree match { - case tree: DefTree => Some(tree.symbol.owner) - case _ => foldOver(x, tree) + /** Get the list of owners of a tree if it has one */ + def treeOwners(tree: Tree)(using Context): List[Symbol] = { + val getOwners = new TreeAccumulator[Map[Int,Symbol]] { + def apply(x: Map[Int,Symbol], tree: tpd.Tree)(using Context): Map[Int,Symbol] = + tree match { + case tree: DefTree => val owner = tree.symbol.owner + x.updated(owner.id, owner) + case _ => foldOver(x,tree) } } - getCurrentOwner(None, tree) + getOwners(Map.empty,tree).values.toList } + /** Changes the owner of the tree based on the current owner of the tree */ def changeOwnerOfTree(tree: Tree, owner: Symbol)(using Context): Tree = { - treeOwner(tree) match - case Some(oldOwner) if oldOwner != owner => tree.changeOwner(oldOwner, owner) - case _ => tree + tree.changeOwners(treeOwners(tree), owner) } end QuoteUtils diff --git a/tests/pos-macros/i10151/Macro_1.scala b/tests/pos-macros/i10151/Macro_1.scala new file mode 100644 index 000000000000..141da4856be6 --- /dev/null +++ b/tests/pos-macros/i10151/Macro_1.scala @@ -0,0 +1,93 @@ +package x + +import scala.quoted._ + +trait CB[T]: + def map[S](f: T=>S): CB[S] = ??? + def flatMap[S](f: T=>CB[S]): CB[S] = ??? + +class MyArr[AK,AV]: + def map1[BK,BV](f: ((AK,AV)) => (BK, BV)):MyArr[BK,BV] = ??? + def map1Out[BK, BV](f: ((AK,AV)) => CB[(BK,BV)]): CB[MyArr[BK,BV]] = ??? + +def await[T](x:CB[T]):T = ??? + +object CBM: + def pure[T](t:T):CB[T] = ??? + +object X: + + inline def process[T](inline f:T) = ${ + processImpl[T]('f) + } + + def processImpl[T:Type](f:Expr[T])(using qctx: QuoteContext):Expr[CB[T]] = + import qctx.reflect._ + + def transform(term:Term):Term = + term match + case Apply(TypeApply(Select(obj,"map1"),targs),args) => + val nArgs = args.map(x => shiftLambda(x)) + val nSelect = Select.unique(obj, "map1Out") + Apply(TypeApply(nSelect,targs),nArgs) + case Apply(TypeApply(Ident("await"),targs),args) => args.head + case a@Apply(x,List(y,z)) => + val mty=MethodType(List("y1"))( _ => List(y.tpe.widen), _ => Type[CB].unseal.tpe.appliedTo(a.tpe.widen)) + val mtz=MethodType(List("z1"))( _ => List(z.tpe.widen), _ => a.tpe.widen) + Apply( + TypeApply(Select.unique(transform(y),"flatMap"), + List(Inferred(a.tpe.widen)) + ), + List( + Lambda(mty, yArgs => + Apply( + TypeApply(Select.unique(transform(z),"map"), + List(Inferred(a.tpe.widen)) + ), + List( + Lambda(mtz, zArgs => { + val termYArgs = yArgs.asInstanceOf[List[Term]] + val termZArgs = zArgs.asInstanceOf[List[Term]] + Apply(x,List(termYArgs.head,termZArgs.head)) + }) + ) + ) + ) + ) + ) + case Block(stats, last) => Block(stats, transform(last)) + case Inlined(x,List(),body) => transform(body) + case l@Literal(x) => + l.seal match + case '{ $l: $L } => + '{ CBM.pure(${term.seal.cast[L]}) }.unseal + case other => + throw RuntimeException(s"Not supported $other") + + def shiftLambda(term:Term): Term = + term match + case lt@Lambda(params, body) => + val paramTypes = params.map(_.tpt.tpe) + val paramNames = params.map(_.name) + val mt = MethodType(paramNames)(_ => paramTypes, _ => Type[CB].unseal.tpe.appliedTo(body.tpe.widen) ) + Lambda(mt, args => changeArgs(params,args,transform(body)) ) + case Block(stats, last) => + Block(stats, shiftLambda(last)) + case _ => + throw RuntimeException("lambda expected") + + def changeArgs(oldArgs:List[Tree], newArgs:List[Tree], body:Term):Term = + val association: Map[Symbol, Term] = (oldArgs zip newArgs).foldLeft(Map.empty){ + case (m, (oldParam, newParam: Term)) => m.updated(oldParam.symbol, newParam) + case (m, (oldParam, newParam: Tree)) => throw RuntimeException("Term expected") + } + val changes = new TreeMap() { + override def transformTerm(tree:Term)(using Context): Term = + tree match + case ident@Ident(name) => association.getOrElse(ident.symbol, super.transformTerm(tree)) + case _ => super.transformTerm(tree) + } + changes.transformTerm(body) + + val r = transform(f.unseal).seal.cast[CB[T]] + r diff --git a/tests/pos-macros/i10151/Test_2.scala b/tests/pos-macros/i10151/Test_2.scala new file mode 100644 index 000000000000..45ae862ac54d --- /dev/null +++ b/tests/pos-macros/i10151/Test_2.scala @@ -0,0 +1,15 @@ +package x + + +object Main { + + def main(args:Array[String]):Unit = + val arr = new MyArr[Int,Int]() + val r = X.process{ + arr.map1( (x,y) => + ( 1, await(CBM.pure(x)) ) + ) + } + println("r") + +} diff --git a/tests/pos-macros/i10211/Macro_1.scala b/tests/pos-macros/i10211/Macro_1.scala new file mode 100644 index 000000000000..070124b37ff4 --- /dev/null +++ b/tests/pos-macros/i10211/Macro_1.scala @@ -0,0 +1,102 @@ +package x + +import scala.quoted._ + +trait CB[T]: + def map[S](f: T=>S): CB[S] = ??? + + +class MyArr[A]: + def map[B](f: A=>B):MyArr[B] = ??? + def mapOut[B](f: A=> CB[B]): CB[MyArr[B]] = ??? + def flatMap[B](f: A=>MyArr[B]):MyArr[B] = ??? + def flatMapOut[B](f: A=>CB[MyArr[B]]):MyArr[B] = ??? + def withFilter(p: A=>Boolean): MyArr[A] = ??? + def withFilterOut(p: A=>CB[Boolean]): DelayedWithFilter[A] = ??? + def map2[B](f: A=>B):MyArr[B] = ??? + +class DelayedWithFilter[A]: + def map[B](f: A=>B):MyArr[B] = ??? + def mapOut[B](f: A=> CB[B]): CB[MyArr[B]] = ??? + def flatMap[B](f: A=>MyArr[B]):MyArr[B] = ??? + def flatMapOut[B](f: A=>CB[MyArr[B]]): CB[MyArr[B]] = ??? + def map2[B](f: A=>B):CB[MyArr[B]] = ??? + + +def await[T](x:CB[T]):T = ??? + +object CBM: + def pure[T](t:T):CB[T] = ??? + def map[T,S](a:CB[T])(f:T=>S):CB[S] = ??? + +object X: + + inline def process[T](inline f:T) = ${ + processImpl[T]('f) + } + + def processImpl[T:Type](f:Expr[T])(using qctx: QuoteContext):Expr[CB[T]] = + import qctx.reflect._ + + def transform(term:Term):Term = + term match + case ap@Apply(TypeApply(Select(obj,name),targs),args) + if (name=="map"||name=="flatMap") => + obj match + case Apply(Select(obj1,"withFilter"),args1) => + val nObj = transform(obj) + transform(Apply(TypeApply(Select.unique(nObj,name),targs),args)) + case _ => + val nArgs = args.map(x => shiftLambda(x)) + val nSelect = Select.unique(obj, name+"Out") + Apply(TypeApply(nSelect,targs),nArgs) + case ap@Apply(Select(obj,"withFilter"),args) => + val nArgs = args.map(x => shiftLambda(x)) + val nSelect = Select.unique(obj, "withFilterOut") + Apply(nSelect,nArgs) + case ap@Apply(TypeApply(Select(obj,"map2"),targs),args) => + val nObj = transform(obj) + Apply(TypeApply( + Select.unique(nObj,"map2"), + List(Type[Int].unseal) + ), + args + ) + case Apply(TypeApply(Ident("await"),targs),args) => args.head + case Apply(Select(obj,"=="),List(b)) => + val tb = transform(b).seal.cast[CB[Int]] + val mt = MethodType(List("p"))(_ => List(b.tpe.widen), _ => Type[Boolean].unseal.tpe) + val mapLambda = Lambda(mt, x => Select.overloaded(obj,"==",List(),List(x.head.asInstanceOf[Term]))).seal.cast[Int=>Boolean] + '{ CBM.map($tb)($mapLambda) }.unseal + case Block(stats, last) => Block(stats, transform(last)) + case Inlined(x,List(),body) => transform(body) + case l@Literal(x) => + '{ CBM.pure(${term.seal}) }.unseal + case other => + throw RuntimeException(s"Not supported $other") + + def shiftLambda(term:Term): Term = + term match + case lt@Lambda(params, body) => + val paramTypes = params.map(_.tpt.tpe) + val paramNames = params.map(_.name) + val mt = MethodType(paramNames)(_ => paramTypes, _ => Type[CB].unseal.tpe.appliedTo(body.tpe.widen) ) + val r = Lambda(mt, args => changeArgs(params,args,transform(body)) ) + r + case _ => + throw RuntimeException("lambda expected") + + def changeArgs(oldArgs:List[Tree], newArgs:List[Tree], body:Term):Term = + val association: Map[Symbol, Term] = (oldArgs zip newArgs).foldLeft(Map.empty){ + case (m, (oldParam, newParam: Term)) => m.updated(oldParam.symbol, newParam) + case (m, (oldParam, newParam: Tree)) => throw RuntimeException("Term expected") + } + val changes = new TreeMap() { + override def transformTerm(tree:Term)(using Context): Term = + tree match + case ident@Ident(name) => association.getOrElse(ident.symbol, super.transformTerm(tree)) + case _ => super.transformTerm(tree) + } + changes.transformTerm(body) + + transform(f.unseal).seal.cast[CB[T]] diff --git a/tests/pos-macros/i10211/Test_2.scala b/tests/pos-macros/i10211/Test_2.scala new file mode 100644 index 000000000000..b6d83cf86489 --- /dev/null +++ b/tests/pos-macros/i10211/Test_2.scala @@ -0,0 +1,18 @@ +package x + + +object Main { + + def main(args:Array[String]):Unit = + val arr1 = new MyArr[Int]() + val arr2 = new MyArr[Int]() + val r = X.process{ + arr1.withFilter(x => x == await(CBM.pure(1))) + .flatMap(x => + arr2.withFilter( y => y == await(CBM.pure(2)) ). + map2( y => x + y ) + ) + } + println(r) + +}