diff --git a/README.md b/README.md index 72febc99..a81ee5a6 100644 --- a/README.md +++ b/README.md @@ -230,7 +230,7 @@ chunk_spans = retrieve_context(query=user_prompt, num_chunks=5, config=my_config # Append a RAG instruction based on the user prompt and context to the message history messages = [] # Or start with an existing message history -messages.append(add_context(user_prompt=user_prompt, context=chunk_spans)) +messages.append(add_context(user_prompt=user_prompt, context=chunk_spans, config=my_config)) # Stream the RAG response and append it to the message history stream = rag(messages, config=my_config) @@ -285,7 +285,7 @@ chunk_spans = retrieve_chunk_spans(chunks_reranked, config=my_config) from raglite import add_context messages = [] # Or start with an existing message history -messages.append(add_context(user_prompt=user_prompt, context=chunk_spans)) +messages.append(add_context(user_prompt=user_prompt, context=chunk_spans, config=my_config)) # Stream the RAG response and append it to the message history from raglite import rag diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index a40f9667..de372222 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -212,7 +212,7 @@ def answer_evals( contexts: list[list[str]] = [] for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True): chunk_spans = retrieve_context(query=eval_.question, config=config) - messages = [add_context(user_prompt=eval_.question, context=chunk_spans)] + messages = [add_context(user_prompt=eval_.question, context=chunk_spans, config=config)] response = rag(messages, config=config) answer = "".join(response) answers.append(answer) diff --git a/src/raglite/_mcp.py b/src/raglite/_mcp.py index ec97364f..0a2a4c3f 100644 --- a/src/raglite/_mcp.py +++ b/src/raglite/_mcp.py @@ -27,7 +27,7 @@ def create_mcp_server(server_name: str, *, config: RAGLiteConfig) -> FastMCP[Any def kb(query: Query) -> str: """Answer a question with information from the knowledge base.""" chunk_spans = retrieve_context(query, config=config) - rag_instruction = add_context(query, chunk_spans) + rag_instruction = add_context(query, chunk_spans, config) return rag_instruction["content"] @mcp.tool() diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 25fc2bd4..2541e3f2 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -1,7 +1,9 @@ """Retrieval-augmented generation.""" import json -from collections.abc import AsyncIterator, Callable, Iterator +import logging +import warnings +from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence from typing import Any import numpy as np @@ -19,6 +21,8 @@ from raglite._search import retrieve_chunk_spans from raglite._typing import MetadataFilter +logger = logging.getLogger(__name__) + # The default RAG instruction template follows Anthropic's best practices [1]. # [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips RAG_INSTRUCTION_TEMPLATE = """ @@ -61,9 +65,125 @@ def retrieve_context( return chunk_spans +def _count_tokens(item: str) -> int: + """Estimate the number of tokens in an item.""" + return len(item) // 3 + + +def _get_last_message_idx(messages: list[dict[str, str]], role: str) -> int | None: + """Get the index of the last message with a specified role.""" + for i in range(len(messages) - 1, -1, -1): + if messages[i].get("role") == role: + return i + return None + + +def _calculate_buffer_tokens( + messages: list[dict[str, str]] | None, + roles: list[str], + user_prompt: str | None, + template: str, +) -> int: + """Calculate the number of tokens used by other messages.""" + # Calculate already used tokens (buffer) + buffer = 0 + # Triggered when using tool calls + if messages: + # Count used tokens by the last message of each role + for role in roles: + idx = _get_last_message_idx(messages, role) + if idx is not None: + buffer += _count_tokens(json.dumps(messages[idx])) + return buffer + # Triggered when using add_context + if user_prompt: + return _count_tokens(template.format(context="", user_prompt=user_prompt)) + return 0 + + +def _cutoff_idx(token_counts: list[int], max_tokens: int, *, reverse: bool = False) -> int: + """Find the cutoff index in token counts to fit within max tokens.""" + counts = token_counts[::-1] if reverse else token_counts + cum_tokens = np.cumsum(counts) + cutoff_idx = int(np.searchsorted(cum_tokens, max_tokens, side="right")) + return len(token_counts) - cutoff_idx if reverse else cutoff_idx + + +def _get_token_counts(items: Sequence[str | ChunkSpan | Mapping[str, str]]) -> list[int]: + """Compute token counts for a list of items.""" + return [ + _count_tokens(item.to_xml()) + if isinstance(item, ChunkSpan) + else _count_tokens(json.dumps(item, ensure_ascii=False)) + if isinstance(item, dict) + else _count_tokens(item) + if isinstance(item, str) + else 0 + for item in items + ] + + +def _limit_chunkspans( + tool_chunk_spans: dict[str, list[ChunkSpan]], + config: RAGLiteConfig, + *, + messages: list[dict[str, str]] | None = None, + user_prompt: str | None = None, + template: str = RAG_INSTRUCTION_TEMPLATE, +) -> dict[str, list[ChunkSpan]]: + """Limit chunk spans to fit within the context window.""" + # Calculate already used tokens (buffer) + buffer = _calculate_buffer_tokens( + messages, ["user", "system", "assistant"], user_prompt, template + ) + # Determine max tokens available for context + max_tokens = get_context_size(config) - buffer + # Compute token counts for all chunk spans per tool + tool_tokens_list: dict[str, list[int]] = {} + tool_total_tokens: dict[str, int] = {} + total_tokens = 0 + total_chunk_spans = 0 + for tool_id, chunk_spans in tool_chunk_spans.items(): + tokens_list = _get_token_counts(chunk_spans) + tool_tokens_list[tool_id] = tokens_list + tool_total = sum(tokens_list) + tool_total_tokens[tool_id] = tool_total + total_tokens += tool_total + total_chunk_spans += len(chunk_spans) + # Early exit if we're already under the limit + if total_tokens <= max_tokens: + return tool_chunk_spans + # Allocate tokens proportionally and truncate + new_total_chunk_spans = 0 + scale_ratio = max_tokens / total_tokens + limited_tool_chunk_spans: dict[str, list[ChunkSpan]] = {} + for tool_id, chunk_spans in tool_chunk_spans.items(): + if not chunk_spans: + limited_tool_chunk_spans[tool_id] = [] + continue + # Proportional allocation + tool_max_tokens = int(scale_ratio * tool_total_tokens[tool_id]) + # Find cutoff point + cutoff_idx = _cutoff_idx(tool_tokens_list[tool_id], tool_max_tokens) + limited_tool_chunk_spans[tool_id] = chunk_spans[ + :cutoff_idx + ] # Keep only up to cutoff (ChunkSpans are ordered in descending relevance) + new_total_chunk_spans += cutoff_idx + # Log warning if chunks were dropped + if new_total_chunk_spans < total_chunk_spans: + logger.warning( + "RAG context was limited to %d out of %d chunks due to context window size. " + "Consider using a model with a bigger context window or reducing the number of retrieved chunks.", + new_total_chunk_spans, + total_chunk_spans, + ) + return limited_tool_chunk_spans + + def add_context( user_prompt: str, context: list[ChunkSpan], + config: RAGLiteConfig, *, rag_instruction_template: str = RAG_INSTRUCTION_TEMPLATE, ) -> dict[str, str]: @@ -73,11 +193,13 @@ def add_context( [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips """ + # Limit context to fit within the context window. + limited_context = _limit_chunkspans({"temp": context}, config, user_prompt=user_prompt)["temp"] message = { "role": "user", "content": rag_instruction_template.format( context="\n".join( - chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(context) + chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(limited_context) ), user_prompt=user_prompt.strip(), ), @@ -87,9 +209,31 @@ def add_context( def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str]]: """Left clip a messages array to avoid hitting the context limit.""" - cum_tokens = np.cumsum([len(message.get("content") or "") // 3 for message in messages][::-1]) - first_message = -np.searchsorted(cum_tokens, max_tokens) - return messages[first_message:] + token_counts = _get_token_counts(messages) + cutoff_idx = _cutoff_idx(token_counts, max_tokens, reverse=True) + idx_user = _get_last_message_idx(messages, "user") + if cutoff_idx == len(messages) or (idx_user is not None and idx_user < cutoff_idx): + warnings.warn( + ( + f"Context window of {max_tokens} tokens exceeded." + "Consider using a model with a bigger context window or reducing the number of retrieved chunks." + ), + stacklevel=2, + ) + # Try to include both last system and user messages if they fit together. + # If not, include just user if it fits, else return empty. + idx_system = _get_last_message_idx(messages, "system") + if ( + idx_user is not None + and idx_system is not None + and idx_system < idx_user + and token_counts[idx_user] + token_counts[idx_system] <= max_tokens + ): + return [messages[idx_system], messages[idx_user]] + if idx_user is not None and token_counts[idx_user] <= max_tokens: + return [messages[idx_user]] + return [] + return messages[cutoff_idx:] def _get_tools( @@ -145,31 +289,35 @@ def _run_tools( tool_calls: list[ChatCompletionMessageToolCall], on_retrieval: Callable[[list[ChunkSpan]], None] | None, config: RAGLiteConfig, + *, + messages: list[dict[str, str]] | None, ) -> list[dict[str, Any]]: """Run tools to search the knowledge base for RAG context.""" + tool_chunk_spans: dict[str, list[ChunkSpan]] = {} tool_messages: list[dict[str, Any]] = [] for tool_call in tool_calls: if tool_call.function.name == "search_knowledge_base": kwargs = json.loads(tool_call.function.arguments) kwargs["config"] = config - chunk_spans = retrieve_context(**kwargs) - tool_messages.append( - { - "role": "tool", - "content": '{{"documents": [{elements}]}}'.format( - elements=", ".join( - chunk_span.to_json(index=i + 1) - for i, chunk_span in enumerate(chunk_spans) - ) - ), - "tool_call_id": tool_call.id, - } - ) - if chunk_spans and callable(on_retrieval): - on_retrieval(chunk_spans) + tool_chunk_spans[tool_call.id] = retrieve_context(**kwargs) else: error_message = f"Unknown function `{tool_call.function.name}`." raise ValueError(error_message) + tool_chunk_spans = _limit_chunkspans(tool_chunk_spans, config, messages=messages) + for tool_id, chunk_spans in tool_chunk_spans.items(): + tool_messages.append( + { + "role": "tool", + "content": '{{"documents": [{elements}]}}'.format( + elements=", ".join( + chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans) + ) + ), + "tool_call_id": tool_id, + } + ) + if chunk_spans and callable(on_retrieval): + on_retrieval(chunk_spans) return tool_messages @@ -202,7 +350,7 @@ def rag( # Add the tool call request to the message array. messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] # Run the tool calls to retrieve the RAG context and append the output to the message array. - messages.extend(_run_tools(tool_calls, on_retrieval, config)) + messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages)) # Stream the assistant response. chunks = [] stream = completion(model=config.llm, messages=_clip(messages, max_tokens), stream=True) @@ -245,7 +393,7 @@ async def async_rag( messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] # Run the tool calls to retrieve the RAG context and append the output to the message array. # TODO: Make this async. - messages.extend(_run_tools(tool_calls, on_retrieval, config)) + messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages)) # Asynchronously stream the assistant response. chunks = [] async_stream = await acompletion( diff --git a/src/raglite/_split_sentences.py b/src/raglite/_split_sentences.py index 1cbeb188..d6ffda0f 100644 --- a/src/raglite/_split_sentences.py +++ b/src/raglite/_split_sentences.py @@ -17,7 +17,7 @@ @cache def _load_sat() -> tuple[SaT, dict[str, Any]]: """Load a Segment any Text (SaT) model.""" - sat = SaT("sat-3l-sm") # This model makes the best trade-off between speed and accuracy. + sat = SaT("sat-1l-sm") # This model makes the best trade-off between speed and accuracy. sat_kwargs = {"stride": 128, "block_size": 256, "weighting": "hat"} return sat, sat_kwargs diff --git a/tests/test_rag.py b/tests/test_rag.py index 151cd96d..4c36463a 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -16,7 +16,7 @@ def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None: # Answer a question with manual RAG. user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?" chunk_spans = retrieve_context(query=user_prompt, config=raglite_test_config) - messages = [add_context(user_prompt, context=chunk_spans)] + messages = [add_context(user_prompt, context=chunk_spans, config=raglite_test_config)] stream = rag(messages, config=raglite_test_config) answer = "" for update in stream: @@ -42,7 +42,8 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: # Verify that RAG context was retrieved automatically. assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"] assert json.loads(messages[-2]["content"]) - assert chunk_spans + if not raglite_test_config.llm.startswith("llama-cpp-python"): + assert chunk_spans assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans) diff --git a/tests/test_split_sentences.py b/tests/test_split_sentences.py index b8de5300..6d82448b 100644 --- a/tests/test_split_sentences.py +++ b/tests/test_split_sentences.py @@ -26,8 +26,7 @@ def test_split_sentences() -> None: "They suggest rather that, as has\nalready been shown to the first order of small quantities, the same laws of\nelectrodynamics and optics will be valid for all frames of reference for which the\nequations of mechanics hold good.1 ", "We will raise this conjecture (the purport\nof which will hereafter be called the “Principle of Relativity”) to the status\n\nof a postulate, and also introduce another postulate, which is only apparently\nirreconcilable with the former, namely, that light is always propagated in empty\nspace with a definite velocity c which is independent of the state of motion of the\nemitting body. ", "These two postulates suffice for the attainment of a simple and\nconsistent theory of the electrodynamics of moving bodies based on Maxwell’s\ntheory for stationary bodies. ", # noqa: RUF001 - "The introduction of a “luminiferous ether” will\nprove to be superfluous inasmuch as the view here to be developed will not\nrequire an “absolutely stationary space” provided with special properties, nor\n1", - "The preceding memoir by Lorentz was not at this time known to the author.\n\n", + "The introduction of a “luminiferous ether” will\nprove to be superfluous inasmuch as the view here to be developed will not\nrequire an “absolutely stationary space” provided with special properties, nor\n1The preceding memoir by Lorentz was not at this time known to the author.\n\n", "assign a velocity-vector to a point of the empty space in which electromagnetic\nprocesses take place.\n\n", "The theory to be developed is based—like all electrodynamics—on the kine-\nmatics of the rigid body, since the assertions of any such theory have to do\nwith the relationships between rigid bodies (systems of co-ordinates), clocks,\nand electromagnetic processes. ", "Insufficient consideration of this circumstance\nlies at the root of the difficulties which the electrodynamics of moving bodies\nat present encounters.\n\n",