Skip to content

GMTDataArrayAccessor: Enable grid operations on the current xarray.DataArray object directly #3854

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 16 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 111 additions & 1 deletion pygmt/xarray/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,19 @@
import xarray as xr
from pygmt.enums import GridRegistration, GridType
from pygmt.exceptions import GMTInvalidInput
from pygmt.src.grdinfo import grdinfo
from pygmt.src import (
dimfilter,
grdclip,
grdcut,
grdfill,
grdfilter,
grdgradient,
grdhisteq,
grdinfo,
grdproject,
grdsample,
grdtrack,
)


@xr.register_dataarray_accessor("gmt")
Expand All @@ -23,6 +35,11 @@
- ``registration``: Grid registration type :class:`pygmt.enums.GridRegistration`.
- ``gtype``: Grid coordinate system type :class:`pygmt.enums.GridType`.

The *gmt* accessor also provides a set of grid-operation methods that enables
applying GMT's grid processing functionalities directly to the current
:class:`xarray.DataArray` object. See the summary table below for the list of
available methods.

Notes
-----
When accessed the first time, the *gmt* accessor will first be initialized to the
Expand Down Expand Up @@ -150,6 +167,19 @@
>>> zval.gmt.gtype = GridType.GEOGRAPHIC
>>> zval.gmt.registration, zval.gmt.gtype
(<GridRegistration.GRIDLINE: 0>, <GridType.GEOGRAPHIC: 1>)

Instead of calling a grid-processing function and passing the
:class:`xarray.DataArray` object as an input, you can call the corresponding method
directly on the object. For example, the following two are equivalent:

>>> from pygmt.datasets import load_earth_relief
>>> grid = load_earth_relief(resolution="30m", region=[10, 30, 15, 25])
>>> # Create a new grid from an input grid. Set all values below 1,000 to
>>> # 0 and all values above 1,500 to 10,000.
>>> # Option 1:
>>> new_grid = pygmt.grdclip(grid=grid, below=[1000, 0], above=[1500, 10000])
>>> # Option 2:
>>> new_grid = grid.gmt.clip(below=[1000, 0], above=[1500, 10000])
"""

def __init__(self, xarray_obj: xr.DataArray):
Expand Down Expand Up @@ -204,3 +234,83 @@
)
raise GMTInvalidInput(msg)
self._gtype = GridType(value)

def dimfilter(self, **kwargs) -> xr.DataArray:
"""
Directional filtering of a grid in the space domain.

See the :func:`pygmt.dimfilter` function for available parameters.
"""
return dimfilter(grid=self._obj, **kwargs)

Check warning on line 244 in pygmt/xarray/accessor.py

View check run for this annotation

Codecov / codecov/patch

pygmt/xarray/accessor.py#L244

Added line #L244 was not covered by tests
Comment on lines +238 to +244
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shorter way to declare new methods is to put this in the __init__ method:

import functools

def __init__(self, xarray_obj: xr.DataArray):
    ...
    self.dimfilter = functools.partial(dimfilter, grid=self._obj)
    self.dimfilter.__doc__ = dimfilter.__doc__

This would preserve the full docs too, e.g. output from help(grid.gmt.dimfilter)

Help on partial in module functools:

functools.partial(<function dimfilter at 0x7fd07...e:  [190. 981.]
    long_name:     elevation (m))
    Directional filtering of grids in the space domain.
    
    Filter a grid in the space (or time) domain by
    dividing the given filter circle into the given number of sectors,
    applying one of the selected primary convolution or non-convolution
    filters to each sector, and choosing the final outcome according to the
    selected secondary filter. It computes distances using Cartesian or
    Spherical geometries. The output grid can optionally be generated as a
    subregion of the input and/or with a new increment using ``spacing``,
    which may add an "extra space" in the input data to prevent edge
    effects for the output grid. If the filter is low-pass, then the output
    may be less frequently sampled than the input. :func:`pygmt.dimfilter`
    will not produce a smooth output as other spatial filters
    do because it returns a minimum median out of *N* medians of *N*
    sectors. The output can be rough unless the input data are noise-free.
    Thus, an additional filtering (e.g., Gaussian via :func:`pygmt.grdfilter`)
    of the DiM-filtered data is generally recommended.
    
    Full option list at :gmt-docs:`dimfilter.html`
    
    **Aliases:**
    
    .. hlist::
       :columns: 3
    
       - D = distance
       - F = filter
       - I = spacing
       - N = sectors
       - R = region
       - V = verbose

There might be a nicer way to wrap things (maybe using https://docs.python.org/3/library/functools.html#functools.update_wrapper?), but haven't played around with it too much.

Copy link
Member Author

@seisman seisman May 10, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried your solution above and found two issues:

  1. self.dimfilter is an attribute of the accessor, so it's not shown on the documentation (https://pygmt-dev--3854.org.readthedocs.build/en/3854/api/generated/pygmt.GMTDataArrayAccessor.html)
  2. The help docs still show grid as its first parameter, which may be more confusing?


def clip(self, **kwargs) -> xr.DataArray:
"""
Clip the range of grid values.

See the :func:`pygmt.grdclip` function for available parameters.
"""
return grdclip(grid=self._obj, **kwargs)

def cut(self, **kwargs) -> xr.DataArray:
"""
Extract subregion from a grid or image or a slice from a cube.

See the :func:`pygmt.grdcut` function for available parameters.
"""
return grdcut(grid=self._obj, **kwargs)

Check warning on line 260 in pygmt/xarray/accessor.py

View check run for this annotation

Codecov / codecov/patch

pygmt/xarray/accessor.py#L260

Added line #L260 was not covered by tests

def equalize_hist(self, **kwargs) -> xr.DataArray:
"""
Perform histogram equalization for a grid.

See the :meth:`pygmt.grdhisteq.equalize_grid` method for available parameters.
"""
return grdhisteq.equalize_grid(grid=self._obj, **kwargs)

Check warning on line 268 in pygmt/xarray/accessor.py

View check run for this annotation

Codecov / codecov/patch

pygmt/xarray/accessor.py#L268

Added line #L268 was not covered by tests

def fill(self, **kwargs) -> xr.DataArray:
"""
Interpolate across holes in the grid.

See the :func:`pygmt.grdfill` function for available parameters.
"""
return grdfill(grid=self._obj, **kwargs)

Check warning on line 276 in pygmt/xarray/accessor.py

View check run for this annotation

Codecov / codecov/patch

pygmt/xarray/accessor.py#L276

Added line #L276 was not covered by tests

def filter(self, **kwargs) -> xr.DataArray:
"""
Filter a grid in the space (or time) domain.

See the :func:`pygmt.grdfilter` function for available parameters.
"""
return grdfilter(grid=self._obj, **kwargs)

Check warning on line 284 in pygmt/xarray/accessor.py

View check run for this annotation

Codecov / codecov/patch

pygmt/xarray/accessor.py#L284

Added line #L284 was not covered by tests

def gradient(self, **kwargs) -> xr.DataArray:
"""
Compute directional gradients from a grid.

See the :func:`pygmt.grdgradient` function for available parameters.
"""
return grdgradient(grid=self._obj, **kwargs)

Check warning on line 292 in pygmt/xarray/accessor.py

View check run for this annotation

Codecov / codecov/patch

pygmt/xarray/accessor.py#L292

Added line #L292 was not covered by tests

def project(self, **kwargs) -> xr.DataArray:
"""
Forward and inverse map transformation of grids.

See the :func:`pygmt.grdproject` function for available parameters.
"""
return grdproject(grid=self._obj, **kwargs)

Check warning on line 300 in pygmt/xarray/accessor.py

View check run for this annotation

Codecov / codecov/patch

pygmt/xarray/accessor.py#L300

Added line #L300 was not covered by tests

def sample(self, **kwargs) -> xr.DataArray:
"""
Resample a grid onto a new lattice.

See the :func:`pygmt.grdsample` function for available parameters.
"""
return grdsample(grid=self._obj, **kwargs)

Check warning on line 308 in pygmt/xarray/accessor.py

View check run for this annotation

Codecov / codecov/patch

pygmt/xarray/accessor.py#L308

Added line #L308 was not covered by tests

def track(self, **kwargs) -> xr.DataArray:
"""
Sample a grid at specified locations.

See the :func:`pygmt.grdtrack` function for available parameters.
"""
return grdtrack(grid=self._obj, **kwargs)

Check warning on line 316 in pygmt/xarray/accessor.py

View check run for this annotation

Codecov / codecov/patch

pygmt/xarray/accessor.py#L316

Added line #L316 was not covered by tests
Loading