Skip to content

Commit 3e1f445

Browse files
Add support for default arguments in product mirrors
1 parent 95266f2 commit 3e1f445

File tree

12 files changed

+264
-13
lines changed

12 files changed

+264
-13
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -787,6 +787,7 @@ class Definitions {
787787
@tu lazy val MirrorClass: ClassSymbol = requiredClass("scala.deriving.Mirror")
788788
@tu lazy val Mirror_ProductClass: ClassSymbol = requiredClass("scala.deriving.Mirror.Product")
789789
@tu lazy val Mirror_Product_fromProduct: Symbol = Mirror_ProductClass.requiredMethod(nme.fromProduct)
790+
@tu lazy val Mirror_Product_defaultArgument: Symbol = Mirror_ProductClass.requiredMethod(nme.defaultArgument)
790791
@tu lazy val Mirror_SumClass: ClassSymbol = requiredClass("scala.deriving.Mirror.Sum")
791792
@tu lazy val Mirror_SingletonClass: ClassSymbol = requiredClass("scala.deriving.Mirror.Singleton")
792793
@tu lazy val Mirror_SingletonProxyClass: ClassSymbol = requiredClass("scala.deriving.Mirror.SingletonProxy")

compiler/src/dotty/tools/dotc/core/StdNames.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ object StdNames {
368368
val LiteralAnnotArg: N = "LiteralAnnotArg"
369369
val Matchable: N = "Matchable"
370370
val MatchCase: N = "MatchCase"
371+
val MirroredElemHasDefaults: N = "MirroredElemHasDefaults"
371372
val MirroredElemTypes: N = "MirroredElemTypes"
372373
val MirroredElemLabels: N = "MirroredElemLabels"
373374
val MirroredLabel: N = "MirroredLabel"
@@ -452,6 +453,7 @@ object StdNames {
452453
val create: N = "create"
453454
val currentMirror: N = "currentMirror"
454455
val curried: N = "curried"
456+
val defaultArgument: N = "defaultArgument"
455457
val definitions: N = "definitions"
456458
val delayedInit: N = "delayedInit"
457459
val delayedInitArg: N = "delayedInit$body"

compiler/src/dotty/tools/dotc/transform/SyntheticMembers.scala

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import Decorators.*
99
import NameOps.*
1010
import Annotations.Annotation
1111
import typer.ProtoTypes.constrained
12-
import ast.untpd
12+
import ast.{tpd, untpd}
1313

1414
import util.Property
1515
import util.Spans.Span
@@ -547,6 +547,30 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
547547
New(classRefApplied, elems)
548548
end fromProductBody
549549

550+
def defaultArgumentBody(caseClass: Symbol, index: Tree, optInfo: Option[MirrorImpl.OfProduct])(using Context): Tree =
551+
val companionTree: Tree =
552+
val companion: Symbol = caseClass.companionModule
553+
val prefix: Type = optInfo.fold(NoPrefix)(_.pre)
554+
ref(TermRef(prefix, companion.asTerm))
555+
556+
def defaultArgumentGetter(idx: Int): Tree =
557+
val getterName = NameKinds.DefaultGetterName(nme.CONSTRUCTOR, idx)
558+
val getterDenot = companionTree.tpe.member(getterName)
559+
companionTree.select(TermRef(companionTree.tpe, getterName, getterDenot))
560+
561+
val withDefaultCases = for
562+
(acc, idx) <- caseClass.caseAccessors.zipWithIndex if acc.is(HasDefault)
563+
body = Typed(defaultArgumentGetter(idx), TypeTree(defn.AnyType)) // so match tree does try to find union of case types
564+
yield CaseDef(Literal(Constant(idx)), EmptyTree, body)
565+
566+
val withoutDefaultCase =
567+
val stringIndex = Apply(Select(index, nme.toString_), Nil)
568+
val nsee = tpd.resolveConstructor(defn.NoSuchElementExceptionType, List(stringIndex))
569+
CaseDef(Underscore(defn.IntType), EmptyTree, Throw(nsee))
570+
571+
Match(index, withDefaultCases :+ withoutDefaultCase)
572+
end defaultArgumentBody
573+
550574
/** For an enum T:
551575
*
552576
* def ordinal(x: MirroredMonoType) = x.ordinal
@@ -616,6 +640,12 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
616640
synthesizeDef(meth, vrefss => body(cls, vrefss.head.head))
617641
}
618642
}
643+
def overrideMethod(name: TermName, info: Type, cls: Symbol, body: (Symbol, Tree) => Context ?=> Tree, isExperimental: Boolean = false): Unit = {
644+
val meth = newSymbol(clazz, name, Synthetic | Method | Override, info, coord = clazz.coord)
645+
if isExperimental then meth.addAnnotation(defn.ExperimentalAnnot)
646+
meth.enteredAfter(thisPhase)
647+
newBody = newBody :+ synthesizeDef(meth, vrefss => body(cls, vrefss.head.head))
648+
}
619649
val linked = clazz.linkedClass
620650
lazy val monoType = {
621651
val existing = clazz.info.member(tpnme.MirroredMonoType).symbol
@@ -633,6 +663,9 @@ class SyntheticMembers(thisPhase: DenotTransformer) {
633663
addParent(defn.Mirror_ProductClass.typeRef)
634664
addMethod(nme.fromProduct, MethodType(defn.ProductClass.typeRef :: Nil, monoType.typeRef), cls,
635665
fromProductBody(_, _, optInfo).ensureConforms(monoType.typeRef)) // t4758.scala or i3381.scala are examples where a cast is needed
666+
if cls.primaryConstructor.hasDefaultParams then
667+
overrideMethod(nme.defaultArgument, MethodType(defn.IntType :: Nil, defn.AnyType), cls,
668+
defaultArgumentBody(_, _, optInfo), isExperimental = true)
636669
}
637670
def makeSumMirror(cls: Symbol, optInfo: Option[MirrorImpl.OfSum]) = {
638671
addParent(defn.Mirror_SumClass.typeRef)

compiler/src/dotty/tools/dotc/typer/Synthesizer.scala

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -409,25 +409,30 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context):
409409

410410
def makeProductMirror(pre: Type, cls: Symbol, tps: Option[List[Type]]): TreeWithErrors =
411411
val accessors = cls.caseAccessors
412-
val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString)))
413-
val typeElems = tps.getOrElse(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
414-
val nestedPairs = TypeOps.nestedPairs(typeElems)
415-
val (monoType, elemsType) = mirroredType match
412+
val Seq(elemLabels, elemHasDefaults, elemTypes1) =
413+
Seq(
414+
accessors.map(acc => ConstantType(Constant(acc.name.toString))),
415+
accessors.map(acc => ConstantType(Constant(acc.is(HasDefault)))),
416+
tps.getOrElse(accessors.map(mirroredType.resultType.memberInfo(_).widenExpr))
417+
).map(TypeOps.nestedPairs)
418+
val (monoType, elemTypes) = mirroredType match
416419
case mirroredType: HKTypeLambda =>
417-
(mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = nestedPairs))
420+
(mkMirroredMonoType(mirroredType), mirroredType.derivedLambdaType(resType = elemTypes1))
418421
case _ =>
419-
(mirroredType, nestedPairs)
420-
val elemsLabels = TypeOps.nestedPairs(elemLabels)
421-
checkRefinement(formal, tpnme.MirroredElemTypes, elemsType, span)
422-
checkRefinement(formal, tpnme.MirroredElemLabels, elemsLabels, span)
422+
(mirroredType, elemTypes1)
423+
424+
checkRefinement(formal, tpnme.MirroredElemTypes, elemTypes, span)
425+
checkRefinement(formal, tpnme.MirroredElemLabels, elemLabels, span)
426+
checkRefinement(formal, tpnme.MirroredElemHasDefaults, elemHasDefaults, span)
423427
val mirrorType = formal.constrained_& {
424428
mirrorCore(defn.Mirror_ProductClass, monoType, mirroredType, cls.name)
425-
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemsType))
426-
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemsLabels))
429+
.refinedWith(tpnme.MirroredElemTypes, TypeAlias(elemTypes))
430+
.refinedWith(tpnme.MirroredElemLabels, TypeAlias(elemLabels))
431+
.refinedWith(tpnme.MirroredElemHasDefaults, TypeAlias(elemHasDefaults))
427432
}
428433
val mirrorRef =
429434
if cls.useCompanionAsProductMirror then companionPath(mirroredType, span)
430-
else if defn.isTupleClass(cls) then newTupleMirror(typeElems.size) // TODO: cls == defn.PairClass when > 22
435+
else if defn.isTupleClass(cls) then newTupleMirror(accessors.size) // TODO: cls == defn.PairClass when > 22
431436
else anonymousMirror(monoType, MirrorImpl.OfProduct(pre), span)
432437
withNoErrors(mirrorRef.cast(mirrorType).withSpan(span))
433438
end makeProductMirror

library/src/scala/deriving/Mirror.scala

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
package scala.deriving
22

3+
import java.util.NoSuchElementException
4+
import scala.annotation.experimental
5+
36
/** Mirrors allows typelevel access to enums, case classes and objects, and their sealed parents.
47
*/
58
sealed trait Mirror {
@@ -27,6 +30,14 @@ object Mirror {
2730

2831
/** Create a new instance of type `T` with elements taken from product `p`. */
2932
def fromProduct(p: scala.Product): MirroredMonoType
33+
34+
/** Whether each product element has a default value */
35+
@experimental type MirroredElemHasDefaults <: Tuple
36+
37+
/** The default argument of the product argument at given `index` */
38+
@experimental def defaultArgument(index: Int): Any =
39+
throw NoSuchElementException(String.valueOf(index))
40+
3041
}
3142

3243
trait Singleton extends Product {

project/MiMaFilters.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ object MiMaFilters {
2828
val LibraryForward: Map[String, Seq[ProblemFilter]] = Map(
2929
// Additions that require a new minor version of the library
3030
Build.previousDottyVersion -> Seq(
31+
ProblemFilters.exclude[DirectMissingMethodProblem]("scala.compiletime.testing.Error.defaultArgument"),
3132
),
3233

3334
// Additions since last LTS
@@ -62,6 +63,7 @@ object MiMaFilters {
6263
),
6364
)
6465
val TastyCore: Seq[ProblemFilter] = Seq(
66+
ProblemFilters.exclude[DirectMissingMethodProblem]("dotty.tools.tasty.TastyVersion.defaultArgument"),
6567
)
6668
val Interfaces: Seq[ProblemFilter] = Seq(
6769
)

tests/run-macros/i7987.check

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,5 @@ scala.deriving.Mirror.Product {
44
type MirroredLabel >: "Some" <: "Some"
55
type MirroredElemTypes >: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple] <: scala.*:[scala.Int, scala.Tuple$package.EmptyTuple]
66
type MirroredElemLabels >: scala.*:["value", scala.Tuple$package.EmptyTuple] <: scala.*:["value", scala.Tuple$package.EmptyTuple]
7+
type MirroredElemHasDefaults >: scala.*:[false, scala.Tuple$package.EmptyTuple] <: scala.*:[false, scala.Tuple$package.EmptyTuple]
78
}
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import scala.deriving._
2+
import scala.annotation.experimental
3+
import scala.quoted._
4+
5+
object MirrorOps:
6+
7+
inline def overridesDefaultArgument[T]: Boolean = ${ overridesDefaultArgumentImpl[T] }
8+
9+
def overridesDefaultArgumentImpl[T](using Quotes, Type[T]): Expr[Boolean] =
10+
import quotes.reflect.*
11+
val cls = TypeRepr.of[T].classSymbol.get
12+
val companion = cls.companionModule.moduleClass
13+
val methods = companion.declaredMethods
14+
15+
val experAnnotType = Symbol.requiredClass("scala.annotation.experimental").typeRef
16+
17+
Expr {
18+
methods.exists { m =>
19+
m.name == "defaultArgument" &&
20+
m.flags.is(Flags.Synthetic) &&
21+
m.annotations.exists(_.tpe <:< experAnnotType)
22+
}
23+
}
24+
25+
end MirrorOps
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
import scala.deriving._
2+
import scala.annotation.experimental
3+
import scala.quoted._
4+
5+
import MirrorOps.*
6+
7+
object Test extends App:
8+
9+
case class WithDefault(x: Int, y: Int = 1)
10+
assert(overridesDefaultArgument[WithDefault])
11+
12+
case class WithoutDefault(x: Int)
13+
assert(!overridesDefaultArgument[WithoutDefault])

tests/run-tasty-inspector/stdlibExperimentalDefinitions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,10 @@ val experimentalDefinitionInLibrary = Set(
9696
"scala.Tuple$.Reverse", // can be stabilized in 3.5
9797
"scala.Tuple$.ReverseOnto", // can be stabilized in 3.5
9898
"scala.runtime.Tuples$.reverse", // can be stabilized in 3.5
99+
100+
// New APIs: Mirror support for default arguments
101+
"scala.deriving.Mirror$.Product.MirroredElemHasDefaults",
102+
"scala.deriving.Mirror$.Product.defaultArgument",
99103
)
100104

101105

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import scala.deriving._
2+
import scala.annotation.experimental
3+
4+
object Test extends App:
5+
6+
case class WithDefault(x: Int, y: Int = 1)
7+
val m = summon[Mirror.Of[WithDefault]]
8+
assert(m.defaultArgument(1) == 1)
9+
try
10+
m.defaultArgument(0)
11+
throw IllegalStateException("There should be no default argument")
12+
catch
13+
case ex: NoSuchElementException => assert(ex.getMessage == "0") // Ok
14+
15+
16+
case class WithCompanion(s: String = "hello")
17+
case object WithCompanion // => mirrors must be anonymous
18+
19+
val m2 = summon[Mirror.Of[WithCompanion]]
20+
assert(m2 ne WithCompanion)
21+
assert(m2.defaultArgument(0) == "hello")
22+
23+
24+
class Outer(val i: Int) {
25+
26+
case class Inner(x: Int, y: Int = i + 1)
27+
case object Inner
28+
29+
val m3 = summon[Mirror.Of[Inner]]
30+
assert(m3.defaultArgument(1) == i + 1)
31+
32+
def localTest(d: Double): Unit = {
33+
case class Local(x: Int = i, y: Double = d, z: Double = i + d)
34+
case object Local
35+
36+
val m4 = summon[Mirror.Of[Local]]
37+
assert(m4.defaultArgument(0) == i)
38+
assert(m4.defaultArgument(1) == d)
39+
assert(m4.defaultArgument(2) == i + d)
40+
}
41+
42+
}
43+
44+
val outer = Outer(3)
45+
val m5 = summon[Mirror.Of[outer.Inner]]
46+
assert(m5.defaultArgument(1) == 3 + 1)
47+
outer.localTest(9d)
48+
49+
50+
// new defaultArgument match tree should be able to unify different default value types
51+
case class Foo[T](x: Int = 0, y: String = "hi")
52+
53+
end Test
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import scala.deriving.Mirror as M
2+
import scala.deriving.*
3+
import scala.Tuple.*
4+
import scala.compiletime.*
5+
import scala.compiletime.ops.int.S
6+
7+
trait Migration[-From, +To]:
8+
def apply(x: From): To
9+
10+
object Migration:
11+
12+
extension [From](x: From)
13+
def migrateTo[To](using m: Migration[From, To]): To = m(x)
14+
15+
given[T]: Migration[T, T] with
16+
override def apply(x: T): T = x
17+
18+
type IndexOf[Elems <: Tuple, X] <: Int = Elems match {
19+
case (X *: elems) => 0
20+
case (_ *: elems) => S[IndexOf[elems, X]]
21+
case EmptyTuple => Nothing
22+
}
23+
24+
inline def migrateElem[F,T, ToIdx <: Int](from: M.ProductOf[F], to: M.ProductOf[T])(x: Product): Any =
25+
26+
type Label = Elem[to.MirroredElemLabels, ToIdx]
27+
type FromIdx = IndexOf[from.MirroredElemLabels, Label]
28+
inline constValueOpt[FromIdx] match
29+
30+
case Some(fromIdx) =>
31+
type FromType = Elem[from.MirroredElemTypes, FromIdx]
32+
type ToType = Elem[to.MirroredElemTypes, ToIdx]
33+
summonFrom { case _: Migration[FromType, ToType] =>
34+
x.productElement(fromIdx).asInstanceOf[FromType].migrateTo[ToType]
35+
}
36+
37+
case None =>
38+
type HasDefault = Elem[to.MirroredElemHasDefaults, ToIdx]
39+
inline erasedValue[HasDefault] match
40+
case _: true => to.defaultArgument(constValue[ToIdx])
41+
case _: false => compiletime.error("An element has no equivalent or default")
42+
43+
44+
inline def migrateElems[F,T, ToIdx <: Int](from: M.ProductOf[F], to: M.ProductOf[T])(x: Product): Seq[Any] =
45+
inline erasedValue[ToIdx] match
46+
case _: Tuple.Size[to.MirroredElemLabels] => Seq()
47+
case _ => migrateElem[F,T,ToIdx](from, to)(x) +: migrateElems[F,T,S[ToIdx]](from, to)(x)
48+
49+
inline def migrateProduct[F,T](from: M.ProductOf[F], to: M.ProductOf[T])
50+
(x: Product): T =
51+
val elems = migrateElems[F, T, 0](from, to)(x)
52+
to.fromProduct(new Product:
53+
def canEqual(that: Any): Boolean = false
54+
def productArity: Int = elems.length
55+
def productElement(n: Int): Any = elems(n)
56+
)
57+
58+
inline def migration[F,T](using from: M.Of[F], to: M.Of[T]): Migration[F,T] = (x: F) =>
59+
inline from match
60+
case fromP: M.ProductOf[F] => inline to match
61+
case toP: M.ProductOf[T] => migrateProduct[F, T](fromP, toP)(x.asInstanceOf[Product])
62+
case _: M.SumOf[T] => compiletime.error("Cannot migrate sums")
63+
case _: M.SumOf[F] => compiletime.error("Cannot migrate sums")
64+
65+
end Migration
66+
67+
68+
import Migration.*
69+
object Test extends App:
70+
71+
case class A1(x: Int)
72+
case class A2(x: Int)
73+
given Migration[A1, A2] = migration
74+
assert(A1(2).migrateTo[A2] == A2(2))
75+
76+
case class B1(x: Int, y: String)
77+
case class B2(y: String, x: Int)
78+
given Migration[B1, B2] = migration
79+
assert(B1(5, "hi").migrateTo[B2] == B2("hi", 5))
80+
81+
case class C1(x: A1)
82+
case class C2(x: A2)
83+
given Migration[C1, C2] = migration
84+
assert(C1(A1(0)).migrateTo[C2] == C2(A2(0)))
85+
86+
case class D1(x: Double)
87+
case class D2(b: Boolean = true, x: Double)
88+
given Migration[D1, D2] = migration
89+
assert(D1(9).migrateTo[D2] == D2(true, 9))
90+
91+
case class E1(x: D1, y: D1)
92+
case class E2(y: D2, s: String = "hi", x: D2)
93+
given Migration[E1, E2] = migration
94+
assert(E1(D1(1), D1(2)).migrateTo[E2] == E2(D2(true, 2), "hi", D2(true, 1)))
95+
96+
// should only use default when needed
97+
case class F1(x: Int)
98+
case class F2(x: Int = 3)
99+
given Migration[F1, F2] = migration
100+
assert(F1(7).migrateTo[F2] == F2(7))
101+

0 commit comments

Comments
 (0)