Skip to content

Commit ca23ba7

Browse files
authored
Add linearStateRewardPredictor and fix names (#302)
Add linearStateRewardPredictor and fix names
1 parent a17b4e6 commit ca23ba7

File tree

18 files changed

+162
-40
lines changed

18 files changed

+162
-40
lines changed

docs/jlearch/jlearch-architecture.md

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ classDiagram
3737
UtBotSymbolicEngine *-- InterproceduralUnitGraph
3838
3939
class Predictors
40-
class NNStateRewardPredictor
40+
class StateRewardPredictor
4141
class NNRewardGuidedSelector
4242
4343
@@ -50,12 +50,13 @@ classDiagram
5050
5151
UtBotSymbolicEngine *-- BasePathSelector
5252
53-
Predictors o-- NNStateRewardPredictor
53+
Predictors o-- StateRewardPredictor
5454
NNRewardGuidedSelector ..> Predictors
5555
NNRewardGuidedSelector *-- FeatureExtractor
5656
57-
NNStateRewardPredictorSmile --|> NNStateRewardPredictor
58-
NNStateRewardPredictorTorch --|> NNStateRewardPredictor
57+
NNStateRewardPredictorSmile --|> StateRewardPredictor
58+
StateRewardPredictorTorch --|> StateRewardPredictor
59+
LinearStateRewardPredictor --|> StateRewardPredictor
5960
6061
NNStateRewardGuidedSelectorWithRecalculationWeight --|> NNRewardGuidedSelector
6162
NNStateRewardGuidedSelectorWithoutRecalculationWeight --|> NNRewardGuidedSelector
@@ -129,12 +130,13 @@ For creating `FeatureExtractor`, it uses `FeatureExtractorFactory` from `EngineA
129130
It is interface in framework-module, that allows to use implementation from analytics module.
130131
* `extractFeatures(state: ExecutionState)` - create features list for state and store it in `state.features`. Now we extract all features, which were described in [paper](https://files.sri.inf.ethz.ch/website/papers/ccs21-learch.pdf). In feature, we can extend the feature list by other features, for example, NeuroSMT.
131132

132-
# NNStateRewardPredictor
133+
# StateRewardPredictor
133134

134-
Interface for reward predictors. Now it has two implementations in `analytics` module:
135+
Interface for reward predictors. Now it has three implementations in `analytics` module:
135136

136137
* `NNStateRewardPredictorSmile`: it uses our own format to store feedforward neural network, and it uses `Smile` library to do multiplication of matrix.
137138
* `NNStateRewardPredictorTorch`: it assumed that a model is any type of model in `pt` format. It uses the `Deep Java library` to use such models.
139+
* `LinearStateRewardPredictor`: it uses our own format to store weights vector: line of doubles, separated by comma with bias as last weight.
138140

139141
It should be created at the beginning of work and stored at `Predictors` class to be used in `NNRewardGuidedSelector` from the `framework` module.
140142

utbot-analytics/src/main/kotlin/org/utbot/predictors/FeedForwardNetwork.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
package org.utbot.predictors
22

3+
import org.utbot.predictors.util.ModelBuildingException
34
import smile.math.matrix.Matrix
45
import kotlin.math.max
56

@@ -26,7 +27,7 @@ internal fun buildModel(nnJson: NNJson): FeedForwardNetwork {
2627
operations.add {
2728
when (nnJson.activationLayers[i]) {
2829
ActivationFunctions.ReLU -> reLU(it)
29-
else -> error("Unsupported activation")
30+
else -> throw ModelBuildingException("Unsupported activation")
3031
}
3132
}
3233
}
Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
package org.utbot.predictors
22

3+
import org.utbot.analytics.StateRewardPredictor
34
import mu.KotlinLogging
4-
import org.utbot.analytics.UtBotAbstractPredictor
55
import org.utbot.framework.PathSelectorType
66
import org.utbot.framework.UtSettings
7+
import org.utbot.predictors.util.PredictorLoadingException
8+
import org.utbot.predictors.util.WeightsLoadingException
9+
import org.utbot.predictors.util.splitByCommaIntoDoubleArray
10+
import smile.math.MathEx.dot
711
import smile.math.matrix.Matrix
812
import java.io.File
913

@@ -16,32 +20,39 @@ private val logger = KotlinLogging.logger {}
1620
*/
1721
private fun loadWeights(path: String): Matrix {
1822
val weightsFile = File("${UtSettings.rewardModelPath}/${path}")
23+
lateinit var weightsArray: DoubleArray
1924

20-
if (!weightsFile.exists()) {
21-
error("There is no file with weights with path: ${weightsFile.absolutePath}")
22-
}
25+
try {
26+
if (!weightsFile.exists()) {
27+
error("There is no file with weights with path: ${weightsFile.absolutePath}")
28+
}
2329

24-
val weightsArray = weightsFile.readText().splitByCommaIntoDoubleArray()
30+
weightsArray = weightsFile.readText().splitByCommaIntoDoubleArray()
31+
} catch (e: Exception) {
32+
throw WeightsLoadingException(e)
33+
}
2534

2635
return Matrix(weightsArray)
2736
}
2837

29-
class LinearStateRewardPredictor(weightsPath: String = DEFAULT_WEIGHT_PATH) :
30-
UtBotAbstractPredictor<List<List<Double>>, List<Double>> {
38+
class LinearStateRewardPredictor(weightsPath: String = DEFAULT_WEIGHT_PATH, scalerPath: String = DEFAULT_SCALER_PATH) :
39+
StateRewardPredictor {
3140
private lateinit var weights: Matrix
41+
private lateinit var scaler: StandardScaler
3242

3343
init {
3444
try {
3545
weights = loadWeights(weightsPath)
36-
} catch (e: Exception) {
46+
scaler = loadScaler(scalerPath)
47+
} catch (e: PredictorLoadingException) {
3748
logger.info(e) {
3849
"Error while initialization of LinearStateRewardPredictor. Changing pathSelectorType on INHERITORS_SELECTOR"
3950
}
4051
UtSettings.pathSelectorType = PathSelectorType.INHERITORS_SELECTOR
4152
}
4253
}
4354

44-
override fun predict(input: List<List<Double>>): List<Double> {
55+
fun predict(input: List<List<Double>>): List<Double> {
4556
// add 1 to each feature vector
4657
val matrixValues = input
4758
.map { (it + 1.0).toDoubleArray() }
@@ -51,4 +62,11 @@ class LinearStateRewardPredictor(weightsPath: String = DEFAULT_WEIGHT_PATH) :
5162

5263
return X.mm(weights).col(0).toList()
5364
}
65+
66+
override fun predict(input: List<Double>): Double {
67+
var inputArray = Matrix(input.toDoubleArray()).sub(scaler.mean).div(scaler.variance).col(0)
68+
inputArray += 1.0
69+
70+
return dot(inputArray, weights.col(0))
71+
}
5472
}

utbot-analytics/src/main/kotlin/org/utbot/predictors/NNJson.kt

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package org.utbot.predictors
22

33
import com.google.gson.Gson
44
import org.utbot.framework.UtSettings
5+
import org.utbot.predictors.util.ModelLoadingException
56
import java.io.FileReader
67
import java.nio.file.Paths
78

@@ -33,10 +34,16 @@ data class NNJson(
3334

3435
internal fun loadModel(path: String): NNJson {
3536
val modelFile = Paths.get(UtSettings.rewardModelPath, path).toFile()
36-
val nnJson: NNJson =
37-
Gson().fromJson(FileReader(modelFile), NNJson::class.java) ?: run {
38-
error("Empty model")
39-
}
37+
lateinit var nnJson: NNJson
38+
39+
try {
40+
nnJson =
41+
Gson().fromJson(FileReader(modelFile), NNJson::class.java) ?: run {
42+
error("Empty model")
43+
}
44+
} catch (e: Exception) {
45+
throw ModelLoadingException(e)
46+
}
4047

4148
return nnJson
4249
}

utbot-analytics/src/main/kotlin/org/utbot/predictors/NNStateRewardPredictor.kt

Lines changed: 0 additions & 8 deletions
This file was deleted.

utbot-analytics/src/main/kotlin/org/utbot/predictors/NNStateRewardPredictorBase.kt

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,28 @@
11
package org.utbot.predictors
22

33
import mu.KotlinLogging
4+
import org.utbot.analytics.StateRewardPredictor
45
import org.utbot.framework.PathSelectorType
56
import org.utbot.framework.UtSettings
7+
import org.utbot.predictors.util.PredictorLoadingException
68
import smile.math.matrix.Matrix
79

810
private const val DEFAULT_MODEL_PATH = "nn.json"
9-
private const val DEFAULT_SCALER_PATH = "scaler.txt"
1011

1112
private val logger = KotlinLogging.logger {}
1213

1314
private fun getModel(path: String) = buildModel(loadModel(path))
1415

1516
class NNStateRewardPredictorBase(modelPath: String = DEFAULT_MODEL_PATH, scalerPath: String = DEFAULT_SCALER_PATH) :
16-
NNStateRewardPredictor {
17+
StateRewardPredictor {
1718
private lateinit var nn: FeedForwardNetwork
1819
private lateinit var scaler: StandardScaler
1920

2021
init {
2122
try {
2223
nn = getModel(modelPath)
2324
scaler = loadScaler(scalerPath)
24-
} catch (e: Exception) {
25+
} catch (e: PredictorLoadingException) {
2526
logger.info(e) {
2627
"Error while initialization of NNStateRewardPredictorBase. Changing pathSelectorType on INHERITORS_SELECTOR"
2728
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package org.utbot.predictors
2+
3+
import org.utbot.analytics.StateRewardPredictorFactory
4+
import org.utbot.framework.StateRewardPredictorType
5+
import org.utbot.framework.UtSettings
6+
7+
/**
8+
* Creates [StateRewardPredictor], by checking the [UtSettings] configuration.
9+
*/
10+
class StateRewardPredictorFactoryImpl : StateRewardPredictorFactory {
11+
override operator fun invoke() = when (UtSettings.stateRewardPredictorType) {
12+
StateRewardPredictorType.BASE -> NNStateRewardPredictorBase()
13+
StateRewardPredictorType.TORCH -> StateRewardPredictorTorch()
14+
StateRewardPredictorType.LINEAR -> LinearStateRewardPredictor()
15+
}
16+
}

utbot-analytics/src/main/kotlin/org/utbot/predictors/NNStateRewardPredictorTorch.kt renamed to utbot-analytics/src/main/kotlin/org/utbot/predictors/StateRewardPredictorTorch.kt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@ import ai.djl.ndarray.NDArray
66
import ai.djl.ndarray.NDList
77
import ai.djl.translate.Translator
88
import ai.djl.translate.TranslatorContext
9+
import org.utbot.analytics.StateRewardPredictor
910
import org.utbot.framework.UtSettings
1011
import java.io.Closeable
1112
import java.nio.file.Paths
1213

13-
class NNStateRewardPredictorTorch : NNStateRewardPredictor, Closeable {
14+
class StateRewardPredictorTorch : StateRewardPredictor, Closeable {
1415
val model: Model = Model.newInstance("model")
1516

1617
init {
Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
11
package org.utbot.predictors
22

33
import org.utbot.framework.UtSettings
4+
import org.utbot.predictors.util.ScalerLoadingException
5+
import org.utbot.predictors.util.splitByCommaIntoDoubleArray
46
import smile.math.matrix.Matrix
57
import java.nio.file.Paths
68

9+
10+
internal const val DEFAULT_SCALER_PATH = "scaler.txt"
11+
712
data class StandardScaler(val mean: Matrix?, val variance: Matrix?)
813

914
internal fun loadScaler(path: String): StandardScaler =
10-
Paths.get(UtSettings.rewardModelPath, path).toFile().bufferedReader().use {
11-
val mean = it.readLine()?.splitByCommaIntoDoubleArray() ?: error("There is not mean in $path")
12-
val variance = it.readLine()?.splitByCommaIntoDoubleArray() ?: error("There is not variance in $path")
13-
StandardScaler(Matrix(mean), Matrix(variance))
15+
try {
16+
Paths.get(UtSettings.rewardModelPath, path).toFile().bufferedReader().use {
17+
val mean = it.readLine()?.splitByCommaIntoDoubleArray() ?: error("There is not mean in $path")
18+
val variance = it.readLine()?.splitByCommaIntoDoubleArray() ?: error("There is not variance in $path")
19+
StandardScaler(Matrix(mean), Matrix(variance))
20+
}
21+
} catch (e: Exception) {
22+
throw ScalerLoadingException(e)
1423
}
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
package org.utbot.predictors.util
2+
3+
sealed class PredictorLoadingException(msg: String?, cause: Throwable? = null) : Exception(msg, cause)
4+
5+
class WeightsLoadingException(e: Throwable) : PredictorLoadingException("Error while loading weights", e)
6+
7+
class ModelLoadingException(e: Throwable) : PredictorLoadingException("Error while loading model", e)
8+
9+
class ScalerLoadingException(e: Throwable) : PredictorLoadingException("Error while loading scaler", e)
10+
11+
class ModelBuildingException(msg: String) : PredictorLoadingException(msg)

utbot-analytics/src/main/kotlin/org/utbot/predictors/util.kt renamed to utbot-analytics/src/main/kotlin/org/utbot/predictors/util/StringUtils.kt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package org.utbot.predictors
1+
package org.utbot.predictors.util
22

33
fun String.splitByCommaIntoDoubleArray() =
44
try {

utbot-analytics/src/test/kotlin/org/utbot/predictors/LinearStateRewardPredictorTest.kt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,4 +31,15 @@ class LinearStateRewardPredictorTest {
3131
}
3232
}
3333
}
34+
35+
@Test
36+
fun simpleTestNotBatch() {
37+
withRewardModelPath("src/test/resources") {
38+
val pred = LinearStateRewardPredictor()
39+
40+
val features = listOf(2.0, 3.0)
41+
42+
assertEquals(6.0, pred.predict(features))
43+
}
44+
}
3445
}

utbot-analytics/src/test/kotlin/org/utbot/predictors/NNStateRewardPredictorTest.kt

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import org.junit.jupiter.api.Assertions.assertEquals
44
import org.junit.jupiter.api.Disabled
55
import org.junit.jupiter.api.Test
66
import org.utbot.examples.withPathSelectorType
7+
import org.utbot.analytics.StateRewardPredictor
78
import org.utbot.examples.withRewardModelPath
89
import org.utbot.framework.PathSelectorType
910
import org.utbot.framework.UtSettings
@@ -32,13 +33,13 @@ class NNStateRewardPredictorTest {
3233

3334

3435
withRewardModelPath("models") {
35-
val averageTime = calcAverageTimeForModelPredict(::NNStateRewardPredictorTorch, 100, features)
36+
val averageTime = calcAverageTimeForModelPredict(::StateRewardPredictorTorch, 100, features)
3637
println(averageTime)
3738
}
3839
}
3940

4041
private fun calcAverageTimeForModelPredict(
41-
model: () -> NNStateRewardPredictor,
42+
model: () -> StateRewardPredictor,
4243
iterations: Int,
4344
features: List<Double>
4445
): Double {

utbot-framework-api/src/main/kotlin/org/utbot/framework/UtSettings.kt

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,11 @@ object UtSettings {
114114
*/
115115
var nnRewardGuidedSelectorType: NNRewardGuidedSelectorType by getEnumProperty(NNRewardGuidedSelectorType.WITHOUT_RECALCULATION)
116116

117+
/**
118+
* Type of [StateRewardPredictor]
119+
*/
120+
var stateRewardPredictorType: StateRewardPredictorType by getEnumProperty(StateRewardPredictorType.BASE)
121+
117122
/**
118123
* Steps limit for path selector.
119124
*/
@@ -422,3 +427,23 @@ enum class NNRewardGuidedSelectorType {
422427
*/
423428
WITHOUT_RECALCULATION
424429
}
430+
431+
/**
432+
* Enum to specify [StateRewardPredictor], see implementations for details
433+
*/
434+
enum class StateRewardPredictorType {
435+
/**
436+
* [NNStateRewardPredictorBase]
437+
*/
438+
BASE,
439+
440+
/**
441+
* [StateRewardPredictorTorch]
442+
*/
443+
TORCH,
444+
445+
/**
446+
* [NNStateRewardPredictorBase]
447+
*/
448+
LINEAR
449+
}

utbot-framework/src/main/kotlin/org/utbot/analytics/EngineAnalyticsContext.kt

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,4 +27,10 @@ object EngineAnalyticsContext {
2727
NNRewardGuidedSelectorType.WITHOUT_RECALCULATION -> NNRewardGuidedSelectorWithoutRecalculationFactory()
2828
NNRewardGuidedSelectorType.WITH_RECALCULATION -> NNRewardGuidedSelectorWithRecalculationFactory()
2929
}
30+
31+
var stateRewardPredictorFactory: StateRewardPredictorFactory = object : StateRewardPredictorFactory {
32+
override fun invoke(): StateRewardPredictor {
33+
error("NNStateRewardPredictor factory wasn't provided")
34+
}
35+
}
3036
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
package org.utbot.analytics
2+
3+
/**
4+
* Interface, which should predict reward for state by features list.
5+
*/
6+
interface StateRewardPredictor : UtBotAbstractPredictor<List<Double>, Double>
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
package org.utbot.analytics
2+
3+
/**
4+
* Encapsulates creation of [StateRewardPredictor]
5+
*/
6+
interface StateRewardPredictorFactory {
7+
operator fun invoke(): StateRewardPredictor
8+
}

0 commit comments

Comments
 (0)