Skip to content

Commit 7619aef

Browse files
authored
🔥 remove deprecated numpy function (#4244)
1 parent ad63b94 commit 7619aef

File tree

2 files changed

+3
-3
lines changed

2 files changed

+3
-3
lines changed

pymc3/distributions/multivariate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ def random(self, point=None, size=None):
827827
"""
828828
nu, V = draw_values([self.nu, self.V], point=point, size=size)
829829
size = 1 if size is None else size
830-
return generate_samples(stats.wishart.rvs, np.asscalar(nu), V, broadcast_shape=(size,))
830+
return generate_samples(stats.wishart.rvs, nu.item(), V, broadcast_shape=(size,))
831831

832832
def logp(self, X):
833833
"""

pymc3/util.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import xarray
2020
import arviz
21-
from numpy import asscalar, ndarray
21+
from numpy import ndarray
2222

2323
from theano.tensor import TensorVariable
2424

@@ -149,7 +149,7 @@ def get_repr_for_variable(variable, formatting="plain"):
149149
pass
150150
value = variable.eval()
151151
if not value.shape or value.shape == (1,):
152-
return asscalar(value)
152+
return value.item()
153153
return "array"
154154

155155
if formatting == "latex":

0 commit comments

Comments
 (0)