@@ -122,7 +122,7 @@ def test_init(
122
122
assert beta .eval ().shape == input_std .shape
123
123
# r2 rv is only created if r2 std is not None
124
124
assert ("beta::r2" in model .named_vars ) == (r2_std is not None ), set (model .named_vars )
125
- # phi is only created if variable importances is not None and there is more than one var
125
+ # phi is only created if variable importance is not None and there is more than one var
126
126
assert ("beta::phi" in model .named_vars ) == (
127
127
"variables_importance" in phi_args or "importance_concentration" in phi_args
128
128
), set (model .named_vars )
@@ -204,3 +204,34 @@ def test_limit_case_requires_std_0(self, model: pm.Model):
204
204
positive_probs = [0.5 , 1 ],
205
205
positive_probs_std = [0.3 , 0.1 ],
206
206
)
207
+
208
+ def test_limit_case_creates_masked_vars (self , model : pm .Model , centered : bool ):
209
+ model .add_coord ("a" , range (2 ))
210
+ pmx .distributions .R2D2M2CP (
211
+ "beta0" ,
212
+ 1 ,
213
+ [1 , 1 ],
214
+ dims = "a" ,
215
+ r2 = 0.8 ,
216
+ positive_probs = [0.5 , 1 ],
217
+ positive_probs_std = [0.3 , 0 ],
218
+ centered = centered ,
219
+ )
220
+ pmx .distributions .R2D2M2CP (
221
+ "beta1" ,
222
+ 1 ,
223
+ [1 , 1 ],
224
+ dims = "a" ,
225
+ r2 = 0.8 ,
226
+ positive_probs = [0.5 , 0 ],
227
+ positive_probs_std = [0.3 , 0 ],
228
+ centered = centered ,
229
+ )
230
+ if not centered :
231
+ assert "beta0::raw::masked" in model .named_vars , model .named_vars
232
+ assert "beta1::raw::masked" in model .named_vars , model .named_vars
233
+ else :
234
+ assert "beta0::masked" in model .named_vars , model .named_vars
235
+ assert "beta1::masked" in model .named_vars , model .named_vars
236
+ assert "beta1::psi::masked" in model .named_vars
237
+ assert "beta0::psi::masked" in model .named_vars
0 commit comments