Description
Is it possible to define two BART RVs with different X and Y inputs in the same pymc-bart model? For example, using BART to estimate the nonlinear relationship for each link of a network? I seem to run into an error when trying a simple independent example for this that suggests this may not be possible so I wanted to be sure.
Assume we have Y1 = f1(X1) + e1 and Y2 = f2(X2) + e2 where e1 and e2 are independent errors for simplicity and X1 and X2 are distinct. I understand you can use 'size=2' when defining a BART RV, as in the heteroskedasticity example, but in this case we have two different inputs for X and Y. I'm guessing Y is needed to set the BART priors and X is used for the splitting rules. Is there a way to specify this model, or is it not currently possible? My use-case will eventually involve correlated errors, but that should be easy enough to incorporate with an LKJ prior and MvNormal as the likelihood specification.
Here's a simple example. Thanks in advance for any help you can provide.
### Simulate Data ###
import numpy as np
import pymc as pm
import pymc_bart as pmb
sigma1, sigma2 = 1, 1
n = 200
# Predictor variable
X11 = np.random.uniform(-2,2,n)
X12 = np.random.uniform(-2,2,n)
X21 = np.random.uniform(-2,2,n)
X22 = np.random.uniform(-2,2,n)
X1=np.c_[X11,X12]
X2=np.c_[X21,X22]
f1 = X1[:,0]
f2 = (X2[:,0]-0.5*X2[:,0]**3)/4
zeta1 = np.random.normal(0,1,n)
zeta2 = np.random.normal(0,1,n)
Y1 = f1 + sigma1*zeta1
Y2 = f2 + sigma2*zeta2
### Specify Model ###
netBART = pm.Model()
with netBART:
f1=pmb.BART("f1",X=X1,Y=Y2)
f2=pmb.BART("f2",X=X2,Y=Y2)
mu1=pm.Deterministic("mu1",f1)
mu2=pm.Deterministic("mu2",f2)
sigma1Sq=pm.InverseGamma("sigma1Sq",1,1)
sigma2Sq=pm.InverseGamma("sigma2Sq",1,1)
sigma1=pm.Deterministic("sigma1",sigma1Sq**0.5)
sigma2=pm.Deterministic("sigma2",sigma2Sq**0.5)
like1=pm.Normal("like1",mu=mu1,sigma=sigma1,observed=Y1)
like2=pm.Normal("like2",mu=mu2,sigma=sigma2,observed=Y2)
idata=pm.sample(draws=100,tune=50,chains=4)
Output:
Only 100 samples in chain.
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[116], line 16
13 like1=pm.Normal("like1",mu=mu1,sigma=sigma1,observed=Y1)
14 like2=pm.Normal("like2",mu=mu2,sigma=sigma2,observed=Y2)
---> 16 idata=pm.sample(draws=100,tune=50,chains=4)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling/mcmc.py:452, in sample(draws, step, init, n_init, initvals, trace, chains, cores, tune, progressbar, model, random_seed, discard_tuned_samples, compute_convergence_checks, callback, jitter_max_retries, return_inferencedata, keep_warning_stat, idata_kwargs, mp_ctx, **kwargs)
449 auto_nuts_init = False
451 initial_points = None
--> 452 step = assign_step_methods(model, step, methods=pm.STEP_METHODS, step_kwargs=kwargs)
454 if isinstance(step, list):
455 step = CompoundStep(step)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling/mcmc.py:194, in assign_step_methods(model, step, methods, step_kwargs)
186 selected = max(
187 methods,
188 key=lambda method, var=rv_var, has_gradient=has_gradient: method._competence(
189 var, has_gradient
190 ),
191 )
192 selected_steps[selected].append(var)
--> 194 return instantiate_steppers(model, steps, selected_steps, step_kwargs)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/sampling/mcmc.py:112, in instantiate_steppers(model, steps, selected_steps, step_kwargs)
110 args = step_kwargs.get(step_class.name, {})
111 used_keys.add(step_class.name)
--> 112 step = step_class(vars=vars, model=model, **args)
113 steps.append(step)
115 unused_args = set(step_kwargs).difference(used_keys)
File ~/miniconda3/envs/pymc_env/lib/python3.10/site-packages/pymc/step_methods/compound.py:89, in BlockedStep.__new__(cls, *args, **kwargs)
86 step = super().__new__(cls)
87 # If we don't return the instance we have to manually
88 # call __init__
---> 89 step.__init__([var], *args, **kwargs)
90 # Hack for creating the class correctly when unpickling.
91 step.__newargs = ([var],) + args, kwargs
TypeError: PGBART.__init__() got an unexpected keyword argument 'blocked'