diff --git a/flox/aggregate_flox.py b/flox/aggregate_flox.py index 12750645..e7245aec 100644 --- a/flox/aggregate_flox.py +++ b/flox/aggregate_flox.py @@ -1,10 +1,102 @@ from functools import partial +from typing import Self import numpy as np from . import xrdtypes as dtypes from .xrutils import is_scalar, isnull, notnull +MULTIARRAY_HANDLED_FUNCTIONS = {} + + +class MultiArray: + arrays: tuple[np.ndarray, ...] + + def __init__(self, arrays): + self.arrays = arrays # something else needed here to be more careful about types (not sure what) + # Do we want to co-erce arrays into a tuple and make sure it's immutable? Do we want it to be immutable? + assert all(arrays[0].shape == a.shape for a in arrays), "Expect all arrays to have the same shape" + + def astype(self, dt, **kwargs): + return MultiArray(tuple(array.astype(dt, **kwargs) for array in self.arrays)) + + def reshape(self, shape, **kwargs): + return MultiArray([array.reshape(shape, **kwargs) for array in self.arrays]) + + def squeeze(self, axis=None): + return MultiArray([array.squeeze(axis) for array in self.arrays]) + + def __array_function__(self, func, types, args, kwargs): + if func not in MULTIARRAY_HANDLED_FUNCTIONS: + return NotImplemented + # Note: this allows subclasses that don't override + # __array_function__ to handle MyArray objects + # if not all(issubclass(t, MyArray) for t in types): # I can't see this being relevant at all for this code, but maybe it's safer to leave it in? + # return NotImplemented + return MULTIARRAY_HANDLED_FUNCTIONS[func](*args, **kwargs) + + # Shape is needed, seems likely that the other two might be + # Making some strong assumptions here that all the arrays are the same shape, and I don't really like this + @property + def dtype(self) -> np.dtype: + return self.arrays[0].dtype + + @property + def shape(self) -> tuple[int, ...]: + return self.arrays[0].shape + + @property + def ndim(self) -> int: + return self.arrays[0].ndim + + def __getitem__(self, key) -> Self: + return type(self)([array[key] for array in self.arrays]) + + +def implements(numpy_function): + """Register an __array_function__ implementation for MyArray objects.""" + + def decorator(func): + MULTIARRAY_HANDLED_FUNCTIONS[numpy_function] = func + return func + + return decorator + + +@implements(np.expand_dims) +def expand_dims_MultiArray(multiarray, axis): + return MultiArray( + [np.expand_dims(a, axis) for a in multiarray.arrays] + ) # This is gonna spit out a list and I'm not sure if I'm okay with that? + + +@implements(np.concatenate) +def concatenate_MultiArray(multiarrays, axis): + n_arrays = len(multiarrays[0].arrays) + for ma in multiarrays[1:]: + if not ( + len(ma.arrays) == n_arrays + ): # I don't know what trying to concatenate MultiArrays with different numbers of arrays would even mean + raise NotImplementedError + + # There's the potential for problematic different shapes coming in here. + # Probably warrants some defensive programming, but I'm not sure what to check for while still being generic + + # I don't like using append and lists here, but I can't work out how to do it better + new_arrays = [] + for i in range(multiarrays[0].ndim): + new_arrays.append(np.concatenate([ma.arrays[i] for ma in multiarrays], axis)) + + out = MultiArray(new_arrays) + return out + + +@implements(np.transpose) +def transpose_MultiArray(multiarray, axes): + return MultiArray( + [np.transpose(a, axes) for a in multiarray.arrays] + ) # This is gonna spit out a list and I'm not sure if I'm okay with that? + def _prepare_for_flox(group_idx, array): """ diff --git a/flox/aggregations.py b/flox/aggregations.py index 6246942b..975c3258 100644 --- a/flox/aggregations.py +++ b/flox/aggregations.py @@ -343,12 +343,99 @@ def _mean_finalize(sum_, count): ) +def var_chunk(group_idx, array, *, engine: str, axis=-1, size=None, fill_value=None, dtype=None): + from .aggregate_flox import MultiArray + + # Calculate length and sum - important for the adjustment terms to sum squared deviations + array_lens = generic_aggregate( + group_idx, + array, + func="nanlen", + engine=engine, + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + array_sums = generic_aggregate( + group_idx, + array, + func="nansum", + engine=engine, + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + # Calculate sum squared deviations - the main part of variance sum + array_means = array_sums / array_lens + + sum_squared_deviations = generic_aggregate( + group_idx, + (array - array_means[..., group_idx]) ** 2, + func="nansum", + engine=engine, + axis=axis, + size=size, + fill_value=fill_value, + dtype=dtype, + ) + + return MultiArray((sum_squared_deviations, array_sums, array_lens)) + + +def _var_combine(array, axis, keepdims=True): + def clip_last(array): + """Return array except the last element along axis + Purely included to tidy up the adj_terms line + """ + not_last = [slice(None, None) for i in range(array.ndim)] + not_last[axis[0]] = slice(None, -1) + return array[*not_last] + + def clip_first(array): + """Return array except the first element along axis + Purely included to tidy up the adj_terms line + """ + not_first = [slice(None, None) for i in range(array.ndim)] + not_first[axis[0]] = slice(1, None) + return array[*not_first] + + assert len(axis) == 1, "Assuming that the combine function is only in one direction at once" + + sum_deviations, sum_X, sum_len = array.arrays + + # Calculate parts needed for cascading combination + cumsum_X = np.cumsum(sum_X, axis=axis[0]) # Don't need to be able to merge the last element + cumsum_len = np.cumsum(sum_len, axis=axis[0]) + + # Adjustment terms to tweak the sum of squared deviations because not every chunk has the same mean + adj_terms = ( + clip_last(cumsum_len) * clip_first(sum_X) - clip_first(sum_len) * clip_last(cumsum_X) + ) ** 2 / (clip_last(cumsum_len) * clip_first(sum_len) * (clip_last(cumsum_len) + clip_first(sum_len))) + + return aggregate_flox.MultiArray( + ( + np.sum(sum_deviations, axis=axis, keepdims=keepdims) + + np.sum(adj_terms, axis=axis, keepdims=keepdims), # sum of squared deviations + np.sum(sum_X, axis=axis, keepdims=keepdims), # sum of array items + np.sum(sum_len, axis=axis, keepdims=keepdims), # sum of array lengths + ) + ) # I'm not even pretending calling this class from there is a good idea, I think it wants to be somewhere else though + + # TODO: fix this for complex numbers -def _var_finalize(sumsq, sum_, count, ddof=0): - with np.errstate(invalid="ignore", divide="ignore"): - result = (sumsq - (sum_**2 / count)) / (count - ddof) - result[count <= ddof] = np.nan - return result +# def _var_finalize(sumsq, sum_, count, ddof=0): +# with np.errstate(invalid="ignore", divide="ignore"): +# result = (sumsq - (sum_**2 / count)) / (count - ddof) +# result[count <= ddof] = np.nan +# return result + + +def _var_finalize(multiarray, ddof=0): + return multiarray.arrays[0] / (multiarray.arrays[2] - ddof) def _std_finalize(sumsq, sum_, count, ddof=0): @@ -366,14 +453,25 @@ def _std_finalize(sumsq, sum_, count, ddof=0): dtypes=(None, None, np.intp), final_dtype=np.floating, ) +# nanvar = Aggregation( +# "nanvar", +# chunk=("nansum_of_squares", "nansum", "nanlen"), +# combine=("sum", "sum", "sum"), +# finalize=_var_finalize, +# fill_value=0, +# final_fill_value=np.nan, +# dtypes=(None, None, np.intp), +# final_dtype=np.floating, +# ) nanvar = Aggregation( "nanvar", - chunk=("nansum_of_squares", "nansum", "nanlen"), - combine=("sum", "sum", "sum"), + chunk=var_chunk, + numpy="nanvar", + combine=(_var_combine,), finalize=_var_finalize, fill_value=0, final_fill_value=np.nan, - dtypes=(None, None, np.intp), + dtypes=(None,), final_dtype=np.floating, ) std = Aggregation( diff --git a/flox/core.py b/flox/core.py index d773d189..8200cb97 100644 --- a/flox/core.py +++ b/flox/core.py @@ -45,6 +45,7 @@ _initialize_aggregation, generic_aggregate, quantile_new_dims_func, + var_chunk, ) from .cache import memoize from .lib import ArrayLayer, dask_array_type, sparse_array_type @@ -1288,7 +1289,8 @@ def chunk_reduce( # optimize that out. previous_reduction: T_Func = "" for reduction, fv, kw, dt in zip(funcs, fill_values, kwargss, dtypes): - if empty: + # UGLY! but this is because the `var` breaks our design assumptions + if empty and reduction is not var_chunk: result = np.full(shape=final_array_shape, fill_value=fv, like=array) elif is_nanlen(reduction) and is_nanlen(previous_reduction): result = results["intermediates"][-1] @@ -1297,6 +1299,10 @@ def chunk_reduce( kw_func = dict(size=size, dtype=dt, fill_value=fv) kw_func.update(kw) + # UGLY! but this is because the `var` breaks our design assumptions + if reduction is var_chunk: + kw_func.update(engine=engine) + if callable(reduction): # passing a custom reduction for npg to apply per-group is really slow! # So this `reduction` has to do the groupby-aggregation diff --git a/tests/test_core.py b/tests/test_core.py index 31c6ab5a..ef98ec04 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -236,7 +236,7 @@ def gen_array_by(size, func): @pytest.mark.parametrize("size", [(1, 12), (12,), (12, 9)]) @pytest.mark.parametrize("nby", [1, 2, 3]) @pytest.mark.parametrize("add_nan_by", [True, False]) -@pytest.mark.parametrize("func", ALL_FUNCS) +@pytest.mark.parametrize("func", ["nanvar"]) def test_groupby_reduce_all(to_sparse, nby, size, chunks, func, add_nan_by, engine): if ("arg" in func and engine in ["flox", "numbagg"]) or (func in BLOCKWISE_FUNCS and chunks != -1): pytest.skip() @@ -2240,3 +2240,29 @@ def test_sparse_nan_fill_value_reductions(chunks, fill_value, shape, func): expected = np.expand_dims(npfunc(numpy_array, axis=-1), axis=-1) actual, *_ = groupby_reduce(array, by, func=func, axis=-1) assert_equal(actual, expected) + + +@pytest.mark.parametrize( + "func", ("nanvar", "var") +) # Expect to expand this to other functions once written. "nanvar" has updated chunk, combine functions. "var", for the moment, still uses the old algorithm +@pytest.mark.parametrize("engine", ("flox",)) # Expect to expand this to other engines once written +@pytest.mark.parametrize( + "offset", (0, 10e2, 10e4, 10e6, 10e8, 10e10, 10e12) +) # Should fail at 10e8 for old algorithm, and survive 10e12 for current +def test_std_var_precision(func, engine, offset): + # Generate a dataset with small variance and big mean + # Check that func with engine gives you the same answer as numpy + + l = 1000 + array = np.linspace(-1, 1, l) # has zero mean + labels = np.arange(l) % 2 # Ideally we'd parametrize this too. + + # These two need to be the same function, but with the offset added and not added + no_offset, _ = groupby_reduce(array, labels, engine=engine, func=func) + with_offset, _ = groupby_reduce(array + offset, labels, engine=engine, func=func) + + tol = {"rtol": 1e-8, "atol": 1e-10} # Not sure how stringent to be here + + # Failure threshold in my external tests is dependent on dask chunksize, maybe needs exploring better? + + assert_equal(no_offset, with_offset, tol)