From 28cbbca11815bcbced81d403dda342012afdf064 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 24 May 2025 00:49:48 +0800 Subject: [PATCH 1/5] hacky notebook --- Rewrite VI, why not.ipynb | 751 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 751 insertions(+) create mode 100644 Rewrite VI, why not.ipynb diff --git a/Rewrite VI, why not.ipynb b/Rewrite VI, why not.ipynb new file mode 100644 index 000000000..f2a6c3794 --- /dev/null +++ b/Rewrite VI, why not.ipynb @@ -0,0 +1,751 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "id": "9f946eb4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytensor.tensor as pt\n", + "\n", + "import pymc as pm" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "e746bc33", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling: [X, alpha, beta, sigma, y]\n" + ] + } + ], + "source": [ + "with pm.Model() as m:\n", + " X = pm.Normal(\"X\", 0, 1, size=(100, 10))\n", + " alpha = pm.Normal(\"alpha\", 100, 10)\n", + " beta = pm.Normal(\"beta\", 0, 5, size=(10,))\n", + "\n", + " mu = alpha + X @ beta\n", + " sigma = pm.Exponential(\"sigma\", 1)\n", + " y = pm.Normal(\"y\", mu=mu, sigma=sigma)\n", + "\n", + " prior = pm.sample_prior_predictive()" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "a8ca0161", + "metadata": {}, + "outputs": [], + "source": [ + "draw = 123\n", + "true_params = np.r_[\n", + " prior.prior.alpha.sel(chain=0, draw=draw).values, prior.prior.beta.sel(chain=0, draw=draw)\n", + "]\n", + "X_data = prior.prior.X.sel(chain=0, draw=draw).values\n", + "y_data = prior.prior.y.sel(chain=0, draw=draw).values" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "b89f4031", + "metadata": {}, + "outputs": [], + "source": [ + "m_obs = pm.observe(pm.do(m, {X: X_data}), {\"y\": y_data})" + ] + }, + { + "cell_type": "code", + "execution_count": 168, + "id": "a42a84e4", + "metadata": {}, + "outputs": [], + "source": [ + "Parameter = pt.tensor\n", + "\n", + "draws = pt.tensor(\"draws\", shape=(), dtype=\"int64\")\n", + "\n", + "with pm.Model() as guide_model:\n", + " X = pm.Data(\"X\", X_data)\n", + " alpha_loc = Parameter(\"alpha_loc\", shape=())\n", + " alpha_scale = Parameter(\"alpha_scale\", shape=())\n", + " alpha_z = pm.Normal(\"alpha_z\", mu=0, sigma=1, shape=(draws,))\n", + " alpha = pm.Deterministic(\"alpha\", alpha_loc + alpha_scale * alpha_z)\n", + "\n", + " beta_loc = Parameter(\"beta_loc\", shape=(10,))\n", + " beta_scale = Parameter(\"beta_scale\", shape=(10,))\n", + " beta_z = pm.Normal(\"beta_z\", mu=0, sigma=1, shape=(draws, 10))\n", + " beta = pm.Deterministic(\"beta\", beta_loc + beta_scale * beta_z)\n", + "\n", + " mu = alpha + X @ beta\n", + "\n", + " sigma_loc = Parameter(\"sigma_loc\", shape=())\n", + " sigma_scale = Parameter(\"sigma_scale\", shape=())\n", + " sigma_z = pm.Normal(\"sigma_z\", 0, 1, shape=(draws,))\n", + " sigma = pm.Deterministic(\"sigma\", pt.softplus(sigma_loc + sigma_scale * sigma_z))\n", + "\n", + "# with pm.Model() as guide_model2:\n", + "# n = 10 + 1 + 1\n", + "# loc = Parameter(\"loc\", shape=(n,))\n", + "# chol_flat = Parameter(\"chol\", shape=(n * n-1, ))\n", + "# chol = pm.expand_packed_triangular(n, chol_flat)\n", + "# latent_mvn = pm.MvNormal(\"latent_mvn\", chol=chol)\n", + "\n", + "# pm.Deterministic(\"beta\", latent_mvn[:10])\n", + "# pm.Deterministic(\"alpha\", latent_mvn[10])\n", + "# pm.Deterministic(\"sigma\", pm.math.exp(latent_mvn[11]))" + ] + }, + { + "cell_type": "code", + "execution_count": 169, + "id": "bffb1b69", + "metadata": {}, + "outputs": [], + "source": [ + "params = [alpha_loc, alpha_scale, beta_loc, beta_scale, sigma_loc, sigma_scale]" + ] + }, + { + "cell_type": "code", + "execution_count": 171, + "id": "e2c4dd95", + "metadata": {}, + "outputs": [], + "source": [ + "f_draw = pm.compile([*params, draws], guide_model.deterministics)" + ] + }, + { + "cell_type": "code", + "execution_count": 182, + "id": "bfe44b73", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([1.78150955]),\n", + " array([[ 0.22971278, 0.4621461 , 0.81535912, 0.62397751, 1.12162984,\n", + " 0.99310042, -0.04733258, 1.20791346, 0.61310399, 0.6248215 ]]),\n", + " array([1.04612599])]" + ] + }, + "execution_count": 182, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f_draw(**param_dict, draws=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 204, + "id": "e831865f", + "metadata": {}, + "outputs": [], + "source": [ + "init_dict = m_obs.initial_point()\n", + "init_dict = {k: np.expand_dims(v, 0) for k, v in init_dict.items()}\n", + "param_dict = {param.name: np.full(param.type.shape, 0.5) for param in params}" + ] + }, + { + "cell_type": "code", + "execution_count": 187, + "id": "a9f3fa0e", + "metadata": {}, + "outputs": [], + "source": [ + "from pytensor.graph.replace import graph_replace, vectorize_graph\n", + "\n", + "outputs = [m_obs.datalogp, m_obs.varlogp]\n", + "inputs = m_obs.value_vars\n", + "inputs_to_guide_rvs = {\n", + " model_value_var: guide_model[rv.name]\n", + " for rv, model_value_var in m_obs.rvs_to_values.items()\n", + " if rv not in m_obs.observed_RVs\n", + "}\n", + "model_logp = vectorize_graph(m_obs.logp(), inputs_to_guide_rvs)\n", + "guide_logq = graph_replace(guide_model.logp(), guide_model.values_to_rvs)\n", + "\n", + "elbo_loss = (guide_logq - model_logp).mean()\n", + "d_loss = pt.grad(elbo_loss, params)\n", + "\n", + "f_loss_dloss = pm.compile(params + [draws], [elbo_loss, *d_loss], trust_input=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 207, + "id": "6086f2cc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-993.9173656892265\n", + "-926.3296229583198\n", + "-1005.6712508438249\n", + "-919.6431050083108\n", + "-968.0575921556582\n", + "-987.30412574489\n", + "-964.7096352135038\n", + "-983.4362084632743\n", + "-910.8310347546233\n", + "-946.7555282915881\n", + "-943.9616597508198\n", + "-963.8766211135949\n", + "-916.6993999462212\n", + "-950.0475547851752\n", + "-967.001546433409\n", + "-943.9433767873104\n", + "-946.3015434524876\n", + "-934.0032018513697\n", + "-947.6452058569878\n", + "-893.224934118183\n", + "-979.4405814988748\n", + "-937.9997192780036\n", + "-931.1970724735944\n", + "-959.3588230003159\n", + "-932.1233734322371\n", + "-940.2791556640857\n", + "-969.4679954671045\n", + "-954.6606993395337\n", + "-982.9304227234845\n", + "-935.3404389982638\n", + "-982.6885250322749\n", + "-964.4628736035113\n", + "-939.0580477804804\n", + "-955.1672719181267\n", + "-982.0467504680682\n", + "-992.8688427985264\n", + "-967.8588846826874\n", + "-966.7655668600194\n", + "-949.3323016540423\n", + "-934.3919364553586\n", + "-1028.4493525361129\n", + "-982.4944127954707\n", + "-931.9404059809432\n", + "-981.3845063690508\n", + "-930.1688452196312\n", + "-952.0305908505434\n", + "-1012.7969628343917\n", + "-937.6379090307294\n", + "-926.5273721862864\n", + "-981.6665090046366\n", + "-980.4287957334458\n", + "-946.1849036479391\n", + "-969.4075653507354\n", + "-958.7555610935674\n", + "-1004.2802914265156\n", + "-963.0682736675113\n", + "-1000.664222867785\n", + "-934.7298341549889\n", + "-978.1088760452403\n", + "-944.0432566257858\n", + "-980.5576704437719\n", + "-969.5612703209024\n", + "-963.9839441465232\n", + "-966.4944698264674\n", + "-953.2780303578338\n", + "-953.5548215260793\n", + "-959.5050408283469\n", + "-989.7240037630827\n", + "-949.1398757808872\n", + "-964.8820244293524\n", + "-940.224485134473\n", + "-947.7591586789077\n", + "-964.679359616359\n", + "-917.0015516708557\n", + "-939.5619779661295\n", + "-993.9876761804422\n", + "-986.3960585411946\n", + "-975.6545633352575\n", + "-957.4304264740541\n", + "-970.7472516497804\n", + "-956.5223102168254\n", + "-969.825369100695\n", + "-987.9141239728567\n", + "-952.6424507446926\n", + "-958.6852758514054\n", + "-965.0105265156106\n", + "-962.943175120098\n", + "-946.9378229990178\n", + "-944.4992023963096\n", + "-947.4171539678032\n", + "-959.1841786549261\n", + "-931.9541021582517\n", + "-972.675454084302\n", + "-1030.6428849456215\n", + "-983.2067258579275\n", + "-1002.0505700409622\n", + "-1015.4906878612064\n", + "-962.2604231546592\n", + "-955.1833250683931\n", + "-938.1777960392025\n" + ] + } + ], + "source": [ + "learning_rate = 1e-4\n", + "for _ in range(100):\n", + " loss, *grads = f_loss_dloss(**param_dict, draws=100)\n", + " for (name, value), grad in zip(param_dict.items(), grads):\n", + " param_dict[name] = value - learning_rate * grad\n", + " print(loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 208, + "id": "1251b6f1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'alpha_loc': 0.6942395020077371,\n", + " 'alpha_scale': 0.5102778323184816,\n", + " 'beta_loc': array([0.49114368, 0.46379742, 0.52277764, 0.53227815, 0.48851862,\n", + " 0.50044113, 0.5441339 , 0.46643231, 0.47894475, 0.51122713]),\n", + " 'beta_scale': array([0.49807094, 0.49740464, 0.5005586 , 0.4991539 , 0.49749037,\n", + " 0.49825551, 0.49992182, 0.4981251 , 0.49745959, 0.49767778]),\n", + " 'sigma_loc': 4.2148859620931765,\n", + " 'sigma_scale': -5.964631972194452e-05}" + ] + }, + "execution_count": 208, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "param_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 209, + "id": "11345fb4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([83.43771778, -6.07738876, -3.3268889 , 8.38393732, 10.77212434,\n", + " -2.81776509, 0.46737085, 8.7204497 , -4.79822835, -3.47220908,\n", + " 8.76186526])" + ] + }, + "execution_count": 209, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "true_params" + ] + }, + { + "cell_type": "code", + "execution_count": 165, + "id": "566105c2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(273519.9558606)" + ] + }, + "execution_count": 165, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f_loss(**param_dict, draws=100)" + ] + }, + { + "cell_type": "code", + "execution_count": 166, + "id": "c39cd24e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[alpha_loc, alpha_scale, beta_loc, beta_scale, sigma_loc, sigma_scale]" + ] + }, + "execution_count": 166, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "params" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "id": "310c739f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{alpha: alpha, beta: beta, sigma_log__: sigma}" + ] + }, + "execution_count": 106, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inputs_to_guide_inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "2fe8da0a", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_loss(m_obs, m_guide, beta=1.0):\n", + " return -m_obs.datalogp + beta * (m_guide.logp() - m_obs.varlogp)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "b04b9b5c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Add [id A]\n", + " ├─ Neg [id B]\n", + " │ └─ Add [id C]\n", + " │ ├─ Sum{axes=None} [id D] '__logp'\n", + " │ │ └─ MakeVector{dtype='float64'} [id E]\n", + " │ │ └─ Sum{axes=None} [id F]\n", + " │ │ └─ Check{sigma > 0} [id G]\n", + " │ │ ├─ Sub [id H]\n", + " │ │ │ ├─ Sub [id I]\n", + " │ │ │ │ ├─ Mul [id J]\n", + " │ │ │ │ │ ├─ ExpandDims{axis=0} [id K]\n", + " │ │ │ │ │ │ └─ -0.5 [id L]\n", + " │ │ │ │ │ └─ Pow [id M]\n", + " │ │ │ │ │ ├─ True_div [id N]\n", + " │ │ │ │ │ │ ├─ Sub [id O]\n", + " │ │ │ │ │ │ │ ├─ [122.65317 ... .32067026] [id P]\n", + " │ │ │ │ │ │ │ └─ Add [id Q]\n", + " │ │ │ │ │ │ │ ├─ ExpandDims{axis=0} [id R]\n", + " │ │ │ │ │ │ │ │ └─ alpha [id S]\n", + " │ │ │ │ │ │ │ └─ Squeeze{axis=1} [id T]\n", + " │ │ │ │ │ │ │ └─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id U]\n", + " │ │ │ │ │ │ │ ├─ X [id V]\n", + " │ │ │ │ │ │ │ └─ ExpandDims{axis=1} [id W]\n", + " │ │ │ │ │ │ │ └─ beta [id X]\n", + " │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Y]\n", + " │ │ │ │ │ │ └─ Exp [id Z]\n", + " │ │ │ │ │ │ └─ sigma_log__ [id BA]\n", + " │ │ │ │ │ └─ ExpandDims{axis=0} [id BB]\n", + " │ │ │ │ │ └─ 2 [id BC]\n", + " │ │ │ │ └─ ExpandDims{axis=0} [id BD]\n", + " │ │ │ │ └─ Log [id BE]\n", + " │ │ │ │ └─ Sqrt [id BF]\n", + " │ │ │ │ └─ 6.283185307179586 [id BG]\n", + " │ │ │ └─ Log [id BH]\n", + " │ │ │ └─ ExpandDims{axis=0} [id Y]\n", + " │ │ │ └─ ···\n", + " │ │ └─ All{axes=None} [id BI]\n", + " │ │ └─ MakeVector{dtype='bool'} [id BJ]\n", + " │ │ └─ All{axes=None} [id BK]\n", + " │ │ └─ Gt [id BL]\n", + " │ │ ├─ ExpandDims{axis=0} [id Y]\n", + " │ │ │ └─ ···\n", + " │ │ └─ ExpandDims{axis=0} [id BM]\n", + " │ │ └─ 0 [id BN]\n", + " │ └─ 0.0 [id BO]\n", + " └─ Mul [id BP]\n", + " ├─ 1.0 [id BQ]\n", + " └─ Sub [id BR]\n", + " ├─ Sum{axes=None} [id BS] '__logp'\n", + " │ └─ MakeVector{dtype='float64'} [id BT]\n", + " │ ├─ Sum{axes=None} [id BU]\n", + " │ │ └─ Check{sigma > 0} [id BV] 'alpha_logprob'\n", + " │ │ ├─ Sub [id BW]\n", + " │ │ │ ├─ Sub [id BX]\n", + " │ │ │ │ ├─ Mul [id BY]\n", + " │ │ │ │ │ ├─ -0.5 [id BZ]\n", + " │ │ │ │ │ └─ Pow [id CA]\n", + " │ │ │ │ │ ├─ True_div [id CB]\n", + " │ │ │ │ │ │ ├─ Sub [id CC]\n", + " │ │ │ │ │ │ │ ├─ alpha [id CD]\n", + " │ │ │ │ │ │ │ └─ alpha_loc [id CE]\n", + " │ │ │ │ │ │ └─ alpha_scale [id CF]\n", + " │ │ │ │ │ └─ 2 [id CG]\n", + " │ │ │ │ └─ Log [id CH]\n", + " │ │ │ │ └─ Sqrt [id CI]\n", + " │ │ │ │ └─ 6.283185307179586 [id CJ]\n", + " │ │ │ └─ Log [id CK]\n", + " │ │ │ └─ alpha_scale [id CF]\n", + " │ │ └─ All{axes=None} [id CL]\n", + " │ │ └─ MakeVector{dtype='bool'} [id CM]\n", + " │ │ └─ All{axes=None} [id CN]\n", + " │ │ └─ Gt [id CO]\n", + " │ │ ├─ alpha_scale [id CF]\n", + " │ │ └─ 0 [id CP]\n", + " │ ├─ Sum{axes=None} [id CQ]\n", + " │ │ └─ Check{sigma > 0} [id CR] 'beta_logprob'\n", + " │ │ ├─ Sub [id CS]\n", + " │ │ │ ├─ Sub [id CT]\n", + " │ │ │ │ ├─ Mul [id CU]\n", + " │ │ │ │ │ ├─ ExpandDims{axis=0} [id CV]\n", + " │ │ │ │ │ │ └─ -0.5 [id CW]\n", + " │ │ │ │ │ └─ Pow [id CX]\n", + " │ │ │ │ │ ├─ True_div [id CY]\n", + " │ │ │ │ │ │ ├─ Sub [id CZ]\n", + " │ │ │ │ │ │ │ ├─ beta [id DA]\n", + " │ │ │ │ │ │ │ └─ beta_loc [id DB]\n", + " │ │ │ │ │ │ └─ beta_scale [id DC]\n", + " │ │ │ │ │ └─ ExpandDims{axis=0} [id DD]\n", + " │ │ │ │ │ └─ 2 [id DE]\n", + " │ │ │ │ └─ ExpandDims{axis=0} [id DF]\n", + " │ │ │ │ └─ Log [id DG]\n", + " │ │ │ │ └─ Sqrt [id DH]\n", + " │ │ │ │ └─ 6.283185307179586 [id DI]\n", + " │ │ │ └─ Log [id DJ]\n", + " │ │ │ └─ beta_scale [id DC]\n", + " │ │ └─ All{axes=None} [id DK]\n", + " │ │ └─ MakeVector{dtype='bool'} [id DL]\n", + " │ │ └─ All{axes=None} [id DM]\n", + " │ │ └─ Gt [id DN]\n", + " │ │ ├─ beta_scale [id DC]\n", + " │ │ └─ ExpandDims{axis=0} [id DO]\n", + " │ │ └─ 0 [id DP]\n", + " │ └─ Sum{axes=None} [id DQ]\n", + " │ └─ Add [id DR] 'sigma_log___logprob'\n", + " │ ├─ Check{sigma > 0} [id DS]\n", + " │ │ ├─ Sub [id DT]\n", + " │ │ │ ├─ Sub [id DU]\n", + " │ │ │ │ ├─ Mul [id DV]\n", + " │ │ │ │ │ ├─ -0.5 [id DW]\n", + " │ │ │ │ │ └─ Pow [id DX]\n", + " │ │ │ │ │ ├─ True_div [id DY]\n", + " │ │ │ │ │ │ ├─ Sub [id DZ]\n", + " │ │ │ │ │ │ │ ├─ Exp [id EA]\n", + " │ │ │ │ │ │ │ │ └─ sigma_log__ [id EB]\n", + " │ │ │ │ │ │ │ └─ sigma_loc [id EC]\n", + " │ │ │ │ │ │ └─ sigma_scale [id ED]\n", + " │ │ │ │ │ └─ 2 [id EE]\n", + " │ │ │ │ └─ Log [id EF]\n", + " │ │ │ │ └─ Sqrt [id EG]\n", + " │ │ │ │ └─ 6.283185307179586 [id EH]\n", + " │ │ │ └─ Log [id EI]\n", + " │ │ │ └─ sigma_scale [id ED]\n", + " │ │ └─ All{axes=None} [id EJ]\n", + " │ │ └─ MakeVector{dtype='bool'} [id EK]\n", + " │ │ └─ All{axes=None} [id EL]\n", + " │ │ └─ Gt [id EM]\n", + " │ │ ├─ sigma_scale [id ED]\n", + " │ │ └─ 0 [id EN]\n", + " │ └─ Identity [id EO] 'sigma_log___log_jacobian'\n", + " │ └─ sigma_log__ [id EB]\n", + " └─ Sum{axes=None} [id EP] '__logp'\n", + " └─ MakeVector{dtype='float64'} [id EQ]\n", + " ├─ Sum{axes=None} [id ER]\n", + " │ └─ Check{sigma > 0} [id ES] 'X_logprob'\n", + " │ ├─ Sub [id ET]\n", + " │ │ ├─ Sub [id EU]\n", + " │ │ │ ├─ Mul [id EV]\n", + " │ │ │ │ ├─ ExpandDims{axes=[0, 1]} [id EW]\n", + " │ │ │ │ │ └─ -0.5 [id EX]\n", + " │ │ │ │ └─ Pow [id EY]\n", + " │ │ │ │ ├─ True_div [id EZ]\n", + " │ │ │ │ │ ├─ Sub [id FA]\n", + " │ │ │ │ │ │ ├─ X [id V]\n", + " │ │ │ │ │ │ └─ [[0]] [id FB]\n", + " │ │ │ │ │ └─ [[1]] [id FC]\n", + " │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id FD]\n", + " │ │ │ │ └─ 2 [id FE]\n", + " │ │ │ └─ ExpandDims{axes=[0, 1]} [id FF]\n", + " │ │ │ └─ Log [id FG]\n", + " │ │ │ └─ Sqrt [id FH]\n", + " │ │ │ └─ 6.283185307179586 [id FI]\n", + " │ │ └─ Log [id FJ]\n", + " │ │ └─ [[1]] [id FC]\n", + " │ └─ All{axes=None} [id FK]\n", + " │ └─ MakeVector{dtype='bool'} [id FL]\n", + " │ └─ All{axes=None} [id FM]\n", + " │ └─ Gt [id FN]\n", + " │ ├─ [[1]] [id FC]\n", + " │ └─ ExpandDims{axes=[0, 1]} [id FO]\n", + " │ └─ 0 [id FP]\n", + " ├─ Sum{axes=None} [id FQ]\n", + " │ └─ Check{sigma > 0} [id FR] 'alpha_logprob'\n", + " │ ├─ Sub [id FS]\n", + " │ │ ├─ Sub [id FT]\n", + " │ │ │ ├─ Mul [id FU]\n", + " │ │ │ │ ├─ -0.5 [id FV]\n", + " │ │ │ │ └─ Pow [id FW]\n", + " │ │ │ │ ├─ True_div [id FX]\n", + " │ │ │ │ │ ├─ Sub [id FY]\n", + " │ │ │ │ │ │ ├─ alpha [id S]\n", + " │ │ │ │ │ │ └─ 100 [id FZ]\n", + " │ │ │ │ │ └─ 10 [id GA]\n", + " │ │ │ │ └─ 2 [id GB]\n", + " │ │ │ └─ Log [id GC]\n", + " │ │ │ └─ Sqrt [id GD]\n", + " │ │ │ └─ 6.283185307179586 [id GE]\n", + " │ │ └─ Log [id GF]\n", + " │ │ └─ 10 [id GA]\n", + " │ └─ All{axes=None} [id GG]\n", + " │ └─ MakeVector{dtype='bool'} [id GH]\n", + " │ └─ All{axes=None} [id GI]\n", + " │ └─ Gt [id GJ]\n", + " │ ├─ 10 [id GA]\n", + " │ └─ 0 [id GK]\n", + " ├─ Sum{axes=None} [id GL]\n", + " │ └─ Check{sigma > 0} [id GM] 'beta_logprob'\n", + " │ ├─ Sub [id GN]\n", + " │ │ ├─ Sub [id GO]\n", + " │ │ │ ├─ Mul [id GP]\n", + " │ │ │ │ ├─ ExpandDims{axis=0} [id GQ]\n", + " │ │ │ │ │ └─ -0.5 [id GR]\n", + " │ │ │ │ └─ Pow [id GS]\n", + " │ │ │ │ ├─ True_div [id GT]\n", + " │ │ │ │ │ ├─ Sub [id GU]\n", + " │ │ │ │ │ │ ├─ beta [id X]\n", + " │ │ │ │ │ │ └─ [0] [id GV]\n", + " │ │ │ │ │ └─ [5] [id GW]\n", + " │ │ │ │ └─ ExpandDims{axis=0} [id GX]\n", + " │ │ │ │ └─ 2 [id GY]\n", + " │ │ │ └─ ExpandDims{axis=0} [id GZ]\n", + " │ │ │ └─ Log [id HA]\n", + " │ │ │ └─ Sqrt [id HB]\n", + " │ │ │ └─ 6.283185307179586 [id HC]\n", + " │ │ └─ Log [id HD]\n", + " │ │ └─ [5] [id GW]\n", + " │ └─ All{axes=None} [id HE]\n", + " │ └─ MakeVector{dtype='bool'} [id HF]\n", + " │ └─ All{axes=None} [id HG]\n", + " │ └─ Gt [id HH]\n", + " │ ├─ [5] [id GW]\n", + " │ └─ ExpandDims{axis=0} [id HI]\n", + " │ └─ 0 [id HJ]\n", + " └─ Sum{axes=None} [id HK]\n", + " └─ Add [id HL] 'sigma_log___logprob'\n", + " ├─ Check{mu >= 0} [id HM]\n", + " │ ├─ Switch [id HN]\n", + " │ │ ├─ Ge [id HO]\n", + " │ │ │ ├─ Exp [id HP]\n", + " │ │ │ │ └─ sigma_log__ [id BA]\n", + " │ │ │ └─ 0.0 [id HQ]\n", + " │ │ ├─ Sub [id HR]\n", + " │ │ │ ├─ Neg [id HS]\n", + " │ │ │ │ └─ Log [id HT]\n", + " │ │ │ │ └─ 1.0 [id HU]\n", + " │ │ │ └─ True_div [id HV]\n", + " │ │ │ ├─ Exp [id HP]\n", + " │ │ │ │ └─ ···\n", + " │ │ │ └─ 1.0 [id HU]\n", + " │ │ └─ -inf [id HW]\n", + " │ └─ All{axes=None} [id HX]\n", + " │ └─ MakeVector{dtype='bool'} [id HY]\n", + " │ └─ All{axes=None} [id HZ]\n", + " │ └─ Ge [id IA]\n", + " │ ├─ 1.0 [id HU]\n", + " │ └─ 0 [id IB]\n", + " └─ Identity [id IC] 'sigma_log___log_jacobian'\n", + " └─ sigma_log__ [id BA]\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "compute_loss(m_obs, guide_model).dprint()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d155685f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98fe5067", + "metadata": {}, + "outputs": [], + "source": [ + "## TODO:\n", + "# 1. Create hyperparameters for mean field approx (mu + sigma of normals)\n", + "# 2. Replace in the logp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e01dcf96", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From a5f6921578281aeb1ce7dfc8b871e54f81b2a46e Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 25 May 2025 19:29:23 +0800 Subject: [PATCH 2/5] re-run notebook, show it actually works --- Rewrite VI, why not.ipynb | 589 +++++--------------------------------- 1 file changed, 75 insertions(+), 514 deletions(-) diff --git a/Rewrite VI, why not.ipynb b/Rewrite VI, why not.ipynb index f2a6c3794..6dda37553 100644 --- a/Rewrite VI, why not.ipynb +++ b/Rewrite VI, why not.ipynb @@ -2,11 +2,12 @@ "cells": [ { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "id": "9f946eb4", "metadata": {}, "outputs": [], "source": [ + "import numpy as np\n", "import pytensor.tensor as pt\n", "\n", "import pymc as pm" @@ -14,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 2, "id": "e746bc33", "metadata": {}, "outputs": [ @@ -41,14 +42,16 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 3, "id": "a8ca0161", "metadata": {}, "outputs": [], "source": [ "draw = 123\n", "true_params = np.r_[\n", - " prior.prior.alpha.sel(chain=0, draw=draw).values, prior.prior.beta.sel(chain=0, draw=draw)\n", + " prior.prior.alpha.sel(chain=0, draw=draw).values,\n", + " prior.prior.beta.sel(chain=0, draw=draw),\n", + " prior.prior.sigma.sel(chain=0, draw=draw),\n", "]\n", "X_data = prior.prior.X.sel(chain=0, draw=draw).values\n", "y_data = prior.prior.y.sel(chain=0, draw=draw).values" @@ -56,7 +59,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 4, "id": "b89f4031", "metadata": {}, "outputs": [], @@ -66,7 +69,7 @@ }, { "cell_type": "code", - "execution_count": 168, + "execution_count": 5, "id": "a42a84e4", "metadata": {}, "outputs": [], @@ -108,7 +111,7 @@ }, { "cell_type": "code", - "execution_count": 169, + "execution_count": 6, "id": "bffb1b69", "metadata": {}, "outputs": [], @@ -118,53 +121,7 @@ }, { "cell_type": "code", - "execution_count": 171, - "id": "e2c4dd95", - "metadata": {}, - "outputs": [], - "source": [ - "f_draw = pm.compile([*params, draws], guide_model.deterministics)" - ] - }, - { - "cell_type": "code", - "execution_count": 182, - "id": "bfe44b73", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[array([1.78150955]),\n", - " array([[ 0.22971278, 0.4621461 , 0.81535912, 0.62397751, 1.12162984,\n", - " 0.99310042, -0.04733258, 1.20791346, 0.61310399, 0.6248215 ]]),\n", - " array([1.04612599])]" - ] - }, - "execution_count": 182, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "f_draw(**param_dict, draws=1)" - ] - }, - { - "cell_type": "code", - "execution_count": 204, - "id": "e831865f", - "metadata": {}, - "outputs": [], - "source": [ - "init_dict = m_obs.initial_point()\n", - "init_dict = {k: np.expand_dims(v, 0) for k, v in init_dict.items()}\n", - "param_dict = {param.name: np.full(param.type.shape, 0.5) for param in params}" - ] - }, - { - "cell_type": "code", - "execution_count": 187, + "execution_count": 7, "id": "a9f3fa0e", "metadata": {}, "outputs": [], @@ -181,547 +138,151 @@ "model_logp = vectorize_graph(m_obs.logp(), inputs_to_guide_rvs)\n", "guide_logq = graph_replace(guide_model.logp(), guide_model.values_to_rvs)\n", "\n", - "elbo_loss = (guide_logq - model_logp).mean()\n", - "d_loss = pt.grad(elbo_loss, params)\n", + "negative_elbo = (guide_logq - model_logp).mean()\n", + "d_loss = pt.grad(negative_elbo, params)\n", "\n", - "f_loss_dloss = pm.compile(params + [draws], [elbo_loss, *d_loss], trust_input=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 207, - "id": "6086f2cc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-993.9173656892265\n", - "-926.3296229583198\n", - "-1005.6712508438249\n", - "-919.6431050083108\n", - "-968.0575921556582\n", - "-987.30412574489\n", - "-964.7096352135038\n", - "-983.4362084632743\n", - "-910.8310347546233\n", - "-946.7555282915881\n", - "-943.9616597508198\n", - "-963.8766211135949\n", - "-916.6993999462212\n", - "-950.0475547851752\n", - "-967.001546433409\n", - "-943.9433767873104\n", - "-946.3015434524876\n", - "-934.0032018513697\n", - "-947.6452058569878\n", - "-893.224934118183\n", - "-979.4405814988748\n", - "-937.9997192780036\n", - "-931.1970724735944\n", - "-959.3588230003159\n", - "-932.1233734322371\n", - "-940.2791556640857\n", - "-969.4679954671045\n", - "-954.6606993395337\n", - "-982.9304227234845\n", - "-935.3404389982638\n", - "-982.6885250322749\n", - "-964.4628736035113\n", - "-939.0580477804804\n", - "-955.1672719181267\n", - "-982.0467504680682\n", - "-992.8688427985264\n", - "-967.8588846826874\n", - "-966.7655668600194\n", - "-949.3323016540423\n", - "-934.3919364553586\n", - "-1028.4493525361129\n", - "-982.4944127954707\n", - "-931.9404059809432\n", - "-981.3845063690508\n", - "-930.1688452196312\n", - "-952.0305908505434\n", - "-1012.7969628343917\n", - "-937.6379090307294\n", - "-926.5273721862864\n", - "-981.6665090046366\n", - "-980.4287957334458\n", - "-946.1849036479391\n", - "-969.4075653507354\n", - "-958.7555610935674\n", - "-1004.2802914265156\n", - "-963.0682736675113\n", - "-1000.664222867785\n", - "-934.7298341549889\n", - "-978.1088760452403\n", - "-944.0432566257858\n", - "-980.5576704437719\n", - "-969.5612703209024\n", - "-963.9839441465232\n", - "-966.4944698264674\n", - "-953.2780303578338\n", - "-953.5548215260793\n", - "-959.5050408283469\n", - "-989.7240037630827\n", - "-949.1398757808872\n", - "-964.8820244293524\n", - "-940.224485134473\n", - "-947.7591586789077\n", - "-964.679359616359\n", - "-917.0015516708557\n", - "-939.5619779661295\n", - "-993.9876761804422\n", - "-986.3960585411946\n", - "-975.6545633352575\n", - "-957.4304264740541\n", - "-970.7472516497804\n", - "-956.5223102168254\n", - "-969.825369100695\n", - "-987.9141239728567\n", - "-952.6424507446926\n", - "-958.6852758514054\n", - "-965.0105265156106\n", - "-962.943175120098\n", - "-946.9378229990178\n", - "-944.4992023963096\n", - "-947.4171539678032\n", - "-959.1841786549261\n", - "-931.9541021582517\n", - "-972.675454084302\n", - "-1030.6428849456215\n", - "-983.2067258579275\n", - "-1002.0505700409622\n", - "-1015.4906878612064\n", - "-962.2604231546592\n", - "-955.1833250683931\n", - "-938.1777960392025\n" - ] - } - ], - "source": [ - "learning_rate = 1e-4\n", - "for _ in range(100):\n", - " loss, *grads = f_loss_dloss(**param_dict, draws=100)\n", - " for (name, value), grad in zip(param_dict.items(), grads):\n", - " param_dict[name] = value - learning_rate * grad\n", - " print(loss)" + "f_loss_dloss = pm.compile(params + [draws], [negative_elbo, *d_loss], trust_input=True)" ] }, { "cell_type": "code", - "execution_count": 208, - "id": "1251b6f1", + "execution_count": 8, + "id": "37ae25b1", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'alpha_loc': 0.6942395020077371,\n", - " 'alpha_scale': 0.5102778323184816,\n", - " 'beta_loc': array([0.49114368, 0.46379742, 0.52277764, 0.53227815, 0.48851862,\n", - " 0.50044113, 0.5441339 , 0.46643231, 0.47894475, 0.51122713]),\n", - " 'beta_scale': array([0.49807094, 0.49740464, 0.5005586 , 0.4991539 , 0.49749037,\n", - " 0.49825551, 0.49992182, 0.4981251 , 0.49745959, 0.49767778]),\n", - " 'sigma_loc': 4.2148859620931765,\n", - " 'sigma_scale': -5.964631972194452e-05}" - ] - }, - "execution_count": 208, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "param_dict" + "init_dict = m_obs.initial_point()\n", + "init_dict = {k: np.expand_dims(v, 0) for k, v in init_dict.items()}\n", + "param_dict = {param.name: np.full(param.type.shape, 0.5) for param in params}" ] }, { "cell_type": "code", - "execution_count": 209, - "id": "11345fb4", + "execution_count": 9, + "id": "6086f2cc", "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "array([83.43771778, -6.07738876, -3.3268889 , 8.38393732, 10.77212434,\n", - " -2.81776509, 0.46737085, 8.7204497 , -4.79822835, -3.47220908,\n", - " 8.76186526])" - ] - }, - "execution_count": 209, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "-8391.5148860327435\r" + ] } ], "source": [ - "true_params" + "learning_rate = 1e-5\n", + "n_iter = 60_000\n", + "loss_history = np.empty(n_iter)\n", + "for i in range(n_iter):\n", + " loss, *grads = f_loss_dloss(**param_dict, draws=500)\n", + " loss_history[i] = loss\n", + " for (name, value), grad in zip(param_dict.items(), grads):\n", + " param_dict[name] = (value - learning_rate * grad).copy()\n", + " if i % 50 == 0:\n", + " print(loss, end=\"\\r\")\n", + " if i % 10_000 == 0 and i > 0:\n", + " learning_rate = min(learning_rate * 10, 1e-3)" ] }, { "cell_type": "code", - "execution_count": 165, - "id": "566105c2", + "execution_count": 10, + "id": "650c5e39", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array(273519.9558606)" + "[]" ] }, - "execution_count": 165, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" - } - ], - "source": [ - "f_loss(**param_dict, draws=100)" - ] - }, - { - "cell_type": "code", - "execution_count": 166, - "id": "c39cd24e", - "metadata": {}, - "outputs": [ + }, { "data": { + "image/png": "", "text/plain": [ - "[alpha_loc, alpha_scale, beta_loc, beta_scale, sigma_loc, sigma_scale]" + "
" ] }, - "execution_count": 166, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "params" + "import matplotlib.pyplot as plt\n", + "\n", + "window_size = 100\n", + "kernel = np.full(window_size, 1 / window_size)\n", + "plt.plot(np.convolve(loss_history, kernel, mode=\"valid\"))" ] }, { "cell_type": "code", - "execution_count": 106, - "id": "310c739f", + "execution_count": 11, + "id": "1251b6f1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{alpha: alpha, beta: beta, sigma_log__: sigma}" + "array([102.12462346, 10.74446646, 0.51969095, -7.61435818,\n", + " 8.55366616, -8.5301462 , 0.69953323, -0.55440606,\n", + " -2.43179013, -5.36278597, -1.29241817, -7.09759975])" ] }, - "execution_count": 106, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "inputs_to_guide_inputs" + "np.r_[param_dict[\"alpha_loc\"], param_dict[\"beta_loc\"], param_dict[\"sigma_loc\"]]" ] }, { "cell_type": "code", - "execution_count": 22, - "id": "2fe8da0a", - "metadata": {}, - "outputs": [], - "source": [ - "def compute_loss(m_obs, m_guide, beta=1.0):\n", - " return -m_obs.datalogp + beta * (m_guide.logp() - m_obs.varlogp)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "b04b9b5c", + "execution_count": 12, + "id": "11345fb4", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Add [id A]\n", - " ├─ Neg [id B]\n", - " │ └─ Add [id C]\n", - " │ ├─ Sum{axes=None} [id D] '__logp'\n", - " │ │ └─ MakeVector{dtype='float64'} [id E]\n", - " │ │ └─ Sum{axes=None} [id F]\n", - " │ │ └─ Check{sigma > 0} [id G]\n", - " │ │ ├─ Sub [id H]\n", - " │ │ │ ├─ Sub [id I]\n", - " │ │ │ │ ├─ Mul [id J]\n", - " │ │ │ │ │ ├─ ExpandDims{axis=0} [id K]\n", - " │ │ │ │ │ │ └─ -0.5 [id L]\n", - " │ │ │ │ │ └─ Pow [id M]\n", - " │ │ │ │ │ ├─ True_div [id N]\n", - " │ │ │ │ │ │ ├─ Sub [id O]\n", - " │ │ │ │ │ │ │ ├─ [122.65317 ... .32067026] [id P]\n", - " │ │ │ │ │ │ │ └─ Add [id Q]\n", - " │ │ │ │ │ │ │ ├─ ExpandDims{axis=0} [id R]\n", - " │ │ │ │ │ │ │ │ └─ alpha [id S]\n", - " │ │ │ │ │ │ │ └─ Squeeze{axis=1} [id T]\n", - " │ │ │ │ │ │ │ └─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id U]\n", - " │ │ │ │ │ │ │ ├─ X [id V]\n", - " │ │ │ │ │ │ │ └─ ExpandDims{axis=1} [id W]\n", - " │ │ │ │ │ │ │ └─ beta [id X]\n", - " │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Y]\n", - " │ │ │ │ │ │ └─ Exp [id Z]\n", - " │ │ │ │ │ │ └─ sigma_log__ [id BA]\n", - " │ │ │ │ │ └─ ExpandDims{axis=0} [id BB]\n", - " │ │ │ │ │ └─ 2 [id BC]\n", - " │ │ │ │ └─ ExpandDims{axis=0} [id BD]\n", - " │ │ │ │ └─ Log [id BE]\n", - " │ │ │ │ └─ Sqrt [id BF]\n", - " │ │ │ │ └─ 6.283185307179586 [id BG]\n", - " │ │ │ └─ Log [id BH]\n", - " │ │ │ └─ ExpandDims{axis=0} [id Y]\n", - " │ │ │ └─ ···\n", - " │ │ └─ All{axes=None} [id BI]\n", - " │ │ └─ MakeVector{dtype='bool'} [id BJ]\n", - " │ │ └─ All{axes=None} [id BK]\n", - " │ │ └─ Gt [id BL]\n", - " │ │ ├─ ExpandDims{axis=0} [id Y]\n", - " │ │ │ └─ ···\n", - " │ │ └─ ExpandDims{axis=0} [id BM]\n", - " │ │ └─ 0 [id BN]\n", - " │ └─ 0.0 [id BO]\n", - " └─ Mul [id BP]\n", - " ├─ 1.0 [id BQ]\n", - " └─ Sub [id BR]\n", - " ├─ Sum{axes=None} [id BS] '__logp'\n", - " │ └─ MakeVector{dtype='float64'} [id BT]\n", - " │ ├─ Sum{axes=None} [id BU]\n", - " │ │ └─ Check{sigma > 0} [id BV] 'alpha_logprob'\n", - " │ │ ├─ Sub [id BW]\n", - " │ │ │ ├─ Sub [id BX]\n", - " │ │ │ │ ├─ Mul [id BY]\n", - " │ │ │ │ │ ├─ -0.5 [id BZ]\n", - " │ │ │ │ │ └─ Pow [id CA]\n", - " │ │ │ │ │ ├─ True_div [id CB]\n", - " │ │ │ │ │ │ ├─ Sub [id CC]\n", - " │ │ │ │ │ │ │ ├─ alpha [id CD]\n", - " │ │ │ │ │ │ │ └─ alpha_loc [id CE]\n", - " │ │ │ │ │ │ └─ alpha_scale [id CF]\n", - " │ │ │ │ │ └─ 2 [id CG]\n", - " │ │ │ │ └─ Log [id CH]\n", - " │ │ │ │ └─ Sqrt [id CI]\n", - " │ │ │ │ └─ 6.283185307179586 [id CJ]\n", - " │ │ │ └─ Log [id CK]\n", - " │ │ │ └─ alpha_scale [id CF]\n", - " │ │ └─ All{axes=None} [id CL]\n", - " │ │ └─ MakeVector{dtype='bool'} [id CM]\n", - " │ │ └─ All{axes=None} [id CN]\n", - " │ │ └─ Gt [id CO]\n", - " │ │ ├─ alpha_scale [id CF]\n", - " │ │ └─ 0 [id CP]\n", - " │ ├─ Sum{axes=None} [id CQ]\n", - " │ │ └─ Check{sigma > 0} [id CR] 'beta_logprob'\n", - " │ │ ├─ Sub [id CS]\n", - " │ │ │ ├─ Sub [id CT]\n", - " │ │ │ │ ├─ Mul [id CU]\n", - " │ │ │ │ │ ├─ ExpandDims{axis=0} [id CV]\n", - " │ │ │ │ │ │ └─ -0.5 [id CW]\n", - " │ │ │ │ │ └─ Pow [id CX]\n", - " │ │ │ │ │ ├─ True_div [id CY]\n", - " │ │ │ │ │ │ ├─ Sub [id CZ]\n", - " │ │ │ │ │ │ │ ├─ beta [id DA]\n", - " │ │ │ │ │ │ │ └─ beta_loc [id DB]\n", - " │ │ │ │ │ │ └─ beta_scale [id DC]\n", - " │ │ │ │ │ └─ ExpandDims{axis=0} [id DD]\n", - " │ │ │ │ │ └─ 2 [id DE]\n", - " │ │ │ │ └─ ExpandDims{axis=0} [id DF]\n", - " │ │ │ │ └─ Log [id DG]\n", - " │ │ │ │ └─ Sqrt [id DH]\n", - " │ │ │ │ └─ 6.283185307179586 [id DI]\n", - " │ │ │ └─ Log [id DJ]\n", - " │ │ │ └─ beta_scale [id DC]\n", - " │ │ └─ All{axes=None} [id DK]\n", - " │ │ └─ MakeVector{dtype='bool'} [id DL]\n", - " │ │ └─ All{axes=None} [id DM]\n", - " │ │ └─ Gt [id DN]\n", - " │ │ ├─ beta_scale [id DC]\n", - " │ │ └─ ExpandDims{axis=0} [id DO]\n", - " │ │ └─ 0 [id DP]\n", - " │ └─ Sum{axes=None} [id DQ]\n", - " │ └─ Add [id DR] 'sigma_log___logprob'\n", - " │ ├─ Check{sigma > 0} [id DS]\n", - " │ │ ├─ Sub [id DT]\n", - " │ │ │ ├─ Sub [id DU]\n", - " │ │ │ │ ├─ Mul [id DV]\n", - " │ │ │ │ │ ├─ -0.5 [id DW]\n", - " │ │ │ │ │ └─ Pow [id DX]\n", - " │ │ │ │ │ ├─ True_div [id DY]\n", - " │ │ │ │ │ │ ├─ Sub [id DZ]\n", - " │ │ │ │ │ │ │ ├─ Exp [id EA]\n", - " │ │ │ │ │ │ │ │ └─ sigma_log__ [id EB]\n", - " │ │ │ │ │ │ │ └─ sigma_loc [id EC]\n", - " │ │ │ │ │ │ └─ sigma_scale [id ED]\n", - " │ │ │ │ │ └─ 2 [id EE]\n", - " │ │ │ │ └─ Log [id EF]\n", - " │ │ │ │ └─ Sqrt [id EG]\n", - " │ │ │ │ └─ 6.283185307179586 [id EH]\n", - " │ │ │ └─ Log [id EI]\n", - " │ │ │ └─ sigma_scale [id ED]\n", - " │ │ └─ All{axes=None} [id EJ]\n", - " │ │ └─ MakeVector{dtype='bool'} [id EK]\n", - " │ │ └─ All{axes=None} [id EL]\n", - " │ │ └─ Gt [id EM]\n", - " │ │ ├─ sigma_scale [id ED]\n", - " │ │ └─ 0 [id EN]\n", - " │ └─ Identity [id EO] 'sigma_log___log_jacobian'\n", - " │ └─ sigma_log__ [id EB]\n", - " └─ Sum{axes=None} [id EP] '__logp'\n", - " └─ MakeVector{dtype='float64'} [id EQ]\n", - " ├─ Sum{axes=None} [id ER]\n", - " │ └─ Check{sigma > 0} [id ES] 'X_logprob'\n", - " │ ├─ Sub [id ET]\n", - " │ │ ├─ Sub [id EU]\n", - " │ │ │ ├─ Mul [id EV]\n", - " │ │ │ │ ├─ ExpandDims{axes=[0, 1]} [id EW]\n", - " │ │ │ │ │ └─ -0.5 [id EX]\n", - " │ │ │ │ └─ Pow [id EY]\n", - " │ │ │ │ ├─ True_div [id EZ]\n", - " │ │ │ │ │ ├─ Sub [id FA]\n", - " │ │ │ │ │ │ ├─ X [id V]\n", - " │ │ │ │ │ │ └─ [[0]] [id FB]\n", - " │ │ │ │ │ └─ [[1]] [id FC]\n", - " │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id FD]\n", - " │ │ │ │ └─ 2 [id FE]\n", - " │ │ │ └─ ExpandDims{axes=[0, 1]} [id FF]\n", - " │ │ │ └─ Log [id FG]\n", - " │ │ │ └─ Sqrt [id FH]\n", - " │ │ │ └─ 6.283185307179586 [id FI]\n", - " │ │ └─ Log [id FJ]\n", - " │ │ └─ [[1]] [id FC]\n", - " │ └─ All{axes=None} [id FK]\n", - " │ └─ MakeVector{dtype='bool'} [id FL]\n", - " │ └─ All{axes=None} [id FM]\n", - " │ └─ Gt [id FN]\n", - " │ ├─ [[1]] [id FC]\n", - " │ └─ ExpandDims{axes=[0, 1]} [id FO]\n", - " │ └─ 0 [id FP]\n", - " ├─ Sum{axes=None} [id FQ]\n", - " │ └─ Check{sigma > 0} [id FR] 'alpha_logprob'\n", - " │ ├─ Sub [id FS]\n", - " │ │ ├─ Sub [id FT]\n", - " │ │ │ ├─ Mul [id FU]\n", - " │ │ │ │ ├─ -0.5 [id FV]\n", - " │ │ │ │ └─ Pow [id FW]\n", - " │ │ │ │ ├─ True_div [id FX]\n", - " │ │ │ │ │ ├─ Sub [id FY]\n", - " │ │ │ │ │ │ ├─ alpha [id S]\n", - " │ │ │ │ │ │ └─ 100 [id FZ]\n", - " │ │ │ │ │ └─ 10 [id GA]\n", - " │ │ │ │ └─ 2 [id GB]\n", - " │ │ │ └─ Log [id GC]\n", - " │ │ │ └─ Sqrt [id GD]\n", - " │ │ │ └─ 6.283185307179586 [id GE]\n", - " │ │ └─ Log [id GF]\n", - " │ │ └─ 10 [id GA]\n", - " │ └─ All{axes=None} [id GG]\n", - " │ └─ MakeVector{dtype='bool'} [id GH]\n", - " │ └─ All{axes=None} [id GI]\n", - " │ └─ Gt [id GJ]\n", - " │ ├─ 10 [id GA]\n", - " │ └─ 0 [id GK]\n", - " ├─ Sum{axes=None} [id GL]\n", - " │ └─ Check{sigma > 0} [id GM] 'beta_logprob'\n", - " │ ├─ Sub [id GN]\n", - " │ │ ├─ Sub [id GO]\n", - " │ │ │ ├─ Mul [id GP]\n", - " │ │ │ │ ├─ ExpandDims{axis=0} [id GQ]\n", - " │ │ │ │ │ └─ -0.5 [id GR]\n", - " │ │ │ │ └─ Pow [id GS]\n", - " │ │ │ │ ├─ True_div [id GT]\n", - " │ │ │ │ │ ├─ Sub [id GU]\n", - " │ │ │ │ │ │ ├─ beta [id X]\n", - " │ │ │ │ │ │ └─ [0] [id GV]\n", - " │ │ │ │ │ └─ [5] [id GW]\n", - " │ │ │ │ └─ ExpandDims{axis=0} [id GX]\n", - " │ │ │ │ └─ 2 [id GY]\n", - " │ │ │ └─ ExpandDims{axis=0} [id GZ]\n", - " │ │ │ └─ Log [id HA]\n", - " │ │ │ └─ Sqrt [id HB]\n", - " │ │ │ └─ 6.283185307179586 [id HC]\n", - " │ │ └─ Log [id HD]\n", - " │ │ └─ [5] [id GW]\n", - " │ └─ All{axes=None} [id HE]\n", - " │ └─ MakeVector{dtype='bool'} [id HF]\n", - " │ └─ All{axes=None} [id HG]\n", - " │ └─ Gt [id HH]\n", - " │ ├─ [5] [id GW]\n", - " │ └─ ExpandDims{axis=0} [id HI]\n", - " │ └─ 0 [id HJ]\n", - " └─ Sum{axes=None} [id HK]\n", - " └─ Add [id HL] 'sigma_log___logprob'\n", - " ├─ Check{mu >= 0} [id HM]\n", - " │ ├─ Switch [id HN]\n", - " │ │ ├─ Ge [id HO]\n", - " │ │ │ ├─ Exp [id HP]\n", - " │ │ │ │ └─ sigma_log__ [id BA]\n", - " │ │ │ └─ 0.0 [id HQ]\n", - " │ │ ├─ Sub [id HR]\n", - " │ │ │ ├─ Neg [id HS]\n", - " │ │ │ │ └─ Log [id HT]\n", - " │ │ │ │ └─ 1.0 [id HU]\n", - " │ │ │ └─ True_div [id HV]\n", - " │ │ │ ├─ Exp [id HP]\n", - " │ │ │ │ └─ ···\n", - " │ │ │ └─ 1.0 [id HU]\n", - " │ │ └─ -inf [id HW]\n", - " │ └─ All{axes=None} [id HX]\n", - " │ └─ MakeVector{dtype='bool'} [id HY]\n", - " │ └─ All{axes=None} [id HZ]\n", - " │ └─ Ge [id IA]\n", - " │ ├─ 1.0 [id HU]\n", - " │ └─ 0 [id IB]\n", - " └─ Identity [id IC] 'sigma_log___log_jacobian'\n", - " └─ sigma_log__ [id BA]\n" - ] - }, { "data": { "text/plain": [ - "" + "array([102.16132511, 10.75292414, 0.54980953, -7.64875998,\n", + " 8.5053264 , -8.56422778, 0.70840797, -0.57081651,\n", + " -2.45245893, -5.30737734, -1.33080016, 0.25923082])" ] }, - "execution_count": 24, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "compute_loss(m_obs, guide_model).dprint()" + "true_params" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "d155685f", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "98fe5067", + "cell_type": "markdown", + "id": "0cf321e6", "metadata": {}, - "outputs": [], "source": [ - "## TODO:\n", - "# 1. Create hyperparameters for mean field approx (mu + sigma of normals)\n", - "# 2. Replace in the logp" + "## Todo:\n", + "\n", + "- Does this \"two models\" frameworks fits into what we already have?\n", + "- `model_to_mean_field` transformation\n", + "- rsample --> stochastic gradients? Or automatic reparameterization?\n", + "- More flexible optimizers..." ] }, { "cell_type": "code", "execution_count": null, - "id": "e01dcf96", + "id": "77786d86", "metadata": {}, "outputs": [], "source": [] From ea2c917374af22abc998024f99991763a9101294 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 8 Jun 2025 13:02:04 +0200 Subject: [PATCH 3/5] Adding overview notebook --- VI_Overview.ipynb | 563 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 563 insertions(+) create mode 100644 VI_Overview.ipynb diff --git a/VI_Overview.ipynb b/VI_Overview.ipynb new file mode 100644 index 000000000..5eb66fc86 --- /dev/null +++ b/VI_Overview.ipynb @@ -0,0 +1,563 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c51c3c1a-553c-45e4-a92f-d75063187863", + "metadata": {}, + "source": [ + "# Variational Inference overview" + ] + }, + { + "cell_type": "markdown", + "id": "c0777ef6-dd90-4452-88af-69a53f0f7713", + "metadata": {}, + "source": [ + "## Existing Variational Inference implementation\n", + "\n", + "The best way to get a sense for the current implementation is to walk backwards from how it's used" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "3d5724fe-72db-4908-a464-46f7fac97309", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pymc as pm\n", + "import arviz as az" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "33437d00-60e6-4505-8b8e-8ebe64473c5b", + "metadata": {}, + "outputs": [], + "source": [ + "data = np.random.normal(size=10_000)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "894f9e31-90a1-4f13-b2bc-b75bdf996f78", + "metadata": {}, + "outputs": [], + "source": [ + "with pm.Model() as model:\n", + " d = pm.Data(\"data\", data)\n", + " batched_data = pm.Minibatch(d, batch_size=100)\n", + " x = pm.Normal(\"x\", 0., 1.)\n", + " y = pm.Normal(\"y\", x, total_size=len(data), observed=batched_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "15bc2997-8974-4d61-88ce-9423b215f84f", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a695972a8ca3415f9f8dd118ae6288dc", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Finished [100%]: Average Loss = 144.77\n" + ] + } + ], + "source": [ + "with model:\n", + " idata = pm.fit(n=10_000, method=\"advi\") " + ] + }, + { + "cell_type": "markdown", + "id": "d311e2f2-f264-4cb2-9287-21d6f5aad3e3", + "metadata": {}, + "source": [ + "But what does fit do? It roughly dispatches on the method. So the above is roughly equalivalent to:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ec3b637d-c6a2-46cb-99bc-87fc952bda3e", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "43686ad598a649b09a88b84480985eac", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Finished [100%]: Average Loss = 143.83\n" + ] + } + ], + "source": [ + "with model:\n", + " advi = pm.ADVI()\n", + " idata = advi.fit(n=100_000)" + ] + }, + { + "cell_type": "markdown", + "id": "bfbd2a63-b5da-41d3-a1d1-254934ad4923", + "metadata": {}, + "source": [ + "But what is this `ADVI` object? Well, if you look at it's implementation with the documentation removed, you see it's a type of `KLqp`\n", + "\n", + "````python\n", + "class ADVI(KLqp):\n", + " def __init__(self, *args, **kwargs):\n", + " super().__init__(MeanField(*args, **kwargs))\n", + "````\n", + "\n", + "So what's a `Klqp`? Look at it's implementation with the documentation removed, you see it's an Inference object\n", + "\n", + "````python\n", + "class KLqp(Inference):\n", + " def __init__(self, approx, beta=1.0):\n", + " super().__init__(KL, approx, None, beta=beta)\n", + "````\n", + "\n", + "So what's an `Inference` object? Look at it's implementation with the documentation removed we finally get a sense for what are the main abstraction we will be working with." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ce19d0bd-8a6b-4877-a4ee-5ee1fa481da8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[0;31mInit signature:\u001b[0m \u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mInference\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapprox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m \n", + "**Base class for Variational Inference**.\n", + "\n", + "Communicates Operator, Approximation and Test Function to build Objective Function\n", + "\n", + "Parameters\n", + "----------\n", + "op : Operator class #:class:`~pymc.variational.operators`\n", + "approx : Approximation class or instance #:class:`~pymc.variational.approximations`\n", + "tf : TestFunction instance #?\n", + "model : Model\n", + " PyMC Model\n", + "kwargs : kwargs passed to :class:`Operator` #:class:`~pymc.variational.operators`, optional\n", + "\u001b[0;31mFile:\u001b[0m ~/upstream/pymc/pymc/variational/inference.py\n", + "\u001b[0;31mType:\u001b[0m type\n", + "\u001b[0;31mSubclasses:\u001b[0m KLqp, ImplicitGradient" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pm.Inference?" + ] + }, + { + "cell_type": "markdown", + "id": "73873c2f-11ed-43a3-b256-2e65816343b9", + "metadata": {}, + "source": [ + "Now things are falling into place. The `Inference` class is the way we perform variational inference. This is where the actual fit machinery lives. It also highlights what we need to do variational inference. We need a `Model`, an `Operator`, and an `Approximation`. We already know for `ADVI`, that the `Operator` is `KL` and the `Approximation` is `MeanField`.\n", + "\n", + "But what do these things mean? And how are they combined to perform inference?\n", + "\n", + "Well the `__init__` method of `Inference` makes it where we can find our answer" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0a77e663-8982-4a8a-bafc-290da5f45838", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[0;31mSignature:\u001b[0m \u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mInference\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapprox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m Initialize self. See help(type(self)) for accurate signature.\n", + "\u001b[0;31mSource:\u001b[0m \n", + " \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapprox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobjective\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mapprox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFile:\u001b[0m ~/upstream/pymc/pymc/variational/inference.py\n", + "\u001b[0;31mType:\u001b[0m function" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pm.Inference.__init__??" + ] + }, + { + "cell_type": "markdown", + "id": "0564cc40-2d42-476d-8fa4-91a063bff433", + "metadata": {}, + "source": [ + "Alright, so let's go ahead and explore the operator `KL`\n", + "\n", + "````python\n", + "class KL(Operator):\n", + " def __init__(self, approx, beta=1.0):\n", + " super().__init__(approx)\n", + " self.beta = pm.floatX(beta)\n", + "\n", + " def apply(self, f):\n", + " return -self.datalogp_norm + self.beta * (self.logq_norm - self.varlogp_norm)\n", + "````\n", + "\n", + "We see no `__call__` but we see a call to the `__init__` of `Operator`. For the `apply` method we see what looks like the ELBO. Let's for now inline for `ADVI` case and see what we get" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1752272d-b32d-4bea-9c3c-1e331b7fda9a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "objective = pm.operators.KL(pm.MeanField(model=model))(None)\n", + "objective" + ] + }, + { + "cell_type": "markdown", + "id": "cb70a473-6ca3-471c-b051-9a3cee666bd6", + "metadata": {}, + "source": [ + "So how'd that happen? Well if you look in the `Objective` class you see\n", + "\n", + "````python\n", + " objective_class = ObjectiveFunction\n", + "\n", + " def __call__(self, f=None):\n", + " if self.has_test_function:\n", + " if f is None:\n", + " raise ParametrizationError(f\"Operator {self} requires TestFunction\")\n", + " else:\n", + " if not isinstance(f, TestFunction):\n", + " f = TestFunction.from_function(f)\n", + " else:\n", + " if f is not None:\n", + " warnings.warn(f\"TestFunction for {self} is redundant and removed\", stacklevel=3)\n", + " else:\n", + " pass\n", + " f = TestFunction()\n", + " f.setup(self.approx)\n", + " return self.objective_class(self, f)\n", + "````\n", + "\n", + "Which finally brings us to `ObjectiveFunction`\n", + "\n", + "This is the function that sets up the actual loss functions and does the updates on it." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "339f72bf-a40d-4f84-9161-99a20ad090cb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[0;31mSignature:\u001b[0m\n", + "\u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopvi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mObjectiveFunction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mobj_n_mc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtf_n_mc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mobj_optimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m<\u001b[0m\u001b[0mfunction\u001b[0m \u001b[0madagrad_window\u001b[0m \u001b[0mat\u001b[0m \u001b[0;36m0x70ee648da480\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtest_optimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m<\u001b[0m\u001b[0mfunction\u001b[0m \u001b[0madagrad_window\u001b[0m \u001b[0mat\u001b[0m \u001b[0;36m0x70ee648da480\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mmore_obj_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mmore_tf_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mmore_updates\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mmore_replacements\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtotal_grad_norm_constraint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mscore\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mcompile_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mfn_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m\n", + "Step function that should be called on each optimization step.\n", + "\n", + "Generally it solves the following problem:\n", + "\n", + ".. math::\n", + "\n", + " \\mathbf{\\lambda^{\\*}} = \\inf_{\\lambda} \\sup_{\\theta} t(\\mathbb{E}_{\\lambda}[(O^{p,q}f_{\\theta})(z)])\n", + "\n", + "Parameters\n", + "----------\n", + "obj_n_mc: `int`\n", + " Number of monte carlo samples used for approximation of objective gradients\n", + "tf_n_mc: `int`\n", + " Number of monte carlo samples used for approximation of test function gradients\n", + "obj_optimizer: function (grads, params) -> updates\n", + " Optimizer that is used for objective params\n", + "test_optimizer: function (grads, params) -> updates\n", + " Optimizer that is used for test function params\n", + "more_obj_params: `list`\n", + " Add custom params for objective optimizer\n", + "more_tf_params: `list`\n", + " Add custom params for test function optimizer\n", + "more_updates: `dict`\n", + " Add custom updates to resulting updates\n", + "total_grad_norm_constraint: `float`\n", + " Bounds gradient norm, prevents exploding gradient problem\n", + "score: `bool`\n", + " calculate loss on each step? Defaults to False for speed\n", + "compile_kwargs: `dict`\n", + " Add kwargs to pytensor.function (e.g. `{'profile': True}`)\n", + "fn_kwargs: dict\n", + " arbitrary kwargs passed to `pytensor.function`\n", + "\n", + " .. warning:: `fn_kwargs` is deprecated and will be removed in future versions\n", + "\n", + "more_replacements: `dict`\n", + " Apply custom replacements before calculating gradients\n", + "\n", + "Returns\n", + "-------\n", + "`pytensor.function`\n", + "\u001b[0;31mFile:\u001b[0m ~/upstream/pymc/pymc/variational/opvi.py\n", + "\u001b[0;31mType:\u001b[0m function" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pm.opvi.ObjectiveFunction.step_function?" + ] + }, + { + "cell_type": "markdown", + "id": "83b5504c-6a99-40eb-8b45-b485a8618ecd", + "metadata": {}, + "source": [ + "## Proposed Improvements\n", + "\n", + "There is a lot to like here, but there is also a lot of indirection. Further, much of it isn't used for the `ADVI` case. This is all in service of `SVGD` and `ASVGD`\n", + "\n", + "Further, the `Inference` class has to be aware of too many of these details. Ideally the `Inference` should be reworked to only take in a step function. It could be re-named `Trainer` to match what's in PyTorch Lightning. I think forcing all `VI` through `OPVI` makes it more challenging to write and port new `VI` algorithms to `pymc`" + ] + }, + { + "cell_type": "markdown", + "id": "f7e83647-ac36-4043-8bb5-8fcf5581ffa3", + "metadata": {}, + "source": [ + "### PyTorch Lightning and Optax optimization" + ] + }, + { + "cell_type": "markdown", + "id": "ece80fe0-e332-444d-b1e3-761aca6a02e8", + "metadata": {}, + "source": [ + "How would this look? One possibility is having each Variational Inference technique encapsulated into an object that takes a model and optimizer as inputs and provides a step function as a method.\n", + "\n", + "````python\n", + "class ADVI(Inference):\n", + " def __init__(self, model=None, optimizers=None):\n", + " ...\n", + "\n", + " def step(self, batched_data):\n", + " ...\n", + " return loss\n", + "````\n", + "\n", + "This is then passed to a `Trainer` object for fitting\n", + "\n", + "````python\n", + "with model:\n", + " trainer = Trainer(method=ADVI(), dataloader= ...)\n", + " trainer.fit(n=10_000)\n", + "````\n", + "\n", + "Under this setup most of the optimization logic moves into the `__init__` and `step` methods. As for how those should happen. I think this can be handled separately. But something like optax might not be so bad. So we could end with code that resembles the below\n", + "\n", + "````python\n", + "class ADVI(Inference):\n", + " def __init__(self, model=None, optimizers=None):\n", + " if model is None:\n", + " model = modelcontext(None)\n", + " if optimizers is None:\n", + " optimizers = [pm.opt.Adam(1e-3)]\n", + " self.optimizer = optimizers[0]\n", + " self.params = self.optimizer.init(model.basic_RVs)\n", + "\n", + " def step(self, batch):\n", + " loss = self.loss_function(self.params, batch)\n", + " grads = grad(loss)\n", + " self.params = self.optimizer.update(grads, self.params)\n", + " return loss\n", + "````" + ] + }, + { + "cell_type": "markdown", + "id": "d5572e2b-6a32-492d-9cb0-d003a14f490d", + "metadata": {}, + "source": [ + "### Model and Guide programs\n", + "\n", + "Additionally it would be nice if we could easily suppose variational inference with guide programs ala pyro/numpyro\n", + "\n", + "The way this could look is we define both as `pymc` models and then pass them to a `SVI` method\n", + "\n", + "````python\n", + "with pm.Model() as model:\n", + " x = pm.Normal(\"x\", 0, 1)\n", + " y = pm.Normal(\"y\", x, 1, observed=data)\n", + "\n", + "with pm.Model() as guide:\n", + " mu = pt.tensor(\"mu\", param=True)\n", + " sd = pt.tensor(\"sd\", param=True)\n", + " x = pm.Normal(\"x\", mu, sd)\n", + "\n", + "\n", + "with model:\n", + " trainer = Trainer(method=SVI(model, guide), dataloader= ...)\n", + " trainer.fit(n=10_000)\n", + "````\n", + "\n", + "Naturally, `SVI` is a very general inference method, and in fact we can re-define `ADVI` in terms of it. Following the lead of pyro/numpyro we can have a guide generation\n", + "\n", + "````python\n", + "with model:\n", + " guide = AutoGuide(model)\n", + " trainer = Trainer(method=SVI(model, guide), dataloader= ...)\n", + " trainer.fit(n=10_000)\n", + "````" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50fcb3a1-4467-4ace-acdd-666e4f342984", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc-dev", + "language": "python", + "name": "pymc-dev" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From a870c7c63f848825b8a0f879fc17324c446df4f9 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 8 Jun 2025 13:38:48 +0200 Subject: [PATCH 4/5] Add minibatch proposal --- VI_Overview.ipynb | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/VI_Overview.ipynb b/VI_Overview.ipynb index 5eb66fc86..8e3ee91e8 100644 --- a/VI_Overview.ipynb +++ b/VI_Overview.ipynb @@ -460,7 +460,7 @@ " def __init__(self, model=None, optimizers=None):\n", " ...\n", "\n", - " def step(self, batched_data):\n", + " def step(self, batch):\n", " ...\n", " return loss\n", "````\n", @@ -506,6 +506,7 @@ "\n", "````python\n", "with pm.Model() as model:\n", + " data = pm.Data(\"data\", ...)\n", " x = pm.Normal(\"x\", 0, 1)\n", " y = pm.Normal(\"y\", x, 1, observed=data)\n", "\n", @@ -530,10 +531,44 @@ "````" ] }, + { + "cell_type": "markdown", + "id": "7f97c341-e9bb-4301-b452-d006d6408cec", + "metadata": {}, + "source": [ + "### Reworking Minibatch\n", + "\n", + "Another small change we should consider is moving `pm.Minibatch` out of the model. Max already has a [proposal](https://github.com/pymc-devs/pymc/issues/7496) that I think can be adopted with only a few changes.\n", + "\n", + "I think where before we explicitly minibatch the data, instead we have dataloaders that stream in updates to the model.\n", + "\n", + "````python\n", + "with pm.Model() as model:\n", + " data = pm.Data(\"data\", None)\n", + " x = pm.Normal(\"x\", 0, 1)\n", + " y = pm.Normal(\"y\", x, 1, observed=data)\n", + "\n", + "dataloader = pm.Dataloader(np.random.normal(10_000, 2), batch_size=64)\n", + "\n", + "with model:\n", + " trainer = Trainer(method=ADVI(), dataloader=dataloader)\n", + " trainer.fit(n=10_000)\n", + "````\n", + "\n", + "Importantly, the model doesn't need to know about the dataloader. We will need to tweak the inference object, but it's not so bad.\n", + "\n", + "````python\n", + "class ADVI(Inference):\n", + " def step(self, batch):\n", + " self.model.set_data(\"data\", batch)\n", + " ...\n", + "````" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "50fcb3a1-4467-4ace-acdd-666e4f342984", + "id": "220ba769-fb8f-47a7-82b6-ab6ca13ad61e", "metadata": {}, "outputs": [], "source": [] From b7aa31d1ab4de1a9976cbcb3381aeabc54e1ff4d Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 8 Jun 2025 21:55:50 +0800 Subject: [PATCH 5/5] Example of an autoguide --- pymc/variational/autoguide.py | 95 +++++++++++++++++++ tests/variational/test_autoguide.py | 137 ++++++++++++++++++++++++++++ 2 files changed, 232 insertions(+) create mode 100644 pymc/variational/autoguide.py create mode 100644 tests/variational/test_autoguide.py diff --git a/pymc/variational/autoguide.py b/pymc/variational/autoguide.py new file mode 100644 index 000000000..be8906264 --- /dev/null +++ b/pymc/variational/autoguide.py @@ -0,0 +1,95 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytensor.tensor as pt + +from pytensor import Variable, graph_replace +from pytensor.graph import vectorize_graph + +import pymc as pm + +from pymc.model.core import Model + +ModelVariable = Variable | str + + +def AutoDiagonalNormal(model): + coords = model.coords + free_rvs = model.free_RVs + draws = pt.tensor("draws", shape=(), dtype="int64") + + with Model(coords=coords) as guide_model: + for rv in free_rvs: + loc = pt.tensor(f"{rv.name}_loc", shape=rv.type.shape) + scale = pt.tensor(f"{rv.name}_scale", shape=rv.type.shape) + z = pm.Normal( + f"{rv.name}_z", + mu=0, + sigma=1, + shape=(draws, *rv.type.shape), + transform=model.rvs_to_transforms[rv], + ) + pm.Deterministic( + rv.name, loc + scale * z, dims=model.named_vars_to_dims.get(rv.name, None) + ) + + return guide_model + + +def AutoFullRankNormal(model): + # TODO: Broken + + coords = model.coords + free_rvs = model.free_RVs + draws = pt.tensor("draws", shape=(), dtype="int64") + + rv_sizes = [np.prod(rv.type.shape) for rv in free_rvs] + total_size = np.sum(rv_sizes) + tril_size = total_size * (total_size + 1) // 2 + + locs = [pt.tensor(f"{rv.name}_loc", shape=rv.type.shape) for rv in free_rvs] + packed_L = pt.tensor("L", shape=(tril_size,), dtype="float64") + L = pm.expand_packed_triangular(packed_L) + + with Model(coords=coords) as guide_model: + z = pm.MvNormal( + "z", mu=np.zeros(total_size), cov=np.eye(total_size), size=(draws, total_size) + ) + params = pt.concatenate([loc.ravel() for loc in locs]) + L @ z + + cursor = 0 + + for rv, size in zip(free_rvs, rv_sizes): + pm.Deterministic( + rv.name, + params[cursor : cursor + size].reshape(rv.type.shape), + dims=model.named_vars_to_dims.get(rv.name, None), + ) + cursor += size + + return guide_model + + +def get_logp_logq(model, guide_model): + inputs_to_guide_rvs = { + model_value_var: guide_model[rv.name] + for rv, model_value_var in model.rvs_to_values.items() + if rv not in model.observed_RVs + } + + logp = vectorize_graph(model.logp(), inputs_to_guide_rvs) + logq = graph_replace(guide_model.logp(), guide_model.values_to_rvs) + + return logp, logq diff --git a/tests/variational/test_autoguide.py b/tests/variational/test_autoguide.py new file mode 100644 index 000000000..9ff53a38e --- /dev/null +++ b/tests/variational/test_autoguide.py @@ -0,0 +1,137 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytensor.tensor as pt +import pytest + +import pymc as pm + +from pymc.variational.autoguide import AutoDiagonalNormal, AutoFullRankNormal, get_logp_logq + +Parameter = pt.tensor + + +@pytest.fixture(scope="module") +def X_y_params(): + """Generate synthetic data for testing.""" + + rng = np.random.default_rng(sum(map(ord, "autoguide_test"))) + + alpha = rng.normal(loc=100, scale=10) + beta = rng.normal(loc=0, scale=1, size=(10,)) + + true_params = { + "alpha": alpha, + "beta": beta, + } + + X_data = rng.normal(size=(100, 10)) + y_data = alpha + X_data @ beta + + return X_data, y_data, true_params + + +@pytest.fixture(scope="module") +def model(X_y_params): + X_data, y_data, _ = X_y_params + + with pm.Model() as model: + X = pm.Data("X", X_data) + alpha = pm.Normal("alpha", 100, 10) + beta = pm.Normal("beta", 0, 5, size=(10,)) + + mu = alpha + X @ beta + sigma = pm.Exponential("sigma", 1) + y = pm.Normal("y", mu=mu, sigma=sigma, observed=y_data) + + return model + + +@pytest.fixture(scope="module") +def target_guide_model(X_y_params): + X_data, *_ = X_y_params + + draws = pt.tensor("draws", shape=(), dtype="int64") + + with pm.Model() as guide_model: + X = pm.Data("X", X_data) + + alpha_loc = Parameter("alpha_loc", shape=()) + alpha_scale = Parameter("alpha_scale", shape=()) + alpha_z = pm.Normal("alpha_z", mu=0, sigma=1, shape=(draws,)) + alpha = pm.Deterministic("alpha", alpha_loc + alpha_scale * alpha_z) + + beta_loc = Parameter("beta_loc", shape=(10,)) + beta_scale = Parameter("beta_scale", shape=(10,)) + beta_z = pm.Normal("beta_z", mu=0, sigma=1, shape=(draws, 10)) + beta = pm.Deterministic("beta", beta_loc + beta_scale * beta_z) + + sigma_loc = Parameter("sigma_loc", shape=()) + sigma_scale = Parameter("sigma_scale", shape=()) + sigma_z = pm.Normal( + "sigma_z", 0, 1, shape=(draws,), transform=pm.distributions.transforms.log + ) + sigma = pm.Deterministic("sigma", sigma_loc + sigma_scale * sigma_z) + + return guide_model + + +def test_diagonal_normal_autoguide(model, target_guide_model, X_y_params): + guide_model = AutoDiagonalNormal(model) + + logp, logq = get_logp_logq(model, guide_model) + logp_target, logq_target = get_logp_logq(model, target_guide_model) + + inputs = pm.inputvars(logp) + target_inputs = pm.inputvars(logp_target) + + expected_locs = [f"{var}_loc" for var in ["alpha", "beta", "sigma"]] + expected_scales = [f"{var}_scale" for var in ["alpha", "beta", "sigma"]] + + expected_inputs = expected_locs + expected_scales + ["draws"] + name_to_input = {input.name: input for input in inputs} + name_to_target_input = {input.name: input for input in target_inputs} + + assert all(input.name in expected_inputs for input in inputs), ( + "Guide inputs do not match expected inputs" + ) + + negative_elbo = (logq - logp).mean() + negative_elbo_target = (logq_target - logp_target).mean() + + fn = pm.compile( + [name_to_input[input] for input in expected_inputs], negative_elbo, random_seed=69420 + ) + fn_target = pm.compile( + [name_to_target_input[input] for input in expected_inputs], + negative_elbo_target, + random_seed=69420, + ) + + test_inputs = { + "alpha_loc": np.zeros(()), + "alpha_scale": np.ones(()), + "beta_loc": np.zeros(10), + "beta_scale": np.ones(10), + "sigma_loc": np.zeros(()), + "sigma_scale": np.ones(()), + "draws": 100, + } + + np.testing.assert_allclose(fn(**test_inputs), fn_target(**test_inputs)) + + +def test_full_mv_normal_guide(model, X_y_params): + guide_model = AutoFullRankNormal(model) + logp, logq = get_logp_logq(model, guide_model)