Skip to content

DOC: "Using ModelBuilder class for deploying PyMC models" example inconsistent #571

Closed
@galenseilis

Description

@galenseilis

Issue with current documentation:

This is copied verbatim from Using ModelBuilder class for deploying PyMC models:

import arviz as az
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import pymc as pm
import xarray as xr

from numpy.random import RandomState

RANDOM_SEED = 8927

rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")

from pymc_experimental.model_builder import ModelBuilder


class LinearModel(ModelBuilder):
    # Give the model a name
    _model_type = "LinearModel"
    # And a version
    version = "0.1"

    def build_model(self, model_config, data=None):
        """
        build_model creates the PyMC model

        Parameters:
        model_config: dictionary
            it is a dictionary with all the parameters that we need in our model example:  a_loc, a_scale, b_loc
        data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]]
            Data we want our model fit on.
        """
        # Note that we do not have to define a with-context

        # Create mutable data containers
        x_data = pm.MutableData("x_data", data["input"].values)
        y_data = pm.MutableData("y_data", data["output"].values)

        # prior parameters
        a_mu_prior = model_config.get("a_mu_prior", 0.0)
        a_sigma_prior = model_config.get("a_sigma_prior", 1.0)
        b_mu_prior = model_config.get("b_mu_prior", 0.0)
        b_sigma_prior = model_config.get("b_sigma_prior", 1.0)
        eps_prior = model_config.get("eps_prior", 1.0)

        # priors
        a = pm.Normal("a", mu=a_mu_prior, sigma=a_sigma_prior)
        b = pm.Normal("b", mu=b_mu_prior, sigma=b_sigma_prior)
        eps = pm.HalfNormal("eps", eps_prior)

        obs = pm.Normal("y", mu=a + b * x_data, sigma=eps, shape=x_data.shape, observed=y_data)

    def _data_setter(self, data: pd.DataFrame):
        """
        _data_setter works as a set_data for the model and updates the data whenever we need to.
        Parameters:
        data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]]
            It is the data we need to update for the model.
        """

        with self.model:
            pm.set_data({"x_data": data["input"].values})
            if "output" in data.columns:
                pm.set_data({"y_data": data["output"].values})

    @classmethod
    def create_sample_input(cls):
        """
        Creates example input and parameters to test the model on.
        This is optional but useful.
        """

        x = np.linspace(start=0, stop=1, num=100)
        y = 0.3 * x + 0.5
        y = y + np.random.normal(0, 1, len(x))
        data = pd.DataFrame({"input": x, "output": y})

        model_config = {
            "a_mu_prior": 0.0,
            "a_sigma_prior": 1.0,
            "b_mu_prior": 0.0,
            "b_sigma_prior": 1.0,
            "eps_prior": 1.0,
        }

        sampler_config = {
            "draws": 1_000,
            "tune": 1_000,
            "chains": 3,
            "target_accept": 0.95,
        }

        return data, model_config, sampler_config

data, model_config, sampler_config = LinearModel.create_sample_input()
model = LinearModel(model_config, sampler_config, data)

When I run this code I get the following error:

Traceback (most recent call last):
  File "/home/galen/projects/try-pymc-modelbuilder/canonical_example.py", line 97, in <module>
    model = LinearModel(model_config, sampler_config, data)
TypeError: ModelBuilder.__init__() takes from 1 to 3 positional arguments but 4 were given

This is due to the fact that build_model does not have a parameter for sampler_config, and in the current state there is no explicit handling of it within the definition of build_model either.

Idea or request for content:

Please consider completing the example such that (1) it runs without issue and (2) shows how sampler_config is intended to be used.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions