From 4543d083de10b3d64e80f3083f46b0f36192af7d Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Fri, 10 Feb 2023 15:06:44 +0100 Subject: [PATCH 01/11] added dim inference from xarray, deprecation warning and unittest for the new feature --- pymc/data.py | 45 +++++++++++++++++++++++++++++++++++------ pymc/tests/test_data.py | 12 +++++++++++ 2 files changed, 51 insertions(+), 6 deletions(-) diff --git a/pymc/data.py b/pymc/data.py index b7d0dcac11..4cf7e07757 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -22,8 +22,10 @@ from typing import Dict, Optional, Sequence, Tuple, Union, cast import numpy as np +import pandas as pd import pytensor import pytensor.tensor as at +import xarray as xr from pytensor.compile.sharedvalue import SharedVariable from pytensor.raise_op import Assert @@ -205,7 +207,7 @@ def Minibatch(variable: TensorVariable, *variables: TensorVariable, batch_size: def determine_coords( model, - value, + value: Union[pd.DataFrame, pd.Series, xr.DataArray], dims: Optional[Sequence[Optional[str]]] = None, coords: Optional[Dict[str, Sequence]] = None, ) -> Tuple[Dict[str, Sequence], Sequence[Optional[str]]]: @@ -213,9 +215,9 @@ def determine_coords( if coords is None: coords = {} + dim_name = None # If value is a df or a series, we interpret the index as coords: if hasattr(value, "index"): - dim_name = None if dims is not None: dim_name = dims[0] if dim_name is None and value.index.name is not None: @@ -225,7 +227,6 @@ def determine_coords( # If value is a df, we also interpret the columns as coords: if hasattr(value, "columns"): - dim_name = None if dims is not None: dim_name = dims[1] if dim_name is None and value.columns.name is not None: @@ -233,6 +234,12 @@ def determine_coords( if dim_name is not None: coords[dim_name] = value.columns + if isinstance(value, xr.DataArray): + if dims is not None: + for dim in dims: + dim_name = dim + coords[dim_name] = value["dim"] + if isinstance(value, np.ndarray) and dims is not None: if len(dims) != value.ndim: raise pm.exceptions.ShapeError( @@ -259,6 +266,7 @@ def ConstantData( dims: Optional[Sequence[str]] = None, coords: Optional[Dict[str, Sequence]] = None, export_index_as_coords=False, + infer_dims_and_coords=False, **kwargs, ) -> TensorConstant: """Alias for ``pm.Data(..., mutable=False)``. @@ -266,12 +274,19 @@ def ConstantData( Registers the ``value`` as a :class:`~pytensor.tensor.TensorConstant` with the model. For more information, please reference :class:`pymc.Data`. """ + if export_index_as_coords: + infer_dims_and_coords = export_index_as_coords + warnings.warn( + "Deprecation warning: 'export_index_as_coords; is deprecated adn will be removed in future versions. Please use 'infer_dims_and_coords' instead.", + DeprecationWarning, + ) + var = Data( name, value, dims=dims, coords=coords, - export_index_as_coords=export_index_as_coords, + infer_dims_and_coords=infer_dims_and_coords, mutable=False, **kwargs, ) @@ -285,6 +300,7 @@ def MutableData( dims: Optional[Sequence[str]] = None, coords: Optional[Dict[str, Sequence]] = None, export_index_as_coords=False, + infer_dims_and_coords=False, **kwargs, ) -> SharedVariable: """Alias for ``pm.Data(..., mutable=True)``. @@ -292,12 +308,19 @@ def MutableData( Registers the ``value`` as a :class:`~pytensor.compile.sharedvalue.SharedVariable` with the model. For more information, please reference :class:`pymc.Data`. """ + if export_index_as_coords: + infer_dims_and_coords = export_index_as_coords + warnings.warn( + "Deprecation warning: 'export_index_as_coords; is deprecated adn will be removed in future versions. Please use 'infer_dims_and_coords' instead.", + DeprecationWarning, + ) + var = Data( name, value, dims=dims, coords=coords, - export_index_as_coords=export_index_as_coords, + infer_dims_and_coords=infer_dims_and_coords, mutable=True, **kwargs, ) @@ -311,6 +334,7 @@ def Data( dims: Optional[Sequence[str]] = None, coords: Optional[Dict[str, Sequence]] = None, export_index_as_coords=False, + infer_dims_and_coords=False, mutable: Optional[bool] = None, **kwargs, ) -> Union[SharedVariable, TensorConstant]: @@ -347,7 +371,8 @@ def Data( names. coords : dict, optional Coordinate values to set for new dimensions introduced by this ``Data`` variable. - export_index_as_coords : bool, default=False + export_index_as_coords : deprecated, previous version of "infer_dims_and_coords" + infer_dims_and_coords : bool, default=False If True, the ``Data`` container will try to infer what the coordinates and dimension names should be if there is an index in ``value``. mutable : bool, optional @@ -426,7 +451,15 @@ def Data( ) # Optionally infer coords and dims from the input value. + if export_index_as_coords: + infer_dims_and_coords = export_index_as_coords + warnings.warn( + "Deprecation warning: 'export_index_as_coords; is deprecated adn will be removed in future versions. Please use 'infer_dims_and_coords' instead.", + DeprecationWarning, + ) + + if infer_dims_and_coords: coords, dims = determine_coords(model, value, dims) if dims: diff --git a/pymc/tests/test_data.py b/pymc/tests/test_data.py index 1a13b2176a..ba062c734c 100644 --- a/pymc/tests/test_data.py +++ b/pymc/tests/test_data.py @@ -405,6 +405,18 @@ def test_implicit_coords_dataframe(self): assert "columns" in pmodel.coords assert pmodel.named_vars_to_dims == {"observations": ("rows", "columns")} + def test_implicit_coords_xarray(self): + xr = pytest.importorskip("xarray") + data = xr.DataArray([[1, 2, 3], [4, 5, 6]], dims=("y", "x")) + with pm.Model() as pmodel: + with pytest.warns(DeprecationWarning): + pm.ConstantData("observations", data, dims=("x", "y"), export_index_as_coords=True) + assert "x" in pmodel.coords + assert "y" in pmodel.coords + assert pmodel.named_vars_to_dims == {"observations": ("x", "y")} + assert pmodel.coords["x"] == [0, 1, 2] + assert pmodel.coords["y"] == [0, 1] + def test_data_kwargs(self): strict_value = True allow_downcast_value = False From 0c11abceeb3ca2445337b759f2262f5f0bc1c5f3 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Sun, 12 Feb 2023 11:01:22 +0100 Subject: [PATCH 02/11] fixed typo in warning --- pymc/data.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc/data.py b/pymc/data.py index 4cf7e07757..e67bc8c8db 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -277,7 +277,7 @@ def ConstantData( if export_index_as_coords: infer_dims_and_coords = export_index_as_coords warnings.warn( - "Deprecation warning: 'export_index_as_coords; is deprecated adn will be removed in future versions. Please use 'infer_dims_and_coords' instead.", + "Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead.", DeprecationWarning, ) @@ -311,7 +311,7 @@ def MutableData( if export_index_as_coords: infer_dims_and_coords = export_index_as_coords warnings.warn( - "Deprecation warning: 'export_index_as_coords; is deprecated adn will be removed in future versions. Please use 'infer_dims_and_coords' instead.", + "Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead.", DeprecationWarning, ) @@ -455,7 +455,7 @@ def Data( if export_index_as_coords: infer_dims_and_coords = export_index_as_coords warnings.warn( - "Deprecation warning: 'export_index_as_coords; is deprecated adn will be removed in future versions. Please use 'infer_dims_and_coords' instead.", + "Deprecation warning: 'export_index_as_coords; is deprecated and will be removed in future versions. Please use 'infer_dims_and_coords' instead.", DeprecationWarning, ) From 94b5b28ffa4f4fc7b782b0139d38fedc52316f07 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Wed, 15 Feb 2023 10:54:43 +0100 Subject: [PATCH 03/11] fixed accidental quotation around dim --- pymc/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/data.py b/pymc/data.py index e67bc8c8db..73f17a010c 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -238,7 +238,7 @@ def determine_coords( if dims is not None: for dim in dims: dim_name = dim - coords[dim_name] = value["dim"] + coords[dim_name] = value[dim] if isinstance(value, np.ndarray) and dims is not None: if len(dims) != value.ndim: From 8396a05d3e35dfc997516c3cb9aa57e398afa5f5 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Wed, 15 Feb 2023 11:31:01 +0100 Subject: [PATCH 04/11] fixed failing assertions --- pymc/tests/test_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/tests/test_data.py b/pymc/tests/test_data.py index ba062c734c..76e7767609 100644 --- a/pymc/tests/test_data.py +++ b/pymc/tests/test_data.py @@ -414,8 +414,8 @@ def test_implicit_coords_xarray(self): assert "x" in pmodel.coords assert "y" in pmodel.coords assert pmodel.named_vars_to_dims == {"observations": ("x", "y")} - assert pmodel.coords["x"] == [0, 1, 2] - assert pmodel.coords["y"] == [0, 1] + assert tuple(pmodel.coords["x"]) == (data.coords["x"],) + assert tuple(pmodel.coords["y"]) == (data.coords["y"],) def test_data_kwargs(self): strict_value = True From 451c3befc721c6081f3929a66b01a49271a699d7 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Fri, 17 Feb 2023 12:35:48 +0100 Subject: [PATCH 05/11] found and fixed cause of the failing test --- pymc/tests/test_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/tests/test_data.py b/pymc/tests/test_data.py index 76e7767609..e330b38b57 100644 --- a/pymc/tests/test_data.py +++ b/pymc/tests/test_data.py @@ -414,8 +414,8 @@ def test_implicit_coords_xarray(self): assert "x" in pmodel.coords assert "y" in pmodel.coords assert pmodel.named_vars_to_dims == {"observations": ("x", "y")} - assert tuple(pmodel.coords["x"]) == (data.coords["x"],) - assert tuple(pmodel.coords["y"]) == (data.coords["y"],) + assert tuple(pmodel.coords["x"]) == (data.coords["x"]) + assert tuple(pmodel.coords["y"]) == (data.coords["y"]) def test_data_kwargs(self): strict_value = True From 675acbb07356d8adb4b9e775f76241394de15347 Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Fri, 17 Feb 2023 16:31:22 +0100 Subject: [PATCH 06/11] changed the coords assertion according to suggested form --- pymc/tests/test_data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc/tests/test_data.py b/pymc/tests/test_data.py index e330b38b57..88db4bb2b9 100644 --- a/pymc/tests/test_data.py +++ b/pymc/tests/test_data.py @@ -414,8 +414,8 @@ def test_implicit_coords_xarray(self): assert "x" in pmodel.coords assert "y" in pmodel.coords assert pmodel.named_vars_to_dims == {"observations": ("x", "y")} - assert tuple(pmodel.coords["x"]) == (data.coords["x"]) - assert tuple(pmodel.coords["y"]) == (data.coords["y"]) + assert tuple(pmodel.coords["x"]) == tuple(data.coords["x"].to_numpy()) + assert tuple(pmodel.coords["y"]) == tuple(data.coords["y"].to_numpy()) def test_data_kwargs(self): strict_value = True From 747685c139f63bfa09042f141cb0312c8e67315d Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Sat, 18 Feb 2023 12:32:09 +0100 Subject: [PATCH 07/11] fixing mypy type missmatch --- pymc/data.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/pymc/data.py b/pymc/data.py index 73f17a010c..c3fd62f27b 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -238,7 +238,9 @@ def determine_coords( if dims is not None: for dim in dims: dim_name = dim - coords[dim_name] = value[dim] + # because coord is expected to be a sequence, we need to convert xarray + # using 'tolist()' function + coords[str(dim_name)] = value[dim].tolist() if isinstance(value, np.ndarray) and dims is not None: if len(dims) != value.ndim: From 45ff28b9c79f5c34a99dc38d00a394c481f2d1ad Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Sun, 19 Feb 2023 18:15:02 +0100 Subject: [PATCH 08/11] working on getting the test to work --- pymc/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/data.py b/pymc/data.py index c3fd62f27b..79c19ca3cf 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -240,7 +240,7 @@ def determine_coords( dim_name = dim # because coord is expected to be a sequence, we need to convert xarray # using 'tolist()' function - coords[str(dim_name)] = value[dim].tolist() + coords[str(dim_name)] = value[dim].to_numpy() if isinstance(value, np.ndarray) and dims is not None: if len(dims) != value.ndim: From 6741d4bd4f0efb59c118678b1d0c6134de8682bb Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Mon, 20 Feb 2023 11:48:34 +0100 Subject: [PATCH 09/11] removed typecasting to string on dim_name, was causing the mypy to fail --- pymc/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc/data.py b/pymc/data.py index 79c19ca3cf..ff569a4b47 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -240,7 +240,7 @@ def determine_coords( dim_name = dim # because coord is expected to be a sequence, we need to convert xarray # using 'tolist()' function - coords[str(dim_name)] = value[dim].to_numpy() + coords[dim_name] = value[dim].to_numpy() if isinstance(value, np.ndarray) and dims is not None: if len(dims) != value.ndim: From a56fedd7f7605bb98e86a0f59521696532c0824a Mon Sep 17 00:00:00 2001 From: Michal Raczycki Date: Mon, 20 Feb 2023 14:54:00 +0100 Subject: [PATCH 10/11] took care locally of mypy errors --- pymc/data.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/pymc/data.py b/pymc/data.py index ff569a4b47..6694152cb4 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -209,8 +209,8 @@ def determine_coords( model, value: Union[pd.DataFrame, pd.Series, xr.DataArray], dims: Optional[Sequence[Optional[str]]] = None, - coords: Optional[Dict[str, Sequence]] = None, -) -> Tuple[Dict[str, Sequence], Sequence[Optional[str]]]: + coords: Optional[Dict[str, Union[Sequence, np.ndarray]]] = None, +) -> Tuple[Dict[str, Union[Sequence, np.ndarray]], Sequence[Optional[str]]]: """Determines coordinate values from data or the model (via ``dims``).""" if coords is None: coords = {} @@ -240,7 +240,7 @@ def determine_coords( dim_name = dim # because coord is expected to be a sequence, we need to convert xarray # using 'tolist()' function - coords[dim_name] = value[dim].to_numpy() + coords[str(dim_name)] = value[dim].to_numpy() if isinstance(value, np.ndarray) and dims is not None: if len(dims) != value.ndim: @@ -266,7 +266,7 @@ def ConstantData( value, *, dims: Optional[Sequence[str]] = None, - coords: Optional[Dict[str, Sequence]] = None, + coords: Optional[Dict[str, Union[Sequence, np.ndarray]]] = None, export_index_as_coords=False, infer_dims_and_coords=False, **kwargs, @@ -300,7 +300,7 @@ def MutableData( value, *, dims: Optional[Sequence[str]] = None, - coords: Optional[Dict[str, Sequence]] = None, + coords: Optional[Dict[str, Union[Sequence, np.ndarray]]] = None, export_index_as_coords=False, infer_dims_and_coords=False, **kwargs, @@ -334,7 +334,7 @@ def Data( value, *, dims: Optional[Sequence[str]] = None, - coords: Optional[Dict[str, Sequence]] = None, + coords: Optional[Dict[str, Union[Sequence, np.ndarray]]] = None, export_index_as_coords=False, infer_dims_and_coords=False, mutable: Optional[bool] = None, From 4038a7dce7e930905cd677408eee9eb354dbcb9c Mon Sep 17 00:00:00 2001 From: Michael Osthege Date: Wed, 22 Feb 2023 11:34:29 +0100 Subject: [PATCH 11/11] Typo/formatting fixes --- pymc/data.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/pymc/data.py b/pymc/data.py index 6694152cb4..19216f4550 100644 --- a/pymc/data.py +++ b/pymc/data.py @@ -238,8 +238,7 @@ def determine_coords( if dims is not None: for dim in dims: dim_name = dim - # because coord is expected to be a sequence, we need to convert xarray - # using 'tolist()' function + # str is applied because dim entries may be None coords[str(dim_name)] = value[dim].to_numpy() if isinstance(value, np.ndarray) and dims is not None: @@ -373,7 +372,8 @@ def Data( names. coords : dict, optional Coordinate values to set for new dimensions introduced by this ``Data`` variable. - export_index_as_coords : deprecated, previous version of "infer_dims_and_coords" + export_index_as_coords : bool + Deprecated, previous version of "infer_dims_and_coords" infer_dims_and_coords : bool, default=False If True, the ``Data`` container will try to infer what the coordinates and dimension names should be if there is an index in ``value``. @@ -453,7 +453,6 @@ def Data( ) # Optionally infer coords and dims from the input value. - if export_index_as_coords: infer_dims_and_coords = export_index_as_coords warnings.warn(