diff --git a/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/domain/models/builders/SpringTestClassModelBuilder.kt b/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/domain/models/builders/SpringTestClassModelBuilder.kt index edf3ea6593..a890671e6a 100644 --- a/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/domain/models/builders/SpringTestClassModelBuilder.kt +++ b/utbot-framework/src/main/kotlin/org/utbot/framework/codegen/domain/models/builders/SpringTestClassModelBuilder.kt @@ -19,6 +19,8 @@ import org.utbot.framework.plugin.api.UtNullModel import org.utbot.framework.plugin.api.UtPrimitiveModel import org.utbot.framework.plugin.api.UtVoidModel import org.utbot.framework.plugin.api.isMockModel +import org.utbot.framework.plugin.api.UtStatementCallModel +import org.utbot.framework.plugin.api.UtDirectSetFieldModel import org.utbot.framework.plugin.api.util.SpringModelUtils.isAutowiredFromContext import org.utbot.framework.plugin.api.canBeSpied @@ -59,7 +61,7 @@ class SpringTestClassModelBuilder(val context: CgContext) : (execution.stateBefore.parameters + execution.stateBefore.thisInstance) .filterNotNull() - .forEach { model -> stateBeforeDependentModels += collectDependentModels(model) } + .forEach { model -> stateBeforeDependentModels += collectAutowiredModels(model) } } } } @@ -88,11 +90,24 @@ class SpringTestClassModelBuilder(val context: CgContext) : ) } + private fun collectAutowiredModels(model: UtModel): Set { + val allDependentModels = mutableSetOf() + + collectRecursively(model, allDependentModels) + + return allDependentModels + } + + private fun collectRecursively(model: UtModel, allDependentModels: MutableSet){ + if(!allDependentModels.add(model.wrap())){ + return + } + collectDependentModels(model).forEach { collectRecursively(it.model, allDependentModels) } + } + private fun collectDependentModels(model: UtModel): Set { val dependentModels = mutableSetOf() - dependentModels.add(model.wrap()) - when (model) { is UtNullModel, is UtPrimitiveModel, @@ -118,6 +133,16 @@ class SpringTestClassModelBuilder(val context: CgContext) : is UtAssembleModel -> { model.instantiationCall.instance?.let { dependentModels.add(it.wrap()) } model.instantiationCall.params.forEach { dependentModels.add(it.wrap()) } + + if(model.isAutowiredFromContext()) { + model.modificationsChain.forEach { stmt -> + stmt.instance?.let { dependentModels.add(it.wrap()) } + when (stmt) { + is UtStatementCallModel -> stmt.params.forEach { dependentModels.add(it.wrap()) } + is UtDirectSetFieldModel -> dependentModels.add(stmt.fieldModel.wrap()) + } + } + } } }