From 95b7a554e627acee24ce38d5038a7caa0f95f18a Mon Sep 17 00:00:00 2001 From: Elijah Williams Date: Mon, 18 Aug 2025 15:05:37 -0600 Subject: [PATCH 01/12] wip on a regenerate response functionality --- .../app/routers/index/sessions/__init__.py | 50 +++++++++++ llm-service/app/services/chat/chat.py | 63 +------------- .../app/services/chat/streaming_chat.py | 83 +++++++++++++++++-- .../chat_history/chat_history_manager.py | 16 ++++ .../chat_history/s3_chat_history_manager.py | 45 ++++++++++ .../simple_chat_history_manager.py | 60 ++++++++++++++ ui/src/api/chatApi.ts | 54 ++++++------ .../ChatOutput/ChatMessages/ChatMessage.tsx | 14 +++- .../ChatMessages/ChatMessageBody.tsx | 7 ++ .../ChatMessages/RegenerateButton.tsx | 67 +++++++++++++++ 10 files changed, 365 insertions(+), 94 deletions(-) create mode 100644 ui/src/pages/RagChatTab/ChatOutput/ChatMessages/RegenerateButton.tsx diff --git a/llm-service/app/routers/index/sessions/__init__.py b/llm-service/app/routers/index/sessions/__init__.py index 7a1c274af..efc4e62cb 100644 --- a/llm-service/app/routers/index/sessions/__init__.py +++ b/llm-service/app/routers/index/sessions/__init__.py @@ -66,6 +66,9 @@ from ....services.mlflow import rating_mlflow_log_metric, feedback_mlflow_log_table from ....services.query.chat_events import ChatEvent from ....services.session import rename_session +from ....services.metadata_apis import session_metadata_api +from ....services import llm_completion +from ....services.chat_history.chat_history_manager import RagMessage logger = logging.getLogger(__name__) router = APIRouter(prefix="/sessions/{session_id}", tags=["Sessions"]) @@ -152,6 +155,44 @@ def chat_history( ) +class RegenerateRequest(BaseModel): + message_id: str + + +@router.post( + "/chat-history/{message_id}/regenerate", + summary="Regenerate an assistant message by message ID and update chat history", +) +@exceptions.propagates +def regenerate_message(session_id: int, message_id: str, remote_user: Optional[str] = Header(None)) -> RagStudioChatMessage: + # Load session + session = session_metadata_api.get_session(session_id, user_name=remote_user) + + # Find existing message + messages: list[RagStudioChatMessage] = chat_history_manager.retrieve_chat_history(session_id=session_id) + target: Optional[RagStudioChatMessage] = next((m for m in messages if m.id == message_id), None) + if target is None: + raise HTTPException(status_code=404, detail="Message not found") + + # Regenerate assistant response for the same user message + completion = llm_completion.completion(session_id=session_id, question=target.rag_message.user, model_name=session.inference_model) + + updated = RagStudioChatMessage( + id=target.id, + session_id=session_id, + source_nodes=target.source_nodes, + inference_model=session.inference_model, + evaluations=[], + rag_message=RagMessage(user=target.rag_message.user, assistant=str(completion.message.content)), + timestamp=time.time(), + condensed_question=None, + ) + + # Persist update in-place + chat_history_manager.update_message(session_id=session_id, message_id=message_id, message=updated) + return updated + + @router.get( "/chat-history/{message_id}", summary="Returns a specific chat messages for the provided session.", @@ -286,12 +327,21 @@ def generate_stream() -> Generator[str, None, None]: try: executor = ThreadPoolExecutor(max_workers=1) + # If a response_id is provided in the request (e.g., regenerate), reuse it; else None + requested_response_id = None + try: + body_dict = request.model_dump() # Pydantic BaseModel + requested_response_id = body_dict.get("response_id") + except Exception: + requested_response_id = None + future = executor.submit( stream_chat, session=session, query=request.query, configuration=configuration, user_name=remote_user, + response_id=requested_response_id, ) # If we get here and the cancel_event is set, the client has disconnected diff --git a/llm-service/app/services/chat/chat.py b/llm-service/app/services/chat/chat.py index 91102b00d..ec51e7849 100644 --- a/llm-service/app/services/chat/chat.py +++ b/llm-service/app/services/chat/chat.py @@ -40,22 +40,19 @@ import uuid from typing import Optional -from llama_index.core.chat_engine.types import AgentChatResponse - from app.ai.vector_stores.vector_store_factory import VectorStoreFactory from app.rag_types import RagPredictConfiguration -from app.services import evaluators, llm_completion -from app.services.chat.utils import retrieve_chat_history, format_source_nodes +from app.services import llm_completion +from app.services.chat.streaming_chat import finalize_response +from app.services.chat.utils import retrieve_chat_history from app.services.chat_history.chat_history_manager import ( - Evaluation, RagMessage, RagStudioChatMessage, chat_history_manager, ) from app.services.metadata_apis.session_metadata_api import Session -from app.services.mlflow import record_rag_mlflow_run, record_direct_llm_mlflow_run +from app.services.mlflow import record_direct_llm_mlflow_run from app.services.query import querier -from app.services.query.querier import get_nodes_from_output from app.services.query.query_configuration import QueryConfiguration logger = logging.getLogger(__name__) @@ -125,58 +122,6 @@ def _run_chat( ) -def finalize_response( - chat_response: AgentChatResponse, - condensed_question: str | None, - query: str, - query_configuration: QueryConfiguration, - response_id: str, - session: Session, - user_name: Optional[str], -) -> RagStudioChatMessage: - if condensed_question and (condensed_question.strip() == query.strip()): - condensed_question = None - - orig_source_nodes = chat_response.source_nodes - source_nodes = get_nodes_from_output(chat_response.response, session) - - # if node with id present in orig_source_nodes, then don't add it again - node_ids_present = set([node.node_id for node in orig_source_nodes]) - for node in source_nodes: - if node.node_id not in node_ids_present: - orig_source_nodes.append(node) - - chat_response.source_nodes = orig_source_nodes - - evaluations = [] - if len(chat_response.source_nodes) != 0: - relevance, faithfulness = evaluators.evaluate_response( - query, chat_response, session.inference_model - ) - evaluations.append(Evaluation(name="relevance", value=relevance)) - evaluations.append(Evaluation(name="faithfulness", value=faithfulness)) - response_source_nodes = format_source_nodes(chat_response) - new_chat_message = RagStudioChatMessage( - id=response_id, - session_id=session.id, - source_nodes=response_source_nodes, - inference_model=session.inference_model, - rag_message=RagMessage( - user=query, - assistant=chat_response.response, - ), - evaluations=evaluations, - timestamp=time.time(), - condensed_question=condensed_question, - ) - record_rag_mlflow_run( - new_chat_message, query_configuration, response_id, session, user_name - ) - chat_history_manager.append_to_history(session.id, [new_chat_message]) - - return new_chat_message - - def direct_llm_chat( session: Session, response_id: str, query: str, user_name: Optional[str] ) -> RagStudioChatMessage: diff --git a/llm-service/app/services/chat/streaming_chat.py b/llm-service/app/services/chat/streaming_chat.py index 9b4aab434..0b6f85734 100644 --- a/llm-service/app/services/chat/streaming_chat.py +++ b/llm-service/app/services/chat/streaming_chat.py @@ -47,23 +47,22 @@ from app.ai.vector_stores.vector_store_factory import VectorStoreFactory from app.rag_types import RagPredictConfiguration -from app.services import llm_completion, models -from app.services.chat.chat import finalize_response -from app.services.chat.utils import retrieve_chat_history +from app.services import llm_completion, models, evaluators +from app.services.chat.utils import retrieve_chat_history, format_source_nodes from app.services.chat_history.chat_history_manager import ( RagStudioChatMessage, RagMessage, - chat_history_manager, + chat_history_manager, Evaluation, ) from app.services.metadata_apis.session_metadata_api import Session -from app.services.mlflow import record_direct_llm_mlflow_run +from app.services.mlflow import record_direct_llm_mlflow_run, record_rag_mlflow_run from app.services.query import querier from app.services.query.chat_engine import ( FlexibleContextChatEngine, build_flexible_chat_engine, ) from app.services.query.querier import ( - build_retriever, + build_retriever, get_nodes_from_output, ) from app.services.query.query_configuration import QueryConfiguration @@ -73,6 +72,7 @@ def stream_chat( query: str, configuration: RagPredictConfiguration, user_name: Optional[str], + response_id: Optional[str] = None, ) -> Generator[ChatResponse, None, None]: query_configuration = QueryConfiguration( top_k=session.response_chunks, @@ -86,7 +86,22 @@ def stream_chat( use_streaming=not session.query_configuration.disable_streaming, ) - response_id = str(uuid.uuid4()) + response_id = response_id or str(uuid.uuid4()) + new_chat_message = RagStudioChatMessage( + id=response_id, + session_id=session.id, + source_nodes=[], + inference_model=session.inference_model, + evaluations=[], + rag_message=RagMessage( + user=query, + assistant="", + ), + timestamp=time.time(), + condensed_question=None, + ) + chat_history_manager.append_to_history(session.id, [new_chat_message]) + total_data_sources_size: int = sum( map( lambda ds_id: VectorStoreFactory.for_chunks(ds_id).size() or 0, @@ -217,4 +232,56 @@ def _stream_direct_llm_chat( timestamp=time.time(), condensed_question=None, ) - chat_history_manager.append_to_history(session.id, [new_chat_message]) + chat_history_manager.update_message(session.id, response_id, new_chat_message) + + +def finalize_response( + chat_response: AgentChatResponse, + condensed_question: str | None, + query: str, + query_configuration: QueryConfiguration, + response_id: str, + session: Session, + user_name: Optional[str], +) -> RagStudioChatMessage: + if condensed_question and (condensed_question.strip() == query.strip()): + condensed_question = None + + orig_source_nodes = chat_response.source_nodes + source_nodes = get_nodes_from_output(chat_response.response, session) + + # if node with id present in orig_source_nodes, then don't add it again + node_ids_present = set([node.node_id for node in orig_source_nodes]) + for node in source_nodes: + if node.node_id not in node_ids_present: + orig_source_nodes.append(node) + + chat_response.source_nodes = orig_source_nodes + + evaluations = [] + if len(chat_response.source_nodes) != 0: + relevance, faithfulness = evaluators.evaluate_response( + query, chat_response, session.inference_model + ) + evaluations.append(Evaluation(name="relevance", value=relevance)) + evaluations.append(Evaluation(name="faithfulness", value=faithfulness)) + response_source_nodes = format_source_nodes(chat_response) + new_chat_message = RagStudioChatMessage( + id=response_id, + session_id=session.id, + source_nodes=response_source_nodes, + inference_model=session.inference_model, + rag_message=RagMessage( + user=query, + assistant=chat_response.response, + ), + evaluations=evaluations, + timestamp=time.time(), + condensed_question=condensed_question, + ) + record_rag_mlflow_run( + new_chat_message, query_configuration, response_id, session, user_name + ) + chat_history_manager.update_message(session.id, response_id, new_chat_message) + + return new_chat_message diff --git a/llm-service/app/services/chat_history/chat_history_manager.py b/llm-service/app/services/chat_history/chat_history_manager.py index 24313a3b2..06ca2bd0a 100644 --- a/llm-service/app/services/chat_history/chat_history_manager.py +++ b/llm-service/app/services/chat_history/chat_history_manager.py @@ -84,6 +84,22 @@ def append_to_history( ) -> None: pass + @abstractmethod + def update_message( + self, session_id: int, message_id: str, message: RagStudioChatMessage + ) -> None: + """Update an existing message by ID for the given session. + + Implementations should overwrite both the user and assistant entries + corresponding to this message ID. + """ + pass + + @abstractmethod + def delete_message(self, session_id: int, message_id: str) -> None: + """Delete an existing message by ID for the given session.""" + pass + def _create_chat_history_manager() -> ChatHistoryManager: from app.services.chat_history.simple_chat_history_manager import ( diff --git a/llm-service/app/services/chat_history/s3_chat_history_manager.py b/llm-service/app/services/chat_history/s3_chat_history_manager.py index 974cf6063..349ef15b6 100644 --- a/llm-service/app/services/chat_history/s3_chat_history_manager.py +++ b/llm-service/app/services/chat_history/s3_chat_history_manager.py @@ -154,3 +154,48 @@ def append_to_history( f"Error appending to chat history for session {session_id}: {e}" ) raise + + def update_message( + self, session_id: int, message_id: str, message: RagStudioChatMessage + ) -> None: + """Update an existing message's content and metadata by ID in S3.""" + s3_key = self._get_s3_key(session_id) + try: + chat_history_data = self.retrieve_chat_history(session_id=session_id) + updated = False + for idx, existing in enumerate(chat_history_data): + if existing.id == message_id: + chat_history_data[idx] = message + updated = True + break + if not updated: + return + chat_history_json = json.dumps( + [m.model_dump() for m in chat_history_data] + ) + self.s3_client.put_object( + Bucket=self.bucket_name, Key=s3_key, Body=chat_history_json + ) + except Exception as e: + logger.error( + f"Error updating chat message {message.id} for session {session_id}: {e}" + ) + raise + + def delete_message(self, session_id: int, message_id: str) -> None: + """Delete a specific message by ID in S3-backed store.""" + s3_key = self._get_s3_key(session_id) + try: + chat_history_data = self.retrieve_chat_history(session_id=session_id) + chat_history_data = [m for m in chat_history_data if m.id != message_id] + chat_history_json = json.dumps( + [m.model_dump() for m in chat_history_data] + ) + self.s3_client.put_object( + Bucket=self.bucket_name, Key=s3_key, Body=chat_history_json + ) + except Exception as e: + logger.error( + f"Error deleting chat message {message_id} for session {session_id}: {e}" + ) + raise diff --git a/llm-service/app/services/chat_history/simple_chat_history_manager.py b/llm-service/app/services/chat_history/simple_chat_history_manager.py index 1176134fb..90b813e66 100644 --- a/llm-service/app/services/chat_history/simple_chat_history_manager.py +++ b/llm-service/app/services/chat_history/simple_chat_history_manager.py @@ -159,6 +159,66 @@ def append_to_history( ) store.persist(self._store_file(session_id)) + def update_message( + self, session_id: int, message_id: str, message: RagStudioChatMessage + ) -> None: + """Update an existing message's user/assistant content and metadata by ID.""" + store = self._store_for_session(session_id) + key = self._build_chat_key(session_id) + messages: list[ChatMessage] = store.get_messages(key) + + # Each logical message is stored as a pair: USER, ASSISTANT with same id + for i in range(0, len(messages), 2): + user_msg = messages[i] + if user_msg.additional_kwargs.get("id") == message_id: + # Update user content + user_msg.content = message.rag_message.user + # Update assistant content and metadata (next message) + if i + 1 < len(messages): + assistant_msg = messages[i + 1] + else: + assistant_msg = ChatMessage(role=MessageRole.ASSISTANT, content="") + messages.append(assistant_msg) + assistant_msg.content = message.rag_message.assistant + assistant_msg.additional_kwargs.update( + { + "id": message_id, + "source_nodes": message.source_nodes, + "inference_model": message.inference_model, + "evaluations": message.evaluations, + "timestamp": message.timestamp, + } + ) + # Persist updated list + store.delete_messages(key) + for m in messages: + store.add_message(key, m) + store.persist(self._store_file(session_id)) + return + + def delete_message(self, session_id: int, message_id: str) -> None: + """Delete both USER and ASSISTANT entries for a given message id.""" + store = self._store_for_session(session_id) + key = self._build_chat_key(session_id) + messages: list[ChatMessage] = store.get_messages(key) + + new_messages: list[ChatMessage] = [] + i = 0 + while i < len(messages): + user_msg = messages[i] + assistant_msg = messages[i + 1] if i + 1 < len(messages) else None + current_id = user_msg.additional_kwargs.get("id") + if current_id != message_id: + new_messages.append(user_msg) + if assistant_msg is not None: + new_messages.append(assistant_msg) + i += 2 + + store.delete_messages(key) + for m in new_messages: + store.add_message(key, m) + store.persist(self._store_file(session_id)) + @staticmethod def _build_chat_key(session_id: int) -> str: return "session_" + str(session_id) diff --git a/ui/src/api/chatApi.ts b/ui/src/api/chatApi.ts index 5ce249ee4..816df5246 100644 --- a/ui/src/api/chatApi.ts +++ b/ui/src/api/chatApi.ts @@ -87,6 +87,7 @@ export interface ChatMutationRequest { query: string; session_id: number; configuration: QueryConfiguration; + response_id?: string; } interface ChatHistoryRequestType { @@ -166,7 +167,7 @@ export interface ChatHistoryResponse { export const chatHistoryQuery = async ( request: ChatHistoryRequestType, - pageParam: number | undefined, + pageParam: number | undefined ): Promise => { const params = new URLSearchParams(); if (request.limit !== undefined) { @@ -178,13 +179,13 @@ export const chatHistoryQuery = async ( return await getRequest( `${llmServicePath}/sessions/${request.session_id.toString()}/chat-history?` + - params.toString(), + params.toString() ); }; export const appendPlaceholderToChatHistory = ( query: string, - cachedData?: InfiniteData, + cachedData?: InfiniteData ): InfiniteData => { if (!cachedData || cachedData.pages.length === 0) { const firstPage: ChatHistoryResponse = { @@ -199,7 +200,7 @@ export const appendPlaceholderToChatHistory = ( } const pageParams = cachedData.pageParams.map((pageParam, index) => - index > 0 && typeof pageParam === "number" ? ++pageParam : pageParam, + index > 0 && typeof pageParam === "number" ? ++pageParam : pageParam ); const pages = cachedData.pages.map((page) => { @@ -215,7 +216,7 @@ export const appendPlaceholderToChatHistory = ( const lastPage = pages[pages.length - 1]; const filteredLastPageData = lastPage.data.filter( - (chatMessage) => !isPlaceholder(chatMessage), + (chatMessage) => !isPlaceholder(chatMessage) ); return { pageParams, @@ -231,7 +232,7 @@ export const appendPlaceholderToChatHistory = ( export const replacePlaceholderInChatHistory = ( data: ChatMessageType, - cachedData?: InfiniteData, + cachedData?: InfiniteData ): InfiniteData => { if (!cachedData || cachedData.pages.length == 0) { return ( @@ -265,7 +266,7 @@ export const replacePlaceholderInChatHistory = ( }; export const createQueryConfiguration = ( - excludeKnowledgeBase: boolean, + excludeKnowledgeBase: boolean ): QueryConfiguration => { return { exclude_knowledge_base: excludeKnowledgeBase, @@ -296,7 +297,7 @@ const ratingMutation = async ({ }): Promise => { return await postRequest( `${llmServicePath}/sessions/${sessionId}/responses/${responseId}/rating`, - { rating }, + { rating } ); }; @@ -323,7 +324,7 @@ const feedbackMutation = async ({ }): Promise => { return await postRequest( `${llmServicePath}/sessions/${sessionId}/responses/${responseId}/feedback`, - { feedback }, + { feedback } ); }; @@ -344,7 +345,7 @@ export interface ChatEvent { const customChatMessage = ( variables: ChatMutationRequest, message: string, - prefix: string, + prefix: string ) => { const uuid = crypto.randomUUID(); const customMessage: ChatMessageType = { @@ -372,7 +373,7 @@ const canceledChatMessage = (variables: ChatMutationRequest) => { return customChatMessage( variables, "Request canceled by user", - CANCELED_PREFIX_ID, + CANCELED_PREFIX_ID ); }; @@ -385,7 +386,7 @@ interface StreamingChatCallbacks { const modifyPlaceholderInChatHistory = ( queryClient: QueryClient, variables: ChatMutationRequest, - replacementMessage: ChatMessageType, + replacementMessage: ChatMessageType ) => { queryClient.setQueryData>( chatHistoryQueryKey({ @@ -393,14 +394,14 @@ const modifyPlaceholderInChatHistory = ( offset: 0, }), (cachedData) => - replacePlaceholderInChatHistory(replacementMessage, cachedData), + replacePlaceholderInChatHistory(replacementMessage, cachedData) ); }; const handlePrepareController = ( getController: ((ctrl: AbortController) => void) | undefined, queryClient: QueryClient, - request: ChatMutationRequest, + request: ChatMutationRequest ) => { return (ctrl: AbortController) => { if (getController) { @@ -410,7 +411,7 @@ const handlePrepareController = ( modifyPlaceholderInChatHistory( queryClient, request, - canceledChatMessage(request), + canceledChatMessage(request) ); ctrl.signal.removeEventListener("abort", onAbort); }; @@ -428,10 +429,10 @@ const handleStreamingSuccess = ( | ((data: ChatMessageType, request?: unknown, context?: unknown) => unknown) | undefined, handleError: (request: ChatMutationRequest, error: Error) => void, - onError: ((error: Error) => void) | undefined, + onError: ((error: Error) => void) | undefined ) => { fetch( - `${llmServicePath}/sessions/${request.session_id.toString()}/chat-history/${messageId}`, + `${llmServicePath}/sessions/${request.session_id.toString()}/chat-history/${messageId}` ) .then(async (res) => { const message = (await res.json()) as ChatMessageType; @@ -439,7 +440,7 @@ const handleStreamingSuccess = ( chatHistoryQueryKey({ session_id: request.session_id, }), - (cachedData) => replacePlaceholderInChatHistory(message, cachedData), + (cachedData) => replacePlaceholderInChatHistory(message, cachedData) ); queryClient .invalidateQueries({ @@ -479,7 +480,7 @@ export const useStreamingChatMutation = ({ const handleGetController = handlePrepareController( getController, queryClient, - request, + request ); return streamChatMutation( @@ -487,7 +488,7 @@ export const useStreamingChatMutation = ({ onChunk, onEvent, convertError, - handleGetController, + handleGetController ); }, onMutate: (variables) => { @@ -496,7 +497,7 @@ export const useStreamingChatMutation = ({ session_id: variables.session_id, }), (cachedData) => - appendPlaceholderToChatHistory(variables.query, cachedData), + appendPlaceholderToChatHistory(variables.query, cachedData) ); }, onSuccess: (messageId, variables) => { @@ -509,7 +510,7 @@ export const useStreamingChatMutation = ({ queryClient, onSuccess, handleError, - onError, + onError ); }, onError: (error: Error, variables) => { @@ -524,7 +525,7 @@ const streamChatMutation = async ( onChunk: (chunk: string) => void, onEvent: (event: ChatEvent) => void, onError: (error: string) => void, - getController?: (ctrl: AbortController) => void, + getController?: (ctrl: AbortController) => void ): Promise => { const ctrl = new AbortController(); if (getController) { @@ -542,6 +543,7 @@ const streamChatMutation = async ( body: JSON.stringify({ query: request.query, configuration: request.configuration, + response_id: request.response_id, }), signal: ctrl.signal, onmessage(msg: EventSourceMessage) { @@ -567,7 +569,7 @@ const streamChatMutation = async ( } catch (error) { console.error("Error parsing message data:", error); onError( - `An error occurred while processing the response. Error message: ${JSON.stringify(msg)}. Error details: ${JSON.stringify(error)}.`, + `An error occurred while processing the response. Error message: ${JSON.stringify(msg)}. Error details: ${JSON.stringify(error)}.` ); ctrl.abort(); } @@ -592,13 +594,13 @@ const streamChatMutation = async ( onError("An error occurred: " + response.statusText); } }, - }, + } ); return responseId; }; export const getOnEvent = ( - setStreamedEvent: Dispatch>, + setStreamedEvent: Dispatch> ) => { return (event: ChatEvent) => { if (event.type === "done") { diff --git a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessage.tsx b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessage.tsx index db5d27eb8..0cd92b370 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessage.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessage.tsx @@ -49,6 +49,8 @@ import UserQuestion from "pages/RagChatTab/ChatOutput/ChatMessages/UserQuestion. import "../tableMarkdown.css"; import { ExclamationCircleTwoTone } from "@ant-design/icons"; import { ChatMessageBody } from "pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageBody.tsx"; +import { useContext } from "react"; +import { RagChatContext } from "pages/RagChatTab/State/RagChatContext.tsx"; import { cdlAmber500 } from "src/cuix/variables.ts"; const isError = (data: ChatMessageType) => { @@ -100,6 +102,10 @@ const WarningMessage = ({ ); }; const ChatMessage = ({ data }: { data: ChatMessageType }) => { + const { activeSession } = useContext(RagChatContext); + const excludeKnowledgeBases = + !activeSession?.dataSourceIds || activeSession.dataSourceIds.length === 0; + if (isError(data)) { return ; } @@ -113,7 +119,13 @@ const ChatMessage = ({ data }: { data: ChatMessageType }) => { return ; } - return ; + return ( + + ); }; export default ChatMessage; diff --git a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageBody.tsx b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageBody.tsx index 0a4b8a32c..f08b87005 100644 --- a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageBody.tsx +++ b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/ChatMessageBody.tsx @@ -39,6 +39,7 @@ import { ChatMessageType, ChatEvent } from "src/api/chatApi.ts"; import UserQuestion from "pages/RagChatTab/ChatOutput/ChatMessages/UserQuestion.tsx"; import { Divider, Flex, Typography } from "antd"; +import RegenerateButton from "pages/RagChatTab/ChatOutput/ChatMessages/RegenerateButton.tsx"; import Images from "src/components/images/Images.ts"; import { cdlBlue500, cdlGray200 } from "src/cuix/variables.ts"; import { Evaluations } from "pages/RagChatTab/ChatOutput/ChatMessages/Evaluations.tsx"; @@ -51,9 +52,11 @@ import { MarkdownResponse } from "pages/RagChatTab/ChatOutput/ChatMessages/Markd export const ChatMessageBody = ({ data, streamedEvents, + excludeKnowledgeBase, }: { data: ChatMessageType; streamedEvents?: ChatEvent[]; + excludeKnowledgeBase: boolean; }) => { return (
@@ -99,6 +102,10 @@ export const ChatMessageBody = ({ + diff --git a/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/RegenerateButton.tsx b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/RegenerateButton.tsx new file mode 100644 index 000000000..36bbf8ef5 --- /dev/null +++ b/ui/src/pages/RagChatTab/ChatOutput/ChatMessages/RegenerateButton.tsx @@ -0,0 +1,67 @@ +import { Button, Tooltip } from "antd"; +import { ReloadOutlined } from "@ant-design/icons"; +import { + ChatMessageType, + createQueryConfiguration, + getOnEvent, + useStreamingChatMutation, +} from "src/api/chatApi.ts"; +import { useStreamingChunkBuffer } from "src/hooks/useStreamingChunkBuffer.ts"; +import { useContext } from "react"; +import { RagChatContext } from "pages/RagChatTab/State/RagChatContext.tsx"; + +const RegenerateButton = ({ + message, + excludeKnowledgeBase, + onStarted, +}: { + message: ChatMessageType; + excludeKnowledgeBase: boolean; + onStarted?: () => void; +}) => { + const { + streamedChatState: [, setStreamedChat], + streamedEventState: [, setStreamedEvent], + streamedAbortControllerState: [, setStreamedAbortController], + } = useContext(RagChatContext); + // Use custom hook to handle batched streaming updates + const { onChunk, flush } = useStreamingChunkBuffer((chunks) => { + setStreamedChat((prev) => prev + chunks); + }); + + const streamChatMutation = useStreamingChatMutation({ + onChunk, + onEvent: getOnEvent(setStreamedEvent), + onSuccess: () => { + // Flush any remaining chunks before cleanup + flush(); + setStreamedChat(""); + }, + getController: (ctrl) => { + setStreamedAbortController(ctrl); + }, + }); + + const handleClick = () => { + onStarted?.(); + streamChatMutation.mutate({ + query: message.rag_message.user, + session_id: message.session_id, + configuration: createQueryConfiguration(excludeKnowledgeBase), + response_id: message.id, + }); + }; + + return ( + +