Skip to content

Commit ad92a16

Browse files
shoyerdcherianclaudestar1327pkeewis
authored
support combine_nested on DataTree objects (#10849)
* support combine_nested on DataTree objects * make mypy happy * Rewrite test_iso8601_decode to work without cftime (#10914) * Rewrite test_iso8601_decode to work without cftime Closes #10907 Co-Authored-By: Claude <noreply@anthropic.com> * silence warning --------- Co-authored-by: Claude <noreply@anthropic.com> * DOC: Correct minor grammar issues (#10915) * skip ci based on a label (#10918) * skip ci based on a label * retrigger ci * Optimize padding for coarsening. (#10921) Co-authored-by: Claude <noreply@anthropic.com> * release v2025.11.0 (#10917) * choose the right version number * formatting * move a breaking change entry to the right section * remove the deprecations section * missing entry * remove the empty internals section * another missing entry * more missing entries * remove empty line * bad merge * remove the empty documentation section * newline * use the old style of setting release dates * set the release date * formatting * back to bold * release summary * new release section (#10926) * new release section * [skip-ci] * update whats-new * removed redundant whatsnew --------- Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com> Co-authored-by: Claude <noreply@anthropic.com> Co-authored-by: Christine P. Chai <star1327p@gmail.com> Co-authored-by: Justus Magin <keewis@users.noreply.github.com>
1 parent b957535 commit ad92a16

File tree

3 files changed

+141
-39
lines changed

3 files changed

+141
-39
lines changed

doc/whats-new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@ v2025.11.1 (unreleased)
1414
New Features
1515
~~~~~~~~~~~~
1616

17+
- :py:func:`combine_nested` now support :py:class:`DataTree` objects
18+
(:pull:`10849`).
19+
By `Stephan Hoyer <https://github.com/shoyer>`_.
1720

1821
Breaking Changes
1922
~~~~~~~~~~~~~~~~

xarray/structure/combine.py

Lines changed: 109 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22

33
from collections import Counter, defaultdict
44
from collections.abc import Callable, Hashable, Iterable, Iterator, Sequence
5-
from typing import TYPE_CHECKING, Literal, TypeVar, Union, cast
5+
from typing import TYPE_CHECKING, Literal, TypeAlias, TypeVar, cast, overload
66

77
import pandas as pd
88

99
from xarray.core import dtypes
1010
from xarray.core.dataarray import DataArray
1111
from xarray.core.dataset import Dataset
12+
from xarray.core.datatree import DataTree
1213
from xarray.core.utils import iterate_nested
1314
from xarray.structure.alignment import AlignmentError
1415
from xarray.structure.concat import concat
@@ -96,27 +97,28 @@ def _ensure_same_types(series, dim):
9697
raise TypeError(error_msg)
9798

9899

99-
def _infer_concat_order_from_coords(datasets):
100+
def _infer_concat_order_from_coords(datasets: list[Dataset] | list[DataTree]):
100101
concat_dims = []
101-
tile_ids = [() for ds in datasets]
102+
tile_ids: list[tuple[int, ...]] = [() for ds in datasets]
102103

103104
# All datasets have same variables because they've been grouped as such
104105
ds0 = datasets[0]
105106
for dim in ds0.dims:
106107
# Check if dim is a coordinate dimension
107108
if dim in ds0:
108109
# Need to read coordinate values to do ordering
109-
indexes = [ds._indexes.get(dim) for ds in datasets]
110-
if any(index is None for index in indexes):
111-
error_msg = (
112-
f"Every dimension requires a corresponding 1D coordinate "
113-
f"and index for inferring concatenation order but the "
114-
f"coordinate '{dim}' has no corresponding index"
115-
)
116-
raise ValueError(error_msg)
117-
118-
# TODO (benbovy, flexible indexes): support flexible indexes?
119-
indexes = [index.to_pandas_index() for index in indexes]
110+
indexes: list[pd.Index] = []
111+
for ds in datasets:
112+
index = ds._indexes.get(dim)
113+
if index is None:
114+
error_msg = (
115+
f"Every dimension requires a corresponding 1D coordinate "
116+
f"and index for inferring concatenation order but the "
117+
f"coordinate '{dim}' has no corresponding index"
118+
)
119+
raise ValueError(error_msg)
120+
# TODO (benbovy, flexible indexes): support flexible indexes?
121+
indexes.append(index.to_pandas_index())
120122

121123
# If dimension coordinate values are same on every dataset then
122124
# should be leaving this dimension alone (it's just a "bystander")
@@ -153,7 +155,7 @@ def _infer_concat_order_from_coords(datasets):
153155
rank = series.rank(
154156
method="dense", ascending=ascending, numeric_only=False
155157
)
156-
order = rank.astype(int).values - 1
158+
order = (rank.astype(int).values - 1).tolist()
157159

158160
# Append positions along extra dimension to structure which
159161
# encodes the multi-dimensional concatenation order
@@ -163,10 +165,16 @@ def _infer_concat_order_from_coords(datasets):
163165
]
164166

165167
if len(datasets) > 1 and not concat_dims:
166-
raise ValueError(
167-
"Could not find any dimension coordinates to use to "
168-
"order the datasets for concatenation"
169-
)
168+
if any(isinstance(data, DataTree) for data in datasets):
169+
raise ValueError(
170+
"Did not find any dimension coordinates at root nodes "
171+
"to order the DataTree objects for concatenation"
172+
)
173+
else:
174+
raise ValueError(
175+
"Could not find any dimension coordinates to use to "
176+
"order the Dataset objects for concatenation"
177+
)
170178

171179
combined_ids = dict(zip(tile_ids, datasets, strict=True))
172180

@@ -224,7 +232,7 @@ def _combine_nd(
224232
225233
Parameters
226234
----------
227-
combined_ids : Dict[Tuple[int, ...]], xarray.Dataset]
235+
combined_ids : Dict[Tuple[int, ...]], xarray.Dataset | xarray.DataTree]
228236
Structure containing all datasets to be concatenated with "tile_IDs" as
229237
keys, which specify position within the desired final combined result.
230238
concat_dims : sequence of str
@@ -235,7 +243,7 @@ def _combine_nd(
235243
236244
Returns
237245
-------
238-
combined_ds : xarray.Dataset
246+
combined_ds : xarray.Dataset | xarray.DataTree
239247
"""
240248

241249
example_tile_id = next(iter(combined_ids.keys()))
@@ -399,20 +407,74 @@ def _nested_combine(
399407
return combined
400408

401409

402-
# Define type for arbitrarily-nested list of lists recursively:
403-
DATASET_HYPERCUBE = Union[Dataset, Iterable["DATASET_HYPERCUBE"]]
410+
# Define types for arbitrarily-nested list of lists.
411+
# Mypy doesn't seem to handle overloads properly with recursive types, so we
412+
# explicitly expand the first handful of levels of recursion.
413+
DatasetLike: TypeAlias = DataArray | Dataset
414+
DatasetHyperCube: TypeAlias = (
415+
DatasetLike
416+
| Sequence[DatasetLike]
417+
| Sequence[Sequence[DatasetLike]]
418+
| Sequence[Sequence[Sequence[DatasetLike]]]
419+
| Sequence[Sequence[Sequence[Sequence[DatasetLike]]]]
420+
)
421+
DataTreeHyperCube: TypeAlias = (
422+
DataTree
423+
| Sequence[DataTree]
424+
| Sequence[Sequence[DataTree]]
425+
| Sequence[Sequence[Sequence[DataTree]]]
426+
| Sequence[Sequence[Sequence[Sequence[DataTree]]]]
427+
)
428+
429+
430+
@overload
431+
def combine_nested(
432+
datasets: DatasetHyperCube,
433+
concat_dim: str
434+
| DataArray
435+
| list[str]
436+
| Sequence[str | DataArray | pd.Index | None]
437+
| None,
438+
compat: str | CombineKwargDefault = ...,
439+
data_vars: str | CombineKwargDefault = ...,
440+
coords: str | CombineKwargDefault = ...,
441+
fill_value: object = ...,
442+
join: JoinOptions | CombineKwargDefault = ...,
443+
combine_attrs: CombineAttrsOptions = ...,
444+
) -> Dataset: ...
445+
446+
447+
@overload
448+
def combine_nested(
449+
datasets: DataTreeHyperCube,
450+
concat_dim: str
451+
| DataArray
452+
| list[str]
453+
| Sequence[str | DataArray | pd.Index | None]
454+
| None,
455+
compat: str | CombineKwargDefault = ...,
456+
data_vars: str | CombineKwargDefault = ...,
457+
coords: str | CombineKwargDefault = ...,
458+
fill_value: object = ...,
459+
join: JoinOptions | CombineKwargDefault = ...,
460+
combine_attrs: CombineAttrsOptions = ...,
461+
) -> DataTree: ...
404462

405463

406464
def combine_nested(
407-
datasets: DATASET_HYPERCUBE,
408-
concat_dim: str | DataArray | Sequence[str | DataArray | pd.Index | None] | None,
465+
datasets: DatasetHyperCube | DataTreeHyperCube,
466+
concat_dim: str
467+
| DataArray
468+
| list[str]
469+
| Sequence[str | DataArray | pd.Index | None]
470+
| None,
409471
compat: str | CombineKwargDefault = _COMPAT_DEFAULT,
410472
data_vars: str | CombineKwargDefault = _DATA_VARS_DEFAULT,
411473
coords: str | CombineKwargDefault = _COORDS_DEFAULT,
412474
fill_value: object = dtypes.NA,
413475
join: JoinOptions | CombineKwargDefault = _JOIN_DEFAULT,
414476
combine_attrs: CombineAttrsOptions = "drop",
415-
) -> Dataset:
477+
) -> Dataset | DataTree:
416478
"""
417479
Explicitly combine an N-dimensional grid of datasets into one by using a
418480
succession of concat and merge operations along each dimension of the grid.
@@ -433,7 +495,7 @@ def combine_nested(
433495
434496
Parameters
435497
----------
436-
datasets : list or nested list of Dataset
498+
datasets : list or nested list of Dataset, DataArray or DataTree
437499
Dataset objects to combine.
438500
If concatenation or merging along more than one dimension is desired,
439501
then datasets must be supplied in a nested list-of-lists.
@@ -527,7 +589,7 @@ def combine_nested(
527589
528590
Returns
529591
-------
530-
combined : xarray.Dataset
592+
combined : xarray.Dataset or xarray.DataTree
531593
532594
Examples
533595
--------
@@ -621,22 +683,29 @@ def combine_nested(
621683
concat
622684
merge
623685
"""
624-
mixed_datasets_and_arrays = any(
625-
isinstance(obj, Dataset) for obj in iterate_nested(datasets)
626-
) and any(
686+
any_datasets = any(isinstance(obj, Dataset) for obj in iterate_nested(datasets))
687+
any_unnamed_arrays = any(
627688
isinstance(obj, DataArray) and obj.name is None
628689
for obj in iterate_nested(datasets)
629690
)
630-
if mixed_datasets_and_arrays:
691+
if any_datasets and any_unnamed_arrays:
631692
raise ValueError("Can't combine datasets with unnamed arrays.")
632693

633-
if isinstance(concat_dim, str | DataArray) or concat_dim is None:
634-
concat_dim = [concat_dim]
694+
any_datatrees = any(isinstance(obj, DataTree) for obj in iterate_nested(datasets))
695+
all_datatrees = all(isinstance(obj, DataTree) for obj in iterate_nested(datasets))
696+
if any_datatrees and not all_datatrees:
697+
raise ValueError("Can't combine a mix of DataTree and non-DataTree objects.")
698+
699+
concat_dims = (
700+
[concat_dim]
701+
if isinstance(concat_dim, str | DataArray) or concat_dim is None
702+
else concat_dim
703+
)
635704

636705
# The IDs argument tells _nested_combine that datasets aren't yet sorted
637706
return _nested_combine(
638707
datasets,
639-
concat_dims=concat_dim,
708+
concat_dims=concat_dims,
640709
compat=compat,
641710
data_vars=data_vars,
642711
coords=coords,
@@ -988,6 +1057,10 @@ def combine_by_coords(
9881057
Finally, if you attempt to combine a mix of unnamed DataArrays with either named
9891058
DataArrays or Datasets, a ValueError will be raised (as this is an ambiguous operation).
9901059
"""
1060+
if any(isinstance(data_object, DataTree) for data_object in data_objects):
1061+
raise NotImplementedError(
1062+
"combine_by_coords() does not yet support DataTree objects."
1063+
)
9911064

9921065
if not data_objects:
9931066
return Dataset()
@@ -1018,7 +1091,7 @@ def combine_by_coords(
10181091
# Must be a mix of unnamed dataarrays with either named dataarrays or with datasets
10191092
# Can't combine these as we wouldn't know whether to merge or concatenate the arrays
10201093
raise ValueError(
1021-
"Can't automatically combine unnamed DataArrays with either named DataArrays or Datasets."
1094+
"Can't automatically combine unnamed DataArrays with named DataArrays or Datasets."
10221095
)
10231096
else:
10241097
# Promote any named DataArrays to single-variable Datasets to simplify combining

xarray/tests/test_combine.py

Lines changed: 29 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import re
34
from itertools import product
45

56
import numpy as np
@@ -8,6 +9,7 @@
89
from xarray import (
910
DataArray,
1011
Dataset,
12+
DataTree,
1113
MergeError,
1214
combine_by_coords,
1315
combine_nested,
@@ -624,8 +626,8 @@ def test_auto_combine_2d_combine_attrs_kwarg(self):
624626
datasets,
625627
concat_dim=["dim1", "dim2"],
626628
data_vars="all",
627-
combine_attrs=combine_attrs, # type: ignore[arg-type]
628-
)
629+
combine_attrs=combine_attrs,
630+
) # type: ignore[call-overload]
629631
assert_identical(result, expected)
630632

631633
def test_combine_nested_missing_data_new_dim(self):
@@ -764,7 +766,21 @@ def test_nested_combine_mixed_datasets_arrays(self):
764766
with pytest.raises(
765767
ValueError, match=r"Can't combine datasets with unnamed arrays."
766768
):
767-
combine_nested(objs, "x")
769+
combine_nested(objs, "x") # type: ignore[arg-type]
770+
771+
def test_nested_combine_mixed_datatrees_and_datasets(self):
772+
objs = [DataTree.from_dict({"foo": 0}), Dataset({"foo": 1})]
773+
with pytest.raises(
774+
ValueError,
775+
match=r"Can't combine a mix of DataTree and non-DataTree objects.",
776+
):
777+
combine_nested(objs, concat_dim="x") # type: ignore[arg-type]
778+
779+
def test_datatree(self):
780+
objs = [DataTree.from_dict({"foo": 0}), DataTree.from_dict({"foo": 1})]
781+
expected = DataTree.from_dict({"foo": ("x", [0, 1])})
782+
actual = combine_nested(objs, concat_dim="x")
783+
assert expected.identical(actual)
768784

769785

770786
class TestCombineDatasetsbyCoords:
@@ -1210,6 +1226,16 @@ def test_combine_by_coords_all_dataarrays_with_the_same_name(self):
12101226
expected = merge([named_da1, named_da2], compat="no_conflicts", join="outer")
12111227
assert_identical(expected, actual)
12121228

1229+
def test_combine_by_coords_datatree(self):
1230+
tree = DataTree.from_dict({"/nested/foo": ("x", [10])}, coords={"x": [1]})
1231+
with pytest.raises(
1232+
NotImplementedError,
1233+
match=re.escape(
1234+
"combine_by_coords() does not yet support DataTree objects."
1235+
),
1236+
):
1237+
combine_by_coords([tree]) # type: ignore[list-item]
1238+
12131239

12141240
class TestNewDefaults:
12151241
def test_concat_along_existing_dim(self):

0 commit comments

Comments
 (0)