-
Notifications
You must be signed in to change notification settings - Fork 46
Move the Torch wrapper to the separate module #821
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
19 commits
Select commit
Hold shift + click to select a range
8248346
Added an initial solution
amandelpie ceda18a
Just renamed some weird names
amandelpie d413892
Removed unused settings
amandelpie f2c89df
Added minimal readme.md and fix some typos
amandelpie 24e92af
Merge branch 'main' into amandelpie/extract-torch-wrapper
amandelpie ef4d4be
Fixed some names
amandelpie 2b0a2b8
Fixed some names
amandelpie 5af2f41
Fixed gradle
amandelpie 22b4904
Handle engine not found exception from DJL
amandelpie 979b9f6
Handle engine not found exception from DJL
amandelpie ebfe04e
Handle engine not found exception from DJL
amandelpie e53b367
Handle engine not found exception from DJL
amandelpie a7d7e25
Merge branch 'main' into amandelpie/extract-torch-wrapper
amandelpie 5e9e145
Changed the setter
amandelpie c2c1d5b
Removed the comment
amandelpie ffd578b
Changed the visibility level for some dependencies
amandelpie a861916
Turn off utbot-analytics-torch module
amandelpie 44dc9d7
Turn off utbot-analytics-torch module
amandelpie 3c27cf0
Merge branch 'main' into amandelpie/extract-torch-wrapper
amandelpie File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
apply from: "${parent.projectDir}/gradle/include/jvm-project.gradle" | ||
|
||
configurations { | ||
torchmodels | ||
} | ||
|
||
def osName = System.getProperty('os.name').toLowerCase().split()[0] | ||
if (osName == "mac") osName = "macosx" | ||
String classifier = osName + "-x86_64" | ||
|
||
evaluationDependsOn(':utbot-framework') | ||
compileTestJava.dependsOn tasks.getByPath(':utbot-framework:testClasses') | ||
|
||
dependencies { | ||
api project(':utbot-analytics') | ||
testImplementation project(':utbot-sample') | ||
testImplementation group: 'junit', name: 'junit', version: junit4_version | ||
|
||
implementation group: 'org.bytedeco', name: 'arpack-ng', version: "3.7.0-1.5.4", classifier: "$classifier" | ||
amandelpie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
implementation group: 'org.bytedeco', name: 'openblas', version: "0.3.10-1.5.4", classifier: "$classifier" | ||
implementation group: 'org.bytedeco', name: 'javacpp', version: javacpp_version, classifier: "$classifier" | ||
implementation group: 'org.jsoup', name: 'jsoup', version: jsoup_version | ||
|
||
implementation "ai.djl:api:$djl_api_version" | ||
implementation "ai.djl.pytorch:pytorch-engine:$djl_api_version" | ||
implementation "ai.djl.pytorch:pytorch-native-auto:$pytorch_native_version" | ||
|
||
testImplementation project(':utbot-framework').sourceSets.test.output | ||
} | ||
|
||
test { | ||
minHeapSize = "128m" | ||
maxHeapSize = "3072m" | ||
|
||
jvmArgs '-XX:MaxHeapSize=3072m' | ||
|
||
useJUnitPlatform() { | ||
excludeTags 'slow', 'IntegrationTest' | ||
} | ||
} | ||
|
||
processResources { | ||
configurations.torchmodels.resolvedConfiguration.resolvedArtifacts.each { artifact -> | ||
from(zipTree(artifact.getFile())) { | ||
into "models" | ||
} | ||
} | ||
} | ||
|
||
jar { | ||
dependsOn classes | ||
manifest { | ||
attributes 'Main-Class': 'org.utbot.QualityAnalysisKt' | ||
} | ||
|
||
dependsOn configurations.runtimeClasspath | ||
from { | ||
configurations.runtimeClasspath.collect { it.isDirectory() ? it : zipTree(it) } | ||
} | ||
|
||
duplicatesStrategy = DuplicatesStrategy.EXCLUDE | ||
zip64 = true | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
To enable support of the `utbot-analytics-torch` models in `utbot-intellij` module the following steps should be made: | ||
|
||
- change the row `api project(':utbot-analytics-torch')` to the `api project(':utbot-analytics-torch')` in the `build.gradle` file in the `utbot-intellij` module | ||
amandelpie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
- change the `pathSelectorType` in the `UtSettings.kt` to the `PathSelectorType.TORCH_SELECTOR` | ||
- don't forget the put the Torch model in the path ruled by the setting `modelPath` in the `UtSettings.kt` | ||
|
||
NOTE: for Windows you could obtain the error message related to the "engine not found problem" from DJL library during the Torch model initialization. | ||
The proposed solution from DJL authors includes the installation of the [Microsoft Visual C++ Redistributable.](https://docs.microsoft.com/en-us/cpp/windows/latest-supported-vc-redist?view=msvc-170) | ||
|
||
But at this moment it doesn't work on Windows at all. |
21 changes: 21 additions & 0 deletions
21
utbot-analytics-torch/src/main/kotlin/org/utbot/AnalyticsTorchConfiguration.kt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
package org.utbot | ||
|
||
import org.utbot.analytics.EngineAnalyticsContext | ||
import org.utbot.features.FeatureExtractorFactoryImpl | ||
import org.utbot.features.FeatureProcessorWithStatesRepetitionFactory | ||
import org.utbot.predictors.StateRewardPredictorWithTorchModelsSupportFactoryImpl | ||
|
||
/** | ||
* The basic configuration of the utbot-analytics-torch module used in utbot-intellij and (as planned) in utbot-cli | ||
* to implement the hidden configuration initialization to avoid direct calls of this configuration and usage of utbot-analytics-torch imports. | ||
* | ||
* @see <a href="https://github.com/UnitTestBot/UTBotJava/issues/725"> | ||
* Issue: Enable utbot-analytics module in utbot-intellij module</a> | ||
*/ | ||
object AnalyticsTorchConfiguration { | ||
init { | ||
EngineAnalyticsContext.featureProcessorFactory = FeatureProcessorWithStatesRepetitionFactory() | ||
EngineAnalyticsContext.featureExtractorFactory = FeatureExtractorFactoryImpl() | ||
EngineAnalyticsContext.stateRewardPredictorFactory = StateRewardPredictorWithTorchModelsSupportFactoryImpl() | ||
} | ||
} |
11 changes: 11 additions & 0 deletions
11
utbot-analytics-torch/src/main/kotlin/org/utbot/predictors/StateRewardPredictorFactory.kt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
package org.utbot.predictors | ||
|
||
import org.utbot.analytics.StateRewardPredictorFactory | ||
import org.utbot.framework.UtSettings | ||
|
||
/** | ||
* Creates [StateRewardPredictor], by checking the [UtSettings] configuration. | ||
*/ | ||
class StateRewardPredictorWithTorchModelsSupportFactoryImpl : StateRewardPredictorFactory { | ||
override operator fun invoke() = StateRewardPredictorTorch() | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
48 changes: 48 additions & 0 deletions
48
utbot-analytics-torch/src/test/kotlin/org/utbot/predictors/NNStateRewardPredictorTest.kt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
package org.utbot.predictors | ||
|
||
import org.junit.jupiter.api.Assertions.assertEquals | ||
import org.junit.jupiter.api.Disabled | ||
import org.junit.jupiter.api.Test | ||
import org.utbot.analytics.StateRewardPredictor | ||
import org.utbot.testcheckers.withModelPath | ||
import kotlin.system.measureNanoTime | ||
|
||
class NNStateRewardPredictorTest { | ||
amandelpie marked this conversation as resolved.
Show resolved
Hide resolved
|
||
@Test | ||
@Disabled("Just to see the performance of predictors") | ||
fun simpleTest() { | ||
withModelPath("src/test/resources") { | ||
val pred = StateRewardPredictorTorch() | ||
|
||
val features = listOf(0.0, 0.0) | ||
|
||
assertEquals(5.0, pred.predict(features)) | ||
} | ||
} | ||
|
||
@Disabled("Just to see the performance of predictors") | ||
@Test | ||
fun performanceTest() { | ||
val features = (1..13).map { 1.0 }.toList() | ||
withModelPath("models") { | ||
val averageTime = calcAverageTimeForModelPredict(::StateRewardPredictorTorch, 100, features) | ||
println(averageTime) | ||
} | ||
} | ||
|
||
private fun calcAverageTimeForModelPredict( | ||
model: () -> StateRewardPredictor, | ||
iterations: Int, | ||
features: List<Double> | ||
): Double { | ||
val pred = model() | ||
|
||
(1..iterations).map { | ||
pred.predict(features) | ||
} | ||
|
||
return (1..iterations) | ||
.map { measureNanoTime { pred.predict(features) } } | ||
.average() | ||
} | ||
} |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<Configuration> | ||
<Appenders> | ||
<RollingFile name="FrameworkAppender" | ||
fileName="logs/utbot.log" | ||
filePattern="logs/utbot-%d{MM-dd-yyyy-HH-mm-ss}.log.gz" | ||
ignoreExceptions="false"> | ||
<PatternLayout pattern="%d{HH:mm:ss.SSS} | %-5level | %c{1} | %msg%n"/> | ||
<Policies> | ||
<OnStartupTriggeringPolicy/> | ||
<SizeBasedTriggeringPolicy size="100 MB"/> | ||
</Policies> | ||
</RollingFile> | ||
|
||
<Console name="Console" target="SYSTEM_OUT"> | ||
<PatternLayout pattern="%d{HH:mm:ss.SSS} | %-5level | %msg%n"/> | ||
</Console> | ||
</Appenders> | ||
<Loggers> | ||
<Logger name="smile" level="trace"> | ||
<AppenderRef ref="Console"/> | ||
</Logger> | ||
|
||
|
||
<Logger name="org.utbot.models" level="trace"/> | ||
|
||
<Logger name="org.utbot" level="debug"> | ||
<AppenderRef ref="FrameworkAppender"/> | ||
</Logger> | ||
|
||
<!-- uncomment to log solver check --> | ||
<!-- <Logger name="org.utbot.engine.pc" level="trace">--> | ||
<!-- <AppenderRef ref="Console"/>--> | ||
<!-- </Logger>--> | ||
|
||
<Root level="info"> | ||
<!-- <AppenderRef ref="Console"/>--> | ||
</Root> | ||
</Loggers> | ||
</Configuration> |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.