Skip to content

Commit 92969ed

Browse files
authored
Merge pull request #596 from alex-rakowski/config_checker
Update check_config workflow
2 parents 3dd902e + 4814164 commit 92969ed

File tree

1 file changed

+175
-95
lines changed

1 file changed

+175
-95
lines changed

py4DSTEM/utils/configuration_checker.py

Lines changed: 175 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,96 @@
11
#### this file contains a function/s that will check if various
22
# libaries/compute options are available
33
import importlib
4-
from operator import mod
5-
6-
# list of modules we expect/may expect to be installed
7-
# as part of a standard py4DSTEM installation
8-
# this needs to be the import name e.g. import mp_api not mp-api
9-
modules = [
10-
"crystal4D",
11-
"cupy",
12-
"dask",
13-
"dill",
14-
"distributed",
15-
"gdown",
16-
"h5py",
17-
"ipyparallel",
18-
"jax",
19-
"matplotlib",
20-
"mp_api",
21-
"ncempy",
22-
"numba",
23-
"numpy",
24-
"pymatgen",
25-
"skimage",
26-
"sklearn",
27-
"scipy",
28-
"tensorflow",
29-
"tensorflow-addons",
30-
"tqdm",
31-
]
32-
33-
# currently this was copy and pasted from setup.py,
34-
# hopefully there's a programatic way to do this.
35-
module_depenencies = {
36-
"base": [
37-
"numpy",
38-
"scipy",
39-
"h5py",
40-
"ncempy",
41-
"matplotlib",
42-
"skimage",
43-
"sklearn",
44-
"tqdm",
45-
"dill",
46-
"gdown",
47-
"dask",
48-
"distributed",
49-
],
50-
"ipyparallel": ["ipyparallel", "dill"],
51-
"cuda": ["cupy"],
52-
"acom": ["pymatgen", "mp_api"],
53-
"aiml": ["tensorflow", "tensorflow-addons", "crystal4D"],
54-
"aiml-cuda": ["tensorflow", "tensorflow-addons", "crystal4D", "cupy"],
55-
"numba": ["numba"],
4+
from importlib.metadata import requires
5+
import re
6+
from importlib.util import find_spec
7+
8+
# need a mapping of pypi/conda names to import names
9+
import_mapping_dict = {
10+
"scikit-image": "skimage",
11+
"scikit-learn": "sklearn",
12+
"scikit-optimize": "skopt",
13+
"mp-api": "mp_api",
5614
}
5715

5816

17+
# programatically get all possible requirements in the import name style
18+
def get_modules_list():
19+
# Get the dependencies from the installed distribution
20+
dependencies = requires("py4DSTEM")
21+
22+
# Define a regular expression pattern for splitting on '>', '>=', '='
23+
delimiter_pattern = re.compile(r">=|>|==|<|<=")
24+
25+
# Extract only the module names without versions
26+
module_names = [
27+
delimiter_pattern.split(dependency.split(";")[0], 1)[0].strip()
28+
for dependency in dependencies
29+
]
30+
31+
# translate pypi names to import names e.g. scikit-image->skimage, mp-api->mp_api
32+
for index, module in enumerate(module_names):
33+
if module in import_mapping_dict.keys():
34+
module_names[index] = import_mapping_dict[module]
35+
36+
return module_names
37+
38+
39+
# programatically get all possible requirements in the import name style,
40+
# split into a dict where optional import names are keys
41+
def get_modules_dict():
42+
package_name = "py4DSTEM"
43+
# Get the dependencies from the installed distribution
44+
dependencies = requires(package_name)
45+
46+
# set the dictionary for modules and packages to go into
47+
# optional dependencies will be added as they are discovered
48+
modules_dict = {
49+
"base": [],
50+
}
51+
# loop over the dependencies
52+
for depend in dependencies:
53+
# all the optional have extra in the name
54+
# if its not there append it to base
55+
if "extra" not in depend:
56+
# String looks like: 'numpy>=1.19'
57+
modules_dict["base"].append(depend)
58+
59+
# if it has extra in the string
60+
else:
61+
# get the name of the optional name
62+
# depend looks like this 'numba>=0.49.1; extra == "numba"'
63+
# grab whatever is in the double quotes i.e. numba
64+
optional_name = re.search(r'"(.*?)"', depend).group(1)
65+
# if the optional name is not in the dict as a key i.e. first requirement of hte optional dependency
66+
if optional_name not in modules_dict:
67+
modules_dict[optional_name] = [depend]
68+
# if the optional_name is already in the dict then just append it to the list
69+
else:
70+
modules_dict[optional_name].append(depend)
71+
# STRIP all the versioning and semi-colons
72+
# Define a regular expression pattern for splitting on '>', '>=', '='
73+
delimiter_pattern = re.compile(r">=|>|==|<|<=")
74+
for key, val in modules_dict.items():
75+
# modules_dict[key] = [dependency.split(';')[0].split(' ')[0] for dependency in val]
76+
modules_dict[key] = [
77+
delimiter_pattern.split(dependency.split(";")[0], 1)[0].strip()
78+
for dependency in val
79+
]
80+
81+
# translate pypi names to import names e.g. scikit-image->skimage, mp-api->mp_api
82+
for key, val in modules_dict.items():
83+
for index, module in enumerate(val):
84+
if module in import_mapping_dict.keys():
85+
val[index] = import_mapping_dict[module]
86+
87+
return modules_dict
88+
89+
90+
# module_depenencies = get_modules_dict()
91+
modules = get_modules_list()
92+
93+
5994
#### Class and Functions to Create Coloured Strings ####
6095
class colours:
6196
CEND = "\x1b[0m"
@@ -140,6 +175,7 @@ def create_underline(s: str) -> str:
140175
### here I use the term state to define a boolean condition as to whether a libary/module was sucessfully imported/can be used
141176

142177

178+
# get the state of each modules as a dict key-val e.g. "numpy" : True
143179
def get_import_states(modules: list = modules) -> dict:
144180
"""
145181
Check the ability to import modules and store the results as a boolean value. Returns as a dict.
@@ -163,16 +199,17 @@ def get_import_states(modules: list = modules) -> dict:
163199
return import_states_dict
164200

165201

202+
# Check
166203
def get_module_states(state_dict: dict) -> dict:
167-
"""_summary_
168-
169-
Args:
170-
state_dict (dict): _description_
204+
"""
205+
given a state dict for all modules e.g. "numpy" : True,
206+
this parses through and checks if all modules required for a state are true
171207
172-
Returns:
173-
dict: _description_
208+
returns dict "base": True, "ai-ml": False etc.
174209
"""
175210

211+
# get the modules_dict
212+
module_depenencies = get_modules_dict()
176213
# create an empty dict to put module states into:
177214
module_states = {}
178215

@@ -196,13 +233,12 @@ def get_module_states(state_dict: dict) -> dict:
196233

197234

198235
def print_import_states(import_states: dict) -> None:
199-
"""_summary_
200-
201-
Args:
202-
import_states (dict): _description_
236+
"""
237+
print with colours if the library could be imported or not
238+
takes dict
239+
"numpy" : True -> prints success
240+
"pymatgen" : False -> prints failure
203241
204-
Returns:
205-
_type_: _description_
206242
"""
207243
# m is the name of the import module
208244
# state is whether it was importable
@@ -223,13 +259,11 @@ def print_import_states(import_states: dict) -> None:
223259

224260

225261
def print_module_states(module_states: dict) -> None:
226-
"""_summary_
227-
228-
Args:
229-
module_states (dict): _description_
230-
231-
Returns:
232-
_type_: _description_
262+
"""
263+
print with colours if all the imports required for module could be imported or not
264+
takes dict
265+
"base" : True -> prints success
266+
"ai-ml" : Fasle -> prints failure
233267
"""
234268
# Print out the state of all the modules in colour code
235269
# key is the name of a py4DSTEM Module
@@ -248,25 +282,33 @@ def print_module_states(module_states: dict) -> None:
248282
return None
249283

250284

251-
def perfrom_extra_checks(
285+
def perform_extra_checks(
252286
import_states: dict, verbose: bool, gratuitously_verbose: bool, **kwargs
253287
) -> None:
254288
"""_summary_
255289
256290
Args:
257-
import_states (dict): _description_
258-
verbose (bool): _description_
259-
gratuitously_verbose (bool): _description_
291+
import_states (dict): dict of modules and if they could be imported or not
292+
verbose (bool): will show module states and all import states
293+
gratuitously_verbose (bool): will run extra checks - Currently only for cupy
260294
261295
Returns:
262296
_type_: _description_
263297
"""
264-
265-
# print a output module
266-
extra_checks_message = "Running Extra Checks"
267-
extra_checks_message = create_bold(extra_checks_message)
268-
print(f"{extra_checks_message}")
269-
# For modules that import run any extra checks
298+
if gratuitously_verbose:
299+
# print a output module
300+
extra_checks_message = "Running Extra Checks"
301+
extra_checks_message = create_bold(extra_checks_message)
302+
print(f"{extra_checks_message}")
303+
# For modules that import run any extra checks
304+
# get all the dependencies
305+
dependencies = requires("py4DSTEM")
306+
# Extract only the module names with versions
307+
depends_with_requirements = [
308+
dependency.split(";")[0] for dependency in dependencies
309+
]
310+
# print(depends_with_requirements)
311+
# need to go from
270312
for key, val in import_states.items():
271313
if val:
272314
# s = create_underline(key.capitalize())
@@ -281,7 +323,10 @@ def perfrom_extra_checks(
281323
if gratuitously_verbose:
282324
s = create_underline(key.capitalize())
283325
print(s)
284-
print_no_extra_checks(key)
326+
# check
327+
generic_versions(
328+
key, depends_with_requires=depends_with_requirements
329+
)
285330
else:
286331
pass
287332

@@ -304,7 +349,7 @@ def import_tester(m: str) -> bool:
304349
# try and import the module
305350
try:
306351
importlib.import_module(m)
307-
except:
352+
except Exception:
308353
state = False
309354

310355
return state
@@ -324,6 +369,7 @@ def check_module_functionality(state_dict: dict) -> None:
324369

325370
# create an empty dict to put module states into:
326371
module_states = {}
372+
module_depenencies = get_modules_dict()
327373

328374
# key is the name of the module e.g. ACOM
329375
# val is a list of its dependencies
@@ -359,6 +405,45 @@ def check_module_functionality(state_dict: dict) -> None:
359405
#### ADDTIONAL CHECKS ####
360406

361407

408+
def generic_versions(module: str, depends_with_requires: list[str]) -> None:
409+
# module will be like numpy, skimage
410+
# depends_with_requires look like: numpy >= 19.0, scikit-image
411+
# get module_translated_name
412+
# mapping scikit-image : skimage
413+
for key, value in import_mapping_dict.items():
414+
# if skimage == skimage get scikit-image
415+
# print(f"{key = } - {value = } - {module = }")
416+
if module in value:
417+
module_depend_name = key
418+
break
419+
else:
420+
# if cant find mapping set the search name to the same
421+
module_depend_name = module
422+
# print(f"{module_depend_name = }")
423+
# find the requirement
424+
for depend in depends_with_requires:
425+
if module_depend_name in depend:
426+
spec_required = depend
427+
# print(f"{spec_required = }")
428+
# get the version installed
429+
spec_installed = find_spec(module)
430+
if spec_installed is None:
431+
s = f"{module} unable to import - {spec_required} required"
432+
s = create_failure(s)
433+
s = f"{s: <80}"
434+
print(s)
435+
436+
else:
437+
try:
438+
version = importlib.metadata.version(module_depend_name)
439+
except Exception:
440+
version = "Couldn't test version"
441+
s = f"{module} imported: {version = } - {spec_required} required"
442+
s = create_warning(s)
443+
s = f"{s: <80}"
444+
print(s)
445+
446+
362447
def check_cupy_gpu(gratuitously_verbose: bool, **kwargs):
363448
"""
364449
This function performs some additional tests which may be useful in
@@ -375,25 +460,18 @@ def check_cupy_gpu(gratuitously_verbose: bool, **kwargs):
375460
# check that CUDA is detected correctly
376461
cuda_availability = cp.cuda.is_available()
377462
if cuda_availability:
378-
s = " CUDA is Available "
463+
s = f" CUDA is Available "
379464
s = create_success(s)
380465
s = f"{s: <80}"
381466
print(s)
382467
else:
383-
s = " CUDA is Unavailable "
468+
s = f" CUDA is Unavailable "
384469
s = create_failure(s)
385470
s = f"{s: <80}"
386471
print(s)
387472

388473
# Count how many GPUs Cupy can detect
389-
# probably should change this to a while loop ...
390-
for i in range(24):
391-
try:
392-
d = cp.cuda.Device(i)
393-
hasattr(d, "attributes")
394-
except:
395-
num_gpus_detected = i
396-
break
474+
num_gpus_detected = cp.cuda.runtime.getDeviceCount()
397475

398476
# print how many GPUs were detected, filter for a couple of special conditons
399477
if num_gpus_detected == 0:
@@ -448,7 +526,9 @@ def print_no_extra_checks(m: str):
448526

449527

450528
# dict of extra check functions
451-
funcs_dict = {"cupy": check_cupy_gpu}
529+
funcs_dict = {
530+
"cupy": check_cupy_gpu,
531+
}
452532

453533

454534
#### main function used to check the configuration of the installation
@@ -493,7 +573,7 @@ def check_config(
493573

494574
print_import_states(states_dict)
495575

496-
perfrom_extra_checks(
576+
perform_extra_checks(
497577
import_states=states_dict,
498578
verbose=verbose,
499579
gratuitously_verbose=gratuitously_verbose,

0 commit comments

Comments
 (0)