diff --git a/fred/fred_commands/_command_utils.py b/fred/fred_commands/_command_utils.py index 95f883b..d9b91ed 100644 --- a/fred/fred_commands/_command_utils.py +++ b/fred/fred_commands/_command_utils.py @@ -1,6 +1,6 @@ from typing import Type -from regex import ENHANCEMATCH, match, escape +from regex import ENHANCEMATCH, match, escape, search as re_search from ..config import Commands, Crashes, Misc from ..libraries.common import new_logger @@ -8,8 +8,7 @@ logger = new_logger("[Command/Crash Search]") -def search(table: Type[Commands | Crashes], pattern: str, column: str, force_fuzzy: bool) -> (str | list[str], bool): - """Returns the top three results based on the result""" +def search(table: Type[Commands | Crashes], pattern: str, column: str, force_fuzzy: bool) -> tuple[str | list[str], bool]: if column not in dir(table): raise KeyError(f"`{column}` is not a column in the {table.__name__} table!") @@ -17,14 +16,35 @@ def search(table: Type[Commands | Crashes], pattern: str, column: str, force_fuz if not force_fuzzy and (exact_match := table.fetch_by(column, pattern)): return exact_match[column], True - fuzzy_pattern = rf".*(?:{escape(pattern)}){{e<={min(len(pattern) // 3, 6)}}}.*" - fuzzies: list[str] = [ - item["name"] - for item in table.fetch_all() - if (item.get(column, None) is not None) and match(fuzzy_pattern, item[column], flags=ENHANCEMATCH) - ] - logger.info(fuzzies) - return fuzzies[:5], False + if len(pattern) < 2: + raise KeyError("Search pattern must be at least 2 characters long for fuzzy searching!") + + # Set fuzzy range - (1/3 pattern length, max 6) + max_edits = min(len(pattern) // 3, 6) + substring_pattern = rf".*(?:{escape(pattern)}){{e<={max_edits}}}.*" + + scored_results: list[tuple[int, str]] = [] + for item in table.fetch_all(): + value = item.get(column) + + # Filter non matching strings + if not isinstance(value, str): + continue + if not re_search(substring_pattern, value, flags=ENHANCEMATCH): + continue + + # add levenshtein score + score = levenshtein(pattern, value) + scored_results.append((score, item["name"])) + + # Sort by score, then alphabetically + scored_results.sort(key=lambda x: (x[0], x[1])) + results = [name for _, name in scored_results] + + # Return all results fitting fuzzy range + logger.info(results) + return results, False + def get_search(table: Type[Commands | Crashes], pattern: str, column: str, force_fuzzy: bool) -> str: @@ -54,3 +74,25 @@ def get_search(table: Type[Commands | Crashes], pattern: str, column: str, force response = e.args[0] return response + + +# Levenshtein distance algorithm +def levenshtein(a: str, b: str) -> int: + if a == b: + return 0 + la, lb = len(a), len(b) + if la == 0: + return lb + if lb == 0: + return la + + prev = list(range(lb + 1)) + for i, ca in enumerate(a, start=1): + cur = [i] + [0] * lb + for j, cb in enumerate(b, start=1): + ins = cur[j - 1] + 1 + delete = prev[j] + 1 + sub = prev[j - 1] + (0 if ca == cb else 1) + cur[j] = min(ins, delete, sub) + prev = cur + return prev[lb] \ No newline at end of file