Skip to content

Fix #3248: support product-seq pattern #5989

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Mar 11, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 13 additions & 7 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -432,6 +432,12 @@ object desugar {
appliedTypeTree(tycon, targs)
}

def isRepeated(tree: Tree): Boolean = tree match {
case PostfixOp(_, Ident(tpnme.raw.STAR)) => true
case ByNameTypeTree(tree1) => isRepeated(tree1)
case _ => false
}

// a reference to the class type bound by `cdef`, with type parameters coming from the constructor
val classTypeRef = appliedRef(classTycon)

Expand Down Expand Up @@ -482,11 +488,6 @@ object desugar {
}
def enumTagMeths = if (isEnumCase) enumTagMeth(CaseKind.Class)._1 :: Nil else Nil
def copyMeths = {
def isRepeated(tree: Tree): Boolean = tree match {
case PostfixOp(_, Ident(tpnme.raw.STAR)) => true
case ByNameTypeTree(tree1) => isRepeated(tree1)
case _ => false
}
val hasRepeatedParam = constrVparamss.exists(_.exists {
case ValDef(_, tpt, _) => isRepeated(tpt)
})
Expand Down Expand Up @@ -560,7 +561,8 @@ object desugar {
// companion definitions include:
// 1. If class is a case class case class C[Ts](p1: T1, ..., pN: TN)(moreParams):
// def apply[Ts](p1: T1, ..., pN: TN)(moreParams) = new C[Ts](p1, ..., pN)(moreParams) (unless C is abstract)
// def unapply[Ts]($1: C[Ts]) = $1
// def unapply[Ts]($1: C[Ts]) = $1 // if not repeated
// def unapplySeq[Ts]($1: C[Ts]) = $1 // if repeated
// 2. The default getters of the constructor
// The parent of the companion object of a non-parameterized case class
// (T11, ..., T1N) => ... => (TM1, ..., TMN) => C
Expand Down Expand Up @@ -609,9 +611,13 @@ object desugar {
app :: widenDefs
}
val unapplyMeth = {
val hasRepeatedParam = constrVparamss.head.exists {
case ValDef(_, tpt, _) => isRepeated(tpt)
}
val methName = if (hasRepeatedParam) nme.unapplySeq else nme.unapply
val unapplyParam = makeSyntheticParameter(tpt = classTypeRef)
val unapplyRHS = if (arity == 0) Literal(Constant(true)) else Ident(unapplyParam.name)
DefDef(nme.unapply, derivedTparams, (unapplyParam :: Nil) :: Nil, TypeTree(), unapplyRHS)
DefDef(methName, derivedTparams, (unapplyParam :: Nil) :: Nil, TypeTree(), unapplyRHS)
.withMods(synthetic)
}
companionDefs(companionParent, applyMeths ::: unapplyMeth :: companionMembers)
Expand Down
27 changes: 25 additions & 2 deletions compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ import collection.mutable
import Symbols._, Contexts._, Types._, StdNames._, NameOps._
import ast.Trees._
import util.Spans._
import typer.Applications.{isProductMatch, isGetMatch, productSelectors}
import typer.Applications.{isProductMatch, isGetMatch, isProductSeqMatch, productSelectors, productArity}
import SymUtils._
import Flags._, Constants._
import Decorators._
Expand Down Expand Up @@ -286,6 +286,21 @@ object PatternMatcher {
matchElemsPlan(getResult, args, exact = true, onSuccess)
}

/** Plan for matching the sequence in `getResult`
*
* `getResult` is a product, where the last element is a sequence of elements.
*/
def unapplyProductSeqPlan(getResult: Symbol, args: List[Tree], arity: Int): Plan = {
assert(arity <= args.size + 1)
val selectors = productSelectors(getResult.info).map(ref(getResult).select(_))

val matchSeq =
letAbstract(selectors.last) { seqResult =>
unapplySeqPlan(seqResult, args.drop(arity - 1))
}
matchArgsPlan(selectors.take(arity - 1), args.take(arity - 1), matchSeq)
}

/** Plan for matching the result of an unapply against argument patterns `args` */
def unapplyPlan(unapp: Tree, args: List[Tree]): Plan = {
def caseClass = unapp.symbol.owner.linkedClass
Expand All @@ -306,12 +321,20 @@ object PatternMatcher {
.map(ref(unappResult).select(_))
matchArgsPlan(selectors, args, onSuccess)
}
else if (isProductSeqMatch(unapp.tpe.widen, args.length, unapp.sourcePos) && isUnapplySeq) {
val arity = productArity(unapp.tpe.widen, unapp.sourcePos)
unapplyProductSeqPlan(unappResult, args, arity)
}
else {
assert(isGetMatch(unapp.tpe))
val argsPlan = {
val get = ref(unappResult).select(nme.get, _.info.isParameterless)
val arity = productArity(get.tpe, unapp.sourcePos)
if (isUnapplySeq)
letAbstract(get)(unapplySeqPlan(_, args))
letAbstract(get) { getResult =>
if (arity > 0) unapplyProductSeqPlan(getResult, args, arity)
else unapplySeqPlan(getResult, args)
}
else
letAbstract(get) { getResult =>
val selectors =
Expand Down
49 changes: 38 additions & 11 deletions compiler/src/dotty/tools/dotc/transform/patmat/Space.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import ProtoTypes._
import transform.SymUtils._
import reporting.diagnostic.messages._
import config.Printers.{exhaustivity => debug}
import util.SourcePosition

/** Space logic for checking exhaustivity and unreachability of pattern matching
*
Expand Down Expand Up @@ -338,8 +339,13 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
if (fun.symbol.name == nme.unapplySeq)
if (fun.symbol.owner == scalaSeqFactoryClass)
projectSeq(pats)
else
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, projectSeq(pats) :: Nil, irrefutable(fun))
else {
val (arity, elemTp, resultTp) = unapplySeqInfo(fun.tpe.widen.finalResultType, fun.sourcePos)
if (elemTp.exists)
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, projectSeq(pats) :: Nil, irrefutable(fun))
else
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, pats.take(arity - 1).map(project) :+ projectSeq(pats.drop(arity - 1)), irrefutable(fun))
}
else
Prod(erase(pat.tpe.stripAnnots), fun.tpe, fun.symbol, pats.map(project), irrefutable(fun))
case Typed(pat @ UnApply(_, _, _), _) => project(pat)
Expand All @@ -354,6 +360,18 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
Empty
}

private def unapplySeqInfo(resTp: Type, pos: SourcePosition)(implicit ctx: Context): (Int, Type, Type) = {
var resultTp = resTp
var elemTp = unapplySeqTypeElemTp(resultTp)
var arity = productArity(resultTp, pos)
if (!elemTp.exists && arity <= 0) {
resultTp = resTp.select(nme.get).finalResultType
elemTp = unapplySeqTypeElemTp(resultTp.widen)
arity = productSelectorTypes(resultTp, pos).size
}
(arity, elemTp, resultTp)
}

/* Erase pattern bound types with WildcardType */
def erase(tp: Type): Type = {
def isPatternTypeSymbol(sym: Symbol) = !sym.isClass && sym.is(Case)
Expand Down Expand Up @@ -424,17 +442,26 @@ class SpaceEngine(implicit ctx: Context) extends SpaceLogic {
List()
else {
val isUnapplySeq = unappSym.name == nme.unapplySeq
if (isProductMatch(mt.finalResultType, argLen) && !isUnapplySeq) {
productSelectors(mt.finalResultType).take(argLen)
.map(_.info.asSeenFrom(mt.finalResultType, mt.resultType.classSymbol).widenExpr)

if (isUnapplySeq) {
val (arity, elemTp, resultTp) = unapplySeqInfo(mt.finalResultType, unappSym.sourcePos)
if (elemTp.exists) scalaListType.appliedTo(elemTp) :: Nil
else {
val sels = productSeqSelectors(resultTp, arity, unappSym.sourcePos)
sels.init :+ scalaListType.appliedTo(sels.last)
}
}
else {
val resTp = mt.finalResultType.select(nme.get).finalResultType.widen
if (isUnapplySeq) scalaListType.appliedTo(resTp.argTypes.head) :: Nil
else if (argLen == 0) Nil
else if (isProductMatch(resTp, argLen))
productSelectors(resTp).map(_.info.asSeenFrom(resTp, resTp.classSymbol).widenExpr)
else resTp :: Nil
val arity = productArity(mt.finalResultType, unappSym.sourcePos)
if (arity > 0)
productSelectors(mt.finalResultType)
.map(_.info.asSeenFrom(mt.finalResultType, mt.resultType.classSymbol).widenExpr)
else {
val resTp = mt.finalResultType.select(nme.get).finalResultType.widen
val arity = productArity(resTp, unappSym.sourcePos)
if (argLen == 1) resTp :: Nil
else productSelectors(resTp).map(_.info.asSeenFrom(resTp, resTp.classSymbol).widenExpr)
}
}
}

Expand Down
115 changes: 65 additions & 50 deletions compiler/src/dotty/tools/dotc/typer/Applications.scala
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,22 @@ object Applications {

/** Does `tp` fit the "product match" conditions as an unapply result type
* for a pattern with `numArgs` subpatterns?
* This is the case of `tp` has members `_1` to `_N` where `N == numArgs`.
* This is the case if `tp` has members `_1` to `_N` where `N == numArgs`.
*/
def isProductMatch(tp: Type, numArgs: Int, errorPos: SourcePosition = NoSourcePosition)(implicit ctx: Context): Boolean =
numArgs > 0 && productArity(tp, errorPos) == numArgs

/** Does `tp` fit the "product-seq match" conditions as an unapply result type
* for a pattern with `numArgs` subpatterns?
* This is the case if (1) `tp` has members `_1` to `_N` where `N <= numArgs + 1`.
* (2) `tp._N` conforms to Seq match
*/
def isProductSeqMatch(tp: Type, numArgs: Int, errorPos: SourcePosition = NoSourcePosition)(implicit ctx: Context): Boolean = {
val arity = productArity(tp, errorPos)
arity > 0 && arity <= numArgs + 1 &&
unapplySeqTypeElemTp(productSelectorTypes(tp, errorPos).last).exists
}

/** Does `tp` fit the "get match" conditions as an unapply result type?
* This is the case of `tp` has a `get` member as well as a
* parameterless `isEmpty` member of result type `Boolean`.
Expand All @@ -60,6 +71,39 @@ object Applications {
extractorMemberType(tp, nme.isEmpty, errorPos).isRef(defn.BooleanClass) &&
extractorMemberType(tp, nme.get, errorPos).exists

/** If `getType` is of the form:
* ```
* {
* def lengthCompare(len: Int): Int // or, def length: Int
* def apply(i: Int): T = a(i)
* def drop(n: Int): scala.Seq[T]
* def toSeq: scala.Seq[T]
* }
* ```
* returns `T`, otherwise NoType.
*/
def unapplySeqTypeElemTp(getTp: Type)(implicit ctx: Context): Type = {
def lengthTp = ExprType(defn.IntType)
def lengthCompareTp = MethodType(List(defn.IntType), defn.IntType)
def applyTp(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
def dropTp(elemTp: Type) = MethodType(List(defn.IntType), defn.SeqType.appliedTo(elemTp))
def toSeqTp(elemTp: Type) = ExprType(defn.SeqType.appliedTo(elemTp))

// the result type of `def apply(i: Int): T`
val elemTp = getTp.member(nme.apply).suchThat(_.info <:< applyTp(WildcardType)).info.resultType

def hasMethod(name: Name, tp: Type) =
getTp.member(name).suchThat(getTp.memberInfo(_) <:< tp).exists

val isValid =
elemTp.exists &&
(hasMethod(nme.lengthCompare, lengthCompareTp) || hasMethod(nme.length, lengthTp)) &&
hasMethod(nme.drop, dropTp(elemTp)) &&
hasMethod(nme.toSeq, toSeqTp(elemTp))

if (isValid) elemTp else NoType
}

def productSelectorTypes(tp: Type, errorPos: SourcePosition)(implicit ctx: Context): List[Type] = {
def tupleSelectors(n: Int, tp: Type): List[Type] = {
val sel = extractorMemberType(tp, nme.selectorName(n), errorPos)
Expand Down Expand Up @@ -92,57 +136,35 @@ object Applications {
else tp :: Nil
} else tp :: Nil

def productSeqSelectors(tp: Type, argsNum: Int, pos: SourcePosition)(implicit ctx: Context): List[Type] = {
val selTps = productSelectorTypes(tp, pos)
val arity = selTps.length
val elemTp = unapplySeqTypeElemTp(selTps.last)
(0 until argsNum).map(i => if (i < arity - 1) selTps(i) else elemTp).toList
}

def unapplyArgs(unapplyResult: Type, unapplyFn: Tree, args: List[untpd.Tree], pos: SourcePosition)(implicit ctx: Context): List[Type] = {

val unapplyName = unapplyFn.symbol.name
def seqSelector = defn.RepeatedParamType.appliedTo(unapplyResult.elemType :: Nil)
def getTp = extractorMemberType(unapplyResult, nme.get, pos)

def fail = {
ctx.error(UnapplyInvalidReturnType(unapplyResult, unapplyName), pos)
Nil
}

/** If `getType` is of the form:
* ```
* {
* def lengthCompare(len: Int): Int // or, def length: Int
* def apply(i: Int): T = a(i)
* def drop(n: Int): scala.Seq[T]
* def toSeq: scala.Seq[T]
* }
* ```
* returns `T`, otherwise NoType.
*/
def unapplySeqTypeElemTp(getTp: Type): Type = {
def lengthTp = ExprType(defn.IntType)
def lengthCompareTp = MethodType(List(defn.IntType), defn.IntType)
def applyTp(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
def dropTp(elemTp: Type) = MethodType(List(defn.IntType), defn.SeqType.appliedTo(elemTp))
def toSeqTp(elemTp: Type) = defn.SeqType.appliedTo(elemTp)

// the result type of `def apply(i: Int): T`
val elemTp = getTp.member(nme.apply).suchThat(_.info <:< applyTp(WildcardType)).info.resultType

def hasMethod(name: Name, tp: Type) =
getTp.member(name).suchThat(getTp.memberInfo(_) <:< tp).exists

val isValid =
elemTp.exists &&
(hasMethod(nme.lengthCompare, lengthCompareTp) || hasMethod(nme.length, lengthTp)) &&
hasMethod(nme.drop, dropTp(elemTp)) &&
hasMethod(nme.toSeq, toSeqTp(elemTp))

if (isValid) elemTp else NoType
def unapplySeq(tp: Type)(fallback: => List[Type]): List[Type] = {
val elemTp = unapplySeqTypeElemTp(tp)
if (elemTp.exists) args.map(Function.const(elemTp))
else if (isProductSeqMatch(tp, args.length, pos)) productSeqSelectors(tp, args.length, pos)
else fallback
}

if (unapplyName == nme.unapplySeq) {
if (isGetMatch(unapplyResult, pos)) {
val elemTp = unapplySeqTypeElemTp(getTp)
if (elemTp.exists) args.map(Function.const(elemTp))
unapplySeq(unapplyResult) {
if (isGetMatch(unapplyResult, pos)) unapplySeq(getTp)(fail)
else fail
}
else fail
}
else {
assert(unapplyName == nme.unapply)
Expand Down Expand Up @@ -1076,19 +1098,12 @@ trait Applications extends Compatibility { self: Typer with Dynamic =>

var argTypes = unapplyArgs(unapplyApp.tpe, unapplyFn, args, tree.sourcePos)
for (argType <- argTypes) assert(!isBounds(argType), unapplyApp.tpe.show)
val bunchedArgs =
if (argTypes.nonEmpty && argTypes.last.isRepeatedParam)
args.lastOption match {
case Some(arg @ Typed(argSeq, _)) if untpd.isWildcardStarArg(arg) =>
args.init :+ argSeq
case _ =>
val (regularArgs, varArgs) = args.splitAt(argTypes.length - 1)
regularArgs :+ untpd.SeqLiteral(varArgs, untpd.TypeTree()).withSpan(tree.span)
}
else if (argTypes.lengthCompare(1) == 0 && args.lengthCompare(1) > 0 && ctx.canAutoTuple)
untpd.Tuple(args) :: Nil
else
args
val bunchedArgs = argTypes match {
case argType :: Nil =>
if (args.lengthCompare(1) > 0 && ctx.canAutoTuple) untpd.Tuple(args) :: Nil
else args
case _ => args
}
if (argTypes.length != bunchedArgs.length) {
ctx.error(UnapplyInvalidNumberOfArguments(qual, argTypes), tree.sourcePos)
argTypes = argTypes.take(args.length) ++
Expand Down
Loading