Skip to content

Commit 9cb34d1

Browse files
ricardoV94ferrine
andauthored
Make xhistogram dependency optional (#92)
* Make xhistogram dependency optional * Update setup.py * Update setup.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> * Update setup.py * Update setup.py * pretty extras require * Update setup.py Co-authored-by: Ricardo Vieira <28983449+ricardoV94@users.noreply.github.com> * fix black Co-authored-by: Maxim Kochurov <maxim.v.kochurov@gmail.com>
1 parent c3634e1 commit 9cb34d1

File tree

3 files changed

+25
-3
lines changed

3 files changed

+25
-3
lines changed

pymc_experimental/distributions/histogram_utils.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
import numpy as np
1919
import pymc as pm
20-
import xhistogram.core
2120
from numpy.typing import ArrayLike
2221

2322
try:
@@ -26,13 +25,21 @@
2625
except ImportError:
2726
dask = None
2827

28+
try:
29+
import xhistogram.core
30+
except ImportError:
31+
xhistogram = None
32+
2933

3034
__all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"]
3135

3236

3337
def quantile_histogram(
3438
data: ArrayLike, n_quantiles=1000, zero_inflation=False
3539
) -> Dict[str, ArrayLike]:
40+
if xhistogram is None:
41+
raise RuntimeError("quantile_histogram requires xhistogram package")
42+
3643
if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)):
3744
data = data.to_dask_array(lengths=True)
3845
if zero_inflation:
@@ -67,6 +74,9 @@ def quantile_histogram(
6774

6875

6976
def discrete_histogram(data: ArrayLike, min_count=None) -> Dict[str, ArrayLike]:
77+
if xhistogram is None:
78+
raise RuntimeError("discrete_histogram requires xhistogram package")
79+
7080
if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)):
7181
data = data.to_dask_array(lengths=True)
7282
mid, count_uniq = np.unique(data, return_counts=True)

requirements.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1 @@
11
pymc>=4.3.0
2-
xhistogram

setup.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import itertools
1516
import re
1617
from codecs import open
1718
from os.path import dirname, join, realpath
@@ -46,10 +47,14 @@
4647
LONG_DESCRIPTION = buff.read()
4748

4849
REQUIREMENTS_FILE = join(PROJECT_ROOT, "requirements.txt")
50+
DEV_REQUIREMENTS_FILE = join(PROJECT_ROOT, "requirements-dev.txt")
4951

5052
with open(REQUIREMENTS_FILE) as f:
5153
install_reqs = f.read().splitlines()
5254

55+
with open(DEV_REQUIREMENTS_FILE) as f:
56+
dev_install_reqs = f.read().splitlines()
57+
5358

5459
def get_version():
5560
VERSIONFILE = join("pymc_experimental", "__init__.py")
@@ -62,6 +67,14 @@ def get_version():
6267
raise RuntimeError(f"Unable to find version in {VERSIONFILE}.")
6368

6469

70+
extras_require = dict(
71+
dask_histogram=["dask[complete]", "xhistogram"],
72+
histogram=["xhistogram"],
73+
)
74+
extras_require["complete"] = sorted(set(itertools.chain.from_iterable(extras_require.values())))
75+
extras_require["dev"] = dev_install_reqs
76+
77+
6578
if __name__ == "__main__":
6679
setup(
6780
name=DISTNAME,
@@ -81,5 +94,5 @@ def get_version():
8194
classifiers=classifiers,
8295
python_requires=">=3.8",
8396
install_requires=install_reqs,
84-
extras_requires=dict(dask=["dask[all]"]),
97+
extras_require=extras_require,
8598
)

0 commit comments

Comments
 (0)