Skip to content

Commit dfdeef7

Browse files
authored
Explicitly keep track of indexes with merging (#3234)
* Explicitly keep track of indexes in merge.py * Typing fixes * More tying fixes * more typing fixes * fixup
1 parent 86fb71d commit dfdeef7

File tree

6 files changed

+352
-280
lines changed

6 files changed

+352
-280
lines changed

xarray/core/alignment.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,7 @@ def align(
268268
all_indexes[dim].append(index)
269269

270270
if join == "override":
271-
objects = _override_indexes(list(objects), all_indexes, exclude)
271+
objects = _override_indexes(objects, all_indexes, exclude)
272272

273273
# We don't reindex over dimensions with all equal indexes for two reasons:
274274
# - It's faster for the usual case (already aligned objects).
@@ -365,26 +365,27 @@ def is_alignable(obj):
365365
targets = []
366366
no_key = object()
367367
not_replaced = object()
368-
for n, variables in enumerate(objects):
368+
for position, variables in enumerate(objects):
369369
if is_alignable(variables):
370-
positions.append(n)
370+
positions.append(position)
371371
keys.append(no_key)
372372
targets.append(variables)
373373
out.append(not_replaced)
374374
elif is_dict_like(variables):
375+
current_out = OrderedDict()
375376
for k, v in variables.items():
376-
if is_alignable(v) and k not in indexes:
377-
# Skip variables in indexes for alignment, because these
378-
# should to be overwritten instead:
379-
# https://github.com/pydata/xarray/issues/725
380-
positions.append(n)
377+
if is_alignable(v):
378+
positions.append(position)
381379
keys.append(k)
382380
targets.append(v)
383-
out.append(OrderedDict(variables))
381+
current_out[k] = not_replaced
382+
else:
383+
current_out[k] = v
384+
out.append(current_out)
384385
elif raise_on_invalid:
385386
raise ValueError(
386387
"object to align is neither an xarray.Dataset, "
387-
"an xarray.DataArray nor a dictionary: %r" % variables
388+
"an xarray.DataArray nor a dictionary: {!r}".format(variables)
388389
)
389390
else:
390391
out.append(variables)
@@ -405,7 +406,10 @@ def is_alignable(obj):
405406
out[position][key] = aligned_obj
406407

407408
# something went wrong: we should have replaced all sentinel values
408-
assert all(arg is not not_replaced for arg in out)
409+
for arg in out:
410+
assert arg is not not_replaced
411+
if is_dict_like(arg):
412+
assert all(value is not not_replaced for value in arg.values())
409413

410414
return out
411415

xarray/core/computation.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,13 @@
2424

2525
from . import duck_array_ops, utils
2626
from .alignment import deep_align
27-
from .merge import expand_and_merge_variables
27+
from .merge import merge_coordinates_without_align
2828
from .pycompat import dask_array_type
2929
from .utils import is_dict_like
3030
from .variable import Variable
3131

3232
if TYPE_CHECKING:
33+
from .coordinates import Coordinates # noqa
3334
from .dataset import Dataset
3435

3536
_DEFAULT_FROZEN_SET = frozenset() # type: frozenset
@@ -152,17 +153,16 @@ def result_name(objects: list) -> Any:
152153
return name
153154

154155

155-
def _get_coord_variables(args):
156-
input_coords = []
156+
def _get_coords_list(args) -> List["Coordinates"]:
157+
coords_list = []
157158
for arg in args:
158159
try:
159160
coords = arg.coords
160161
except AttributeError:
161162
pass # skip this argument
162163
else:
163-
coord_vars = getattr(coords, "variables", coords)
164-
input_coords.append(coord_vars)
165-
return input_coords
164+
coords_list.append(coords)
165+
return coords_list
166166

167167

168168
def build_output_coords(
@@ -185,32 +185,29 @@ def build_output_coords(
185185
-------
186186
OrderedDict of Variable objects with merged coordinates.
187187
"""
188-
input_coords = _get_coord_variables(args)
188+
coords_list = _get_coords_list(args)
189189

190-
if exclude_dims:
191-
input_coords = [
192-
OrderedDict(
193-
(k, v) for k, v in coord_vars.items() if exclude_dims.isdisjoint(v.dims)
194-
)
195-
for coord_vars in input_coords
196-
]
197-
198-
if len(input_coords) == 1:
190+
if len(coords_list) == 1 and not exclude_dims:
199191
# we can skip the expensive merge
200-
unpacked_input_coords, = input_coords
201-
merged = OrderedDict(unpacked_input_coords)
192+
unpacked_coords, = coords_list
193+
merged_vars = OrderedDict(unpacked_coords.variables)
202194
else:
203-
merged = expand_and_merge_variables(input_coords)
195+
# TODO: save these merged indexes, instead of re-computing them later
196+
merged_vars, unused_indexes = merge_coordinates_without_align(
197+
coords_list, exclude_dims=exclude_dims
198+
)
204199

205200
output_coords = []
206201
for output_dims in signature.output_core_dims:
207202
dropped_dims = signature.all_input_core_dims - set(output_dims)
208203
if dropped_dims:
209204
filtered = OrderedDict(
210-
(k, v) for k, v in merged.items() if dropped_dims.isdisjoint(v.dims)
205+
(k, v)
206+
for k, v in merged_vars.items()
207+
if dropped_dims.isdisjoint(v.dims)
211208
)
212209
else:
213-
filtered = merged
210+
filtered = merged_vars
214211
output_coords.append(filtered)
215212

216213
return output_coords

xarray/core/coordinates.py

Lines changed: 54 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
Hashable,
77
Iterator,
88
Mapping,
9-
Sequence,
109
Set,
10+
Sequence,
1111
Tuple,
1212
Union,
1313
cast,
@@ -17,11 +17,7 @@
1717

1818
from . import formatting, indexing
1919
from .indexes import Indexes
20-
from .merge import (
21-
expand_and_merge_variables,
22-
merge_coords,
23-
merge_coords_for_inplace_math,
24-
)
20+
from .merge import merge_coords, merge_coordinates_without_align
2521
from .utils import Frozen, ReprObject, either_dict_or_kwargs
2622
from .variable import Variable
2723

@@ -34,7 +30,7 @@
3430
_THIS_ARRAY = ReprObject("<this-array>")
3531

3632

37-
class AbstractCoordinates(Mapping[Hashable, "DataArray"]):
33+
class Coordinates(Mapping[Hashable, "DataArray"]):
3834
__slots__ = ()
3935

4036
def __getitem__(self, key: Hashable) -> "DataArray":
@@ -57,10 +53,10 @@ def indexes(self) -> Indexes:
5753

5854
@property
5955
def variables(self):
60-
raise NotImplementedError()
56+
raise NotImplementedError
6157

62-
def _update_coords(self, coords):
63-
raise NotImplementedError()
58+
def _update_coords(self, coords, indexes):
59+
raise NotImplementedError
6460

6561
def __iter__(self) -> Iterator["Hashable"]:
6662
# needs to be in the same order as the dataset variables
@@ -116,38 +112,38 @@ def to_index(self, ordered_dims: Sequence[Hashable] = None) -> pd.Index:
116112

117113
def update(self, other: Mapping[Hashable, Any]) -> None:
118114
other_vars = getattr(other, "variables", other)
119-
coords = merge_coords(
115+
coords, indexes = merge_coords(
120116
[self.variables, other_vars], priority_arg=1, indexes=self.indexes
121117
)
122-
self._update_coords(coords)
118+
self._update_coords(coords, indexes)
123119

124120
def _merge_raw(self, other):
125121
"""For use with binary arithmetic."""
126122
if other is None:
127123
variables = OrderedDict(self.variables)
124+
indexes = OrderedDict(self.indexes)
128125
else:
129-
# don't align because we already called xarray.align
130-
variables = expand_and_merge_variables([self.variables, other.variables])
131-
return variables
126+
variables, indexes = merge_coordinates_without_align([self, other])
127+
return variables, indexes
132128

133129
@contextmanager
134130
def _merge_inplace(self, other):
135131
"""For use with in-place binary arithmetic."""
136132
if other is None:
137133
yield
138134
else:
139-
# don't include indexes in priority_vars, because we didn't align
140-
# first
141-
priority_vars = OrderedDict(
142-
kv for kv in self.variables.items() if kv[0] not in self.dims
143-
)
144-
variables = merge_coords_for_inplace_math(
145-
[self.variables, other.variables], priority_vars=priority_vars
135+
# don't include indexes in prioritized, because we didn't align
136+
# first and we want indexes to be checked
137+
prioritized = {
138+
k: (v, None) for k, v in self.variables.items() if k not in self.indexes
139+
}
140+
variables, indexes = merge_coordinates_without_align(
141+
[self, other], prioritized
146142
)
147143
yield
148-
self._update_coords(variables)
144+
self._update_coords(variables, indexes)
149145

150-
def merge(self, other: "AbstractCoordinates") -> "Dataset":
146+
def merge(self, other: "Coordinates") -> "Dataset":
151147
"""Merge two sets of coordinates to create a new Dataset
152148
153149
The method implements the logic used for joining coordinates in the
@@ -173,13 +169,19 @@ def merge(self, other: "AbstractCoordinates") -> "Dataset":
173169

174170
if other is None:
175171
return self.to_dataset()
176-
else:
177-
other_vars = getattr(other, "variables", other)
178-
coords = expand_and_merge_variables([self.variables, other_vars])
179-
return Dataset._from_vars_and_coord_names(coords, set(coords))
172+
173+
if not isinstance(other, Coordinates):
174+
other = Dataset(coords=other).coords
175+
176+
coords, indexes = merge_coordinates_without_align([self, other])
177+
coord_names = set(coords)
178+
merged = Dataset._construct_direct(
179+
variables=coords, coord_names=coord_names, indexes=indexes
180+
)
181+
return merged
180182

181183

182-
class DatasetCoordinates(AbstractCoordinates):
184+
class DatasetCoordinates(Coordinates):
183185
"""Dictionary like container for Dataset coordinates.
184186
185187
Essentially an immutable OrderedDict with keys given by the array's
@@ -218,7 +220,11 @@ def to_dataset(self) -> "Dataset":
218220
"""
219221
return self._data._copy_listed(self._names)
220222

221-
def _update_coords(self, coords: Mapping[Hashable, Any]) -> None:
223+
def _update_coords(
224+
self,
225+
coords: "OrderedDict[Hashable, Variable]",
226+
indexes: Mapping[Hashable, pd.Index],
227+
) -> None:
222228
from .dataset import calculate_dimensions
223229

224230
variables = self._data._variables.copy()
@@ -234,7 +240,12 @@ def _update_coords(self, coords: Mapping[Hashable, Any]) -> None:
234240
self._data._variables = variables
235241
self._data._coord_names.update(new_coord_names)
236242
self._data._dims = dims
237-
self._data._indexes = None
243+
244+
# TODO(shoyer): once ._indexes is always populated by a dict, modify
245+
# it to update inplace instead.
246+
original_indexes = OrderedDict(self._data.indexes)
247+
original_indexes.update(indexes)
248+
self._data._indexes = original_indexes
238249

239250
def __delitem__(self, key: Hashable) -> None:
240251
if key in self:
@@ -251,7 +262,7 @@ def _ipython_key_completions_(self):
251262
]
252263

253264

254-
class DataArrayCoordinates(AbstractCoordinates):
265+
class DataArrayCoordinates(Coordinates):
255266
"""Dictionary like container for DataArray coordinates.
256267
257268
Essentially an OrderedDict with keys given by the array's
@@ -274,7 +285,11 @@ def _names(self) -> Set[Hashable]:
274285
def __getitem__(self, key: Hashable) -> "DataArray":
275286
return self._data._getitem_coord(key)
276287

277-
def _update_coords(self, coords) -> None:
288+
def _update_coords(
289+
self,
290+
coords: "OrderedDict[Hashable, Variable]",
291+
indexes: Mapping[Hashable, pd.Index],
292+
) -> None:
278293
from .dataset import calculate_dimensions
279294

280295
coords_plus_data = coords.copy()
@@ -285,7 +300,12 @@ def _update_coords(self, coords) -> None:
285300
"cannot add coordinates with new dimensions to " "a DataArray"
286301
)
287302
self._data._coords = coords
288-
self._data._indexes = None
303+
304+
# TODO(shoyer): once ._indexes is always populated by a dict, modify
305+
# it to update inplace instead.
306+
original_indexes = OrderedDict(self._data.indexes)
307+
original_indexes.update(indexes)
308+
self._data._indexes = original_indexes
289309

290310
@property
291311
def variables(self):

xarray/core/dataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2519,7 +2519,7 @@ def func(self, other):
25192519
if not reflexive
25202520
else f(other_variable, self.variable)
25212521
)
2522-
coords = self.coords._merge_raw(other_coords)
2522+
coords, indexes = self.coords._merge_raw(other_coords)
25232523
name = self._result_name(other)
25242524

25252525
return self._replace(variable, coords, name)

0 commit comments

Comments
 (0)