Skip to content

Commit 16feb96

Browse files
authored
More next item rules refactoring (#75)
* More next item rules refactoring * Criteria => MultiCriterion * Get rid of most functors and convert into named methods * Introduce abstract types for pointwise item criteria * Add tests for Stateful * Use best_item in StatefulCatConfig * Fix up benchmark
1 parent 67a2587 commit 16feb96

File tree

20 files changed

+271
-130
lines changed

20 files changed

+271
-130
lines changed

benchmark/benchmarks.jl

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,11 @@ function prepare_4pls(group)
4848
tracked_responses = TrackedResponses(BareResponses(ResponseType(item_bank)),
4949
item_bank,
5050
NullAbilityTracker())
51-
group["$(est_nick)_point_mepv_bare"] = @benchmarkable ($next_item_rule)(
52-
$tracked_responses, $item_bank)
51+
group["$(est_nick)_point_mepv_bare"] = @benchmarkable best_item(
52+
$next_item_rule,
53+
$tracked_responses,
54+
$item_bank
55+
)
5356
bare_responses = BareResponses(
5457
ResponseType(item_bank),
5558
response_idxs,
@@ -60,8 +63,11 @@ function prepare_4pls(group)
6063
bare_responses,
6164
item_bank,
6265
NullAbilityTracker())
63-
group["$(est_nick)_point_mepv_10"] = @benchmarkable ($next_item_rule)(
64-
$tracked_responses, $item_bank)
66+
group["$(est_nick)_point_mepv_10"] = @benchmarkable best_item(
67+
$next_item_rule,
68+
$tracked_responses,
69+
$item_bank
70+
)
6571
end
6672
return group
6773
end

src/Sim.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using FittedItemBanks: AbstractItemBank, ResponseType
55
using ..Responses
66
using ..CatConfig: CatLoopConfig, CatRules
77
using ..Aggregators: TrackedResponses, add_response!, Speculator, Aggregators
8-
using ..NextItemRules: compute_criteria
8+
using ..NextItemRules: compute_criteria, best_item
99

1010
export run_cat, prompt_response, auto_responder
1111

@@ -56,7 +56,7 @@ function run_cat(cat_config::CatLoopConfig{RulesT},
5656
"Best items"
5757
end criteria
5858
try
59-
next_index = next_item(responses, item_bank)
59+
next_index = best_item(next_item, responses, item_bank)
6060
catch exc
6161
if isa(exc, NextItemError)
6262
@warn "Terminating early due to error getting next item" err=sprint(

src/Stateful.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ using FittedItemBanks: AbstractItemBank, ResponseType
44
using ..Aggregators: TrackedResponses, Aggregators
55
using ..CatConfig: CatLoopConfig, CatRules
66
using ..Responses: BareResponses, Response
7-
using ..NextItemRules: compute_criteria
7+
using ..NextItemRules: compute_criteria, best_item
88

99
## StatefulCat interface
1010
abstract type StatefulCat end
@@ -73,7 +73,7 @@ function StatefulCatConfig(rules, item_bank)
7373
end
7474

7575
function next_item(config::StatefulCatConfig)
76-
return config.rules.next_item(config.tracked_responses, config.item_bank)
76+
return best_item(config.rules.next_item, config.tracked_responses, config.item_bank)
7777
end
7878

7979
function ranked_items(config::StatefulCatConfig)

src/decision_tree/DecisionTree.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ function generate_dt_cat(config::DecisionTreeGenerationConfig, item_bank)
128128
while true
129129
track!(responses, config.ability_tracker)
130130
ability = config.ability_estimator(responses)
131-
next_item = config.next_item(responses, item_bank)
131+
next_item = best_item(config.next_item, responses, item_bank)
132132

133133
insert!(decision_tree_result, responses.responses, ability, next_item)
134134
if state_tree.cur_depth == state_tree.max_depth

src/next_item_rules/NextItemRules.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,17 +39,18 @@ export RandomNextItemRule
3939
export ExhaustiveSearch
4040
export catr_next_item_aliases
4141
export preallocate
42-
export compute_criteria
42+
export compute_criteria, compute_criterion, compute_multi_criterion,
43+
compute_pointwise_criterion
44+
export best_item
4345
export PointResponseExpectation, DistributionResponseExpectation
4446
export MatrixScalarizer, DeterminantScalarizer, TraceScalarizer
45-
export AbilityCovarianceStateCriteria, StateCriteria, ItemCriteria
47+
export AbilityCovarianceStateMultiCriterion, StateMultiCriterion, ItemMultiCriterion
4648
export InformationMatrixCriteria
4749
export ScalarizedStateCriteron, ScalarizedItemCriteron
4850

4951
# Prelude
5052
include("./prelude/abstract.jl")
5153
include("./prelude/next_item_rule.jl")
52-
include("./prelude/strategy.jl")
5354
include("./prelude/criteria.jl")
5455
include("./prelude/preallocate.jl")
5556

src/next_item_rules/combinators/expectation.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -96,14 +96,16 @@ function init_thread(::ExpectationBasedItemCriterion, responses::TrackedResponse
9696
end
9797

9898
function _generic_criterion(criterion::StateCriterion, tracked_responses, item_idx)
99-
criterion(tracked_responses)
99+
compute_criterion(criterion, tracked_responses)
100100
end
101101
# TODO: Support init_thread for wrapped ItemCriterion
102102
function _generic_criterion(criterion::ItemCriterion, tracked_responses, item_idx)
103-
criterion(tracked_responses, item_idx)
103+
compute_criterion(criterion, tracked_responses, item_idx)
104104
end
105105

106-
function (item_criterion::ExpectationBasedItemCriterion)(speculator::Speculator,
106+
function compute_criterion(
107+
item_criterion::ExpectationBasedItemCriterion,
108+
speculator::Speculator,
107109
tracked_responses::TrackedResponses,
108110
item_idx)
109111
exp_resp = Aggregators.response_expectation(item_criterion.response_expectation,
Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,57 +1,58 @@
11
struct DeterminantScalarizer <: MatrixScalarizer end
2-
(::DeterminantScalarizer)(mat) = det(mat)
2+
scalarize(::DeterminantScalarizer, mat) = det(mat)
33

44
struct TraceScalarizer <: MatrixScalarizer end
5-
(::TraceScalarizer)(mat) = tr(mat)
5+
scalarize(::TraceScalarizer, mat) = tr(mat)
66

77
struct ScalarizedItemCriteron{
8-
ItemCriteriaT <: ItemCriteria,
8+
ItemMultiCriterionT <: ItemMultiCriterion,
99
MatrixScalarizerT <: MatrixScalarizer
1010
} <: ItemCriterion
11-
criteria::ItemCriteriaT
11+
criteria::ItemMultiCriterionT
1212
scalarizer::MatrixScalarizerT
1313
end
1414

15-
function (ssc::ScalarizedItemCriteron)(tracked_responses, item_idx)
16-
res = ssc.criteria(
17-
init_thread(ssc.criteria, tracked_responses), tracked_responses, item_idx) |>
18-
ssc.scalarizer
19-
if !should_minimize(ssc.criteria)
20-
res = -res
21-
end
22-
res
23-
end
24-
2515
struct ScalarizedStateCriteron{
26-
StateCriteriaT <: StateCriteria,
16+
StateMultiCriterionT <: StateMultiCriterion,
2717
MatrixScalarizerT <: MatrixScalarizer
2818
} <: StateCriterion
29-
criteria::StateCriteriaT
19+
criteria::StateMultiCriterionT
3020
scalarizer::MatrixScalarizerT
3121
end
3222

33-
function (ssc::ScalarizedStateCriteron)(tracked_responses)
34-
res = ssc.criteria(tracked_responses) |> ssc.scalarizer
23+
function compute_criterion(ssc::Union{ScalarizedItemCriteron, ScalarizedStateCriteron},
24+
tracked_responses::TrackedResponses, item_idx...)
25+
res = scalarize(
26+
ssc.scalarizer,
27+
compute_multi_criterion(
28+
ssc.criteria,
29+
init_thread(ssc.criteria, tracked_responses),
30+
tracked_responses,
31+
item_idx...
32+
)
33+
)
3534
if !should_minimize(ssc.criteria)
3635
res = -res
3736
end
3837
res
3938
end
4039

41-
struct WeightedStateCriteria{InnerT <: StateCriteria} <: StateCriteria
40+
struct WeightedStateMultiCriterion{InnerT <: StateMultiCriterion} <: StateMultiCriterion
4241
weights::Vector{Float64}
4342
criteria::InnerT
4443
end
4544

46-
function (wsc::WeightedStateCriteria)(tracked_responses, item_idx)
45+
function compute_multi_criterion(
46+
wsc::WeightedStateMultiCriterion, tracked_responses::TrackedResponses, item_idx)
4747
wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights
4848
end
4949

50-
struct WeightedItemCriteria{InnerT <: ItemCriteria} <: ItemCriteria
50+
struct WeightedItemMultiCriterion{InnerT <: ItemMultiCriterion} <: ItemMultiCriterion
5151
weights::Vector{Float64}
5252
criteria::InnerT
5353
end
5454

55-
function (wsc::WeightedItemCriteria)(tracked_responses, item_idx)
55+
function compute_multi_criterion(
56+
wsc::WeightedItemMultiCriterion, tracked_responses::TrackedResponses, item_idx)
5657
wsc.weights' * wsc.criteria(tracked_responses, item_idx) * wsc.weights
5758
end

src/next_item_rules/criteria/item/information.jl

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,15 +9,17 @@ function InformationItemCriterion(ability_estimator)
99
InformationItemCriterion(ability_estimator, expected_item_information)
1010
end
1111

12-
function (item_criterion::InformationItemCriterion)(tracked_responses::TrackedResponses,
12+
function compute_criterion(
13+
item_criterion::InformationItemCriterion, tracked_responses::TrackedResponses,
1314
item_idx)
1415
ability = maybe_tracked_ability_estimate(tracked_responses,
1516
item_criterion.ability_estimator)
1617
ir = ItemResponse(tracked_responses.item_bank, item_idx)
1718
return -item_criterion.expected_item_information(ir, ability)
1819
end
1920

20-
struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <: ItemCriteria
21+
struct InformationMatrixCriteria{AbilityEstimatorT <: AbilityEstimator, F} <:
22+
ItemMultiCriterion
2123
ability_estimator::AbilityEstimatorT
2224
expected_item_information::F
2325
end
@@ -35,7 +37,8 @@ function init_thread(item_criterion::InformationMatrixCriteria,
3537
responses_information(responses.item_bank, responses.responses, ability)
3638
end
3739

38-
function (item_criterion::InformationMatrixCriteria)(acc_info::Matrix{Float64},
40+
function compute_multi_criterion(
41+
item_criterion::InformationMatrixCriteria, acc_info::Matrix{Float64},
3942
tracked_responses::TrackedResponses,
4043
item_idx)
4144
# TODO: Add in information from the prior

src/next_item_rules/criteria/item/urry.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@ function raw_difficulty(item_bank, item_idx)
1414
item_params(item_bank, item_idx).difficulty
1515
end
1616

17-
function (item_criterion::UrryItemCriterion)(tracked_responses::TrackedResponses, item_idx)
17+
function compute_criterion(
18+
item_criterion::UrryItemCriterion, tracked_responses::TrackedResponses, item_idx)
1819
ability = maybe_tracked_ability_estimate(tracked_responses,
1920
item_criterion.ability_estimator)
2021
diff = raw_difficulty(tracked_responses.item_bank, item_idx)

src/next_item_rules/criteria/state/ability_variance.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,18 +36,20 @@ function AbilityVarianceStateCriterion(bits...)
3636
return AbilityVarianceStateCriterion(dist_est, integrator, skip_zero)
3737
end
3838

39-
function (criterion::AbilityVarianceStateCriterion)(tracked_responses::TrackedResponses)::Float64
39+
function compute_criterion(criterion::AbilityVarianceStateCriterion,
40+
tracked_responses::TrackedResponses)::Float64
4041
# XXX: Not sure if the estimator should come from somewhere else here
4142
denom = normdenom(criterion.integrator,
4243
criterion.dist_est,
4344
tracked_responses)
4445
if denom == 0.0 && criterion.skip_zero
4546
return Inf
4647
end
47-
criterion(DomainType(tracked_responses.item_bank), tracked_responses, denom)
48+
compute_criterion(
49+
criterion, DomainType(tracked_responses.item_bank), tracked_responses, denom)
4850
end
4951

50-
function (criterion::AbilityVarianceStateCriterion)(
52+
function compute_criterion(criterion::AbilityVarianceStateCriterion,
5153
::Union{OneDimContinuousDomain, DiscreteDomain},
5254
tracked_responses::TrackedResponses,
5355
denom)::Float64
@@ -59,9 +61,12 @@ function (criterion::AbilityVarianceStateCriterion)(
5961
)
6062
end
6163

62-
function (criterion::AbilityVarianceStateCriterion)(::Vector,
64+
function compute_criterion(
65+
criterion::AbilityVarianceStateCriterion,
66+
::Vector,
6367
tracked_responses::TrackedResponses,
64-
denom)::Float64
68+
denom
69+
)::Float64
6570
# XXX: Not quite sure about this --- is it useful, the MIRT rules cover this case
6671
mean = expectation(IntegralCoeffs.id,
6772
ndims(tracked_responses.item_bank),
@@ -77,25 +82,26 @@ function (criterion::AbilityVarianceStateCriterion)(::Vector,
7782
denom)
7883
end
7984

80-
struct AbilityCovarianceStateCriteria{
85+
struct AbilityCovarianceStateMultiCriterion{
8186
DistEstT <: DistributionAbilityEstimator,
8287
IntegratorT <: AbilityIntegrator
83-
} <: StateCriteria
88+
} <: StateMultiCriterion
8489
dist_est::DistEstT
8590
integrator::IntegratorT
8691
skip_zero::Bool
8792
end
8893

89-
function AbilityCovarianceStateCriteria(bits...)
94+
function AbilityCovarianceStateMultiCriterion(bits...)
9095
skip_zero = false
9196
@requiresome (dist_est, integrator) = _get_dist_est_and_integrator(bits...)
92-
return AbilityCovarianceStateCriteria(dist_est, integrator, skip_zero)
97+
return AbilityCovarianceStateMultiCriterion(dist_est, integrator, skip_zero)
9398
end
9499

95100
# XXX: Should be at type level
96-
should_minimize(::AbilityCovarianceStateCriteria) = true
101+
should_minimize(::AbilityCovarianceStateMultiCriterion) = true
97102

98-
function (criteria::AbilityCovarianceStateCriteria)(
103+
function compute_multi_criterion(
104+
criteria::AbilityCovarianceStateMultiCriterion,
99105
tracked_responses::TrackedResponses,
100106
denom = normdenom(criteria.integrator,
101107
criteria.dist_est,

src/next_item_rules/prelude/abstract.jl

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,32 @@ $(TYPEDEF)
44
55
Abstract base type for all item selection rules. All descendants of this type
66
are expected to implement the interface
7-
`(rule::NextItemRule)(responses::TrackedResponses, items::AbstractItemBank)::Int`
7+
`(::NextItemRule)(responses::TrackedResponses, items::AbstractItemBank)::Int`.
8+
9+
In practice, all adaptive rules in this package use `ItemStrategyNextItemRule`.
810
911
$(FUNCTIONNAME)(bits...; ability_estimator=nothing, parallel=true)
1012
1113
Implicit constructor for $(FUNCTIONNAME). Uses any given `NextItemRule` or
12-
delegates to `ItemStrategyNextItemRule`.
14+
delegates to `ItemStrategyNextItemRule` the default instance.
1315
"""
1416
abstract type NextItemRule <: CatConfigBase end
1517

1618
"""
1719
$(TYPEDEF)
20+
21+
Abstract type for next item strategies, tightly coupled with `ItemStrategyNextItemRule`.
22+
All descendants of this type are expected to implement the interface
23+
`(rule::ItemStrategyNextItemRule{::NextItemStrategy, ::ItemCriterion})(responses::TrackedResponses,
24+
items) where {ItemCriterionT <: }
25+
`(strategy::NextItemStrategy)(; parallel=true)::NextItemStrategy`
1826
"""
1927
abstract type NextItemStrategy <: CatConfigBase end
2028

2129
"""
2230
$(TYPEDEF)
31+
32+
Abstract type for next item criteria
2333
"""
2434
abstract type ItemCriterion <: CatConfigBase end
2535

@@ -28,6 +38,13 @@ $(TYPEDEF)
2838
"""
2939
abstract type StateCriterion <: CatConfigBase end
3040

41+
"""
42+
$(TYPEDEF)
43+
"""
44+
abstract type PointwiseItemCriterion <: CatConfigBase end
45+
46+
abstract type PurePointwiseItemCriterion <: PointwiseItemCriterion end
47+
3148
abstract type MatrixScalarizer end
32-
abstract type StateCriteria end
33-
abstract type ItemCriteria end
49+
abstract type StateMultiCriterion end
50+
abstract type ItemMultiCriterion end

0 commit comments

Comments
 (0)