Skip to content

Commit f338416

Browse files
committed
Implement extension methods as implicit value classes
1 parent 7ace3d7 commit f338416

File tree

3 files changed

+88
-36
lines changed

3 files changed

+88
-36
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 74 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ object desugar {
300300
def isAnyVal(tree: Tree): Boolean = tree match {
301301
case Ident(tpnme.AnyVal) => true
302302
case Select(qual, tpnme.AnyVal) => isScala(qual)
303+
case TypedSplice(tree) => tree.tpe.isRef(defn.AnyValClass)
303304
case _ => false
304305
}
305306
def isScala(tree: Tree): Boolean = tree match {
@@ -770,66 +771,105 @@ object desugar {
770771
(elimTypeDefs.transform(tree), bindingsBuf.toList)
771772
}
772773

773-
/** augment <name> <type-params> <params> extends <parents> { <body>} }
774+
/** augment <type-pattern> <params> extends <parents> { <body>} }
774775
* ->
775-
* implicit class <deconame> <type-params> ($this: name <type-args>) <params>
776-
* extends <parents> { <body1> }
777-
*
778-
* augment <type-param> <params> extends <parents> { <body>} }
779-
* ->
780-
* implicit class <deconame> <type-param> ($this: <type-arg>) <params>
776+
* implicit class <deconame> <type-params> ($this: <decorated>) <combined-params>
781777
* extends <parents> { <body1> }
782778
*
783779
* where
784780
*
785-
* <deconame> = <name>To<parent>$<n> where <parent> is first extended class name
786-
* = <name>Augmentation$<n> if no such <parent> exists
781+
* (<decorated>, <type-params0>) = decomposeTypePattern(<type-pattern>)
782+
* (<type-params>, <evidence-params>) = desugarTypeBindings(<type-params0>)
783+
* <combined-params> = <params> concatenated with <evidence-params> in one clause
784+
* <deconame> = <from>To<parent>_in_<location>$$<n> where <parent> is first extended class name
785+
*
786+
* = <from>Augmentation_in_<location>$$<n> if no such <parent> exists
787+
* <from> = underlying type name of <decorated>
788+
* <location> = flat name of enclosing toplevel class
787789
* <n> = counter making prefix unique
788-
* <type-args> = references to <type-params>
789790
* <body1> = <body> with each occurrence of unqualified `this` substituted by `$this`.
791+
*
792+
* ~~~~~~~~~~~~~~~~~~~~~~~~~~~~
793+
*
794+
* augment <type-pattern> <params> { <body> }
795+
* ->
796+
* implicit class <deconame> <type-params> ($this: <decorated>)
797+
* extends AnyVal { <body2> }
798+
*
799+
* where
800+
*
801+
* <body2> = <body1> where each method definition gets <combined-params> as last parameter section.
802+
* <deconame>, <type-params> are as above.
790803
*/
791804
def augmentation(tree: Augment)(implicit ctx: Context): Tree = {
792805
val Augment(augmented, impl) = tree
793-
val constr @ DefDef(_, Nil, vparamss, _, _) = impl.constr
806+
val isSimpleExtension =
807+
impl.parents.isEmpty &&
808+
impl.self.isEmpty &&
809+
impl.body.forall(_.isInstanceOf[DefDef])
794810
val (decorated, bindings) = decomposeTypePattern(augmented)
795-
val firstParam = ValDef(nme.SELF, decorated, EmptyTree).withFlags(Private | Local | ParamAccessor)
796-
val constr1 =
797-
cpy.DefDef(constr)(
798-
tparams = bindings.map(_.withFlags(Param | Private | Local)),
799-
vparamss = (firstParam :: Nil) :: vparamss)
800-
val substThis = new UntypedTreeMap {
801-
override def transform(tree: Tree)(implicit ctx: Context): Tree = tree match {
802-
case This(Ident(tpnme.EMPTY)) => Ident(nme.SELF).withPos(tree.pos)
803-
case _ => super.transform(tree)
804-
}
805-
}
811+
val (typeParams, evidenceParams) =
812+
desugarTypeBindings(bindings, forPrimaryConstructor = !isSimpleExtension)
806813
val decoName = {
807-
def clsName(tree: Tree): String = tree match {
808-
case Apply(tycon, args) => clsName(tycon)
809-
case TypeApply(tycon, args) => clsName(tycon)
810-
case Select(pre, nme.CONSTRUCTOR) => clsName(pre)
811-
case New(tpt) => clsName(tpt)
812-
case AppliedTypeTree(tycon, _) => clsName(tycon)
813-
case tree: RefTree if tree.name.isTypeName => tree.name.toString
814-
case Parens(tree) => clsName(tree)
815-
case tree: TypeDef => tree.name.toString
816-
case _ => ""
817-
}
814+
def clsName(tree: Tree): String = leadingName("", tree)
818815
val fromName = clsName(augmented)
819816
val toName = impl.parents match {
820817
case parent :: _ if !clsName(parent).isEmpty => "To" + clsName(parent)
821818
case _ => str.Augmentation
822819
}
823820
s"${fromName}${toName}_in_${ctx.owner.topLevelClass.flatName}"
824821
}
822+
823+
val firstParam = ValDef(nme.SELF, decorated, EmptyTree).withFlags(Private | Local | ParamAccessor)
824+
var constr1 =
825+
cpy.DefDef(impl.constr)(
826+
tparams = typeParams.map(_.withFlags(Param | Private | Local)),
827+
vparamss = (firstParam :: Nil) :: impl.constr.vparamss)
828+
var parents1 = impl.parents
829+
var body1 = substThis.transform(impl.body)
830+
if (isSimpleExtension) {
831+
constr1 = cpy.DefDef(constr1)(vparamss = constr1.vparamss.take(1))
832+
parents1 = ref(defn.AnyValType) :: Nil
833+
body1 = body1.map {
834+
case ddef: DefDef =>
835+
def resetFlags(vdef: ValDef) =
836+
vdef.withMods(vdef.mods &~ PrivateLocalParamAccessor | Param)
837+
val originalParams = impl.constr.vparamss.headOption.getOrElse(Nil).map(resetFlags)
838+
addEvidenceParams(addEvidenceParams(ddef, originalParams), evidenceParams)
839+
}
840+
}
841+
else
842+
constr1 = addEvidenceParams(constr1, evidenceParams)
843+
825844
val icls =
826845
TypeDef(UniqueName.fresh(decoName.toTermName).toTypeName,
827-
cpy.Template(impl)(constr = constr1, body = substThis.transform(impl.body)))
846+
cpy.Template(impl)(constr = constr1, parents = parents1, body = body1))
828847
.withFlags(Implicit)
829848
desugr.println(i"desugar $augmented --> $icls")
830849
classDef(icls)
831850
}
832851

852+
private val substThis = new UntypedTreeMap {
853+
override def transform(tree: Tree)(implicit ctx: Context): Tree = tree match {
854+
case This(Ident(tpnme.EMPTY)) => Ident(nme.SELF).withPos(tree.pos)
855+
case _ => super.transform(tree)
856+
}
857+
}
858+
859+
private val leadingName = new UntypedTreeAccumulator[String] {
860+
override def apply(x: String, tree: Tree)(implicit ctx: Context): String =
861+
if (x.isEmpty)
862+
tree match {
863+
case Select(pre, nme.CONSTRUCTOR) => foldOver(x, pre)
864+
case tree: RefTree if tree.name.isTypeName => tree.name.toString
865+
case tree: TypeDef => tree.name.toString
866+
case tree: Tuple => "Tuple"
867+
case tree: Function => "Function"
868+
case _ => foldOver(x, tree)
869+
}
870+
else x
871+
}
872+
833873
def defTree(tree: Tree)(implicit ctx: Context): Tree = tree match {
834874
case tree: ValDef => valDef(tree)
835875
case tree: TypeDef => if (tree.isClassDef) classDef(tree) else tree

tests/neg/augment.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ object augments {
2929

3030
// Specific trait implementations
3131

32-
augment List[Int] {
33-
import java.lang._
32+
augment List[Int] { self => // error: `def` expected
33+
import java.lang._ // error: `def` expected
3434
def maxx = (0 /: this)(_ `max` _)
3535
}
3636

tests/pos/augment.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ object augments {
4747
def isSquare: Boolean = implicitly[Eql[T]].eql(this.width, this.height)
4848
}
4949

50+
augment Rectangle[type T](implicit ev: Eql[T]) {
51+
def isNotSquare: Boolean = !implicitly[Eql[T]].eql(this.width, this.height)
52+
}
53+
5054
// Simple generic augments
5155

5256
augment (type T) {
@@ -63,6 +67,10 @@ object augments {
6367
def === (that: T): Boolean = implicitly[Eql[T]].eql(this, that)
6468
}
6569

70+
augment (type T)(implicit ev: Eql[T]) {
71+
def ==== (that: T): Boolean = implicitly[Eql[T]].eql(this, that)
72+
}
73+
6674
augment Rectangle[type T: Eql] extends HasEql[Rectangle[T]] {
6775
def === (that: Rectangle[T]) =
6876
this.x === that.x &&
@@ -97,8 +105,12 @@ object Test extends App {
97105
val r2 = Rectangle(0, 0, 2, 3)
98106
println(r1.isSquare)
99107
println(r2.isSquare)
108+
println(r2.isNotSquare)
109+
println(r1.isNotSquare)
100110
println(r1 === r1)
101111
println(r1 === r2)
112+
println(1 ==== 1)
113+
println(1 ==== 2)
102114
println(List(1, 2, 3).second)
103115
println(List(List(1), List(2, 3)).flattened)
104116
println(List(List(1), List(2, 3)).flattened.maxx)

0 commit comments

Comments
 (0)