From 28cbbca11815bcbced81d403dda342012afdf064 Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sat, 24 May 2025 00:49:48 +0800 Subject: [PATCH 1/5] hacky notebook --- Rewrite VI, why not.ipynb | 751 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 751 insertions(+) create mode 100644 Rewrite VI, why not.ipynb diff --git a/Rewrite VI, why not.ipynb b/Rewrite VI, why not.ipynb new file mode 100644 index 000000000..f2a6c3794 --- /dev/null +++ b/Rewrite VI, why not.ipynb @@ -0,0 +1,751 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 6, + "id": "9f946eb4", + "metadata": {}, + "outputs": [], + "source": [ + "import pytensor.tensor as pt\n", + "\n", + "import pymc as pm" + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "id": "e746bc33", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling: [X, alpha, beta, sigma, y]\n" + ] + } + ], + "source": [ + "with pm.Model() as m:\n", + " X = pm.Normal(\"X\", 0, 1, size=(100, 10))\n", + " alpha = pm.Normal(\"alpha\", 100, 10)\n", + " beta = pm.Normal(\"beta\", 0, 5, size=(10,))\n", + "\n", + " mu = alpha + X @ beta\n", + " sigma = pm.Exponential(\"sigma\", 1)\n", + " y = pm.Normal(\"y\", mu=mu, sigma=sigma)\n", + "\n", + " prior = pm.sample_prior_predictive()" + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "id": "a8ca0161", + "metadata": {}, + "outputs": [], + "source": [ + "draw = 123\n", + "true_params = np.r_[\n", + " prior.prior.alpha.sel(chain=0, draw=draw).values, prior.prior.beta.sel(chain=0, draw=draw)\n", + "]\n", + "X_data = prior.prior.X.sel(chain=0, draw=draw).values\n", + "y_data = prior.prior.y.sel(chain=0, draw=draw).values" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "id": "b89f4031", + "metadata": {}, + "outputs": [], + "source": [ + "m_obs = pm.observe(pm.do(m, {X: X_data}), {\"y\": y_data})" + ] + }, + { + "cell_type": "code", + "execution_count": 168, + "id": "a42a84e4", + "metadata": {}, + "outputs": [], + "source": [ + "Parameter = pt.tensor\n", + "\n", + "draws = pt.tensor(\"draws\", shape=(), dtype=\"int64\")\n", + "\n", + "with pm.Model() as guide_model:\n", + " X = pm.Data(\"X\", X_data)\n", + " alpha_loc = Parameter(\"alpha_loc\", shape=())\n", + " alpha_scale = Parameter(\"alpha_scale\", shape=())\n", + " alpha_z = pm.Normal(\"alpha_z\", mu=0, sigma=1, shape=(draws,))\n", + " alpha = pm.Deterministic(\"alpha\", alpha_loc + alpha_scale * alpha_z)\n", + "\n", + " beta_loc = Parameter(\"beta_loc\", shape=(10,))\n", + " beta_scale = Parameter(\"beta_scale\", shape=(10,))\n", + " beta_z = pm.Normal(\"beta_z\", mu=0, sigma=1, shape=(draws, 10))\n", + " beta = pm.Deterministic(\"beta\", beta_loc + beta_scale * beta_z)\n", + "\n", + " mu = alpha + X @ beta\n", + "\n", + " sigma_loc = Parameter(\"sigma_loc\", shape=())\n", + " sigma_scale = Parameter(\"sigma_scale\", shape=())\n", + " sigma_z = pm.Normal(\"sigma_z\", 0, 1, shape=(draws,))\n", + " sigma = pm.Deterministic(\"sigma\", pt.softplus(sigma_loc + sigma_scale * sigma_z))\n", + "\n", + "# with pm.Model() as guide_model2:\n", + "# n = 10 + 1 + 1\n", + "# loc = Parameter(\"loc\", shape=(n,))\n", + "# chol_flat = Parameter(\"chol\", shape=(n * n-1, ))\n", + "# chol = pm.expand_packed_triangular(n, chol_flat)\n", + "# latent_mvn = pm.MvNormal(\"latent_mvn\", chol=chol)\n", + "\n", + "# pm.Deterministic(\"beta\", latent_mvn[:10])\n", + "# pm.Deterministic(\"alpha\", latent_mvn[10])\n", + "# pm.Deterministic(\"sigma\", pm.math.exp(latent_mvn[11]))" + ] + }, + { + "cell_type": "code", + "execution_count": 169, + "id": "bffb1b69", + "metadata": {}, + "outputs": [], + "source": [ + "params = [alpha_loc, alpha_scale, beta_loc, beta_scale, sigma_loc, sigma_scale]" + ] + }, + { + "cell_type": "code", + "execution_count": 171, + "id": "e2c4dd95", + "metadata": {}, + "outputs": [], + "source": [ + "f_draw = pm.compile([*params, draws], guide_model.deterministics)" + ] + }, + { + "cell_type": "code", + "execution_count": 182, + "id": "bfe44b73", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[array([1.78150955]),\n", + " array([[ 0.22971278, 0.4621461 , 0.81535912, 0.62397751, 1.12162984,\n", + " 0.99310042, -0.04733258, 1.20791346, 0.61310399, 0.6248215 ]]),\n", + " array([1.04612599])]" + ] + }, + "execution_count": 182, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f_draw(**param_dict, draws=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 204, + "id": "e831865f", + "metadata": {}, + "outputs": [], + "source": [ + "init_dict = m_obs.initial_point()\n", + "init_dict = {k: np.expand_dims(v, 0) for k, v in init_dict.items()}\n", + "param_dict = {param.name: np.full(param.type.shape, 0.5) for param in params}" + ] + }, + { + "cell_type": "code", + "execution_count": 187, + "id": "a9f3fa0e", + "metadata": {}, + "outputs": [], + "source": [ + "from pytensor.graph.replace import graph_replace, vectorize_graph\n", + "\n", + "outputs = [m_obs.datalogp, m_obs.varlogp]\n", + "inputs = m_obs.value_vars\n", + "inputs_to_guide_rvs = {\n", + " model_value_var: guide_model[rv.name]\n", + " for rv, model_value_var in m_obs.rvs_to_values.items()\n", + " if rv not in m_obs.observed_RVs\n", + "}\n", + "model_logp = vectorize_graph(m_obs.logp(), inputs_to_guide_rvs)\n", + "guide_logq = graph_replace(guide_model.logp(), guide_model.values_to_rvs)\n", + "\n", + "elbo_loss = (guide_logq - model_logp).mean()\n", + "d_loss = pt.grad(elbo_loss, params)\n", + "\n", + "f_loss_dloss = pm.compile(params + [draws], [elbo_loss, *d_loss], trust_input=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 207, + "id": "6086f2cc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-993.9173656892265\n", + "-926.3296229583198\n", + "-1005.6712508438249\n", + "-919.6431050083108\n", + "-968.0575921556582\n", + "-987.30412574489\n", + "-964.7096352135038\n", + "-983.4362084632743\n", + "-910.8310347546233\n", + "-946.7555282915881\n", + "-943.9616597508198\n", + "-963.8766211135949\n", + "-916.6993999462212\n", + "-950.0475547851752\n", + "-967.001546433409\n", + "-943.9433767873104\n", + "-946.3015434524876\n", + "-934.0032018513697\n", + "-947.6452058569878\n", + "-893.224934118183\n", + "-979.4405814988748\n", + "-937.9997192780036\n", + "-931.1970724735944\n", + "-959.3588230003159\n", + "-932.1233734322371\n", + "-940.2791556640857\n", + "-969.4679954671045\n", + "-954.6606993395337\n", + "-982.9304227234845\n", + "-935.3404389982638\n", + "-982.6885250322749\n", + "-964.4628736035113\n", + "-939.0580477804804\n", + "-955.1672719181267\n", + "-982.0467504680682\n", + "-992.8688427985264\n", + "-967.8588846826874\n", + "-966.7655668600194\n", + "-949.3323016540423\n", + "-934.3919364553586\n", + "-1028.4493525361129\n", + "-982.4944127954707\n", + "-931.9404059809432\n", + "-981.3845063690508\n", + "-930.1688452196312\n", + "-952.0305908505434\n", + "-1012.7969628343917\n", + "-937.6379090307294\n", + "-926.5273721862864\n", + "-981.6665090046366\n", + "-980.4287957334458\n", + "-946.1849036479391\n", + "-969.4075653507354\n", + "-958.7555610935674\n", + "-1004.2802914265156\n", + "-963.0682736675113\n", + "-1000.664222867785\n", + "-934.7298341549889\n", + "-978.1088760452403\n", + "-944.0432566257858\n", + "-980.5576704437719\n", + "-969.5612703209024\n", + "-963.9839441465232\n", + "-966.4944698264674\n", + "-953.2780303578338\n", + "-953.5548215260793\n", + "-959.5050408283469\n", + "-989.7240037630827\n", + "-949.1398757808872\n", + "-964.8820244293524\n", + "-940.224485134473\n", + "-947.7591586789077\n", + "-964.679359616359\n", + "-917.0015516708557\n", + "-939.5619779661295\n", + "-993.9876761804422\n", + "-986.3960585411946\n", + "-975.6545633352575\n", + "-957.4304264740541\n", + "-970.7472516497804\n", + "-956.5223102168254\n", + "-969.825369100695\n", + "-987.9141239728567\n", + "-952.6424507446926\n", + "-958.6852758514054\n", + "-965.0105265156106\n", + "-962.943175120098\n", + "-946.9378229990178\n", + "-944.4992023963096\n", + "-947.4171539678032\n", + "-959.1841786549261\n", + "-931.9541021582517\n", + "-972.675454084302\n", + "-1030.6428849456215\n", + "-983.2067258579275\n", + "-1002.0505700409622\n", + "-1015.4906878612064\n", + "-962.2604231546592\n", + "-955.1833250683931\n", + "-938.1777960392025\n" + ] + } + ], + "source": [ + "learning_rate = 1e-4\n", + "for _ in range(100):\n", + " loss, *grads = f_loss_dloss(**param_dict, draws=100)\n", + " for (name, value), grad in zip(param_dict.items(), grads):\n", + " param_dict[name] = value - learning_rate * grad\n", + " print(loss)" + ] + }, + { + "cell_type": "code", + "execution_count": 208, + "id": "1251b6f1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{'alpha_loc': 0.6942395020077371,\n", + " 'alpha_scale': 0.5102778323184816,\n", + " 'beta_loc': array([0.49114368, 0.46379742, 0.52277764, 0.53227815, 0.48851862,\n", + " 0.50044113, 0.5441339 , 0.46643231, 0.47894475, 0.51122713]),\n", + " 'beta_scale': array([0.49807094, 0.49740464, 0.5005586 , 0.4991539 , 0.49749037,\n", + " 0.49825551, 0.49992182, 0.4981251 , 0.49745959, 0.49767778]),\n", + " 'sigma_loc': 4.2148859620931765,\n", + " 'sigma_scale': -5.964631972194452e-05}" + ] + }, + "execution_count": 208, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "param_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 209, + "id": "11345fb4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([83.43771778, -6.07738876, -3.3268889 , 8.38393732, 10.77212434,\n", + " -2.81776509, 0.46737085, 8.7204497 , -4.79822835, -3.47220908,\n", + " 8.76186526])" + ] + }, + "execution_count": 209, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "true_params" + ] + }, + { + "cell_type": "code", + "execution_count": 165, + "id": "566105c2", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array(273519.9558606)" + ] + }, + "execution_count": 165, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "f_loss(**param_dict, draws=100)" + ] + }, + { + "cell_type": "code", + "execution_count": 166, + "id": "c39cd24e", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[alpha_loc, alpha_scale, beta_loc, beta_scale, sigma_loc, sigma_scale]" + ] + }, + "execution_count": 166, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "params" + ] + }, + { + "cell_type": "code", + "execution_count": 106, + "id": "310c739f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "{alpha: alpha, beta: beta, sigma_log__: sigma}" + ] + }, + "execution_count": 106, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "inputs_to_guide_inputs" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "id": "2fe8da0a", + "metadata": {}, + "outputs": [], + "source": [ + "def compute_loss(m_obs, m_guide, beta=1.0):\n", + " return -m_obs.datalogp + beta * (m_guide.logp() - m_obs.varlogp)" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "b04b9b5c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Add [id A]\n", + " ├─ Neg [id B]\n", + " │ └─ Add [id C]\n", + " │ ├─ Sum{axes=None} [id D] '__logp'\n", + " │ │ └─ MakeVector{dtype='float64'} [id E]\n", + " │ │ └─ Sum{axes=None} [id F]\n", + " │ │ └─ Check{sigma > 0} [id G]\n", + " │ │ ├─ Sub [id H]\n", + " │ │ │ ├─ Sub [id I]\n", + " │ │ │ │ ├─ Mul [id J]\n", + " │ │ │ │ │ ├─ ExpandDims{axis=0} [id K]\n", + " │ │ │ │ │ │ └─ -0.5 [id L]\n", + " │ │ │ │ │ └─ Pow [id M]\n", + " │ │ │ │ │ ├─ True_div [id N]\n", + " │ │ │ │ │ │ ├─ Sub [id O]\n", + " │ │ │ │ │ │ │ ├─ [122.65317 ... .32067026] [id P]\n", + " │ │ │ │ │ │ │ └─ Add [id Q]\n", + " │ │ │ │ │ │ │ ├─ ExpandDims{axis=0} [id R]\n", + " │ │ │ │ │ │ │ │ └─ alpha [id S]\n", + " │ │ │ │ │ │ │ └─ Squeeze{axis=1} [id T]\n", + " │ │ │ │ │ │ │ └─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id U]\n", + " │ │ │ │ │ │ │ ├─ X [id V]\n", + " │ │ │ │ │ │ │ └─ ExpandDims{axis=1} [id W]\n", + " │ │ │ │ │ │ │ └─ beta [id X]\n", + " │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Y]\n", + " │ │ │ │ │ │ └─ Exp [id Z]\n", + " │ │ │ │ │ │ └─ sigma_log__ [id BA]\n", + " │ │ │ │ │ └─ ExpandDims{axis=0} [id BB]\n", + " │ │ │ │ │ └─ 2 [id BC]\n", + " │ │ │ │ └─ ExpandDims{axis=0} [id BD]\n", + " │ │ │ │ └─ Log [id BE]\n", + " │ │ │ │ └─ Sqrt [id BF]\n", + " │ │ │ │ └─ 6.283185307179586 [id BG]\n", + " │ │ │ └─ Log [id BH]\n", + " │ │ │ └─ ExpandDims{axis=0} [id Y]\n", + " │ │ │ └─ ···\n", + " │ │ └─ All{axes=None} [id BI]\n", + " │ │ └─ MakeVector{dtype='bool'} [id BJ]\n", + " │ │ └─ All{axes=None} [id BK]\n", + " │ │ └─ Gt [id BL]\n", + " │ │ ├─ ExpandDims{axis=0} [id Y]\n", + " │ │ │ └─ ···\n", + " │ │ └─ ExpandDims{axis=0} [id BM]\n", + " │ │ └─ 0 [id BN]\n", + " │ └─ 0.0 [id BO]\n", + " └─ Mul [id BP]\n", + " ├─ 1.0 [id BQ]\n", + " └─ Sub [id BR]\n", + " ├─ Sum{axes=None} [id BS] '__logp'\n", + " │ └─ MakeVector{dtype='float64'} [id BT]\n", + " │ ├─ Sum{axes=None} [id BU]\n", + " │ │ └─ Check{sigma > 0} [id BV] 'alpha_logprob'\n", + " │ │ ├─ Sub [id BW]\n", + " │ │ │ ├─ Sub [id BX]\n", + " │ │ │ │ ├─ Mul [id BY]\n", + " │ │ │ │ │ ├─ -0.5 [id BZ]\n", + " │ │ │ │ │ └─ Pow [id CA]\n", + " │ │ │ │ │ ├─ True_div [id CB]\n", + " │ │ │ │ │ │ ├─ Sub [id CC]\n", + " │ │ │ │ │ │ │ ├─ alpha [id CD]\n", + " │ │ │ │ │ │ │ └─ alpha_loc [id CE]\n", + " │ │ │ │ │ │ └─ alpha_scale [id CF]\n", + " │ │ │ │ │ └─ 2 [id CG]\n", + " │ │ │ │ └─ Log [id CH]\n", + " │ │ │ │ └─ Sqrt [id CI]\n", + " │ │ │ │ └─ 6.283185307179586 [id CJ]\n", + " │ │ │ └─ Log [id CK]\n", + " │ │ │ └─ alpha_scale [id CF]\n", + " │ │ └─ All{axes=None} [id CL]\n", + " │ │ └─ MakeVector{dtype='bool'} [id CM]\n", + " │ │ └─ All{axes=None} [id CN]\n", + " │ │ └─ Gt [id CO]\n", + " │ │ ├─ alpha_scale [id CF]\n", + " │ │ └─ 0 [id CP]\n", + " │ ├─ Sum{axes=None} [id CQ]\n", + " │ │ └─ Check{sigma > 0} [id CR] 'beta_logprob'\n", + " │ │ ├─ Sub [id CS]\n", + " │ │ │ ├─ Sub [id CT]\n", + " │ │ │ │ ├─ Mul [id CU]\n", + " │ │ │ │ │ ├─ ExpandDims{axis=0} [id CV]\n", + " │ │ │ │ │ │ └─ -0.5 [id CW]\n", + " │ │ │ │ │ └─ Pow [id CX]\n", + " │ │ │ │ │ ├─ True_div [id CY]\n", + " │ │ │ │ │ │ ├─ Sub [id CZ]\n", + " │ │ │ │ │ │ │ ├─ beta [id DA]\n", + " │ │ │ │ │ │ │ └─ beta_loc [id DB]\n", + " │ │ │ │ │ │ └─ beta_scale [id DC]\n", + " │ │ │ │ │ └─ ExpandDims{axis=0} [id DD]\n", + " │ │ │ │ │ └─ 2 [id DE]\n", + " │ │ │ │ └─ ExpandDims{axis=0} [id DF]\n", + " │ │ │ │ └─ Log [id DG]\n", + " │ │ │ │ └─ Sqrt [id DH]\n", + " │ │ │ │ └─ 6.283185307179586 [id DI]\n", + " │ │ │ └─ Log [id DJ]\n", + " │ │ │ └─ beta_scale [id DC]\n", + " │ │ └─ All{axes=None} [id DK]\n", + " │ │ └─ MakeVector{dtype='bool'} [id DL]\n", + " │ │ └─ All{axes=None} [id DM]\n", + " │ │ └─ Gt [id DN]\n", + " │ │ ├─ beta_scale [id DC]\n", + " │ │ └─ ExpandDims{axis=0} [id DO]\n", + " │ │ └─ 0 [id DP]\n", + " │ └─ Sum{axes=None} [id DQ]\n", + " │ └─ Add [id DR] 'sigma_log___logprob'\n", + " │ ├─ Check{sigma > 0} [id DS]\n", + " │ │ ├─ Sub [id DT]\n", + " │ │ │ ├─ Sub [id DU]\n", + " │ │ │ │ ├─ Mul [id DV]\n", + " │ │ │ │ │ ├─ -0.5 [id DW]\n", + " │ │ │ │ │ └─ Pow [id DX]\n", + " │ │ │ │ │ ├─ True_div [id DY]\n", + " │ │ │ │ │ │ ├─ Sub [id DZ]\n", + " │ │ │ │ │ │ │ ├─ Exp [id EA]\n", + " │ │ │ │ │ │ │ │ └─ sigma_log__ [id EB]\n", + " │ │ │ │ │ │ │ └─ sigma_loc [id EC]\n", + " │ │ │ │ │ │ └─ sigma_scale [id ED]\n", + " │ │ │ │ │ └─ 2 [id EE]\n", + " │ │ │ │ └─ Log [id EF]\n", + " │ │ │ │ └─ Sqrt [id EG]\n", + " │ │ │ │ └─ 6.283185307179586 [id EH]\n", + " │ │ │ └─ Log [id EI]\n", + " │ │ │ └─ sigma_scale [id ED]\n", + " │ │ └─ All{axes=None} [id EJ]\n", + " │ │ └─ MakeVector{dtype='bool'} [id EK]\n", + " │ │ └─ All{axes=None} [id EL]\n", + " │ │ └─ Gt [id EM]\n", + " │ │ ├─ sigma_scale [id ED]\n", + " │ │ └─ 0 [id EN]\n", + " │ └─ Identity [id EO] 'sigma_log___log_jacobian'\n", + " │ └─ sigma_log__ [id EB]\n", + " └─ Sum{axes=None} [id EP] '__logp'\n", + " └─ MakeVector{dtype='float64'} [id EQ]\n", + " ├─ Sum{axes=None} [id ER]\n", + " │ └─ Check{sigma > 0} [id ES] 'X_logprob'\n", + " │ ├─ Sub [id ET]\n", + " │ │ ├─ Sub [id EU]\n", + " │ │ │ ├─ Mul [id EV]\n", + " │ │ │ │ ├─ ExpandDims{axes=[0, 1]} [id EW]\n", + " │ │ │ │ │ └─ -0.5 [id EX]\n", + " │ │ │ │ └─ Pow [id EY]\n", + " │ │ │ │ ├─ True_div [id EZ]\n", + " │ │ │ │ │ ├─ Sub [id FA]\n", + " │ │ │ │ │ │ ├─ X [id V]\n", + " │ │ │ │ │ │ └─ [[0]] [id FB]\n", + " │ │ │ │ │ └─ [[1]] [id FC]\n", + " │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id FD]\n", + " │ │ │ │ └─ 2 [id FE]\n", + " │ │ │ └─ ExpandDims{axes=[0, 1]} [id FF]\n", + " │ │ │ └─ Log [id FG]\n", + " │ │ │ └─ Sqrt [id FH]\n", + " │ │ │ └─ 6.283185307179586 [id FI]\n", + " │ │ └─ Log [id FJ]\n", + " │ │ └─ [[1]] [id FC]\n", + " │ └─ All{axes=None} [id FK]\n", + " │ └─ MakeVector{dtype='bool'} [id FL]\n", + " │ └─ All{axes=None} [id FM]\n", + " │ └─ Gt [id FN]\n", + " │ ├─ [[1]] [id FC]\n", + " │ └─ ExpandDims{axes=[0, 1]} [id FO]\n", + " │ └─ 0 [id FP]\n", + " ├─ Sum{axes=None} [id FQ]\n", + " │ └─ Check{sigma > 0} [id FR] 'alpha_logprob'\n", + " │ ├─ Sub [id FS]\n", + " │ │ ├─ Sub [id FT]\n", + " │ │ │ ├─ Mul [id FU]\n", + " │ │ │ │ ├─ -0.5 [id FV]\n", + " │ │ │ │ └─ Pow [id FW]\n", + " │ │ │ │ ├─ True_div [id FX]\n", + " │ │ │ │ │ ├─ Sub [id FY]\n", + " │ │ │ │ │ │ ├─ alpha [id S]\n", + " │ │ │ │ │ │ └─ 100 [id FZ]\n", + " │ │ │ │ │ └─ 10 [id GA]\n", + " │ │ │ │ └─ 2 [id GB]\n", + " │ │ │ └─ Log [id GC]\n", + " │ │ │ └─ Sqrt [id GD]\n", + " │ │ │ └─ 6.283185307179586 [id GE]\n", + " │ │ └─ Log [id GF]\n", + " │ │ └─ 10 [id GA]\n", + " │ └─ All{axes=None} [id GG]\n", + " │ └─ MakeVector{dtype='bool'} [id GH]\n", + " │ └─ All{axes=None} [id GI]\n", + " │ └─ Gt [id GJ]\n", + " │ ├─ 10 [id GA]\n", + " │ └─ 0 [id GK]\n", + " ├─ Sum{axes=None} [id GL]\n", + " │ └─ Check{sigma > 0} [id GM] 'beta_logprob'\n", + " │ ├─ Sub [id GN]\n", + " │ │ ├─ Sub [id GO]\n", + " │ │ │ ├─ Mul [id GP]\n", + " │ │ │ │ ├─ ExpandDims{axis=0} [id GQ]\n", + " │ │ │ │ │ └─ -0.5 [id GR]\n", + " │ │ │ │ └─ Pow [id GS]\n", + " │ │ │ │ ├─ True_div [id GT]\n", + " │ │ │ │ │ ├─ Sub [id GU]\n", + " │ │ │ │ │ │ ├─ beta [id X]\n", + " │ │ │ │ │ │ └─ [0] [id GV]\n", + " │ │ │ │ │ └─ [5] [id GW]\n", + " │ │ │ │ └─ ExpandDims{axis=0} [id GX]\n", + " │ │ │ │ └─ 2 [id GY]\n", + " │ │ │ └─ ExpandDims{axis=0} [id GZ]\n", + " │ │ │ └─ Log [id HA]\n", + " │ │ │ └─ Sqrt [id HB]\n", + " │ │ │ └─ 6.283185307179586 [id HC]\n", + " │ │ └─ Log [id HD]\n", + " │ │ └─ [5] [id GW]\n", + " │ └─ All{axes=None} [id HE]\n", + " │ └─ MakeVector{dtype='bool'} [id HF]\n", + " │ └─ All{axes=None} [id HG]\n", + " │ └─ Gt [id HH]\n", + " │ ├─ [5] [id GW]\n", + " │ └─ ExpandDims{axis=0} [id HI]\n", + " │ └─ 0 [id HJ]\n", + " └─ Sum{axes=None} [id HK]\n", + " └─ Add [id HL] 'sigma_log___logprob'\n", + " ├─ Check{mu >= 0} [id HM]\n", + " │ ├─ Switch [id HN]\n", + " │ │ ├─ Ge [id HO]\n", + " │ │ │ ├─ Exp [id HP]\n", + " │ │ │ │ └─ sigma_log__ [id BA]\n", + " │ │ │ └─ 0.0 [id HQ]\n", + " │ │ ├─ Sub [id HR]\n", + " │ │ │ ├─ Neg [id HS]\n", + " │ │ │ │ └─ Log [id HT]\n", + " │ │ │ │ └─ 1.0 [id HU]\n", + " │ │ │ └─ True_div [id HV]\n", + " │ │ │ ├─ Exp [id HP]\n", + " │ │ │ │ └─ ···\n", + " │ │ │ └─ 1.0 [id HU]\n", + " │ │ └─ -inf [id HW]\n", + " │ └─ All{axes=None} [id HX]\n", + " │ └─ MakeVector{dtype='bool'} [id HY]\n", + " │ └─ All{axes=None} [id HZ]\n", + " │ └─ Ge [id IA]\n", + " │ ├─ 1.0 [id HU]\n", + " │ └─ 0 [id IB]\n", + " └─ Identity [id IC] 'sigma_log___log_jacobian'\n", + " └─ sigma_log__ [id BA]\n" + ] + }, + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "compute_loss(m_obs, guide_model).dprint()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d155685f", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "98fe5067", + "metadata": {}, + "outputs": [], + "source": [ + "## TODO:\n", + "# 1. Create hyperparameters for mean field approx (mu + sigma of normals)\n", + "# 2. Replace in the logp" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e01dcf96", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From a5f6921578281aeb1ce7dfc8b871e54f81b2a46e Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 25 May 2025 19:29:23 +0800 Subject: [PATCH 2/5] re-run notebook, show it actually works --- Rewrite VI, why not.ipynb | 589 +++++--------------------------------- 1 file changed, 75 insertions(+), 514 deletions(-) diff --git a/Rewrite VI, why not.ipynb b/Rewrite VI, why not.ipynb index f2a6c3794..6dda37553 100644 --- a/Rewrite VI, why not.ipynb +++ b/Rewrite VI, why not.ipynb @@ -2,11 +2,12 @@ "cells": [ { "cell_type": "code", - "execution_count": 6, + "execution_count": 1, "id": "9f946eb4", "metadata": {}, "outputs": [], "source": [ + "import numpy as np\n", "import pytensor.tensor as pt\n", "\n", "import pymc as pm" @@ -14,7 +15,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": 2, "id": "e746bc33", "metadata": {}, "outputs": [ @@ -41,14 +42,16 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": 3, "id": "a8ca0161", "metadata": {}, "outputs": [], "source": [ "draw = 123\n", "true_params = np.r_[\n", - " prior.prior.alpha.sel(chain=0, draw=draw).values, prior.prior.beta.sel(chain=0, draw=draw)\n", + " prior.prior.alpha.sel(chain=0, draw=draw).values,\n", + " prior.prior.beta.sel(chain=0, draw=draw),\n", + " prior.prior.sigma.sel(chain=0, draw=draw),\n", "]\n", "X_data = prior.prior.X.sel(chain=0, draw=draw).values\n", "y_data = prior.prior.y.sel(chain=0, draw=draw).values" @@ -56,7 +59,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 4, "id": "b89f4031", "metadata": {}, "outputs": [], @@ -66,7 +69,7 @@ }, { "cell_type": "code", - "execution_count": 168, + "execution_count": 5, "id": "a42a84e4", "metadata": {}, "outputs": [], @@ -108,7 +111,7 @@ }, { "cell_type": "code", - "execution_count": 169, + "execution_count": 6, "id": "bffb1b69", "metadata": {}, "outputs": [], @@ -118,53 +121,7 @@ }, { "cell_type": "code", - "execution_count": 171, - "id": "e2c4dd95", - "metadata": {}, - "outputs": [], - "source": [ - "f_draw = pm.compile([*params, draws], guide_model.deterministics)" - ] - }, - { - "cell_type": "code", - "execution_count": 182, - "id": "bfe44b73", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "[array([1.78150955]),\n", - " array([[ 0.22971278, 0.4621461 , 0.81535912, 0.62397751, 1.12162984,\n", - " 0.99310042, -0.04733258, 1.20791346, 0.61310399, 0.6248215 ]]),\n", - " array([1.04612599])]" - ] - }, - "execution_count": 182, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "f_draw(**param_dict, draws=1)" - ] - }, - { - "cell_type": "code", - "execution_count": 204, - "id": "e831865f", - "metadata": {}, - "outputs": [], - "source": [ - "init_dict = m_obs.initial_point()\n", - "init_dict = {k: np.expand_dims(v, 0) for k, v in init_dict.items()}\n", - "param_dict = {param.name: np.full(param.type.shape, 0.5) for param in params}" - ] - }, - { - "cell_type": "code", - "execution_count": 187, + "execution_count": 7, "id": "a9f3fa0e", "metadata": {}, "outputs": [], @@ -181,547 +138,151 @@ "model_logp = vectorize_graph(m_obs.logp(), inputs_to_guide_rvs)\n", "guide_logq = graph_replace(guide_model.logp(), guide_model.values_to_rvs)\n", "\n", - "elbo_loss = (guide_logq - model_logp).mean()\n", - "d_loss = pt.grad(elbo_loss, params)\n", + "negative_elbo = (guide_logq - model_logp).mean()\n", + "d_loss = pt.grad(negative_elbo, params)\n", "\n", - "f_loss_dloss = pm.compile(params + [draws], [elbo_loss, *d_loss], trust_input=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 207, - "id": "6086f2cc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "-993.9173656892265\n", - "-926.3296229583198\n", - "-1005.6712508438249\n", - "-919.6431050083108\n", - "-968.0575921556582\n", - "-987.30412574489\n", - "-964.7096352135038\n", - "-983.4362084632743\n", - "-910.8310347546233\n", - "-946.7555282915881\n", - "-943.9616597508198\n", - "-963.8766211135949\n", - "-916.6993999462212\n", - "-950.0475547851752\n", - "-967.001546433409\n", - "-943.9433767873104\n", - "-946.3015434524876\n", - "-934.0032018513697\n", - "-947.6452058569878\n", - "-893.224934118183\n", - "-979.4405814988748\n", - "-937.9997192780036\n", - "-931.1970724735944\n", - "-959.3588230003159\n", - "-932.1233734322371\n", - "-940.2791556640857\n", - "-969.4679954671045\n", - "-954.6606993395337\n", - "-982.9304227234845\n", - "-935.3404389982638\n", - "-982.6885250322749\n", - "-964.4628736035113\n", - "-939.0580477804804\n", - "-955.1672719181267\n", - "-982.0467504680682\n", - "-992.8688427985264\n", - "-967.8588846826874\n", - "-966.7655668600194\n", - "-949.3323016540423\n", - "-934.3919364553586\n", - "-1028.4493525361129\n", - "-982.4944127954707\n", - "-931.9404059809432\n", - "-981.3845063690508\n", - "-930.1688452196312\n", - "-952.0305908505434\n", - "-1012.7969628343917\n", - "-937.6379090307294\n", - "-926.5273721862864\n", - "-981.6665090046366\n", - "-980.4287957334458\n", - "-946.1849036479391\n", - "-969.4075653507354\n", - "-958.7555610935674\n", - "-1004.2802914265156\n", - "-963.0682736675113\n", - "-1000.664222867785\n", - "-934.7298341549889\n", - "-978.1088760452403\n", - "-944.0432566257858\n", - "-980.5576704437719\n", - "-969.5612703209024\n", - "-963.9839441465232\n", - "-966.4944698264674\n", - "-953.2780303578338\n", - "-953.5548215260793\n", - "-959.5050408283469\n", - "-989.7240037630827\n", - "-949.1398757808872\n", - "-964.8820244293524\n", - "-940.224485134473\n", - "-947.7591586789077\n", - "-964.679359616359\n", - "-917.0015516708557\n", - "-939.5619779661295\n", - "-993.9876761804422\n", - "-986.3960585411946\n", - "-975.6545633352575\n", - "-957.4304264740541\n", - "-970.7472516497804\n", - "-956.5223102168254\n", - "-969.825369100695\n", - "-987.9141239728567\n", - "-952.6424507446926\n", - "-958.6852758514054\n", - "-965.0105265156106\n", - "-962.943175120098\n", - "-946.9378229990178\n", - "-944.4992023963096\n", - "-947.4171539678032\n", - "-959.1841786549261\n", - "-931.9541021582517\n", - "-972.675454084302\n", - "-1030.6428849456215\n", - "-983.2067258579275\n", - "-1002.0505700409622\n", - "-1015.4906878612064\n", - "-962.2604231546592\n", - "-955.1833250683931\n", - "-938.1777960392025\n" - ] - } - ], - "source": [ - "learning_rate = 1e-4\n", - "for _ in range(100):\n", - " loss, *grads = f_loss_dloss(**param_dict, draws=100)\n", - " for (name, value), grad in zip(param_dict.items(), grads):\n", - " param_dict[name] = value - learning_rate * grad\n", - " print(loss)" + "f_loss_dloss = pm.compile(params + [draws], [negative_elbo, *d_loss], trust_input=True)" ] }, { "cell_type": "code", - "execution_count": 208, - "id": "1251b6f1", + "execution_count": 8, + "id": "37ae25b1", "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "{'alpha_loc': 0.6942395020077371,\n", - " 'alpha_scale': 0.5102778323184816,\n", - " 'beta_loc': array([0.49114368, 0.46379742, 0.52277764, 0.53227815, 0.48851862,\n", - " 0.50044113, 0.5441339 , 0.46643231, 0.47894475, 0.51122713]),\n", - " 'beta_scale': array([0.49807094, 0.49740464, 0.5005586 , 0.4991539 , 0.49749037,\n", - " 0.49825551, 0.49992182, 0.4981251 , 0.49745959, 0.49767778]),\n", - " 'sigma_loc': 4.2148859620931765,\n", - " 'sigma_scale': -5.964631972194452e-05}" - ] - }, - "execution_count": 208, - "metadata": {}, - "output_type": "execute_result" - } - ], + "outputs": [], "source": [ - "param_dict" + "init_dict = m_obs.initial_point()\n", + "init_dict = {k: np.expand_dims(v, 0) for k, v in init_dict.items()}\n", + "param_dict = {param.name: np.full(param.type.shape, 0.5) for param in params}" ] }, { "cell_type": "code", - "execution_count": 209, - "id": "11345fb4", + "execution_count": 9, + "id": "6086f2cc", "metadata": {}, "outputs": [ { - "data": { - "text/plain": [ - "array([83.43771778, -6.07738876, -3.3268889 , 8.38393732, 10.77212434,\n", - " -2.81776509, 0.46737085, 8.7204497 , -4.79822835, -3.47220908,\n", - " 8.76186526])" - ] - }, - "execution_count": 209, - "metadata": {}, - "output_type": "execute_result" + "name": "stdout", + "output_type": "stream", + "text": [ + "-8391.5148860327435\r" + ] } ], "source": [ - "true_params" + "learning_rate = 1e-5\n", + "n_iter = 60_000\n", + "loss_history = np.empty(n_iter)\n", + "for i in range(n_iter):\n", + " loss, *grads = f_loss_dloss(**param_dict, draws=500)\n", + " loss_history[i] = loss\n", + " for (name, value), grad in zip(param_dict.items(), grads):\n", + " param_dict[name] = (value - learning_rate * grad).copy()\n", + " if i % 50 == 0:\n", + " print(loss, end=\"\\r\")\n", + " if i % 10_000 == 0 and i > 0:\n", + " learning_rate = min(learning_rate * 10, 1e-3)" ] }, { "cell_type": "code", - "execution_count": 165, - "id": "566105c2", + "execution_count": 10, + "id": "650c5e39", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "array(273519.9558606)" + "[]" ] }, - "execution_count": 165, + "execution_count": 10, "metadata": {}, "output_type": "execute_result" - } - ], - "source": [ - "f_loss(**param_dict, draws=100)" - ] - }, - { - "cell_type": "code", - "execution_count": 166, - "id": "c39cd24e", - "metadata": {}, - "outputs": [ + }, { "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj0AAAGdCAYAAAD5ZcJyAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAPQxJREFUeJzt3XtclWW+///34rRAAiSWgnhCOmpUNmCKVow5SqllO6fGqTHZ09gmpRM631JrcnQ8/EbH3c5d2ZRZTe5pdmPu8ZSimZpGKh4KD0kpKB6QUAOPHK/fH8qdS1HBXCzgfj0fj/vxYN3rs+77WlcU7+77uu7LYYwxAgAAaOJ8vN0AAACA+kDoAQAAtkDoAQAAtkDoAQAAtkDoAQAAtkDoAQAAtkDoAQAAtkDoAQAAtuDn7QY0FFVVVdq/f79CQkLkcDi83RwAAFALxhgdPXpU0dHR8vG5+LUcQs8Z+/fvV9u2bb3dDAAAcBny8/PVpk2bi9YQes4ICQmRdLrTQkNDvdwaAABQGyUlJWrbtq31d/xiCD1nVN/SCg0NJfQAANDI1GZoCgOZAQCALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALbDgqIfN2bBX2fuKdU9clLrFRni7OQAA2BZXejxsZc73eveLPG3bX+LtpgAAYGuEHgAAYAuEHgAAYAuEHgAAYAuEHgAAYAuEHgAAYAuEnnpivN0AAABsjtDjYQ6Ht1sAAAAkQg8AALAJQg8AALAFQg8AALAFQg8AALAFQk89MYb5WwAAeBOhx8OYvAUAQMNA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6PEwByuOAgDQIHg89CxcuFBdu3ZVUFCQXC6XHnzwQbf39+zZo/vuu0/BwcFyuVx6+umnVVZW5laTnZ2tpKQkBQUFqXXr1ho3btx5D/tbuXKl4uPjFRgYqNjYWM2YMcPTXw0AADQifp48+Jw5czR06FBNnDhRd999t4wxys7Ott6vrKxUv3791KJFC61evVqHDh3SkCFDZIzR9OnTJUklJSXq3bu3evbsqfXr1ysnJ0cpKSkKDg7WiBEjJEm5ubnq27evhg4dqg8++EBr1qzRsGHD1KJFCw0cONCTXxEAADQSHgs9FRUVeuaZZzRlyhQ9/vjj1v4bbrjB+jkjI0Pbtm1Tfn6+oqOjJUl/+ctflJKSogkTJig0NFSzZ8/WqVOn9O6778rpdCouLk45OTmaNm2a0tPT5XA4NGPGDLVr106vvPKKJKljx47KysrS1KlTCT0AAECSB29vbdy4Ufv27ZOPj49uu+02tWrVSvfee6+2bt1q1WRmZiouLs4KPJKUnJys0tJSbdiwwapJSkqS0+l0q9m/f7/y8vKsmj59+ridPzk5WVlZWSovL6+xfaWlpSopKXHbAABA0+Wx0LNr1y5J0tixY/Xiiy9qwYIFCg8PV1JSkg4fPixJKigoUGRkpNvnwsPDFRAQoIKCggvWVL++VE1FRYWKiopqbN+kSZMUFhZmbW3btv2J3xgAADRkdQ49Y8eOlcPhuOiWlZWlqqoqSdKYMWM0cOBAxcfHa9asWXI4HProo4+s49U0u8kY47b/3JrqQcx1rTnbqFGjVFxcbG35+fl16YY6Y5F1AAC8q85jetLS0jRo0KCL1sTExOjo0aOSpE6dOln7nU6nYmNjtWfPHklSVFSU1q5d6/bZI0eOqLy83LpyExUVZV3RqVZYWChJl6zx8/NTREREjW10Op1ut8w8hQnrAAA0DHUOPS6XSy6X65J18fHxcjqd2rFjh+644w5JUnl5ufLy8tS+fXtJUmJioiZMmKADBw6oVatWkk4PbnY6nYqPj7dqRo8erbKyMgUEBFg10dHRiomJsWrmz5/vdv6MjAwlJCTI39+/rl8RAAA0QR4b0xMaGqrU1FS9/PLLysjI0I4dO/Tkk09Kkh566CFJUp8+fdSpUycNHjxYmzZt0qeffqqRI0dq6NChCg0NlSQ98sgjcjqdSklJ0ZYtWzR37lxNnDjRmrklSampqdq9e7fS09O1fft2vfPOO5o5c6ZGjhzpqa8HAAAaGY8+p2fKlCny8/PT4MGDdfLkSXXt2lXLly9XeHi4JMnX11cLFy7UsGHD1KNHDwUFBemRRx7R1KlTrWOEhYVp6dKlGj58uBISEhQeHq709HSlp6dbNR06dNCiRYv03HPP6bXXXlN0dLReffVVpqsDAACLw5z7aGObKikpUVhYmIqLi62rTFdC+j826+NN+zSmb0cNvSv2ih0XAADU7e83a2/VEyOyJQAA3kTo8TSmbwEA0CAQegAAgC0QegAAgC0QegAAgC0QegAAgC0QegAAgC0QeuoJT0MCAMC7CD0e5mDOOgAADQKhBwAA2AKhBwAA2AKhBwAA2AKhBwAA2AKhp54weQsAAO8i9HiYg8lbAAA0CIQeAABgC4QeAABgC4QeAABgC4QeAABgC4QeAABgC4SeesKCowAAeBehx8OYsQ4AQMNA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6KknhnXWAQDwKkKPh7HKOgAADQOhBwAA2AKhBwAA2AKhBwAA2AKhBwAA2AKhp56w4CgAAN5F6PEwB0uOAgDQIBB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6PIwFRwEAaBgIPQAAwBYIPQAAwBYIPQAAwBYIPQAAwBYIPQAAwBYIPfXEsMw6AABeRejxMKasAwDQMBB6AACALRB6AACALRB6AACALRB6AACALRB66gmTtwAA8C6Ph56FCxeqa9euCgoKksvl0oMPPuj2vsPhOG+bMWOGW012draSkpIUFBSk1q1ba9y4cedNAV+5cqXi4+MVGBio2NjY847hPUzfAgCgIfDz5MHnzJmjoUOHauLEibr77rtljFF2dvZ5dbNmzdI999xjvQ4LC7N+LikpUe/evdWzZ0+tX79eOTk5SklJUXBwsEaMGCFJys3NVd++fTV06FB98MEHWrNmjYYNG6YWLVpo4MCBnvyKAACgkfBY6KmoqNAzzzyjKVOm6PHHH7f233DDDefVNm/eXFFRUTUeZ/bs2Tp16pTeffddOZ1OxcXFKScnR9OmTVN6erp1Zahdu3Z65ZVXJEkdO3ZUVlaWpk6dSugBAACSPHh7a+PGjdq3b598fHx02223qVWrVrr33nu1devW82rT0tLkcrnUpUsXzZgxQ1VVVdZ7mZmZSkpKktPptPYlJydr//79ysvLs2r69Onjdszk5GRlZWWpvLy8xvaVlpaqpKTEbQMAAE2Xx0LPrl27JEljx47Viy++qAULFig8PFxJSUk6fPiwVTd+/Hh99NFHWrZsmQYNGqQRI0Zo4sSJ1vsFBQWKjIx0O3b164KCgovWVFRUqKioqMb2TZo0SWFhYdbWtm3bn/6lAQBAg1Xn0DN27NgaBx+fvWVlZVlXa8aMGaOBAwcqPj5es2bNksPh0EcffWQd78UXX1RiYqI6d+6sESNGaNy4cZoyZYrbOR3nrOVQPYj57P21qTnbqFGjVFxcbG35+fl17QoAANCI1HlMT1pamgYNGnTRmpiYGB09elSS1KlTJ2u/0+lUbGys9uzZc8HPduvWTSUlJTp48KAiIyMVFRVlXdGpVlhYKOnHKz4XqvHz81NERESN53E6nW63zDyNGesAAHhXnUOPy+WSy+W6ZF18fLycTqd27NihO+64Q5JUXl6uvLw8tW/f/oKf27RpkwIDA9W8eXNJUmJiokaPHq2ysjIFBARIkjIyMhQdHa2YmBirZv78+W7HycjIUEJCgvz9/ev6Fa8oFhwFAKBh8NiYntDQUKWmpurll19WRkaGduzYoSeffFKS9NBDD0mS5s+fr7feektbtmzRzp079fbbb2vMmDF64oknrKswjzzyiJxOp1JSUrRlyxbNnTtXEydOtGZuSVJqaqp2796t9PR0bd++Xe+8845mzpypkSNHeurrAQCARsajz+mZMmWK/Pz8NHjwYJ08eVJdu3bV8uXLFR4eLkny9/fX66+/rvT0dFVVVSk2Nlbjxo3T8OHDrWOEhYVp6dKlGj58uBISEhQeHq709HSlp6dbNR06dNCiRYv03HPP6bXXXlN0dLReffVVpqsDAACLw5z7aGObKikpUVhYmIqLixUaGnrFjjt6brb+Z+0epfe+Xk/3uu6KHRcAANTt7zdrbwEAAFsg9AAAAFsg9NQTbiICAOBdhB4PY8Y6AAANA6EHAADYAqEHAADYAqEHAADYAqEHAADYAqGnnhiWHAUAwKsIPR7GgqMAADQMhB4AAGALhB4AAGALhB4AAGALhB4AAGALhB4AAGALhJ56woKjAAB4F6HHwxwsOQoAQINA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6KknzFgHAMC7CD0exirrAAA0DIQeAABgC4QeAABgC4QeAABgC4QeAABgC4Se+sKKowAAeBWhx8OYvAUAQMNA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6KknTFgHAMC7CD0e5mDFUQAAGgRCDwAAsAVCDwAAsAVCDwAAsAVCDwAAsAVCDwAAsAVCTz1hkXUAALyL0AMAAGyB0AMAAGyB0AMAAGyB0AMAAGyB0AMAAGyB0FNPDEuOAgDgVYQeD2O9UQAAGgZCDwAAsAVCDwAAsAVCDwAAsAVCDwAAsAWPhZ4VK1bI4XDUuK1fv96q27Nnj+677z4FBwfL5XLp6aefVllZmduxsrOzlZSUpKCgILVu3Vrjxo2TOWcxq5UrVyo+Pl6BgYGKjY3VjBkzPPXVAABAI+TnqQN3795dBw4ccNv30ksvadmyZUpISJAkVVZWql+/fmrRooVWr16tQ4cOaciQITLGaPr06ZKkkpIS9e7dWz179tT69euVk5OjlJQUBQcHa8SIEZKk3Nxc9e3bV0OHDtUHH3ygNWvWaNiwYWrRooUGDhzoqa9YJyw4CgCAd3ks9AQEBCgqKsp6XV5ernnz5iktLU2OM/O4MzIytG3bNuXn5ys6OlqS9Je//EUpKSmaMGGCQkNDNXv2bJ06dUrvvvuunE6n4uLilJOTo2nTpik9PV0Oh0MzZsxQu3bt9Morr0iSOnbsqKysLE2dOtXrocch5qwDANAQ1NuYnnnz5qmoqEgpKSnWvszMTMXFxVmBR5KSk5NVWlqqDRs2WDVJSUlyOp1uNfv371deXp5V06dPH7fzJScnKysrS+Xl5TW2p7S0VCUlJW4bAABouuot9MycOVPJyclq27atta+goECRkZFudeHh4QoICFBBQcEFa6pfX6qmoqJCRUVFNbZn0qRJCgsLs7az2wUAAJqeOoeesWPHXnCAcvWWlZXl9pm9e/dqyZIlevzxx887nqOGRxYbY9z2n1tTPYi5rjVnGzVqlIqLi60tPz//Yl8bAAA0cnUe05OWlqZBgwZdtCYmJsbt9axZsxQREaH777/fbX9UVJTWrl3rtu/IkSMqLy+3rtxERUVZV3SqFRYWStIla/z8/BQREVFjG51Op9stMwAA0LTVOfS4XC65XK5a1xtjNGvWLD322GPy9/d3ey8xMVETJkzQgQMH1KpVK0mnBzc7nU7Fx8dbNaNHj1ZZWZkCAgKsmujoaCtcJSYmav78+W7HzsjIUEJCwnnn9BYmbwEA4F0eH9OzfPly5ebm1nhrq0+fPurUqZMGDx6sTZs26dNPP9XIkSM1dOhQhYaGSpIeeeQROZ1OpaSkaMuWLZo7d64mTpxozdySpNTUVO3evVvp6enavn273nnnHc2cOVMjR4709Ne7JBYcBQCgYfB46Jk5c6a6d++ujh07nveer6+vFi5cqMDAQPXo0UMPP/ywHnjgAU2dOtWqCQsL09KlS7V3714lJCRo2LBhSk9PV3p6ulXToUMHLVq0SCtWrFDnzp01fvx4vfrqq16frg4AABoOhzn30cY2VVJSorCwMBUXF1tXma6E8Qu2aebqXD3582v0/D03XrHjAgCAuv39Zu0tAABgC4QeAABgC4QeAABgC4SeesLIKQAAvIvQ42HMWAcAoGEg9AAAAFsg9AAAAFsg9AAAAFsg9AAAAFsg9AAAAFsg9NQTwzrrAAB4FaHHw1hlHQCAhoHQAwAAbIHQAwAAbIHQAwAAbIHQAwAAbIHQU1+YvAUAgFcRejzMwfQtAAAaBEIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUJPPWHGOgAA3kXo8TAmrAMA0DAQegAAgC0QegAAgC0QegAAgC0QegAAgC0QegAAgC0QeuqJMUxaBwDAmwg9nsacdQAAGgRCDwAAsAVCDwAAsAVCDwAAsAVCDwAAsAVCTz1h8hYAAN5F6PEwB9O3AABoEAg9AADAFgg9AADAFgg9AADAFgg9AADAFgg9AADAFgg99YQZ6wAAeBehx8N8z/RwFQ/qAQDAqwg9HubjOP2cHjIPAADeRejxMMeZ0MOVHgAAvIvQ42E+Zx7ITOgBAMC7CD0eVn17q7LKyw0BAMDmCD0e5utTPaaHKz0AAHgTocfDHNzeAgCgQSD0eJiPNZDZyw0BAMDmCD0exkBmAAAaBkKPh1lXerjUAwCAV3ks9KxYsUIOh6PGbf369VZdTe/PmDHD7VjZ2dlKSkpSUFCQWrdurXHjxp03MHjlypWKj49XYGCgYmNjzzuGt3B7CwCAhsHPUwfu3r27Dhw44LbvpZde0rJly5SQkOC2f9asWbrnnnus12FhYdbPJSUl6t27t3r27Kn169crJydHKSkpCg4O1ogRIyRJubm56tu3r4YOHaoPPvhAa9as0bBhw9SiRQsNHDjQU1+xVri9BQBAw+Cx0BMQEKCoqCjrdXl5uebNm6e0tDTrKcXVmjdv7lZ7ttmzZ+vUqVN699135XQ6FRcXp5ycHE2bNk3p6enWlaF27drplVdekSR17NhRWVlZmjp1qvdDjw/LUAAA0BDU25ieefPmqaioSCkpKee9l5aWJpfLpS5dumjGjBmqqvrxSX6ZmZlKSkqS0+m09iUnJ2v//v3Ky8uzavr06eN2zOTkZGVlZam8vLzG9pSWlqqkpMRt8wSWoQAAoGGot9Azc+ZMJScnq23btm77x48fr48++kjLli3ToEGDNGLECE2cONF6v6CgQJGRkW6fqX5dUFBw0ZqKigoVFRXV2J5JkyYpLCzM2s5t15VSfXurkkE9AAB4VZ1Dz9ixYy84QLl6y8rKcvvM3r17tWTJEj3++OPnHe/FF19UYmKiOnfurBEjRmjcuHGaMmWKW825t8OqBzGfvb82NWcbNWqUiouLrS0/P7+WPVA3vgxkBgCgQajzmJ60tDQNGjToojUxMTFur2fNmqWIiAjdf//9lzx+t27dVFJSooMHDyoyMlJRUVHWFZ1qhYWFkn684nOhGj8/P0VERNR4HqfT6XbLzFOqZ2+xDAUAAN5V59DjcrnkcrlqXW+M0axZs/TYY4/J39//kvWbNm1SYGCgmjdvLklKTEzU6NGjVVZWpoCAAElSRkaGoqOjrXCVmJio+fPnux0nIyNDCQkJtTqnJ7EMBQAADYPHx/QsX75cubm5Nd7amj9/vt566y1t2bJFO3fu1Ntvv60xY8boiSeesK7CPPLII3I6nUpJSdGWLVs0d+5cTZw40Zq5JUmpqanavXu30tPTtX37dr3zzjuaOXOmRo4c6emvd0k8pwcAgIbBY1PWq82cOVPdu3dXx44dz3vP399fr7/+utLT01VVVaXY2FiNGzdOw4cPt2rCwsK0dOlSDR8+XAkJCQoPD1d6errS09Otmg4dOmjRokV67rnn9Nprryk6Olqvvvqq16erS5LPmVjJlR4AALzLYRhsIun0QxDDwsJUXFys0NDQK3bcf23ep2c+3Kwe10Zo9u+6XbHjAgCAuv39Zu0tD/tx7S0vNwQAAJsj9HiYDw8nBACgQSD0eBhrbwEA0DAQejzMwewtAAAaBEKPh/meudTDMhQAAHgXocfD/HxPh54KRjIDAOBVhB4Pc/qe7uKyCkIPAADeROjxsAA/Qg8AAA0BocfDCD0AADQMhB4Ps0JPJaEHAABvIvR4WMCZMT2lXOkBAMCrCD0exu0tAAAaBkKPhwX6+0o6faWHZ/UAAOA9hB4PCwn0s34+dqrCiy0BAMDeCD0e5vTzlfPMLa6SU+Vebg0AAPZF6KkHoUH+kgg9AAB4E6GnHlTf4jrK7S0AALyG0FMPqsfyfH+01MstAQDAvgg99aDwTNj525e7vdwSAADsi9BTD1qFBUqSrmt5lZdbAgCAfRF66sGvb28nSaqo5Dk9AAB4C6GnHjQLOP2Awn9k5eufG/Zqfd5hSdKx0poHNn/2TaGyztTU5MN1exTzwkKdKq+88o3V6adHG1O7gHayrFIny+rejlPlldr1/bE6f67a2l2HdKLs8gaGV1RWae+RE5d9bgBA4+R36RL8VGc/iXnkR1+d9/68tB7K2HpQuYeOa+HXB6z92WP7KNDfV/6+Pop5YaGucvppUJe2ent1riTpxpcWK29yP50sq1THPyx2O+aCp+5Qp1ahih29SJK08vc/V9KUFeed++3HEnT4RJn+3z+/1mOJ7ZW9r1ib9vxgHSP/8Akt+PqAFmafbldEcIDuuM6lV37VWcNmb9QnWwokSf/5q1v13D++0qAubTV54C3nnae8skpHTpTp9gmf6u4bW2r5N4WSpNgWwfrTA3FKjI2Qw+HQvh9O6kRpha6LDLE++13hMQ3479U6Xlap/xveQ39dtVOLsk+fd92YXtpz6ITe/SJPv+nWXt1iI/T//vmV/jdrrxY8dYeKT5br0bfX6rORP1dMRDM5HA49/GamNp75jrmT+urjjfvUPqKZfjkjUw6HZIy05Nm7dH3kVXI4HBf852qM0cnySjULqPlfo/LKKv1r837dGxelAD8ffZX/gzq3bS4/3x//X6Oqyijv0HF1cAVr75GTuvPPn0mSpj50q5ZtO6hRfW9U+4jg8449LWOHbmsXro6tQhUZ6rxoOwEApzlMbf+XvokrKSlRWFiYiouLFRoaekWPXVlldM2Z8IHay3juLvX5z1VebcPcYd01c3WuHvxZa/323Sxr/1d/6KNbx2W41f4yvo0S2ofr5ze01Pq8w3rq75sueNx/PNFN0c2DrJBzMZte6q3w4ABVVRn5+Dj03D82a+6mfW41MRHNtPjZu6xlT+Zu2qsuMVerTXizunxdAGh06vL3m9BzhidDjyTFvLDwih8TONctbcI0oHNrjV+wzdqXN7mfF1sEAJ5F6LkMng49p8ortWRrgbrFRqjrxE+v+PGBuvhlfBtNfehW63VFZZWMJP8zt95yi45r8ZYCPXFXrHx9uHUGoOEi9FwGT4eeC3nw9TXW+JJOrUL1r7Qe1tiXc+2a2FdHTpQp79BxFZ8st2633N7haq3LPaxvJ9wrf18f/XnxN3p9xc4az5c3uV+NV53+Z2hXPfLW2ho/M+fJ7vpZu+YqrahSaUWVUv+2QZm7Dmnaw7fqwZ+1kSTd+eflujrYqf/+9W2atjTnvNsvF2qLJC3ZWqD/+NuGi9a+++9dlDJrvfW6S0y4JGl93hG3uifuitW/Nu9Tv5uj9c6aXC1+9k7dEBmi42WVuueVVdp75KR+062dPvhyT43nad08SPt+OHnJtjdV//3Ibbrz2hbWrbvkmyL15uAEtxpjjKqMCEMAGgRCz2XwVuiRTs/W+v5oqR5KaGMNSC2rqFLRsVIF+fuqeTP/Og9UNcbow/X56ty2uZ75cJNyDh7T4mfv1I1RoTLGaOv+EnVqFSqfs/5w5RYdV8+pKzTr37uo5w0tVXKqXM38fd0G3v5UVVVGb6/epV8ltFNYM/9afY/q7179q3puX5ScKpePw6GrnLUfl19WUaUvdhbphqgQNfP3k4+PZCSFBp5uU3llle55ZZV2fn9ckjT+gTj9Mytf7/777co7dFz/9voX1rFm/CZeqR+4h7a3H0vQLzpFqrLKaMnWAjVv5q8bo0L1s/FLrZodf7pHTj9fK4TO/l1X9bjWpdvGZejIidPrtG39Y7JuenmJJCm99/ValfO9snafDno7J/ZVt0mfevxJ3107XK21uYfl5+NQxVmD8ockttcfB8R59NwAcCmEnsvgzdCDhmtd7mHFuJqpZUig2/7Dx8v06faDuu/WaGvwcG7RcYUG+iniKucFj2eM0YKvD+iu61pcMPRVVRnNXrdH/W9upfDggIu2zxijZdsL1Sk6VD0mL1e32Kv15a7DevuxBP3u/ayLfvZKm/GbnynxGpe27S/Rr9/6UtLpWXBB/r5qF9FMJ8sqdeREmaKbB9VruwA0bYSey0DoQVNz7qMM0npeq0+2HNDPb2ip+2+N1oDX1tRbW/72+O0aPHOdJKnHtRGa/btu9XZuAE0boecyEHpgR9W31nZO7CtfH4eSpnym3YdOP7hxTN+OmrBou0fOu3NiXw36a+Z5Y7K+nXCvlm47qLbhzdS8mb827jmi+26JVvHJ8kte9QJgT4Sey0DoAU4rPlGukEA/+fg4VHj0lNbnHtHv//mVTpRVKqV7jMbef5O27CtW/+mr671tz/7iOiXfFKXo5kEKC/JXeWWVfB0Ot7FpAOyF0HMZCD3A5TPGqMOoHx/AWdPDGz1pzQt3q3XzII2em63/WXt6Zh7PJwLsgdBzGQg9wJVljFHh0VIFO/10orRCYc389cs3MpW9r7hezp/14i/08ryt6hvXSv1uaVUv5wRQ/wg9l4HQA3ieMUalFVVasrVAOwuP6dlfXK+yyir9fd0e/XH+tksf4Cd667EEdb8mQl/vLVbmziKl97nhvJr8wye0+9AJ3XGdy+PtAfDTEXouA6EH8K6Ji7brhxNl+vMvb9V3hce0Kud73RMXJT8fh1qGBmrUx1/r7+vy6609E/4tToO6tFP+4RNqf2axWgAND6HnMhB6gMbpl298YT2w0ZOCA3x1vKxST/78Gj1/z43W/vzDJ5Sx7aAeS2xvLeMBoP4Qei4DoQdo3E6VV+pkWaWaN/N3G1TtKTsn9lVZRZX1LKSftWuuj4f18Ph5Abgj9FwGQg/QdBSfLNetfzw9e2zXxL5yOOQWhCJDnTpY4tnlOyQpJqKZZg/tptbNg3SqvFI3j12iZ39xvbbuL9Z//qqznH6+Hm8D0NQRei4DoQdo+korKi8aNMorq+Tn43ALSH1vjtKi7AKPtGfrH5MV5O/Lc4aAn4DQcxkIPQAu5MX/y9YHX+7R6ud7asnWgxq/4MrPNFvw1B3WAx8H/qyN5mzcK0lK6R6jd7/I0z+e6KausRFX/LxAY0fouQyEHgC1VXyi3O3hi99OuFfXjfnE4+d999+76J01eXrwtta6Jy7KWuwWsDNCz2Ug9ACoqy92FqmDK1itwoK08/tj+nLXIf26Szu321XV65tJ0qNd22n2mSdG/1SDurTV+AfidKq8UiGB/lfkmEBjROi5DIQeAJ5QUVml5d8Uqvu1Ll3l9FPxiXI5fKRbxnpmmY7f9uigNuFBuuM6l65tcZUk6UDJKe05dEKJ13B7DE0PoecyEHoAeMvZV4Oqvdivo/608Mqucn/ndS69/9vbdeh4mVxXOa/osQFvIfRcBkIPAG85eqpcn2QX6OEubWWMUZWRfM+6RfbtwaPq/Z+rJEmLn71T97zy+U8+559/eYuuaXGV9hw+rgc6t+aJ02i0CD2XgdADoLFI/9/N+njjvit6zF43ttQL996ooABfBfr7KuFPyySxWj0aPkLPZSD0AGhsjDFyOBw6fLxM+384qfe+yFP2vmJ9U3D0ip3jpuhQLXjqDnUYtUjhzfy16Q99rtixgSuB0HMZCD0AmgpjjA4Un1J08yDr9ZVamuPB21pr4oM3M10eDUZd/n6zOh4ANDEOh8MKPNWv30lJUOvmQdr0Um/dEBlivdfrxpZ1OvbHm/bpxpcW6/ujpaqsMqqorLpi7QY8jSs9Z3ClB4CdnCqv1KHjZWrdPEh/WrBNb6/O1Z8H3qL/N+fryzrem4Pj1btjJEtqoN5xe+syEHoAQNq2v0SfbDmgZ3pdJ4fDodsnLJOvj0N9b26ld7/Iu+Tnf5XQVh1aBGvyJ98o47m7dP1ZV5UATyD0XAZCDwBcXPbeYlUaowdeW1Pnz77329uVdH0LD7QKdtdgxvTk5ORowIABcrlcCg0NVY8ePfTZZ5+51ezZs0f33XefgoOD5XK59PTTT6usrMytJjs7W0lJSQoKClLr1q01btw4nZvVVq5cqfj4eAUGBio2NlYzZszw5FcDANu5uU2YOrdtrp0T+9b5s0PeWaf5X+33QKuA2vPz5MH79eun66+/XsuXL1dQUJBeeeUV9e/fXzt37lRUVJQqKyvVr18/tWjRQqtXr9ahQ4c0ZMgQGWM0ffp0SacTXO/evdWzZ0+tX79eOTk5SklJUXBwsEaMGCFJys3NVd++fTV06FB98MEHWrNmjYYNG6YWLVpo4MCBnvyKAGA7vj4Ot+f3bNh9RAPf+OKSn3vq75v01N83KTLUqYMlpZKk3El9VXSsTC1CeEI0PM9jt7eKiorUokULrVq1Snfeeack6ejRowoNDdWyZcvUq1cvffLJJ+rfv7/y8/MVHR0tSfrwww+VkpKiwsJChYaG6o033tCoUaN08OBBOZ2n/6WYPHmypk+frr1798rhcOj555/XvHnztH37j49sT01N1VdffaXMzMxatZfbWwDw0739+S59ueuwlm0/WKfP/TK+jVK6x+im6FCeDo06aRC3tyIiItSxY0e9//77On78uCoqKvTmm28qMjJS8fHxkqTMzEzFxcVZgUeSkpOTVVpaqg0bNlg1SUlJVuCprtm/f7/y8vKsmj593B+YlZycrKysLJWXl3vqKwIAzvG7O2P19pAEtytBLWtxFeefG/aq//TVeuyddZ5sHmzOY7e3HA6Hli5dqgEDBigkJEQ+Pj6KjIzU4sWL1bx5c0lSQUGBIiMj3T4XHh6ugIAAFRQUWDUxMTFuNdWfKSgoUIcOHWo8TmRkpCoqKlRUVKRWrVqd177S0lKVlpZar0tKSn7qVwYAnOXs4PP90VJ1mbDskp/5/NsiLco+oGXbD2raw5092DrYUZ2v9IwdO1YOh+OiW1ZWlowxGjZsmFq2bKnPP/9c69at04ABA9S/f38dOHDAOl5NlzGrH61+oZrqO3J1rTnbpEmTFBYWZm1t27atY08AAGqrRYhTeZP7aVn6XZesHTZ7oz7euE83/WGxKquYYIwrp85XetLS0jRo0KCL1sTExGj58uVasGCBjhw5Yt1je/3117V06VK99957euGFFxQVFaW1a9e6ffbIkSMqLy+3rtxERUVZV32qFRYWStIla/z8/BQREVFjG0eNGqX09HTrdUlJCcEHADzs2pYh2jWxr3x8HMotOq6eU1dcsPZ4WaWuGb2IRU9xxdQ59LhcLrlcrkvWnThxQpLk4+N+McnHx0dVVacfW56YmKgJEybowIED1i2ojIwMOZ1Oa9xPYmKiRo8erbKyMgUEBFg10dHR1m2vxMREzZ8/3+08GRkZSkhIkL+/f43tczqdbuOEAAD1o/qpzR1cwW6BJuaFhTXWn73/qbuv1Yg+N3i2gWiyPDaQOTExUeHh4RoyZIi++uor5eTk6Pe//71yc3PVr9/pX/I+ffqoU6dOGjx4sDZt2qRPP/1UI0eO1NChQ62rQ4888oicTqdSUlK0ZcsWzZ07VxMnTlR6erp16yo1NVW7d+9Wenq6tm/frnfeeUczZ87UyJEjPfX1AABX2LrRvS5ZM335d4p5YSG3vXBZPBZ6XC6XFi9erGPHjunuu+9WQkKCVq9erX/961+69dZbJUm+vr5auHChAgMD1aNHDz388MN64IEHNHXqVOs4YWFhWrp0qfbu3auEhAQNGzZM6enpbremOnTooEWLFmnFihXq3Lmzxo8fr1dffZVn9ABAI9IyNFBznkzUv/eI0ddj+1y09prRV2bVeNgLy1CcwXN6AKBhKa+s0txN+3RjVIju/++al7648zqX/vZ413puGRqSBvGcHgAAfgp/Xx89nNBWt7RprlW/76k3B8efV/P5t0X6ctchxbywUCM/+soLrURjQugBADR47SKaKfmmKGWOuvu89wb99UtJpx9wWFpRWd9NQyNC6AEANBqtwoIuOoU9//CJemwNGhtCDwCg0bnQSu8v/d/Wem4JGhNCDwCg0fH1cdQ4xT1z1yF1n/SpF1qExoDQAwBolFqGBipvcr/zbnftLz6laUtzvNQqNGSEHgBAo7fy9z93e/3qp9/qhxNl3mkMGixCDwCg0WsfEaxvJ9zrtu+eVz73UmvQUBF6AABNgr+vj9utroKSU15sDRoiQg8AoMkqOlbq7SagASH0AACalJTuMdbPCX9a5r2GoMEh9AAAmpSx99/k9rqKFdlxBqEHANDk3Nw6zPp5+TeFXmwJGhJCDwCgyZmX1sP6+XfvZ3mxJWhICD0AgCbH4XB4uwlogAg9AIAmKaF9uPXzsdIKL7YEDQWhBwDQJL0yqLP183/8jVtcIPQAAJqoNuHNrJ/XfHfIiy1BQ0HoAQDYwp5DJ7zdBHgZoQcA0GTlTupr/Tw1Y4cXW4KGgNADAGiyzp7FFeDHnzy74zcAANCk3XV9C0lSx1ahXm4JvI3QAwBo0lo3D5QkbdlX7OWWwNsIPQCAJu1UeZUkaXP+D95tCLyO0AMAaNI6nbmt5WRMj+3xGwAAaNKimwdJkkID/b3cEngboQcA0KSFBPpJkkpOlXu5JfA2Qg8AoEmrDj1HT7H+lt0RegAATVp16GHRURB6AABNWsiZsTxHT5XLGOPl1sCbCD0AgCYt2Hn6Sk+V+XH6OuyJ0AMAaNKa+ftaPx8v4xaXnRF6AABNmo+PQ80CTgef44zrsTVCDwCgyau+xXW8tNLLLYE3EXoAAE1ecPWVHm5v2RqhBwDQ5DULYNo6CD0AABu4KrD69hahx84IPQCAJq963a2Sk4QeOyP0AACavObNToeeA8UnvdwSeBOhBwDQ5F11ZvbW9OXfefxcB0tOqehYqcfPg7oj9AAAmrx9P/x4hWfE/36lXd8fu6LHP1ZaoZ3fH1OXCcvUdeKnSvjTMpVWXHp6fPo/NuuzHYXn7f+Pv2Up5oWFeuStL2v83PYDJRc8/rrcw/qu8KiMMW7LbhhjdOR4mfV6ypJv9P8t/sZ6T5JOlp1/zKoqo417jqiisvZPsy4oPqWYFxbqV29m1voz9cFhWIhEklRSUqKwsDAVFxcrNDTU280BAFxBn3/7vQbPXOe273//I1EPv5mp22Ou1rq8w9b+nRP7ytfHofELtmnm6twLHvP6yKuUc/Di4WnV73vqrimfSZLyJveTJFVWGZVVVKnjHxa71X48rLsefP2L847x69vbqX1EM03+5Bs5HNLZf7VzJ/XVt4XH1Oc/V12wDTdEhmjHwaMXbefZWjcPskJicICvjp8ThNa8cLd6TF5uvY5tEayBP2ujKUt2SJIiQ506WPLjla5l6Un6xbSVkqTvJtwrP98re72lLn+/CT1nEHoAoGmLeWGht5ugR7u20+y1e7zdDK9JjI3Q35/odkWPWZe/39zeAgDYwpRf3uLtJtg68EjSL+PbePX8hB4AgC1c2/Iqj58jc9TdWju6l8fPc6XdHnO1Nr7U+6I1/zWo808+z0Avhx4/r54dAIB6cnPrMN15nUsRwQE6WFKqzF2HJElrR/dSyxCnJGnxlgLN2bhXy7YX6jfd2ulPD9zsdoyFXx/QibIKdb/WpZ2Fx3THtS7Fjl6k1s2DtOaFu626vMn9ZIzRnsMnFBUWqMmffKNZa/JqbFfupL6qMtLcTft0b1yUPt60T+9/kaclz96lAyWnVFZRpQM/nFS32Ag5HNJ7X+Rp1bdFWv7NjwOgb23bXF/l/yDp9Diix95ZKyNp0oM368jxcvW7pZVVe7DklLpO/FSSNOHf4nR9ZIi6xFxttVuSyiurdN2YT6zPPPuL6zSgc2uNX7BNRcfK9Fhie40bEGe9v+a7IhUUn9KIj75y64PiE+WavvxbPZEUq5YhgbX65+RJjOk5gzE9AIDGoqyiSp9uP6jEayIUEugvXx+Hdh86rlZhQQrwu/RNnO8Kj2rP4RO6+8bIC9ZUVRk5HJLD4ah1uwqPntKMFbv0zC+uU1iQf60/91MwkPkyEHoAAGh8GMgMAABwDkIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBY+GnpycHA0YMEAul0uhoaHq0aOHPvvsM7cah8Nx3jZjxgy3muzsbCUlJSkoKEitW7fWuHHjdO5M+5UrVyo+Pl6BgYGKjY097xgAAMDePBp6+vXrp4qKCi1fvlwbNmxQ586d1b9/fxUUFLjVzZo1SwcOHLC2IUOGWO+VlJSod+/eio6O1vr16zV9+nRNnTpV06ZNs2pyc3PVt29f3Xnnndq0aZNGjx6tp59+WnPmzPHk1wMAAI2Ixx5OWFRUpBYtWmjVqlW68847JUlHjx5VaGioli1bpl69Tq9N4nA4NHfuXD3wwAM1HueNN97QqFGjdPDgQTmdpx8TPnnyZE2fPl179+6Vw+HQ888/r3nz5mn79u3W51JTU/XVV18pMzOzVu3l4YQAADQ+DeLhhBEREerYsaPef/99HT9+XBUVFXrzzTcVGRmp+Ph4t9q0tDS5XC516dJFM2bMUFVVlfVeZmamkpKSrMAjScnJydq/f7/y8vKsmj59+rgdMzk5WVlZWSovL6+xfaWlpSopKXHbAABA0+WxBUcdDoeWLl2qAQMGKCQkRD4+PoqMjNTixYvVvHlzq278+PHq1auXgoKC9Omnn2rEiBEqKirSiy++KEkqKChQTEyM27EjIyOt9zp06KCCggJr39k1FRUVKioqUqtWrXSuSZMm6Y9//OOV/dIAAKDBqvOVnrFjx9Y4+PjsLSsrS8YYDRs2TC1bttTnn3+udevWacCAAerfv78OHDhgHe/FF19UYmKiOnfurBEjRmjcuHGaMmWK2znPXeys+o7c2ftrU3O2UaNGqbi42Nry8/Pr2hUAAKARqfOVnrS0NA0aNOiiNTExMVq+fLkWLFigI0eOWPfYXn/9dS1dulTvvfeeXnjhhRo/261bN5WUlOjgwYOKjIxUVFTUeQOfCwsLJf14xedCNX5+foqIiKjxPE6n0+2WGQAAaNrqHHpcLpdcLtcl606cOCFJ8vFxv5jk4+PjNmbnXJs2bVJgYKB1CywxMVGjR49WWVmZAgICJEkZGRmKjo62bnslJiZq/vz5bsfJyMhQQkKC/P1rt7R99ZUhxvYAANB4VP/drtW8LOMh33//vYmIiDAPPvig2bx5s9mxY4cZOXKk8ff3N5s3bzbGGDNv3jzz17/+1WRnZ5vvvvvOvPXWWyY0NNQ8/fTT1nF++OEHExkZaX7961+b7Oxs8/HHH5vQ0FAzdepUq2bXrl2mWbNm5rnnnjPbtm0zM2fONP7+/uaf//xnrdubn59vJLGxsbGxsbE1wi0/P/+Sf+s9NmVdkrKysjRmzBhrFtVNN92kP/zhD7r33nslSYsXL9aoUaP03XffqaqqSrGxsfrd736n4cOHy8/vx4tQ2dnZGj58uNatW6fw8HClpqbqD3/4g9t4nZUrV+q5557T1q1bFR0dreeff16pqam1bmtVVZX279+vkJCQC44DuhwlJSVq27at8vPzmQpfC/RX7dFXtUdf1Q39VXv0Ve15qq+MMTp69Kiio6PPu7t0Lo+GHvD8n7qiv2qPvqo9+qpu6K/ao69qryH0FWtvAQAAWyD0AAAAWyD0eJjT6dTLL7/M9Phaor9qj76qPfqqbuiv2qOvaq8h9BVjegAAgC1wpQcAANgCoQcAANgCoQcAANgCoQcAANgCocfDXn/9dXXo0EGBgYGKj4/X559/7u0mXVGrVq3Sfffdp+joaDkcDv3f//2f2/vGGI0dO1bR0dEKCgrSz3/+c23dutWtprS0VE899ZRcLpeCg4N1//33a+/evW41R44c0eDBgxUWFqawsDANHjxYP/zwg1vNnj17dN999yk4OFgul0tPP/20ysrKPPG1L8ukSZPUpUsXhYSEqGXLlnrggQe0Y8cOtxr667Q33nhDt9xyi0JDQxUaGqrExER98skn1vv004VNmjRJDodDzz77rLWP/vrR2LFj5XA43LaoqCjrffrK3b59+/Sb3/xGERERatasmTp37qwNGzZY7ze6/qr14lSosw8//ND4+/ubt956y2zbts0888wzJjg42OzevdvbTbtiFi1aZMaMGWPmzJljJJm5c+e6vT958mQTEhJi5syZY7Kzs82vfvUr06pVK1NSUmLVpKammtatW5ulS5eajRs3mp49e5pbb73VVFRUWDX33HOPiYuLM1988YX54osvTFxcnOnfv7/1fkVFhYmLizM9e/Y0GzduNEuXLjXR0dEmLS3N431QW8nJyWbWrFlmy5YtZvPmzaZfv36mXbt25tixY1YN/XXavHnzzMKFC82OHTvMjh07zOjRo42/v7/ZsmWLMYZ+upB169aZmJgYc8stt5hnnnnG2k9//ejll182N910kzlw4IC1FRYWWu/TVz86fPiwad++vUlJSTFr1641ubm5ZtmyZea7776zahpbfxF6POj22283qampbvtuvPFG88ILL3ipRZ51buipqqoyUVFRZvLkyda+U6dOmbCwMDNjxgxjzOkFZf39/c2HH35o1ezbt8/4+PiYxYsXG2OM2bZtm5FkvvzyS6smMzPTSDLffPONMeZ0+PLx8TH79u2zav7+978bp9NpiouLPfJ9f6rCwkIjyaxcudIYQ39dSnh4uHn77bfppws4evSoue6668zSpUtNUlKSFXroL3cvv/yyufXWW2t8j75y9/zzz5s77rjjgu83xv7i9paHlJWVacOGDerTp4/b/j59+uiLL77wUqvqV25urgoKCtz6wOl0KikpyeqDDRs2qLy83K0mOjpacXFxVk1mZqbCwsLUtWtXq6Zbt24KCwtzq4mLi1N0dLRVk5ycrNLSUrdLsQ1JcXGxJOnqq6+WRH9dSGVlpT788EMdP35ciYmJ9NMFDB8+XP369dMvfvELt/301/m+/fZbRUdHq0OHDho0aJB27dolib4617x585SQkKCHHnpILVu21G233aa33nrLer8x9hehx0OKiopUWVmpyMhIt/2RkZEqKCjwUqvqV/X3vFgfFBQUKCAgQOHh4Retadmy5XnHb9mypVvNuecJDw9XQEBAg+xvY4zS09N1xx13KC4uThL9da7s7GxdddVVcjqdSk1N1dy5c9WpUyf6qQYffvihNm7cqEmTJp33Hv3lrmvXrnr//fe1ZMkSvfXWWyooKFD37t116NAh+uocu3bt0htvvKHrrrtOS5YsUWpqqp5++mm9//77khrn75ZfrStxWRwOh9trY8x5+5q6y+mDc2tqqr+cmoYiLS1NX3/9tVavXn3ee/TXaTfccIM2b96sH374QXPmzNGQIUO0cuVK63366bT8/Hw988wzysjIUGBg4AXr6K/T7r33Xuvnm2++WYmJibrmmmv03nvvqVu3bpLoq2pVVVVKSEjQxIkTJUm33Xabtm7dqjfeeEOPPfaYVdeY+osrPR7icrnk6+t7XgItLCw8L602VdUzIi7WB1FRUSorK9ORI0cuWnPw4MHzjv/999+71Zx7niNHjqi8vLzB9fdTTz2lefPm6bPPPlObNm2s/fSXu4CAAF177bVKSEjQpEmTdOutt+q//uu/6KdzbNiwQYWFhYqPj5efn5/8/Py0cuVKvfrqq/Lz87PaSX/VLDg4WDfffLO+/fZbfrfO0apVK3Xq1MltX8eOHbVnzx5JjfO/WYQeDwkICFB8fLyWLl3qtn/p0qXq3r27l1pVvzp06KCoqCi3PigrK9PKlSutPoiPj5e/v79bzYEDB7RlyxarJjExUcXFxVq3bp1Vs3btWhUXF7vVbNmyRQcOHLBqMjIy5HQ6FR8f79HvWVvGGKWlpenjjz/W8uXL1aFDB7f36a+LM8aotLSUfjpHr169lJ2drc2bN1tbQkKCHn30UW3evFmxsbH010WUlpZq+/btatWqFb9b5+jRo8d5j9XIyclR+/btJTXS/2bVesgz6qx6yvrMmTPNtm3bzLPPPmuCg4NNXl6et5t2xRw9etRs2rTJbNq0yUgy06ZNM5s2bbKm5U+ePNmEhYWZjz/+2GRnZ5tf//rXNU5nbNOmjVm2bJnZuHGjufvuu2ucznjLLbeYzMxMk5mZaW6++eYapzP26tXLbNy40Sxbtsy0adOmQU3/fPLJJ01YWJhZsWKF23TZEydOWDX012mjRo0yq1atMrm5uebrr782o0ePNj4+PiYjI8MYQz9dytmzt4yhv842YsQIs2LFCrNr1y7z5Zdfmv79+5uQkBDrv8v01Y/WrVtn/Pz8zIQJE8y3335rZs+ebZo1a2Y++OADq6ax9Rehx8Nee+010759exMQEGB+9rOfWdOTm4rPPvvMSDpvGzJkiDHm9JTGl19+2URFRRmn02nuuusuk52d7XaMkydPmrS0NHP11VeboKAg079/f7Nnzx63mkOHDplHH33UhISEmJCQEPPoo4+aI0eOuNXs3r3b9OvXzwQFBZmrr77apKWlmVOnTnny69dJTf0kycyaNcuqob9O++1vf2v9e9OiRQvTq1cvK/AYQz9dyrmhh/76UfVzZPz9/U10dLR58MEHzdatW6336St38+fPN3FxccbpdJobb7zR/PWvf3V7v7H1l8MYY2p/XQgAAKBxYkwPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwhf8fYOkktmwuj/oAAAAASUVORK5CYII=", "text/plain": [ - "[alpha_loc, alpha_scale, beta_loc, beta_scale, sigma_loc, sigma_scale]" + "
" ] }, - "execution_count": 166, "metadata": {}, - "output_type": "execute_result" + "output_type": "display_data" } ], "source": [ - "params" + "import matplotlib.pyplot as plt\n", + "\n", + "window_size = 100\n", + "kernel = np.full(window_size, 1 / window_size)\n", + "plt.plot(np.convolve(loss_history, kernel, mode=\"valid\"))" ] }, { "cell_type": "code", - "execution_count": 106, - "id": "310c739f", + "execution_count": 11, + "id": "1251b6f1", "metadata": {}, "outputs": [ { "data": { "text/plain": [ - "{alpha: alpha, beta: beta, sigma_log__: sigma}" + "array([102.12462346, 10.74446646, 0.51969095, -7.61435818,\n", + " 8.55366616, -8.5301462 , 0.69953323, -0.55440606,\n", + " -2.43179013, -5.36278597, -1.29241817, -7.09759975])" ] }, - "execution_count": 106, + "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "inputs_to_guide_inputs" + "np.r_[param_dict[\"alpha_loc\"], param_dict[\"beta_loc\"], param_dict[\"sigma_loc\"]]" ] }, { "cell_type": "code", - "execution_count": 22, - "id": "2fe8da0a", - "metadata": {}, - "outputs": [], - "source": [ - "def compute_loss(m_obs, m_guide, beta=1.0):\n", - " return -m_obs.datalogp + beta * (m_guide.logp() - m_obs.varlogp)" - ] - }, - { - "cell_type": "code", - "execution_count": 24, - "id": "b04b9b5c", + "execution_count": 12, + "id": "11345fb4", "metadata": {}, "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Add [id A]\n", - " ├─ Neg [id B]\n", - " │ └─ Add [id C]\n", - " │ ├─ Sum{axes=None} [id D] '__logp'\n", - " │ │ └─ MakeVector{dtype='float64'} [id E]\n", - " │ │ └─ Sum{axes=None} [id F]\n", - " │ │ └─ Check{sigma > 0} [id G]\n", - " │ │ ├─ Sub [id H]\n", - " │ │ │ ├─ Sub [id I]\n", - " │ │ │ │ ├─ Mul [id J]\n", - " │ │ │ │ │ ├─ ExpandDims{axis=0} [id K]\n", - " │ │ │ │ │ │ └─ -0.5 [id L]\n", - " │ │ │ │ │ └─ Pow [id M]\n", - " │ │ │ │ │ ├─ True_div [id N]\n", - " │ │ │ │ │ │ ├─ Sub [id O]\n", - " │ │ │ │ │ │ │ ├─ [122.65317 ... .32067026] [id P]\n", - " │ │ │ │ │ │ │ └─ Add [id Q]\n", - " │ │ │ │ │ │ │ ├─ ExpandDims{axis=0} [id R]\n", - " │ │ │ │ │ │ │ │ └─ alpha [id S]\n", - " │ │ │ │ │ │ │ └─ Squeeze{axis=1} [id T]\n", - " │ │ │ │ │ │ │ └─ Blockwise{dot, (m,k),(k,n)->(m,n)} [id U]\n", - " │ │ │ │ │ │ │ ├─ X [id V]\n", - " │ │ │ │ │ │ │ └─ ExpandDims{axis=1} [id W]\n", - " │ │ │ │ │ │ │ └─ beta [id X]\n", - " │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Y]\n", - " │ │ │ │ │ │ └─ Exp [id Z]\n", - " │ │ │ │ │ │ └─ sigma_log__ [id BA]\n", - " │ │ │ │ │ └─ ExpandDims{axis=0} [id BB]\n", - " │ │ │ │ │ └─ 2 [id BC]\n", - " │ │ │ │ └─ ExpandDims{axis=0} [id BD]\n", - " │ │ │ │ └─ Log [id BE]\n", - " │ │ │ │ └─ Sqrt [id BF]\n", - " │ │ │ │ └─ 6.283185307179586 [id BG]\n", - " │ │ │ └─ Log [id BH]\n", - " │ │ │ └─ ExpandDims{axis=0} [id Y]\n", - " │ │ │ └─ ···\n", - " │ │ └─ All{axes=None} [id BI]\n", - " │ │ └─ MakeVector{dtype='bool'} [id BJ]\n", - " │ │ └─ All{axes=None} [id BK]\n", - " │ │ └─ Gt [id BL]\n", - " │ │ ├─ ExpandDims{axis=0} [id Y]\n", - " │ │ │ └─ ···\n", - " │ │ └─ ExpandDims{axis=0} [id BM]\n", - " │ │ └─ 0 [id BN]\n", - " │ └─ 0.0 [id BO]\n", - " └─ Mul [id BP]\n", - " ├─ 1.0 [id BQ]\n", - " └─ Sub [id BR]\n", - " ├─ Sum{axes=None} [id BS] '__logp'\n", - " │ └─ MakeVector{dtype='float64'} [id BT]\n", - " │ ├─ Sum{axes=None} [id BU]\n", - " │ │ └─ Check{sigma > 0} [id BV] 'alpha_logprob'\n", - " │ │ ├─ Sub [id BW]\n", - " │ │ │ ├─ Sub [id BX]\n", - " │ │ │ │ ├─ Mul [id BY]\n", - " │ │ │ │ │ ├─ -0.5 [id BZ]\n", - " │ │ │ │ │ └─ Pow [id CA]\n", - " │ │ │ │ │ ├─ True_div [id CB]\n", - " │ │ │ │ │ │ ├─ Sub [id CC]\n", - " │ │ │ │ │ │ │ ├─ alpha [id CD]\n", - " │ │ │ │ │ │ │ └─ alpha_loc [id CE]\n", - " │ │ │ │ │ │ └─ alpha_scale [id CF]\n", - " │ │ │ │ │ └─ 2 [id CG]\n", - " │ │ │ │ └─ Log [id CH]\n", - " │ │ │ │ └─ Sqrt [id CI]\n", - " │ │ │ │ └─ 6.283185307179586 [id CJ]\n", - " │ │ │ └─ Log [id CK]\n", - " │ │ │ └─ alpha_scale [id CF]\n", - " │ │ └─ All{axes=None} [id CL]\n", - " │ │ └─ MakeVector{dtype='bool'} [id CM]\n", - " │ │ └─ All{axes=None} [id CN]\n", - " │ │ └─ Gt [id CO]\n", - " │ │ ├─ alpha_scale [id CF]\n", - " │ │ └─ 0 [id CP]\n", - " │ ├─ Sum{axes=None} [id CQ]\n", - " │ │ └─ Check{sigma > 0} [id CR] 'beta_logprob'\n", - " │ │ ├─ Sub [id CS]\n", - " │ │ │ ├─ Sub [id CT]\n", - " │ │ │ │ ├─ Mul [id CU]\n", - " │ │ │ │ │ ├─ ExpandDims{axis=0} [id CV]\n", - " │ │ │ │ │ │ └─ -0.5 [id CW]\n", - " │ │ │ │ │ └─ Pow [id CX]\n", - " │ │ │ │ │ ├─ True_div [id CY]\n", - " │ │ │ │ │ │ ├─ Sub [id CZ]\n", - " │ │ │ │ │ │ │ ├─ beta [id DA]\n", - " │ │ │ │ │ │ │ └─ beta_loc [id DB]\n", - " │ │ │ │ │ │ └─ beta_scale [id DC]\n", - " │ │ │ │ │ └─ ExpandDims{axis=0} [id DD]\n", - " │ │ │ │ │ └─ 2 [id DE]\n", - " │ │ │ │ └─ ExpandDims{axis=0} [id DF]\n", - " │ │ │ │ └─ Log [id DG]\n", - " │ │ │ │ └─ Sqrt [id DH]\n", - " │ │ │ │ └─ 6.283185307179586 [id DI]\n", - " │ │ │ └─ Log [id DJ]\n", - " │ │ │ └─ beta_scale [id DC]\n", - " │ │ └─ All{axes=None} [id DK]\n", - " │ │ └─ MakeVector{dtype='bool'} [id DL]\n", - " │ │ └─ All{axes=None} [id DM]\n", - " │ │ └─ Gt [id DN]\n", - " │ │ ├─ beta_scale [id DC]\n", - " │ │ └─ ExpandDims{axis=0} [id DO]\n", - " │ │ └─ 0 [id DP]\n", - " │ └─ Sum{axes=None} [id DQ]\n", - " │ └─ Add [id DR] 'sigma_log___logprob'\n", - " │ ├─ Check{sigma > 0} [id DS]\n", - " │ │ ├─ Sub [id DT]\n", - " │ │ │ ├─ Sub [id DU]\n", - " │ │ │ │ ├─ Mul [id DV]\n", - " │ │ │ │ │ ├─ -0.5 [id DW]\n", - " │ │ │ │ │ └─ Pow [id DX]\n", - " │ │ │ │ │ ├─ True_div [id DY]\n", - " │ │ │ │ │ │ ├─ Sub [id DZ]\n", - " │ │ │ │ │ │ │ ├─ Exp [id EA]\n", - " │ │ │ │ │ │ │ │ └─ sigma_log__ [id EB]\n", - " │ │ │ │ │ │ │ └─ sigma_loc [id EC]\n", - " │ │ │ │ │ │ └─ sigma_scale [id ED]\n", - " │ │ │ │ │ └─ 2 [id EE]\n", - " │ │ │ │ └─ Log [id EF]\n", - " │ │ │ │ └─ Sqrt [id EG]\n", - " │ │ │ │ └─ 6.283185307179586 [id EH]\n", - " │ │ │ └─ Log [id EI]\n", - " │ │ │ └─ sigma_scale [id ED]\n", - " │ │ └─ All{axes=None} [id EJ]\n", - " │ │ └─ MakeVector{dtype='bool'} [id EK]\n", - " │ │ └─ All{axes=None} [id EL]\n", - " │ │ └─ Gt [id EM]\n", - " │ │ ├─ sigma_scale [id ED]\n", - " │ │ └─ 0 [id EN]\n", - " │ └─ Identity [id EO] 'sigma_log___log_jacobian'\n", - " │ └─ sigma_log__ [id EB]\n", - " └─ Sum{axes=None} [id EP] '__logp'\n", - " └─ MakeVector{dtype='float64'} [id EQ]\n", - " ├─ Sum{axes=None} [id ER]\n", - " │ └─ Check{sigma > 0} [id ES] 'X_logprob'\n", - " │ ├─ Sub [id ET]\n", - " │ │ ├─ Sub [id EU]\n", - " │ │ │ ├─ Mul [id EV]\n", - " │ │ │ │ ├─ ExpandDims{axes=[0, 1]} [id EW]\n", - " │ │ │ │ │ └─ -0.5 [id EX]\n", - " │ │ │ │ └─ Pow [id EY]\n", - " │ │ │ │ ├─ True_div [id EZ]\n", - " │ │ │ │ │ ├─ Sub [id FA]\n", - " │ │ │ │ │ │ ├─ X [id V]\n", - " │ │ │ │ │ │ └─ [[0]] [id FB]\n", - " │ │ │ │ │ └─ [[1]] [id FC]\n", - " │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id FD]\n", - " │ │ │ │ └─ 2 [id FE]\n", - " │ │ │ └─ ExpandDims{axes=[0, 1]} [id FF]\n", - " │ │ │ └─ Log [id FG]\n", - " │ │ │ └─ Sqrt [id FH]\n", - " │ │ │ └─ 6.283185307179586 [id FI]\n", - " │ │ └─ Log [id FJ]\n", - " │ │ └─ [[1]] [id FC]\n", - " │ └─ All{axes=None} [id FK]\n", - " │ └─ MakeVector{dtype='bool'} [id FL]\n", - " │ └─ All{axes=None} [id FM]\n", - " │ └─ Gt [id FN]\n", - " │ ├─ [[1]] [id FC]\n", - " │ └─ ExpandDims{axes=[0, 1]} [id FO]\n", - " │ └─ 0 [id FP]\n", - " ├─ Sum{axes=None} [id FQ]\n", - " │ └─ Check{sigma > 0} [id FR] 'alpha_logprob'\n", - " │ ├─ Sub [id FS]\n", - " │ │ ├─ Sub [id FT]\n", - " │ │ │ ├─ Mul [id FU]\n", - " │ │ │ │ ├─ -0.5 [id FV]\n", - " │ │ │ │ └─ Pow [id FW]\n", - " │ │ │ │ ├─ True_div [id FX]\n", - " │ │ │ │ │ ├─ Sub [id FY]\n", - " │ │ │ │ │ │ ├─ alpha [id S]\n", - " │ │ │ │ │ │ └─ 100 [id FZ]\n", - " │ │ │ │ │ └─ 10 [id GA]\n", - " │ │ │ │ └─ 2 [id GB]\n", - " │ │ │ └─ Log [id GC]\n", - " │ │ │ └─ Sqrt [id GD]\n", - " │ │ │ └─ 6.283185307179586 [id GE]\n", - " │ │ └─ Log [id GF]\n", - " │ │ └─ 10 [id GA]\n", - " │ └─ All{axes=None} [id GG]\n", - " │ └─ MakeVector{dtype='bool'} [id GH]\n", - " │ └─ All{axes=None} [id GI]\n", - " │ └─ Gt [id GJ]\n", - " │ ├─ 10 [id GA]\n", - " │ └─ 0 [id GK]\n", - " ├─ Sum{axes=None} [id GL]\n", - " │ └─ Check{sigma > 0} [id GM] 'beta_logprob'\n", - " │ ├─ Sub [id GN]\n", - " │ │ ├─ Sub [id GO]\n", - " │ │ │ ├─ Mul [id GP]\n", - " │ │ │ │ ├─ ExpandDims{axis=0} [id GQ]\n", - " │ │ │ │ │ └─ -0.5 [id GR]\n", - " │ │ │ │ └─ Pow [id GS]\n", - " │ │ │ │ ├─ True_div [id GT]\n", - " │ │ │ │ │ ├─ Sub [id GU]\n", - " │ │ │ │ │ │ ├─ beta [id X]\n", - " │ │ │ │ │ │ └─ [0] [id GV]\n", - " │ │ │ │ │ └─ [5] [id GW]\n", - " │ │ │ │ └─ ExpandDims{axis=0} [id GX]\n", - " │ │ │ │ └─ 2 [id GY]\n", - " │ │ │ └─ ExpandDims{axis=0} [id GZ]\n", - " │ │ │ └─ Log [id HA]\n", - " │ │ │ └─ Sqrt [id HB]\n", - " │ │ │ └─ 6.283185307179586 [id HC]\n", - " │ │ └─ Log [id HD]\n", - " │ │ └─ [5] [id GW]\n", - " │ └─ All{axes=None} [id HE]\n", - " │ └─ MakeVector{dtype='bool'} [id HF]\n", - " │ └─ All{axes=None} [id HG]\n", - " │ └─ Gt [id HH]\n", - " │ ├─ [5] [id GW]\n", - " │ └─ ExpandDims{axis=0} [id HI]\n", - " │ └─ 0 [id HJ]\n", - " └─ Sum{axes=None} [id HK]\n", - " └─ Add [id HL] 'sigma_log___logprob'\n", - " ├─ Check{mu >= 0} [id HM]\n", - " │ ├─ Switch [id HN]\n", - " │ │ ├─ Ge [id HO]\n", - " │ │ │ ├─ Exp [id HP]\n", - " │ │ │ │ └─ sigma_log__ [id BA]\n", - " │ │ │ └─ 0.0 [id HQ]\n", - " │ │ ├─ Sub [id HR]\n", - " │ │ │ ├─ Neg [id HS]\n", - " │ │ │ │ └─ Log [id HT]\n", - " │ │ │ │ └─ 1.0 [id HU]\n", - " │ │ │ └─ True_div [id HV]\n", - " │ │ │ ├─ Exp [id HP]\n", - " │ │ │ │ └─ ···\n", - " │ │ │ └─ 1.0 [id HU]\n", - " │ │ └─ -inf [id HW]\n", - " │ └─ All{axes=None} [id HX]\n", - " │ └─ MakeVector{dtype='bool'} [id HY]\n", - " │ └─ All{axes=None} [id HZ]\n", - " │ └─ Ge [id IA]\n", - " │ ├─ 1.0 [id HU]\n", - " │ └─ 0 [id IB]\n", - " └─ Identity [id IC] 'sigma_log___log_jacobian'\n", - " └─ sigma_log__ [id BA]\n" - ] - }, { "data": { "text/plain": [ - "" + "array([102.16132511, 10.75292414, 0.54980953, -7.64875998,\n", + " 8.5053264 , -8.56422778, 0.70840797, -0.57081651,\n", + " -2.45245893, -5.30737734, -1.33080016, 0.25923082])" ] }, - "execution_count": 24, + "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "compute_loss(m_obs, guide_model).dprint()" + "true_params" ] }, { - "cell_type": "code", - "execution_count": null, - "id": "d155685f", - "metadata": {}, - "outputs": [], - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "98fe5067", + "cell_type": "markdown", + "id": "0cf321e6", "metadata": {}, - "outputs": [], "source": [ - "## TODO:\n", - "# 1. Create hyperparameters for mean field approx (mu + sigma of normals)\n", - "# 2. Replace in the logp" + "## Todo:\n", + "\n", + "- Does this \"two models\" frameworks fits into what we already have?\n", + "- `model_to_mean_field` transformation\n", + "- rsample --> stochastic gradients? Or automatic reparameterization?\n", + "- More flexible optimizers..." ] }, { "cell_type": "code", "execution_count": null, - "id": "e01dcf96", + "id": "77786d86", "metadata": {}, "outputs": [], "source": [] From ea2c917374af22abc998024f99991763a9101294 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 8 Jun 2025 13:02:04 +0200 Subject: [PATCH 3/5] Adding overview notebook --- VI_Overview.ipynb | 563 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 563 insertions(+) create mode 100644 VI_Overview.ipynb diff --git a/VI_Overview.ipynb b/VI_Overview.ipynb new file mode 100644 index 000000000..5eb66fc86 --- /dev/null +++ b/VI_Overview.ipynb @@ -0,0 +1,563 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c51c3c1a-553c-45e4-a92f-d75063187863", + "metadata": {}, + "source": [ + "# Variational Inference overview" + ] + }, + { + "cell_type": "markdown", + "id": "c0777ef6-dd90-4452-88af-69a53f0f7713", + "metadata": {}, + "source": [ + "## Existing Variational Inference implementation\n", + "\n", + "The best way to get a sense for the current implementation is to walk backwards from how it's used" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "3d5724fe-72db-4908-a464-46f7fac97309", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pymc as pm\n", + "import arviz as az" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "33437d00-60e6-4505-8b8e-8ebe64473c5b", + "metadata": {}, + "outputs": [], + "source": [ + "data = np.random.normal(size=10_000)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "894f9e31-90a1-4f13-b2bc-b75bdf996f78", + "metadata": {}, + "outputs": [], + "source": [ + "with pm.Model() as model:\n", + " d = pm.Data(\"data\", data)\n", + " batched_data = pm.Minibatch(d, batch_size=100)\n", + " x = pm.Normal(\"x\", 0., 1.)\n", + " y = pm.Normal(\"y\", x, total_size=len(data), observed=batched_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "15bc2997-8974-4d61-88ce-9423b215f84f", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a695972a8ca3415f9f8dd118ae6288dc", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Finished [100%]: Average Loss = 144.77\n" + ] + } + ], + "source": [ + "with model:\n", + " idata = pm.fit(n=10_000, method=\"advi\") " + ] + }, + { + "cell_type": "markdown", + "id": "d311e2f2-f264-4cb2-9287-21d6f5aad3e3", + "metadata": {}, + "source": [ + "But what does fit do? It roughly dispatches on the method. So the above is roughly equalivalent to:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ec3b637d-c6a2-46cb-99bc-87fc952bda3e", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "43686ad598a649b09a88b84480985eac", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Finished [100%]: Average Loss = 143.83\n" + ] + } + ], + "source": [ + "with model:\n", + " advi = pm.ADVI()\n", + " idata = advi.fit(n=100_000)" + ] + }, + { + "cell_type": "markdown", + "id": "bfbd2a63-b5da-41d3-a1d1-254934ad4923", + "metadata": {}, + "source": [ + "But what is this `ADVI` object? Well, if you look at it's implementation with the documentation removed, you see it's a type of `KLqp`\n", + "\n", + "````python\n", + "class ADVI(KLqp):\n", + " def __init__(self, *args, **kwargs):\n", + " super().__init__(MeanField(*args, **kwargs))\n", + "````\n", + "\n", + "So what's a `Klqp`? Look at it's implementation with the documentation removed, you see it's an Inference object\n", + "\n", + "````python\n", + "class KLqp(Inference):\n", + " def __init__(self, approx, beta=1.0):\n", + " super().__init__(KL, approx, None, beta=beta)\n", + "````\n", + "\n", + "So what's an `Inference` object? Look at it's implementation with the documentation removed we finally get a sense for what are the main abstraction we will be working with." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ce19d0bd-8a6b-4877-a4ee-5ee1fa481da8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[0;31mInit signature:\u001b[0m \u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mInference\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapprox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m \n", + "**Base class for Variational Inference**.\n", + "\n", + "Communicates Operator, Approximation and Test Function to build Objective Function\n", + "\n", + "Parameters\n", + "----------\n", + "op : Operator class #:class:`~pymc.variational.operators`\n", + "approx : Approximation class or instance #:class:`~pymc.variational.approximations`\n", + "tf : TestFunction instance #?\n", + "model : Model\n", + " PyMC Model\n", + "kwargs : kwargs passed to :class:`Operator` #:class:`~pymc.variational.operators`, optional\n", + "\u001b[0;31mFile:\u001b[0m ~/upstream/pymc/pymc/variational/inference.py\n", + "\u001b[0;31mType:\u001b[0m type\n", + "\u001b[0;31mSubclasses:\u001b[0m KLqp, ImplicitGradient" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pm.Inference?" + ] + }, + { + "cell_type": "markdown", + "id": "73873c2f-11ed-43a3-b256-2e65816343b9", + "metadata": {}, + "source": [ + "Now things are falling into place. The `Inference` class is the way we perform variational inference. This is where the actual fit machinery lives. It also highlights what we need to do variational inference. We need a `Model`, an `Operator`, and an `Approximation`. We already know for `ADVI`, that the `Operator` is `KL` and the `Approximation` is `MeanField`.\n", + "\n", + "But what do these things mean? And how are they combined to perform inference?\n", + "\n", + "Well the `__init__` method of `Inference` makes it where we can find our answer" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0a77e663-8982-4a8a-bafc-290da5f45838", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[0;31mSignature:\u001b[0m \u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mInference\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapprox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m Initialize self. See help(type(self)) for accurate signature.\n", + "\u001b[0;31mSource:\u001b[0m \n", + " \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapprox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobjective\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mapprox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFile:\u001b[0m ~/upstream/pymc/pymc/variational/inference.py\n", + "\u001b[0;31mType:\u001b[0m function" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pm.Inference.__init__??" + ] + }, + { + "cell_type": "markdown", + "id": "0564cc40-2d42-476d-8fa4-91a063bff433", + "metadata": {}, + "source": [ + "Alright, so let's go ahead and explore the operator `KL`\n", + "\n", + "````python\n", + "class KL(Operator):\n", + " def __init__(self, approx, beta=1.0):\n", + " super().__init__(approx)\n", + " self.beta = pm.floatX(beta)\n", + "\n", + " def apply(self, f):\n", + " return -self.datalogp_norm + self.beta * (self.logq_norm - self.varlogp_norm)\n", + "````\n", + "\n", + "We see no `__call__` but we see a call to the `__init__` of `Operator`. For the `apply` method we see what looks like the ELBO. Let's for now inline for `ADVI` case and see what we get" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1752272d-b32d-4bea-9c3c-1e331b7fda9a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "objective = pm.operators.KL(pm.MeanField(model=model))(None)\n", + "objective" + ] + }, + { + "cell_type": "markdown", + "id": "cb70a473-6ca3-471c-b051-9a3cee666bd6", + "metadata": {}, + "source": [ + "So how'd that happen? Well if you look in the `Objective` class you see\n", + "\n", + "````python\n", + " objective_class = ObjectiveFunction\n", + "\n", + " def __call__(self, f=None):\n", + " if self.has_test_function:\n", + " if f is None:\n", + " raise ParametrizationError(f\"Operator {self} requires TestFunction\")\n", + " else:\n", + " if not isinstance(f, TestFunction):\n", + " f = TestFunction.from_function(f)\n", + " else:\n", + " if f is not None:\n", + " warnings.warn(f\"TestFunction for {self} is redundant and removed\", stacklevel=3)\n", + " else:\n", + " pass\n", + " f = TestFunction()\n", + " f.setup(self.approx)\n", + " return self.objective_class(self, f)\n", + "````\n", + "\n", + "Which finally brings us to `ObjectiveFunction`\n", + "\n", + "This is the function that sets up the actual loss functions and does the updates on it." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "339f72bf-a40d-4f84-9161-99a20ad090cb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[0;31mSignature:\u001b[0m\n", + "\u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopvi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mObjectiveFunction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mobj_n_mc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtf_n_mc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mobj_optimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m<\u001b[0m\u001b[0mfunction\u001b[0m \u001b[0madagrad_window\u001b[0m \u001b[0mat\u001b[0m \u001b[0;36m0x70ee648da480\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtest_optimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m<\u001b[0m\u001b[0mfunction\u001b[0m \u001b[0madagrad_window\u001b[0m \u001b[0mat\u001b[0m \u001b[0;36m0x70ee648da480\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mmore_obj_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mmore_tf_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mmore_updates\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mmore_replacements\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtotal_grad_norm_constraint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mscore\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mcompile_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mfn_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m\n", + "Step function that should be called on each optimization step.\n", + "\n", + "Generally it solves the following problem:\n", + "\n", + ".. math::\n", + "\n", + " \\mathbf{\\lambda^{\\*}} = \\inf_{\\lambda} \\sup_{\\theta} t(\\mathbb{E}_{\\lambda}[(O^{p,q}f_{\\theta})(z)])\n", + "\n", + "Parameters\n", + "----------\n", + "obj_n_mc: `int`\n", + " Number of monte carlo samples used for approximation of objective gradients\n", + "tf_n_mc: `int`\n", + " Number of monte carlo samples used for approximation of test function gradients\n", + "obj_optimizer: function (grads, params) -> updates\n", + " Optimizer that is used for objective params\n", + "test_optimizer: function (grads, params) -> updates\n", + " Optimizer that is used for test function params\n", + "more_obj_params: `list`\n", + " Add custom params for objective optimizer\n", + "more_tf_params: `list`\n", + " Add custom params for test function optimizer\n", + "more_updates: `dict`\n", + " Add custom updates to resulting updates\n", + "total_grad_norm_constraint: `float`\n", + " Bounds gradient norm, prevents exploding gradient problem\n", + "score: `bool`\n", + " calculate loss on each step? Defaults to False for speed\n", + "compile_kwargs: `dict`\n", + " Add kwargs to pytensor.function (e.g. `{'profile': True}`)\n", + "fn_kwargs: dict\n", + " arbitrary kwargs passed to `pytensor.function`\n", + "\n", + " .. warning:: `fn_kwargs` is deprecated and will be removed in future versions\n", + "\n", + "more_replacements: `dict`\n", + " Apply custom replacements before calculating gradients\n", + "\n", + "Returns\n", + "-------\n", + "`pytensor.function`\n", + "\u001b[0;31mFile:\u001b[0m ~/upstream/pymc/pymc/variational/opvi.py\n", + "\u001b[0;31mType:\u001b[0m function" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pm.opvi.ObjectiveFunction.step_function?" + ] + }, + { + "cell_type": "markdown", + "id": "83b5504c-6a99-40eb-8b45-b485a8618ecd", + "metadata": {}, + "source": [ + "## Proposed Improvements\n", + "\n", + "There is a lot to like here, but there is also a lot of indirection. Further, much of it isn't used for the `ADVI` case. This is all in service of `SVGD` and `ASVGD`\n", + "\n", + "Further, the `Inference` class has to be aware of too many of these details. Ideally the `Inference` should be reworked to only take in a step function. It could be re-named `Trainer` to match what's in PyTorch Lightning. I think forcing all `VI` through `OPVI` makes it more challenging to write and port new `VI` algorithms to `pymc`" + ] + }, + { + "cell_type": "markdown", + "id": "f7e83647-ac36-4043-8bb5-8fcf5581ffa3", + "metadata": {}, + "source": [ + "### PyTorch Lightning and Optax optimization" + ] + }, + { + "cell_type": "markdown", + "id": "ece80fe0-e332-444d-b1e3-761aca6a02e8", + "metadata": {}, + "source": [ + "How would this look? One possibility is having each Variational Inference technique encapsulated into an object that takes a model and optimizer as inputs and provides a step function as a method.\n", + "\n", + "````python\n", + "class ADVI(Inference):\n", + " def __init__(self, model=None, optimizers=None):\n", + " ...\n", + "\n", + " def step(self, batched_data):\n", + " ...\n", + " return loss\n", + "````\n", + "\n", + "This is then passed to a `Trainer` object for fitting\n", + "\n", + "````python\n", + "with model:\n", + " trainer = Trainer(method=ADVI(), dataloader= ...)\n", + " trainer.fit(n=10_000)\n", + "````\n", + "\n", + "Under this setup most of the optimization logic moves into the `__init__` and `step` methods. As for how those should happen. I think this can be handled separately. But something like optax might not be so bad. So we could end with code that resembles the below\n", + "\n", + "````python\n", + "class ADVI(Inference):\n", + " def __init__(self, model=None, optimizers=None):\n", + " if model is None:\n", + " model = modelcontext(None)\n", + " if optimizers is None:\n", + " optimizers = [pm.opt.Adam(1e-3)]\n", + " self.optimizer = optimizers[0]\n", + " self.params = self.optimizer.init(model.basic_RVs)\n", + "\n", + " def step(self, batch):\n", + " loss = self.loss_function(self.params, batch)\n", + " grads = grad(loss)\n", + " self.params = self.optimizer.update(grads, self.params)\n", + " return loss\n", + "````" + ] + }, + { + "cell_type": "markdown", + "id": "d5572e2b-6a32-492d-9cb0-d003a14f490d", + "metadata": {}, + "source": [ + "### Model and Guide programs\n", + "\n", + "Additionally it would be nice if we could easily suppose variational inference with guide programs ala pyro/numpyro\n", + "\n", + "The way this could look is we define both as `pymc` models and then pass them to a `SVI` method\n", + "\n", + "````python\n", + "with pm.Model() as model:\n", + " x = pm.Normal(\"x\", 0, 1)\n", + " y = pm.Normal(\"y\", x, 1, observed=data)\n", + "\n", + "with pm.Model() as guide:\n", + " mu = pt.tensor(\"mu\", param=True)\n", + " sd = pt.tensor(\"sd\", param=True)\n", + " x = pm.Normal(\"x\", mu, sd)\n", + "\n", + "\n", + "with model:\n", + " trainer = Trainer(method=SVI(model, guide), dataloader= ...)\n", + " trainer.fit(n=10_000)\n", + "````\n", + "\n", + "Naturally, `SVI` is a very general inference method, and in fact we can re-define `ADVI` in terms of it. Following the lead of pyro/numpyro we can have a guide generation\n", + "\n", + "````python\n", + "with model:\n", + " guide = AutoGuide(model)\n", + " trainer = Trainer(method=SVI(model, guide), dataloader= ...)\n", + " trainer.fit(n=10_000)\n", + "````" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "50fcb3a1-4467-4ace-acdd-666e4f342984", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc-dev", + "language": "python", + "name": "pymc-dev" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From a870c7c63f848825b8a0f879fc17324c446df4f9 Mon Sep 17 00:00:00 2001 From: Rob Zinkov Date: Sun, 8 Jun 2025 13:38:48 +0200 Subject: [PATCH 4/5] Add minibatch proposal --- VI_Overview.ipynb | 39 +++++++++++++++++++++++++++++++++++++-- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/VI_Overview.ipynb b/VI_Overview.ipynb index 5eb66fc86..8e3ee91e8 100644 --- a/VI_Overview.ipynb +++ b/VI_Overview.ipynb @@ -460,7 +460,7 @@ " def __init__(self, model=None, optimizers=None):\n", " ...\n", "\n", - " def step(self, batched_data):\n", + " def step(self, batch):\n", " ...\n", " return loss\n", "````\n", @@ -506,6 +506,7 @@ "\n", "````python\n", "with pm.Model() as model:\n", + " data = pm.Data(\"data\", ...)\n", " x = pm.Normal(\"x\", 0, 1)\n", " y = pm.Normal(\"y\", x, 1, observed=data)\n", "\n", @@ -530,10 +531,44 @@ "````" ] }, + { + "cell_type": "markdown", + "id": "7f97c341-e9bb-4301-b452-d006d6408cec", + "metadata": {}, + "source": [ + "### Reworking Minibatch\n", + "\n", + "Another small change we should consider is moving `pm.Minibatch` out of the model. Max already has a [proposal](https://github.com/pymc-devs/pymc/issues/7496) that I think can be adopted with only a few changes.\n", + "\n", + "I think where before we explicitly minibatch the data, instead we have dataloaders that stream in updates to the model.\n", + "\n", + "````python\n", + "with pm.Model() as model:\n", + " data = pm.Data(\"data\", None)\n", + " x = pm.Normal(\"x\", 0, 1)\n", + " y = pm.Normal(\"y\", x, 1, observed=data)\n", + "\n", + "dataloader = pm.Dataloader(np.random.normal(10_000, 2), batch_size=64)\n", + "\n", + "with model:\n", + " trainer = Trainer(method=ADVI(), dataloader=dataloader)\n", + " trainer.fit(n=10_000)\n", + "````\n", + "\n", + "Importantly, the model doesn't need to know about the dataloader. We will need to tweak the inference object, but it's not so bad.\n", + "\n", + "````python\n", + "class ADVI(Inference):\n", + " def step(self, batch):\n", + " self.model.set_data(\"data\", batch)\n", + " ...\n", + "````" + ] + }, { "cell_type": "code", "execution_count": null, - "id": "50fcb3a1-4467-4ace-acdd-666e4f342984", + "id": "220ba769-fb8f-47a7-82b6-ab6ca13ad61e", "metadata": {}, "outputs": [], "source": [] From b7aa31d1ab4de1a9976cbcb3381aeabc54e1ff4d Mon Sep 17 00:00:00 2001 From: Jesse Grabowski Date: Sun, 8 Jun 2025 21:55:50 +0800 Subject: [PATCH 5/5] Example of an autoguide --- pymc/variational/autoguide.py | 95 +++++++++++++++++++ tests/variational/test_autoguide.py | 137 ++++++++++++++++++++++++++++ 2 files changed, 232 insertions(+) create mode 100644 pymc/variational/autoguide.py create mode 100644 tests/variational/test_autoguide.py diff --git a/pymc/variational/autoguide.py b/pymc/variational/autoguide.py new file mode 100644 index 000000000..be8906264 --- /dev/null +++ b/pymc/variational/autoguide.py @@ -0,0 +1,95 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytensor.tensor as pt + +from pytensor import Variable, graph_replace +from pytensor.graph import vectorize_graph + +import pymc as pm + +from pymc.model.core import Model + +ModelVariable = Variable | str + + +def AutoDiagonalNormal(model): + coords = model.coords + free_rvs = model.free_RVs + draws = pt.tensor("draws", shape=(), dtype="int64") + + with Model(coords=coords) as guide_model: + for rv in free_rvs: + loc = pt.tensor(f"{rv.name}_loc", shape=rv.type.shape) + scale = pt.tensor(f"{rv.name}_scale", shape=rv.type.shape) + z = pm.Normal( + f"{rv.name}_z", + mu=0, + sigma=1, + shape=(draws, *rv.type.shape), + transform=model.rvs_to_transforms[rv], + ) + pm.Deterministic( + rv.name, loc + scale * z, dims=model.named_vars_to_dims.get(rv.name, None) + ) + + return guide_model + + +def AutoFullRankNormal(model): + # TODO: Broken + + coords = model.coords + free_rvs = model.free_RVs + draws = pt.tensor("draws", shape=(), dtype="int64") + + rv_sizes = [np.prod(rv.type.shape) for rv in free_rvs] + total_size = np.sum(rv_sizes) + tril_size = total_size * (total_size + 1) // 2 + + locs = [pt.tensor(f"{rv.name}_loc", shape=rv.type.shape) for rv in free_rvs] + packed_L = pt.tensor("L", shape=(tril_size,), dtype="float64") + L = pm.expand_packed_triangular(packed_L) + + with Model(coords=coords) as guide_model: + z = pm.MvNormal( + "z", mu=np.zeros(total_size), cov=np.eye(total_size), size=(draws, total_size) + ) + params = pt.concatenate([loc.ravel() for loc in locs]) + L @ z + + cursor = 0 + + for rv, size in zip(free_rvs, rv_sizes): + pm.Deterministic( + rv.name, + params[cursor : cursor + size].reshape(rv.type.shape), + dims=model.named_vars_to_dims.get(rv.name, None), + ) + cursor += size + + return guide_model + + +def get_logp_logq(model, guide_model): + inputs_to_guide_rvs = { + model_value_var: guide_model[rv.name] + for rv, model_value_var in model.rvs_to_values.items() + if rv not in model.observed_RVs + } + + logp = vectorize_graph(model.logp(), inputs_to_guide_rvs) + logq = graph_replace(guide_model.logp(), guide_model.values_to_rvs) + + return logp, logq diff --git a/tests/variational/test_autoguide.py b/tests/variational/test_autoguide.py new file mode 100644 index 000000000..9ff53a38e --- /dev/null +++ b/tests/variational/test_autoguide.py @@ -0,0 +1,137 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytensor.tensor as pt +import pytest + +import pymc as pm + +from pymc.variational.autoguide import AutoDiagonalNormal, AutoFullRankNormal, get_logp_logq + +Parameter = pt.tensor + + +@pytest.fixture(scope="module") +def X_y_params(): + """Generate synthetic data for testing.""" + + rng = np.random.default_rng(sum(map(ord, "autoguide_test"))) + + alpha = rng.normal(loc=100, scale=10) + beta = rng.normal(loc=0, scale=1, size=(10,)) + + true_params = { + "alpha": alpha, + "beta": beta, + } + + X_data = rng.normal(size=(100, 10)) + y_data = alpha + X_data @ beta + + return X_data, y_data, true_params + + +@pytest.fixture(scope="module") +def model(X_y_params): + X_data, y_data, _ = X_y_params + + with pm.Model() as model: + X = pm.Data("X", X_data) + alpha = pm.Normal("alpha", 100, 10) + beta = pm.Normal("beta", 0, 5, size=(10,)) + + mu = alpha + X @ beta + sigma = pm.Exponential("sigma", 1) + y = pm.Normal("y", mu=mu, sigma=sigma, observed=y_data) + + return model + + +@pytest.fixture(scope="module") +def target_guide_model(X_y_params): + X_data, *_ = X_y_params + + draws = pt.tensor("draws", shape=(), dtype="int64") + + with pm.Model() as guide_model: + X = pm.Data("X", X_data) + + alpha_loc = Parameter("alpha_loc", shape=()) + alpha_scale = Parameter("alpha_scale", shape=()) + alpha_z = pm.Normal("alpha_z", mu=0, sigma=1, shape=(draws,)) + alpha = pm.Deterministic("alpha", alpha_loc + alpha_scale * alpha_z) + + beta_loc = Parameter("beta_loc", shape=(10,)) + beta_scale = Parameter("beta_scale", shape=(10,)) + beta_z = pm.Normal("beta_z", mu=0, sigma=1, shape=(draws, 10)) + beta = pm.Deterministic("beta", beta_loc + beta_scale * beta_z) + + sigma_loc = Parameter("sigma_loc", shape=()) + sigma_scale = Parameter("sigma_scale", shape=()) + sigma_z = pm.Normal( + "sigma_z", 0, 1, shape=(draws,), transform=pm.distributions.transforms.log + ) + sigma = pm.Deterministic("sigma", sigma_loc + sigma_scale * sigma_z) + + return guide_model + + +def test_diagonal_normal_autoguide(model, target_guide_model, X_y_params): + guide_model = AutoDiagonalNormal(model) + + logp, logq = get_logp_logq(model, guide_model) + logp_target, logq_target = get_logp_logq(model, target_guide_model) + + inputs = pm.inputvars(logp) + target_inputs = pm.inputvars(logp_target) + + expected_locs = [f"{var}_loc" for var in ["alpha", "beta", "sigma"]] + expected_scales = [f"{var}_scale" for var in ["alpha", "beta", "sigma"]] + + expected_inputs = expected_locs + expected_scales + ["draws"] + name_to_input = {input.name: input for input in inputs} + name_to_target_input = {input.name: input for input in target_inputs} + + assert all(input.name in expected_inputs for input in inputs), ( + "Guide inputs do not match expected inputs" + ) + + negative_elbo = (logq - logp).mean() + negative_elbo_target = (logq_target - logp_target).mean() + + fn = pm.compile( + [name_to_input[input] for input in expected_inputs], negative_elbo, random_seed=69420 + ) + fn_target = pm.compile( + [name_to_target_input[input] for input in expected_inputs], + negative_elbo_target, + random_seed=69420, + ) + + test_inputs = { + "alpha_loc": np.zeros(()), + "alpha_scale": np.ones(()), + "beta_loc": np.zeros(10), + "beta_scale": np.ones(10), + "sigma_loc": np.zeros(()), + "sigma_scale": np.ones(()), + "draws": 100, + } + + np.testing.assert_allclose(fn(**test_inputs), fn_target(**test_inputs)) + + +def test_full_mv_normal_guide(model, X_y_params): + guide_model = AutoFullRankNormal(model) + logp, logq = get_logp_logq(model, guide_model)