Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 109 additions & 17 deletions src/pipe/add_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,91 @@
from src.pipe.schema_repo import DatabaseSchema, DatabaseSchemaRepo


def _parse_schema_item(item: str) -> str | None:
"""
Parse and validate a schema item reference.

Parameters
----------
item : str
Schema item reference (e.g., 'COLUMN:table.column').

Returns
-------
str or None
Column reference if valid COLUMN item, None otherwise.
"""
if not isinstance(item, str) or ":" not in item:
return None

parts = item.split(":", 1)
if len(parts) != 2:
return None

item_type, item_ref = parts

if "[*]" in item_ref or item_type != "COLUMN":
return None

return item_ref


def _parse_column_ref(col_ref: str) -> tuple[str, str] | None:
"""
Parse a column reference into table and column names.

Parameters
----------
col_ref : str
Column reference in format 'table.column'.

Returns
-------
tuple[str, str] or None
(table_name, col_name) if valid, None otherwise.
"""
if "." not in col_ref:
return None

parts = col_ref.split(".", 1)
if len(parts) != 2:
return None

return parts[0], parts[1]


def _get_foreign_key(
schema: DatabaseSchema, table_name: str, col_name: str
) -> str | None:
"""
Get foreign key reference for a column if it exists.

Parameters
----------
schema : DatabaseSchema
The database schema.
table_name : str
Name of the table.
col_name : str
Name of the column.

Returns
-------
str or None
Foreign key reference if exists, None otherwise.
"""
if table_name not in schema.tables or col_name not in schema.tables[table_name]:
return None

col_data = schema.tables[table_name][col_name]
if isinstance(col_data, dict) and "foreign_key" in col_data:
fk_ref = col_data["foreign_key"]
if isinstance(fk_ref, str):
return fk_ref

return None


def filter_schema(schema: DatabaseSchema, schema_items: list[str]) -> DatabaseSchema:
"""
Filter database schema to include only specified schema items.
Expand All @@ -22,31 +107,38 @@ def filter_schema(schema: DatabaseSchema, schema_items: list[str]) -> DatabaseSc
DatabaseSchema
Filtered schema containing only the specified items and their foreign keys.
"""
if not schema_items:
return DatabaseSchema()

# Extract valid column references from schema items
columns = set()
for item in schema_items:
item_ref = item.split(":")[1]
if "[*]" in item_ref:
continue
if item.split(":")[0] == "COLUMN":
columns.add(item_ref)
col_ref = _parse_schema_item(item)
if col_ref:
columns.add(col_ref)

# Add foreign key references
for col_ref in list(columns):
table_name = col_ref.split(".")[0]
col_name = col_ref.split(".")[1]
col_data = schema.tables[table_name][col_name]
if isinstance(col_data, dict) and "foreign_key" in col_data:
fk_ref = col_data["foreign_key"]
if isinstance(fk_ref, str):
columns.add(fk_ref)
parsed = _parse_column_ref(col_ref)
if not parsed:
continue

table_name, col_name = parsed
fk_ref = _get_foreign_key(schema, table_name, col_name)
if fk_ref:
columns.add(fk_ref)

# Build filtered schema
filtered_schema = DatabaseSchema()
for table_name, table_columns in schema.tables.items():
filtered_table_columns = {}
for col_name, col_data in table_columns.items():
if f"{table_name}.{col_name}" in columns:
filtered_table_columns[col_name] = col_data
if len(filtered_table_columns) > 0:
filtered_table_columns = {
col_name: col_data
for col_name, col_data in table_columns.items()
if f"{table_name}.{col_name}" in columns
}
if filtered_table_columns:
filtered_schema.tables[table_name] = filtered_table_columns

return filtered_schema


Expand Down
53 changes: 48 additions & 5 deletions src/pipe/llm_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,44 @@ async def send_prompt(
return content, usage


def _preprocess_json_string(text: str) -> str:
"""
Pre-process JSON string to fix common LLM formatting errors.

Parameters
----------
text : str
Raw JSON text that may have formatting issues

Returns
-------
str
Cleaned JSON text
"""
# Strip whitespace
text = text.strip()

# Fix common array termination issues like: ["item1", "item2".]
# Replace ".] with "]
text = re.sub(r'"\s*\.\s*\]', '"]', text)

# Fix missing closing quotes before array end: ["item1", "item2]
# Find patterns like: "something] where ] should be "]
text = re.sub(r'([^"])\]', r'\1"]', text)
# But undo if we just added "" which would be wrong
text = text.replace('"""]', '"]')

# Fix patterns like: "COLUMN:." or "TABLE:" (empty after colon)
# Remove items that are just "TYPE:." or "TYPE:"
text = re.sub(r'"\s*[A-Z]+\s*:\s*\.?\s*"', '""', text)

# Remove empty strings from arrays
text = re.sub(r',\s*""\s*,', ",", text) # middle
text = re.sub(r'\[\s*""\s*,', "[", text) # start
text = re.sub(r',\s*""\s*\]', "]", text) # end
return re.sub(r'\[\s*""\s*\]', "[]", text) # only item


def extract_json(text: str) -> dict[str, Any] | None:
"""
Extract JSON object from text with code blocks.
Expand All @@ -121,13 +159,17 @@ def extract_json(text: str) -> dict[str, Any] | None:
try:
if "```json" in text:
res = re.findall(r"```json([\s\S]*?)```", text)
json_res = json.loads(res[0])
json_text = res[0]
elif "```" in text:
res = re.findall(r"```([\s\S]*?)```", text)
json_res = json.loads(res[0])
json_text = res[0]
else:
json_res = json.loads(text)
return json_res
json_text = text

# Pre-process to fix common formatting errors
json_text = _preprocess_json_string(json_text)

return json.loads(json_text)
except Exception as e:
logger.warning(f"Failed to extract json from: {text}, error={e}")
return None
Expand Down Expand Up @@ -172,6 +214,7 @@ def extract_object(text: str) -> Any | None:
if obj is None:
obj = eval_literal(text)
if obj is None:
logger.error(f"Failed to extract object: {text}")
# Only log at debug level since callers typically handle None gracefully
logger.debug(f"Failed to extract object: {text}")
obj = None
return obj
69 changes: 67 additions & 2 deletions src/pipe/rank_schema_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from src.pipe.llm_util import extract_object
from src.pipe.rank_schema_prompts.v1 import RANK_SCHEMA_ITEMS_V1
from src.pipe.schema_repo import DatabaseSchemaRepo
from src.utils.logging import logger


class RankSchemaItems(PromptProcessor):
Expand All @@ -31,8 +32,72 @@ def __init__(
super().__init__(prop_name, openai_config=openai_config, model=model)
self.schema_repo = DatabaseSchemaRepo(tables_path)

def _process_output(self, row: dict[str, Any], output: str) -> Any:
return extract_object(output)
def _sanitize_schema_item(self, item: str) -> str | None:
"""
Sanitize a schema item reference to ensure proper formatting.

Parameters
----------
item : str
Schema item reference (e.g., "TABLE:[name]" or "COLUMN:[table].[col]")

Returns
-------
str or None
Sanitized schema item or None if invalid
"""
if not isinstance(item, str) or ":" not in item:
return None

parts = item.split(":", 1)
if len(parts) != 2:
return None

item_type, item_ref = parts

# Skip empty references
if not item_ref or item_ref.strip() in ["", ".", "[.]"]:
return None

# Ensure all opening brackets have closing brackets
bracket_count = item_ref.count("[") - item_ref.count("]")
if bracket_count > 0:
# Add missing closing brackets
item_ref = item_ref + ("]" * bracket_count)
elif bracket_count < 0:
# More closing than opening - invalid
return None

return f"{item_type}:{item_ref}"

def _process_output(self, row: dict[str, Any], output: str) -> list[str]:
result = extract_object(output)

# Handle None or invalid output
if result is None or not isinstance(result, list):
logger.warning(
f"LLM returned invalid schema items for question_id={row.get('question_id')}, "
f"falling back to all schema items"
)
# Fallback: return all schema items
return self.extract_schema_items(row)

# Sanitize and filter out invalid items
sanitized_items = []
for item in result:
sanitized = self._sanitize_schema_item(item)
if sanitized:
sanitized_items.append(sanitized)

# If sanitization removed everything, fallback to all items
if not sanitized_items:
logger.warning(
f"All LLM schema items were invalid for question_id={row.get('question_id')}, "
f"falling back to all schema items"
)
return self.extract_schema_items(row)

return sanitized_items

def extract_schema_items(self, row: dict[str, Any]) -> list[str]:
"""
Expand Down
58 changes: 35 additions & 23 deletions src/pipe/rank_schema_prompts/v1.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,52 @@
"""Schema ranking prompt template version 1."""

RANK_SCHEMA_ITEMS_V1 = """
You are given:
1. A natural language question.
2. A list of schema items of an underlying database. Each schema item is either
"TABLE:[table_name]" or "COLUMN:[table_name].[column_name]

Task:
Filter the given list and return a subset of these items that are most relevant to the given question.
You can include at most 4 tables and at most 5 columns for each table.

Example:
Question: “What is the name of the instructor who has the lowest salary?”
Schema Items:
RANK_SCHEMA_ITEMS_V1 = """You are a database schema analyzer. Your task is to identify which schema items are relevant for answering a given question.

## Input Format
You will receive:
1. A natural language question
2. A list of schema items (tables and columns) from a database

Each schema item follows this format:
- Tables: "TABLE:[table_name]"
- Columns: "COLUMN:[table_name].[column_name]"

## Task
Select the schema items needed to answer the question. Choose:
- Maximum 4 tables
- Maximum 5 columns per table

## Output Requirements
1. Return a valid JSON array of strings
2. Select items EXACTLY as they appear in the input list - do not modify them
3. Include only items that are relevant to answering the question
4. Ensure the output is valid JSON (properly quoted and bracketed)

## Example

Input Question: "What is the name of the instructor who has the lowest salary?"

Input Schema Items:
[
"TABLE:[department]",
"COLUMN:[department].[name]",
"TABLE:[instructor]",
"COLUMN:[instructor].[name]"
"COLUMN:[instructor].[salary]"
"COLUMN:[instructor].[name]",
"COLUMN:[instructor].[salary]",
"COLUMN:[instructor].[age]"
]

Output:
Expected Output:
[
"TABLE:[instructor]",
"COLUMN:[instructor].[name]"
"COLUMN:[instructor].[name]",
"COLUMN:[instructor].[salary]"
]

Now filter the following list of Schema Items based on the given question.
## Your Turn

Question: {question}

Schema Items: {schema_items}

Instructions:
- Output only a valid list of strings.
- Do not include any additional text, explanations, or formatting.:
- All strings should be in double quotes.
"""
Output:"""
Loading