Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/rca_accelerator_chatbot/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from rca_accelerator_chatbot.settings import ModelSettings
from rca_accelerator_chatbot.auth import authentification
from rca_accelerator_chatbot.models import (
gen_model_provider, embed_model_provider, init_model_providers
gen_model_provider, embed_model_provider, rerank_model_provider, init_model_providers
)

app = FastAPI(title="RCAccelerator API")
Expand All @@ -34,6 +34,7 @@ class BaseModelSettings(BaseModel):
max_tokens: int = Field(config.default_max_tokens, gt=1, le=1024)
generative_model_name: str = Field("")
embeddings_model_name: str = Field("")
rerank_model_name: str = Field("")
profile_name: str = Field(CI_LOGS_PROFILE)
enable_rerank: bool = Field(config.enable_rerank)

Expand Down Expand Up @@ -76,6 +77,15 @@ async def validate_settings(request: BaseModelSettings) -> BaseModelSettings:
detail=f"Invalid embeddings model. Available: {available_embedding_models}"
)

available_rerank_models = rerank_model_provider.all_model_names
if not request.rerank_model_name:
request.rerank_model_name = available_rerank_models[0]
elif request.rerank_model_name not in available_rerank_models:
raise HTTPException(
status_code=400,
detail=f"Invalid rerank model. Available: {available_rerank_models}"
)

if request.profile_name not in [CI_LOGS_PROFILE, DOCS_PROFILE, RCA_FULL_PROFILE]:
raise HTTPException(
status_code=400,
Expand Down Expand Up @@ -224,12 +234,16 @@ async def process_prompt(
embeddings_model_settings: ModelSettings = {
"model": message_data.embeddings_model_name,
}
rerank_model_settings: ModelSettings = {
"model": message_data.rerank_model_name,
}

response = await handle_user_message_api(
message_data.content,
message_data.similarity_threshold,
generative_model_settings,
embeddings_model_settings,
rerank_model_settings,
message_data.profile_name,
message_data.enable_rerank,
)
Expand Down Expand Up @@ -263,6 +277,9 @@ async def process_rca(
embeddings_model_settings: ModelSettings = {
"model": request.embeddings_model_name,
}
rerank_model_settings: ModelSettings = {
"model": request.rerank_model_name,
}

unique_items = {}
for item in traceback_items:
Expand All @@ -278,6 +295,7 @@ async def process_rca(
similarity_threshold=request.similarity_threshold,
generative_model_settings=generative_model_settings,
embeddings_model_settings=embeddings_model_settings,
rerank_model_settings=rerank_model_settings,
profile_name=request.profile_name,
enable_rerank=request.enable_rerank,
)
Expand Down
9 changes: 7 additions & 2 deletions src/rca_accelerator_chatbot/chat.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""Handler for chat messages and responses."""
from dataclasses import dataclass
import chainlit as cl
import httpx

from openai.types.chat import ChatCompletionAssistantMessageParam
import httpx
import chainlit as cl

from rca_accelerator_chatbot import constants
from rca_accelerator_chatbot.prompt import build_prompt
Expand Down Expand Up @@ -330,6 +331,7 @@ async def handle_user_message_api( # pylint: disable=too-many-arguments
similarity_threshold: float,
generative_model_settings: ModelSettings,
embeddings_model_settings: ModelSettings,
rerank_model_settings: ModelSettings,
profile_name: str,
enable_rerank: bool = True,
) -> MockMessage:
Expand Down Expand Up @@ -362,6 +364,9 @@ async def handle_user_message_api( # pylint: disable=too-many-arguments
settings={
"enable_rerank": enable_rerank,
"rerank_top_n": config.rerank_top_n,
"rerank_model": rerank_model_settings["model"],
"generative_model": generative_model_settings["model"],
"embeddings_model": embeddings_model_settings["model"],
},
)
except httpx.HTTPStatusError:
Expand Down
2 changes: 1 addition & 1 deletion src/rca_accelerator_chatbot/models/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
"""Provider for the embedding model."""

import chainlit as cl
from openai import OpenAIError
import chainlit as cl

from rca_accelerator_chatbot.config import config
from rca_accelerator_chatbot.models.model import ModelProvider
Expand Down
2 changes: 1 addition & 1 deletion src/rca_accelerator_chatbot/models/generative.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""Provider for the generative model."""
import chainlit as cl
from openai import OpenAIError
import chainlit as cl

from rca_accelerator_chatbot.settings import ModelSettings, ThreadMessages
from rca_accelerator_chatbot.config import config
Expand Down
3 changes: 2 additions & 1 deletion src/rca_accelerator_chatbot/vectordb.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Vector database client for RAG operations."""

from typing import List
import chainlit as cl

from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import ApiException
import chainlit as cl

from rca_accelerator_chatbot.config import config

Expand Down
Loading