Skip to content

Commit 75f083e

Browse files
authored
Merge pull request #76 from JuliaPsychometricsBazaar/refactor-next-item-rules
Refactor next item rules
2 parents 16feb96 + 7a0ada6 commit 75f083e

27 files changed

+375
-469
lines changed

Project.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ FittedItemBanks = "3f797b09-34e4-41d7-acf6-3302ae3248a5"
1616
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1717
HypothesisTests = "09f84164-cd44-5f33-b23f-e6b0d136a0d5"
1818
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
19-
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
2019
Lazy = "50d2b5c4-7a5e-59d5-8109-a42b560f39c0"
2120
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2221
LogarithmicNumbers = "aa2f6b4e-9042-5d33-9679-40d3a6b85899"
@@ -48,13 +47,12 @@ FittedItemBanks = "^0.6.3"
4847
ForwardDiff = "0.10.24"
4948
HypothesisTests = "^0.10.12, ^0.11.0"
5049
Interpolations = "^0.14, ^0.15"
51-
KernelAbstractions = "^0.9.22"
5250
Lazy = "0.15"
5351
LogarithmicNumbers = "1"
5452
MacroTools = "^0.5.6"
5553
Measurements = "^2.10.0"
5654
OrderedCollections = "^1.6"
57-
PsychometricsBazaarBase = "^0.8.0"
55+
PsychometricsBazaarBase = "^0.8.1"
5856
Reexport = "1"
5957
Setfield = "^1"
6058
StaticArrays = "1"

src/Comparison.jl

Lines changed: 87 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ module Comparison
44
# Should be kept in mind and kept distinct or code reuse
55

66
using StatsBase
7-
using FittedItemBanks: AbstractItemBank, ResponseType
7+
using FittedItemBanks: AbstractItemBank, ResponseType, subset
88
using ..Responses
99
using ..CatConfig: CatLoopConfig, CatRules
1010
using ..Aggregators: TrackedResponses, add_response!, Speculator, Aggregators, track!,
@@ -14,11 +14,11 @@ using Base: Iterators
1414

1515
using HypothesisTests
1616
using EffectSizes
17-
using DataFrames
17+
using DataFrames: DataFrame
1818
using ComputerAdaptiveTesting: Stateful
1919

2020
export run_random_comparison, run_comparison
21-
export CatComparisonExecutionStrategy#, IncreaseItemBankSizeExecutionStrategy
21+
export CatComparisonExecutionStrategy, IncreaseItemBankSizeExecutionStrategy
2222
#export FollowOneExecutionStrategy, RunIndependentlyExecutionStrategy
2323
#export DecisionTreeExecutionStrategy
2424
export ReplayResponsesExecutionStrategy
@@ -83,7 +83,8 @@ end
8383

8484
abstract type CatComparisonExecutionStrategy end
8585

86-
Base.@kwdef struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrategy}
86+
struct CatComparisonConfig{
87+
StrategyT <: CatComparisonExecutionStrategy, PhasesT <: NamedTuple}
8788
"""
8889
A named tuple with the (named) CatRules (or compatable) to be compared
8990
"""
@@ -99,13 +100,42 @@ Base.@kwdef struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrate
99100
measurements::Vector{}
100101
=#
101102
"""
102-
Which phases to run and/or call the callback on
103+
The phases to run, optionally paired with a callback
103104
"""
104-
phases::Set{Symbol} = Set((:before_next_item, :after_next_item))
105-
"""
106-
The callback which should take a named tuple with information at different phases
107-
"""
108-
callback::Any
105+
phases::PhasesT
106+
end
107+
108+
"""
109+
CatComparisonConfig(;
110+
rules::NamedTuple{Symbol, StatefulCat},
111+
strategy::CatComparisonExecutionStrategy,
112+
phases::Union{NamedTuple{Symbol, Callable}, Tuple{Symbol}},
113+
callback::Callable
114+
) -> CatComparisonConfig
115+
116+
CatComparisonConfig sets up a evaluation-oriented comparison between different CAT systems.
117+
118+
Specify the comparison by listing: CAT systems in `rules`, a `NamedTuple` which gives
119+
identifiers to implementations of the `StatefulCat` interface; the `strategy` to use,
120+
an implementation of `CatComparisonExecutionStrategy`; the `phases` to run listed as
121+
either as a `NamedTuple` with names of phases and corresponding callbacks or `nothing` a
122+
`Tuple` of phases to run; and a `callback` which will be used as a fallback in cases where
123+
no callback is provided.
124+
125+
The exact phases depend on the strategy used. See their individual documentation for more.
126+
"""
127+
function CatComparisonConfig(; rules, strategy, phases = nothing, callback = nothing)
128+
if callback === nothing
129+
callback = (info; kwargs...) -> nothing
130+
end
131+
if phases === nothing
132+
phases = (:before_next_item, :after_next_item)
133+
end
134+
# TODO: normalize phases into named tuple
135+
if !(phases isa NamedTuple)
136+
phases = NamedTuple((phase => callback for phase in phases))
137+
end
138+
CatComparisonConfig(rules, strategy, phases)
109139
end
110140

111141
# Comparison scenarios:
@@ -129,9 +159,11 @@ end
129159

130160
#phase_func=nothing;
131161
function measure_all(comparison, system, cat, phase; kwargs...)
132-
if !(phase in comparison.phases)
162+
@info "measure_all" phase comparison.phases
163+
if !(phase in keys(comparison.phases))
133164
return
134165
end
166+
callback = comparison.phases[phase]
135167
strategy = comparison.strategy
136168
#=measurement_results = []
137169
for measurement in comparison.measurements
@@ -145,7 +177,7 @@ function measure_all(comparison, system, cat, phase; kwargs...)
145177
#end
146178
push!(measurement_results, result)
147179
end=#
148-
comparison.callback((;
180+
callback((;
149181
phase,
150182
system,
151183
cat,
@@ -158,30 +190,56 @@ struct IncreaseItemBankSizeExecutionStrategy <: CatComparisonExecutionStrategy
158190
item_bank::AbstractItemBank
159191
sizes::AbstractVector{Int}
160192
starting_responses::Int
193+
shuffle::Bool
194+
time_limit::Float64
195+
196+
function IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, args...)
197+
if any((size > length(item_bank) for size in sizes))
198+
error("IncreaseItemBankSizeExecutionStrategy: No subset size can be greater than the number of items available in the item bank")
199+
end
200+
new(item_bank, sizes, args...)
201+
end
161202
end
162203

163204
function IncreaseItemBankSizeExecutionStrategy(item_bank, sizes)
164-
return IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, 0)
205+
return IncreaseItemBankSizeExecutionStrategy(item_bank, sizes, 0, false, Inf)
165206
end
166207

167-
function run_comparison(strategy::IncreaseItemBankSizeExecutionStrategy, config)
208+
function run_comparison(comparison::CatComparisonConfig{IncreaseItemBankSizeExecutionStrategy})
209+
strategy = comparison.strategy
210+
current_cats = collect(pairs(comparison.rules))
211+
next_current_cats = copy(current_cats)
212+
@info "sizes" strategy.sizes
168213
for size in strategy.sizes
169-
subsetted_item_bank = subset(strategy.item_bank, size)
170-
responses = TrackedResponses(
171-
BareResponses(ResponseType(strategy.item_bank)),
172-
subsetted_item_bank,
173-
config.ability_tracker
174-
)
175-
for _ in 1:(strategy.starting_responses)
176-
next_item = config.next_item(responses, subsetted_item_bank)
177-
add_response!(responses,
178-
Response(ResponseType(subsetted_item_bank), next_item, rand(Bool)))
214+
subsetted_item_bank = subset(strategy.item_bank, 1:size)
215+
empty!(next_current_cats)
216+
for (name, cat) in current_cats
217+
Stateful.set_item_bank!(cat, subsetted_item_bank)
218+
for _ in 1:(strategy.starting_responses)
219+
Stateful.next_item(cat)
220+
end
221+
measure_all(
222+
comparison,
223+
name,
224+
cat,
225+
:before_next_item
226+
)
227+
timed_next_item = @timed Stateful.next_item(cat)
228+
next_item = timed_next_item.value
229+
measure_all(
230+
comparison,
231+
name,
232+
cat,
233+
:after_next_item,
234+
next_item = next_item,
235+
timing = timed_next_item
236+
)
237+
@info "next_item" timed_next_item.time strategy.time_limit
238+
if timed_next_item.time < strategy.time_limit
239+
push!(next_current_cats, name => cat)
240+
end
179241
end
180-
measure_all(config, :before_next_item, before_next_item; responses = responses)
181-
timed_next_item = @timed config.next_item(responses, item_bank)
182-
next_item = timed_next_item.value
183-
measure_all(config, :after_next_item, after_next_item;
184-
responses = responses, next_item = next_item)
242+
current_cats, next_current_cats = next_current_cats, current_cats
185243
end
186244
end
187245

src/Stateful.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ end
5959
struct StatefulCatConfig{ItemBankT <: AbstractItemBank} <: StatefulCat
6060
rules::CatRules
6161
tracked_responses::TrackedResponses
62-
item_bank::ItemBankT
62+
item_bank::Ref{ItemBankT}
6363
end
6464

6565
function StatefulCatConfig(rules, item_bank)
@@ -69,26 +69,27 @@ function StatefulCatConfig(rules, item_bank)
6969
item_bank,
7070
rules.ability_tracker
7171
)
72-
return StatefulCatConfig(rules, tracked_responses, item_bank)
72+
return StatefulCatConfig(rules, tracked_responses, Ref(item_bank))
7373
end
7474

7575
function next_item(config::StatefulCatConfig)
76-
return best_item(config.rules.next_item, config.tracked_responses, config.item_bank)
76+
return best_item(config.rules.next_item, config.tracked_responses, config.item_bank[])
7777
end
7878

7979
function ranked_items(config::StatefulCatConfig)
8080
return sortperm(compute_criteria(
81-
config.rules.next_item, config.tracked_responses, config.item_bank))
81+
config.rules.next_item, config.tracked_responses, config.item_bank[]))
8282
end
8383

8484
function item_criteria(config::StatefulCatConfig)
8585
return compute_criteria(
86-
config.rules.next_item, config.tracked_responses, config.item_bank)
86+
config.rules.next_item, config.tracked_responses, config.item_bank[])
8787
end
8888

8989
function add_response!(config::StatefulCatConfig, index, response)
9090
Aggregators.add_response!(
91-
config.tracked_responses, Response(ResponseType(config.item_bank), index, response))
91+
config.tracked_responses, Response(
92+
ResponseType(config.item_bank[]), index, response))
9293
end
9394

9495
function rollback!(config::StatefulCatConfig)
@@ -99,6 +100,11 @@ function reset!(config::StatefulCatConfig)
99100
empty!(config.tracked_responses)
100101
end
101102

103+
function set_item_bank!(config::StatefulCatConfig, item_bank)
104+
reset!(config)
105+
config.item_bank[] = item_bank
106+
end
107+
102108
function get_responses(config::StatefulCatConfig)
103109
return config.tracked_responses.responses
104110
end

src/TerminationConditions.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,7 @@ function (condition::SimpleFunctionTerminationCondition)(responses::TrackedRespo
4141
end
4242

4343
struct RunForeverTerminationCondition <: TerminationCondition end
44-
function (condition::RunForeverTerminationCondition)(responses::TrackedResponses,
45-
items::AbstractItemBank)
44+
function (condition::RunForeverTerminationCondition)(::TrackedResponses, ::AbstractItemBank)
4645
return false
4746
end
4847

src/aggregators/Aggregators.jl

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import PsychometricsBazaarBase.IntegralCoeffs
2626

2727
export AbilityEstimator, TrackedResponses
2828
export AbilityTracker, NullAbilityTracker, PointAbilityTracker, GriddedAbilityTracker
29-
export ClosedFormNormalAbilityTracker, MultiAbilityTracker, track!
29+
export ClosedFormNormalAbilityTracker, track!
3030
export response_expectation,
3131
add_response!, pop_response!, expectation, distribution_estimator
3232
export PointAbilityEstimator, PriorAbilityEstimator, LikelihoodAbilityEstimator
@@ -91,12 +91,16 @@ function AbilityTracker(bits...; integrator = nothing, ability_estimator = nothi
9191
end
9292
end
9393

94-
function compatible_tracker(bits...; integrator, ability_estimator, prefer_tracked)
95-
ability_tracker = AbilityTracker(bits...; ability_estimator = ability_estimator)
96-
if ability_tracker isa GriddedAbilityTracker &&
94+
function find_ability_tracker(ability_tracker, typ, integrator)
95+
if ability_tracker isa typ &&
9796
ability_tracker.integrator === integrator
9897
return ability_tracker
9998
end
99+
end
100+
101+
function compatible_tracker(bits...; integrator, ability_estimator, prefer_tracked)
102+
ability_tracker = AbilityTracker(bits...; ability_estimator = ability_estimator)
103+
@returnsome find_ability_tracker(ability_tracker, GriddedAbilityTracker, integrator)
100104
if prefer_tracked
101105
return AbilityTracker(bits...;
102106
integrator = integrator,

src/aggregators/ability_tracker.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,6 @@ include("./ability_trackers/grid.jl")
8080
include("./ability_trackers/point.jl")
8181
include("./ability_trackers/closed_form_normal.jl")
8282
include("./ability_trackers/laplace.jl")
83-
include("./ability_trackers/multi.jl")
8483

8584
"""
8685
This method returns a tracked point estimate if it is has the given ability

src/aggregators/ability_trackers/grid.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,14 @@ end
99

1010
function GriddedAbilityTracker(ability_estimator::DistributionAbilityEstimator,
1111
integrator::FixedGridIntegrator)
12-
GriddedAbilityTracker(ability_estimator, integrator, fill(NaN, length(integrator.grid)))
12+
GriddedAbilityTracker(ability_estimator, integrator, fill(1.0, length(integrator.grid)))
1313
end
1414

15+
find_grid(integrator::FixedGridIntegrator) = integrator.grid
16+
find_grid(integrator::PreallocatedFixedGridIntegrator) = integrator.inner.grid
17+
1518
function track!(responses, ability_tracker::GriddedAbilityTracker)
1619
ability_pdf = pdf(ability_tracker.ability_estimator, responses)
17-
ability_tracker.cur_ability .= ability_pdf.(ability_tracker.integrator.grid)
20+
grid = find_grid(ability_tracker.integrator)
21+
ability_tracker.cur_ability .= ability_pdf.(grid)
1822
end

src/aggregators/ability_trackers/multi.jl

Lines changed: 0 additions & 9 deletions
This file was deleted.

src/next_item_rules/NextItemRules.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer
4747
export AbilityCovarianceStateMultiCriterion, StateMultiCriterion, ItemMultiCriterion
4848
export InformationMatrixCriteria
4949
export ScalarizedStateCriteron, ScalarizedItemCriteron
50+
export DRuleItemCriterion, TRuleItemCriterion
5051

5152
# Prelude
5253
include("./prelude/abstract.jl")
@@ -61,18 +62,18 @@ include("./strategies/exhaustive.jl")
6162
# Combinators
6263
include("./combinators/expectation.jl")
6364
include("./combinators/scalarizers.jl")
65+
include("./combinators/likelihood.jl")
6466

6567
# Criteria
6668
include("./criteria/item/information_special.jl")
6769
include("./criteria/item/information_support.jl")
6870
include("./criteria/item/information.jl")
6971
include("./criteria/item/urry.jl")
7072
include("./criteria/state/ability_variance.jl")
73+
include("./criteria/pointwise/kl.jl")
7174

7275
# Porcelain
76+
include("./porcelain/porcelain.jl")
7377
include("./porcelain/aliases.jl")
7478

75-
# Experimental
76-
include("./experimental/ka.jl")
77-
7879
end
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
struct LikelihoodWeightedItemCriterion{
2+
PointwiseItemCriterionT <: PointwiseItemCriterion,
3+
AbilityIntegratorT <: AbilityIntegrator,
4+
AbilityEstimatorT <: DistributionAbilityEstimator
5+
} <: ItemCriterion
6+
criterion::PointwiseItemCriterionT
7+
integrator::AbilityIntegratorT
8+
estimator::AbilityEstimatorT
9+
end
10+
11+
function compute_criterion(
12+
lwic::LikelihoodWeightedItemCriterion,
13+
tracked_responses::TrackedResponses,
14+
item_idx
15+
)
16+
func = FunctionProduct(
17+
pdf(lwic.estimator, tracked_responses), lwic.criterion(tracked_responses, item_idx))
18+
lwic.integrator(func, 0, lwic.estimator, tracked_responses)
19+
end

0 commit comments

Comments
 (0)