Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions examples/merge_schemas.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from scrapeschema import FileExtractor, PDFParser
from scrapeschema.llm_client import LLMClient
import os
from dotenv import load_dotenv

load_dotenv() # Load environment variables from .env file

# Get the OpenAI API key from the environment variables
api_key = os.getenv("OPENAI_API_KEY")

# get current directory
curr_dirr = os.path.dirname(os.path.abspath(__file__))

def main():
# Path to the PDF file
pdf_name = "test.pdf"
pdf_path = os.path.join(curr_dirr, pdf_name)

# Create a PDFParser instance with the API key
llm_client = LLMClient(api_key)
pdf_parser = PDFParser(llm_client)

# Create a FileExtractor instance with the PDF parser
pdf_extractor = FileExtractor(pdf_path, pdf_parser)

# Extract entities from the PDF
entities = pdf_extractor.extract_entities()
print("Extracted Entities:", entities)

# Hardcoded schema to merge with
hardcoded_schema = {
"title": "Fund",
"type": "object",
"properties": {
"costCategory": {
"type": "object",
"properties": {
"costFlag": {"type": "string"},

}
}
}
}

# Perform merge schema function
updated_schema = pdf_extractor.merge_schemas(hardcoded_schema)

print("Updated Schema:", updated_schema)

if __name__ == "__main__":
main()
145 changes: 141 additions & 4 deletions scrapeschema/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from typing import List, Tuple, Dict, Any
from .primitives import Entity, Relation
from .parsers.base_parser import BaseParser
from .parsers.prompts import DELETE_PROMPT, UPDATE_ENTITIES_PROMPT
import requests
from .parsers.prompts import DELETE_PROMPT, UPDATE_ENTITIES_PROMPT, UPDATE_SCHEMA_PROMPT
from .llm_client import LLMClient
from .parsers.prompts import UPDATE_SCHEMA_PROMPT
import json

logger = logging.getLogger(__name__)
Expand All @@ -18,7 +19,7 @@ def extract_entities(self) -> List[Entity]:
@abstractmethod
def extract_relations(self) -> List[Relation]:
pass

@abstractmethod
def entities_json_schema(self) -> Dict[str, Any]:
pass
Expand All @@ -27,6 +28,26 @@ def entities_json_schema(self) -> Dict[str, Any]:
def update_entities(self, new_entities: List[Entity]) -> List[Entity]:
pass

@abstractmethod
def merge_schemas(self, other_schema: Dict[str, Any]) -> None:
pass

@abstractmethod
def delete_entity_or_relation(self, item_description: str) -> None:
pass

@abstractmethod
def set_entities(self, entities: List[Entity]) -> None:
pass

@abstractmethod
def set_relations(self, relations: List[Relation]) -> None:
pass

@abstractmethod
def set_json_schema(self, schema: Dict[str, Any]) -> None:
pass

class FileExtractor(Extractor):
def __init__(self, file_path: str, parser: BaseParser):
self.file_path = file_path
Expand All @@ -39,6 +60,7 @@ def extract_entities(self) -> List[Entity]:
def extract_relations(self) -> List[Relation]:
return self.parser.extract_relations(self.file_path)


def entities_json_schema(self) -> Dict[str, Any]:
return self.parser.entities_json_schema(self.file_path)

Expand Down Expand Up @@ -116,4 +138,119 @@ def update_entities(self, new_entities: List[Entity]) -> List[Entity]:
return updated_entities
except json.JSONDecodeError:
logger.error("Error: Unable to parse the LLM response.")
return existing_entities
return existing_entities

def set_json_schema(self, schema: Dict[str, Any]) -> None:
"""
Set the JSON schema for the parser.

Args:
schema (Dict[str, Any]): The JSON schema to set.
"""
self.parser.set_json_schema(schema)

def set_entities(self, entities: List[Entity]) -> None:
"""
Set the entities for the parser.

Args:
entities (List[Entity]): The list of entities to set.
"""
self.parser.set_entities(entities)

def set_relations(self, relations: List[Relation]) -> None:
"""
Set the relations for the parser.

Args:
relations (List[Relation]): The list of relations to set.
"""
self.parser.set_relations(relations)

def get_entities(self) -> List[Entity]:
"""
Get the entities from the parser.

Returns:
List[Entity]: The list of entities.
"""
return self.parser.get_entities()

def get_relations(self) -> List[Relation]:
"""
Get the relations from the parser.

Returns:
List[Relation]: The list of relations.
"""
return self.parser.get_relations()

def get_json_schema(self) -> Dict[str, Any]:
"""
Get the JSON schema from the parser.

Returns:
Dict[str, Any]: The JSON schema.
"""
return self.parser.get_json_schema()

def merge_schemas(self, other_schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Merges the current parser's schema with another schema.

Args:
other_schema (Dict[str, Any]): The schema to merge with.
"""
def _merge_json_schemas(self, schema1: Dict[str, Any], schema2: Dict[str, Any]) -> Dict[str, Any]:
"""
Merges two JSON schemas using an API call to OpenAI.

Args:
schema1 (Dict[str, Any]): The first JSON schema.
schema2 (Dict[str, Any]): The second JSON schema.

Returns:
Dict[str, Any]: The merged JSON schema.
"""

# Initialize the LLMClient (assuming the API key is set in the environment)

llm_client = self.parser.llm_client

# Prepare the prompt
prompt = UPDATE_SCHEMA_PROMPT.format(
existing_schema=json.dumps(schema1, indent=2),
new_schema=json.dumps(schema2, indent=2)
)

# Get the response from the LLM
response = llm_client.get_response(prompt)

# Extract the JSON schema from the response
response = response.strip().strip('```json').strip('```')
try:
merged_schema = json.loads(response)
return merged_schema
except json.JSONDecodeError as e:
logger.error(f"JSONDecodeError: {e}")
logger.error("Error: Unable to parse the LLM response.")
return schema1 # Return the original schema in case of an error


if not self.parser.get_json_schema():
logger.error("No JSON schema found in the parser.")
return

# Merge JSON schemas
merged_schema = _merge_json_schemas(self, self.get_json_schema(), other_schema)

self.set_json_schema(merged_schema)
# Re-extract entities and relations based on the merged schema
new_entities = self.parser.extract_entities_from_json_schema(merged_schema)
new_relations = self.parser.extract_relations()

return self.get_json_schema()


###########################################################################################

40 changes: 38 additions & 2 deletions scrapeschema/parsers/base_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Dict, Any
from typing import List, Dict, Any, Optional
from ..primitives import Entity, Relation
from ..llm_client import LLMClient

Expand All @@ -16,6 +16,7 @@ def __init__(self, llm_client: LLMClient):
"Content-Type": "application/json",
"Authorization": f"Bearer {self.llm_client.get_api_key()}"
}
self._json_schema = {}
self._entities = []
self._relations = []

Expand All @@ -24,7 +25,7 @@ def extract_entities(self, file_path: str) -> List[Entity]:
pass

@abstractmethod
def extract_relations(self, file_path: str) -> List[Relation]:
def extract_relations(self, file_path: Optional[str] = None) -> List[Relation]:
pass

@abstractmethod
Expand Down Expand Up @@ -61,6 +62,9 @@ def get_entities(self):
def get_relations(self):
return self._relations

def get_json_schema(self):
return self._json_schema

def set_entities(self, entities: List[Entity]):
if not isinstance(entities, list) or not all(isinstance(entity, Entity) for entity in entities):
raise TypeError("entities must be a List of Entity objects")
Expand All @@ -70,3 +74,35 @@ def set_relations(self, relations: List[Relation]):
if not isinstance(relations, list) or not all(isinstance(relation, Relation) for relation in relations):
raise TypeError("relations must be a List of Relation objects")
self._relations = relations

def set_json_schema(self, schema: Dict[str, Any]):
self._json_schema = schema

def extract_entities_from_json_schema(self, json_schema: Dict[str, Any]) -> List[Entity]:
"""
Extracts entities from a given JSON schema.

Args:
json_schema (Dict[str, Any]): The JSON schema to extract entities from.

Returns:
List[Entity]: A list of extracted entities.
"""
entities = []

def traverse_schema(schema: Dict[str, Any], parent_id: str = None):
if isinstance(schema, dict):
entity_id = parent_id if parent_id else schema.get('title', 'root')
entity_type = schema.get('type', 'object')
attributes = schema.get('properties', {})

if attributes:
entity = Entity(id=entity_id, type=entity_type, attributes=attributes)
entities.append(entity)

for key, value in attributes.items():
traverse_schema(value, key)

traverse_schema(json_schema)
self.set_entities(entities)
return entities
19 changes: 12 additions & 7 deletions scrapeschema/parsers/pdf_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import base64
import os
import tempfile
import requests
import json
from .prompts import DIGRAPH_EXAMPLE_PROMPT, JSON_SCHEMA_PROMPT, RELATIONS_PROMPT, UPDATE_ENTITIES_PROMPT
from PIL import Image
Expand All @@ -13,6 +12,7 @@
import logging
import re
from ..llm_client import LLMClient
from requests.exceptions import ReadTimeout

# Set up logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
Expand Down Expand Up @@ -240,7 +240,7 @@ def update_entities(self, new_entities: List[Entity]) -> List[Entity]:
logging.error("Error: Unable to parse the LLM response.")
return existing_entities

def extract_relations(self, file_path: str) -> List[Relation]:
def extract_relations(self, file_path: Optional[str] = None) -> List[Relation]:
"""
Extracts relations from a PDF file.

Expand All @@ -250,11 +250,12 @@ def extract_relations(self, file_path: str) -> List[Relation]:
Returns:
List[Relation]: A list of extracted relations.
"""
if not os.path.exists(file_path):
raise FileNotFoundError(f"PDF file not found: {file_path}")

if not self._entities or len(self._entities) == 0:
self.extract_entities(file_path)
if file_path is not None:
if not os.path.exists(file_path):
raise FileNotFoundError(f"PDF file not found: {file_path}")

if not self._entities or len(self._entities) == 0:
self.extract_entities(file_path)

relation_class_str = inspect.getsource(Relation)
relations_prompt = RELATIONS_PROMPT.format(
Expand Down Expand Up @@ -330,6 +331,10 @@ def entities_json_schema(self, file_path: str) -> Dict[str, Any]:
logging.info(json_schema)
# json schema is a valid json schema but its a string convert it to a python dict
entities_json_schema = json.loads(json_schema)

# Assign the generated schema to self._json_schema
self._json_schema = entities_json_schema

return entities_json_schema

def _generate_digraph(self, base64_images: List[str]) -> List[str]:
Expand Down
16 changes: 16 additions & 0 deletions scrapeschema/parsers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,4 +202,20 @@

Please provide the updated list of entities as a JSON array. Each entity should be a JSON object with 'id', 'type', and 'attributes' fields.
Provide only the JSON array, wrapped in backticks (`) like ```json ... ``` and nothing else.
"""

UPDATE_SCHEMA_PROMPT = """
You need to update the json schema with the new one, avoiding duplicates and reconciling any conflicts. Here are the rules:

1. If a new entity has the same ID as an existing entity, update the existing entity with any new or changed attributes.
2. Add any completely new entities that don't match with existing ones.
3. Try to maintain the base structure you have for the existing entities, adding new entities or updating existing entities
4. If exist entities is empty, copy the new entity into the existing entity as they are.
{existing_schema}

With this json schema:
{new_schema}

Please provide the updated json schema as a JSON object.
Provide only the JSON object, wrapped in backticks (`) like ```json ... ``` and nothing else.
"""