22
33from collections import Counter , defaultdict
44from collections .abc import Callable , Hashable , Iterable , Iterator , Sequence
5- from typing import TYPE_CHECKING , Literal , TypeVar , Union , cast
5+ from typing import TYPE_CHECKING , Literal , TypeAlias , TypeVar , cast , overload
66
77import pandas as pd
88
99from xarray .core import dtypes
1010from xarray .core .dataarray import DataArray
1111from xarray .core .dataset import Dataset
12+ from xarray .core .datatree import DataTree
1213from xarray .core .utils import iterate_nested
1314from xarray .structure .alignment import AlignmentError
1415from xarray .structure .concat import concat
@@ -96,27 +97,28 @@ def _ensure_same_types(series, dim):
9697 raise TypeError (error_msg )
9798
9899
99- def _infer_concat_order_from_coords (datasets ):
100+ def _infer_concat_order_from_coords (datasets : list [ Dataset ] | list [ DataTree ] ):
100101 concat_dims = []
101- tile_ids = [() for ds in datasets ]
102+ tile_ids : list [ tuple [ int , ...]] = [() for ds in datasets ]
102103
103104 # All datasets have same variables because they've been grouped as such
104105 ds0 = datasets [0 ]
105106 for dim in ds0 .dims :
106107 # Check if dim is a coordinate dimension
107108 if dim in ds0 :
108109 # Need to read coordinate values to do ordering
109- indexes = [ds ._indexes .get (dim ) for ds in datasets ]
110- if any (index is None for index in indexes ):
111- error_msg = (
112- f"Every dimension requires a corresponding 1D coordinate "
113- f"and index for inferring concatenation order but the "
114- f"coordinate '{ dim } ' has no corresponding index"
115- )
116- raise ValueError (error_msg )
117-
118- # TODO (benbovy, flexible indexes): support flexible indexes?
119- indexes = [index .to_pandas_index () for index in indexes ]
110+ indexes : list [pd .Index ] = []
111+ for ds in datasets :
112+ index = ds ._indexes .get (dim )
113+ if index is None :
114+ error_msg = (
115+ f"Every dimension requires a corresponding 1D coordinate "
116+ f"and index for inferring concatenation order but the "
117+ f"coordinate '{ dim } ' has no corresponding index"
118+ )
119+ raise ValueError (error_msg )
120+ # TODO (benbovy, flexible indexes): support flexible indexes?
121+ indexes .append (index .to_pandas_index ())
120122
121123 # If dimension coordinate values are same on every dataset then
122124 # should be leaving this dimension alone (it's just a "bystander")
@@ -153,7 +155,7 @@ def _infer_concat_order_from_coords(datasets):
153155 rank = series .rank (
154156 method = "dense" , ascending = ascending , numeric_only = False
155157 )
156- order = rank .astype (int ).values - 1
158+ order = ( rank .astype (int ).values - 1 ). tolist ()
157159
158160 # Append positions along extra dimension to structure which
159161 # encodes the multi-dimensional concatenation order
@@ -163,10 +165,16 @@ def _infer_concat_order_from_coords(datasets):
163165 ]
164166
165167 if len (datasets ) > 1 and not concat_dims :
166- raise ValueError (
167- "Could not find any dimension coordinates to use to "
168- "order the datasets for concatenation"
169- )
168+ if any (isinstance (data , DataTree ) for data in datasets ):
169+ raise ValueError (
170+ "Did not find any dimension coordinates at root nodes "
171+ "to order the DataTree objects for concatenation"
172+ )
173+ else :
174+ raise ValueError (
175+ "Could not find any dimension coordinates to use to "
176+ "order the Dataset objects for concatenation"
177+ )
170178
171179 combined_ids = dict (zip (tile_ids , datasets , strict = True ))
172180
@@ -224,7 +232,7 @@ def _combine_nd(
224232
225233 Parameters
226234 ----------
227- combined_ids : Dict[Tuple[int, ...]], xarray.Dataset]
235+ combined_ids : Dict[Tuple[int, ...]], xarray.Dataset | xarray.DataTree ]
228236 Structure containing all datasets to be concatenated with "tile_IDs" as
229237 keys, which specify position within the desired final combined result.
230238 concat_dims : sequence of str
@@ -235,7 +243,7 @@ def _combine_nd(
235243
236244 Returns
237245 -------
238- combined_ds : xarray.Dataset
246+ combined_ds : xarray.Dataset | xarray.DataTree
239247 """
240248
241249 example_tile_id = next (iter (combined_ids .keys ()))
@@ -399,20 +407,74 @@ def _nested_combine(
399407 return combined
400408
401409
402- # Define type for arbitrarily-nested list of lists recursively:
403- DATASET_HYPERCUBE = Union [Dataset , Iterable ["DATASET_HYPERCUBE" ]]
410+ # Define types for arbitrarily-nested list of lists.
411+ # Mypy doesn't seem to handle overloads properly with recursive types, so we
412+ # explicitly expand the first handful of levels of recursion.
413+ DatasetLike : TypeAlias = DataArray | Dataset
414+ DatasetHyperCube : TypeAlias = (
415+ DatasetLike
416+ | Sequence [DatasetLike ]
417+ | Sequence [Sequence [DatasetLike ]]
418+ | Sequence [Sequence [Sequence [DatasetLike ]]]
419+ | Sequence [Sequence [Sequence [Sequence [DatasetLike ]]]]
420+ )
421+ DataTreeHyperCube : TypeAlias = (
422+ DataTree
423+ | Sequence [DataTree ]
424+ | Sequence [Sequence [DataTree ]]
425+ | Sequence [Sequence [Sequence [DataTree ]]]
426+ | Sequence [Sequence [Sequence [Sequence [DataTree ]]]]
427+ )
428+
429+
430+ @overload
431+ def combine_nested (
432+ datasets : DatasetHyperCube ,
433+ concat_dim : str
434+ | DataArray
435+ | list [str ]
436+ | Sequence [str | DataArray | pd .Index | None ]
437+ | None ,
438+ compat : str | CombineKwargDefault = ...,
439+ data_vars : str | CombineKwargDefault = ...,
440+ coords : str | CombineKwargDefault = ...,
441+ fill_value : object = ...,
442+ join : JoinOptions | CombineKwargDefault = ...,
443+ combine_attrs : CombineAttrsOptions = ...,
444+ ) -> Dataset : ...
445+
446+
447+ @overload
448+ def combine_nested (
449+ datasets : DataTreeHyperCube ,
450+ concat_dim : str
451+ | DataArray
452+ | list [str ]
453+ | Sequence [str | DataArray | pd .Index | None ]
454+ | None ,
455+ compat : str | CombineKwargDefault = ...,
456+ data_vars : str | CombineKwargDefault = ...,
457+ coords : str | CombineKwargDefault = ...,
458+ fill_value : object = ...,
459+ join : JoinOptions | CombineKwargDefault = ...,
460+ combine_attrs : CombineAttrsOptions = ...,
461+ ) -> DataTree : ...
404462
405463
406464def combine_nested (
407- datasets : DATASET_HYPERCUBE ,
408- concat_dim : str | DataArray | Sequence [str | DataArray | pd .Index | None ] | None ,
465+ datasets : DatasetHyperCube | DataTreeHyperCube ,
466+ concat_dim : str
467+ | DataArray
468+ | list [str ]
469+ | Sequence [str | DataArray | pd .Index | None ]
470+ | None ,
409471 compat : str | CombineKwargDefault = _COMPAT_DEFAULT ,
410472 data_vars : str | CombineKwargDefault = _DATA_VARS_DEFAULT ,
411473 coords : str | CombineKwargDefault = _COORDS_DEFAULT ,
412474 fill_value : object = dtypes .NA ,
413475 join : JoinOptions | CombineKwargDefault = _JOIN_DEFAULT ,
414476 combine_attrs : CombineAttrsOptions = "drop" ,
415- ) -> Dataset :
477+ ) -> Dataset | DataTree :
416478 """
417479 Explicitly combine an N-dimensional grid of datasets into one by using a
418480 succession of concat and merge operations along each dimension of the grid.
@@ -433,7 +495,7 @@ def combine_nested(
433495
434496 Parameters
435497 ----------
436- datasets : list or nested list of Dataset
498+ datasets : list or nested list of Dataset, DataArray or DataTree
437499 Dataset objects to combine.
438500 If concatenation or merging along more than one dimension is desired,
439501 then datasets must be supplied in a nested list-of-lists.
@@ -527,7 +589,7 @@ def combine_nested(
527589
528590 Returns
529591 -------
530- combined : xarray.Dataset
592+ combined : xarray.Dataset or xarray.DataTree
531593
532594 Examples
533595 --------
@@ -621,22 +683,29 @@ def combine_nested(
621683 concat
622684 merge
623685 """
624- mixed_datasets_and_arrays = any (
625- isinstance (obj , Dataset ) for obj in iterate_nested (datasets )
626- ) and any (
686+ any_datasets = any (isinstance (obj , Dataset ) for obj in iterate_nested (datasets ))
687+ any_unnamed_arrays = any (
627688 isinstance (obj , DataArray ) and obj .name is None
628689 for obj in iterate_nested (datasets )
629690 )
630- if mixed_datasets_and_arrays :
691+ if any_datasets and any_unnamed_arrays :
631692 raise ValueError ("Can't combine datasets with unnamed arrays." )
632693
633- if isinstance (concat_dim , str | DataArray ) or concat_dim is None :
634- concat_dim = [concat_dim ]
694+ any_datatrees = any (isinstance (obj , DataTree ) for obj in iterate_nested (datasets ))
695+ all_datatrees = all (isinstance (obj , DataTree ) for obj in iterate_nested (datasets ))
696+ if any_datatrees and not all_datatrees :
697+ raise ValueError ("Can't combine a mix of DataTree and non-DataTree objects." )
698+
699+ concat_dims = (
700+ [concat_dim ]
701+ if isinstance (concat_dim , str | DataArray ) or concat_dim is None
702+ else concat_dim
703+ )
635704
636705 # The IDs argument tells _nested_combine that datasets aren't yet sorted
637706 return _nested_combine (
638707 datasets ,
639- concat_dims = concat_dim ,
708+ concat_dims = concat_dims ,
640709 compat = compat ,
641710 data_vars = data_vars ,
642711 coords = coords ,
@@ -988,6 +1057,10 @@ def combine_by_coords(
9881057 Finally, if you attempt to combine a mix of unnamed DataArrays with either named
9891058 DataArrays or Datasets, a ValueError will be raised (as this is an ambiguous operation).
9901059 """
1060+ if any (isinstance (data_object , DataTree ) for data_object in data_objects ):
1061+ raise NotImplementedError (
1062+ "combine_by_coords() does not yet support DataTree objects."
1063+ )
9911064
9921065 if not data_objects :
9931066 return Dataset ()
@@ -1018,7 +1091,7 @@ def combine_by_coords(
10181091 # Must be a mix of unnamed dataarrays with either named dataarrays or with datasets
10191092 # Can't combine these as we wouldn't know whether to merge or concatenate the arrays
10201093 raise ValueError (
1021- "Can't automatically combine unnamed DataArrays with either named DataArrays or Datasets."
1094+ "Can't automatically combine unnamed DataArrays with named DataArrays or Datasets."
10221095 )
10231096 else :
10241097 # Promote any named DataArrays to single-variable Datasets to simplify combining
0 commit comments