@@ -40,6 +40,24 @@ def toy_y(toy_X):
40
40
return y
41
41
42
42
43
+ @pytest .fixture (scope = "module" )
44
+ def fitted_model_instance (toy_X , toy_y ):
45
+ sampler_config = {
46
+ "draws" : 500 ,
47
+ "tune" : 300 ,
48
+ "chains" : 2 ,
49
+ "target_accept" : 0.95 ,
50
+ }
51
+ model_config = {
52
+ "a" : {"loc" : 0 , "scale" : 10 },
53
+ "b" : {"loc" : 0 , "scale" : 10 },
54
+ "obs_error" : 2 ,
55
+ }
56
+ model = test_ModelBuilder (model_config = model_config , sampler_config = sampler_config )
57
+ model .fit (toy_X )
58
+ return model
59
+
60
+
43
61
class test_ModelBuilder (ModelBuilder ):
44
62
45
63
_model_type = "LinearModel"
@@ -103,16 +121,11 @@ def default_sampler_config(self) -> Dict:
103
121
"target_accept" : 0.95 ,
104
122
}
105
123
106
- @staticmethod
107
- def initial_build_and_fit (toy_X , toy_y , check_idata = True ) -> ModelBuilder :
108
- model_builder = test_ModelBuilder ()
109
- model_builder .idata = model_builder .fit (
110
- toy_X , toy_y , predictor_names = ["input" ], random_seed = 1234
111
- )
112
- if check_idata :
113
- assert model_builder .idata is not None
114
- assert "posterior" in model_builder .idata .groups ()
115
- return model_builder
124
+
125
+ def test_initial_build_and_fit (fitted_model_instance , check_idata = True ) -> ModelBuilder :
126
+ if check_idata :
127
+ assert fitted_model_instance .idata is not None
128
+ assert "posterior" in fitted_model_instance .idata .groups ()
116
129
117
130
118
131
def test_save_without_fit_raises_runtime_error ():
@@ -129,59 +142,62 @@ def test_empty_sampler_config_fit(toy_X, toy_y):
129
142
assert "posterior" in model_builder .idata .groups ()
130
143
131
144
132
- def test_fit (toy_X , toy_y ):
133
- model_builder = test_ModelBuilder .initial_build_and_fit (toy_X , toy_y )
134
- x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
135
- prediction_data = pd .DataFrame ({"input" : x_pred })
136
- pred = model_builder .predict (prediction_data ["input" ])
137
- post_pred = model_builder .sample_posterior_predictive (
145
+ def test_fit (fitted_model_instance ):
146
+ prediction_data = pd .DataFrame ({"input" : np .random .uniform (low = 0 , high = 1 , size = 100 )})
147
+ pred = fitted_model_instance .predict (prediction_data ["input" ])
148
+ post_pred = fitted_model_instance .sample_posterior_predictive (
138
149
prediction_data ["input" ], extend_idata = True , combined = True
139
150
)
140
- post_pred [model_builder .output_var ].shape [0 ] == prediction_data .input .shape
151
+ post_pred [fitted_model_instance .output_var ].shape [0 ] == prediction_data .input .shape
152
+
153
+
154
+ def test_fit_no_y (toy_X ):
155
+ model_builder = test_ModelBuilder ()
156
+ model_builder .idata = model_builder .fit (X = toy_X )
157
+ assert model_builder .model is not None
158
+ assert model_builder .idata is not None
159
+ assert "posterior" in model_builder .idata .groups ()
141
160
142
161
143
162
@pytest .mark .skipif (
144
163
sys .platform == "win32" , reason = "Permissions for temp files not granted on windows CI."
145
164
)
146
- def test_save_load (toy_X , toy_y ):
147
- test_builder = test_ModelBuilder .initial_build_and_fit (toy_X , toy_y )
165
+ def test_save_load (fitted_model_instance ):
148
166
temp = tempfile .NamedTemporaryFile (mode = "w" , encoding = "utf-8" , delete = False )
149
- test_builder .save (temp .name )
167
+ fitted_model_instance .save (temp .name )
150
168
test_builder2 = test_ModelBuilder .load (temp .name )
151
- assert test_builder .idata .groups () == test_builder2 .idata .groups ()
169
+ assert fitted_model_instance .idata .groups () == test_builder2 .idata .groups ()
152
170
153
171
x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
154
172
prediction_data = pd .DataFrame ({"input" : x_pred })
155
- pred1 = test_builder .predict (prediction_data ["input" ])
173
+ pred1 = fitted_model_instance .predict (prediction_data ["input" ])
156
174
pred2 = test_builder2 .predict (prediction_data ["input" ])
157
175
assert pred1 .shape == pred2 .shape
158
176
temp .close ()
159
177
160
178
161
- def test_predict (toy_X , toy_y ):
162
- model_builder = test_ModelBuilder .initial_build_and_fit (toy_X , toy_y )
179
+ def test_predict (fitted_model_instance ):
163
180
x_pred = np .random .uniform (low = 0 , high = 1 , size = 100 )
164
181
prediction_data = pd .DataFrame ({"input" : x_pred })
165
- pred = model_builder .predict (prediction_data ["input" ])
182
+ pred = fitted_model_instance .predict (prediction_data ["input" ])
166
183
# Perform elementwise comparison using numpy
167
184
assert type (pred ) == np .ndarray
168
185
assert len (pred ) > 0
169
186
170
187
171
188
@pytest .mark .parametrize ("combined" , [True , False ])
172
- def test_sample_posterior_predictive (toy_X , toy_y , combined ):
173
- model_builder = test_ModelBuilder .initial_build_and_fit (toy_X , toy_y )
189
+ def test_sample_posterior_predictive (fitted_model_instance , combined ):
174
190
n_pred = 100
175
191
x_pred = np .random .uniform (low = 0 , high = 1 , size = n_pred )
176
192
prediction_data = pd .DataFrame ({"input" : x_pred })
177
- pred = model_builder .sample_posterior_predictive (
193
+ pred = fitted_model_instance .sample_posterior_predictive (
178
194
prediction_data ["input" ], combined = combined , extend_idata = True
179
195
)
180
- chains = model_builder .idata .sample_stats .dims ["chain" ]
181
- draws = model_builder .idata .sample_stats .dims ["draw" ]
196
+ chains = fitted_model_instance .idata .sample_stats .dims ["chain" ]
197
+ draws = fitted_model_instance .idata .sample_stats .dims ["draw" ]
182
198
expected_shape = (n_pred , chains * draws ) if combined else (chains , draws , n_pred )
183
- assert pred [model_builder .output_var ].shape == expected_shape
184
- assert np .issubdtype (pred [model_builder .output_var ].dtype , np .floating )
199
+ assert pred [fitted_model_instance .output_var ].shape == expected_shape
200
+ assert np .issubdtype (pred [fitted_model_instance .output_var ].dtype , np .floating )
185
201
186
202
187
203
def test_id ():
0 commit comments