Skip to content

Drop old extension method syntax #9476

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Aug 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 3 additions & 65 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -837,36 +837,10 @@ object desugar {
*
* <module> val name: name$ = New(name$)
* <module> final class name$ extends parents { self: name.type => body }
*
* Special case for extension methods with collective parameters. Expand:
*
* given object name[tparams](x: T) extends parents { self => bpdy }
*
* to:
*
* given object name extends parents { self => body' }
*
* where every definition in `body` is expanded to an extension method
* taking type parameters `tparams` and a leading paramter `(x: T)`.
* See: collectiveExtensionBody
* TODO: drop this part
*/
def moduleDef(mdef: ModuleDef)(using Context): Tree = {
val impl = mdef.impl
val mods = mdef.mods
impl.constr match {
case DefDef(_, tparams, vparamss @ (vparam :: Nil) :: givenParamss, _, _) =>
// Transform collective extension
assert(mods.is(Given))
return moduleDef(
cpy.ModuleDef(mdef)(
mdef.name,
cpy.Template(impl)(
constr = emptyConstructor,
body = collectiveExtensionBody(impl.body, tparams, vparamss))))
case _ =>
}

val moduleName = normalizeName(mdef, impl).asTermName
def isEnumCase = mods.isEnumCase

Expand Down Expand Up @@ -921,46 +895,10 @@ object desugar {
vparams1 :: ext.vparamss ::: vparamss1
case _ =>
ext.vparamss ++ mdef.vparamss
).withMods(mdef.mods | Extension)
).withMods(mdef.mods | ExtensionMethod)
)
}

/** Transform the statements of a collective extension
* @param stats the original statements as they were parsed
* @param tparams the collective type parameters
* @param vparamss the collective value parameters, consisting
* of a single leading value parameter, followed by
* zero or more context parameter clauses
*
* Note: It is already assured by Parser.checkExtensionMethod that all
* statements conform to requirements.
*
* Each method in stats is transformed into an extension method. Example:
*
* extension on [Ts](x: T)(using C):
* def f(y: T) = ???
* def g(z: T) = f(z)
*
* is turned into
*
* extension:
* <extension> def f[Ts](x: T)(using C)(y: T) = ???
* <extension> def g[Ts](x: T)(using C)(z: T) = f(z)
*/
def collectiveExtensionBody(stats: List[Tree],
tparams: List[TypeDef], vparamss: List[List[ValDef]])(using Context): List[Tree] =
for stat <- stats yield
stat match
case mdef: DefDef =>
cpy.DefDef(mdef)(
name = mdef.name.toExtensionName,
tparams = tparams ++ mdef.tparams,
vparamss = vparamss ::: mdef.vparamss,
).withMods(mdef.mods | Extension)
case mdef =>
mdef
end collectiveExtensionBody

/** Transforms
*
* <mods> type $T >: Low <: Hi
Expand Down Expand Up @@ -997,7 +935,7 @@ object desugar {
report.error(IllegalRedefinitionOfStandardKind(kind, name), errPos)
name = name.errorName
}
if name.isExtensionName && !mdef.mods.is(Extension) then
if name.isExtensionName && !mdef.mods.is(ExtensionMethod) then
report.error(em"illegal method name: $name may not start with `extension_`", errPos)
name
}
Expand All @@ -1008,7 +946,7 @@ object desugar {
case impl: Template =>
if impl.parents.isEmpty then
impl.body.find {
case dd: DefDef if dd.mods.is(Extension) => true
case dd: DefDef if dd.mods.is(ExtensionMethod) => true
case _ => false
}
match
Expand Down
12 changes: 3 additions & 9 deletions compiler/src/dotty/tools/dotc/ast/Positioned.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ import util.Spans._
import util.{SourceFile, NoSource, SourcePosition, SrcPos}
import core.Contexts._
import core.Decorators._
import core.Flags.{JavaDefined, Extension}
import core.Flags.{JavaDefined, ExtensionMethod}
import core.StdNames.nme
import ast.Trees.mods
import annotation.constructorOnly
Expand Down Expand Up @@ -152,11 +152,6 @@ abstract class Positioned(implicit @constructorOnly src: SourceFile) extends Src
}
}

/** A hook that can be overridden if overlap checking in `checkPos` should be
* disabled for this node.
*/
def disableOverlapChecks = false

/** Check that all positioned items in this tree satisfy the following conditions:
* - Parent spans contain child spans
* - If item is a non-empty tree, it has a position
Expand All @@ -179,7 +174,7 @@ abstract class Positioned(implicit @constructorOnly src: SourceFile) extends Src
s"position error: position not set for $tree # ${tree.uniqueId}")
case _ =>
}
if (nonOverlapping && !disableOverlapChecks) {
if nonOverlapping then
this match {
case _: XMLBlock =>
// FIXME: Trees generated by the XML parser do not satisfy `checkPos`
Expand All @@ -197,7 +192,6 @@ abstract class Positioned(implicit @constructorOnly src: SourceFile) extends Src
}
lastPositioned = p
lastSpan = p.span
}
p.checkPos(nonOverlapping)
case m: untpd.Modifiers =>
m.annotations.foreach(check)
Expand All @@ -212,7 +206,7 @@ abstract class Positioned(implicit @constructorOnly src: SourceFile) extends Src
// Leave out tparams, they are copied with wrong positions from parent class
check(tree.mods)
check(tree.vparamss)
case tree: DefDef if tree.mods.is(Extension) =>
case tree: DefDef if tree.mods.is(ExtensionMethod) =>
tree.vparamss match {
case vparams1 :: vparams2 :: rest if !isLeftAssoc(tree.name) =>
check(tree.tparams)
Expand Down
5 changes: 0 additions & 5 deletions compiler/src/dotty/tools/dotc/ast/Trees.scala
Original file line number Diff line number Diff line change
Expand Up @@ -757,11 +757,6 @@ object Trees {
assert(tpt != genericEmptyTree)
def unforced: LazyTree[T] = preRhs
protected def force(x: Tree[T @uncheckedVariance]): Unit = preRhs = x

override def disableOverlapChecks = rawMods.is(Extension)
// disable order checks for extension methods as long as we parse
// type parameters both before and after the leading parameter section.
// TODO drop this once syntax of type parameters has settled.
}

/** mods class name template or
Expand Down
5 changes: 2 additions & 3 deletions compiler/src/dotty/tools/dotc/core/Flags.scala
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ object Flags {
val (_, HasDefault @ _, _) = newFlags(27, "<hasdefault>")

/** An extension method, or a collective extension instance */
val (_, Extension @ _, _) = newFlags(28, "<extension>")
val (Extension @ _, ExtensionMethod @ _, _) = newFlags(28, "<extension>")

/** An inferable (`given`) parameter */
val (Given @ _, _, _) = newFlags(29, "given")
Expand Down Expand Up @@ -495,7 +495,7 @@ object Flags {

/** Flags that can apply to a module val */
val RetainedModuleValFlags: FlagSet = RetainedModuleValAndClassFlags |
Override | Final | Method | Implicit | Given | Lazy | Extension |
Override | Final | Method | Implicit | Given | Lazy |
Accessor | AbsOverride | StableRealizable | Captured | Synchronized | Erased

/** Flags that can apply to a module class */
Expand Down Expand Up @@ -527,7 +527,6 @@ object Flags {
val DeferredOrTypeParam: FlagSet = Deferred | TypeParam // type symbols without right-hand sides
val EnumValue: FlagSet = Enum | StableRealizable // A Scala enum value
val StableOrErased: FlagSet = Erased | StableRealizable // Assumed to be pure
val ExtensionMethod: FlagSet = Extension | Method
val FinalOrInline: FlagSet = Final | Inline
val FinalOrModuleClass: FlagSet = Final | ModuleClass // A module class or a final class
val EffectivelyFinalFlags: FlagSet = Final | Private
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/core/SymDenotations.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1104,7 +1104,7 @@ object SymDenotations {
* provided the extension method appears in the same class.
*/
final def enclosingExtensionMethod(using Context): Symbol =
if this.isAllOf(ExtensionMethod) then symbol
if this.is(ExtensionMethod) then symbol
else if this.isClass then NoSymbol
else if this.exists then owner.enclosingExtensionMethod
else NoSymbol
Expand Down
63 changes: 12 additions & 51 deletions compiler/src/dotty/tools/dotc/parsing/Parsers.scala
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,6 @@ object Parsers {
|| defIntroTokens.contains(in.token)
|| allowedMods.contains(in.token)
|| in.isSoftModifierInModifierPosition && !excludedSoftModifiers.contains(in.name)
|| isIdent(nme.extension) && followingIsOldExtension()

def isStatSep: Boolean = in.isNewLine || in.token == SEMI

Expand Down Expand Up @@ -919,23 +918,10 @@ object Parsers {
skipParams()
lookahead.isIdent(nme.as)

def followingIsNewExtension() =
def followingIsExtension() =
val next = in.lookahead.token
next == LBRACKET || next == LPAREN

def followingIsOldExtension() =
val lookahead = in.LookaheadScanner()
lookahead.nextToken()
if lookahead.isIdent && !lookahead.isIdent(nme.on) then
lookahead.nextToken()
if lookahead.isNewLine then
lookahead.nextToken()
lookahead.isIdent(nme.on)
|| lookahead.token == LBRACE
|| lookahead.token == COLON

def followingIsExtension() = followingIsOldExtension() || followingIsNewExtension()

/* --------- OPERAND/OPERATOR STACK --------------------------------------- */

var opStack: List[OpInfo] = Nil
Expand Down Expand Up @@ -1311,7 +1297,7 @@ object Parsers {
case stat: MemberDef if !stat.name.isEmpty =>
if stat.name == nme.CONSTRUCTOR then in.token == THIS
else in.isIdent && in.name == stat.name.toTermName
case ModuleDef(_, Template(_, Nil, _, _)) | ExtMethods(_, _, _) =>
case ExtMethods(_, _, _) =>
in.token == IDENTIFIER && in.name == nme.extension
case PackageDef(pid: RefTree, _) =>
in.isIdent && in.name == pid.name
Expand Down Expand Up @@ -3305,7 +3291,7 @@ object Parsers {
def extParamss() =
try paramClause(0, prefix = true) :: Nil
finally
mods1 = addFlag(mods, Extension)
mods1 = addFlag(mods, ExtensionMethod)
if in.token == DOT then in.nextToken()
else
isInfix = true
Expand All @@ -3319,14 +3305,14 @@ object Parsers {
(Nil, Nil)
val ident = termIdent()
var name = ident.name.asTermName
if mods1.is(Extension) then name = name.toExtensionName
if mods1.is(ExtensionMethod) then name = name.toExtensionName
if isInfix && !name.isOperatorName then
val infixAnnot = Apply(wrapNew(scalaAnnotationDot(tpnme.infix)), Nil)
.withSpan(Span(start, start))
mods1 = mods1.withAddedAnnotation(infixAnnot)
val tparams =
if in.token == LBRACKET then
if mods1.is(Extension) then syntaxError("no type parameters allowed here")
if mods1.is(ExtensionMethod) then syntaxError("no type parameters allowed here")
typeParamClause(ParamOwner.Def)
else leadingTparams
val vparamss = paramClauses() match
Expand Down Expand Up @@ -3465,11 +3451,8 @@ object Parsers {
case GIVEN =>
givenDef(start, mods, atSpan(in.skipToken()) { Mod.Given() })
case _ =>
if isIdent(nme.extension) && followingIsOldExtension() then
extensionDef(start, mods)
else
syntaxErrorOrIncomplete(ExpectedStartOfTopLevelDefinition())
EmptyTree
syntaxErrorOrIncomplete(ExpectedStartOfTopLevelDefinition())
EmptyTree
}

/** ClassDef ::= id ClassConstr TemplateOpt
Expand Down Expand Up @@ -3563,9 +3546,9 @@ object Parsers {
def checkExtensionMethod(tparams: List[Tree],
vparamss: List[List[Tree]], stat: Tree): Unit = stat match {
case stat: DefDef =>
if stat.mods.is(Extension) && vparamss.nonEmpty then
if stat.mods.is(ExtensionMethod) && vparamss.nonEmpty then
syntaxError(i"no extension method allowed here since leading parameter was already given", stat.span)
else if !stat.mods.is(Extension) && vparamss.isEmpty then
else if !stat.mods.is(ExtensionMethod) && vparamss.isEmpty then
syntaxError(i"an extension method is required here", stat.span)
else if tparams.nonEmpty && stat.tparams.nonEmpty then
syntaxError(i"extension method cannot have type parameters since some were already given previously",
Expand Down Expand Up @@ -3614,28 +3597,6 @@ object Parsers {
finalizeDef(gdef, mods1, start)
}

/** ExtensionDef ::= [id] [‘on’ ExtParamClause {UsingParamClause}] TemplateBody
*/
def extensionDef(start: Offset, mods: Modifiers): ModuleDef =
in.nextToken()
val nameOffset = in.offset
val name = if isIdent && !isIdent(nme.on) then ident() else EmptyTermName
val (tparams, vparamss, extensionFlag) =
if isIdent(nme.on) then
in.nextToken()
val tparams = typeParamClauseOpt(ParamOwner.Def)
val extParams = paramClause(0, prefix = true)
val givenParamss = paramClauses(givenOnly = true)
(tparams, extParams :: givenParamss, Extension)
else
(Nil, Nil, EmptyFlags)
possibleTemplateStart()
if !in.isNestedStart then syntaxError("Extension without extension methods")
val templ = templateBodyOpt(makeConstructor(tparams, vparamss), Nil, Nil)
templ.body.foreach(checkExtensionMethod(tparams, vparamss, _))
val edef = atSpan(start, nameOffset, in.offset)(ModuleDef(name, templ))
finalizeDef(edef, addFlag(mods, Given | extensionFlag), start)

/** Extension ::= ‘extension’ [DefTypeParamClause] ‘(’ DefParam ‘)’
* {UsingParamClause} ExtMethods
*/
Expand Down Expand Up @@ -3816,7 +3777,7 @@ object Parsers {
stats ++= importClause(IMPORT, mkImport(outermost))
else if (in.token == EXPORT)
stats ++= importClause(EXPORT, Export.apply)
else if isIdent(nme.extension) && followingIsNewExtension() then
else if isIdent(nme.extension) && followingIsExtension() then
stats += extension()
else if isDefIntro(modifierTokens)
stats +++= defOrDcl(in.offset, defAnnotsMods(modifierTokens))
Expand Down Expand Up @@ -3870,7 +3831,7 @@ object Parsers {
stats ++= importClause(IMPORT, mkImport())
else if (in.token == EXPORT)
stats ++= importClause(EXPORT, Export.apply)
else if isIdent(nme.extension) && followingIsNewExtension() then
else if isIdent(nme.extension) && followingIsExtension() then
stats += extension()
else if (isDefIntro(modifierTokensOrCase))
stats +++= defOrDcl(in.offset, defAnnotsMods(modifierTokens))
Expand Down Expand Up @@ -3952,7 +3913,7 @@ object Parsers {
stats += expr(Location.InBlock)
else if in.token == IMPLICIT && !in.inModifierPosition() then
stats += closure(in.offset, Location.InBlock, modifiers(BitSet(IMPLICIT)))
else if isIdent(nme.extension) && followingIsNewExtension() then
else if isIdent(nme.extension) && followingIsExtension() then
stats += extension()
else if isDefIntro(localModifierTokens, excludedSoftModifiers = Set(nme.`opaque`)) then
stats +++= localDef(in.offset)
Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/printing/RefinedPrinter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -745,7 +745,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
protected def nameIdText[T >: Untyped](tree: NameTree[T], dropExtension: Boolean = false): Text =
if (tree.hasType && tree.symbol.exists) {
var str = nameString(tree.symbol)
if tree.symbol.isExtensionMethod && dropExtension && str.startsWith("extension_") then
if tree.symbol.is(ExtensionMethod) && dropExtension && str.startsWith("extension_") then
str = str.drop("extension_".length)
tree match {
case tree: RefTree => withPos(str, tree.sourcePos)
Expand Down Expand Up @@ -788,7 +788,7 @@ class RefinedPrinter(_ctx: Context) extends PlainPrinter(_ctx) {
import untpd._
dclTextOr(tree) {
val defKeyword = modText(tree.mods, tree.symbol, keywordStr("def"), isType = false)
val isExtension = tree.hasType && tree.symbol.isExtensionMethod
val isExtension = tree.hasType && tree.symbol.is(ExtensionMethod)
withEnclosingDef(tree) {
val (prefix, vparamss) =
if isExtension then
Expand Down
10 changes: 0 additions & 10 deletions compiler/src/dotty/tools/dotc/transform/SymUtils.scala
Original file line number Diff line number Diff line change
Expand Up @@ -213,16 +213,6 @@ object SymUtils {
def isTypeSplice(using Context): Boolean =
self == defn.QuotedType_splice

/** Is symbol an extension method? Accessors are excluded since
* after the getters phase collective extension objects become accessors
*/
def isExtensionMethod(using Context): Boolean =
self.isAllOf(ExtensionMethod, butNot = Accessor)

/** Is symbol the module class of a collective extension object? */
def isCollectiveExtensionClass(using Context): Boolean =
self.is(ModuleClass) && self.sourceModule.is(Extension) && !self.sourceModule.isExtensionMethod

def isScalaStatic(using Context): Boolean =
self.hasAnnotation(defn.ScalaStaticAnnot)

Expand Down
4 changes: 2 additions & 2 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2108,8 +2108,8 @@ trait Applications extends Compatibility {
}
def isExtension(tree: Tree): Boolean = methPart(tree) match {
case Inlined(call, _, _) => isExtension(call)
case tree @ Select(qual, nme.apply) => tree.symbol.is(Extension) || isExtension(qual)
case tree => tree.symbol.is(Extension)
case tree @ Select(qual, nme.apply) => tree.symbol.is(ExtensionMethod) || isExtension(qual)
case tree => tree.symbol.is(ExtensionMethod)
}
if (!isExtension(app))
report.error(em"not an extension method: $methodRef", receiver.srcPos)
Expand Down
Loading