From fe21c517883ad0a5bccb8f76df6815734654b962 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Thu, 3 Aug 2023 16:30:25 +0200 Subject: [PATCH 1/5] including ModelBuilder intro notebook --- ...delBuilder in PyMC-Marketing context.ipynb | 17374 ++++++++++++++++ 1 file changed, 17374 insertions(+) create mode 100644 examples/model_builder/ModelBuilder in PyMC-Marketing context.ipynb diff --git a/examples/model_builder/ModelBuilder in PyMC-Marketing context.ipynb b/examples/model_builder/ModelBuilder in PyMC-Marketing context.ipynb new file mode 100644 index 000000000..11b7e71f8 --- /dev/null +++ b/examples/model_builder/ModelBuilder in PyMC-Marketing context.ipynb @@ -0,0 +1,17374 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "1d74584c", + "metadata": {}, + "source": [ + "# Deploying MMMs and CLVs in Production: Saving and Loading Models" + ] + }, + { + "cell_type": "markdown", + "id": "3222ef04", + "metadata": {}, + "source": [ + "In this article, we'll tackle the historically challenging process of deploying Bayesian models built with PyMC. Introducing a revolutionary deployment module, we bring unprecedented simplicity and efficiency to the deployment of PyMC models. As we prioritize user-friendly solutions, let's delve into how this innovation can significantly elevate your data science projects." + ] + }, + { + "cell_type": "markdown", + "id": "ddb28436", + "metadata": {}, + "source": [ + "\n", + "Recent release of PyMC-Marketing by [Labs](https://www.pymc-labs.io) proves to be a big hit [(PyMC-Marketing)](https://www.pymc-labs.io/blog-posts/pymc-marketing-a-bayesian-approach-to-marketing-data-science/). In the feedback one could see an ongoing theme, many of you have been requesting easy and robust way of deploying models to production. It’s been a long-standing problem with PyMC ( and most other PPLs). The reason for that is that there’s no obvious way, and doesn’t matter which approach you try it proves to be tricky. That is why we’re happy to announce the release of `ModelBuilder`, brand new PyMC-experimental module that addresses this need, and improves on the deployment process significantly.\n", + "\n", + "The ModelBuilder module is a new feature of PyMC based models. It provides 2 easy-to-use methods: save() and load() that can be used after the model has been fit.save() allow easy preservation of the model to .netcdf format, and load() gives one-line replication of the original model. Users can control the prior settings with model_config, and customize the sampling process using sampler_config. Default values of those are working just fine, so first time give it a try without changing, and provide your own model_config and model_sampler if afterwards you want to try to customize it more for your use case!\n" + ] + }, + { + "cell_type": "markdown", + "id": "a808e36a", + "metadata": {}, + "source": [ + "For this notebook I'll use the example model used in [MMM Example Notebook](https://www.pymc-marketing.io/en/stable/notebooks/mmm/mmm_example.html), but ommit the details of data generation and plotting functionalities, since they're out of scope for this introduction, I highly recommend to see that part as well, but for now let's focus on today's topic: Groundbreaking deployment improvements in PyMC-Marketing!" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "1050a937", + "metadata": {}, + "outputs": [], + "source": [ + "import arviz as az\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from pymc_marketing.mmm import DelayedSaturatedMMM" + ] + }, + { + "cell_type": "markdown", + "id": "f37d808e", + "metadata": {}, + "source": [ + "Let's load the dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "id": "b7b1193f", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
date_weekyx1x2event_1event_2dayofyeartsin_order_1cos_order_1sin_order_2cos_order_2
02018-04-023984.6622370.3185800.0000000.00.09200.999930-0.011826-0.023651-0.999720
12018-04-093762.8717940.1123880.0000000.00.09910.991269-0.131859-0.261414-0.965227
22018-04-164466.9673880.2924000.0000000.00.010620.968251-0.249981-0.484089-0.875019
32018-04-233864.2193730.0713990.0000000.00.011330.931210-0.364483-0.678820-0.734304
42018-04-304441.6252780.3867450.0000000.00.012040.880683-0.473706-0.834370-0.551205
.......................................
1742021-08-023553.5461480.0330240.0000000.00.0214174-0.513901-0.8578490.8816990.471812
1752021-08-095565.5096820.1656150.8633490.00.0221175-0.613230-0.7899050.9687860.247898
1762021-08-164137.6514850.1718820.0000000.00.0228176-0.703677-0.7105200.9999530.009676
1772021-08-234479.0413510.2802570.0000000.00.0235177-0.783934-0.6208440.973402-0.229104
1782021-08-304675.9734390.4388570.0000000.00.0242178-0.852837-0.5221780.890665-0.454661
\n", + "

179 rows × 12 columns

\n", + "
" + ], + "text/plain": [ + " date_week y x1 x2 event_1 event_2 dayofyear \\\n", + "0 2018-04-02 3984.662237 0.318580 0.000000 0.0 0.0 92 \n", + "1 2018-04-09 3762.871794 0.112388 0.000000 0.0 0.0 99 \n", + "2 2018-04-16 4466.967388 0.292400 0.000000 0.0 0.0 106 \n", + "3 2018-04-23 3864.219373 0.071399 0.000000 0.0 0.0 113 \n", + "4 2018-04-30 4441.625278 0.386745 0.000000 0.0 0.0 120 \n", + ".. ... ... ... ... ... ... ... \n", + "174 2021-08-02 3553.546148 0.033024 0.000000 0.0 0.0 214 \n", + "175 2021-08-09 5565.509682 0.165615 0.863349 0.0 0.0 221 \n", + "176 2021-08-16 4137.651485 0.171882 0.000000 0.0 0.0 228 \n", + "177 2021-08-23 4479.041351 0.280257 0.000000 0.0 0.0 235 \n", + "178 2021-08-30 4675.973439 0.438857 0.000000 0.0 0.0 242 \n", + "\n", + " t sin_order_1 cos_order_1 sin_order_2 cos_order_2 \n", + "0 0 0.999930 -0.011826 -0.023651 -0.999720 \n", + "1 1 0.991269 -0.131859 -0.261414 -0.965227 \n", + "2 2 0.968251 -0.249981 -0.484089 -0.875019 \n", + "3 3 0.931210 -0.364483 -0.678820 -0.734304 \n", + "4 4 0.880683 -0.473706 -0.834370 -0.551205 \n", + ".. ... ... ... ... ... \n", + "174 174 -0.513901 -0.857849 0.881699 0.471812 \n", + "175 175 -0.613230 -0.789905 0.968786 0.247898 \n", + "176 176 -0.703677 -0.710520 0.999953 0.009676 \n", + "177 177 -0.783934 -0.620844 0.973402 -0.229104 \n", + "178 178 -0.852837 -0.522178 0.890665 -0.454661 \n", + "\n", + "[179 rows x 12 columns]" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "url = \"https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/datasets/mmm_example.csv\"\n", + "df = pd.read_csv(url)\n", + "df" + ] + }, + { + "cell_type": "markdown", + "id": "87deb70d", + "metadata": {}, + "source": [ + "But for our model we need much smaller dataset, many of the previous features were contributing to generation of others, now as our target variable is computed we can filter out not needed columns:" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "id": "52b6d127", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
date_weekyx1x2event_1event_2dayofyeart
02018-04-023984.6622370.3185800.00.00.0920
12018-04-093762.8717940.1123880.00.00.0991
22018-04-164466.9673880.2924000.00.00.01062
32018-04-233864.2193730.0713990.00.00.01133
42018-04-304441.6252780.3867450.00.00.01204
\n", + "
" + ], + "text/plain": [ + " date_week y x1 x2 event_1 event_2 dayofyear t\n", + "0 2018-04-02 3984.662237 0.318580 0.0 0.0 0.0 92 0\n", + "1 2018-04-09 3762.871794 0.112388 0.0 0.0 0.0 99 1\n", + "2 2018-04-16 4466.967388 0.292400 0.0 0.0 0.0 106 2\n", + "3 2018-04-23 3864.219373 0.071399 0.0 0.0 0.0 113 3\n", + "4 2018-04-30 4441.625278 0.386745 0.0 0.0 0.0 120 4" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "columns_to_keep = [\n", + " \"date_week\",\n", + " \"y\",\n", + " \"x1\",\n", + " \"x2\",\n", + " \"event_1\",\n", + " \"event_2\",\n", + " \"dayofyear\",\n", + "]\n", + "seed: int = sum(map(ord, \"mmm\"))\n", + "rng =np.random.default_rng(seed=seed)\n", + "\n", + "data = df[columns_to_keep].copy()\n", + "\n", + "data[\"t\"] = range(df.shape[0])\n", + "data.head()" + ] + }, + { + "cell_type": "markdown", + "id": "9518a885", + "metadata": {}, + "source": [ + "## _Model Creation_\n", + "After we have our dataset ready, we could proceed straight to our model definition, but first to show the full potential of one of the new features: `model_config` we need to use some of our data to define our prior for sigma parameter for each of the channels. `model_config` is a customizable dictionary with keys corresponding to priors within the model, and values containing a dictionaries with parameters necessary to initialize them. Later on we'll learn that through the `save()` method we can preserve our priors contained inside the `model_config`, to allow complete replication of our model." + ] + }, + { + "cell_type": "markdown", + "id": "4b52b2c1", + "metadata": {}, + "source": [ + "### model_config" + ] + }, + { + "cell_type": "markdown", + "id": "41021a72", + "metadata": {}, + "source": [ + "`default_model_config` attribute of every model inheriting from `ModelBuilder` will allow you to see which priors are available for customization. To see it simply initialize a dummy model:" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "id": "284bd558", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'intercept': {'mu': 0, 'sigma': 2},\n", + " 'beta_channel': {'sigma': 2, 'dims': ('channel',)},\n", + " 'alpha': {'alpha': 1, 'beta': 3, 'dims': ('channel',)},\n", + " 'lam': {'alpha': 3, 'beta': 1, 'dims': ('channel',)},\n", + " 'sigma': {'sigma': 2},\n", + " 'gamma_control': {'mu': 0, 'sigma': 2, 'dims': ('control',)},\n", + " 'mu': {'dims': ('date',)},\n", + " 'likelihood': {'dims': ('date',)},\n", + " 'gamma_fourier': {'mu': 0, 'b': 1, 'dims': 'fourier_mode'}}" + ] + }, + "execution_count": 34, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dummy_model = DelayedSaturatedMMM(date_column = '', channel_columns= '', adstock_max_lag = 4)\n", + "dummy_model.default_model_config" + ] + }, + { + "cell_type": "markdown", + "id": "f0fd248f", + "metadata": {}, + "source": [ + "You can change only the prior parameters that you wish, no need to alter all of them, unless you'd like to!\n", + "In this case we'll just simply replace our sigma for beta_channel with our computed one:" + ] + }, + { + "cell_type": "markdown", + "id": "19f075f0-4d3d-4509-a9c6-f15efdb9293d", + "metadata": {}, + "source": [ + "First, let's compute the share of spend per channel:" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "id": "4785596a-e333-4cd0-af15-1332e97b66d5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "x1 0.65632\n", + "x2 0.34368\n", + "dtype: float64" + ] + }, + "execution_count": 35, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "total_spend_per_channel = data[[\"x1\", \"x2\"]].sum(axis=0)\n", + "\n", + "spend_share = total_spend_per_channel / total_spend_per_channel.sum()\n", + "\n", + "spend_share" + ] + }, + { + "cell_type": "markdown", + "id": "40d17642-1e21-4adc-97f4-633eede87915", + "metadata": {}, + "source": [ + "Next, we specify the `sigma`parameter per channel:" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "id": "672b36bd-3d08-46df-85b8-d67d3ade75d7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[2.1775326025486734, 1.140260877391939]" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# The scale necessary to make a HalfNormal distribution have unit variance\n", + "HALFNORMAL_SCALE = 1 / np.sqrt(1 - 2 / np.pi)\n", + "\n", + "n_channels = 2\n", + "\n", + "prior_sigma = HALFNORMAL_SCALE * n_channels * spend_share.to_numpy()\n", + "\n", + "prior_sigma.tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "id": "bac9f587", + "metadata": {}, + "outputs": [], + "source": [ + "custom_beta_channel_prior = {'beta_channel': {'sigma': prior_sigma, 'dims': ('channel',)}}\n", + "my_model_config = dummy_model.default_model_config| custom_beta_channel_prior" + ] + }, + { + "cell_type": "markdown", + "id": "1aa435bf", + "metadata": {}, + "source": [ + "As mentioned in the original notebook: \"_For the prior specification there is no right or wrong answer. It all depends on the data, the context and the assumptions you are willing to make. It is always recommended to do some prior predictive sampling and sensitivity analysis to check the impact of the priors on the posterior. We skip this here for the sake of simplicity. If you are not sure about specific priors, the `DelayedSaturatedMMM` class has some default priors that you can use as a starting point._\"" + ] + }, + { + "cell_type": "markdown", + "id": "f195a79e", + "metadata": {}, + "source": [ + "The second feature that we can use for model definition is `sampler_config`. Similar to `model_config`, it's a dictionary that gets saved and contains things you'd usually pass to the `fit()` kwargs. It's not mandatory to create your own `sampler_config`; if not provided, both `model_config` and `sampler_config` will default to the forms specified by PyMC Labs experts, which allows for the usage of all model functionalities. The default `sampler_config` is left empty because the default sampling parameters usually prove sufficient for a start." + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "id": "0ab8140c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{}" + ] + }, + "execution_count": 38, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dummy_model.default_sampler_config" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "id": "bf5a50f4", + "metadata": {}, + "outputs": [], + "source": [ + "my_sampler_config = {\n", + " 'tune':1000,\n", + " 'draws':1000,\n", + " 'chains':4,\n", + " 'target_accept':0.95,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "f3bfe090", + "metadata": {}, + "source": [ + "Let's finally assemble our model!" + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "id": "c7bd6909", + "metadata": {}, + "outputs": [], + "source": [ + "mmm = DelayedSaturatedMMM(\n", + " model_config = my_model_config,\n", + " sampler_config = my_sampler_config,\n", + " date_column=\"date_week\",\n", + " channel_columns=[\"x1\", \"x2\"],\n", + " control_columns=[\n", + " \"event_1\",\n", + " \"event_2\",\n", + " \"t\",\n", + " ],\n", + " adstock_max_lag=8,\n", + " yearly_seasonality=2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "54095b1a", + "metadata": {}, + "source": [ + "An important thing to note here is that in the new version of `DelayedSaturatedMMM`, we don't pass our dataset to the class constructor itself. This is due to a reason I've mentioned before - it supports `sklearn` transformers and validations that require a usual X, y split and typically expect the data to be passed to the `fit()` method." + ] + }, + { + "cell_type": "markdown", + "id": "dec9b1b0", + "metadata": {}, + "source": [ + "## _Model Fitting_" + ] + }, + { + "cell_type": "markdown", + "id": "d5e64562-ba78-4497-a0f8-123b4bc88b79", + "metadata": {}, + "source": [ + "Let's split the dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "id": "ff23006b-a55b-4a22-9f34-4eeaddf47486", + "metadata": {}, + "outputs": [], + "source": [ + "X = data.drop('y',axis=1)\n", + "y = data['y']" + ] + }, + { + "cell_type": "markdown", + "id": "403e3ed2", + "metadata": {}, + "source": [ + "All that's left now is to finally fit the model:\n", + "\n", + "As you can see below, you can still pass the sampler kwargs directly to `fit()` method. However, only those kwargs passed using `sampler_config` will be saved. Therefore, only these will be available after loading the model." + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "id": "0f6ab0a8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Auto-assigning NUTS sampler...\n", + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (4 chains in 4 jobs)\n", + "NUTS: [intercept, beta_channel, alpha, lam, sigma, gamma_control, gamma_fourier]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " 100.00% [8000/8000 00:29<00:00 Sampling 4 chains, 0 divergences]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 30 seconds.\n", + "Sampling: [alpha, beta_channel, gamma_control, gamma_fourier, intercept, lam, likelihood, sigma]\n", + "Sampling: [likelihood]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " 100.00% [4000/4000 00:00<00:00]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + "
    \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
      +       "                                fourier_mode: 4, channel: 2, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain                      (chain) int64 0 1 2 3\n",
      +       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
      +       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      +       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      +       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      +       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      +       "Data variables: (12/13)\n",
      +       "    intercept                  (chain, draw) float64 0.3381 0.3361 ... 0.3479\n",
      +       "    gamma_control              (chain, draw, control) float64 0.2968 ... 0.00...\n",
      +       "    gamma_fourier              (chain, draw, fourier_mode) float64 -0.002612 ...\n",
      +       "    beta_channel               (chain, draw, channel) float64 0.3718 ... 0.2628\n",
      +       "    alpha                      (chain, draw, channel) float64 0.4162 ... 0.2007\n",
      +       "    lam                        (chain, draw, channel) float64 4.26 ... 2.758\n",
      +       "    ...                         ...\n",
      +       "    channel_adstock            (chain, draw, date, channel) float64 0.1868 .....\n",
      +       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3781 .....\n",
      +       "    channel_contributions      (chain, draw, date, channel) float64 0.1406 .....\n",
      +       "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
      +       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 -0.0...\n",
      +       "    mu                         (chain, draw, date) float64 0.4769 ... 0.5924\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:14.027598\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1\n",
      +       "    sampling_time:              29.821417093276978\n",
      +       "    tuning_steps:               1000

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain       (chain) int64 0 1 2 3\n",
      +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      +       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (chain, draw, date) float64 0.5039 0.439 ... 0.5873 0.6072\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:15.416164\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                (chain: 4, draw: 1000)\n",
      +       "Coordinates:\n",
      +       "  * chain                  (chain) int64 0 1 2 3\n",
      +       "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
      +       "Data variables: (12/17)\n",
      +       "    process_time_diff      (chain, draw) float64 0.008982 0.009454 ... 0.009203\n",
      +       "    step_size_bar          (chain, draw) float64 0.05522 0.05522 ... 0.06921\n",
      +       "    step_size              (chain, draw) float64 0.05862 0.05862 ... 0.07543\n",
      +       "    acceptance_rate        (chain, draw) float64 0.9987 0.9934 ... 0.8931 0.9167\n",
      +       "    index_in_trajectory    (chain, draw) int64 -12 20 43 -17 ... 42 20 -43 -29\n",
      +       "    tree_depth             (chain, draw) int64 6 6 6 6 7 6 6 7 ... 6 5 5 6 7 6 6\n",
      +       "    ...                     ...\n",
      +       "    energy_error           (chain, draw) float64 -0.01181 -0.02828 ... -0.01478\n",
      +       "    perf_counter_diff      (chain, draw) float64 0.009316 0.01008 ... 0.009355\n",
      +       "    n_steps                (chain, draw) float64 63.0 63.0 63.0 ... 63.0 63.0\n",
      +       "    diverging              (chain, draw) bool False False False ... False False\n",
      +       "    perf_counter_start     (chain, draw) float64 3.949e+06 ... 3.949e+06\n",
      +       "    lp                     (chain, draw) float64 355.1 355.2 ... 355.6 351.6\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:14.040220\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1\n",
      +       "    sampling_time:              29.821417093276978\n",
      +       "    tuning_steps:               1000

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                    (chain: 1, draw: 500, fourier_mode: 4,\n",
      +       "                                date: 179, channel: 2, control: 3)\n",
      +       "Coordinates:\n",
      +       "  * chain                      (chain) int64 0\n",
      +       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
      +       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      +       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      +       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      +       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      +       "Data variables: (12/13)\n",
      +       "    gamma_fourier              (chain, draw, fourier_mode) float64 1.28 ... 1...\n",
      +       "    intercept                  (chain, draw) float64 1.178 -2.005 ... 2.533\n",
      +       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 1.28...\n",
      +       "    mu                         (chain, draw, date) float64 0.8703 ... -44.96\n",
      +       "    channel_contributions      (chain, draw, date, channel) float64 0.4222 .....\n",
      +       "    control_contributions      (chain, draw, date, control) float64 0.0 ... -...\n",
      +       "    ...                         ...\n",
      +       "    gamma_control              (chain, draw, control) float64 1.547 ... -0.2666\n",
      +       "    channel_adstock            (chain, draw, date, channel) float64 0.3087 .....\n",
      +       "    alpha                      (chain, draw, channel) float64 0.03434 ... 0.1455\n",
      +       "    lam                        (chain, draw, channel) float64 2.558 ... 4.149\n",
      +       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3755 .....\n",
      +       "    sigma                      (chain, draw) float64 0.05319 1.662 ... 1.212\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:15.159471\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain       (chain) int64 0\n",
      +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      +       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (chain, draw, date) float64 0.8725 3.274 6.447 ... -45.43 -45.69\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:15.164125\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (date: 179)\n",
      +       "Coordinates:\n",
      +       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:14.043853\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
      +       "Coordinates:\n",
      +       "  * date          (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "  * channel       (channel) <U2 'x1' 'x2'\n",
      +       "  * control       (control) <U7 'event_1' 'event_2' 't'\n",
      +       "  * fourier_mode  (fourier_mode) <U11 'sin_order_1' ... 'cos_order_2'\n",
      +       "Data variables:\n",
      +       "    channel_data  (date, channel) float64 0.3196 0.0 0.1128 ... 0.0 0.4403 0.0\n",
      +       "    target        (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      +       "    control_data  (date, control) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0\n",
      +       "    fourier_data  (date, fourier_mode) float64 0.9999 -0.01183 ... -0.4547\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:14.045029\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:    (index: 179)\n",
      +       "Coordinates:\n",
      +       "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      +       "Data variables:\n",
      +       "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
      +       "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
      +       "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      +       "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      +       "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
      +       "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      +       "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> posterior_predictive\n", + "\t> sample_stats\n", + "\t> prior\n", + "\t> prior_predictive\n", + "\t> observed_data\n", + "\t> constant_data\n", + "\t> fit_data" + ] + }, + "execution_count": 42, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mmm.fit(X=X, y=y, random_seed=rng)" + ] + }, + { + "cell_type": "markdown", + "id": "c29a6461", + "metadata": {}, + "source": [ + "The `fit()` method automatically builds the model using the priors from `model_config`, and assigns the created model to our instance. You can access it as a normal attribute." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "id": "c6b8e2af", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "pymc.model.Model" + ] + }, + "execution_count": 43, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(mmm.model)" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "id": "f046ee2c", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "clusterdate (179) x channel (2)\n", + "\n", + "date (179) x channel (2)\n", + "\n", + "\n", + "clusterdate (179)\n", + "\n", + "date (179)\n", + "\n", + "\n", + "clusterchannel (2)\n", + "\n", + "channel (2)\n", + "\n", + "\n", + "clusterdate (179) x control (3)\n", + "\n", + "date (179) x control (3)\n", + "\n", + "\n", + "clustercontrol (3)\n", + "\n", + "control (3)\n", + "\n", + "\n", + "clusterdate (179) x fourier_mode (4)\n", + "\n", + "date (179) x fourier_mode (4)\n", + "\n", + "\n", + "clusterfourier_mode (4)\n", + "\n", + "fourier_mode (4)\n", + "\n", + "\n", + "\n", + "channel_contributions\n", + "\n", + "channel_contributions\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "mu\n", + "\n", + "mu\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "channel_contributions->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "channel_data\n", + "\n", + "channel_data\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "channel_adstock\n", + "\n", + "channel_adstock\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "channel_data->channel_adstock\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "channel_adstock_saturated\n", + "\n", + "channel_adstock_saturated\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "channel_adstock->channel_adstock_saturated\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "channel_adstock_saturated->channel_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "target\n", + "\n", + "target\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "likelihood\n", + "\n", + "likelihood\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "mu->likelihood\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "likelihood->target\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "intercept\n", + "\n", + "intercept\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "intercept->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "sigma\n", + "\n", + "sigma\n", + "~\n", + "HalfNormal\n", + "\n", + "\n", + "\n", + "sigma->likelihood\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "lam\n", + "\n", + "lam\n", + "~\n", + "Gamma\n", + "\n", + "\n", + "\n", + "lam->channel_adstock_saturated\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "alpha\n", + "\n", + "alpha\n", + "~\n", + "Beta\n", + "\n", + "\n", + "\n", + "alpha->channel_adstock\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "beta_channel\n", + "\n", + "beta_channel\n", + "~\n", + "HalfNormal\n", + "\n", + "\n", + "\n", + "beta_channel->channel_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "control_data\n", + "\n", + "control_data\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "control_contributions\n", + "\n", + "control_contributions\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "control_data->control_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "control_contributions->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "gamma_control\n", + "\n", + "gamma_control\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "gamma_control->control_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "fourier_contributions\n", + "\n", + "fourier_contributions\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "fourier_contributions->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "fourier_data\n", + "\n", + "fourier_data\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "fourier_data->fourier_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "gamma_fourier\n", + "\n", + "gamma_fourier\n", + "~\n", + "Laplace\n", + "\n", + "\n", + "\n", + "gamma_fourier->fourier_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 44, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mmm.graphviz()" + ] + }, + { + "cell_type": "markdown", + "id": "c804b600", + "metadata": {}, + "source": [ + "posterior trace can be accessed by `fit_result` attribute" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "id": "66903965", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset>\n",
+       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
+       "                                fourier_mode: 4, channel: 2, date: 179)\n",
+       "Coordinates:\n",
+       "  * chain                      (chain) int64 0 1 2 3\n",
+       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
+       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
+       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
+       "  * channel                    (channel) <U2 'x1' 'x2'\n",
+       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
+       "Data variables: (12/13)\n",
+       "    intercept                  (chain, draw) float64 0.3381 0.3361 ... 0.3479\n",
+       "    gamma_control              (chain, draw, control) float64 0.2968 ... 0.00...\n",
+       "    gamma_fourier              (chain, draw, fourier_mode) float64 -0.002612 ...\n",
+       "    beta_channel               (chain, draw, channel) float64 0.3718 ... 0.2628\n",
+       "    alpha                      (chain, draw, channel) float64 0.4162 ... 0.2007\n",
+       "    lam                        (chain, draw, channel) float64 4.26 ... 2.758\n",
+       "    ...                         ...\n",
+       "    channel_adstock            (chain, draw, date, channel) float64 0.1868 .....\n",
+       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3781 .....\n",
+       "    channel_contributions      (chain, draw, date, channel) float64 0.1406 .....\n",
+       "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
+       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 -0.0...\n",
+       "    mu                         (chain, draw, date) float64 0.4769 ... 0.5924\n",
+       "Attributes:\n",
+       "    created_at:                 2023-08-03T11:09:14.027598\n",
+       "    arviz_version:              0.16.1\n",
+       "    inference_library:          pymc\n",
+       "    inference_library_version:  5.6.1\n",
+       "    sampling_time:              29.821417093276978\n",
+       "    tuning_steps:               1000
" + ], + "text/plain": [ + "\n", + "Dimensions: (chain: 4, draw: 1000, control: 3,\n", + " fourier_mode: 4, channel: 2, date: 179)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n", + " * control (control) \n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + "
    \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
      +       "                                fourier_mode: 4, channel: 2, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain                      (chain) int64 0 1 2 3\n",
      +       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
      +       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      +       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      +       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      +       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      +       "Data variables: (12/13)\n",
      +       "    intercept                  (chain, draw) float64 0.3381 0.3361 ... 0.3479\n",
      +       "    gamma_control              (chain, draw, control) float64 0.2968 ... 0.00...\n",
      +       "    gamma_fourier              (chain, draw, fourier_mode) float64 -0.002612 ...\n",
      +       "    beta_channel               (chain, draw, channel) float64 0.3718 ... 0.2628\n",
      +       "    alpha                      (chain, draw, channel) float64 0.4162 ... 0.2007\n",
      +       "    lam                        (chain, draw, channel) float64 4.26 ... 2.758\n",
      +       "    ...                         ...\n",
      +       "    channel_adstock            (chain, draw, date, channel) float64 0.1868 .....\n",
      +       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3781 .....\n",
      +       "    channel_contributions      (chain, draw, date, channel) float64 0.1406 .....\n",
      +       "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
      +       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 -0.0...\n",
      +       "    mu                         (chain, draw, date) float64 0.4769 ... 0.5924\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:14.027598\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1\n",
      +       "    sampling_time:              29.821417093276978\n",
      +       "    tuning_steps:               1000

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain       (chain) int64 0 1 2 3\n",
      +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      +       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (chain, draw, date) float64 0.5039 0.439 ... 0.5873 0.6072\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:15.416164\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                (chain: 4, draw: 1000)\n",
      +       "Coordinates:\n",
      +       "  * chain                  (chain) int64 0 1 2 3\n",
      +       "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
      +       "Data variables: (12/17)\n",
      +       "    process_time_diff      (chain, draw) float64 0.008982 0.009454 ... 0.009203\n",
      +       "    step_size_bar          (chain, draw) float64 0.05522 0.05522 ... 0.06921\n",
      +       "    step_size              (chain, draw) float64 0.05862 0.05862 ... 0.07543\n",
      +       "    acceptance_rate        (chain, draw) float64 0.9987 0.9934 ... 0.8931 0.9167\n",
      +       "    index_in_trajectory    (chain, draw) int64 -12 20 43 -17 ... 42 20 -43 -29\n",
      +       "    tree_depth             (chain, draw) int64 6 6 6 6 7 6 6 7 ... 6 5 5 6 7 6 6\n",
      +       "    ...                     ...\n",
      +       "    energy_error           (chain, draw) float64 -0.01181 -0.02828 ... -0.01478\n",
      +       "    perf_counter_diff      (chain, draw) float64 0.009316 0.01008 ... 0.009355\n",
      +       "    n_steps                (chain, draw) float64 63.0 63.0 63.0 ... 63.0 63.0\n",
      +       "    diverging              (chain, draw) bool False False False ... False False\n",
      +       "    perf_counter_start     (chain, draw) float64 3.949e+06 ... 3.949e+06\n",
      +       "    lp                     (chain, draw) float64 355.1 355.2 ... 355.6 351.6\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:14.040220\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1\n",
      +       "    sampling_time:              29.821417093276978\n",
      +       "    tuning_steps:               1000

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                    (chain: 1, draw: 500, fourier_mode: 4,\n",
      +       "                                date: 179, channel: 2, control: 3)\n",
      +       "Coordinates:\n",
      +       "  * chain                      (chain) int64 0\n",
      +       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
      +       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      +       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      +       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      +       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      +       "Data variables: (12/13)\n",
      +       "    gamma_fourier              (chain, draw, fourier_mode) float64 1.28 ... 1...\n",
      +       "    intercept                  (chain, draw) float64 1.178 -2.005 ... 2.533\n",
      +       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 1.28...\n",
      +       "    mu                         (chain, draw, date) float64 0.8703 ... -44.96\n",
      +       "    channel_contributions      (chain, draw, date, channel) float64 0.4222 .....\n",
      +       "    control_contributions      (chain, draw, date, control) float64 0.0 ... -...\n",
      +       "    ...                         ...\n",
      +       "    gamma_control              (chain, draw, control) float64 1.547 ... -0.2666\n",
      +       "    channel_adstock            (chain, draw, date, channel) float64 0.3087 .....\n",
      +       "    alpha                      (chain, draw, channel) float64 0.03434 ... 0.1455\n",
      +       "    lam                        (chain, draw, channel) float64 2.558 ... 4.149\n",
      +       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3755 .....\n",
      +       "    sigma                      (chain, draw) float64 0.05319 1.662 ... 1.212\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:15.159471\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain       (chain) int64 0\n",
      +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      +       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (chain, draw, date) float64 0.8725 3.274 6.447 ... -45.43 -45.69\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:15.164125\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (date: 179)\n",
      +       "Coordinates:\n",
      +       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:14.043853\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
      +       "Coordinates:\n",
      +       "  * date          (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "  * channel       (channel) <U2 'x1' 'x2'\n",
      +       "  * control       (control) <U7 'event_1' 'event_2' 't'\n",
      +       "  * fourier_mode  (fourier_mode) <U11 'sin_order_1' ... 'cos_order_2'\n",
      +       "Data variables:\n",
      +       "    channel_data  (date, channel) float64 0.3196 0.0 0.1128 ... 0.0 0.4403 0.0\n",
      +       "    target        (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      +       "    control_data  (date, control) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0\n",
      +       "    fourier_data  (date, fourier_mode) float64 0.9999 -0.01183 ... -0.4547\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:14.045029\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:    (index: 179)\n",
      +       "Coordinates:\n",
      +       "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      +       "Data variables:\n",
      +       "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
      +       "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
      +       "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      +       "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      +       "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
      +       "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      +       "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
\n", + " \n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> posterior_predictive\n", + "\t> sample_stats\n", + "\t> prior\n", + "\t> prior_predictive\n", + "\t> observed_data\n", + "\t> constant_data\n", + "\t> fit_data" + ] + }, + "execution_count": 46, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mmm.idata" + ] + }, + { + "cell_type": "markdown", + "id": "8b433c7f-0f0d-40b2-bcfb-a19555b528bd", + "metadata": {}, + "source": [ + "## `Save` and `load`" + ] + }, + { + "cell_type": "markdown", + "id": "7b0a35f4", + "metadata": {}, + "source": [ + "All the data passed to the model on initialisation is stored in `idata.attrs`. This will be used later in the `save()` method to convert both this data and all the fit data into the netCDF format." + ] + }, + { + "cell_type": "markdown", + "id": "45948f46", + "metadata": {}, + "source": [ + "Simply specify the path to which you'd like to save your model:" + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "id": "b3abe93a", + "metadata": {}, + "outputs": [], + "source": [ + "mmm.save('my_saved_model.nc')" + ] + }, + { + "cell_type": "markdown", + "id": "8a5eba79", + "metadata": {}, + "source": [ + "And pass it to the `load()` method when it's needed again on the target system:" + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "id": "0421bae8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/michalraczycki/Documents/pymc-marketing/.conda/envs/pymc-marketing/lib/python3.10/site-packages/arviz/data/inference_data.py:153: UserWarning: fit_data group is not defined in the InferenceData scheme\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "loaded_model = DelayedSaturatedMMM.load('my_saved_model.nc')" + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "id": "a8b666d3", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "clusterdate (179) x channel (2)\n", + "\n", + "date (179) x channel (2)\n", + "\n", + "\n", + "clusterdate (179)\n", + "\n", + "date (179)\n", + "\n", + "\n", + "clusterchannel (2)\n", + "\n", + "channel (2)\n", + "\n", + "\n", + "clusterdate (179) x control (3)\n", + "\n", + "date (179) x control (3)\n", + "\n", + "\n", + "clustercontrol (3)\n", + "\n", + "control (3)\n", + "\n", + "\n", + "clusterdate (179) x fourier_mode (4)\n", + "\n", + "date (179) x fourier_mode (4)\n", + "\n", + "\n", + "clusterfourier_mode (4)\n", + "\n", + "fourier_mode (4)\n", + "\n", + "\n", + "\n", + "channel_contributions\n", + "\n", + "channel_contributions\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "mu\n", + "\n", + "mu\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "channel_contributions->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "channel_data\n", + "\n", + "channel_data\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "channel_adstock\n", + "\n", + "channel_adstock\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "channel_data->channel_adstock\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "channel_adstock_saturated\n", + "\n", + "channel_adstock_saturated\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "channel_adstock->channel_adstock_saturated\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "channel_adstock_saturated->channel_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "target\n", + "\n", + "target\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "likelihood\n", + "\n", + "likelihood\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "mu->likelihood\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "likelihood->target\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "intercept\n", + "\n", + "intercept\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "intercept->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "sigma\n", + "\n", + "sigma\n", + "~\n", + "HalfNormal\n", + "\n", + "\n", + "\n", + "sigma->likelihood\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "lam\n", + "\n", + "lam\n", + "~\n", + "Gamma\n", + "\n", + "\n", + "\n", + "lam->channel_adstock_saturated\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "alpha\n", + "\n", + "alpha\n", + "~\n", + "Beta\n", + "\n", + "\n", + "\n", + "alpha->channel_adstock\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "beta_channel\n", + "\n", + "beta_channel\n", + "~\n", + "HalfNormal\n", + "\n", + "\n", + "\n", + "beta_channel->channel_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "control_data\n", + "\n", + "control_data\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "control_contributions\n", + "\n", + "control_contributions\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "control_data->control_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "control_contributions->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "gamma_control\n", + "\n", + "gamma_control\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "gamma_control->control_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "fourier_contributions\n", + "\n", + "fourier_contributions\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "fourier_contributions->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "fourier_data\n", + "\n", + "fourier_data\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "fourier_data->fourier_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "gamma_fourier\n", + "\n", + "gamma_fourier\n", + "~\n", + "Laplace\n", + "\n", + "\n", + "\n", + "gamma_fourier->fourier_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 49, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaded_model.graphviz()" + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "id": "cfb64a2c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + "
    \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
      +       "                                fourier_mode: 4, channel: 2, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain                      (chain) int64 0 1 2 3\n",
      +       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
      +       "  * control                    (control) object 'event_1' 'event_2' 't'\n",
      +       "  * fourier_mode               (fourier_mode) object 'sin_order_1' ... 'cos_o...\n",
      +       "  * channel                    (channel) object 'x1' 'x2'\n",
      +       "  * date                       (date) object '2018-04-02' ... '2021-08-30'\n",
      +       "Data variables: (12/13)\n",
      +       "    intercept                  (chain, draw) float64 ...\n",
      +       "    gamma_control              (chain, draw, control) float64 ...\n",
      +       "    gamma_fourier              (chain, draw, fourier_mode) float64 ...\n",
      +       "    beta_channel               (chain, draw, channel) float64 ...\n",
      +       "    alpha                      (chain, draw, channel) float64 ...\n",
      +       "    lam                        (chain, draw, channel) float64 ...\n",
      +       "    ...                         ...\n",
      +       "    channel_adstock            (chain, draw, date, channel) float64 ...\n",
      +       "    channel_adstock_saturated  (chain, draw, date, channel) float64 ...\n",
      +       "    channel_contributions      (chain, draw, date, channel) float64 ...\n",
      +       "    control_contributions      (chain, draw, date, control) float64 ...\n",
      +       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 ...\n",
      +       "    mu                         (chain, draw, date) float64 ...\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:14.027598\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1\n",
      +       "    sampling_time:              29.821417093276978\n",
      +       "    tuning_steps:               1000

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain       (chain) int64 0 1 2 3\n",
      +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      +       "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (chain, draw, date) float64 ...\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:15.416164\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                (chain: 4, draw: 1000)\n",
      +       "Coordinates:\n",
      +       "  * chain                  (chain) int64 0 1 2 3\n",
      +       "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
      +       "Data variables: (12/17)\n",
      +       "    process_time_diff      (chain, draw) float64 ...\n",
      +       "    step_size_bar          (chain, draw) float64 ...\n",
      +       "    step_size              (chain, draw) float64 ...\n",
      +       "    acceptance_rate        (chain, draw) float64 ...\n",
      +       "    index_in_trajectory    (chain, draw) int64 ...\n",
      +       "    tree_depth             (chain, draw) int64 ...\n",
      +       "    ...                     ...\n",
      +       "    energy_error           (chain, draw) float64 ...\n",
      +       "    perf_counter_diff      (chain, draw) float64 ...\n",
      +       "    n_steps                (chain, draw) float64 ...\n",
      +       "    diverging              (chain, draw) bool ...\n",
      +       "    perf_counter_start     (chain, draw) float64 ...\n",
      +       "    lp                     (chain, draw) float64 ...\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:14.040220\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1\n",
      +       "    sampling_time:              29.821417093276978\n",
      +       "    tuning_steps:               1000

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                    (chain: 1, draw: 500, fourier_mode: 4,\n",
      +       "                                date: 179, channel: 2, control: 3)\n",
      +       "Coordinates:\n",
      +       "  * chain                      (chain) int64 0\n",
      +       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
      +       "  * fourier_mode               (fourier_mode) object 'sin_order_1' ... 'cos_o...\n",
      +       "  * date                       (date) object '2018-04-02' ... '2021-08-30'\n",
      +       "  * channel                    (channel) object 'x1' 'x2'\n",
      +       "  * control                    (control) object 'event_1' 'event_2' 't'\n",
      +       "Data variables: (12/13)\n",
      +       "    gamma_fourier              (chain, draw, fourier_mode) float64 ...\n",
      +       "    intercept                  (chain, draw) float64 ...\n",
      +       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 ...\n",
      +       "    mu                         (chain, draw, date) float64 ...\n",
      +       "    channel_contributions      (chain, draw, date, channel) float64 ...\n",
      +       "    control_contributions      (chain, draw, date, control) float64 ...\n",
      +       "    ...                         ...\n",
      +       "    gamma_control              (chain, draw, control) float64 ...\n",
      +       "    channel_adstock            (chain, draw, date, channel) float64 ...\n",
      +       "    alpha                      (chain, draw, channel) float64 ...\n",
      +       "    lam                        (chain, draw, channel) float64 ...\n",
      +       "    channel_adstock_saturated  (chain, draw, date, channel) float64 ...\n",
      +       "    sigma                      (chain, draw) float64 ...\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:15.159471\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain       (chain) int64 0\n",
      +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      +       "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (chain, draw, date) float64 ...\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:15.164125\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (date: 179)\n",
      +       "Coordinates:\n",
      +       "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (date) float64 ...\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:14.043853\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
      +       "Coordinates:\n",
      +       "  * date          (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "  * channel       (channel) object 'x1' 'x2'\n",
      +       "  * control       (control) object 'event_1' 'event_2' 't'\n",
      +       "  * fourier_mode  (fourier_mode) object 'sin_order_1' ... 'cos_order_2'\n",
      +       "Data variables:\n",
      +       "    channel_data  (date, channel) float64 ...\n",
      +       "    target        (date) float64 ...\n",
      +       "    control_data  (date, control) float64 ...\n",
      +       "    fourier_data  (date, fourier_mode) float64 ...\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-03T11:09:14.045029\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:    (index: 179)\n",
      +       "Coordinates:\n",
      +       "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      +       "Data variables:\n",
      +       "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
      +       "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
      +       "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      +       "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      +       "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
      +       "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      +       "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> posterior_predictive\n", + "\t> sample_stats\n", + "\t> prior\n", + "\t> prior_predictive\n", + "\t> observed_data\n", + "\t> constant_data\n", + "\t> fit_data" + ] + }, + "execution_count": 50, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaded_model.idata" + ] + }, + { + "cell_type": "markdown", + "id": "ab64be46-7fe5-4f39-b72b-36da1419f809", + "metadata": {}, + "source": [ + "A model loaded in this way is ready to be used for sampling and prediction, and has access to all previous samples and data." + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "id": "dd59d056-6ac7-431c-85ee-99e7a8eefd8a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling: [likelihood]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " 100.00% [4000/4000 00:00<00:00]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset>\n",
+       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
+       "Coordinates:\n",
+       "  * chain       (chain) int64 0 1 2 3\n",
+       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
+       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
+       "Data variables:\n",
+       "    likelihood  (chain, draw, date) float64 0.4907 0.4282 ... 0.5548 0.5396\n",
+       "Attributes:\n",
+       "    created_at:                 2023-08-03T11:09:22.139450\n",
+       "    arviz_version:              0.16.1\n",
+       "    inference_library:          pymc\n",
+       "    inference_library_version:  5.6.1
" + ], + "text/plain": [ + "\n", + "Dimensions: (chain: 4, draw: 1000, date: 179)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n", + " * date (date) " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "az.plot_ppc(loaded_model.idata);" + ] + }, + { + "cell_type": "markdown", + "id": "e8e807f9", + "metadata": {}, + "source": [ + "## Summary:" + ] + }, + { + "cell_type": "markdown", + "id": "61f232c1", + "metadata": {}, + "source": [ + "In summary, this article introduces the revolutionary ModelBuilder, a new PyMC-experimental module that simplifies the deployment of PyMC Bayesian models. It addresses a historic challenge faced by users of PyMC and most PPLs by offering a user-friendly and efficient approach to model deployment. The ModelBuilder provides two straightforward methods, save() and load(), which streamline the model preservation and replication process post fitting. Users are offered flexibility in controlling the prior settings with model_config and customizing the sampling process via sampler_config.\n", + "\n", + "The use of an example model from the MMM Example Notebook demonstrates the practical implementation of ModelBuilder, emphasizing its ability to enhance model sharing among teams without the necessity for extensive domain knowledge about the model. The deployment improvements in PyMC-Marketing brought about by ModelBuilder are not only user-friendly but also significantly enhance efficiency, making PyMC models more accessible for a wider audience." + ] + }, + { + "cell_type": "markdown", + "id": "b8ad333d", + "metadata": {}, + "source": [ + "Even though this introduction is using `DelayedSaturatedMMM`, functionalities from `ModelBuilder` are available in the CLV models as well." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc-marketing", + "language": "python", + "name": "pymc-marketing" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 414da20cb7ce5c8cdb07adc4f00ff1be6d72ada7 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Thu, 17 Aug 2023 16:11:13 +0200 Subject: [PATCH 2/5] notebook formatting --- .../howto/ModelBuilder_usage_example.ipynb | 17466 ++++++++++++++++ .../howto/ModelBuilder_usage_example.myst.md | 279 + 2 files changed, 17745 insertions(+) create mode 100644 examples/howto/ModelBuilder_usage_example.ipynb create mode 100644 examples/howto/ModelBuilder_usage_example.myst.md diff --git a/examples/howto/ModelBuilder_usage_example.ipynb b/examples/howto/ModelBuilder_usage_example.ipynb new file mode 100644 index 000000000..d9a409205 --- /dev/null +++ b/examples/howto/ModelBuilder_usage_example.ipynb @@ -0,0 +1,17466 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "ad9920da-e687-408d-b99f-060a99c0b561", + "metadata": {}, + "source": [ + "(ModelBuilder usage example)=\n", + "# ModelBuilder usage example\n", + "\n", + ":::{post} Aug 18, 2023\n", + ":tags: ModelBuilder, model deployment,\n", + ":category: intermediate, tutorial\n", + ":author: Michał Raczycki\n", + ":::" + ] + }, + { + "cell_type": "markdown", + "id": "1d74584c", + "metadata": {}, + "source": [ + "# Deploying MMMs and CLVs in Production: Saving and Loading Models" + ] + }, + { + "cell_type": "markdown", + "id": "3222ef04", + "metadata": {}, + "source": [ + "In this article, we'll tackle the historically challenging process of deploying Bayesian models built with PyMC. Introducing a revolutionary deployment module, we bring unprecedented simplicity and efficiency to the deployment of PyMC models. As we prioritize user-friendly solutions, let's delve into how this innovation can significantly elevate your data science projects." + ] + }, + { + "cell_type": "markdown", + "id": "ddb28436", + "metadata": {}, + "source": [ + "\n", + "Recent release of PyMC-Marketing by [Labs](https://www.pymc-labs.io) proves to be a big hit [(PyMC-Marketing)](https://www.pymc-labs.io/blog-posts/pymc-marketing-a-bayesian-approach-to-marketing-data-science/). In the feedback one could see an ongoing theme, many of you have been requesting easy and robust way of deploying models to production. It’s been a long-standing problem with PyMC ( and most other Probabilistic Programming Languages). The reason for that is that there’s no obvious way, and doesn’t matter which approach you try it proves to be tricky. That is why we’re happy to announce the release of `ModelBuilder`, brand new PyMC-experimental module that addresses this need, and improves on the deployment process significantly.\n", + "\n", + "The ModelBuilder module is a new feature of PyMC based models. It provides 2 easy-to-use methods: save() and load() that can be used after the model has been fit.save() allow easy preservation of the model to .netcdf format, and load() gives one-line replication of the original model. Users can control the prior settings with model_config, and customize the sampling process using sampler_config. Default values of those are working just fine, so first time give it a try without changing, and provide your own model_config and model_sampler if afterwards you want to try to customize it more for your use case!\n" + ] + }, + { + "cell_type": "markdown", + "id": "a808e36a", + "metadata": {}, + "source": [ + "For this notebook I'll use the example model used in [MMM Example Notebook](https://www.pymc-marketing.io/en/stable/notebooks/mmm/mmm_example.html), but ommit the details of data generation and plotting functionalities, since they're out of scope for this introduction, I highly recommend to see that part as well, but for now let's focus on today's topic: Groundbreaking deployment improvements in PyMC-Marketing!" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1050a937", + "metadata": {}, + "outputs": [], + "source": [ + "import arviz as az\n", + "import numpy as np\n", + "import pandas as pd\n", + "\n", + "from pymc_marketing.mmm import DelayedSaturatedMMM" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e2c3f4c4-1d74-4ae2-9ff2-c13cbcf7fe54", + "metadata": {}, + "outputs": [], + "source": [ + "az.style.use(\"arviz-darkgrid\")" + ] + }, + { + "cell_type": "markdown", + "id": "f37d808e", + "metadata": {}, + "source": [ + "Let's load the dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b7b1193f", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
date_weekyx1x2event_1event_2dayofyeartsin_order_1cos_order_1sin_order_2cos_order_2
02018-04-023984.6622370.3185800.0000000.00.09200.999930-0.011826-0.023651-0.999720
12018-04-093762.8717940.1123880.0000000.00.09910.991269-0.131859-0.261414-0.965227
22018-04-164466.9673880.2924000.0000000.00.010620.968251-0.249981-0.484089-0.875019
32018-04-233864.2193730.0713990.0000000.00.011330.931210-0.364483-0.678820-0.734304
42018-04-304441.6252780.3867450.0000000.00.012040.880683-0.473706-0.834370-0.551205
.......................................
1742021-08-023553.5461480.0330240.0000000.00.0214174-0.513901-0.8578490.8816990.471812
1752021-08-095565.5096820.1656150.8633490.00.0221175-0.613230-0.7899050.9687860.247898
1762021-08-164137.6514850.1718820.0000000.00.0228176-0.703677-0.7105200.9999530.009676
1772021-08-234479.0413510.2802570.0000000.00.0235177-0.783934-0.6208440.973402-0.229104
1782021-08-304675.9734390.4388570.0000000.00.0242178-0.852837-0.5221780.890665-0.454661
\n", + "

179 rows × 12 columns

\n", + "
" + ], + "text/plain": [ + " date_week y x1 x2 event_1 event_2 dayofyear \\\n", + "0 2018-04-02 3984.662237 0.318580 0.000000 0.0 0.0 92 \n", + "1 2018-04-09 3762.871794 0.112388 0.000000 0.0 0.0 99 \n", + "2 2018-04-16 4466.967388 0.292400 0.000000 0.0 0.0 106 \n", + "3 2018-04-23 3864.219373 0.071399 0.000000 0.0 0.0 113 \n", + "4 2018-04-30 4441.625278 0.386745 0.000000 0.0 0.0 120 \n", + ".. ... ... ... ... ... ... ... \n", + "174 2021-08-02 3553.546148 0.033024 0.000000 0.0 0.0 214 \n", + "175 2021-08-09 5565.509682 0.165615 0.863349 0.0 0.0 221 \n", + "176 2021-08-16 4137.651485 0.171882 0.000000 0.0 0.0 228 \n", + "177 2021-08-23 4479.041351 0.280257 0.000000 0.0 0.0 235 \n", + "178 2021-08-30 4675.973439 0.438857 0.000000 0.0 0.0 242 \n", + "\n", + " t sin_order_1 cos_order_1 sin_order_2 cos_order_2 \n", + "0 0 0.999930 -0.011826 -0.023651 -0.999720 \n", + "1 1 0.991269 -0.131859 -0.261414 -0.965227 \n", + "2 2 0.968251 -0.249981 -0.484089 -0.875019 \n", + "3 3 0.931210 -0.364483 -0.678820 -0.734304 \n", + "4 4 0.880683 -0.473706 -0.834370 -0.551205 \n", + ".. ... ... ... ... ... \n", + "174 174 -0.513901 -0.857849 0.881699 0.471812 \n", + "175 175 -0.613230 -0.789905 0.968786 0.247898 \n", + "176 176 -0.703677 -0.710520 0.999953 0.009676 \n", + "177 177 -0.783934 -0.620844 0.973402 -0.229104 \n", + "178 178 -0.852837 -0.522178 0.890665 -0.454661 \n", + "\n", + "[179 rows x 12 columns]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "url = \"https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/datasets/mmm_example.csv\"\n", + "df = pd.read_csv(url)\n", + "df" + ] + }, + { + "cell_type": "markdown", + "id": "87deb70d", + "metadata": {}, + "source": [ + "But for our model we need much smaller dataset, many of the previous features were contributing to generation of others, now as our target variable is computed we can filter out not needed columns:" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "52b6d127", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
date_weekyx1x2event_1event_2dayofyeart
02018-04-023984.6622370.3185800.00.00.0920
12018-04-093762.8717940.1123880.00.00.0991
22018-04-164466.9673880.2924000.00.00.01062
32018-04-233864.2193730.0713990.00.00.01133
42018-04-304441.6252780.3867450.00.00.01204
\n", + "
" + ], + "text/plain": [ + " date_week y x1 x2 event_1 event_2 dayofyear t\n", + "0 2018-04-02 3984.662237 0.318580 0.0 0.0 0.0 92 0\n", + "1 2018-04-09 3762.871794 0.112388 0.0 0.0 0.0 99 1\n", + "2 2018-04-16 4466.967388 0.292400 0.0 0.0 0.0 106 2\n", + "3 2018-04-23 3864.219373 0.071399 0.0 0.0 0.0 113 3\n", + "4 2018-04-30 4441.625278 0.386745 0.0 0.0 0.0 120 4" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "columns_to_keep = [\n", + " \"date_week\",\n", + " \"y\",\n", + " \"x1\",\n", + " \"x2\",\n", + " \"event_1\",\n", + " \"event_2\",\n", + " \"dayofyear\",\n", + "]\n", + "seed: int = sum(map(ord, \"mmm\"))\n", + "rng = np.random.default_rng(seed=seed)\n", + "\n", + "data = df[columns_to_keep].copy()\n", + "\n", + "data[\"t\"] = range(df.shape[0])\n", + "data.head()" + ] + }, + { + "cell_type": "markdown", + "id": "9518a885", + "metadata": {}, + "source": [ + "## _Model Creation_\n", + "After we have our dataset ready, we could proceed straight to our model definition, but first to show the full potential of one of the new features: `model_config` we need to use some of our data to define our prior for sigma parameter for each of the channels. `model_config` is a customizable dictionary with keys corresponding to priors within the model, and values containing a dictionaries with parameters necessary to initialize them. Later on we'll learn that through the `save()` method we can preserve our priors contained inside the `model_config`, to allow complete replication of our model." + ] + }, + { + "cell_type": "markdown", + "id": "4b52b2c1", + "metadata": {}, + "source": [ + "### model_config" + ] + }, + { + "cell_type": "markdown", + "id": "41021a72", + "metadata": {}, + "source": [ + "`default_model_config` attribute of every model inheriting from `ModelBuilder` will allow you to see which priors are available for customization. To see it simply initialize a dummy model:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "284bd558", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'intercept': {'mu': 0, 'sigma': 2},\n", + " 'beta_channel': {'sigma': 2, 'dims': ('channel',)},\n", + " 'alpha': {'alpha': 1, 'beta': 3, 'dims': ('channel',)},\n", + " 'lam': {'alpha': 3, 'beta': 1, 'dims': ('channel',)},\n", + " 'sigma': {'sigma': 2},\n", + " 'gamma_control': {'mu': 0, 'sigma': 2, 'dims': ('control',)},\n", + " 'mu': {'dims': ('date',)},\n", + " 'likelihood': {'dims': ('date',)},\n", + " 'gamma_fourier': {'mu': 0, 'b': 1, 'dims': 'fourier_mode'}}" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dummy_model = DelayedSaturatedMMM(date_column=\"\", channel_columns=\"\", adstock_max_lag=4)\n", + "dummy_model.default_model_config" + ] + }, + { + "cell_type": "markdown", + "id": "f0fd248f", + "metadata": {}, + "source": [ + "You can change only the prior parameters that you wish, no need to alter all of them, unless you'd like to!\n", + "In this case we'll just simply replace our sigma for beta_channel with our computed one:" + ] + }, + { + "cell_type": "markdown", + "id": "19f075f0-4d3d-4509-a9c6-f15efdb9293d", + "metadata": {}, + "source": [ + "First, let's compute the share of spend per channel:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "4785596a-e333-4cd0-af15-1332e97b66d5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "x1 0.65632\n", + "x2 0.34368\n", + "dtype: float64" + ] + }, + "execution_count": 6, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "total_spend_per_channel = data[[\"x1\", \"x2\"]].sum(axis=0)\n", + "\n", + "spend_share = total_spend_per_channel / total_spend_per_channel.sum()\n", + "\n", + "spend_share" + ] + }, + { + "cell_type": "markdown", + "id": "40d17642-1e21-4adc-97f4-633eede87915", + "metadata": {}, + "source": [ + "Next, we specify the `sigma`parameter per channel:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "672b36bd-3d08-46df-85b8-d67d3ade75d7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[2.1775326025486734, 1.140260877391939]" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# The scale necessary to make a HalfNormal distribution have unit variance\n", + "HALFNORMAL_SCALE = 1 / np.sqrt(1 - 2 / np.pi)\n", + "\n", + "n_channels = 2\n", + "\n", + "prior_sigma = HALFNORMAL_SCALE * n_channels * spend_share.to_numpy()\n", + "\n", + "prior_sigma.tolist()" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "bac9f587", + "metadata": {}, + "outputs": [], + "source": [ + "custom_beta_channel_prior = {\"beta_channel\": {\"sigma\": prior_sigma, \"dims\": (\"channel\",)}}\n", + "my_model_config = dummy_model.default_model_config | custom_beta_channel_prior" + ] + }, + { + "cell_type": "markdown", + "id": "1aa435bf", + "metadata": {}, + "source": [ + "As mentioned in the original notebook: \"_For the prior specification there is no right or wrong answer. It all depends on the data, the context and the assumptions you are willing to make. It is always recommended to do some prior predictive sampling and sensitivity analysis to check the impact of the priors on the posterior. We skip this here for the sake of simplicity. If you are not sure about specific priors, the `DelayedSaturatedMMM` class has some default priors that you can use as a starting point._\"" + ] + }, + { + "cell_type": "markdown", + "id": "f195a79e", + "metadata": {}, + "source": [ + "The second feature that we can use for model definition is `sampler_config`. Similar to `model_config`, it's a dictionary that gets saved and contains things you'd usually pass to the `fit()` kwargs. It's not mandatory to create your own `sampler_config`; if not provided, both `model_config` and `sampler_config` will default to the forms specified by PyMC Labs experts, which allows for the usage of all model functionalities. The default `sampler_config` is left empty because the default sampling parameters usually prove sufficient for a start." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "0ab8140c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dummy_model.default_sampler_config" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "bf5a50f4", + "metadata": {}, + "outputs": [], + "source": [ + "my_sampler_config = {\n", + " \"tune\": 1000,\n", + " \"draws\": 1000,\n", + " \"chains\": 4,\n", + " \"target_accept\": 0.95,\n", + "}" + ] + }, + { + "cell_type": "markdown", + "id": "f3bfe090", + "metadata": {}, + "source": [ + "Let's finally assemble our model!" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c7bd6909", + "metadata": {}, + "outputs": [], + "source": [ + "mmm = DelayedSaturatedMMM(\n", + " model_config=my_model_config,\n", + " sampler_config=my_sampler_config,\n", + " date_column=\"date_week\",\n", + " channel_columns=[\"x1\", \"x2\"],\n", + " control_columns=[\n", + " \"event_1\",\n", + " \"event_2\",\n", + " \"t\",\n", + " ],\n", + " adstock_max_lag=8,\n", + " yearly_seasonality=2,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "54095b1a", + "metadata": {}, + "source": [ + "An important thing to note here is that in the new version of `DelayedSaturatedMMM`, we don't pass our dataset to the class constructor itself. This is due to a reason I've mentioned before - it supports `sklearn` transformers and validations that require a usual X, y split and typically expect the data to be passed to the `fit()` method." + ] + }, + { + "cell_type": "markdown", + "id": "dec9b1b0", + "metadata": {}, + "source": [ + "## _Model Fitting_" + ] + }, + { + "cell_type": "markdown", + "id": "d5e64562-ba78-4497-a0f8-123b4bc88b79", + "metadata": {}, + "source": [ + "Let's split the dataset:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "ff23006b-a55b-4a22-9f34-4eeaddf47486", + "metadata": {}, + "outputs": [], + "source": [ + "X = data.drop(\"y\", axis=1)\n", + "y = data[\"y\"]" + ] + }, + { + "cell_type": "markdown", + "id": "403e3ed2", + "metadata": {}, + "source": [ + "All that's left now is to finally fit the model:\n", + "\n", + "As you can see below, you can still pass the sampler kwargs directly to `fit()` method. However, only those kwargs passed using `sampler_config` will be saved. Therefore, only these will be available after loading the model." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "0f6ab0a8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Auto-assigning NUTS sampler...\n", + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (4 chains in 4 jobs)\n", + "NUTS: [intercept, beta_channel, alpha, lam, sigma, gamma_control, gamma_fourier]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " 100.00% [8000/8000 00:27<00:00 Sampling 4 chains, 0 divergences]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 28 seconds.\n", + "Sampling: [alpha, beta_channel, gamma_control, gamma_fourier, intercept, lam, likelihood, sigma]\n", + "Sampling: [likelihood]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " 100.00% [4000/4000 00:00<00:00]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + "
    \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
      +       "                                fourier_mode: 4, channel: 2, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain                      (chain) int64 0 1 2 3\n",
      +       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
      +       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      +       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      +       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      +       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      +       "Data variables: (12/13)\n",
      +       "    intercept                  (chain, draw) float64 0.3381 0.3361 ... 0.3479\n",
      +       "    gamma_control              (chain, draw, control) float64 0.2968 ... 0.00...\n",
      +       "    gamma_fourier              (chain, draw, fourier_mode) float64 -0.002612 ...\n",
      +       "    beta_channel               (chain, draw, channel) float64 0.3718 ... 0.2628\n",
      +       "    alpha                      (chain, draw, channel) float64 0.4162 ... 0.2007\n",
      +       "    lam                        (chain, draw, channel) float64 4.26 ... 2.758\n",
      +       "    ...                         ...\n",
      +       "    channel_adstock            (chain, draw, date, channel) float64 0.1868 .....\n",
      +       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3781 .....\n",
      +       "    channel_contributions      (chain, draw, date, channel) float64 0.1406 .....\n",
      +       "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
      +       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 -0.0...\n",
      +       "    mu                         (chain, draw, date) float64 0.4769 ... 0.5924\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:12.916984\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1\n",
      +       "    sampling_time:              27.989259004592896\n",
      +       "    tuning_steps:               1000

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain       (chain) int64 0 1 2 3\n",
      +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      +       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (chain, draw, date) float64 0.5148 0.4369 ... 0.4858 0.6176\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:14.741457\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                (chain: 4, draw: 1000)\n",
      +       "Coordinates:\n",
      +       "  * chain                  (chain) int64 0 1 2 3\n",
      +       "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
      +       "Data variables: (12/17)\n",
      +       "    step_size              (chain, draw) float64 0.05862 0.05862 ... 0.07543\n",
      +       "    lp                     (chain, draw) float64 355.1 355.2 ... 355.6 351.6\n",
      +       "    n_steps                (chain, draw) float64 63.0 63.0 63.0 ... 63.0 63.0\n",
      +       "    largest_eigval         (chain, draw) float64 nan nan nan nan ... nan nan nan\n",
      +       "    energy                 (chain, draw) float64 -344.2 -349.9 ... -350.2 -347.5\n",
      +       "    acceptance_rate        (chain, draw) float64 0.9987 0.9934 ... 0.8931 0.9167\n",
      +       "    ...                     ...\n",
      +       "    diverging              (chain, draw) bool False False False ... False False\n",
      +       "    smallest_eigval        (chain, draw) float64 nan nan nan nan ... nan nan nan\n",
      +       "    max_energy_error       (chain, draw) float64 -0.06059 -0.06074 ... 0.284\n",
      +       "    index_in_trajectory    (chain, draw) int64 -12 20 43 -17 ... 42 20 -43 -29\n",
      +       "    energy_error           (chain, draw) float64 -0.01181 -0.02828 ... -0.01478\n",
      +       "    tree_depth             (chain, draw) int64 6 6 6 6 7 6 6 7 ... 6 5 5 6 7 6 6\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:12.932375\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1\n",
      +       "    sampling_time:              27.989259004592896\n",
      +       "    tuning_steps:               1000

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                    (chain: 1, draw: 500, channel: 2, control: 3,\n",
      +       "                                date: 179, fourier_mode: 4)\n",
      +       "Coordinates:\n",
      +       "  * chain                      (chain) int64 0\n",
      +       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
      +       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      +       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      +       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      +       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      +       "Data variables: (12/13)\n",
      +       "    beta_channel               (chain, draw, channel) float64 0.6143 ... 0.9639\n",
      +       "    sigma                      (chain, draw) float64 2.034 5.255 ... 0.07125\n",
      +       "    gamma_control              (chain, draw, control) float64 -0.9767 ... 0.0...\n",
      +       "    control_contributions      (chain, draw, date, control) float64 -0.0 ... ...\n",
      +       "    mu                         (chain, draw, date) float64 5.528 7.076 ... 8.396\n",
      +       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3626 .....\n",
      +       "    ...                         ...\n",
      +       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 1.64...\n",
      +       "    lam                        (chain, draw, channel) float64 2.667 ... 1.438\n",
      +       "    intercept                  (chain, draw) float64 2.51 -3.898 ... -2.654\n",
      +       "    gamma_fourier              (chain, draw, fourier_mode) float64 1.647 ... ...\n",
      +       "    channel_adstock            (chain, draw, date, channel) float64 0.2848 .....\n",
      +       "    alpha                      (chain, draw, channel) float64 0.1089 ... 0.2786\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:14.422347\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain       (chain) int64 0\n",
      +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      +       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (chain, draw, date) float64 5.221 8.352 5.313 ... 7.793 8.528\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:14.428434\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (date: 179)\n",
      +       "Coordinates:\n",
      +       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:12.937052\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
      +       "Coordinates:\n",
      +       "  * date          (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "  * channel       (channel) <U2 'x1' 'x2'\n",
      +       "  * control       (control) <U7 'event_1' 'event_2' 't'\n",
      +       "  * fourier_mode  (fourier_mode) <U11 'sin_order_1' ... 'cos_order_2'\n",
      +       "Data variables:\n",
      +       "    channel_data  (date, channel) float64 0.3196 0.0 0.1128 ... 0.0 0.4403 0.0\n",
      +       "    target        (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      +       "    control_data  (date, control) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0\n",
      +       "    fourier_data  (date, fourier_mode) float64 0.9999 -0.01183 ... -0.4547\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:12.938885\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:    (index: 179)\n",
      +       "Coordinates:\n",
      +       "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      +       "Data variables:\n",
      +       "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
      +       "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
      +       "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      +       "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      +       "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
      +       "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      +       "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> posterior_predictive\n", + "\t> sample_stats\n", + "\t> prior\n", + "\t> prior_predictive\n", + "\t> observed_data\n", + "\t> constant_data\n", + "\t> fit_data" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mmm.fit(X=X, y=y, random_seed=rng)" + ] + }, + { + "cell_type": "markdown", + "id": "c29a6461", + "metadata": {}, + "source": [ + "The `fit()` method automatically builds the model using the priors from `model_config`, and assigns the created model to our instance. You can access it as a normal attribute." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "c6b8e2af", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "pymc.model.Model" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(mmm.model)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "f046ee2c", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "clusterdate (179) x channel (2)\n", + "\n", + "date (179) x channel (2)\n", + "\n", + "\n", + "clusterdate (179)\n", + "\n", + "date (179)\n", + "\n", + "\n", + "clusterchannel (2)\n", + "\n", + "channel (2)\n", + "\n", + "\n", + "clusterdate (179) x control (3)\n", + "\n", + "date (179) x control (3)\n", + "\n", + "\n", + "clustercontrol (3)\n", + "\n", + "control (3)\n", + "\n", + "\n", + "clusterdate (179) x fourier_mode (4)\n", + "\n", + "date (179) x fourier_mode (4)\n", + "\n", + "\n", + "clusterfourier_mode (4)\n", + "\n", + "fourier_mode (4)\n", + "\n", + "\n", + "\n", + "channel_adstock_saturated\n", + "\n", + "channel_adstock_saturated\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "channel_contributions\n", + "\n", + "channel_contributions\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "channel_adstock_saturated->channel_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "mu\n", + "\n", + "mu\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "channel_contributions->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "channel_data\n", + "\n", + "channel_data\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "channel_adstock\n", + "\n", + "channel_adstock\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "channel_data->channel_adstock\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "channel_adstock->channel_adstock_saturated\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "likelihood\n", + "\n", + "likelihood\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "target\n", + "\n", + "target\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "likelihood->target\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "mu->likelihood\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "sigma\n", + "\n", + "sigma\n", + "~\n", + "HalfNormal\n", + "\n", + "\n", + "\n", + "sigma->likelihood\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "intercept\n", + "\n", + "intercept\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "intercept->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "beta_channel\n", + "\n", + "beta_channel\n", + "~\n", + "HalfNormal\n", + "\n", + "\n", + "\n", + "beta_channel->channel_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "alpha\n", + "\n", + "alpha\n", + "~\n", + "Beta\n", + "\n", + "\n", + "\n", + "alpha->channel_adstock\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "lam\n", + "\n", + "lam\n", + "~\n", + "Gamma\n", + "\n", + "\n", + "\n", + "lam->channel_adstock_saturated\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "control_data\n", + "\n", + "control_data\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "control_contributions\n", + "\n", + "control_contributions\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "control_data->control_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "control_contributions->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "gamma_control\n", + "\n", + "gamma_control\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "gamma_control->control_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "fourier_data\n", + "\n", + "fourier_data\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "fourier_contributions\n", + "\n", + "fourier_contributions\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "fourier_data->fourier_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "fourier_contributions->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "gamma_fourier\n", + "\n", + "gamma_fourier\n", + "~\n", + "Laplace\n", + "\n", + "\n", + "\n", + "gamma_fourier->fourier_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mmm.graphviz()" + ] + }, + { + "cell_type": "markdown", + "id": "c804b600", + "metadata": {}, + "source": [ + "posterior trace can be accessed by `fit_result` attribute:" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "66903965", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset>\n",
+       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
+       "                                fourier_mode: 4, channel: 2, date: 179)\n",
+       "Coordinates:\n",
+       "  * chain                      (chain) int64 0 1 2 3\n",
+       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
+       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
+       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
+       "  * channel                    (channel) <U2 'x1' 'x2'\n",
+       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
+       "Data variables: (12/13)\n",
+       "    intercept                  (chain, draw) float64 0.3381 0.3361 ... 0.3479\n",
+       "    gamma_control              (chain, draw, control) float64 0.2968 ... 0.00...\n",
+       "    gamma_fourier              (chain, draw, fourier_mode) float64 -0.002612 ...\n",
+       "    beta_channel               (chain, draw, channel) float64 0.3718 ... 0.2628\n",
+       "    alpha                      (chain, draw, channel) float64 0.4162 ... 0.2007\n",
+       "    lam                        (chain, draw, channel) float64 4.26 ... 2.758\n",
+       "    ...                         ...\n",
+       "    channel_adstock            (chain, draw, date, channel) float64 0.1868 .....\n",
+       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3781 .....\n",
+       "    channel_contributions      (chain, draw, date, channel) float64 0.1406 .....\n",
+       "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
+       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 -0.0...\n",
+       "    mu                         (chain, draw, date) float64 0.4769 ... 0.5924\n",
+       "Attributes:\n",
+       "    created_at:                 2023-08-17T13:57:12.916984\n",
+       "    arviz_version:              0.16.1\n",
+       "    inference_library:          pymc\n",
+       "    inference_library_version:  5.6.1\n",
+       "    sampling_time:              27.989259004592896\n",
+       "    tuning_steps:               1000
" + ], + "text/plain": [ + "\n", + "Dimensions: (chain: 4, draw: 1000, control: 3,\n", + " fourier_mode: 4, channel: 2, date: 179)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n", + " * control (control) \n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + "
    \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
      +       "                                fourier_mode: 4, channel: 2, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain                      (chain) int64 0 1 2 3\n",
      +       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
      +       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      +       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      +       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      +       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      +       "Data variables: (12/13)\n",
      +       "    intercept                  (chain, draw) float64 0.3381 0.3361 ... 0.3479\n",
      +       "    gamma_control              (chain, draw, control) float64 0.2968 ... 0.00...\n",
      +       "    gamma_fourier              (chain, draw, fourier_mode) float64 -0.002612 ...\n",
      +       "    beta_channel               (chain, draw, channel) float64 0.3718 ... 0.2628\n",
      +       "    alpha                      (chain, draw, channel) float64 0.4162 ... 0.2007\n",
      +       "    lam                        (chain, draw, channel) float64 4.26 ... 2.758\n",
      +       "    ...                         ...\n",
      +       "    channel_adstock            (chain, draw, date, channel) float64 0.1868 .....\n",
      +       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3781 .....\n",
      +       "    channel_contributions      (chain, draw, date, channel) float64 0.1406 .....\n",
      +       "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
      +       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 -0.0...\n",
      +       "    mu                         (chain, draw, date) float64 0.4769 ... 0.5924\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:12.916984\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1\n",
      +       "    sampling_time:              27.989259004592896\n",
      +       "    tuning_steps:               1000

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain       (chain) int64 0 1 2 3\n",
      +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      +       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (chain, draw, date) float64 0.5148 0.4369 ... 0.4858 0.6176\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:14.741457\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                (chain: 4, draw: 1000)\n",
      +       "Coordinates:\n",
      +       "  * chain                  (chain) int64 0 1 2 3\n",
      +       "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
      +       "Data variables: (12/17)\n",
      +       "    step_size              (chain, draw) float64 0.05862 0.05862 ... 0.07543\n",
      +       "    lp                     (chain, draw) float64 355.1 355.2 ... 355.6 351.6\n",
      +       "    n_steps                (chain, draw) float64 63.0 63.0 63.0 ... 63.0 63.0\n",
      +       "    largest_eigval         (chain, draw) float64 nan nan nan nan ... nan nan nan\n",
      +       "    energy                 (chain, draw) float64 -344.2 -349.9 ... -350.2 -347.5\n",
      +       "    acceptance_rate        (chain, draw) float64 0.9987 0.9934 ... 0.8931 0.9167\n",
      +       "    ...                     ...\n",
      +       "    diverging              (chain, draw) bool False False False ... False False\n",
      +       "    smallest_eigval        (chain, draw) float64 nan nan nan nan ... nan nan nan\n",
      +       "    max_energy_error       (chain, draw) float64 -0.06059 -0.06074 ... 0.284\n",
      +       "    index_in_trajectory    (chain, draw) int64 -12 20 43 -17 ... 42 20 -43 -29\n",
      +       "    energy_error           (chain, draw) float64 -0.01181 -0.02828 ... -0.01478\n",
      +       "    tree_depth             (chain, draw) int64 6 6 6 6 7 6 6 7 ... 6 5 5 6 7 6 6\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:12.932375\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1\n",
      +       "    sampling_time:              27.989259004592896\n",
      +       "    tuning_steps:               1000

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                    (chain: 1, draw: 500, channel: 2, control: 3,\n",
      +       "                                date: 179, fourier_mode: 4)\n",
      +       "Coordinates:\n",
      +       "  * chain                      (chain) int64 0\n",
      +       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
      +       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      +       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      +       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      +       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      +       "Data variables: (12/13)\n",
      +       "    beta_channel               (chain, draw, channel) float64 0.6143 ... 0.9639\n",
      +       "    sigma                      (chain, draw) float64 2.034 5.255 ... 0.07125\n",
      +       "    gamma_control              (chain, draw, control) float64 -0.9767 ... 0.0...\n",
      +       "    control_contributions      (chain, draw, date, control) float64 -0.0 ... ...\n",
      +       "    mu                         (chain, draw, date) float64 5.528 7.076 ... 8.396\n",
      +       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3626 .....\n",
      +       "    ...                         ...\n",
      +       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 1.64...\n",
      +       "    lam                        (chain, draw, channel) float64 2.667 ... 1.438\n",
      +       "    intercept                  (chain, draw) float64 2.51 -3.898 ... -2.654\n",
      +       "    gamma_fourier              (chain, draw, fourier_mode) float64 1.647 ... ...\n",
      +       "    channel_adstock            (chain, draw, date, channel) float64 0.2848 .....\n",
      +       "    alpha                      (chain, draw, channel) float64 0.1089 ... 0.2786\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:14.422347\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain       (chain) int64 0\n",
      +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      +       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (chain, draw, date) float64 5.221 8.352 5.313 ... 7.793 8.528\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:14.428434\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (date: 179)\n",
      +       "Coordinates:\n",
      +       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:12.937052\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
      +       "Coordinates:\n",
      +       "  * date          (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "  * channel       (channel) <U2 'x1' 'x2'\n",
      +       "  * control       (control) <U7 'event_1' 'event_2' 't'\n",
      +       "  * fourier_mode  (fourier_mode) <U11 'sin_order_1' ... 'cos_order_2'\n",
      +       "Data variables:\n",
      +       "    channel_data  (date, channel) float64 0.3196 0.0 0.1128 ... 0.0 0.4403 0.0\n",
      +       "    target        (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      +       "    control_data  (date, control) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0\n",
      +       "    fourier_data  (date, fourier_mode) float64 0.9999 -0.01183 ... -0.4547\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:12.938885\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:    (index: 179)\n",
      +       "Coordinates:\n",
      +       "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      +       "Data variables:\n",
      +       "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
      +       "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
      +       "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      +       "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      +       "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
      +       "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      +       "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
\n", + " \n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> posterior_predictive\n", + "\t> sample_stats\n", + "\t> prior\n", + "\t> prior_predictive\n", + "\t> observed_data\n", + "\t> constant_data\n", + "\t> fit_data" + ] + }, + "execution_count": 17, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "mmm.idata" + ] + }, + { + "cell_type": "markdown", + "id": "8b433c7f-0f0d-40b2-bcfb-a19555b528bd", + "metadata": {}, + "source": [ + "## `Save` and `load`" + ] + }, + { + "cell_type": "markdown", + "id": "7b0a35f4", + "metadata": {}, + "source": [ + "All the data passed to the model on initialisation is stored in `idata.attrs`. This will be used later in the `save()` method to convert both this data and all the fit data into the netCDF format." + ] + }, + { + "cell_type": "markdown", + "id": "45948f46", + "metadata": {}, + "source": [ + "Simply specify the path to which you'd like to save your model:" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "id": "b3abe93a", + "metadata": {}, + "outputs": [], + "source": [ + "mmm.save(\"my_saved_model.nc\")" + ] + }, + { + "cell_type": "markdown", + "id": "8a5eba79", + "metadata": {}, + "source": [ + "And pass it to the `load()` method when it's needed again on the target system:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "0421bae8", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/michalraczycki/Documents/pymc-marketing/.conda/envs/pymc-marketing/lib/python3.10/site-packages/arviz/data/inference_data.py:153: UserWarning: fit_data group is not defined in the InferenceData scheme\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "loaded_model = DelayedSaturatedMMM.load(\"my_saved_model.nc\")" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "id": "a8b666d3", + "metadata": {}, + "outputs": [ + { + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "clusterdate (179) x channel (2)\n", + "\n", + "date (179) x channel (2)\n", + "\n", + "\n", + "clusterdate (179)\n", + "\n", + "date (179)\n", + "\n", + "\n", + "clusterchannel (2)\n", + "\n", + "channel (2)\n", + "\n", + "\n", + "clusterdate (179) x control (3)\n", + "\n", + "date (179) x control (3)\n", + "\n", + "\n", + "clustercontrol (3)\n", + "\n", + "control (3)\n", + "\n", + "\n", + "clusterdate (179) x fourier_mode (4)\n", + "\n", + "date (179) x fourier_mode (4)\n", + "\n", + "\n", + "clusterfourier_mode (4)\n", + "\n", + "fourier_mode (4)\n", + "\n", + "\n", + "\n", + "channel_adstock_saturated\n", + "\n", + "channel_adstock_saturated\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "channel_contributions\n", + "\n", + "channel_contributions\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "channel_adstock_saturated->channel_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "mu\n", + "\n", + "mu\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "channel_contributions->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "channel_data\n", + "\n", + "channel_data\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "channel_adstock\n", + "\n", + "channel_adstock\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "channel_data->channel_adstock\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "channel_adstock->channel_adstock_saturated\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "likelihood\n", + "\n", + "likelihood\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "target\n", + "\n", + "target\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "likelihood->target\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "mu->likelihood\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "sigma\n", + "\n", + "sigma\n", + "~\n", + "HalfNormal\n", + "\n", + "\n", + "\n", + "sigma->likelihood\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "intercept\n", + "\n", + "intercept\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "intercept->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "beta_channel\n", + "\n", + "beta_channel\n", + "~\n", + "HalfNormal\n", + "\n", + "\n", + "\n", + "beta_channel->channel_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "alpha\n", + "\n", + "alpha\n", + "~\n", + "Beta\n", + "\n", + "\n", + "\n", + "alpha->channel_adstock\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "lam\n", + "\n", + "lam\n", + "~\n", + "Gamma\n", + "\n", + "\n", + "\n", + "lam->channel_adstock_saturated\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "control_data\n", + "\n", + "control_data\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "control_contributions\n", + "\n", + "control_contributions\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "control_data->control_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "control_contributions->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "gamma_control\n", + "\n", + "gamma_control\n", + "~\n", + "Normal\n", + "\n", + "\n", + "\n", + "gamma_control->control_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "fourier_data\n", + "\n", + "fourier_data\n", + "~\n", + "MutableData\n", + "\n", + "\n", + "\n", + "fourier_contributions\n", + "\n", + "fourier_contributions\n", + "~\n", + "Deterministic\n", + "\n", + "\n", + "\n", + "fourier_data->fourier_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "fourier_contributions->mu\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "gamma_fourier\n", + "\n", + "gamma_fourier\n", + "~\n", + "Laplace\n", + "\n", + "\n", + "\n", + "gamma_fourier->fourier_contributions\n", + "\n", + "\n", + "\n", + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "execution_count": 20, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaded_model.graphviz()" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "id": "cfb64a2c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "
\n", + "
\n", + "
arviz.InferenceData
\n", + "
\n", + "
    \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
      +       "                                fourier_mode: 4, channel: 2, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain                      (chain) int64 0 1 2 3\n",
      +       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
      +       "  * control                    (control) object 'event_1' 'event_2' 't'\n",
      +       "  * fourier_mode               (fourier_mode) object 'sin_order_1' ... 'cos_o...\n",
      +       "  * channel                    (channel) object 'x1' 'x2'\n",
      +       "  * date                       (date) object '2018-04-02' ... '2021-08-30'\n",
      +       "Data variables: (12/13)\n",
      +       "    intercept                  (chain, draw) float64 ...\n",
      +       "    gamma_control              (chain, draw, control) float64 ...\n",
      +       "    gamma_fourier              (chain, draw, fourier_mode) float64 ...\n",
      +       "    beta_channel               (chain, draw, channel) float64 ...\n",
      +       "    alpha                      (chain, draw, channel) float64 ...\n",
      +       "    lam                        (chain, draw, channel) float64 ...\n",
      +       "    ...                         ...\n",
      +       "    channel_adstock            (chain, draw, date, channel) float64 ...\n",
      +       "    channel_adstock_saturated  (chain, draw, date, channel) float64 ...\n",
      +       "    channel_contributions      (chain, draw, date, channel) float64 ...\n",
      +       "    control_contributions      (chain, draw, date, control) float64 ...\n",
      +       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 ...\n",
      +       "    mu                         (chain, draw, date) float64 ...\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:12.916984\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1\n",
      +       "    sampling_time:              27.989259004592896\n",
      +       "    tuning_steps:               1000

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain       (chain) int64 0 1 2 3\n",
      +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      +       "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (chain, draw, date) float64 ...\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:14.741457\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                (chain: 4, draw: 1000)\n",
      +       "Coordinates:\n",
      +       "  * chain                  (chain) int64 0 1 2 3\n",
      +       "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
      +       "Data variables: (12/17)\n",
      +       "    step_size              (chain, draw) float64 ...\n",
      +       "    lp                     (chain, draw) float64 ...\n",
      +       "    n_steps                (chain, draw) float64 ...\n",
      +       "    largest_eigval         (chain, draw) float64 ...\n",
      +       "    energy                 (chain, draw) float64 ...\n",
      +       "    acceptance_rate        (chain, draw) float64 ...\n",
      +       "    ...                     ...\n",
      +       "    diverging              (chain, draw) bool ...\n",
      +       "    smallest_eigval        (chain, draw) float64 ...\n",
      +       "    max_energy_error       (chain, draw) float64 ...\n",
      +       "    index_in_trajectory    (chain, draw) int64 ...\n",
      +       "    energy_error           (chain, draw) float64 ...\n",
      +       "    tree_depth             (chain, draw) int64 ...\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:12.932375\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1\n",
      +       "    sampling_time:              27.989259004592896\n",
      +       "    tuning_steps:               1000

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:                    (chain: 1, draw: 500, channel: 2, control: 3,\n",
      +       "                                date: 179, fourier_mode: 4)\n",
      +       "Coordinates:\n",
      +       "  * chain                      (chain) int64 0\n",
      +       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
      +       "  * channel                    (channel) object 'x1' 'x2'\n",
      +       "  * control                    (control) object 'event_1' 'event_2' 't'\n",
      +       "  * date                       (date) object '2018-04-02' ... '2021-08-30'\n",
      +       "  * fourier_mode               (fourier_mode) object 'sin_order_1' ... 'cos_o...\n",
      +       "Data variables: (12/13)\n",
      +       "    beta_channel               (chain, draw, channel) float64 ...\n",
      +       "    sigma                      (chain, draw) float64 ...\n",
      +       "    gamma_control              (chain, draw, control) float64 ...\n",
      +       "    control_contributions      (chain, draw, date, control) float64 ...\n",
      +       "    mu                         (chain, draw, date) float64 ...\n",
      +       "    channel_adstock_saturated  (chain, draw, date, channel) float64 ...\n",
      +       "    ...                         ...\n",
      +       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 ...\n",
      +       "    lam                        (chain, draw, channel) float64 ...\n",
      +       "    intercept                  (chain, draw) float64 ...\n",
      +       "    gamma_fourier              (chain, draw, fourier_mode) float64 ...\n",
      +       "    channel_adstock            (chain, draw, date, channel) float64 ...\n",
      +       "    alpha                      (chain, draw, channel) float64 ...\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:14.422347\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
      +       "Coordinates:\n",
      +       "  * chain       (chain) int64 0\n",
      +       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      +       "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (chain, draw, date) float64 ...\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:14.428434\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:     (date: 179)\n",
      +       "Coordinates:\n",
      +       "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "Data variables:\n",
      +       "    likelihood  (date) float64 ...\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:12.937052\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
      +       "Coordinates:\n",
      +       "  * date          (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "  * channel       (channel) object 'x1' 'x2'\n",
      +       "  * control       (control) object 'event_1' 'event_2' 't'\n",
      +       "  * fourier_mode  (fourier_mode) object 'sin_order_1' ... 'cos_order_2'\n",
      +       "Data variables:\n",
      +       "    channel_data  (date, channel) float64 ...\n",
      +       "    target        (date) float64 ...\n",
      +       "    control_data  (date, control) float64 ...\n",
      +       "    fourier_data  (date, fourier_mode) float64 ...\n",
      +       "Attributes:\n",
      +       "    created_at:                 2023-08-17T13:57:12.938885\n",
      +       "    arviz_version:              0.16.1\n",
      +       "    inference_library:          pymc\n",
      +       "    inference_library_version:  5.6.1

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
  • \n", + " \n", + " \n", + "
    \n", + "
    \n", + "
      \n", + "
      \n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
      <xarray.Dataset>\n",
      +       "Dimensions:    (index: 179)\n",
      +       "Coordinates:\n",
      +       "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      +       "Data variables:\n",
      +       "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      +       "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
      +       "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
      +       "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      +       "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      +       "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
      +       "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      +       "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", + "
    \n", + "
    \n", + "
  • \n", + " \n", + "
\n", + "
\n", + " " + ], + "text/plain": [ + "Inference data with groups:\n", + "\t> posterior\n", + "\t> posterior_predictive\n", + "\t> sample_stats\n", + "\t> prior\n", + "\t> prior_predictive\n", + "\t> observed_data\n", + "\t> constant_data\n", + "\t> fit_data" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "loaded_model.idata" + ] + }, + { + "cell_type": "markdown", + "id": "ab64be46-7fe5-4f39-b72b-36da1419f809", + "metadata": {}, + "source": [ + "A model loaded in this way is ready to be used for sampling and prediction, and has access to all previous samples and data." + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "dd59d056-6ac7-431c-85ee-99e7a8eefd8a", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling: [likelihood]\n" + ] + }, + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + "
\n", + " \n", + " 100.00% [4000/4000 00:00<00:00]\n", + "
\n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "
<xarray.Dataset>\n",
+       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
+       "Coordinates:\n",
+       "  * chain       (chain) int64 0 1 2 3\n",
+       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
+       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
+       "Data variables:\n",
+       "    likelihood  (chain, draw, date) float64 0.4907 0.4282 ... 0.5548 0.5396\n",
+       "Attributes:\n",
+       "    created_at:                 2023-08-17T13:57:22.302381\n",
+       "    arviz_version:              0.16.1\n",
+       "    inference_library:          pymc\n",
+       "    inference_library_version:  5.6.1
" + ], + "text/plain": [ + "\n", + "Dimensions: (chain: 4, draw: 1000, date: 179)\n", + "Coordinates:\n", + " * chain (chain) int64 0 1 2 3\n", + " * draw (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n", + " * date (date) " + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "az.plot_ppc(loaded_model.idata);" + ] + }, + { + "cell_type": "markdown", + "id": "e8e807f9", + "metadata": {}, + "source": [ + "## Summary:" + ] + }, + { + "cell_type": "markdown", + "id": "61f232c1", + "metadata": {}, + "source": [ + "In summary, this article introduces the revolutionary ModelBuilder, a new [PyMC-experimental](https://github.com/pymc-devs/pymc-experimental) module that simplifies the deployment of PyMC Bayesian models. It addresses a historic challenge faced by users of PyMC and most PPLs by offering a user-friendly and efficient approach to model deployment. The ModelBuilder provides two straightforward methods, save() and load(), which streamline the model preservation and replication process post fitting. Users are offered flexibility in controlling the prior settings with `model_config` and customizing the sampling process via `sampler_config`.\n", + "\n", + "The use of an example model from the [MMM Example Notebook](https://www.pymc-marketing.io/en/stable/notebooks/index.html) demonstrates the practical implementation of `ModelBuilder`, emphasizing its ability to enhance model sharing among teams without the necessity for extensive domain knowledge about the model. The deployment improvements in [PyMC-Marketing](https://github.com/pymc-labs/pymc-marketing) brought about by ModelBuilder are not only user-friendly but also significantly enhance efficiency, making PyMC models more accessible for a wider audience." + ] + }, + { + "cell_type": "markdown", + "id": "b8ad333d", + "metadata": {}, + "source": [ + "Even though this introduction is using `DelayedSaturatedMMM`, functionalities from `ModelBuilder` are available in the CLV models as well." + ] + }, + { + "cell_type": "markdown", + "id": "de822885-2fb4-4ad1-aaf9-7659771f7363", + "metadata": {}, + "source": [ + "## Authors\n", + "- Authored by [Michał Raczycki](https://github.com/michaelraczycki) in August 2023" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "07ec32d2-2f38-47ec-a200-b9d7258c3ac5", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Last updated: Thu Aug 17 2023\n", + "\n", + "Python implementation: CPython\n", + "Python version : 3.10.12\n", + "IPython version : 8.14.0\n", + "\n", + "pytensor: 2.12.3\n", + "aeppl : not installed\n", + "xarray : 2023.7.0\n", + "\n", + "numpy : 1.25.1\n", + "pandas: 2.0.3\n", + "arviz : 0.16.1\n", + "\n", + "Watermark: 2.4.3\n", + "\n" + ] + } + ], + "source": [ + "%load_ext watermark\n", + "%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray" + ] + }, + { + "cell_type": "markdown", + "id": "7e453d63-39f3-456c-a099-2c780d3ba4a9", + "metadata": {}, + "source": [ + ":::\n", + "{include} ../page_footer.md\n", + ":::" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc-marketing", + "language": "python", + "name": "pymc-marketing" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/howto/ModelBuilder_usage_example.myst.md b/examples/howto/ModelBuilder_usage_example.myst.md new file mode 100644 index 000000000..3ea5bbe16 --- /dev/null +++ b/examples/howto/ModelBuilder_usage_example.myst.md @@ -0,0 +1,279 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 +kernelspec: + display_name: pymc-marketing + language: python + name: pymc-marketing +--- + +(ModelBuilder usage example)= +# ModelBuilder usage example + +:::{post} Aug 18, 2023 +:tags: ModelBuilder, model deployment, +:category: intermediate, tutorial +:author: Michał Raczycki +::: + ++++ + +# Deploying MMMs and CLVs in Production: Saving and Loading Models + ++++ + +In this article, we'll tackle the historically challenging process of deploying Bayesian models built with PyMC. Introducing a revolutionary deployment module, we bring unprecedented simplicity and efficiency to the deployment of PyMC models. As we prioritize user-friendly solutions, let's delve into how this innovation can significantly elevate your data science projects. + ++++ + + +Recent release of PyMC-Marketing by [Labs](https://www.pymc-labs.io) proves to be a big hit [(PyMC-Marketing)](https://www.pymc-labs.io/blog-posts/pymc-marketing-a-bayesian-approach-to-marketing-data-science/). In the feedback one could see an ongoing theme, many of you have been requesting easy and robust way of deploying models to production. It’s been a long-standing problem with PyMC ( and most other Probabilistic Programming Languages). The reason for that is that there’s no obvious way, and doesn’t matter which approach you try it proves to be tricky. That is why we’re happy to announce the release of `ModelBuilder`, brand new PyMC-experimental module that addresses this need, and improves on the deployment process significantly. + +The ModelBuilder module is a new feature of PyMC based models. It provides 2 easy-to-use methods: save() and load() that can be used after the model has been fit.save() allow easy preservation of the model to .netcdf format, and load() gives one-line replication of the original model. Users can control the prior settings with model_config, and customize the sampling process using sampler_config. Default values of those are working just fine, so first time give it a try without changing, and provide your own model_config and model_sampler if afterwards you want to try to customize it more for your use case! + ++++ + +For this notebook I'll use the example model used in [MMM Example Notebook](https://www.pymc-marketing.io/en/stable/notebooks/mmm/mmm_example.html), but ommit the details of data generation and plotting functionalities, since they're out of scope for this introduction, I highly recommend to see that part as well, but for now let's focus on today's topic: Groundbreaking deployment improvements in PyMC-Marketing! + +```{code-cell} ipython3 +import arviz as az +import numpy as np +import pandas as pd + +from pymc_marketing.mmm import DelayedSaturatedMMM +``` + +```{code-cell} ipython3 +az.style.use("arviz-darkgrid") +``` + +Let's load the dataset: + +```{code-cell} ipython3 +url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/datasets/mmm_example.csv" +df = pd.read_csv(url) +df +``` + +But for our model we need much smaller dataset, many of the previous features were contributing to generation of others, now as our target variable is computed we can filter out not needed columns: + +```{code-cell} ipython3 +columns_to_keep = [ + "date_week", + "y", + "x1", + "x2", + "event_1", + "event_2", + "dayofyear", +] +seed: int = sum(map(ord, "mmm")) +rng = np.random.default_rng(seed=seed) + +data = df[columns_to_keep].copy() + +data["t"] = range(df.shape[0]) +data.head() +``` + +## _Model Creation_ +After we have our dataset ready, we could proceed straight to our model definition, but first to show the full potential of one of the new features: `model_config` we need to use some of our data to define our prior for sigma parameter for each of the channels. `model_config` is a customizable dictionary with keys corresponding to priors within the model, and values containing a dictionaries with parameters necessary to initialize them. Later on we'll learn that through the `save()` method we can preserve our priors contained inside the `model_config`, to allow complete replication of our model. + ++++ + +### model_config + ++++ + +`default_model_config` attribute of every model inheriting from `ModelBuilder` will allow you to see which priors are available for customization. To see it simply initialize a dummy model: + +```{code-cell} ipython3 +dummy_model = DelayedSaturatedMMM(date_column="", channel_columns="", adstock_max_lag=4) +dummy_model.default_model_config +``` + +You can change only the prior parameters that you wish, no need to alter all of them, unless you'd like to! +In this case we'll just simply replace our sigma for beta_channel with our computed one: + ++++ + +First, let's compute the share of spend per channel: + +```{code-cell} ipython3 +total_spend_per_channel = data[["x1", "x2"]].sum(axis=0) + +spend_share = total_spend_per_channel / total_spend_per_channel.sum() + +spend_share +``` + +Next, we specify the `sigma`parameter per channel: + +```{code-cell} ipython3 +# The scale necessary to make a HalfNormal distribution have unit variance +HALFNORMAL_SCALE = 1 / np.sqrt(1 - 2 / np.pi) + +n_channels = 2 + +prior_sigma = HALFNORMAL_SCALE * n_channels * spend_share.to_numpy() + +prior_sigma.tolist() +``` + +```{code-cell} ipython3 +custom_beta_channel_prior = {"beta_channel": {"sigma": prior_sigma, "dims": ("channel",)}} +my_model_config = dummy_model.default_model_config | custom_beta_channel_prior +``` + +As mentioned in the original notebook: "_For the prior specification there is no right or wrong answer. It all depends on the data, the context and the assumptions you are willing to make. It is always recommended to do some prior predictive sampling and sensitivity analysis to check the impact of the priors on the posterior. We skip this here for the sake of simplicity. If you are not sure about specific priors, the `DelayedSaturatedMMM` class has some default priors that you can use as a starting point._" + ++++ + +The second feature that we can use for model definition is `sampler_config`. Similar to `model_config`, it's a dictionary that gets saved and contains things you'd usually pass to the `fit()` kwargs. It's not mandatory to create your own `sampler_config`; if not provided, both `model_config` and `sampler_config` will default to the forms specified by PyMC Labs experts, which allows for the usage of all model functionalities. The default `sampler_config` is left empty because the default sampling parameters usually prove sufficient for a start. + +```{code-cell} ipython3 +dummy_model.default_sampler_config +``` + +```{code-cell} ipython3 +my_sampler_config = { + "tune": 1000, + "draws": 1000, + "chains": 4, + "target_accept": 0.95, +} +``` + +Let's finally assemble our model! + +```{code-cell} ipython3 +mmm = DelayedSaturatedMMM( + model_config=my_model_config, + sampler_config=my_sampler_config, + date_column="date_week", + channel_columns=["x1", "x2"], + control_columns=[ + "event_1", + "event_2", + "t", + ], + adstock_max_lag=8, + yearly_seasonality=2, +) +``` + +An important thing to note here is that in the new version of `DelayedSaturatedMMM`, we don't pass our dataset to the class constructor itself. This is due to a reason I've mentioned before - it supports `sklearn` transformers and validations that require a usual X, y split and typically expect the data to be passed to the `fit()` method. + ++++ + +## _Model Fitting_ + ++++ + +Let's split the dataset: + +```{code-cell} ipython3 +X = data.drop("y", axis=1) +y = data["y"] +``` + +All that's left now is to finally fit the model: + +As you can see below, you can still pass the sampler kwargs directly to `fit()` method. However, only those kwargs passed using `sampler_config` will be saved. Therefore, only these will be available after loading the model. + +```{code-cell} ipython3 +mmm.fit(X=X, y=y, random_seed=rng) +``` + +The `fit()` method automatically builds the model using the priors from `model_config`, and assigns the created model to our instance. You can access it as a normal attribute. + +```{code-cell} ipython3 +type(mmm.model) +``` + +```{code-cell} ipython3 +mmm.graphviz() +``` + +posterior trace can be accessed by `fit_result` attribute: + +```{code-cell} ipython3 +mmm.fit_result +``` + +If you wish to inspect the entire inference data, use the `idata` attribute. Within `idata`, you can find the entire dataset passed to the model under `fit_data`. + +```{code-cell} ipython3 +mmm.idata +``` + +## `Save` and `load` + ++++ + +All the data passed to the model on initialisation is stored in `idata.attrs`. This will be used later in the `save()` method to convert both this data and all the fit data into the netCDF format. + ++++ + +Simply specify the path to which you'd like to save your model: + +```{code-cell} ipython3 +mmm.save("my_saved_model.nc") +``` + +And pass it to the `load()` method when it's needed again on the target system: + +```{code-cell} ipython3 +loaded_model = DelayedSaturatedMMM.load("my_saved_model.nc") +``` + +```{code-cell} ipython3 +loaded_model.graphviz() +``` + +```{code-cell} ipython3 +loaded_model.idata +``` + +A model loaded in this way is ready to be used for sampling and prediction, and has access to all previous samples and data. + +```{code-cell} ipython3 +with loaded_model.model: + new_predictions = loaded_model.sample_posterior_predictive( + X, extend_idata=True, combined=False, random_seed=rng + ) +new_predictions +``` + +```{code-cell} ipython3 +az.plot_ppc(loaded_model.idata); +``` + +## Summary: + ++++ + +In summary, this article introduces the revolutionary ModelBuilder, a new [PyMC-experimental](https://github.com/pymc-devs/pymc-experimental) module that simplifies the deployment of PyMC Bayesian models. It addresses a historic challenge faced by users of PyMC and most PPLs by offering a user-friendly and efficient approach to model deployment. The ModelBuilder provides two straightforward methods, save() and load(), which streamline the model preservation and replication process post fitting. Users are offered flexibility in controlling the prior settings with `model_config` and customizing the sampling process via `sampler_config`. + +The use of an example model from the [MMM Example Notebook](https://www.pymc-marketing.io/en/stable/notebooks/index.html) demonstrates the practical implementation of `ModelBuilder`, emphasizing its ability to enhance model sharing among teams without the necessity for extensive domain knowledge about the model. The deployment improvements in [PyMC-Marketing](https://github.com/pymc-labs/pymc-marketing) brought about by ModelBuilder are not only user-friendly but also significantly enhance efficiency, making PyMC models more accessible for a wider audience. + ++++ + +Even though this introduction is using `DelayedSaturatedMMM`, functionalities from `ModelBuilder` are available in the CLV models as well. + ++++ + +## Authors +- Authored by [Michał Raczycki](https://github.com/michaelraczycki) in August 2023 + +```{code-cell} ipython3 +%load_ext watermark +%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray +``` + +::: +{include} ../page_footer.md +::: From 431dcd4402cddd3cd561285382e45aec05520f3a Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Wed, 23 Aug 2023 10:45:37 +0200 Subject: [PATCH 3/5] removing initial notebook, updating model_builder.ipynb --- .../howto/ModelBuilder_usage_example.ipynb | 17466 ---------------- .../howto/ModelBuilder_usage_example.myst.md | 279 - examples/howto/model_builder.ipynb | 692 + examples/howto/model_builder.myst.md | 350 + 4 files changed, 1042 insertions(+), 17745 deletions(-) delete mode 100644 examples/howto/ModelBuilder_usage_example.ipynb delete mode 100644 examples/howto/ModelBuilder_usage_example.myst.md create mode 100644 examples/howto/model_builder.ipynb create mode 100644 examples/howto/model_builder.myst.md diff --git a/examples/howto/ModelBuilder_usage_example.ipynb b/examples/howto/ModelBuilder_usage_example.ipynb deleted file mode 100644 index d9a409205..000000000 --- a/examples/howto/ModelBuilder_usage_example.ipynb +++ /dev/null @@ -1,17466 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "ad9920da-e687-408d-b99f-060a99c0b561", - "metadata": {}, - "source": [ - "(ModelBuilder usage example)=\n", - "# ModelBuilder usage example\n", - "\n", - ":::{post} Aug 18, 2023\n", - ":tags: ModelBuilder, model deployment,\n", - ":category: intermediate, tutorial\n", - ":author: Michał Raczycki\n", - ":::" - ] - }, - { - "cell_type": "markdown", - "id": "1d74584c", - "metadata": {}, - "source": [ - "# Deploying MMMs and CLVs in Production: Saving and Loading Models" - ] - }, - { - "cell_type": "markdown", - "id": "3222ef04", - "metadata": {}, - "source": [ - "In this article, we'll tackle the historically challenging process of deploying Bayesian models built with PyMC. Introducing a revolutionary deployment module, we bring unprecedented simplicity and efficiency to the deployment of PyMC models. As we prioritize user-friendly solutions, let's delve into how this innovation can significantly elevate your data science projects." - ] - }, - { - "cell_type": "markdown", - "id": "ddb28436", - "metadata": {}, - "source": [ - "\n", - "Recent release of PyMC-Marketing by [Labs](https://www.pymc-labs.io) proves to be a big hit [(PyMC-Marketing)](https://www.pymc-labs.io/blog-posts/pymc-marketing-a-bayesian-approach-to-marketing-data-science/). In the feedback one could see an ongoing theme, many of you have been requesting easy and robust way of deploying models to production. It’s been a long-standing problem with PyMC ( and most other Probabilistic Programming Languages). The reason for that is that there’s no obvious way, and doesn’t matter which approach you try it proves to be tricky. That is why we’re happy to announce the release of `ModelBuilder`, brand new PyMC-experimental module that addresses this need, and improves on the deployment process significantly.\n", - "\n", - "The ModelBuilder module is a new feature of PyMC based models. It provides 2 easy-to-use methods: save() and load() that can be used after the model has been fit.save() allow easy preservation of the model to .netcdf format, and load() gives one-line replication of the original model. Users can control the prior settings with model_config, and customize the sampling process using sampler_config. Default values of those are working just fine, so first time give it a try without changing, and provide your own model_config and model_sampler if afterwards you want to try to customize it more for your use case!\n" - ] - }, - { - "cell_type": "markdown", - "id": "a808e36a", - "metadata": {}, - "source": [ - "For this notebook I'll use the example model used in [MMM Example Notebook](https://www.pymc-marketing.io/en/stable/notebooks/mmm/mmm_example.html), but ommit the details of data generation and plotting functionalities, since they're out of scope for this introduction, I highly recommend to see that part as well, but for now let's focus on today's topic: Groundbreaking deployment improvements in PyMC-Marketing!" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "1050a937", - "metadata": {}, - "outputs": [], - "source": [ - "import arviz as az\n", - "import numpy as np\n", - "import pandas as pd\n", - "\n", - "from pymc_marketing.mmm import DelayedSaturatedMMM" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "e2c3f4c4-1d74-4ae2-9ff2-c13cbcf7fe54", - "metadata": {}, - "outputs": [], - "source": [ - "az.style.use(\"arviz-darkgrid\")" - ] - }, - { - "cell_type": "markdown", - "id": "f37d808e", - "metadata": {}, - "source": [ - "Let's load the dataset:" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "b7b1193f", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
date_weekyx1x2event_1event_2dayofyeartsin_order_1cos_order_1sin_order_2cos_order_2
02018-04-023984.6622370.3185800.0000000.00.09200.999930-0.011826-0.023651-0.999720
12018-04-093762.8717940.1123880.0000000.00.09910.991269-0.131859-0.261414-0.965227
22018-04-164466.9673880.2924000.0000000.00.010620.968251-0.249981-0.484089-0.875019
32018-04-233864.2193730.0713990.0000000.00.011330.931210-0.364483-0.678820-0.734304
42018-04-304441.6252780.3867450.0000000.00.012040.880683-0.473706-0.834370-0.551205
.......................................
1742021-08-023553.5461480.0330240.0000000.00.0214174-0.513901-0.8578490.8816990.471812
1752021-08-095565.5096820.1656150.8633490.00.0221175-0.613230-0.7899050.9687860.247898
1762021-08-164137.6514850.1718820.0000000.00.0228176-0.703677-0.7105200.9999530.009676
1772021-08-234479.0413510.2802570.0000000.00.0235177-0.783934-0.6208440.973402-0.229104
1782021-08-304675.9734390.4388570.0000000.00.0242178-0.852837-0.5221780.890665-0.454661
\n", - "

179 rows × 12 columns

\n", - "
" - ], - "text/plain": [ - " date_week y x1 x2 event_1 event_2 dayofyear \\\n", - "0 2018-04-02 3984.662237 0.318580 0.000000 0.0 0.0 92 \n", - "1 2018-04-09 3762.871794 0.112388 0.000000 0.0 0.0 99 \n", - "2 2018-04-16 4466.967388 0.292400 0.000000 0.0 0.0 106 \n", - "3 2018-04-23 3864.219373 0.071399 0.000000 0.0 0.0 113 \n", - "4 2018-04-30 4441.625278 0.386745 0.000000 0.0 0.0 120 \n", - ".. ... ... ... ... ... ... ... \n", - "174 2021-08-02 3553.546148 0.033024 0.000000 0.0 0.0 214 \n", - "175 2021-08-09 5565.509682 0.165615 0.863349 0.0 0.0 221 \n", - "176 2021-08-16 4137.651485 0.171882 0.000000 0.0 0.0 228 \n", - "177 2021-08-23 4479.041351 0.280257 0.000000 0.0 0.0 235 \n", - "178 2021-08-30 4675.973439 0.438857 0.000000 0.0 0.0 242 \n", - "\n", - " t sin_order_1 cos_order_1 sin_order_2 cos_order_2 \n", - "0 0 0.999930 -0.011826 -0.023651 -0.999720 \n", - "1 1 0.991269 -0.131859 -0.261414 -0.965227 \n", - "2 2 0.968251 -0.249981 -0.484089 -0.875019 \n", - "3 3 0.931210 -0.364483 -0.678820 -0.734304 \n", - "4 4 0.880683 -0.473706 -0.834370 -0.551205 \n", - ".. ... ... ... ... ... \n", - "174 174 -0.513901 -0.857849 0.881699 0.471812 \n", - "175 175 -0.613230 -0.789905 0.968786 0.247898 \n", - "176 176 -0.703677 -0.710520 0.999953 0.009676 \n", - "177 177 -0.783934 -0.620844 0.973402 -0.229104 \n", - "178 178 -0.852837 -0.522178 0.890665 -0.454661 \n", - "\n", - "[179 rows x 12 columns]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "url = \"https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/datasets/mmm_example.csv\"\n", - "df = pd.read_csv(url)\n", - "df" - ] - }, - { - "cell_type": "markdown", - "id": "87deb70d", - "metadata": {}, - "source": [ - "But for our model we need much smaller dataset, many of the previous features were contributing to generation of others, now as our target variable is computed we can filter out not needed columns:" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "52b6d127", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
date_weekyx1x2event_1event_2dayofyeart
02018-04-023984.6622370.3185800.00.00.0920
12018-04-093762.8717940.1123880.00.00.0991
22018-04-164466.9673880.2924000.00.00.01062
32018-04-233864.2193730.0713990.00.00.01133
42018-04-304441.6252780.3867450.00.00.01204
\n", - "
" - ], - "text/plain": [ - " date_week y x1 x2 event_1 event_2 dayofyear t\n", - "0 2018-04-02 3984.662237 0.318580 0.0 0.0 0.0 92 0\n", - "1 2018-04-09 3762.871794 0.112388 0.0 0.0 0.0 99 1\n", - "2 2018-04-16 4466.967388 0.292400 0.0 0.0 0.0 106 2\n", - "3 2018-04-23 3864.219373 0.071399 0.0 0.0 0.0 113 3\n", - "4 2018-04-30 4441.625278 0.386745 0.0 0.0 0.0 120 4" - ] - }, - "execution_count": 4, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "columns_to_keep = [\n", - " \"date_week\",\n", - " \"y\",\n", - " \"x1\",\n", - " \"x2\",\n", - " \"event_1\",\n", - " \"event_2\",\n", - " \"dayofyear\",\n", - "]\n", - "seed: int = sum(map(ord, \"mmm\"))\n", - "rng = np.random.default_rng(seed=seed)\n", - "\n", - "data = df[columns_to_keep].copy()\n", - "\n", - "data[\"t\"] = range(df.shape[0])\n", - "data.head()" - ] - }, - { - "cell_type": "markdown", - "id": "9518a885", - "metadata": {}, - "source": [ - "## _Model Creation_\n", - "After we have our dataset ready, we could proceed straight to our model definition, but first to show the full potential of one of the new features: `model_config` we need to use some of our data to define our prior for sigma parameter for each of the channels. `model_config` is a customizable dictionary with keys corresponding to priors within the model, and values containing a dictionaries with parameters necessary to initialize them. Later on we'll learn that through the `save()` method we can preserve our priors contained inside the `model_config`, to allow complete replication of our model." - ] - }, - { - "cell_type": "markdown", - "id": "4b52b2c1", - "metadata": {}, - "source": [ - "### model_config" - ] - }, - { - "cell_type": "markdown", - "id": "41021a72", - "metadata": {}, - "source": [ - "`default_model_config` attribute of every model inheriting from `ModelBuilder` will allow you to see which priors are available for customization. To see it simply initialize a dummy model:" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "284bd558", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'intercept': {'mu': 0, 'sigma': 2},\n", - " 'beta_channel': {'sigma': 2, 'dims': ('channel',)},\n", - " 'alpha': {'alpha': 1, 'beta': 3, 'dims': ('channel',)},\n", - " 'lam': {'alpha': 3, 'beta': 1, 'dims': ('channel',)},\n", - " 'sigma': {'sigma': 2},\n", - " 'gamma_control': {'mu': 0, 'sigma': 2, 'dims': ('control',)},\n", - " 'mu': {'dims': ('date',)},\n", - " 'likelihood': {'dims': ('date',)},\n", - " 'gamma_fourier': {'mu': 0, 'b': 1, 'dims': 'fourier_mode'}}" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dummy_model = DelayedSaturatedMMM(date_column=\"\", channel_columns=\"\", adstock_max_lag=4)\n", - "dummy_model.default_model_config" - ] - }, - { - "cell_type": "markdown", - "id": "f0fd248f", - "metadata": {}, - "source": [ - "You can change only the prior parameters that you wish, no need to alter all of them, unless you'd like to!\n", - "In this case we'll just simply replace our sigma for beta_channel with our computed one:" - ] - }, - { - "cell_type": "markdown", - "id": "19f075f0-4d3d-4509-a9c6-f15efdb9293d", - "metadata": {}, - "source": [ - "First, let's compute the share of spend per channel:" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "4785596a-e333-4cd0-af15-1332e97b66d5", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "x1 0.65632\n", - "x2 0.34368\n", - "dtype: float64" - ] - }, - "execution_count": 6, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "total_spend_per_channel = data[[\"x1\", \"x2\"]].sum(axis=0)\n", - "\n", - "spend_share = total_spend_per_channel / total_spend_per_channel.sum()\n", - "\n", - "spend_share" - ] - }, - { - "cell_type": "markdown", - "id": "40d17642-1e21-4adc-97f4-633eede87915", - "metadata": {}, - "source": [ - "Next, we specify the `sigma`parameter per channel:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "672b36bd-3d08-46df-85b8-d67d3ade75d7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[2.1775326025486734, 1.140260877391939]" - ] - }, - "execution_count": 7, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# The scale necessary to make a HalfNormal distribution have unit variance\n", - "HALFNORMAL_SCALE = 1 / np.sqrt(1 - 2 / np.pi)\n", - "\n", - "n_channels = 2\n", - "\n", - "prior_sigma = HALFNORMAL_SCALE * n_channels * spend_share.to_numpy()\n", - "\n", - "prior_sigma.tolist()" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "bac9f587", - "metadata": {}, - "outputs": [], - "source": [ - "custom_beta_channel_prior = {\"beta_channel\": {\"sigma\": prior_sigma, \"dims\": (\"channel\",)}}\n", - "my_model_config = dummy_model.default_model_config | custom_beta_channel_prior" - ] - }, - { - "cell_type": "markdown", - "id": "1aa435bf", - "metadata": {}, - "source": [ - "As mentioned in the original notebook: \"_For the prior specification there is no right or wrong answer. It all depends on the data, the context and the assumptions you are willing to make. It is always recommended to do some prior predictive sampling and sensitivity analysis to check the impact of the priors on the posterior. We skip this here for the sake of simplicity. If you are not sure about specific priors, the `DelayedSaturatedMMM` class has some default priors that you can use as a starting point._\"" - ] - }, - { - "cell_type": "markdown", - "id": "f195a79e", - "metadata": {}, - "source": [ - "The second feature that we can use for model definition is `sampler_config`. Similar to `model_config`, it's a dictionary that gets saved and contains things you'd usually pass to the `fit()` kwargs. It's not mandatory to create your own `sampler_config`; if not provided, both `model_config` and `sampler_config` will default to the forms specified by PyMC Labs experts, which allows for the usage of all model functionalities. The default `sampler_config` is left empty because the default sampling parameters usually prove sufficient for a start." - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "0ab8140c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{}" - ] - }, - "execution_count": 9, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dummy_model.default_sampler_config" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "bf5a50f4", - "metadata": {}, - "outputs": [], - "source": [ - "my_sampler_config = {\n", - " \"tune\": 1000,\n", - " \"draws\": 1000,\n", - " \"chains\": 4,\n", - " \"target_accept\": 0.95,\n", - "}" - ] - }, - { - "cell_type": "markdown", - "id": "f3bfe090", - "metadata": {}, - "source": [ - "Let's finally assemble our model!" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "c7bd6909", - "metadata": {}, - "outputs": [], - "source": [ - "mmm = DelayedSaturatedMMM(\n", - " model_config=my_model_config,\n", - " sampler_config=my_sampler_config,\n", - " date_column=\"date_week\",\n", - " channel_columns=[\"x1\", \"x2\"],\n", - " control_columns=[\n", - " \"event_1\",\n", - " \"event_2\",\n", - " \"t\",\n", - " ],\n", - " adstock_max_lag=8,\n", - " yearly_seasonality=2,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "54095b1a", - "metadata": {}, - "source": [ - "An important thing to note here is that in the new version of `DelayedSaturatedMMM`, we don't pass our dataset to the class constructor itself. This is due to a reason I've mentioned before - it supports `sklearn` transformers and validations that require a usual X, y split and typically expect the data to be passed to the `fit()` method." - ] - }, - { - "cell_type": "markdown", - "id": "dec9b1b0", - "metadata": {}, - "source": [ - "## _Model Fitting_" - ] - }, - { - "cell_type": "markdown", - "id": "d5e64562-ba78-4497-a0f8-123b4bc88b79", - "metadata": {}, - "source": [ - "Let's split the dataset:" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "ff23006b-a55b-4a22-9f34-4eeaddf47486", - "metadata": {}, - "outputs": [], - "source": [ - "X = data.drop(\"y\", axis=1)\n", - "y = data[\"y\"]" - ] - }, - { - "cell_type": "markdown", - "id": "403e3ed2", - "metadata": {}, - "source": [ - "All that's left now is to finally fit the model:\n", - "\n", - "As you can see below, you can still pass the sampler kwargs directly to `fit()` method. However, only those kwargs passed using `sampler_config` will be saved. Therefore, only these will be available after loading the model." - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "0f6ab0a8", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Auto-assigning NUTS sampler...\n", - "Initializing NUTS using jitter+adapt_diag...\n", - "Multiprocess sampling (4 chains in 4 jobs)\n", - "NUTS: [intercept, beta_channel, alpha, lam, sigma, gamma_control, gamma_fourier]\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " 100.00% [8000/8000 00:27<00:00 Sampling 4 chains, 0 divergences]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 28 seconds.\n", - "Sampling: [alpha, beta_channel, gamma_control, gamma_fourier, intercept, lam, likelihood, sigma]\n", - "Sampling: [likelihood]\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " 100.00% [4000/4000 00:00<00:00]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - "
\n", - "
arviz.InferenceData
\n", - "
\n", - "
    \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
      -       "                                fourier_mode: 4, channel: 2, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain                      (chain) int64 0 1 2 3\n",
      -       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
      -       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      -       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      -       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      -       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      -       "Data variables: (12/13)\n",
      -       "    intercept                  (chain, draw) float64 0.3381 0.3361 ... 0.3479\n",
      -       "    gamma_control              (chain, draw, control) float64 0.2968 ... 0.00...\n",
      -       "    gamma_fourier              (chain, draw, fourier_mode) float64 -0.002612 ...\n",
      -       "    beta_channel               (chain, draw, channel) float64 0.3718 ... 0.2628\n",
      -       "    alpha                      (chain, draw, channel) float64 0.4162 ... 0.2007\n",
      -       "    lam                        (chain, draw, channel) float64 4.26 ... 2.758\n",
      -       "    ...                         ...\n",
      -       "    channel_adstock            (chain, draw, date, channel) float64 0.1868 .....\n",
      -       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3781 .....\n",
      -       "    channel_contributions      (chain, draw, date, channel) float64 0.1406 .....\n",
      -       "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
      -       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 -0.0...\n",
      -       "    mu                         (chain, draw, date) float64 0.4769 ... 0.5924\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:12.916984\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1\n",
      -       "    sampling_time:              27.989259004592896\n",
      -       "    tuning_steps:               1000

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain       (chain) int64 0 1 2 3\n",
      -       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      -       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (chain, draw, date) float64 0.5148 0.4369 ... 0.4858 0.6176\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:14.741457\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                (chain: 4, draw: 1000)\n",
      -       "Coordinates:\n",
      -       "  * chain                  (chain) int64 0 1 2 3\n",
      -       "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
      -       "Data variables: (12/17)\n",
      -       "    step_size              (chain, draw) float64 0.05862 0.05862 ... 0.07543\n",
      -       "    lp                     (chain, draw) float64 355.1 355.2 ... 355.6 351.6\n",
      -       "    n_steps                (chain, draw) float64 63.0 63.0 63.0 ... 63.0 63.0\n",
      -       "    largest_eigval         (chain, draw) float64 nan nan nan nan ... nan nan nan\n",
      -       "    energy                 (chain, draw) float64 -344.2 -349.9 ... -350.2 -347.5\n",
      -       "    acceptance_rate        (chain, draw) float64 0.9987 0.9934 ... 0.8931 0.9167\n",
      -       "    ...                     ...\n",
      -       "    diverging              (chain, draw) bool False False False ... False False\n",
      -       "    smallest_eigval        (chain, draw) float64 nan nan nan nan ... nan nan nan\n",
      -       "    max_energy_error       (chain, draw) float64 -0.06059 -0.06074 ... 0.284\n",
      -       "    index_in_trajectory    (chain, draw) int64 -12 20 43 -17 ... 42 20 -43 -29\n",
      -       "    energy_error           (chain, draw) float64 -0.01181 -0.02828 ... -0.01478\n",
      -       "    tree_depth             (chain, draw) int64 6 6 6 6 7 6 6 7 ... 6 5 5 6 7 6 6\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:12.932375\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1\n",
      -       "    sampling_time:              27.989259004592896\n",
      -       "    tuning_steps:               1000

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                    (chain: 1, draw: 500, channel: 2, control: 3,\n",
      -       "                                date: 179, fourier_mode: 4)\n",
      -       "Coordinates:\n",
      -       "  * chain                      (chain) int64 0\n",
      -       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
      -       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      -       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      -       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      -       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      -       "Data variables: (12/13)\n",
      -       "    beta_channel               (chain, draw, channel) float64 0.6143 ... 0.9639\n",
      -       "    sigma                      (chain, draw) float64 2.034 5.255 ... 0.07125\n",
      -       "    gamma_control              (chain, draw, control) float64 -0.9767 ... 0.0...\n",
      -       "    control_contributions      (chain, draw, date, control) float64 -0.0 ... ...\n",
      -       "    mu                         (chain, draw, date) float64 5.528 7.076 ... 8.396\n",
      -       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3626 .....\n",
      -       "    ...                         ...\n",
      -       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 1.64...\n",
      -       "    lam                        (chain, draw, channel) float64 2.667 ... 1.438\n",
      -       "    intercept                  (chain, draw) float64 2.51 -3.898 ... -2.654\n",
      -       "    gamma_fourier              (chain, draw, fourier_mode) float64 1.647 ... ...\n",
      -       "    channel_adstock            (chain, draw, date, channel) float64 0.2848 .....\n",
      -       "    alpha                      (chain, draw, channel) float64 0.1089 ... 0.2786\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:14.422347\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain       (chain) int64 0\n",
      -       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      -       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (chain, draw, date) float64 5.221 8.352 5.313 ... 7.793 8.528\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:14.428434\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (date: 179)\n",
      -       "Coordinates:\n",
      -       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:12.937052\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
      -       "Coordinates:\n",
      -       "  * date          (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "  * channel       (channel) <U2 'x1' 'x2'\n",
      -       "  * control       (control) <U7 'event_1' 'event_2' 't'\n",
      -       "  * fourier_mode  (fourier_mode) <U11 'sin_order_1' ... 'cos_order_2'\n",
      -       "Data variables:\n",
      -       "    channel_data  (date, channel) float64 0.3196 0.0 0.1128 ... 0.0 0.4403 0.0\n",
      -       "    target        (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      -       "    control_data  (date, control) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0\n",
      -       "    fourier_data  (date, fourier_mode) float64 0.9999 -0.01183 ... -0.4547\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:12.938885\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:    (index: 179)\n",
      -       "Coordinates:\n",
      -       "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      -       "Data variables:\n",
      -       "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
      -       "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
      -       "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      -       "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      -       "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
      -       "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      -       "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
\n", - "
\n", - " " - ], - "text/plain": [ - "Inference data with groups:\n", - "\t> posterior\n", - "\t> posterior_predictive\n", - "\t> sample_stats\n", - "\t> prior\n", - "\t> prior_predictive\n", - "\t> observed_data\n", - "\t> constant_data\n", - "\t> fit_data" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mmm.fit(X=X, y=y, random_seed=rng)" - ] - }, - { - "cell_type": "markdown", - "id": "c29a6461", - "metadata": {}, - "source": [ - "The `fit()` method automatically builds the model using the priors from `model_config`, and assigns the created model to our instance. You can access it as a normal attribute." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "c6b8e2af", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "pymc.model.Model" - ] - }, - "execution_count": 14, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "type(mmm.model)" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "f046ee2c", - "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "clusterdate (179) x channel (2)\n", - "\n", - "date (179) x channel (2)\n", - "\n", - "\n", - "clusterdate (179)\n", - "\n", - "date (179)\n", - "\n", - "\n", - "clusterchannel (2)\n", - "\n", - "channel (2)\n", - "\n", - "\n", - "clusterdate (179) x control (3)\n", - "\n", - "date (179) x control (3)\n", - "\n", - "\n", - "clustercontrol (3)\n", - "\n", - "control (3)\n", - "\n", - "\n", - "clusterdate (179) x fourier_mode (4)\n", - "\n", - "date (179) x fourier_mode (4)\n", - "\n", - "\n", - "clusterfourier_mode (4)\n", - "\n", - "fourier_mode (4)\n", - "\n", - "\n", - "\n", - "channel_adstock_saturated\n", - "\n", - "channel_adstock_saturated\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "channel_contributions\n", - "\n", - "channel_contributions\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "channel_adstock_saturated->channel_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "mu\n", - "\n", - "mu\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "channel_contributions->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "channel_data\n", - "\n", - "channel_data\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "channel_adstock\n", - "\n", - "channel_adstock\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "channel_data->channel_adstock\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "channel_adstock->channel_adstock_saturated\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "likelihood\n", - "\n", - "likelihood\n", - "~\n", - "Normal\n", - "\n", - "\n", - "\n", - "target\n", - "\n", - "target\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "likelihood->target\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "mu->likelihood\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "sigma\n", - "\n", - "sigma\n", - "~\n", - "HalfNormal\n", - "\n", - "\n", - "\n", - "sigma->likelihood\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "intercept\n", - "\n", - "intercept\n", - "~\n", - "Normal\n", - "\n", - "\n", - "\n", - "intercept->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "beta_channel\n", - "\n", - "beta_channel\n", - "~\n", - "HalfNormal\n", - "\n", - "\n", - "\n", - "beta_channel->channel_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "alpha\n", - "\n", - "alpha\n", - "~\n", - "Beta\n", - "\n", - "\n", - "\n", - "alpha->channel_adstock\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "lam\n", - "\n", - "lam\n", - "~\n", - "Gamma\n", - "\n", - "\n", - "\n", - "lam->channel_adstock_saturated\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "control_data\n", - "\n", - "control_data\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "control_contributions\n", - "\n", - "control_contributions\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "control_data->control_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "control_contributions->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "gamma_control\n", - "\n", - "gamma_control\n", - "~\n", - "Normal\n", - "\n", - "\n", - "\n", - "gamma_control->control_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "fourier_data\n", - "\n", - "fourier_data\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "fourier_contributions\n", - "\n", - "fourier_contributions\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "fourier_data->fourier_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "fourier_contributions->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "gamma_fourier\n", - "\n", - "gamma_fourier\n", - "~\n", - "Laplace\n", - "\n", - "\n", - "\n", - "gamma_fourier->fourier_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 15, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mmm.graphviz()" - ] - }, - { - "cell_type": "markdown", - "id": "c804b600", - "metadata": {}, - "source": [ - "posterior trace can be accessed by `fit_result` attribute:" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "66903965", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.Dataset>\n",
-       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
-       "                                fourier_mode: 4, channel: 2, date: 179)\n",
-       "Coordinates:\n",
-       "  * chain                      (chain) int64 0 1 2 3\n",
-       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
-       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
-       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
-       "  * channel                    (channel) <U2 'x1' 'x2'\n",
-       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
-       "Data variables: (12/13)\n",
-       "    intercept                  (chain, draw) float64 0.3381 0.3361 ... 0.3479\n",
-       "    gamma_control              (chain, draw, control) float64 0.2968 ... 0.00...\n",
-       "    gamma_fourier              (chain, draw, fourier_mode) float64 -0.002612 ...\n",
-       "    beta_channel               (chain, draw, channel) float64 0.3718 ... 0.2628\n",
-       "    alpha                      (chain, draw, channel) float64 0.4162 ... 0.2007\n",
-       "    lam                        (chain, draw, channel) float64 4.26 ... 2.758\n",
-       "    ...                         ...\n",
-       "    channel_adstock            (chain, draw, date, channel) float64 0.1868 .....\n",
-       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3781 .....\n",
-       "    channel_contributions      (chain, draw, date, channel) float64 0.1406 .....\n",
-       "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
-       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 -0.0...\n",
-       "    mu                         (chain, draw, date) float64 0.4769 ... 0.5924\n",
-       "Attributes:\n",
-       "    created_at:                 2023-08-17T13:57:12.916984\n",
-       "    arviz_version:              0.16.1\n",
-       "    inference_library:          pymc\n",
-       "    inference_library_version:  5.6.1\n",
-       "    sampling_time:              27.989259004592896\n",
-       "    tuning_steps:               1000
" - ], - "text/plain": [ - "\n", - "Dimensions: (chain: 4, draw: 1000, control: 3,\n", - " fourier_mode: 4, channel: 2, date: 179)\n", - "Coordinates:\n", - " * chain (chain) int64 0 1 2 3\n", - " * draw (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n", - " * control (control) \n", - "
\n", - "
arviz.InferenceData
\n", - "
\n", - "
    \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
      -       "                                fourier_mode: 4, channel: 2, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain                      (chain) int64 0 1 2 3\n",
      -       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
      -       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      -       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      -       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      -       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      -       "Data variables: (12/13)\n",
      -       "    intercept                  (chain, draw) float64 0.3381 0.3361 ... 0.3479\n",
      -       "    gamma_control              (chain, draw, control) float64 0.2968 ... 0.00...\n",
      -       "    gamma_fourier              (chain, draw, fourier_mode) float64 -0.002612 ...\n",
      -       "    beta_channel               (chain, draw, channel) float64 0.3718 ... 0.2628\n",
      -       "    alpha                      (chain, draw, channel) float64 0.4162 ... 0.2007\n",
      -       "    lam                        (chain, draw, channel) float64 4.26 ... 2.758\n",
      -       "    ...                         ...\n",
      -       "    channel_adstock            (chain, draw, date, channel) float64 0.1868 .....\n",
      -       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3781 .....\n",
      -       "    channel_contributions      (chain, draw, date, channel) float64 0.1406 .....\n",
      -       "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
      -       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 -0.0...\n",
      -       "    mu                         (chain, draw, date) float64 0.4769 ... 0.5924\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:12.916984\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1\n",
      -       "    sampling_time:              27.989259004592896\n",
      -       "    tuning_steps:               1000

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain       (chain) int64 0 1 2 3\n",
      -       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      -       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (chain, draw, date) float64 0.5148 0.4369 ... 0.4858 0.6176\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:14.741457\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                (chain: 4, draw: 1000)\n",
      -       "Coordinates:\n",
      -       "  * chain                  (chain) int64 0 1 2 3\n",
      -       "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
      -       "Data variables: (12/17)\n",
      -       "    step_size              (chain, draw) float64 0.05862 0.05862 ... 0.07543\n",
      -       "    lp                     (chain, draw) float64 355.1 355.2 ... 355.6 351.6\n",
      -       "    n_steps                (chain, draw) float64 63.0 63.0 63.0 ... 63.0 63.0\n",
      -       "    largest_eigval         (chain, draw) float64 nan nan nan nan ... nan nan nan\n",
      -       "    energy                 (chain, draw) float64 -344.2 -349.9 ... -350.2 -347.5\n",
      -       "    acceptance_rate        (chain, draw) float64 0.9987 0.9934 ... 0.8931 0.9167\n",
      -       "    ...                     ...\n",
      -       "    diverging              (chain, draw) bool False False False ... False False\n",
      -       "    smallest_eigval        (chain, draw) float64 nan nan nan nan ... nan nan nan\n",
      -       "    max_energy_error       (chain, draw) float64 -0.06059 -0.06074 ... 0.284\n",
      -       "    index_in_trajectory    (chain, draw) int64 -12 20 43 -17 ... 42 20 -43 -29\n",
      -       "    energy_error           (chain, draw) float64 -0.01181 -0.02828 ... -0.01478\n",
      -       "    tree_depth             (chain, draw) int64 6 6 6 6 7 6 6 7 ... 6 5 5 6 7 6 6\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:12.932375\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1\n",
      -       "    sampling_time:              27.989259004592896\n",
      -       "    tuning_steps:               1000

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                    (chain: 1, draw: 500, channel: 2, control: 3,\n",
      -       "                                date: 179, fourier_mode: 4)\n",
      -       "Coordinates:\n",
      -       "  * chain                      (chain) int64 0\n",
      -       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
      -       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      -       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      -       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      -       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      -       "Data variables: (12/13)\n",
      -       "    beta_channel               (chain, draw, channel) float64 0.6143 ... 0.9639\n",
      -       "    sigma                      (chain, draw) float64 2.034 5.255 ... 0.07125\n",
      -       "    gamma_control              (chain, draw, control) float64 -0.9767 ... 0.0...\n",
      -       "    control_contributions      (chain, draw, date, control) float64 -0.0 ... ...\n",
      -       "    mu                         (chain, draw, date) float64 5.528 7.076 ... 8.396\n",
      -       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3626 .....\n",
      -       "    ...                         ...\n",
      -       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 1.64...\n",
      -       "    lam                        (chain, draw, channel) float64 2.667 ... 1.438\n",
      -       "    intercept                  (chain, draw) float64 2.51 -3.898 ... -2.654\n",
      -       "    gamma_fourier              (chain, draw, fourier_mode) float64 1.647 ... ...\n",
      -       "    channel_adstock            (chain, draw, date, channel) float64 0.2848 .....\n",
      -       "    alpha                      (chain, draw, channel) float64 0.1089 ... 0.2786\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:14.422347\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain       (chain) int64 0\n",
      -       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      -       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (chain, draw, date) float64 5.221 8.352 5.313 ... 7.793 8.528\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:14.428434\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (date: 179)\n",
      -       "Coordinates:\n",
      -       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:12.937052\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
      -       "Coordinates:\n",
      -       "  * date          (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "  * channel       (channel) <U2 'x1' 'x2'\n",
      -       "  * control       (control) <U7 'event_1' 'event_2' 't'\n",
      -       "  * fourier_mode  (fourier_mode) <U11 'sin_order_1' ... 'cos_order_2'\n",
      -       "Data variables:\n",
      -       "    channel_data  (date, channel) float64 0.3196 0.0 0.1128 ... 0.0 0.4403 0.0\n",
      -       "    target        (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      -       "    control_data  (date, control) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0\n",
      -       "    fourier_data  (date, fourier_mode) float64 0.9999 -0.01183 ... -0.4547\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:12.938885\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:    (index: 179)\n",
      -       "Coordinates:\n",
      -       "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      -       "Data variables:\n",
      -       "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
      -       "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
      -       "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      -       "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      -       "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
      -       "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      -       "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
\n", - " \n", - " " - ], - "text/plain": [ - "Inference data with groups:\n", - "\t> posterior\n", - "\t> posterior_predictive\n", - "\t> sample_stats\n", - "\t> prior\n", - "\t> prior_predictive\n", - "\t> observed_data\n", - "\t> constant_data\n", - "\t> fit_data" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mmm.idata" - ] - }, - { - "cell_type": "markdown", - "id": "8b433c7f-0f0d-40b2-bcfb-a19555b528bd", - "metadata": {}, - "source": [ - "## `Save` and `load`" - ] - }, - { - "cell_type": "markdown", - "id": "7b0a35f4", - "metadata": {}, - "source": [ - "All the data passed to the model on initialisation is stored in `idata.attrs`. This will be used later in the `save()` method to convert both this data and all the fit data into the netCDF format." - ] - }, - { - "cell_type": "markdown", - "id": "45948f46", - "metadata": {}, - "source": [ - "Simply specify the path to which you'd like to save your model:" - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "b3abe93a", - "metadata": {}, - "outputs": [], - "source": [ - "mmm.save(\"my_saved_model.nc\")" - ] - }, - { - "cell_type": "markdown", - "id": "8a5eba79", - "metadata": {}, - "source": [ - "And pass it to the `load()` method when it's needed again on the target system:" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "0421bae8", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/michalraczycki/Documents/pymc-marketing/.conda/envs/pymc-marketing/lib/python3.10/site-packages/arviz/data/inference_data.py:153: UserWarning: fit_data group is not defined in the InferenceData scheme\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "loaded_model = DelayedSaturatedMMM.load(\"my_saved_model.nc\")" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "a8b666d3", - "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "clusterdate (179) x channel (2)\n", - "\n", - "date (179) x channel (2)\n", - "\n", - "\n", - "clusterdate (179)\n", - "\n", - "date (179)\n", - "\n", - "\n", - "clusterchannel (2)\n", - "\n", - "channel (2)\n", - "\n", - "\n", - "clusterdate (179) x control (3)\n", - "\n", - "date (179) x control (3)\n", - "\n", - "\n", - "clustercontrol (3)\n", - "\n", - "control (3)\n", - "\n", - "\n", - "clusterdate (179) x fourier_mode (4)\n", - "\n", - "date (179) x fourier_mode (4)\n", - "\n", - "\n", - "clusterfourier_mode (4)\n", - "\n", - "fourier_mode (4)\n", - "\n", - "\n", - "\n", - "channel_adstock_saturated\n", - "\n", - "channel_adstock_saturated\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "channel_contributions\n", - "\n", - "channel_contributions\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "channel_adstock_saturated->channel_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "mu\n", - "\n", - "mu\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "channel_contributions->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "channel_data\n", - "\n", - "channel_data\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "channel_adstock\n", - "\n", - "channel_adstock\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "channel_data->channel_adstock\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "channel_adstock->channel_adstock_saturated\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "likelihood\n", - "\n", - "likelihood\n", - "~\n", - "Normal\n", - "\n", - "\n", - "\n", - "target\n", - "\n", - "target\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "likelihood->target\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "mu->likelihood\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "sigma\n", - "\n", - "sigma\n", - "~\n", - "HalfNormal\n", - "\n", - "\n", - "\n", - "sigma->likelihood\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "intercept\n", - "\n", - "intercept\n", - "~\n", - "Normal\n", - "\n", - "\n", - "\n", - "intercept->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "beta_channel\n", - "\n", - "beta_channel\n", - "~\n", - "HalfNormal\n", - "\n", - "\n", - "\n", - "beta_channel->channel_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "alpha\n", - "\n", - "alpha\n", - "~\n", - "Beta\n", - "\n", - "\n", - "\n", - "alpha->channel_adstock\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "lam\n", - "\n", - "lam\n", - "~\n", - "Gamma\n", - "\n", - "\n", - "\n", - "lam->channel_adstock_saturated\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "control_data\n", - "\n", - "control_data\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "control_contributions\n", - "\n", - "control_contributions\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "control_data->control_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "control_contributions->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "gamma_control\n", - "\n", - "gamma_control\n", - "~\n", - "Normal\n", - "\n", - "\n", - "\n", - "gamma_control->control_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "fourier_data\n", - "\n", - "fourier_data\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "fourier_contributions\n", - "\n", - "fourier_contributions\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "fourier_data->fourier_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "fourier_contributions->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "gamma_fourier\n", - "\n", - "gamma_fourier\n", - "~\n", - "Laplace\n", - "\n", - "\n", - "\n", - "gamma_fourier->fourier_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 20, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "loaded_model.graphviz()" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "id": "cfb64a2c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "
\n", - "
\n", - "
arviz.InferenceData
\n", - "
\n", - "
    \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
      -       "                                fourier_mode: 4, channel: 2, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain                      (chain) int64 0 1 2 3\n",
      -       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
      -       "  * control                    (control) object 'event_1' 'event_2' 't'\n",
      -       "  * fourier_mode               (fourier_mode) object 'sin_order_1' ... 'cos_o...\n",
      -       "  * channel                    (channel) object 'x1' 'x2'\n",
      -       "  * date                       (date) object '2018-04-02' ... '2021-08-30'\n",
      -       "Data variables: (12/13)\n",
      -       "    intercept                  (chain, draw) float64 ...\n",
      -       "    gamma_control              (chain, draw, control) float64 ...\n",
      -       "    gamma_fourier              (chain, draw, fourier_mode) float64 ...\n",
      -       "    beta_channel               (chain, draw, channel) float64 ...\n",
      -       "    alpha                      (chain, draw, channel) float64 ...\n",
      -       "    lam                        (chain, draw, channel) float64 ...\n",
      -       "    ...                         ...\n",
      -       "    channel_adstock            (chain, draw, date, channel) float64 ...\n",
      -       "    channel_adstock_saturated  (chain, draw, date, channel) float64 ...\n",
      -       "    channel_contributions      (chain, draw, date, channel) float64 ...\n",
      -       "    control_contributions      (chain, draw, date, control) float64 ...\n",
      -       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 ...\n",
      -       "    mu                         (chain, draw, date) float64 ...\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:12.916984\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1\n",
      -       "    sampling_time:              27.989259004592896\n",
      -       "    tuning_steps:               1000

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain       (chain) int64 0 1 2 3\n",
      -       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      -       "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (chain, draw, date) float64 ...\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:14.741457\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                (chain: 4, draw: 1000)\n",
      -       "Coordinates:\n",
      -       "  * chain                  (chain) int64 0 1 2 3\n",
      -       "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
      -       "Data variables: (12/17)\n",
      -       "    step_size              (chain, draw) float64 ...\n",
      -       "    lp                     (chain, draw) float64 ...\n",
      -       "    n_steps                (chain, draw) float64 ...\n",
      -       "    largest_eigval         (chain, draw) float64 ...\n",
      -       "    energy                 (chain, draw) float64 ...\n",
      -       "    acceptance_rate        (chain, draw) float64 ...\n",
      -       "    ...                     ...\n",
      -       "    diverging              (chain, draw) bool ...\n",
      -       "    smallest_eigval        (chain, draw) float64 ...\n",
      -       "    max_energy_error       (chain, draw) float64 ...\n",
      -       "    index_in_trajectory    (chain, draw) int64 ...\n",
      -       "    energy_error           (chain, draw) float64 ...\n",
      -       "    tree_depth             (chain, draw) int64 ...\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:12.932375\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1\n",
      -       "    sampling_time:              27.989259004592896\n",
      -       "    tuning_steps:               1000

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                    (chain: 1, draw: 500, channel: 2, control: 3,\n",
      -       "                                date: 179, fourier_mode: 4)\n",
      -       "Coordinates:\n",
      -       "  * chain                      (chain) int64 0\n",
      -       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
      -       "  * channel                    (channel) object 'x1' 'x2'\n",
      -       "  * control                    (control) object 'event_1' 'event_2' 't'\n",
      -       "  * date                       (date) object '2018-04-02' ... '2021-08-30'\n",
      -       "  * fourier_mode               (fourier_mode) object 'sin_order_1' ... 'cos_o...\n",
      -       "Data variables: (12/13)\n",
      -       "    beta_channel               (chain, draw, channel) float64 ...\n",
      -       "    sigma                      (chain, draw) float64 ...\n",
      -       "    gamma_control              (chain, draw, control) float64 ...\n",
      -       "    control_contributions      (chain, draw, date, control) float64 ...\n",
      -       "    mu                         (chain, draw, date) float64 ...\n",
      -       "    channel_adstock_saturated  (chain, draw, date, channel) float64 ...\n",
      -       "    ...                         ...\n",
      -       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 ...\n",
      -       "    lam                        (chain, draw, channel) float64 ...\n",
      -       "    intercept                  (chain, draw) float64 ...\n",
      -       "    gamma_fourier              (chain, draw, fourier_mode) float64 ...\n",
      -       "    channel_adstock            (chain, draw, date, channel) float64 ...\n",
      -       "    alpha                      (chain, draw, channel) float64 ...\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:14.422347\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain       (chain) int64 0\n",
      -       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      -       "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (chain, draw, date) float64 ...\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:14.428434\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (date: 179)\n",
      -       "Coordinates:\n",
      -       "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (date) float64 ...\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:12.937052\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
      -       "Coordinates:\n",
      -       "  * date          (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "  * channel       (channel) object 'x1' 'x2'\n",
      -       "  * control       (control) object 'event_1' 'event_2' 't'\n",
      -       "  * fourier_mode  (fourier_mode) object 'sin_order_1' ... 'cos_order_2'\n",
      -       "Data variables:\n",
      -       "    channel_data  (date, channel) float64 ...\n",
      -       "    target        (date) float64 ...\n",
      -       "    control_data  (date, control) float64 ...\n",
      -       "    fourier_data  (date, fourier_mode) float64 ...\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-17T13:57:12.938885\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:    (index: 179)\n",
      -       "Coordinates:\n",
      -       "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      -       "Data variables:\n",
      -       "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
      -       "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
      -       "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      -       "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      -       "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
      -       "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      -       "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
\n", - "
\n", - " " - ], - "text/plain": [ - "Inference data with groups:\n", - "\t> posterior\n", - "\t> posterior_predictive\n", - "\t> sample_stats\n", - "\t> prior\n", - "\t> prior_predictive\n", - "\t> observed_data\n", - "\t> constant_data\n", - "\t> fit_data" - ] - }, - "execution_count": 21, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "loaded_model.idata" - ] - }, - { - "cell_type": "markdown", - "id": "ab64be46-7fe5-4f39-b72b-36da1419f809", - "metadata": {}, - "source": [ - "A model loaded in this way is ready to be used for sampling and prediction, and has access to all previous samples and data." - ] - }, - { - "cell_type": "code", - "execution_count": 22, - "id": "dd59d056-6ac7-431c-85ee-99e7a8eefd8a", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Sampling: [likelihood]\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " 100.00% [4000/4000 00:00<00:00]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.Dataset>\n",
-       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
-       "Coordinates:\n",
-       "  * chain       (chain) int64 0 1 2 3\n",
-       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
-       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
-       "Data variables:\n",
-       "    likelihood  (chain, draw, date) float64 0.4907 0.4282 ... 0.5548 0.5396\n",
-       "Attributes:\n",
-       "    created_at:                 2023-08-17T13:57:22.302381\n",
-       "    arviz_version:              0.16.1\n",
-       "    inference_library:          pymc\n",
-       "    inference_library_version:  5.6.1
" - ], - "text/plain": [ - "\n", - "Dimensions: (chain: 4, draw: 1000, date: 179)\n", - "Coordinates:\n", - " * chain (chain) int64 0 1 2 3\n", - " * draw (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n", - " * date (date) " - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "az.plot_ppc(loaded_model.idata);" - ] - }, - { - "cell_type": "markdown", - "id": "e8e807f9", - "metadata": {}, - "source": [ - "## Summary:" - ] - }, - { - "cell_type": "markdown", - "id": "61f232c1", - "metadata": {}, - "source": [ - "In summary, this article introduces the revolutionary ModelBuilder, a new [PyMC-experimental](https://github.com/pymc-devs/pymc-experimental) module that simplifies the deployment of PyMC Bayesian models. It addresses a historic challenge faced by users of PyMC and most PPLs by offering a user-friendly and efficient approach to model deployment. The ModelBuilder provides two straightforward methods, save() and load(), which streamline the model preservation and replication process post fitting. Users are offered flexibility in controlling the prior settings with `model_config` and customizing the sampling process via `sampler_config`.\n", - "\n", - "The use of an example model from the [MMM Example Notebook](https://www.pymc-marketing.io/en/stable/notebooks/index.html) demonstrates the practical implementation of `ModelBuilder`, emphasizing its ability to enhance model sharing among teams without the necessity for extensive domain knowledge about the model. The deployment improvements in [PyMC-Marketing](https://github.com/pymc-labs/pymc-marketing) brought about by ModelBuilder are not only user-friendly but also significantly enhance efficiency, making PyMC models more accessible for a wider audience." - ] - }, - { - "cell_type": "markdown", - "id": "b8ad333d", - "metadata": {}, - "source": [ - "Even though this introduction is using `DelayedSaturatedMMM`, functionalities from `ModelBuilder` are available in the CLV models as well." - ] - }, - { - "cell_type": "markdown", - "id": "de822885-2fb4-4ad1-aaf9-7659771f7363", - "metadata": {}, - "source": [ - "## Authors\n", - "- Authored by [Michał Raczycki](https://github.com/michaelraczycki) in August 2023" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "07ec32d2-2f38-47ec-a200-b9d7258c3ac5", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Last updated: Thu Aug 17 2023\n", - "\n", - "Python implementation: CPython\n", - "Python version : 3.10.12\n", - "IPython version : 8.14.0\n", - "\n", - "pytensor: 2.12.3\n", - "aeppl : not installed\n", - "xarray : 2023.7.0\n", - "\n", - "numpy : 1.25.1\n", - "pandas: 2.0.3\n", - "arviz : 0.16.1\n", - "\n", - "Watermark: 2.4.3\n", - "\n" - ] - } - ], - "source": [ - "%load_ext watermark\n", - "%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray" - ] - }, - { - "cell_type": "markdown", - "id": "7e453d63-39f3-456c-a099-2c780d3ba4a9", - "metadata": {}, - "source": [ - ":::\n", - "{include} ../page_footer.md\n", - ":::" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pymc-marketing", - "language": "python", - "name": "pymc-marketing" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/examples/howto/ModelBuilder_usage_example.myst.md b/examples/howto/ModelBuilder_usage_example.myst.md deleted file mode 100644 index 3ea5bbe16..000000000 --- a/examples/howto/ModelBuilder_usage_example.myst.md +++ /dev/null @@ -1,279 +0,0 @@ ---- -jupytext: - text_representation: - extension: .md - format_name: myst - format_version: 0.13 -kernelspec: - display_name: pymc-marketing - language: python - name: pymc-marketing ---- - -(ModelBuilder usage example)= -# ModelBuilder usage example - -:::{post} Aug 18, 2023 -:tags: ModelBuilder, model deployment, -:category: intermediate, tutorial -:author: Michał Raczycki -::: - -+++ - -# Deploying MMMs and CLVs in Production: Saving and Loading Models - -+++ - -In this article, we'll tackle the historically challenging process of deploying Bayesian models built with PyMC. Introducing a revolutionary deployment module, we bring unprecedented simplicity and efficiency to the deployment of PyMC models. As we prioritize user-friendly solutions, let's delve into how this innovation can significantly elevate your data science projects. - -+++ - - -Recent release of PyMC-Marketing by [Labs](https://www.pymc-labs.io) proves to be a big hit [(PyMC-Marketing)](https://www.pymc-labs.io/blog-posts/pymc-marketing-a-bayesian-approach-to-marketing-data-science/). In the feedback one could see an ongoing theme, many of you have been requesting easy and robust way of deploying models to production. It’s been a long-standing problem with PyMC ( and most other Probabilistic Programming Languages). The reason for that is that there’s no obvious way, and doesn’t matter which approach you try it proves to be tricky. That is why we’re happy to announce the release of `ModelBuilder`, brand new PyMC-experimental module that addresses this need, and improves on the deployment process significantly. - -The ModelBuilder module is a new feature of PyMC based models. It provides 2 easy-to-use methods: save() and load() that can be used after the model has been fit.save() allow easy preservation of the model to .netcdf format, and load() gives one-line replication of the original model. Users can control the prior settings with model_config, and customize the sampling process using sampler_config. Default values of those are working just fine, so first time give it a try without changing, and provide your own model_config and model_sampler if afterwards you want to try to customize it more for your use case! - -+++ - -For this notebook I'll use the example model used in [MMM Example Notebook](https://www.pymc-marketing.io/en/stable/notebooks/mmm/mmm_example.html), but ommit the details of data generation and plotting functionalities, since they're out of scope for this introduction, I highly recommend to see that part as well, but for now let's focus on today's topic: Groundbreaking deployment improvements in PyMC-Marketing! - -```{code-cell} ipython3 -import arviz as az -import numpy as np -import pandas as pd - -from pymc_marketing.mmm import DelayedSaturatedMMM -``` - -```{code-cell} ipython3 -az.style.use("arviz-darkgrid") -``` - -Let's load the dataset: - -```{code-cell} ipython3 -url = "https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/datasets/mmm_example.csv" -df = pd.read_csv(url) -df -``` - -But for our model we need much smaller dataset, many of the previous features were contributing to generation of others, now as our target variable is computed we can filter out not needed columns: - -```{code-cell} ipython3 -columns_to_keep = [ - "date_week", - "y", - "x1", - "x2", - "event_1", - "event_2", - "dayofyear", -] -seed: int = sum(map(ord, "mmm")) -rng = np.random.default_rng(seed=seed) - -data = df[columns_to_keep].copy() - -data["t"] = range(df.shape[0]) -data.head() -``` - -## _Model Creation_ -After we have our dataset ready, we could proceed straight to our model definition, but first to show the full potential of one of the new features: `model_config` we need to use some of our data to define our prior for sigma parameter for each of the channels. `model_config` is a customizable dictionary with keys corresponding to priors within the model, and values containing a dictionaries with parameters necessary to initialize them. Later on we'll learn that through the `save()` method we can preserve our priors contained inside the `model_config`, to allow complete replication of our model. - -+++ - -### model_config - -+++ - -`default_model_config` attribute of every model inheriting from `ModelBuilder` will allow you to see which priors are available for customization. To see it simply initialize a dummy model: - -```{code-cell} ipython3 -dummy_model = DelayedSaturatedMMM(date_column="", channel_columns="", adstock_max_lag=4) -dummy_model.default_model_config -``` - -You can change only the prior parameters that you wish, no need to alter all of them, unless you'd like to! -In this case we'll just simply replace our sigma for beta_channel with our computed one: - -+++ - -First, let's compute the share of spend per channel: - -```{code-cell} ipython3 -total_spend_per_channel = data[["x1", "x2"]].sum(axis=0) - -spend_share = total_spend_per_channel / total_spend_per_channel.sum() - -spend_share -``` - -Next, we specify the `sigma`parameter per channel: - -```{code-cell} ipython3 -# The scale necessary to make a HalfNormal distribution have unit variance -HALFNORMAL_SCALE = 1 / np.sqrt(1 - 2 / np.pi) - -n_channels = 2 - -prior_sigma = HALFNORMAL_SCALE * n_channels * spend_share.to_numpy() - -prior_sigma.tolist() -``` - -```{code-cell} ipython3 -custom_beta_channel_prior = {"beta_channel": {"sigma": prior_sigma, "dims": ("channel",)}} -my_model_config = dummy_model.default_model_config | custom_beta_channel_prior -``` - -As mentioned in the original notebook: "_For the prior specification there is no right or wrong answer. It all depends on the data, the context and the assumptions you are willing to make. It is always recommended to do some prior predictive sampling and sensitivity analysis to check the impact of the priors on the posterior. We skip this here for the sake of simplicity. If you are not sure about specific priors, the `DelayedSaturatedMMM` class has some default priors that you can use as a starting point._" - -+++ - -The second feature that we can use for model definition is `sampler_config`. Similar to `model_config`, it's a dictionary that gets saved and contains things you'd usually pass to the `fit()` kwargs. It's not mandatory to create your own `sampler_config`; if not provided, both `model_config` and `sampler_config` will default to the forms specified by PyMC Labs experts, which allows for the usage of all model functionalities. The default `sampler_config` is left empty because the default sampling parameters usually prove sufficient for a start. - -```{code-cell} ipython3 -dummy_model.default_sampler_config -``` - -```{code-cell} ipython3 -my_sampler_config = { - "tune": 1000, - "draws": 1000, - "chains": 4, - "target_accept": 0.95, -} -``` - -Let's finally assemble our model! - -```{code-cell} ipython3 -mmm = DelayedSaturatedMMM( - model_config=my_model_config, - sampler_config=my_sampler_config, - date_column="date_week", - channel_columns=["x1", "x2"], - control_columns=[ - "event_1", - "event_2", - "t", - ], - adstock_max_lag=8, - yearly_seasonality=2, -) -``` - -An important thing to note here is that in the new version of `DelayedSaturatedMMM`, we don't pass our dataset to the class constructor itself. This is due to a reason I've mentioned before - it supports `sklearn` transformers and validations that require a usual X, y split and typically expect the data to be passed to the `fit()` method. - -+++ - -## _Model Fitting_ - -+++ - -Let's split the dataset: - -```{code-cell} ipython3 -X = data.drop("y", axis=1) -y = data["y"] -``` - -All that's left now is to finally fit the model: - -As you can see below, you can still pass the sampler kwargs directly to `fit()` method. However, only those kwargs passed using `sampler_config` will be saved. Therefore, only these will be available after loading the model. - -```{code-cell} ipython3 -mmm.fit(X=X, y=y, random_seed=rng) -``` - -The `fit()` method automatically builds the model using the priors from `model_config`, and assigns the created model to our instance. You can access it as a normal attribute. - -```{code-cell} ipython3 -type(mmm.model) -``` - -```{code-cell} ipython3 -mmm.graphviz() -``` - -posterior trace can be accessed by `fit_result` attribute: - -```{code-cell} ipython3 -mmm.fit_result -``` - -If you wish to inspect the entire inference data, use the `idata` attribute. Within `idata`, you can find the entire dataset passed to the model under `fit_data`. - -```{code-cell} ipython3 -mmm.idata -``` - -## `Save` and `load` - -+++ - -All the data passed to the model on initialisation is stored in `idata.attrs`. This will be used later in the `save()` method to convert both this data and all the fit data into the netCDF format. - -+++ - -Simply specify the path to which you'd like to save your model: - -```{code-cell} ipython3 -mmm.save("my_saved_model.nc") -``` - -And pass it to the `load()` method when it's needed again on the target system: - -```{code-cell} ipython3 -loaded_model = DelayedSaturatedMMM.load("my_saved_model.nc") -``` - -```{code-cell} ipython3 -loaded_model.graphviz() -``` - -```{code-cell} ipython3 -loaded_model.idata -``` - -A model loaded in this way is ready to be used for sampling and prediction, and has access to all previous samples and data. - -```{code-cell} ipython3 -with loaded_model.model: - new_predictions = loaded_model.sample_posterior_predictive( - X, extend_idata=True, combined=False, random_seed=rng - ) -new_predictions -``` - -```{code-cell} ipython3 -az.plot_ppc(loaded_model.idata); -``` - -## Summary: - -+++ - -In summary, this article introduces the revolutionary ModelBuilder, a new [PyMC-experimental](https://github.com/pymc-devs/pymc-experimental) module that simplifies the deployment of PyMC Bayesian models. It addresses a historic challenge faced by users of PyMC and most PPLs by offering a user-friendly and efficient approach to model deployment. The ModelBuilder provides two straightforward methods, save() and load(), which streamline the model preservation and replication process post fitting. Users are offered flexibility in controlling the prior settings with `model_config` and customizing the sampling process via `sampler_config`. - -The use of an example model from the [MMM Example Notebook](https://www.pymc-marketing.io/en/stable/notebooks/index.html) demonstrates the practical implementation of `ModelBuilder`, emphasizing its ability to enhance model sharing among teams without the necessity for extensive domain knowledge about the model. The deployment improvements in [PyMC-Marketing](https://github.com/pymc-labs/pymc-marketing) brought about by ModelBuilder are not only user-friendly but also significantly enhance efficiency, making PyMC models more accessible for a wider audience. - -+++ - -Even though this introduction is using `DelayedSaturatedMMM`, functionalities from `ModelBuilder` are available in the CLV models as well. - -+++ - -## Authors -- Authored by [Michał Raczycki](https://github.com/michaelraczycki) in August 2023 - -```{code-cell} ipython3 -%load_ext watermark -%watermark -n -u -v -iv -w -p pytensor,aeppl,xarray -``` - -::: -{include} ../page_footer.md -::: diff --git a/examples/howto/model_builder.ipynb b/examples/howto/model_builder.ipynb new file mode 100644 index 000000000..0caab67fc --- /dev/null +++ b/examples/howto/model_builder.ipynb @@ -0,0 +1,692 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8f00588f-6a28-4d93-b072-f464f78aae40", + "metadata": {}, + "source": [ + "# Using ModelBuilder class for deploying PyMC models \n", + ":::{post} Feb 22, 2023\n", + ":tags: deployment\n", + ":category: Advanced\n", + ":author: Shashank Kirtania, Thomas Wiecki, Michał Raczycki\n", + ":::" + ] + }, + { + "cell_type": "markdown", + "id": "1fdfb702-b893-4e63-8354-935f9742fdde", + "metadata": {}, + "source": [ + "## Motivation" + ] + }, + { + "cell_type": "markdown", + "id": "13a7e0ee-506c-4d5d-adb3-a52cc24cac50", + "metadata": {}, + "source": [ + "Many users face difficulty in deploying their PyMC models to production because deploying/saving/loading a user-created model is not well standardized. One of the reasons behind this is there is no direct way to save or load a model in PyMC like scikit-learn or TensorFlow. The new `ModelBuilder` class is aimed to improve this workflow by providing a scikit-learn inspired API to wrap your PyMC models.\n", + "\n", + "The new `ModelBuilder` class allows users to use methods to `fit()`, `predict()`, `save()`, `load()`. Users can create any model they want, inherit the `ModelBuilder` class, and use predefined methods." + ] + }, + { + "cell_type": "markdown", + "id": "94832375-dc7e-4b4f-ad2e-87363fc363db", + "metadata": {}, + "source": [ + "Let's go through the full workflow, starting with a simple linear regression PyMC model as it's usually written. Of course, this model is just a place-holder for your own model." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "48e35045", + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Dict, List, Optional, Tuple, Union\n", + "\n", + "import arviz as az\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import pandas as pd\n", + "import pymc as pm\n", + "import xarray as xr\n", + "\n", + "from numpy.random import RandomState\n", + "\n", + "%config InlineBackend.figure_format = 'retina'\n", + "RANDOM_SEED = 8927\n", + "\n", + "rng = np.random.default_rng(RANDOM_SEED)\n", + "az.style.use(\"arviz-darkgrid\")" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "d6eccf65", + "metadata": {}, + "outputs": [], + "source": [ + "# Generate data\n", + "x = np.linspace(start=0, stop=1, num=100)\n", + "y = 0.3 * x + 0.5 + rng.normal(0, 1, len(x))" + ] + }, + { + "cell_type": "markdown", + "id": "291452ed", + "metadata": {}, + "source": [ + "## Standard syntax\n", + "Usually a PyMC model will have this form:" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "84d07dc6", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Auto-assigning NUTS sampler...\n", + "Initializing NUTS using jitter+adapt_diag...\n", + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (4 chains in 4 jobs)\n", + "NUTS: [a, b, eps]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 10 seconds.\n", + "Sampling: [a, b, eps, y_model]\n" + ] + } + ], + "source": [ + "with pm.Model() as model:\n", + " # priors\n", + " a = pm.Normal(\"a\", mu=0, sigma=1)\n", + " b = pm.Normal(\"b\", mu=0, sigma=1)\n", + " eps = pm.HalfNormal(\"eps\", 1.0)\n", + "\n", + " # observed data\n", + " y_model = pm.Normal(\"y_model\", mu=a + b * x, sigma=eps, observed=y)\n", + "\n", + " # Fitting\n", + " idata = pm.sample()\n", + " idata.extend(pm.sample_prior_predictive())\n", + "\n", + " # posterior predict\n", + " idata.extend(pm.sample_posterior_predictive(idata))" + ] + }, + { + "cell_type": "markdown", + "id": "eda28484", + "metadata": {}, + "source": [ + "How would we deploy this model? Save the fitted model, load it on an instance, and predict? Not so simple.\n", + "\n", + "`ModelBuilder` is built for this purpose. It is currently part of the `pymc-experimental` package which we can pip install with `pip install pymc-experimental`. As the name implies, this feature is still experimental and subject to change." + ] + }, + { + "cell_type": "markdown", + "id": "213ee05a", + "metadata": {}, + "source": [ + "## Model builder class" + ] + }, + { + "cell_type": "markdown", + "id": "36214695-5fb1-4450-a3ea-789f2e965746", + "metadata": {}, + "source": [ + "Let's import the `ModelBuilder` class." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "a1f5fa98-53d8-459a-827b-fa5179861918", + "metadata": {}, + "outputs": [], + "source": [ + "from pymc_experimental.model_builder import ModelBuilder" + ] + }, + { + "cell_type": "markdown", + "id": "ef0412fe-5aae-4bfa-8a1f-0e1e3762fc5f", + "metadata": {}, + "source": [ + "To define our desired model we inherit from the `ModelBuilder` class. There are a couple of methods we need to define." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "1c4b4575-630c-45a7-9eee-4790adf8924f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "class LinearModel(ModelBuilder):\n", + " # Give the model a name\n", + " _model_type = \"LinearModel\"\n", + "\n", + " # And a version\n", + " version = \"0.1\"\n", + "\n", + " def build_model(self, X: pd.DataFrame, y: Union[pd.Series, np.ndarray], **kwargs):\n", + " \"\"\"\n", + " build_model creates the PyMC model\n", + "\n", + " Parameters:\n", + " model_config: dictionary\n", + " it is a dictionary with all the parameters that we need in our model example: a_loc, a_scale, b_loc\n", + " data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]]\n", + " Data we want our model fit on.\n", + " \"\"\"\n", + " # Check the type of X and y and adjust access accordingly\n", + " X_values = X[\"input\"].values\n", + " y_values = y.values if isinstance(y, pd.Series) else y\n", + " self._generate_and_preprocess_model_data(X_values, y_values)\n", + "\n", + " with pm.Model(coords=self.model_coords) as self.model:\n", + "\n", + " # Create mutable data containers\n", + " x_data = pm.MutableData(\"x_data\", X_values)\n", + " y_data = pm.MutableData(\"y_data\", y_values)\n", + "\n", + " # prior parameters\n", + " a_mu_prior = self.model_config.get(\"a_mu_prior\", 0.0)\n", + " a_sigma_prior = self.model_config.get(\"a_sigma_prior\", 1.0)\n", + " b_mu_prior = self.model_config.get(\"b_mu_prior\", 0.0)\n", + " b_sigma_prior = self.model_config.get(\"b_sigma_prior\", 1.0)\n", + " eps_prior = self.model_config.get(\"eps_prior\", 1.0)\n", + "\n", + " # priors\n", + " a = pm.Normal(\"a\", mu=a_mu_prior, sigma=a_sigma_prior)\n", + " b = pm.Normal(\"b\", mu=b_mu_prior, sigma=b_sigma_prior)\n", + " eps = pm.HalfNormal(\"eps\", eps_prior)\n", + "\n", + " obs = pm.Normal(\"y\", mu=a + b * x_data, sigma=eps, shape=x_data.shape, observed=y_data)\n", + "\n", + " def _data_setter(\n", + " self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray] = None\n", + " ):\n", + " if isinstance(X, pd.DataFrame):\n", + " x_values = X[\"input\"].values\n", + " else:\n", + " # Assuming \"input\" is the first column\n", + " x_values = X[:, 0]\n", + "\n", + " with self.model:\n", + " pm.set_data({\"x_data\": x_values})\n", + " if y is not None:\n", + " pm.set_data({\"y_data\": y.values if isinstance(y, pd.Series) else y})\n", + "\n", + " @property\n", + " def default_model_config(self) -> Dict:\n", + " \"\"\"\n", + " default_model_config is a property that returns a dictionary with all the prior values we want to build the model with.\n", + " It supports more complex data structures like lists, dictionaries, etc.\n", + " It will be passed to the class instance on initialization, in case the user doesn't provide any model_config of their own.\n", + " \"\"\"\n", + " model_config: Dict = {\n", + " \"a_mu_prior\": 0.0,\n", + " \"a_sigma_prior\": 1.0,\n", + " \"b_mu_prior\": 0.0,\n", + " \"b_sigma_prior\": 1.0,\n", + " \"eps_prior\": 1.0,\n", + " }\n", + " return model_config\n", + "\n", + " @property\n", + " def default_sampler_config(self) -> Dict:\n", + " \"\"\"\n", + " default_sampler_config is a property that returns a dictionary with all most important sampler parameters.\n", + " It will be used in case the user doesn't provide any sampler_config of their own.\n", + " \"\"\"\n", + " sampler_config: Dict = {\n", + " \"draws\": 1_000,\n", + " \"tune\": 1_000,\n", + " \"chains\": 3,\n", + " \"target_accept\": 0.95,\n", + " }\n", + " return sampler_config\n", + "\n", + " @property\n", + " def output_var(self):\n", + " return \"y\"\n", + "\n", + " @property\n", + " def _serializable_model_config(self) -> Dict[str, Union[int, float, Dict]]:\n", + " \"\"\"\n", + " _serializable_model_config is a property that returns a dictionary with all the model parameters that we want to save.\n", + " as some of the data structures are not json serializable, we need to convert them to json serializable objects.\n", + " Some models will need them, others can just define them to return the model_config.\n", + " \"\"\"\n", + " return self.model_config\n", + "\n", + " def _save_input_params(self, idata) -> None:\n", + " \"\"\"\n", + " Saves any additional model parameters (other than the dataset) to the idata object.\n", + "\n", + " These parameters are stored within `idata.attrs` using keys that correspond to the parameter names.\n", + " If you don't need to store any extra parameters, you can leave this method unimplemented.\n", + "\n", + " Example:\n", + " For saving customer IDs provided as an 'customer_ids' input to the model:\n", + " self.customer_ids = customer_ids.values #this line is done outside of the function, preferably at the initialization of the model object.\n", + " idata.attrs[\"customer_ids\"] = json.dumps(self.customer_ids.tolist()) # Convert numpy array to a JSON-serializable list.\n", + " \"\"\"\n", + " pass\n", + "\n", + " pass\n", + "\n", + " def _generate_and_preprocess_model_data(\n", + " self, X: Union[pd.DataFrame, pd.Series], y: Union[pd.Series, np.ndarray]\n", + " ) -> None:\n", + " \"\"\"\n", + " Depending on the model, we might need to preprocess the data before fitting the model.\n", + " all required preprocessing and conditional assignments should be defined here.\n", + " \"\"\"\n", + " self.model_coords = None # in our case we're not using coords, but if we were, we would define them here, or later on in the function, if extracting them from the data.\n", + " # as we don't do any data preprocessing, we just assign the data givenin by the user. Note that it's very basic model,\n", + " # and usually we would need to do some preprocessing, or generate the coords from the data.\n", + " self.X = X\n", + " self.y = y" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "aa682cee-58b0-4c51-b5fd-f99d6afaea69", + "metadata": {}, + "source": [ + "Now we can create the `LinearModel` object. First step we need to take care of, is data generation:" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "8658306c-f1eb-45a7-9c71-3fcee06183bc", + "metadata": {}, + "outputs": [], + "source": [ + "X = pd.DataFrame(data=np.linspace(start=0, stop=1, num=100), columns=[\"input\"])\n", + "y = 0.3 * x + 0.5\n", + "y = y + np.random.normal(0, 1, len(x))\n", + "\n", + "model = LinearModel()" + ] + }, + { + "cell_type": "markdown", + "id": "294cf57b-b51f-4c77-8e0b-5adaf0a63f2b", + "metadata": {}, + "source": [ + "After making the object of class `LinearModel` we can fit the model using the `.fit()` method." + ] + }, + { + "cell_type": "markdown", + "id": "3d4dead3", + "metadata": {}, + "source": [ + "## Fitting to data" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "e50eb992", + "metadata": {}, + "source": [ + "The `fit()` method takes one argument `data` on which we need to fit the model. The meta-data is saved in the `InferenceData` object where also the trace is stored. These are the fields that are stored:\n", + "\n", + "* `id` : This is a unique id given to a model based on model_config, sample_conifg, version, and model_type. Users can use it to check if the model matches to another model they have defined.\n", + "* `model_type` : Model type tells us what kind of model it is. This in this case it outputs **Linear Model** \n", + "* `version` : In case you want to improve on models, you can keep track of model by their version. As the version changes the unique hash in the `id` also changes.\n", + "* `sample_conifg` : It stores values of the sampler configuration set by user for this particular model.\n", + "* `model_config` : It stores values of the model configuration set by user for this particular model.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a3708a8f-40f6-4a04-bcbf-284397f25450", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Auto-assigning NUTS sampler...\n", + "Initializing NUTS using jitter+adapt_diag...\n", + "Initializing NUTS using jitter+adapt_diag...\n", + "Multiprocess sampling (3 chains in 4 jobs)\n", + "NUTS: [a, b, eps]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling 3 chains for 1_000 tune and 1_000 draw iterations (3_000 + 3_000 draws total) took 7 seconds.\n", + "We recommend running at least 4 chains for robust computation of convergence diagnostics\n", + "Sampling: [a, b, eps, y]\n" + ] + } + ], + "source": [ + "idata = model.fit(X, y)" + ] + }, + { + "cell_type": "markdown", + "id": "ac975628", + "metadata": {}, + "source": [ + "## Saving model to file" + ] + }, + { + "cell_type": "markdown", + "id": "1649556a-13b6-409f-ac09-5c4b7e0277b7", + "metadata": {}, + "source": [ + "After fitting the model, we can probably save it to share the model as a file so one can use it again.\n", + "To `save()` or `load()`, we can quickly call methods for respective tasks with the following syntax." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "a965d738-60c5-4b4b-b872-f2613621851b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "fname = \"linear_model_v1.nc\"\n", + "model.save(fname)" + ] + }, + { + "cell_type": "markdown", + "id": "490e8802-0395-42c7-a01a-18d9af272320", + "metadata": {}, + "source": [ + "This saves a file at the given path, and the name
\n", + "A NetCDF `.nc` file that stores the inference data of the model." + ] + }, + { + "cell_type": "markdown", + "id": "cf072612", + "metadata": {}, + "source": [ + "## Loading a model" + ] + }, + { + "cell_type": "markdown", + "id": "3e188eb0-c42e-4cd5-b70c-568d9cde71f0", + "metadata": {}, + "source": [ + "Now if we wanted to deploy this model, or just have other people use it to predict data, they need two things:\n", + "1. the `LinearModel` class (probably in a .py file)\n", + "2. the linear_model_v1.nc file\n", + "\n", + "With these, you can easily load a fitted model in a different environment (e.g. production):" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "fe2bccf2-1707-4b21-803b-50716e9298c3", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/michalraczycki/Documents/pymc-marketing/.conda/envs/pymc-marketing/lib/python3.10/site-packages/arviz/data/inference_data.py:153: UserWarning: fit_data group is not defined in the InferenceData scheme\n", + " warnings.warn(\n" + ] + } + ], + "source": [ + "model_2 = LinearModel.load(fname)" + ] + }, + { + "cell_type": "markdown", + "id": "27ac7c8f", + "metadata": {}, + "source": [ + "Note that `load()` is a class-method, we do not need to instantiate the `LinearModel` object." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "b67f25d6", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "__main__.LinearModel" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "type(model_2)" + ] + }, + { + "cell_type": "markdown", + "id": "3dc1840c", + "metadata": {}, + "source": [ + "## Prediction" + ] + }, + { + "cell_type": "markdown", + "id": "1d7254f1-7a59-4623-a128-8a1dd48d0407", + "metadata": {}, + "source": [ + "Next we might want to predict on new data. The `predict()` method allows users to do posterior prediction with the fitted model on new data.\n", + "\n", + "Our first task is to create data on which we need to predict." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "3ecc8694-db5e-4d45-b8e0-78608b7eaa83", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "numpy.ndarray" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "x_pred = np.random.uniform(low=1, high=2, size=10)\n", + "prediction_data = pd.DataFrame({\"input\": x_pred})\n", + "type(prediction_data[\"input\"].values)" + ] + }, + { + "cell_type": "markdown", + "id": "1b155d2d-0211-4d85-8b60-a728a62e3743", + "metadata": {}, + "source": [ + "`ModelBuilder` provides two methods for prediction:\n", + "1. point estimates (the mean) with `predict()`\n", + "2. full posterior prediction (samples) with `predict_posterior()`" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "6926eba3-52ed-4c6c-b58f-f2e0bba7b45a", + "metadata": {}, + "outputs": [], + "source": [ + "pred_mean = model_2.predict(prediction_data)\n", + "# samples\n", + "pred_samples = model_2.predict_posterior(prediction_data)" + ] + }, + { + "cell_type": "markdown", + "id": "cfb595b5-e237-4099-b16d-f00c4448307e", + "metadata": {}, + "source": [ + "After using the `predict()`, we can plot our data and see graphically how satisfactory our `LinearModel` is." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "a5375a1c-ed19-4e06-9d9f-74369877cac2", + "metadata": {}, + "outputs": [ + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": { + "image/png": { + "height": 711, + "width": 711 + } + }, + "output_type": "display_data" + } + ], + "source": [ + "fig, ax = plt.subplots(figsize=(7, 7))\n", + "posterior = az.extract(idata, num_samples=20)\n", + "x_plot = xr.DataArray(np.linspace(1, 2, 100))\n", + "y_plot = posterior[\"b\"] * x_plot + posterior[\"a\"]\n", + "Line2 = ax.plot(x_plot, y_plot.transpose(), color=\"C1\")\n", + "Line1 = ax.plot(x_pred, pred_mean, \"x\")\n", + "ax.set(title=\"Posterior predictive regression lines\", xlabel=\"x\", ylabel=\"y\")\n", + "ax.legend(\n", + " handles=[Line1[0], Line2[0]], labels=[\"predicted average\", \"inferred regression line\"], loc=0\n", + ");" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "07fb64ed-f707-4e19-9e27-b0c2700c04f6", + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Last updated: Wed Aug 23 2023\n", + "\n", + "Python implementation: CPython\n", + "Python version : 3.10.12\n", + "IPython version : 8.14.0\n", + "\n", + "pymc_experimental: 0.0.11\n", + "\n", + "xarray : 2023.7.0\n", + "arviz : 0.16.1\n", + "pandas : 2.0.3\n", + "numpy : 1.25.1\n", + "matplotlib: 3.7.2\n", + "pymc : 5.6.1\n", + "\n", + "Watermark: 2.4.3\n", + "\n" + ] + } + ], + "source": [ + "%load_ext watermark\n", + "%watermark -n -u -v -iv -w -p pymc_experimental" + ] + }, + { + "cell_type": "markdown", + "id": "4917782b", + "metadata": {}, + "source": [ + "## Authors\n", + "* Authored by Shashank Kirtania and Thomas Wiecki in 2023.\n", + "* Modified and updated by Michał Raczycki in 08/2023" + ] + }, + { + "cell_type": "markdown", + "id": "dab6cda6", + "metadata": {}, + "source": [ + ":::{include} ../page_footer.md\n", + ":::" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc-marketing", + "language": "python", + "name": "pymc-marketing" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + }, + "toc-autonumbering": false, + "toc-showmarkdowntxt": true + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/howto/model_builder.myst.md b/examples/howto/model_builder.myst.md new file mode 100644 index 000000000..0747511aa --- /dev/null +++ b/examples/howto/model_builder.myst.md @@ -0,0 +1,350 @@ +--- +jupytext: + text_representation: + extension: .md + format_name: myst + format_version: 0.13 +kernelspec: + display_name: pymc-marketing + language: python + name: pymc-marketing +--- + +# Using ModelBuilder class for deploying PyMC models +:::{post} Feb 22, 2023 +:tags: deployment +:category: Advanced +:author: Shashank Kirtania, Thomas Wiecki, Michał Raczycki +::: + ++++ + +## Motivation + ++++ + +Many users face difficulty in deploying their PyMC models to production because deploying/saving/loading a user-created model is not well standardized. One of the reasons behind this is there is no direct way to save or load a model in PyMC like scikit-learn or TensorFlow. The new `ModelBuilder` class is aimed to improve this workflow by providing a scikit-learn inspired API to wrap your PyMC models. + +The new `ModelBuilder` class allows users to use methods to `fit()`, `predict()`, `save()`, `load()`. Users can create any model they want, inherit the `ModelBuilder` class, and use predefined methods. + ++++ + +Let's go through the full workflow, starting with a simple linear regression PyMC model as it's usually written. Of course, this model is just a place-holder for your own model. + +```{code-cell} ipython3 +from typing import Dict, List, Optional, Tuple, Union + +import arviz as az +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import pymc as pm +import xarray as xr + +from numpy.random import RandomState + +%config InlineBackend.figure_format = 'retina' +RANDOM_SEED = 8927 + +rng = np.random.default_rng(RANDOM_SEED) +az.style.use("arviz-darkgrid") +``` + +```{code-cell} ipython3 +# Generate data +x = np.linspace(start=0, stop=1, num=100) +y = 0.3 * x + 0.5 + rng.normal(0, 1, len(x)) +``` + +## Standard syntax +Usually a PyMC model will have this form: + +```{code-cell} ipython3 +with pm.Model() as model: + # priors + a = pm.Normal("a", mu=0, sigma=1) + b = pm.Normal("b", mu=0, sigma=1) + eps = pm.HalfNormal("eps", 1.0) + + # observed data + y_model = pm.Normal("y_model", mu=a + b * x, sigma=eps, observed=y) + + # Fitting + idata = pm.sample() + idata.extend(pm.sample_prior_predictive()) + + # posterior predict + idata.extend(pm.sample_posterior_predictive(idata)) +``` + +How would we deploy this model? Save the fitted model, load it on an instance, and predict? Not so simple. + +`ModelBuilder` is built for this purpose. It is currently part of the `pymc-experimental` package which we can pip install with `pip install pymc-experimental`. As the name implies, this feature is still experimental and subject to change. + ++++ + +## Model builder class + ++++ + +Let's import the `ModelBuilder` class. + +```{code-cell} ipython3 +from pymc_experimental.model_builder import ModelBuilder +``` + +To define our desired model we inherit from the `ModelBuilder` class. There are a couple of methods we need to define. + +```{code-cell} ipython3 +:tags: [] + +class LinearModel(ModelBuilder): + # Give the model a name + _model_type = "LinearModel" + + # And a version + version = "0.1" + + def build_model(self, X: pd.DataFrame, y: Union[pd.Series, np.ndarray], **kwargs): + """ + build_model creates the PyMC model + + Parameters: + model_config: dictionary + it is a dictionary with all the parameters that we need in our model example: a_loc, a_scale, b_loc + data: Dict[str, Union[np.ndarray, pd.DataFrame, pd.Series]] + Data we want our model fit on. + """ + # Check the type of X and y and adjust access accordingly + X_values = X["input"].values + y_values = y.values if isinstance(y, pd.Series) else y + self._generate_and_preprocess_model_data(X_values, y_values) + + with pm.Model(coords=self.model_coords) as self.model: + + # Create mutable data containers + x_data = pm.MutableData("x_data", X_values) + y_data = pm.MutableData("y_data", y_values) + + # prior parameters + a_mu_prior = self.model_config.get("a_mu_prior", 0.0) + a_sigma_prior = self.model_config.get("a_sigma_prior", 1.0) + b_mu_prior = self.model_config.get("b_mu_prior", 0.0) + b_sigma_prior = self.model_config.get("b_sigma_prior", 1.0) + eps_prior = self.model_config.get("eps_prior", 1.0) + + # priors + a = pm.Normal("a", mu=a_mu_prior, sigma=a_sigma_prior) + b = pm.Normal("b", mu=b_mu_prior, sigma=b_sigma_prior) + eps = pm.HalfNormal("eps", eps_prior) + + obs = pm.Normal("y", mu=a + b * x_data, sigma=eps, shape=x_data.shape, observed=y_data) + + def _data_setter( + self, X: Union[pd.DataFrame, np.ndarray], y: Union[pd.Series, np.ndarray] = None + ): + if isinstance(X, pd.DataFrame): + x_values = X["input"].values + else: + # Assuming "input" is the first column + x_values = X[:, 0] + + with self.model: + pm.set_data({"x_data": x_values}) + if y is not None: + pm.set_data({"y_data": y.values if isinstance(y, pd.Series) else y}) + + @property + def default_model_config(self) -> Dict: + """ + default_model_config is a property that returns a dictionary with all the prior values we want to build the model with. + It supports more complex data structures like lists, dictionaries, etc. + It will be passed to the class instance on initialization, in case the user doesn't provide any model_config of their own. + """ + model_config: Dict = { + "a_mu_prior": 0.0, + "a_sigma_prior": 1.0, + "b_mu_prior": 0.0, + "b_sigma_prior": 1.0, + "eps_prior": 1.0, + } + return model_config + + @property + def default_sampler_config(self) -> Dict: + """ + default_sampler_config is a property that returns a dictionary with all most important sampler parameters. + It will be used in case the user doesn't provide any sampler_config of their own. + """ + sampler_config: Dict = { + "draws": 1_000, + "tune": 1_000, + "chains": 3, + "target_accept": 0.95, + } + return sampler_config + + @property + def output_var(self): + return "y" + + @property + def _serializable_model_config(self) -> Dict[str, Union[int, float, Dict]]: + """ + _serializable_model_config is a property that returns a dictionary with all the model parameters that we want to save. + as some of the data structures are not json serializable, we need to convert them to json serializable objects. + Some models will need them, others can just define them to return the model_config. + """ + return self.model_config + + def _save_input_params(self, idata) -> None: + """ + Saves any additional model parameters (other than the dataset) to the idata object. + + These parameters are stored within `idata.attrs` using keys that correspond to the parameter names. + If you don't need to store any extra parameters, you can leave this method unimplemented. + + Example: + For saving customer IDs provided as an 'customer_ids' input to the model: + self.customer_ids = customer_ids.values #this line is done outside of the function, preferably at the initialization of the model object. + idata.attrs["customer_ids"] = json.dumps(self.customer_ids.tolist()) # Convert numpy array to a JSON-serializable list. + """ + pass + + pass + + def _generate_and_preprocess_model_data( + self, X: Union[pd.DataFrame, pd.Series], y: Union[pd.Series, np.ndarray] + ) -> None: + """ + Depending on the model, we might need to preprocess the data before fitting the model. + all required preprocessing and conditional assignments should be defined here. + """ + self.model_coords = None # in our case we're not using coords, but if we were, we would define them here, or later on in the function, if extracting them from the data. + # as we don't do any data preprocessing, we just assign the data givenin by the user. Note that it's very basic model, + # and usually we would need to do some preprocessing, or generate the coords from the data. + self.X = X + self.y = y +``` + +Now we can create the `LinearModel` object. First step we need to take care of, is data generation: + +```{code-cell} ipython3 +X = pd.DataFrame(data=np.linspace(start=0, stop=1, num=100), columns=["input"]) +y = 0.3 * x + 0.5 +y = y + np.random.normal(0, 1, len(x)) + +model = LinearModel() +``` + +After making the object of class `LinearModel` we can fit the model using the `.fit()` method. + ++++ + +## Fitting to data + ++++ + +The `fit()` method takes one argument `data` on which we need to fit the model. The meta-data is saved in the `InferenceData` object where also the trace is stored. These are the fields that are stored: + +* `id` : This is a unique id given to a model based on model_config, sample_conifg, version, and model_type. Users can use it to check if the model matches to another model they have defined. +* `model_type` : Model type tells us what kind of model it is. This in this case it outputs **Linear Model** +* `version` : In case you want to improve on models, you can keep track of model by their version. As the version changes the unique hash in the `id` also changes. +* `sample_conifg` : It stores values of the sampler configuration set by user for this particular model. +* `model_config` : It stores values of the model configuration set by user for this particular model. + +```{code-cell} ipython3 +idata = model.fit(X, y) +``` + +## Saving model to file + ++++ + +After fitting the model, we can probably save it to share the model as a file so one can use it again. +To `save()` or `load()`, we can quickly call methods for respective tasks with the following syntax. + +```{code-cell} ipython3 +:tags: [] + +fname = "linear_model_v1.nc" +model.save(fname) +``` + +This saves a file at the given path, and the name
+A NetCDF `.nc` file that stores the inference data of the model. + ++++ + +## Loading a model + ++++ + +Now if we wanted to deploy this model, or just have other people use it to predict data, they need two things: +1. the `LinearModel` class (probably in a .py file) +2. the linear_model_v1.nc file + +With these, you can easily load a fitted model in a different environment (e.g. production): + +```{code-cell} ipython3 +model_2 = LinearModel.load(fname) +``` + +Note that `load()` is a class-method, we do not need to instantiate the `LinearModel` object. + +```{code-cell} ipython3 +type(model_2) +``` + +## Prediction + ++++ + +Next we might want to predict on new data. The `predict()` method allows users to do posterior prediction with the fitted model on new data. + +Our first task is to create data on which we need to predict. + +```{code-cell} ipython3 +x_pred = np.random.uniform(low=1, high=2, size=10) +prediction_data = pd.DataFrame({"input": x_pred}) +type(prediction_data["input"].values) +``` + +`ModelBuilder` provides two methods for prediction: +1. point estimates (the mean) with `predict()` +2. full posterior prediction (samples) with `predict_posterior()` + +```{code-cell} ipython3 +pred_mean = model_2.predict(prediction_data) +# samples +pred_samples = model_2.predict_posterior(prediction_data) +``` + +After using the `predict()`, we can plot our data and see graphically how satisfactory our `LinearModel` is. + +```{code-cell} ipython3 +fig, ax = plt.subplots(figsize=(7, 7)) +posterior = az.extract(idata, num_samples=20) +x_plot = xr.DataArray(np.linspace(1, 2, 100)) +y_plot = posterior["b"] * x_plot + posterior["a"] +Line2 = ax.plot(x_plot, y_plot.transpose(), color="C1") +Line1 = ax.plot(x_pred, pred_mean, "x") +ax.set(title="Posterior predictive regression lines", xlabel="x", ylabel="y") +ax.legend( + handles=[Line1[0], Line2[0]], labels=["predicted average", "inferred regression line"], loc=0 +); +``` + +```{code-cell} ipython3 +%load_ext watermark +%watermark -n -u -v -iv -w -p pymc_experimental +``` + +## Authors +* Authored by Shashank Kirtania and Thomas Wiecki in 2023. +* Modified and updated by Michał Raczycki in 08/2023 + ++++ + +:::{include} ../page_footer.md +::: From 92b2a5f68b3fc3f6e48fca6ff7cb1b09a7ca4433 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Fri, 1 Sep 2023 09:48:40 +0200 Subject: [PATCH 4/5] removing the mb folder, mb intro in examples/howto --- ...delBuilder in PyMC-Marketing context.ipynb | 17374 ---------------- 1 file changed, 17374 deletions(-) delete mode 100644 examples/model_builder/ModelBuilder in PyMC-Marketing context.ipynb diff --git a/examples/model_builder/ModelBuilder in PyMC-Marketing context.ipynb b/examples/model_builder/ModelBuilder in PyMC-Marketing context.ipynb deleted file mode 100644 index 11b7e71f8..000000000 --- a/examples/model_builder/ModelBuilder in PyMC-Marketing context.ipynb +++ /dev/null @@ -1,17374 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "1d74584c", - "metadata": {}, - "source": [ - "# Deploying MMMs and CLVs in Production: Saving and Loading Models" - ] - }, - { - "cell_type": "markdown", - "id": "3222ef04", - "metadata": {}, - "source": [ - "In this article, we'll tackle the historically challenging process of deploying Bayesian models built with PyMC. Introducing a revolutionary deployment module, we bring unprecedented simplicity and efficiency to the deployment of PyMC models. As we prioritize user-friendly solutions, let's delve into how this innovation can significantly elevate your data science projects." - ] - }, - { - "cell_type": "markdown", - "id": "ddb28436", - "metadata": {}, - "source": [ - "\n", - "Recent release of PyMC-Marketing by [Labs](https://www.pymc-labs.io) proves to be a big hit [(PyMC-Marketing)](https://www.pymc-labs.io/blog-posts/pymc-marketing-a-bayesian-approach-to-marketing-data-science/). In the feedback one could see an ongoing theme, many of you have been requesting easy and robust way of deploying models to production. It’s been a long-standing problem with PyMC ( and most other PPLs). The reason for that is that there’s no obvious way, and doesn’t matter which approach you try it proves to be tricky. That is why we’re happy to announce the release of `ModelBuilder`, brand new PyMC-experimental module that addresses this need, and improves on the deployment process significantly.\n", - "\n", - "The ModelBuilder module is a new feature of PyMC based models. It provides 2 easy-to-use methods: save() and load() that can be used after the model has been fit.save() allow easy preservation of the model to .netcdf format, and load() gives one-line replication of the original model. Users can control the prior settings with model_config, and customize the sampling process using sampler_config. Default values of those are working just fine, so first time give it a try without changing, and provide your own model_config and model_sampler if afterwards you want to try to customize it more for your use case!\n" - ] - }, - { - "cell_type": "markdown", - "id": "a808e36a", - "metadata": {}, - "source": [ - "For this notebook I'll use the example model used in [MMM Example Notebook](https://www.pymc-marketing.io/en/stable/notebooks/mmm/mmm_example.html), but ommit the details of data generation and plotting functionalities, since they're out of scope for this introduction, I highly recommend to see that part as well, but for now let's focus on today's topic: Groundbreaking deployment improvements in PyMC-Marketing!" - ] - }, - { - "cell_type": "code", - "execution_count": 31, - "id": "1050a937", - "metadata": {}, - "outputs": [], - "source": [ - "import arviz as az\n", - "import numpy as np\n", - "import pandas as pd\n", - "\n", - "from pymc_marketing.mmm import DelayedSaturatedMMM" - ] - }, - { - "cell_type": "markdown", - "id": "f37d808e", - "metadata": {}, - "source": [ - "Let's load the dataset:" - ] - }, - { - "cell_type": "code", - "execution_count": 32, - "id": "b7b1193f", - "metadata": { - "scrolled": true - }, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
date_weekyx1x2event_1event_2dayofyeartsin_order_1cos_order_1sin_order_2cos_order_2
02018-04-023984.6622370.3185800.0000000.00.09200.999930-0.011826-0.023651-0.999720
12018-04-093762.8717940.1123880.0000000.00.09910.991269-0.131859-0.261414-0.965227
22018-04-164466.9673880.2924000.0000000.00.010620.968251-0.249981-0.484089-0.875019
32018-04-233864.2193730.0713990.0000000.00.011330.931210-0.364483-0.678820-0.734304
42018-04-304441.6252780.3867450.0000000.00.012040.880683-0.473706-0.834370-0.551205
.......................................
1742021-08-023553.5461480.0330240.0000000.00.0214174-0.513901-0.8578490.8816990.471812
1752021-08-095565.5096820.1656150.8633490.00.0221175-0.613230-0.7899050.9687860.247898
1762021-08-164137.6514850.1718820.0000000.00.0228176-0.703677-0.7105200.9999530.009676
1772021-08-234479.0413510.2802570.0000000.00.0235177-0.783934-0.6208440.973402-0.229104
1782021-08-304675.9734390.4388570.0000000.00.0242178-0.852837-0.5221780.890665-0.454661
\n", - "

179 rows × 12 columns

\n", - "
" - ], - "text/plain": [ - " date_week y x1 x2 event_1 event_2 dayofyear \\\n", - "0 2018-04-02 3984.662237 0.318580 0.000000 0.0 0.0 92 \n", - "1 2018-04-09 3762.871794 0.112388 0.000000 0.0 0.0 99 \n", - "2 2018-04-16 4466.967388 0.292400 0.000000 0.0 0.0 106 \n", - "3 2018-04-23 3864.219373 0.071399 0.000000 0.0 0.0 113 \n", - "4 2018-04-30 4441.625278 0.386745 0.000000 0.0 0.0 120 \n", - ".. ... ... ... ... ... ... ... \n", - "174 2021-08-02 3553.546148 0.033024 0.000000 0.0 0.0 214 \n", - "175 2021-08-09 5565.509682 0.165615 0.863349 0.0 0.0 221 \n", - "176 2021-08-16 4137.651485 0.171882 0.000000 0.0 0.0 228 \n", - "177 2021-08-23 4479.041351 0.280257 0.000000 0.0 0.0 235 \n", - "178 2021-08-30 4675.973439 0.438857 0.000000 0.0 0.0 242 \n", - "\n", - " t sin_order_1 cos_order_1 sin_order_2 cos_order_2 \n", - "0 0 0.999930 -0.011826 -0.023651 -0.999720 \n", - "1 1 0.991269 -0.131859 -0.261414 -0.965227 \n", - "2 2 0.968251 -0.249981 -0.484089 -0.875019 \n", - "3 3 0.931210 -0.364483 -0.678820 -0.734304 \n", - "4 4 0.880683 -0.473706 -0.834370 -0.551205 \n", - ".. ... ... ... ... ... \n", - "174 174 -0.513901 -0.857849 0.881699 0.471812 \n", - "175 175 -0.613230 -0.789905 0.968786 0.247898 \n", - "176 176 -0.703677 -0.710520 0.999953 0.009676 \n", - "177 177 -0.783934 -0.620844 0.973402 -0.229104 \n", - "178 178 -0.852837 -0.522178 0.890665 -0.454661 \n", - "\n", - "[179 rows x 12 columns]" - ] - }, - "execution_count": 32, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "url = \"https://raw.githubusercontent.com/pymc-labs/pymc-marketing/main/datasets/mmm_example.csv\"\n", - "df = pd.read_csv(url)\n", - "df" - ] - }, - { - "cell_type": "markdown", - "id": "87deb70d", - "metadata": {}, - "source": [ - "But for our model we need much smaller dataset, many of the previous features were contributing to generation of others, now as our target variable is computed we can filter out not needed columns:" - ] - }, - { - "cell_type": "code", - "execution_count": 33, - "id": "52b6d127", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - " \n", - "
date_weekyx1x2event_1event_2dayofyeart
02018-04-023984.6622370.3185800.00.00.0920
12018-04-093762.8717940.1123880.00.00.0991
22018-04-164466.9673880.2924000.00.00.01062
32018-04-233864.2193730.0713990.00.00.01133
42018-04-304441.6252780.3867450.00.00.01204
\n", - "
" - ], - "text/plain": [ - " date_week y x1 x2 event_1 event_2 dayofyear t\n", - "0 2018-04-02 3984.662237 0.318580 0.0 0.0 0.0 92 0\n", - "1 2018-04-09 3762.871794 0.112388 0.0 0.0 0.0 99 1\n", - "2 2018-04-16 4466.967388 0.292400 0.0 0.0 0.0 106 2\n", - "3 2018-04-23 3864.219373 0.071399 0.0 0.0 0.0 113 3\n", - "4 2018-04-30 4441.625278 0.386745 0.0 0.0 0.0 120 4" - ] - }, - "execution_count": 33, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "columns_to_keep = [\n", - " \"date_week\",\n", - " \"y\",\n", - " \"x1\",\n", - " \"x2\",\n", - " \"event_1\",\n", - " \"event_2\",\n", - " \"dayofyear\",\n", - "]\n", - "seed: int = sum(map(ord, \"mmm\"))\n", - "rng =np.random.default_rng(seed=seed)\n", - "\n", - "data = df[columns_to_keep].copy()\n", - "\n", - "data[\"t\"] = range(df.shape[0])\n", - "data.head()" - ] - }, - { - "cell_type": "markdown", - "id": "9518a885", - "metadata": {}, - "source": [ - "## _Model Creation_\n", - "After we have our dataset ready, we could proceed straight to our model definition, but first to show the full potential of one of the new features: `model_config` we need to use some of our data to define our prior for sigma parameter for each of the channels. `model_config` is a customizable dictionary with keys corresponding to priors within the model, and values containing a dictionaries with parameters necessary to initialize them. Later on we'll learn that through the `save()` method we can preserve our priors contained inside the `model_config`, to allow complete replication of our model." - ] - }, - { - "cell_type": "markdown", - "id": "4b52b2c1", - "metadata": {}, - "source": [ - "### model_config" - ] - }, - { - "cell_type": "markdown", - "id": "41021a72", - "metadata": {}, - "source": [ - "`default_model_config` attribute of every model inheriting from `ModelBuilder` will allow you to see which priors are available for customization. To see it simply initialize a dummy model:" - ] - }, - { - "cell_type": "code", - "execution_count": 34, - "id": "284bd558", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'intercept': {'mu': 0, 'sigma': 2},\n", - " 'beta_channel': {'sigma': 2, 'dims': ('channel',)},\n", - " 'alpha': {'alpha': 1, 'beta': 3, 'dims': ('channel',)},\n", - " 'lam': {'alpha': 3, 'beta': 1, 'dims': ('channel',)},\n", - " 'sigma': {'sigma': 2},\n", - " 'gamma_control': {'mu': 0, 'sigma': 2, 'dims': ('control',)},\n", - " 'mu': {'dims': ('date',)},\n", - " 'likelihood': {'dims': ('date',)},\n", - " 'gamma_fourier': {'mu': 0, 'b': 1, 'dims': 'fourier_mode'}}" - ] - }, - "execution_count": 34, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dummy_model = DelayedSaturatedMMM(date_column = '', channel_columns= '', adstock_max_lag = 4)\n", - "dummy_model.default_model_config" - ] - }, - { - "cell_type": "markdown", - "id": "f0fd248f", - "metadata": {}, - "source": [ - "You can change only the prior parameters that you wish, no need to alter all of them, unless you'd like to!\n", - "In this case we'll just simply replace our sigma for beta_channel with our computed one:" - ] - }, - { - "cell_type": "markdown", - "id": "19f075f0-4d3d-4509-a9c6-f15efdb9293d", - "metadata": {}, - "source": [ - "First, let's compute the share of spend per channel:" - ] - }, - { - "cell_type": "code", - "execution_count": 35, - "id": "4785596a-e333-4cd0-af15-1332e97b66d5", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "x1 0.65632\n", - "x2 0.34368\n", - "dtype: float64" - ] - }, - "execution_count": 35, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "total_spend_per_channel = data[[\"x1\", \"x2\"]].sum(axis=0)\n", - "\n", - "spend_share = total_spend_per_channel / total_spend_per_channel.sum()\n", - "\n", - "spend_share" - ] - }, - { - "cell_type": "markdown", - "id": "40d17642-1e21-4adc-97f4-633eede87915", - "metadata": {}, - "source": [ - "Next, we specify the `sigma`parameter per channel:" - ] - }, - { - "cell_type": "code", - "execution_count": 36, - "id": "672b36bd-3d08-46df-85b8-d67d3ade75d7", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[2.1775326025486734, 1.140260877391939]" - ] - }, - "execution_count": 36, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "# The scale necessary to make a HalfNormal distribution have unit variance\n", - "HALFNORMAL_SCALE = 1 / np.sqrt(1 - 2 / np.pi)\n", - "\n", - "n_channels = 2\n", - "\n", - "prior_sigma = HALFNORMAL_SCALE * n_channels * spend_share.to_numpy()\n", - "\n", - "prior_sigma.tolist()" - ] - }, - { - "cell_type": "code", - "execution_count": 37, - "id": "bac9f587", - "metadata": {}, - "outputs": [], - "source": [ - "custom_beta_channel_prior = {'beta_channel': {'sigma': prior_sigma, 'dims': ('channel',)}}\n", - "my_model_config = dummy_model.default_model_config| custom_beta_channel_prior" - ] - }, - { - "cell_type": "markdown", - "id": "1aa435bf", - "metadata": {}, - "source": [ - "As mentioned in the original notebook: \"_For the prior specification there is no right or wrong answer. It all depends on the data, the context and the assumptions you are willing to make. It is always recommended to do some prior predictive sampling and sensitivity analysis to check the impact of the priors on the posterior. We skip this here for the sake of simplicity. If you are not sure about specific priors, the `DelayedSaturatedMMM` class has some default priors that you can use as a starting point._\"" - ] - }, - { - "cell_type": "markdown", - "id": "f195a79e", - "metadata": {}, - "source": [ - "The second feature that we can use for model definition is `sampler_config`. Similar to `model_config`, it's a dictionary that gets saved and contains things you'd usually pass to the `fit()` kwargs. It's not mandatory to create your own `sampler_config`; if not provided, both `model_config` and `sampler_config` will default to the forms specified by PyMC Labs experts, which allows for the usage of all model functionalities. The default `sampler_config` is left empty because the default sampling parameters usually prove sufficient for a start." - ] - }, - { - "cell_type": "code", - "execution_count": 38, - "id": "0ab8140c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{}" - ] - }, - "execution_count": 38, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "dummy_model.default_sampler_config" - ] - }, - { - "cell_type": "code", - "execution_count": 39, - "id": "bf5a50f4", - "metadata": {}, - "outputs": [], - "source": [ - "my_sampler_config = {\n", - " 'tune':1000,\n", - " 'draws':1000,\n", - " 'chains':4,\n", - " 'target_accept':0.95,\n", - "}" - ] - }, - { - "cell_type": "markdown", - "id": "f3bfe090", - "metadata": {}, - "source": [ - "Let's finally assemble our model!" - ] - }, - { - "cell_type": "code", - "execution_count": 40, - "id": "c7bd6909", - "metadata": {}, - "outputs": [], - "source": [ - "mmm = DelayedSaturatedMMM(\n", - " model_config = my_model_config,\n", - " sampler_config = my_sampler_config,\n", - " date_column=\"date_week\",\n", - " channel_columns=[\"x1\", \"x2\"],\n", - " control_columns=[\n", - " \"event_1\",\n", - " \"event_2\",\n", - " \"t\",\n", - " ],\n", - " adstock_max_lag=8,\n", - " yearly_seasonality=2,\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "54095b1a", - "metadata": {}, - "source": [ - "An important thing to note here is that in the new version of `DelayedSaturatedMMM`, we don't pass our dataset to the class constructor itself. This is due to a reason I've mentioned before - it supports `sklearn` transformers and validations that require a usual X, y split and typically expect the data to be passed to the `fit()` method." - ] - }, - { - "cell_type": "markdown", - "id": "dec9b1b0", - "metadata": {}, - "source": [ - "## _Model Fitting_" - ] - }, - { - "cell_type": "markdown", - "id": "d5e64562-ba78-4497-a0f8-123b4bc88b79", - "metadata": {}, - "source": [ - "Let's split the dataset:" - ] - }, - { - "cell_type": "code", - "execution_count": 41, - "id": "ff23006b-a55b-4a22-9f34-4eeaddf47486", - "metadata": {}, - "outputs": [], - "source": [ - "X = data.drop('y',axis=1)\n", - "y = data['y']" - ] - }, - { - "cell_type": "markdown", - "id": "403e3ed2", - "metadata": {}, - "source": [ - "All that's left now is to finally fit the model:\n", - "\n", - "As you can see below, you can still pass the sampler kwargs directly to `fit()` method. However, only those kwargs passed using `sampler_config` will be saved. Therefore, only these will be available after loading the model." - ] - }, - { - "cell_type": "code", - "execution_count": 42, - "id": "0f6ab0a8", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Auto-assigning NUTS sampler...\n", - "Initializing NUTS using jitter+adapt_diag...\n", - "Multiprocess sampling (4 chains in 4 jobs)\n", - "NUTS: [intercept, beta_channel, alpha, lam, sigma, gamma_control, gamma_fourier]\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " 100.00% [8000/8000 00:29<00:00 Sampling 4 chains, 0 divergences]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Sampling 4 chains for 1_000 tune and 1_000 draw iterations (4_000 + 4_000 draws total) took 30 seconds.\n", - "Sampling: [alpha, beta_channel, gamma_control, gamma_fourier, intercept, lam, likelihood, sigma]\n", - "Sampling: [likelihood]\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " 100.00% [4000/4000 00:00<00:00]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - "
\n", - "
arviz.InferenceData
\n", - "
\n", - "
    \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
      -       "                                fourier_mode: 4, channel: 2, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain                      (chain) int64 0 1 2 3\n",
      -       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
      -       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      -       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      -       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      -       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      -       "Data variables: (12/13)\n",
      -       "    intercept                  (chain, draw) float64 0.3381 0.3361 ... 0.3479\n",
      -       "    gamma_control              (chain, draw, control) float64 0.2968 ... 0.00...\n",
      -       "    gamma_fourier              (chain, draw, fourier_mode) float64 -0.002612 ...\n",
      -       "    beta_channel               (chain, draw, channel) float64 0.3718 ... 0.2628\n",
      -       "    alpha                      (chain, draw, channel) float64 0.4162 ... 0.2007\n",
      -       "    lam                        (chain, draw, channel) float64 4.26 ... 2.758\n",
      -       "    ...                         ...\n",
      -       "    channel_adstock            (chain, draw, date, channel) float64 0.1868 .....\n",
      -       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3781 .....\n",
      -       "    channel_contributions      (chain, draw, date, channel) float64 0.1406 .....\n",
      -       "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
      -       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 -0.0...\n",
      -       "    mu                         (chain, draw, date) float64 0.4769 ... 0.5924\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:14.027598\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1\n",
      -       "    sampling_time:              29.821417093276978\n",
      -       "    tuning_steps:               1000

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain       (chain) int64 0 1 2 3\n",
      -       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      -       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (chain, draw, date) float64 0.5039 0.439 ... 0.5873 0.6072\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:15.416164\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                (chain: 4, draw: 1000)\n",
      -       "Coordinates:\n",
      -       "  * chain                  (chain) int64 0 1 2 3\n",
      -       "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
      -       "Data variables: (12/17)\n",
      -       "    process_time_diff      (chain, draw) float64 0.008982 0.009454 ... 0.009203\n",
      -       "    step_size_bar          (chain, draw) float64 0.05522 0.05522 ... 0.06921\n",
      -       "    step_size              (chain, draw) float64 0.05862 0.05862 ... 0.07543\n",
      -       "    acceptance_rate        (chain, draw) float64 0.9987 0.9934 ... 0.8931 0.9167\n",
      -       "    index_in_trajectory    (chain, draw) int64 -12 20 43 -17 ... 42 20 -43 -29\n",
      -       "    tree_depth             (chain, draw) int64 6 6 6 6 7 6 6 7 ... 6 5 5 6 7 6 6\n",
      -       "    ...                     ...\n",
      -       "    energy_error           (chain, draw) float64 -0.01181 -0.02828 ... -0.01478\n",
      -       "    perf_counter_diff      (chain, draw) float64 0.009316 0.01008 ... 0.009355\n",
      -       "    n_steps                (chain, draw) float64 63.0 63.0 63.0 ... 63.0 63.0\n",
      -       "    diverging              (chain, draw) bool False False False ... False False\n",
      -       "    perf_counter_start     (chain, draw) float64 3.949e+06 ... 3.949e+06\n",
      -       "    lp                     (chain, draw) float64 355.1 355.2 ... 355.6 351.6\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:14.040220\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1\n",
      -       "    sampling_time:              29.821417093276978\n",
      -       "    tuning_steps:               1000

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                    (chain: 1, draw: 500, fourier_mode: 4,\n",
      -       "                                date: 179, channel: 2, control: 3)\n",
      -       "Coordinates:\n",
      -       "  * chain                      (chain) int64 0\n",
      -       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
      -       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      -       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      -       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      -       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      -       "Data variables: (12/13)\n",
      -       "    gamma_fourier              (chain, draw, fourier_mode) float64 1.28 ... 1...\n",
      -       "    intercept                  (chain, draw) float64 1.178 -2.005 ... 2.533\n",
      -       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 1.28...\n",
      -       "    mu                         (chain, draw, date) float64 0.8703 ... -44.96\n",
      -       "    channel_contributions      (chain, draw, date, channel) float64 0.4222 .....\n",
      -       "    control_contributions      (chain, draw, date, control) float64 0.0 ... -...\n",
      -       "    ...                         ...\n",
      -       "    gamma_control              (chain, draw, control) float64 1.547 ... -0.2666\n",
      -       "    channel_adstock            (chain, draw, date, channel) float64 0.3087 .....\n",
      -       "    alpha                      (chain, draw, channel) float64 0.03434 ... 0.1455\n",
      -       "    lam                        (chain, draw, channel) float64 2.558 ... 4.149\n",
      -       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3755 .....\n",
      -       "    sigma                      (chain, draw) float64 0.05319 1.662 ... 1.212\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:15.159471\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain       (chain) int64 0\n",
      -       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      -       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (chain, draw, date) float64 0.8725 3.274 6.447 ... -45.43 -45.69\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:15.164125\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (date: 179)\n",
      -       "Coordinates:\n",
      -       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:14.043853\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
      -       "Coordinates:\n",
      -       "  * date          (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "  * channel       (channel) <U2 'x1' 'x2'\n",
      -       "  * control       (control) <U7 'event_1' 'event_2' 't'\n",
      -       "  * fourier_mode  (fourier_mode) <U11 'sin_order_1' ... 'cos_order_2'\n",
      -       "Data variables:\n",
      -       "    channel_data  (date, channel) float64 0.3196 0.0 0.1128 ... 0.0 0.4403 0.0\n",
      -       "    target        (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      -       "    control_data  (date, control) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0\n",
      -       "    fourier_data  (date, fourier_mode) float64 0.9999 -0.01183 ... -0.4547\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:14.045029\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:    (index: 179)\n",
      -       "Coordinates:\n",
      -       "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      -       "Data variables:\n",
      -       "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
      -       "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
      -       "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      -       "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      -       "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
      -       "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      -       "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
\n", - "
\n", - " " - ], - "text/plain": [ - "Inference data with groups:\n", - "\t> posterior\n", - "\t> posterior_predictive\n", - "\t> sample_stats\n", - "\t> prior\n", - "\t> prior_predictive\n", - "\t> observed_data\n", - "\t> constant_data\n", - "\t> fit_data" - ] - }, - "execution_count": 42, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mmm.fit(X=X, y=y, random_seed=rng)" - ] - }, - { - "cell_type": "markdown", - "id": "c29a6461", - "metadata": {}, - "source": [ - "The `fit()` method automatically builds the model using the priors from `model_config`, and assigns the created model to our instance. You can access it as a normal attribute." - ] - }, - { - "cell_type": "code", - "execution_count": 43, - "id": "c6b8e2af", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "pymc.model.Model" - ] - }, - "execution_count": 43, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "type(mmm.model)" - ] - }, - { - "cell_type": "code", - "execution_count": 44, - "id": "f046ee2c", - "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "clusterdate (179) x channel (2)\n", - "\n", - "date (179) x channel (2)\n", - "\n", - "\n", - "clusterdate (179)\n", - "\n", - "date (179)\n", - "\n", - "\n", - "clusterchannel (2)\n", - "\n", - "channel (2)\n", - "\n", - "\n", - "clusterdate (179) x control (3)\n", - "\n", - "date (179) x control (3)\n", - "\n", - "\n", - "clustercontrol (3)\n", - "\n", - "control (3)\n", - "\n", - "\n", - "clusterdate (179) x fourier_mode (4)\n", - "\n", - "date (179) x fourier_mode (4)\n", - "\n", - "\n", - "clusterfourier_mode (4)\n", - "\n", - "fourier_mode (4)\n", - "\n", - "\n", - "\n", - "channel_contributions\n", - "\n", - "channel_contributions\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "mu\n", - "\n", - "mu\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "channel_contributions->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "channel_data\n", - "\n", - "channel_data\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "channel_adstock\n", - "\n", - "channel_adstock\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "channel_data->channel_adstock\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "channel_adstock_saturated\n", - "\n", - "channel_adstock_saturated\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "channel_adstock->channel_adstock_saturated\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "channel_adstock_saturated->channel_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "target\n", - "\n", - "target\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "likelihood\n", - "\n", - "likelihood\n", - "~\n", - "Normal\n", - "\n", - "\n", - "\n", - "mu->likelihood\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "likelihood->target\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "intercept\n", - "\n", - "intercept\n", - "~\n", - "Normal\n", - "\n", - "\n", - "\n", - "intercept->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "sigma\n", - "\n", - "sigma\n", - "~\n", - "HalfNormal\n", - "\n", - "\n", - "\n", - "sigma->likelihood\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "lam\n", - "\n", - "lam\n", - "~\n", - "Gamma\n", - "\n", - "\n", - "\n", - "lam->channel_adstock_saturated\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "alpha\n", - "\n", - "alpha\n", - "~\n", - "Beta\n", - "\n", - "\n", - "\n", - "alpha->channel_adstock\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "beta_channel\n", - "\n", - "beta_channel\n", - "~\n", - "HalfNormal\n", - "\n", - "\n", - "\n", - "beta_channel->channel_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "control_data\n", - "\n", - "control_data\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "control_contributions\n", - "\n", - "control_contributions\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "control_data->control_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "control_contributions->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "gamma_control\n", - "\n", - "gamma_control\n", - "~\n", - "Normal\n", - "\n", - "\n", - "\n", - "gamma_control->control_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "fourier_contributions\n", - "\n", - "fourier_contributions\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "fourier_contributions->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "fourier_data\n", - "\n", - "fourier_data\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "fourier_data->fourier_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "gamma_fourier\n", - "\n", - "gamma_fourier\n", - "~\n", - "Laplace\n", - "\n", - "\n", - "\n", - "gamma_fourier->fourier_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 44, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mmm.graphviz()" - ] - }, - { - "cell_type": "markdown", - "id": "c804b600", - "metadata": {}, - "source": [ - "posterior trace can be accessed by `fit_result` attribute" - ] - }, - { - "cell_type": "code", - "execution_count": 45, - "id": "66903965", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.Dataset>\n",
-       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
-       "                                fourier_mode: 4, channel: 2, date: 179)\n",
-       "Coordinates:\n",
-       "  * chain                      (chain) int64 0 1 2 3\n",
-       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
-       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
-       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
-       "  * channel                    (channel) <U2 'x1' 'x2'\n",
-       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
-       "Data variables: (12/13)\n",
-       "    intercept                  (chain, draw) float64 0.3381 0.3361 ... 0.3479\n",
-       "    gamma_control              (chain, draw, control) float64 0.2968 ... 0.00...\n",
-       "    gamma_fourier              (chain, draw, fourier_mode) float64 -0.002612 ...\n",
-       "    beta_channel               (chain, draw, channel) float64 0.3718 ... 0.2628\n",
-       "    alpha                      (chain, draw, channel) float64 0.4162 ... 0.2007\n",
-       "    lam                        (chain, draw, channel) float64 4.26 ... 2.758\n",
-       "    ...                         ...\n",
-       "    channel_adstock            (chain, draw, date, channel) float64 0.1868 .....\n",
-       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3781 .....\n",
-       "    channel_contributions      (chain, draw, date, channel) float64 0.1406 .....\n",
-       "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
-       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 -0.0...\n",
-       "    mu                         (chain, draw, date) float64 0.4769 ... 0.5924\n",
-       "Attributes:\n",
-       "    created_at:                 2023-08-03T11:09:14.027598\n",
-       "    arviz_version:              0.16.1\n",
-       "    inference_library:          pymc\n",
-       "    inference_library_version:  5.6.1\n",
-       "    sampling_time:              29.821417093276978\n",
-       "    tuning_steps:               1000
" - ], - "text/plain": [ - "\n", - "Dimensions: (chain: 4, draw: 1000, control: 3,\n", - " fourier_mode: 4, channel: 2, date: 179)\n", - "Coordinates:\n", - " * chain (chain) int64 0 1 2 3\n", - " * draw (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n", - " * control (control) \n", - "
\n", - "
arviz.InferenceData
\n", - "
\n", - "
    \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
      -       "                                fourier_mode: 4, channel: 2, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain                      (chain) int64 0 1 2 3\n",
      -       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
      -       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      -       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      -       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      -       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      -       "Data variables: (12/13)\n",
      -       "    intercept                  (chain, draw) float64 0.3381 0.3361 ... 0.3479\n",
      -       "    gamma_control              (chain, draw, control) float64 0.2968 ... 0.00...\n",
      -       "    gamma_fourier              (chain, draw, fourier_mode) float64 -0.002612 ...\n",
      -       "    beta_channel               (chain, draw, channel) float64 0.3718 ... 0.2628\n",
      -       "    alpha                      (chain, draw, channel) float64 0.4162 ... 0.2007\n",
      -       "    lam                        (chain, draw, channel) float64 4.26 ... 2.758\n",
      -       "    ...                         ...\n",
      -       "    channel_adstock            (chain, draw, date, channel) float64 0.1868 .....\n",
      -       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3781 .....\n",
      -       "    channel_contributions      (chain, draw, date, channel) float64 0.1406 .....\n",
      -       "    control_contributions      (chain, draw, date, control) float64 0.0 ... 0...\n",
      -       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 -0.0...\n",
      -       "    mu                         (chain, draw, date) float64 0.4769 ... 0.5924\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:14.027598\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1\n",
      -       "    sampling_time:              29.821417093276978\n",
      -       "    tuning_steps:               1000

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain       (chain) int64 0 1 2 3\n",
      -       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      -       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (chain, draw, date) float64 0.5039 0.439 ... 0.5873 0.6072\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:15.416164\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                (chain: 4, draw: 1000)\n",
      -       "Coordinates:\n",
      -       "  * chain                  (chain) int64 0 1 2 3\n",
      -       "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
      -       "Data variables: (12/17)\n",
      -       "    process_time_diff      (chain, draw) float64 0.008982 0.009454 ... 0.009203\n",
      -       "    step_size_bar          (chain, draw) float64 0.05522 0.05522 ... 0.06921\n",
      -       "    step_size              (chain, draw) float64 0.05862 0.05862 ... 0.07543\n",
      -       "    acceptance_rate        (chain, draw) float64 0.9987 0.9934 ... 0.8931 0.9167\n",
      -       "    index_in_trajectory    (chain, draw) int64 -12 20 43 -17 ... 42 20 -43 -29\n",
      -       "    tree_depth             (chain, draw) int64 6 6 6 6 7 6 6 7 ... 6 5 5 6 7 6 6\n",
      -       "    ...                     ...\n",
      -       "    energy_error           (chain, draw) float64 -0.01181 -0.02828 ... -0.01478\n",
      -       "    perf_counter_diff      (chain, draw) float64 0.009316 0.01008 ... 0.009355\n",
      -       "    n_steps                (chain, draw) float64 63.0 63.0 63.0 ... 63.0 63.0\n",
      -       "    diverging              (chain, draw) bool False False False ... False False\n",
      -       "    perf_counter_start     (chain, draw) float64 3.949e+06 ... 3.949e+06\n",
      -       "    lp                     (chain, draw) float64 355.1 355.2 ... 355.6 351.6\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:14.040220\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1\n",
      -       "    sampling_time:              29.821417093276978\n",
      -       "    tuning_steps:               1000

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                    (chain: 1, draw: 500, fourier_mode: 4,\n",
      -       "                                date: 179, channel: 2, control: 3)\n",
      -       "Coordinates:\n",
      -       "  * chain                      (chain) int64 0\n",
      -       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
      -       "  * fourier_mode               (fourier_mode) <U11 'sin_order_1' ... 'cos_ord...\n",
      -       "  * date                       (date) <U10 '2018-04-02' ... '2021-08-30'\n",
      -       "  * channel                    (channel) <U2 'x1' 'x2'\n",
      -       "  * control                    (control) <U7 'event_1' 'event_2' 't'\n",
      -       "Data variables: (12/13)\n",
      -       "    gamma_fourier              (chain, draw, fourier_mode) float64 1.28 ... 1...\n",
      -       "    intercept                  (chain, draw) float64 1.178 -2.005 ... 2.533\n",
      -       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 1.28...\n",
      -       "    mu                         (chain, draw, date) float64 0.8703 ... -44.96\n",
      -       "    channel_contributions      (chain, draw, date, channel) float64 0.4222 .....\n",
      -       "    control_contributions      (chain, draw, date, control) float64 0.0 ... -...\n",
      -       "    ...                         ...\n",
      -       "    gamma_control              (chain, draw, control) float64 1.547 ... -0.2666\n",
      -       "    channel_adstock            (chain, draw, date, channel) float64 0.3087 .....\n",
      -       "    alpha                      (chain, draw, channel) float64 0.03434 ... 0.1455\n",
      -       "    lam                        (chain, draw, channel) float64 2.558 ... 4.149\n",
      -       "    channel_adstock_saturated  (chain, draw, date, channel) float64 0.3755 .....\n",
      -       "    sigma                      (chain, draw) float64 0.05319 1.662 ... 1.212\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:15.159471\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain       (chain) int64 0\n",
      -       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      -       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (chain, draw, date) float64 0.8725 3.274 6.447 ... -45.43 -45.69\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:15.164125\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (date: 179)\n",
      -       "Coordinates:\n",
      -       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:14.043853\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
      -       "Coordinates:\n",
      -       "  * date          (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "  * channel       (channel) <U2 'x1' 'x2'\n",
      -       "  * control       (control) <U7 'event_1' 'event_2' 't'\n",
      -       "  * fourier_mode  (fourier_mode) <U11 'sin_order_1' ... 'cos_order_2'\n",
      -       "Data variables:\n",
      -       "    channel_data  (date, channel) float64 0.3196 0.0 0.1128 ... 0.0 0.4403 0.0\n",
      -       "    target        (date) float64 0.4794 0.4527 0.5374 ... 0.4978 0.5388 0.5625\n",
      -       "    control_data  (date, control) float64 0.0 0.0 0.0 0.0 ... 0.0 0.0 178.0\n",
      -       "    fourier_data  (date, fourier_mode) float64 0.9999 -0.01183 ... -0.4547\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:14.045029\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:    (index: 179)\n",
      -       "Coordinates:\n",
      -       "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      -       "Data variables:\n",
      -       "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
      -       "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
      -       "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      -       "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      -       "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
      -       "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      -       "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
\n", - " \n", - " " - ], - "text/plain": [ - "Inference data with groups:\n", - "\t> posterior\n", - "\t> posterior_predictive\n", - "\t> sample_stats\n", - "\t> prior\n", - "\t> prior_predictive\n", - "\t> observed_data\n", - "\t> constant_data\n", - "\t> fit_data" - ] - }, - "execution_count": 46, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "mmm.idata" - ] - }, - { - "cell_type": "markdown", - "id": "8b433c7f-0f0d-40b2-bcfb-a19555b528bd", - "metadata": {}, - "source": [ - "## `Save` and `load`" - ] - }, - { - "cell_type": "markdown", - "id": "7b0a35f4", - "metadata": {}, - "source": [ - "All the data passed to the model on initialisation is stored in `idata.attrs`. This will be used later in the `save()` method to convert both this data and all the fit data into the netCDF format." - ] - }, - { - "cell_type": "markdown", - "id": "45948f46", - "metadata": {}, - "source": [ - "Simply specify the path to which you'd like to save your model:" - ] - }, - { - "cell_type": "code", - "execution_count": 47, - "id": "b3abe93a", - "metadata": {}, - "outputs": [], - "source": [ - "mmm.save('my_saved_model.nc')" - ] - }, - { - "cell_type": "markdown", - "id": "8a5eba79", - "metadata": {}, - "source": [ - "And pass it to the `load()` method when it's needed again on the target system:" - ] - }, - { - "cell_type": "code", - "execution_count": 48, - "id": "0421bae8", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/michalraczycki/Documents/pymc-marketing/.conda/envs/pymc-marketing/lib/python3.10/site-packages/arviz/data/inference_data.py:153: UserWarning: fit_data group is not defined in the InferenceData scheme\n", - " warnings.warn(\n" - ] - } - ], - "source": [ - "loaded_model = DelayedSaturatedMMM.load('my_saved_model.nc')" - ] - }, - { - "cell_type": "code", - "execution_count": 49, - "id": "a8b666d3", - "metadata": {}, - "outputs": [ - { - "data": { - "image/svg+xml": [ - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "clusterdate (179) x channel (2)\n", - "\n", - "date (179) x channel (2)\n", - "\n", - "\n", - "clusterdate (179)\n", - "\n", - "date (179)\n", - "\n", - "\n", - "clusterchannel (2)\n", - "\n", - "channel (2)\n", - "\n", - "\n", - "clusterdate (179) x control (3)\n", - "\n", - "date (179) x control (3)\n", - "\n", - "\n", - "clustercontrol (3)\n", - "\n", - "control (3)\n", - "\n", - "\n", - "clusterdate (179) x fourier_mode (4)\n", - "\n", - "date (179) x fourier_mode (4)\n", - "\n", - "\n", - "clusterfourier_mode (4)\n", - "\n", - "fourier_mode (4)\n", - "\n", - "\n", - "\n", - "channel_contributions\n", - "\n", - "channel_contributions\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "mu\n", - "\n", - "mu\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "channel_contributions->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "channel_data\n", - "\n", - "channel_data\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "channel_adstock\n", - "\n", - "channel_adstock\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "channel_data->channel_adstock\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "channel_adstock_saturated\n", - "\n", - "channel_adstock_saturated\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "channel_adstock->channel_adstock_saturated\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "channel_adstock_saturated->channel_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "target\n", - "\n", - "target\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "likelihood\n", - "\n", - "likelihood\n", - "~\n", - "Normal\n", - "\n", - "\n", - "\n", - "mu->likelihood\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "likelihood->target\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "intercept\n", - "\n", - "intercept\n", - "~\n", - "Normal\n", - "\n", - "\n", - "\n", - "intercept->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "sigma\n", - "\n", - "sigma\n", - "~\n", - "HalfNormal\n", - "\n", - "\n", - "\n", - "sigma->likelihood\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "lam\n", - "\n", - "lam\n", - "~\n", - "Gamma\n", - "\n", - "\n", - "\n", - "lam->channel_adstock_saturated\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "alpha\n", - "\n", - "alpha\n", - "~\n", - "Beta\n", - "\n", - "\n", - "\n", - "alpha->channel_adstock\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "beta_channel\n", - "\n", - "beta_channel\n", - "~\n", - "HalfNormal\n", - "\n", - "\n", - "\n", - "beta_channel->channel_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "control_data\n", - "\n", - "control_data\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "control_contributions\n", - "\n", - "control_contributions\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "control_data->control_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "control_contributions->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "gamma_control\n", - "\n", - "gamma_control\n", - "~\n", - "Normal\n", - "\n", - "\n", - "\n", - "gamma_control->control_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "fourier_contributions\n", - "\n", - "fourier_contributions\n", - "~\n", - "Deterministic\n", - "\n", - "\n", - "\n", - "fourier_contributions->mu\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "fourier_data\n", - "\n", - "fourier_data\n", - "~\n", - "MutableData\n", - "\n", - "\n", - "\n", - "fourier_data->fourier_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "gamma_fourier\n", - "\n", - "gamma_fourier\n", - "~\n", - "Laplace\n", - "\n", - "\n", - "\n", - "gamma_fourier->fourier_contributions\n", - "\n", - "\n", - "\n", - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "execution_count": 49, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "loaded_model.graphviz()" - ] - }, - { - "cell_type": "code", - "execution_count": 50, - "id": "cfb64a2c", - "metadata": {}, - "outputs": [ - { - "data": { - "text/html": [ - "\n", - "
\n", - "
\n", - "
arviz.InferenceData
\n", - "
\n", - "
    \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                    (chain: 4, draw: 1000, control: 3,\n",
      -       "                                fourier_mode: 4, channel: 2, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain                      (chain) int64 0 1 2 3\n",
      -       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 995 996 997 998 999\n",
      -       "  * control                    (control) object 'event_1' 'event_2' 't'\n",
      -       "  * fourier_mode               (fourier_mode) object 'sin_order_1' ... 'cos_o...\n",
      -       "  * channel                    (channel) object 'x1' 'x2'\n",
      -       "  * date                       (date) object '2018-04-02' ... '2021-08-30'\n",
      -       "Data variables: (12/13)\n",
      -       "    intercept                  (chain, draw) float64 ...\n",
      -       "    gamma_control              (chain, draw, control) float64 ...\n",
      -       "    gamma_fourier              (chain, draw, fourier_mode) float64 ...\n",
      -       "    beta_channel               (chain, draw, channel) float64 ...\n",
      -       "    alpha                      (chain, draw, channel) float64 ...\n",
      -       "    lam                        (chain, draw, channel) float64 ...\n",
      -       "    ...                         ...\n",
      -       "    channel_adstock            (chain, draw, date, channel) float64 ...\n",
      -       "    channel_adstock_saturated  (chain, draw, date, channel) float64 ...\n",
      -       "    channel_contributions      (chain, draw, date, channel) float64 ...\n",
      -       "    control_contributions      (chain, draw, date, control) float64 ...\n",
      -       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 ...\n",
      -       "    mu                         (chain, draw, date) float64 ...\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:14.027598\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1\n",
      -       "    sampling_time:              29.821417093276978\n",
      -       "    tuning_steps:               1000

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain       (chain) int64 0 1 2 3\n",
      -       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
      -       "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (chain, draw, date) float64 ...\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:15.416164\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                (chain: 4, draw: 1000)\n",
      -       "Coordinates:\n",
      -       "  * chain                  (chain) int64 0 1 2 3\n",
      -       "  * draw                   (draw) int64 0 1 2 3 4 5 ... 994 995 996 997 998 999\n",
      -       "Data variables: (12/17)\n",
      -       "    process_time_diff      (chain, draw) float64 ...\n",
      -       "    step_size_bar          (chain, draw) float64 ...\n",
      -       "    step_size              (chain, draw) float64 ...\n",
      -       "    acceptance_rate        (chain, draw) float64 ...\n",
      -       "    index_in_trajectory    (chain, draw) int64 ...\n",
      -       "    tree_depth             (chain, draw) int64 ...\n",
      -       "    ...                     ...\n",
      -       "    energy_error           (chain, draw) float64 ...\n",
      -       "    perf_counter_diff      (chain, draw) float64 ...\n",
      -       "    n_steps                (chain, draw) float64 ...\n",
      -       "    diverging              (chain, draw) bool ...\n",
      -       "    perf_counter_start     (chain, draw) float64 ...\n",
      -       "    lp                     (chain, draw) float64 ...\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:14.040220\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1\n",
      -       "    sampling_time:              29.821417093276978\n",
      -       "    tuning_steps:               1000

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:                    (chain: 1, draw: 500, fourier_mode: 4,\n",
      -       "                                date: 179, channel: 2, control: 3)\n",
      -       "Coordinates:\n",
      -       "  * chain                      (chain) int64 0\n",
      -       "  * draw                       (draw) int64 0 1 2 3 4 5 ... 495 496 497 498 499\n",
      -       "  * fourier_mode               (fourier_mode) object 'sin_order_1' ... 'cos_o...\n",
      -       "  * date                       (date) object '2018-04-02' ... '2021-08-30'\n",
      -       "  * channel                    (channel) object 'x1' 'x2'\n",
      -       "  * control                    (control) object 'event_1' 'event_2' 't'\n",
      -       "Data variables: (12/13)\n",
      -       "    gamma_fourier              (chain, draw, fourier_mode) float64 ...\n",
      -       "    intercept                  (chain, draw) float64 ...\n",
      -       "    fourier_contributions      (chain, draw, date, fourier_mode) float64 ...\n",
      -       "    mu                         (chain, draw, date) float64 ...\n",
      -       "    channel_contributions      (chain, draw, date, channel) float64 ...\n",
      -       "    control_contributions      (chain, draw, date, control) float64 ...\n",
      -       "    ...                         ...\n",
      -       "    gamma_control              (chain, draw, control) float64 ...\n",
      -       "    channel_adstock            (chain, draw, date, channel) float64 ...\n",
      -       "    alpha                      (chain, draw, channel) float64 ...\n",
      -       "    lam                        (chain, draw, channel) float64 ...\n",
      -       "    channel_adstock_saturated  (chain, draw, date, channel) float64 ...\n",
      -       "    sigma                      (chain, draw) float64 ...\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:15.159471\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (chain: 1, draw: 500, date: 179)\n",
      -       "Coordinates:\n",
      -       "  * chain       (chain) int64 0\n",
      -       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 492 493 494 495 496 497 498 499\n",
      -       "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (chain, draw, date) float64 ...\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:15.164125\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:     (date: 179)\n",
      -       "Coordinates:\n",
      -       "  * date        (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "Data variables:\n",
      -       "    likelihood  (date) float64 ...\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:14.043853\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:       (date: 179, channel: 2, control: 3, fourier_mode: 4)\n",
      -       "Coordinates:\n",
      -       "  * date          (date) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "  * channel       (channel) object 'x1' 'x2'\n",
      -       "  * control       (control) object 'event_1' 'event_2' 't'\n",
      -       "  * fourier_mode  (fourier_mode) object 'sin_order_1' ... 'cos_order_2'\n",
      -       "Data variables:\n",
      -       "    channel_data  (date, channel) float64 ...\n",
      -       "    target        (date) float64 ...\n",
      -       "    control_data  (date, control) float64 ...\n",
      -       "    fourier_data  (date, fourier_mode) float64 ...\n",
      -       "Attributes:\n",
      -       "    created_at:                 2023-08-03T11:09:14.045029\n",
      -       "    arviz_version:              0.16.1\n",
      -       "    inference_library:          pymc\n",
      -       "    inference_library_version:  5.6.1

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
  • \n", - " \n", - " \n", - "
    \n", - "
    \n", - "
      \n", - "
      \n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
      <xarray.Dataset>\n",
      -       "Dimensions:    (index: 179)\n",
      -       "Coordinates:\n",
      -       "  * index      (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      -       "Data variables:\n",
      -       "    date_week  (index) object '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
      -       "    x1         (index) float64 0.3186 0.1124 0.2924 ... 0.1719 0.2803 0.4389\n",
      -       "    x2         (index) float64 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.8633 0.0 0.0 0.0\n",
      -       "    event_1    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      -       "    event_2    (index) float64 0.0 0.0 0.0 0.0 0.0 0.0 ... 0.0 0.0 0.0 0.0 0.0\n",
      -       "    dayofyear  (index) int64 92 99 106 113 120 127 ... 207 214 221 228 235 242\n",
      -       "    t          (index) int64 0 1 2 3 4 5 6 7 ... 171 172 173 174 175 176 177 178\n",
      -       "    y          (index) float64 3.985e+03 3.763e+03 ... 4.479e+03 4.676e+03

      \n", - "
    \n", - "
    \n", - "
  • \n", - " \n", - "
\n", - "
\n", - " " - ], - "text/plain": [ - "Inference data with groups:\n", - "\t> posterior\n", - "\t> posterior_predictive\n", - "\t> sample_stats\n", - "\t> prior\n", - "\t> prior_predictive\n", - "\t> observed_data\n", - "\t> constant_data\n", - "\t> fit_data" - ] - }, - "execution_count": 50, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "loaded_model.idata" - ] - }, - { - "cell_type": "markdown", - "id": "ab64be46-7fe5-4f39-b72b-36da1419f809", - "metadata": {}, - "source": [ - "A model loaded in this way is ready to be used for sampling and prediction, and has access to all previous samples and data." - ] - }, - { - "cell_type": "code", - "execution_count": 51, - "id": "dd59d056-6ac7-431c-85ee-99e7a8eefd8a", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Sampling: [likelihood]\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "\n" - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " \n", - " 100.00% [4000/4000 00:00<00:00]\n", - "
\n", - " " - ], - "text/plain": [ - "" - ] - }, - "metadata": {}, - "output_type": "display_data" - }, - { - "data": { - "text/html": [ - "
\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "\n", - "
<xarray.Dataset>\n",
-       "Dimensions:     (chain: 4, draw: 1000, date: 179)\n",
-       "Coordinates:\n",
-       "  * chain       (chain) int64 0 1 2 3\n",
-       "  * draw        (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n",
-       "  * date        (date) <U10 '2018-04-02' '2018-04-09' ... '2021-08-30'\n",
-       "Data variables:\n",
-       "    likelihood  (chain, draw, date) float64 0.4907 0.4282 ... 0.5548 0.5396\n",
-       "Attributes:\n",
-       "    created_at:                 2023-08-03T11:09:22.139450\n",
-       "    arviz_version:              0.16.1\n",
-       "    inference_library:          pymc\n",
-       "    inference_library_version:  5.6.1
" - ], - "text/plain": [ - "\n", - "Dimensions: (chain: 4, draw: 1000, date: 179)\n", - "Coordinates:\n", - " * chain (chain) int64 0 1 2 3\n", - " * draw (draw) int64 0 1 2 3 4 5 6 7 ... 992 993 994 995 996 997 998 999\n", - " * date (date) " - ] - }, - "metadata": {}, - "output_type": "display_data" - } - ], - "source": [ - "az.plot_ppc(loaded_model.idata);" - ] - }, - { - "cell_type": "markdown", - "id": "e8e807f9", - "metadata": {}, - "source": [ - "## Summary:" - ] - }, - { - "cell_type": "markdown", - "id": "61f232c1", - "metadata": {}, - "source": [ - "In summary, this article introduces the revolutionary ModelBuilder, a new PyMC-experimental module that simplifies the deployment of PyMC Bayesian models. It addresses a historic challenge faced by users of PyMC and most PPLs by offering a user-friendly and efficient approach to model deployment. The ModelBuilder provides two straightforward methods, save() and load(), which streamline the model preservation and replication process post fitting. Users are offered flexibility in controlling the prior settings with model_config and customizing the sampling process via sampler_config.\n", - "\n", - "The use of an example model from the MMM Example Notebook demonstrates the practical implementation of ModelBuilder, emphasizing its ability to enhance model sharing among teams without the necessity for extensive domain knowledge about the model. The deployment improvements in PyMC-Marketing brought about by ModelBuilder are not only user-friendly but also significantly enhance efficiency, making PyMC models more accessible for a wider audience." - ] - }, - { - "cell_type": "markdown", - "id": "b8ad333d", - "metadata": {}, - "source": [ - "Even though this introduction is using `DelayedSaturatedMMM`, functionalities from `ModelBuilder` are available in the CLV models as well." - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "pymc-marketing", - "language": "python", - "name": "pymc-marketing" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} From d12b00b66f05d54f6e91788bdf36bfcea7f09390 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Fri, 1 Sep 2023 11:23:06 +0200 Subject: [PATCH 5/5] updating .myst.md file --- examples/howto/model_builder.myst.md | 1 - 1 file changed, 1 deletion(-) diff --git a/examples/howto/model_builder.myst.md b/examples/howto/model_builder.myst.md index 477c91e98..0747511aa 100644 --- a/examples/howto/model_builder.myst.md +++ b/examples/howto/model_builder.myst.md @@ -344,7 +344,6 @@ ax.legend( * Authored by Shashank Kirtania and Thomas Wiecki in 2023. * Modified and updated by Michał Raczycki in 08/2023 - +++ :::{include} ../page_footer.md