diff --git a/README.md b/README.md index 72febc99..a81ee5a6 100644 --- a/README.md +++ b/README.md @@ -230,7 +230,7 @@ chunk_spans = retrieve_context(query=user_prompt, num_chunks=5, config=my_config # Append a RAG instruction based on the user prompt and context to the message history messages = [] # Or start with an existing message history -messages.append(add_context(user_prompt=user_prompt, context=chunk_spans)) +messages.append(add_context(user_prompt=user_prompt, context=chunk_spans, config=my_config)) # Stream the RAG response and append it to the message history stream = rag(messages, config=my_config) @@ -285,7 +285,7 @@ chunk_spans = retrieve_chunk_spans(chunks_reranked, config=my_config) from raglite import add_context messages = [] # Or start with an existing message history -messages.append(add_context(user_prompt=user_prompt, context=chunk_spans)) +messages.append(add_context(user_prompt=user_prompt, context=chunk_spans, config=my_config)) # Stream the RAG response and append it to the message history from raglite import rag diff --git a/src/raglite/_eval.py b/src/raglite/_eval.py index a40f9667..de372222 100644 --- a/src/raglite/_eval.py +++ b/src/raglite/_eval.py @@ -212,7 +212,7 @@ def answer_evals( contexts: list[list[str]] = [] for eval_ in tqdm(evals, desc="Answering evals", unit="eval", dynamic_ncols=True): chunk_spans = retrieve_context(query=eval_.question, config=config) - messages = [add_context(user_prompt=eval_.question, context=chunk_spans)] + messages = [add_context(user_prompt=eval_.question, context=chunk_spans, config=config)] response = rag(messages, config=config) answer = "".join(response) answers.append(answer) diff --git a/src/raglite/_mcp.py b/src/raglite/_mcp.py index ec97364f..0a2a4c3f 100644 --- a/src/raglite/_mcp.py +++ b/src/raglite/_mcp.py @@ -27,7 +27,7 @@ def create_mcp_server(server_name: str, *, config: RAGLiteConfig) -> FastMCP[Any def kb(query: Query) -> str: """Answer a question with information from the knowledge base.""" chunk_spans = retrieve_context(query, config=config) - rag_instruction = add_context(query, chunk_spans) + rag_instruction = add_context(query, chunk_spans, config) return rag_instruction["content"] @mcp.tool() diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 25fc2bd4..2ab080d5 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -1,6 +1,8 @@ """Retrieval-augmented generation.""" import json +import logging +import warnings from collections.abc import AsyncIterator, Callable, Iterator from typing import Any @@ -19,6 +21,8 @@ from raglite._search import retrieve_chunk_spans from raglite._typing import MetadataFilter +logger = logging.getLogger(__name__) + # The default RAG instruction template follows Anthropic's best practices [1]. # [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips RAG_INSTRUCTION_TEMPLATE = """ @@ -61,9 +65,84 @@ def retrieve_context( return chunk_spans +def _count_tokens(item: str) -> int: + """Estimate the number of tokens in an item.""" + return len(item) // 3 + + +def _get_last_message_idx(messages: list[dict[str, str]], role: str) -> int | None: + """Get the index of the last message with a specified role.""" + return next( + (-i for i, m in enumerate(reversed(messages), 1) if m.get("role") == role), + None, + ) # Last message index + + +def _limit_chunkspans( + tool_chunk_spans: dict[str, list[ChunkSpan]], + config: RAGLiteConfig, + *, + messages: list[dict[str, str]] | None = None, + user_prompt: str | None = None, + template: str = RAG_INSTRUCTION_TEMPLATE, +) -> dict[str, list[ChunkSpan]]: + """Limit chunk spans to fit within the context window.""" + # Calculate already used tokens (buffer) + buffer = 0 + # Triggered when using tool calls + if messages: + # Count tokens in the last user, system and tool call messages + for role in ("user", "system", "assistant"): + idx = _get_last_message_idx(messages, role) + if idx is not None: + buffer += _count_tokens(json.dumps(messages[idx])) + # Triggered when using add_context + elif user_prompt: + buffer = _count_tokens(template.format(context="", user_prompt=user_prompt)) + # Determine max tokens available for context + max_tokens = get_context_size(config) - buffer + # Compute token counts for all chunk spans per tool + tool_tokens_list: dict[str, list[int]] = {} + tool_total_tokens: dict[str, int] = {} + total_tokens = 0 + for tool_id, chunk_spans in tool_chunk_spans.items(): + tokens_list = [_count_tokens(chunk_span.to_xml()) for chunk_span in chunk_spans] + tool_tokens_list[tool_id] = tokens_list + tool_total = sum(tokens_list) + tool_total_tokens[tool_id] = tool_total + total_tokens += tool_total + # Early exit if we're already under the limit + if total_tokens <= max_tokens: + return tool_chunk_spans + # Allocate tokens proportionally and truncate + total_chunk_spans = sum(len(spans) for spans in tool_chunk_spans.values()) + limited_tool_chunk_spans: dict[str, list[ChunkSpan]] = {} + for tool_id, chunk_spans in tool_chunk_spans.items(): + if not chunk_spans: + limited_tool_chunk_spans[tool_id] = [] + continue + # Proportional allocation + tool_max_tokens = max_tokens * tool_total_tokens[tool_id] // total_tokens + # Find cutoff point using cumulative sum + cum_tokens = np.cumsum(tool_tokens_list[tool_id]) + cutoff_idx = np.searchsorted(cum_tokens, tool_max_tokens, side="right") + limited_tool_chunk_spans[tool_id] = chunk_spans[:cutoff_idx] + # Log warning if chunks were dropped + new_total_chunk_spans = sum(len(spans) for spans in limited_tool_chunk_spans.values()) + if new_total_chunk_spans < total_chunk_spans: + logger.warning( + "RAG context was limited to %d out of %d chunks due to context window size. " + "Consider using a model with a bigger context window or reducing the number of retrieved chunks.", + new_total_chunk_spans, + total_chunk_spans, + ) + return limited_tool_chunk_spans + + def add_context( user_prompt: str, context: list[ChunkSpan], + config: RAGLiteConfig, *, rag_instruction_template: str = RAG_INSTRUCTION_TEMPLATE, ) -> dict[str, str]: @@ -73,11 +152,12 @@ def add_context( [1] https://docs.anthropic.com/en/docs/build-with-claude/prompt-engineering/long-context-tips """ + limited_context = _limit_chunkspans({"temp": context}, config, user_prompt=user_prompt)["temp"] message = { "role": "user", "content": rag_instruction_template.format( context="\n".join( - chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(context) + chunk_span.to_xml(index=i + 1) for i, chunk_span in enumerate(limited_context) ), user_prompt=user_prompt.strip(), ), @@ -87,8 +167,23 @@ def add_context( def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str]]: """Left clip a messages array to avoid hitting the context limit.""" - cum_tokens = np.cumsum([len(message.get("content") or "") // 3 for message in messages][::-1]) + cum_tokens = np.cumsum([_count_tokens(json.dumps(message)) for message in messages][::-1]) first_message = -np.searchsorted(cum_tokens, max_tokens) + idx = _get_last_message_idx(messages, "user") + if first_message == 0 or ( + idx is not None and idx < first_message + ): # No message fits or last user message (user query) would be clipped + warnings.warn( + ( + f"Context window of {max_tokens} tokens exceeded." + "Consider using a model with a bigger context window or reducing the number of retrieved chunks." + ), + stacklevel=2, + ) + # Return only the last user message if it fits. + if idx is not None and _count_tokens(json.dumps(messages[idx])) <= max_tokens: + return [messages[idx]] + return [] return messages[first_message:] @@ -145,31 +240,35 @@ def _run_tools( tool_calls: list[ChatCompletionMessageToolCall], on_retrieval: Callable[[list[ChunkSpan]], None] | None, config: RAGLiteConfig, + *, + messages: list[dict[str, str]] | None, ) -> list[dict[str, Any]]: """Run tools to search the knowledge base for RAG context.""" + tool_chunk_spans: dict[str, list[ChunkSpan]] = {} tool_messages: list[dict[str, Any]] = [] for tool_call in tool_calls: if tool_call.function.name == "search_knowledge_base": kwargs = json.loads(tool_call.function.arguments) kwargs["config"] = config - chunk_spans = retrieve_context(**kwargs) - tool_messages.append( - { - "role": "tool", - "content": '{{"documents": [{elements}]}}'.format( - elements=", ".join( - chunk_span.to_json(index=i + 1) - for i, chunk_span in enumerate(chunk_spans) - ) - ), - "tool_call_id": tool_call.id, - } - ) - if chunk_spans and callable(on_retrieval): - on_retrieval(chunk_spans) + tool_chunk_spans[tool_call.id] = retrieve_context(**kwargs) else: error_message = f"Unknown function `{tool_call.function.name}`." raise ValueError(error_message) + tool_chunk_spans = _limit_chunkspans(tool_chunk_spans, config, messages=messages) + for tool_id, chunk_spans in tool_chunk_spans.items(): + tool_messages.append( + { + "role": "tool", + "content": '{{"documents": [{elements}]}}'.format( + elements=", ".join( + chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans) + ) + ), + "tool_call_id": tool_id, + } + ) + if chunk_spans and callable(on_retrieval): + on_retrieval(chunk_spans) return tool_messages @@ -202,7 +301,7 @@ def rag( # Add the tool call request to the message array. messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] # Run the tool calls to retrieve the RAG context and append the output to the message array. - messages.extend(_run_tools(tool_calls, on_retrieval, config)) + messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages)) # Stream the assistant response. chunks = [] stream = completion(model=config.llm, messages=_clip(messages, max_tokens), stream=True) @@ -245,7 +344,7 @@ async def async_rag( messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] # Run the tool calls to retrieve the RAG context and append the output to the message array. # TODO: Make this async. - messages.extend(_run_tools(tool_calls, on_retrieval, config)) + messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages)) # Asynchronously stream the assistant response. chunks = [] async_stream = await acompletion( diff --git a/tests/test_rag.py b/tests/test_rag.py index 151cd96d..4c36463a 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -16,7 +16,7 @@ def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None: # Answer a question with manual RAG. user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?" chunk_spans = retrieve_context(query=user_prompt, config=raglite_test_config) - messages = [add_context(user_prompt, context=chunk_spans)] + messages = [add_context(user_prompt, context=chunk_spans, config=raglite_test_config)] stream = rag(messages, config=raglite_test_config) answer = "" for update in stream: @@ -42,7 +42,8 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: # Verify that RAG context was retrieved automatically. assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"] assert json.loads(messages[-2]["content"]) - assert chunk_spans + if not raglite_test_config.llm.startswith("llama-cpp-python"): + assert chunk_spans assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)