diff --git a/src/pipe/add_schema.py b/src/pipe/add_schema.py index 827f25a..b722c06 100644 --- a/src/pipe/add_schema.py +++ b/src/pipe/add_schema.py @@ -6,6 +6,91 @@ from src.pipe.schema_repo import DatabaseSchema, DatabaseSchemaRepo +def _parse_schema_item(item: str) -> str | None: + """ + Parse and validate a schema item reference. + + Parameters + ---------- + item : str + Schema item reference (e.g., 'COLUMN:table.column'). + + Returns + ------- + str or None + Column reference if valid COLUMN item, None otherwise. + """ + if not isinstance(item, str) or ":" not in item: + return None + + parts = item.split(":", 1) + if len(parts) != 2: + return None + + item_type, item_ref = parts + + if "[*]" in item_ref or item_type != "COLUMN": + return None + + return item_ref + + +def _parse_column_ref(col_ref: str) -> tuple[str, str] | None: + """ + Parse a column reference into table and column names. + + Parameters + ---------- + col_ref : str + Column reference in format 'table.column'. + + Returns + ------- + tuple[str, str] or None + (table_name, col_name) if valid, None otherwise. + """ + if "." not in col_ref: + return None + + parts = col_ref.split(".", 1) + if len(parts) != 2: + return None + + return parts[0], parts[1] + + +def _get_foreign_key( + schema: DatabaseSchema, table_name: str, col_name: str +) -> str | None: + """ + Get foreign key reference for a column if it exists. + + Parameters + ---------- + schema : DatabaseSchema + The database schema. + table_name : str + Name of the table. + col_name : str + Name of the column. + + Returns + ------- + str or None + Foreign key reference if exists, None otherwise. + """ + if table_name not in schema.tables or col_name not in schema.tables[table_name]: + return None + + col_data = schema.tables[table_name][col_name] + if isinstance(col_data, dict) and "foreign_key" in col_data: + fk_ref = col_data["foreign_key"] + if isinstance(fk_ref, str): + return fk_ref + + return None + + def filter_schema(schema: DatabaseSchema, schema_items: list[str]) -> DatabaseSchema: """ Filter database schema to include only specified schema items. @@ -22,31 +107,38 @@ def filter_schema(schema: DatabaseSchema, schema_items: list[str]) -> DatabaseSc DatabaseSchema Filtered schema containing only the specified items and their foreign keys. """ + if not schema_items: + return DatabaseSchema() + + # Extract valid column references from schema items columns = set() for item in schema_items: - item_ref = item.split(":")[1] - if "[*]" in item_ref: - continue - if item.split(":")[0] == "COLUMN": - columns.add(item_ref) + col_ref = _parse_schema_item(item) + if col_ref: + columns.add(col_ref) + # Add foreign key references for col_ref in list(columns): - table_name = col_ref.split(".")[0] - col_name = col_ref.split(".")[1] - col_data = schema.tables[table_name][col_name] - if isinstance(col_data, dict) and "foreign_key" in col_data: - fk_ref = col_data["foreign_key"] - if isinstance(fk_ref, str): - columns.add(fk_ref) + parsed = _parse_column_ref(col_ref) + if not parsed: + continue + table_name, col_name = parsed + fk_ref = _get_foreign_key(schema, table_name, col_name) + if fk_ref: + columns.add(fk_ref) + + # Build filtered schema filtered_schema = DatabaseSchema() for table_name, table_columns in schema.tables.items(): - filtered_table_columns = {} - for col_name, col_data in table_columns.items(): - if f"{table_name}.{col_name}" in columns: - filtered_table_columns[col_name] = col_data - if len(filtered_table_columns) > 0: + filtered_table_columns = { + col_name: col_data + for col_name, col_data in table_columns.items() + if f"{table_name}.{col_name}" in columns + } + if filtered_table_columns: filtered_schema.tables[table_name] = filtered_table_columns + return filtered_schema diff --git a/src/pipe/llm_util.py b/src/pipe/llm_util.py index 07d1262..ed1445d 100644 --- a/src/pipe/llm_util.py +++ b/src/pipe/llm_util.py @@ -104,6 +104,44 @@ async def send_prompt( return content, usage +def _preprocess_json_string(text: str) -> str: + """ + Pre-process JSON string to fix common LLM formatting errors. + + Parameters + ---------- + text : str + Raw JSON text that may have formatting issues + + Returns + ------- + str + Cleaned JSON text + """ + # Strip whitespace + text = text.strip() + + # Fix common array termination issues like: ["item1", "item2".] + # Replace ".] with "] + text = re.sub(r'"\s*\.\s*\]', '"]', text) + + # Fix missing closing quotes before array end: ["item1", "item2] + # Find patterns like: "something] where ] should be "] + text = re.sub(r'([^"])\]', r'\1"]', text) + # But undo if we just added "" which would be wrong + text = text.replace('"""]', '"]') + + # Fix patterns like: "COLUMN:." or "TABLE:" (empty after colon) + # Remove items that are just "TYPE:." or "TYPE:" + text = re.sub(r'"\s*[A-Z]+\s*:\s*\.?\s*"', '""', text) + + # Remove empty strings from arrays + text = re.sub(r',\s*""\s*,', ",", text) # middle + text = re.sub(r'\[\s*""\s*,', "[", text) # start + text = re.sub(r',\s*""\s*\]', "]", text) # end + return re.sub(r'\[\s*""\s*\]', "[]", text) # only item + + def extract_json(text: str) -> dict[str, Any] | None: """ Extract JSON object from text with code blocks. @@ -121,13 +159,17 @@ def extract_json(text: str) -> dict[str, Any] | None: try: if "```json" in text: res = re.findall(r"```json([\s\S]*?)```", text) - json_res = json.loads(res[0]) + json_text = res[0] elif "```" in text: res = re.findall(r"```([\s\S]*?)```", text) - json_res = json.loads(res[0]) + json_text = res[0] else: - json_res = json.loads(text) - return json_res + json_text = text + + # Pre-process to fix common formatting errors + json_text = _preprocess_json_string(json_text) + + return json.loads(json_text) except Exception as e: logger.warning(f"Failed to extract json from: {text}, error={e}") return None @@ -172,6 +214,7 @@ def extract_object(text: str) -> Any | None: if obj is None: obj = eval_literal(text) if obj is None: - logger.error(f"Failed to extract object: {text}") + # Only log at debug level since callers typically handle None gracefully + logger.debug(f"Failed to extract object: {text}") obj = None return obj diff --git a/src/pipe/rank_schema_llm.py b/src/pipe/rank_schema_llm.py index 91b6c3f..c0b928a 100644 --- a/src/pipe/rank_schema_llm.py +++ b/src/pipe/rank_schema_llm.py @@ -7,6 +7,7 @@ from src.pipe.llm_util import extract_object from src.pipe.rank_schema_prompts.v1 import RANK_SCHEMA_ITEMS_V1 from src.pipe.schema_repo import DatabaseSchemaRepo +from src.utils.logging import logger class RankSchemaItems(PromptProcessor): @@ -31,8 +32,72 @@ def __init__( super().__init__(prop_name, openai_config=openai_config, model=model) self.schema_repo = DatabaseSchemaRepo(tables_path) - def _process_output(self, row: dict[str, Any], output: str) -> Any: - return extract_object(output) + def _sanitize_schema_item(self, item: str) -> str | None: + """ + Sanitize a schema item reference to ensure proper formatting. + + Parameters + ---------- + item : str + Schema item reference (e.g., "TABLE:[name]" or "COLUMN:[table].[col]") + + Returns + ------- + str or None + Sanitized schema item or None if invalid + """ + if not isinstance(item, str) or ":" not in item: + return None + + parts = item.split(":", 1) + if len(parts) != 2: + return None + + item_type, item_ref = parts + + # Skip empty references + if not item_ref or item_ref.strip() in ["", ".", "[.]"]: + return None + + # Ensure all opening brackets have closing brackets + bracket_count = item_ref.count("[") - item_ref.count("]") + if bracket_count > 0: + # Add missing closing brackets + item_ref = item_ref + ("]" * bracket_count) + elif bracket_count < 0: + # More closing than opening - invalid + return None + + return f"{item_type}:{item_ref}" + + def _process_output(self, row: dict[str, Any], output: str) -> list[str]: + result = extract_object(output) + + # Handle None or invalid output + if result is None or not isinstance(result, list): + logger.warning( + f"LLM returned invalid schema items for question_id={row.get('question_id')}, " + f"falling back to all schema items" + ) + # Fallback: return all schema items + return self.extract_schema_items(row) + + # Sanitize and filter out invalid items + sanitized_items = [] + for item in result: + sanitized = self._sanitize_schema_item(item) + if sanitized: + sanitized_items.append(sanitized) + + # If sanitization removed everything, fallback to all items + if not sanitized_items: + logger.warning( + f"All LLM schema items were invalid for question_id={row.get('question_id')}, " + f"falling back to all schema items" + ) + return self.extract_schema_items(row) + + return sanitized_items def extract_schema_items(self, row: dict[str, Any]) -> list[str]: """ diff --git a/src/pipe/rank_schema_prompts/v1.py b/src/pipe/rank_schema_prompts/v1.py index 9843298..2f24e70 100644 --- a/src/pipe/rank_schema_prompts/v1.py +++ b/src/pipe/rank_schema_prompts/v1.py @@ -1,40 +1,52 @@ """Schema ranking prompt template version 1.""" -RANK_SCHEMA_ITEMS_V1 = """ -You are given: - 1. A natural language question. - 2. A list of schema items of an underlying database. Each schema item is either - "TABLE:[table_name]" or "COLUMN:[table_name].[column_name] - -Task: -Filter the given list and return a subset of these items that are most relevant to the given question. -You can include at most 4 tables and at most 5 columns for each table. - -Example: -Question: “What is the name of the instructor who has the lowest salary?” -Schema Items: +RANK_SCHEMA_ITEMS_V1 = """You are a database schema analyzer. Your task is to identify which schema items are relevant for answering a given question. + +## Input Format +You will receive: +1. A natural language question +2. A list of schema items (tables and columns) from a database + +Each schema item follows this format: +- Tables: "TABLE:[table_name]" +- Columns: "COLUMN:[table_name].[column_name]" + +## Task +Select the schema items needed to answer the question. Choose: +- Maximum 4 tables +- Maximum 5 columns per table + +## Output Requirements +1. Return a valid JSON array of strings +2. Select items EXACTLY as they appear in the input list - do not modify them +3. Include only items that are relevant to answering the question +4. Ensure the output is valid JSON (properly quoted and bracketed) + +## Example + +Input Question: "What is the name of the instructor who has the lowest salary?" + +Input Schema Items: [ "TABLE:[department]", "COLUMN:[department].[name]", "TABLE:[instructor]", - "COLUMN:[instructor].[name]" - "COLUMN:[instructor].[salary]" + "COLUMN:[instructor].[name]", + "COLUMN:[instructor].[salary]", "COLUMN:[instructor].[age]" ] -Output: +Expected Output: [ "TABLE:[instructor]", - "COLUMN:[instructor].[name]" + "COLUMN:[instructor].[name]", "COLUMN:[instructor].[salary]" ] -Now filter the following list of Schema Items based on the given question. +## Your Turn + Question: {question} + Schema Items: {schema_items} -Instructions: -- Output only a valid list of strings. -- Do not include any additional text, explanations, or formatting.: -- All strings should be in double quotes. -""" +Output:"""