diff --git a/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/domain/models/CgElement.kt b/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/domain/models/CgElement.kt index 4144455d9f..2d0d82a431 100644 --- a/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/domain/models/CgElement.kt +++ b/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/domain/models/CgElement.kt @@ -888,7 +888,8 @@ class CgIfStatement( data class CgSwitchCaseLabel( val label: CgLiteral? = null, // have to be compile time constant (null for default label) - val statements: MutableList + val statements: List, + val addBreakStatementToEnd: Boolean = true // do not set this field to "true" value if you manually added "break" to statements ) : CgStatement data class CgSwitchCase( diff --git a/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/renderer/CgJavaRenderer.kt b/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/renderer/CgJavaRenderer.kt index b178b221ca..456e00b373 100644 --- a/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/renderer/CgJavaRenderer.kt +++ b/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/renderer/CgJavaRenderer.kt @@ -314,8 +314,10 @@ internal class CgJavaRenderer(context: CgRendererContext, printer: CgPrinter = C for (statement in element.statements) { statement.accept(this) } - // break statement in the end - CgBreakStatement.accept(this) + + if (element.addBreakStatementToEnd) { + CgBreakStatement.accept(this) + } } } diff --git a/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/services/framework/MockFrameworkManager.kt b/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/services/framework/MockFrameworkManager.kt index fb332a8a6d..2ea9dd8ada 100644 --- a/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/services/framework/MockFrameworkManager.kt +++ b/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/services/framework/MockFrameworkManager.kt @@ -22,6 +22,7 @@ import org.utbot.framework.codegen.domain.context.CgContext import org.utbot.framework.codegen.domain.context.CgContextOwner import org.utbot.framework.codegen.domain.models.CgAnonymousFunction import org.utbot.framework.codegen.domain.models.CgAssignment +import org.utbot.framework.codegen.domain.models.CgBreakStatement import org.utbot.framework.codegen.domain.models.CgConstructorCall import org.utbot.framework.codegen.domain.models.CgDeclaration import org.utbot.framework.codegen.domain.models.CgExecutableCall @@ -226,7 +227,6 @@ private class MockitoStaticMocker(context: CgContext, private val mocker: Object nameGenerator.variableName(MOCK_CLASS_COUNTER_NAME), CgConstructorCall(ConstructorId(atomicIntegerClassId, emptyList()), emptyList()) ) - +mockClassCounter val mocksExecutablesAnswers = mock .instances @@ -242,13 +242,19 @@ private class MockitoStaticMocker(context: CgContext, private val mocker: Object mocksExecutablesAnswers, mockClassCounter.variable ) + + if (mockConstructionInitializer.isMockClassCounterRequired) { + // We should insert the counter declaration only if we use this counter, for better readability. + +mockClassCounter + } + val mockedConstructionDeclaration = CgDeclaration( MockitoStaticMocking.mockedConstructionClassId, nameGenerator.variableName(MOCKED_CONSTRUCTION_NAME), - mockConstructionInitializer + mockConstructionInitializer.mockConstructionCall ) resources += mockedConstructionDeclaration - +CgAssignment(mockedConstructionDeclaration.variable, mockConstructionInitializer) + +CgAssignment(mockedConstructionDeclaration.variable, mockConstructionInitializer.mockConstructionCall) mockedStaticConstructions += classId } @@ -317,9 +323,9 @@ private class MockitoStaticMocker(context: CgContext, private val mocker: Object private fun mockConstruction( clazz: CgExpression, classId: ClassId, - mocksWhenAnswers: List>>, + mocksWhenAnswers: List>>, mockClassCounter: CgVariable - ): CgMethodCall { + ): MockConstructionBlock { val mockParameter = variableConstructor.declareParameter( classId, nameGenerator.variableName(classId.simpleName, isMock = true) @@ -333,6 +339,8 @@ private class MockitoStaticMocker(context: CgContext, private val mocker: Object for ((index, mockWhenAnswers) in mocksWhenAnswers.withIndex()) { val statements = mutableListOf() for ((executable, values) in mockWhenAnswers) { + // For now, all constructors are considered like void methods, but it is proposed to be changed + // for better constructors testing. if (executable.returnType == voidClassId) continue when (executable) { @@ -343,7 +351,7 @@ private class MockitoStaticMocker(context: CgContext, private val mocker: Object mocker.`when`(mockParameter[executable](*matchers))[thenReturnMethodId](*results) ) } - else -> error("Expected MethodId but got ConstructorId $executable") + is ConstructorId -> error("Expected MethodId but got ConstructorId $executable") } } @@ -352,15 +360,36 @@ private class MockitoStaticMocker(context: CgContext, private val mocker: Object val switchCase = CgSwitchCase(mockClassCounter[atomicIntegerGet](), caseLabels) + // If all switch-case labels are empty, + // it means we do not need this switch and mock counter itself at all. + val mockConstructionBody = if (caseLabels.map { it.statements }.all { it.isEmpty() }) { + emptyList() + } else { + listOf(switchCase, CgStatementExecutableCall(mockClassCounter[atomicIntegerGetAndIncrement]())) + } + val answersBlock = CgAnonymousFunction( voidClassId, listOf(mockParameter, contextParameter).map { CgParameterDeclaration(it, isVararg = false) }, - listOf(switchCase, CgStatementExecutableCall(mockClassCounter[atomicIntegerGetAndIncrement]())) + mockConstructionBody ) - return mockitoClassId[MockitoStaticMocking.mockConstructionMethodId](clazz, answersBlock) + return MockConstructionBlock( + mockitoClassId[MockitoStaticMocking.mockConstructionMethodId](clazz, answersBlock), + mockConstructionBody.isNotEmpty() + ) } + /** + * Represents a body for invocation of the [MockitoStaticMocking.mockConstructionMethodId] method and information + * whether we need to use a counter for different mocking invocations + * (i.e., on each mocking we expect possibly different results). + */ + private data class MockConstructionBlock( + val mockConstructionCall: CgMethodCall, + val isMockClassCounterRequired: Boolean + ) + private fun mockStatic(clazz: CgExpression): CgMethodCall = mockitoClassId[MockitoStaticMocking.mockStaticMethodId](clazz)