Skip to content

Commit 159268e

Browse files
authored
plot_dependence fix bug when setting grid (#71)
1 parent 6644cb4 commit 159268e

File tree

1 file changed

+11
-4
lines changed

1 file changed

+11
-4
lines changed

pymc_bart/utils.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Utility function for variable selection and bart interpretability."""
22

3+
import warnings
4+
35
import arviz as az
46
import matplotlib.pyplot as plt
57
import numpy as np
@@ -287,6 +289,7 @@ def plot_dependence(
287289
y_mins.append(np.min(y_pred))
288290
new_y.append(np.array(y_pred).T)
289291

292+
new_y = np.array(new_y)
290293
if func is not None:
291294
new_y = [func(nyi) for nyi in new_y]
292295
shape = 1
@@ -299,6 +302,14 @@ def plot_dependence(
299302
fig, axes = plt.subplots(1, len(var_idx) * shape, sharey=sharey, figsize=figsize)
300303
elif isinstance(grid, tuple):
301304
fig, axes = plt.subplots(grid[0], grid[1], sharey=sharey, figsize=figsize)
305+
grid_size = grid[0] * grid[1]
306+
n_plots = new_y.squeeze().shape[0]
307+
if n_plots > grid_size:
308+
warnings.warn("The grid is smaller than the number of available variables to plot")
309+
elif n_plots < grid_size:
310+
for i in range(n_plots, grid[0] * grid[1]):
311+
fig.delaxes(axes.flatten()[i])
312+
axes = axes.flatten()[:n_plots]
302313
axes = np.ravel(axes)
303314
else:
304315
axes = [ax]
@@ -307,10 +318,6 @@ def plot_dependence(
307318
x_idx = 0
308319
y_idx = 0
309320
for ax in axes: # pylint: disable=redefined-argument-from-local
310-
if x_idx >= len(var_idx):
311-
ax.set_axis_off()
312-
fig.delaxes(ax)
313-
314321
nyi = new_y[x_idx][y_idx]
315322
nxi = new_x_target[x_idx]
316323
var = var_idx[x_idx]

0 commit comments

Comments
 (0)