Skip to content

Commit 549b7fb

Browse files
5hv5hvnkmichaelosthegetwiecki
authored
Added model_builder as directed in PR #6023 on pymc (#64)
* added model_builder * added explanation * added tests * formatting * change in save and load method * updated save and load methods * fixed more errors * fixed path variable in save and load * Documentation * Documentation * fixed formatting and tests * fixed docstring * fixed minor issues * fixed minor issues * Update pymc_experimental/model_builder.py Co-authored-by: Michael Osthege <michael.osthege@outlook.com> * Update pymc_experimental/model_builder.py Co-authored-by: Michael Osthege <michael.osthege@outlook.com> * fixed spelling errors * Update pymc_experimental/model_builder.py Co-authored-by: Thomas Wiecki <thomas.wiecki@gmail.com> * Update pymc_experimental/model_builder.py * Update pymc_experimental/model_builder.py * added build method again * Update pymc_experimental/tests/test_model_builder.py * removed unecessary imports * changed arviz to az * fixed codes to pass build tests * linespace -> linspace * updated test_model_builder.py * updated model_builder.py * fixed overloading of test_fit() * fixed indentation in docstring * added some better examples * fixed test.yml * indetation * Apply suggestions from code review * Update pymc_experimental/model_builder.py * Update pymc_experimental/model_builder.py Co-authored-by: Michael Osthege <michael.osthege@outlook.com> Co-authored-by: Thomas Wiecki <thomas.wiecki@gmail.com>
1 parent 14e2406 commit 549b7fb

File tree

3 files changed

+433
-1
lines changed

3 files changed

+433
-1
lines changed

.github/workflows/test.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ jobs:
131131
# The ">-" in the next line replaces newlines with spaces (see https://stackoverflow.com/a/66809682).
132132
run: >-
133133
conda activate pymc-test-py37 &&
134-
python -m pytest -vv --cov=pymc_experimental --doctest-modules pymc_experimental --cov-append --cov-report=xml --cov-report term --durations=50 %TEST_SUBSET%
134+
python -m pytest -vv --cov=pymc_experimental --cov-append --cov-report=xml --cov-report term --durations=50 %TEST_SUBSET%
135135
- name: Upload coverage to Codecov
136136
uses: codecov/codecov-action@v2
137137
with:

pymc_experimental/model_builder.py

Lines changed: 310 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,310 @@
1+
import hashlib
2+
from pathlib import Path
3+
from typing import Dict, Union
4+
5+
import arviz as az
6+
import numpy as np
7+
import pandas as pd
8+
import pymc as pm
9+
10+
11+
class ModelBuilder(pm.Model):
12+
"""
13+
ModelBuilder can be used to provide an easy-to-use API (similar to scikit-learn) for models
14+
and help with deployment.
15+
16+
Extends the pymc.Model class.
17+
"""
18+
19+
_model_type = "BaseClass"
20+
version = "None"
21+
22+
def __init__(
23+
self,
24+
model_config: Dict,
25+
sampler_config: Dict,
26+
data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
27+
):
28+
"""
29+
Initializes model configuration and sampler configuration for the model
30+
31+
Parameters
32+
----------
33+
model_config : Dictionary
34+
dictionary of parameters that initialise model configuration. Generated by the user defined create_sample_input method.
35+
sampler_config : Dictionary
36+
dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method.
37+
data : Dictionary
38+
It is the data we need to train the model on.
39+
Examples
40+
--------
41+
>>> class LinearModel(ModelBuilder):
42+
>>> ...
43+
>>> model = LinearModel(model_config, sampler_config)
44+
"""
45+
46+
super().__init__()
47+
self.model_config = model_config # parameters for priors etc.
48+
self.sample_config = sampler_config # parameters for sampling
49+
self.idata = None # inference data object
50+
self.data = data
51+
self.build()
52+
53+
def build(self):
54+
"""
55+
Builds the defined model.
56+
"""
57+
58+
with self:
59+
self.build_model(self.model_config, self.data)
60+
61+
def _data_setter(
62+
self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]], x_only: bool = True
63+
):
64+
"""
65+
Sets new data in the model.
66+
67+
Parameters
68+
----------
69+
data : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
70+
It is the data we need to set as idata for the model
71+
x_only : bool
72+
if data only contains values of x and y is not present in the data
73+
74+
Examples
75+
--------
76+
>>> def _data_setter(self, data : pd.DataFrame):
77+
>>> with self.model:
78+
>>> pm.set_data({'x': data['input'].values})
79+
>>> try: # if y values in new data
80+
>>> pm.set_data({'y_data': data['output'].values})
81+
>>> except: # dummies otherwise
82+
>>> pm.set_data({'y_data': np.zeros(len(data))})
83+
"""
84+
85+
raise NotImplementedError
86+
87+
@classmethod
88+
def create_sample_input(cls):
89+
"""
90+
Needs to be implemented by the user in the inherited class.
91+
Returns examples for data, model_config, sampler_config.
92+
This is useful for understanding the required
93+
data structures for the user model.
94+
95+
Examples
96+
--------
97+
>>> @classmethod
98+
>>> def create_sample_input(cls):
99+
>>> x = np.linspace(start=1, stop=50, num=100)
100+
>>> y = 5 * x + 3 + np.random.normal(0, 1, len(x)) * np.random.rand(100)*10 + np.random.rand(100)*6.4
101+
>>> data = pd.DataFrame({'input': x, 'output': y})
102+
103+
>>> model_config = {
104+
>>> 'a_loc': 7,
105+
>>> 'a_scale': 3,
106+
>>> 'b_loc': 5,
107+
>>> 'b_scale': 3,
108+
>>> 'obs_error': 2,
109+
>>> }
110+
111+
>>> sampler_config = {
112+
>>> 'draws': 1_000,
113+
>>> 'tune': 1_000,
114+
>>> 'chains': 1,
115+
>>> 'target_accept': 0.95,
116+
>>> }
117+
>>> return data, model_config, sampler_config
118+
"""
119+
120+
raise NotImplementedError
121+
122+
def save(self, fname):
123+
"""
124+
Saves inference data of the model.
125+
126+
Parameters
127+
----------
128+
fname : string
129+
This denotes the name with path from where idata should be saved.
130+
131+
Examples
132+
--------
133+
>>> class LinearModel(ModelBuilder):
134+
>>> ...
135+
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
136+
>>> model = LinearModel(model_config, sampler_config)
137+
>>> idata = model.fit(data)
138+
>>> name = './mymodel.nc'
139+
>>> model.save(name)
140+
"""
141+
142+
file = Path(str(fname))
143+
self.idata.to_netcdf(file)
144+
145+
@classmethod
146+
def load(cls, fname):
147+
"""
148+
Loads inference data for the model.
149+
150+
Parameters
151+
----------
152+
fname : string
153+
This denotes the name with path from where idata should be loaded from.
154+
155+
Returns
156+
-------
157+
Returns the inference data that is loaded from local system.
158+
159+
Raises
160+
------
161+
ValueError
162+
If the inference data that is loaded doesn't match with the model.
163+
164+
Examples
165+
--------
166+
>>> class LinearModel(ModelBuilder):
167+
>>> ...
168+
>>> name = './mymodel.nc'
169+
>>> imported_model = LinearModel.load(name)
170+
"""
171+
172+
filepath = Path(str(fname))
173+
data = az.from_netcdf(filepath)
174+
idata = data
175+
# Since there is an issue with attrs getting saved in netcdf format which will be fixed in future the following part of code is commented
176+
# Link of issue -> https://github.com/arviz-devs/arviz/issues/2109
177+
# if model.idata.attrs is not None:
178+
# if model.idata.attrs['id'] == self.idata.attrs['id']:
179+
# self = model
180+
# self.idata = data
181+
# return self
182+
# else:
183+
# raise ValueError(
184+
# f"The route '{file}' does not contain an inference data of the same model '{self.__name__}'"
185+
# )
186+
return idata
187+
188+
def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None):
189+
"""
190+
As the name suggests fit can be used to fit a model using the data that is passed as a parameter.
191+
Sets attrs to inference data of the model.
192+
193+
Parameter
194+
---------
195+
data : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
196+
It is the data we need to train the model on.
197+
198+
Returns
199+
-------
200+
returns inference data of the fitted model.
201+
202+
Examples
203+
--------
204+
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
205+
>>> model = LinearModel(model_config, sampler_config)
206+
>>> idata = model.fit(data)
207+
Auto-assigning NUTS sampler...
208+
Initializing NUTS using jitter+adapt_diag...
209+
"""
210+
211+
if data is not None:
212+
self.data = data
213+
self._data_setter(data)
214+
215+
if self.basic_RVs == []:
216+
self.build()
217+
218+
with self:
219+
self.idata = pm.sample(**self.sample_config)
220+
self.idata.extend(pm.sample_prior_predictive())
221+
self.idata.extend(pm.sample_posterior_predictive(self.idata))
222+
223+
self.idata.attrs["id"] = self.id()
224+
self.idata.attrs["model_type"] = self._model_type
225+
self.idata.attrs["version"] = self.version
226+
self.idata.attrs["sample_conifg"] = self.sample_config
227+
self.idata.attrs["model_config"] = self.model_config
228+
return self.idata
229+
230+
def predict(
231+
self,
232+
data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None,
233+
point_estimate: bool = True,
234+
):
235+
"""
236+
Uses model to predict on unseen data.
237+
238+
Parameters
239+
---------
240+
data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series
241+
It is the data we need to make prediction on using the model.
242+
point_estimate : bool
243+
Adds point like estimate used as mean passed as
244+
245+
Returns
246+
-------
247+
returns dictionary of sample's posterior predict.
248+
249+
Examples
250+
--------
251+
>>> data, model_config, sampler_config = LinearModel.create_sample_input()
252+
>>> model = LinearModel(model_config, sampler_config)
253+
>>> idata = model.fit(data)
254+
>>> x_pred = []
255+
>>> prediction_data = pd.DataFrame({'input':x_pred})
256+
# only point estimate
257+
>>> pred_mean = model.predict(prediction_data)
258+
# samples
259+
>>> pred_samples = model.predict(prediction_data, point_estimate=False)
260+
"""
261+
262+
if data_prediction is not None: # set new input data
263+
self._data_setter(data_prediction)
264+
265+
with self.model: # sample with new input data
266+
post_pred = pm.sample_posterior_predictive(self.idata)
267+
268+
# reshape output
269+
post_pred = self._extract_samples(post_pred)
270+
if point_estimate: # average, if point-like estimate desired
271+
for key in post_pred:
272+
post_pred[key] = post_pred[key].mean(axis=0)
273+
274+
return post_pred
275+
276+
@staticmethod
277+
def _extract_samples(post_pred: az.data.inference_data.InferenceData) -> Dict[str, np.array]:
278+
"""
279+
This method can be used to extract samples from posterior predict.
280+
281+
Parameters
282+
----------
283+
post_pred: arviz InferenceData object
284+
285+
Returns
286+
-------
287+
Dictionary of numpy arrays from InferenceData object
288+
"""
289+
290+
post_pred_dict = dict()
291+
for key in post_pred.posterior_predictive:
292+
post_pred_dict[key] = post_pred.posterior_predictive[key].to_numpy()[0]
293+
294+
return post_pred_dict
295+
296+
def id(self):
297+
"""
298+
It creates a hash value to match the model version using last 16 characters of hash encoding.
299+
300+
Returns
301+
-------
302+
Returns string of length 16 characters contains unique hash of the model
303+
"""
304+
305+
hasher = hashlib.sha256()
306+
hasher.update(str(self.model_config.values()).encode())
307+
hasher.update(self.version.encode())
308+
hasher.update(self._model_type.encode())
309+
hasher.update(str(self.sample_config.values()).encode())
310+
return hasher.hexdigest()[:16]

0 commit comments

Comments
 (0)