Skip to content

Commit 111fae3

Browse files
Allow for StatsBijection.rmap of incomplete stat dicts
1 parent ac9652e commit 111fae3

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

pymc/step_methods/compound.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,12 @@ def map(self, stats_list: Sequence[Mapping[str, Any]]) -> StatsDict:
198198
return stats_dict
199199

200200
def rmap(self, stats_dict: Mapping[str, Any]) -> StatsType:
201-
"""Split a global stats dict into a list of sampler-wise stats dicts."""
201+
"""Split a global stats dict into a list of sampler-wise stats dicts.
202+
203+
The ``stats_dict`` can be a subset of all sampler stats.
204+
"""
202205
stats_list = []
203206
for namemap in self._stat_groups:
204-
d = {statname: stats_dict[sname] for sname, statname in namemap}
207+
d = {statname: stats_dict[sname] for sname, statname in namemap if sname in stats_dict}
205208
stats_list.append(d)
206209
return stats_list

pymc/tests/step_methods/test_compound.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,3 +130,8 @@ def test_stats_bijection(self):
130130
assert isinstance(rev, list)
131131
assert len(rev) == len(stats_l)
132132
assert rev == stats_l
133+
# Also rmap incomplete dicts
134+
rev2 = bij.rmap({"sampler_1__a": 0})
135+
assert len(rev2) == 2
136+
assert len(rev2[0]) == 0
137+
assert len(rev2[1]) == 1

0 commit comments

Comments
 (0)