Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions src/rca_accelerator_chatbot/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,14 @@
from fastapi.security import APIKeyHeader
from pydantic import BaseModel, Field, HttpUrl

from rca_accelerator_chatbot.constants import (
CI_LOGS_PROFILE, DOCS_PROFILE, RCA_FULL_PROFILE
)
from rca_accelerator_chatbot.constants import CI_LOGS_PROFILE, DOCS_PROFILE, RCA_FULL_PROFILE
from rca_accelerator_chatbot.chat import handle_user_message_api
from rca_accelerator_chatbot.config import config
from rca_accelerator_chatbot.settings import ModelSettings
from rca_accelerator_chatbot.generation import discover_generative_model_names
from rca_accelerator_chatbot.embeddings import discover_embeddings_model_names
from rca_accelerator_chatbot.auth import authentification
from rca_accelerator_chatbot.models import (
gen_model_provider, embed_model_provider, init_model_providers
)

app = FastAPI(title="RCAccelerator API")

Expand Down Expand Up @@ -52,10 +51,14 @@ class RcaRequest(BaseModelSettings):
async def validate_settings(request: BaseModelSettings) -> BaseModelSettings:
"""Validate the settings for any request.
This function performs checks to ensure the API request is valid.
Some checks are performed asynchronously which is why we don't use
Some checks are performed asynchronously, which is why we don't use
the built-in Pydantic validators.
"""
available_generative_models = await discover_generative_model_names()
# Make sure we pull the latest info about running models. Note that the responses
# are cached by the providers for a certain amount of time.
await init_model_providers()

available_generative_models = gen_model_provider.all_model_names
if not request.generative_model_name:
request.generative_model_name = available_generative_models[0]
elif request.generative_model_name not in available_generative_models:
Expand All @@ -64,7 +67,7 @@ async def validate_settings(request: BaseModelSettings) -> BaseModelSettings:
detail=f"Invalid generative model. Available: {available_generative_models}"
)

available_embedding_models = await discover_embeddings_model_names()
available_embedding_models = embed_model_provider.all_model_names
if not request.embeddings_model_name:
request.embeddings_model_name = available_embedding_models[0]
elif request.embeddings_model_name not in available_embedding_models:
Expand Down
21 changes: 14 additions & 7 deletions src/rca_accelerator_chatbot/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
from rca_accelerator_chatbot import constants
from rca_accelerator_chatbot.chat import handle_user_message
from rca_accelerator_chatbot.auth import authentification
from rca_accelerator_chatbot.generation import discover_generative_model_names
from rca_accelerator_chatbot.embeddings import discover_embeddings_model_names
from rca_accelerator_chatbot.models import (
init_model_providers, gen_model_provider, embed_model_provider, rerank_model_provider
)


@cl.set_chat_profiles
Expand Down Expand Up @@ -63,7 +64,6 @@ async def chat_profile() -> list[cl.ChatProfile]:
)
]


@cl.on_chat_start
async def init_chat():
"""
Expand All @@ -73,17 +73,18 @@ async def init_chat():
"""

cl.user_session.set("counter", 0)
await init_model_providers()
await setup_chat_settings()


async def setup_chat_settings():
"""
Set up the chat settings interface with model selection,
temperature, token limits, and other configuration options.
"""
generative_model_names = await discover_generative_model_names()
embeddings_model_names = await discover_embeddings_model_names()
if not generative_model_names or not embeddings_model_names:
generative_model_names = gen_model_provider.all_model_names
embeddings_model_names = embed_model_provider.all_model_names
rerank_model_names = rerank_model_provider.all_model_names
if not generative_model_names or not embeddings_model_names or not rerank_model_names:
await cl.Message(
content="No generative or embeddings model found. "
"Please check your configuration."
Expand All @@ -104,6 +105,12 @@ async def setup_chat_settings():
values=embeddings_model_names,
initial_index=0,
),
Select(
id="rerank_model",
label="Re-rank model",
values=rerank_model_names,
initial_index=0
),
Slider(
id="temperature",
label="Model Temperature",
Expand Down
97 changes: 26 additions & 71 deletions src/rca_accelerator_chatbot/chat.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
"""Handler for chat messages and responses."""
from dataclasses import dataclass
import chainlit as cl
from chainlit.context import ChainlitContextException
import httpx
from openai.types.chat import ChatCompletionAssistantMessageParam

from rca_accelerator_chatbot import constants
from rca_accelerator_chatbot.prompt import build_prompt
from rca_accelerator_chatbot.vectordb import vector_store
from rca_accelerator_chatbot.generation import get_response
from rca_accelerator_chatbot.embeddings import (
get_num_tokens, generate_embedding,
get_rerank_score, get_default_embeddings_model_name
)
from rca_accelerator_chatbot.settings import ModelSettings, HistorySettings, ThreadMessages
from rca_accelerator_chatbot.config import config
from rca_accelerator_chatbot.models import (
gen_model_provider, embed_model_provider, rerank_model_provider
)
from rca_accelerator_chatbot.constants import (
DOCS_PROFILE,
RCA_FULL_PROFILE,
Expand Down Expand Up @@ -56,7 +53,9 @@ async def perform_multi_collection_search(
collections: A list of collections to search.
settings: The settings user provided through the UI.
"""
embedding = await generate_embedding(message_content, embeddings_model_name)
embedding = await embed_model_provider.generate_embedding(
message_content, embeddings_model_name
)
if embedding is None:
return []

Expand All @@ -69,7 +68,9 @@ async def perform_multi_collection_search(
for r in results:
r['collection'] = collection
if settings['enable_rerank']:
r['rerank_score'] = await get_rerank_score(message_content, r['text'])
r['rerank_score'] = await rerank_model_provider.get_rerank_score(
message_content, r['text'], settings["rerank_model"]
)
else:
r['rerank_score'] = None

Expand Down Expand Up @@ -119,49 +120,6 @@ def update_msg_count():
counter += 1
cl.user_session.set("counter", counter)


async def check_message_length(message_content: str) -> tuple[bool, str]:
"""
Check if the message content exceeds the token limit.

Args:
message_content: The content to check

Returns:
A tuple containing:
- bool: True if the message is within length limits, False otherwise
- str: Error message if the length check fails, empty string otherwise
"""
try:
num_required_tokens = await get_num_tokens(message_content,
await get_embeddings_model_name())
except httpx.HTTPStatusError as e:
cl.logger.error(e)
return False, "We've encountered an issue. Please try again later ..."

if num_required_tokens > config.embeddings_llm_max_context:
# Calculate the maximum character limit estimation for the embedding model.
approx_max_chars = round(
config.embeddings_llm_max_context * config.chars_per_token_estimation, -2)

error_message = (
"⚠️ **Your input is too lengthy!**\n We can process inputs of up "
f"to approximately {approx_max_chars} characters. The exact limit "
"may vary depending on the input type. For instance, plain text "
"inputs can be longer compared to logs or structured data "
"containing special characters (e.g., `[`, `]`, `:`, etc.).\n\n"
"To proceed, please:\n"
" - Focus on including only the most relevant details, and\n"
" - Shorten your input if possible."
" \n\n"
"To let you continue, we will reset the conversation history.\n"
"Please start over with a shorter input."
)
return False, error_message

return True, ""


async def print_debug_content(
settings: dict,
search_content: str,
Expand Down Expand Up @@ -192,8 +150,9 @@ async def print_debug_content(
)

# Display the number of tokens in the search content
num_t = await get_num_tokens(search_content,
await get_embeddings_model_name())
num_t = await embed_model_provider.get_num_tokens(
search_content, settings["embeddings_model"]
)
debug_content += f"**Number of tokens in search content:** {num_t}\n\n"

# Display vector DB debug information if debug mode is enabled
Expand Down Expand Up @@ -229,13 +188,13 @@ def _build_search_content(message_history: ThreadMessages,
if message['role'] == 'user':
previous_message_content += f"\n{message['content']}"

# Limit the size of the message content as this is passed as query to
# the reranking model. This is brute truncation, but can be improved
# Limit the size of the message content as this is passed as a query to
# the reranking model. This is brute truncation but can be improved
# when we handle message history better.
if settings['enable_rerank']:
max_text_len_from_history = (
config.reranking_model_max_context // 2 - len(
current_message)) - 1
rerank_model_provider.get_context_size(settings["rerank_model"]) // 2
- len(current_message)) - 1
return previous_message_content[
:max_text_len_from_history] + current_message
return previous_message_content + current_message
Expand Down Expand Up @@ -271,8 +230,8 @@ async def handle_user_message( # pylint: disable=too-many-locals,too-many-statem
settings)

# Check message length
is_valid_length, error_message = await check_message_length(
search_content)
is_valid_length, error_message = await embed_model_provider.check_message_length(
search_content, settings["embeddings_model"])
if not is_valid_length:
resp.content = error_message
# Reset message history to let the user try again
Expand All @@ -296,7 +255,7 @@ async def handle_user_message( # pylint: disable=too-many-locals,too-many-statem
try:
search_results = await perform_multi_collection_search(
search_content,
await get_embeddings_model_name(),
settings["embeddings_model"],
get_similarity_threshold(),
collections,
settings,
Expand All @@ -320,6 +279,7 @@ async def handle_user_message( # pylint: disable=too-many-locals,too-many-statem
keep_history=settings["keep_history"],
message_history=message_history,
),
settings["generative_model"],
)

if is_error_prompt:
Expand All @@ -337,7 +297,7 @@ async def handle_user_message( # pylint: disable=too-many-locals,too-many-statem
temperature = settings["temperature"]
else:
temperature = config.default_temperature_without_search_results
is_error = await get_response(
is_error = await gen_model_provider.get_response(
full_prompt,
resp,
{
Expand Down Expand Up @@ -379,7 +339,9 @@ async def handle_user_message_api( # pylint: disable=too-many-arguments
response = MockMessage(content="", urls=[])

# Check message length
is_valid_length, error_message = await check_message_length(message_content)
is_valid_length, error_message = await embed_model_provider.check_message_length(
message_content, embeddings_model_settings["model"]
)
if not is_valid_length:
response.content = error_message
return response
Expand Down Expand Up @@ -418,9 +380,10 @@ async def handle_user_message_api( # pylint: disable=too-many-arguments
keep_history=False,
message_history=[],
),
generative_model_settings["model"],
)
# Process user message and get AI response
is_error = await get_response(
is_error = await gen_model_provider.get_response(
full_prompt,
response,
generative_model_settings,
Expand Down Expand Up @@ -455,14 +418,6 @@ def get_similarity_threshold() -> float:
# If threshold is above 1, cap it at 1
return min(threshold, 1.0)

async def get_embeddings_model_name() -> str:
"""Get name of the embeddings model."""
try:
settings = cl.user_session.get("settings")
return settings.get("embeddings_model")
except ChainlitContextException:
return await get_default_embeddings_model_name()

def check_collections(collections_to_check: list[str]) -> str:
"""
Verify if the specified collections exist in the vector store.
Expand Down
19 changes: 0 additions & 19 deletions src/rca_accelerator_chatbot/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,10 @@ class Config:
generation_llm_api_url: str
generation_llm_api_key: str
enable_rerank: bool
reranking_model_name: str
reranking_model_api_key: str
reranking_model_api_url: str
reranking_model_max_context: int
embeddings_llm_api_url: str
embeddings_llm_api_key: str
embeddings_llm_max_context: int
generative_model_max_context: int
default_temperature: float
default_temperature_without_search_results: float
default_max_tokens: int
Expand Down Expand Up @@ -62,31 +58,16 @@ def from_env(cls) -> 'Config':
"GENERATION_LLM_API_KEY", ""),
enable_rerank=os.environ.get(
"ENABLE_RERANK", "true").lower() == "true",
reranking_model_name=os.environ.get(
"RERANKING_MODEL_NAME", "BAAI/bge-reranker-v2-m3"
),
reranking_model_api_url=os.environ.get(
"RERANKING_MODEL_API_URL", "http://localhost:8001/v1"
),
reranking_model_api_key=os.environ.get(
"RERANKING_MODEL_API_KEY", ""
),
reranking_model_max_context=int(os.environ.get(
"RERANKING_MODEL_MAX_CONTEXT",
8192,
)),
embeddings_llm_api_url=os.environ.get(
"EMBEDDINGS_LLM_API_URL", "http://localhost:8000/v1"),
embeddings_llm_api_key=os.environ.get(
"EMBEDDINGS_LLM_API_KEY", ""),
embeddings_llm_max_context=int(os.environ.get(
"EMBEDDINGS_LLM_MAX_CONTEXT",
8192,
)),
generative_model_max_context=int(os.environ.get(
"GENERATIVE_MODEL_MAX_CONTEXT",
32000,
)),
default_temperature=float(
os.environ.get("DEFAULT_MODEL_TEMPERATURE", 0.3)),
default_temperature_without_search_results=float(
Expand Down
Loading
Loading