Skip to content

Commit 38bc5d0

Browse files
authored
Merge pull request #18 from ScrapeGraphAI/5-merge-two-schemas
feat(): merge schemas implementation
2 parents 41286df + 83e9dcf commit 38bc5d0

File tree

5 files changed

+258
-13
lines changed

5 files changed

+258
-13
lines changed

examples/merge_schemas.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
from scrapeschema import FileExtractor, PDFParser
2+
from scrapeschema.llm_client import LLMClient
3+
import os
4+
from dotenv import load_dotenv
5+
6+
load_dotenv() # Load environment variables from .env file
7+
8+
# Get the OpenAI API key from the environment variables
9+
api_key = os.getenv("OPENAI_API_KEY")
10+
11+
# get current directory
12+
curr_dirr = os.path.dirname(os.path.abspath(__file__))
13+
14+
def main():
15+
# Path to the PDF file
16+
pdf_name = "test.pdf"
17+
pdf_path = os.path.join(curr_dirr, pdf_name)
18+
19+
# Create a PDFParser instance with the API key
20+
llm_client = LLMClient(api_key)
21+
pdf_parser = PDFParser(llm_client)
22+
23+
# Create a FileExtractor instance with the PDF parser
24+
pdf_extractor = FileExtractor(pdf_path, pdf_parser)
25+
26+
# Extract entities from the PDF
27+
entities = pdf_extractor.extract_entities()
28+
print("Extracted Entities:", entities)
29+
30+
# Hardcoded schema to merge with
31+
hardcoded_schema = {
32+
"title": "Fund",
33+
"type": "object",
34+
"properties": {
35+
"costCategory": {
36+
"type": "object",
37+
"properties": {
38+
"costFlag": {"type": "string"},
39+
40+
}
41+
}
42+
}
43+
}
44+
45+
# Perform merge schema function
46+
updated_schema = pdf_extractor.merge_schemas(hardcoded_schema)
47+
48+
print("Updated Schema:", updated_schema)
49+
50+
if __name__ == "__main__":
51+
main()

scrapeschema/extractor.py

Lines changed: 141 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from typing import List, Tuple, Dict, Any
44
from .primitives import Entity, Relation
55
from .parsers.base_parser import BaseParser
6-
from .parsers.prompts import DELETE_PROMPT, UPDATE_ENTITIES_PROMPT
7-
import requests
6+
from .parsers.prompts import DELETE_PROMPT, UPDATE_ENTITIES_PROMPT, UPDATE_SCHEMA_PROMPT
7+
from .llm_client import LLMClient
8+
from .parsers.prompts import UPDATE_SCHEMA_PROMPT
89
import json
910

1011
logger = logging.getLogger(__name__)
@@ -18,7 +19,7 @@ def extract_entities(self) -> List[Entity]:
1819
@abstractmethod
1920
def extract_relations(self) -> List[Relation]:
2021
pass
21-
22+
2223
@abstractmethod
2324
def entities_json_schema(self) -> Dict[str, Any]:
2425
pass
@@ -27,6 +28,26 @@ def entities_json_schema(self) -> Dict[str, Any]:
2728
def update_entities(self, new_entities: List[Entity]) -> List[Entity]:
2829
pass
2930

31+
@abstractmethod
32+
def merge_schemas(self, other_schema: Dict[str, Any]) -> None:
33+
pass
34+
35+
@abstractmethod
36+
def delete_entity_or_relation(self, item_description: str) -> None:
37+
pass
38+
39+
@abstractmethod
40+
def set_entities(self, entities: List[Entity]) -> None:
41+
pass
42+
43+
@abstractmethod
44+
def set_relations(self, relations: List[Relation]) -> None:
45+
pass
46+
47+
@abstractmethod
48+
def set_json_schema(self, schema: Dict[str, Any]) -> None:
49+
pass
50+
3051
class FileExtractor(Extractor):
3152
def __init__(self, file_path: str, parser: BaseParser):
3253
self.file_path = file_path
@@ -39,6 +60,7 @@ def extract_entities(self) -> List[Entity]:
3960
def extract_relations(self) -> List[Relation]:
4061
return self.parser.extract_relations(self.file_path)
4162

63+
4264
def entities_json_schema(self) -> Dict[str, Any]:
4365
return self.parser.entities_json_schema(self.file_path)
4466

@@ -116,4 +138,119 @@ def update_entities(self, new_entities: List[Entity]) -> List[Entity]:
116138
return updated_entities
117139
except json.JSONDecodeError:
118140
logger.error("Error: Unable to parse the LLM response.")
119-
return existing_entities
141+
return existing_entities
142+
143+
def set_json_schema(self, schema: Dict[str, Any]) -> None:
144+
"""
145+
Set the JSON schema for the parser.
146+
147+
Args:
148+
schema (Dict[str, Any]): The JSON schema to set.
149+
"""
150+
self.parser.set_json_schema(schema)
151+
152+
def set_entities(self, entities: List[Entity]) -> None:
153+
"""
154+
Set the entities for the parser.
155+
156+
Args:
157+
entities (List[Entity]): The list of entities to set.
158+
"""
159+
self.parser.set_entities(entities)
160+
161+
def set_relations(self, relations: List[Relation]) -> None:
162+
"""
163+
Set the relations for the parser.
164+
165+
Args:
166+
relations (List[Relation]): The list of relations to set.
167+
"""
168+
self.parser.set_relations(relations)
169+
170+
def get_entities(self) -> List[Entity]:
171+
"""
172+
Get the entities from the parser.
173+
174+
Returns:
175+
List[Entity]: The list of entities.
176+
"""
177+
return self.parser.get_entities()
178+
179+
def get_relations(self) -> List[Relation]:
180+
"""
181+
Get the relations from the parser.
182+
183+
Returns:
184+
List[Relation]: The list of relations.
185+
"""
186+
return self.parser.get_relations()
187+
188+
def get_json_schema(self) -> Dict[str, Any]:
189+
"""
190+
Get the JSON schema from the parser.
191+
192+
Returns:
193+
Dict[str, Any]: The JSON schema.
194+
"""
195+
return self.parser.get_json_schema()
196+
197+
def merge_schemas(self, other_schema: Dict[str, Any]) -> Dict[str, Any]:
198+
"""
199+
Merges the current parser's schema with another schema.
200+
201+
Args:
202+
other_schema (Dict[str, Any]): The schema to merge with.
203+
"""
204+
def _merge_json_schemas(self, schema1: Dict[str, Any], schema2: Dict[str, Any]) -> Dict[str, Any]:
205+
"""
206+
Merges two JSON schemas using an API call to OpenAI.
207+
208+
Args:
209+
schema1 (Dict[str, Any]): The first JSON schema.
210+
schema2 (Dict[str, Any]): The second JSON schema.
211+
212+
Returns:
213+
Dict[str, Any]: The merged JSON schema.
214+
"""
215+
216+
# Initialize the LLMClient (assuming the API key is set in the environment)
217+
218+
llm_client = self.parser.llm_client
219+
220+
# Prepare the prompt
221+
prompt = UPDATE_SCHEMA_PROMPT.format(
222+
existing_schema=json.dumps(schema1, indent=2),
223+
new_schema=json.dumps(schema2, indent=2)
224+
)
225+
226+
# Get the response from the LLM
227+
response = llm_client.get_response(prompt)
228+
229+
# Extract the JSON schema from the response
230+
response = response.strip().strip('```json').strip('```')
231+
try:
232+
merged_schema = json.loads(response)
233+
return merged_schema
234+
except json.JSONDecodeError as e:
235+
logger.error(f"JSONDecodeError: {e}")
236+
logger.error("Error: Unable to parse the LLM response.")
237+
return schema1 # Return the original schema in case of an error
238+
239+
240+
if not self.parser.get_json_schema():
241+
logger.error("No JSON schema found in the parser.")
242+
return
243+
244+
# Merge JSON schemas
245+
merged_schema = _merge_json_schemas(self, self.get_json_schema(), other_schema)
246+
247+
self.set_json_schema(merged_schema)
248+
# Re-extract entities and relations based on the merged schema
249+
new_entities = self.parser.extract_entities_from_json_schema(merged_schema)
250+
new_relations = self.parser.extract_relations()
251+
252+
return self.get_json_schema()
253+
254+
255+
###########################################################################################
256+

scrapeschema/parsers/base_parser.py

Lines changed: 38 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import List, Dict, Any
2+
from typing import List, Dict, Any, Optional
33
from ..primitives import Entity, Relation
44
from ..llm_client import LLMClient
55

@@ -16,6 +16,7 @@ def __init__(self, llm_client: LLMClient):
1616
"Content-Type": "application/json",
1717
"Authorization": f"Bearer {self.llm_client.get_api_key()}"
1818
}
19+
self._json_schema = {}
1920
self._entities = []
2021
self._relations = []
2122

@@ -24,7 +25,7 @@ def extract_entities(self, file_path: str) -> List[Entity]:
2425
pass
2526

2627
@abstractmethod
27-
def extract_relations(self, file_path: str) -> List[Relation]:
28+
def extract_relations(self, file_path: Optional[str] = None) -> List[Relation]:
2829
pass
2930

3031
@abstractmethod
@@ -61,6 +62,9 @@ def get_entities(self):
6162
def get_relations(self):
6263
return self._relations
6364

65+
def get_json_schema(self):
66+
return self._json_schema
67+
6468
def set_entities(self, entities: List[Entity]):
6569
if not isinstance(entities, list) or not all(isinstance(entity, Entity) for entity in entities):
6670
raise TypeError("entities must be a List of Entity objects")
@@ -70,3 +74,35 @@ def set_relations(self, relations: List[Relation]):
7074
if not isinstance(relations, list) or not all(isinstance(relation, Relation) for relation in relations):
7175
raise TypeError("relations must be a List of Relation objects")
7276
self._relations = relations
77+
78+
def set_json_schema(self, schema: Dict[str, Any]):
79+
self._json_schema = schema
80+
81+
def extract_entities_from_json_schema(self, json_schema: Dict[str, Any]) -> List[Entity]:
82+
"""
83+
Extracts entities from a given JSON schema.
84+
85+
Args:
86+
json_schema (Dict[str, Any]): The JSON schema to extract entities from.
87+
88+
Returns:
89+
List[Entity]: A list of extracted entities.
90+
"""
91+
entities = []
92+
93+
def traverse_schema(schema: Dict[str, Any], parent_id: str = None):
94+
if isinstance(schema, dict):
95+
entity_id = parent_id if parent_id else schema.get('title', 'root')
96+
entity_type = schema.get('type', 'object')
97+
attributes = schema.get('properties', {})
98+
99+
if attributes:
100+
entity = Entity(id=entity_id, type=entity_type, attributes=attributes)
101+
entities.append(entity)
102+
103+
for key, value in attributes.items():
104+
traverse_schema(value, key)
105+
106+
traverse_schema(json_schema)
107+
self.set_entities(entities)
108+
return entities

scrapeschema/parsers/pdf_parser.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import base64
55
import os
66
import tempfile
7-
import requests
87
import json
98
from .prompts import DIGRAPH_EXAMPLE_PROMPT, JSON_SCHEMA_PROMPT, RELATIONS_PROMPT, UPDATE_ENTITIES_PROMPT
109
from PIL import Image
@@ -13,6 +12,7 @@
1312
import logging
1413
import re
1514
from ..llm_client import LLMClient
15+
from requests.exceptions import ReadTimeout
1616

1717
# Set up logging
1818
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -240,7 +240,7 @@ def update_entities(self, new_entities: List[Entity]) -> List[Entity]:
240240
logging.error("Error: Unable to parse the LLM response.")
241241
return existing_entities
242242

243-
def extract_relations(self, file_path: str) -> List[Relation]:
243+
def extract_relations(self, file_path: Optional[str] = None) -> List[Relation]:
244244
"""
245245
Extracts relations from a PDF file.
246246
@@ -250,11 +250,12 @@ def extract_relations(self, file_path: str) -> List[Relation]:
250250
Returns:
251251
List[Relation]: A list of extracted relations.
252252
"""
253-
if not os.path.exists(file_path):
254-
raise FileNotFoundError(f"PDF file not found: {file_path}")
255-
256-
if not self._entities or len(self._entities) == 0:
257-
self.extract_entities(file_path)
253+
if file_path is not None:
254+
if not os.path.exists(file_path):
255+
raise FileNotFoundError(f"PDF file not found: {file_path}")
256+
257+
if not self._entities or len(self._entities) == 0:
258+
self.extract_entities(file_path)
258259

259260
relation_class_str = inspect.getsource(Relation)
260261
relations_prompt = RELATIONS_PROMPT.format(
@@ -330,6 +331,10 @@ def entities_json_schema(self, file_path: str) -> Dict[str, Any]:
330331
logging.info(json_schema)
331332
# json schema is a valid json schema but its a string convert it to a python dict
332333
entities_json_schema = json.loads(json_schema)
334+
335+
# Assign the generated schema to self._json_schema
336+
self._json_schema = entities_json_schema
337+
333338
return entities_json_schema
334339

335340
def _generate_digraph(self, base64_images: List[str]) -> List[str]:

scrapeschema/parsers/prompts.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,4 +202,20 @@
202202
203203
Please provide the updated list of entities as a JSON array. Each entity should be a JSON object with 'id', 'type', and 'attributes' fields.
204204
Provide only the JSON array, wrapped in backticks (`) like ```json ... ``` and nothing else.
205+
"""
206+
207+
UPDATE_SCHEMA_PROMPT = """
208+
You need to update the json schema with the new one, avoiding duplicates and reconciling any conflicts. Here are the rules:
209+
210+
1. If a new entity has the same ID as an existing entity, update the existing entity with any new or changed attributes.
211+
2. Add any completely new entities that don't match with existing ones.
212+
3. Try to maintain the base structure you have for the existing entities, adding new entities or updating existing entities
213+
4. If exist entities is empty, copy the new entity into the existing entity as they are.
214+
{existing_schema}
215+
216+
With this json schema:
217+
{new_schema}
218+
219+
Please provide the updated json schema as a JSON object.
220+
Provide only the JSON object, wrapped in backticks (`) like ```json ... ``` and nothing else.
205221
"""

0 commit comments

Comments
 (0)