Skip to content

Commit cdcdb58

Browse files
authored
Allow copy and deepcopy of PYMC models (#7492)
1 parent 67f43ae commit cdcdb58

File tree

3 files changed

+91
-0
lines changed

3 files changed

+91
-0
lines changed

pymc/model/core.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1577,6 +1577,41 @@ def __getitem__(self, key):
15771577
def __contains__(self, key):
15781578
return key in self.named_vars or self.name_for(key) in self.named_vars
15791579

1580+
def __copy__(self):
1581+
return self.copy()
1582+
1583+
def __deepcopy__(self, _):
1584+
return self.copy()
1585+
1586+
def copy(self):
1587+
"""
1588+
Clone the model
1589+
1590+
To access variables in the cloned model use `cloned_model["var_name"]`.
1591+
1592+
Examples
1593+
--------
1594+
.. code-block:: python
1595+
1596+
import pymc as pm
1597+
import copy
1598+
1599+
with pm.Model() as m:
1600+
p = pm.Beta("p", 1, 1)
1601+
x = pm.Bernoulli("x", p=p, shape=(3,))
1602+
1603+
clone_m = copy.copy(m)
1604+
1605+
# Access cloned variables by name
1606+
clone_x = clone_m["x"]
1607+
1608+
# z will be part of clone_m but not m
1609+
z = pm.Deterministic("z", clone_x + 1)
1610+
"""
1611+
from pymc.model.fgraph import clone_model
1612+
1613+
return clone_model(self)
1614+
15801615
def replace_rvs_by_values(
15811616
self,
15821617
graphs: Sequence[TensorVariable],

pymc/model/fgraph.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
15+
1416
from copy import copy, deepcopy
1517

1618
import pytensor
@@ -158,6 +160,14 @@ def fgraph_from_model(
158160
"Nested sub-models cannot be converted to fgraph. Convert the parent model instead"
159161
)
160162

163+
if any(
164+
("_rotated_" in var_name or "_hsgp_coeffs_" in var_name) for var_name in model.named_vars
165+
):
166+
warnings.warn(
167+
"Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883",
168+
UserWarning,
169+
)
170+
161171
# Collect PyTensor variables
162172
rvs_to_values = model.rvs_to_values
163173
rvs = list(rvs_to_values.keys())

tests/model/test_core.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import copy
1415
import pickle
1516
import threading
1617
import traceback
@@ -1761,3 +1762,48 @@ def test_graphviz_call_function(self, var_names, filenames) -> None:
17611762
figsize=None,
17621763
dpi=300,
17631764
)
1765+
1766+
1767+
class TestModelCopy:
1768+
@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
1769+
def test_copy_model(self, copy_method) -> None:
1770+
with pm.Model() as simple_model:
1771+
pm.Normal("y")
1772+
1773+
copy_simple_model = copy_method(simple_model)
1774+
1775+
with simple_model:
1776+
simple_model_prior_predictive = pm.sample_prior_predictive(samples=1, random_seed=42)
1777+
1778+
with copy_simple_model:
1779+
z = pm.Deterministic("z", copy_simple_model["y"] + 1)
1780+
copy_simple_model_prior_predictive = pm.sample_prior_predictive(
1781+
samples=1, random_seed=42
1782+
)
1783+
1784+
assert (
1785+
simple_model_prior_predictive["prior"]["y"].values
1786+
== copy_simple_model_prior_predictive["prior"]["y"].values
1787+
)
1788+
1789+
assert "z" in copy_simple_model.named_vars
1790+
assert "z" not in simple_model.named_vars
1791+
assert (
1792+
copy_simple_model_prior_predictive["prior"]["z"].values
1793+
== 1 + simple_model_prior_predictive["prior"]["y"].values
1794+
)
1795+
1796+
@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy))
1797+
def test_guassian_process_copy_failure(self, copy_method) -> None:
1798+
with pm.Model() as gaussian_process_model:
1799+
ell = pm.Gamma("ell", alpha=2, beta=1)
1800+
cov = 2 * pm.gp.cov.ExpQuad(1, ell)
1801+
gp = pm.gp.Latent(cov_func=cov)
1802+
f = gp.prior("f", X=np.arange(10)[:, None])
1803+
pm.Normal("y", f * 2)
1804+
1805+
with pytest.warns(
1806+
UserWarning,
1807+
match="Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883",
1808+
):
1809+
copy_method(gaussian_process_model)

0 commit comments

Comments
 (0)