diff --git a/src/rca_accelerator_chatbot/api.py b/src/rca_accelerator_chatbot/api.py index 0544a49..056435d 100644 --- a/src/rca_accelerator_chatbot/api.py +++ b/src/rca_accelerator_chatbot/api.py @@ -12,15 +12,14 @@ from fastapi.security import APIKeyHeader from pydantic import BaseModel, Field, HttpUrl -from rca_accelerator_chatbot.constants import ( - CI_LOGS_PROFILE, DOCS_PROFILE, RCA_FULL_PROFILE -) +from rca_accelerator_chatbot.constants import CI_LOGS_PROFILE, DOCS_PROFILE, RCA_FULL_PROFILE from rca_accelerator_chatbot.chat import handle_user_message_api from rca_accelerator_chatbot.config import config from rca_accelerator_chatbot.settings import ModelSettings -from rca_accelerator_chatbot.generation import discover_generative_model_names -from rca_accelerator_chatbot.embeddings import discover_embeddings_model_names from rca_accelerator_chatbot.auth import authentification +from rca_accelerator_chatbot.models import ( + gen_model_provider, embed_model_provider, init_model_providers +) app = FastAPI(title="RCAccelerator API") @@ -52,10 +51,14 @@ class RcaRequest(BaseModelSettings): async def validate_settings(request: BaseModelSettings) -> BaseModelSettings: """Validate the settings for any request. This function performs checks to ensure the API request is valid. - Some checks are performed asynchronously which is why we don't use + Some checks are performed asynchronously, which is why we don't use the built-in Pydantic validators. """ - available_generative_models = await discover_generative_model_names() + # Make sure we pull the latest info about running models. Note that the responses + # are cached by the providers for a certain amount of time. + await init_model_providers() + + available_generative_models = gen_model_provider.all_model_names if not request.generative_model_name: request.generative_model_name = available_generative_models[0] elif request.generative_model_name not in available_generative_models: @@ -64,7 +67,7 @@ async def validate_settings(request: BaseModelSettings) -> BaseModelSettings: detail=f"Invalid generative model. Available: {available_generative_models}" ) - available_embedding_models = await discover_embeddings_model_names() + available_embedding_models = embed_model_provider.all_model_names if not request.embeddings_model_name: request.embeddings_model_name = available_embedding_models[0] elif request.embeddings_model_name not in available_embedding_models: diff --git a/src/rca_accelerator_chatbot/app.py b/src/rca_accelerator_chatbot/app.py index ca996bf..076a454 100644 --- a/src/rca_accelerator_chatbot/app.py +++ b/src/rca_accelerator_chatbot/app.py @@ -9,8 +9,9 @@ from rca_accelerator_chatbot import constants from rca_accelerator_chatbot.chat import handle_user_message from rca_accelerator_chatbot.auth import authentification -from rca_accelerator_chatbot.generation import discover_generative_model_names -from rca_accelerator_chatbot.embeddings import discover_embeddings_model_names +from rca_accelerator_chatbot.models import ( + init_model_providers, gen_model_provider, embed_model_provider, rerank_model_provider +) @cl.set_chat_profiles @@ -63,7 +64,6 @@ async def chat_profile() -> list[cl.ChatProfile]: ) ] - @cl.on_chat_start async def init_chat(): """ @@ -73,17 +73,18 @@ async def init_chat(): """ cl.user_session.set("counter", 0) + await init_model_providers() await setup_chat_settings() - async def setup_chat_settings(): """ Set up the chat settings interface with model selection, temperature, token limits, and other configuration options. """ - generative_model_names = await discover_generative_model_names() - embeddings_model_names = await discover_embeddings_model_names() - if not generative_model_names or not embeddings_model_names: + generative_model_names = gen_model_provider.all_model_names + embeddings_model_names = embed_model_provider.all_model_names + rerank_model_names = rerank_model_provider.all_model_names + if not generative_model_names or not embeddings_model_names or not rerank_model_names: await cl.Message( content="No generative or embeddings model found. " "Please check your configuration." @@ -104,6 +105,12 @@ async def setup_chat_settings(): values=embeddings_model_names, initial_index=0, ), + Select( + id="rerank_model", + label="Re-rank model", + values=rerank_model_names, + initial_index=0 + ), Slider( id="temperature", label="Model Temperature", diff --git a/src/rca_accelerator_chatbot/chat.py b/src/rca_accelerator_chatbot/chat.py index ffb06ab..bcc4f73 100644 --- a/src/rca_accelerator_chatbot/chat.py +++ b/src/rca_accelerator_chatbot/chat.py @@ -1,20 +1,17 @@ """Handler for chat messages and responses.""" from dataclasses import dataclass import chainlit as cl -from chainlit.context import ChainlitContextException import httpx from openai.types.chat import ChatCompletionAssistantMessageParam from rca_accelerator_chatbot import constants from rca_accelerator_chatbot.prompt import build_prompt from rca_accelerator_chatbot.vectordb import vector_store -from rca_accelerator_chatbot.generation import get_response -from rca_accelerator_chatbot.embeddings import ( - get_num_tokens, generate_embedding, - get_rerank_score, get_default_embeddings_model_name -) from rca_accelerator_chatbot.settings import ModelSettings, HistorySettings, ThreadMessages from rca_accelerator_chatbot.config import config +from rca_accelerator_chatbot.models import ( + gen_model_provider, embed_model_provider, rerank_model_provider +) from rca_accelerator_chatbot.constants import ( DOCS_PROFILE, RCA_FULL_PROFILE, @@ -56,7 +53,9 @@ async def perform_multi_collection_search( collections: A list of collections to search. settings: The settings user provided through the UI. """ - embedding = await generate_embedding(message_content, embeddings_model_name) + embedding = await embed_model_provider.generate_embedding( + message_content, embeddings_model_name + ) if embedding is None: return [] @@ -69,7 +68,9 @@ async def perform_multi_collection_search( for r in results: r['collection'] = collection if settings['enable_rerank']: - r['rerank_score'] = await get_rerank_score(message_content, r['text']) + r['rerank_score'] = await rerank_model_provider.get_rerank_score( + message_content, r['text'], settings["rerank_model"] + ) else: r['rerank_score'] = None @@ -119,49 +120,6 @@ def update_msg_count(): counter += 1 cl.user_session.set("counter", counter) - -async def check_message_length(message_content: str) -> tuple[bool, str]: - """ - Check if the message content exceeds the token limit. - - Args: - message_content: The content to check - - Returns: - A tuple containing: - - bool: True if the message is within length limits, False otherwise - - str: Error message if the length check fails, empty string otherwise - """ - try: - num_required_tokens = await get_num_tokens(message_content, - await get_embeddings_model_name()) - except httpx.HTTPStatusError as e: - cl.logger.error(e) - return False, "We've encountered an issue. Please try again later ..." - - if num_required_tokens > config.embeddings_llm_max_context: - # Calculate the maximum character limit estimation for the embedding model. - approx_max_chars = round( - config.embeddings_llm_max_context * config.chars_per_token_estimation, -2) - - error_message = ( - "⚠️ **Your input is too lengthy!**\n We can process inputs of up " - f"to approximately {approx_max_chars} characters. The exact limit " - "may vary depending on the input type. For instance, plain text " - "inputs can be longer compared to logs or structured data " - "containing special characters (e.g., `[`, `]`, `:`, etc.).\n\n" - "To proceed, please:\n" - " - Focus on including only the most relevant details, and\n" - " - Shorten your input if possible." - " \n\n" - "To let you continue, we will reset the conversation history.\n" - "Please start over with a shorter input." - ) - return False, error_message - - return True, "" - - async def print_debug_content( settings: dict, search_content: str, @@ -192,8 +150,9 @@ async def print_debug_content( ) # Display the number of tokens in the search content - num_t = await get_num_tokens(search_content, - await get_embeddings_model_name()) + num_t = await embed_model_provider.get_num_tokens( + search_content, settings["embeddings_model"] + ) debug_content += f"**Number of tokens in search content:** {num_t}\n\n" # Display vector DB debug information if debug mode is enabled @@ -229,13 +188,13 @@ def _build_search_content(message_history: ThreadMessages, if message['role'] == 'user': previous_message_content += f"\n{message['content']}" - # Limit the size of the message content as this is passed as query to - # the reranking model. This is brute truncation, but can be improved + # Limit the size of the message content as this is passed as a query to + # the reranking model. This is brute truncation but can be improved # when we handle message history better. if settings['enable_rerank']: max_text_len_from_history = ( - config.reranking_model_max_context // 2 - len( - current_message)) - 1 + rerank_model_provider.get_context_size(settings["rerank_model"]) // 2 + - len(current_message)) - 1 return previous_message_content[ :max_text_len_from_history] + current_message return previous_message_content + current_message @@ -271,8 +230,8 @@ async def handle_user_message( # pylint: disable=too-many-locals,too-many-statem settings) # Check message length - is_valid_length, error_message = await check_message_length( - search_content) + is_valid_length, error_message = await embed_model_provider.check_message_length( + search_content, settings["embeddings_model"]) if not is_valid_length: resp.content = error_message # Reset message history to let the user try again @@ -296,7 +255,7 @@ async def handle_user_message( # pylint: disable=too-many-locals,too-many-statem try: search_results = await perform_multi_collection_search( search_content, - await get_embeddings_model_name(), + settings["embeddings_model"], get_similarity_threshold(), collections, settings, @@ -320,6 +279,7 @@ async def handle_user_message( # pylint: disable=too-many-locals,too-many-statem keep_history=settings["keep_history"], message_history=message_history, ), + settings["generative_model"], ) if is_error_prompt: @@ -337,7 +297,7 @@ async def handle_user_message( # pylint: disable=too-many-locals,too-many-statem temperature = settings["temperature"] else: temperature = config.default_temperature_without_search_results - is_error = await get_response( + is_error = await gen_model_provider.get_response( full_prompt, resp, { @@ -379,7 +339,9 @@ async def handle_user_message_api( # pylint: disable=too-many-arguments response = MockMessage(content="", urls=[]) # Check message length - is_valid_length, error_message = await check_message_length(message_content) + is_valid_length, error_message = await embed_model_provider.check_message_length( + message_content, embeddings_model_settings["model"] + ) if not is_valid_length: response.content = error_message return response @@ -418,9 +380,10 @@ async def handle_user_message_api( # pylint: disable=too-many-arguments keep_history=False, message_history=[], ), + generative_model_settings["model"], ) # Process user message and get AI response - is_error = await get_response( + is_error = await gen_model_provider.get_response( full_prompt, response, generative_model_settings, @@ -455,14 +418,6 @@ def get_similarity_threshold() -> float: # If threshold is above 1, cap it at 1 return min(threshold, 1.0) -async def get_embeddings_model_name() -> str: - """Get name of the embeddings model.""" - try: - settings = cl.user_session.get("settings") - return settings.get("embeddings_model") - except ChainlitContextException: - return await get_default_embeddings_model_name() - def check_collections(collections_to_check: list[str]) -> str: """ Verify if the specified collections exist in the vector store. diff --git a/src/rca_accelerator_chatbot/config.py b/src/rca_accelerator_chatbot/config.py index 537749c..1d2e56e 100644 --- a/src/rca_accelerator_chatbot/config.py +++ b/src/rca_accelerator_chatbot/config.py @@ -19,14 +19,10 @@ class Config: generation_llm_api_url: str generation_llm_api_key: str enable_rerank: bool - reranking_model_name: str reranking_model_api_key: str reranking_model_api_url: str - reranking_model_max_context: int embeddings_llm_api_url: str embeddings_llm_api_key: str - embeddings_llm_max_context: int - generative_model_max_context: int default_temperature: float default_temperature_without_search_results: float default_max_tokens: int @@ -62,31 +58,16 @@ def from_env(cls) -> 'Config': "GENERATION_LLM_API_KEY", ""), enable_rerank=os.environ.get( "ENABLE_RERANK", "true").lower() == "true", - reranking_model_name=os.environ.get( - "RERANKING_MODEL_NAME", "BAAI/bge-reranker-v2-m3" - ), reranking_model_api_url=os.environ.get( "RERANKING_MODEL_API_URL", "http://localhost:8001/v1" ), reranking_model_api_key=os.environ.get( "RERANKING_MODEL_API_KEY", "" ), - reranking_model_max_context=int(os.environ.get( - "RERANKING_MODEL_MAX_CONTEXT", - 8192, - )), embeddings_llm_api_url=os.environ.get( "EMBEDDINGS_LLM_API_URL", "http://localhost:8000/v1"), embeddings_llm_api_key=os.environ.get( "EMBEDDINGS_LLM_API_KEY", ""), - embeddings_llm_max_context=int(os.environ.get( - "EMBEDDINGS_LLM_MAX_CONTEXT", - 8192, - )), - generative_model_max_context=int(os.environ.get( - "GENERATIVE_MODEL_MAX_CONTEXT", - 32000, - )), default_temperature=float( os.environ.get("DEFAULT_MODEL_TEMPERATURE", 0.3)), default_temperature_without_search_results=float( diff --git a/src/rca_accelerator_chatbot/embeddings.py b/src/rca_accelerator_chatbot/embeddings.py deleted file mode 100644 index 848b69e..0000000 --- a/src/rca_accelerator_chatbot/embeddings.py +++ /dev/null @@ -1,157 +0,0 @@ -"""Embedding generation and vector search functionality.""" - -from typing import List -from urllib.parse import urlparse - -import chainlit as cl -import httpx -from openai import AsyncOpenAI, OpenAIError - -from rca_accelerator_chatbot.config import config -from rca_accelerator_chatbot.generation import extract_model_ids - -# Initialize embedding LLM client -emb_llm = AsyncOpenAI( - base_url=config.embeddings_llm_api_url, - organization="", - api_key=config.embeddings_llm_api_key, -) - -async def discover_embeddings_model_names() -> List[str]: - """Discover available embedding LLM models.""" - models = await emb_llm.models.list() - return extract_model_ids(models) - -async def get_default_embeddings_model_name() -> str: - """Get name of the default embeddings model.""" - models = await discover_embeddings_model_names() - return models[0] - -async def generate_embedding( - text: str, model_name: str -) -> None | List[float]: - """Generate embeddings for the given text using the specified model.""" - try: - embedding_response = await emb_llm.embeddings.create( - model=model_name, input=text, encoding_format="float" - ) - - if not embedding_response: - cl.logger.error( - "Failed to get embeddings: " + "No response from model %s", model_name - ) - return None - if not embedding_response.data or len(embedding_response.data) == 0: - cl.logger.error( - "Failed to get embeddings: " + "Empty response for model %s", model_name - ) - return None - - return embedding_response.data[0].embedding - except OpenAIError as e: - cl.logger.error("Error generating embeddings: %s", str(e)) - return None - - -async def get_num_tokens( - prompt: str, - model: str, - llm_url: str = config.embeddings_llm_api_url, - api_key: str = config.embeddings_llm_api_key, -) -> int: - """Retrieve the number of tokens required to process the prompt. - - This function calls the /tokenize API endpoint to get the number of - tokens the input will be transformed into when processed by the specified - model (default is the embedding model). - - Args: - prompt: The input text for which to calculate the token count. - model: The model to use for tokenization. - llm_url: The URL of the model. - api_key: The API key used for authentication. - - Raises: - HTTPStatusError: If the response from the /tokenize API endpoint is - not 200 status code. - """ - headers = {"Content-Type": "application/json", "Authorization": f"Bearer {api_key}"} - - data = { - "model": model, - "prompt": prompt, - } - - llm_url_parse = urlparse(llm_url) - tokenize_url = f"{llm_url_parse.scheme}://{llm_url_parse.netloc}/tokenize" - - async with httpx.AsyncClient() as client: - response = await client.post(tokenize_url, headers=headers, json=data) - - if response.status_code == 200: - response_data = response.json() - return response_data["count"] - - response.raise_for_status() - - return 0 - -async def get_rerank_score( - query_text: str, - search_content: str, - model: str = config.reranking_model_name, - reranking_model_url: str = config.reranking_model_api_url, - reranking_model_api_key: str = config.reranking_model_api_key, -) -> float: - """Contact a re-rank model and get a more precise re-ranking score for the search content. - - This function calls the /rerank API endpoint to calculate a new more accurate - score for the search content. First it chunks the search content to fit - the context of the re-rank model, and then it calculates the score for each - such a chunk. The final score is the maximum re-rank score out of all the - chunks. - - Args: - query_text: query text that the search content should be related to. - search_content: Is a chunk retrieved from the vector database. - model: Name of the model to use for re-ranking. - reranking_model_url: URL of the re-rank model. - reranking_model_api_key: API key for the re-rank model. - - Raises: - HTTPStatusError: If the response from the /rerank API endpoint is - not 200 status code. - """ - - headers = { - "Content-Type": "application/json", - "Authorization": f"Bearer {reranking_model_api_key}" - } - - # If the search_content is too big, we have to split it. We use half of the - # reranking_model_max_content because we have to leave space for the user's - # input. - max_chunk_size = config.reranking_model_max_context // 2 - sub_chunks = [ - search_content[i:i + max_chunk_size] - for i in range(0, len(search_content), max_chunk_size) - ] - - data = { - "model": model, - "query": query_text, - "documents": sub_chunks, - } - rerank_url = f"{reranking_model_url}/rerank" - async with httpx.AsyncClient() as client: - response = await client.post(rerank_url, headers=headers, json=data) - - if response.status_code == 200: - response_data = response.json() - if len(response_data["results"]) == 0: - return .0 - return response_data["results"][0].get("relevance_score", .0) - - response.raise_for_status() - - return .0 diff --git a/src/rca_accelerator_chatbot/generation.py b/src/rca_accelerator_chatbot/generation.py deleted file mode 100644 index ce13874..0000000 --- a/src/rca_accelerator_chatbot/generation.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Text generation with large language models.""" - -import chainlit as cl -from openai import AsyncOpenAI, OpenAIError - -from rca_accelerator_chatbot.settings import ModelSettings, ThreadMessages -from rca_accelerator_chatbot.config import config -from rca_accelerator_chatbot.constants import ( - DOCS_PROFILE, RCA_FULL_PROFILE, CI_LOGS_PROFILE -) - -# Initialize generative LLM client -gen_llm = AsyncOpenAI( - base_url=config.generation_llm_api_url, - organization='', - api_key=config.generation_llm_api_key, -) - - -async def discover_generative_model_names() -> list[str]: - """Discover available generative LLM models.""" - models = await gen_llm.models.list() - return extract_model_ids(models) - - -def extract_model_ids(models) -> list[str]: - """Extracts model IDs from the models list.""" - model_ids = [] - for model in models.data: - model_ids.append(model.id) - if not model_ids: - cl.logger.error("No models available.") - return [] - return model_ids - - -def _handle_context_size_limit(err: OpenAIError, - is_api: bool = False) -> str: - if 'reduce the length of the messages or completion' in err.message: - if not is_api : - cl.user_session.set('message_history', '') - return 'Request size with history exceeded limit, ' \ - 'Please start a new thread.' - return str(err) - - -async def get_response(user_message: ThreadMessages, # pylint: disable=too-many-arguments - response_msg: cl.Message, - model_settings: ModelSettings, - is_api: bool = False, - stream_response: bool = True, - step: cl.Step = None) -> bool: - """Send a user's message and generate a response using the LLM. - - Args: - user_message: The user's input message object. - response_msg: The message object to populate with the LLM's - generated response or an error message if something goes wrong. - model_settings: A dictionary containing LLM configuration. - stream_response: Indicates whether we want to stream the response or - get the process in a single chunk. - is_api: Indicates whether the function is called from the API or not. - step: Optional step object to stream reasoning content to. - - Returns: - bool indicating whether the function was successful or not. - """ - is_error = True - - try: - if stream_response: - async for stream_resp in await gen_llm.chat.completions.create( - messages=user_message, stream=stream_response, - **model_settings - ): - if stream_resp.choices and len(stream_resp.choices) > 0: - delta = stream_resp.choices[0].delta - - # Stream content to the response message - if token := delta.content or "": - await response_msg.stream_token(token) - - # Stream reasoning content to the step if it exists - if step and hasattr(delta, "reasoning_content") and delta.reasoning_content: - await step.stream_token(delta.reasoning_content) - else: - response = await gen_llm.chat.completions.create( - messages=user_message, stream=stream_response, - **model_settings - ) - response_msg.content = response.choices[0].message.content or "" - - # If we have a step and reasoning content, update the step output - message = response.choices[0].message - if (step and hasattr(message, "reasoning_content") - and message.reasoning_content): - step.output = response.choices[0].message.reasoning_content - - is_error = False - except OpenAIError as e: - err_msg = _handle_context_size_limit(e, is_api) - if not is_api: - cl.logger.error("Error in process_message_and_get_response: %s", - err_msg) - response_msg.content = ( - f"I encountered an error while generating a response: {err_msg}." - ) - return is_error - - -def get_system_prompt_per_profile(profile_name: str) -> str: - """Get the system prompt for the specified profile. - - Args: - profile_name: The name of the profile for which to get the system prompt. - Returns: - The system prompt for the specified profile. - """ - if profile_name == DOCS_PROFILE: - return config.docs_system_prompt - if profile_name in [CI_LOGS_PROFILE, RCA_FULL_PROFILE]: - return config.ci_logs_system_prompt + config.jira_formatting_syntax_prompt - return config.ci_logs_system_prompt diff --git a/src/rca_accelerator_chatbot/models/__init__.py b/src/rca_accelerator_chatbot/models/__init__.py new file mode 100644 index 0000000..bcfa16e --- /dev/null +++ b/src/rca_accelerator_chatbot/models/__init__.py @@ -0,0 +1,25 @@ +"""Communicate with various models needed by the Chatbot. + +This package initializes three singleton instances used to communicate with: + - the generative model + - the embedding model + - the rerank model + +These singletons are initialized using values provided by the user via the +environment variables (see config module). Note that init_model_providers() +must be called by the consumer of this package in order for the all provider's +functionality to be fully available. +""" +import asyncio +from rca_accelerator_chatbot.models.generative import gen_model_provider +from rca_accelerator_chatbot.models.embeddings import embed_model_provider +from rca_accelerator_chatbot.models.rerank import rerank_model_provider + + +async def init_model_providers(): + """Initialize all providers""" + await asyncio.gather( + gen_model_provider.init(), + embed_model_provider.init(), + rerank_model_provider.init(), + ) diff --git a/src/rca_accelerator_chatbot/models/embeddings.py b/src/rca_accelerator_chatbot/models/embeddings.py new file mode 100644 index 0000000..ab700e1 --- /dev/null +++ b/src/rca_accelerator_chatbot/models/embeddings.py @@ -0,0 +1,41 @@ +"""Provider for the embedding model.""" + +import chainlit as cl +from openai import OpenAIError + +from rca_accelerator_chatbot.config import config +from rca_accelerator_chatbot.models.model import ModelProvider + + +class EmbeddingsModelProvider(ModelProvider): + """Embeddings model provider""" + + def __init__(self, + base_url: str = config.embeddings_llm_api_url, + api_key: str = config.embeddings_llm_api_key): + super().__init__(base_url, api_key) + + async def generate_embedding(self, text: str, model_name: str) -> None | list[float]: + """Generate embeddings for the given text using the specified model.""" + try: + embedding_response = await self.llm.embeddings.create( + model=model_name, input=text, encoding_format="float" + ) + + if not embedding_response: + cl.logger.error( + "Failed to get embeddings: " + "No response from model %s", model_name + ) + return None + if not embedding_response.data or len(embedding_response.data) == 0: + cl.logger.error( + "Failed to get embeddings: " + "Empty response for model %s", model_name + ) + return None + + return embedding_response.data[0].embedding + except OpenAIError as e: + cl.logger.error("Error generating embeddings: %s", str(e)) + return None + +embed_model_provider = EmbeddingsModelProvider() diff --git a/src/rca_accelerator_chatbot/models/generative.py b/src/rca_accelerator_chatbot/models/generative.py new file mode 100644 index 0000000..264a973 --- /dev/null +++ b/src/rca_accelerator_chatbot/models/generative.py @@ -0,0 +1,93 @@ +"""Provider for the generative model.""" +import chainlit as cl +from openai import OpenAIError + +from rca_accelerator_chatbot.settings import ModelSettings, ThreadMessages +from rca_accelerator_chatbot.config import config +from rca_accelerator_chatbot.models.model import ModelProvider + + +class GenerativeModelProvider(ModelProvider): + """Generative model provider""" + + def __init__(self, + base_url: str = config.generation_llm_api_url, + api_key: str = config.generation_llm_api_key): + super().__init__(base_url, api_key) + + # pylint: disable=too-many-arguments + async def get_response(self, + user_message: ThreadMessages, + response_msg: cl.Message, + model_settings: ModelSettings, + is_api: bool = False, + stream_response: bool = True, + step: cl.Step = None) -> bool: + """Send a user's message and generate a response using the LLM. + + Args: + user_message: The user's input message object. + response_msg: The message object to populate with the LLM's + generated response or an error message if something goes wrong. + model_settings: A dictionary containing LLM configuration. + stream_response: Indicates whether we want to stream the response or + get the process in a single chunk. + is_api: Indicates whether the function is called from the API or not. + step: Optional step object to stream reasoning content to. + + Returns: + bool indicating whether the function was successful or not. + """ + is_error = True + + try: + if stream_response: + async for stream_resp in await self.llm.chat.completions.create( + messages=user_message, stream=stream_response, + **model_settings + ): + if stream_resp.choices and len(stream_resp.choices) > 0: + delta = stream_resp.choices[0].delta + + # Stream content to the response message + if token := delta.content or "": + await response_msg.stream_token(token) + + # Stream reasoning content to the step if it exists + if step and hasattr(delta, "reasoning_content") and delta.reasoning_content: + await step.stream_token(delta.reasoning_content) + else: + response = await self.llm.chat.completions.create( + messages=user_message, stream=stream_response, + **model_settings + ) + response_msg.content = response.choices[0].message.content or "" + + # If we have a step and reasoning content, update the step output + message = response.choices[0].message + if (step and hasattr(message, "reasoning_content") + and message.reasoning_content): + step.output = response.choices[0].message.reasoning_content + + is_error = False + except OpenAIError as e: + err_msg = GenerativeModelProvider._handle_context_size_limit(e, is_api) + if not is_api: + cl.logger.error("Error in process_message_and_get_response: %s", + err_msg) + response_msg.content = ( + f"I encountered an error while generating a response: {err_msg}." + ) + return is_error + + @staticmethod + def _handle_context_size_limit(err: OpenAIError, + is_api: bool = False) -> str: + if 'reduce the length of the messages or completion' in err.message: + if not is_api : + cl.user_session.set('message_history', '') + return 'Request size with history exceeded limit, ' \ + 'Please start a new thread.' + return str(err) + +gen_model_provider = GenerativeModelProvider() diff --git a/src/rca_accelerator_chatbot/models/model.py b/src/rca_accelerator_chatbot/models/model.py new file mode 100644 index 0000000..6b446a7 --- /dev/null +++ b/src/rca_accelerator_chatbot/models/model.py @@ -0,0 +1,146 @@ +"""Base class for the model providers""" +import time +from functools import cached_property +from urllib.parse import urlparse + +from openai import AsyncOpenAI +import httpx + +from rca_accelerator_chatbot.config import config + + +class ModelProvider: + """Base class for the model providers + + It is required to call init() after the class is initialized to ensure that + all the fields are populated and all functionality is available. + """ + + def __init__(self, base_url: str, api_key: str, cache_timeout: int = 60): + self.base_url = base_url + self.api_key = api_key + self.llm = AsyncOpenAI( + base_url=self.base_url, + organization='', + api_key=self.api_key, + ) + + self._api_response = None + self._api_response_time = 0 + self.cache_timeout = cache_timeout + + async def init(self) -> None: + """Initialize the object with the data from the /models api response + + This function must be called after the object is instantiated to ensure + that all the functionality is unlocked. It caches the response from /models + api response. + """ + cache_expired = time.time() > self._api_response_time + self.cache_timeout + if self._api_response is not None and not cache_expired: + return + + # List available models at the API endpoint + models_page = await self.llm.models.list() + self._api_response_time = time.time() + self._api_response = [model.model_dump() for model in models_page.data] + if not self._api_response: + raise RuntimeError(f"No model discovered at {self.base_url}") + + @cached_property + def all_model_names(self) -> list[str]: + """Return all models available at the base_url endpoint.""" + if self._api_response is None: + raise RuntimeError(f"{__class__} was not initialized - init().") + + return [response["id"] for response in self._api_response] + + def get_context_size(self, model_name: str) -> int: + """Return max context size of the model""" + if self._api_response is None: + raise RuntimeError(f"{__class__} was not initialized - init().") + + if model_name not in self.all_model_names: + raise RuntimeError(f"{model_name} model is not available at {self.base_url}") + + for response in self._api_response: + if response["id"] == model_name: + return response["max_model_len"] + + return 0 + + async def get_num_tokens(self, prompt: str, model_name: str) -> int: + """Retrieve the number of tokens required to process the prompt. + + This function calls the /tokenize API endpoint to get the number of + tokens the input will be transformed into when processed by the specified + model (default is the embedding model). + + Args: + prompt: The input text for which to calculate the token count. + model_name: The name of the model whose tokenizer we will use + + Raises: + HTTPStatusError: If the response from the /tokenize API endpoint is + not 200 status code. + """ + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}"} + + data = { + "model": model_name, + "prompt": prompt, + } + + llm_url_parse = urlparse(self.base_url) + tokenize_url = f"{llm_url_parse.scheme}://{llm_url_parse.netloc}/tokenize" + + async with httpx.AsyncClient() as client: + response = await client.post(tokenize_url, headers=headers, json=data) + + if response.status_code == 200: + response_data = response.json() + return response_data["count"] + + response.raise_for_status() + + return 0 + + async def check_message_length(self, message_content: str, model_name: str) -> tuple[bool, str]: + """Check if the message content exceeds the token limit. + + Args: + message_content: The content to check + model_name: The name of the model that should be used for tokenization + + Returns: + A tuple containing: + - bool: True if the message is within length limits, False otherwise + - str: Error message if the length check fails, empty string otherwise + """ + try: + num_required_tokens = await self.get_num_tokens(message_content, model_name) + except httpx.HTTPStatusError: + return False, "We've encountered an issue. Please try again later ..." + + embed_model_max_context = self.get_context_size(model_name) + if num_required_tokens > embed_model_max_context: + # Calculate the maximum character limit estimation for the embedding model. + approx_max_chars = round( + embed_model_max_context * config.chars_per_token_estimation, -2) + + error_message = ( + "⚠️ **Your input is too lengthy!**\n We can process inputs of up " + f"to approximately {approx_max_chars} characters. The exact limit " + "may vary depending on the input type. For instance, plain text " + "inputs can be longer compared to logs or structured data " + "containing special characters (e.g., `[`, `]`, `:`, etc.).\n\n" + "To proceed, please:\n" + " - Focus on including only the most relevant details, and\n" + " - Shorten your input if possible." + " \n\n" + "To let you continue, we will reset the conversation history.\n" + "Please start over with a shorter input." + ) + return False, error_message + + return True, "" diff --git a/src/rca_accelerator_chatbot/models/rerank.py b/src/rca_accelerator_chatbot/models/rerank.py new file mode 100644 index 0000000..6e2ab3a --- /dev/null +++ b/src/rca_accelerator_chatbot/models/rerank.py @@ -0,0 +1,73 @@ +"""Provider for the rerank model""" +import httpx + +from rca_accelerator_chatbot.config import config +from rca_accelerator_chatbot.models.model import ModelProvider + + +class RerankModelProvider(ModelProvider): + """Rerank model provider""" + + def __init__(self, + base_url: str = config.reranking_model_api_url, + api_key: str = config.reranking_model_api_key): + super().__init__(base_url, api_key) + + async def get_rerank_score( + self, + query_text: str, + search_content: str, + model: str, + ) -> float: + """Contact a re-rank model and get a more precise re-ranking score for the search content. + + This function calls the /rerank API endpoint to calculate a new more accurate + score for the search content. First it chunks the search content to fit + the context of the re-rank model, and then it calculates the score for each + such a chunk. The final score is the maximum re-rank score out of all the + chunks. + + Args: + query_text: query text that the search content should be related to. + search_content: Is a chunk retrieved from the vector database. + model: Name of the model to use for re-ranking. + + Raises: + HTTPStatusError: If the response from the /rerank API endpoint is + not 200 status code. + """ + + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.api_key}" + } + + # If the search_content is too big, we have to split it. We use half of the + # reranking_model_max_content because we have to leave space for the user's + # input. + max_chunk_size = self.get_context_size(model) // 2 + sub_chunks = [ + search_content[i:i + max_chunk_size] + for i in range(0, len(search_content), max_chunk_size) + ] + + data = { + "model": model, + "query": query_text, + "documents": sub_chunks, + } + rerank_url = f"{self.base_url}/rerank" + async with httpx.AsyncClient() as client: + response = await client.post(rerank_url, headers=headers, json=data) + + if response.status_code == 200: + response_data = response.json() + if len(response_data["results"]) == 0: + return .0 + return response_data["results"][0].get("relevance_score", .0) + + response.raise_for_status() + + return .0 + +rerank_model_provider = RerankModelProvider() diff --git a/src/rca_accelerator_chatbot/prompt.py b/src/rca_accelerator_chatbot/prompt.py index d1d4f45..e6fabdf 100644 --- a/src/rca_accelerator_chatbot/prompt.py +++ b/src/rca_accelerator_chatbot/prompt.py @@ -6,10 +6,11 @@ from rca_accelerator_chatbot.config import config from rca_accelerator_chatbot.constants import ( - NO_RESULTS_FOUND, SEARCH_RESULTS_TEMPLATE, SEARCH_RESULT_TRUNCATED_CHUNK + NO_RESULTS_FOUND, SEARCH_RESULTS_TEMPLATE, SEARCH_RESULT_TRUNCATED_CHUNK, + DOCS_PROFILE, CI_LOGS_PROFILE, RCA_FULL_PROFILE ) +from rca_accelerator_chatbot.models import gen_model_provider from rca_accelerator_chatbot.settings import HistorySettings, ThreadMessages -from rca_accelerator_chatbot.generation import get_system_prompt_per_profile def search_result_to_str(search_result: dict) -> str: @@ -40,6 +41,7 @@ async def build_prompt( user_message: str, profile_name: str, history_settings: HistorySettings, + model_name: str, ) -> (bool, ThreadMessages): """Generate a full prompt that gets sent to the generative model. @@ -59,7 +61,7 @@ async def build_prompt( user_message: The user's message content history_settings: Settings for the message history profile_name: The name of the profile for which to generate the prompt - + model_name: The name of the generative model Returns: tuple: A tuple containing two elements: - List of messages that make up the full prompt. Each return value @@ -84,7 +86,7 @@ async def build_prompt( full_prompt_len += len(str(full_prompt)) # Calculate the maximum character limit estimation for the generative model. - approx_max_chars = (config.generative_model_max_context * + approx_max_chars = (gen_model_provider.get_context_size(model_name) * config.generative_model_max_context_percentage * config.chars_per_token_estimation) @@ -138,3 +140,17 @@ async def build_prompt( )) return is_error, full_prompt + +def get_system_prompt_per_profile(profile_name: str) -> str: + """Get the system prompt for the specified profile. + + Args: + profile_name: The name of the profile for which to get the system prompt. + Returns: + The system prompt for the specified profile. + """ + if profile_name == DOCS_PROFILE: + return config.docs_system_prompt + if profile_name in [CI_LOGS_PROFILE, RCA_FULL_PROFILE]: + return config.ci_logs_system_prompt + config.jira_formatting_syntax_prompt + return config.ci_logs_system_prompt