Skip to content

Commit d96dc21

Browse files
committed
Reject nonsensical refinements and added neg tests
1 parent b7841ae commit d96dc21

File tree

2 files changed

+96
-68
lines changed

2 files changed

+96
-68
lines changed

compiler/src/dotty/tools/dotc/typer/Implicits.scala

Lines changed: 82 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -901,6 +901,16 @@ trait Implicits { self: Typer =>
901901
ref.withSpan(span)
902902
}
903903

904+
private def checkFormal(formal: Type)(implicit ctx: Context): Boolean = {
905+
@tailrec
906+
def loop(tp: Type): Boolean = tp match {
907+
case RefinedType(_, _, _: MethodicType) => false
908+
case RefinedType(parent, _, _) => loop(parent)
909+
case _ => true
910+
}
911+
loop(formal)
912+
}
913+
904914
/** An implied instance for a type of the form `Mirror.Product { type MirroredType = T }`
905915
* where `T` is a generic product type or a case object or an enum case.
906916
*/
@@ -955,85 +965,89 @@ trait Implicits { self: Typer =>
955965
}
956966
}
957967

958-
formal.member(tpnme.MirroredType).info match {
959-
case TypeBounds(mirroredType, _) => mirrorFor(mirroredType)
960-
case other => EmptyTree
961-
}
968+
if (!checkFormal(formal)) EmptyTree
969+
else
970+
formal.member(tpnme.MirroredType).info match {
971+
case TypeBounds(mirroredType, _) if checkFormal(formal) => mirrorFor(mirroredType)
972+
case other => EmptyTree
973+
}
962974
}
963975

964976
/** An implied instance for a type of the form `Mirror.Sum { type MirroredType = T }`
965977
* where `T` is a generic sum type.
966978
*/
967979
lazy val synthesizedSumMirror: SpecialHandler =
968980
(formal, span) => implicit ctx => {
969-
formal.member(tpnme.MirroredType).info match {
970-
case TypeBounds(mirroredType0, _) =>
971-
val mirroredType = mirroredType0.stripTypeVar
972-
if (mirroredType.classSymbol.isGenericSum) {
973-
val cls = mirroredType.classSymbol
974-
val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))
975-
976-
def solve(sym: Symbol): Type = sym match {
977-
case caseClass: ClassSymbol =>
978-
assert(caseClass.is(Case))
979-
if (caseClass.is(Module))
980-
caseClass.sourceModule.termRef
981-
else {
982-
caseClass.primaryConstructor.info match {
983-
case info: PolyType =>
984-
// Compute the the full child type by solving the subtype constraint
985-
// `C[X1, ..., Xn] <: P`, where
986-
//
987-
// - P is the current `mirroredType`
988-
// - C is the child class, with type parameters X1, ..., Xn
989-
//
990-
// Contravariant type parameters are minimized, all other type parameters are maximized.
991-
def instantiate(implicit ctx: Context) = {
992-
val poly = constrained(info, untpd.EmptyTree)._1
993-
val resType = poly.finalResultType
994-
val target = mirroredType match {
995-
case tp: HKTypeLambda => tp.resultType
996-
case tp => tp
981+
if (!checkFormal(formal)) EmptyTree
982+
else
983+
formal.member(tpnme.MirroredType).info match {
984+
case TypeBounds(mirroredType0, _) =>
985+
val mirroredType = mirroredType0.stripTypeVar
986+
if (mirroredType.classSymbol.isGenericSum) {
987+
val cls = mirroredType.classSymbol
988+
val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))
989+
990+
def solve(sym: Symbol): Type = sym match {
991+
case caseClass: ClassSymbol =>
992+
assert(caseClass.is(Case))
993+
if (caseClass.is(Module))
994+
caseClass.sourceModule.termRef
995+
else {
996+
caseClass.primaryConstructor.info match {
997+
case info: PolyType =>
998+
// Compute the the full child type by solving the subtype constraint
999+
// `C[X1, ..., Xn] <: P`, where
1000+
//
1001+
// - P is the current `mirroredType`
1002+
// - C is the child class, with type parameters X1, ..., Xn
1003+
//
1004+
// Contravariant type parameters are minimized, all other type parameters are maximized.
1005+
def instantiate(implicit ctx: Context) = {
1006+
val poly = constrained(info, untpd.EmptyTree)._1
1007+
val resType = poly.finalResultType
1008+
val target = mirroredType match {
1009+
case tp: HKTypeLambda => tp.resultType
1010+
case tp => tp
1011+
}
1012+
resType <:< target
1013+
val tparams = poly.paramRefs
1014+
val variances = caseClass.typeParams.map(_.paramVariance)
1015+
val instanceTypes = (tparams, variances).zipped.map((tparam, variance) =>
1016+
ctx.typeComparer.instanceType(tparam, fromBelow = variance < 0))
1017+
resType.substParams(poly, instanceTypes)
9971018
}
998-
resType <:< target
999-
val tparams = poly.paramRefs
1000-
val variances = caseClass.typeParams.map(_.paramVariance)
1001-
val instanceTypes = (tparams, variances).zipped.map((tparam, variance) =>
1002-
ctx.typeComparer.instanceType(tparam, fromBelow = variance < 0))
1003-
resType.substParams(poly, instanceTypes)
1004-
}
1005-
instantiate(ctx.fresh.setExploreTyperState().setOwner(caseClass))
1006-
case _ =>
1007-
caseClass.typeRef
1019+
instantiate(ctx.fresh.setExploreTyperState().setOwner(caseClass))
1020+
case _ =>
1021+
caseClass.typeRef
1022+
}
10081023
}
1009-
}
1010-
case child => child.termRef
1011-
}
1024+
case child => child.termRef
1025+
}
10121026

1013-
val (monoType, elemsType) = mirroredType match {
1014-
case mirroredType: HKTypeLambda =>
1015-
val elems = mirroredType.derivedLambdaType(
1016-
resType = TypeOps.nestedPairs(cls.children.map(solve))
1017-
)
1018-
val AppliedType(tycon, _) = mirroredType.resultType
1019-
val monoType = AppliedType(tycon, mirroredType.paramInfos)
1020-
(monoType, elems)
1021-
case _ =>
1022-
val elems = TypeOps.nestedPairs(cls.children.map(solve))
1023-
(mirroredType, elems)
1024-
}
1027+
val (monoType, elemsType) = mirroredType match {
1028+
case mirroredType: HKTypeLambda =>
1029+
val elems = mirroredType.derivedLambdaType(
1030+
resType = TypeOps.nestedPairs(cls.children.map(solve))
1031+
)
1032+
val AppliedType(tycon, _) = mirroredType.resultType
1033+
val monoType = AppliedType(tycon, mirroredType.paramInfos)
1034+
(monoType, elems)
1035+
case _ =>
1036+
val elems = TypeOps.nestedPairs(cls.children.map(solve))
1037+
(mirroredType, elems)
1038+
}
10251039

1026-
val mirrorType =
1027-
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
1028-
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
1029-
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
1030-
val mirrorRef =
1031-
if (cls.linkedClass.exists && !cls.is(Scala2x)) companionPath(mirroredType, span)
1032-
else anonymousMirror(monoType, ExtendsSumMirror, span)
1033-
mirrorRef.cast(mirrorType)
1034-
} else EmptyTree
1035-
case _ => EmptyTree
1036-
}
1040+
val mirrorType =
1041+
mirrorCore(defn.Mirror_SumClass, monoType, mirroredType, cls.name, formal)
1042+
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
1043+
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
1044+
val mirrorRef =
1045+
if (cls.linkedClass.exists && !cls.is(Scala2x)) companionPath(mirroredType, span)
1046+
else anonymousMirror(monoType, ExtendsSumMirror, span)
1047+
mirrorRef.cast(mirrorType)
1048+
} else EmptyTree
1049+
case _ => EmptyTree
1050+
}
10371051
}
10381052

10391053
/** An implied instance for a type of the form `Mirror { type MirroredType = T }`

tests/neg/mirror-implicit-scope.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import scala.deriving._
2+
3+
object Test {
4+
class SomeClass
5+
case class ISB(i: Int, s: String, b: Boolean)
6+
case class BI(b: Boolean, i: Int)
7+
8+
val v0 = the[Mirror.ProductOf[ISB]] // OK
9+
val v1 = the[SomeClass & Mirror.ProductOf[ISB]] // error
10+
val v2 = the[Mirror.ProductOf[ISB] & Mirror.ProductOf[BI]] // error
11+
val v3 = the[Mirror.Product { type MirroredType = ISB ; def foo: Int }] // error
12+
val v4 = the[Mirror.Product { type MirroredType = ISB ; def foo(i: Int): Int }] // error
13+
val v5 = the[Mirror.Product { type MirroredType = ISB ; def foo[T](t: T): T }] // error // error
14+
}

0 commit comments

Comments
 (0)