Skip to content

Commit 04e2b6b

Browse files
authored
Merge pull request #706 from tcpekin/small_bug
Small typo, docstring update, and fixed a bug.
2 parents 9f46139 + ec01593 commit 04e2b6b

File tree

1 file changed

+8
-11
lines changed

1 file changed

+8
-11
lines changed

py4DSTEM/process/diffraction/crystal_viz.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,19 +30,19 @@ def plot_structure(
3030
camera_dist: Optional[float] = None,
3131
show_axes: bool = False,
3232
perspective_axes: bool = True,
33-
figsize: Union[tuple, list, np.ndarray] = (8, 8),
33+
figsize: tuple = (8, 8),
3434
returnfig: bool = False,
3535
):
3636
"""
37-
Quick 3D plot of the untit cell /atomic structure.
37+
Quick 3D plot of the unit cell/atomic structure.
3838
3939
Args:
4040
orientation_matrix (array): (3,3) orientation matrix, where columns represent projection directions.
4141
zone_axis_lattice (array): (3,) projection direction in lattice indices
4242
proj_x_lattice (array): (3,) x-axis direction in lattice indices
4343
zone_axis_cartesian (array): (3,) cartesian projection direction
4444
proj_x_cartesian (array): (3,) cartesian projection direction
45-
scale_markers (float): Size scaling for markers
45+
size_marker (float): Size scaling for markers
4646
tol_distance (float): Tolerance for repeating atoms on edges on cell boundaries.
4747
plot_limit (float): (2,3) numpy array containing x y z plot min and max in columns.
4848
Default is 1.1* unit cell dimensions.
@@ -98,20 +98,17 @@ def plot_structure(
9898
sub = pos[:, 0] < tol_distance
9999
pos = np.vstack([pos, pos[sub, :] + np.array([1, 0, 0])])
100100
ID = np.hstack([ID, ID[sub]])
101-
if occ is not None:
102-
occ = np.hstack([occ, occ[sub]])
101+
occ = np.hstack([occ, occ[sub]])
103102
# y tile
104103
sub = pos[:, 1] < tol_distance
105104
pos = np.vstack([pos, pos[sub, :] + np.array([0, 1, 0])])
106105
ID = np.hstack([ID, ID[sub]])
107-
if occ is not None:
108-
occ = np.hstack([occ, occ[sub]])
106+
occ = np.hstack([occ, occ[sub]])
109107
# z tile
110108
sub = pos[:, 2] < tol_distance
111109
pos = np.vstack([pos, pos[sub, :] + np.array([0, 0, 1])])
112110
ID = np.hstack([ID, ID[sub]])
113-
if occ is not None:
114-
occ = np.hstack([occ, occ[sub]])
111+
occ = np.hstack([occ, occ[sub]])
115112

116113
# Cartesian atomic positions
117114
xyz = pos @ self.lat_real
@@ -150,7 +147,7 @@ def plot_structure(
150147

151148
# atoms
152149
ID_all = np.unique(ID)
153-
if occ is None:
150+
if np.all(occ == 1.0):
154151
for ID_plot in ID_all:
155152
sub = ID == ID_plot
156153
ax.scatter(
@@ -166,7 +163,7 @@ def plot_structure(
166163
# init
167164
tol = 1e-4
168165
num_seg = 180
169-
radius = 0.7
166+
radius = size_marker / 800
170167
zp = np.zeros(num_seg + 1)
171168

172169
mark = np.ones(xyz.shape[0], dtype="bool")

0 commit comments

Comments
 (0)