1
1
"""Utility function for variable selection and bart interpretability."""
2
2
3
+ import warnings
4
+
3
5
import arviz as az
4
6
import matplotlib .pyplot as plt
5
7
import numpy as np
@@ -287,6 +289,7 @@ def plot_dependence(
287
289
y_mins .append (np .min (y_pred ))
288
290
new_y .append (np .array (y_pred ).T )
289
291
292
+ new_y = np .array (new_y )
290
293
if func is not None :
291
294
new_y = [func (nyi ) for nyi in new_y ]
292
295
shape = 1
@@ -299,6 +302,14 @@ def plot_dependence(
299
302
fig , axes = plt .subplots (1 , len (var_idx ) * shape , sharey = sharey , figsize = figsize )
300
303
elif isinstance (grid , tuple ):
301
304
fig , axes = plt .subplots (grid [0 ], grid [1 ], sharey = sharey , figsize = figsize )
305
+ grid_size = grid [0 ] * grid [1 ]
306
+ n_plots = new_y .squeeze ().shape [0 ]
307
+ if n_plots > grid_size :
308
+ warnings .warn ("The grid is smaller than the number of available variables to plot" )
309
+ elif n_plots < grid_size :
310
+ for i in range (n_plots , grid [0 ] * grid [1 ]):
311
+ fig .delaxes (axes .flatten ()[i ])
312
+ axes = axes .flatten ()[:n_plots ]
302
313
axes = np .ravel (axes )
303
314
else :
304
315
axes = [ax ]
@@ -307,10 +318,6 @@ def plot_dependence(
307
318
x_idx = 0
308
319
y_idx = 0
309
320
for ax in axes : # pylint: disable=redefined-argument-from-local
310
- if x_idx >= len (var_idx ):
311
- ax .set_axis_off ()
312
- fig .delaxes (ax )
313
-
314
321
nyi = new_y [x_idx ][y_idx ]
315
322
nxi = new_x_target [x_idx ]
316
323
var = var_idx [x_idx ]
0 commit comments