Skip to content

Fix #9439: Translate Java varargs ...T into T* instead of (T & Object)* #9451

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 28, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions compiler/src/dotty/tools/dotc/ast/Desugar.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1684,10 +1684,12 @@ object desugar {
Apply(Select(Apply(scalaDot(nme.StringContext), strs), id).withSpan(tree.span), elems)
case PostfixOp(t, op) =>
if ((ctx.mode is Mode.Type) && !isBackquoted(op) && op.name == tpnme.raw.STAR) {
val seqType = if (ctx.compilationUnit.isJava) defn.ArrayType else defn.SeqType
Annotated(
AppliedTypeTree(ref(seqType), t),
New(ref(defn.RepeatedAnnot.typeRef), Nil :: Nil))
if ctx.compilationUnit.isJava then
AppliedTypeTree(ref(defn.RepeatedParamType), t)
else
Annotated(
AppliedTypeTree(ref(defn.SeqType), t),
New(ref(defn.RepeatedAnnot.typeRef), Nil :: Nil))
}
else {
assert(ctx.mode.isExpr || ctx.reporter.errorsReported || ctx.mode.is(Mode.Interactive), ctx.mode)
Expand Down
2 changes: 2 additions & 0 deletions compiler/src/dotty/tools/dotc/core/Definitions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -399,6 +399,8 @@ class Definitions {
def runtimeMethodRef(name: PreName): TermRef = ScalaRuntimeModule.requiredMethodRef(name)
def ScalaRuntime_drop: Symbol = runtimeMethodRef(nme.drop).symbol
@tu lazy val ScalaRuntime__hashCode: Symbol = ScalaRuntimeModule.requiredMethod(nme._hashCode_)
@tu lazy val ScalaRuntime_toArray: Symbol = ScalaRuntimeModule.requiredMethod(nme.toArray)
@tu lazy val ScalaRuntime_toObjectArray: Symbol = ScalaRuntimeModule.requiredMethod(nme.toObjectArray)

@tu lazy val BoxesRunTimeModule: Symbol = requiredModule("scala.runtime.BoxesRunTime")
@tu lazy val BoxesRunTimeModule_externalEquals: Symbol = BoxesRunTimeModule.info.decl(nme.equals_).suchThat(toDenot(_).info.firstParamTypes.size == 2).symbol
Expand Down
55 changes: 41 additions & 14 deletions compiler/src/dotty/tools/dotc/core/classfile/ClassfileParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -280,15 +280,13 @@ class ClassfileParser(
addConstructorTypeParams(denot)
}

denot.info = pool.getType(in.nextChar)
val isVarargs = denot.is(Flags.Method) && (jflags & JAVA_ACC_VARARGS) != 0
denot.info = pool.getType(in.nextChar, isVarargs)
if (isEnum) denot.info = ConstantType(Constant(sym))
if (isConstructor) normalizeConstructorParams()
denot.info = translateTempPoly(parseAttributes(sym, denot.info))
denot.info = translateTempPoly(parseAttributes(sym, denot.info, isVarargs))
if (isConstructor) normalizeConstructorInfo()

if (denot.is(Flags.Method) && (jflags & JAVA_ACC_VARARGS) != 0)
denot.info = arrayToRepeated(denot.info)

if (ctx.explicitNulls) denot.info = JavaNullInterop.nullifyMember(denot.symbol, denot.info, isEnum)

// seal java enums
Expand Down Expand Up @@ -324,7 +322,7 @@ class ClassfileParser(
case BOOL_TAG => defn.BooleanType
}

private def sigToType(sig: SimpleName, owner: Symbol = null)(using Context): Type = {
private def sigToType(sig: SimpleName, owner: Symbol = null, isVarargs: Boolean = false)(using Context): Type = {
var index = 0
val end = sig.length
def accept(ch: Char): Unit = {
Expand Down Expand Up @@ -395,13 +393,42 @@ class ClassfileParser(
val elemtp = sig2type(tparams, skiptvs)
defn.ArrayOf(elemtp.translateJavaArrayElementType)
case '(' =>
// we need a method symbol. given in line 486 by calling getType(methodSym, ..)
def isMethodEnd(i: Int) = sig(i) == ')'
def isArray(i: Int) = sig(i) == '['

/** Is this a repeated parameter type?
* This is true if we're in a vararg method and this is the last parameter.
*/
def isRepeatedParam(i: Int): Boolean =
if !isVarargs then return false
var cur = i
// Repeated parameters are represented as arrays
if !isArray(cur) then return false
// Handle nested arrays: int[]...
while isArray(cur) do
cur += 1
// Simple check to see if we're the last parameter: there should be no
// array in the signature until the method end.
while !isMethodEnd(cur) do
if isArray(cur) then return false
cur += 1
true
end isRepeatedParam

val paramtypes = new ListBuffer[Type]()
var paramnames = new ListBuffer[TermName]()
while (sig(index) != ')') {
while !isMethodEnd(index) do
paramnames += nme.syntheticParamName(paramtypes.length)
paramtypes += objToAny(sig2type(tparams, skiptvs))
}
paramtypes += {
if isRepeatedParam(index) then
index += 1
val elemType = sig2type(tparams, skiptvs)
// `ElimRepeated` is responsible for correctly erasing this.
defn.RepeatedParamType.appliedTo(elemType)
else
objToAny(sig2type(tparams, skiptvs))
}

index += 1
val restype = sig2type(tparams, skiptvs)
JavaMethodType(paramnames.toList, paramtypes.toList, restype)
Expand Down Expand Up @@ -574,7 +601,7 @@ class ClassfileParser(
None // ignore malformed annotations
}

def parseAttributes(sym: Symbol, symtype: Type)(using Context): Type = {
def parseAttributes(sym: Symbol, symtype: Type, isVarargs: Boolean = false)(using Context): Type = {
var newType = symtype

def parseAttribute(): Unit = {
Expand All @@ -584,7 +611,7 @@ class ClassfileParser(
attrName match {
case tpnme.SignatureATTR =>
val sig = pool.getExternalName(in.nextChar)
newType = sigToType(sig, sym)
newType = sigToType(sig, sym, isVarargs)
if (ctx.debug && ctx.verbose)
println("" + sym + "; signature = " + sig + " type = " + newType)
case tpnme.SyntheticATTR =>
Expand Down Expand Up @@ -1103,8 +1130,8 @@ class ClassfileParser(
c
}

def getType(index: Int)(using Context): Type =
sigToType(getExternalName(index))
def getType(index: Int, isVarargs: Boolean = false)(using Context): Type =
sigToType(getExternalName(index), isVarargs = isVarargs)

def getSuperClass(index: Int)(using Context): Symbol = {
assert(index != 0, "attempt to parse java.lang.Object from classfile")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,21 +59,6 @@ object Scala2Unpickler {
denot.info = PolyType.fromParams(denot.owner.typeParams, denot.info)
}

/** Convert array parameters denoting a repeated parameter of a Java method
* to `RepeatedParamClass` types.
*/
def arrayToRepeated(tp: Type)(using Context): Type = tp match {
case tp: MethodType =>
val lastArg = tp.paramInfos.last
assert(lastArg isRef defn.ArrayClass)
tp.derivedLambdaType(
tp.paramNames,
tp.paramInfos.init :+ lastArg.translateParameterized(defn.ArrayClass, defn.RepeatedParamClass),
tp.resultType)
case tp: PolyType =>
tp.derivedLambdaType(tp.paramNames, tp.paramInfos, arrayToRepeated(tp.resultType))
}

def ensureConstructor(cls: ClassSymbol, scope: Scope)(using Context): Unit = {
if (scope.lookup(nme.CONSTRUCTOR) == NoSymbol) {
val constr = newDefaultConstructor(cls)
Expand Down
71 changes: 62 additions & 9 deletions compiler/src/dotty/tools/dotc/transform/ElimRepeated.scala
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ object ElimRepeated {
val name: String = "elimRepeated"
}

/** A transformer that removes repeated parameters (T*) from all types, replacing
* them with Seq types.
/** A transformer that eliminates repeated parameters (T*) from all types, replacing
* them with Seq or Array types and adapting repeated arguments to conform to
* the transformed type if needed.
*/
class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
import ast.tpd._
Expand Down Expand Up @@ -55,9 +56,28 @@ class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
case tp @ MethodTpe(paramNames, paramTypes, resultType) =>
val resultType1 = elimRepeated(resultType)
val paramTypes1 =
if paramTypes.nonEmpty && paramTypes.last.isRepeatedParam then
val last = paramTypes.last.translateFromRepeated(toArray = tp.isJavaMethod)
paramTypes.init :+ last
val lastIdx = paramTypes.length - 1
if lastIdx >= 0 then
val last = paramTypes(lastIdx)
if last.isRepeatedParam then
val isJava = tp.isJavaMethod
// A generic Java varargs `T...` where `T` is unbounded is erased to
// `Object[]` in bytecode, we directly translate such a type to
// `Array[_ <: Object]` instead of `Array[_ <: T]` here. This allows
// the tree transformer of this phase to emit the correct adaptation
// for repeated arguments if needed (for example, an `Array[Int]` will
// be copied into an `Array[Object]`, see `adaptToArray`).
val last1 =
if isJava && {
val elemTp = last.elemType
elemTp.isInstanceOf[TypeParamRef] && elemTp.typeSymbol == defn.AnyClass
}
then
defn.ArrayOf(TypeBounds.upper(defn.ObjectType))
else
last.translateFromRepeated(toArray = isJava)
paramTypes.updated(lastIdx, last1)
else paramTypes
else paramTypes
tp.derivedLambdaType(paramNames, paramTypes1, resultType1)
case tp: PolyType =>
Expand All @@ -82,9 +102,10 @@ class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
case arg: Typed if isWildcardStarArg(arg) =>
val isJavaDefined = tree.fun.symbol.is(JavaDefined)
val tpe = arg.expr.tpe
if isJavaDefined && tpe.derivesFrom(defn.SeqClass) then
seqToArray(arg.expr)
else if !isJavaDefined && tpe.derivesFrom(defn.ArrayClass)
if isJavaDefined then
val pt = tree.fun.tpe.widen.firstParamTypes.last
adaptToArray(arg.expr, pt.elemType.bounds.hi)
else if tpe.derivesFrom(defn.ArrayClass) then
arrayToSeq(arg.expr)
else
arg.expr
Expand All @@ -107,7 +128,39 @@ class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
.appliedToType(elemType)
.appliedTo(tree, clsOf(elemClass.typeRef))

/** Convert Java array argument to Scala Seq */
/** Adapt a Seq or Array tree to be a subtype of `Array[_ <: $elemPt]`.
*
* @pre `elemPt` must either be a super type of the argument element type or `Object`.
* The special handling of `Object` is required to deal with the translation
* of generic Java varargs in `elimRepeated`.
*/
private def adaptToArray(tree: Tree, elemPt: Type)(implicit ctx: Context): Tree =
val elemTp = tree.tpe.elemType
val treeIsArray = tree.tpe.derivesFrom(defn.ArrayClass)
if elemTp <:< elemPt then
if treeIsArray then
tree // no adaptation needed
else
tree match
case SeqLiteral(elems, elemtpt) =>
JavaSeqLiteral(elems, elemtpt).withSpan(tree.span)
case _ =>
// Convert a Seq[T] to an Array[$elemPt]
ref(defn.DottyArraysModule)
.select(nme.seqToArray)
.appliedToType(elemPt)
.appliedTo(tree, clsOf(elemPt))
else if treeIsArray then
// Convert an Array[T] to an Array[Object]
ref(defn.ScalaRuntime_toObjectArray)
.appliedTo(tree)
else
// Convert a Seq[T] to an Array[Object]
ref(defn.ScalaRuntime_toArray)
.appliedToType(elemTp)
.appliedTo(tree)

/** Convert an Array into a scala.Seq */
private def arrayToSeq(tree: Tree)(using Context): Tree =
tpd.wrapArray(tree, tree.tpe.elemType)

Expand Down
2 changes: 1 addition & 1 deletion compiler/src/dotty/tools/dotc/typer/Typer.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1723,7 +1723,7 @@ class Typer extends Namer
checkedArgs = checkedArgs.mapconserve(arg =>
checkSimpleKinded(checkNoWildcard(arg)))
else if (ctx.compilationUnit.isJava)
if (tpt1.symbol eq defn.ArrayClass) || (tpt1.symbol eq defn.RepeatedParamClass) then
if (tpt1.symbol eq defn.ArrayClass) then
checkedArgs match {
case List(arg) =>
val elemtp = arg.tpe.translateJavaArrayElementType
Expand Down
2 changes: 1 addition & 1 deletion tests/neg/i533/Test.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ object Test {
val x = new Array[Int](1)
x(0) = 10
println(JA.get(x)) // error
println(JA.getVarargs(x: _*)) // error
println(JA.getVarargs(x: _*)) // now OK.
}
}
1 change: 1 addition & 0 deletions tests/pos/arrays2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,4 +24,5 @@ one warning found
// #2461
object arrays3 {
def apply[X <: AnyRef](xs : X*) : java.util.List[X] = java.util.Arrays.asList(xs: _*)
def apply2[X](xs : X*) : java.util.List[X] = java.util.Arrays.asList(xs: _*)
}
17 changes: 17 additions & 0 deletions tests/run/i9439.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
object Test {
// First example with a concrete type <: AnyVal
def main(args: Array[String]): Unit = {
val coll = new java.util.ArrayList[Int]()
java.util.Collections.addAll(coll, 5, 6)
println(coll.size())

foo(5, 6)
}

// Second example with an abstract type not known to be <: AnyRef
def foo[A](a1: A, a2: A): Unit = {
val coll = new java.util.ArrayList[A]()
java.util.Collections.addAll(coll, a1, a2)
println(coll.size())
}
}
10 changes: 10 additions & 0 deletions tests/run/java-varargs-2/A.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
class A {
public static void foo(int... args) {
}

public static <T> void gen(T... args) {
}

public static <T extends java.io.Serializable> void gen2(T... args) {
}
}
13 changes: 13 additions & 0 deletions tests/run/java-varargs-2/Test.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
object Test {
def main(args: Array[String]): Unit = {
A.foo(1)
A.foo(Array(1): _*)
A.foo(Seq(1): _*)
A.gen(1)
A.gen(Array(1): _*)
A.gen(Seq(1): _*)
A.gen2("")
A.gen2(Array(""): _*)
A.gen2(Seq(""): _*)
}
}
3 changes: 3 additions & 0 deletions tests/run/java-varargs/A_1.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,7 @@ public static void foo(int... args) {

public static <T> void gen(T... args) {
}

public static <T extends java.io.Serializable> void gen2(T... args) {
}
}
5 changes: 5 additions & 0 deletions tests/run/java-varargs/Test_2.scala
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
object Test {
def main(args: Array[String]): Unit = {
A_1.foo(1)
A_1.foo(Array(1): _*)
A_1.foo(Seq(1): _*)
A_1.gen(1)
A_1.gen(Array(1): _*)
A_1.gen(Seq(1): _*)
A_1.gen2("")
A_1.gen2(Array(""): _*)
A_1.gen2(Seq(""): _*)
}
}