Skip to content

Commit e7a2286

Browse files
committed
refactor function extractor in union
1 parent 8f197ac commit e7a2286

File tree

2 files changed

+24
-29
lines changed

2 files changed

+24
-29
lines changed

compiler/src/dotty/tools/dotc/core/Types.scala

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1651,6 +1651,26 @@ object Types {
16511651
case _ => resultType
16521652
}
16531653

1654+
/** Find the function type in union.
1655+
* If there are multiple function types, NoType is returned.
1656+
*/
1657+
def findFuntionTypeInUnion(using Context): Type = this match {
1658+
case t: OrType =>
1659+
val t1 = t.tp1.findFuntionTypeInUnion
1660+
if t1 == NoType then t.tp2.findFuntionTypeInUnion else
1661+
val t2 = t.tp2.findFuntionTypeInUnion
1662+
// Returen NoType if the union contains multiple function types
1663+
if t2 == NoType then t1 else NoType
1664+
case t: TypeParamRef =>
1665+
ctx.typerState.constraint.entry(t).bounds.hi.findFuntionTypeInUnion
1666+
case t if defn.isNonRefinedFunction(t) =>
1667+
t
1668+
case t @ SAMType(_: MethodType) =>
1669+
t
1670+
case _ =>
1671+
NoType
1672+
}
1673+
16541674
/** This type seen as a TypeBounds */
16551675
final def bounds(using Context): TypeBounds = this match {
16561676
case tp: TypeBounds => tp

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1125,19 +1125,6 @@ class Typer extends Namer
11251125
newTypeVar(apply(bounds.orElse(TypeBounds.empty)).bounds)
11261126
case _ => mapOver(t)
11271127
}
1128-
def extractInUnion(t: Type): Seq[Type] = t match {
1129-
case t: OrType =>
1130-
extractInUnion(t.tp1) ++ extractInUnion(t.tp2)
1131-
case t: TypeParamRef =>
1132-
extractInUnion(ctx.typerState.constraint.entry(t).bounds.hi)
1133-
case t if defn.isNonRefinedFunction(t) =>
1134-
Seq(t)
1135-
case SAMType(_: MethodType) =>
1136-
Seq(t)
1137-
case _ =>
1138-
Nil
1139-
}
1140-
def defaultResult = (List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
11411128

11421129
val pt1 = pt.stripTypeVar.dealias
11431130
if (pt1 ne pt1.dropDependentRefinement)
@@ -1148,11 +1135,7 @@ class Typer extends Namer
11481135
|is a curried dependent context function type. Such types are not yet supported.""",
11491136
tree.srcPos)
11501137

1151-
val elems = extractInUnion(pt1)
1152-
if elems.length != 1 then
1153-
// The union type containing multiple function types is ignored
1154-
defaultResult
1155-
else elems.head match {
1138+
pt1.findFuntionTypeInUnion match {
11561139
case pt1 if defn.isNonRefinedFunction(pt1) =>
11571140
// if expected parameter type(s) are wildcards, approximate from below.
11581141
// if expected result type is a wildcard, approximate from above.
@@ -1165,7 +1148,7 @@ class Typer extends Namer
11651148
else
11661149
typeTree(restpe))
11671150
case _ =>
1168-
defaultResult
1151+
(List.tabulate(defaultArity)(alwaysWildcardType), untpd.TypeTree())
11691152
}
11701153
}
11711154

@@ -1410,22 +1393,14 @@ class Typer extends Namer
14101393
}
14111394

14121395
def typedClosure(tree: untpd.Closure, pt: Type)(using Context): Tree = {
1413-
def extractInUnion(t: Type): Seq[Type] = t match {
1414-
case t: OrType =>
1415-
extractInUnion(t.tp1) ++ extractInUnion(t.tp2)
1416-
case SAMType(_) =>
1417-
Seq(t)
1418-
case _ =>
1419-
Nil
1420-
}
14211396
val env1 = tree.env mapconserve (typed(_))
14221397
val meth1 = typedUnadapted(tree.meth)
14231398
val target =
14241399
if (tree.tpt.isEmpty)
14251400
meth1.tpe.widen match {
14261401
case mt: MethodType =>
1427-
extractInUnion(pt) match {
1428-
case Seq(pt @ SAMType(sam))
1402+
pt.findFuntionTypeInUnion match {
1403+
case pt @ SAMType(sam)
14291404
if !defn.isFunctionType(pt) && mt <:< sam =>
14301405
// SAMs of the form C[?] where C is a class cannot be conversion targets.
14311406
// The resulting class `class $anon extends C[?] {...}` would be illegal,

0 commit comments

Comments
 (0)