Skip to content

Commit 9abf4e0

Browse files
michaelosthegericardoV94
authored andcommitted
Remove unused merge_traces helper function
1 parent 9aeb6b5 commit 9abf4e0

File tree

3 files changed

+1
-79
lines changed

3 files changed

+1
-79
lines changed

pymc/backends/base.py

Lines changed: 1 addition & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,11 @@
2121
import warnings
2222

2323
from abc import ABC
24-
from typing import List
2524

2625
import aesara.tensor as at
2726
import numpy as np
2827

29-
from pymc.backends.report import SamplerReport, merge_reports
28+
from pymc.backends.report import SamplerReport
3029
from pymc.model import modelcontext
3130
from pymc.util import get_var_name
3231

@@ -570,43 +569,6 @@ def points(self, chains=None):
570569
return itl.chain.from_iterable(self._straces[chain] for chain in chains)
571570

572571

573-
def merge_traces(mtraces: List[MultiTrace]) -> MultiTrace:
574-
"""Merge MultiTrace objects.
575-
576-
Parameters
577-
----------
578-
mtraces: list of MultiTraces
579-
Each instance should have unique chain numbers.
580-
581-
Raises
582-
------
583-
A ValueError is raised if any traces have overlapping chain numbers,
584-
or if chains are of different lengths.
585-
586-
Returns
587-
-------
588-
A MultiTrace instance with merged chains
589-
"""
590-
if len(mtraces) == 0:
591-
raise ValueError("Cannot merge an empty set of traces.")
592-
base_mtrace = mtraces[0]
593-
chain_len = len(base_mtrace)
594-
# check base trace
595-
if any(
596-
len(st) != chain_len for _, st in base_mtrace._straces.items()
597-
): # pylint: disable=line-too-long
598-
raise ValueError("Chains are of different lengths.")
599-
for new_mtrace in mtraces[1:]:
600-
for new_chain, strace in new_mtrace._straces.items():
601-
if new_chain in base_mtrace._straces:
602-
raise ValueError("Chains are not unique.")
603-
if len(strace) != chain_len:
604-
raise ValueError("Chains are of different lengths.")
605-
base_mtrace._straces[new_chain] = strace
606-
base_mtrace._report = merge_reports([trace.report for trace in mtraces])
607-
return base_mtrace
608-
609-
610572
def _squeeze_cat(results, combine, squeeze):
611573
"""Squeeze and concatenate the results depending on values of
612574
`combine` and `squeeze`."""

pymc/backends/report.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,3 @@ def filter_warns(warnings):
206206
report._add_warnings(filter_warns(self._chain_warnings[chain]), chain)
207207

208208
return report
209-
210-
211-
def merge_reports(reports):
212-
report = SamplerReport()
213-
for rep in reports:
214-
report._add_warnings(rep._global_warnings)
215-
for chain in rep._chain_warnings:
216-
report._add_warnings(rep._chain_warnings[chain], chain)
217-
return report

pymc/tests/backends/test_ndarray.py

Lines changed: 0 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -123,37 +123,6 @@ def test_multitrace_nonunique(self):
123123
with pytest.raises(ValueError):
124124
base.MultiTrace([self.strace0, self.strace1])
125125

126-
def test_merge_traces_no_traces(self):
127-
with pytest.raises(ValueError):
128-
base.merge_traces([])
129-
130-
def test_merge_traces_diff_lengths(self):
131-
with self.model:
132-
strace0 = self.backend(self.name)
133-
strace0.setup(self.draws, 1)
134-
for i in range(self.draws):
135-
strace0.record(self.test_point)
136-
strace0.close()
137-
mtrace0 = base.MultiTrace([self.strace0])
138-
139-
with self.model:
140-
strace1 = self.backend(self.name)
141-
strace1.setup(2 * self.draws, 1)
142-
for i in range(2 * self.draws):
143-
strace1.record(self.test_point)
144-
strace1.close()
145-
mtrace1 = base.MultiTrace([strace1])
146-
147-
with pytest.raises(ValueError):
148-
base.merge_traces([mtrace0, mtrace1])
149-
150-
def test_merge_traces_nonunique(self):
151-
mtrace0 = base.MultiTrace([self.strace0])
152-
mtrace1 = base.MultiTrace([self.strace1])
153-
154-
with pytest.raises(ValueError):
155-
base.merge_traces([mtrace0, mtrace1])
156-
157126

158127
class TestMultiTrace_add_remove_values(bf.ModelBackendSampledTestCase):
159128
name = None

0 commit comments

Comments
 (0)