Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 53 additions & 11 deletions fred/fred_commands/_command_utils.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,50 @@
from typing import Type

from regex import ENHANCEMATCH, match, escape
from regex import ENHANCEMATCH, match, escape, search as re_search

from ..config import Commands, Crashes, Misc
from ..libraries.common import new_logger

logger = new_logger("[Command/Crash Search]")


def search(table: Type[Commands | Crashes], pattern: str, column: str, force_fuzzy: bool) -> (str | list[str], bool):
"""Returns the top three results based on the result"""
def search(table: Type[Commands | Crashes], pattern: str, column: str, force_fuzzy: bool) -> tuple[str | list[str], bool]:

if column not in dir(table):
raise KeyError(f"`{column}` is not a column in the {table.__name__} table!")

if not force_fuzzy and (exact_match := table.fetch_by(column, pattern)):
return exact_match[column], True

fuzzy_pattern = rf".*(?:{escape(pattern)}){{e<={min(len(pattern) // 3, 6)}}}.*"
fuzzies: list[str] = [
item["name"]
for item in table.fetch_all()
if (item.get(column, None) is not None) and match(fuzzy_pattern, item[column], flags=ENHANCEMATCH)
]
logger.info(fuzzies)
return fuzzies[:5], False
if len(pattern) < 2:
raise KeyError("Search pattern must be at least 2 characters long for fuzzy searching!")

# Set fuzzy range - (1/3 pattern length, max 6)
max_edits = min(len(pattern) // 3, 6)
substring_pattern = rf".*(?:{escape(pattern)}){{e<={max_edits}}}.*"

scored_results: list[tuple[int, str]] = []
for item in table.fetch_all():
value = item.get(column)

# Filter non matching strings
if not isinstance(value, str):
continue
if not re_search(substring_pattern, value, flags=ENHANCEMATCH):
continue

# add levenshtein score
score = levenshtein(pattern, value)
scored_results.append((score, item["name"]))

# Sort by score, then alphabetically
scored_results.sort(key=lambda x: (x[0], x[1]))
results = [name for _, name in scored_results]

# Return all results fitting fuzzy range
logger.info(results)
return results, False



def get_search(table: Type[Commands | Crashes], pattern: str, column: str, force_fuzzy: bool) -> str:
Expand Down Expand Up @@ -54,3 +74,25 @@ def get_search(table: Type[Commands | Crashes], pattern: str, column: str, force
response = e.args[0]

return response


# Levenshtein distance algorithm
def levenshtein(a: str, b: str) -> int:
if a == b:
return 0
la, lb = len(a), len(b)
if la == 0:
return lb
if lb == 0:
return la

prev = list(range(lb + 1))
for i, ca in enumerate(a, start=1):
cur = [i] + [0] * lb
for j, cb in enumerate(b, start=1):
ins = cur[j - 1] + 1
delete = prev[j] + 1
sub = prev[j - 1] + (0 if ca == cb else 1)
cur[j] = min(ins, delete, sub)
prev = cur
return prev[lb]