Skip to content

Commit bcb4309

Browse files
committed
changed raise to warning, moved warning to low level clone_graph, added doc example, updated pytest
1 parent fe4e0c5 commit bcb4309

File tree

3 files changed

+70
-27
lines changed

3 files changed

+70
-27
lines changed

pymc/model/core.py

Lines changed: 40 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1574,31 +1574,57 @@ def __contains__(self, key):
15741574
def __copy__(self):
15751575
"""
15761576
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph.
1577-
if guassian process variables are detected then an exception will be raised.
1577+
Constants are not cloned and if guassian process variables are detected then a warning will be triggered.
1578+
1579+
Examples
1580+
--------
1581+
.. code-block:: python
1582+
1583+
import pymc as pm
1584+
import copy
1585+
1586+
with pm.Model() as m:
1587+
p = pm.Beta("p", 1, 1)
1588+
x = pm.Bernoulli("x", p=p, shape=(3,))
1589+
1590+
clone_m = copy.copy(m)
1591+
1592+
# Access cloned variables by name
1593+
clone_x = clone_m["x"]
1594+
1595+
# z will be part of clone_m but not m
1596+
z = pm.Deterministic("z", clone_x + 1)
15781597
"""
15791598
from pymc.model.fgraph import clone_model
15801599

1581-
check_for_gp_vars = [
1582-
k for x in ["_rotated_", "_hsgp_coeffs_"] for k in self.named_vars.keys() if x in k
1583-
]
1584-
if len(check_for_gp_vars) > 0:
1585-
raise Exception("Unable to clone Gaussian Process Variables")
1586-
15871600
return clone_model(self)
15881601

15891602
def __deepcopy__(self, _):
15901603
"""
15911604
Clone a pymc model by overiding the python copy method using the clone_model method from fgraph.
1592-
if guassian process variables are detected then an exception will be raised.
1605+
Constants are not cloned and if guassian process variables are detected then a warning will be triggered.
1606+
1607+
Examples
1608+
--------
1609+
.. code-block:: python
1610+
1611+
import pymc as pm
1612+
import copy
1613+
1614+
with pm.Model() as m:
1615+
p = pm.Beta("p", 1, 1)
1616+
x = pm.Bernoulli("x", p=p, shape=(3,))
1617+
1618+
clone_m = copy.deepcopy(m)
1619+
1620+
# Access cloned variables by name
1621+
clone_x = clone_m["x"]
1622+
1623+
# z will be part of clone_m but not m
1624+
z = pm.Deterministic("z", clone_x + 1)
15931625
"""
15941626
from pymc.model.fgraph import clone_model
15951627

1596-
check_for_gp_vars = [
1597-
k for x in ["_rotated_", "_hsgp_coeffs_"] for k in self.named_vars.keys() if x in k
1598-
]
1599-
if len(check_for_gp_vars) > 0:
1600-
raise Exception("Unable to clone Gaussian Process Variables")
1601-
16021628
return clone_model(self)
16031629

16041630
def replace_rvs_by_values(

pymc/model/fgraph.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import warnings
15+
1416
from copy import copy, deepcopy
1517

1618
import pytensor
@@ -369,7 +371,7 @@ def clone_model(model: Model) -> Model:
369371
370372
Recreates a PyMC model with clones of the original variables.
371373
Shared variables will point to the same container but be otherwise different objects.
372-
Constants are not cloned.
374+
Constants are not cloned and if guassian process variables are detected then a warning will be triggered.
373375
374376
375377
Examples
@@ -391,6 +393,11 @@ def clone_model(model: Model) -> Model:
391393
z = pm.Deterministic("z", clone_x + 1)
392394
393395
"""
396+
check_for_gp_vars = [
397+
k for x in ["_rotated_", "_hsgp_coeffs_"] for k in model.named_vars.keys() if x in k
398+
]
399+
if len(check_for_gp_vars) > 0:
400+
warnings.warn("Unable to clone Gaussian Process Variables", UserWarning)
394401
return model_from_fgraph(fgraph_from_model(model)[0], mutate_fgraph=True)
395402

396403

tests/model/test_core.py

Lines changed: 22 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1764,15 +1764,15 @@ def test_graphviz_call_function(self, var_names, filenames) -> None:
17641764
)
17651765

17661766

1767-
class TestModelCopy(unittest.TestCase):
1767+
class TestModelCopy:
17681768
@staticmethod
17691769
def simple_model() -> pm.Model:
17701770
with pm.Model() as simple_model:
17711771
error = pm.HalfNormal("error", 0.5)
17721772
alpha = pm.Normal("alpha", 0, 1)
17731773
pm.Normal("y", alpha, error)
17741774
return simple_model
1775-
1775+
17761776
@staticmethod
17771777
def gp_model() -> pm.Model:
17781778
with pm.Model() as gp_model:
@@ -1782,7 +1782,7 @@ def gp_model() -> pm.Model:
17821782
f = gp.prior("f", X=np.arange(10)[:, None])
17831783
pm.Normal("y", f * 2)
17841784
return gp_model
1785-
1785+
17861786
def test_copy_model(self) -> None:
17871787
simple_model = self.simple_model()
17881788
copy_simple_model = copy.copy(simple_model)
@@ -1797,17 +1797,27 @@ def test_copy_model(self) -> None:
17971797
with deepcopy_simple_model:
17981798
deepcopy_simple_model_prior_predictive = pm.sample_prior_predictive(random_seed=42)
17991799

1800-
simple_model_prior_predictive_mean = simple_model_prior_predictive['prior']['y'].mean(('chain', 'draw'))
1801-
copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive['prior']['y'].mean(('chain', 'draw'))
1802-
deepcopy_simple_model_prior_predictive_mean = deepcopy_simple_model_prior_predictive['prior']['y'].mean(('chain', 'draw'))
1800+
simple_model_prior_predictive_mean = simple_model_prior_predictive["prior"]["y"].mean(
1801+
("chain", "draw")
1802+
)
1803+
copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive["prior"][
1804+
"y"
1805+
].mean(("chain", "draw"))
1806+
deepcopy_simple_model_prior_predictive_mean = deepcopy_simple_model_prior_predictive[
1807+
"prior"
1808+
]["y"].mean(("chain", "draw"))
18031809

1804-
assert np.isclose(simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean)
1805-
assert np.isclose(simple_model_prior_predictive_mean, deepcopy_simple_model_prior_predictive_mean)
1810+
assert np.isclose(
1811+
simple_model_prior_predictive_mean, copy_simple_model_prior_predictive_mean
1812+
)
1813+
assert np.isclose(
1814+
simple_model_prior_predictive_mean, deepcopy_simple_model_prior_predictive_mean
1815+
)
18061816

18071817
def test_guassian_process_copy_failure(self) -> None:
18081818
gaussian_process_model = self.gp_model()
1809-
with pytest.raises(Exception) as e:
1819+
with pytest.warns(UserWarning):
18101820
copy.copy(gaussian_process_model)
1811-
1812-
with pytest.raises(Exception) as e:
1813-
copy.deepcopy(gaussian_process_model)
1821+
1822+
with pytest.warns(UserWarning):
1823+
copy.deepcopy(gaussian_process_model)

0 commit comments

Comments
 (0)