Skip to content

Commit 42d914e

Browse files
committed
Add support for some type aliases, when expanding context bounds for poly functions
1 parent 458fd29 commit 42d914e

File tree

3 files changed

+64
-27
lines changed

3 files changed

+64
-27
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ object desugar {
5555
/** An attachment key to indicate that a DefDef is a poly function apply
5656
* method definition.
5757
*/
58-
val PolyFunctionApply: Property.Key[Unit] = Property.StickyKey()
58+
val PolyFunctionApply: Property.Key[List[ValDef]] = Property.StickyKey()
5959

6060
/** What static check should be applied to a Match? */
6161
enum MatchCheck {
@@ -514,17 +514,25 @@ object desugar {
514514
case Nil =>
515515
params :: Nil
516516

517+
// TODO(kπ) is this enough? SHould this be a TreeTraverse-thing?
518+
def pushDownEvidenceParams(tree: Tree): Tree = tree match
519+
case Function(params, body) =>
520+
cpy.Function(tree)(params, pushDownEvidenceParams(body))
521+
case Block(stats, expr) =>
522+
cpy.Block(tree)(stats, pushDownEvidenceParams(expr))
523+
case tree =>
524+
val paramTpts = params.map(_.tpt)
525+
val paramNames = params.map(_.name)
526+
val paramsErased = params.map(_.mods.flags.is(Erased))
527+
makeContextualFunction(paramTpts, paramNames, tree, paramsErased).withSpan(tree.span)
528+
517529
if meth.hasAttachment(PolyFunctionApply) then
518530
meth.removeAttachment(PolyFunctionApply)
519-
val paramTpts = params.map(_.tpt)
520-
val paramNames = params.map(_.name)
521-
val paramsErased = params.map(_.mods.flags.is(Erased))
531+
// (kπ): deffer this until we can type the result?
522532
if ctx.mode.is(Mode.Type) then
523-
val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.tpt, paramsErased)
524-
cpy.DefDef(meth)(tpt = ctxFunction)
533+
cpy.DefDef(meth)(tpt = meth.tpt.withAttachment(PolyFunctionApply, params))
525534
else
526-
val ctxFunction = makeContextualFunction(paramTpts, paramNames, meth.rhs, paramsErased)
527-
cpy.DefDef(meth)(rhs = ctxFunction)
535+
cpy.DefDef(meth)(rhs = pushDownEvidenceParams(meth.rhs))
528536
else
529537
cpy.DefDef(meth)(paramss = recur(meth.paramss))
530538
end addEvidenceParams
@@ -1251,7 +1259,7 @@ object desugar {
12511259
RefinedTypeTree(ref(defn.PolyFunctionType), List(
12521260
DefDef(nme.apply, tparams :: vparams :: Nil, res, EmptyTree)
12531261
.withFlags(Synthetic)
1254-
.withAttachment(PolyFunctionApply, ())
1262+
.withAttachment(PolyFunctionApply, List.empty)
12551263
)).withSpan(tree.span)
12561264
end makePolyFunctionType
12571265

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 41 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@ import config.MigrationVersion
5353
import transform.CheckUnused.OriginalName
5454

5555
import scala.annotation.constructorOnly
56-
import dotty.tools.dotc.ast.desugar.PolyFunctionApply
5756

5857
object Typer {
5958

@@ -1958,7 +1957,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19581957
untpd.InLambdaTypeTree(isResult = true, (tsyms, vsyms) =>
19591958
mt.resultType.substParams(mt, vsyms.map(_.termRef)).substParams(poly, tsyms.map(_.typeRef)))
19601959
val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, inferredVParams, body, resultTpt, tree.span)
1961-
defdef.putAttachment(PolyFunctionApply, ())
1960+
defdef.putAttachment(desugar.PolyFunctionApply, List.empty)
19621961
typed(desugared, pt)
19631962
else
19641963
val msg =
@@ -1967,7 +1966,7 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
19671966
errorTree(EmptyTree, msg, tree.srcPos)
19681967
case _ =>
19691968
val desugared @ Block(List(defdef), _) = desugar.makeClosure(tparams, vparams, body, untpd.TypeTree(), tree.span)
1970-
defdef.putAttachment(PolyFunctionApply, ())
1969+
defdef.putAttachment(desugar.PolyFunctionApply, List.empty)
19711970
typed(desugared, pt)
19721971
end typedPolyFunctionValue
19731972

@@ -3580,30 +3579,57 @@ class Typer(@constructorOnly nestingLevel: Int = 0) extends Namer
35803579
case xtree => typedUnnamed(xtree)
35813580

35823581
val unsimplifiedType = result.tpe
3583-
simplify(result, pt, locked)
3584-
result.tpe.stripTypeVar match
3582+
val result1 = simplify(result, pt, locked)
3583+
result1.tpe.stripTypeVar match
35853584
case e: ErrorType if !unsimplifiedType.isErroneous => errorTree(xtree, e.msg, xtree.srcPos)
3586-
case _ => result
3585+
case _ => result1
35873586
catch case ex: TypeError =>
35883587
handleTypeError(ex)
35893588
}
35903589
}
35913590

3591+
private def pushDownDeferredEvidenceParams(tpe: Type, params: List[untpd.ValDef], span: Span)(using Context): Type = tpe.dealias match {
3592+
case tpe: MethodType =>
3593+
MethodType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3594+
case tpe: PolyType =>
3595+
PolyType(tpe.paramNames)(paramNames => tpe.paramInfos, _ => pushDownDeferredEvidenceParams(tpe.resultType, params, span))
3596+
case tpe: RefinedType =>
3597+
// TODO(kπ): Doesn't seem right, but the PolyFunction ends up being a refinement
3598+
RefinedType(pushDownDeferredEvidenceParams(tpe.parent, params, span), tpe.refinedName, pushDownDeferredEvidenceParams(tpe.refinedInfo, params, span))
3599+
case tpe @ AppliedType(tycon, args) if defn.isFunctionType(tpe) && args.size > 1 =>
3600+
AppliedType(tpe.tycon, args.init :+ pushDownDeferredEvidenceParams(args.last, params, span))
3601+
case tpe =>
3602+
val paramNames = params.map(_.name)
3603+
val paramTpts = params.map(_.tpt)
3604+
val paramsErased = params.map(_.mods.flags.is(Erased))
3605+
val ctxFunction = desugar.makeContextualFunction(paramTpts, paramNames, untpd.TypedSplice(TypeTree(tpe.dealias)), paramsErased).withSpan(span)
3606+
typed(ctxFunction).tpe
3607+
}
3608+
3609+
private def addDownDeferredEvidenceParams(tree: Tree, pt: Type)(using Context): (Tree, Type) = {
3610+
tree.getAttachment(desugar.PolyFunctionApply) match
3611+
case Some(params) if params.nonEmpty =>
3612+
tree.removeAttachment(desugar.PolyFunctionApply)
3613+
val tpe = pushDownDeferredEvidenceParams(tree.tpe, params, tree.span)
3614+
TypeTree(tpe).withSpan(tree.span) -> tpe
3615+
case _ => tree -> pt
3616+
}
3617+
35923618
/** Interpolate and simplify the type of the given tree. */
3593-
protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): tree.type =
3594-
if !tree.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
3595-
if !tree.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
3596-
|| tree.isDef // ... unless tree is a definition
3619+
protected def simplify(tree: Tree, pt: Type, locked: TypeVars)(using Context): Tree =
3620+
val (tree1, pt1) = addDownDeferredEvidenceParams(tree, pt)
3621+
if !tree1.denot.isOverloaded then // for overloaded trees: resolve overloading before simplifying
3622+
if !tree1.tpe.widen.isInstanceOf[MethodOrPoly] // wait with simplifying until method is fully applied
3623+
|| tree1.isDef // ... unless tree is a definition
35973624
then
3598-
interpolateTypeVars(tree, pt, locked)
3599-
val simplified = tree.tpe.simplified
3600-
if !MatchType.thatReducesUsingGadt(tree.tpe) then // needs a GADT cast. i15743
3625+
interpolateTypeVars(tree1, pt1, locked)
3626+
val simplified = tree1.tpe.simplified
3627+
if !MatchType.thatReducesUsingGadt(tree1.tpe) then // needs a GADT cast. i15743
36013628
tree.overwriteType(simplified)
3602-
tree
3629+
tree1
36033630

36043631
protected def makeContextualFunction(tree: untpd.Tree, pt: Type)(using Context): Tree = {
36053632
val defn.FunctionOf(formals, _, true) = pt.dropDependentRefinement: @unchecked
3606-
println(i"make contextual function $tree / $pt")
36073633
val paramNamesOrNil = pt match
36083634
case RefinedType(_, _, rinfo: MethodType) => rinfo.paramNames
36093635
case _ => Nil

tests/pos/contextbounds-for-poly-functions.scala

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,12 @@ val less3: Comparer = [X: Ord as ord] => (x: X, y: X) => ord.compare(x, y) < 0
2828
// type Comparer2 = [X: Ord] => Cmp[X]
2929
// val less4: Comparer2 = [X: Ord] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
3030

31-
// type CmpWeak[X] = (x: X, y: X) => Boolean
32-
// type Comparer2Weak = [X: Ord] => (x: X) => CmpWeak[X]
33-
// val less4: Comparer2Weak = [X: Ord] => (x: X) => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
31+
type CmpWeak[X] = X => Boolean
32+
type Comparer2Weak = [X: Ord] => X => CmpWeak[X]
33+
val less4_0: [X: Ord] => X => X => Boolean =
34+
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
35+
val less4: Comparer2Weak =
36+
[X: Ord] => (x: X) => (y: X) => summon[Ord[X]].compare(x, y) < 0
3437

3538
val less5 = [X: [X] =>> Ord[X]] => (x: X, y: X) => summon[Ord[X]].compare(x, y) < 0
3639

0 commit comments

Comments
 (0)