88from langchain_core .prompts import PromptTemplate
99
1010from .base import BaseChatHandler , SlashCommandRoutingType
11+ from .learn import LearnChatHandler , Retriever
1112
1213PROMPT_TEMPLATE = """Given the following conversation and a follow up question, rephrase the follow up question to be a standalone question.
1314
1819CONDENSE_PROMPT = PromptTemplate .from_template (PROMPT_TEMPLATE )
1920
2021
22+ class CustomLearnException (Exception ):
23+ """Exception raised when Jupyter AI's /ask command is used without the required /learn command."""
24+
25+ def __init__ (self ):
26+ super ().__init__ (
27+ "Jupyter AI's default /ask command requires the default /learn command. "
28+ "If you are overriding /learn via the entry points API, be sure to also override or disable /ask."
29+ )
30+
31+
2132class AskChatHandler (BaseChatHandler ):
2233 """Processes messages prefixed with /ask. This actor will
2334 send the message as input to a RetrieverQA chain, that
@@ -33,12 +44,16 @@ class AskChatHandler(BaseChatHandler):
3344
3445 uses_llm = True
3546
36- def __init__ (self , retriever , * args , ** kwargs ):
47+ def __init__ (self , * args , ** kwargs ):
3748 super ().__init__ (* args , ** kwargs )
3849
39- self ._retriever = retriever
4050 self .parser .prog = "/ask"
4151 self .parser .add_argument ("query" , nargs = argparse .REMAINDER )
52+ learn_chat_handler = self .chat_handlers .get ("/learn" )
53+ if not isinstance (learn_chat_handler , LearnChatHandler ):
54+ raise CustomLearnException ()
55+
56+ self ._retriever = Retriever (learn_chat_handler = learn_chat_handler )
4257
4358 def create_llm_chain (
4459 self , provider : Type [BaseProvider ], provider_params : Dict [str , str ]
@@ -51,6 +66,7 @@ def create_llm_chain(
5166 memory = ConversationBufferWindowMemory (
5267 memory_key = "chat_history" , return_messages = True , k = 2
5368 )
69+
5470 self .llm_chain = ConversationalRetrievalChain .from_llm (
5571 self .llm ,
5672 self ._retriever ,
0 commit comments