Skip to content

Commit a7c4d79

Browse files
making y optional, adding fitted model fixture to make tests more efficient
1 parent b2a9f3f commit a7c4d79

File tree

2 files changed

+51
-34
lines changed

2 files changed

+51
-34
lines changed

pymc_experimental/model_builder.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def load(cls, fname: str):
424424
def fit(
425425
self,
426426
X: pd.DataFrame,
427-
y: pd.Series,
427+
y: Optional[pd.Series] = None,
428428
progressbar: bool = True,
429429
predictor_names: List[str] = None,
430430
random_seed: RandomState = None,
@@ -464,7 +464,8 @@ def fit(
464464
"""
465465
if predictor_names is None:
466466
predictor_names = []
467-
467+
if y is None:
468+
y = np.zeros(X.shape[0])
468469
y = pd.DataFrame({self.output_var: y})
469470
self.generate_and_preprocess_model_data(X, y.values.flatten())
470471
self.build_model(self.X, self.y)

pymc_experimental/tests/test_model_builder.py

Lines changed: 48 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,24 @@ def toy_y(toy_X):
4040
return y
4141

4242

43+
@pytest.fixture(scope="module")
44+
def fitted_model_instance(toy_X, toy_y):
45+
sampler_config = {
46+
"draws": 500,
47+
"tune": 300,
48+
"chains": 2,
49+
"target_accept": 0.95,
50+
}
51+
model_config = {
52+
"a": {"loc": 0, "scale": 10},
53+
"b": {"loc": 0, "scale": 10},
54+
"obs_error": 2,
55+
}
56+
model = test_ModelBuilder(model_config=model_config, sampler_config=sampler_config)
57+
model.fit(toy_X)
58+
return model
59+
60+
4361
class test_ModelBuilder(ModelBuilder):
4462

4563
_model_type = "LinearModel"
@@ -103,16 +121,11 @@ def default_sampler_config(self) -> Dict:
103121
"target_accept": 0.95,
104122
}
105123

106-
@staticmethod
107-
def initial_build_and_fit(toy_X, toy_y, check_idata=True) -> ModelBuilder:
108-
model_builder = test_ModelBuilder()
109-
model_builder.idata = model_builder.fit(
110-
toy_X, toy_y, predictor_names=["input"], random_seed=1234
111-
)
112-
if check_idata:
113-
assert model_builder.idata is not None
114-
assert "posterior" in model_builder.idata.groups()
115-
return model_builder
124+
125+
def test_initial_build_and_fit(fitted_model_instance, check_idata=True) -> ModelBuilder:
126+
if check_idata:
127+
assert fitted_model_instance.idata is not None
128+
assert "posterior" in fitted_model_instance.idata.groups()
116129

117130

118131
def test_save_without_fit_raises_runtime_error():
@@ -129,59 +142,62 @@ def test_empty_sampler_config_fit(toy_X, toy_y):
129142
assert "posterior" in model_builder.idata.groups()
130143

131144

132-
def test_fit(toy_X, toy_y):
133-
model_builder = test_ModelBuilder.initial_build_and_fit(toy_X, toy_y)
134-
x_pred = np.random.uniform(low=0, high=1, size=100)
135-
prediction_data = pd.DataFrame({"input": x_pred})
136-
pred = model_builder.predict(prediction_data["input"])
137-
post_pred = model_builder.sample_posterior_predictive(
145+
def test_fit(fitted_model_instance):
146+
prediction_data = pd.DataFrame({"input": np.random.uniform(low=0, high=1, size=100)})
147+
pred = fitted_model_instance.predict(prediction_data["input"])
148+
post_pred = fitted_model_instance.sample_posterior_predictive(
138149
prediction_data["input"], extend_idata=True, combined=True
139150
)
140-
post_pred[model_builder.output_var].shape[0] == prediction_data.input.shape
151+
post_pred[fitted_model_instance.output_var].shape[0] == prediction_data.input.shape
152+
153+
154+
def test_fit_no_y(toy_X):
155+
model_builder = test_ModelBuilder()
156+
model_builder.idata = model_builder.fit(X=toy_X)
157+
assert model_builder.model is not None
158+
assert model_builder.idata is not None
159+
assert "posterior" in model_builder.idata.groups()
141160

142161

143162
@pytest.mark.skipif(
144163
sys.platform == "win32", reason="Permissions for temp files not granted on windows CI."
145164
)
146-
def test_save_load(toy_X, toy_y):
147-
test_builder = test_ModelBuilder.initial_build_and_fit(toy_X, toy_y)
165+
def test_save_load(fitted_model_instance):
148166
temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False)
149-
test_builder.save(temp.name)
167+
fitted_model_instance.save(temp.name)
150168
test_builder2 = test_ModelBuilder.load(temp.name)
151-
assert test_builder.idata.groups() == test_builder2.idata.groups()
169+
assert fitted_model_instance.idata.groups() == test_builder2.idata.groups()
152170

153171
x_pred = np.random.uniform(low=0, high=1, size=100)
154172
prediction_data = pd.DataFrame({"input": x_pred})
155-
pred1 = test_builder.predict(prediction_data["input"])
173+
pred1 = fitted_model_instance.predict(prediction_data["input"])
156174
pred2 = test_builder2.predict(prediction_data["input"])
157175
assert pred1.shape == pred2.shape
158176
temp.close()
159177

160178

161-
def test_predict(toy_X, toy_y):
162-
model_builder = test_ModelBuilder.initial_build_and_fit(toy_X, toy_y)
179+
def test_predict(fitted_model_instance):
163180
x_pred = np.random.uniform(low=0, high=1, size=100)
164181
prediction_data = pd.DataFrame({"input": x_pred})
165-
pred = model_builder.predict(prediction_data["input"])
182+
pred = fitted_model_instance.predict(prediction_data["input"])
166183
# Perform elementwise comparison using numpy
167184
assert type(pred) == np.ndarray
168185
assert len(pred) > 0
169186

170187

171188
@pytest.mark.parametrize("combined", [True, False])
172-
def test_sample_posterior_predictive(toy_X, toy_y, combined):
173-
model_builder = test_ModelBuilder.initial_build_and_fit(toy_X, toy_y)
189+
def test_sample_posterior_predictive(fitted_model_instance, combined):
174190
n_pred = 100
175191
x_pred = np.random.uniform(low=0, high=1, size=n_pred)
176192
prediction_data = pd.DataFrame({"input": x_pred})
177-
pred = model_builder.sample_posterior_predictive(
193+
pred = fitted_model_instance.sample_posterior_predictive(
178194
prediction_data["input"], combined=combined, extend_idata=True
179195
)
180-
chains = model_builder.idata.sample_stats.dims["chain"]
181-
draws = model_builder.idata.sample_stats.dims["draw"]
196+
chains = fitted_model_instance.idata.sample_stats.dims["chain"]
197+
draws = fitted_model_instance.idata.sample_stats.dims["draw"]
182198
expected_shape = (n_pred, chains * draws) if combined else (chains, draws, n_pred)
183-
assert pred[model_builder.output_var].shape == expected_shape
184-
assert np.issubdtype(pred[model_builder.output_var].dtype, np.floating)
199+
assert pred[fitted_model_instance.output_var].shape == expected_shape
200+
assert np.issubdtype(pred[fitted_model_instance.output_var].dtype, np.floating)
185201

186202

187203
def test_id():

0 commit comments

Comments
 (0)