Skip to content

Commit 915e6a2

Browse files
authored
Better float32 sampling support for TruncatedNormal (#7026)
* manually inv transform; force rng same type * always upcast f64 and downcast to dtype of param * add comment * use class attr dtype * need else stmt for dtype * actually no need to downcast in this method * rm unused import
1 parent f8b142a commit 915e6a2

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

pymc/distributions/continuous.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -568,11 +568,13 @@ def rng_fn(
568568
upper: Union[np.ndarray, float],
569569
size: Optional[Union[List[int], int]],
570570
) -> np.ndarray:
571+
# Upcast to float64. (Caller will downcast to desired dtype if needed)
572+
# (Work-around for https://github.com/scipy/scipy/issues/15928)
571573
return stats.truncnorm.rvs(
572-
a=(lower - mu) / sigma,
573-
b=(upper - mu) / sigma,
574-
loc=mu,
575-
scale=sigma,
574+
a=((lower - mu) / sigma).astype("float64"),
575+
b=((upper - mu) / sigma).astype("float64"),
576+
loc=(mu).astype("float64"),
577+
scale=(sigma).astype("float64"),
576578
size=size,
577579
random_state=rng,
578580
)

0 commit comments

Comments
 (0)