@@ -4,7 +4,7 @@ module Comparison
4
4
# Should be kept in mind and kept distinct or code reuse
5
5
6
6
using StatsBase
7
- using FittedItemBanks: AbstractItemBank, ResponseType
7
+ using FittedItemBanks: AbstractItemBank, ResponseType, subset
8
8
using .. Responses
9
9
using .. CatConfig: CatLoopConfig, CatRules
10
10
using .. Aggregators: TrackedResponses, add_response!, Speculator, Aggregators, track!,
@@ -14,11 +14,11 @@ using Base: Iterators
14
14
15
15
using HypothesisTests
16
16
using EffectSizes
17
- using DataFrames
17
+ using DataFrames: DataFrame
18
18
using ComputerAdaptiveTesting: Stateful
19
19
20
20
export run_random_comparison, run_comparison
21
- export CatComparisonExecutionStrategy# , IncreaseItemBankSizeExecutionStrategy
21
+ export CatComparisonExecutionStrategy, IncreaseItemBankSizeExecutionStrategy
22
22
# export FollowOneExecutionStrategy, RunIndependentlyExecutionStrategy
23
23
# export DecisionTreeExecutionStrategy
24
24
export ReplayResponsesExecutionStrategy
83
83
84
84
abstract type CatComparisonExecutionStrategy end
85
85
86
- Base. @kwdef struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrategy }
86
+ struct CatComparisonConfig{
87
+ StrategyT <: CatComparisonExecutionStrategy , PhasesT <: NamedTuple }
87
88
"""
88
89
A named tuple with the (named) CatRules (or compatable) to be compared
89
90
"""
@@ -99,13 +100,42 @@ Base.@kwdef struct CatComparisonConfig{StrategyT <: CatComparisonExecutionStrate
99
100
measurements::Vector{}
100
101
=#
101
102
"""
102
- Which phases to run and/or call the callback on
103
+ The phases to run, optionally paired with a callback
103
104
"""
104
- phases:: Set{Symbol} = Set ((:before_next_item , :after_next_item ))
105
- """
106
- The callback which should take a named tuple with information at different phases
107
- """
108
- callback:: Any
105
+ phases:: PhasesT
106
+ end
107
+
108
+ """
109
+ CatComparisonConfig(;
110
+ rules::NamedTuple{Symbol, StatefulCat},
111
+ strategy::CatComparisonExecutionStrategy,
112
+ phases::Union{NamedTuple{Symbol, Callable}, Tuple{Symbol}},
113
+ callback::Callable
114
+ ) -> CatComparisonConfig
115
+
116
+ CatComparisonConfig sets up a evaluation-oriented comparison between different CAT systems.
117
+
118
+ Specify the comparison by listing: CAT systems in `rules`, a `NamedTuple` which gives
119
+ identifiers to implementations of the `StatefulCat` interface; the `strategy` to use,
120
+ an implementation of `CatComparisonExecutionStrategy`; the `phases` to run listed as
121
+ either as a `NamedTuple` with names of phases and corresponding callbacks or `nothing` a
122
+ `Tuple` of phases to run; and a `callback` which will be used as a fallback in cases where
123
+ no callback is provided.
124
+
125
+ The exact phases depend on the strategy used. See their individual documentation for more.
126
+ """
127
+ function CatComparisonConfig (; rules, strategy, phases = nothing , callback = nothing )
128
+ if callback === nothing
129
+ callback = (info; kwargs... ) -> nothing
130
+ end
131
+ if phases === nothing
132
+ phases = (:before_next_item , :after_next_item )
133
+ end
134
+ # TODO : normalize phases into named tuple
135
+ if ! (phases isa NamedTuple)
136
+ phases = NamedTuple ((phase => callback for phase in phases))
137
+ end
138
+ CatComparisonConfig (rules, strategy, phases)
109
139
end
110
140
111
141
# Comparison scenarios:
129
159
130
160
# phase_func=nothing;
131
161
function measure_all (comparison, system, cat, phase; kwargs... )
132
- if ! (phase in comparison. phases)
162
+ @info " measure_all" phase comparison. phases
163
+ if ! (phase in keys (comparison. phases))
133
164
return
134
165
end
166
+ callback = comparison. phases[phase]
135
167
strategy = comparison. strategy
136
168
#= measurement_results = []
137
169
for measurement in comparison.measurements
@@ -145,7 +177,7 @@ function measure_all(comparison, system, cat, phase; kwargs...)
145
177
#end
146
178
push!(measurement_results, result)
147
179
end=#
148
- comparison . callback ((;
180
+ callback ((;
149
181
phase,
150
182
system,
151
183
cat,
@@ -158,30 +190,56 @@ struct IncreaseItemBankSizeExecutionStrategy <: CatComparisonExecutionStrategy
158
190
item_bank:: AbstractItemBank
159
191
sizes:: AbstractVector{Int}
160
192
starting_responses:: Int
193
+ shuffle:: Bool
194
+ time_limit:: Float64
195
+
196
+ function IncreaseItemBankSizeExecutionStrategy (item_bank, sizes, args... )
197
+ if any ((size > length (item_bank) for size in sizes))
198
+ error (" IncreaseItemBankSizeExecutionStrategy: No subset size can be greater than the number of items available in the item bank" )
199
+ end
200
+ new (item_bank, sizes, args... )
201
+ end
161
202
end
162
203
163
204
function IncreaseItemBankSizeExecutionStrategy (item_bank, sizes)
164
- return IncreaseItemBankSizeExecutionStrategy (item_bank, sizes, 0 )
205
+ return IncreaseItemBankSizeExecutionStrategy (item_bank, sizes, 0 , false , Inf )
165
206
end
166
207
167
- function run_comparison (strategy:: IncreaseItemBankSizeExecutionStrategy , config)
208
+ function run_comparison (comparison:: CatComparisonConfig{IncreaseItemBankSizeExecutionStrategy} )
209
+ strategy = comparison. strategy
210
+ current_cats = collect (pairs (comparison. rules))
211
+ next_current_cats = copy (current_cats)
212
+ @info " sizes" strategy. sizes
168
213
for size in strategy. sizes
169
- subsetted_item_bank = subset (strategy. item_bank, size)
170
- responses = TrackedResponses (
171
- BareResponses (ResponseType (strategy. item_bank)),
172
- subsetted_item_bank,
173
- config. ability_tracker
174
- )
175
- for _ in 1 : (strategy. starting_responses)
176
- next_item = config. next_item (responses, subsetted_item_bank)
177
- add_response! (responses,
178
- Response (ResponseType (subsetted_item_bank), next_item, rand (Bool)))
214
+ subsetted_item_bank = subset (strategy. item_bank, 1 : size)
215
+ empty! (next_current_cats)
216
+ for (name, cat) in current_cats
217
+ Stateful. set_item_bank! (cat, subsetted_item_bank)
218
+ for _ in 1 : (strategy. starting_responses)
219
+ Stateful. next_item (cat)
220
+ end
221
+ measure_all (
222
+ comparison,
223
+ name,
224
+ cat,
225
+ :before_next_item
226
+ )
227
+ timed_next_item = @timed Stateful. next_item (cat)
228
+ next_item = timed_next_item. value
229
+ measure_all (
230
+ comparison,
231
+ name,
232
+ cat,
233
+ :after_next_item ,
234
+ next_item = next_item,
235
+ timing = timed_next_item
236
+ )
237
+ @info " next_item" timed_next_item. time strategy. time_limit
238
+ if timed_next_item. time < strategy. time_limit
239
+ push! (next_current_cats, name => cat)
240
+ end
179
241
end
180
- measure_all (config, :before_next_item , before_next_item; responses = responses)
181
- timed_next_item = @timed config. next_item (responses, item_bank)
182
- next_item = timed_next_item. value
183
- measure_all (config, :after_next_item , after_next_item;
184
- responses = responses, next_item = next_item)
242
+ current_cats, next_current_cats = next_current_cats, current_cats
185
243
end
186
244
end
187
245
0 commit comments