From 5886ab20b770a577b3a3965f0272993572600474 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Fri, 31 Mar 2023 09:35:30 +0100 Subject: [PATCH 1/6] new branch after failed rebasing --- pymc_experimental/model_builder.py | 100 +++++++++++------- pymc_experimental/tests/test_model_builder.py | 80 ++++++++------ 2 files changed, 106 insertions(+), 74 deletions(-) diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index 6d2513349..191b1b1e7 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -15,6 +15,7 @@ import hashlib import json +from abc import abstractmethod from pathlib import Path from typing import Dict, Union @@ -24,12 +25,10 @@ import pymc as pm -class ModelBuilder(pm.Model): +class ModelBuilder: """ ModelBuilder can be used to provide an easy-to-use API (similar to scikit-learn) for models and help with deployment. - - Extends the pymc.Model class. """ _model_type = "BaseClass" @@ -38,8 +37,8 @@ class ModelBuilder(pm.Model): def __init__( self, model_config: Dict, - sampler_config: Dict, - data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None, + data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]], + sampler_config: Dict = None, ): """ Initializes model configuration and sampler configuration for the model @@ -48,10 +47,10 @@ def __init__( ---------- model_config : Dictionary dictionary of parameters that initialise model configuration. Generated by the user defined create_sample_input method. - sampler_config : Dictionary - dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method. data : Dictionary It is the data we need to train the model on. + sampler_config : Dictionary + dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method. Examples -------- >>> class LinearModel(ModelBuilder): @@ -60,20 +59,23 @@ def __init__( """ super().__init__() + if sampler_config is None: + sampler_config = {} self.model_config = model_config # parameters for priors etc. - self.sample_config = sampler_config # parameters for sampling - self.idata = None # inference data object + self.sampler_config = sampler_config # parameters for sampling self.data = data - self.build() + self.idata = ( + None # inference data object placeholder, idata is generated during build execution + ) - def build(self): + def build(self) -> None: """ Builds the defined model. """ - with self: - self.build_model(self.model_config, self.data) + self.build_model(self, self.model_config, self.data) + @abstractmethod def _data_setter( self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]], x_only: bool = True ): @@ -100,8 +102,10 @@ def _data_setter( raise NotImplementedError - @classmethod - def create_sample_input(cls): + # need a discussion if it's really needed. + @staticmethod + @abstractmethod + def create_sample_input(): """ Needs to be implemented by the user in the inherited class. Returns examples for data, model_config, sampler_config. @@ -135,7 +139,7 @@ def create_sample_input(cls): raise NotImplementedError - def save(self, fname): + def save(self, fname: str) -> None: """ Saves inference data of the model. @@ -159,8 +163,9 @@ def save(self, fname): self.idata.to_netcdf(file) @classmethod - def load(cls, fname): + def load(cls, fname: str): """ + Creates a ModelBuilder instance from a file, Loads inference data for the model. Parameters @@ -170,7 +175,7 @@ def load(cls, fname): Returns ------- - Returns the inference data that is loaded from local system. + Returns an instance of ModelBuilder. Raises ------ @@ -187,22 +192,29 @@ def load(cls, fname): filepath = Path(str(fname)) idata = az.from_netcdf(filepath) - self = cls( - json.loads(idata.attrs["model_config"]), - json.loads(idata.attrs["sampler_config"]), - idata.fit_data.to_dataframe(), + if "sampler_config" in idata.attrs: + sampler_config = json.loads(idata.attrs["sampler_config"]) + else: + sampler_config = {} + model_builder = cls( + model_config=json.loads(idata.attrs["model_config"]), + sampler_config=sampler_config, + data=idata.fit_data.to_dataframe(), ) - self.idata = idata - if self.id != idata.attrs["id"]: + model_builder.idata = idata + model_builder.build() + if model_builder.id != idata.attrs["id"]: raise ValueError( - f"The file '{fname}' does not contain an inference data of the same model or configuration as '{self._model_type}'" + f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'" ) - return self + return model_builder - def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None): + def fit( + self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None + ) -> az.InferenceData: """ - As the name suggests fit can be used to fit a model using the data that is passed as a parameter. + Fit a model using the data passed as a parameter. Sets attrs to inference data of the model. Parameter @@ -225,20 +237,22 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None if data is not None: self.data = data - self._data_setter(data) - - if self.basic_RVs == []: self.build() + self._data_setter(data) - with self: - self.idata = pm.sample(**self.sample_config) + with self.model: + if self.sampler_config: + self.idata = pm.sample(**self.sampler_config) + else: + self.idata = pm.sample() self.idata.extend(pm.sample_prior_predictive()) self.idata.extend(pm.sample_posterior_predictive(self.idata)) self.idata.attrs["id"] = self.id self.idata.attrs["model_type"] = self._model_type self.idata.attrs["version"] = self.version - self.idata.attrs["sampler_config"] = json.dumps(self.sample_config) + if self.sampler_config: + self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config) self.idata.attrs["model_config"] = json.dumps(self.model_config) self.idata.add_groups(fit_data=self.data.to_xarray()) return self.idata @@ -246,7 +260,8 @@ def fit(self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None def predict( self, data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None, - ): + extend_idata: bool = True, + ) -> dict: """ Uses model to predict on unseen data and return point prediction of all the samples @@ -254,6 +269,8 @@ def predict( --------- data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series It is the data we need to make prediction on using the model. + extend_idata : Boolean determining whether the predictions should be added to inference data object. + Defaults to True. Returns ------- @@ -275,7 +292,8 @@ def predict( with self.model: # sample with new input data post_pred = pm.sample_posterior_predictive(self.idata) - + if extend_idata: + self.idata.extend(post_pred) # reshape output post_pred = self._extract_samples(post_pred) for key in post_pred: @@ -286,7 +304,8 @@ def predict( def predict_posterior( self, data_prediction: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] = None, - ): + extend_idata: bool = True, + ) -> Dict[str, np.array]: """ Uses model to predict samples on unseen data. @@ -294,8 +313,8 @@ def predict_posterior( --------- data_prediction : Dictionary of string and either of numpy array, pandas dataframe or pandas Series It is the data we need to make prediction on using the model. - point_estimate : bool - Adds point like estimate used as mean passed as + extend_idata : Boolean determining whether the predictions should be added to inference data object. + Defaults to True. Returns ------- @@ -317,6 +336,8 @@ def predict_posterior( with self.model: # sample with new input data post_pred = pm.sample_posterior_predictive(self.idata) + if extend_idata: + self.idata.extend(post_pred) # reshape output post_pred = self._extract_samples(post_pred) @@ -357,5 +378,4 @@ def id(self) -> str: hasher.update(str(self.model_config.values()).encode()) hasher.update(self.version.encode()) hasher.update(self._model_type.encode()) - # hasher.update(str(self.sample_config.values()).encode()) return hasher.hexdigest()[:16] diff --git a/pymc_experimental/tests/test_model_builder.py b/pymc_experimental/tests/test_model_builder.py index 1dd67e621..21c6890d2 100644 --- a/pymc_experimental/tests/test_model_builder.py +++ b/pymc_experimental/tests/test_model_builder.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. - import hashlib import sys import tempfile @@ -29,26 +28,39 @@ class test_ModelBuilder(ModelBuilder): _model_type = "LinearModel" version = "0.1" - def build_model(self, model_config, data=None): - if data is not None: - x = pm.MutableData("x", data["input"].values) - y_data = pm.MutableData("y_data", data["output"].values) - - # prior parameters - a_loc = model_config["a_loc"] - a_scale = model_config["a_scale"] - b_loc = model_config["b_loc"] - b_scale = model_config["b_scale"] - obs_error = model_config["obs_error"] - - # priors - a = pm.Normal("a", a_loc, sigma=a_scale) - b = pm.Normal("b", b_loc, sigma=b_scale) - obs_error = pm.HalfNormal("σ_model_fmc", obs_error) - - # observed data - if data is not None: - y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data) + def build_model( + self, + model_instance: ModelBuilder, + model_config: dict, + data: dict = None, + sampler_config: dict = None, + ): + model_instance.model_config = model_config + model_instance.data = data + self.model_config = model_config + self.sampler_config = sampler_config + self.data = data + + with pm.Model() as model_instance.model: + if data is not None: + x = pm.MutableData("x", data["input"].values) + y_data = pm.MutableData("y_data", data["output"].values) + + # prior parameters + a_loc = model_config["a_loc"] + a_scale = model_config["a_scale"] + b_loc = model_config["b_loc"] + b_scale = model_config["b_scale"] + obs_error = model_config["obs_error"] + + # priors + a = pm.Normal("a", a_loc, sigma=a_scale) + b = pm.Normal("b", b_loc, sigma=b_scale) + obs_error = pm.HalfNormal("σ_model_fmc", obs_error) + + # observed data + if data is not None: + y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data) def _data_setter(self, data: pd.DataFrame): with self.model: @@ -57,7 +69,7 @@ def _data_setter(self, data: pd.DataFrame): pm.set_data({"y_data": data["output"].values}) @classmethod - def create_sample_input(cls): + def create_sample_input(self): x = np.linspace(start=0, stop=1, num=100) y = 5 * x + 3 y = y + np.random.normal(0, 1, len(x)) @@ -81,14 +93,14 @@ def create_sample_input(cls): return data, model_config, sampler_config @staticmethod - def initial_build_and_fit(check_idata=True): + def initial_build_and_fit(check_idata=True) -> ModelBuilder: data, model_config, sampler_config = test_ModelBuilder.create_sample_input() - model = test_ModelBuilder(model_config, sampler_config, data) - model.fit() + model_builder = test_ModelBuilder(model_config, sampler_config, data) + model_builder.idata = model_builder.fit(data=data) if check_idata: - assert model.idata is not None - assert "posterior" in model.idata.groups() - return model + assert model_builder.idata is not None + assert "posterior" in model_builder.idata.groups() + return model_builder def test_fit(): @@ -105,16 +117,16 @@ def test_fit(): sys.platform == "win32", reason="Permissions for temp files not granted on windows CI." ) def test_save_load(): - model = test_ModelBuilder.initial_build_and_fit(False) + test_builder = test_ModelBuilder.initial_build_and_fit() temp = tempfile.NamedTemporaryFile(mode="w", encoding="utf-8", delete=False) - model.save(temp.name) - model2 = test_ModelBuilder.load(temp.name) - assert model.idata.groups() == model2.idata.groups() + test_builder.save(temp.name) + test_builder2 = test_ModelBuilder.load(temp.name) + assert test_builder.idata.groups() == test_builder2.idata.groups() x_pred = np.random.uniform(low=0, high=1, size=100) prediction_data = pd.DataFrame({"input": x_pred}) - pred1 = model.predict(prediction_data) - pred2 = model2.predict(prediction_data) + pred1 = test_builder.predict(prediction_data) + pred2 = test_builder2.predict(prediction_data) assert pred1["y_model"].shape == pred2["y_model"].shape temp.close() From 4b96e814c9aafee4451c66561bfaa8fe58ff167d Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Fri, 31 Mar 2023 15:03:34 +0100 Subject: [PATCH 2/6] making model config optional, implementing requested changes --- pymc_experimental/model_builder.py | 26 +++++++------- pymc_experimental/tests/test_model_builder.py | 34 +++++++++++++++---- 2 files changed, 39 insertions(+), 21 deletions(-) diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index 191b1b1e7..8aa81e051 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -36,8 +36,8 @@ class ModelBuilder: def __init__( self, - model_config: Dict, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]], + model_config: Dict = None, sampler_config: Dict = None, ): """ @@ -61,6 +61,8 @@ def __init__( super().__init__() if sampler_config is None: sampler_config = {} + if model_config is None: + model_config = {} self.model_config = model_config # parameters for priors etc. self.sampler_config = sampler_config # parameters for sampling self.data = data @@ -73,7 +75,12 @@ def build(self) -> None: Builds the defined model. """ - self.build_model(self, self.model_config, self.data) + self.build_model( + model_instance=self, + data=self.data, + model_config=self.model_config, + sampler_config=self.sampler_config, + ) @abstractmethod def _data_setter( @@ -102,7 +109,6 @@ def _data_setter( raise NotImplementedError - # need a discussion if it's really needed. @staticmethod @abstractmethod def create_sample_input(): @@ -192,13 +198,9 @@ def load(cls, fname: str): filepath = Path(str(fname)) idata = az.from_netcdf(filepath) - if "sampler_config" in idata.attrs: - sampler_config = json.loads(idata.attrs["sampler_config"]) - else: - sampler_config = {} model_builder = cls( model_config=json.loads(idata.attrs["model_config"]), - sampler_config=sampler_config, + sampler_config=json.loads(idata.attrs["sampler_config"]), data=idata.fit_data.to_dataframe(), ) model_builder.idata = idata @@ -241,18 +243,14 @@ def fit( self._data_setter(data) with self.model: - if self.sampler_config: - self.idata = pm.sample(**self.sampler_config) - else: - self.idata = pm.sample() + self.idata = pm.sample(**self.sampler_config) self.idata.extend(pm.sample_prior_predictive()) self.idata.extend(pm.sample_posterior_predictive(self.idata)) self.idata.attrs["id"] = self.id self.idata.attrs["model_type"] = self._model_type self.idata.attrs["version"] = self.version - if self.sampler_config: - self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config) + self.idata.attrs["sampler_config"] = json.dumps(self.sampler_config) self.idata.attrs["model_config"] = json.dumps(self.model_config) self.idata.add_groups(fit_data=self.data.to_xarray()) return self.idata diff --git a/pymc_experimental/tests/test_model_builder.py b/pymc_experimental/tests/test_model_builder.py index 21c6890d2..f69347f56 100644 --- a/pymc_experimental/tests/test_model_builder.py +++ b/pymc_experimental/tests/test_model_builder.py @@ -31,15 +31,13 @@ class test_ModelBuilder(ModelBuilder): def build_model( self, model_instance: ModelBuilder, + data: dict, model_config: dict, - data: dict = None, - sampler_config: dict = None, + sampler_config: dict, ): model_instance.model_config = model_config model_instance.data = data - self.model_config = model_config - self.sampler_config = sampler_config - self.data = data + model_instance.sampler_config = sampler_config with pm.Model() as model_instance.model: if data is not None: @@ -95,7 +93,9 @@ def create_sample_input(self): @staticmethod def initial_build_and_fit(check_idata=True) -> ModelBuilder: data, model_config, sampler_config = test_ModelBuilder.create_sample_input() - model_builder = test_ModelBuilder(model_config, sampler_config, data) + model_builder = test_ModelBuilder( + model_config=model_config, sampler_config=sampler_config, data=data + ) model_builder.idata = model_builder.fit(data=data) if check_idata: assert model_builder.idata is not None @@ -103,6 +103,26 @@ def initial_build_and_fit(check_idata=True) -> ModelBuilder: return model_builder +def test_empty_model_config(): + data, model_config, sampler_config = test_ModelBuilder.create_sample_input() + sampler_config = {} + model_builder = test_ModelBuilder( + model_config=model_config, sampler_config=sampler_config, data=data + ) + model_builder.idata = model_builder.fit(data=data) + assert model_builder.idata is not None + assert "posterior" in model_builder.idata.groups() + + +def test_empty_model_config(): + data, model_config, sampler_config = test_ModelBuilder.create_sample_input() + model_config = {} + model_builder = test_ModelBuilder( + model_config=model_config, sampler_config=sampler_config, data=data + ) + assert model_builder.model_config == {} + + def test_fit(): model = test_ModelBuilder.initial_build_and_fit() x_pred = np.random.uniform(low=0, high=1, size=100) @@ -175,7 +195,7 @@ def test_extract_samples(): def test_id(): data, model_config, sampler_config = test_ModelBuilder.create_sample_input() - model = test_ModelBuilder(model_config, sampler_config, data) + model = test_ModelBuilder(model_config=model_config, sampler_config=sampler_config, data=data) expected_id = hashlib.sha256( str(model_config.values()).encode() + model.version.encode() + model._model_type.encode() From a8453cfcec6ba07e040b265471bbd0f55b26bd6a Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Fri, 31 Mar 2023 16:30:58 +0100 Subject: [PATCH 3/6] removed super().__init__() --- pymc_experimental/model_builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index 8aa81e051..a3567950c 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -58,7 +58,6 @@ def __init__( >>> model = LinearModel(model_config, sampler_config) """ - super().__init__() if sampler_config is None: sampler_config = {} if model_config is None: From 5e0706489cf259b4d9a9bccceee8a9c7751e891b Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Fri, 31 Mar 2023 16:46:08 +0100 Subject: [PATCH 4/6] removed build method, used direct calls to build_model instead --- pymc_experimental/model_builder.py | 23 +++++++++-------------- 1 file changed, 9 insertions(+), 14 deletions(-) diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index a3567950c..1d87c9359 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -69,18 +69,6 @@ def __init__( None # inference data object placeholder, idata is generated during build execution ) - def build(self) -> None: - """ - Builds the defined model. - """ - - self.build_model( - model_instance=self, - data=self.data, - model_config=self.model_config, - sampler_config=self.sampler_config, - ) - @abstractmethod def _data_setter( self, data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]], x_only: bool = True @@ -236,10 +224,17 @@ def fit( Initializing NUTS using jitter+adapt_diag... """ + # If a new data was provided, assign it to the model if data is not None: self.data = data - self.build() - self._data_setter(data) + + self.build_model( + model_instance=self, + data=self.data, + model_config=self.model_config, + sampler_config=self.sampler_config, + ) + self._data_setter(data) with self.model: self.idata = pm.sample(**self.sampler_config) From 14581505e24313d89bf6c531c78cff25a751f79e Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Fri, 31 Mar 2023 16:59:16 +0100 Subject: [PATCH 5/6] refactoring build_model, renaming to 'build' --- pymc_experimental/model_builder.py | 15 +++------ pymc_experimental/tests/test_model_builder.py | 33 +++++++------------ 2 files changed, 17 insertions(+), 31 deletions(-) diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index 1d87c9359..c4992c177 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -45,11 +45,11 @@ def __init__( Parameters ---------- - model_config : Dictionary + model_config : Dictionary, optional dictionary of parameters that initialise model configuration. Generated by the user defined create_sample_input method. - data : Dictionary + data : Dictionary, required It is the data we need to train the model on. - sampler_config : Dictionary + sampler_config : Dictionary, optional dictionary of parameters that initialise sampler configuration. Generated by the user defined create_sample_input method. Examples -------- @@ -191,7 +191,7 @@ def load(cls, fname: str): data=idata.fit_data.to_dataframe(), ) model_builder.idata = idata - model_builder.build() + model_builder.build_model() if model_builder.id != idata.attrs["id"]: raise ValueError( f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'" @@ -228,12 +228,7 @@ def fit( if data is not None: self.data = data - self.build_model( - model_instance=self, - data=self.data, - model_config=self.model_config, - sampler_config=self.sampler_config, - ) + self.build_model() self._data_setter(data) with self.model: diff --git a/pymc_experimental/tests/test_model_builder.py b/pymc_experimental/tests/test_model_builder.py index f69347f56..0bbec37d4 100644 --- a/pymc_experimental/tests/test_model_builder.py +++ b/pymc_experimental/tests/test_model_builder.py @@ -28,28 +28,19 @@ class test_ModelBuilder(ModelBuilder): _model_type = "LinearModel" version = "0.1" - def build_model( - self, - model_instance: ModelBuilder, - data: dict, - model_config: dict, - sampler_config: dict, - ): - model_instance.model_config = model_config - model_instance.data = data - model_instance.sampler_config = sampler_config - - with pm.Model() as model_instance.model: - if data is not None: - x = pm.MutableData("x", data["input"].values) - y_data = pm.MutableData("y_data", data["output"].values) + def build_model(self): + + with pm.Model() as self.model: + if self.data is not None: + x = pm.MutableData("x", self.data["input"].values) + y_data = pm.MutableData("y_data", self.data["output"].values) # prior parameters - a_loc = model_config["a_loc"] - a_scale = model_config["a_scale"] - b_loc = model_config["b_loc"] - b_scale = model_config["b_scale"] - obs_error = model_config["obs_error"] + a_loc = self.model_config["a_loc"] + a_scale = self.model_config["a_scale"] + b_loc = self.model_config["b_loc"] + b_scale = self.model_config["b_scale"] + obs_error = self.model_config["obs_error"] # priors a = pm.Normal("a", a_loc, sigma=a_scale) @@ -57,7 +48,7 @@ def build_model( obs_error = pm.HalfNormal("σ_model_fmc", obs_error) # observed data - if data is not None: + if self.data is not None: y_model = pm.Normal("y_model", a + b * x, obs_error, shape=x.shape, observed=y_data) def _data_setter(self, data: pd.DataFrame): From b660afd1d87b42f690e58297a780b3392cd30aa0 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Fri, 31 Mar 2023 17:01:14 +0100 Subject: [PATCH 6/6] renamed build_model to 'build' --- pymc_experimental/model_builder.py | 4 ++-- pymc_experimental/tests/test_model_builder.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc_experimental/model_builder.py b/pymc_experimental/model_builder.py index c4992c177..3f48f3f92 100644 --- a/pymc_experimental/model_builder.py +++ b/pymc_experimental/model_builder.py @@ -191,7 +191,7 @@ def load(cls, fname: str): data=idata.fit_data.to_dataframe(), ) model_builder.idata = idata - model_builder.build_model() + model_builder.build() if model_builder.id != idata.attrs["id"]: raise ValueError( f"The file '{fname}' does not contain an inference data of the same model or configuration as '{cls._model_type}'" @@ -228,7 +228,7 @@ def fit( if data is not None: self.data = data - self.build_model() + self.build() self._data_setter(data) with self.model: diff --git a/pymc_experimental/tests/test_model_builder.py b/pymc_experimental/tests/test_model_builder.py index 0bbec37d4..20845fee8 100644 --- a/pymc_experimental/tests/test_model_builder.py +++ b/pymc_experimental/tests/test_model_builder.py @@ -28,7 +28,7 @@ class test_ModelBuilder(ModelBuilder): _model_type = "LinearModel" version = "0.1" - def build_model(self): + def build(self): with pm.Model() as self.model: if self.data is not None: