Skip to content

Commit 845fb6d

Browse files
authored
Merge pull request #332 from guziy/shiftdata_fix
Shiftdata fix
2 parents e981c14 + 5afc170 commit 845fb6d

File tree

2 files changed

+61
-21
lines changed

2 files changed

+61
-21
lines changed

lib/mpl_toolkits/basemap/__init__.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,8 @@ def with_transform(self,x,y,data,*args,**kwargs):
526526
# shift data to map projection region for
527527
# cylindrical and pseudo-cylindrical projections.
528528
if self.projection in _cylproj or self.projection in _pseudocyl:
529-
x, data = self.shiftdata(x, data)
529+
x, data = self.shiftdata(x, data,
530+
fix_wrap_around=plotfunc.__name__ not in ["scatter"])
530531
# convert lat/lon coords to map projection coords.
531532
x, y = self(x,y)
532533
return plotfunc(self,x,y,data,*args,**kwargs)
@@ -544,7 +545,7 @@ def with_transform(self,x,y,*args,**kwargs):
544545
# cylindrical and pseudo-cylindrical projections.
545546
if self.projection in _cylproj or self.projection in _pseudocyl:
546547
if x.ndim == 1:
547-
x = self.shiftdata(x)
548+
x = self.shiftdata(x, fix_wrap_around=plotfunc.__name__ not in ["scatter"])
548549
elif x.ndim == 0:
549550
if x > 180:
550551
x = x - 360.
@@ -4723,7 +4724,7 @@ def _ax_plt_from_kw(self, kw):
47234724
_ax = plt.gca()
47244725
return _ax, plt
47254726

4726-
def shiftdata(self,lonsin,datain=None,lon_0=None):
4727+
def shiftdata(self,lonsin,datain=None,lon_0=None,fix_wrap_around=True):
47274728
"""
47284729
Shift longitudes (and optionally data) so that they match map projection region.
47294730
Only valid for cylindrical/pseudo-cylindrical global projections and data
@@ -4746,6 +4747,13 @@ def shiftdata(self,lonsin,datain=None,lon_0=None):
47464747
datain original 1-d or 2-d data. Default None.
47474748
lon_0 center of map projection region. Defaut None,
47484749
given by current map projection.
4750+
fix_wrap_around if True reindex (if required) longitudes (and data) to
4751+
avoid jumps caused by remapping of longitudes of
4752+
points from outside of the [lon_0-180, lon_0+180]
4753+
interval back into the interval.
4754+
If False do not reindex longitudes and data, but do
4755+
make sure that longitudes are in the
4756+
[lon_0-180, lon_0+180] range.
47494757
============== ====================================================
47504758
47514759
if datain given, returns ``dataout,lonsout`` (data and longitudes shifted to fit in interval
@@ -4784,7 +4792,7 @@ def shiftdata(self,lonsin,datain=None,lon_0=None):
47844792

47854793
# if no shift necessary, itemindex will be
47864794
# empty, so don't do anything
4787-
if itemindex:
4795+
if fix_wrap_around and itemindex:
47884796
# check to see if cyclic (wraparound) point included
47894797
# if so, remove it.
47904798
if np.abs(lonsin1[0]-lonsin1[-1]) < 1.e-4:
@@ -4811,13 +4819,7 @@ def shiftdata(self,lonsin,datain=None,lon_0=None):
48114819
datain_save[:,1:] = datain
48124820
datain_save[:,0] = datain[:,-1]
48134821
datain = datain_save
4814-
# mask points outside
4815-
# map region so they don't wrap back in the domain.
4816-
mask = np.logical_or(lonsin<lon_0-180,lonsin>lon_0+180)
4817-
lonsin = np.where(mask,1.e30,lonsin)
4818-
if datain is not None and mask.any():
4819-
# superimpose on existing mask
4820-
datain = ma.masked_where(mask, datain)
4822+
48214823
# 1-d data.
48224824
elif lonsin.ndim == 1:
48234825
nlons = len(lonsin)
@@ -4832,7 +4834,7 @@ def shiftdata(self,lonsin,datain=None,lon_0=None):
48324834
else:
48334835
itemindex = 0
48344836

4835-
if itemindex:
4837+
if fix_wrap_around and itemindex:
48364838
# check to see if cyclic (wraparound) point included
48374839
# if so, remove it.
48384840
if np.abs(lonsin[0]-lonsin[-1]) < 1.e-4:
@@ -4856,12 +4858,14 @@ def shiftdata(self,lonsin,datain=None,lon_0=None):
48564858
datain_save[1:] = datain
48574859
datain_save[0] = datain[-1]
48584860
datain = datain_save
4859-
# mask points outside
4860-
# map region so they don't wrap back in the domain.
4861-
mask = np.logical_or(lonsin<lon_0-180,lonsin>lon_0+180)
4862-
lonsin = np.where(mask,1.e30,lonsin)
4863-
if datain is not None and mask.any():
4864-
datain = ma.masked_where(mask, datain)
4861+
4862+
# mask points outside
4863+
# map region so they don't wrap back in the domain.
4864+
mask = np.logical_or(lonsin<lon_0-180,lonsin>lon_0+180)
4865+
lonsin = np.where(mask,1.e30,lonsin)
4866+
if datain is not None and mask.any():
4867+
datain = ma.masked_where(mask, datain)
4868+
48654869
if datain is not None:
48664870
return lonsin, datain
48674871
else:

lib/mpl_toolkits/basemap/test.py

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_cylindrical(self):
3838
def test_nan(self):
3939
B = Basemap()
4040
u,v,lat,lon=self.make_array()
41-
# Set one element to 0, so that the vector magnitude is 0.
41+
# Set one element to 0, so that the vector magnitude is 0.
4242
u[1,1] = 0.
4343
ru, rv = B.rotate_vector(u,v, lon, lat)
4444
assert not np.isnan(ru).any()
@@ -117,6 +117,42 @@ def _get_2d_lons(self, lons1d):
117117
lats = [10, ] * len(lons1d)
118118
return np.meshgrid(lons1d, lats)[0]
119119

120+
def test_non_monotonous_longitudes(self):
121+
"""
122+
when called for scatter, the longitudes passed to shiftdata are
123+
not necessarily monotonous...
124+
"""
125+
lons = [179, 180, 180, 0, 290, 10, 320, -150, 350, -250, 250]
126+
bm = Basemap(lon_0=0)
127+
128+
# before, having several break points would cause the exception,
129+
# inside the shiftdata method called from scatter method.
130+
self.assertRaises(ValueError, bm.shiftdata, lons, fix_wrap_around=True)
131+
132+
lons_new = bm.shiftdata(lons, fix_wrap_around=False)
133+
134+
# Check if the modified longitudes are inside of the projection region
135+
for lon in lons_new:
136+
assert lon >= bm.projparams["lon_0"] - 180
137+
assert lon <= bm.projparams["lon_0"] + 180
138+
139+
140+
def test_shiftdata_on_monotonous_lons(self):
141+
"""
142+
Test that shiftdata with fix_wrap_around keyword added works as before,
143+
when it is True
144+
"""
145+
146+
bm = Basemap(lon_0=0)
147+
148+
lons_in = [120, 140, 160, 180, 200, 220]
149+
lons_out_expect = [-160, -140, 120, 140, 160, 180]
150+
lons_out = bm.shiftdata(lons_in, fix_wrap_around=True)
151+
152+
assert_almost_equal(lons_out, lons_out_expect)
153+
154+
155+
120156
def test_2_points_should_work(self):
121157
"""
122158
Shiftdata should work with 2 points
@@ -160,7 +196,7 @@ def test_less_than_n_by_3_points_should_work(self):
160196
lonsout = bm.shiftdata(lonsin[:, :2])
161197
assert_almost_equal(lonsout_expected, lonsout)
162198

163-
@skipIf(PY3 and LooseVersion(pyproj.__version__) <= LooseVersion("1.9.4"),
199+
@skipIf(PY3 and LooseVersion(pyproj.__version__) <= LooseVersion("1.9.4"),
164200
"Test skipped in Python 3.x with pyproj version 1.9.4 and below.")
165201
class TestProjectCoords(TestCase):
166202
def get_data(self):
@@ -224,7 +260,7 @@ def test():
224260
import unittest
225261

226262
from mpl_toolkits.basemap.diagnostic import package_versions
227-
263+
228264
if '--verbose' in sys.argv or '-v' in sys.argv:
229265
pkg_vers = package_versions()
230266
print('Basemaps installed package versions:')

0 commit comments

Comments
 (0)