-
-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Allow copy and deepcopy of PYMC models #7492
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 7 commits
b34742d
fe4e0c5
bcb4309
33c5766
88fde25
90419cb
07106ec
fb00f85
d057a9d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,6 +11,7 @@ | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
import copy | ||
import pickle | ||
import threading | ||
import traceback | ||
|
@@ -1761,3 +1762,57 @@ def test_graphviz_call_function(self, var_names, filenames) -> None: | |
figsize=None, | ||
dpi=300, | ||
) | ||
|
||
|
||
class TestModelCopy: | ||
@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy)) | ||
def test_copy_model(self, copy_method) -> None: | ||
with pm.Model() as simple_model: | ||
error = pm.HalfNormal("error", 0.5) | ||
alpha = pm.Normal("alpha", 0, 1) | ||
pm.Normal("y", alpha, error) | ||
|
||
copy_simple_model = copy_method(simple_model) | ||
|
||
with simple_model: | ||
simple_model_prior_predictive = pm.sample_prior_predictive(samples=1, random_seed=42) | ||
|
||
with copy_simple_model: | ||
copy_simple_model_prior_predictive = pm.sample_prior_predictive( | ||
samples=1, random_seed=42 | ||
) | ||
|
||
simple_model_prior_predictive_val = simple_model_prior_predictive["prior"]["y"].values | ||
copy_simple_model_prior_predictive_val = copy_simple_model_prior_predictive["prior"][ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just compare directly, no need to assign to separate variables, that are almost as verbose as the way they are accessed |
||
"y" | ||
].values | ||
|
||
assert simple_model_prior_predictive_val == copy_simple_model_prior_predictive_val | ||
|
||
with copy_simple_model: | ||
z = pm.Deterministic("z", copy_simple_model["alpha"] + 1) | ||
copy_simple_model_prior_predictive = pm.sample_prior_predictive( | ||
samples=1, random_seed=42 | ||
) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You can do this above, and call |
||
|
||
assert "z" in copy_simple_model.named_vars | ||
assert "z" not in simple_model.named_vars | ||
assert ( | ||
copy_simple_model_prior_predictive["prior"]["z"].values | ||
== 1 + copy_simple_model_prior_predictive["prior"]["alpha"].values | ||
) | ||
|
||
@pytest.mark.parametrize("copy_method", (copy.copy, copy.deepcopy)) | ||
def test_guassian_process_copy_failure(self, copy_method) -> None: | ||
with pm.Model() as gaussian_process_model: | ||
ell = pm.Gamma("ell", alpha=2, beta=1) | ||
cov = 2 * pm.gp.cov.ExpQuad(1, ell) | ||
gp = pm.gp.Latent(cov_func=cov) | ||
f = gp.prior("f", X=np.arange(10)[:, None]) | ||
pm.Normal("y", f * 2) | ||
|
||
with pytest.warns( | ||
UserWarning, | ||
match="Detected variables likely created by GP objects. Further use of these old GP objects should be avoided as it may reintroduce variables from the old model. See issue: https://github.com/pymc-devs/pymc/issues/6883", | ||
): | ||
copy_method(gaussian_process_model) |
Uh oh!
There was an error while loading. Please reload this page.