@@ -134,9 +134,17 @@ def __init__(
134
134
self .missing_data = np .any (np .isnan (self .X ))
135
135
self .m = self .bart .m
136
136
self .response = self .bart .response
137
+
137
138
shape = initial_values [value_bart .name ].shape
139
+
138
140
self .shape = 1 if len (shape ) == 1 else shape [0 ]
139
141
142
+ # Set trees_shape (dim for separate tree structures)
143
+ # and leaves_shape (dim for leaf node values)
144
+ # One of the two is always one, the other equal to self.shape
145
+ self .trees_shape = self .shape if self .bart .separate_trees else 1
146
+ self .leaves_shape = self .shape if not self .bart .separate_trees else 1
147
+
140
148
if self .bart .split_prior :
141
149
self .alpha_vec = self .bart .split_prior
142
150
else :
@@ -153,27 +161,31 @@ def __init__(
153
161
self .available_predictors = list (range (self .num_variates ))
154
162
155
163
# if data is binary
164
+ self .leaf_sd = np .ones ((self .trees_shape , self .leaves_shape ))
165
+
156
166
y_unique = np .unique (self .bart .Y )
157
167
if y_unique .size == 2 and np .all (y_unique == [0 , 1 ]):
158
- self .leaf_sd = 3 / self .m ** 0.5
168
+ self .leaf_sd * = 3 / self .m ** 0.5
159
169
else :
160
- self .leaf_sd = self .bart .Y .std () / self .m ** 0.5
170
+ self .leaf_sd * = self .bart .Y .std () / self .m ** 0.5
161
171
162
- self .running_sd = RunningSd (shape )
172
+ self .running_sd = [
173
+ RunningSd ((self .leaves_shape , self .num_observations )) for _ in range (self .trees_shape )
174
+ ]
163
175
164
- self .sum_trees = np .full (( self . shape , self . bart . Y . shape [ 0 ]), init_mean ). astype (
165
- config . floatX
166
- )
176
+ self .sum_trees = np .full (
177
+ ( self . trees_shape , self . leaves_shape , self . bart . Y . shape [ 0 ]), init_mean
178
+ ). astype ( config . floatX )
167
179
self .sum_trees_noi = self .sum_trees - init_mean
168
180
self .a_tree = Tree .new_tree (
169
181
leaf_node_value = init_mean / self .m ,
170
182
idx_data_points = np .arange (self .num_observations , dtype = "int32" ),
171
183
num_observations = self .num_observations ,
172
- shape = self .shape ,
184
+ shape = self .leaves_shape ,
173
185
split_rules = self .split_rules ,
174
186
)
175
187
176
- self .normal = NormalSampler (1 , self .shape )
188
+ self .normal = NormalSampler (1 , self .leaves_shape )
177
189
self .uniform = UniformSampler (0 , 1 )
178
190
self .prior_prob_leaf_node = compute_prior_probability (self .bart .alpha , self .bart .beta )
179
191
self .ssv = SampleSplittingVariable (self .alpha_vec )
@@ -188,8 +200,10 @@ def __init__(
188
200
self .indices = list (range (1 , num_particles ))
189
201
shared = make_shared_replacements (initial_values , vars , model )
190
202
self .likelihood_logp = logp (initial_values , [model .datalogp ], vars , shared )
191
- self .all_particles = [ParticleTree (self .a_tree ) for _ in range (self .m )]
192
- self .all_trees = np .array ([p .tree for p in self .all_particles ])
203
+ self .all_particles = [
204
+ [ParticleTree (self .a_tree ) for _ in range (self .m )] for _ in range (self .trees_shape )
205
+ ]
206
+ self .all_trees = np .array ([[p .tree for p in pl ] for pl in self .all_particles ])
193
207
self .lower = 0
194
208
self .iter = 0
195
209
super ().__init__ (vars , shared )
@@ -201,72 +215,75 @@ def astep(self, _):
201
215
tree_ids = range (self .lower , upper )
202
216
self .lower = upper if upper < self .m else 0
203
217
204
- for tree_id in tree_ids :
205
- self .iter += 1
206
- # Compute the sum of trees without the old tree that we are attempting to replace
207
- self .sum_trees_noi = self .sum_trees - self .all_particles [tree_id ].tree ._predict ()
208
- # Generate an initial set of particles
209
- # at the end we return one of these particles as the new tree
210
- particles = self .init_particles (tree_id )
211
-
212
- while True :
213
- # Sample each particle (try to grow each tree), except for the first one
214
- stop_growing = True
215
- for p in particles [1 :]:
216
- if p .sample_tree (
217
- self .ssv ,
218
- self .available_predictors ,
219
- self .prior_prob_leaf_node ,
220
- self .X ,
221
- self .missing_data ,
222
- self .sum_trees ,
223
- self .leaf_sd ,
224
- self .m ,
225
- self .response ,
226
- self .normal ,
227
- self .shape ,
228
- ):
229
- self .update_weight (p )
230
- if p .expansion_nodes :
231
- stop_growing = False
232
- if stop_growing :
233
- break
234
-
235
- # Normalize weights
236
- normalized_weights = self .normalize (particles [1 :])
237
-
238
- # Resample
239
- particles = self .resample (particles , normalized_weights )
240
-
241
- normalized_weights = self .normalize (particles )
242
- # Get the new particle and associated tree
243
- self .all_particles [tree_id ], new_tree = self .get_particle_tree (
244
- particles , normalized_weights
245
- )
246
- # Update the sum of trees
247
- new = new_tree ._predict ()
248
- self .sum_trees = self .sum_trees_noi + new
249
- # To reduce memory usage, we trim the tree
250
- self .all_trees [tree_id ] = new_tree .trim ()
251
-
252
- if self .tune :
253
- # Update the splitting variable and the splitting variable sampler
254
- if self .iter > self .m :
255
- self .ssv = SampleSplittingVariable (self .alpha_vec )
256
-
257
- for index in new_tree .get_split_variables ():
258
- self .alpha_vec [index ] += 1
259
-
260
- # update standard deviation at leaf nodes
261
- if self .iter > 2 :
262
- self .leaf_sd = self .running_sd .update (new )
263
- else :
264
- self .running_sd .update (new )
218
+ for odim in range (self .trees_shape ):
219
+ for tree_id in tree_ids :
220
+ self .iter += 1
221
+ # Compute the sum of trees without the old tree that we are attempting to replace
222
+ self .sum_trees_noi [odim ] = (
223
+ self .sum_trees [odim ] - self .all_particles [odim ][tree_id ].tree ._predict ()
224
+ )
225
+ # Generate an initial set of particles
226
+ # at the end we return one of these particles as the new tree
227
+ particles = self .init_particles (tree_id , odim )
228
+
229
+ while True :
230
+ # Sample each particle (try to grow each tree), except for the first one
231
+ stop_growing = True
232
+ for p in particles [1 :]:
233
+ if p .sample_tree (
234
+ self .ssv ,
235
+ self .available_predictors ,
236
+ self .prior_prob_leaf_node ,
237
+ self .X ,
238
+ self .missing_data ,
239
+ self .sum_trees [odim ],
240
+ self .leaf_sd [odim ],
241
+ self .m ,
242
+ self .response ,
243
+ self .normal ,
244
+ self .leaves_shape ,
245
+ ):
246
+ self .update_weight (p , odim )
247
+ if p .expansion_nodes :
248
+ stop_growing = False
249
+ if stop_growing :
250
+ break
251
+
252
+ # Normalize weights
253
+ normalized_weights = self .normalize (particles [1 :])
254
+
255
+ # Resample
256
+ particles = self .resample (particles , normalized_weights )
257
+
258
+ normalized_weights = self .normalize (particles )
259
+ # Get the new particle and associated tree
260
+ self .all_particles [odim ][tree_id ], new_tree = self .get_particle_tree (
261
+ particles , normalized_weights
262
+ )
263
+ # Update the sum of trees
264
+ new = new_tree ._predict ()
265
+ self .sum_trees [odim ] = self .sum_trees_noi [odim ] + new
266
+ # To reduce memory usage, we trim the tree
267
+ self .all_trees [odim ][tree_id ] = new_tree .trim ()
268
+
269
+ if self .tune :
270
+ # Update the splitting variable and the splitting variable sampler
271
+ if self .iter > self .m :
272
+ self .ssv = SampleSplittingVariable (self .alpha_vec )
273
+
274
+ for index in new_tree .get_split_variables ():
275
+ self .alpha_vec [index ] += 1
276
+
277
+ # update standard deviation at leaf nodes
278
+ if self .iter > 2 :
279
+ self .leaf_sd [odim ] = self .running_sd [odim ].update (new )
280
+ else :
281
+ self .running_sd [odim ].update (new )
265
282
266
- else :
267
- # update the variable inclusion
268
- for index in new_tree .get_split_variables ():
269
- variable_inclusion [index ] += 1
283
+ else :
284
+ # update the variable inclusion
285
+ for index in new_tree .get_split_variables ():
286
+ variable_inclusion [index ] += 1
270
287
271
288
if not self .tune :
272
289
self .bart .all_trees .append (self .all_trees )
@@ -331,23 +348,27 @@ def systematic(self, normalized_weights: npt.NDArray[np.float_]) -> npt.NDArray[
331
348
single_uniform = (self .uniform .rvs () + np .arange (lnw )) / lnw
332
349
return inverse_cdf (single_uniform , normalized_weights )
333
350
334
- def init_particles (self , tree_id : int ) -> List [ParticleTree ]:
351
+ def init_particles (self , tree_id : int , odim : int ) -> List [ParticleTree ]:
335
352
"""Initialize particles."""
336
- p0 : ParticleTree = self .all_particles [tree_id ]
353
+ p0 : ParticleTree = self .all_particles [odim ][ tree_id ]
337
354
# The old tree does not grow so we update the weight only once
338
- self .update_weight (p0 )
355
+ self .update_weight (p0 , odim )
339
356
particles : List [ParticleTree ] = [p0 ]
340
357
341
358
particles .extend (ParticleTree (self .a_tree ) for _ in self .indices )
342
359
return particles
343
360
344
- def update_weight (self , particle : ParticleTree ) -> None :
361
+ def update_weight (self , particle : ParticleTree , odim : int ) -> None :
345
362
"""
346
363
Update the weight of a particle.
347
364
"""
348
- new_likelihood = self .likelihood_logp (
349
- (self .sum_trees_noi + particle .tree ._predict ()).flatten ()
365
+
366
+ delta = (
367
+ np .identity (self .trees_shape )[odim ][:, None , None ]
368
+ * particle .tree ._predict ()[None , :, :]
350
369
)
370
+
371
+ new_likelihood = self .likelihood_logp ((self .sum_trees_noi + delta ).flatten ())
351
372
particle .log_weight = new_likelihood
352
373
353
374
@staticmethod
0 commit comments