From 78ce8185a66d36b22d46bd549eb2af874833bab2 Mon Sep 17 00:00:00 2001 From: Lukas Piwowarski Date: Thu, 15 May 2025 13:32:30 +0200 Subject: [PATCH 1/2] Add rerank_model to API interface Let's add the option to pick up the rerank_model into the API. If no rerank model is specified, we pick the first available at the given endpoint. --- src/rca_accelerator_chatbot/api.py | 20 +++++++++++++++++++- src/rca_accelerator_chatbot/chat.py | 4 ++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/rca_accelerator_chatbot/api.py b/src/rca_accelerator_chatbot/api.py index 056435d..ccc9710 100644 --- a/src/rca_accelerator_chatbot/api.py +++ b/src/rca_accelerator_chatbot/api.py @@ -18,7 +18,7 @@ from rca_accelerator_chatbot.settings import ModelSettings from rca_accelerator_chatbot.auth import authentification from rca_accelerator_chatbot.models import ( - gen_model_provider, embed_model_provider, init_model_providers + gen_model_provider, embed_model_provider, rerank_model_provider, init_model_providers ) app = FastAPI(title="RCAccelerator API") @@ -34,6 +34,7 @@ class BaseModelSettings(BaseModel): max_tokens: int = Field(config.default_max_tokens, gt=1, le=1024) generative_model_name: str = Field("") embeddings_model_name: str = Field("") + rerank_model_name: str = Field("") profile_name: str = Field(CI_LOGS_PROFILE) enable_rerank: bool = Field(config.enable_rerank) @@ -76,6 +77,15 @@ async def validate_settings(request: BaseModelSettings) -> BaseModelSettings: detail=f"Invalid embeddings model. Available: {available_embedding_models}" ) + available_rerank_models = rerank_model_provider.all_model_names + if not request.rerank_model_name: + request.rerank_model_name = available_rerank_models[0] + elif request.rerank_model_name not in available_rerank_models: + raise HTTPException( + status_code=400, + detail=f"Invalid rerank model. Available: {available_rerank_models}" + ) + if request.profile_name not in [CI_LOGS_PROFILE, DOCS_PROFILE, RCA_FULL_PROFILE]: raise HTTPException( status_code=400, @@ -224,12 +234,16 @@ async def process_prompt( embeddings_model_settings: ModelSettings = { "model": message_data.embeddings_model_name, } + rerank_model_settings: ModelSettings = { + "model": message_data.rerank_model_name, + } response = await handle_user_message_api( message_data.content, message_data.similarity_threshold, generative_model_settings, embeddings_model_settings, + rerank_model_settings, message_data.profile_name, message_data.enable_rerank, ) @@ -263,6 +277,9 @@ async def process_rca( embeddings_model_settings: ModelSettings = { "model": request.embeddings_model_name, } + rerank_model_settings: ModelSettings = { + "model": request.rerank_model_name, + } unique_items = {} for item in traceback_items: @@ -278,6 +295,7 @@ async def process_rca( similarity_threshold=request.similarity_threshold, generative_model_settings=generative_model_settings, embeddings_model_settings=embeddings_model_settings, + rerank_model_settings=rerank_model_settings, profile_name=request.profile_name, enable_rerank=request.enable_rerank, ) diff --git a/src/rca_accelerator_chatbot/chat.py b/src/rca_accelerator_chatbot/chat.py index bcc4f73..9601848 100644 --- a/src/rca_accelerator_chatbot/chat.py +++ b/src/rca_accelerator_chatbot/chat.py @@ -330,6 +330,7 @@ async def handle_user_message_api( # pylint: disable=too-many-arguments similarity_threshold: float, generative_model_settings: ModelSettings, embeddings_model_settings: ModelSettings, + rerank_model_settings: ModelSettings, profile_name: str, enable_rerank: bool = True, ) -> MockMessage: @@ -362,6 +363,9 @@ async def handle_user_message_api( # pylint: disable=too-many-arguments settings={ "enable_rerank": enable_rerank, "rerank_top_n": config.rerank_top_n, + "rerank_model": rerank_model_settings["model"], + "generative_model": generative_model_settings["model"], + "embeddings_model": embeddings_model_settings["model"], }, ) except httpx.HTTPStatusError: From 158f618807ce31fc00e5e188251e8748655333a2 Mon Sep 17 00:00:00 2001 From: Lukas Piwowarski Date: Thu, 15 May 2025 13:37:20 +0200 Subject: [PATCH 2/2] Fix import order Tox is complaining about import order. Let's fix it. --- src/rca_accelerator_chatbot/chat.py | 5 +++-- src/rca_accelerator_chatbot/models/embeddings.py | 2 +- src/rca_accelerator_chatbot/models/generative.py | 2 +- src/rca_accelerator_chatbot/vectordb.py | 3 ++- 4 files changed, 7 insertions(+), 5 deletions(-) diff --git a/src/rca_accelerator_chatbot/chat.py b/src/rca_accelerator_chatbot/chat.py index 9601848..190f660 100644 --- a/src/rca_accelerator_chatbot/chat.py +++ b/src/rca_accelerator_chatbot/chat.py @@ -1,8 +1,9 @@ """Handler for chat messages and responses.""" from dataclasses import dataclass -import chainlit as cl -import httpx + from openai.types.chat import ChatCompletionAssistantMessageParam +import httpx +import chainlit as cl from rca_accelerator_chatbot import constants from rca_accelerator_chatbot.prompt import build_prompt diff --git a/src/rca_accelerator_chatbot/models/embeddings.py b/src/rca_accelerator_chatbot/models/embeddings.py index ab700e1..eb426a1 100644 --- a/src/rca_accelerator_chatbot/models/embeddings.py +++ b/src/rca_accelerator_chatbot/models/embeddings.py @@ -1,7 +1,7 @@ """Provider for the embedding model.""" -import chainlit as cl from openai import OpenAIError +import chainlit as cl from rca_accelerator_chatbot.config import config from rca_accelerator_chatbot.models.model import ModelProvider diff --git a/src/rca_accelerator_chatbot/models/generative.py b/src/rca_accelerator_chatbot/models/generative.py index 264a973..6887f3e 100644 --- a/src/rca_accelerator_chatbot/models/generative.py +++ b/src/rca_accelerator_chatbot/models/generative.py @@ -1,6 +1,6 @@ """Provider for the generative model.""" -import chainlit as cl from openai import OpenAIError +import chainlit as cl from rca_accelerator_chatbot.settings import ModelSettings, ThreadMessages from rca_accelerator_chatbot.config import config diff --git a/src/rca_accelerator_chatbot/vectordb.py b/src/rca_accelerator_chatbot/vectordb.py index 353c48a..8e1fcbc 100644 --- a/src/rca_accelerator_chatbot/vectordb.py +++ b/src/rca_accelerator_chatbot/vectordb.py @@ -1,9 +1,10 @@ """Vector database client for RAG operations.""" from typing import List -import chainlit as cl + from qdrant_client import QdrantClient from qdrant_client.http.exceptions import ApiException +import chainlit as cl from rca_accelerator_chatbot.config import config