From 6a5db313d3eee74a1c14f4e73d2b73989899bc80 Mon Sep 17 00:00:00 2001 From: Lorenzo Padoan Date: Thu, 3 Oct 2024 12:29:39 +0200 Subject: [PATCH 1/2] feat(): implemented langraph logic for entities --- pyproject.toml | 4 +- requirements.txt | 8 + scrapontology/extractor.py | 3 + scrapontology/parsers/base_parser.py | 84 ---------- scrapontology/parsers/pdf_parser.py | 234 +++++++++++++++------------ scrapontology/parsers/prompts.py | 2 +- 6 files changed, 149 insertions(+), 186 deletions(-) create mode 100644 requirements.txt diff --git a/pyproject.toml b/pyproject.toml index bc17f17..eb2b688 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,9 +3,9 @@ name = "scrapontology" version = "0.0.1" description = "Library for extracting schemas and building ontologies from documents using LLM" authors = [ + { name = "Lorenzo Padoan", email = "lorenzo.padoan977@gmail.com" }, { name = "Marco Vinciguerra", email = "mvincig11@gmail.com" }, - { name = "Marco Perini", email = "perinim.98@gmail.com" }, - { name = "Lorenzo Padoan", email = "lorenzo.padoan977@gmail.com" } + { name = "Marco Perini", email = "perinim.98@gmail.com" } ] dependencies = [ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..878ede6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +certifi==2024.7.4 +charset-normalizer==3.3.2 +idna==3.8 +pillow==10.4.0 +python-dotenv==1.0.1 +requests==2.32.3 +urllib3==2.2.2 +langgraph==0.2.32 \ No newline at end of file diff --git a/scrapontology/extractor.py b/scrapontology/extractor.py index 0d2fee6..4136481 100644 --- a/scrapontology/extractor.py +++ b/scrapontology/extractor.py @@ -250,6 +250,9 @@ def _merge_json_schemas(self, schema1: Dict[str, Any], schema2: Dict[str, Any]) new_relations = self.parser.extract_relations() return self.get_json_schema() + + def extract_graph(self): + return self.parser.extract_graph() ########################################################################################### diff --git a/scrapontology/parsers/base_parser.py b/scrapontology/parsers/base_parser.py index 210d587..7c8bbb3 100644 --- a/scrapontology/parsers/base_parser.py +++ b/scrapontology/parsers/base_parser.py @@ -20,89 +20,5 @@ def __init__(self, llm_client: LLMClient): self._entities = [] self._relations = [] - @abstractmethod - def extract_entities(self, file_path: str, prompt: Optional[str] = None) -> List[Entity]: - pass - @abstractmethod - def extract_relations(self, file_path: Optional[str] = None, prompt: Optional[str] = None) -> List[Relation]: - pass - @abstractmethod - def entities_json_schema(self, file_path: str) -> Dict[str, Any]: - pass - - def get_api_key(self): - return self._api_key - - def get_headers(self): - return self._headers - - def get_model(self): - return self._model - - def get_temperature(self): - return self._temperature - - def get_inference_base_url(self): - return self._inference_base_url - - def set_api_key(self, api_key: str): - self._api_key = api_key - - def set_headers(self, headers: Dict[str, str]): - self._headers = headers - - def set_inference_base_url(self, inference_base_url: str): - self.inference_base_url = inference_base_url - - def get_entities(self): - return self._entities - - def get_relations(self): - return self._relations - - def get_json_schema(self): - return self._json_schema - - def set_entities(self, entities: List[Entity]): - if not isinstance(entities, list) or not all(isinstance(entity, Entity) for entity in entities): - raise TypeError("entities must be a List of Entity objects") - self._entities = entities - - def set_relations(self, relations: List[Relation]): - if not isinstance(relations, list) or not all(isinstance(relation, Relation) for relation in relations): - raise TypeError("relations must be a List of Relation objects") - self._relations = relations - - def set_json_schema(self, schema: Dict[str, Any]): - self._json_schema = schema - - def extract_entities_from_json_schema(self, json_schema: Dict[str, Any]) -> List[Entity]: - """ - Extracts entities from a given JSON schema. - - Args: - json_schema (Dict[str, Any]): The JSON schema to extract entities from. - - Returns: - List[Entity]: A list of extracted entities. - """ - entities = [] - - def traverse_schema(schema: Dict[str, Any], parent_id: str = None): - if isinstance(schema, dict): - entity_id = parent_id if parent_id else schema.get('title', 'root') - entity_type = schema.get('type', 'object') - attributes = schema.get('properties', {}) - - if attributes: - entity = Entity(id=entity_id, type=entity_type, attributes=attributes) - entities.append(entity) - - for key, value in attributes.items(): - traverse_schema(value, key) - - traverse_schema(json_schema) - self.set_entities(entities) - return entities diff --git a/scrapontology/parsers/pdf_parser.py b/scrapontology/parsers/pdf_parser.py index a17a398..b8d5f79 100644 --- a/scrapontology/parsers/pdf_parser.py +++ b/scrapontology/parsers/pdf_parser.py @@ -1,4 +1,4 @@ -from typing import List, Dict, Any, Optional +from typing import List, Dict, Any, Optional, Literal from .base_parser import BaseParser from ..primitives import Entity, Relation import base64 @@ -13,6 +13,9 @@ import re from ..llm_client import LLMClient from requests.exceptions import ReadTimeout +from langgraph.graph import StateGraph, START, END +from pydantic import BaseModel +from typing import Optional, List # Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -117,44 +120,6 @@ def save_image_to_temp(image: Image.Image) -> str: image.save(temp_file.name, 'JPEG') return temp_file.name -def process_pdf(pdf_path: str) -> Optional[List[str]]: - """ - Processes a PDF file and converts each page to a base64 encoded image. - - Args: - pdf_path (str): The path to the PDF file. - - Returns: - List[str] or None: A list of base64 encoded images if successful, None otherwise. - """ - if not os.path.exists(pdf_path): - raise FileNotFoundError(f"PDF file not found: {pdf_path}") - - # Load PDF as images - images = load_pdf_as_images(pdf_path) - if not images: - return None - - base64_images = [] - - for page_num, image in enumerate(images, start=1): - temp_image_path = None - try: - # Save image to temporary file - temp_image_path = save_image_to_temp(image) - - # Convert image to base64 - base64_image = encode_image(temp_image_path) - base64_images.append(base64_image) - - except Exception as e: - logging.error(f"Error processing page {page_num}: {e}") - finally: - # Ensure temp file is deleted even in case of an error - if temp_image_path and os.path.exists(temp_image_path): - os.unlink(temp_image_path) - - return base64_images class PDFParser(BaseParser): """ @@ -174,42 +139,93 @@ def __init__(self, llm_client: LLMClient): super().__init__(llm_client) - def extract_entities(self, file_path: str, prompt: Optional[str] = None) -> List[Entity]: - if not os.path.exists(file_path): - raise FileNotFoundError(f"PDF file not found: {file_path}") - - if prompt: - entities_json_schema = self.entities_json_schema(file_path, prompt) - else: - entities_json_schema = self.entities_json_schema(file_path) - - # pass to the llm the entities json schema: - prompt = EXTRACT_ENTITIES_CODE_PROMPT.format(json_schema=str(entities_json_schema) , entity_class=str(inspect.getsource(Entity))) + class State_Entities(BaseModel): + file_path: Optional[str] = None + user_prompt_for_filter: Optional[str] = None + entities_code: Optional[str] = None + entity_class: Optional[str] = None + temp_entities: Optional[List[Entity]] = None + entities: Optional[List[Entity]] = None + base64_images: Optional[List[str]] = None + page_answers: Optional[List[str]] = None + entities_json_schema: Optional[Dict[str, Any]] = None + + # initialize the state with the avaible field at the start, the entity class definition + self.state_entities = State_Entities() + self.state_entities.entity_class = inspect.getsource(Entity) + + #nodes for the entities graph + builder_for_entities = StateGraph(State_Entities) + builder_for_entities.add_node("process_pdf", self._process_pdf) + builder_for_entities.add_node("generate_json_schemas", self._generate_json_schemas) + builder_for_entities.add_node("merge_json_schemas", self._merge_json_schemas) + builder_for_entities.add_node("generate_entities_code", self._generate_entities_code) + builder_for_entities.add_node("execute_entities_code", self._execute_entities_code) + builder_for_entities.add_node("assign_entities", self.update_entities) + + #edges for the entities graph + builder_for_entities.add_edge(START, "process_pdf") + builder_for_entities.add_edge("process_pdf", "generate_json_schemas") + builder_for_entities.add_edge("generate_json_schemas", "merge_json_schemas") + builder_for_entities.add_edge("merge_json_schemas", "generate_entities_code") + builder_for_entities.add_edge("generate_entities_code","execute_entities_code") + builder_for_entities.add_edge("execute_entities_code", "assign_entities") + builder_for_entities.add_edge("assign_entities", END) + + self.graph_for_entities = builder_for_entities.compile() + + + def extract_graph(self): + return self.graph_for_entities + + def _generate_entities_code(self, *_) -> str: + prompt = EXTRACT_ENTITIES_CODE_PROMPT.format(json_schema=str(self.state_entities.entities_json_schema) , entity_class=str(inspect.getsource(Entity))) entities_code = self.llm_client.get_response(prompt) # extract the python code from the entities_code remove the ```python and ``` entities_code = entities_code.replace("```python", "").replace("```", "") - # execute the code and get the entities + self.state_entities.entities_code = entities_code + return self.state_entities + + def _execute_entities_code(self, *_): local_vars = {} max_retries = 3 retry_count = 0 while retry_count < max_retries: try: - exec(entities_code, globals(), local_vars) + exec(self.state_entities.entities_code, globals(), local_vars) break # If successful, exit the loop except Exception as e: logging.error(f"Error executing entities code (attempt {retry_count + 1}): {e}") if retry_count == max_retries - 1: logging.error("Max retries reached. Unable to execute entities code.") break - fix_code_prompt = FIX_CODE_PROMPT.format(code=entities_code, error=str(e)) + fix_code_prompt = FIX_CODE_PROMPT.format(code=self.state_entities.entities_code, error=str(e)) fixed_code = self.llm_client.get_response(fix_code_prompt) - entities_code = fixed_code # Update entities_code with the fixed version + fixed_code = fixed_code.replace("```python", "").replace("```", "") + self.state_entities.entities_code = fixed_code # Update entities_code with the fixed version retry_count += 1 new_entities = local_vars.get('entities', []) + self.state_entities.temp_entities = new_entities - return self.update_entities(new_entities) + return self.state_entities + + + def extract_entities(self, file_path: str, prompt: Optional[str] = None) -> List[Entity]: + if not os.path.exists(file_path): + raise FileNotFoundError(f"PDF file not found: {file_path}") + + self.state_entities.file_path = file_path + + if prompt: + self.state_entities.user_prompt_for_filter = prompt + + + self.graph_for_entities.invoke(self.state_entities) + + return self.state_entities.entities + def _extract_json_content(self, input_string: str) -> str: @@ -226,12 +242,12 @@ def _extract_python_content(self, input_string: str) -> str: return match.group(1).strip() return "" - def update_entities(self, new_entities: List[Entity]) -> List[Entity]: + def update_entities(self, *_): existing_entities = self._entities prompt = UPDATE_ENTITIES_PROMPT.format( existing_entities=json.dumps([e.__dict__ for e in existing_entities], indent=2), - new_entities=json.dumps([e.__dict__ for e in new_entities], indent=2) + new_entities=json.dumps([e.__dict__ for e in self.state_entities.temp_entities], indent=2) ) response = self.llm_client.get_response(prompt) @@ -242,6 +258,7 @@ def update_entities(self, new_entities: List[Entity]) -> List[Entity]: updated_entities = [Entity(**entity_data) for entity_data in updated_entities_data] # Update the parser's entities + self.state_entities.entities = updated_entities self._entities = updated_entities # print the updated entities @@ -249,11 +266,11 @@ def update_entities(self, new_entities: List[Entity]) -> List[Entity]: for entity in updated_entities: logging.info(entity.__dict__) logging.info(f"Entities updated. New count: {len(updated_entities)}") - return updated_entities + return self.state_entities except json.JSONDecodeError as e: logging.error(f"JSONDecodeError: {e}") logging.error("Error: Unable to parse the LLM response.") - return existing_entities + return self.state_entities def extract_relations(self, file_path: Optional[str] = None, prompt: Optional[str] = None) -> List[Relation]: """ @@ -308,7 +325,7 @@ def plot_entities_schema(self, file_path: str) -> None: raise FileNotFoundError(f"PDF file not found: {file_path}") entities = [] - base64_images = process_pdf(file_path) + base64_images = self._process_pdf(file_path) if base64_images: page_answers = self._generate_digraph(base64_images) @@ -319,35 +336,6 @@ def plot_entities_schema(self, file_path: str) -> None: logging.info("digraph_code_execution----------------------------------") exec(digraph_code[9:-3]) - def entities_json_schema(self, file_path: str, prompt: Optional[str] = None) -> Dict[str, Any]: - """ - Generates a JSON schema of entities from a PDF file. - - Args: - file_path (str): The path to the PDF file. - - Returns: - Dict[str, Any]: The JSON schema of entities. - """ - if not os.path.exists(file_path): - raise FileNotFoundError(f"PDF file not found: {file_path}") - - base64_images = process_pdf(file_path) - - if base64_images: - page_answers = self._generate_json_schema(base64_images, prompt) - json_schema = self._merge_json_schemas(page_answers) - json_schema = self._extract_json_content(json_schema) - - logging.info("\n PDF JSON Schema:") - logging.info(json_schema) - # json schema is a valid json schema but its a string convert it to a python dict - entities_json_schema = json.loads(json_schema) - - # Assign the generated schema to self._json_schema - self._json_schema = entities_json_schema - - return entities_json_schema def _generate_digraph(self, base64_images: List[str]) -> List[str]: page_answers = [] @@ -377,12 +365,12 @@ def _merge_digraphs_for_plot(self, page_answers: List[str]) -> str: digraph_code = self.llm_client.get_response(digraph_prompt) return digraph_code - def _generate_json_schema(self, base64_images: List[str], prompt: Optional[str] = None) -> List[str]: + def _generate_json_schemas(self,*_): page_answers = [] - for page_num, base64_image in enumerate(base64_images, start=1): - if prompt: + for page_num, base64_image in enumerate(self.state_entities.base64_images, start=1): + if self.state_entities.user_prompt_for_filter: customized_prompt = f"{JSON_SCHEMA_PROMPT} extract only what is required from the following prompt:\ - {prompt} (Page {page_num})" + {self.state_entities.user_prompt_for_filter} (Page {page_num})" else: customized_prompt = f"{JSON_SCHEMA_PROMPT} (Page {page_num})" @@ -396,13 +384,61 @@ def _generate_json_schema(self, base64_images: List[str], prompt: Optional[str] page_answers.append(f"Page {page_num}: {answer}") logging.info(f"Processed page {page_num}") - return page_answers + self.state_entities.page_answers = page_answers + return self.state_entities - def _merge_json_schemas(self, page_answers: List[str]) -> str: - digraph_prompt = "Generate a unique json schema starting from the following \ - \n\n" + "\n\n".join(page_answers) + "\n\n \ + def _merge_json_schemas(self,*_): + json_schema_prompt = "Generate a unique json schema starting from the following \ + \n\n" + "\n\n".join(self.state_entities.page_answers) + "\n\n \ Remember to provide only the json schema without any comments, wrapped in backticks (`) like ```json ... ``` and nothing else." - digraph_code = self.llm_client.get_response(digraph_prompt) - return digraph_code + json_schema_answer = self.llm_client.get_response(json_schema_prompt) + json_schema = self._extract_json_content(json_schema_answer) + logging.info("\n PDF JSON Schema:") + logging.info(json_schema) + # json schema is a valid json schema but its a string convert it to a python dict + entities_json_schema = json.loads(json_schema) + # Assign the generated schema to self._json_schema + self.state_entities.entities_json_schema = entities_json_schema + return self.state_entities + + def _process_pdf(self, *_): #Optional[List[str]]: + """ + Processes a PDF file and converts each page to a base64 encoded image. + + Args: + pdf_path (str): The path to the PDF file. + + Returns: + List[str] or None: A list of base64 encoded images if successful, None otherwise. + """ + if not os.path.exists(self.state_entities.file_path): + raise FileNotFoundError(f"PDF file not found: {self.state_entities.file_path}") + + # Load PDF as images + images = load_pdf_as_images(self.state_entities.file_path) + if not images: + return None + + base64_images = [] + + for page_num, image in enumerate(images, start=1): + temp_image_path = None + try: + # Save image to temporary file + temp_image_path = save_image_to_temp(image) + + # Convert image to base64 + base64_image = encode_image(temp_image_path) + base64_images.append(base64_image) + + except Exception as e: + logging.error(f"Error processing page {page_num}: {e}") + finally: + # Ensure temp file is deleted even in case of an error + if temp_image_path and os.path.exists(temp_image_path): + os.unlink(temp_image_path) + + self.state_entities.base64_images = base64_images + return self.state_entities diff --git a/scrapontology/parsers/prompts.py b/scrapontology/parsers/prompts.py index a5e4f14..e0a6c9e 100644 --- a/scrapontology/parsers/prompts.py +++ b/scrapontology/parsers/prompts.py @@ -228,7 +228,7 @@ {entity_class} Takes as reference the following python code for building the entities: -from scrape_schema import Entity +from scrapontology.primitives import Entity # Define entities with nested attributes entities = [ From 8d5412fde77db78a26ef19ee37ddd4431df60198 Mon Sep 17 00:00:00 2001 From: Lorenzo Padoan Date: Thu, 3 Oct 2024 17:11:58 +0200 Subject: [PATCH 2/2] feat(): implemented langraph logic for relations --- scrapontology/extractor.py | 155 +++++++++++++-------------- scrapontology/parsers/base_parser.py | 99 ++++++++++++++++- scrapontology/parsers/pdf_parser.py | 150 ++++++++++++++++++-------- 3 files changed, 279 insertions(+), 125 deletions(-) diff --git a/scrapontology/extractor.py b/scrapontology/extractor.py index 4136481..9c439f2 100644 --- a/scrapontology/extractor.py +++ b/scrapontology/extractor.py @@ -21,48 +21,87 @@ def extract_relations(self) -> List[Relation]: pass @abstractmethod - def entities_json_schema(self) -> Dict[str, Any]: + def generate_entities_json_schema(self) -> Dict[str, Any]: pass @abstractmethod - def update_entities(self, new_entities: List[Entity]) -> List[Entity]: + def merge_schemas(self, other_schema: Dict[str, Any]) -> None: pass @abstractmethod - def merge_schemas(self, other_schema: Dict[str, Any]) -> None: + def delete_entity_or_relation(self, item_description: str) -> None: pass @abstractmethod - def delete_entity_or_relation(self, item_description: str) -> None: + def get_entities(self) -> List[Entity]: + pass + + @abstractmethod + def get_relations(self) -> List[Relation]: + pass + + @abstractmethod + def get_json_schema(self) -> Dict[str, Any]: pass @abstractmethod - def set_entities(self, entities: List[Entity]) -> None: + def get_entities_graph(self): pass @abstractmethod - def set_relations(self, relations: List[Relation]) -> None: + def get_relations_graph(self): pass @abstractmethod - def set_json_schema(self, schema: Dict[str, Any]) -> None: + def get_json_schema(self): pass class FileExtractor(Extractor): def __init__(self, file_path: str, parser: BaseParser): + """ + Initialize the FileExtractor. + + Args: + file_path (str): The path to the file to be processed. + parser (BaseParser): The parser to be used for extraction. + """ self.file_path = file_path self.parser = parser def extract_entities(self, prompt: Optional[str] = None) -> List[Entity]: + """ + Extract entities from the file. + + Args: + prompt (Optional[str]): An optional prompt to guide the extraction. + + Returns: + List[Entity]: A list of extracted entities. + """ new_entities = self.parser.extract_entities(self.file_path, prompt) return new_entities def extract_relations(self, prompt: Optional[str] = None) -> List[Relation]: + """ + Extract relations from the file. + + Args: + prompt (Optional[str]): An optional prompt to guide the extraction. + + Returns: + List[Relation]: A list of extracted relations. + """ return self.parser.extract_relations(self.file_path, prompt) + def generate_entities_json_schema(self) -> Dict[str, Any]: + """ + Generate a JSON schema for the entities. - def entities_json_schema(self) -> Dict[str, Any]: - return self.parser.entities_json_schema(self.file_path) + Returns: + Dict[str, Any]: The generated JSON schema. + """ + self.parser.generate_json_schema(self.file_path) + return self.parser.get_json_schema() def delete_entity_or_relation(self, item_description: str) -> None: entities_ids = [e.id for e in self.parser.get_entities()] @@ -109,63 +148,7 @@ def _delete_relation(self, relation_id: str) -> None: self.parser.set_relations(relations) logger.info(f"Relation '{name}' between '{source}' and '{target}' has been deleted.") - def update_entities(self, new_entities: List[Entity]) -> List[Entity]: - """ - Update the existing entities with new entities, integrating and deduplicating as necessary. - - :param new_entities: List of new entities to be integrated - :return: Updated list of entities - """ - existing_entities = self.parser.get_entities() - - # Prepare the prompt for the LLM - prompt = UPDATE_ENTITIES_PROMPT.format( - existing_entities=json.dumps([e.__dict__ for e in existing_entities], indent=2), - new_entities=json.dumps([e.__dict__ for e in new_entities], indent=2) - ) - - # Get the LLM response - response = self._get_llm_response(prompt) - - try: - updated_entities_data = json.loads(response) - updated_entities = [Entity(**entity_data) for entity_data in updated_entities_data] - - # Update the parser's entities - self.parser.set_entities(updated_entities) - - logger.info(f"Entities updated. New count: {len(updated_entities)}") - return updated_entities - except json.JSONDecodeError: - logger.error("Error: Unable to parse the LLM response.") - return existing_entities - - def set_json_schema(self, schema: Dict[str, Any]) -> None: - """ - Set the JSON schema for the parser. - - Args: - schema (Dict[str, Any]): The JSON schema to set. - """ - self.parser.set_json_schema(schema) - - def set_entities(self, entities: List[Entity]) -> None: - """ - Set the entities for the parser. - - Args: - entities (List[Entity]): The list of entities to set. - """ - self.parser.set_entities(entities) - - def set_relations(self, relations: List[Relation]) -> None: - """ - Set the relations for the parser. - - Args: - relations (List[Relation]): The list of relations to set. - """ - self.parser.set_relations(relations) + def get_entities(self) -> List[Entity]: """ @@ -185,14 +168,6 @@ def get_relations(self) -> List[Relation]: """ return self.parser.get_relations() - def get_json_schema(self) -> Dict[str, Any]: - """ - Get the JSON schema from the parser. - - Returns: - Dict[str, Any]: The JSON schema. - """ - return self.parser.get_json_schema() def merge_schemas(self, other_schema: Dict[str, Any]) -> Dict[str, Any]: """ @@ -250,10 +225,32 @@ def _merge_json_schemas(self, schema1: Dict[str, Any], schema2: Dict[str, Any]) new_relations = self.parser.extract_relations() return self.get_json_schema() - - def extract_graph(self): - return self.parser.extract_graph() - ########################################################################################### - \ No newline at end of file + def get_entities_graph(self): + """ + Retrieves the state graph for entities extraction. + + Returns: + Any: The entities state graph from the parser. + """ + return self.parser.get_entities_graph() + + def get_relations_graph(self): + """ + Retrieves the state graph for relations extraction. + + Returns: + Any: The relations state graph from the parser. + """ + return self.parser.get_relations_graph() + + def get_json_schema(self): + """ + Retrieves the JSON schema. + + Returns: + Dict[str, Any]: The current JSON schema from the parser. + """ + return self.parser.get_json_schema() + diff --git a/scrapontology/parsers/base_parser.py b/scrapontology/parsers/base_parser.py index 7c8bbb3..c44cd50 100644 --- a/scrapontology/parsers/base_parser.py +++ b/scrapontology/parsers/base_parser.py @@ -6,10 +6,10 @@ class BaseParser(ABC): def __init__(self, llm_client: LLMClient): """ - Initializes the PDFParser with an API key. + Initializes the BaseParser with an LLMClient. Args: - api_key (str): The API key for authentication. + llm_client (LLMClient): The LLM client for inference. """ self.llm_client = llm_client self._headers = { @@ -20,5 +20,100 @@ def __init__(self, llm_client: LLMClient): self._entities = [] self._relations = [] + @abstractmethod + def extract_entities(self, file_path: str, prompt: Optional[str] = None) -> List[Entity]: + """ + Extracts entities from the given file. + + Args: + file_path (str): The path to the file. + prompt (Optional[str]): An optional prompt to guide the extraction. + + Returns: + List[Entity]: A list of extracted entities. + """ + pass + + @abstractmethod + def extract_relations(self, file_path: Optional[str] = None, prompt: Optional[str] = None) -> List[Relation]: + """ + Extracts relations from the given file. + + Args: + file_path (Optional[str]): The path to the file. + prompt (Optional[str]): An optional prompt to guide the extraction. + + Returns: + List[Relation]: A list of extracted relations. + """ + pass + + @abstractmethod + def generate_json_schema(self, file_path: str) -> Dict[str, Any]: + """ + Generates a JSON schema for the entities based on the given file. + + Args: + file_path (str): The path to the file. + + Returns: + Dict[str, Any]: The generated JSON schema. + """ + pass + + @abstractmethod + def get_entities(self) -> List[Entity]: + """ + Retrieves the list of entities. + + Returns: + List[Entity]: The current list of entities. + """ + pass + + + @abstractmethod + def get_relations(self) -> List[Relation]: + """ + Retrieves the list of relations. + + Returns: + List[Relation]: The current list of relations. + """ + pass + + + @abstractmethod + def get_json_schema(self) -> Dict[str, Any]: + """ + Retrieves the JSON schema. + + Returns: + Dict[str, Any]: The current JSON schema. + """ + pass + + + @abstractmethod + def get_entities_graph(self): + """ + Retrieves the state graph for entities extraction. + + Returns: + Any: The entities state graph. + """ + pass + + @abstractmethod + def get_relations_graph(self): + """ + Retrieves the state graph for relations extraction. + + Returns: + Any: The relations state graph. + """ + pass + + diff --git a/scrapontology/parsers/pdf_parser.py b/scrapontology/parsers/pdf_parser.py index b8d5f79..938253c 100644 --- a/scrapontology/parsers/pdf_parser.py +++ b/scrapontology/parsers/pdf_parser.py @@ -174,9 +174,45 @@ class State_Entities(BaseModel): self.graph_for_entities = builder_for_entities.compile() + class State_Relations(BaseModel): + entities: Optional[List[Entity]] = None + user_prompt_for_filter: Optional[str] = None + relations_code: Optional[str] = None + relation_class: Optional[str] = None + relations: Optional[List[Relation]] = None + + self.state_relations = State_Relations() + self.state_relations.relation_class = inspect.getsource(Relation) + + builder_for_relations = StateGraph(State_Relations) + builder_for_relations.add_node("extract_relations", self._extract_relations_code) + builder_for_relations.add_node("execute_relations_code", self._execute_relations_code) + builder_for_relations.add_edge(START, "extract_relations") + builder_for_relations.add_edge("extract_relations", "execute_relations_code") + builder_for_relations.add_edge("execute_relations_code", END) + + self.graph_for_relations = builder_for_relations.compile() + + class State_Entities_Json_Schema(BaseModel): + file_path: Optional[str] = None + base64_images: Optional[List[str]] = None + page_answers: Optional[List[str]] = None + entities_json_schema: Optional[Dict[str, Any]] = None + + self.state_entities_json_schema = State_Entities_Json_Schema() + + + builder_for_entities_json_schema = StateGraph(State_Entities_Json_Schema) + builder_for_entities_json_schema.add_node("process_pdf", self._process_pdf) + builder_for_entities_json_schema.add_node("generate_json_schemas", self._generate_json_schemas) + builder_for_entities_json_schema.add_node("merge_json_schemas", self._merge_json_schemas) + builder_for_entities_json_schema.add_edge(START, "process_pdf") + builder_for_entities_json_schema.add_edge("process_pdf", "generate_json_schemas") + builder_for_entities_json_schema.add_edge("generate_json_schemas", "merge_json_schemas") + builder_for_entities_json_schema.add_edge("merge_json_schemas", END) + + self.graph_for_entities_json_schema = builder_for_entities_json_schema.compile() - def extract_graph(self): - return self.graph_for_entities def _generate_entities_code(self, *_) -> str: prompt = EXTRACT_ENTITIES_CODE_PROMPT.format(json_schema=str(self.state_entities.entities_json_schema) , entity_class=str(inspect.getsource(Entity))) @@ -225,8 +261,8 @@ def extract_entities(self, file_path: str, prompt: Optional[str] = None) -> List self.graph_for_entities.invoke(self.state_entities) return self.state_entities.entities - + def _extract_json_content(self, input_string: str) -> str: # Use regex to match content between ```json and ``` @@ -283,36 +319,48 @@ def extract_relations(self, file_path: Optional[str] = None, prompt: Optional[st List[Relation]: A list of extracted relations. """ if file_path is not None and not os.path.exists(file_path): + logging.error(f"PDF file not found: {file_path}") raise FileNotFoundError(f"PDF file not found: {file_path}") if not self._entities or len(self._entities) == 0: - self.extract_entities(file_path, prompt) + logging.error("Entities not found. Please extract entities first.") + raise ValueError("Entities not found. Please extract entities first.") + + self.graph_for_relations.invoke(self.state_relations) + return self.state_relations.relations + + + def _extract_relations_code(self, *_): relation_class_str = inspect.getsource(Relation) relations_prompt = RELATIONS_PROMPT.format( entities=json.dumps([e.__dict__ for e in self._entities], indent=2), - relation_class=relation_class_str + relation_class=self.state_relations.relation_class ) - if prompt: + if self.state_relations.user_prompt_for_filter: #append to the relations_prompt the prompt - relations_prompt += f"\n\n Extract only the relations that are required from the following user prompt:\n\n{prompt}" + relations_prompt += f"\n\n Extract only the relations that are required from the following user prompt:\n\n{self.state_relations.user_prompt_for_filter}" - relations_answer_code = self.llm_client.get_response(relations_prompt) - relations_answer_code = self._extract_python_content(relations_answer_code) + relations_code_answer = self.llm_client.get_response(relations_prompt) + relations_code = self._extract_python_content(relations_code_answer) + self.state_relations.relations_code = relations_code + return self.state_relations + def _execute_relations_code(self, *_): local_vars = {} try: - exec(relations_answer_code, globals(), local_vars) + exec(self.state_relations.relations_code, globals(), local_vars) except Exception as e: logging.error(f"Error executing relations code: {e}") raise ValueError(f"The language model generated invalid code: {e}") from e relations_answer = local_vars.get('relations', []) self._relations = relations_answer - logging.info(f"Extracted relations: {relations_answer_code}") + self.state_relations.relations = relations_answer + logging.info(f"Extracted relations: {self.state_relations.relations}") - return self._relations + return self.state_relations def plot_entities_schema(self, file_path: str) -> None: """ @@ -337,34 +385,6 @@ def plot_entities_schema(self, file_path: str) -> None: exec(digraph_code[9:-3]) - def _generate_digraph(self, base64_images: List[str]) -> List[str]: - page_answers = [] - for page_num, base64_image in enumerate(base64_images, start=1): - prompt = f"You are an AI specialized in creating python code for generating digraph graphviz, you have to create python code for creating a digraph with the relative entities with the relative attributes \ - (name_attribute : type) (i.e type is int,float,list[dict],dict,string,etc...) from a PDF screenshot.\ - in the digraph you have to represent the entities with their attributes names and types, \ - NOT THE VALUES OF THE ATTRIBUTES, IT'S EXTREMELY IMPORTANT. \ - you must provide only the code to generate the digraph, without any comments before or after the code.\ - Remember you don't have to insert the values of the attribute but only (name)\ - Remember the generated digraph must be a tree, following the hierarchy of the entities in the PDF\ - Remember to the deduplicate similar entities and to the remove the duplicate edges, you have to provide the best digraph\ - that represent the PDF document because the partial digraphs are generated from the same document but from different parts of the PDF\ - Remeber to follow a structure like this one: \n\n{DIGRAPH_EXAMPLE_PROMPT}\n\nHere a page to from a PDF screenshot (Page {page_num})" - - image_data = f"data:image/jpeg;base64,{base64_image}" - answer = self.llm_client.get_response(prompt, image_url=image_data) - page_answers.append(f"Page {page_num}: {answer}") - logging.info(f"Processed page {page_num}") - - return page_answers - - def _merge_digraphs_for_plot(self, page_answers: List[str]) -> str: - digraph_prompt = "Merge the partial digraphs that I provide to you merging together all the detected entities, \n\n" + "\n\n".join(page_answers) + \ - "\nYour answer digraph must be a tree and must contain only the code for a valid graphviz graph" - - digraph_code = self.llm_client.get_response(digraph_prompt) - return digraph_code - def _generate_json_schemas(self,*_): page_answers = [] for page_num, base64_image in enumerate(self.state_entities.base64_images, start=1): @@ -398,8 +418,9 @@ def _merge_json_schemas(self,*_): logging.info(json_schema) # json schema is a valid json schema but its a string convert it to a python dict entities_json_schema = json.loads(json_schema) - # Assign the generated schema to self._json_schema + self.state_entities.entities_json_schema = entities_json_schema + self._json_schema = entities_json_schema return self.state_entities def _process_pdf(self, *_): #Optional[List[str]]: @@ -412,11 +433,22 @@ def _process_pdf(self, *_): #Optional[List[str]]: Returns: List[str] or None: A list of base64 encoded images if successful, None otherwise. """ - if not os.path.exists(self.state_entities.file_path): - raise FileNotFoundError(f"PDF file not found: {self.state_entities.file_path}") + if (self.state_entities.file_path is None and + self.state_entities_json_schema.file_path is None): + raise FileNotFoundError(f"PDF file not found") + elif self.state_entities.file_path is None: + # check if the file exists + if not os.path.exists(self.state_entities_json_schema.file_path): + raise FileNotFoundError(f"PDF file not found: {self.state_entities_json_schema.file_path}") + file_path = self.state_entities_json_schema.file_path + elif self.state_entities_json_schema.file_path is None: + # check if the file exists + if not os.path.exists(self.state_entities.file_path): + raise FileNotFoundError(f"PDF file not found: {self.state_entities.file_path}") + file_path = self.state_entities.file_path # Load PDF as images - images = load_pdf_as_images(self.state_entities.file_path) + images = load_pdf_as_images(file_path) if not images: return None @@ -441,4 +473,34 @@ def _process_pdf(self, *_): #Optional[List[str]]: self.state_entities.base64_images = base64_images return self.state_entities + + def get_entities_graph(self): + return self.graph_for_entities + + def get_relations_graph(self): + return self.graph_for_relations + + def generate_json_schema(self, file_path: str) -> Dict[str, Any]: + if not os.path.exists(file_path): + raise FileNotFoundError(f"PDF file not found: {file_path}") + + self.state_entities_json_schema.file_path = file_path + self.graph_for_entities_json_schema.invoke(self.state_entities_json_schema) + + logging.info(f"Entities JSON Schema: {self._json_schema}") + return self._json_schema + + def get_json_schema_graph(self): + return self.graph_for_entities_json_schema + + def get_json_schema(self): + return self._json_schema + + def get_entities(self): + return self._entities + + def get_relations(self): + return self._relations + +