Skip to content

Commit 0728d33

Browse files
committed
Improve the suggested fix
1 parent af37218 commit 0728d33

File tree

2 files changed

+11
-12
lines changed

2 files changed

+11
-12
lines changed

utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/constructor/tree/CgMethodConstructor.kt

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,15 +1258,12 @@ internal class CgMethodConstructor(val context: CgContext) : CgContextOwner by c
12581258
val method = currentExecutable as MethodId
12591259
val containsFailureExecution = containsFailureExecution(testSet)
12601260

1261-
val expectedResultClassIdType = when {
1262-
method.returnType.isPrivate && !method.returnType.isPublic -> objectClassId
1263-
else -> method.returnType
1264-
}
1261+
val expectedResultClassId = wrapTypeIfRequired(method.returnType)
12651262

1266-
if (expectedResultClassIdType != voidClassId) {
1263+
if (expectedResultClassId != voidClassId) {
12671264
testArguments += CgParameterDeclaration(
1268-
expectedResultVarName, resultClassId(expectedResultClassIdType),
1269-
isReferenceType = containsFailureExecution || !expectedResultClassIdType.isPrimitive
1265+
expectedResultVarName, resultClassId(expectedResultClassId),
1266+
isReferenceType = containsFailureExecution || !expectedResultClassId.isPrimitive
12701267
)
12711268
}
12721269
if (containsFailureExecution) {

utbot-framework/src/main/kotlin/org/utbot/framework/codegen/model/constructor/util/CgStatementConstructor.kt

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@ interface CgStatementConstructor {
151151
fun declareVariable(type: ClassId, name: String): CgVariable
152152

153153
fun guardExpression(baseType: ClassId, expression: CgExpression): ExpressionWithType
154+
155+
fun wrapTypeIfRequired(baseType: ClassId): ClassId
154156
}
155157

156158
internal class CgStatementConstructorImpl(context: CgContext) :
@@ -385,6 +387,9 @@ internal class CgStatementConstructorImpl(context: CgContext) :
385387
updateVariableScope(it)
386388
}
387389

390+
override fun wrapTypeIfRequired(baseType: ClassId): ClassId =
391+
if (baseType.isAccessibleFrom(testClassPackageName)) baseType else objectClassId
392+
388393
// utils
389394

390395
private fun classRefOrNull(type: ClassId, expr: CgExpression): ClassId? {
@@ -444,11 +449,8 @@ internal class CgStatementConstructorImpl(context: CgContext) :
444449
if (call.executableId != mockMethodId) return guardExpression(baseType, call)
445450

446451
// call represents a call to mock() method
447-
return if (baseType.isAccessibleFrom(testClassPackageName)) {
448-
ExpressionWithType(baseType, call)
449-
} else {
450-
ExpressionWithType(objectClassId, call)
451-
}
452+
val wrappedType = wrapTypeIfRequired(baseType)
453+
return ExpressionWithType(wrappedType, call)
452454
}
453455

454456
override fun guardExpression(baseType: ClassId, expression: CgExpression): ExpressionWithType {

0 commit comments

Comments
 (0)