Skip to content

Commit 439a0d8

Browse files
authored
Move the Torch wrapper to the separate module (#821)
* Added an initial solution * Just renamed some weird names * Removed unused settings * Added minimal readme.md and fix some typos * Fixed some names * Fixed some names * Fixed gradle * Handle engine not found exception from DJL * Handle engine not found exception from DJL * Handle engine not found exception from DJL * Handle engine not found exception from DJL * Changed the setter * Removed the comment * Changed the visibility level for some dependencies
1 parent 779e053 commit 439a0d8

File tree

25 files changed

+317
-116
lines changed

25 files changed

+317
-116
lines changed

settings.gradle

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ include 'utbot-sample'
1818
include 'utbot-fuzzers'
1919
include 'utbot-junit-contest'
2020
include 'utbot-analytics'
21+
include 'utbot-analytics-torch'
2122
include 'utbot-cli'
2223
include 'utbot-api'
2324
include 'utbot-instrumentation'

utbot-analytics-torch/build.gradle

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
apply from: "${parent.projectDir}/gradle/include/jvm-project.gradle"
2+
3+
configurations {
4+
torchmodels
5+
}
6+
7+
def osName = System.getProperty('os.name').toLowerCase().split()[0]
8+
if (osName == "mac") osName = "macosx"
9+
String classifier = osName + "-x86_64"
10+
11+
evaluationDependsOn(':utbot-framework')
12+
compileTestJava.dependsOn tasks.getByPath(':utbot-framework:testClasses')
13+
14+
dependencies {
15+
api project(':utbot-analytics')
16+
testImplementation project(':utbot-sample')
17+
testImplementation group: 'junit', name: 'junit', version: junit4_version
18+
19+
implementation group: 'org.bytedeco', name: 'arpack-ng', version: "3.7.0-1.5.4", classifier: "$classifier"
20+
implementation group: 'org.bytedeco', name: 'openblas', version: "0.3.10-1.5.4", classifier: "$classifier"
21+
implementation group: 'org.bytedeco', name: 'javacpp', version: javacpp_version, classifier: "$classifier"
22+
implementation group: 'org.jsoup', name: 'jsoup', version: jsoup_version
23+
24+
implementation "ai.djl:api:$djl_api_version"
25+
implementation "ai.djl.pytorch:pytorch-engine:$djl_api_version"
26+
implementation "ai.djl.pytorch:pytorch-native-auto:$pytorch_native_version"
27+
28+
testImplementation project(':utbot-framework').sourceSets.test.output
29+
}
30+
31+
test {
32+
minHeapSize = "128m"
33+
maxHeapSize = "3072m"
34+
35+
jvmArgs '-XX:MaxHeapSize=3072m'
36+
37+
useJUnitPlatform() {
38+
excludeTags 'slow', 'IntegrationTest'
39+
}
40+
}
41+
42+
processResources {
43+
configurations.torchmodels.resolvedConfiguration.resolvedArtifacts.each { artifact ->
44+
from(zipTree(artifact.getFile())) {
45+
into "models"
46+
}
47+
}
48+
}
49+
50+
jar {
51+
dependsOn classes
52+
manifest {
53+
attributes 'Main-Class': 'org.utbot.QualityAnalysisKt'
54+
}
55+
56+
dependsOn configurations.runtimeClasspath
57+
from {
58+
configurations.runtimeClasspath.collect { it.isDirectory() ? it : zipTree(it) }
59+
}
60+
61+
duplicatesStrategy = DuplicatesStrategy.EXCLUDE
62+
zip64 = true
63+
}

utbot-analytics-torch/readme.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
To enable support of the `utbot-analytics-torch` models in `utbot-intellij` module the following steps should be made:
2+
3+
- change the row `api project(':utbot-analytics-torch')` to the `api project(':utbot-analytics-torch')` in the `build.gradle` file in the `utbot-intellij` module
4+
- change the `pathSelectorType` in the `UtSettings.kt` to the `PathSelectorType.TORCH_SELECTOR`
5+
- don't forget the put the Torch model in the path ruled by the setting `modelPath` in the `UtSettings.kt`
6+
7+
NOTE: for Windows you could obtain the error message related to the "engine not found problem" from DJL library during the Torch model initialization.
8+
The proposed solution from DJL authors includes the installation of the [Microsoft Visual C++ Redistributable.](https://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-170)
9+
10+
But at this moment it doesn't work on Windows at all.
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
package org.utbot
2+
3+
import org.utbot.analytics.EngineAnalyticsContext
4+
import org.utbot.features.FeatureExtractorFactoryImpl
5+
import org.utbot.features.FeatureProcessorWithStatesRepetitionFactory
6+
import org.utbot.predictors.StateRewardPredictorWithTorchModelsSupportFactoryImpl
7+
8+
/**
9+
* The basic configuration of the utbot-analytics-torch module used in utbot-intellij and (as planned) in utbot-cli
10+
* to implement the hidden configuration initialization to avoid direct calls of this configuration and usage of utbot-analytics-torch imports.
11+
*
12+
* @see <a href="https://github.com/UnitTestBot/UTBotJava/issues/725">
13+
* Issue: Enable utbot-analytics module in utbot-intellij module</a>
14+
*/
15+
object AnalyticsTorchConfiguration {
16+
init {
17+
EngineAnalyticsContext.featureProcessorFactory = FeatureProcessorWithStatesRepetitionFactory()
18+
EngineAnalyticsContext.featureExtractorFactory = FeatureExtractorFactoryImpl()
19+
EngineAnalyticsContext.stateRewardPredictorFactory = StateRewardPredictorWithTorchModelsSupportFactoryImpl()
20+
}
21+
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package org.utbot.predictors
2+
3+
import org.utbot.analytics.StateRewardPredictorFactory
4+
import org.utbot.framework.UtSettings
5+
6+
/**
7+
* Creates [StateRewardPredictor], by checking the [UtSettings] configuration.
8+
*/
9+
class StateRewardPredictorWithTorchModelsSupportFactoryImpl : StateRewardPredictorFactory {
10+
override operator fun invoke() = StateRewardPredictorTorch()
11+
}

utbot-analytics/src/main/kotlin/org/utbot/predictors/StateRewardPredictorTorch.kt renamed to utbot-analytics-torch/src/main/kotlin/org/utbot/predictors/StateRewardPredictorTorch.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ import java.io.Closeable
1212
import java.nio.file.Paths
1313

1414
class StateRewardPredictorTorch : StateRewardPredictor, Closeable {
15-
val model: Model = Model.newInstance("model")
15+
val model: Model
1616

1717
init {
18-
model.load(Paths.get(UtSettings.rewardModelPath, "model.pt1"))
18+
model = Model.newInstance("model")
19+
model.load(Paths.get(UtSettings.modelPath, "model.pt1"))
1920
}
2021

2122
private val predictor: Predictor<List<Float>, Float> = model.newPredictor(object : Translator<List<Float>, Float> {
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
package org.utbot.predictors
2+
3+
import org.junit.jupiter.api.Assertions.assertEquals
4+
import org.junit.jupiter.api.Disabled
5+
import org.junit.jupiter.api.Test
6+
import org.utbot.analytics.StateRewardPredictor
7+
import org.utbot.testcheckers.withModelPath
8+
import kotlin.system.measureNanoTime
9+
10+
class NNStateRewardPredictorTest {
11+
@Test
12+
@Disabled("Just to see the performance of predictors")
13+
fun simpleTest() {
14+
withModelPath("src/test/resources") {
15+
val pred = StateRewardPredictorTorch()
16+
17+
val features = listOf(0.0, 0.0)
18+
19+
assertEquals(5.0, pred.predict(features))
20+
}
21+
}
22+
23+
@Disabled("Just to see the performance of predictors")
24+
@Test
25+
fun performanceTest() {
26+
val features = (1..13).map { 1.0 }.toList()
27+
withModelPath("models") {
28+
val averageTime = calcAverageTimeForModelPredict(::StateRewardPredictorTorch, 100, features)
29+
println(averageTime)
30+
}
31+
}
32+
33+
private fun calcAverageTimeForModelPredict(
34+
model: () -> StateRewardPredictor,
35+
iterations: Int,
36+
features: List<Double>
37+
): Double {
38+
val pred = model()
39+
40+
(1..iterations).map {
41+
pred.predict(features)
42+
}
43+
44+
return (1..iterations)
45+
.map { measureNanoTime { pred.predict(features) } }
46+
.average()
47+
}
48+
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
<?xml version="1.0" encoding="UTF-8"?>
2+
<Configuration>
3+
<Appenders>
4+
<RollingFile name="FrameworkAppender"
5+
fileName="logs/utbot.log"
6+
filePattern="logs/utbot-%d{MM-dd-yyyy-HH-mm-ss}.log.gz"
7+
ignoreExceptions="false">
8+
<PatternLayout pattern="%d{HH:mm:ss.SSS} | %-5level | %c{1} | %msg%n"/>
9+
<Policies>
10+
<OnStartupTriggeringPolicy/>
11+
<SizeBasedTriggeringPolicy size="100 MB"/>
12+
</Policies>
13+
</RollingFile>
14+
15+
<Console name="Console" target="SYSTEM_OUT">
16+
<PatternLayout pattern="%d{HH:mm:ss.SSS} | %-5level | %msg%n"/>
17+
</Console>
18+
</Appenders>
19+
<Loggers>
20+
<Logger name="smile" level="trace">
21+
<AppenderRef ref="Console"/>
22+
</Logger>
23+
24+
25+
<Logger name="org.utbot.models" level="trace"/>
26+
27+
<Logger name="org.utbot" level="debug">
28+
<AppenderRef ref="FrameworkAppender"/>
29+
</Logger>
30+
31+
<!-- uncomment to log solver check -->
32+
<!-- <Logger name="org.utbot.engine.pc" level="trace">-->
33+
<!-- <AppenderRef ref="Console"/>-->
34+
<!-- </Logger>-->
35+
36+
<Root level="info">
37+
<!-- <AppenderRef ref="Console"/>-->
38+
</Root>
39+
</Loggers>
40+
</Configuration>

utbot-analytics/build.gradle

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,13 @@ evaluationDependsOn(':utbot-framework')
1212
compileTestJava.dependsOn tasks.getByPath(':utbot-framework:testClasses')
1313

1414
dependencies {
15-
implementation(project(":utbot-api"))
16-
implementation(project(":utbot-core"))
17-
implementation(project(":utbot-summary"))
18-
implementation(project(":utbot-framework-api"))
19-
implementation(project(":utbot-fuzzers"))
20-
implementation(project(":utbot-instrumentation"))
21-
implementation(project(":utbot-framework"))
15+
api(project(":utbot-api"))
16+
api(project(":utbot-core"))
17+
api(project(":utbot-summary"))
18+
api(project(":utbot-framework-api"))
19+
api(project(":utbot-fuzzers"))
20+
api(project(":utbot-instrumentation"))
21+
api(project(":utbot-framework"))
2222
testImplementation project(':utbot-sample')
2323
testImplementation group: 'junit', name: 'junit', version: junit4_version
2424

@@ -38,20 +38,12 @@ dependencies {
3838
implementation group: 'tech.tablesaw', name: 'tablesaw-jsplot', version: '0.38.2'
3939

4040
implementation group: 'org.apache.commons', name: 'commons-text', version: '1.9'
41-
4241
implementation group: 'com.github.javaparser', name: 'javaparser-core', version: '3.22.1'
4342

44-
implementation group: 'org.jsoup', name: 'jsoup', version: jsoup_version
45-
46-
implementation "ai.djl:api:$djl_api_version"
47-
implementation "ai.djl.pytorch:pytorch-engine:$djl_api_version"
48-
implementation "ai.djl.pytorch:pytorch-native-auto:$pytorch_native_version"
49-
5043
testImplementation project(':utbot-framework').sourceSets.test.output
5144
}
5245

5346
test {
54-
5547
minHeapSize = "128m"
5648
maxHeapSize = "3072m"
5749

utbot-analytics/src/main/kotlin/org/utbot/predictors/LinearStateRewardPredictor.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ private val logger = KotlinLogging.logger {}
1919
* Last weight is bias
2020
*/
2121
private fun loadWeights(path: String): Matrix {
22-
val weightsFile = File("${UtSettings.rewardModelPath}/${path}")
22+
val weightsFile = File("${UtSettings.modelPath}/${path}")
2323
lateinit var weightsArray: DoubleArray
2424

2525
try {

utbot-analytics/src/main/kotlin/org/utbot/predictors/NNJson.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ data class NNJson(
3333
}
3434

3535
internal fun loadModel(path: String): NNJson {
36-
val modelFile = Paths.get(UtSettings.rewardModelPath, path).toFile()
36+
val modelFile = Paths.get(UtSettings.modelPath, path).toFile()
3737
lateinit var nnJson: NNJson
3838

3939
try {

utbot-analytics/src/main/kotlin/org/utbot/predictors/StateRewardPredictorFactory.kt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ import org.utbot.framework.UtSettings
1010
class StateRewardPredictorFactoryImpl : StateRewardPredictorFactory {
1111
override operator fun invoke() = when (UtSettings.stateRewardPredictorType) {
1212
StateRewardPredictorType.BASE -> NNStateRewardPredictorBase()
13-
StateRewardPredictorType.TORCH -> StateRewardPredictorTorch()
1413
StateRewardPredictorType.LINEAR -> LinearStateRewardPredictor()
1514
}
1615
}

utbot-analytics/src/main/kotlin/org/utbot/predictors/scalers.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ data class StandardScaler(val mean: Matrix?, val variance: Matrix?)
1313

1414
internal fun loadScaler(path: String): StandardScaler =
1515
try {
16-
Paths.get(UtSettings.rewardModelPath, path).toFile().bufferedReader().use {
16+
Paths.get(UtSettings.modelPath, path).toFile().bufferedReader().use {
1717
val mean = it.readLine()?.splitByCommaIntoDoubleArray() ?: error("There is not mean in $path")
1818
val variance = it.readLine()?.splitByCommaIntoDoubleArray() ?: error("There is not variance in $path")
1919
StandardScaler(Matrix(mean), Matrix(variance))

utbot-analytics/src/test/kotlin/org/utbot/predictors/LinearStateRewardPredictorTest.kt

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ import org.junit.jupiter.api.Test
55
import org.utbot.framework.PathSelectorType
66
import org.utbot.framework.UtSettings
77
import org.utbot.testcheckers.withPathSelectorType
8-
import org.utbot.testcheckers.withRewardModelPath
8+
import org.utbot.testcheckers.withModelPath
99

1010
class LinearStateRewardPredictorTest {
1111
@Test
1212
fun simpleTest() {
13-
withRewardModelPath("src/test/resources") {
13+
withModelPath("src/test/resources") {
1414
val pred = LinearStateRewardPredictor()
1515

1616
val features = listOf(
@@ -24,8 +24,8 @@ class LinearStateRewardPredictorTest {
2424

2525
@Test
2626
fun wrongFormatTest() {
27-
withRewardModelPath("src/test/resources") {
28-
withPathSelectorType(PathSelectorType.NN_REWARD_GUIDED_SELECTOR) {
27+
withModelPath("src/test/resources") {
28+
withPathSelectorType(PathSelectorType.ML_SELECTOR) {
2929
LinearStateRewardPredictor("wrong_format_linear.txt")
3030
assertEquals(PathSelectorType.INHERITORS_SELECTOR, UtSettings.pathSelectorType)
3131
}
@@ -34,7 +34,7 @@ class LinearStateRewardPredictorTest {
3434

3535
@Test
3636
fun simpleTestNotBatch() {
37-
withRewardModelPath("src/test/resources") {
37+
withModelPath("src/test/resources") {
3838
val pred = LinearStateRewardPredictor()
3939

4040
val features = listOf(2.0, 3.0)

0 commit comments

Comments
 (0)