Skip to content

Commit 4c12ed3

Browse files
liufengyunallanrenucci
authored andcommitted
address review: try length if lengthCompare unavailable
1 parent 6999ca8 commit 4c12ed3

File tree

4 files changed

+69
-23
lines changed

4 files changed

+69
-23
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,8 @@ class Definitions {
422422
def Seq_drop(implicit ctx: Context) = Seq_dropR.symbol
423423
lazy val Seq_lengthCompareR = SeqClass.requiredMethodRef(nme.lengthCompare)
424424
def Seq_lengthCompare(implicit ctx: Context) = Seq_lengthCompareR.symbol
425+
lazy val Seq_lengthR = SeqClass.requiredMethodRef(nme.length)
426+
def Seq_length(implicit ctx: Context) = Seq_lengthR.symbol
425427
lazy val Seq_toSeqR = SeqClass.requiredMethodRef(nme.toSeq)
426428
def Seq_toSeq(implicit ctx: Context) = Seq_toSeqR.symbol
427429

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -643,11 +643,18 @@ object PatternMatcher {
643643
case EqualTest(tree) =>
644644
tree.equal(scrutinee)
645645
case LengthTest(len, exact) =>
646-
scrutinee
647-
.select(defn.Seq_lengthCompare.matchingMember(scrutinee.tpe))
648-
.appliedTo(Literal(Constant(len)))
649-
.select(if (exact) defn.Int_== else defn.Int_>=)
650-
.appliedTo(Literal(Constant(0)))
646+
val lengthCompareSym = defn.Seq_lengthCompare.matchingMember(scrutinee.tpe)
647+
if (lengthCompareSym.exists)
648+
scrutinee
649+
.select(defn.Seq_lengthCompare.matchingMember(scrutinee.tpe))
650+
.appliedTo(Literal(Constant(len)))
651+
.select(if (exact) defn.Int_== else defn.Int_>=)
652+
.appliedTo(Literal(Constant(0)))
653+
else // try length
654+
scrutinee
655+
.select(defn.Seq_length.matchingMember(scrutinee.tpe))
656+
.select(if (exact) defn.Int_== else defn.Int_>=)
657+
.appliedTo(Literal(Constant(len)))
651658
case TypeTest(tpt) =>
652659
val expectedTp = tpt.tpe
653660

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -100,29 +100,34 @@ object Applications {
100100
Nil
101101
}
102102

103-
def validUnapplySeqType(getTp: Type): Boolean = {
104-
def superType(elemTp: Type) = {
105-
val tps = List(
106-
MethodType(List("len".toTermName))(_ => defn.IntType :: Nil, _ => defn.IntType),
107-
MethodType(List("i".toTermName))(_ => defn.IntType :: Nil, _ => elemTp),
108-
MethodType(List("n".toTermName))(_ => defn.IntType :: Nil, _ => defn.SeqType.appliedTo(elemTp)),
109-
ExprType(defn.SeqType.appliedTo(elemTp)),
110-
)
111-
val names = List(nme.lengthCompare, nme.apply, nme.drop, nme.toSeq)
112-
RefinedType.make(defn.AnyType, names, tps)
113-
}
114-
getTp <:< superType(WildcardType) && {
115-
val seqArg = extractorMemberType(getTp, nme.toSeq).elemType.hiBound
116-
getTp <:< superType(seqArg)
117-
}
103+
def unapplySeqTypeElemTp(getTp: Type): Type = {
104+
val lengthTp = ExprType(defn.IntType)
105+
val lengthCompareTp = MethodType(List("len".toTermName))(_ => defn.IntType :: Nil, _ => defn.IntType)
106+
def applyTp(elemTp: Type) = MethodType(List("i".toTermName))(_ => defn.IntType :: Nil, _ => elemTp)
107+
def dropTp(elemTp: Type) = MethodType(List("n".toTermName))(_ => defn.IntType :: Nil, _ => defn.SeqType.appliedTo(elemTp))
108+
def toSeqTp(elemTp: Type) = ExprType(defn.SeqType.appliedTo(elemTp))
109+
110+
val elemTp = getTp.member(nme.apply).suchThat(_.info <:< applyTp(WildcardType)).info.resultType
111+
112+
def names1 = List(nme.lengthCompare, nme.apply, nme.drop, nme.toSeq)
113+
def types1 = List(lengthCompareTp, applyTp(elemTp), dropTp(elemTp), toSeqTp(elemTp))
114+
115+
def names2 = List(nme.length, nme.apply, nme.drop, nme.toSeq)
116+
def types2 = List(lengthTp, applyTp(elemTp), dropTp(elemTp), toSeqTp(elemTp))
117+
118+
val valid = getTp <:< RefinedType.make(defn.AnyType, names1, types1) ||
119+
getTp <:< RefinedType.make(defn.AnyType, names2, types2)
120+
121+
if (valid) elemTp else NoType
118122
}
119123

124+
def validUnapplySeqType(getTp: Type): Boolean = unapplySeqTypeElemTp(getTp).exists
125+
120126
if (unapplyName == nme.unapplySeq) {
121127
if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil
122128
else if (isGetMatch(unapplyResult, pos) && validUnapplySeqType(getTp)) {
123-
val seqArg = extractorMemberType(getTp, nme.toSeq).elemType.hiBound
124-
if (seqArg.exists) args.map(Function.const(seqArg))
125-
else fail
129+
val elemTp = unapplySeqTypeElemTp(getTp)
130+
args.map(Function.const(elemTp))
126131
}
127132
else fail
128133
}

tests/run/i4984c.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
object Array2 {
2+
def unapplySeq(x: Array[Int]): Data = new Data
3+
4+
final class Data {
5+
def isEmpty: Boolean = false
6+
def get: Data = this
7+
def length: Int = 2
8+
def apply(i: Int): Int = 3
9+
def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
10+
def toSeq: scala.Seq[Int] = Seq(6, 7)
11+
}
12+
}
13+
14+
object Test {
15+
def test1(xs: Array[Int]): Int = xs match {
16+
case Array2(x, y) => x + y
17+
}
18+
19+
def test2(xs: Array[Int]): Seq[Int] = xs match {
20+
case Array2(x, y, xs:_*) => xs
21+
}
22+
23+
def test3(xs: Array[Int]): Seq[Int] = xs match {
24+
case Array2(xs:_*) => xs
25+
}
26+
27+
def main(args: Array[String]): Unit = {
28+
test1(Array(3, 5))
29+
test2(Array(3, 5))
30+
test3(Array(3, 5))
31+
}
32+
}

0 commit comments

Comments
 (0)