diff --git a/compiler/src/dotty/tools/dotc/ast/Desugar.scala b/compiler/src/dotty/tools/dotc/ast/Desugar.scala index 7121f7e7370d..e0beb323da51 100644 --- a/compiler/src/dotty/tools/dotc/ast/Desugar.scala +++ b/compiler/src/dotty/tools/dotc/ast/Desugar.scala @@ -736,7 +736,19 @@ object desugar { companionDefs(anyRef, applyMeths ::: unapplyMeth :: toStringMeth :: companionMembers) } - else if (companionMembers.nonEmpty || companionDerived.nonEmpty || isEnum) + else if (isEnum) + val isSingletonEnum = companionMembers.forall { + case _ : PatDef => true + case _ : ModuleDef => true + case _ => false + } + val enumCompClass = + if (isSingletonEnum) defn.SingletonEnumCompanionClass.typeRef + else defn.EnumCompanionClass.typeRef + val clsWithArgs = appliedTypeTree(Ident(className), impliedTparams.map(_ => WildcardTypeBoundsTree())) + val parent = appliedTypeTree(ref(enumCompClass), clsWithArgs :: Nil) + companionDefs(parent, companionMembers) + else if (companionMembers.nonEmpty || companionDerived.nonEmpty) companionDefs(anyRef, companionMembers) else if (isValueClass) companionDefs(anyRef, Nil) diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 464c7900a54f..1fe6ad038059 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -740,6 +740,8 @@ class Definitions { @tu lazy val EnumClass: ClassSymbol = requiredClass("scala.reflect.Enum") @tu lazy val Enum_ordinal: Symbol = EnumClass.requiredMethod(nme.ordinal) + @tu lazy val EnumCompanionClass: ClassSymbol = requiredClass("scala.reflect.EnumCompanion") + @tu lazy val SingletonEnumCompanionClass: ClassSymbol = requiredClass("scala.reflect.SingletonEnumCompanion") @tu lazy val EnumValueSerializationProxyClass: ClassSymbol = requiredClass("scala.runtime.EnumValueSerializationProxy") @tu lazy val EnumValueSerializationProxyConstructor: TermSymbol = diff --git a/library/src/scala/reflect/Enum.scala b/library/src/scala/reflect/Enum.scala index 92efa34cf430..ffd38a0754d4 100644 --- a/library/src/scala/reflect/Enum.scala +++ b/library/src/scala/reflect/Enum.scala @@ -5,3 +5,11 @@ package scala.reflect /** A number uniquely identifying a case of an enum */ def ordinal: Int + +/** A base trait of all Scala enum companion definitions */ +@annotation.transparentTrait trait EnumCompanion[E <: Enum] extends AnyRef + +/** A base trait of all Scala singleton enum companion definitions */ +@annotation.transparentTrait trait SingletonEnumCompanion[E <: Enum] extends EnumCompanion[E]: + def values : Array[E] + def valueOf(name : String) : E \ No newline at end of file diff --git a/tests/init/neg/enum-desugared.scala b/tests/init/neg/enum-desugared.scala index eb80f112a06c..9438124efcfc 100644 --- a/tests/init/neg/enum-desugared.scala +++ b/tests/init/neg/enum-desugared.scala @@ -8,7 +8,7 @@ sealed abstract class ErrorMessageID($name: String, _$ordinal: Int) def errorNumber: Int = this.ordinal() - 2 } -object ErrorMessageID { +object ErrorMessageID extends scala.reflect.EnumCompanion[ErrorMessageID]{ final val LazyErrorId = $new(0, "LazyErrorId") final val NoExplanationID = $new(1, "NoExplanationID") diff --git a/tests/run/enum-reflect-companion.scala b/tests/run/enum-reflect-companion.scala new file mode 100644 index 000000000000..7d7c4d59447c --- /dev/null +++ b/tests/run/enum-reflect-companion.scala @@ -0,0 +1,48 @@ +import scala.reflect.{EnumCompanion, SingletonEnumCompanion} +enum Foo1: + case Baz, Bar + +val check1 = summon[Foo1.type <:< SingletonEnumCompanion[Foo1]] + +enum Foo2[T]: + case Baz extends Foo2[1] + case Bar extends Foo2[2] + +val check2 = summon[Foo2.type <:< SingletonEnumCompanion[Foo2[?]]] + +enum Foo3[A, B[_]]: + case Baz extends Foo3[Int, List] + case Bar extends Foo3[Int, List] + +val check3 = summon[Foo3.type <:< SingletonEnumCompanion[Foo3[?, ?]]] + +extension [T <: reflect.Enum](enumCompanion : SingletonEnumCompanion[T]) + def check(arg : T) : Unit = assert(enumCompanion.values.contains(arg)) + +enum Foo4: + case Yes + case No(whyNot: String) + case Skip + +val check4 = summon[Foo4.type <:< EnumCompanion[Foo4]] + +@main def Test : Unit = + Foo3.check(Foo3.Bar) + (Foo3 : AnyRef) match + case _ : SingletonEnumCompanion[?] => + case _ : EnumCompanion[?] => assert(false) + case _ => assert(false) + + (Foo4 : AnyRef) match + case _ : SingletonEnumCompanion[?] => assert(false) + case _ : EnumCompanion[?] => + case _ => assert(false) + +enum Foo5: + case Baz, Bar + +trait Hello +object Foo5 extends Hello + +//TODO: fix implementation so this would work +//val check5 = summon[Foo5.type <:< scala.reflect.EnumCompanion[Foo5] with Hello]