Skip to content

Commit e8ea6e9

Browse files
committed
Allow cross case references when access checking cases
1 parent b4258e2 commit e8ea6e9

File tree

2 files changed

+53
-32
lines changed

2 files changed

+53
-32
lines changed

compiler/src/dotty/tools/dotc/typer/Checking.scala

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -785,42 +785,55 @@ trait Checking {
785785
* @param enumCtx the context immediately enclosing the corresponding enum
786786
*/
787787
private def checkEnumCaseRefsLegal(cdef: TypeDef, enumCtx: Context)(implicit ctx: Context): Unit = {
788-
def check(tree: Tree) = {
789-
// allow access to `sym` if a typedIdent just outside the enclosing enum
790-
// would have produced the same symbol without errors
791-
def allowAccess(name: Name, sym: Symbol): Boolean = {
792-
val testCtx = enumCtx.fresh.setNewTyperState()
793-
val ref = ctx.typer.typedIdent(untpd.Ident(name), WildcardType)(testCtx)
794-
ref.symbol == sym && !testCtx.reporter.hasErrors
788+
789+
def checkCaseOrDefault(stat: Tree, caseCtx: Context) = {
790+
791+
def check(tree: Tree) = {
792+
// allow access to `sym` if a typedIdent just outside the enclosing enum
793+
// would have produced the same symbol without errors
794+
def allowAccess(name: Name, sym: Symbol): Boolean = {
795+
val testCtx = caseCtx.fresh.setNewTyperState()
796+
val ref = ctx.typer.typedIdent(untpd.Ident(name), WildcardType)(testCtx)
797+
ref.symbol == sym && !testCtx.reporter.hasErrors
798+
}
799+
checkRefsLegal(tree, cdef.symbol, allowAccess, "enum case")
795800
}
796-
checkRefsLegal(tree, cdef.symbol, allowAccess, "enum case")
801+
802+
if (stat.symbol.is(Case))
803+
stat match {
804+
case TypeDef(_, Template(DefDef(_, tparams, vparamss, _, _), parents, _, _)) =>
805+
tparams.foreach(check)
806+
vparamss.foreach(_.foreach(check))
807+
parents.foreach(check)
808+
case vdef: ValDef =>
809+
vdef.rhs match {
810+
case Block((clsDef @ TypeDef(_, impl: Template)) :: Nil, _)
811+
if clsDef.symbol.isAnonymousClass =>
812+
impl.parents.foreach(check)
813+
case _ =>
814+
}
815+
case _ =>
816+
}
817+
else if (stat.symbol.is(Module) && stat.symbol.linkedClass.is(Case))
818+
stat match {
819+
case TypeDef(_, impl: Template) =>
820+
for ((defaultGetter @
821+
DefDef(DefaultGetterName(nme.CONSTRUCTOR, _), _, _, _, _)) <- impl.body)
822+
check(defaultGetter.rhs)
823+
case _ =>
824+
}
797825
}
826+
798827
cdef.rhs match {
799828
case impl: Template =>
800-
for (stat <- impl.body)
801-
if (stat.symbol.is(Case))
802-
stat match {
803-
case TypeDef(_, Template(DefDef(_, tparams, vparamss, _, _), parents, _, _)) =>
804-
tparams.foreach(check)
805-
vparamss.foreach(_.foreach(check))
806-
parents.foreach(check)
807-
case vdef: ValDef =>
808-
vdef.rhs match {
809-
case Block((clsDef @ TypeDef(_, impl: Template)) :: Nil, _)
810-
if clsDef.symbol.isAnonymousClass =>
811-
impl.parents.foreach(check)
812-
case _ =>
813-
}
814-
case _ =>
815-
}
816-
else if (stat.symbol.is(Module) && stat.symbol.linkedClass.is(Case))
817-
stat match {
818-
case TypeDef(_, impl: Template) =>
819-
for ((defaultGetter @
820-
DefDef(DefaultGetterName(nme.CONSTRUCTOR, _), _, _, _, _)) <- impl.body)
821-
check(defaultGetter.rhs)
822-
case _ =>
823-
}
829+
def isCase(stat: Tree) = stat match {
830+
case _: ValDef | _: TypeDef => stat.symbol.is(Case)
831+
case _ => false
832+
}
833+
val cases = for (stat <- impl.body if isCase(stat)) yield untpd.Ident(stat.symbol.name)
834+
val caseImport: Import = Import(ref(cdef.symbol), cases)
835+
val caseCtx = enumCtx.importContext(caseImport, caseImport.symbol)
836+
for (stat <- impl.body) checkCaseOrDefault(stat, caseCtx)
824837
case _ =>
825838
}
826839
}

tests/run/enums.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,14 @@ object Test5 {
100100
}
101101
}
102102

103+
object Test6 {
104+
enum Color(val x: Int) {
105+
case Green extends Color(3)
106+
case Red extends Color(2)
107+
case Violet extends Color(Green.x + Red.x)
108+
}
109+
}
110+
103111
object SerializationTest {
104112
object Types extends Enumeration { val X, Y = Value }
105113
class A extends java.io.Serializable { val types = Types.values }

0 commit comments

Comments
 (0)