From 4bea62dc28b93c4c347439dae9865a767363e3d5 Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Sat, 24 Apr 2021 15:47:32 +0100 Subject: [PATCH 1/2] use coords in dependent density regression --- .../dependent_density_regression.ipynb | 717 ++++-------------- 1 file changed, 156 insertions(+), 561 deletions(-) diff --git a/examples/mixture_models/dependent_density_regression.ipynb b/examples/mixture_models/dependent_density_regression.ipynb index 0cd667489..d9f8b420d 100644 --- a/examples/mixture_models/dependent_density_regression.ipynb +++ b/examples/mixture_models/dependent_density_regression.ipynb @@ -2,16 +2,7 @@ "cells": [ { "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.021485, - "end_time": "2020-12-18T15:39:09.385692", - "exception": false, - "start_time": "2020-12-18T15:39:09.364207", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "source": [ "# Dependent density regression\n", "In another [example](dp_mix.ipynb), we showed how to use Dirichlet processes to perform Bayesian nonparametric density estimation. This example expands on the previous one, illustrating dependent density regression.\n", @@ -21,29 +12,14 @@ }, { "cell_type": "code", - "execution_count": 1, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:39:09.433230Z", - "iopub.status.busy": "2020-12-18T15:39:09.432485Z", - "iopub.status.idle": "2020-12-18T15:39:16.049724Z", - "shell.execute_reply": "2020-12-18T15:39:16.048768Z" - }, - "papermill": { - "duration": 6.644193, - "end_time": "2020-12-18T15:39:16.049848", - "exception": false, - "start_time": "2020-12-18T15:39:09.405655", - "status": "completed" - }, - "tags": [] - }, + "execution_count": 2, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Running on PyMC3 v3.10.0\n" + "Running on PyMC3 v3.11.2\n" ] } ], @@ -64,23 +40,8 @@ }, { "cell_type": "code", - "execution_count": 2, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:39:16.118903Z", - "iopub.status.busy": "2020-12-18T15:39:16.116655Z", - "iopub.status.idle": "2020-12-18T15:39:16.119741Z", - "shell.execute_reply": "2020-12-18T15:39:16.120266Z" - }, - "papermill": { - "duration": 0.048777, - "end_time": "2020-12-18T15:39:16.120430", - "exception": false, - "start_time": "2020-12-18T15:39:16.071653", - "status": "completed" - }, - "tags": [] - }, + "execution_count": 3, + "metadata": {}, "outputs": [], "source": [ "%config InlineBackend.figure_format = 'retina'\n", @@ -93,39 +54,15 @@ }, { "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.020257, - "end_time": "2020-12-18T15:39:16.160652", - "exception": false, - "start_time": "2020-12-18T15:39:16.140395", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "source": [ "We will use the LIDAR data set from Larry Wasserman's excellent book, [_All of Nonparametric Statistics_](http://www.stat.cmu.edu/~larry/all-of-nonpar/). We standardize the data set to improve the rate of convergence of our samples." ] }, { "cell_type": "code", - "execution_count": 3, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:39:16.210611Z", - "iopub.status.busy": "2020-12-18T15:39:16.209819Z", - "iopub.status.idle": "2020-12-18T15:39:16.708850Z", - "shell.execute_reply": "2020-12-18T15:39:16.708223Z" - }, - "papermill": { - "duration": 0.527789, - "end_time": "2020-12-18T15:39:16.708982", - "exception": false, - "start_time": "2020-12-18T15:39:16.181193", - "status": "completed" - }, - "tags": [] - }, + "execution_count": 4, + "metadata": {}, "outputs": [], "source": [ "DATA_URI = \"http://www.stat.cmu.edu/~larry/all-of-nonpar/=data/lidar.dat\"\n", @@ -142,23 +79,8 @@ }, { "cell_type": "code", - "execution_count": 4, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:39:16.761187Z", - "iopub.status.busy": "2020-12-18T15:39:16.760225Z", - "iopub.status.idle": "2020-12-18T15:39:16.769609Z", - "shell.execute_reply": "2020-12-18T15:39:16.770070Z" - }, - "papermill": { - "duration": 0.040034, - "end_time": "2020-12-18T15:39:16.770185", - "exception": false, - "start_time": "2020-12-18T15:39:16.730151", - "status": "completed" - }, - "tags": [] - }, + "execution_count": 5, + "metadata": {}, "outputs": [ { "data": { @@ -236,7 +158,7 @@ "4 396 -0.059913 -1.655168 0.818631" ] }, - "execution_count": 4, + "execution_count": 5, "metadata": {}, "output_type": "execute_result" } @@ -247,39 +169,15 @@ }, { "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.021121, - "end_time": "2020-12-18T15:39:16.811381", - "exception": false, - "start_time": "2020-12-18T15:39:16.790260", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "source": [ "We plot the LIDAR data below." ] }, { "cell_type": "code", - "execution_count": 5, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:39:16.858458Z", - "iopub.status.busy": "2020-12-18T15:39:16.857709Z", - "iopub.status.idle": "2020-12-18T15:39:17.276172Z", - "shell.execute_reply": "2020-12-18T15:39:17.276774Z" - }, - "papermill": { - "duration": 0.444885, - "end_time": "2020-12-18T15:39:17.276934", - "exception": false, - "start_time": "2020-12-18T15:39:16.832049", - "status": "completed" - }, - "tags": [] - }, + "execution_count": 6, + "metadata": {}, "outputs": [ { "data": { @@ -311,16 +209,7 @@ }, { "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.022903, - "end_time": "2020-12-18T15:39:17.324348", - "exception": false, - "start_time": "2020-12-18T15:39:17.301445", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "source": [ "This data set has a two interesting properties that make it useful for illustrating dependent density regression.\n", "\n", @@ -332,23 +221,8 @@ }, { "cell_type": "code", - "execution_count": 6, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:39:17.394681Z", - "iopub.status.busy": "2020-12-18T15:39:17.393643Z", - "iopub.status.idle": "2020-12-18T15:39:18.315621Z", - "shell.execute_reply": "2020-12-18T15:39:18.314834Z" - }, - "papermill": { - "duration": 0.968079, - "end_time": "2020-12-18T15:39:18.315769", - "exception": false, - "start_time": "2020-12-18T15:39:17.347690", - "status": "completed" - }, - "tags": [] - }, + "execution_count": 7, + "metadata": {}, "outputs": [], "source": [ "fig, (scatter_ax, hist_ax) = plt.subplots(ncols=2, figsize=(16, 6))\n", @@ -392,23 +266,8 @@ }, { "cell_type": "code", - "execution_count": 7, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:39:18.367248Z", - "iopub.status.busy": "2020-12-18T15:39:18.366431Z", - "iopub.status.idle": "2020-12-18T15:39:22.854668Z", - "shell.execute_reply": "2020-12-18T15:39:22.853747Z" - }, - "papermill": { - "duration": 4.515607, - "end_time": "2020-12-18T15:39:22.854784", - "exception": false, - "start_time": "2020-12-18T15:39:18.339177", - "status": "completed" - }, - "tags": [] - }, + "execution_count": 8, + "metadata": {}, "outputs": [ { "data": { @@ -1442,7 +1301,7 @@ "" ] }, - "execution_count": 7, + "execution_count": 8, "metadata": {}, "output_type": "execute_result" } @@ -1453,16 +1312,7 @@ }, { "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.026205, - "end_time": "2020-12-18T15:39:22.908143", - "exception": false, - "start_time": "2020-12-18T15:39:22.881938", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "source": [ "As we slice the data with a window sliding along the x-axis in the left plot, the empirical distribution of the y-values of the points in the window varies in the right plot. An important aspect of this approach is that the density estimates that correspond to close values of the predictor are similar.\n", "\n", @@ -1489,23 +1339,8 @@ }, { "cell_type": "code", - "execution_count": 8, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:39:22.972869Z", - "iopub.status.busy": "2020-12-18T15:39:22.970896Z", - "iopub.status.idle": "2020-12-18T15:39:22.973674Z", - "shell.execute_reply": "2020-12-18T15:39:22.974219Z" - }, - "papermill": { - "duration": 0.038654, - "end_time": "2020-12-18T15:39:22.974358", - "exception": false, - "start_time": "2020-12-18T15:39:22.935704", - "status": "completed" - }, - "tags": [] - }, + "execution_count": 9, + "metadata": {}, "outputs": [], "source": [ "def norm_cdf(z):\n", @@ -1520,23 +1355,8 @@ }, { "cell_type": "code", - "execution_count": 9, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:39:23.038435Z", - "iopub.status.busy": "2020-12-18T15:39:23.037472Z", - "iopub.status.idle": "2020-12-18T15:39:53.434884Z", - "shell.execute_reply": "2020-12-18T15:39:53.433651Z" - }, - "papermill": { - "duration": 30.433368, - "end_time": "2020-12-18T15:39:53.435014", - "exception": false, - "start_time": "2020-12-18T15:39:23.001646", - "status": "completed" - }, - "tags": [] - }, + "execution_count": 10, + "metadata": {}, "outputs": [], "source": [ "N, _ = df.shape\n", @@ -1545,26 +1365,17 @@ "std_range = df.std_range.values[:, np.newaxis]\n", "std_logratio = df.std_logratio.values\n", "\n", - "with pm.Model() as model:\n", - " alpha = pm.Normal(\"alpha\", 0.0, 5.0, shape=K)\n", + "with pm.Model(coords={\"N\": np.arange(N), \"K\": np.arange(K) + 1}) as model:\n", + " alpha = pm.Normal(\"alpha\", 0.0, 5.0, dims=\"K\")\n", " beta = pm.Normal(\"beta\", 0.0, 5.0, shape=(1, K))\n", " x = pm.Data(\"x\", std_range)\n", " v = norm_cdf(alpha + pm.math.dot(x, beta))\n", - " w = pm.Deterministic(\"w\", stick_breaking(v))" + " w = pm.Deterministic(\"w\", stick_breaking(v), dims=[\"N\", \"K\"])" ] }, { "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.026426, - "end_time": "2020-12-18T15:39:53.491766", - "exception": false, - "start_time": "2020-12-18T15:39:53.465340", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "source": [ "We have defined `x` as a `pm.Data` container in order to use `PyMC3`'s posterior prediction capabilities later.\n", "\n", @@ -1588,23 +1399,8 @@ }, { "cell_type": "code", - "execution_count": 10, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:39:53.557290Z", - "iopub.status.busy": "2020-12-18T15:39:53.556204Z", - "iopub.status.idle": "2020-12-18T15:39:53.787741Z", - "shell.execute_reply": "2020-12-18T15:39:53.787110Z" - }, - "papermill": { - "duration": 0.26821, - "end_time": "2020-12-18T15:39:53.787869", - "exception": false, - "start_time": "2020-12-18T15:39:53.519659", - "status": "completed" - }, - "tags": [] - }, + "execution_count": 11, + "metadata": {}, "outputs": [], "source": [ "with model:\n", @@ -1615,39 +1411,15 @@ }, { "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.030827, - "end_time": "2020-12-18T15:39:53.849720", - "exception": false, - "start_time": "2020-12-18T15:39:53.818893", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "source": [ "Finally, we place the prior $\\tau_i \\sim \\textrm{Gamma}(1, 1)$ on the component precisions." ] }, { "cell_type": "code", - "execution_count": 11, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:39:53.921597Z", - "iopub.status.busy": "2020-12-18T15:39:53.920393Z", - "iopub.status.idle": "2020-12-18T15:40:49.643976Z", - "shell.execute_reply": "2020-12-18T15:40:49.644515Z" - }, - "papermill": { - "duration": 55.765002, - "end_time": "2020-12-18T15:40:49.644639", - "exception": false, - "start_time": "2020-12-18T15:39:53.879637", - "status": "completed" - }, - "tags": [] - }, + "execution_count": 12, + "metadata": {}, "outputs": [ { "data": { @@ -1665,177 +1437,177 @@ "\n", "\n", "cluster20\n", - "\n", - "20\n", + "\n", + "20\n", "\n", "\n", "cluster1 x 20\n", - "\n", - "1 x 20\n", + "\n", + "1 x 20\n", "\n", "\n", "cluster221 x 1\n", - "\n", - "221 x 1\n", + "\n", + "221 x 1\n", "\n", "\n", "cluster221 x 20\n", - "\n", - "221 x 20\n", + "\n", + "221 x 20\n", "\n", "\n", "cluster221\n", - "\n", - "221\n", - "\n", - "\n", - "\n", - "tau\n", - "\n", - "tau\n", - "~\n", - "Gamma\n", - "\n", - "\n", - "\n", - "obs\n", - "\n", - "obs\n", - "~\n", - "NormalMixture\n", - "\n", - "\n", - "\n", - "tau->obs\n", - "\n", - "\n", + "\n", + "221\n", "\n", "\n", - "\n", + "\n", "gamma\n", - "\n", - "gamma\n", - "~\n", - "Normal\n", + "\n", + "gamma\n", + "~\n", + "Normal\n", "\n", "\n", "\n", "mu\n", - "\n", - "mu\n", - "~\n", - "Deterministic\n", + "\n", + "mu\n", + "~\n", + "Deterministic\n", "\n", "\n", - "\n", + "\n", "gamma->mu\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "alpha\n", - "\n", - "alpha\n", - "~\n", - "Normal\n", + "\n", + "alpha\n", + "~\n", + "Normal\n", "\n", "\n", "\n", "w\n", - "\n", - "w\n", - "~\n", - "Deterministic\n", + "\n", + "w\n", + "~\n", + "Deterministic\n", "\n", "\n", - "\n", + "\n", "alpha->w\n", - "\n", - "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "tau\n", + "\n", + "tau\n", + "~\n", + "Gamma\n", + "\n", + "\n", + "\n", + "obs\n", + "\n", + "obs\n", + "~\n", + "NormalMixture\n", + "\n", + "\n", + "\n", + "tau->obs\n", + "\n", + "\n", "\n", "\n", "\n", "beta\n", - "\n", - "beta\n", - "~\n", - "Normal\n", + "\n", + "beta\n", + "~\n", + "Normal\n", "\n", "\n", - "\n", + "\n", "beta->w\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "delta\n", - "\n", - "delta\n", - "~\n", - "Normal\n", + "\n", + "delta\n", + "~\n", + "Normal\n", "\n", "\n", - "\n", + "\n", "delta->mu\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "x\n", - "\n", - "x\n", - "~\n", - "Data\n", + "\n", + "x\n", + "~\n", + "Data\n", "\n", "\n", - "\n", + "\n", "x->w\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "x->mu\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "w->obs\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "mu->obs\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "y\n", - "\n", - "y\n", - "~\n", - "Data\n", + "\n", + "y\n", + "~\n", + "Data\n", "\n", "\n", "\n", "obs->y\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, - "execution_count": 11, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } @@ -1851,39 +1623,15 @@ }, { "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.027851, - "end_time": "2020-12-18T15:40:49.700524", - "exception": false, - "start_time": "2020-12-18T15:40:49.672673", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "source": [ "We now sample from the dependent density regression model." ] }, { "cell_type": "code", - "execution_count": 12, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:40:49.763101Z", - "iopub.status.busy": "2020-12-18T15:40:49.762286Z", - "iopub.status.idle": "2020-12-18T15:48:50.089080Z", - "shell.execute_reply": "2020-12-18T15:48:50.087914Z" - }, - "papermill": { - "duration": 480.361556, - "end_time": "2020-12-18T15:48:50.089232", - "exception": false, - "start_time": "2020-12-18T15:40:49.727676", - "status": "completed" - }, - "tags": [] - }, + "execution_count": 13, + "metadata": {}, "outputs": [ { "data": { @@ -1903,7 +1651,7 @@ " }\n", " \n", " \n", - " 100.00% [30000/30000 03:12<00:00 Sampling chain 0, 0 divergences]\n", + " 100.00% [30000/30000 02:48<00:00 Sampling chain 0, 0 divergences]\n", " \n", " " ], @@ -1932,7 +1680,7 @@ " }\n", " \n", " \n", - " 100.00% [30000/30000 03:06<00:00 Sampling chain 1, 0 divergences]\n", + " 100.00% [30000/30000 02:43<00:00 Sampling chain 1, 0 divergences]\n", " \n", " " ], @@ -1950,79 +1698,20 @@ "\n", "with model:\n", " step = pm.Metropolis()\n", - " trace = pm.sample(SAMPLES, step=step, tune=BURN, random_seed=SEED, return_inferencedata=True)" + " trace = pm.sample(SAMPLES, tune=BURN, step=step, random_seed=SEED, return_inferencedata=True)" ] }, { "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.029328, - "end_time": "2020-12-18T15:48:50.149503", - "exception": false, - "start_time": "2020-12-18T15:48:50.120175", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "source": [ "To verify that truncation did not unduly influence our results, we plot the largest posterior expected mixture weight for each component. (In this model, each point has a mixture weight for each component, so we plot the maximum mixture weight for each component across all data points in order to judge if the component exerts any influence on the posterior.)" ] }, - { - "cell_type": "code", - "execution_count": 13, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:48:50.215378Z", - "iopub.status.busy": "2020-12-18T15:48:50.214431Z", - "iopub.status.idle": "2020-12-18T15:48:50.218196Z", - "shell.execute_reply": "2020-12-18T15:48:50.218804Z" - }, - "papermill": { - "duration": 0.040946, - "end_time": "2020-12-18T15:48:50.218924", - "exception": false, - "start_time": "2020-12-18T15:48:50.177978", - "status": "completed" - }, - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "(2, 20000, 221, 20)" - ] - }, - "execution_count": 13, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "trace.posterior[\"w\"].shape # (n_chains, n_samples, N, K)" - ] - }, { "cell_type": "code", "execution_count": 14, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:48:50.285924Z", - "iopub.status.busy": "2020-12-18T15:48:50.284921Z", - "iopub.status.idle": "2020-12-18T15:48:50.844573Z", - "shell.execute_reply": "2020-12-18T15:48:50.845231Z" - }, - "papermill": { - "duration": 0.5966, - "end_time": "2020-12-18T15:48:50.845425", - "exception": false, - "start_time": "2020-12-18T15:48:50.248825", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "outputs": [ { "data": { @@ -2043,10 +1732,8 @@ "source": [ "fig, ax = plt.subplots(figsize=(8, 6))\n", "\n", - "w_posterior = trace.posterior[\"w\"].data\n", - "# reshape to (n_chains*n_samples, N, K)\n", - "w_posterior = w_posterior.reshape(-1, w_posterior.shape[2], w_posterior.shape[3])\n", - "ax.bar(np.arange(K) + 1, w_posterior.mean(axis=0).max(axis=0))\n", + "max_mixture_weights = trace.posterior[\"w\"].mean((\"chain\", \"draw\")).max(\"N\")\n", + "ax.bar(max_mixture_weights.coords.to_index(), max_mixture_weights)\n", "\n", "ax.set_xlim(1 - 0.5, K + 0.5)\n", "ax.set_xticks(np.arange(0, K, 2) + 1)\n", @@ -2057,16 +1744,7 @@ }, { "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.031097, - "end_time": "2020-12-18T15:48:50.911683", - "exception": false, - "start_time": "2020-12-18T15:48:50.880586", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "source": [ "Since only three mixture components have appreciable posterior expected weight for any data point, we can be fairly certain that truncation did not unduly influence our results. (If most components had appreciable posterior expected weight, truncation may have influenced the results, and we would have increased the number of components and sampled again.)\n", "\n", @@ -2076,28 +1754,13 @@ { "cell_type": "code", "execution_count": 15, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:48:51.606803Z", - "iopub.status.busy": "2020-12-18T15:48:51.130587Z", - "iopub.status.idle": "2020-12-18T15:52:56.720126Z", - "shell.execute_reply": "2020-12-18T15:52:56.720612Z" - }, - "papermill": { - "duration": 245.777219, - "end_time": "2020-12-18T15:52:56.720777", - "exception": false, - "start_time": "2020-12-18T15:48:50.943558", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "/opt/conda/lib/python3.7/site-packages/pymc3/sampling.py:1691: UserWarning: samples parameter is smaller than nchains times ndraws, some draws and/or chains may not be represented in the returned posterior predictive sample\n", + "/opt/conda/lib/python3.7/site-packages/pymc3/sampling.py:1690: UserWarning: samples parameter is smaller than nchains times ndraws, some draws and/or chains may not be represented in the returned posterior predictive sample\n", " \"samples parameter is smaller than nchains times ndraws, some draws \"\n" ] }, @@ -2119,7 +1782,7 @@ " }\n", " \n", " \n", - " 100.00% [5000/5000 03:21<00:00]\n", + " 100.00% [5000/5000 02:36<00:00]\n", " \n", " " ], @@ -2143,16 +1806,7 @@ }, { "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.032508, - "end_time": "2020-12-18T15:52:56.786756", - "exception": false, - "start_time": "2020-12-18T15:52:56.754248", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "source": [ "Below we plot the posterior expected value and the 95% posterior credible interval." ] @@ -2160,22 +1814,7 @@ { "cell_type": "code", "execution_count": 16, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:52:56.864375Z", - "iopub.status.busy": "2020-12-18T15:52:56.863260Z", - "iopub.status.idle": "2020-12-18T15:52:57.294728Z", - "shell.execute_reply": "2020-12-18T15:52:57.295278Z" - }, - "papermill": { - "duration": 0.475588, - "end_time": "2020-12-18T15:52:57.295445", - "exception": false, - "start_time": "2020-12-18T15:52:56.819857", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "outputs": [ { "data": { @@ -2217,16 +1856,7 @@ }, { "cell_type": "markdown", - "metadata": { - "papermill": { - "duration": 0.037029, - "end_time": "2020-12-18T15:52:57.371211", - "exception": false, - "start_time": "2020-12-18T15:52:57.334182", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "source": [ "The model has fit the linear components of the data well, and also accomodated its heteroskedasticity. This flexibility, along with the ability to modularly specify the conditional mixture weights and conditional component densities, makes dependent density regression an extremely useful nonparametric Bayesian model.\n", "\n", @@ -2239,43 +1869,28 @@ }, { "cell_type": "code", - "execution_count": 18, - "metadata": { - "execution": { - "iopub.execute_input": "2020-12-18T15:53:09.107982Z", - "iopub.status.busy": "2020-12-18T15:53:09.106908Z", - "iopub.status.idle": "2020-12-18T15:53:09.194464Z", - "shell.execute_reply": "2020-12-18T15:53:09.193888Z" - }, - "papermill": { - "duration": 0.137799, - "end_time": "2020-12-18T15:53:09.194592", - "exception": false, - "start_time": "2020-12-18T15:53:09.056793", - "status": "completed" - }, - "tags": [] - }, + "execution_count": 17, + "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Last updated: Fri Dec 18 2020\n", + "Last updated: Sat Apr 24 2021\n", "\n", "Python implementation: CPython\n", "Python version : 3.7.6\n", "IPython version : 7.13.0\n", "\n", "pandas : 1.1.5\n", - "seaborn : 0.10.0\n", "matplotlib: 3.2.1\n", - "theano : 1.0.5\n", - "arviz : 0.10.0\n", + "seaborn : 0.10.0\n", + "theano : 1.1.2\n", + "arviz : 0.11.2\n", "numpy : 1.17.5\n", - "pymc3 : 3.10.0\n", + "pymc3 : 3.11.2\n", "\n", - "Watermark: 2.1.0\n", + "Watermark: 2.2.0\n", "\n" ] } @@ -2288,16 +1903,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "papermill": { - "duration": 0.042162, - "end_time": "2020-12-18T15:53:09.280948", - "exception": false, - "start_time": "2020-12-18T15:53:09.238786", - "status": "completed" - }, - "tags": [] - }, + "metadata": {}, "outputs": [], "source": [] } @@ -2319,17 +1925,6 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" - }, - "papermill": { - "duration": 844.57815, - "end_time": "2020-12-18T15:53:09.535936", - "environment_variables": {}, - "exception": null, - "input_path": "__notebook__.ipynb", - "output_path": "__notebook__.ipynb", - "parameters": {}, - "start_time": "2020-12-18T15:39:04.957786", - "version": "2.1.0" } }, "nbformat": 4, From 1de39c2e286ff926238cdd31707128aade37a7de Mon Sep 17 00:00:00 2001 From: Marco Gorelli Date: Sat, 24 Apr 2021 19:10:27 +0100 Subject: [PATCH 2/2] use {one} in coords --- .../dependent_density_regression.ipynb | 210 +++++++++--------- 1 file changed, 105 insertions(+), 105 deletions(-) diff --git a/examples/mixture_models/dependent_density_regression.ipynb b/examples/mixture_models/dependent_density_regression.ipynb index d9f8b420d..c5628c0d5 100644 --- a/examples/mixture_models/dependent_density_regression.ipynb +++ b/examples/mixture_models/dependent_density_regression.ipynb @@ -1359,15 +1359,15 @@ "metadata": {}, "outputs": [], "source": [ - "N, _ = df.shape\n", + "N = len(df)\n", "K = 20\n", "\n", "std_range = df.std_range.values[:, np.newaxis]\n", "std_logratio = df.std_logratio.values\n", "\n", - "with pm.Model(coords={\"N\": np.arange(N), \"K\": np.arange(K) + 1}) as model:\n", + "with pm.Model(coords={\"N\": np.arange(N), \"K\": np.arange(K) + 1, \"one\": [1]}) as model:\n", " alpha = pm.Normal(\"alpha\", 0.0, 5.0, dims=\"K\")\n", - " beta = pm.Normal(\"beta\", 0.0, 5.0, shape=(1, K))\n", + " beta = pm.Normal(\"beta\", 0.0, 5.0, dims=(\"one\", \"K\"))\n", " x = pm.Data(\"x\", std_range)\n", " v = norm_cdf(alpha + pm.math.dot(x, beta))\n", " w = pm.Deterministic(\"w\", stick_breaking(v), dims=[\"N\", \"K\"])" @@ -1404,8 +1404,8 @@ "outputs": [], "source": [ "with model:\n", - " gamma = pm.Normal(\"gamma\", 0.0, 10.0, shape=K)\n", - " delta = pm.Normal(\"delta\", 0.0, 10.0, shape=(1, K))\n", + " gamma = pm.Normal(\"gamma\", 0.0, 10.0, dims=\"K\")\n", + " delta = pm.Normal(\"delta\", 0.0, 10.0, dims=(\"one\", \"K\"))\n", " mu = pm.Deterministic(\"mu\", gamma + pm.math.dot(x, delta))" ] }, @@ -1437,174 +1437,174 @@ "\n", "\n", "cluster20\n", - "\n", - "20\n", + "\n", + "20\n", "\n", "\n", "cluster1 x 20\n", - "\n", - "1 x 20\n", + "\n", + "1 x 20\n", "\n", "\n", "cluster221 x 1\n", - "\n", - "221 x 1\n", + "\n", + "221 x 1\n", "\n", "\n", "cluster221 x 20\n", - "\n", - "221 x 20\n", + "\n", + "221 x 20\n", "\n", "\n", "cluster221\n", - "\n", - "221\n", + "\n", + "221\n", "\n", - "\n", + "\n", "\n", - "gamma\n", - "\n", - "gamma\n", - "~\n", - "Normal\n", + "tau\n", + "\n", + "tau\n", + "~\n", + "Gamma\n", "\n", - "\n", - "\n", - "mu\n", - "\n", - "mu\n", - "~\n", - "Deterministic\n", + "\n", + "\n", + "obs\n", + "\n", + "obs\n", + "~\n", + "NormalMixture\n", "\n", - "\n", - "\n", - "gamma->mu\n", - "\n", - "\n", + "\n", + "\n", + "tau->obs\n", + "\n", + "\n", "\n", "\n", "\n", "alpha\n", - "\n", - "alpha\n", - "~\n", - "Normal\n", + "\n", + "alpha\n", + "~\n", + "Normal\n", "\n", "\n", "\n", "w\n", - "\n", - "w\n", - "~\n", - "Deterministic\n", + "\n", + "w\n", + "~\n", + "Deterministic\n", "\n", "\n", - "\n", + "\n", "alpha->w\n", - "\n", - "\n", + "\n", + "\n", "\n", - "\n", + "\n", "\n", - "tau\n", - "\n", - "tau\n", - "~\n", - "Gamma\n", + "gamma\n", + "\n", + "gamma\n", + "~\n", + "Normal\n", "\n", - "\n", - "\n", - "obs\n", - "\n", - "obs\n", - "~\n", - "NormalMixture\n", + "\n", + "\n", + "mu\n", + "\n", + "mu\n", + "~\n", + "Deterministic\n", "\n", - "\n", - "\n", - "tau->obs\n", - "\n", - "\n", + "\n", + "\n", + "gamma->mu\n", + "\n", + "\n", "\n", "\n", "\n", "beta\n", - "\n", - "beta\n", - "~\n", - "Normal\n", + "\n", + "beta\n", + "~\n", + "Normal\n", "\n", "\n", - "\n", + "\n", "beta->w\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "delta\n", - "\n", - "delta\n", - "~\n", - "Normal\n", + "\n", + "delta\n", + "~\n", + "Normal\n", "\n", "\n", "\n", "delta->mu\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "x\n", - "\n", - "x\n", - "~\n", - "Data\n", + "\n", + "x\n", + "~\n", + "Data\n", "\n", "\n", - "\n", + "\n", "x->w\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", - "\n", + "\n", "x->mu\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "w->obs\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "mu->obs\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n", "y\n", - "\n", - "y\n", - "~\n", - "Data\n", + "\n", + "y\n", + "~\n", + "Data\n", "\n", "\n", "\n", "obs->y\n", - "\n", - "\n", + "\n", + "\n", "\n", "\n", "\n" ], "text/plain": [ - "" + "" ] }, "execution_count": 12, @@ -1614,7 +1614,7 @@ ], "source": [ "with model:\n", - " tau = pm.Gamma(\"tau\", 1.0, 1.0, shape=K)\n", + " tau = pm.Gamma(\"tau\", 1.0, 1.0, dims=\"K\")\n", " y = pm.Data(\"y\", std_logratio)\n", " obs = pm.NormalMixture(\"obs\", w, mu, tau=tau, observed=y)\n", "\n", @@ -1651,7 +1651,7 @@ " }\n", " \n", " \n", - " 100.00% [30000/30000 02:48<00:00 Sampling chain 0, 0 divergences]\n", + " 100.00% [30000/30000 02:56<00:00 Sampling chain 0, 0 divergences]\n", " \n", " " ], @@ -1680,7 +1680,7 @@ " }\n", " \n", " \n", - " 100.00% [30000/30000 02:43<00:00 Sampling chain 1, 0 divergences]\n", + " 100.00% [30000/30000 02:50<00:00 Sampling chain 1, 0 divergences]\n", " \n", " " ], @@ -1782,7 +1782,7 @@ " }\n", " \n", " \n", - " 100.00% [5000/5000 02:36<00:00]\n", + " 100.00% [5000/5000 03:23<00:00]\n", " \n", " " ], @@ -1882,13 +1882,13 @@ "Python version : 3.7.6\n", "IPython version : 7.13.0\n", "\n", - "pandas : 1.1.5\n", - "matplotlib: 3.2.1\n", - "seaborn : 0.10.0\n", - "theano : 1.1.2\n", - "arviz : 0.11.2\n", "numpy : 1.17.5\n", + "matplotlib: 3.2.1\n", "pymc3 : 3.11.2\n", + "arviz : 0.11.2\n", + "theano : 1.1.2\n", + "pandas : 1.1.5\n", + "seaborn : 0.10.0\n", "\n", "Watermark: 2.2.0\n", "\n"