3
3
from typing import List , Tuple , Dict , Any
4
4
from .primitives import Entity , Relation
5
5
from .parsers .base_parser import BaseParser
6
- from .parsers .prompts import DELETE_PROMPT , UPDATE_ENTITIES_PROMPT
7
- import requests
6
+ from .parsers .prompts import DELETE_PROMPT , UPDATE_ENTITIES_PROMPT , UPDATE_SCHEMA_PROMPT
7
+ from .llm_client import LLMClient
8
+ from .parsers .prompts import UPDATE_SCHEMA_PROMPT
8
9
import json
9
10
10
11
logger = logging .getLogger (__name__ )
@@ -18,7 +19,7 @@ def extract_entities(self) -> List[Entity]:
18
19
@abstractmethod
19
20
def extract_relations (self ) -> List [Relation ]:
20
21
pass
21
-
22
+
22
23
@abstractmethod
23
24
def entities_json_schema (self ) -> Dict [str , Any ]:
24
25
pass
@@ -27,6 +28,26 @@ def entities_json_schema(self) -> Dict[str, Any]:
27
28
def update_entities (self , new_entities : List [Entity ]) -> List [Entity ]:
28
29
pass
29
30
31
+ @abstractmethod
32
+ def merge_schemas (self , other_schema : Dict [str , Any ]) -> None :
33
+ pass
34
+
35
+ @abstractmethod
36
+ def delete_entity_or_relation (self , item_description : str ) -> None :
37
+ pass
38
+
39
+ @abstractmethod
40
+ def set_entities (self , entities : List [Entity ]) -> None :
41
+ pass
42
+
43
+ @abstractmethod
44
+ def set_relations (self , relations : List [Relation ]) -> None :
45
+ pass
46
+
47
+ @abstractmethod
48
+ def set_json_schema (self , schema : Dict [str , Any ]) -> None :
49
+ pass
50
+
30
51
class FileExtractor (Extractor ):
31
52
def __init__ (self , file_path : str , parser : BaseParser ):
32
53
self .file_path = file_path
@@ -39,6 +60,7 @@ def extract_entities(self) -> List[Entity]:
39
60
def extract_relations (self ) -> List [Relation ]:
40
61
return self .parser .extract_relations (self .file_path )
41
62
63
+
42
64
def entities_json_schema (self ) -> Dict [str , Any ]:
43
65
return self .parser .entities_json_schema (self .file_path )
44
66
@@ -116,4 +138,119 @@ def update_entities(self, new_entities: List[Entity]) -> List[Entity]:
116
138
return updated_entities
117
139
except json .JSONDecodeError :
118
140
logger .error ("Error: Unable to parse the LLM response." )
119
- return existing_entities
141
+ return existing_entities
142
+
143
+ def set_json_schema (self , schema : Dict [str , Any ]) -> None :
144
+ """
145
+ Set the JSON schema for the parser.
146
+
147
+ Args:
148
+ schema (Dict[str, Any]): The JSON schema to set.
149
+ """
150
+ self .parser .set_json_schema (schema )
151
+
152
+ def set_entities (self , entities : List [Entity ]) -> None :
153
+ """
154
+ Set the entities for the parser.
155
+
156
+ Args:
157
+ entities (List[Entity]): The list of entities to set.
158
+ """
159
+ self .parser .set_entities (entities )
160
+
161
+ def set_relations (self , relations : List [Relation ]) -> None :
162
+ """
163
+ Set the relations for the parser.
164
+
165
+ Args:
166
+ relations (List[Relation]): The list of relations to set.
167
+ """
168
+ self .parser .set_relations (relations )
169
+
170
+ def get_entities (self ) -> List [Entity ]:
171
+ """
172
+ Get the entities from the parser.
173
+
174
+ Returns:
175
+ List[Entity]: The list of entities.
176
+ """
177
+ return self .parser .get_entities ()
178
+
179
+ def get_relations (self ) -> List [Relation ]:
180
+ """
181
+ Get the relations from the parser.
182
+
183
+ Returns:
184
+ List[Relation]: The list of relations.
185
+ """
186
+ return self .parser .get_relations ()
187
+
188
+ def get_json_schema (self ) -> Dict [str , Any ]:
189
+ """
190
+ Get the JSON schema from the parser.
191
+
192
+ Returns:
193
+ Dict[str, Any]: The JSON schema.
194
+ """
195
+ return self .parser .get_json_schema ()
196
+
197
+ def merge_schemas (self , other_schema : Dict [str , Any ]) -> Dict [str , Any ]:
198
+ """
199
+ Merges the current parser's schema with another schema.
200
+
201
+ Args:
202
+ other_schema (Dict[str, Any]): The schema to merge with.
203
+ """
204
+ def _merge_json_schemas (self , schema1 : Dict [str , Any ], schema2 : Dict [str , Any ]) -> Dict [str , Any ]:
205
+ """
206
+ Merges two JSON schemas using an API call to OpenAI.
207
+
208
+ Args:
209
+ schema1 (Dict[str, Any]): The first JSON schema.
210
+ schema2 (Dict[str, Any]): The second JSON schema.
211
+
212
+ Returns:
213
+ Dict[str, Any]: The merged JSON schema.
214
+ """
215
+
216
+ # Initialize the LLMClient (assuming the API key is set in the environment)
217
+
218
+ llm_client = self .parser .llm_client
219
+
220
+ # Prepare the prompt
221
+ prompt = UPDATE_SCHEMA_PROMPT .format (
222
+ existing_schema = json .dumps (schema1 , indent = 2 ),
223
+ new_schema = json .dumps (schema2 , indent = 2 )
224
+ )
225
+
226
+ # Get the response from the LLM
227
+ response = llm_client .get_response (prompt )
228
+
229
+ # Extract the JSON schema from the response
230
+ response = response .strip ().strip ('```json' ).strip ('```' )
231
+ try :
232
+ merged_schema = json .loads (response )
233
+ return merged_schema
234
+ except json .JSONDecodeError as e :
235
+ logger .error (f"JSONDecodeError: { e } " )
236
+ logger .error ("Error: Unable to parse the LLM response." )
237
+ return schema1 # Return the original schema in case of an error
238
+
239
+
240
+ if not self .parser .get_json_schema ():
241
+ logger .error ("No JSON schema found in the parser." )
242
+ return
243
+
244
+ # Merge JSON schemas
245
+ merged_schema = _merge_json_schemas (self , self .get_json_schema (), other_schema )
246
+
247
+ self .set_json_schema (merged_schema )
248
+ # Re-extract entities and relations based on the merged schema
249
+ new_entities = self .parser .extract_entities_from_json_schema (merged_schema )
250
+ new_relations = self .parser .extract_relations ()
251
+
252
+ return self .get_json_schema ()
253
+
254
+
255
+ ###########################################################################################
256
+
0 commit comments