Skip to content

Commit 3982de0

Browse files
committed
add --count and --batch args for data_export.py
1 parent 19d9f8c commit 3982de0

File tree

2 files changed

+11
-4
lines changed

2 files changed

+11
-4
lines changed

ann_benchmarks/results.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ def load_all_results(dataset: Optional[str] = None,
8484
Yields:
8585
tuple: A tuple containing properties as a dictionary and an h5py file object.
8686
"""
87-
for root, _, files in os.walk(build_result_filepath(dataset, count)):
87+
for root, _, files in os.walk(build_result_filepath(dataset, count, batch_mode=batch_mode)):
8888
for filename in files:
8989
if os.path.splitext(filename)[-1] != ".hdf5":
9090
continue
@@ -110,4 +110,4 @@ def get_unique_algorithms() -> Set[str]:
110110
for batch_mode in [False, True]:
111111
for properties, _ in load_all_results(batch_mode=batch_mode):
112112
algorithms.add(properties["algo"])
113-
return algorithms
113+
return algorithms

data_export.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,21 @@
99
parser = argparse.ArgumentParser()
1010
parser.add_argument("--output", help="Path to the output file", required=True)
1111
parser.add_argument("--recompute", action="store_true", help="Recompute metrics")
12+
parser.add_argument(
13+
"-k", "--count", default=10, type=int, help="The number of near neighbours to search for"
14+
)
15+
parser.add_argument("--batch", action="store_true", help="Batch mode")
1216
args = parser.parse_args()
1317

1418
datasets = DATASETS.keys()
1519
dfs = []
1620
for dataset_name in datasets:
1721
print("Looking at dataset", dataset_name)
18-
if len(list(load_all_results(dataset_name))) > 0:
19-
results = load_all_results(dataset_name)
22+
if len(list(load_all_results(dataset_name,
23+
count=args.count,
24+
batch_mode=args.batch
25+
))) > 0:
26+
results = load_all_results(dataset_name, count=args.count, batch_mode=args.batch)
2027
dataset, _ = get_dataset(dataset_name)
2128
results = compute_metrics_all_runs(dataset, results, args.recompute)
2229
for res in results:

0 commit comments

Comments
 (0)