Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 36 additions & 0 deletions examples/prompt_based_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from scrapeschema import FileExtractor, PDFParser
from scrapeschema.llm_client import LLMClient
import os
from dotenv import load_dotenv

load_dotenv() # Load environment variables from .env file

# Get the OpenAI API key from the environment variables
api_key = os.getenv("OPENAI_API_KEY")

def main():
# Path to the PDF file
pdf_name = "test.pdf"
curr_dir = os.path.dirname(os.path.abspath(__file__))
pdf_path = os.path.join(curr_dir, pdf_name)

# Create an LLMClient instance
llm_client = LLMClient(api_key)

# Create a PDFParser instance with the LLMClient
pdf_parser = PDFParser(llm_client)

# Create a FileExtractor instance with the PDF parser
pdf_extractor = FileExtractor(pdf_path, pdf_parser)

# Define a custom prompt to extract only financial-related entities
custom_prompt = """
Insert in the schema only info related to the top 10 investment.
"""

# Extract entities from the PDF using the custom prompt
entities = pdf_extractor.extract_entities(prompt=custom_prompt)
print("Extracted Entities:", entities)

if __name__ == "__main__":
main()
10 changes: 5 additions & 5 deletions scrapeschema/extractor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from abc import ABC, abstractmethod
from typing import List, Tuple, Dict, Any
from typing import List, Tuple, Dict, Any, Optional
from .primitives import Entity, Relation
from .parsers.base_parser import BaseParser
from .parsers.prompts import DELETE_PROMPT, UPDATE_ENTITIES_PROMPT, UPDATE_SCHEMA_PROMPT
Expand Down Expand Up @@ -53,12 +53,12 @@ def __init__(self, file_path: str, parser: BaseParser):
self.file_path = file_path
self.parser = parser

def extract_entities(self) -> List[Entity]:
new_entities = self.parser.extract_entities(self.file_path)
def extract_entities(self, prompt: Optional[str] = None) -> List[Entity]:
new_entities = self.parser.extract_entities(self.file_path, prompt)
return new_entities

def extract_relations(self) -> List[Relation]:
return self.parser.extract_relations(self.file_path)
def extract_relations(self, prompt: Optional[str] = None) -> List[Relation]:
return self.parser.extract_relations(self.file_path, prompt)


def entities_json_schema(self) -> Dict[str, Any]:
Expand Down
4 changes: 2 additions & 2 deletions scrapeschema/parsers/base_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ def __init__(self, llm_client: LLMClient):
self._relations = []

@abstractmethod
def extract_entities(self, file_path: str) -> List[Entity]:
def extract_entities(self, file_path: str, prompt: Optional[str] = None) -> List[Entity]:
pass

@abstractmethod
def extract_relations(self, file_path: Optional[str] = None) -> List[Relation]:
def extract_relations(self, file_path: Optional[str] = None, prompt: Optional[str] = None) -> List[Relation]:
pass

@abstractmethod
Expand Down
61 changes: 34 additions & 27 deletions scrapeschema/parsers/pdf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,20 @@ def __init__(self, llm_client: LLMClient):

super().__init__(llm_client)

def extract_entities(self, file_path: str) -> List[Entity]:
def extract_entities(self, file_path: str, prompt: Optional[str] = None) -> List[Entity]:
if not os.path.exists(file_path):
raise FileNotFoundError(f"PDF file not found: {file_path}")

new_entities = self._extract_entities_from_pdf(file_path)
new_entities = self._extract_entities_from_pdf(file_path, prompt)
return self.update_entities(new_entities)

def _extract_entities_from_pdf(self, file_path: str) -> List[Entity]:
def _extract_entities_from_pdf(self, file_path: str, prompt: Optional[str] = None) -> List[Entity]:
if prompt:
entities_json_schema = self.entities_json_schema(file_path, prompt)
else:
entities_json_schema = self.entities_json_schema(file_path)

entities = []
entities_json_schema = self.entities_json_schema(file_path)

def traverse_schema(schema: Dict[str, Any], parent_id: str = None):
if isinstance(schema, dict):
entity_id = parent_id if parent_id else schema.get('title', 'root')
Expand Down Expand Up @@ -233,14 +236,18 @@ def update_entities(self, new_entities: List[Entity]) -> List[Entity]:
# Update the parser's entities
self._entities = updated_entities

# print the updated entities
logging.info("Updated entities:")
for entity in updated_entities:
logging.info(entity.__dict__)
logging.info(f"Entities updated. New count: {len(updated_entities)}")
return updated_entities
except json.JSONDecodeError as e:
logging.error(f"JSONDecodeError: {e}")
logging.error("Error: Unable to parse the LLM response.")
return existing_entities

def extract_relations(self, file_path: Optional[str] = None) -> List[Relation]:
def extract_relations(self, file_path: Optional[str] = None, prompt: Optional[str] = None) -> List[Relation]:
"""
Extracts relations from a PDF file.

Expand All @@ -250,41 +257,38 @@ def extract_relations(self, file_path: Optional[str] = None) -> List[Relation]:
Returns:
List[Relation]: A list of extracted relations.
"""
if file_path is not None:
if not os.path.exists(file_path):
raise FileNotFoundError(f"PDF file not found: {file_path}")

if not self._entities or len(self._entities) == 0:
self.extract_entities(file_path)
if file_path is not None and not os.path.exists(file_path):
raise FileNotFoundError(f"PDF file not found: {file_path}")

if not self._entities or len(self._entities) == 0:
self.extract_entities(file_path, prompt)

relation_class_str = inspect.getsource(Relation)
relations_prompt = RELATIONS_PROMPT.format(
entities=json.dumps([e.__dict__ for e in self._entities], indent=2),
relation_class=relation_class_str
)

if prompt:
#append to the relations_prompt the prompt
relations_prompt += f"\n\n Extract only the relations that are required from the following user prompt:\n\n{prompt}"


relations_answer_code = self.llm_client.get_response(relations_prompt)
relations_answer_code = self._extract_python_content(relations_answer_code)

# Create a new dictionary to store the local variables
local_vars = {}

# Execute the code in the context of local_vars
try:
exec(relations_answer_code, globals(), local_vars)
except Exception as e:
logging.error(f"Error executing relations code: {e}")
raise ValueError(f"The language model generated invalid code: {e}") from e

# Extract the relations from local_vars

relations_answer = local_vars.get('relations', [])

self._relations = relations_answer
logging.info(f"Extracted relations: {relations_answer_code}")

return self._relations
return self._relations


def plot_entities_schema(self, file_path: str) -> None:
"""
Plots the entities schema from a PDF file.
Expand All @@ -307,7 +311,7 @@ def plot_entities_schema(self, file_path: str) -> None:
logging.info("digraph_code_execution----------------------------------")
exec(digraph_code[9:-3])

def entities_json_schema(self, file_path: str) -> Dict[str, Any]:
def entities_json_schema(self, file_path: str, prompt: Optional[str] = None) -> Dict[str, Any]:
"""
Generates a JSON schema of entities from a PDF file.

Expand All @@ -323,7 +327,7 @@ def entities_json_schema(self, file_path: str) -> Dict[str, Any]:
base64_images = process_pdf(file_path)

if base64_images:
page_answers = self._generate_json_schema(base64_images)
page_answers = self._generate_json_schema(base64_images, prompt)
json_schema = self._merge_json_schemas(page_answers)
json_schema = self._extract_json_content(json_schema)

Expand Down Expand Up @@ -365,17 +369,20 @@ def _merge_digraphs_for_plot(self, page_answers: List[str]) -> str:
digraph_code = self.llm_client.get_response(digraph_prompt)
return digraph_code

def _generate_json_schema(self, base64_images: List[str]) -> List[str]:
def _generate_json_schema(self, base64_images: List[str], prompt: Optional[str] = None) -> List[str]:
page_answers = []
for page_num, base64_image in enumerate(base64_images, start=1):
prompt = f"{JSON_SCHEMA_PROMPT} (Page {page_num})"
if prompt:
customized_prompt = f"{JSON_SCHEMA_PROMPT} extract only what is required from the following prompt:\
{prompt} (Page {page_num})"
else:
customized_prompt = f"{JSON_SCHEMA_PROMPT} (Page {page_num})"

image_data = f"data:image/jpeg;base64,{base64_image}"
try:
answer = self.llm_client.get_response(prompt, image_url=image_data)
answer = self.llm_client.get_response(customized_prompt, image_url=image_data)
except ReadTimeout:
logging.warning("Request to OpenAI API timed out. Retrying...")
# Implement retry logic or skip to the next image
continue
answer = self._extract_json_content(answer)
page_answers.append(f"Page {page_num}: {answer}")
Expand Down