diff --git a/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/constructor/tree/CgMethodConstructor.kt b/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/constructor/tree/CgMethodConstructor.kt index 5eb26f42d7..c8571ed1dd 100644 --- a/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/constructor/tree/CgMethodConstructor.kt +++ b/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/constructor/tree/CgMethodConstructor.kt @@ -231,8 +231,10 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c val fieldAccessible = field.isAccessibleFrom(testClassPackageName) // prevValue is nullable if not accessible because of getStaticFieldValue(..) : Any? - val prevValue = newVar(CgClassId(field.type, isNullable = !fieldAccessible), - "prev${field.name.capitalize()}") { + val prevValue = newVar( + CgClassId(field.type, isNullable = !fieldAccessible), + "prev${field.name.capitalize()}" + ) { if (fieldAccessible) { declaringClass[field] } else { @@ -1198,7 +1200,8 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c it.variableName, // guard initializer to reuse typecast creation logic initializer = guardExpression(varType, nullLiteral()).expression, - isMutable = true) + isMutable = true, + ) } +tryWithMocksFinallyClosing } @@ -1253,10 +1256,13 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c } val method = currentExecutable as MethodId val containsFailureExecution = containsFailureExecution(testSet) - if (method.returnType != voidClassId) { + + val expectedResultClassId = wrapTypeIfRequired(method.returnType) + + if (expectedResultClassId != voidClassId) { testArguments += CgParameterDeclaration( - expectedResultVarName, resultClassId(method.returnType), - isReferenceType = containsFailureExecution || !method.returnType.isPrimitive + expectedResultVarName, resultClassId(expectedResultClassId), + isReferenceType = containsFailureExecution || !expectedResultClassId.isPrimitive ) } if (containsFailureExecution) { diff --git a/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/constructor/util/CgStatementConstructor.kt b/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/constructor/util/CgStatementConstructor.kt index 97a94e313b..ba1f8e1faa 100644 --- a/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/constructor/util/CgStatementConstructor.kt +++ b/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/constructor/util/CgStatementConstructor.kt @@ -151,6 +151,8 @@ interface CgStatementConstructor { fun declareVariable(type: ClassId, name: String): CgVariable fun guardExpression(baseType: ClassId, expression: CgExpression): ExpressionWithType + + fun wrapTypeIfRequired(baseType: ClassId): ClassId } internal class CgStatementConstructorImpl(context: CgContext) : @@ -385,6 +387,9 @@ internal class CgStatementConstructorImpl(context: CgContext) : updateVariableScope(it) } + override fun wrapTypeIfRequired(baseType: ClassId): ClassId = + if (baseType.isAccessibleFrom(testClassPackageName)) baseType else objectClassId + // utils private fun classRefOrNull(type: ClassId, expr: CgExpression): ClassId? { @@ -444,11 +449,8 @@ internal class CgStatementConstructorImpl(context: CgContext) : if (call.executableId != mockMethodId) return guardExpression(baseType, call) // call represents a call to mock() method - return if (baseType.isAccessibleFrom(testClassPackageName)) { - ExpressionWithType(baseType, call) - } else { - ExpressionWithType(objectClassId, call) - } + val wrappedType = wrapTypeIfRequired(baseType) + return ExpressionWithType(wrappedType, call) } override fun guardExpression(baseType: ClassId, expression: CgExpression): ExpressionWithType {