Skip to content

Commit 8c1a6dc

Browse files
committed
Factor out variance manipulation in TypeMap and TypeAccumulator
1 parent 89180b5 commit 8c1a6dc

File tree

1 file changed

+24
-41
lines changed

1 file changed

+24
-41
lines changed

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 24 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -3676,14 +3676,26 @@ object Types {
36763676

36773677
// ----- TypeMaps --------------------------------------------------------------------
36783678

3679-
abstract class TypeMap(implicit protected val ctx: Context) extends (Type => Type) { thisMap =>
3679+
/** Common base class of TypeMap and TypeAccumulator */
3680+
abstract class VariantTraversal {
3681+
protected[core] var variance = 1
3682+
3683+
@inline protected def atVariance[T](v: Int)(op: => T): T = {
3684+
val saved = variance
3685+
variance = v
3686+
val res = op
3687+
variance = saved
3688+
res
3689+
}
3690+
}
3691+
3692+
abstract class TypeMap(implicit protected val ctx: Context)
3693+
extends VariantTraversal with (Type => Type) { thisMap =>
36803694

36813695
protected def stopAtStatic = true
36823696

36833697
def apply(tp: Type): Type
36843698

3685-
protected[core] var variance = 1
3686-
36873699
protected def derivedSelect(tp: NamedType, pre: Type): Type =
36883700
tp.derivedSelect(pre)
36893701
protected def derivedRefinedType(tp: RefinedType, parent: Type, info: Type): Type =
@@ -3721,16 +3733,13 @@ object Types {
37213733
case tp: NamedType =>
37223734
if (stopAtStatic && tp.symbol.isStatic) tp
37233735
else {
3724-
val saved = variance
3725-
variance = variance max 0
3736+
val prefix1 = atVariance(variance max 0)(this(tp.prefix))
37263737
// A prefix is never contravariant. Even if say `p.A` is used in a contravariant
37273738
// context, we cannot assume contravariance for `p` because `p`'s lower
37283739
// bound might not have a binding for `A` (e.g. the lower bound could be `Nothing`).
37293740
// By contrast, covariance does translate to the prefix, since we have that
37303741
// if `p <: q` then `p.A <: q.A`, and well-formedness requires that `A` is a member
37313742
// of `p`'s upper bound.
3732-
val prefix1 = this(tp.prefix)
3733-
variance = saved
37343743
derivedSelect(tp, prefix1)
37353744
}
37363745
case _: ThisType
@@ -3741,11 +3750,7 @@ object Types {
37413750
derivedRefinedType(tp, this(tp.parent), this(tp.refinedInfo))
37423751

37433752
case tp: TypeAlias =>
3744-
val saved = variance
3745-
variance *= tp.variance
3746-
val alias1 = this(tp.alias)
3747-
variance = saved
3748-
derivedTypeAlias(tp, alias1)
3753+
derivedTypeAlias(tp, atVariance(variance * tp.variance)(this(tp.alias)))
37493754

37503755
case tp: TypeBounds =>
37513756
variance = -variance
@@ -3761,12 +3766,8 @@ object Types {
37613766
if (inst.exists) apply(inst) else tp
37623767

37633768
case tp: HKApply =>
3764-
def mapArg(arg: Type, tparam: ParamInfo): Type = {
3765-
val saved = variance
3766-
variance *= tparam.paramVariance
3767-
try this(arg)
3768-
finally variance = saved
3769-
}
3769+
def mapArg(arg: Type, tparam: ParamInfo): Type =
3770+
atVariance(variance * tparam.paramVariance)(this(arg))
37703771
derivedAppliedType(tp, this(tp.tycon),
37713772
tp.args.zipWithConserve(tp.typeParams)(mapArg))
37723773

@@ -3891,12 +3892,6 @@ object Types {
38913892
case _ => tp
38923893
}
38933894

3894-
protected def atVariance[T](v: Int)(op: => T): T = {
3895-
val saved = variance
3896-
variance = v
3897-
try op finally variance = saved
3898-
}
3899-
39003895
/** Derived selection.
39013896
* @pre the (upper bound of) prefix `pre` has a member named `tp.name`.
39023897
*/
@@ -4051,23 +4046,17 @@ object Types {
40514046

40524047
// ----- TypeAccumulators ----------------------------------------------------
40534048

4054-
abstract class TypeAccumulator[T](implicit protected val ctx: Context) extends ((T, Type) => T) {
4049+
abstract class TypeAccumulator[T](implicit protected val ctx: Context)
4050+
extends VariantTraversal with ((T, Type) => T) {
40554051

40564052
protected def stopAtStatic = true
40574053

40584054
def apply(x: T, tp: Type): T
40594055

40604056
protected def applyToAnnot(x: T, annot: Annotation): T = x // don't go into annotations
40614057

4062-
protected var variance = 1
4063-
4064-
protected final def applyToPrefix(x: T, tp: NamedType) = {
4065-
val saved = variance
4066-
variance = variance max 0 // see remark on NamedType case in TypeMap
4067-
val result = this(x, tp.prefix)
4068-
variance = saved
4069-
result
4070-
}
4058+
protected final def applyToPrefix(x: T, tp: NamedType) =
4059+
atVariance(variance max 0)(this(x, tp.prefix)) // see remark on NamedType case in TypeMap
40714060

40724061
def foldOver(x: T, tp: Type): T = tp match {
40734062
case tp: TypeRef =>
@@ -4088,13 +4077,7 @@ object Types {
40884077
this(this(x, tp.parent), tp.refinedInfo)
40894078

40904079
case bounds @ TypeBounds(lo, hi) =>
4091-
if (lo eq hi) {
4092-
val saved = variance
4093-
variance = variance * bounds.variance
4094-
val result = this(x, lo)
4095-
variance = saved
4096-
result
4097-
}
4080+
if (lo eq hi) atVariance(variance * bounds.variance)(this(x, lo))
40984081
else {
40994082
variance = -variance
41004083
val y = this(x, lo)

0 commit comments

Comments
 (0)