Skip to content

Commit d8f9f02

Browse files
authored
Merge pull request #13952 from dotty-staging/fix-13554-alt
move implementation of ordinal in enums to posttyper
2 parents dcd4249 + b2570dc commit d8f9f02

File tree

7 files changed

+55
-13
lines changed

7 files changed

+55
-13
lines changed

compiler/src/dotty/tools/dotc/ast/DesugarEnums.scala

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -179,13 +179,12 @@ object DesugarEnums {
179179
* }
180180
*/
181181
private def enumValueCreator(using Context) = {
182-
val fieldMethods = if isJavaEnum then Nil else ordinalMeth(Ident(nme.ordinalDollar_)) :: Nil
183182
val creator = New(Template(
184183
constr = emptyConstructor,
185184
parents = enumClassRef :: scalaRuntimeDot(tpnme.EnumValue) :: Nil,
186185
derived = Nil,
187186
self = EmptyValDef,
188-
body = fieldMethods
187+
body = Nil
189188
).withAttachment(ExtendsSingletonMirror, ()))
190189
DefDef(nme.DOLLAR_NEW,
191190
List(List(param(nme.ordinalDollar_, defn.IntType), param(nme.nameDollar, defn.StringType))),
@@ -270,8 +269,6 @@ object DesugarEnums {
270269
def param(name: TermName, typ: Type)(using Context): ValDef = param(name, TypeTree(typ))
271270
def param(name: TermName, tpt: Tree)(using Context): ValDef = ValDef(name, tpt, EmptyTree).withFlags(Param)
272271

273-
private def isJavaEnum(using Context): Boolean = enumClass.derivesFrom(defn.JavaEnumClass)
274-
275272
def ordinalMeth(body: Tree)(using Context): DefDef =
276273
DefDef(nme.ordinal, Nil, TypeTree(defn.IntType), body).withAddedFlags(Synthetic)
277274

@@ -290,10 +287,8 @@ object DesugarEnums {
290287
expandSimpleEnumCase(name, mods, definesLookups, span)
291288
else {
292289
val (tag, scaffolding) = nextOrdinal(name, CaseKind.Object, definesLookups)
293-
val impl1 = cpy.Template(impl)(
294-
parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue),
295-
body = if isJavaEnum then Nil else ordinalMethLit(tag) :: Nil
296-
).withAttachment(ExtendsSingletonMirror, ())
290+
val impl1 = cpy.Template(impl)(parents = impl.parents :+ scalaRuntimeDot(tpnme.EnumValue), body = Nil)
291+
.withAttachment(ExtendsSingletonMirror, ())
297292
val vdef = ValDef(name, TypeTree(), New(impl1)).withMods(mods.withAddedFlags(EnumValue, span))
298293
flatTree(vdef :: scaffolding).withSpan(span)
299294
}

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,7 @@ class Definitions {
733733
@tu lazy val NoneModule: Symbol = requiredModule("scala.None")
734734

735735
@tu lazy val EnumClass: ClassSymbol = requiredClass("scala.reflect.Enum")
736+
@tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal)
736737

737738
@tu lazy val EnumValueSerializationProxyClass: ClassSymbol = requiredClass("scala.runtime.EnumValueSerializationProxy")
738739
@tu lazy val EnumValueSerializationProxyConstructor: TermSymbol =

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
6666
myCaseSymbols = defn.caseClassSynthesized
6767
myCaseModuleSymbols = myCaseSymbols.filter(_ ne defn.Any_equals)
6868
myEnumValueSymbols = List(defn.Product_productPrefix)
69-
myNonJavaEnumValueSymbols = myEnumValueSymbols :+ defn.Any_toString
69+
myNonJavaEnumValueSymbols = myEnumValueSymbols :+ defn.Any_toString :+ defn.Enum_ordinal
7070
}
7171

7272
def valueSymbols(using Context): List[Symbol] = { initSymbols; myValueSymbols }
@@ -132,6 +132,17 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
132132
else // assume owner is `val Foo = new MyEnum { def ordinal = 0 }`
133133
Literal(Constant(clazz.owner.name.toString))
134134

135+
def ordinalRef: Tree =
136+
if isSimpleEnumValue then // owner is `def $new(_$ordinal: Int, $name: String) = new MyEnum { ... }`
137+
ref(clazz.owner.paramSymss.head.find(_.name == nme.ordinalDollar_).get)
138+
else // val CaseN = new MyEnum { ... def ordinal: Int = n }
139+
val vdef = clazz.owner
140+
val parentEnum = vdef.owner.companionClass
141+
val children = parentEnum.children.zipWithIndex
142+
val candidate: Option[Int] = children.collectFirst { case (child, idx) if child == vdef => idx }
143+
assert(candidate.isDefined, i"could not find child for $vdef")
144+
Literal(Constant(candidate.get))
145+
135146
def toStringBody(vrefss: List[List[Tree]]): Tree =
136147
if (clazz.is(ModuleClass)) ownName
137148
else if (isNonJavaEnumValue) identifierRef
@@ -143,6 +154,7 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
143154
case nme.toString_ => toStringBody(vrefss)
144155
case nme.equals_ => equalsBody(vrefss.head.head)
145156
case nme.canEqual_ => canEqualBody(vrefss.head.head)
157+
case nme.ordinal => ordinalRef
146158
case nme.productArity => Literal(Constant(accessors.length))
147159
case nme.productPrefix if isEnumValue => nameRef
148160
case nme.productPrefix => ownName

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,8 +1217,19 @@ trait Checking {
12171217
/** 1. Check that all case classes that extend `scala.reflect.Enum` are `enum` cases
12181218
* 2. Check that parameterised `enum` cases do not extend java.lang.Enum.
12191219
* 3. Check that only a static `enum` base class can extend java.lang.Enum.
1220+
* 4. Check that user does not implement an `ordinal` method in the body of an enum class.
12201221
*/
12211222
def checkEnum(cdef: untpd.TypeDef, cls: Symbol, firstParent: Symbol)(using Context): Unit = {
1223+
def existingDef(sym: Symbol, clazz: ClassSymbol)(using Context): Symbol = // adapted from SyntheticMembers
1224+
val existing = sym.matchingMember(clazz.thisType)
1225+
if existing != sym && !existing.is(Deferred) then existing else NoSymbol
1226+
def checkExistingOrdinal(using Context) =
1227+
val decl = existingDef(defn.Enum_ordinal, cls.asClass)
1228+
if decl.exists then
1229+
if decl.owner == cls then
1230+
report.error(em"the ordinal method of enum $cls can not be defined by the user", decl.srcPos)
1231+
else
1232+
report.error(em"enum $cls can not inherit the concrete ordinal method of ${decl.owner}", cdef.srcPos)
12221233
def isEnumAnonCls =
12231234
cls.isAnonymousClass
12241235
&& cls.owner.isTerm
@@ -1238,6 +1249,8 @@ trait Checking {
12381249
// this test allows inheriting from `Enum` by hand;
12391250
// see enum-List-control.scala.
12401251
report.error(ClassCannotExtendEnum(cls, firstParent), cdef.srcPos)
1252+
if cls.isEnumClass && !isJavaEnum then
1253+
checkExistingOrdinal
12411254
}
12421255

12431256
/** Check that the firstParent for an enum case derives from the declaring enum class, if not, adds it as a parent

tests/neg/enumsLabel-singleimpl.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
enum Ordinalled {
22

3-
case A // error: method ordinal of type => Int needs `override` modifier
3+
case A
44

5-
def ordinal: Int = -1
5+
def ordinal: Int = -1 // error: the ordinal method of enum class Ordinalled can not be defined by the user
66

77
}
88

99
trait HasOrdinal { def ordinal: Int = 23 }
1010

11-
enum MyEnum extends HasOrdinal {
12-
case Foo // error: method ordinal of type => Int needs `override` modifier
11+
enum MyEnum extends HasOrdinal { // error: enum class MyEnum can not inherit the concrete ordinal method of trait HasOrdinal
12+
case Foo
1313
}

tests/pos/i13554.scala

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
object StatusCode:
2+
class Matcher
3+
4+
enum StatusCode(m: StatusCode.Matcher):
5+
case InternalServerError extends StatusCode(???)
6+

tests/pos/i13554a.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
object StatusCode:
2+
enum Matcher:
3+
case ServerError extends Matcher
4+
end Matcher
5+
end StatusCode
6+
7+
enum StatusCode(code: Int, m: StatusCode.Matcher):
8+
case InternalServerError extends StatusCode(500, StatusCode.Matcher.ServerError)
9+
end StatusCode
10+
11+
object Main {
12+
def main(args: Array[String]): Unit = {
13+
println(StatusCode.InternalServerError)
14+
}
15+
}

0 commit comments

Comments
 (0)