diff --git a/examples/diagnostics_and_criticism/model_comparison.ipynb b/examples/diagnostics_and_criticism/model_comparison.ipynb deleted file mode 100644 index 1e70de829..000000000 --- a/examples/diagnostics_and_criticism/model_comparison.ipynb +++ /dev/null @@ -1,556 +0,0 @@ -{ - "cells": [ - { - "cell_type": "code", - "execution_count": 1, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Runing on PyMC3 v3.11.0\n" - ] - } - ], - "source": [ - "import arviz as az\n", - "import matplotlib.pyplot as plt\n", - "import numpy as np\n", - "import pymc3 as pm\n", - "\n", - "print(f\"Runing on PyMC3 v{pm.__version__}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "metadata": {}, - "outputs": [], - "source": [ - "%load_ext watermark\n", - "az.style.use(\"arviz-darkgrid\")" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "# Model comparison\n", - "\n", - "To demonstrate the use of model comparison criteria in PyMC3, we implement the **8 schools** example from Section 5.5 of Gelman et al (2003), which attempts to infer the effects of coaching on SAT scores of students from 8 schools. Below, we fit a **pooled model**, which assumes a single fixed effect across all schools, and a **hierarchical model** that allows for a random effect that partially pools the data." - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "The data include the observed treatment effects and associated standard deviations in the 8 schools." - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "metadata": {}, - "outputs": [], - "source": [ - "J = 8\n", - "y = np.array([28, 8, -3, 7, -1, 1, 18, 12])\n", - "sigma = np.array([15, 10, 16, 11, 9, 11, 10, 18])" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "### Pooled model" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "Auto-assigning NUTS sampler...\n", - "Initializing NUTS using jitter+adapt_diag...\n", - "Multiprocess sampling (2 chains in 2 jobs)\n", - "NUTS: [mu]\n" - ] - }, - { - "data": { - "text/html": [ - "\n", - "
\n", - " | rank | \n", - "loo | \n", - "p_loo | \n", - "d_loo | \n", - "weight | \n", - "se | \n", - "dse | \n", - "warning | \n", - "loo_scale | \n", - "
---|---|---|---|---|---|---|---|---|---|
pooled | \n", - "0 | \n", - "-30.569563 | \n", - "0.680583 | \n", - "0.000000 | \n", - "1.0 | \n", - "1.105191 | \n", - "0.0000 | \n", - "False | \n", - "log | \n", - "
hierarchical | \n", - "1 | \n", - "-30.754275 | \n", - "1.113869 | \n", - "0.184711 | \n", - "0.0 | \n", - "1.045108 | \n", - "0.2397 | \n", - "False | \n", - "log | \n", - "
\n", - " | mean | \n", - "sd | \n", - "hdi_3% | \n", - "hdi_97% | \n", - "mcse_mean | \n", - "mcse_sd | \n", - "ess_mean | \n", - "ess_sd | \n", - "ess_bulk | \n", - "ess_tail | \n", - "r_hat | \n", - "
---|---|---|---|---|---|---|---|---|---|---|---|
betas[0] | \n", - "0.09 | \n", - "0.11 | \n", - "-0.1 | \n", - "0.30 | \n", - "0.0 | \n", - "0.0 | \n", - "1621.30 | \n", - "1304.44 | \n", - "1622.59 | \n", - "1219.60 | \n", - "1.0 | \n", - "
betas[1] | \n", - "1.07 | \n", - "0.13 | \n", - "0.8 | \n", - "1.29 | \n", - "0.0 | \n", - "0.0 | \n", - "1733.47 | \n", - "1682.04 | \n", - "1736.02 | \n", - "1377.74 | \n", - "1.0 | \n", - "
\n", - " | mean | \n", - "sd | \n", - "hdi_3% | \n", - "hdi_97% | \n", - "mcse_mean | \n", - "mcse_sd | \n", - "ess_mean | \n", - "ess_sd | \n", - "ess_bulk | \n", - "ess_tail | \n", - "r_hat | \n", - "
---|---|---|---|---|---|---|---|---|---|---|---|
alpha | \n", - "0.96 | \n", - "0.11 | \n", - "0.75 | \n", - "1.16 | \n", - "0.00 | \n", - "0.0 | \n", - "9813.79 | \n", - "9781.97 | \n", - "9816.05 | \n", - "6783.11 | \n", - "1.0 | \n", - "
beta[0] | \n", - "1.10 | \n", - "0.12 | \n", - "0.89 | \n", - "1.33 | \n", - "0.00 | \n", - "0.0 | \n", - "8841.92 | \n", - "8797.67 | \n", - "8856.21 | \n", - "7109.65 | \n", - "1.0 | \n", - "
beta[1] | \n", - "2.99 | \n", - "0.53 | \n", - "1.95 | \n", - "3.95 | \n", - "0.01 | \n", - "0.0 | \n", - "7878.01 | \n", - "7765.26 | \n", - "7880.25 | \n", - "6515.70 | \n", - "1.0 | \n", - "
sigma | \n", - "1.07 | \n", - "0.08 | \n", - "0.92 | \n", - "1.21 | \n", - "0.00 | \n", - "0.0 | \n", - "8651.16 | \n", - "8475.93 | \n", - "8901.69 | \n", - "6633.66 | \n", - "1.0 | \n", - "
<xarray.Dataset>\n", - "Dimensions: (chain: 2, disasters_missing_dim_0: 2, draw: 10000)\n", - "Coordinates:\n", - " * chain (chain) int64 0 1\n", - " * draw (draw) int64 0 1 2 3 4 ... 9995 9996 9997 9998 9999\n", - " * disasters_missing_dim_0 (disasters_missing_dim_0) int64 0 1\n", - "Data variables:\n", - " switchpoint (chain, draw) int64 1891 1891 1891 ... 1892 1891\n", - " disasters_missing (chain, draw, disasters_missing_dim_0) int64 7 ....\n", - " early_rate (chain, draw) float64 3.025 3.076 ... 3.307 3.005\n", - " late_rate (chain, draw) float64 0.877 0.8663 ... 0.802 0.9272\n", - "Attributes:\n", - " created_at: 2021-02-08T06:29:28.922616\n", - " arviz_version: 0.11.0\n", - " inference_library: pymc3\n", - " inference_library_version: 3.11.0\n", - " sampling_time: 33.77551817893982\n", - " tuning_steps: 1000
array([0, 1])
array([ 0, 1, 2, ..., 9997, 9998, 9999])
array([0, 1])
array([[1891, 1891, 1891, ..., 1889, 1889, 1889],\n", - " [1892, 1886, 1889, ..., 1893, 1892, 1891]])
array([[[7, 0],\n", - " [6, 0],\n", - " [5, 1],\n", - " ...,\n", - " [1, 0],\n", - " [2, 1],\n", - " [0, 1]],\n", - "\n", - " [[3, 1],\n", - " [3, 1],\n", - " [2, 0],\n", - " ...,\n", - " [5, 1],\n", - " [5, 2],\n", - " [3, 0]]])
array([[3.02490339, 3.07599078, 3.56514373, ..., 2.61447178, 2.61447178,\n", - " 3.44918562],\n", - " [3.27857005, 3.03935203, 3.30460144, ..., 3.16269593, 3.30703433,\n", - " 3.00495012]])
array([[0.8769546 , 0.86634116, 0.95610952, ..., 0.826213 , 0.826213 ,\n", - " 0.91611314],\n", - " [0.9603648 , 1.01766597, 0.92037428, ..., 0.82312568, 0.80196379,\n", - " 0.92723499]])
<xarray.Dataset>\n", - "Dimensions: (chain: 2, disasters_dim_0: 111, draw: 10000)\n", - "Coordinates:\n", - " * chain (chain) int64 0 1\n", - " * draw (draw) int64 0 1 2 3 4 5 ... 9994 9995 9996 9997 9998 9999\n", - " * disasters_dim_0 (disasters_dim_0) int64 0 1 2 3 4 5 ... 106 107 108 109 110\n", - "Data variables:\n", - " disasters (chain, draw, disasters_dim_0) float64 -1.775 ... -1.003\n", - "Attributes:\n", - " created_at: 2021-02-08T06:29:30.682570\n", - " arviz_version: 0.11.0\n", - " inference_library: pymc3\n", - " inference_library_version: 3.11.0
array([0, 1])
array([ 0, 1, 2, ..., 9997, 9998, 9999])
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n", - " 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,\n", - " 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,\n", - " 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,\n", - " 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,\n", - " 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,\n", - " 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,\n", - " 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110])
array([[[-1.77544061, -2.27799936, -1.77544061, ..., -1.00825466,\n", - " -0.8769546 , -1.00825466],\n", - " [-1.75953639, -2.24534725, -1.75953639, ..., -1.00981766,\n", - " -0.86634116, -1.00981766],\n", - " [-1.65838008, -1.99661362, -1.65838008, ..., -1.00099233,\n", - " -0.95610952, -1.00099233],\n", - " ...,\n", - " [-1.94827729, -2.59665312, -1.94827729, ..., -1.01711567,\n", - " -0.826213 , -1.01711567],\n", - " [-1.94827729, -2.59665312, -1.94827729, ..., -1.01711567,\n", - " -0.826213 , -1.01711567],\n", - " [-1.67468685, -2.04598661, -1.67468685, ..., -1.00372855,\n", - " -0.91611314, -1.00372855]],\n", - "\n", - " [[-1.70699441, -2.12902496, -1.70699441, ..., -1.00080687,\n", - " -0.9603648 , -1.00080687],\n", - " [-1.77082848, -2.26862205, -1.77082848, ..., -1.00015423,\n", - " -1.01766597, -1.00015423],\n", - " [-1.70139178, -2.11551382, -1.70139178, ..., -1.00334915,\n", - " -0.92037428, -1.00334915],\n", - " ...,\n", - " [-1.73505054, -2.19306364, -1.73505054, ..., -1.01777206,\n", - " -0.82312568, -1.01777206],\n", - " [-1.7008809 , -2.114267 , -1.7008809 , ..., -1.02265561,\n", - " -0.80196379, -1.02265561],\n", - " [-1.78196007, -2.29113702, -1.78196007, ..., -1.00278324,\n", - " -0.92723499, -1.00278324]]])
<xarray.Dataset>\n", - "Dimensions: (accept_dim_0: 2, accepted_dim_0: 2, chain: 2, draw: 10000, scaling_dim_0: 2)\n", - "Coordinates:\n", - " * chain (chain) int64 0 1\n", - " * draw (draw) int64 0 1 2 3 4 5 ... 9995 9996 9997 9998 9999\n", - " * accepted_dim_0 (accepted_dim_0) int64 0 1\n", - " * scaling_dim_0 (scaling_dim_0) int64 0 1\n", - " * accept_dim_0 (accept_dim_0) int64 0 1\n", - "Data variables:\n", - " perf_counter_start (chain, draw) float64 481.6 481.6 481.6 ... 498.6 498.6\n", - " accepted (chain, draw, accepted_dim_0) bool False True ... True\n", - " diverging (chain, draw) bool False False False ... False False\n", - " step_size (chain, draw) float64 0.8683 0.8683 ... 1.078 1.078\n", - " tree_size (chain, draw) float64 3.0 1.0 3.0 3.0 ... 3.0 3.0 3.0\n", - " step_size_bar (chain, draw) float64 1.135 1.135 1.135 ... 1.112 1.112\n", - " energy (chain, draw) float64 180.8 177.0 178.7 ... 179.2 176.4\n", - " depth (chain, draw) int64 2 1 2 2 1 2 2 2 ... 2 1 2 2 1 2 2 2\n", - " scaling (chain, draw, scaling_dim_0) float64 1.464 ... 2.358\n", - " process_time_diff (chain, draw) float64 0.000612 0.00028 ... 0.000477\n", - " lp (chain, draw) float64 -177.8 -177.0 ... -179.1 -175.6\n", - " perf_counter_diff (chain, draw) float64 0.0006293 0.0002809 ... 0.0004771\n", - " accept (chain, draw, accept_dim_0) float64 0.3994 ... 2.969\n", - " energy_error (chain, draw) float64 -0.005702 -0.01118 ... -0.2271\n", - " max_energy_error (chain, draw) float64 1.129 -0.01118 ... 0.174 -0.2271\n", - " mean_tree_accept (chain, draw) float64 0.7362 1.0 ... 0.9566 0.9841\n", - "Attributes:\n", - " created_at: 2021-02-08T06:29:28.933359\n", - " arviz_version: 0.11.0\n", - " inference_library: pymc3\n", - " inference_library_version: 3.11.0\n", - " sampling_time: 33.77551817893982\n", - " tuning_steps: 1000
array([0, 1])
array([ 0, 1, 2, ..., 9997, 9998, 9999])
array([0, 1])
array([0, 1])
array([0, 1])
array([[481.62788461, 481.6291168 , 481.62996704, ..., 495.36697623,\n", - " 495.36820486, 495.36928442],\n", - " [484.98033663, 484.98164721, 484.98357028, ..., 498.62839408,\n", - " 498.62960144, 498.63077901]])
array([[[False, True],\n", - " [ True, False],\n", - " [ True, False],\n", - " ...,\n", - " [False, False],\n", - " [ True, False],\n", - " [ True, False]],\n", - "\n", - " [[False, True],\n", - " [False, True],\n", - " [ True, True],\n", - " ...,\n", - " [ True, True],\n", - " [ True, True],\n", - " [ True, True]]])
array([[False, False, False, ..., False, False, False],\n", - " [False, False, False, ..., False, False, False]])
array([[0.86829042, 0.86829042, 0.86829042, ..., 0.86829042, 0.86829042,\n", - " 0.86829042],\n", - " [1.0782656 , 1.0782656 , 1.0782656 , ..., 1.0782656 , 1.0782656 ,\n", - " 1.0782656 ]])
array([[3., 1., 3., ..., 1., 1., 3.],\n", - " [3., 3., 3., ..., 3., 3., 3.]])
array([[1.13544036, 1.13544036, 1.13544036, ..., 1.13544036, 1.13544036,\n", - " 1.13544036],\n", - " [1.11179432, 1.11179432, 1.11179432, ..., 1.11179432, 1.11179432,\n", - " 1.11179432]])
array([[180.84583537, 177.02264999, 178.74471241, ..., 177.91502384,\n", - " 179.90576987, 177.29210912],\n", - " [177.80556826, 177.9962228 , 178.65102567, ..., 179.71597262,\n", - " 179.20684298, 176.38309815]])
array([[2, 1, 2, ..., 1, 1, 2],\n", - " [2, 2, 2, ..., 2, 2, 2]])
array([[[1.4641 , 2.662 ],\n", - " [1.4641 , 2.662 ],\n", - " [1.4641 , 2.662 ],\n", - " ...,\n", - " [1.4641 , 2.662 ],\n", - " [1.4641 , 2.662 ],\n", - " [1.4641 , 2.662 ]],\n", - "\n", - " [[1.331 , 2.35794769],\n", - " [1.331 , 2.35794769],\n", - " [1.331 , 2.35794769],\n", - " ...,\n", - " [1.331 , 2.35794769],\n", - " [1.331 , 2.35794769],\n", - " [1.331 , 2.35794769]]])
array([[0.000612, 0.00028 , 0.000512, ..., 0.00048 , 0.000462, 0.000682],\n", - " [0.000659, 0.000737, 0.000566, ..., 0.000642, 0.000668, 0.000477]])
array([[-177.84833139, -176.96998605, -177.62208958, ..., -177.33637597,\n", - " -178.41132849, -176.00996052],\n", - " [-177.04631713, -177.91064831, -176.3128248 , ..., -178.69018227,\n", - " -179.11188422, -175.55994681]])
array([[0.00062935, 0.00028094, 0.00051247, ..., 0.00048048, 0.00046108,\n", - " 0.00102844],\n", - " [0.00066692, 0.00077655, 0.00056549, ..., 0.00066893, 0.00068386,\n", - " 0.00047712]])
array([[[0.39939657, 1.09440673],\n", - " [2.31412349, 0.720068 ],\n", - " [1.68987729, 0.05915239],\n", - " ...,\n", - " [0. , 0.89222621],\n", - " [0.34131396, 0.18869271],\n", - " [2.92985378, 0.28010575]],\n", - "\n", - " [[0.27690895, 0.63139657],\n", - " [0.03377025, 0.43825176],\n", - " [2.89674824, 1.64828846],\n", - " ...,\n", - " [0.82312568, 1. ],\n", - " [0.41156284, 2.70066777],\n", - " [5.68686693, 2.96930198]]])
array([[-0.00570162, -0.01118174, 0.30225096, ..., -0.01023484,\n", - " 0. , -0.42615453],\n", - " [ 0.04211059, 0.01095353, 0.05170346, ..., 0. ,\n", - " 0.17401783, -0.22711861]])
array([[ 1.12866425, -0.01118174, 0.7498719 , ..., -0.01023484,\n", - " 0.35802402, -0.42615453],\n", - " [ 0.35504171, 0.02885821, 0.82411198, ..., 0.35797759,\n", - " 0.17401783, -0.22711861]])
array([[0.7361729 , 1. , 0.70539919, ..., 1. , 0.69905628,\n", - " 1. ],\n", - " [0.84416844, 0.98705606, 0.70683012, ..., 0.70433264, 0.95660635,\n", - " 0.98413996]])
<xarray.Dataset>\n", - "Dimensions: (disasters_dim_0: 111)\n", - "Coordinates:\n", - " * disasters_dim_0 (disasters_dim_0) int64 0 1 2 3 4 5 ... 106 107 108 109 110\n", - "Data variables:\n", - " disasters (disasters_dim_0) float64 4.0 5.0 4.0 0.0 ... 1.0 0.0 1.0\n", - "Attributes:\n", - " created_at: 2021-02-08T06:29:30.683682\n", - " arviz_version: 0.11.0\n", - " inference_library: pymc3\n", - " inference_library_version: 3.11.0
array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,\n", - " 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27,\n", - " 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41,\n", - " 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55,\n", - " 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69,\n", - " 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, 80, 81, 82, 83,\n", - " 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97,\n", - " 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110])
array([ 4., 5., 4., 0., 1., 4., 3., 4., 0., 6., 3., 3., 4.,\n", - " 0., 2., 6., 3., 3., 5., 4., 5., 3., 1., 4., 4., 1.,\n", - " 5., 5., 3., 4., 2., 5., 2., 2., 3., 4., 2., 1., 3.,\n", - " nan, 2., 1., 1., 1., 1., 3., 0., 0., 1., 0., 1., 1.,\n", - " 0., 0., 3., 1., 0., 3., 2., 2., 0., 1., 1., 1., 0.,\n", - " 1., 0., 1., 0., 0., 0., 2., 1., 0., 0., 0., 1., 1.,\n", - " 0., 2., 3., 3., 1., nan, 2., 1., 1., 1., 1., 2., 4.,\n", - " 2., 0., 0., 1., 4., 0., 0., 0., 1., 0., 0., 0., 0.,\n", - " 0., 1., 0., 0., 1., 0., 1.])