11
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
# See the License for the specific language governing permissions and
13
13
# limitations under the License.
14
+ import copy
14
15
import pickle
15
16
import threading
16
17
import traceback
@@ -1761,3 +1762,52 @@ def test_graphviz_call_function(self, var_names, filenames) -> None:
1761
1762
figsize = None ,
1762
1763
dpi = 300 ,
1763
1764
)
1765
+
1766
+
1767
+ class TestModelCopy (unittest .TestCase ):
1768
+ @staticmethod
1769
+ def simple_model () -> pm .Model :
1770
+ with pm .Model () as simple_model :
1771
+ error = pm .HalfNormal ("error" , 0.5 )
1772
+ alpha = pm .Normal ("alpha" , 0 , 1 )
1773
+ pm .Normal ("y" , alpha , error )
1774
+ return simple_model
1775
+
1776
+ @staticmethod
1777
+ def gp_model () -> pm .Model :
1778
+ with pm .Model () as gp_model :
1779
+ ell = pm .Gamma ("ell" , alpha = 2 , beta = 1 )
1780
+ cov = 2 * pm .gp .cov .ExpQuad (1 , ell )
1781
+ gp = pm .gp .Latent (cov_func = cov )
1782
+ f = gp .prior ("f" , X = np .arange (10 )[:, None ])
1783
+ pm .Normal ("y" , f * 2 )
1784
+ return gp_model
1785
+
1786
+ def test_copy_model (self ) -> None :
1787
+ simple_model = self .simple_model ()
1788
+ copy_simple_model = copy .copy (simple_model )
1789
+ deepcopy_simple_model = copy .deepcopy (simple_model )
1790
+
1791
+ with simple_model :
1792
+ simple_model_prior_predictive = pm .sample_prior_predictive (random_seed = 42 )
1793
+
1794
+ with copy_simple_model :
1795
+ copy_simple_model_prior_predictive = pm .sample_prior_predictive (random_seed = 42 )
1796
+
1797
+ with deepcopy_simple_model :
1798
+ deepcopy_simple_model_prior_predictive = pm .sample_prior_predictive (random_seed = 42 )
1799
+
1800
+ simple_model_prior_predictive_mean = simple_model_prior_predictive ['prior' ]['y' ].mean (('chain' , 'draw' ))
1801
+ copy_simple_model_prior_predictive_mean = copy_simple_model_prior_predictive ['prior' ]['y' ].mean (('chain' , 'draw' ))
1802
+ deepcopy_simple_model_prior_predictive_mean = deepcopy_simple_model_prior_predictive ['prior' ]['y' ].mean (('chain' , 'draw' ))
1803
+
1804
+ assert np .isclose (simple_model_prior_predictive_mean , copy_simple_model_prior_predictive_mean )
1805
+ assert np .isclose (simple_model_prior_predictive_mean , deepcopy_simple_model_prior_predictive_mean )
1806
+
1807
+ def test_guassian_process_copy_failure (self ) -> None :
1808
+ gaussian_process_model = self .gp_model ()
1809
+ with pytest .raises (Exception ) as e :
1810
+ copy .copy (gaussian_process_model )
1811
+
1812
+ with pytest .raises (Exception ) as e :
1813
+ copy .deepcopy (gaussian_process_model )
0 commit comments