From 7fd3ba626d31b53cd71552a25523cbb3e6e2bf94 Mon Sep 17 00:00:00 2001 From: Amrit Krishnan Date: Thu, 6 Nov 2025 14:43:53 -0500 Subject: [PATCH 1/2] Fixes to llm output parsing when using LLM based ranking --- main.py | 3 - src/pipe/add_schema.py | 126 +++++++++++++++++++++++++---- src/pipe/llm_util.py | 53 ++++++++++-- src/pipe/rank_schema_llm.py | 69 +++++++++++++++- src/pipe/rank_schema_prompts/v1.py | 58 +++++++------ 5 files changed, 259 insertions(+), 50 deletions(-) diff --git a/main.py b/main.py index dc14393..61019b9 100644 --- a/main.py +++ b/main.py @@ -144,9 +144,6 @@ async def main() -> None: # Handle clean operation if args.clean: clean_data_directory(args.data) - # If no pipeline operation requested (only cleaning), exit - if not args.resd: - return # Run pipeline conf = MaskSqlConfig(args.data, args.resd, "full") diff --git a/src/pipe/add_schema.py b/src/pipe/add_schema.py index 827f25a..b722c06 100644 --- a/src/pipe/add_schema.py +++ b/src/pipe/add_schema.py @@ -6,6 +6,91 @@ from src.pipe.schema_repo import DatabaseSchema, DatabaseSchemaRepo +def _parse_schema_item(item: str) -> str | None: + """ + Parse and validate a schema item reference. + + Parameters + ---------- + item : str + Schema item reference (e.g., 'COLUMN:table.column'). + + Returns + ------- + str or None + Column reference if valid COLUMN item, None otherwise. + """ + if not isinstance(item, str) or ":" not in item: + return None + + parts = item.split(":", 1) + if len(parts) != 2: + return None + + item_type, item_ref = parts + + if "[*]" in item_ref or item_type != "COLUMN": + return None + + return item_ref + + +def _parse_column_ref(col_ref: str) -> tuple[str, str] | None: + """ + Parse a column reference into table and column names. + + Parameters + ---------- + col_ref : str + Column reference in format 'table.column'. + + Returns + ------- + tuple[str, str] or None + (table_name, col_name) if valid, None otherwise. + """ + if "." not in col_ref: + return None + + parts = col_ref.split(".", 1) + if len(parts) != 2: + return None + + return parts[0], parts[1] + + +def _get_foreign_key( + schema: DatabaseSchema, table_name: str, col_name: str +) -> str | None: + """ + Get foreign key reference for a column if it exists. + + Parameters + ---------- + schema : DatabaseSchema + The database schema. + table_name : str + Name of the table. + col_name : str + Name of the column. + + Returns + ------- + str or None + Foreign key reference if exists, None otherwise. + """ + if table_name not in schema.tables or col_name not in schema.tables[table_name]: + return None + + col_data = schema.tables[table_name][col_name] + if isinstance(col_data, dict) and "foreign_key" in col_data: + fk_ref = col_data["foreign_key"] + if isinstance(fk_ref, str): + return fk_ref + + return None + + def filter_schema(schema: DatabaseSchema, schema_items: list[str]) -> DatabaseSchema: """ Filter database schema to include only specified schema items. @@ -22,31 +107,38 @@ def filter_schema(schema: DatabaseSchema, schema_items: list[str]) -> DatabaseSc DatabaseSchema Filtered schema containing only the specified items and their foreign keys. """ + if not schema_items: + return DatabaseSchema() + + # Extract valid column references from schema items columns = set() for item in schema_items: - item_ref = item.split(":")[1] - if "[*]" in item_ref: - continue - if item.split(":")[0] == "COLUMN": - columns.add(item_ref) + col_ref = _parse_schema_item(item) + if col_ref: + columns.add(col_ref) + # Add foreign key references for col_ref in list(columns): - table_name = col_ref.split(".")[0] - col_name = col_ref.split(".")[1] - col_data = schema.tables[table_name][col_name] - if isinstance(col_data, dict) and "foreign_key" in col_data: - fk_ref = col_data["foreign_key"] - if isinstance(fk_ref, str): - columns.add(fk_ref) + parsed = _parse_column_ref(col_ref) + if not parsed: + continue + table_name, col_name = parsed + fk_ref = _get_foreign_key(schema, table_name, col_name) + if fk_ref: + columns.add(fk_ref) + + # Build filtered schema filtered_schema = DatabaseSchema() for table_name, table_columns in schema.tables.items(): - filtered_table_columns = {} - for col_name, col_data in table_columns.items(): - if f"{table_name}.{col_name}" in columns: - filtered_table_columns[col_name] = col_data - if len(filtered_table_columns) > 0: + filtered_table_columns = { + col_name: col_data + for col_name, col_data in table_columns.items() + if f"{table_name}.{col_name}" in columns + } + if filtered_table_columns: filtered_schema.tables[table_name] = filtered_table_columns + return filtered_schema diff --git a/src/pipe/llm_util.py b/src/pipe/llm_util.py index 7d2326d..a6a3827 100644 --- a/src/pipe/llm_util.py +++ b/src/pipe/llm_util.py @@ -103,6 +103,44 @@ async def send_prompt(prompt: str, model: str | None = None) -> tuple[str, str]: return content, usage +def _preprocess_json_string(text: str) -> str: + """ + Pre-process JSON string to fix common LLM formatting errors. + + Parameters + ---------- + text : str + Raw JSON text that may have formatting issues + + Returns + ------- + str + Cleaned JSON text + """ + # Strip whitespace + text = text.strip() + + # Fix common array termination issues like: ["item1", "item2".] + # Replace ".] with "] + text = re.sub(r'"\s*\.\s*\]', '"]', text) + + # Fix missing closing quotes before array end: ["item1", "item2] + # Find patterns like: "something] where ] should be "] + text = re.sub(r'([^"])\]', r'\1"]', text) + # But undo if we just added "" which would be wrong + text = text.replace('"""]', '"]') + + # Fix patterns like: "COLUMN:." or "TABLE:" (empty after colon) + # Remove items that are just "TYPE:." or "TYPE:" + text = re.sub(r'"\s*[A-Z]+\s*:\s*\.?\s*"', '""', text) + + # Remove empty strings from arrays + text = re.sub(r',\s*""\s*,', ",", text) # middle + text = re.sub(r'\[\s*""\s*,', "[", text) # start + text = re.sub(r',\s*""\s*\]', "]", text) # end + return re.sub(r'\[\s*""\s*\]', "[]", text) # only item + + def extract_json(text: str) -> dict[str, Any] | None: """ Extract JSON object from text with code blocks. @@ -120,13 +158,17 @@ def extract_json(text: str) -> dict[str, Any] | None: try: if "```json" in text: res = re.findall(r"```json([\s\S]*?)```", text) - json_res = json.loads(res[0]) + json_text = res[0] elif "```" in text: res = re.findall(r"```([\s\S]*?)```", text) - json_res = json.loads(res[0]) + json_text = res[0] else: - json_res = json.loads(text) - return json_res + json_text = text + + # Pre-process to fix common formatting errors + json_text = _preprocess_json_string(json_text) + + return json.loads(json_text) except Exception as e: logger.warning(f"Failed to extract json from: {text}, error={e}") return None @@ -171,6 +213,7 @@ def extract_object(text: str) -> Any | None: if obj is None: obj = eval_literal(text) if obj is None: - logger.error(f"Failed to extract object: {text}") + # Only log at debug level since callers typically handle None gracefully + logger.debug(f"Failed to extract object: {text}") obj = None return obj diff --git a/src/pipe/rank_schema_llm.py b/src/pipe/rank_schema_llm.py index c42088d..48e6662 100644 --- a/src/pipe/rank_schema_llm.py +++ b/src/pipe/rank_schema_llm.py @@ -6,6 +6,7 @@ from src.pipe.llm_util import extract_object from src.pipe.rank_schema_prompts.v1 import RANK_SCHEMA_ITEMS_V1 from src.pipe.schema_repo import DatabaseSchemaRepo +from src.utils.logging import logger class RankSchemaItems(PromptProcessor): @@ -26,8 +27,72 @@ def __init__(self, prop_name: str, tables_path: str, model: str) -> None: super().__init__(prop_name, model=model) self.schema_repo = DatabaseSchemaRepo(tables_path) - def _process_output(self, row: dict[str, Any], output: str) -> Any: - return extract_object(output) + def _sanitize_schema_item(self, item: str) -> str | None: + """ + Sanitize a schema item reference to ensure proper formatting. + + Parameters + ---------- + item : str + Schema item reference (e.g., "TABLE:[name]" or "COLUMN:[table].[col]") + + Returns + ------- + str or None + Sanitized schema item or None if invalid + """ + if not isinstance(item, str) or ":" not in item: + return None + + parts = item.split(":", 1) + if len(parts) != 2: + return None + + item_type, item_ref = parts + + # Skip empty references + if not item_ref or item_ref.strip() in ["", ".", "[.]"]: + return None + + # Ensure all opening brackets have closing brackets + bracket_count = item_ref.count("[") - item_ref.count("]") + if bracket_count > 0: + # Add missing closing brackets + item_ref = item_ref + ("]" * bracket_count) + elif bracket_count < 0: + # More closing than opening - invalid + return None + + return f"{item_type}:{item_ref}" + + def _process_output(self, row: dict[str, Any], output: str) -> list[str]: + result = extract_object(output) + + # Handle None or invalid output + if result is None or not isinstance(result, list): + logger.warning( + f"LLM returned invalid schema items for question_id={row.get('question_id')}, " + f"falling back to all schema items" + ) + # Fallback: return all schema items + return self.extract_schema_items(row) + + # Sanitize and filter out invalid items + sanitized_items = [] + for item in result: + sanitized = self._sanitize_schema_item(item) + if sanitized: + sanitized_items.append(sanitized) + + # If sanitization removed everything, fallback to all items + if not sanitized_items: + logger.warning( + f"All LLM schema items were invalid for question_id={row.get('question_id')}, " + f"falling back to all schema items" + ) + return self.extract_schema_items(row) + + return sanitized_items def extract_schema_items(self, row: dict[str, Any]) -> list[str]: """ diff --git a/src/pipe/rank_schema_prompts/v1.py b/src/pipe/rank_schema_prompts/v1.py index 9843298..2f24e70 100644 --- a/src/pipe/rank_schema_prompts/v1.py +++ b/src/pipe/rank_schema_prompts/v1.py @@ -1,40 +1,52 @@ """Schema ranking prompt template version 1.""" -RANK_SCHEMA_ITEMS_V1 = """ -You are given: - 1. A natural language question. - 2. A list of schema items of an underlying database. Each schema item is either - "TABLE:[table_name]" or "COLUMN:[table_name].[column_name] - -Task: -Filter the given list and return a subset of these items that are most relevant to the given question. -You can include at most 4 tables and at most 5 columns for each table. - -Example: -Question: “What is the name of the instructor who has the lowest salary?” -Schema Items: +RANK_SCHEMA_ITEMS_V1 = """You are a database schema analyzer. Your task is to identify which schema items are relevant for answering a given question. + +## Input Format +You will receive: +1. A natural language question +2. A list of schema items (tables and columns) from a database + +Each schema item follows this format: +- Tables: "TABLE:[table_name]" +- Columns: "COLUMN:[table_name].[column_name]" + +## Task +Select the schema items needed to answer the question. Choose: +- Maximum 4 tables +- Maximum 5 columns per table + +## Output Requirements +1. Return a valid JSON array of strings +2. Select items EXACTLY as they appear in the input list - do not modify them +3. Include only items that are relevant to answering the question +4. Ensure the output is valid JSON (properly quoted and bracketed) + +## Example + +Input Question: "What is the name of the instructor who has the lowest salary?" + +Input Schema Items: [ "TABLE:[department]", "COLUMN:[department].[name]", "TABLE:[instructor]", - "COLUMN:[instructor].[name]" - "COLUMN:[instructor].[salary]" + "COLUMN:[instructor].[name]", + "COLUMN:[instructor].[salary]", "COLUMN:[instructor].[age]" ] -Output: +Expected Output: [ "TABLE:[instructor]", - "COLUMN:[instructor].[name]" + "COLUMN:[instructor].[name]", "COLUMN:[instructor].[salary]" ] -Now filter the following list of Schema Items based on the given question. +## Your Turn + Question: {question} + Schema Items: {schema_items} -Instructions: -- Output only a valid list of strings. -- Do not include any additional text, explanations, or formatting.: -- All strings should be in double quotes. -""" +Output:""" From a4648170149d6c879adea1606cd3a91ea0ec62e5 Mon Sep 17 00:00:00 2001 From: Amrit Krishnan Date: Mon, 17 Nov 2025 09:02:23 -0500 Subject: [PATCH 2/2] Revert change to main.py --- main.py | 54 ++++++++++++++++++++++++++++++++++++------------------ 1 file changed, 36 insertions(+), 18 deletions(-) diff --git a/main.py b/main.py index 61019b9..9a44440 100644 --- a/main.py +++ b/main.py @@ -3,9 +3,10 @@ import argparse import asyncio import logging +import os from pathlib import Path -from config import MaskSqlConfig +from src.config import MaskSqlConfig from src.pipe.add_schema import AddFilteredSchema from src.pipe.add_symb_schema import AddSymbolicSchema from src.pipe.attack import AddInferenceAttack @@ -25,6 +26,7 @@ from src.pipe.repair_symb_sql import RepairSymbolicSQL from src.pipe.resdsql import AddResd from src.pipe.results import Results +from src.pipe.run_resdsql import RunResdsql from src.pipe.slm_sql import SlmSQL from src.pipe.symb_table import AddSymbolTable from src.pipe.unmask import AddConcreteSql @@ -95,10 +97,25 @@ def create_pipeline_stages(conf: MaskSqlConfig) -> list[JsonListProcessor]: List of pipeline stage objects to execute. """ if conf.resd: - rank_schema = [AddResd(conf.resd_path), RankSchemaResd(conf.tables_path)] + # Use RESDSQL for schema ranking + # RunResdsql will skip if output already exists (unless force=True) + device = os.environ.get("TORCH_DEVICE", "cpu") + rank_schema = [ + RunResdsql( + conf.tables_path, + conf.db_path, + conf.resd_path, + device=device, + ), + AddResd(conf.resd_path), + RankSchemaResd(conf.tables_path), + ] else: + # Use LLM-based schema ranking rank_schema = [ - RankSchemaItems("schema_items", conf.tables_path, model=conf.slm) + RankSchemaItems( + "schema_items", conf.tables_path, conf.openai, model=conf.slm + ) ] return [ LimitJson(), @@ -106,21 +123,21 @@ def create_pipeline_stages(conf: MaskSqlConfig) -> list[JsonListProcessor]: # ResdItemCount(), AddFilteredSchema(conf.tables_path), AddSymbolTable(conf.tables_path), - SlmSQL("slm_sql", model=conf.slm), - DetectValues("values", model=conf.slm), - LinkValues("value_links", model=conf.slm), + SlmSQL("slm_sql", conf.openai, model=conf.slm), + DetectValues("values", conf.openai, model=conf.slm), + LinkValues("value_links", conf.openai, model=conf.slm), CopyTransformer("value_links", "filtered_value_links"), - LinkSchema("schema_links", model=conf.slm), + LinkSchema("schema_links", conf.openai, model=conf.slm), CopyTransformer("schema_links", "filtered_schema_links"), AddSymbolicSchema(conf.tables_path), AddSymbolicQuestion(), - GenerateSymbolicSql("symbolic", model=conf.llm), - RepairSymbolicSQL("symbolic", model=conf.llm), + GenerateSymbolicSql("symbolic", conf.openai, model=conf.llm), + RepairSymbolicSQL("symbolic", conf.openai, model=conf.llm), AddConcreteSql(), ExecuteConcreteSql(conf.db_path), - RepairSQL("pred_sql", model=conf.slm), + RepairSQL("pred_sql", conf.openai, model=conf.slm), CalcExecAcc(conf.db_path, conf.policy), - AddInferenceAttack("attack", model=conf.llm), + AddInferenceAttack("attack", conf.openai, model=conf.llm), # PrintProps(['question', 'symbolic.question', 'attack']) Results(), ] @@ -129,24 +146,25 @@ def create_pipeline_stages(conf: MaskSqlConfig) -> list[JsonListProcessor]: async def main() -> None: """Run the MaskSQL main pipeline.""" parser = argparse.ArgumentParser(description="MaskSQL") - parser.add_argument( - "--data", type=str, required=False, help="Data directory", default="data" - ) - parser.add_argument("--resd", action="store_true", dest="resd", help="Use RESDSQL") parser.add_argument( "--clean", action="store_true", help="Clean intermediate files from data directory", ) + parser.add_argument( + "-c", "--config", default="configs/conf.yaml", help="Path to config file" + ) args = parser.parse_args() configure_logging() + # Load configuration + conf = MaskSqlConfig.from_yaml(args.config) + # Handle clean operation if args.clean: - clean_data_directory(args.data) + clean_data_directory(conf.data_dir) - # Run pipeline - conf = MaskSqlConfig(args.data, args.resd, "full") + # Create and run pipeline pipeline_stages = create_pipeline_stages(conf) pipeline = Pipeline(pipeline_stages) await pipeline.run(conf.input_path)