Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions examples/.env.example
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# OpenAI API Key
OPENAI_API_KEY=your_openai_api_key_here

# PostgreSQL Database Configuration
POSTGRES_HOST=localhost
POSTGRES_PORT=5432
POSTGRES_DB=your_database_name
POSTGRES_USER=your_postgres_username
POSTGRES_PASSWORD=your_secure_password

# Example:
# OPENAI_API_KEY=sk-abcdefghijklmnopqrstuvwxyz123456
# POSTGRES_HOST=db.example.com
# POSTGRES_PORT=5432
# POSTGRES_DB=myapp_database
# POSTGRES_USER=db_user
# POSTGRES_PASSWORD=super_secret_password
43 changes: 43 additions & 0 deletions examples/generate_tables_from_pdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from scrapontology import FileExtractor, PDFParser
from scrapontology.llm_client import LLMClient
from scrapontology.db_client import PostgresDBClient
import os
from dotenv import load_dotenv

def main():
# Load environment variables
load_dotenv()
api_key = os.getenv("OPENAI_API_KEY")

# Get current directory and set PDF path
curr_dir = os.path.dirname(os.path.abspath(__file__))
pdf_name = "test.pdf"
pdf_path = os.path.join(curr_dir, pdf_name)

# Create LLMClient and PDFParser instances
llm_client = LLMClient(api_key)
pdf_parser = PDFParser(llm_client)

# Create DBClient instance
postgres_host = os.getenv("POSTGRES_HOST")
postgres_port = os.getenv("POSTGRES_PORT")
postgres_db = os.getenv("POSTGRES_DB")
postgres_user = os.getenv("POSTGRES_USER")
postgres_password = os.getenv("POSTGRES_PASSWORD")
db_client = PostgresDBClient(postgres_host, postgres_port, postgres_db, postgres_user, postgres_password)

# Create FileExtractor instance
pdf_extractor = FileExtractor(pdf_path, pdf_parser, db_client)

# Generate JSON schema from the PDF
json_schema = pdf_extractor.generate_entities_json_schema()
print("Generated JSON Schema:")
print(json_schema)

# Create tables in the database
pdf_extractor.create_tables()

print("Tables created successfully in the PostgreSQL database.")

if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ dependencies = [
"requests==2.32.3",
"urllib3==2.2.2",
"langgraph==0.2.31",
"psycopg2==2.9.9"
]

license = "MIT"
Expand Down
2 changes: 2 additions & 0 deletions requirements-dev.lock
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,8 @@ pluggy==1.5.0
# via pytest
prettytable==3.11.0
# via pyecharts
psycopg2==2.9.9
# via scrapontology
pydantic==2.9.2
# via langchain-core
# via langsmith
Expand Down
2 changes: 2 additions & 0 deletions requirements.lock
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,8 @@ packaging==24.1
# via langchain-core
pillow==10.4.0
# via scrapontology
psycopg2==2.9.9
# via scrapontology
pydantic==2.9.2
# via langchain-core
# via langsmith
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ pillow==10.4.0
python-dotenv==1.0.1
requests==2.32.3
urllib3==2.2.2
langgraph==0.2.32
langgraph==0.2.32
psycopg2==2.9.9
95 changes: 95 additions & 0 deletions scrapontology/db_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import os
from abc import ABC, abstractmethod
import psycopg2
from psycopg2.extras import RealDictCursor
import logging
from pydantic_core import CoreSchema, core_schema
from typing import Any, Callable


logger = logging.getLogger(__name__)

class DBClient(ABC):
@abstractmethod
def connect(self):
pass

@abstractmethod
def disconnect(self):
pass

@abstractmethod
def execute_query(self, query, params=None):
pass

class PostgresDBClient(DBClient):
def __init__(self, host=None, port=None, database=None, user=None, password=None):
self.conn = None
self.cursor = None
self.host = host or os.getenv('POSTGRES_HOST', 'localhost')
self.port = port or os.getenv('POSTGRE1S_PORT', '5432')
self.database = database or os.getenv('POSTGRES_DB', 'scrapontology_test')
self.user = user or os.getenv('POSTGRES_USER', 'lurens')
self.password = password or os.getenv('POSTGRES_PASSWORD', 'cicciopasticco')

@classmethod
def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: Callable) -> CoreSchema:
return core_schema.any_schema()

def connect(self):
try:
self.conn = psycopg2.connect(
host=self.host,
port=self.port,
database=self.database,
user=self.user,
password=self.password
)
self.cursor = self.conn.cursor(cursor_factory=RealDictCursor)
logger.info("Successfully connected to PostgreSQL database")
except (Exception, psycopg2.Error) as error:
logger.error(f"Error while connecting to PostgreSQL: {error}")

def disconnect(self):
if self.conn:
self.cursor.close()
self.conn.close()
logger.info("PostgreSQL connection is closed")

def execute_query(self, query, params=None):
try:
self.cursor.execute(query, params)
self.conn.commit()
return self.cursor.fetchall()
except (Exception, psycopg2.Error) as error:
logger.error(f"Error executing query: {error}")
self.conn.rollback()
raise error # Re-raise the exception to propagate it


class Neo4jDBClient(DBClient):
def __init__(self):
# Placeholder for future implementation
pass

def connect(self):
# Placeholder for future implementation
logger.info("Neo4j connection not yet implemented")

def disconnect(self):
# Placeholder for future implementation
pass

def execute_query(self, query, params=None):
# Placeholder for future implementation
logger.info("Neo4j query execution not yet implemented")
return None

# Factory function to get the appropriate DB client
def get_db_client(db_type='postgres'):
if db_type.lower() == 'postgres':
return PostgresDBClient()
elif db_type.lower() == 'neo4j':
return Neo4jDBClient()
else:
raise ValueError(f"Unsupported database type: {db_type}")
135 changes: 132 additions & 3 deletions scrapontology/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
from typing import List, Tuple, Dict, Any, Optional
from .primitives import Entity, Relation
from .parsers.base_parser import BaseParser
from .parsers.prompts import DELETE_PROMPT, UPDATE_ENTITIES_PROMPT, UPDATE_SCHEMA_PROMPT
from .parsers.prompts import DELETE_PROMPT, UPDATE_ENTITIES_PROMPT, UPDATE_SCHEMA_PROMPT, CREATE_TABLES_PROMPT
from .llm_client import LLMClient
from .parsers.prompts import UPDATE_SCHEMA_PROMPT
from .db_client import DBClient, PostgresDBClient
import json

from langgraph.graph import StateGraph, END, START
from typing import TypedDict, Literal
from scrapontology.db_client import PostgresDBClient
from scrapontology.llm_client import LLMClient
from pydantic import BaseModel
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

Expand Down Expand Up @@ -56,17 +61,31 @@ def get_relations_graph(self):
def get_json_schema(self):
pass

@abstractmethod
def get_db_client(self):
pass

@abstractmethod
def set_db_client(self, db_client: DBClient):
pass

@abstractmethod
def create_tables(self):
pass

class FileExtractor(Extractor):
def __init__(self, file_path: str, parser: BaseParser):
def __init__(self, file_path: str, parser: BaseParser, db_client: Optional[DBClient] = None):
"""
Initialize the FileExtractor.
Args:
file_path (str): The path to the file to be processed.
parser (BaseParser): The parser to be used for extraction.
db_client (Optional[DBClient]): The database client. Defaults to None.
"""
self.file_path = file_path
self.parser = parser
self.db_client = db_client

def extract_entities(self, prompt: Optional[str] = None) -> List[Entity]:
"""
Expand Down Expand Up @@ -253,4 +272,114 @@ def get_json_schema(self):
Dict[str, Any]: The current JSON schema from the parser.
"""
return self.parser.get_json_schema()

def get_db_client(self):
"""
Retrieves the DB client.
Returns:
DBClient: The DB client.
"""
if not isinstance(self.db_client, PostgresDBClient):
logger.error("DB client is not a relational database client.")
raise ValueError("DB client is not a relational database client.")\

return self.db_client

def set_db_client(self, db_client: DBClient):
"""
Sets the DB client.
Args:
db_client (DBClient): The DB client.
"""
if not isinstance(db_client, DBClient):
logger.error("DB client is not a valid DB client.")
raise ValueError("DB client is not a valid DB client.")

self.db_client = db_client

def create_tables(self):
"""
Creates the tables in a relational database.
"""
# checks if the db_client is a PostgresDBClient
if not isinstance(self.db_client, PostgresDBClient):
logger.error("DB client is not a relational database client.")
raise ValueError("DB client is not a relational database client.")

# get the json schema
json_schema = self.get_json_schema()

self.db_client.connect()

class StateCreateTables(BaseModel):
json_schema: Optional[str] = None
sql_code: Optional[str] = None
retry: Optional[bool] = None
error: Optional[str] = None
retry_count: Optional[int] = 0

state_create_tables = StateCreateTables()
state_create_tables.json_schema = str(json_schema)


def generate_sql_code(state: StateCreateTables, *_) -> StateCreateTables:
if state.sql_code is None:
create_tables_prompt = CREATE_TABLES_PROMPT.format(
json_schema=json.dumps(state.json_schema, indent=2)
)
sql_code = self.parser.llm_client.get_response(create_tables_prompt)
sql_code = sql_code.replace("```sql", "").replace("```", "").strip()
state.sql_code = sql_code
else:
create_tables_prompt_fixed = CREATE_TABLES_PROMPT.format(
json_schema=json.dumps(state.json_schema, indent=2)
) + "You generated previously the following erroneous code: " + state.sql_code + "With the following error: " + state.error + " Please fix it, if the relation already exists in the database please just ignore it and do not create it again."

state.retry_count += 1
sql_code = self.parser.llm_client.get_response(create_tables_prompt_fixed)
sql_code = sql_code.replace("```sql", "").replace("```", "").strip()
state.sql_code = sql_code
return state

def execute_sql_code(state: StateCreateTables, *_) -> Literal["success", "failure"]:
try:
self.db_client.execute_query(state.sql_code)
state.retry = False
return state
except Exception as e:
print(f"Error executing SQL: {e}")
state.error = str(e)
state.retry = True
return state

def retry_or_not(state: StateCreateTables, *_):
if state.retry and state.retry_count < 2:
return "generate_sql_code"
else:
return END

# Build the graph
workflow = StateGraph(StateCreateTables)

# Add nodes
workflow.add_node("generate_sql_code", generate_sql_code)
workflow.add_node("execute_sql_code", execute_sql_code)

# Add edges
workflow.add_edge(START, "generate_sql_code")
workflow.add_edge("generate_sql_code", "execute_sql_code")
workflow.add_conditional_edges(
"execute_sql_code",
retry_or_not
)
workflow.add_edge("execute_sql_code", END)

# Compile the graph
graph = workflow.compile()

# Execute the graph
graph.invoke(state_create_tables)
self.db_client.disconnect()
logger.info("Tables created successfully.")
8 changes: 6 additions & 2 deletions scrapontology/llm_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import requests
import logging
from typing import Dict, Any, Optional, List
import json
from requests.adapters import HTTPAdapter
from typing import Any, Callable
from pydantic_core import CoreSchema, core_schema

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -30,6 +30,10 @@ def __init__(self, api_key: str, inference_base_url: str = "https://api.openai.c

self.session = requests.Session()

@classmethod
def __get_pydantic_core_schema__(cls, _source_type: Any, _handler: Callable) -> CoreSchema:
return core_schema.any_schema()

def get_api_key(self) -> str:
return self._api_key

Expand Down
9 changes: 9 additions & 0 deletions scrapontology/parsers/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,4 +255,13 @@
{code}
Please provide only the corrected Python code, nothing else before or after the code.
"""

CREATE_TABLES_PROMPT = """
How would you create table based on this json schema, remember to use the \
normalization and all the 3 forms of normalizations, provide me only the postgres \
sql code without any comment before and after the code.
Json schema:
{json_schema}
"""