Skip to content

Commit b34742d

Browse files
committed
added __copy__ and __deepcopy__ methods to Model and added unit tests
1 parent 2856062 commit b34742d

File tree

2 files changed

+66
-0
lines changed

2 files changed

+66
-0
lines changed

pymc/model/core.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1570,6 +1570,22 @@ def __getitem__(self, key):
15701570

15711571
def __contains__(self, key):
15721572
return key in self.named_vars or self.name_for(key) in self.named_vars
1573+
1574+
def __copy__(self):
1575+
from pymc.model.fgraph import clone_model
1576+
check_for_gp_vars = [k for x in ['_rotated_', '_hsgp_coeffs_'] for k in self.named_vars.keys() if x in k]
1577+
if len(check_for_gp_vars) > 0:
1578+
raise Exception("Unable to clone Gaussian Process Variables")
1579+
1580+
return clone_model(self)
1581+
1582+
def __deepcopy__(self, _):
1583+
from pymc.model.fgraph import clone_model
1584+
check_for_gp_vars = [k for x in ['_rotated_', '_hsgp_coeffs_'] for k in self.named_vars.keys() if x in k]
1585+
if len(check_for_gp_vars) > 0:
1586+
raise Exception("Unable to clone Gaussian Process Variables")
1587+
1588+
return clone_model(self)
15731589

15741590
def replace_rvs_by_values(
15751591
self,

tests/model/test_core.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
1415
import pickle
1516
import threading
1617
import traceback
@@ -1761,3 +1762,52 @@ def test_graphviz_call_function(self, var_names, filenames) -> None:
17611762
figsize=None,
17621763
dpi=300,
17631764
)
1765+
1766+
1767+
class TestModelCopy(unittest.TestCase):
1768+
@staticmethod
1769+
def simple_model() -> pm.Model:
1770+
with pm.Model() as simple_model:
1771+
error = pm.HalfNormal("error", 0.5)
1772+
alpha = pm.Normal("alpha", 0, 1)
1773+
pm.Normal("y", alpha, error)
1774+
return simple_model
1775+
1776+
@staticmethod
1777+
def gp_model() -> pm.Model:
1778+
with pm.Model() as gp_model:
1779+
ell = pm.Gamma("ell", alpha=2, beta=1)
1780+
cov = 2 * pm.gp.cov.ExpQuad(1, ell)
1781+
gp = pm.gp.Latent(cov_func=cov)
1782+
f = gp.prior("f", X=np.arange(10)[:, None])
1783+
pm.Normal("y", f * 2)
1784+
return gp_model
1785+
1786+
def test_copy_model(self) -> None:
1787+
simple_model = self.simple_model()
1788+
copy_simple_model = copy.copy(simple_model)
1789+
deepcopy_simple_model = copy.deepcopy(simple_model)
1790+
1791+
with simple_model:
1792+
simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42)
1793+
1794+
with copy_simple_model:
1795+
copy_simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42)
1796+
1797+
with deepcopy_simple_model:
1798+
deepcopy_simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42)
1799+
1800+
simple_model_prior_predictive_mean = simple_model_prior_predictive['prior']['y'].mean(('chain', 'draw'))
1801+
copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive['prior']['y'].mean(('chain', 'draw'))
1802+
deepcopy_simple_model_prior_predictive_mean = deepcopy_simple_model_prior_predictive['prior']['y'].mean(('chain', 'draw'))
1803+
1804+
assert np.isclose(simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean)
1805+
assert np.isclose(simple_model_prior_predictive_mean, deepcopy_simple_model_prior_predictive_mean)
1806+
1807+
def test_guassian_process_copy_failure(self) -> None:
1808+
gaussian_process_model = self.gp_model()
1809+
with pytest.raises(Exception) as e:
1810+
copy.copy(gaussian_process_model)
1811+
1812+
with pytest.raises(Exception) as e:
1813+
copy.deepcopy(gaussian_process_model)

0 commit comments

Comments
 (0)