Skip to content

Commit 8cfd999

Browse files
authored
Add is_supported_aggregation (#455)
1 parent 839b4d0 commit 8cfd999

File tree

3 files changed

+24
-3
lines changed

3 files changed

+24
-3
lines changed

flox/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""Top-level module for flox ."""
44

55
from . import cache
6-
from .aggregations import Aggregation, Scan
6+
from .aggregations import Aggregation, Scan, is_supported_aggregation
77
from .core import (
88
groupby_reduce,
99
groupby_scan,
@@ -36,4 +36,5 @@ def _get_version():
3636
"set_options",
3737
"ReindexStrategy",
3838
"ReindexArrayType",
39+
"is_supported_aggregation",
3940
]

flox/aggregations.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from . import aggregate_flox, aggregate_npg, xrutils
1616
from . import xrdtypes as dtypes
17-
from .lib import sparse_array_type
17+
from .lib import dask_array_type, sparse_array_type
1818

1919
if TYPE_CHECKING:
2020
FuncTuple = tuple[Callable | str, ...]
@@ -895,3 +895,20 @@ def _initialize_aggregation(
895895
agg.simple_combine = tuple(simple_combine)
896896

897897
return agg
898+
899+
900+
def is_supported_aggregation(array, func: str) -> bool:
901+
if isinstance(array, dask_array_type):
902+
array = array._meta
903+
904+
if isinstance(array, sparse_array_type):
905+
from flox.core import _is_sparse_supported_reduction
906+
907+
return _is_sparse_supported_reduction(func)
908+
909+
module, *_ = type(array).__module__.split(".")
910+
911+
if module in ["numpy", "cubed"]:
912+
return func in AGGREGATIONS
913+
else:
914+
return False

flox/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,10 @@ def _postprocess_numbagg(result, *, func, fill_value, size, seen_groups):
209209
if needs_masking:
210210
mask = np.isin(groups, seen_groups, assume_unique=True, invert=True)
211211
if mask.any():
212-
result[..., groups[mask]] = fill_value
212+
if isinstance(result, sparse_array_type):
213+
result.fill_value = fill_value
214+
else:
215+
result[..., groups[mask]] = fill_value
213216
return result
214217

215218

0 commit comments

Comments
 (0)