Skip to content

Commit e09e1cc

Browse files
committed
Allow collectve parameters for extension methods
1 parent 13da159 commit e09e1cc

File tree

5 files changed

+125
-8
lines changed

5 files changed

+125
-8
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -824,12 +824,37 @@ object desugar {
824824
* object name extends parents { self => body }
825825
*
826826
* to:
827+
*
827828
* <module> val name: name$ = New(name$)
828829
* <module> final class name$ extends parents { self: name.type => body }
830+
*
831+
* Special case for extension methods with collective parameters. Expand:
832+
*
833+
* given object name[tparams](x: T) extends parents { self => bpdy }
834+
*
835+
* to:
836+
*
837+
* given object name extends parents { self => body' }
838+
*
839+
* where every definition in `body` is expanded to an extension method
840+
* taking type parameters `tparams` and a leading paramter `(x: T)`.
841+
* See: makeExtensionDef
829842
*/
830843
def moduleDef(mdef: ModuleDef)(implicit ctx: Context): Tree = {
831844
val impl = mdef.impl
832845
val mods = mdef.mods
846+
impl.constr match {
847+
case DefDef(_, tparams, (vparams @ (vparam :: Nil)) :: _, _, _) =>
848+
assert(mods.is(Given))
849+
return moduleDef(
850+
cpy.ModuleDef(mdef)(
851+
mdef.name,
852+
cpy.Template(impl)(
853+
constr = emptyConstructor,
854+
body = impl.body.map(makeExtensionDef(_, tparams, vparams)))))
855+
case _ =>
856+
}
857+
833858
val moduleName = normalizeName(mdef, impl).asTermName
834859
def isEnumCase = mods.isEnumCase
835860

@@ -869,6 +894,36 @@ object desugar {
869894
}
870895
}
871896

897+
/** Given tpe parameters `Ts` (possibly empty) and a leading value parameter `(x: T)`,
898+
* map a method definition
899+
*
900+
* def foo [Us] paramss ...
901+
*
902+
* to
903+
*
904+
* <extension> def foo[Ts ++ Us](x: T) parammss ...
905+
*
906+
* If the given member `mdef` is not of this form, flag it as an error.
907+
*/
908+
909+
def makeExtensionDef(mdef: Tree, tparams: List[TypeDef], leadingParams: List[ValDef]) given (ctx: Context): Tree = {
910+
val allowed = "allowed here, since collective parameters are given"
911+
mdef match {
912+
case mdef: DefDef =>
913+
if (mdef.mods.is(Extension)) {
914+
ctx.error(em"No extension method $allowed", mdef.sourcePos)
915+
mdef
916+
}
917+
else cpy.DefDef(mdef)(tparams = tparams ++ mdef.tparams, vparamss = leadingParams :: Nil)
918+
.withFlags(Extension)
919+
case mdef: Import =>
920+
mdef
921+
case mdef =>
922+
ctx.error(em"Only methods $allowed", mdef.sourcePos)
923+
mdef
924+
}
925+
}
926+
872927
/** Transforms
873928
*
874929
* <mods> type $T >: Low <: Hi

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2869,11 +2869,20 @@ object Parsers {
28692869
/** GivenDef ::= [id] [DefTypeParamClause] GivenBody
28702870
* GivenBody ::= [‘as ConstrApp {‘,’ ConstrApp }] {GivenParamClause} [TemplateBody]
28712871
* | ‘as’ Type {GivenParamClause} ‘=’ Expr
2872+
* | ‘(’ DefParam ‘)’ TemplateBody
28722873
*/
28732874
def instanceDef(newStyle: Boolean, start: Offset, mods: Modifiers, instanceMod: Mod) = atSpan(start, nameStart) {
28742875
var mods1 = addMod(mods, instanceMod)
28752876
val name = if (isIdent && (!newStyle || in.name != nme.as)) ident() else EmptyTermName
28762877
val tparams = typeParamClauseOpt(ParamOwner.Def)
2878+
var leadingParamss =
2879+
if (in.token == LPAREN)
2880+
try paramClause(prefix = true) :: Nil
2881+
finally {
2882+
newLineOptWhenFollowedBy(LBRACE)
2883+
if (in.token != LBRACE) syntaxErrorOrIncomplete("`{' expected")
2884+
}
2885+
else Nil
28772886
val parents =
28782887
if (!newStyle && in.token == FOR || isIdent(nme.as)) { // for the moment, accept both `given for` and `given as`
28792888
in.nextToken()
@@ -2889,11 +2898,15 @@ object Parsers {
28892898
}
28902899
else {
28912900
newLineOptWhenFollowedBy(LBRACE)
2892-
val tparams1 = tparams.map(tparam => tparam.withMods(tparam.mods | PrivateLocal))
2893-
val vparamss1 = vparamss.map(_.map(vparam =>
2894-
vparam.withMods(vparam.mods &~ Param | ParamAccessor | PrivateLocal)))
2901+
val (tparams1, vparamss1) =
2902+
if (leadingParamss.nonEmpty)
2903+
(tparams, leadingParamss)
2904+
else
2905+
(tparams.map(tparam => tparam.withMods(tparam.mods | PrivateLocal)),
2906+
vparamss.map(_.map(vparam =>
2907+
vparam.withMods(vparam.mods &~ Param | ParamAccessor | PrivateLocal))))
28952908
val templ = templateBodyOpt(makeConstructor(tparams1, vparamss1), parents, Nil)
2896-
if (tparams.isEmpty && vparamss.isEmpty) ModuleDef(name, templ)
2909+
if (tparams.isEmpty && vparamss1.isEmpty || leadingParamss.nonEmpty) ModuleDef(name, templ)
28972910
else TypeDef(name.toTypeName, templ)
28982911
}
28992912
finalizeDef(instDef, mods1, start)

docs/docs/internals/syntax.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ EnumDef ::= id ClassConstr InheritClauses EnumBody
387387
GivenDef ::= [id] [DefTypeParamClause] GivenBody
388388
GivenBody ::= [‘as ConstrApp {‘,’ ConstrApp }] {GivenParamClause} [TemplateBody]
389389
| ‘as’ Type {GivenParamClause} ‘=’ Expr
390+
| ‘(’ DefParam ‘)’ TemplateBody
390391
Template ::= InheritClauses [TemplateBody] Template(constr, parents, self, stats)
391392
InheritClauses ::= [‘extends’ ConstrApps] [‘derives’ QualId {‘,’ QualId}]
392393
ConstrApps ::= ConstrApp {‘with’ ConstrApp}

docs/docs/reference/contextual/extension-methods.md

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ So `circle.circumference` translates to `CircleOps.circumference(circle)`, provi
8080

8181
### Given Instances for Extension Methods
8282

83-
Given instances that define extension methods can also be defined without a `for` clause. E.g.,
83+
Given instances that define extension methods can also be defined without an `as` clause. E.g.,
8484

8585
```scala
8686
given StringOps {
@@ -94,8 +94,33 @@ given {
9494
def (xs: List[T]) second[T] = xs.tail.head
9595
}
9696
```
97-
If such given instances are anonymous (as in the second clause), their name is synthesized from the name
98-
of the first defined extension method.
97+
If such given instances are anonymous (as in the second clause), their name is synthesized from the name of the first defined extension method.
98+
99+
### Given Instances with Collective Parameters
100+
101+
If a given instance has several extension methods one can pull out the left parameter section
102+
as well as any type parameters of these extension methods into the given instance itself.
103+
For instance, here is a given instance with two extension methods.
104+
```scala
105+
given ListOps {
106+
def (xs: List[T]) second[T]: T = xs.tail.head
107+
def (xs: List[T]) third[T]: T = xs.tail.tail.head
108+
}
109+
```
110+
The repetition in the parameters can be avoided by moving the parameters into the given instance itself. The following version is a shorthand for the code above.
111+
```scala
112+
given ListOps[T](xs: List[T]) {
113+
def second: T = xs.tail.head
114+
def third: T = xs.tail.tail.head
115+
}
116+
```
117+
This syntax just adds convenience at the definition site. Applications of such extension methods are exactly the same as if their parameters were repeated in each extension method.
118+
Examples:
119+
```scala
120+
val xs = List(1, 2, 3)
121+
xs.second[Int]
122+
ListOps.third[T](xs)
123+
```
99124

100125
### Operators
101126

@@ -143,4 +168,6 @@ to the [current syntax](../../internals/syntax.md).
143168
```
144169
DefSig ::= ...
145170
| ‘(’ DefParam ‘)’ [nl] id [DefTypeParamClause] DefParamClauses
171+
GivenBody ::= ...
172+
| ‘(’ DefParam ‘)’ TemplateBody
146173
```

tests/run/extmethods2.scala

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,25 @@ object Test extends App {
1414
}
1515

1616
test given TC()
17-
}
17+
18+
object A {
19+
given ListOps[T](xs: List[T]) {
20+
def second: T = xs.tail.head
21+
def third: T = xs.tail.tail.head
22+
}
23+
given (xs: List[Int]) {
24+
def prod = (1 /: xs)(_ * _)
25+
}
26+
}
27+
28+
object B {
29+
import given A._
30+
val xs = List(1, 2, 3)
31+
assert(xs.second[Int] == 2)
32+
assert(xs.third == 3)
33+
assert(A.ListOps.second[Int](xs) == 2)
34+
assert(A.ListOps.third(xs) == 3)
35+
assert(xs.prod == 6)
36+
}
37+
}
38+

0 commit comments

Comments
 (0)