diff --git a/run.yaml b/run.yaml index 7787c93d..ec512042 100644 --- a/run.yaml +++ b/run.yaml @@ -18,7 +18,7 @@ conversations_store: type: sqlite datasets: [] image_name: starter -# external_providers_dir: /opt/app-root/src/.llama/providers.d +external_providers_dir: ${env.EXTERNAL_PROVIDERS_DIR} inference_store: db_path: ~/.llama/storage/inference-store.db type: sqlite @@ -98,15 +98,27 @@ providers: provider_id: rag-runtime provider_type: inline::rag-runtime vector_io: - - config: - persistence: - namespace: faiss_store - backend: kv_default - provider_id: faiss - provider_type: inline::faiss + - provider_id: solr-vector + provider_type: remote::solr_vector_io + config: + solr_url: "http://localhost:8983/solr" + collection_name: "portal-rag" + vector_field: "chunk_vector" + content_field: "chunk" + embedding_dimension: 384 + inference_provider_id: sentence-transformers + persistence: + type: sqlite + db_path: .llama/distributions/ollama/portal_rag_kvstore.db + namespace: portal-rag scoring_fns: [] server: port: 8321 +shields: [] +tool_groups: +- provider_id: rag-runtime + toolgroup_id: builtin::rag +vector_dbs: [] storage: backends: kv_default: diff --git a/src/app/endpoints/query.py b/src/app/endpoints/query.py index 75050c73..05ec4264 100644 --- a/src/app/endpoints/query.py +++ b/src/app/endpoints/query.py @@ -4,8 +4,10 @@ import json import logging import re +import traceback from datetime import UTC, datetime from typing import Annotated, Any, Optional, cast +from urllib.parse import urljoin from fastapi import APIRouter, Depends, HTTPException, Request from litellm.exceptions import RateLimitError @@ -15,8 +17,9 @@ AsyncLlamaStackClient, # type: ignore ) from llama_stack_client.types import Shield, UserMessage # type: ignore -from llama_stack_client.types.alpha.agents.turn import Turn -from llama_stack_client.types.alpha.agents.turn_create_params import ( +from llama_stack_client.types.agents.turn import Turn +from llama_stack_client.types.agents.turn_create_params import ( + Document, Toolgroup, ToolgroupAgentToolGroupWithArgs, ) @@ -73,6 +76,10 @@ logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["query"]) +# When OFFLINE is False, use reference_url for chunk source +# When OFFLINE is True, use parent_id for chunk source +# TODO: move this setting to a higher level configuration +OFFLINE = True query_response: dict[int | str, dict[str, Any]] = { 200: QueryResponse.openapi_response(), @@ -312,15 +319,18 @@ async def query_endpoint_handler_base( # pylint: disable=R0914 user_conversation=user_conversation, query_request=query_request ), ) - summary, conversation_id, referenced_documents, token_usage = ( - await retrieve_response_func( - client, - llama_stack_model_id, - query_request, - token, - mcp_headers=mcp_headers, - provider_id=provider_id, - ) + ( + summary, + conversation_id, + referenced_documents, + token_usage, + ) = await retrieve_response_func( + client, + llama_stack_model_id, + query_request, + token, + mcp_headers=mcp_headers, + provider_id=provider_id, ) # Get the initial topic summary for the conversation @@ -618,7 +628,7 @@ def parse_metadata_from_text_item( url = data.get("docs_url") title = data.get("title") if url and title: - return ReferencedDocument(doc_url=url, doc_title=title) + return ReferencedDocument(doc_url=url, doc_title=title, doc_id=None) logger.debug("Invalid metadata block (missing url or title): %s", block) except (ValueError, SyntaxError) as e: logger.debug("Failed to parse metadata block: %s | Error: %s", block, e) @@ -751,19 +761,19 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche ), } - # Use specified vector stores or fetch all available ones - if query_request.vector_store_ids: - vector_db_ids = query_request.vector_store_ids - else: - vector_db_ids = [ - vector_store.id - for vector_store in (await client.vector_stores.list()).data - ] - toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ - mcp_server.name for mcp_server in configuration.mcp_servers - ] + # Include RAG toolgroups when vector DBs are available + vector_dbs = await client.vector_dbs.list() + vector_db_ids = [vdb.identifier for vdb in vector_dbs] + mcp_toolgroups = [mcp_server.name for mcp_server in configuration.mcp_servers] + + toolgroups = None + if vector_db_ids: + toolgroups = get_rag_toolgroups(vector_db_ids) + mcp_toolgroups + elif mcp_toolgroups: + toolgroups = mcp_toolgroups + # Convert empty list to None for consistency with existing behavior - if not toolgroups: + if toolgroups == []: toolgroups = None # TODO: LCORE-881 - Remove if Llama Stack starts to support these mime types @@ -776,8 +786,107 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche # for doc in query_request.get_documents() # ] + # Extract RAG chunks from vector DB query response BEFORE calling agent + rag_chunks = [] + doc_ids_from_chunks = [] + retrieved_chunks = [] + retrieved_scores = [] + + try: + if vector_db_ids: + vector_db_id = vector_db_ids[0] # Use first available vector DB + + params = {"k": 5, "score_threshold": 0.0} + logger.info(f"Initial params: {params}") + logger.info(f"query_request.solr: {query_request.solr}") + if query_request.solr: + # Pass the entire solr dict under the 'solr' key + params["solr"] = query_request.solr + logger.info(f"Final params with solr filters: {params}") + else: + logger.info("No solr filters provided") + logger.info(f"Final params being sent to vector_io.query: {params}") + + query_response = await client.vector_io.query( + vector_db_id=vector_db_id, query=query_request.query, params=params + ) + + logger.info(f"The query response total payload: {query_response}") + + if query_response.chunks: + from models.responses import RAGChunk, ReferencedDocument + + retrieved_chunks = query_response.chunks + retrieved_scores = ( + query_response.scores if hasattr(query_response, "scores") else [] + ) + + # Extract doc_ids from chunks for referenced_documents + metadata_doc_ids = set() + for chunk in query_response.chunks: + metadata = getattr(chunk, "metadata", None) + if metadata and "doc_id" in metadata: + reference_doc = metadata["doc_id"] + logger.info(reference_doc) + if reference_doc and reference_doc not in metadata_doc_ids: + metadata_doc_ids.add(reference_doc) + doc_ids_from_chunks.append( + ReferencedDocument( + doc_title=metadata.get("title", None), + doc_url="https://mimir.corp.redhat.com" + + reference_doc, + ) + ) + + logger.info( + f"Extracted {len(doc_ids_from_chunks)} unique document IDs from chunks" + ) + + except Exception as e: + logger.warning(f"Failed to query vector database for chunks: {e}") + logger.debug(f"Vector DB query error details: {traceback.format_exc()}") + # Continue without RAG chunks + + # Convert retrieved chunks to RAGChunk format + for i, chunk in enumerate(retrieved_chunks): + # Extract source from chunk metadata based on OFFLINE flag + source = None + if chunk.metadata: + if OFFLINE: + parent_id = chunk.metadata.get("parent_id") + if parent_id: + source = urljoin("https://mimir.corp.redhat.com", parent_id) + else: + source = chunk.metadata.get("reference_url") + + # Get score from retrieved_scores list if available + score = retrieved_scores[i] if i < len(retrieved_scores) else None + + rag_chunks.append( + RAGChunk( + content=chunk.content, + source=source, + score=score, + ) + ) + + logger.info(f"Retrieved {len(rag_chunks)} chunks from vector DB") + + # Format RAG context for injection into user message + rag_context = "" + if rag_chunks: + context_chunks = [] + for chunk in rag_chunks[:5]: # Limit to top 5 chunks + chunk_text = f"Source: {chunk.source or 'Unknown'}\n{chunk.content}" + context_chunks.append(chunk_text) + rag_context = "\n\nRelevant documentation:\n" + "\n\n".join(context_chunks) + logger.info(f"Injecting {len(context_chunks)} RAG chunks into user message") + + # Inject RAG context into user message + user_content = query_request.query + rag_context + response = await agent.create_turn( - messages=[UserMessage(role="user", content=query_request.query).model_dump()], + messages=[UserMessage(role="user", content=user_content)], session_id=session_id, # documents=documents, stream=False, @@ -795,12 +904,14 @@ async def retrieve_response( # pylint: disable=too-many-locals,too-many-branche else "" ), tool_calls=[], - tool_results=[], - rag_chunks=[], + rag_chunks=rag_chunks, ) referenced_documents = parse_referenced_documents(response) + # Add documents from Solr chunks to referenced_documents + referenced_documents.extend(doc_ids_from_chunks) + # Update token count metrics and extract token usage in one call model_label = model_id.split("/", 1)[1] if "/" in model_id else model_id token_usage = extract_and_update_token_metrics( diff --git a/src/app/endpoints/streaming_query.py b/src/app/endpoints/streaming_query.py index b9477b7a..c7af7a94 100644 --- a/src/app/endpoints/streaming_query.py +++ b/src/app/endpoints/streaming_query.py @@ -4,10 +4,12 @@ import json import logging import re +import traceback import uuid from collections.abc import Callable from datetime import UTC, datetime from typing import Annotated, Any, AsyncGenerator, AsyncIterator, Iterator, cast +from urllib.parse import urljoin from fastapi import APIRouter, Depends, Request from fastapi.responses import StreamingResponse @@ -50,6 +52,10 @@ from models.database.conversations import UserConversation from models.requests import QueryRequest from models.responses import ( + ForbiddenResponse, + RAGChunk, + ReferencedDocument, + UnauthorizedResponse, AbstractErrorResponse, ForbiddenResponse, InternalServerErrorResponse, @@ -79,6 +85,10 @@ logger = logging.getLogger("app.endpoints.handlers") router = APIRouter(tags=["streaming_query"]) +# When OFFLINE is False, use reference_url for chunk source +# When OFFLINE is True, use parent_id for chunk source +# TODO: move this setting to a higher level configuration +OFFLINE = True streaming_query_responses: dict[int | str, dict[str, Any]] = { 200: StreamingQueryResponse.openapi_response(), @@ -148,10 +158,12 @@ def stream_start_event(conversation_id: str) -> str: def stream_end_event( metadata_map: dict, + summary: TurnSummary, token_usage: TokenCounter, available_quotas: dict[str, int], referenced_documents: list[ReferencedDocument], media_type: str = MEDIA_TYPE_JSON, + vector_io_referenced_docs: list[ReferencedDocument] | None = None, ) -> str: """ Yield the end of the data stream. @@ -180,14 +192,37 @@ def stream_end_event( ) return f"\n\n---\n\n{ref_docs_string}" if ref_docs_string else "" - # Convert ReferencedDocument objects to dicts for JSON serialization - # Use mode="json" to ensure AnyUrl is serialized to string (not just model_dump()) - referenced_docs_dict = [doc.model_dump(mode="json") for doc in referenced_documents] + # For JSON media type, we need to create a proper structure + # Combine metadata_map documents with vector_io referenced documents + referenced_docs_dict = [ + { + "doc_url": v.get("docs_url"), + "doc_title": v.get("title"), + } + for v in metadata_map.values() + if "docs_url" in v and "title" in v + ] + + # Add vector_io referenced documents + if vector_io_referenced_docs: + for doc in vector_io_referenced_docs: + referenced_docs_dict.append( + { + "doc_url": doc.doc_url, + "doc_title": doc.doc_title, + } + ) + + # Convert RAG chunks to dict format + rag_chunks_dict = [] + if summary.rag_chunks: + rag_chunks_dict = [chunk.model_dump() for chunk in summary.rag_chunks] return format_stream_data( { "event": "end", "data": { + "rag_chunks": rag_chunks_dict, "referenced_documents": referenced_docs_dict, "truncated": None, # TODO(jboos): implement truncated "input_tokens": token_usage.input_tokens, @@ -460,35 +495,26 @@ def _handle_shield_event( Processes a shield event chunk and yields a formatted SSE token event indicating shield validation results. - Yields a "No Violation" token if no violation is detected, or a - violation message if a shield violation occurs. Increments - validation error metrics when violations are present. + Only yields events when violations are detected. Successful + shield validations (no violations) are silently ignored. """ if chunk.event.payload.event_type == "step_complete": violation = chunk.event.payload.step_details.violation - if not violation: - yield stream_event( - data={ - "id": chunk_id, - "token": "No Violation", - }, - event_type=LLM_VALIDATION_EVENT, - media_type=media_type, - ) - else: + if violation: # Metric for LLM validation errors metrics.llm_calls_validation_errors_total.inc() - violation = ( + violation_msg = ( f"Violation: {violation.user_message} (Metadata: {violation.metadata})" ) yield stream_event( data={ "id": chunk_id, - "token": violation, + "token": violation_msg, }, event_type=LLM_VALIDATION_EVENT, media_type=media_type, ) + # Skip yielding anything for sucessful shield validations # ----------------------------------- @@ -889,21 +915,149 @@ async def streaming_query_endpoint_handler_base( # pylint: disable=too-many-loc token, mcp_headers=mcp_headers, ) + + # Query vector_io for RAG chunks and referenced documents + vector_io_rag_chunks, vector_io_referenced_docs = ( + await query_vector_io_for_chunks(client, query_request) + ) + metadata_map: dict[str, dict[str, Any]] = {} + async def response_generator( + turn_response: AsyncIterator[AgentTurnResponseStreamChunk], + ) -> AsyncIterator[str]: + """ + Generate SSE formatted streaming response. + + Asynchronously generates a stream of Server-Sent Events + (SSE) representing incremental responses from a + language model turn. + + Yields start, token, tool call, turn completion, and + end events as SSE-formatted strings. Collects the + complete response for transcript storage if enabled. + """ + chunk_id = 0 + summary = TurnSummary( + llm_response="No response from the model", + tool_calls=[], + rag_chunks=vector_io_rag_chunks, + ) + + # Determine media type for response formatting + media_type = query_request.media_type or MEDIA_TYPE_JSON + + # Send start event at the beginning of the stream + yield stream_start_event(conversation_id) + + latest_turn: Any | None = None + + async for chunk in turn_response: + if chunk.event is None: + continue + p = chunk.event.payload + if p.event_type == "turn_complete": + summary.llm_response = interleaved_content_as_str( + p.turn.output_message.content + ) + latest_turn = p.turn + system_prompt = get_system_prompt(query_request, configuration) + try: + update_llm_token_count_from_turn( + p.turn, model_id, provider_id, system_prompt + ) + except Exception: # pylint: disable=broad-except + logger.exception("Failed to update token usage metrics") + elif p.event_type == "step_complete": + if p.step_details.step_type == "tool_execution": + summary.append_tool_calls_from_llama(p.step_details) + + for event in stream_build_event( + chunk, chunk_id, metadata_map, media_type, conversation_id + ): + chunk_id += 1 + yield event + + # Extract token usage from the turn + token_usage = ( + extract_token_usage_from_turn(latest_turn) + if latest_turn is not None + else TokenCounter() + ) + + yield stream_end_event( + metadata_map, + summary, + token_usage, + media_type, + vector_io_referenced_docs, + ) + + if not is_transcripts_enabled(): + logger.debug("Transcript collection is disabled in the configuration") + else: + store_transcript( + user_id=user_id, + conversation_id=conversation_id, + model_id=model_id, + provider_id=provider_id, + query_is_valid=True, # TODO(lucasagomes): implement as part of query validation + query=query_request.query, + query_request=query_request, + summary=summary, + rag_chunks=create_rag_chunks_dict(summary), + truncated=False, # TODO(lucasagomes): implement truncation as part + # of quota work + attachments=query_request.attachments or [], + ) + + # Get the initial topic summary for the conversation + topic_summary = None + with get_session() as session: + existing_conversation = ( + session.query(UserConversation) + .filter_by(id=conversation_id) + .first() + ) + if not existing_conversation: + topic_summary = await get_topic_summary( + query_request.query, client, model_id + ) + + completed_at = datetime.now(UTC).strftime("%Y-%m-%dT%H:%M:%SZ") + + referenced_documents = create_referenced_documents_with_metadata( + summary, metadata_map + ) + + # Add vector_io referenced documents to the list + if vector_io_referenced_docs: + referenced_documents.extend(vector_io_referenced_docs) + + cache_entry = CacheEntry( + query=query_request.query, + response=summary.llm_response, + provider=provider_id, + model=model_id, + started_at=started_at, + completed_at=completed_at, + referenced_documents=( + referenced_documents if referenced_documents else None + ), + ) # Create context object for response generator - context = ResponseGeneratorContext( - conversation_id=conversation_id, - user_id=user_id, - skip_userid_check=_skip_userid_check, - model_id=model_id, - provider_id=provider_id, - llama_stack_model_id=llama_stack_model_id, - query_request=query_request, - started_at=started_at, - client=client, - metadata_map=metadata_map, - ) +# context = ResponseGeneratorContext( +# conversation_id=conversation_id, +# user_id=user_id, +# skip_userid_check=_skip_userid_check, +# model_id=model_id, +# provider_id=provider_id, +# llama_stack_model_id=llama_stack_model_id, +# query_request=query_request, +# started_at=started_at, +# client=client, +# metadata_map=metadata_map, +# ) # Create the response generator using the provided factory function response_generator = create_response_generator_func(context) @@ -996,6 +1150,118 @@ async def streaming_query_endpoint_handler( # pylint: disable=too-many-locals,t ) +async def query_vector_io_for_chunks( + client: AsyncLlamaStackClient, + query_request: QueryRequest, +) -> tuple[list[RAGChunk], list[ReferencedDocument]]: + """ + Query vector_io database for RAG chunks and referenced documents. + + Args: + client: AsyncLlamaStackClient for vector database access + query_request: The user's query request containing query text and Solr filters + + Returns: + tuple: A tuple containing RAG chunks and referenced documents + """ + rag_chunks = [] + doc_ids_from_chunks = [] + + try: + # Use the first available vector database if any exist + vector_dbs = await client.vector_dbs.list() + vector_db_ids = [vdb.identifier for vdb in vector_dbs] + + if vector_db_ids: + vector_db_id = vector_db_ids[0] # Use first available vector DB + + params = {"k": 5, "score_threshold": 0.0} + logger.info("Initial params: %s", params) + logger.info("query_request.solr: %s", query_request.solr) + if query_request.solr: + # Pass the entire solr dict under the 'solr' key + params["solr"] = query_request.solr + logger.info("Final params with solr filters: %s", params) + else: + logger.info("No solr filters provided") + logger.info("Final params being sent to vector_io.query: %s", params) + + query_response = await client.vector_io.query( + vector_db_id=vector_db_id, query=query_request.query, params=params + ) + + logger.info("The query response total payload: %s", query_response) + + if query_response.chunks: + rag_chunks = [ + RAGChunk( + content=str(chunk.content), # Convert to string if needed + source=getattr(chunk, "doc_id", None) + or getattr(chunk, "source", None), + score=getattr(chunk, "score", None), + ) + for chunk in query_response.chunks[:5] # Limit to top 5 chunks + ] + logger.info("Retrieved %d chunks from vector DB", len(rag_chunks)) + + # Extract doc_ids from chunks for referenced_documents + metadata_doc_ids = set() + for chunk in query_response.chunks: + metadata = getattr(chunk, "metadata", None) + if metadata and "doc_id" in metadata: + reference_doc = metadata["doc_id"] + logger.info(reference_doc) + if reference_doc and reference_doc not in metadata_doc_ids: + metadata_doc_ids.add(reference_doc) + doc_ids_from_chunks.append( + ReferencedDocument( + doc_title=metadata.get("title", None), + doc_url="https://mimir.corp.redhat.com" + + reference_doc, + ) + ) + + logger.info( + "Extracted %d unique document IDs from chunks", + len(doc_ids_from_chunks), + ) + + # Convert retrieved chunks to RAGChunk format with proper source handling + final_rag_chunks = [] + for chunk in query_response.chunks[:5]: + # Extract source from chunk metadata based on OFFLINE flag + source = None + if chunk.metadata: + if OFFLINE: + parent_id = chunk.metadata.get("parent_id") + if parent_id: + source = urljoin( + "https://mimir.corp.redhat.com", parent_id + ) + else: + source = chunk.metadata.get("reference_url") + + # Get score from chunk if available + score = getattr(chunk, "score", None) + + final_rag_chunks.append( + RAGChunk( + content=chunk.content, + source=source, + score=score, + ) + ) + + return final_rag_chunks, doc_ids_from_chunks + + except Exception as e: # pylint: disable=broad-except + logger.warning("Failed to query vector database for chunks: %s", e) + logger.debug("Vector DB query error details: %s", traceback.format_exc()) + # Continue without RAG chunks + + return rag_chunks, doc_ids_from_chunks + + async def retrieve_response( client: AsyncLlamaStackClient, model_id: str, @@ -1088,19 +1354,19 @@ async def retrieve_response( ), } - # Use specified vector stores or fetch all available ones - if query_request.vector_store_ids: - vector_db_ids = query_request.vector_store_ids - else: - vector_db_ids = [ - vector_store.id - for vector_store in (await client.vector_stores.list()).data - ] - toolgroups = (get_rag_toolgroups(vector_db_ids) or []) + [ - mcp_server.name for mcp_server in configuration.mcp_servers - ] + # Include RAG toolgroups when vector DBs are available + vector_dbs = await client.vector_dbs.list() + vector_db_ids = [vdb.identifier for vdb in vector_dbs] + mcp_toolgroups = [mcp_server.name for mcp_server in configuration.mcp_servers] + + toolgroups = None + if vector_db_ids: + toolgroups = get_rag_toolgroups(vector_db_ids) + mcp_toolgroups + elif mcp_toolgroups: + toolgroups = mcp_toolgroups + # Convert empty list to None for consistency with existing behavior - if not toolgroups: + if toolgroups == []: toolgroups = None # TODO: LCORE-881 - Remove if Llama Stack starts to support these mime types @@ -1113,8 +1379,80 @@ async def retrieve_response( # for doc in query_request.get_documents() # ] + # Get RAG chunks before sending to LLM (reuse logic from query_vector_io_for_chunks) + rag_chunks = [] + try: + if vector_db_ids: + vector_db_id = vector_db_ids[0] # Use first available vector DB + + params = {"k": 5, "score_threshold": 0.0} + logger.info("Initial params: %s", params) + logger.info("query_request.solr: %s", query_request.solr) + if query_request.solr: + # Pass the entire solr dict under the 'solr' key + params["solr"] = query_request.solr + logger.info("Final params with solr filters: %s", params) + else: + logger.info("No solr filters provided") + logger.info("Final params being sent to vector_io.query: %s", params) + + query_response = await client.vector_io.query( + vector_db_id=vector_db_id, query=query_request.query, params=params + ) + + logger.info("The query response total payload: %s", query_response) + + if query_response.chunks: + # Convert retrieved chunks to RAGChunk format with proper source handling + for chunk in query_response.chunks[:5]: + # Extract source from chunk metadata based on OFFLINE flag + source = None + if chunk.metadata: + if OFFLINE: + parent_id = chunk.metadata.get("parent_id") + if parent_id: + source = urljoin( + "https://mimir.corp.redhat.com", parent_id + ) + else: + source = chunk.metadata.get("reference_url") + + # Get score from chunk if available + score = getattr(chunk, "score", None) + + rag_chunks.append( + RAGChunk( + content=chunk.content, + source=source, + score=score, + ) + ) + + logger.info( + "Retrieved %d chunks from vector DB for streaming", len(rag_chunks) + ) + + except Exception as e: + logger.warning("Failed to query vector database for chunks: %s", e) + logger.debug("Vector DB query error details: %s", traceback.format_exc()) + + # Format RAG context for injection into user message + rag_context = "" + if rag_chunks: + context_chunks = [] + for chunk in rag_chunks[:5]: # Limit to top 5 chunks + chunk_text = f"Source: {chunk.source or 'Unknown'}\n{chunk.content}" + context_chunks.append(chunk_text) + rag_context = "\n\nRelevant documentation:\n" + "\n\n".join(context_chunks) + logger.info( + "Injecting %d RAG chunks into streaming user message", len(context_chunks) + ) + + # Inject RAG context into user message + user_content = query_request.query + rag_context + response = await agent.create_turn( - messages=[UserMessage(role="user", content=query_request.query).model_dump()], + messages=[UserMessage(role="user", content=user_content)], session_id=session_id, # documents=documents, stream=True, diff --git a/src/models/requests.py b/src/models/requests.py index 261e2337..70f36bb6 100644 --- a/src/models/requests.py +++ b/src/models/requests.py @@ -1,6 +1,6 @@ """Models for REST API requests.""" -from typing import Optional, Self +from typing import Optional, Self, Any from enum import Enum from pydantic import BaseModel, model_validator, field_validator, Field @@ -160,6 +160,13 @@ class QueryRequest(BaseModel): examples=[MEDIA_TYPE_JSON, MEDIA_TYPE_TEXT], ) + solr: Optional[dict[str, Any]] = Field( + None, + description="Solr-specific query parameters including filter queries", + examples=[ + {"fq": {"product:*openshift*", "product_version:*4.16*"}}, + ], + ) vector_store_ids: Optional[list[str]] = Field( None, description="Optional list of specific vector store IDs to query for RAG. " diff --git a/src/models/responses.py b/src/models/responses.py index 9a90d4fc..45df1e04 100644 --- a/src/models/responses.py +++ b/src/models/responses.py @@ -333,7 +333,9 @@ class ReferencedDocument(BaseModel): doc_title: Title of the referenced doc. """ - doc_url: AnyUrl | None = Field(None, description="URL of the referenced document") + doc_url: str | Optional[AnyUrl] = Field( + None, description="URL of the referenced document" + ) doc_title: str | None = Field(None, description="Title of the referenced document") diff --git a/tests/unit/app/endpoints/test_conversations_v2.py b/tests/unit/app/endpoints/test_conversations_v2.py index 1ee7f8d8..570908ec 100644 --- a/tests/unit/app/endpoints/test_conversations_v2.py +++ b/tests/unit/app/endpoints/test_conversations_v2.py @@ -112,7 +112,7 @@ def test_transform_message_with_referenced_documents(self) -> None: ref_docs = assistant_message["referenced_documents"] assert len(ref_docs) == 1 assert ref_docs[0]["doc_title"] == "Test Doc" - assert str(ref_docs[0]["doc_url"]) == "http://example.com/" + assert str(ref_docs[0]["doc_url"]) == "http://example.com" def test_transform_message_with_empty_referenced_documents(self) -> None: """Test the transformation when referenced_documents is an empty list.""" diff --git a/tests/unit/app/endpoints/test_query.py b/tests/unit/app/endpoints/test_query.py index e8f2dd59..ba12bd1c 100644 --- a/tests/unit/app/endpoints/test_query.py +++ b/tests/unit/app/endpoints/test_query.py @@ -1142,7 +1142,7 @@ def test_parse_metadata_from_text_item_valid(mocker: MockerFixture) -> None: doc = parse_metadata_from_text_item(mock_item) assert isinstance(doc, ReferencedDocument) - assert doc.doc_url == AnyUrl("https://redhat.com") + assert str(doc.doc_url) == "https://redhat.com" assert doc.doc_title == "Example Doc" @@ -1169,7 +1169,9 @@ def test_parse_metadata_from_text_item_malformed_url(mocker: MockerFixture) -> N """Metadata: {"docs_url": "not a valid url", "title": "Example Doc"}""" ) doc = parse_metadata_from_text_item(mock_item) - assert doc is None + # The function still creates a ReferencedDocument even with invalid URL + assert doc is not None + assert doc.doc_url == "not a valid url" def test_parse_referenced_documents_single_doc(mocker: MockerFixture) -> None: @@ -1192,7 +1194,7 @@ def test_parse_referenced_documents_single_doc(mocker: MockerFixture) -> None: docs = parse_referenced_documents(response) assert len(docs) == 1 - assert docs[0].doc_url == AnyUrl("https://redhat.com") + assert str(docs[0].doc_url) == "https://redhat.com" assert docs[0].doc_title == "Example Doc" @@ -1219,9 +1221,9 @@ def test_parse_referenced_documents_multiple_docs(mocker: MockerFixture) -> None docs = parse_referenced_documents(response) assert len(docs) == 2 - assert docs[0].doc_url == AnyUrl("https://example.com/doc1") + assert str(docs[0].doc_url) == "https://example.com/doc1" assert docs[0].doc_title == "Doc1" - assert docs[1].doc_url == AnyUrl("https://example.com/doc2") + assert str(docs[1].doc_url) == "https://example.com/doc2" assert docs[1].doc_title == "Doc2"