11#### this file contains a function/s that will check if various
22# libaries/compute options are available
33import importlib
4- from operator import mod
5-
6- # list of modules we expect/may expect to be installed
7- # as part of a standard py4DSTEM installation
8- # this needs to be the import name e.g. import mp_api not mp-api
9- modules = [
10- "crystal4D" ,
11- "cupy" ,
12- "dask" ,
13- "dill" ,
14- "distributed" ,
15- "gdown" ,
16- "h5py" ,
17- "ipyparallel" ,
18- "jax" ,
19- "matplotlib" ,
20- "mp_api" ,
21- "ncempy" ,
22- "numba" ,
23- "numpy" ,
24- "pymatgen" ,
25- "skimage" ,
26- "sklearn" ,
27- "scipy" ,
28- "tensorflow" ,
29- "tensorflow-addons" ,
30- "tqdm" ,
31- ]
32-
33- # currently this was copy and pasted from setup.py,
34- # hopefully there's a programatic way to do this.
35- module_depenencies = {
36- "base" : [
37- "numpy" ,
38- "scipy" ,
39- "h5py" ,
40- "ncempy" ,
41- "matplotlib" ,
42- "skimage" ,
43- "sklearn" ,
44- "tqdm" ,
45- "dill" ,
46- "gdown" ,
47- "dask" ,
48- "distributed" ,
49- ],
50- "ipyparallel" : ["ipyparallel" , "dill" ],
51- "cuda" : ["cupy" ],
52- "acom" : ["pymatgen" , "mp_api" ],
53- "aiml" : ["tensorflow" , "tensorflow-addons" , "crystal4D" ],
54- "aiml-cuda" : ["tensorflow" , "tensorflow-addons" , "crystal4D" , "cupy" ],
55- "numba" : ["numba" ],
4+ from importlib .metadata import requires
5+ import re
6+ from importlib .util import find_spec
7+
8+ # need a mapping of pypi/conda names to import names
9+ import_mapping_dict = {
10+ "scikit-image" : "skimage" ,
11+ "scikit-learn" : "sklearn" ,
12+ "scikit-optimize" : "skopt" ,
13+ "mp-api" : "mp_api" ,
5614}
5715
5816
17+ # programatically get all possible requirements in the import name style
18+ def get_modules_list ():
19+ # Get the dependencies from the installed distribution
20+ dependencies = requires ("py4DSTEM" )
21+
22+ # Define a regular expression pattern for splitting on '>', '>=', '='
23+ delimiter_pattern = re .compile (r">=|>|==|<|<=" )
24+
25+ # Extract only the module names without versions
26+ module_names = [
27+ delimiter_pattern .split (dependency .split (";" )[0 ], 1 )[0 ].strip ()
28+ for dependency in dependencies
29+ ]
30+
31+ # translate pypi names to import names e.g. scikit-image->skimage, mp-api->mp_api
32+ for index , module in enumerate (module_names ):
33+ if module in import_mapping_dict .keys ():
34+ module_names [index ] = import_mapping_dict [module ]
35+
36+ return module_names
37+
38+
39+ # programatically get all possible requirements in the import name style,
40+ # split into a dict where optional import names are keys
41+ def get_modules_dict ():
42+ package_name = "py4DSTEM"
43+ # Get the dependencies from the installed distribution
44+ dependencies = requires (package_name )
45+
46+ # set the dictionary for modules and packages to go into
47+ # optional dependencies will be added as they are discovered
48+ modules_dict = {
49+ "base" : [],
50+ }
51+ # loop over the dependencies
52+ for depend in dependencies :
53+ # all the optional have extra in the name
54+ # if its not there append it to base
55+ if "extra" not in depend :
56+ # String looks like: 'numpy>=1.19'
57+ modules_dict ["base" ].append (depend )
58+
59+ # if it has extra in the string
60+ else :
61+ # get the name of the optional name
62+ # depend looks like this 'numba>=0.49.1; extra == "numba"'
63+ # grab whatever is in the double quotes i.e. numba
64+ optional_name = re .search (r'"(.*?)"' , depend ).group (1 )
65+ # if the optional name is not in the dict as a key i.e. first requirement of hte optional dependency
66+ if optional_name not in modules_dict :
67+ modules_dict [optional_name ] = [depend ]
68+ # if the optional_name is already in the dict then just append it to the list
69+ else :
70+ modules_dict [optional_name ].append (depend )
71+ # STRIP all the versioning and semi-colons
72+ # Define a regular expression pattern for splitting on '>', '>=', '='
73+ delimiter_pattern = re .compile (r">=|>|==|<|<=" )
74+ for key , val in modules_dict .items ():
75+ # modules_dict[key] = [dependency.split(';')[0].split(' ')[0] for dependency in val]
76+ modules_dict [key ] = [
77+ delimiter_pattern .split (dependency .split (";" )[0 ], 1 )[0 ].strip ()
78+ for dependency in val
79+ ]
80+
81+ # translate pypi names to import names e.g. scikit-image->skimage, mp-api->mp_api
82+ for key , val in modules_dict .items ():
83+ for index , module in enumerate (val ):
84+ if module in import_mapping_dict .keys ():
85+ val [index ] = import_mapping_dict [module ]
86+
87+ return modules_dict
88+
89+
90+ # module_depenencies = get_modules_dict()
91+ modules = get_modules_list ()
92+
93+
5994#### Class and Functions to Create Coloured Strings ####
6095class colours :
6196 CEND = "\x1b [0m"
@@ -140,6 +175,7 @@ def create_underline(s: str) -> str:
140175### here I use the term state to define a boolean condition as to whether a libary/module was sucessfully imported/can be used
141176
142177
178+ # get the state of each modules as a dict key-val e.g. "numpy" : True
143179def get_import_states (modules : list = modules ) -> dict :
144180 """
145181 Check the ability to import modules and store the results as a boolean value. Returns as a dict.
@@ -163,16 +199,17 @@ def get_import_states(modules: list = modules) -> dict:
163199 return import_states_dict
164200
165201
202+ # Check
166203def get_module_states (state_dict : dict ) -> dict :
167- """_summary_
168-
169- Args:
170- state_dict (dict): _description_
204+ """
205+ given a state dict for all modules e.g. "numpy" : True,
206+ this parses through and checks if all modules required for a state are true
171207
172- Returns:
173- dict: _description_
208+ returns dict "base": True, "ai-ml": False etc.
174209 """
175210
211+ # get the modules_dict
212+ module_depenencies = get_modules_dict ()
176213 # create an empty dict to put module states into:
177214 module_states = {}
178215
@@ -196,13 +233,12 @@ def get_module_states(state_dict: dict) -> dict:
196233
197234
198235def print_import_states (import_states : dict ) -> None :
199- """_summary_
200-
201- Args:
202- import_states (dict): _description_
236+ """
237+ print with colours if the library could be imported or not
238+ takes dict
239+ "numpy" : True -> prints success
240+ "pymatgen" : False -> prints failure
203241
204- Returns:
205- _type_: _description_
206242 """
207243 # m is the name of the import module
208244 # state is whether it was importable
@@ -223,13 +259,11 @@ def print_import_states(import_states: dict) -> None:
223259
224260
225261def print_module_states (module_states : dict ) -> None :
226- """_summary_
227-
228- Args:
229- module_states (dict): _description_
230-
231- Returns:
232- _type_: _description_
262+ """
263+ print with colours if all the imports required for module could be imported or not
264+ takes dict
265+ "base" : True -> prints success
266+ "ai-ml" : Fasle -> prints failure
233267 """
234268 # Print out the state of all the modules in colour code
235269 # key is the name of a py4DSTEM Module
@@ -248,25 +282,33 @@ def print_module_states(module_states: dict) -> None:
248282 return None
249283
250284
251- def perfrom_extra_checks (
285+ def perform_extra_checks (
252286 import_states : dict , verbose : bool , gratuitously_verbose : bool , ** kwargs
253287) -> None :
254288 """_summary_
255289
256290 Args:
257- import_states (dict): _description_
258- verbose (bool): _description_
259- gratuitously_verbose (bool): _description_
291+ import_states (dict): dict of modules and if they could be imported or not
292+ verbose (bool): will show module states and all import states
293+ gratuitously_verbose (bool): will run extra checks - Currently only for cupy
260294
261295 Returns:
262296 _type_: _description_
263297 """
264-
265- # print a output module
266- extra_checks_message = "Running Extra Checks"
267- extra_checks_message = create_bold (extra_checks_message )
268- print (f"{ extra_checks_message } " )
269- # For modules that import run any extra checks
298+ if gratuitously_verbose :
299+ # print a output module
300+ extra_checks_message = "Running Extra Checks"
301+ extra_checks_message = create_bold (extra_checks_message )
302+ print (f"{ extra_checks_message } " )
303+ # For modules that import run any extra checks
304+ # get all the dependencies
305+ dependencies = requires ("py4DSTEM" )
306+ # Extract only the module names with versions
307+ depends_with_requirements = [
308+ dependency .split (";" )[0 ] for dependency in dependencies
309+ ]
310+ # print(depends_with_requirements)
311+ # need to go from
270312 for key , val in import_states .items ():
271313 if val :
272314 # s = create_underline(key.capitalize())
@@ -281,7 +323,10 @@ def perfrom_extra_checks(
281323 if gratuitously_verbose :
282324 s = create_underline (key .capitalize ())
283325 print (s )
284- print_no_extra_checks (key )
326+ # check
327+ generic_versions (
328+ key , depends_with_requires = depends_with_requirements
329+ )
285330 else :
286331 pass
287332
@@ -304,7 +349,7 @@ def import_tester(m: str) -> bool:
304349 # try and import the module
305350 try :
306351 importlib .import_module (m )
307- except :
352+ except Exception :
308353 state = False
309354
310355 return state
@@ -324,6 +369,7 @@ def check_module_functionality(state_dict: dict) -> None:
324369
325370 # create an empty dict to put module states into:
326371 module_states = {}
372+ module_depenencies = get_modules_dict ()
327373
328374 # key is the name of the module e.g. ACOM
329375 # val is a list of its dependencies
@@ -359,6 +405,45 @@ def check_module_functionality(state_dict: dict) -> None:
359405#### ADDTIONAL CHECKS ####
360406
361407
408+ def generic_versions (module : str , depends_with_requires : list [str ]) -> None :
409+ # module will be like numpy, skimage
410+ # depends_with_requires look like: numpy >= 19.0, scikit-image
411+ # get module_translated_name
412+ # mapping scikit-image : skimage
413+ for key , value in import_mapping_dict .items ():
414+ # if skimage == skimage get scikit-image
415+ # print(f"{key = } - {value = } - {module = }")
416+ if module in value :
417+ module_depend_name = key
418+ break
419+ else :
420+ # if cant find mapping set the search name to the same
421+ module_depend_name = module
422+ # print(f"{module_depend_name = }")
423+ # find the requirement
424+ for depend in depends_with_requires :
425+ if module_depend_name in depend :
426+ spec_required = depend
427+ # print(f"{spec_required = }")
428+ # get the version installed
429+ spec_installed = find_spec (module )
430+ if spec_installed is None :
431+ s = f"{ module } unable to import - { spec_required } required"
432+ s = create_failure (s )
433+ s = f"{ s : <80} "
434+ print (s )
435+
436+ else :
437+ try :
438+ version = importlib .metadata .version (module_depend_name )
439+ except Exception :
440+ version = "Couldn't test version"
441+ s = f"{ module } imported: { version = } - { spec_required } required"
442+ s = create_warning (s )
443+ s = f"{ s : <80} "
444+ print (s )
445+
446+
362447def check_cupy_gpu (gratuitously_verbose : bool , ** kwargs ):
363448 """
364449 This function performs some additional tests which may be useful in
@@ -375,25 +460,18 @@ def check_cupy_gpu(gratuitously_verbose: bool, **kwargs):
375460 # check that CUDA is detected correctly
376461 cuda_availability = cp .cuda .is_available ()
377462 if cuda_availability :
378- s = " CUDA is Available "
463+ s = f " CUDA is Available "
379464 s = create_success (s )
380465 s = f"{ s : <80} "
381466 print (s )
382467 else :
383- s = " CUDA is Unavailable "
468+ s = f " CUDA is Unavailable "
384469 s = create_failure (s )
385470 s = f"{ s : <80} "
386471 print (s )
387472
388473 # Count how many GPUs Cupy can detect
389- # probably should change this to a while loop ...
390- for i in range (24 ):
391- try :
392- d = cp .cuda .Device (i )
393- hasattr (d , "attributes" )
394- except :
395- num_gpus_detected = i
396- break
474+ num_gpus_detected = cp .cuda .runtime .getDeviceCount ()
397475
398476 # print how many GPUs were detected, filter for a couple of special conditons
399477 if num_gpus_detected == 0 :
@@ -448,7 +526,9 @@ def print_no_extra_checks(m: str):
448526
449527
450528# dict of extra check functions
451- funcs_dict = {"cupy" : check_cupy_gpu }
529+ funcs_dict = {
530+ "cupy" : check_cupy_gpu ,
531+ }
452532
453533
454534#### main function used to check the configuration of the installation
@@ -493,7 +573,7 @@ def check_config(
493573
494574 print_import_states (states_dict )
495575
496- perfrom_extra_checks (
576+ perform_extra_checks (
497577 import_states = states_dict ,
498578 verbose = verbose ,
499579 gratuitously_verbose = gratuitously_verbose ,
0 commit comments