diff --git a/examples/.env.example b/examples/.env.example new file mode 100644 index 0000000..73e49d5 --- /dev/null +++ b/examples/.env.example @@ -0,0 +1,17 @@ +# OpenAI API Key +OPENAI_API_KEY=your_openai_api_key_here + +# PostgreSQL Database Configuration +POSTGRES_HOST=localhost +POSTGRES_PORT=5432 +POSTGRES_DB=your_database_name +POSTGRES_USER=your_postgres_username +POSTGRES_PASSWORD=your_secure_password + +# Example: +# OPENAI_API_KEY=sk-abcdefghijklmnopqrstuvwxyz123456 +# POSTGRES_HOST=db.example.com +# POSTGRES_PORT=5432 +# POSTGRES_DB=myapp_database +# POSTGRES_USER=db_user +# POSTGRES_PASSWORD=super_secret_password \ No newline at end of file diff --git a/examples/generate_tables_from_pdf.py b/examples/generate_tables_from_pdf.py new file mode 100644 index 0000000..2dbd144 --- /dev/null +++ b/examples/generate_tables_from_pdf.py @@ -0,0 +1,43 @@ +from scrapontology import FileExtractor, PDFParser +from scrapontology.llm_client import LLMClient +from scrapontology.db_client import PostgresDBClient +import os +from dotenv import load_dotenv + +def main(): + # Load environment variables + load_dotenv() + api_key = os.getenv("OPENAI_API_KEY") + + # Get current directory and set PDF path + curr_dir = os.path.dirname(os.path.abspath(__file__)) + pdf_name = "test.pdf" + pdf_path = os.path.join(curr_dir, pdf_name) + + # Create LLMClient and PDFParser instances + llm_client = LLMClient(api_key) + pdf_parser = PDFParser(llm_client) + + # Create DBClient instance + postgres_host = os.getenv("POSTGRES_HOST") + postgres_port = os.getenv("POSTGRES_PORT") + postgres_db = os.getenv("POSTGRES_DB") + postgres_user = os.getenv("POSTGRES_USER") + postgres_password = os.getenv("POSTGRES_PASSWORD") + db_client = PostgresDBClient(postgres_host, postgres_port, postgres_db, postgres_user, postgres_password) + + # Create FileExtractor instance + pdf_extractor = FileExtractor(pdf_path, pdf_parser, db_client) + + # Generate JSON schema from the PDF + json_schema = pdf_extractor.generate_entities_json_schema() + print("Generated JSON Schema:") + print(json_schema) + + # Create tables in the database + pdf_extractor.create_tables() + + print("Tables created successfully in the PostgreSQL database.") + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index eb2b688..60dbd2d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,6 +17,7 @@ dependencies = [ "requests==2.32.3", "urllib3==2.2.2", "langgraph==0.2.31", + "psycopg2==2.9.9" ] license = "MIT" diff --git a/requirements-dev.lock b/requirements-dev.lock index 2b3bf35..21b87cb 100644 --- a/requirements-dev.lock +++ b/requirements-dev.lock @@ -94,6 +94,8 @@ pluggy==1.5.0 # via pytest prettytable==3.11.0 # via pyecharts +psycopg2==2.9.9 + # via scrapontology pydantic==2.9.2 # via langchain-core # via langsmith diff --git a/requirements.lock b/requirements.lock index daf06eb..9191dd8 100644 --- a/requirements.lock +++ b/requirements.lock @@ -56,6 +56,8 @@ packaging==24.1 # via langchain-core pillow==10.4.0 # via scrapontology +psycopg2==2.9.9 + # via scrapontology pydantic==2.9.2 # via langchain-core # via langsmith diff --git a/requirements.txt b/requirements.txt index 878ede6..08b769a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,4 +5,5 @@ pillow==10.4.0 python-dotenv==1.0.1 requests==2.32.3 urllib3==2.2.2 -langgraph==0.2.32 \ No newline at end of file +langgraph==0.2.32 +psycopg2==2.9.9 \ No newline at end of file diff --git a/scrapontology/db_client.py b/scrapontology/db_client.py new file mode 100644 index 0000000..b4ff62d --- /dev/null +++ b/scrapontology/db_client.py @@ -0,0 +1,95 @@ +import os +from abc import ABC, abstractmethod +import psycopg2 +from psycopg2.extras import RealDictCursor +import logging +from pydantic_core import CoreSchema, core_schema +from typing import Any, Callable + + +logger = logging.getLogger(__name__) + +class DBClient(ABC): + @abstractmethod + def connect(self): + pass + + @abstractmethod + def disconnect(self): + pass + + @abstractmethod + def execute_query(self, query, params=None): + pass + +class PostgresDBClient(DBClient): + def __init__(self, host=None, port=None, database=None, user=None, password=None): + self.conn = None + self.cursor = None + self.host = host or os.getenv('POSTGRES_HOST', 'localhost') + self.port = port or os.getenv('POSTGRE1S_PORT', '5432') + self.database = database or os.getenv('POSTGRES_DB', 'scrapontology_test') + self.user = user or os.getenv('POSTGRES_USER', 'lurens') + self.password = password or os.getenv('POSTGRES_PASSWORD', 'cicciopasticco') + + @classmethod + def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: Callable) -> CoreSchema: + return core_schema.any_schema() + + def connect(self): + try: + self.conn = psycopg2.connect( + host=self.host, + port=self.port, + database=self.database, + user=self.user, + password=self.password + ) + self.cursor = self.conn.cursor(cursor_factory=RealDictCursor) + logger.info("Successfully connected to PostgreSQL database") + except (Exception, psycopg2.Error) as error: + logger.error(f"Error while connecting to PostgreSQL: {error}") + + def disconnect(self): + if self.conn: + self.cursor.close() + self.conn.close() + logger.info("PostgreSQL connection is closed") + + def execute_query(self, query, params=None): + try: + self.cursor.execute(query, params) + self.conn.commit() + return self.cursor.fetchall() + except (Exception, psycopg2.Error) as error: + logger.error(f"Error executing query: {error}") + self.conn.rollback() + raise error # Re-raise the exception to propagate it + + +class Neo4jDBClient(DBClient): + def __init__(self): + # Placeholder for future implementation + pass + + def connect(self): + # Placeholder for future implementation + logger.info("Neo4j connection not yet implemented") + + def disconnect(self): + # Placeholder for future implementation + pass + + def execute_query(self, query, params=None): + # Placeholder for future implementation + logger.info("Neo4j query execution not yet implemented") + return None + +# Factory function to get the appropriate DB client +def get_db_client(db_type='postgres'): + if db_type.lower() == 'postgres': + return PostgresDBClient() + elif db_type.lower() == 'neo4j': + return Neo4jDBClient() + else: + raise ValueError(f"Unsupported database type: {db_type}") \ No newline at end of file diff --git a/scrapontology/extractor.py b/scrapontology/extractor.py index 9c439f2..bb908c0 100644 --- a/scrapontology/extractor.py +++ b/scrapontology/extractor.py @@ -3,11 +3,16 @@ from typing import List, Tuple, Dict, Any, Optional from .primitives import Entity, Relation from .parsers.base_parser import BaseParser -from .parsers.prompts import DELETE_PROMPT, UPDATE_ENTITIES_PROMPT, UPDATE_SCHEMA_PROMPT +from .parsers.prompts import DELETE_PROMPT, UPDATE_ENTITIES_PROMPT, UPDATE_SCHEMA_PROMPT, CREATE_TABLES_PROMPT from .llm_client import LLMClient from .parsers.prompts import UPDATE_SCHEMA_PROMPT +from .db_client import DBClient, PostgresDBClient import json - +from langgraph.graph import StateGraph, END, START +from typing import TypedDict, Literal +from scrapontology.db_client import PostgresDBClient +from scrapontology.llm_client import LLMClient +from pydantic import BaseModel logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -56,17 +61,31 @@ def get_relations_graph(self): def get_json_schema(self): pass + @abstractmethod + def get_db_client(self): + pass + + @abstractmethod + def set_db_client(self, db_client: DBClient): + pass + + @abstractmethod + def create_tables(self): + pass + class FileExtractor(Extractor): - def __init__(self, file_path: str, parser: BaseParser): + def __init__(self, file_path: str, parser: BaseParser, db_client: Optional[DBClient] = None): """ Initialize the FileExtractor. Args: file_path (str): The path to the file to be processed. parser (BaseParser): The parser to be used for extraction. + db_client (Optional[DBClient]): The database client. Defaults to None. """ self.file_path = file_path self.parser = parser + self.db_client = db_client def extract_entities(self, prompt: Optional[str] = None) -> List[Entity]: """ @@ -253,4 +272,114 @@ def get_json_schema(self): Dict[str, Any]: The current JSON schema from the parser. """ return self.parser.get_json_schema() + + def get_db_client(self): + """ + Retrieves the DB client. + + Returns: + DBClient: The DB client. + """ + if not isinstance(self.db_client, PostgresDBClient): + logger.error("DB client is not a relational database client.") + raise ValueError("DB client is not a relational database client.")\ + + return self.db_client + + def set_db_client(self, db_client: DBClient): + """ + Sets the DB client. + + Args: + db_client (DBClient): The DB client. + """ + if not isinstance(db_client, DBClient): + logger.error("DB client is not a valid DB client.") + raise ValueError("DB client is not a valid DB client.") + + self.db_client = db_client + + def create_tables(self): + """ + Creates the tables in a relational database. + """ + # checks if the db_client is a PostgresDBClient + if not isinstance(self.db_client, PostgresDBClient): + logger.error("DB client is not a relational database client.") + raise ValueError("DB client is not a relational database client.") + + # get the json schema + json_schema = self.get_json_schema() + + self.db_client.connect() + + class StateCreateTables(BaseModel): + json_schema: Optional[str] = None + sql_code: Optional[str] = None + retry: Optional[bool] = None + error: Optional[str] = None + retry_count: Optional[int] = 0 + + state_create_tables = StateCreateTables() + state_create_tables.json_schema = str(json_schema) + + + def generate_sql_code(state: StateCreateTables, *_) -> StateCreateTables: + if state.sql_code is None: + create_tables_prompt = CREATE_TABLES_PROMPT.format( + json_schema=json.dumps(state.json_schema, indent=2) + ) + sql_code = self.parser.llm_client.get_response(create_tables_prompt) + sql_code = sql_code.replace("```sql", "").replace("```", "").strip() + state.sql_code = sql_code + else: + create_tables_prompt_fixed = CREATE_TABLES_PROMPT.format( + json_schema=json.dumps(state.json_schema, indent=2) + ) + "You generated previously the following erroneous code: " + state.sql_code + "With the following error: " + state.error + " Please fix it, if the relation already exists in the database please just ignore it and do not create it again." + + state.retry_count += 1 + sql_code = self.parser.llm_client.get_response(create_tables_prompt_fixed) + sql_code = sql_code.replace("```sql", "").replace("```", "").strip() + state.sql_code = sql_code + return state + + def execute_sql_code(state: StateCreateTables, *_) -> Literal["success", "failure"]: + try: + self.db_client.execute_query(state.sql_code) + state.retry = False + return state + except Exception as e: + print(f"Error executing SQL: {e}") + state.error = str(e) + state.retry = True + return state + + def retry_or_not(state: StateCreateTables, *_): + if state.retry and state.retry_count < 2: + return "generate_sql_code" + else: + return END + + # Build the graph + workflow = StateGraph(StateCreateTables) + + # Add nodes + workflow.add_node("generate_sql_code", generate_sql_code) + workflow.add_node("execute_sql_code", execute_sql_code) + + # Add edges + workflow.add_edge(START, "generate_sql_code") + workflow.add_edge("generate_sql_code", "execute_sql_code") + workflow.add_conditional_edges( + "execute_sql_code", + retry_or_not + ) + workflow.add_edge("execute_sql_code", END) + + # Compile the graph + graph = workflow.compile() + # Execute the graph + graph.invoke(state_create_tables) + self.db_client.disconnect() + logger.info("Tables created successfully.") diff --git a/scrapontology/llm_client.py b/scrapontology/llm_client.py index 3dbae63..f7ce9bf 100644 --- a/scrapontology/llm_client.py +++ b/scrapontology/llm_client.py @@ -1,8 +1,8 @@ import requests import logging from typing import Dict, Any, Optional, List -import json -from requests.adapters import HTTPAdapter +from typing import Any, Callable +from pydantic_core import CoreSchema, core_schema logger = logging.getLogger(__name__) @@ -30,6 +30,10 @@ def __init__(self, api_key: str, inference_base_url: str = "https://api.openai.c self.session = requests.Session() + @classmethod + def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: Callable) -> CoreSchema: + return core_schema.any_schema() + def get_api_key(self) -> str: return self._api_key diff --git a/scrapontology/parsers/prompts.py b/scrapontology/parsers/prompts.py index e0a6c9e..2d7a08f 100644 --- a/scrapontology/parsers/prompts.py +++ b/scrapontology/parsers/prompts.py @@ -255,4 +255,13 @@ {code} Please provide only the corrected Python code, nothing else before or after the code. +""" + +CREATE_TABLES_PROMPT = """ +How would you create table based on this json schema, remember to use the \ +normalization and all the 3 forms of normalizations, provide me only the postgres \ +sql code without any comment before and after the code. + +Json schema: +{json_schema} """ \ No newline at end of file