From dbd9c6e399e85950679afd2d641ea94aec4c5a79 Mon Sep 17 00:00:00 2001 From: Lorenzo Padoan Date: Tue, 24 Sep 2024 10:46:17 +0200 Subject: [PATCH] feat(): prompt based extraction --- examples/prompt_based_extraction.py | 36 +++++++++++++++++ scrapeschema/extractor.py | 10 ++--- scrapeschema/parsers/base_parser.py | 4 +- scrapeschema/parsers/pdf_parser.py | 61 ++++++++++++++++------------- 4 files changed, 77 insertions(+), 34 deletions(-) create mode 100644 examples/prompt_based_extraction.py diff --git a/examples/prompt_based_extraction.py b/examples/prompt_based_extraction.py new file mode 100644 index 0000000..83c1c3c --- /dev/null +++ b/examples/prompt_based_extraction.py @@ -0,0 +1,36 @@ +from scrapeschema import FileExtractor, PDFParser +from scrapeschema.llm_client import LLMClient +import os +from dotenv import load_dotenv + +load_dotenv() # Load environment variables from .env file + +# Get the OpenAI API key from the environment variables +api_key = os.getenv("OPENAI_API_KEY") + +def main(): + # Path to the PDF file + pdf_name = "test.pdf" + curr_dir = os.path.dirname(os.path.abspath(__file__)) + pdf_path = os.path.join(curr_dir, pdf_name) + + # Create an LLMClient instance + llm_client = LLMClient(api_key) + + # Create a PDFParser instance with the LLMClient + pdf_parser = PDFParser(llm_client) + + # Create a FileExtractor instance with the PDF parser + pdf_extractor = FileExtractor(pdf_path, pdf_parser) + + # Define a custom prompt to extract only financial-related entities + custom_prompt = """ + Insert in the schema only info related to the top 10 investment. + """ + + # Extract entities from the PDF using the custom prompt + entities = pdf_extractor.extract_entities(prompt=custom_prompt) + print("Extracted Entities:", entities) + +if __name__ == "__main__": + main() diff --git a/scrapeschema/extractor.py b/scrapeschema/extractor.py index d34318c..0d2fee6 100644 --- a/scrapeschema/extractor.py +++ b/scrapeschema/extractor.py @@ -1,6 +1,6 @@ import logging from abc import ABC, abstractmethod -from typing import List, Tuple, Dict, Any +from typing import List, Tuple, Dict, Any, Optional from .primitives import Entity, Relation from .parsers.base_parser import BaseParser from .parsers.prompts import DELETE_PROMPT, UPDATE_ENTITIES_PROMPT, UPDATE_SCHEMA_PROMPT @@ -53,12 +53,12 @@ def __init__(self, file_path: str, parser: BaseParser): self.file_path = file_path self.parser = parser - def extract_entities(self) -> List[Entity]: - new_entities = self.parser.extract_entities(self.file_path) + def extract_entities(self, prompt: Optional[str] = None) -> List[Entity]: + new_entities = self.parser.extract_entities(self.file_path, prompt) return new_entities - def extract_relations(self) -> List[Relation]: - return self.parser.extract_relations(self.file_path) + def extract_relations(self, prompt: Optional[str] = None) -> List[Relation]: + return self.parser.extract_relations(self.file_path, prompt) def entities_json_schema(self) -> Dict[str, Any]: diff --git a/scrapeschema/parsers/base_parser.py b/scrapeschema/parsers/base_parser.py index 635db8a..210d587 100644 --- a/scrapeschema/parsers/base_parser.py +++ b/scrapeschema/parsers/base_parser.py @@ -21,11 +21,11 @@ def __init__(self, llm_client: LLMClient): self._relations = [] @abstractmethod - def extract_entities(self, file_path: str) -> List[Entity]: + def extract_entities(self, file_path: str, prompt: Optional[str] = None) -> List[Entity]: pass @abstractmethod - def extract_relations(self, file_path: Optional[str] = None) -> List[Relation]: + def extract_relations(self, file_path: Optional[str] = None, prompt: Optional[str] = None) -> List[Relation]: pass @abstractmethod diff --git a/scrapeschema/parsers/pdf_parser.py b/scrapeschema/parsers/pdf_parser.py index bfecc34..ff63a22 100644 --- a/scrapeschema/parsers/pdf_parser.py +++ b/scrapeschema/parsers/pdf_parser.py @@ -174,17 +174,20 @@ def __init__(self, llm_client: LLMClient): super().__init__(llm_client) - def extract_entities(self, file_path: str) -> List[Entity]: + def extract_entities(self, file_path: str, prompt: Optional[str] = None) -> List[Entity]: if not os.path.exists(file_path): raise FileNotFoundError(f"PDF file not found: {file_path}") - new_entities = self._extract_entities_from_pdf(file_path) + new_entities = self._extract_entities_from_pdf(file_path, prompt) return self.update_entities(new_entities) - def _extract_entities_from_pdf(self, file_path: str) -> List[Entity]: + def _extract_entities_from_pdf(self, file_path: str, prompt: Optional[str] = None) -> List[Entity]: + if prompt: + entities_json_schema = self.entities_json_schema(file_path, prompt) + else: + entities_json_schema = self.entities_json_schema(file_path) + entities = [] - entities_json_schema = self.entities_json_schema(file_path) - def traverse_schema(schema: Dict[str, Any], parent_id: str = None): if isinstance(schema, dict): entity_id = parent_id if parent_id else schema.get('title', 'root') @@ -233,6 +236,10 @@ def update_entities(self, new_entities: List[Entity]) -> List[Entity]: # Update the parser's entities self._entities = updated_entities + # print the updated entities + logging.info("Updated entities:") + for entity in updated_entities: + logging.info(entity.__dict__) logging.info(f"Entities updated. New count: {len(updated_entities)}") return updated_entities except json.JSONDecodeError as e: @@ -240,7 +247,7 @@ def update_entities(self, new_entities: List[Entity]) -> List[Entity]: logging.error("Error: Unable to parse the LLM response.") return existing_entities - def extract_relations(self, file_path: Optional[str] = None) -> List[Relation]: + def extract_relations(self, file_path: Optional[str] = None, prompt: Optional[str] = None) -> List[Relation]: """ Extracts relations from a PDF file. @@ -250,41 +257,38 @@ def extract_relations(self, file_path: Optional[str] = None) -> List[Relation]: Returns: List[Relation]: A list of extracted relations. """ - if file_path is not None: - if not os.path.exists(file_path): - raise FileNotFoundError(f"PDF file not found: {file_path}") - - if not self._entities or len(self._entities) == 0: - self.extract_entities(file_path) + if file_path is not None and not os.path.exists(file_path): + raise FileNotFoundError(f"PDF file not found: {file_path}") + + if not self._entities or len(self._entities) == 0: + self.extract_entities(file_path, prompt) relation_class_str = inspect.getsource(Relation) relations_prompt = RELATIONS_PROMPT.format( entities=json.dumps([e.__dict__ for e in self._entities], indent=2), relation_class=relation_class_str ) - + if prompt: + #append to the relations_prompt the prompt + relations_prompt += f"\n\n Extract only the relations that are required from the following user prompt:\n\n{prompt}" + + relations_answer_code = self.llm_client.get_response(relations_prompt) relations_answer_code = self._extract_python_content(relations_answer_code) - # Create a new dictionary to store the local variables local_vars = {} - - # Execute the code in the context of local_vars try: exec(relations_answer_code, globals(), local_vars) except Exception as e: logging.error(f"Error executing relations code: {e}") raise ValueError(f"The language model generated invalid code: {e}") from e - - # Extract the relations from local_vars + relations_answer = local_vars.get('relations', []) - self._relations = relations_answer logging.info(f"Extracted relations: {relations_answer_code}") - return self._relations + return self._relations - def plot_entities_schema(self, file_path: str) -> None: """ Plots the entities schema from a PDF file. @@ -307,7 +311,7 @@ def plot_entities_schema(self, file_path: str) -> None: logging.info("digraph_code_execution----------------------------------") exec(digraph_code[9:-3]) - def entities_json_schema(self, file_path: str) -> Dict[str, Any]: + def entities_json_schema(self, file_path: str, prompt: Optional[str] = None) -> Dict[str, Any]: """ Generates a JSON schema of entities from a PDF file. @@ -323,7 +327,7 @@ def entities_json_schema(self, file_path: str) -> Dict[str, Any]: base64_images = process_pdf(file_path) if base64_images: - page_answers = self._generate_json_schema(base64_images) + page_answers = self._generate_json_schema(base64_images, prompt) json_schema = self._merge_json_schemas(page_answers) json_schema = self._extract_json_content(json_schema) @@ -365,17 +369,20 @@ def _merge_digraphs_for_plot(self, page_answers: List[str]) -> str: digraph_code = self.llm_client.get_response(digraph_prompt) return digraph_code - def _generate_json_schema(self, base64_images: List[str]) -> List[str]: + def _generate_json_schema(self, base64_images: List[str], prompt: Optional[str] = None) -> List[str]: page_answers = [] for page_num, base64_image in enumerate(base64_images, start=1): - prompt = f"{JSON_SCHEMA_PROMPT} (Page {page_num})" + if prompt: + customized_prompt = f"{JSON_SCHEMA_PROMPT} extract only what is required from the following prompt:\ + {prompt} (Page {page_num})" + else: + customized_prompt = f"{JSON_SCHEMA_PROMPT} (Page {page_num})" image_data = f"data:image/jpeg;base64,{base64_image}" try: - answer = self.llm_client.get_response(prompt, image_url=image_data) + answer = self.llm_client.get_response(customized_prompt, image_url=image_data) except ReadTimeout: logging.warning("Request to OpenAI API timed out. Retrying...") - # Implement retry logic or skip to the next image continue answer = self._extract_json_content(answer) page_answers.append(f"Page {page_num}: {answer}")