2626
2727
2828def vector_search (
29- query : str | FloatMatrix , * , num_results : int = 3 , config : RAGLiteConfig | None = None
29+ query : str | FloatMatrix ,
30+ * ,
31+ num_results : int = 3 ,
32+ oversample : int = 8 ,
33+ config : RAGLiteConfig | None = None ,
3034) -> tuple [list [ChunkId ], list [float ]]:
3135 """Search chunks using ANN vector search."""
3236 # Read the config.
@@ -57,7 +61,9 @@ def vector_search(
5761 )
5862 distance = distance_func (query_embedding ).label ("distance" )
5963 results = session .exec (
60- select (ChunkEmbedding .chunk_id , distance ).order_by (distance ).limit (8 * num_results )
64+ select (ChunkEmbedding .chunk_id , distance )
65+ .order_by (distance )
66+ .limit (oversample * num_results )
6167 )
6268 chunk_ids_ , distance = zip (* results , strict = True )
6369 chunk_ids , similarity = np .asarray (chunk_ids_ ), 1.0 - np .asarray (distance )
@@ -70,7 +76,7 @@ def vector_search(
7076 from pynndescent import NNDescent
7177
7278 multi_vector_indices , distance = cast (NNDescent , index ).query (
73- query_embedding [np .newaxis , :], k = 8 * num_results
79+ query_embedding [np .newaxis , :], k = oversample * num_results
7480 )
7581 similarity = 1 - distance [0 , :]
7682 # Transform the multi-vector indices into chunk indices, and then to chunk ids.
@@ -105,36 +111,32 @@ def keyword_search(
105111 if db_backend == "postgresql" :
106112 # Convert the query to a tsquery [1].
107113 # [1] https://www.postgresql.org/docs/current/textsearch-controls.html
108- query_escaped = re .sub (r"[&|!():<>\" ]" , " " , query )
114+ query_escaped = re .sub (f"[ { re . escape ( string . punctuation ) } ]" , " " , query )
109115 tsv_query = " | " .join (query_escaped .split ())
110116 # Perform keyword search with tsvector.
111- statement = text (
112- """
117+ statement = text ("""
113118 SELECT id as chunk_id, ts_rank(to_tsvector('simple', body), to_tsquery('simple', :query)) AS score
114119 FROM chunk
115120 WHERE to_tsvector('simple', body) @@ to_tsquery('simple', :query)
116121 ORDER BY score DESC
117122 LIMIT :limit;
118- """
119- )
123+ """ )
120124 results = session .execute (statement , params = {"query" : tsv_query , "limit" : num_results })
121125 elif db_backend == "sqlite" :
122126 # Convert the query to an FTS5 query [1].
123127 # [1] https://www.sqlite.org/fts5.html#full_text_query_syntax
124- query_escaped = re .sub (f"[{ re .escape (string .punctuation )} ]" , "" , query )
128+ query_escaped = re .sub (f"[{ re .escape (string .punctuation )} ]" , " " , query )
125129 fts5_query = " OR " .join (query_escaped .split ())
126130 # Perform keyword search with FTS5. In FTS5, BM25 scores are negative [1], so we
127131 # negate them to make them positive.
128132 # [1] https://www.sqlite.org/fts5.html#the_bm25_function
129- statement = text (
130- """
133+ statement = text ("""
131134 SELECT chunk.id as chunk_id, -bm25(keyword_search_chunk_index) as score
132135 FROM chunk JOIN keyword_search_chunk_index ON chunk.rowid = keyword_search_chunk_index.rowid
133136 WHERE keyword_search_chunk_index MATCH :match
134137 ORDER BY score DESC
135138 LIMIT :limit;
136- """
137- )
139+ """ )
138140 results = session .execute (statement , params = {"match" : fts5_query , "limit" : num_results })
139141 # Unpack the results.
140142 results = list (results ) # type: ignore[assignment]
@@ -162,12 +164,12 @@ def reciprocal_rank_fusion(
162164
163165
164166def hybrid_search (
165- query : str , * , num_results : int = 3 , num_rerank : int = 100 , config : RAGLiteConfig | None = None
167+ query : str , * , num_results : int = 3 , oversample : int = 4 , config : RAGLiteConfig | None = None
166168) -> tuple [list [ChunkId ], list [float ]]:
167169 """Search chunks by combining ANN vector search with BM25 keyword search."""
168170 # Run both searches.
169- vs_chunk_ids , _ = vector_search (query , num_results = num_rerank , config = config )
170- ks_chunk_ids , _ = keyword_search (query , num_results = num_rerank , config = config )
171+ vs_chunk_ids , _ = vector_search (query , num_results = oversample * num_results , config = config )
172+ ks_chunk_ids , _ = keyword_search (query , num_results = oversample * num_results , config = config )
171173 # Combine the results with Reciprocal Rank Fusion (RRF).
172174 chunk_ids , hybrid_score = reciprocal_rank_fusion ([vs_chunk_ids , ks_chunk_ids ])
173175 chunk_ids , hybrid_score = chunk_ids [:num_results ], hybrid_score [:num_results ]
0 commit comments