Skip to content

Commit 3ecc0a8

Browse files
authored
Merge pull request #12922 from dotty-staging/fix-12914
Map opaque types in arguments of inlined calls to proxies
2 parents ed70022 + 1c2a8de commit 3ecc0a8

File tree

3 files changed

+103
-37
lines changed

3 files changed

+103
-37
lines changed

compiler/src/dotty/tools/dotc/typer/Inliner.scala

Lines changed: 68 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
409409

410410
private val methPart = funPart(call)
411411
private val callTypeArgs = typeArgss(call).flatten
412-
private val callValueArgss = termArgss(call)
412+
private val rawCallValueArgss = termArgss(call)
413413
private val inlinedMethod = methPart.symbol
414414
private val inlineCallPrefix =
415415
qualifier(methPart).orElse(This(inlinedMethod.enclosingClass.asClass))
@@ -581,31 +581,17 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
581581
case (from, to) if from.symbol == ref.symbol && from =:= ref => to
582582
}
583583

584-
/** If `binding` contains TermRefs that refer to objects with opaque
585-
* type aliases, add proxy definitions that expose these aliases
586-
* and substitute such TermRefs with theproxies. Example from pos/opaque-inline1.scala:
587-
*
588-
* object refined:
589-
* opaque type Positive = Int
590-
* inline def Positive(value: Int): Positive = f(value)
591-
* def f(x: Positive): Positive = x
592-
* def run: Unit = { val x = 9; val nine = refined.Positive(x) }
593-
*
594-
* This generates the following proxies:
595-
*
596-
* val $proxy1: refined.type{type Positive = Int} =
597-
* refined.$asInstanceOf$[refined.type{type Positive = Int}]
598-
* val refined$_this: ($proxy1 : refined.type{Positive = Int}) =
599-
* $proxy1
600-
*
601-
* and every reference to `refined` in the inlined expression is replaced by
602-
* `refined_$this`.
584+
/** If `tp` contains TermRefs that refer to objects with opaque
585+
* type aliases, add proxy definitions to `opaqueProxies` that expose these aliases.
603586
*/
604-
def accountForOpaques(binding: ValDef)(using Context): ValDef =
605-
binding.symbol.info.foreachPart {
587+
def addOpaqueProxies(tp: Type, span: Span, forThisProxy: Boolean)(using Context): Unit =
588+
tp.foreachPart {
606589
case ref: TermRef =>
607590
for cls <- ref.widen.classSymbols do
608-
if cls.containsOpaques && mapRef(ref).isEmpty then
591+
if cls.containsOpaques
592+
&& (forThisProxy || inlinedMethod.isContainedIn(cls))
593+
&& mapRef(ref).isEmpty
594+
then
609595
def openOpaqueAliases(selfType: Type): List[(Name, Type)] = selfType match
610596
case RefinedType(parent, rname, TypeAlias(alias)) =>
611597
val opaq = cls.info.member(rname).symbol
@@ -620,27 +606,67 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
620606
RefinedType(parent, refinement._1, TypeAlias(refinement._2))
621607
)
622608
val refiningSym = newSym(InlineBinderName.fresh(), Synthetic, refinedType).asTerm
623-
val refiningDef = ValDef(refiningSym, tpd.ref(ref).cast(refinedType)).withSpan(binding.span)
624-
inlining.println(i"add opaque alias proxy $refiningDef")
609+
val refiningDef = ValDef(refiningSym, tpd.ref(ref).cast(refinedType)).withSpan(span)
610+
inlining.println(i"add opaque alias proxy $refiningDef for $ref in $tp")
625611
bindingsBuf += refiningDef
626612
opaqueProxies += ((ref, refiningSym.termRef))
627613
case _ =>
628614
}
615+
616+
/** Map all TermRefs that match left element in `opaqueProxies` to the
617+
* corresponding right element.
618+
*/
619+
val mapOpaques = TreeTypeMap(
620+
typeMap = new TypeMap:
621+
override def stopAt = StopAt.Package
622+
def apply(t: Type) = mapOver {
623+
t match
624+
case ref: TermRef => mapRef(ref).getOrElse(ref)
625+
case _ => t
626+
}
627+
)
628+
629+
/** If `binding` contains TermRefs that refer to objects with opaque
630+
* type aliases, add proxy definitions that expose these aliases
631+
* and substitute such TermRefs with theproxies. Example from pos/opaque-inline1.scala:
632+
*
633+
* object refined:
634+
* opaque type Positive = Int
635+
* inline def Positive(value: Int): Positive = f(value)
636+
* def f(x: Positive): Positive = x
637+
* def run: Unit = { val x = 9; val nine = refined.Positive(x) }
638+
*
639+
* This generates the following proxies:
640+
*
641+
* val $proxy1: refined.type{type Positive = Int} =
642+
* refined.$asInstanceOf$[refined.type{type Positive = Int}]
643+
* val refined$_this: ($proxy1 : refined.type{Positive = Int}) =
644+
* $proxy1
645+
*
646+
* and every reference to `refined` in the inlined expression is replaced by
647+
* `refined_$this`.
648+
*/
649+
def accountForOpaques(binding: ValDef)(using Context): ValDef =
650+
addOpaqueProxies(binding.symbol.info, binding.span, forThisProxy = true)
629651
if opaqueProxies.isEmpty then binding
630652
else
631-
val mapType = new TypeMap:
632-
override def stopAt = StopAt.Package
633-
def apply(t: Type) = mapOver {
634-
t match
635-
case ref: TermRef => mapRef(ref).getOrElse(ref)
636-
case _ => t
637-
}
638-
binding.symbol.info = mapType(binding.symbol.info)
639-
val mapTree = TreeTypeMap(typeMap = mapType)
640-
mapTree.transform(binding).asInstanceOf[ValDef]
653+
binding.symbol.info = mapOpaques.typeMap(binding.symbol.info)
654+
mapOpaques.transform(binding).asInstanceOf[ValDef]
641655
.showing(i"transformed this binding exposing opaque aliases: $result", inlining)
642656
end accountForOpaques
643657

658+
/** If value argument contains references to objects that contain opaque types,
659+
* map them to their opaque proxies.
660+
*/
661+
def mapOpaquesInValueArg(arg: Tree)(using Context): Tree =
662+
val argType = arg.tpe.widen
663+
addOpaqueProxies(argType, arg.span, forThisProxy = false)
664+
if opaqueProxies.nonEmpty then
665+
val mappedType = mapOpaques.typeMap(argType)
666+
if mappedType ne argType then arg.cast(AndType(arg.tpe, mappedType))
667+
else arg
668+
else arg
669+
644670
private def canElideThis(tpe: ThisType): Boolean =
645671
inlineCallPrefix.tpe == tpe && ctx.owner.isContainedIn(tpe.cls)
646672
|| tpe.cls.isContainedIn(inlinedMethod)
@@ -773,7 +799,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
773799
def inlined(sourcePos: SrcPos): Tree = {
774800

775801
// Special handling of `requireConst` and `codeOf`
776-
callValueArgss match
802+
rawCallValueArgss match
777803
case (arg :: Nil) :: Nil =>
778804
if inlinedMethod == defn.Compiletime_requireConst then
779805
arg match
@@ -823,6 +849,11 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
823849
case TypeApply(fn, _) => paramTypess(fn, acc)
824850
case _ => acc
825851

852+
val callValueArgss = rawCallValueArgss.nestedMapConserve(mapOpaquesInValueArg)
853+
854+
if callValueArgss ne rawCallValueArgss then
855+
inlining.println(i"mapped value args = ${callValueArgss.flatten}%, %")
856+
826857
// Compute bindings for all parameters, appending them to bindingsBuf
827858
if !computeParamBindings(inlinedMethod.info, callTypeArgs, callValueArgss, paramTypess(call, Nil)) then
828859
return call
@@ -1254,7 +1285,7 @@ class Inliner(call: tpd.Tree, rhsToInline: tpd.Tree)(using Context) {
12541285
case fail: Implicits.SearchFailureType =>
12551286
false
12561287
case _ =>
1257-
//inliner.println(i"inferred implicit $sym: ${sym.info} with $evidence: ${evidence.tpe.widen}, ${evCtx.gadt.constraint}, ${evCtx.typerState.constraint}")
1288+
//inlining.println(i"inferred implicit $sym: ${sym.info} with $evidence: ${evidence.tpe.widen}, ${evCtx.gadt.constraint}, ${evCtx.typerState.constraint}")
12581289
newTermBinding(sym, evidence)
12591290
true
12601291
}

tests/run/i12914.check

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
ASD
2+
asd
3+
ASD
4+
asd
5+
ASD
6+
asd
7+
aSdaSdaSd
8+
aSdaSdaSd

tests/run/i12914.scala

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
2+
class opq:
3+
opaque type Str = java.lang.String
4+
object Str:
5+
def apply(s: String): Str = s
6+
inline def lower(s: Str): String = s.toLowerCase
7+
extension (s: Str)
8+
transparent inline def upper: String = s.toUpperCase
9+
inline def concat(xs: List[Str]): Str = String(xs.flatten.toArray)
10+
transparent inline def concat2(xs: List[Str]): Str = String(xs.flatten.toArray)
11+
12+
13+
@main def Test =
14+
val opq = new opq()
15+
import opq.*
16+
val a: Str = Str("aSd")
17+
println(a.upper)
18+
println(opq.lower(a))
19+
def b: Str = Str("aSd")
20+
println(b.upper)
21+
println(opq.lower(b))
22+
def c(): Str = Str("aSd")
23+
println(c().upper)
24+
println(opq.lower(c()))
25+
println(opq.concat(List(a, b, c())))
26+
println(opq.concat2(List(a, b, c())))
27+

0 commit comments

Comments
 (0)