@@ -30,19 +30,19 @@ def plot_structure(
3030 camera_dist : Optional [float ] = None ,
3131 show_axes : bool = False ,
3232 perspective_axes : bool = True ,
33- figsize : Union [ tuple , list , np . ndarray ] = (8 , 8 ),
33+ figsize : tuple = (8 , 8 ),
3434 returnfig : bool = False ,
3535):
3636 """
37- Quick 3D plot of the untit cell /atomic structure.
37+ Quick 3D plot of the unit cell/atomic structure.
3838
3939 Args:
4040 orientation_matrix (array): (3,3) orientation matrix, where columns represent projection directions.
4141 zone_axis_lattice (array): (3,) projection direction in lattice indices
4242 proj_x_lattice (array): (3,) x-axis direction in lattice indices
4343 zone_axis_cartesian (array): (3,) cartesian projection direction
4444 proj_x_cartesian (array): (3,) cartesian projection direction
45- scale_markers (float): Size scaling for markers
45+ size_marker (float): Size scaling for markers
4646 tol_distance (float): Tolerance for repeating atoms on edges on cell boundaries.
4747 plot_limit (float): (2,3) numpy array containing x y z plot min and max in columns.
4848 Default is 1.1* unit cell dimensions.
@@ -98,20 +98,17 @@ def plot_structure(
9898 sub = pos [:, 0 ] < tol_distance
9999 pos = np .vstack ([pos , pos [sub , :] + np .array ([1 , 0 , 0 ])])
100100 ID = np .hstack ([ID , ID [sub ]])
101- if occ is not None :
102- occ = np .hstack ([occ , occ [sub ]])
101+ occ = np .hstack ([occ , occ [sub ]])
103102 # y tile
104103 sub = pos [:, 1 ] < tol_distance
105104 pos = np .vstack ([pos , pos [sub , :] + np .array ([0 , 1 , 0 ])])
106105 ID = np .hstack ([ID , ID [sub ]])
107- if occ is not None :
108- occ = np .hstack ([occ , occ [sub ]])
106+ occ = np .hstack ([occ , occ [sub ]])
109107 # z tile
110108 sub = pos [:, 2 ] < tol_distance
111109 pos = np .vstack ([pos , pos [sub , :] + np .array ([0 , 0 , 1 ])])
112110 ID = np .hstack ([ID , ID [sub ]])
113- if occ is not None :
114- occ = np .hstack ([occ , occ [sub ]])
111+ occ = np .hstack ([occ , occ [sub ]])
115112
116113 # Cartesian atomic positions
117114 xyz = pos @ self .lat_real
@@ -150,7 +147,7 @@ def plot_structure(
150147
151148 # atoms
152149 ID_all = np .unique (ID )
153- if occ is None :
150+ if np . all ( occ == 1.0 ) :
154151 for ID_plot in ID_all :
155152 sub = ID == ID_plot
156153 ax .scatter (
@@ -166,7 +163,7 @@ def plot_structure(
166163 # init
167164 tol = 1e-4
168165 num_seg = 180
169- radius = 0.7
166+ radius = size_marker / 800
170167 zp = np .zeros (num_seg + 1 )
171168
172169 mark = np .ones (xyz .shape [0 ], dtype = "bool" )
0 commit comments