@@ -1163,3 +1163,124 @@ def from_store(
1163
1163
self .coords = coords
1164
1164
self .vars_to_dims = vars_to_dims
1165
1165
return self
1166
+
1167
+ def assert_model_and_step_are_compatible (
1168
+ self ,
1169
+ step : BlockedStep | CompoundStep ,
1170
+ model : Model ,
1171
+ vars : list [TensorVariable ] | None = None ,
1172
+ ):
1173
+ zarr_groups = set (self .root .group_keys ())
1174
+ arrays_ = itertools .chain (
1175
+ self .posterior .arrays (),
1176
+ self .constant_data .arrays () if "constant_data" in zarr_groups else [],
1177
+ self .observed_data .arrays () if "observed_data" in zarr_groups else [],
1178
+ )
1179
+ if self .include_transformed :
1180
+ arrays_ = itertools .chain (arrays_ , self .unconstrained_posterior .arrays ())
1181
+ arrays = list (arrays_ )
1182
+ zarr_varnames = []
1183
+ zarr_coords = {}
1184
+ zarr_vars_to_dims = {}
1185
+ zarr_deterministics = []
1186
+ zarr_free_vars = []
1187
+ for name , array in arrays :
1188
+ dims = array .attrs ["_ARRAY_DIMENSIONS" ]
1189
+ if dims [:2 ] == ["chain" , "draw" ]:
1190
+ # Random Variable
1191
+ zarr_vars_to_dims [name ] = dims [2 :]
1192
+ zarr_varnames .append (name )
1193
+ if array .attrs ["kind" ] == "freeRV" :
1194
+ zarr_free_vars .append (name )
1195
+ else :
1196
+ zarr_deterministics .append (name )
1197
+ elif len (dims ) == 1 and name == dims [0 ]:
1198
+ # Coordinate
1199
+ if name not in ["chain" , "draw" ]:
1200
+ zarr_coords [name ] = np .asarray (array )
1201
+ else :
1202
+ # Constant data or observation
1203
+ zarr_vars_to_dims [name ] = dims
1204
+ zarr_constant_data = (
1205
+ [name for name in self .constant_data .array_keys () if name not in zarr_coords ]
1206
+ if "constant_data" in zarr_groups
1207
+ else []
1208
+ )
1209
+ zarr_observed_data = (
1210
+ [name for name in self .observed_data .array_keys () if name not in zarr_coords ]
1211
+ if "observed_data" in zarr_groups
1212
+ else []
1213
+ )
1214
+ autogenerated_dims = {dim for dim in zarr_coords if re .search (r"_dim_\d+$" , dim )}
1215
+
1216
+ # Check deterministics, free RVs and transformed RVs
1217
+ _ , var_names = self .parse_varnames (model , vars )
1218
+ assert set (var_names ) == set (zarr_free_vars + zarr_deterministics ), (
1219
+ "The model deterministics and random variables given the sampled var_names "
1220
+ "do not match with the stored deterministics variables in the trace."
1221
+ )
1222
+ for name , array in arrays :
1223
+ if name not in zarr_free_vars or name not in zarr_deterministics :
1224
+ continue
1225
+ model_var = model [name ]
1226
+ assert np .dtype (model_var .dtype ) == np .dtype (array .dtype ), (
1227
+ "The model deterministics and random variables given the sampled "
1228
+ "var_names do not match with the stored deterministics variables in "
1229
+ "the trace."
1230
+ )
1231
+
1232
+ # Check coordinates
1233
+ assert (set (zarr_coords ) - set (autogenerated_dims )) == set (model .coords ) and all (
1234
+ np .array_equal (np .asarray (zarr_coords [dim ]), np .asarray (coord ))
1235
+ for dim , coord in model .coords .items ()
1236
+ ), "Model coordinates don't match the coordinates stored in the trace"
1237
+ vars_to_explicit_dims = {}
1238
+ for name , dims in zarr_vars_to_dims .items ():
1239
+ if len (dims ) == 0 or all (dim in autogenerated_dims for dim in dims ):
1240
+ # These variables wont be included in the named_vars_to_dims
1241
+ continue
1242
+ vars_to_explicit_dims [name ] = [
1243
+ dim if dim not in autogenerated_dims else None for dim in dims
1244
+ ]
1245
+ assert set (vars_to_explicit_dims ) == set (model .named_vars_to_dims ) and all (
1246
+ vars_to_explicit_dims [name ] == list (dims )
1247
+ for name , dims in model .named_vars_to_dims .items ()
1248
+ ), "Some model variables have different dimensions than those stored in the trace."
1249
+
1250
+ # Check constant data
1251
+ model_constant_data = find_constants (model )
1252
+ assert set (zarr_constant_data ) == set (model_constant_data ), (
1253
+ "The model constant data does not match with the stored constant data"
1254
+ )
1255
+ for name , model_data in model_constant_data .items ():
1256
+ assert np .array_equal (self .constant_data [name ], model_data , equal_nan = True ), (
1257
+ "The model constant data does not match with the stored constant data"
1258
+ )
1259
+
1260
+ # Check observed data
1261
+ model_observed_data = find_observations (model )
1262
+ assert set (zarr_observed_data ) == set (model_observed_data ), (
1263
+ "The model observed data does not match with the stored observed data"
1264
+ )
1265
+ for name , model_data in model_observed_data .items ():
1266
+ assert np .array_equal (self .observed_data [name ], model_data , equal_nan = True ), (
1267
+ "The model observed data does not match with the stored observed data"
1268
+ )
1269
+
1270
+ # Check sample stats given the step method
1271
+ stats_dtypes_shapes = get_stats_dtypes_shapes_from_steps (
1272
+ [step ] if isinstance (step , BlockedStep ) else step .methods
1273
+ )
1274
+ assert (set (stats_dtypes_shapes ) | {"chain" , "draw" }) == set (
1275
+ self .sample_stats .array_keys ()
1276
+ ), "The step method sample stats do not match the ones stored in the trace."
1277
+ for name , array in self .sample_stats .arrays ():
1278
+ if name in ("chain" , "draw" ):
1279
+ continue
1280
+ assert np .dtype (stats_dtypes_shapes [name ][0 ]) == np .dtype (array .dtype ), (
1281
+ "The step method sample stats do not match the ones stored in the trace."
1282
+ )
1283
+
1284
+ assert step .sampling_state .is_compatible (self ._sampling_state .sampling_state [0 ]), (
1285
+ "The step method sampling state class is incompatible with what's stored in the trace."
1286
+ )
0 commit comments