Skip to content

Commit df8083a

Browse files
authored
Fix logic in _get_algorithm_definitions to avoid skipping algorithm definitions (#498)
1 parent 4c8b1c1 commit df8083a

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

ann_benchmarks/definitions.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -166,8 +166,7 @@ def _get_algorithm_definitions(point_type: str, distance_metric: str) -> Dict[st
166166
metric. For example, `ann_benchmarks.algorithms.nmslib` has two definitions for euclidean float
167167
data: specifically `SW-graph(nmslib)` and `hnsw(nmslib)`, even though the module is named nmslib.
168168
169-
If an algorithm has an 'any' distance metric is found for the specific point type, it is used
170-
regardless (and takes precendence) over if the distance metric is present.
169+
If an algorithm has an 'any' distance metric, it is also included.
171170
172171
Returns: A mapping from the algorithm name (not the algorithm class), to the algorithm definitions, i.e.:
173172
```
@@ -195,11 +194,10 @@ def _get_algorithm_definitions(point_type: str, distance_metric: str) -> Dict[st
195194
# param `_` is filename, not specific name
196195
for _, config in configs.items():
197196
c = []
198-
if "any" in config: # "any" branch must come first
199-
c = config["any"]
200-
elif distance_metric in config:
201-
c = config[distance_metric]
202-
197+
if "any" in config:
198+
c.extend(config["any"])
199+
if distance_metric in config:
200+
c.extend(config[distance_metric])
203201
for cc in c:
204202
definitions[cc.pop("name")] = cc
205203

@@ -359,4 +357,4 @@ def get_definitions(
359357
)
360358

361359

362-
return definitions
360+
return definitions

0 commit comments

Comments
 (0)