diff --git a/py4DSTEM/process/phase/parallax.py b/py4DSTEM/process/phase/parallax.py index d3d41447c..e490f1a85 100644 --- a/py4DSTEM/process/phase/parallax.py +++ b/py4DSTEM/process/phase/parallax.py @@ -745,15 +745,17 @@ def preprocess( if force_transpose: force_rotation_angle_deg *= -1 - aberrations_basis, aberrations_basis_du, aberrations_basis_dv = ( - calculate_aberration_gradient_basis( - aberrations_mn, - sampling, - self._region_of_interest_shape, - self._wavelength, - rotation_angle=np.deg2rad(force_rotation_angle_deg), - xp=xp, - ) + ( + aberrations_basis, + aberrations_basis_du, + aberrations_basis_dv, + ) = calculate_aberration_gradient_basis( + aberrations_mn, + sampling, + self._region_of_interest_shape, + self._wavelength, + rotation_angle=np.deg2rad(force_rotation_angle_deg), + xp=xp, ) # shifts @@ -828,7 +830,6 @@ def preprocess( ) else: - self._recon_BF = ( self._stack_mean * mask_inv + xp.mean(self._stack_BF_shifted * self._stack_mask, axis=0) @@ -2392,14 +2393,18 @@ def aberration_fit( asnumpy = self._asnumpy # Initial estimate - shifts_Ang, rotation_rad, aberrations_C1, aberrations_C12a, aberrations_C12b = ( - self._aberration_fit_polar_decomposition( - self._xy_shifts, - self._scan_sampling, - self._probe_angles, - force_transpose=force_transpose, - force_rotation_angle_deg=force_rotation_angle_deg, - ) + ( + shifts_Ang, + rotation_rad, + aberrations_C1, + aberrations_C12a, + aberrations_C12b, + ) = self._aberration_fit_polar_decomposition( + self._xy_shifts, + self._scan_sampling, + self._probe_angles, + force_transpose=force_transpose, + force_rotation_angle_deg=force_rotation_angle_deg, ) self.aberrations_C1 = aberrations_C1 @@ -2532,14 +2537,15 @@ def calculate_CTF(alpha_shape, *coefs): ) ) - self._aberrations_coefs, fitted_shifts_Ang = ( - self._aberration_fit_deltas_and_increment( - shifts_Ang, - fitted_shifts_Ang, - gradients, - self._aberrations_coefs, - indices, - ) + ( + self._aberrations_coefs, + fitted_shifts_Ang, + ) = self._aberration_fit_deltas_and_increment( + shifts_Ang, + fitted_shifts_Ang, + gradients, + self._aberrations_coefs, + indices, ) if force_transpose: @@ -2689,7 +2695,6 @@ def calculate_CTF(alpha_shape, *coefs): row_index += 1 if plot_BF_shifts_comparison: - scale_arrows = kwargs.pop("scale_arrows", 1) plot_arrow_freq = kwargs.pop("plot_arrow_freq", 1) @@ -2982,7 +2987,9 @@ def depth_section( if use_CTF_fit: sin_chi = xp.sin( - self._calculate_CTF((nx, ny), (sx, sy), *self._aberrations_coefs) + self._calculate_CTF( + (nx, ny), (sx, sy), self._aberrations_mn, self._aberrations_coefs + ) ) else: sin_chi = xp.sin((xp.pi * self._wavelength * self.aberrations_C1) * kra2)