@@ -30,7 +30,7 @@ def plot_structure(
3030 camera_dist : Optional [float ] = None ,
3131 show_axes : bool = False ,
3232 perspective_axes : bool = True ,
33- figsize : Union [ tuple , list , np . ndarray ] = (8 , 8 ),
33+ figsize : tuple = (8 , 8 ),
3434 returnfig : bool = False ,
3535):
3636 """
@@ -98,20 +98,17 @@ def plot_structure(
9898 sub = pos [:, 0 ] < tol_distance
9999 pos = np .vstack ([pos , pos [sub , :] + np .array ([1 , 0 , 0 ])])
100100 ID = np .hstack ([ID , ID [sub ]])
101- if occ is not None :
102- occ = np .hstack ([occ , occ [sub ]])
101+ occ = np .hstack ([occ , occ [sub ]])
103102 # y tile
104103 sub = pos [:, 1 ] < tol_distance
105104 pos = np .vstack ([pos , pos [sub , :] + np .array ([0 , 1 , 0 ])])
106105 ID = np .hstack ([ID , ID [sub ]])
107- if occ is not None :
108- occ = np .hstack ([occ , occ [sub ]])
106+ occ = np .hstack ([occ , occ [sub ]])
109107 # z tile
110108 sub = pos [:, 2 ] < tol_distance
111109 pos = np .vstack ([pos , pos [sub , :] + np .array ([0 , 0 , 1 ])])
112110 ID = np .hstack ([ID , ID [sub ]])
113- if occ is not None :
114- occ = np .hstack ([occ , occ [sub ]])
111+ occ = np .hstack ([occ , occ [sub ]])
115112
116113 # Cartesian atomic positions
117114 xyz = pos @ self .lat_real
@@ -150,7 +147,7 @@ def plot_structure(
150147
151148 # atoms
152149 ID_all = np .unique (ID )
153- if np .all (np . asarray ( occ ) == 1 ):
150+ if np .all (occ == 1.0 ):
154151 for ID_plot in ID_all :
155152 sub = ID == ID_plot
156153 ax .scatter (
0 commit comments