diff --git a/python/stempy/image/__init__.py b/python/stempy/image/__init__.py index dee8d95d..94db98a0 100644 --- a/python/stempy/image/__init__.py +++ b/python/stempy/image/__init__.py @@ -404,7 +404,7 @@ def com_v1_kernel( position_indices: np.ndarray, scan_shape: Tuple[int, int], frame_shape: Tuple[int, int], - crop_to: Union[Tuple[int, int], None] = None, + crop_to: Union[int, None] = None, init_center: Union[Tuple[int, int], None] = None, replace_nans: bool = True, ) -> np.ndarray: @@ -432,12 +432,8 @@ def com_v1_kernel( # Cropping if crop_to is not None: if init_center is not None: - xmin = init_center[0] - crop_to[0] - xmax = init_center[0] + crop_to[0] - ymin = init_center[1] - crop_to[1] - ymax = init_center[1] + crop_to[1] - - mask = (x > xmin) & (x <= xmax) & (y > ymin) & (y <= ymax) + r = np.sqrt((x - init_center[0])**2 + (y - init_center[1])**2) + mask = (r < crop_to) position_indices = position_indices[mask] x = x[mask] y = y[mask] @@ -454,13 +450,8 @@ def com_v1_kernel( event_center_x = centers_x[position_indices] event_center_y = centers_y[position_indices] - - mask = ( - (x > event_center_x - crop_to[0]) - & (x <= event_center_x + crop_to[0]) - & (y > event_center_y - crop_to[1]) - & (y <= event_center_y + crop_to[1]) - ) + r = np.sqrt((x - event_center_x)**2 + (y - event_center_y)**2) + mask = (r < crop_to) position_indices = position_indices[mask] x = x[mask] y = y[mask] @@ -552,8 +543,8 @@ def _com_sparse_v0(array, crop_to=None, init_center=None, replace_nans=True): if crop_to is not None: # Crop around the initial center - keep = (x > comx0 - crop_to[0]) & (x <= comx0 + crop_to[0]) & (y > comy0 - crop_to[1]) & ( - y <= comy0 + crop_to[1]) + r = np.sqrt((x - comx0)**2 + (y - comy0)**2) + keep = (r < crop_to) x = x[keep] y = y[keep] mm = len(x) @@ -594,9 +585,9 @@ def com_sparse( :param array: A SparseArray of the electron counted data :type array: stempy.io.SparseArray - :param crop_to: optional; The size of the region to crop around initial full frame COM for improved COM near + :param crop_to: optional; The radius from the center to crop around initial full frame COM for improved COM near the zero beam - :type crop_to: tuple of ints of length 2 + :type crop_to: int :param init_center: optional; The initial center to use before cropping. If this is not set then cropping will be applied around the center of mass of the each full frame. :type init_center: tuple of ints of length 2 diff --git a/tests/test_image.py b/tests/test_image.py index a964c41c..731836a4 100644 --- a/tests/test_image.py +++ b/tests/test_image.py @@ -58,12 +58,12 @@ def test_com_sparse_parameters(simulate_sparse_array, version): assert round(com0[0,].mean()) == 30 # Test crop_to input. Initial COM should be full frame COM - com1 = com_sparse(sp, crop_to=(10, 10), version=version) + com1 = com_sparse(sp, crop_to=10, version=version) assert round(com1[0,].mean()) == 30 # Test crop_to and init_center input. # No counts will be in the center so all positions will be np.nan - com2 = com_sparse(sp, crop_to=(10, 10), init_center=(1, 1), version=version) + com2 = com_sparse(sp, crop_to=10, init_center=(1, 1), version=version) assert np.isnan(com2[0,0,0]) @@ -75,17 +75,17 @@ def test_com_sparse_version_comparison(sparse_array_small): {"crop_to": None, "init_center": None, "replace_nans": True}, {"crop_to": None, "init_center": None, "replace_nans": False}, { - "crop_to": (frame_x // 2, frame_y // 2), + "crop_to": frame_x // 2, "init_center": None, "replace_nans": True, }, { - "crop_to": (frame_x // 2, frame_y // 2), + "crop_to": frame_x // 2, "init_center": (frame_x // 2, frame_y // 2), "replace_nans": True, }, { - "crop_to": (frame_x // 2, frame_y // 2), + "crop_to": frame_x // 2, "init_center": (frame_x // 2, frame_y // 2), "replace_nans": False, },