Skip to content

Commit d2c95de

Browse files
authored
Merge pull request #77 from JuliaPsychometricsBazaar/improve-stateful-interface
Improve stateful interface
2 parents 75f083e + 3995bc8 commit d2c95de

File tree

15 files changed

+115
-79
lines changed

15 files changed

+115
-79
lines changed

.github/workflows/CI.yml

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ jobs:
1818
fail-fast: false
1919
matrix:
2020
version:
21-
- '1.10'
2221
- '1.11'
2322
os:
2423
- ubuntu-latest
@@ -33,18 +32,8 @@ jobs:
3332
- uses: julia-actions/julia-downgrade-compat@v1
3433
if: ${{ matrix.version == '1.10' }}
3534
- uses: julia-actions/cache@v1
36-
- name: Set CAT packages to develop & resolve env
37-
run: |
38-
julia --project=test/ -e 'using Pkg;
39-
Pkg.develop(path=".");
40-
Pkg.resolve();
41-
Pkg.instantiate()'
42-
env:
43-
R_HOME: '*'
4435
- uses: julia-actions/julia-buildpkg@v1
45-
- name: Run tests
46-
run: |
47-
cd test && julia --project=. --code-coverage=user ./runtests.jl
36+
- uses: julia-actions/julia-runtest@v1
4837
- uses: julia-actions/julia-processcoverage@v1
4938
- uses: coverallsapp/github-action@v2
5039
with:
@@ -59,7 +48,7 @@ jobs:
5948
miniforge-version: latest
6049
- uses: julia-actions/setup-julia@v1
6150
with:
62-
version: '1.10'
51+
version: '1.11'
6352
- name: Set CAT packages to develop & resolve env
6453
run: |
6554
julia --project=docs/ -e 'using Pkg;

.github/workflows/benchmark_pr.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ jobs:
1414

1515
steps:
1616
- uses: actions/checkout@v2
17-
- uses: julia-actions/setup-julia@v1
17+
- uses: julia-actions/setup-julia@v2
1818
with:
19-
version: "1.10"
20-
- uses: julia-actions/cache@v1
19+
version: "1.11"
20+
- uses: julia-actions/cache@v2
2121
- name: Extract Package Name from Project.toml
2222
id: extract-package-name
2323
run: |

Project.toml

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
3636

3737
[compat]
3838
Accessors = "^0.1.12"
39+
Aqua = "0.5.5, 0.6.5"
3940
AutoHashEquals = "2"
4041
ConstructionBase = "^1.2"
4142
DataFrames = "1.6.1"
@@ -47,16 +48,30 @@ FittedItemBanks = "^0.6.3"
4748
ForwardDiff = "0.10.24"
4849
HypothesisTests = "^0.10.12, ^0.11.0"
4950
Interpolations = "^0.14, ^0.15"
51+
JET = "^0.9"
5052
Lazy = "0.15"
5153
LogarithmicNumbers = "1"
5254
MacroTools = "^0.5.6"
5355
Measurements = "^2.10.0"
56+
Optim = "1.7.3"
5457
OrderedCollections = "^1.6"
5558
PsychometricsBazaarBase = "^0.8.1"
5659
Reexport = "1"
60+
ResumableFunctions = "^0.6"
5761
Setfield = "^1"
5862
StaticArrays = "1"
5963
StatsBase = "^0.34"
6064
StatsFuns = "^0.9.15, ^1"
65+
Test = "^1.11"
6166
UnPack = "1"
62-
julia = "^1.10"
67+
julia = "^1.11"
68+
69+
[extras]
70+
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
71+
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
72+
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
73+
ResumableFunctions = "c5292f4c-5179-55e1-98c5-05642aab7184"
74+
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
75+
76+
[targets]
77+
test = ["Aqua", "JET", "Optim", "ResumableFunctions", "Test"]

src/Comparison.jl

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@ end
159159

160160
#phase_func=nothing;
161161
function measure_all(comparison, system, cat, phase; kwargs...)
162-
@info "measure_all" phase comparison.phases
162+
@info "measure_all" phase system kwargs
163163
if !(phase in keys(comparison.phases))
164164
return
165165
end
@@ -189,6 +189,7 @@ end
189189
struct IncreaseItemBankSizeExecutionStrategy <: CatComparisonExecutionStrategy
190190
item_bank::AbstractItemBank
191191
sizes::AbstractVector{Int}
192+
responses::Vector # XXX: Type
192193
starting_responses::Int
193194
shuffle::Bool
194195
time_limit::Float64
@@ -205,24 +206,42 @@ function IncreaseItemBankSizeExecutionStrategy(item_bank, sizes)
205206
return IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, 0, false, Inf)
206207
end
207208

209+
function init_cat(cat::Stateful.StatefulCat, item_bank)
210+
Stateful.set_item_bank!(cat, item_bank)
211+
cat
212+
end
213+
214+
function init_cat(cat, item_bank)
215+
cat(item_bank)
216+
end
217+
208218
function run_comparison(comparison::CatComparisonConfig{IncreaseItemBankSizeExecutionStrategy})
209219
strategy = comparison.strategy
210220
current_cats = collect(pairs(comparison.rules))
211-
next_current_cats = copy(current_cats)
221+
next_current_cats = []
212222
@info "sizes" strategy.sizes
213223
for size in strategy.sizes
214224
subsetted_item_bank = subset(strategy.item_bank, 1:size)
215225
empty!(next_current_cats)
216-
for (name, cat) in current_cats
217-
Stateful.set_item_bank!(cat, subsetted_item_bank)
218-
for _ in 1:(strategy.starting_responses)
219-
Stateful.next_item(cat)
226+
for (name, mk_cat) in current_cats
227+
init_time = @timed begin
228+
cat = init_cat(mk_cat, subsetted_item_bank)
220229
end
230+
response_add_time = @timed begin
231+
for idx in 1:(strategy.starting_responses)
232+
Stateful.add_response!(cat, idx, strategy.responses[idx])
233+
end
234+
end
235+
@info "responses" Stateful.get_responses(cat)
221236
measure_all(
222237
comparison,
223238
name,
224239
cat,
225-
:before_next_item
240+
:before_next_item,
241+
init_time = init_time.time,
242+
response_add_time = response_add_time.time,
243+
num_items=size,
244+
system_name=name
226245
)
227246
timed_next_item = @timed Stateful.next_item(cat)
228247
next_item = timed_next_item.value
@@ -232,14 +251,17 @@ function run_comparison(comparison::CatComparisonConfig{IncreaseItemBankSizeExec
232251
cat,
233252
:after_next_item,
234253
next_item = next_item,
235-
timing = timed_next_item
254+
timing = timed_next_item,
255+
num_items=size,
256+
system_name=name
236257
)
237-
@info "next_item" timed_next_item.time strategy.time_limit
258+
@info "next_item" name timed_next_item.time strategy.time_limit
238259
if timed_next_item.time < strategy.time_limit
239260
push!(next_current_cats, name => cat)
240261
end
241262
end
242-
current_cats, next_current_cats = next_current_cats, current_cats
263+
current_cats = next_current_cats
264+
next_current_cats = []
243265
end
244266
end
245267

src/ComputerAdaptiveTesting.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,11 @@ include("./hacks.jl")
55
using Pkg
66
using Reexport
77

8-
export ConfigBase, Responses, Aggregators, NextItemRules, TerminationConditions
9-
export CatConfig, Sim
8+
# Modules
9+
export ConfigBase, Responses, Aggregators
10+
export NextItemRules, TerminationConditions
11+
export CatConfig, Sim, DecisionTree
12+
export Stateful, Comparison
1013

1114
# Vendored dependencies
1215
include("./vendor/PushVectors.jl")

src/Stateful.jl

Lines changed: 24 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@ using ..CatConfig: CatLoopConfig, CatRules
66
using ..Responses: BareResponses, Response
77
using ..NextItemRules: compute_criteria, best_item
88

9+
export StatefulCat, StatefulCatConfig, run_cat
10+
public next_item, ranked_items, item_criteria
11+
public add_response!, rollback!, reset!, get_responses, get_ability
12+
913
## StatefulCat interface
1014
abstract type StatefulCat end
1115

@@ -56,61 +60,65 @@ end
5660
## TODO: Materialise the cat into a decsision tree
5761

5862
## Implementation for CatConfig
59-
struct StatefulCatConfig{ItemBankT <: AbstractItemBank} <: StatefulCat
63+
struct StatefulCatConfig{TrackedResponsesT <: TrackedResponses} <: StatefulCat
6064
rules::CatRules
61-
tracked_responses::TrackedResponses
62-
item_bank::Ref{ItemBankT}
65+
tracked_responses::Ref{TrackedResponsesT}
6366
end
6467

65-
function StatefulCatConfig(rules, item_bank)
68+
function StatefulCatConfig(rules::CatRules, item_bank::AbstractItemBank)
6669
bare_responses = BareResponses(ResponseType(item_bank))
6770
tracked_responses = TrackedResponses(
6871
bare_responses,
6972
item_bank,
7073
rules.ability_tracker
7174
)
72-
return StatefulCatConfig(rules, tracked_responses, Ref(item_bank))
75+
return StatefulCatConfig(rules, Ref(tracked_responses))
7376
end
7477

7578
function next_item(config::StatefulCatConfig)
76-
return best_item(config.rules.next_item, config.tracked_responses, config.item_bank[])
79+
return best_item(config.rules.next_item, config.tracked_responses[])
7780
end
7881

7982
function ranked_items(config::StatefulCatConfig)
8083
return sortperm(compute_criteria(
81-
config.rules.next_item, config.tracked_responses, config.item_bank[]))
84+
config.rules.next_item, config.tracked_responses[]))
8285
end
8386

8487
function item_criteria(config::StatefulCatConfig)
8588
return compute_criteria(
86-
config.rules.next_item, config.tracked_responses, config.item_bank[])
89+
config.rules.next_item, config.tracked_responses[])
8790
end
8891

8992
function add_response!(config::StatefulCatConfig, index, response)
93+
tracked_responses = config.tracked_responses[]
9094
Aggregators.add_response!(
91-
config.tracked_responses, Response(
92-
ResponseType(config.item_bank[]), index, response))
95+
tracked_responses, Response(
96+
ResponseType(tracked_responses.item_bank), index, response))
9397
end
9498

9599
function rollback!(config::StatefulCatConfig)
96-
pop_response!(config.tracked_responses)
100+
pop_response!(config.tracked_responses[])
97101
end
98102

99103
function reset!(config::StatefulCatConfig)
100-
empty!(config.tracked_responses)
104+
empty!(config.tracked_responses[])
101105
end
102106

103107
function set_item_bank!(config::StatefulCatConfig, item_bank)
104-
reset!(config)
105-
config.item_bank[] = item_bank
108+
bare_responses = BareResponses(ResponseType(item_bank))
109+
config.tracked_responses[] = TrackedResponses(
110+
bare_responses,
111+
item_bank,
112+
config.rules.ability_tracker
113+
)
106114
end
107115

108116
function get_responses(config::StatefulCatConfig)
109-
return config.tracked_responses.responses
117+
return config.tracked_responses[].responses
110118
end
111119

112120
function get_ability(config::StatefulCatConfig)
113-
return (config.rules.ability_estimator(config.tracked_responses), nothing)
121+
return (config.rules.ability_estimator(config.tracked_responses[]), nothing)
114122
end
115123

116124
## TODO: Implementation for MaterializedDecisionTree

src/next_item_rules/prelude/criteria.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,13 @@ function compute_criteria(
5050
for item_idx in eachindex(items)]
5151
end
5252

53+
function compute_criteria(
54+
criterion::ItemCriterion,
55+
responses::TrackedResponses,
56+
)
57+
compute_criteria(criterion, responses, responses.item_bank)
58+
end
59+
5360
function compute_criteria(
5461
rule::ItemStrategyNextItemRule{StrategyT, ItemCriterionT},
5562
responses,
@@ -58,6 +65,13 @@ function compute_criteria(
5865
compute_criteria(rule.criterion, responses, items)
5966
end
6067

68+
function compute_criteria(
69+
rule::ItemStrategyNextItemRule{StrategyT, ItemCriterionT},
70+
responses::TrackedResponses
71+
) where {StrategyT, ItemCriterionT <: ItemCriterion}
72+
compute_criteria(rule.criterion, responses)
73+
end
74+
6175
function compute_pointwise_criterion(
6276
ppic::PurePointwiseItemCriterion, tracked_responses, item_idx)
6377
compute_pointwise_criterion(ppic, ItemResponse(tracked_responses.item_bank, item_idx))

src/next_item_rules/prelude/next_item_rule.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,7 @@ function ItemStrategyNextItemRule(bits...;
5151
return ItemStrategyNextItemRule(strategy, criterion)
5252
end
5353
end
54+
55+
function best_item(rule::NextItemRule, tracked_responses::TrackedResponses)
56+
best_item(rule, tracked_responses, tracked_responses.item_bank)
57+
end

src/next_item_rules/strategies/exhaustive.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,5 @@ function best_item(
3939
responses::TrackedResponses,
4040
items
4141
) where {ItemCriterionT <: ItemCriterion}
42-
#, rule.strategy.parallel
4342
exhaustive_search(rule.criterion, responses, items)[1]
44-
end
43+
end

test/Project.toml

Lines changed: 0 additions & 21 deletions
This file was deleted.

test/dummy.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
module Dummy
22

3-
using Accessors
43
using ComputerAdaptiveTesting.NextItemRules
54
using ComputerAdaptiveTesting.Aggregators
65
using ComputerAdaptiveTesting.Responses

test/format.jl

Lines changed: 0 additions & 8 deletions
This file was deleted.

test/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,4 @@ using .Dummy
3232
include("./smoke.jl")
3333
include("./dt.jl")
3434
include("./stateful.jl")
35-
include("./format.jl")
3635
end

test/stateful.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,18 @@
11
@testset "Stateful" begin
2+
using ComputerAdaptiveTesting: CatRules
3+
using FittedItemBanks.DummyData: dummy_full
4+
using FittedItemBanks: OneDimContinuousDomain, SimpleItemBankSpec, StdModel3PL,
5+
BooleanResponse
6+
using ComputerAdaptiveTesting.TerminationConditions: FixedItemsTerminationCondition
7+
using ComputerAdaptiveTesting.NextItemRules: RandomNextItemRule
8+
using ComputerAdaptiveTesting: Stateful
9+
using ResumableFunctions
10+
using Test: @test, @testset
11+
12+
include("./dummy.jl")
13+
using .Dummy
14+
using Random
15+
216
rng = Random.default_rng(42)
317

418
# Create test data

test/tests_top.jl

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)