Skip to content

Commit 792389a

Browse files
authored
Add support for top-level Kotlin functions #847 (#1147)
1 parent 651f7d4 commit 792389a

File tree

17 files changed

+190
-30
lines changed

17 files changed

+190
-30
lines changed

utbot-core/src/main/kotlin/org/utbot/common/KClassUtil.kt

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,6 @@ package org.utbot.common
22

33
import java.lang.reflect.InvocationTargetException
44
import java.lang.reflect.Method
5-
import kotlin.reflect.KClass
6-
7-
val Class<*>.nameOfPackage: String get() = `package`?.name?:""
85

96
/**
107
* Invokes [this] method of passed [obj] instance (null for static methods) with the passed [args] arguments.
@@ -16,7 +13,4 @@ fun Method.invokeCatching(obj: Any?, args: List<Any?>) = try {
1613
Result.success(invocation)
1714
} catch (e: InvocationTargetException) {
1815
Result.failure<Nothing>(e.targetException)
19-
}
20-
21-
val KClass<*>.allNestedClasses: List<KClass<*>>
22-
get() = listOf(this) + nestedClasses.flatMap { it.allNestedClasses }
16+
}

utbot-core/src/main/kotlin/org/utbot/common/ReflectionUtil.kt

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,10 @@ val Class<*>.isFinal
111111
get() = Modifier.isFinal(modifiers)
112112

113113
val Class<*>.isProtected
114-
get() = Modifier.isProtected(modifiers)
114+
get() = Modifier.isProtected(modifiers)
115+
116+
val Class<*>.nameOfPackage: String
117+
get() = `package`?.name?:""
118+
119+
val Class<*>.allNestedClasses: List<Class<*>>
120+
get() = listOf(this) + this.declaredClasses.flatMap { it.allNestedClasses }

utbot-framework-api/src/main/kotlin/org/utbot/framework/plugin/api/util/IdUtil.kt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,14 @@ import kotlin.reflect.KCallable
2424
import kotlin.reflect.KClass
2525
import kotlin.reflect.KFunction
2626
import kotlin.reflect.KProperty
27+
import kotlin.reflect.full.extensionReceiverParameter
2728
import kotlin.reflect.full.instanceParameter
29+
import kotlin.reflect.jvm.internal.impl.load.kotlin.header.KotlinClassHeader
2830
import kotlin.reflect.jvm.javaConstructor
2931
import kotlin.reflect.jvm.javaField
3032
import kotlin.reflect.jvm.javaGetter
3133
import kotlin.reflect.jvm.javaMethod
34+
import kotlin.reflect.jvm.kotlinFunction
3235

3336
// ClassId utils
3437

@@ -187,6 +190,14 @@ val ClassId.isDoubleType: Boolean
187190
val ClassId.isClassType: Boolean
188191
get() = this == classClassId
189192

193+
/**
194+
* Checks if the class is a Kotlin class with kind File (see [Metadata.kind] for more details)
195+
*/
196+
val ClassId.isKotlinFile: Boolean
197+
get() = jClass.annotations.filterIsInstance<Metadata>().singleOrNull()?.let {
198+
KotlinClassHeader.Kind.getById(it.kind) == KotlinClassHeader.Kind.FILE_FACADE
199+
} ?: false
200+
190201
val voidClassId = ClassId("void")
191202
val booleanClassId = ClassId("boolean")
192203
val byteClassId = ClassId("byte")
@@ -450,6 +461,12 @@ val MethodId.method: Method
450461
?: error("Can't find method $signature in ${declaringClass.name}")
451462
}
452463

464+
/**
465+
* See [KCallable.extensionReceiverParameter] for more details
466+
*/
467+
val MethodId.extensionReceiverParameterIndex: Int?
468+
get() = this.method.kotlinFunction?.extensionReceiverParameter?.index
469+
453470
// TODO: maybe cache it somehow in the future
454471
val ConstructorId.constructor: Constructor<*>
455472
get() {
@@ -504,6 +521,7 @@ val Method.displayName: String
504521

505522
val KCallable<*>.declaringClazz: Class<*>
506523
get() = when (this) {
524+
is KFunction<*> -> javaMethod?.declaringClass?.kotlin
507525
is CallableReference -> owner as? KClass<*>
508526
else -> instanceParameter?.type?.classifier as? KClass<*>
509527
}?.java ?: tryConstructor(this) ?: error("Can't get parent class for $this")
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package org.utbot.examples.codegen
2+
3+
import org.junit.jupiter.api.Test
4+
import org.utbot.testcheckers.eq
5+
import org.utbot.tests.infrastructure.UtValueTestCaseChecker
6+
import kotlin.reflect.KFunction3
7+
8+
@Suppress("UNCHECKED_CAST")
9+
internal class FileWithTopLevelFunctionsTest : UtValueTestCaseChecker(testClass = FileWithTopLevelFunctionsReflectHelper.clazz.kotlin) {
10+
@Test
11+
fun topLevelSumTest() {
12+
check(
13+
::topLevelSum,
14+
eq(1),
15+
)
16+
}
17+
18+
@Test
19+
fun extensionOnBasicTypeTest() {
20+
check(
21+
Int::extensionOnBasicType,
22+
eq(1),
23+
)
24+
}
25+
26+
@Test
27+
fun extensionOnCustomClassTest() {
28+
check(
29+
// NB: cast is important here because we need to treat receiver as an argument to be able to check its content in matchers
30+
CustomClass::extensionOnCustomClass as KFunction3<*, CustomClass, CustomClass, Boolean>,
31+
eq(2),
32+
{ receiver, argument, result -> receiver === argument && result == true },
33+
{ receiver, argument, result -> receiver !== argument && result == false },
34+
additionalDependencies = dependenciesForClassExtensions
35+
)
36+
}
37+
38+
companion object {
39+
// Compilation of extension methods for ref objects produces call to
40+
// `kotlin.jvm.internal.Intrinsics::checkNotNullParameter`, so we need to add it to dependencies
41+
val dependenciesForClassExtensions = arrayOf<Class<*>>(kotlin.jvm.internal.Intrinsics::class.java)
42+
}
43+
}

utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/constructor/tree/CgCallableAccessManager.kt

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ import org.utbot.framework.plugin.api.FieldId
5151
import org.utbot.framework.plugin.api.MethodId
5252
import org.utbot.framework.plugin.api.UtExplicitlyThrownException
5353
import org.utbot.framework.plugin.api.util.exceptions
54+
import org.utbot.framework.plugin.api.util.extensionReceiverParameterIndex
55+
import org.utbot.framework.plugin.api.util.humanReadableName
5456
import org.utbot.framework.plugin.api.util.id
5557
import org.utbot.framework.plugin.api.util.isAbstract
5658
import org.utbot.framework.plugin.api.util.isArray
@@ -114,7 +116,7 @@ internal class CgCallableAccessManagerImpl(val context: CgContext) : CgCallableA
114116
override operator fun CgIncompleteMethodCall.invoke(vararg args: Any?): CgMethodCall {
115117
val resolvedArgs = args.resolve()
116118
val methodCall = if (method.canBeCalledWith(caller, resolvedArgs)) {
117-
CgMethodCall(caller, method, resolvedArgs.guardedForDirectCallOf(method))
119+
CgMethodCall(caller, method, resolvedArgs.guardedForDirectCallOf(method)).takeCallerFromArgumentsIfNeeded()
118120
} else {
119121
method.callWithReflection(caller, resolvedArgs)
120122
}
@@ -198,6 +200,29 @@ internal class CgCallableAccessManagerImpl(val context: CgContext) : CgCallableA
198200
else -> false
199201
}
200202

203+
/**
204+
* For Kotlin extension functions, real caller is one of the arguments in JVM method (and declaration class is omitted),
205+
* thus we should move it from arguments to caller
206+
*
207+
* For example, if we have `Int.f(a: Int)` declared in `Main.kt`, the JVM method signature will be `MainKt.f(Int, Int)`
208+
* and in Kotlin we should render this not like `MainKt.f(a, b)` but like `a.f(b)`
209+
*/
210+
private fun CgMethodCall.takeCallerFromArgumentsIfNeeded(): CgMethodCall {
211+
if (codegenLanguage == CodegenLanguage.KOTLIN) {
212+
// TODO: reflection calls for util and some of mockito methods produce exceptions => here we suppose that
213+
// methods for BuiltinClasses are not extensions by default (which should be true as long as we suppose them to be java methods)
214+
if (executableId.classId !is BuiltinClassId) {
215+
executableId.extensionReceiverParameterIndex?.let { receiverIndex ->
216+
require(caller == null) { "${executableId.humanReadableName} is an extension function but it already has a non-static caller provided" }
217+
val args = arguments.toMutableList()
218+
return CgMethodCall(args.removeAt(receiverIndex), executableId, args, typeParameters)
219+
}
220+
}
221+
}
222+
223+
return this
224+
}
225+
201226
private infix fun CgExpression.canBeArgOf(type: ClassId): Boolean {
202227
// TODO: SAT-1210 support generics so that we wouldn't need to check specific cases such as this one
203228
if (this is CgExecutableCall && (executableId == any || executableId == anyOfClass)) {

utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/visitor/CgAbstractRenderer.kt

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,13 @@ internal abstract class CgAbstractRenderer(
131131
}
132132
}
133133

134+
/**
135+
* Returns true if one can call methods of this class without specifying a caller (for example if ClassId represents this instance)
136+
*/
137+
protected abstract val ClassId.methodsAreAccessibleAsTopLevel: Boolean
138+
134139
private val MethodId.accessibleByName: Boolean
135-
get() = (context.shouldOptimizeImports && this in context.importedStaticMethods) || classId == context.generatedClass
140+
get() = (context.shouldOptimizeImports && this in context.importedStaticMethods) || classId.methodsAreAccessibleAsTopLevel
136141

137142
override fun visit(element: CgElement) {
138143
val error =
@@ -654,8 +659,10 @@ internal abstract class CgAbstractRenderer(
654659
}
655660

656661
override fun visit(element: CgStaticFieldAccess) {
657-
print(element.declaringClass.asString())
658-
print(".")
662+
if (!element.declaringClass.methodsAreAccessibleAsTopLevel) {
663+
print(element.declaringClass.asString())
664+
print(".")
665+
}
659666
print(element.fieldName)
660667
}
661668

@@ -707,7 +714,10 @@ internal abstract class CgAbstractRenderer(
707714
if (caller != null) {
708715
// 'this' can be omitted, otherwise render caller
709716
if (caller !is CgThisInstance) {
717+
// TODO: we need parentheses for calls like (-1).inv(), do something smarter here
718+
if (caller !is CgVariable) print("(")
710719
caller.accept(this)
720+
if (caller !is CgVariable) print(")")
711721
renderAccess(caller)
712722
}
713723
} else {

utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/visitor/CgJavaRenderer.kt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ internal class CgJavaRenderer(context: CgRendererContext, printer: CgPrinter = C
6464

6565
override val langPackage: String = "java.lang"
6666

67+
override val ClassId.methodsAreAccessibleAsTopLevel: Boolean
68+
get() = this == context.generatedClass
69+
6770
override fun visit(element: AbstractCgClass<*>) {
6871
for (annotation in element.annotations) {
6972
annotation.accept(this)

utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/visitor/CgKotlinRenderer.kt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ import org.utbot.framework.plugin.api.util.isProtected
5757
import org.utbot.framework.plugin.api.util.isPublic
5858
import org.utbot.framework.plugin.api.util.id
5959
import org.utbot.framework.plugin.api.util.isArray
60+
import org.utbot.framework.plugin.api.util.isKotlinFile
6061
import org.utbot.framework.plugin.api.util.isPrimitive
6162
import org.utbot.framework.plugin.api.util.isPrimitiveWrapper
6263
import org.utbot.framework.plugin.api.util.kClass
@@ -76,6 +77,10 @@ internal class CgKotlinRenderer(context: CgRendererContext, printer: CgPrinter =
7677

7778
override val langPackage: String = "kotlin"
7879

80+
override val ClassId.methodsAreAccessibleAsTopLevel: Boolean
81+
// NB: the order of operands is important as `isKotlinFile` uses reflection and thus can't be called on context.generatedClass
82+
get() = (this == context.generatedClass) || isKotlinFile
83+
7984
override fun visit(element: AbstractCgClass<*>) {
8085
for (annotation in element.annotations) {
8186
annotation.accept(this)

utbot-framework/src/main/kotlin/org/utbot/framework/plugin/api/SignatureUtil.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@ import kotlin.reflect.KFunction
44
import kotlin.reflect.KParameter
55
import kotlin.reflect.jvm.javaType
66

7+
// Note that rules for obtaining signature here should correlate with PsiMethod.signature()
78
fun KFunction<*>.signature() =
8-
Signature(this.name, this.parameters.filter { it.kind == KParameter.Kind.VALUE }.map { it.type.javaType.typeName })
9+
Signature(this.name, this.parameters.filter { it.kind != KParameter.Kind.INSTANCE }.map { it.type.javaType.typeName })
910

1011
data class Signature(val name: String, val parameterTypes: List<String?>) {
1112

utbot-framework/src/main/kotlin/org/utbot/framework/process/EngineMain.kt

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ import org.utbot.framework.plugin.api.util.UtContext
2121
import org.utbot.framework.plugin.api.util.executableId
2222
import org.utbot.framework.plugin.api.util.id
2323
import org.utbot.framework.plugin.api.util.jClass
24+
import org.utbot.framework.plugin.api.util.method
2425
import org.utbot.framework.plugin.services.JdkInfo
2526
import org.utbot.framework.process.generated.*
2627
import org.utbot.framework.util.ConflictTriggers
@@ -36,7 +37,7 @@ import org.utbot.summary.summarize
3637
import java.io.File
3738
import java.net.URLClassLoader
3839
import java.nio.file.Paths
39-
import kotlin.reflect.full.functions
40+
import kotlin.reflect.jvm.kotlinFunction
4041
import kotlin.time.Duration.Companion.seconds
4142

4243
private val messageFromMainTimeoutMillis = 120.seconds
@@ -158,8 +159,8 @@ private fun EngineProcessModel.setup(
158159
synchronizer.measureExecutionForTermination(findMethodsInClassMatchingSelected) { params ->
159160
val classId = kryoHelper.readObject<ClassId>(params.classId)
160161
val selectedSignatures = params.signatures.map { Signature(it.name, it.parametersTypes) }
161-
FindMethodsInClassMatchingSelectedResult(kryoHelper.writeObject(classId.jClass.kotlin.allNestedClasses.flatMap { clazz ->
162-
clazz.functions.sortedWith(compareBy { selectedSignatures.indexOf(it.signature()) })
162+
FindMethodsInClassMatchingSelectedResult(kryoHelper.writeObject(classId.jClass.allNestedClasses.flatMap { clazz ->
163+
clazz.id.allMethods.mapNotNull { it.method.kotlinFunction }.sortedWith(compareBy { selectedSignatures.indexOf(it.signature()) })
163164
.filter { it.signature().normalized() in selectedSignatures }
164165
.map { it.executableId }
165166
}))
@@ -168,7 +169,7 @@ private fun EngineProcessModel.setup(
168169
val classId = kryoHelper.readObject<ClassId>(params.classId)
169170
val bySignature = kryoHelper.readObject<Map<Signature, List<String>>>(params.bySignature)
170171
FindMethodParamNamesResult(kryoHelper.writeObject(
171-
classId.jClass.kotlin.allNestedClasses.flatMap { it.functions }
172+
classId.jClass.allNestedClasses.flatMap { clazz -> clazz.id.allMethods.mapNotNull { it.method.kotlinFunction } }
172173
.mapNotNull { method -> bySignature[method.signature()]?.let { params -> method.executableId to params } }
173174
.toMap()
174175
))

utbot-intellij/src/main/kotlin/org/utbot/intellij/plugin/process/EngineProcess.kt

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,16 @@ class EngineProcess(parent: Lifetime, val project: Project) {
238238
}
239239

240240
private fun MemberInfo.paramNames(): List<String> =
241-
(this.member as PsiMethod).parameterList.parameters.map { it.name }
241+
(this.member as PsiMethod).parameterList.parameters.map {
242+
if (it.name.startsWith("\$this"))
243+
// If member is Kotlin extension function, name of first argument isn't good for further usage,
244+
// so we better choose name based on type of receiver.
245+
//
246+
// There seems no API to check whether parameter is an extension receiver by PSI
247+
it.type.presentableText
248+
else
249+
it.name
250+
}
242251

243252
fun generate(
244253
mockInstalled: Boolean,

utbot-intellij/src/main/kotlin/org/utbot/intellij/plugin/ui/actions/GenerateTestsAction.kt

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import com.intellij.openapi.vfs.VirtualFile
1818
import com.intellij.psi.*
1919
import com.intellij.psi.util.PsiTreeUtil
2020
import com.intellij.refactoring.util.classMembers.MemberInfo
21+
import org.jetbrains.kotlin.asJava.findFacadeClass
2122
import org.jetbrains.kotlin.idea.core.getPackage
2223
import org.jetbrains.kotlin.idea.core.util.toPsiDirectory
2324
import org.jetbrains.kotlin.idea.core.util.toPsiFile
@@ -26,6 +27,7 @@ import org.utbot.intellij.plugin.util.extractFirstLevelMembers
2627
import org.utbot.intellij.plugin.util.isVisible
2728
import java.util.*
2829
import org.jetbrains.kotlin.j2k.getContainingClass
30+
import org.jetbrains.kotlin.psi.KtFile
2931
import org.jetbrains.kotlin.utils.addIfNotNull
3032
import org.utbot.framework.plugin.api.util.LockFile
3133
import org.utbot.intellij.plugin.models.packageName
@@ -218,7 +220,7 @@ class GenerateTestsAction : AnAction(), UpdateInBackground {
218220
}
219221

220222
private fun getAllClasses(directory: PsiDirectory): Set<PsiClass> {
221-
val allClasses = directory.files.flatMap { getClassesFromFile(it) }.toMutableSet()
223+
val allClasses = directory.files.flatMap { PsiElementHandler.makePsiElementHandler(it).getClassesFromFile(it) }.toMutableSet()
222224
for (subDir in directory.subdirectories) allClasses += getAllClasses(subDir)
223225
return allClasses
224226
}
@@ -231,15 +233,10 @@ class GenerateTestsAction : AnAction(), UpdateInBackground {
231233
if (!dirsArePackages) {
232234
return emptySet()
233235
}
234-
val allClasses = psiFiles.flatMap { getClassesFromFile(it) }.toMutableSet()
236+
val allClasses = psiFiles.flatMap { PsiElementHandler.makePsiElementHandler(it).getClassesFromFile(it) }.toMutableSet()
237+
allClasses.addAll(psiFiles.mapNotNull { (it as? KtFile)?.findFacadeClass() })
235238
for (psiDir in psiDirectories) allClasses += getAllClasses(psiDir)
236239

237240
return allClasses
238241
}
239-
240-
private fun getClassesFromFile(psiFile: PsiFile): List<PsiClass> {
241-
val psiElementHandler = PsiElementHandler.makePsiElementHandler(psiFile)
242-
return PsiTreeUtil.getChildrenOfTypeAsList(psiFile, psiElementHandler.classClass)
243-
.map { psiElementHandler.toPsi(it, PsiClass::class.java) }
244-
}
245242
}

utbot-intellij/src/main/kotlin/org/utbot/intellij/plugin/ui/utils/KotlinPsiElementHandler.kt

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package org.utbot.intellij.plugin.ui.utils
22

33
import com.intellij.psi.PsiClass
44
import com.intellij.psi.PsiElement
5+
import com.intellij.psi.PsiFile
6+
import com.intellij.psi.util.findParentOfType
7+
import org.jetbrains.kotlin.asJava.findFacadeClass
58
import org.jetbrains.kotlin.idea.testIntegration.KotlinCreateTestIntention
69
import org.jetbrains.kotlin.psi.KtClass
710
import org.jetbrains.kotlin.psi.KtClassOrObject
@@ -24,13 +27,27 @@ class KotlinPsiElementHandler(
2427
return element.toUElement()?.javaPsi as? T ?: error("Could not cast $element to $clazz")
2528
}
2629

27-
override fun isCreateTestActionAvailable(element: PsiElement): Boolean =
28-
getTarget(element)?.let { KotlinCreateTestIntention().applicabilityRange(it) != null } ?: false
30+
override fun getClassesFromFile(psiFile: PsiFile): List<PsiClass> {
31+
return listOfNotNull((psiFile as? KtFile)?.findFacadeClass()) + super.getClassesFromFile(psiFile)
32+
}
33+
34+
override fun isCreateTestActionAvailable(element: PsiElement): Boolean {
35+
getTarget(element)?.let {
36+
return KotlinCreateTestIntention().applicabilityRange(it) != null
37+
}
38+
return (element.containingFile as? KtFile)?.findFacadeClass() != null
39+
}
2940

3041
private fun getTarget(element: PsiElement?): KtNamedDeclaration? =
3142
element?.parentsWithSelf
3243
?.firstOrNull { it is KtClassOrObject || it is KtNamedDeclaration && it.parent is KtFile } as? KtNamedDeclaration
3344

34-
override fun containingClass(element: PsiElement): PsiClass? =
35-
element.parentsWithSelf.firstOrNull { it is KtClassOrObject }?.let { toPsi(it, PsiClass::class.java) }
45+
override fun containingClass(element: PsiElement): PsiClass? {
46+
element.findParentOfType<KtClassOrObject>(strict=false)?.let {
47+
return toPsi(it, PsiClass::class.java)
48+
}
49+
return element.findParentOfType<KtFile>(strict=false)?.findFacadeClass()?.let {
50+
toPsi(it, PsiClass::class.java)
51+
}
52+
}
3653
}

0 commit comments

Comments
 (0)