Skip to content

Commit ea58b66

Browse files
smarterKordyjan
authored andcommitted
Properly handle SAM types with wildcards
[Cherry-picked 89735d0][modified] When typing a closure with an expected type containing a wildcard, the closure type itself should not contain wildcards, because it might be expanded to an anonymous class extending the closure type (this happens on non-JVM backends as well as on the JVM itself in situations where a SAM trait does not compile down to a SAM interface). We were already approximating wildcards in the method type returned by the SAMType extractor, but to fix this issue we had to change the extractor to perform the approximation on the expected type itself to generate a valid parent type. The SAMType extractor now returns both the approximated parent type and the type of the method itself. The wildcard approximation analysis relies on a new `VarianceMap` opaque type extracted from Inferencing#variances. Fixes #16065. Fixes #18096.
1 parent f68b617 commit ea58b66

File tree

10 files changed

+207
-129
lines changed

10 files changed

+207
-129
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -744,6 +744,7 @@ class Definitions {
744744
@tu lazy val StringContextModule_processEscapes: Symbol = StringContextModule.requiredMethod(nme.processEscapes)
745745

746746
@tu lazy val PartialFunctionClass: ClassSymbol = requiredClass("scala.PartialFunction")
747+
@tu lazy val PartialFunction_apply: Symbol = PartialFunctionClass.requiredMethod(nme.apply)
747748
@tu lazy val PartialFunction_isDefinedAt: Symbol = PartialFunctionClass.requiredMethod(nme.isDefinedAt)
748749
@tu lazy val PartialFunction_applyOrElse: Symbol = PartialFunctionClass.requiredMethod(nme.applyOrElse)
749750

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 128 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ import CheckRealizable._
2121
import Variances.{Variance, setStructuralVariances, Invariant}
2222
import typer.Nullables
2323
import util.Stats._
24-
import util.SimpleIdentitySet
24+
import util.{SimpleIdentityMap, SimpleIdentitySet}
2525
import ast.tpd._
2626
import ast.TreeTypeMap
2727
import printing.Texts._
@@ -1741,7 +1741,7 @@ object Types {
17411741
t
17421742
case t if defn.isErasedFunctionType(t) =>
17431743
t
1744-
case t @ SAMType(_) =>
1744+
case t @ SAMType(_, _) =>
17451745
t
17461746
case _ =>
17471747
NoType
@@ -5497,104 +5497,119 @@ object Types {
54975497
* A type is a SAM type if it is a reference to a class or trait, which
54985498
*
54995499
* - has a single abstract method with a method type (ExprType
5500-
* and PolyType not allowed!) whose result type is not an implicit function type
5501-
* and which is not marked inline.
5500+
* and PolyType not allowed!) according to `possibleSamMethods`.
55025501
* - can be instantiated without arguments or with just () as argument.
55035502
*
5504-
* The pattern `SAMType(sam)` matches a SAM type, where `sam` is the
5505-
* type of the single abstract method.
5503+
* The pattern `SAMType(samMethod, samParent)` matches a SAM type, where `samMethod` is the
5504+
* type of the single abstract method and `samParent` is a subtype of the matched
5505+
* SAM type which has been stripped of wildcards to turn it into a valid parent
5506+
* type.
55065507
*/
55075508
object SAMType {
5508-
def zeroParamClass(tp: Type)(using Context): Type = tp match {
5509+
/** If possible, return a type which is both a subtype of `origTp` and a type
5510+
* application of `samClass` where none of the type arguments are
5511+
* wildcards (thus making it a valid parent type), otherwise return
5512+
* NoType.
5513+
*
5514+
* A wildcard in the original type will be replaced by its upper or lower bound in a way
5515+
* that maximizes the number of possible implementations of `samMeth`. For example,
5516+
* java.util.function defines an interface equivalent to:
5517+
*
5518+
* trait Function[T, R]:
5519+
* def apply(t: T): R
5520+
*
5521+
* and it usually appears with wildcards to compensate for the lack of
5522+
* definition-site variance in Java:
5523+
*
5524+
* (x => x.toInt): Function[? >: String, ? <: Int]
5525+
*
5526+
* When typechecking this lambda, we need to approximate the wildcards to find
5527+
* a valid parent type for our lambda to extend. We can see that in `apply`,
5528+
* `T` only appears contravariantly and `R` only appears covariantly, so by
5529+
* minimizing the first parameter and maximizing the second, we maximize the
5530+
* number of valid implementations of `apply` which lets us implement the lambda
5531+
* with a closure equivalent to:
5532+
*
5533+
* new Function[String, Int] { def apply(x: String): Int = x.toInt }
5534+
*
5535+
* If a type parameter appears invariantly or does not appear at all in `samMeth`, then
5536+
* we arbitrarily pick the upper-bound.
5537+
*/
5538+
def samParent(origTp: Type, samClass: Symbol, samMeth: Symbol)(using Context): Type =
5539+
val tp = origTp.baseType(samClass)
5540+
if !(tp <:< origTp) then NoType
5541+
else tp match
5542+
case tp @ AppliedType(tycon, args) if tp.hasWildcardArg =>
5543+
val accu = new TypeAccumulator[VarianceMap[Symbol]]:
5544+
def apply(vmap: VarianceMap[Symbol], t: Type): VarianceMap[Symbol] = t match
5545+
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) =>
5546+
vmap.recordLocalVariance(tp.symbol, variance)
5547+
case _ =>
5548+
foldOver(vmap, t)
5549+
val vmap = accu(VarianceMap.empty, samMeth.info)
5550+
val tparams = tycon.typeParamSymbols
5551+
val args1 = args.zipWithConserve(tparams):
5552+
case (arg @ TypeBounds(lo, hi), tparam) =>
5553+
val v = vmap.computedVariance(tparam)
5554+
if v.uncheckedNN < 0 then lo
5555+
else hi
5556+
case (arg, _) => arg
5557+
tp.derivedAppliedType(tycon, args1)
5558+
case _ =>
5559+
tp
5560+
5561+
def samClass(tp: Type)(using Context): Symbol = tp match
55095562
case tp: ClassInfo =>
5510-
def zeroParams(tp: Type): Boolean = tp.stripPoly match {
5563+
def zeroParams(tp: Type): Boolean = tp.stripPoly match
55115564
case mt: MethodType => mt.paramInfos.isEmpty && !mt.resultType.isInstanceOf[MethodType]
55125565
case et: ExprType => true
55135566
case _ => false
5514-
}
5515-
// `ContextFunctionN` does not have constructors
5516-
val ctor = tp.cls.primaryConstructor
5517-
if (!ctor.exists || zeroParams(ctor.info)) tp
5518-
else NoType
5567+
val cls = tp.cls
5568+
val validCtor =
5569+
val ctor = cls.primaryConstructor
5570+
// `ContextFunctionN` does not have constructors
5571+
!ctor.exists || zeroParams(ctor.info)
5572+
val isInstantiable = !cls.isOneOf(FinalOrSealed) && (tp.appliedRef <:< tp.selfType)
5573+
if validCtor && isInstantiable then tp.cls
5574+
else NoSymbol
55195575
case tp: AppliedType =>
5520-
zeroParamClass(tp.superType)
5576+
samClass(tp.superType)
55215577
case tp: TypeRef =>
5522-
zeroParamClass(tp.underlying)
5578+
samClass(tp.underlying)
55235579
case tp: RefinedType =>
5524-
zeroParamClass(tp.underlying)
5580+
samClass(tp.underlying)
55255581
case tp: TypeBounds =>
5526-
zeroParamClass(tp.underlying)
5582+
samClass(tp.underlying)
55275583
case tp: TypeVar =>
5528-
zeroParamClass(tp.underlying)
5584+
samClass(tp.underlying)
55295585
case tp: AnnotatedType =>
5530-
zeroParamClass(tp.underlying)
5531-
case _ =>
5532-
NoType
5533-
}
5534-
def isInstantiatable(tp: Type)(using Context): Boolean = zeroParamClass(tp) match {
5535-
case cinfo: ClassInfo if !cinfo.cls.isOneOf(FinalOrSealed) =>
5536-
val selfType = cinfo.selfType.asSeenFrom(tp, cinfo.cls)
5537-
tp <:< selfType
5586+
samClass(tp.underlying)
55385587
case _ =>
5539-
false
5540-
}
5541-
def unapply(tp: Type)(using Context): Option[MethodType] =
5542-
if (isInstantiatable(tp)) {
5543-
val absMems = tp.possibleSamMethods
5544-
if (absMems.size == 1)
5545-
absMems.head.info match {
5546-
case mt: MethodType if !mt.isParamDependent &&
5547-
mt.resultType.isValueTypeOrWildcard =>
5548-
val cls = tp.classSymbol
5549-
5550-
// Given a SAM type such as:
5551-
//
5552-
// import java.util.function.Function
5553-
// Function[? >: String, ? <: Int]
5554-
//
5555-
// the single abstract method will have type:
5556-
//
5557-
// (x: Function[? >: String, ? <: Int]#T): Function[? >: String, ? <: Int]#R
5558-
//
5559-
// which is not implementable outside of the scope of Function.
5560-
//
5561-
// To avoid this kind of issue, we approximate references to
5562-
// parameters of the SAM type by their bounds, this way in the
5563-
// above example we get:
5564-
//
5565-
// (x: String): Int
5566-
val approxParams = new ApproximatingTypeMap {
5567-
def apply(tp: Type): Type = tp match {
5568-
case tp: TypeRef if tp.symbol.isAllOf(ClassTypeParam) && tp.symbol.owner == cls =>
5569-
tp.info match {
5570-
case info: AliasingBounds =>
5571-
mapOver(info.alias)
5572-
case TypeBounds(lo, hi) =>
5573-
range(atVariance(-variance)(apply(lo)), apply(hi))
5574-
case _ =>
5575-
range(defn.NothingType, defn.AnyType) // should happen only in error cases
5576-
}
5577-
case _ =>
5578-
mapOver(tp)
5579-
}
5580-
}
5581-
val approx =
5582-
if ctx.owner.isContainedIn(cls) then mt
5583-
else approxParams(mt).asInstanceOf[MethodType]
5584-
Some(approx)
5588+
NoSymbol
5589+
5590+
def unapply(tp: Type)(using Context): Option[(MethodType, Type)] =
5591+
val cls = samClass(tp)
5592+
if cls.exists then
5593+
val absMems =
5594+
if tp.isRef(defn.PartialFunctionClass) then
5595+
// To maintain compatibility with 2.x, we treat PartialFunction specially,
5596+
// pretending it is a SAM type. In the future it would be better to merge
5597+
// Function and PartialFunction, have Function1 contain a isDefinedAt method
5598+
// def isDefinedAt(x: T) = true
5599+
// and overwrite that method whenever the function body is a sequence of
5600+
// case clauses.
5601+
List(defn.PartialFunction_apply)
5602+
else
5603+
tp.possibleSamMethods.map(_.symbol)
5604+
if absMems.lengthCompare(1) == 0 then
5605+
val samMethSym = absMems.head
5606+
val parent = samParent(tp, cls, samMethSym)
5607+
samMethSym.asSeenFrom(parent).info match
5608+
case mt: MethodType if !mt.isParamDependent && mt.resultType.isValueTypeOrWildcard =>
5609+
Some(mt, parent)
55855610
case _ =>
55865611
None
5587-
}
5588-
else if (tp isRef defn.PartialFunctionClass)
5589-
// To maintain compatibility with 2.x, we treat PartialFunction specially,
5590-
// pretending it is a SAM type. In the future it would be better to merge
5591-
// Function and PartialFunction, have Function1 contain a isDefinedAt method
5592-
// def isDefinedAt(x: T) = true
5593-
// and overwrite that method whenever the function body is a sequence of
5594-
// case clauses.
5595-
absMems.find(_.symbol.name == nme.apply).map(_.info.asInstanceOf[MethodType])
55965612
else None
5597-
}
55985613
else None
55995614
}
56005615

@@ -6427,6 +6442,37 @@ object Types {
64276442
}
64286443
}
64296444

6445+
object VarianceMap:
6446+
/** An immutable map representing the variance of keys of type `K` */
6447+
opaque type VarianceMap[K <: AnyRef] <: AnyRef = SimpleIdentityMap[K, Integer]
6448+
def empty[K <: AnyRef]: VarianceMap[K] = SimpleIdentityMap.empty[K]
6449+
extension [K <: AnyRef](vmap: VarianceMap[K])
6450+
/** The backing map used to implement this VarianceMap. */
6451+
inline def underlying: SimpleIdentityMap[K, Integer] = vmap
6452+
6453+
/** Return a new map taking into account that K appears in a
6454+
* {co,contra,in}-variant position if `localVariance` is {positive,negative,zero}.
6455+
*/
6456+
def recordLocalVariance(k: K, localVariance: Int): VarianceMap[K] =
6457+
val previousVariance = vmap(k)
6458+
if previousVariance == null then
6459+
vmap.updated(k, localVariance)
6460+
else if previousVariance == localVariance || previousVariance == 0 then
6461+
vmap
6462+
else
6463+
vmap.updated(k, 0)
6464+
6465+
/** Return the variance of `k`:
6466+
* - A positive value means that `k` appears only covariantly.
6467+
* - A negative value means that `k` appears only contravariantly.
6468+
* - A zero value means that `k` appears both covariantly and
6469+
* contravariantly, or appears invariantly.
6470+
* - A null value means that `k` does not appear at all.
6471+
*/
6472+
def computedVariance(k: K): Integer | Null =
6473+
vmap(k)
6474+
export VarianceMap.VarianceMap
6475+
64306476
// ----- Name Filters --------------------------------------------------
64316477

64326478
/** A name filter selects or discards a member name of a type `pre`.

compiler/src/dotty/tools/dotc/transform/ExpandSAMs.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,10 @@ class ExpandSAMs extends MiniPhase:
5050
tree // it's a plain function
5151
case tpe if defn.isContextFunctionType(tpe) =>
5252
tree
53-
case tpe @ SAMType(_) if tpe.isRef(defn.PartialFunctionClass) =>
53+
case SAMType(_, tpe) if tpe.isRef(defn.PartialFunctionClass) =>
5454
val tpe1 = checkRefinements(tpe, fn)
5555
toPartialFunction(tree, tpe1)
56-
case tpe @ SAMType(_) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
56+
case SAMType(_, tpe) if ExpandSAMs.isPlatformSam(tpe.classSymbol.asClass) =>
5757
checkRefinements(tpe, fn)
5858
tree
5959
case tpe =>

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -696,7 +696,7 @@ trait Applications extends Compatibility {
696696

697697
def SAMargOK =
698698
defn.isFunctionNType(argtpe1) && formal.match
699-
case SAMType(sam) => argtpe <:< sam.toFunctionType(isJava = formal.classSymbol.is(JavaDefined))
699+
case SAMType(samMeth, samParent) => argtpe <:< samMeth.toFunctionType(isJava = samParent.classSymbol.is(JavaDefined))
700700
case _ => false
701701

702702
isCompatible(argtpe, formal)
@@ -2080,7 +2080,7 @@ trait Applications extends Compatibility {
20802080
* new java.io.ObjectOutputStream(f)
20812081
*/
20822082
pt match {
2083-
case SAMType(mtp) =>
2083+
case SAMType(mtp, _) =>
20842084
narrowByTypes(alts, mtp.paramInfos, mtp.resultType)
20852085
case _ =>
20862086
// pick any alternatives that are not methods since these might be convertible

compiler/src/dotty/tools/dotc/typer/Inferencing.scala

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -411,7 +411,7 @@ object Inferencing {
411411
val vs = variances(tp)
412412
val patternBindings = new mutable.ListBuffer[(Symbol, TypeParamRef)]
413413
val gadtBounds = ctx.gadt.symbols.map(ctx.gadt.bounds(_).nn)
414-
vs foreachBinding { (tvar, v) =>
414+
vs.underlying foreachBinding { (tvar, v) =>
415415
if !tvar.isInstantiated then
416416
// if the tvar is covariant/contravariant (v == 1/-1, respectively) in the input type tp
417417
// then it is safe to instantiate if it doesn't occur in any of the GADT bounds.
@@ -444,8 +444,6 @@ object Inferencing {
444444
res
445445
}
446446

447-
type VarianceMap = SimpleIdentityMap[TypeVar, Integer]
448-
449447
/** All occurrences of type vars in `tp` that satisfy predicate
450448
* `include` mapped to their variances (-1/0/1) in both `tp` and
451449
* `pt.finalResultType`, where
@@ -469,23 +467,18 @@ object Inferencing {
469467
*
470468
* we want to instantiate U to x.type right away. No need to wait further.
471469
*/
472-
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap = {
470+
private def variances(tp: Type, pt: Type = WildcardType)(using Context): VarianceMap[TypeVar] = {
473471
Stats.record("variances")
474472
val constraint = ctx.typerState.constraint
475473

476-
object accu extends TypeAccumulator[VarianceMap] {
474+
object accu extends TypeAccumulator[VarianceMap[TypeVar]]:
477475
def setVariance(v: Int) = variance = v
478-
def apply(vmap: VarianceMap, t: Type): VarianceMap = t match {
476+
def apply(vmap: VarianceMap[TypeVar], t: Type): VarianceMap[TypeVar] = t match
479477
case t: TypeVar
480478
if !t.isInstantiated && accCtx.typerState.constraint.contains(t) =>
481-
val v = vmap(t)
482-
if (v == null) vmap.updated(t, variance)
483-
else if (v == variance || v == 0) vmap
484-
else vmap.updated(t, 0)
479+
vmap.recordLocalVariance(t, variance)
485480
case _ =>
486481
foldOver(vmap, t)
487-
}
488-
}
489482

490483
/** Include in `vmap` type variables occurring in the constraints of type variables
491484
* already in `vmap`. Specifically:
@@ -497,10 +490,10 @@ object Inferencing {
497490
* bounds as non-variant.
498491
* Do this in a fixpoint iteration until `vmap` stabilizes.
499492
*/
500-
def propagate(vmap: VarianceMap): VarianceMap = {
493+
def propagate(vmap: VarianceMap[TypeVar]): VarianceMap[TypeVar] = {
501494
var vmap1 = vmap
502495
def traverse(tp: Type) = { vmap1 = accu(vmap1, tp) }
503-
vmap.foreachBinding { (tvar, v) =>
496+
vmap.underlying.foreachBinding { (tvar, v) =>
504497
val param = tvar.origin
505498
constraint.entry(param) match
506499
case TypeBounds(lo, hi) =>
@@ -516,7 +509,7 @@ object Inferencing {
516509
if (vmap1 eq vmap) vmap else propagate(vmap1)
517510
}
518511

519-
propagate(accu(accu(SimpleIdentityMap.empty, tp), pt.finalResultType))
512+
propagate(accu(accu(VarianceMap.empty, tp), pt.finalResultType))
520513
}
521514

522515
/** Run the transformation after dealiasing but return the original type if it was a no-op. */
@@ -642,7 +635,7 @@ trait Inferencing { this: Typer =>
642635
if !tvar.isInstantiated then
643636
// isInstantiated needs to be checked again, since previous interpolations could already have
644637
// instantiated `tvar` through unification.
645-
val v = vs(tvar)
638+
val v = vs.computedVariance(tvar)
646639
if v == null then buf += ((tvar, 0))
647640
else if v.intValue != 0 then buf += ((tvar, v.intValue))
648641
else comparing(cmp =>

0 commit comments

Comments
 (0)