@@ -215,12 +215,17 @@ object Inlines:
215
215
cls match {
216
216
case cls @ tpd.TypeDef (_, impl : Template ) =>
217
217
val clsOverriddenSyms = cls.symbol.info.decls.toList.flatMap(_.allOverriddenSymbols).toSet
218
- val inlineDefs = inlineTraitAncestors(cls).foldLeft(List .empty[Tree ])(
219
- (defs, parent) =>
220
- val overriddenSymbols = clsOverriddenSyms ++ defs.flatMap(_.symbol.allOverriddenSymbols)
221
- defs ::: InlineParentTrait (parent)(using ctx.withOwner(cls.symbol)).expandDefs(overriddenSymbols)
222
- )
223
- val impl1 = cpy.Template (impl)(body = inlineDefs ::: impl.body)
218
+ val newDefs = inContext(ctx.withOwner(cls.symbol)) {
219
+ inlineTraitAncestors(cls).foldLeft((List .empty[Tree ], impl.body)){
220
+ case ((inlineDefs, childDefs), parent) =>
221
+ val parentTraitInliner = InlineParentTrait (parent)
222
+ val overriddenSymbols = clsOverriddenSyms ++ inlineDefs.flatMap(_.symbol.allOverriddenSymbols)
223
+ val inlinedDefs1 = inlineDefs ::: parentTraitInliner.expandDefs(overriddenSymbols)
224
+ val childDefs1 = parentTraitInliner.adaptDefs(childDefs) // TODO do this outside of inlining: we need to adapt ALL references to inlined stuff
225
+ (inlinedDefs1, childDefs1)
226
+ }
227
+ }
228
+ val impl1 = cpy.Template (impl)(body = newDefs._1 ::: newDefs._2)
224
229
cpy.TypeDef (cls)(rhs = impl1)
225
230
case _ =>
226
231
cls
@@ -540,6 +545,8 @@ object Inlines:
540
545
}
541
546
end expandDefs
542
547
548
+ def adaptDefs (definitions : List [Tree ]): List [Tree ] = definitions.mapconserve(defsAdapter(_))
549
+
543
550
protected class InlineTraitTypeMap extends InlinerTypeMap {
544
551
override def apply (t : Type ) = super .apply(t) match {
545
552
case t : ThisType if t.cls == parentSym => childThisType
@@ -698,6 +705,32 @@ object Inlines:
698
705
// TODO make version of inlined that does not return bindings?
699
706
Inlined (tpd.ref(parentSym), Nil , inlined(rhs)._2).withSpan(parent.span)
700
707
708
+ private val defsAdapter =
709
+ val typeMap = new DeepTypeMap {
710
+ override def apply (tp : Type ): Type = tp match {
711
+ case TypeRef (_, sym : Symbol ) if innerClassNewSyms.contains(sym) =>
712
+ TypeRef (childThisType, innerClassNewSyms(sym))
713
+ case _ =>
714
+ mapOver(tp)
715
+ }
716
+ }
717
+ def treeMap (tree : Tree ) = tree match {
718
+ case ident : Ident if innerClassNewSyms.contains(ident.symbol) =>
719
+ Ident (innerClassNewSyms(ident.symbol).namedType)
720
+ case tdef : TypeDef if tdef.symbol.isClass =>
721
+ tdef.symbol.info = typeMap(tdef.symbol.info)
722
+ tdef
723
+ case tree =>
724
+ tree
725
+ }
726
+ new TreeTypeMap (
727
+ typeMap = typeMap,
728
+ treeMap = treeMap,
729
+ substFrom = substFrom,
730
+ substTo = substTo,
731
+ )
732
+ end defsAdapter
733
+
701
734
private class ParamAccessorsMapper :
702
735
private val paramAccessorsTrees : mutable.Map [Symbol , Map [Name , Tree ]] = mutable.Map .empty
703
736
private val paramAccessorsNewNames : mutable.Map [(Symbol , Name ), Name ] = mutable.Map .empty
0 commit comments