Skip to content

Commit 951c0ef

Browse files
Variational inference/Bayesian neural network: fixed data dims in ann_input/ann_output (#506)
* fixed data dims * pre-commit Co-authored-by: Oriol (ZBook) <oriol.abril.pla@gmail.com>
1 parent 5288e05 commit 951c0ef

File tree

2 files changed

+225
-75
lines changed

2 files changed

+225
-75
lines changed

examples/variational_inference/bayesian_neural_network_advi.ipynb

Lines changed: 221 additions & 73 deletions
Large diffs are not rendered by default.

examples/variational_inference/bayesian_neural_network_advi.myst.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,8 @@ def construct_nn(ann_input, ann_output):
130130
# "obs_id": np.arange(X_train.shape[0]),
131131
}
132132
with pm.Model(coords=coords) as neural_network:
133-
ann_input = pm.Data("ann_input", X_train, mutable=True)
134-
ann_output = pm.Data("ann_output", Y_train, mutable=True)
133+
ann_input = pm.Data("ann_input", X_train, mutable=True, dims=("obs_id", "train_cols"))
134+
ann_output = pm.Data("ann_output", Y_train, mutable=True, dims="obs_id")
135135
136136
# Weights from input to hidden layer
137137
weights_in_1 = pm.Normal(
@@ -157,6 +157,7 @@ def construct_nn(ann_input, ann_output):
157157
act_out,
158158
observed=ann_output,
159159
total_size=Y_train.shape[0], # IMPORTANT for minibatches
160+
dims="obs_id",
160161
)
161162
return neural_network
162163
@@ -340,6 +341,7 @@ You might argue that the above network isn't really deep, but note that we could
340341

341342
- This notebook was originally authored as a [blog post](https://twiecki.github.io/blog/2016/06/01/bayesian-deep-learning/) by Thomas Wiecki in 2016
342343
- Updated by Chris Fonnesbeck for PyMC v4 in 2022
344+
- Updated by Oriol Abril-Pla and Earl Bellinger in 2023
343345

344346
## Watermark
345347

0 commit comments

Comments
 (0)