Skip to content

Commit 7d228b7

Browse files
thuydotmgiancastro
andauthored
Added focal_stats (#453)
* commented out warnings * removed calc_mean and calc_sum from documentation * added focal_stats * added stats functions * added tests * corrected docstrings * flake8 fixes * convolution_2d * flake8 fixes * remove commented code * fix focal docstrings * add convolution_2d and focal_stats to docs Co-authored-by: giancastro <giancastrok@gmail.com>
1 parent d89d302 commit 7d228b7

File tree

4 files changed

+137
-145
lines changed

4 files changed

+137
-145
lines changed

docs/source/reference/focal.rst

100644100755
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ Focal Statistics
3131
.. autosummary::
3232
:toctree: _autosummary
3333

34+
xrspatial.convolution.convolution_2d
3435
xrspatial.convolution.annulus_kernel
3536
xrspatial.convolution.calc_cellsize
36-
xrspatial.focal.calc_mean
37-
xrspatial.focal.calc_sum
3837
xrspatial.convolution.circle_kernel
3938
xrspatial.focal.custom_kernel
39+
xrspatial.focal.focal_stats

xrspatial/convolution.py

Lines changed: 34 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
from functools import partial
22

33
import re
4-
import warnings
54

65
import numpy as np
76
import dask.array as da
7+
from xarray import DataArray
88

99
from numba import cuda, float32, prange, jit
1010

@@ -52,13 +52,8 @@ def _get_distance(distance_str):
5252
raise ValueError("Invalid distance.")
5353

5454
unit = DEFAULT_UNIT
55-
if len(splits) == 1:
56-
with warnings.catch_warnings():
57-
warnings.simplefilter('default')
58-
warnings.warn('Raster distance unit not provided. '
59-
'Use meter as default.', Warning)
6055

61-
elif len(splits) == 2:
56+
if len(splits) == 2:
6257
unit = splits[1]
6358

6459
number = splits[0]
@@ -163,10 +158,6 @@ def calc_cellsize(raster):
163158
unit = raster.attrs['unit']
164159
else:
165160
unit = DEFAULT_UNIT
166-
with warnings.catch_warnings():
167-
warnings.simplefilter('default')
168-
warnings.warn('Raster distance unit not provided. '
169-
'Use meter as default.', Warning)
170161

171162
cellsize_x, cellsize_y = get_dataarray_resolution(raster)
172163
cellsize_x = _to_meters(cellsize_x, unit)
@@ -286,25 +277,6 @@ def annulus_kernel(cellsize_x, cellsize_y, outer_radius, inner_radius):
286277
[0., 1., 1., 1., 1., 0., 1., 1., 1., 1., 0.],
287278
[0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])
288279
"""
289-
# validate radii, convert to meters
290-
r2 = _get_distance(str(outer_radius))
291-
r1 = _get_distance(str(inner_radius))
292-
293-
# Validate that outer radius is indeed outer radius
294-
if r2 > r1:
295-
r_outer = r2
296-
r_inner = r1
297-
else:
298-
r_outer = r1
299-
r_inner = r2
300-
301-
if r_outer - r_inner < np.sqrt((cellsize_x / 2)**2 + (cellsize_y / 2)**2):
302-
with warnings.catch_warnings():
303-
warnings.simplefilter('default')
304-
warnings.warn(
305-
'Annulus radii are closer than cellsize distance.', Warning
306-
)
307-
308280
# Get the two circular kernels for the annulus
309281
kernel_outer = circle_kernel(cellsize_x, cellsize_y, outer_radius)
310282
kernel_inner = circle_kernel(cellsize_x, cellsize_y, inner_radius)
@@ -534,3 +506,35 @@ def convolve_2d(data, kernel):
534506
raise TypeError('Unsupported Array Type: {}'.format(type(data)))
535507

536508
return out
509+
510+
511+
def convolution_2d(agg, kernel):
512+
"""
513+
Calculates, for all inner cells of an array, the 2D convolution of
514+
each cell via Numba. To account for edge cells, a pad can be added
515+
to the image array. Convolution is frequently used for image
516+
processing, such as smoothing, sharpening, and edge detection of
517+
images by eliminating spurious data or enhancing features in the
518+
data.
519+
520+
Parameters
521+
----------
522+
agg : xarray.DataArray
523+
2D array of values to processed and padded.
524+
kernel : array-like object
525+
Impulse kernel, determines area to apply impulse function for
526+
each cell.
527+
528+
Returns
529+
-------
530+
convolve_agg : xarray.DataArray
531+
2D array representation of the impulse function.
532+
"""
533+
534+
# wrapper of convolve_2d
535+
out = convolve_2d(agg.data, kernel)
536+
537+
return DataArray(out,
538+
coords=agg.coords,
539+
dims=agg.dims,
540+
attrs=agg.attrs)

xrspatial/focal.py

Lines changed: 79 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,19 @@
1-
from functools import partial
1+
import numpy as np
2+
import pandas as pd
3+
import xarray as xr
4+
import dask.array as da
25

6+
from functools import partial
37
from math import isnan
4-
5-
from numba import prange
6-
import numpy as np
8+
from numba import prange, cuda
79
from xarray import DataArray
8-
import dask.array as da
910

1011
try:
1112
import cupy
1213
except ImportError:
1314
class cupy(object):
1415
ndarray = False
1516

16-
from numba import cuda
1717
from xrspatial.utils import cuda_args
1818
from xrspatial.utils import has_cuda
1919
from xrspatial.utils import is_cupy_backed
@@ -208,119 +208,40 @@ def mean(agg, passes=1, excludes=[np.nan], name='mean'):
208208

209209

210210
@ngjit
211-
def calc_mean(array):
212-
"""
213-
Calculates the mean of an array.
211+
def _calc_mean(array):
212+
return np.nanmean(array)
214213

215-
Parameters
216-
----------
217-
array : numpy.Array
218-
Array of input values.
219214

220-
Returns
221-
-------
222-
array_sum : float
223-
Mean of input data.
215+
@ngjit
216+
def _calc_sum(array):
217+
return np.nansum(array)
224218

225-
Examples
226-
--------
227-
.. sourcecode:: python
228219

229-
>>> from xrspatial.focal import calc_mean
230-
>>> import numpy as np
231-
232-
>>> # 1D Array of Integers
233-
>>> array1 = np.array([1, 2, 3, 4, 5])
234-
>>> print(array1)
235-
[1 2 3 4 5]
236-
237-
>>> # Calculate Mean
238-
>>> array_mean = calc_mean(array1)
239-
>>> print(array_mean)
240-
3.0
241-
242-
>>> # 2D Array of Floats
243-
>>> array2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
244-
>>> print(array2)
245-
[[1. 2. 3.]
246-
[4. 5. 6.]]
247-
248-
>>> # Calculate Mean
249-
>>> array_mean = calc_mean(array2)
250-
>>> print(array_mean)
251-
3.5
252-
253-
>>> # 3D Array of Integers
254-
>>> array3 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
255-
>>> print(array3)
256-
[[1 2 3]
257-
[4 5 6]
258-
[7 8 9]]
259-
260-
>>> # Calculate Mean
261-
>>> array_mean = calc_mean(array3)
262-
>>> print(array_mean)
263-
5.0
264-
"""
265-
return np.nanmean(array)
220+
@ngjit
221+
def _calc_min(array):
222+
return np.nanmin(array)
266223

267224

268225
@ngjit
269-
def calc_sum(array):
270-
"""
271-
Calculates the sum of an array.
226+
def _calc_max(array):
227+
return np.nanmax(array)
272228

273-
Parameters
274-
----------
275-
array : numpy.Array
276-
Array of input values.
277229

278-
Returns
279-
-------
280-
array_sum : float
281-
Sum of input data.
230+
@ngjit
231+
def _calc_std(array):
232+
return np.nanstd(array)
282233

283-
Examples
284-
--------
285-
.. sourcecode:: python
286234

287-
>>> from xrspatial.focal import calc_sum
288-
>>> import numpy as np
289-
290-
>>> # 1D Array of Integers
291-
>>> array1 = np.array([1, 2, 3, 4, 5])
292-
>>> print(array1)
293-
[1 2 3 4 5]
294-
295-
>>> # Calculate Sum
296-
>>> array_sum = calc_sum(array1)
297-
>>> print(array_sum)
298-
15
299-
300-
>>> # 2D Array of Floats
301-
>>> array2 = np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]])
302-
>>> print(array2)
303-
[[1. 2. 3.]
304-
[4. 5. 6.]]
305-
306-
>>> # Calculate Sum
307-
>>> array_sum = calc_sum(array2)
308-
>>> print(array_sum)
309-
21.0
310-
311-
>>> # 3D Array of Integers
312-
>>> array3 = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
313-
>>> print(array3)
314-
[[1 2 3]
315-
[4 5 6]
316-
[7 8 9]]
317-
318-
>>> # Calculate Sum
319-
>>> agg_sum = calc_sum(array3)
320-
>>> print(agg_sum)
321-
45
322-
"""
323-
return np.nansum(array)
235+
@ngjit
236+
def _calc_range(array):
237+
value_min = _calc_min(array)
238+
value_max = _calc_max(array)
239+
return value_max - value_min
240+
241+
242+
@ngjit
243+
def _calc_var(array):
244+
return np.nanvar(array)
324245

325246

326247
@ngjit
@@ -419,17 +340,17 @@ def _apply(data, kernel, func):
419340
return out
420341

421342

422-
def apply(raster, kernel, func=calc_mean):
343+
def apply(raster, kernel, func=_calc_mean):
423344
"""
424-
Returns Mean filtered array using a user-created window.
345+
Returns custom function applied array using a user-created window.
425346
426347
Parameters
427348
----------
428349
raster : xarray.DataArray
429350
2D array of input values to be filtered.
430-
kernel : Numpy Array
351+
kernel : numpy.array
431352
2D array where values of 1 indicate the kernel.
432-
func : xrspatial.focal.calc_mean
353+
func : callable, default=xrspatial.focal._calc_mean
433354
Function which takes an input array and returns an array.
434355
435356
Returns
@@ -547,6 +468,52 @@ def apply(raster, kernel, func=calc_mean):
547468
return result
548469

549470

471+
def focal_stats(agg,
472+
kernel,
473+
stats_funcs=[
474+
'mean', 'max', 'min', 'range', 'std', 'var', 'sum'
475+
]):
476+
"""
477+
Calculates statistics of the values within a specified focal neighborhood
478+
for each pixel in an input raster. The statistics types are Mean, Maximum,
479+
Minimum, Range, Standard deviation, Variation and Sum.
480+
481+
Parameters
482+
----------
483+
agg : xarray.DataArray
484+
2D array of input values to be analysed.
485+
kernel : numpy.array
486+
2D array where values of 1 indicate the kernel.
487+
stats_funcs: list of string
488+
List of statistics types to be calculated.
489+
Default set to ['mean', 'max', 'min', 'range', 'std', 'var', 'sum'].
490+
491+
Returns
492+
-------
493+
stats_agg : xarray.DataArray of same type as `agg`
494+
3D array with dimensions of `(stat, y, x)` and with values
495+
indicating the focal stats.
496+
"""
497+
498+
_function_mapping = {
499+
'mean': _calc_mean,
500+
'max': _calc_max,
501+
'min': _calc_min,
502+
'range': _calc_range,
503+
'std': _calc_std,
504+
'var': _calc_var,
505+
'sum': _calc_sum
506+
}
507+
508+
stats_aggs = []
509+
for stats in stats_funcs:
510+
stats_agg = apply(agg, kernel, func=_function_mapping[stats])
511+
stats_aggs.append(stats_agg)
512+
513+
stats = xr.concat(stats_aggs, pd.Index(stats_funcs, name='stats'))
514+
return stats
515+
516+
550517
@ngjit
551518
def _calc_hotspots_numpy(z_array):
552519
out = np.zeros_like(z_array, dtype=np.int8)

xrspatial/tests/test_focal.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from xrspatial.utils import ngjit
99

1010
from xrspatial import mean
11-
from xrspatial.focal import hotspots, apply
11+
from xrspatial.focal import hotspots, apply, focal_stats
1212
from xrspatial.convolution import (
1313
convolve_2d, calc_cellsize, circle_kernel, annulus_kernel
1414
)
@@ -287,6 +287,27 @@ def func_zero_cpu(x):
287287
assert e_info
288288

289289

290+
def test_focal_stats_cpu():
291+
data = np.arange(16).reshape(4, 4)
292+
numpy_agg = xr.DataArray(data)
293+
dask_numpy_agg = xr.DataArray(da.from_array(data, chunks=(3, 3)))
294+
295+
cellsize = (1, 1)
296+
kernel = circle_kernel(*cellsize, 1.5)
297+
298+
numpy_focalstats = focal_stats(numpy_agg, kernel)
299+
assert isinstance(numpy_focalstats.data, np.ndarray)
300+
assert numpy_focalstats.ndim == 3
301+
assert numpy_agg.shape == numpy_focalstats.shape[1:]
302+
303+
dask_numpy_focalstats = focal_stats(dask_numpy_agg, kernel)
304+
assert isinstance(dask_numpy_focalstats.data, da.Array)
305+
306+
assert np.isclose(
307+
numpy_focalstats, dask_numpy_focalstats.compute(), equal_nan=True
308+
).all()
309+
310+
290311
def test_hotspot():
291312
n, m = 10, 10
292313
data = np.zeros((n, m), dtype=float)

0 commit comments

Comments
 (0)