Skip to content

Commit 3dd902e

Browse files
authored
Merge pull request #588 from alex-rakowski/fcu-net
Making FCU-Net compatible with 14.9
2 parents 4243c33 + 300aaa3 commit 3dd902e

File tree

4 files changed

+31
-15
lines changed

4 files changed

+31
-15
lines changed

py4DSTEM/braggvectors/diskdetection_aiml.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
import json
99
import shutil
1010
import numpy as np
11+
from pathlib import Path
12+
1113

1214
from scipy.ndimage import gaussian_filter
1315
from time import time
@@ -437,9 +439,9 @@ def find_Bragg_disks_aiml_serial(
437439
raise ImportError("Import Error: Please install crystal4D before proceeding")
438440

439441
# Make the peaks PointListArray
440-
# dtype = [('qx',float),('qy',float),('intensity',float)]
441-
peaks = BraggVectors(datacube.Rshape, datacube.Qshape)
442-
442+
dtype = [("qx", float), ("qy", float), ("intensity", float)]
443+
# peaks = BraggVectors(datacube.Rshape, datacube.Qshape)
444+
peaks = PointListArray(dtype=dtype, shape=(datacube.R_Nx, datacube.R_Ny))
443445
# check that the filtered DP is the right size for the probe kernel:
444446
if filter_function:
445447
assert callable(filter_function), "filter_function must be callable"
@@ -518,7 +520,7 @@ def find_Bragg_disks_aiml_serial(
518520
subpixel=subpixel,
519521
upsample_factor=upsample_factor,
520522
filter_function=filter_function,
521-
peaks=peaks.vectors_uncal.get_pointlist(Rx, Ry),
523+
peaks=peaks.get_pointlist(Rx, Ry),
522524
model_path=model_path,
523525
)
524526
t2 = time() - t0
@@ -884,7 +886,7 @@ def _get_latest_model(model_path=None):
884886
+ "https://www.tensorflow.org/install"
885887
+ "for more information"
886888
)
887-
from py4DSTEM.io.google_drive_downloader import download_file_from_google_drive
889+
from py4DSTEM.io.google_drive_downloader import gdrive_download
888890

889891
tf.keras.backend.clear_session()
890892

@@ -894,7 +896,12 @@ def _get_latest_model(model_path=None):
894896
except:
895897
pass
896898
# download the json file with the meta data
897-
download_file_from_google_drive("FCU-Net", "./tmp/model_metadata.json")
899+
gdrive_download(
900+
"FCU-Net",
901+
destination="./tmp/",
902+
filename="model_metadata.json",
903+
overwrite=True,
904+
)
898905
with open("./tmp/model_metadata.json") as f:
899906
metadata = json.load(f)
900907
file_id = metadata["file_id"]
@@ -918,7 +925,8 @@ def _get_latest_model(model_path=None):
918925
else:
919926
print("Checking the latest model on the cloud... \n")
920927
filename = file_path + file_type
921-
download_file_from_google_drive(file_id, filename)
928+
filename = Path(filename)
929+
gdrive_download(file_id, destination="./tmp", filename=filename.name)
922930
try:
923931
shutil.unpack_archive(filename, "./tmp", format="zip")
924932
except:

py4DSTEM/braggvectors/diskdetection_aiml_cuda.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -124,8 +124,9 @@ def find_Bragg_disks_aiml_CUDA(
124124
"""
125125

126126
# Make the peaks PointListArray
127-
# dtype = [('qx',float),('qy',float),('intensity',float)]
128-
peaks = BraggVectors(datacube.Rshape, datacube.Qshape)
127+
dtype = [("qx", float), ("qy", float), ("intensity", float)]
128+
# peaks = BraggVectors(datacube.Rshape, datacube.Qshape)
129+
peaks = PointListArray(dtype=dtype, shape=(datacube.R_Nx, datacube.R_Ny))
129130

130131
# check that the filtered DP is the right size for the probe kernel:
131132
if filter_function:
@@ -221,7 +222,7 @@ def find_Bragg_disks_aiml_CUDA(
221222
subpixel=subpixel,
222223
upsample_factor=upsample_factor,
223224
filter_function=filter_function,
224-
peaks=peaks.vectors_uncal.get_pointlist(Rx, Ry),
225+
peaks=peaks.get_pointlist(Rx, Ry),
225226
get_maximal_points=get_maximal_points,
226227
blocks=blocks,
227228
threads=threads,

py4DSTEM/io/google_drive_downloader.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@
5757
),
5858
"small_dm3_3Dstack": ("small_dm3_3Dstack.dm3", "1B-xX3F65JcWzAg0v7f1aVwnawPIfb5_o"),
5959
"FCU-Net": (
60-
"filename.name",
60+
"model_metadata.json",
6161
"1-KX0saEYfhZ9IJAOwabH38PCVtfXidJi",
6262
),
6363
"small_datacube": (
@@ -221,7 +221,8 @@ def gdrive_download(
221221
kwargs = {"fuzzy": True}
222222
if id_ in file_ids:
223223
f = file_ids[id_]
224-
filename = f[0]
224+
# Use the name in the collection filename passed
225+
filename = filename if filename is not None else f[0]
225226
kwargs["id"] = f[1]
226227

227228
# if its not in the list of files we expect

setup.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,18 @@
4747
"ipyparallel": ["ipyparallel >= 6.2.4", "dill >= 0.3.3"],
4848
"cuda": ["cupy >= 10.0.0"],
4949
"acom": ["pymatgen >= 2022", "mp-api == 0.24.1"],
50-
"aiml": ["tensorflow == 2.4.1", "tensorflow-addons <= 0.14.0", "crystal4D"],
50+
"aiml": [
51+
"tensorflow <= 2.10.0",
52+
"tensorflow-addons <= 0.16.1",
53+
"crystal4D",
54+
"typeguard == 2.7",
55+
],
5156
"aiml-cuda": [
52-
"tensorflow == 2.4.1",
53-
"tensorflow-addons <= 0.14.0",
57+
"tensorflow <= 2.10.0",
58+
"tensorflow-addons <= 0.16.1",
5459
"crystal4D",
5560
"cupy >= 10.0.0",
61+
"typeguard == 2.7",
5662
],
5763
"numba": ["numba >= 0.49.1"],
5864
},

0 commit comments

Comments
 (0)