Skip to content

Commit 91df766

Browse files
committed
Generalized type class derivation for higher kinded type classes
* Mirrors of data types of kinds other than * are now generated. The Mirror type member MirroredTypeConstructor has been renamed to MirroredType: this is the type which is used to select the "naturally" kinded Mirror (ie. the Mirror with the kind that matches the kind of the data type). * Data types can now have derives clauses for type class which are indexed by types of the same kind, for all kinds. Data types continue to support derives clauses for type classes indexed by types of lower kinds than the data type via polymorphic derived members.
1 parent 930ca64 commit 91df766

File tree

6 files changed

+189
-139
lines changed

6 files changed

+189
-139
lines changed

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ object StdNames {
343343
val MirroredElemLabels: N = "MirroredElemLabels"
344344
val MirroredLabel: N = "MirroredLabel"
345345
val MirroredMonoType: N = "MirroredMonoType"
346-
val MirroredTypeConstructor: N = "MirroredTypeConstructor"
346+
val MirroredType: N = "MirroredType"
347347
val Modifiers: N = "Modifiers"
348348
val NestedAnnotArg: N = "NestedAnnotArg"
349349
val NoFlags: N = "NoFlags"

compiler/src/dotty/tools/dotc/typer/Deriving.scala

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -89,47 +89,57 @@ trait Deriving { this: Typer =>
8989
val typeClass = derivedType.classSymbol
9090
val nparams = typeClass.typeParams.length
9191

92-
// A matrix of all parameter combinations of current class parameters
93-
// and derived typeclass parameters.
94-
// Rows: parameters of current class
95-
// Columns: parameters of typeclass
96-
97-
// Running example: typeclass: class TC[X, Y, Z], deriving class: class A[T, U]
98-
// clsParamss =
99-
// T_X T_Y T_Z
100-
// U_X U_Y U_Z
101-
val clsParamss: List[List[TypeSymbol]] = cls.typeParams.map { tparam =>
102-
if (nparams == 0) Nil
103-
else if (nparams == 1) tparam :: Nil
104-
else typeClass.typeParams.map(tcparam =>
105-
tparam.copy(name = s"${tparam.name}_$$_${tcparam.name}".toTypeName)
106-
.asInstanceOf[TypeSymbol])
107-
}
108-
val firstKindedParamss = clsParamss.filter {
109-
case param :: _ => !param.info.isLambdaSub
110-
case nil => false
111-
}
92+
lazy val clsTpe = cls.typeRef.EtaExpand(cls.typeParams)
93+
if (nparams == 1 && clsTpe.hasSameKindAs(typeClass.typeParams.head.info)) {
94+
// A "natural" type class instance ... the kind of the data type
95+
// matches the kind of the unique type class type parameter
96+
97+
val resultType = derivedType.appliedTo(clsTpe)
98+
val instanceInfo = ExprType(resultType)
99+
addDerivedInstance(originalType.typeSymbol.name, instanceInfo, derived.sourcePos)
100+
} else {
101+
// A matrix of all parameter combinations of current class parameters
102+
// and derived typeclass parameters.
103+
// Rows: parameters of current class
104+
// Columns: parameters of typeclass
105+
106+
// Running example: typeclass: class TC[X, Y, Z], deriving class: class A[T, U]
107+
// clsParamss =
108+
// T_X T_Y T_Z
109+
// U_X U_Y U_Z
110+
val clsParamss: List[List[TypeSymbol]] = cls.typeParams.map { tparam =>
111+
if (nparams == 0) Nil
112+
else if (nparams == 1) tparam :: Nil
113+
else typeClass.typeParams.map(tcparam =>
114+
tparam.copy(name = s"${tparam.name}_$$_${tcparam.name}".toTypeName)
115+
.asInstanceOf[TypeSymbol])
116+
}
117+
val firstKindedParamss = clsParamss.filter {
118+
case param :: _ => !param.info.isLambdaSub
119+
case nil => false
120+
}
112121

113-
// The types of the required evidence parameters. In the running example:
114-
// TC[T_X, T_Y, T_Z], TC[U_X, U_Y, U_Z]
115-
val evidenceParamInfos =
116-
for (row <- firstKindedParamss)
117-
yield derivedType.appliedTo(row.map(_.typeRef))
118-
119-
// The class instances in the result type. Running example:
120-
// A[T_X, U_X], A[T_Y, U_Y], A[T_Z, U_Z]
121-
val resultInstances =
122-
for (n <- List.range(0, nparams))
123-
yield cls.typeRef.appliedTo(clsParamss.map(row => row(n).typeRef))
124-
125-
// TC[A[T_X, U_X], A[T_Y, U_Y], A[T_Z, U_Z]]
126-
val resultType = derivedType.appliedTo(resultInstances)
127-
128-
val clsParams: List[TypeSymbol] = clsParamss.flatten
129-
val instanceInfo =
130-
if (clsParams.isEmpty) ExprType(resultType)
131-
else PolyType.fromParams(clsParams, ImplicitMethodType(evidenceParamInfos, resultType))
132-
addDerivedInstance(originalType.typeSymbol.name, instanceInfo, derived.sourcePos)
122+
// The types of the required evidence parameters. In the running example:
123+
// TC[T_X, T_Y, T_Z], TC[U_X, U_Y, U_Z]
124+
val evidenceParamInfos =
125+
for (row <- firstKindedParamss)
126+
yield derivedType.appliedTo(row.map(_.typeRef))
127+
128+
// The class instances in the result type. Running example:
129+
// A[T_X, U_X], A[T_Y, U_Y], A[T_Z, U_Z]
130+
val resultInstances =
131+
for (n <- List.range(0, nparams))
132+
yield cls.typeRef.appliedTo(clsParamss.map(row => row(n).typeRef))
133+
134+
// TC[A[T_X, U_X], A[T_Y, U_Y], A[T_Z, U_Z]]
135+
val resultType = derivedType.appliedTo(resultInstances)
136+
137+
val clsParams: List[TypeSymbol] = clsParamss.flatten
138+
val instanceInfo =
139+
if (clsParams.isEmpty) ExprType(resultType)
140+
else PolyType.fromParams(clsParams, ImplicitMethodType(evidenceParamInfos, resultType))
141+
addDerivedInstance(originalType.typeSymbol.name, instanceInfo, derived.sourcePos)
142+
}
133143
}
134144

135145
/** Create symbols for derived instances and infrastructure,

compiler/src/dotty/tools/dotc/typer/Implicits.scala

Lines changed: 126 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -867,16 +867,10 @@ trait Implicits { self: Typer =>
867867
* MirroredTypeConstrictor = <tycon>
868868
* MirroredLabel = <label> }
869869
*/
870-
private def mirrorCore(parent: Type, monoType: Type, label: Name)(implicit ctx: Context) = {
871-
val mirroredType = monoType match {
872-
case monoType @ AppliedType(tycon, targs) if targs.forall(_.isInstanceOf[TypeBounds]) =>
873-
EtaExpansion(tycon)
874-
case _ =>
875-
monoType
876-
}
870+
private def mirrorCore(parent: Type, monoType: Type, mirroredType: Type, label: Name)(implicit ctx: Context) = {
877871
parent
878872
.refinedWith(tpnme.MirroredMonoType, TypeAlias(monoType))
879-
.refinedWith(tpnme.MirroredTypeConstructor, TypeAlias(mirroredType))
873+
.refinedWith(tpnme.MirroredType, TypeAlias(mirroredType))
880874
.refinedWith(tpnme.MirroredLabel, TypeAlias(ConstantType(Constant(label.toString))))
881875
}
882876

@@ -892,106 +886,158 @@ trait Implicits { self: Typer =>
892886
*/
893887
lazy val synthesizedProductMirror: SpecialHandler =
894888
(formal: Type, span: Span) => implicit (ctx: Context) => {
895-
def mirrorFor(monoType: Type): Tree = monoType match {
896-
case AndType(tp1, tp2) =>
897-
mirrorFor(tp1).orElse(mirrorFor(tp2))
898-
case _ =>
899-
if (monoType.termSymbol.is(CaseVal)) {
900-
val module = monoType.termSymbol
901-
val modulePath = pathFor(monoType).withSpan(span)
902-
if (module.info.classSymbol.is(Scala2x)) {
903-
val mirrorType = mirrorCore(defn.Mirror_SingletonProxyType, monoType, module.name)
904-
val mirrorRef = New(defn.Mirror_SingletonProxyType, modulePath :: Nil)
905-
mirrorRef.cast(mirrorType)
889+
def mirrorFor(mirroredType0: Type): Tree = {
890+
val mirroredType = mirroredType0.stripTypeVar
891+
mirroredType match {
892+
case AndType(tp1, tp2) =>
893+
mirrorFor(tp1).orElse(mirrorFor(tp2))
894+
case _ =>
895+
if (mirroredType.termSymbol.is(CaseVal)) {
896+
val module = mirroredType.termSymbol
897+
val modulePath = pathFor(mirroredType).withSpan(span)
898+
if (module.info.classSymbol.is(Scala2x)) {
899+
val mirrorType = mirrorCore(defn.Mirror_SingletonProxyType, mirroredType, mirroredType, module.name)
900+
val mirrorRef = New(defn.Mirror_SingletonProxyType, modulePath :: Nil)
901+
mirrorRef.cast(mirrorType)
902+
}
903+
else {
904+
val mirrorType = mirrorCore(defn.Mirror_SingletonType, mirroredType, mirroredType, module.name)
905+
modulePath.cast(mirrorType)
906+
}
906907
}
907-
else {
908-
val mirrorType = mirrorCore(defn.Mirror_SingletonType, monoType, module.name)
909-
modulePath.cast(mirrorType)
908+
else if (mirroredType.classSymbol.isGenericProduct) {
909+
val cls = mirroredType.classSymbol
910+
val accessors = cls.caseAccessors.filterNot(_.is(PrivateLocal))
911+
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
912+
val (monoType, elemTypes) = mirroredType match {
913+
case mirroredType: HKTypeLambda =>
914+
val elems =
915+
mirroredType.derivedLambdaType(
916+
resType = TypeOps.nestedPairs(accessors.map(mirroredType.memberInfo(_).widenExpr))
917+
)
918+
val AppliedType(tycon, _) = mirroredType.resultType
919+
val monoType = AppliedType(tycon, mirroredType.paramInfos)
920+
(monoType, elems)
921+
case _ =>
922+
val elems = TypeOps.nestedPairs(accessors.map(mirroredType.memberInfo(_).widenExpr))
923+
(mirroredType, elems)
924+
}
925+
val mirrorType =
926+
mirrorCore(defn.Mirror_ProductType, monoType, mirroredType, cls.name)
927+
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemTypes))
928+
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
929+
val mirrorRef =
930+
if (cls.is(Scala2x)) anonymousMirror(monoType, ExtendsProductMirror, span)
931+
else companionPath(mirroredType, span)
932+
mirrorRef.cast(mirrorType)
910933
}
911-
}
912-
else if (monoType.classSymbol.isGenericProduct) {
913-
val cls = monoType.classSymbol
914-
val accessors = cls.caseAccessors.filterNot(_.is(PrivateLocal))
915-
val elemTypes = accessors.map(monoType.memberInfo(_).widenExpr)
916-
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
917-
val mirrorType =
918-
mirrorCore(defn.Mirror_ProductType, monoType, cls.name)
919-
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(TypeOps.nestedPairs(elemTypes)))
920-
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
921-
val mirrorRef =
922-
if (cls.is(Scala2x)) anonymousMirror(monoType, ExtendsProductMirror, span)
923-
else companionPath(monoType, span)
924-
mirrorRef.cast(mirrorType)
925-
}
926-
else EmptyTree
934+
else EmptyTree
935+
}
927936
}
928-
formal.member(tpnme.MirroredMonoType).info match {
929-
case monoAlias @ TypeAlias(monoType) => mirrorFor(monoType)
930-
case _ => EmptyTree
937+
938+
formal.member(tpnme.MirroredType).info match {
939+
case TypeAlias(mirroredType) => mirrorFor(mirroredType)
940+
case TypeBounds(mirroredType, _) => mirrorFor(mirroredType)
941+
case other => EmptyTree
931942
}
932943
}
933944

934945
/** An implied instance for a type of the form `Mirror.Sum { type MirroredMonoType = T }`
935946
* where `T` is a generic sum type.
936947
*/
937948
lazy val synthesizedSumMirror: SpecialHandler =
938-
(formal: Type, span: Span) => implicit (ctx: Context) =>
939-
formal.member(tpnme.MirroredMonoType).info match {
940-
case TypeAlias(monoType) if monoType.classSymbol.isGenericSum =>
941-
val cls = monoType.classSymbol
942-
val elemTypes = cls.children.map {
949+
(formal: Type, span: Span) => implicit (ctx: Context) => {
950+
def mirrorFor(mirroredType0: Type): Tree = {
951+
val mirroredType = mirroredType0.stripTypeVar
952+
if (mirroredType.classSymbol.isGenericSum) {
953+
val cls = mirroredType.classSymbol
954+
val elemLabels = cls.children.map(c => ConstantType(Constant(c.name.toString)))
955+
956+
def solve(sym: Symbol): Type = sym match {
943957
case caseClass: ClassSymbol =>
944958
assert(caseClass.is(Case))
945959
if (caseClass.is(Module))
946960
caseClass.sourceModule.termRef
947-
else caseClass.primaryConstructor.info match {
948-
case info: PolyType =>
949-
// Compute the the full child type by solving the subtype constraint
950-
// `C[X1, ..., Xn] <: P`, where
951-
//
952-
// - P is the current `monoType`
953-
// - C is the child class, with type parameters X1, ..., Xn
954-
//
955-
// Contravariant type parameters are minimized, all other type parameters are maximized.
956-
def instantiate(implicit ctx: Context) = {
957-
val poly = constrained(info, untpd.EmptyTree)._1
958-
val resType = poly.finalResultType
959-
resType <:< monoType
960-
val tparams = poly.paramRefs
961-
val variances = caseClass.typeParams.map(_.paramVariance)
962-
val instanceTypes = (tparams, variances).zipped.map((tparam, variance) =>
963-
ctx.typeComparer.instanceType(tparam, fromBelow = variance < 0))
964-
resType.substParams(poly, instanceTypes)
965-
}
966-
instantiate(ctx.fresh.setExploreTyperState().setOwner(caseClass))
967-
case _ =>
968-
caseClass.typeRef
961+
else {
962+
caseClass.primaryConstructor.info match {
963+
case info: PolyType =>
964+
// Compute the the full child type by solving the subtype constraint
965+
// `C[X1, ..., Xn] <: P`, where
966+
//
967+
// - P is the current `mirroredType`
968+
// - C is the child class, with type parameters X1, ..., Xn
969+
//
970+
// Contravariant type parameters are minimized, all other type parameters are maximized.
971+
def instantiate(implicit ctx: Context) = {
972+
val poly = constrained(info, untpd.EmptyTree)._1
973+
val resType = poly.finalResultType
974+
val target = mirroredType match {
975+
case tp: HKTypeLambda => tp.resultType
976+
case tp => tp
977+
}
978+
resType <:< target
979+
val tparams = poly.paramRefs
980+
val variances = caseClass.typeParams.map(_.paramVariance)
981+
val instanceTypes = (tparams, variances).zipped.map((tparam, variance) =>
982+
ctx.typeComparer.instanceType(tparam, fromBelow = variance < 0))
983+
resType.substParams(poly, instanceTypes)
984+
}
985+
instantiate(ctx.fresh.setExploreTyperState().setOwner(caseClass))
986+
case _ =>
987+
caseClass.typeRef
988+
}
969989
}
970990
case child => child.termRef
971991
}
992+
993+
val (monoType, elemTypes) = mirroredType match {
994+
case mirroredType: HKTypeLambda =>
995+
val elems = mirroredType.derivedLambdaType(
996+
resType = TypeOps.nestedPairs(cls.children.map(solve))
997+
)
998+
val AppliedType(tycon, _) = mirroredType.resultType
999+
val monoType = AppliedType(tycon, mirroredType.paramInfos)
1000+
(monoType, elems)
1001+
case _ =>
1002+
val elems = TypeOps.nestedPairs(cls.children.map(solve))
1003+
(mirroredType, elems)
1004+
}
1005+
9721006
val mirrorType =
973-
mirrorCore(defn.Mirror_SumType, monoType, cls.name)
974-
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(TypeOps.nestedPairs(elemTypes)))
1007+
mirrorCore(defn.Mirror_SumType, monoType, mirroredType, cls.name)
1008+
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemTypes))
1009+
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(TypeOps.nestedPairs(elemLabels)))
9751010
val mirrorRef =
976-
if (cls.linkedClass.exists && !cls.is(Scala2x)) companionPath(monoType, span)
1011+
if (cls.linkedClass.exists && !cls.is(Scala2x)) companionPath(mirroredType, span)
9771012
else anonymousMirror(monoType, ExtendsSumMirror, span)
9781013
mirrorRef.cast(mirrorType)
979-
case _ =>
980-
EmptyTree
1014+
} else EmptyTree
9811015
}
9821016

1017+
formal.member(tpnme.MirroredType).info match {
1018+
case TypeAlias(mirroredType) => mirrorFor(mirroredType)
1019+
case TypeBounds(mirroredType, _) => mirrorFor(mirroredType)
1020+
case _ => EmptyTree
1021+
}
1022+
}
1023+
9831024
/** An implied instance for a type of the form `Mirror { type MirroredMonoType = T }`
9841025
* where `T` is a generic sum or product or singleton type.
9851026
*/
9861027
lazy val synthesizedMirror: SpecialHandler =
987-
(formal: Type, span: Span) => implicit (ctx: Context) =>
988-
formal.member(tpnme.MirroredMonoType).info match {
989-
case monoAlias @ TypeAlias(monoType) =>
990-
if (monoType.termSymbol.is(CaseVal) || monoType.classSymbol.isGenericProduct)
991-
synthesizedProductMirror(formal, span)(ctx)
992-
else
993-
synthesizedSumMirror(formal, span)(ctx)
1028+
(formal: Type, span: Span) => implicit (ctx: Context) => {
1029+
def mirrorFor(mirroredType: Type): Tree =
1030+
if (mirroredType.termSymbol.is(CaseVal) || mirroredType.classSymbol.isGenericProduct)
1031+
synthesizedProductMirror(formal, span)(ctx)
1032+
else
1033+
synthesizedSumMirror(formal, span)(ctx)
1034+
1035+
formal.member(tpnme.MirroredType).info match {
1036+
case TypeAlias(mirroredType) => mirrorFor(mirroredType)
1037+
case TypeBounds(mirroredType, _) => mirrorFor(mirroredType)
1038+
case _ => EmptyTree
9941039
}
1040+
}
9951041

9961042
private var mySpecialHandlers: SpecialHandlers = null
9971043

0 commit comments

Comments
 (0)