|
| 1 | +"""Example demonstrating how to use the ContextGroundingVectorStore class with LangChain.""" |
| 2 | + |
| 3 | +import argparse |
| 4 | +import asyncio |
| 5 | +from pprint import pprint |
| 6 | +from typing import Any |
| 7 | + |
| 8 | +from dotenv import find_dotenv, load_dotenv |
| 9 | +from langchain_core.language_models.chat_models import BaseChatModel |
| 10 | +from langchain_core.output_parsers import StrOutputParser |
| 11 | +from langchain_core.prompts import ChatPromptTemplate |
| 12 | +from langchain_core.runnables import RunnablePassthrough |
| 13 | +from langchain_core.vectorstores import VectorStore |
| 14 | +from uipath_langchain.chat.models import UiPathAzureChatOpenAI |
| 15 | +from uipath_langchain.vectorstores.context_grounding_vectorstore import ( |
| 16 | + ContextGroundingVectorStore, |
| 17 | +) |
| 18 | + |
| 19 | + |
| 20 | +def create_retrieval_chain(vectorstore: VectorStore, model: BaseChatModel, k: int = 3): |
| 21 | + """Create a retrieval chain using a vector store. |
| 22 | +
|
| 23 | + Args: |
| 24 | + vectorstore: Vector store to use for the chain |
| 25 | + model: LangChain language model to use for the chain |
| 26 | +
|
| 27 | + Returns: |
| 28 | + A retrieval chain ready to answer questions |
| 29 | + """ |
| 30 | + # Create a retriever from the vector store |
| 31 | + retriever = vectorstore.as_retriever( |
| 32 | + search_kwargs={"k": k}, |
| 33 | + ) |
| 34 | + |
| 35 | + # Create a prompt template |
| 36 | + template = """Answer the question based on the following context: |
| 37 | + {context} |
| 38 | + Question: {question} |
| 39 | + """ |
| 40 | + prompt = ChatPromptTemplate.from_template(template) |
| 41 | + |
| 42 | + # Create the retrieval chain |
| 43 | + chain = ( |
| 44 | + {"context": retriever, "question": RunnablePassthrough()} |
| 45 | + | prompt |
| 46 | + | model |
| 47 | + | StrOutputParser() |
| 48 | + ) |
| 49 | + |
| 50 | + # Return a function that will run the chain and include source documents |
| 51 | + def retrieval_chain(query: str) -> dict[str, Any]: |
| 52 | + # Get documents separately to include them in the result |
| 53 | + docs = retriever.invoke(query) |
| 54 | + # Run the chain |
| 55 | + answer = chain.invoke(query) |
| 56 | + # Return combined result |
| 57 | + return {"result": answer, "source_documents": docs} |
| 58 | + |
| 59 | + return retrieval_chain |
| 60 | + |
| 61 | + |
| 62 | +async def main(index_name: str, query: str, k: int = 3): |
| 63 | + load_dotenv(find_dotenv()) |
| 64 | + |
| 65 | + """Run a simple example of ContextGroundingVectorStore.""" |
| 66 | + vectorstore = ContextGroundingVectorStore( |
| 67 | + index_name=index_name, |
| 68 | + ) |
| 69 | + |
| 70 | + # Example query |
| 71 | + query = "What is the ECCN for a laptop?" |
| 72 | + |
| 73 | + # Perform semantic searches with distance scores |
| 74 | + docs_with_scores = await vectorstore.asimilarity_search_with_score(query=query, k=5) |
| 75 | + print("==== Docs with distance scores ====") |
| 76 | + pprint( |
| 77 | + [ |
| 78 | + {"page_content": doc.page_content, "distance_score": distance_score} |
| 79 | + for doc, distance_score in docs_with_scores |
| 80 | + ] |
| 81 | + ) |
| 82 | + |
| 83 | + # Perform a similarity search with relevance scores |
| 84 | + docs_with_relevance_scores = ( |
| 85 | + await vectorstore.asimilarity_search_with_relevance_scores(query=query, k=5) |
| 86 | + ) |
| 87 | + print("==== Docs with relevance scores ====") |
| 88 | + pprint( |
| 89 | + [ |
| 90 | + {"page_content": doc.page_content, "relevance_score": relevance_score} |
| 91 | + for doc, relevance_score in docs_with_relevance_scores |
| 92 | + ] |
| 93 | + ) |
| 94 | + |
| 95 | + # Run a retrieval chain |
| 96 | + model = UiPathAzureChatOpenAI( |
| 97 | + model="gpt-4o-2024-08-06", |
| 98 | + max_retries=3, |
| 99 | + ) |
| 100 | + |
| 101 | + retrieval_chain = create_retrieval_chain( |
| 102 | + vectorstore=vectorstore, |
| 103 | + model=model, |
| 104 | + ) |
| 105 | + |
| 106 | + # Run a retrieval chain |
| 107 | + result = retrieval_chain(query) |
| 108 | + print("==== Retrieval chain result ====") |
| 109 | + print(f"Query: {query}") |
| 110 | + print(f"Answer: {result['result']}") |
| 111 | + print("\nSource Documents:") |
| 112 | + for i, doc in enumerate(result["source_documents"]): |
| 113 | + print(f"\nDocument {i + 1}:") |
| 114 | + print(f"Content: {doc.page_content[:100]}...") |
| 115 | + print( |
| 116 | + f"Source: {doc.metadata.get('source', 'N/A')}, Page Number: {doc.metadata.get('page_number', '0')}" |
| 117 | + ) |
| 118 | + |
| 119 | + |
| 120 | +if __name__ == "__main__": |
| 121 | + parser = argparse.ArgumentParser() |
| 122 | + parser.add_argument( |
| 123 | + "--index_name", type=str, default="ECCN", help="The name of the index to use" |
| 124 | + ) |
| 125 | + parser.add_argument( |
| 126 | + "--query", |
| 127 | + type=str, |
| 128 | + default="What is the ECCN for a laptop?", |
| 129 | + help="The query for which documents will be retrieved", |
| 130 | + ) |
| 131 | + parser.add_argument( |
| 132 | + "--k", type=int, default=3, help="The number of documents to retrieve" |
| 133 | + ) |
| 134 | + args = parser.parse_args() |
| 135 | + asyncio.run(main(args.index_name, args.query, args.k)) |
0 commit comments