diff --git a/pymc_experimental/distributions/histogram_utils.py b/pymc_experimental/distributions/histogram_utils.py index 1829e9f95..608615ae2 100644 --- a/pymc_experimental/distributions/histogram_utils.py +++ b/pymc_experimental/distributions/histogram_utils.py @@ -17,7 +17,6 @@ import numpy as np import pymc as pm -import xhistogram.core from numpy.typing import ArrayLike try: @@ -26,6 +25,11 @@ except ImportError: dask = None +try: + import xhistogram.core +except ImportError: + xhistogram = None + __all__ = ["quantile_histogram", "discrete_histogram", "histogram_approximation"] @@ -33,6 +37,9 @@ def quantile_histogram( data: ArrayLike, n_quantiles=1000, zero_inflation=False ) -> Dict[str, ArrayLike]: + if xhistogram is None: + raise RuntimeError("quantile_histogram requires xhistogram package") + if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)): data = data.to_dask_array(lengths=True) if zero_inflation: @@ -67,6 +74,9 @@ def quantile_histogram( def discrete_histogram(data: ArrayLike, min_count=None) -> Dict[str, ArrayLike]: + if xhistogram is None: + raise RuntimeError("discrete_histogram requires xhistogram package") + if dask and isinstance(data, (dask.dataframe.Series, dask.dataframe.DataFrame)): data = data.to_dask_array(lengths=True) mid, count_uniq = np.unique(data, return_counts=True) diff --git a/requirements.txt b/requirements.txt index e820d8cbe..48d2355af 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ pymc>=4.3.0 -xhistogram diff --git a/setup.py b/setup.py index a89029664..c142cc867 100644 --- a/setup.py +++ b/setup.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import itertools import re from codecs import open from os.path import dirname, join, realpath @@ -46,10 +47,14 @@ LONG_DESCRIPTION = buff.read() REQUIREMENTS_FILE = join(PROJECT_ROOT, "requirements.txt") +DEV_REQUIREMENTS_FILE = join(PROJECT_ROOT, "requirements-dev.txt") with open(REQUIREMENTS_FILE) as f: install_reqs = f.read().splitlines() +with open(DEV_REQUIREMENTS_FILE) as f: + dev_install_reqs = f.read().splitlines() + def get_version(): VERSIONFILE = join("pymc_experimental", "__init__.py") @@ -62,6 +67,14 @@ def get_version(): raise RuntimeError(f"Unable to find version in {VERSIONFILE}.") +extras_require = dict( + dask_histogram=["dask[complete]", "xhistogram"], + histogram=["xhistogram"], +) +extras_require["complete"] = sorted(set(itertools.chain.from_iterable(extras_require.values()))) +extras_require["dev"] = dev_install_reqs + + if __name__ == "__main__": setup( name=DISTNAME, @@ -81,5 +94,5 @@ def get_version(): classifiers=classifiers, python_requires=">=3.8", install_requires=install_reqs, - extras_requires=dict(dask=["dask[all]"]), + extras_require=extras_require, )