From 83e9dcf149846a6dd7cb0f38e1aeea502b9778fb Mon Sep 17 00:00:00 2001 From: Lorenzo Padoan Date: Mon, 23 Sep 2024 20:16:45 +0200 Subject: [PATCH] feat(): merge schemas implementation --- examples/merge_schemas.py | 51 ++++++++++ scrapeschema/extractor.py | 145 +++++++++++++++++++++++++++- scrapeschema/parsers/base_parser.py | 40 +++++++- scrapeschema/parsers/pdf_parser.py | 19 ++-- scrapeschema/parsers/prompts.py | 16 +++ 5 files changed, 258 insertions(+), 13 deletions(-) create mode 100644 examples/merge_schemas.py diff --git a/examples/merge_schemas.py b/examples/merge_schemas.py new file mode 100644 index 0000000..1a4f2da --- /dev/null +++ b/examples/merge_schemas.py @@ -0,0 +1,51 @@ +from scrapeschema import FileExtractor, PDFParser +from scrapeschema.llm_client import LLMClient +import os +from dotenv import load_dotenv + +load_dotenv() # Load environment variables from .env file + +# Get the OpenAI API key from the environment variables +api_key = os.getenv("OPENAI_API_KEY") + +# get current directory +curr_dirr = os.path.dirname(os.path.abspath(__file__)) + +def main(): + # Path to the PDF file + pdf_name = "test.pdf" + pdf_path = os.path.join(curr_dirr, pdf_name) + + # Create a PDFParser instance with the API key + llm_client = LLMClient(api_key) + pdf_parser = PDFParser(llm_client) + + # Create a FileExtractor instance with the PDF parser + pdf_extractor = FileExtractor(pdf_path, pdf_parser) + + # Extract entities from the PDF + entities = pdf_extractor.extract_entities() + print("Extracted Entities:", entities) + + # Hardcoded schema to merge with + hardcoded_schema = { + "title": "Fund", + "type": "object", + "properties": { + "costCategory": { + "type": "object", + "properties": { + "costFlag": {"type": "string"}, + + } + } + } + } + + # Perform merge schema function + updated_schema = pdf_extractor.merge_schemas(hardcoded_schema) + + print("Updated Schema:", updated_schema) + +if __name__ == "__main__": + main() diff --git a/scrapeschema/extractor.py b/scrapeschema/extractor.py index 84e2d0a..d34318c 100644 --- a/scrapeschema/extractor.py +++ b/scrapeschema/extractor.py @@ -3,8 +3,9 @@ from typing import List, Tuple, Dict, Any from .primitives import Entity, Relation from .parsers.base_parser import BaseParser -from .parsers.prompts import DELETE_PROMPT, UPDATE_ENTITIES_PROMPT -import requests +from .parsers.prompts import DELETE_PROMPT, UPDATE_ENTITIES_PROMPT, UPDATE_SCHEMA_PROMPT +from .llm_client import LLMClient +from .parsers.prompts import UPDATE_SCHEMA_PROMPT import json logger = logging.getLogger(__name__) @@ -18,7 +19,7 @@ def extract_entities(self) -> List[Entity]: @abstractmethod def extract_relations(self) -> List[Relation]: pass - + @abstractmethod def entities_json_schema(self) -> Dict[str, Any]: pass @@ -27,6 +28,26 @@ def entities_json_schema(self) -> Dict[str, Any]: def update_entities(self, new_entities: List[Entity]) -> List[Entity]: pass + @abstractmethod + def merge_schemas(self, other_schema: Dict[str, Any]) -> None: + pass + + @abstractmethod + def delete_entity_or_relation(self, item_description: str) -> None: + pass + + @abstractmethod + def set_entities(self, entities: List[Entity]) -> None: + pass + + @abstractmethod + def set_relations(self, relations: List[Relation]) -> None: + pass + + @abstractmethod + def set_json_schema(self, schema: Dict[str, Any]) -> None: + pass + class FileExtractor(Extractor): def __init__(self, file_path: str, parser: BaseParser): self.file_path = file_path @@ -39,6 +60,7 @@ def extract_entities(self) -> List[Entity]: def extract_relations(self) -> List[Relation]: return self.parser.extract_relations(self.file_path) + def entities_json_schema(self) -> Dict[str, Any]: return self.parser.entities_json_schema(self.file_path) @@ -116,4 +138,119 @@ def update_entities(self, new_entities: List[Entity]) -> List[Entity]: return updated_entities except json.JSONDecodeError: logger.error("Error: Unable to parse the LLM response.") - return existing_entities \ No newline at end of file + return existing_entities + + def set_json_schema(self, schema: Dict[str, Any]) -> None: + """ + Set the JSON schema for the parser. + + Args: + schema (Dict[str, Any]): The JSON schema to set. + """ + self.parser.set_json_schema(schema) + + def set_entities(self, entities: List[Entity]) -> None: + """ + Set the entities for the parser. + + Args: + entities (List[Entity]): The list of entities to set. + """ + self.parser.set_entities(entities) + + def set_relations(self, relations: List[Relation]) -> None: + """ + Set the relations for the parser. + + Args: + relations (List[Relation]): The list of relations to set. + """ + self.parser.set_relations(relations) + + def get_entities(self) -> List[Entity]: + """ + Get the entities from the parser. + + Returns: + List[Entity]: The list of entities. + """ + return self.parser.get_entities() + + def get_relations(self) -> List[Relation]: + """ + Get the relations from the parser. + + Returns: + List[Relation]: The list of relations. + """ + return self.parser.get_relations() + + def get_json_schema(self) -> Dict[str, Any]: + """ + Get the JSON schema from the parser. + + Returns: + Dict[str, Any]: The JSON schema. + """ + return self.parser.get_json_schema() + + def merge_schemas(self, other_schema: Dict[str, Any]) -> Dict[str, Any]: + """ + Merges the current parser's schema with another schema. + + Args: + other_schema (Dict[str, Any]): The schema to merge with. + """ + def _merge_json_schemas(self, schema1: Dict[str, Any], schema2: Dict[str, Any]) -> Dict[str, Any]: + """ + Merges two JSON schemas using an API call to OpenAI. + + Args: + schema1 (Dict[str, Any]): The first JSON schema. + schema2 (Dict[str, Any]): The second JSON schema. + + Returns: + Dict[str, Any]: The merged JSON schema. + """ + + # Initialize the LLMClient (assuming the API key is set in the environment) + + llm_client = self.parser.llm_client + + # Prepare the prompt + prompt = UPDATE_SCHEMA_PROMPT.format( + existing_schema=json.dumps(schema1, indent=2), + new_schema=json.dumps(schema2, indent=2) + ) + + # Get the response from the LLM + response = llm_client.get_response(prompt) + + # Extract the JSON schema from the response + response = response.strip().strip('```json').strip('```') + try: + merged_schema = json.loads(response) + return merged_schema + except json.JSONDecodeError as e: + logger.error(f"JSONDecodeError: {e}") + logger.error("Error: Unable to parse the LLM response.") + return schema1 # Return the original schema in case of an error + + + if not self.parser.get_json_schema(): + logger.error("No JSON schema found in the parser.") + return + + # Merge JSON schemas + merged_schema = _merge_json_schemas(self, self.get_json_schema(), other_schema) + + self.set_json_schema(merged_schema) + # Re-extract entities and relations based on the merged schema + new_entities = self.parser.extract_entities_from_json_schema(merged_schema) + new_relations = self.parser.extract_relations() + + return self.get_json_schema() + + + ########################################################################################### + \ No newline at end of file diff --git a/scrapeschema/parsers/base_parser.py b/scrapeschema/parsers/base_parser.py index f49f115..635db8a 100644 --- a/scrapeschema/parsers/base_parser.py +++ b/scrapeschema/parsers/base_parser.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import List, Dict, Any +from typing import List, Dict, Any, Optional from ..primitives import Entity, Relation from ..llm_client import LLMClient @@ -16,6 +16,7 @@ def __init__(self, llm_client: LLMClient): "Content-Type": "application/json", "Authorization": f"Bearer {self.llm_client.get_api_key()}" } + self._json_schema = {} self._entities = [] self._relations = [] @@ -24,7 +25,7 @@ def extract_entities(self, file_path: str) -> List[Entity]: pass @abstractmethod - def extract_relations(self, file_path: str) -> List[Relation]: + def extract_relations(self, file_path: Optional[str] = None) -> List[Relation]: pass @abstractmethod @@ -61,6 +62,9 @@ def get_entities(self): def get_relations(self): return self._relations + def get_json_schema(self): + return self._json_schema + def set_entities(self, entities: List[Entity]): if not isinstance(entities, list) or not all(isinstance(entity, Entity) for entity in entities): raise TypeError("entities must be a List of Entity objects") @@ -70,3 +74,35 @@ def set_relations(self, relations: List[Relation]): if not isinstance(relations, list) or not all(isinstance(relation, Relation) for relation in relations): raise TypeError("relations must be a List of Relation objects") self._relations = relations + + def set_json_schema(self, schema: Dict[str, Any]): + self._json_schema = schema + + def extract_entities_from_json_schema(self, json_schema: Dict[str, Any]) -> List[Entity]: + """ + Extracts entities from a given JSON schema. + + Args: + json_schema (Dict[str, Any]): The JSON schema to extract entities from. + + Returns: + List[Entity]: A list of extracted entities. + """ + entities = [] + + def traverse_schema(schema: Dict[str, Any], parent_id: str = None): + if isinstance(schema, dict): + entity_id = parent_id if parent_id else schema.get('title', 'root') + entity_type = schema.get('type', 'object') + attributes = schema.get('properties', {}) + + if attributes: + entity = Entity(id=entity_id, type=entity_type, attributes=attributes) + entities.append(entity) + + for key, value in attributes.items(): + traverse_schema(value, key) + + traverse_schema(json_schema) + self.set_entities(entities) + return entities diff --git a/scrapeschema/parsers/pdf_parser.py b/scrapeschema/parsers/pdf_parser.py index fb4517a..bfecc34 100644 --- a/scrapeschema/parsers/pdf_parser.py +++ b/scrapeschema/parsers/pdf_parser.py @@ -4,7 +4,6 @@ import base64 import os import tempfile -import requests import json from .prompts import DIGRAPH_EXAMPLE_PROMPT, JSON_SCHEMA_PROMPT, RELATIONS_PROMPT, UPDATE_ENTITIES_PROMPT from PIL import Image @@ -13,6 +12,7 @@ import logging import re from ..llm_client import LLMClient +from requests.exceptions import ReadTimeout # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -240,7 +240,7 @@ def update_entities(self, new_entities: List[Entity]) -> List[Entity]: logging.error("Error: Unable to parse the LLM response.") return existing_entities - def extract_relations(self, file_path: str) -> List[Relation]: + def extract_relations(self, file_path: Optional[str] = None) -> List[Relation]: """ Extracts relations from a PDF file. @@ -250,11 +250,12 @@ def extract_relations(self, file_path: str) -> List[Relation]: Returns: List[Relation]: A list of extracted relations. """ - if not os.path.exists(file_path): - raise FileNotFoundError(f"PDF file not found: {file_path}") - - if not self._entities or len(self._entities) == 0: - self.extract_entities(file_path) + if file_path is not None: + if not os.path.exists(file_path): + raise FileNotFoundError(f"PDF file not found: {file_path}") + + if not self._entities or len(self._entities) == 0: + self.extract_entities(file_path) relation_class_str = inspect.getsource(Relation) relations_prompt = RELATIONS_PROMPT.format( @@ -330,6 +331,10 @@ def entities_json_schema(self, file_path: str) -> Dict[str, Any]: logging.info(json_schema) # json schema is a valid json schema but its a string convert it to a python dict entities_json_schema = json.loads(json_schema) + + # Assign the generated schema to self._json_schema + self._json_schema = entities_json_schema + return entities_json_schema def _generate_digraph(self, base64_images: List[str]) -> List[str]: diff --git a/scrapeschema/parsers/prompts.py b/scrapeschema/parsers/prompts.py index be819d3..c371ddd 100644 --- a/scrapeschema/parsers/prompts.py +++ b/scrapeschema/parsers/prompts.py @@ -202,4 +202,20 @@ Please provide the updated list of entities as a JSON array. Each entity should be a JSON object with 'id', 'type', and 'attributes' fields. Provide only the JSON array, wrapped in backticks (`) like ```json ... ``` and nothing else. +""" + +UPDATE_SCHEMA_PROMPT = """ +You need to update the json schema with the new one, avoiding duplicates and reconciling any conflicts. Here are the rules: + +1. If a new entity has the same ID as an existing entity, update the existing entity with any new or changed attributes. +2. Add any completely new entities that don't match with existing ones. +3. Try to maintain the base structure you have for the existing entities, adding new entities or updating existing entities +4. If exist entities is empty, copy the new entity into the existing entity as they are. +{existing_schema} + +With this json schema: +{new_schema} + +Please provide the updated json schema as a JSON object. +Provide only the JSON object, wrapped in backticks (`) like ```json ... ``` and nothing else. """ \ No newline at end of file