@@ -1092,3 +1092,74 @@ def to_inferencedata(self, save_warmup: bool = False) -> az.InferenceData:
1092
1092
data .attrs = make_attrs (attrs = attrs , library = pymc )
1093
1093
groups [name ] = data .load () if az .rcParams ["data.load" ] == "eager" else data
1094
1094
return az .InferenceData (** groups )
1095
+
1096
+ @classmethod
1097
+ def from_store (
1098
+ cls : type ["ZarrTrace" ],
1099
+ store : BaseStore | MutableMapping ,
1100
+ synchronizer : Synchronizer | None = None ,
1101
+ ) -> "ZarrTrace" :
1102
+ if not _zarr_available :
1103
+ raise RuntimeError ("You must install zarr to be able to create ZarrTrace instances" )
1104
+ self : ZarrTrace = object .__new__ (cls )
1105
+ self .root = zarr .group (
1106
+ store = store ,
1107
+ overwrite = False ,
1108
+ synchronizer = synchronizer ,
1109
+ )
1110
+ self .synchronizer = synchronizer
1111
+ self .compressor = default_compressor
1112
+
1113
+ groups = set (self .root .group_keys ())
1114
+ assert groups >= {
1115
+ "posterior" ,
1116
+ "sample_stats" ,
1117
+ "warmup_posterior" ,
1118
+ "warmup_sample_stats" ,
1119
+ "constant_data" ,
1120
+ "observed_data" ,
1121
+ "_sampling_state" ,
1122
+ }
1123
+
1124
+ if "posterior" in groups :
1125
+ for _ , array in self .posterior .arrays ():
1126
+ dims = array .attrs .get ("_ARRAY_DIMENSIONS" , [])
1127
+ if len (dims ) >= 2 and dims [1 ] == "draw" :
1128
+ draws_per_chunk = int (array .chunks [1 ])
1129
+ break
1130
+ else :
1131
+ draws_per_chunk = 1
1132
+
1133
+ self .draws_per_chunk = int (draws_per_chunk )
1134
+ assert self .draws_per_chunk >= 1
1135
+
1136
+ self .include_transformed = "unconstrained_posterior" in groups
1137
+ arrays = itertools .chain (
1138
+ self .posterior .arrays (),
1139
+ self .constant_data .arrays (),
1140
+ self .observed_data .arrays (),
1141
+ )
1142
+ if self .include_transformed :
1143
+ arrays = itertools .chain (arrays , self .unconstrained_posterior .arrays ())
1144
+ varnames = []
1145
+ coords = {}
1146
+ vars_to_dims = {}
1147
+ for name , array in arrays :
1148
+ dims = array .attrs ["_ARRAY_DIMENSIONS" ]
1149
+ if dims [:2 ] == ["chain" , "draw" ]:
1150
+ # Random Variable
1151
+ vars_to_dims [name ] = dims [2 :]
1152
+ varnames .append (name )
1153
+ elif len (dims ) == 1 and name == dims [0 ]:
1154
+ # Coordinate
1155
+ # We store all model coordinates, which means we have to exclude chain
1156
+ # and draw
1157
+ if name not in ["chain" , "draw" ]:
1158
+ coords [name ] = np .asarray (array )
1159
+ else :
1160
+ # Constant data or observation
1161
+ vars_to_dims [name ] = dims
1162
+ self .varnames = varnames
1163
+ self .coords = coords
1164
+ self .vars_to_dims = vars_to_dims
1165
+ return self
0 commit comments