diff --git a/README_CN.md b/README_CN.md index 6770385..d947e8c 100644 --- a/README_CN.md +++ b/README_CN.md @@ -18,15 +18,23 @@ ### 用法 -#### 1. 输入用于授权的 databaseURI。目前支持 `mysql`、`postgresql`、`sqlite`、`sqlserver`、`oracle`,示例格式如下: +#### 1. 输入用于授权的 databaseURI。目前支持 `mysql`、`postgresql`、`sqlite`、`sqlserver`、`oracle`、`clickhouse`,示例格式如下: ```shell mysql+pymysql://root:123456@localhost:3306/test postgresql+psycopg2://postgres:123456@localhost:5432/test sqlite:///test.db mssql+pymssql://:@/?charset=utf8 oracle+oracledb://user:pass@hostname:port[/dbname][?service_name=[&key=value&key=value...]] +clickhouse://default:password@localhost:8123/default +clickhouse+connect://default:password@localhost:8123/mydatabase ``` +**ClickHouse/MyScale 特别说明:** +- 支持 ClickHouse 和 MyScale(基于 ClickHouse 的向量数据库) +- 使用 `clickhouse-connect` Python 驱动 +- 默认端口:8123(HTTP)、9000(Native) +- MyScale 是一个兼容 ClickHouse 的向量数据库,支持向量相似性搜索功能 + > **注意:** 此插件总是在 Docker 中运行,因此 `localhost` 始终指 Docker 内部网络,请尝试使用 `host.docker.internal` 代替。 #### 2. 使用 `SQL 执行` (SQL Execute) 工具从数据库查询数据。 @@ -55,8 +63,38 @@ URL 请求格式示例: curl -X POST 'https://daemon-plugin.dify.dev/o3wvwZfYFLU5iGopr5CxYmGaM5mWV7xf/sql' -H 'Content-Type: application/json' -d '{"query":"select * from test", "format": "md"}' ``` +#### 7. 如何打包plugin。 + +Step 1:install homebrew-dify +```shell +brew install langgenius/dify/dify +``` + +Step 2:install Dify cli +```shell +brew tap langgenius/dify +brew install dify +``` + +Step 3: 验证dify 安装 +```shell +dify --version +``` + +Step 4:package Dify database plugin +```shell +dify plugin package ./dify-plugin-database +``` + ### 更新日志 +#### 0.0.7 +1. 新增 ClickHouse/MyScale 数据库支持 +2. 使用 clickhouse-connect 0.10.0+ 驱动 +3. 支持所有现有的输出格式(JSON、CSV、YAML、XLSX、HTML、Markdown) +4. 支持表结构获取,包括 ClickHouse 特有的引擎、排序键、分区键等信息 +5. 完全兼容 MyScale 向量数据库 + #### 0.0.6 1. 支持在 `get table schema` 工具中获取更多信息,例如表和字段的注释、外键关联索引等 2. support special `schema` of `get table schema` tool @@ -86,4 +124,4 @@ curl -X POST 'https://daemon-plugin.dify.dev/o3wvwZfYFLU5iGopr5CxYmGaM5mWV7xf/sq ### 加群 -![1](_assets/contact.jpg) \ No newline at end of file +![1](_assets/contact.jpg) diff --git a/manifest.yaml b/manifest.yaml index e3f20e1..87e75fc 100644 --- a/manifest.yaml +++ b/manifest.yaml @@ -1,4 +1,4 @@ -version: 0.0.6 +version: 0.0.7 type: plugin author: hjlarry name: database diff --git a/provider/database.py b/provider/database.py index 713db44..f3882d0 100644 --- a/provider/database.py +++ b/provider/database.py @@ -3,17 +3,43 @@ from dify_plugin import ToolProvider from dify_plugin.errors.tool import ToolProviderCredentialValidationError from tools.sql_execute import SQLExecuteTool +from tools.db_utils import is_clickhouse_uri, parse_clickhouse_uri class DatabaseProvider(ToolProvider): def _validate_credentials(self, credentials: dict[str, Any]) -> None: if not credentials.get("db_uri"): return - query = "SELECT 1 FROM DUAL" if "oracle" in credentials.get("db_uri") else "SELECT 1" - try: - for _ in SQLExecuteTool.from_credentials(credentials).invoke( - tool_parameters={"query": query} - ): - pass - except Exception as e: - raise ToolProviderCredentialValidationError(str(e)) + + db_uri = credentials.get("db_uri") + + # 对于 ClickHouse/MyScale,使用原生验证 + if is_clickhouse_uri(db_uri): + try: + import clickhouse_connect + + config = parse_clickhouse_uri(db_uri) + + # 对于 8443 端口,自动添加 SSL 支持 + if config.get('port') == 8443: + config['secure'] = True + + # 尝试建立连接并执行测试查询 + client = clickhouse_connect.get_client(**config) + client.command("SELECT 1") + client.close() + + except ImportError: + raise ToolProviderCredentialValidationError("ClickHouse driver (clickhouse-connect) is not installed") + except Exception as e: + raise ToolProviderCredentialValidationError(f"ClickHouse connection failed: {str(e)}") + else: + # 对于其他数据库,使用原有的验证逻辑 + query = "SELECT 1 FROM DUAL" if "oracle" in db_uri else "SELECT 1" + try: + for _ in SQLExecuteTool.from_credentials(credentials).invoke( + tool_parameters={"query": query} + ): + pass + except Exception as e: + raise ToolProviderCredentialValidationError(str(e)) diff --git a/provider/database.yaml b/provider/database.yaml index 8f2260c..07fe22c 100644 --- a/provider/database.yaml +++ b/provider/database.yaml @@ -18,13 +18,13 @@ extra: credentials_for_provider: db_uri: help: - en_US: For example `mysql+pymysql://:@:/` - zh_Hans: 例如 `mysql+pymysql://:@:/` + en_US: For example `mysql+pymysql://:@:/`, `clickhouse://:@:/`, `myscale://:@:/` + zh_Hans: 例如 `mysql+pymysql://:@:/`、`clickhouse://:@:/`、`myscale://:@:/` label: en_US: Database URI zh_Hans: 数据库 URI placeholder: - en_US: Please enter the database URI - zh_Hans: 请输入数据库 URI + en_US: Please enter the database URI (supports MySQL, PostgreSQL, SQLite, SQL Server, Oracle, ClickHouse, MyScale) + zh_Hans: 请输入数据库 URI(支持 MySQL、PostgreSQL、SQLite、SQL Server、Oracle、ClickHouse、MyScale) required: false type: secret-input diff --git a/requirements.txt b/requirements.txt index 2b668fe..c9c3382 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,10 @@ dify_plugin>=0.1.0,<0.2.0 -records[all]~=0.6.0 +sqlalchemy~=2.0.0 pymysql~=1.1.1 pymssql~=2.3.2 oracledb~=2.2.1 psycopg2-binary~=2.9.10 cryptography~=44.0.2 -pandas~=2.2.3 \ No newline at end of file +pandas~=2.2.3 +clickhouse-connect>=0.10.0 +tabulate>=0.9.0 \ No newline at end of file diff --git a/tools/db_utils.py b/tools/db_utils.py index 89e5d62..9e5fc1f 100644 --- a/tools/db_utils.py +++ b/tools/db_utils.py @@ -65,4 +65,109 @@ def fix_db_uri_encoding(db_uri: str) -> str: except Exception: # 如果解析失败,返回原始 URI - return db_uri \ No newline at end of file + return db_uri + + +def is_clickhouse_uri(db_uri: str) -> bool: + """ + 检查是否为 ClickHouse/MyScale 数据库 URI + + Args: + db_uri: 数据库 URI + + Returns: + bool: 是否为 ClickHouse URI + """ + return db_uri.startswith(('clickhouse://', 'clickhouse+connect://', 'myscale://', 'myscale+connect://')) + + +def parse_clickhouse_uri(db_uri: str) -> dict: + """ + 解析 ClickHouse/MyScale 连接字符串 + + Args: + db_uri: ClickHouse/MyScale URI,格式如: + clickhouse://user:password@host:port/database + clickhouse+connect://user:password@host:port/database + myscale://user:password@host:port/database + myscale+connect://user:password@host:port/database + + Returns: + dict: 包含连接参数的字典 + """ + try: + # 移除协议前缀 + if db_uri.startswith('clickhouse+connect://'): + db_uri = db_uri.replace('clickhouse+connect://', '') + elif db_uri.startswith('clickhouse://'): + db_uri = db_uri.replace('clickhouse://', '') + elif db_uri.startswith('myscale+connect://'): + db_uri = db_uri.replace('myscale+connect://', '') + elif db_uri.startswith('myscale://'): + db_uri = db_uri.replace('myscale://', '') + + # 解析用户名、密码、主机、端口、数据库 + if '@' in db_uri: + auth_part, host_part = db_uri.split('@', 1) + + # 解析用户名和密码 + if ':' in auth_part: + username, password = auth_part.split(':', 1) + else: + username, password = auth_part, '' + + # 解析主机、端口和数据库 + if '/' in host_part: + host_db_part = host_part.split('/', 1) + host_port = host_db_part[0] + database = host_db_part[1] if len(host_db_part) > 1 else 'default' + else: + host_port = host_part + database = 'default' + + # 解析主机和端口 + if ':' in host_port: + host, port = host_port.split(':', 1) + else: + host = host_port + port = 8123 # ClickHouse 默认端口 + + return { + 'host': host, + 'port': int(port), + 'username': username, + 'password': password, + 'database': database + } + else: + # 没有认证信息的情况 + if '/' in db_uri: + host_db_part = db_uri.split('/', 1) + host_port = host_db_part[0] + database = host_db_part[1] if len(host_db_part) > 1 else 'default' + else: + host_port = db_uri + database = 'default' + + if ':' in host_port: + host, port = host_port.split(':', 1) + else: + host = host_port + port = 8123 + + return { + 'host': host, + 'port': int(port), + 'username': 'default', + 'password': '', + 'database': database + } + except Exception: + # 解析失败时返回默认配置 + return { + 'host': 'localhost', + 'port': 8123, + 'username': 'default', + 'password': '', + 'database': 'default' + } \ No newline at end of file diff --git a/tools/sql_execute.py b/tools/sql_execute.py index 770f2e1..d7119e9 100644 --- a/tools/sql_execute.py +++ b/tools/sql_execute.py @@ -3,11 +3,10 @@ import re import json -import records -from sqlalchemy import text +from sqlalchemy import create_engine, text from dify_plugin import Tool from dify_plugin.entities.tool import ToolInvokeMessage -from tools.db_utils import fix_db_uri_encoding +from tools.db_utils import fix_db_uri_encoding, is_clickhouse_uri, parse_clickhouse_uri class SQLExecuteTool(Tool): @@ -15,61 +14,267 @@ def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessag db_uri = tool_parameters.get("db_uri") or self.runtime.credentials.get("db_uri") if not db_uri: raise ValueError("Database URI is not provided.") - - # 修复 db_uri 中的特殊字符编码问题 - db_uri = fix_db_uri_encoding(db_uri) - + query = tool_parameters.get("query").strip() format = tool_parameters.get("format", "json") config_options = tool_parameters.get("config_options") or "{}" - try: - config_options = json.loads(config_options) - except json.JSONDecodeError: - raise ValueError("Invalid JSON format for Connect Config") - db = records.Database(db_uri, **config_options) + # 检查是否为 ClickHouse/MyScale 数据库 + if is_clickhouse_uri(db_uri): + # 处理 ClickHouse/MyScale 连接 + config = parse_clickhouse_uri(db_uri) + + # 解析额外的配置选项 + try: + extra_options = json.loads(config_options) + config.update(extra_options) + except json.JSONDecodeError: + raise ValueError("Invalid JSON format for Connect Config") + + return self._handle_clickhouse_query(config, query, format) + else: + # 处理其他数据库类型(原有逻辑) + db_uri = fix_db_uri_encoding(db_uri) + + try: + config_options = json.loads(config_options) + except json.JSONDecodeError: + raise ValueError("Invalid JSON format for Connect Config") + engine = create_engine(db_uri, **config_options) + + return self._handle_sqlalchemy_query(engine, query, format) + + def _handle_clickhouse_query(self, config: dict, query: str, format: str) -> Generator[ToolInvokeMessage]: + """处理 ClickHouse/MyScale 查询""" try: - if re.match(r'^\s*(SELECT|WITH)\s+', query, re.IGNORECASE): - rows = db.query(query) - if format == "json": - result = rows.as_dict() - yield self.create_json_message({"result": result}) - elif format == "md": - result = str(rows.dataset) - yield self.create_text_message(result) - elif format == "csv": - result = rows.export("csv").encode() - yield self.create_blob_message( - result, meta={"mime_type": "text/csv", "filename": "result.csv"} - ) - elif format == "yaml": - result = rows.export("yaml").encode() - yield self.create_blob_message( - result, - meta={"mime_type": "text/yaml", "filename": "result.yaml"}, - ) - elif format == "xlsx": - result = rows.export("xlsx") - yield self.create_blob_message( - result, - meta={ - "mime_type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", - "filename": "result.xlsx", - }, - ) - elif format == "html": - result = rows.export("html").encode() - yield self.create_blob_message( - result, - meta={"mime_type": "text/html", "filename": "result.html"}, + import clickhouse_connect + + # 建立 ClickHouse 连接 + client = clickhouse_connect.get_client(**config) + + try: + # 检查查询类型 + if re.match(r'^\s*(SELECT|WITH|SHOW|DESCRIBE|EXISTS)\s+', query, re.IGNORECASE): + # 查询类语句 + result = client.query(query) + + if format == "json": + # 转换为 JSON 格式 + data = [] + for row in result.result_rows: + row_dict = {} + for i, column_name in enumerate(result.column_names): + value = row[i] + # 处理日期和其他不可序列化的类型 + if hasattr(value, 'isoformat'): # datetime/date 对象 + value = value.isoformat() + elif hasattr(value, '__str__'): # 其他对象 + value = str(value) + row_dict[column_name] = value + data.append(row_dict) + yield self.create_json_message({"result": data}) + + elif format == "md": + # 生成 Markdown 表格 + if result.column_names and result.result_rows: + # 表头 + header = "| " + " | ".join(result.column_names) + " |" + separator = "| " + " | ".join(["---"] * len(result.column_names)) + " |" + + # 数据行 + rows = [] + for row in result.result_rows: + row_str = "| " + " | ".join(str(cell) if cell is not None else "NULL" for cell in row) + " |" + rows.append(row_str) + + markdown_table = "\n".join([header, separator] + rows) + yield self.create_text_message(markdown_table) + else: + yield self.create_text_message("Query returned no results") + + elif format == "csv": + # 生成 CSV 格式 + import io + import csv + + output = io.StringIO() + writer = csv.writer(output) + + # 写入表头 + if result.column_names: + writer.writerow(result.column_names) + + # 写入数据 + for row in result.result_rows: + writer.writerow(row) + + csv_data = output.getvalue() + yield self.create_blob_message( + csv_data.encode('utf-8'), + meta={"mime_type": "text/csv", "filename": "result.csv"} + ) + + elif format == "yaml": + # 生成 YAML 格式 + import yaml + + data = [] + for row in result.result_rows: + row_dict = {} + for i, column_name in enumerate(result.column_names): + row_dict[column_name] = row[i] + data.append(row_dict) + + yaml_data = yaml.dump({"result": data}, default_flow_style=False, allow_unicode=True) + yield self.create_blob_message( + yaml_data.encode('utf-8'), + meta={"mime_type": "text/yaml", "filename": "result.yaml"} + ) + + elif format == "xlsx": + # 生成 Excel 格式 + import pandas as pd + + # 创建 DataFrame + df = pd.DataFrame(result.result_rows, columns=result.column_names) + + # 保存为 Excel + import io + output = io.BytesIO() + with pd.ExcelWriter(output, engine='openpyxl') as writer: + df.to_excel(writer, index=False, sheet_name='Results') + + yield self.create_blob_message( + output.getvalue(), + meta={ + "mime_type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "filename": "result.xlsx", + }, + ) + + elif format == "html": + # 生成 HTML 格式 + html_data = "\n" + + # 表头 + if result.column_names: + html_data += "" + for col in result.column_names: + html_data += f"" + html_data += "\n" + + # 数据 + html_data += "\n" + for row in result.result_rows: + html_data += "" + for cell in row: + html_data += f"" + html_data += "\n" + html_data += "\n
{col}
{cell if cell is not None else 'NULL'}
" + + yield self.create_blob_message( + html_data.encode('utf-8'), + meta={"mime_type": "text/html", "filename": "result.html"}, + ) + + else: + raise ValueError(f"Unsupported format: {format}") + + else: + # 非查询语句(INSERT, UPDATE, DELETE, CREATE, DROP, etc.) + command_result = client.command(query) + affected_rows = command_result if isinstance(command_result, int) else 0 + yield self.create_text_message( + f"Query executed successfully. Affected rows: {affected_rows}" ) + + finally: + client.close() + + except ImportError: + raise ValueError("ClickHouse driver (clickhouse-connect) is not installed. Please add it to requirements.txt") + except Exception as e: + yield self.create_text_message(f"Error: {str(e)}") + + def _handle_sqlalchemy_query(self, engine, query: str, format: str) -> Generator[ToolInvokeMessage]: + """处理 SQLAlchemy 支持的数据库查询""" + try: + with engine.connect() as conn: + if re.match(r'^\s*(SELECT|WITH)\s+', query, re.IGNORECASE): + # 查询语句 + result = conn.execute(text(query)) + rows = result.fetchall() + columns = result.keys() + + if format == "json": + result_data = [dict(zip(columns, row)) for row in rows] + yield self.create_json_message({"result": result_data}) + elif format == "md": + from tabulate import tabulate + if rows: + table = tabulate(rows, headers=columns, tablefmt="pipe") + yield self.create_text_message(table) + else: + yield self.create_text_message("Query returned no results") + elif format == "csv": + import io + import csv + output = io.StringIO() + writer = csv.writer(output) + writer.writerow(columns) + writer.writerows(rows) + csv_data = output.getvalue() + yield self.create_blob_message( + csv_data.encode('utf-8'), + meta={"mime_type": "text/csv", "filename": "result.csv"} + ) + elif format == "yaml": + import yaml + result_data = [dict(zip(columns, row)) for row in rows] + yaml_data = yaml.dump({"result": result_data}, default_flow_style=False, allow_unicode=True) + yield self.create_blob_message( + yaml_data.encode('utf-8'), + meta={"mime_type": "text/yaml", "filename": "result.yaml"}, + ) + elif format == "xlsx": + import pandas as pd + df = pd.DataFrame(rows, columns=columns) + import io + output = io.BytesIO() + with pd.ExcelWriter(output, engine='openpyxl') as writer: + df.to_excel(writer, index=False, sheet_name='Results') + yield self.create_blob_message( + output.getvalue(), + meta={ + "mime_type": "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + "filename": "result.xlsx", + }, + ) + elif format == "html": + html_data = "\n" + if columns: + html_data += "" + for col in columns: + html_data += f"" + html_data += "\n" + html_data += "\n" + for row in rows: + html_data += "" + for cell in row: + html_data += f"" + html_data += "\n" + html_data += "\n
{col}
{cell if cell is not None else 'NULL'}
" + yield self.create_blob_message( + html_data.encode('utf-8'), + meta={"mime_type": "text/html", "filename": "result.html"}, + ) + else: + raise ValueError(f"Unsupported format: {format}") else: - raise ValueError(f"Unsupported format: {format}") - else: - with db.get_connection() as conn: - trans = conn._conn.begin() + # 非查询语句 + trans = conn.begin() try: - result = conn._conn.execute(text(query)) + result = conn.execute(text(query)) affected_rows = result.rowcount trans.commit() yield self.create_text_message( @@ -79,4 +284,4 @@ def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessag trans.rollback() yield self.create_text_message(f"Error: {str(e)}") finally: - db.close() + engine.dispose() diff --git a/tools/table_schema.py b/tools/table_schema.py index edada6a..7fa5f54 100644 --- a/tools/table_schema.py +++ b/tools/table_schema.py @@ -5,7 +5,7 @@ from sqlalchemy import create_engine, inspect from dify_plugin import Tool from dify_plugin.entities.tool import ToolInvokeMessage -from tools.db_utils import fix_db_uri_encoding +from tools.db_utils import fix_db_uri_encoding, is_clickhouse_uri, parse_clickhouse_uri class QueryTool(Tool): @@ -13,23 +13,137 @@ def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessag db_uri = tool_parameters.get("db_uri") or self.runtime.credentials.get("db_uri") if not db_uri: raise ValueError("Database URI is not provided.") - - # 修复 db_uri 中的特殊字符编码问题 - db_uri = fix_db_uri_encoding(db_uri) - - config_options = tool_parameters.get("config_options") or "{}" - try: - config_options = json.loads(config_options) - except json.JSONDecodeError: - raise ValueError("Invalid JSON format for Connect Config") - engine = create_engine(db_uri, **config_options) - inspector = inspect(engine) tables = tool_parameters.get("tables") schema = tool_parameters.get("schema") if not schema: # sometimes the schema is empty string, it must be None schema = None + + # 检查是否为 ClickHouse/MyScale 数据库 + if is_clickhouse_uri(db_uri): + # 处理 ClickHouse/MyScale + config = parse_clickhouse_uri(db_uri) + + config_options = tool_parameters.get("config_options") or "{}" + try: + extra_options = json.loads(config_options) + config.update(extra_options) + except json.JSONDecodeError: + raise ValueError("Invalid JSON format for Connect Config") + + return self._get_clickhouse_schema(config, tables, schema) + else: + # 处理其他数据库类型(原有逻辑) + db_uri = fix_db_uri_encoding(db_uri) + + config_options = tool_parameters.get("config_options") or "{}" + try: + config_options = json.loads(config_options) + except json.JSONDecodeError: + raise ValueError("Invalid JSON format for Connect Config") + engine = create_engine(db_uri, **config_options) + inspector = inspect(engine) + + return self._get_sqlalchemy_schema(inspector, engine, tables, schema) + + def _get_clickhouse_schema(self, config: dict, tables: str, schema: str) -> Generator[ToolInvokeMessage]: + """获取 ClickHouse/MyScale 表结构""" + try: + import clickhouse_connect + + client = clickhouse_connect.get_client(**config) + + try: + # 获取所有表名或使用指定的表名 + if not tables: + # 查询所有表名 + tables_query = "SELECT name FROM system.tables WHERE database = currentDatabase()" + result = client.query(tables_query) + tables_list = [row[0] for row in result.result_rows] + else: + tables_list = [t.strip() for t in tables.split(",")] + + schema_info = {} + + for table_name in tables_list: + try: + # 获取表结构信息 + columns_query = f""" + SELECT + name, + type, + is_in_primary_key, + is_in_sorting_key, + is_in_partition_key, + comment + FROM system.columns + WHERE database = currentDatabase() AND table = '{table_name}' + ORDER BY position + """ + + columns_result = client.query(columns_query) + + table_info = { + "table_name": table_name, + "columns": [], + "primary_keys": [], + "sorting_keys": [], + "partition_keys": [], + "engine": "", + "comment": "" + } + + # 获取表的引擎和注释 + try: + table_query = f""" + SELECT engine, comment + FROM system.tables + WHERE database = currentDatabase() AND name = '{table_name}' + """ + table_result = client.query(table_query) + if table_result.result_rows: + row = table_result.result_rows[0] + table_info["engine"] = row[0] + table_info["comment"] = row[1] or "" + except Exception: + pass + + # 处理列信息 + for row in columns_result.result_rows: + column_info = { + "name": row[0], + "type": row[1], + "nullable": "Nullable" in str(row[1]), + "default": None, + "comment": row[5] or "" + } + table_info["columns"].append(column_info) + + if row[2]: # is_in_primary_key + table_info["primary_keys"].append(row[0]) + if row[3]: # is_in_sorting_key + table_info["sorting_keys"].append(row[0]) + if row[4]: # is_in_partition_key + table_info["partition_keys"].append(row[0]) + + schema_info[table_name] = table_info + + except Exception as e: + schema_info[table_name] = f"Error getting table schema: {str(e)}" + + yield self.create_text_message(json.dumps(schema_info, ensure_ascii=False)) + + finally: + client.close() + + except ImportError: + raise ValueError("ClickHouse driver (clickhouse-connect) is not installed. Please add it to requirements.txt") + except Exception as e: + yield self.create_text_message(f"Error: {str(e)}") + + def _get_sqlalchemy_schema(self, inspector, engine, tables: str, schema: str) -> Generator[ToolInvokeMessage]: + """获取 SQLAlchemy 支持的数据库表结构""" tables = tables.split(",") if tables else inspector.get_table_names(schema=schema) schema_info = {} @@ -43,13 +157,13 @@ def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessag "foreign_keys": [], "indexes": [] } - + # Get table comment try: table_info["comment"] = inspector.get_table_comment(table_name, schema=schema).get('text', '') except NotImplementedError: table_info["comment"] = "" - + # Get foreign keys try: for fk in inspector.get_foreign_keys(table_name, schema=schema): @@ -60,7 +174,7 @@ def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessag }) except NotImplementedError: pass - + # Get indexes try: for idx in inspector.get_indexes(table_name, schema=schema): @@ -71,7 +185,7 @@ def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessag }) except NotImplementedError: pass - + # Get columns try: columns = inspector.get_columns(table_name, schema=schema) @@ -85,7 +199,7 @@ def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessag } for col in columns ] - + schema_info[table_name] = table_info except Exception as e: schema_info[table_name] = f"Error getting schema: {str(e)}" diff --git a/tools/text2sql.py b/tools/text2sql.py index 22f099f..510cc15 100644 --- a/tools/text2sql.py +++ b/tools/text2sql.py @@ -8,7 +8,7 @@ from dify_plugin import Tool from dify_plugin.entities.tool import ToolInvokeMessage from dify_plugin.entities.model.message import SystemPromptMessage, UserPromptMessage -from tools.db_utils import fix_db_uri_encoding +from tools.db_utils import fix_db_uri_encoding, is_clickhouse_uri, parse_clickhouse_uri SYSTEM_PROMPT_TEMPLATE = """ You are a {dialect} expert. Your task is to generate an executable {dialect} query based on the user's question. @@ -66,40 +66,53 @@ def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessag db_uri = tool_parameters.get("db_uri") or self.runtime.credentials.get("db_uri") if not db_uri: raise ValueError("Database URI is not provided.") - - # 修复 db_uri 中的特殊字符编码问题 - db_uri = fix_db_uri_encoding(db_uri) - - config_options = tool_parameters.get("config_options") or "{}" - try: - config_options = json.loads(config_options) - except json.JSONDecodeError: - raise ValueError("Invalid JSON format for Connect Config") - engine = create_engine(db_uri, **config_options) - inspector = inspect(engine) - dialect = engine.dialect.name - - tables = tool_parameters.get("tables") - tables = tables.split(",") if tables else inspector.get_table_names() - schema_info = {} - with engine.connect() as _: - - for table_name in tables: - try: - columns = inspector.get_columns(table_name) - schema_info[table_name] = [ - { - "name": col["name"], - "type": str(col["type"]), - "nullable": col.get("nullable", True), - "default": col.get("default"), - "primary_key": col.get("primary_key", False), - } - for col in columns - ] - except Exception as e: - schema_info[table_name] = f"Error getting schema: {str(e)}" + # 检查是否为 ClickHouse/MyScale 数据库 + if is_clickhouse_uri(db_uri): + # 处理 ClickHouse/MyScale + config = parse_clickhouse_uri(db_uri) + config_options = tool_parameters.get("config_options") or "{}" + try: + extra_options = json.loads(config_options) + config.update(extra_options) + except json.JSONDecodeError: + raise ValueError("Invalid JSON format for Connect Config") + + dialect = "clickhouse" # 设置方言为 clickhouse + schema_info = self._get_clickhouse_schema(config) + else: + # 处理其他数据库类型(原有逻辑) + db_uri = fix_db_uri_encoding(db_uri) + + config_options = tool_parameters.get("config_options") or "{}" + try: + config_options = json.loads(config_options) + except json.JSONDecodeError: + raise ValueError("Invalid JSON format for Connect Config") + engine = create_engine(db_uri, **config_options) + inspector = inspect(engine) + dialect = engine.dialect.name + + tables = tool_parameters.get("tables") + tables = tables.split(",") if tables else inspector.get_table_names() + + schema_info = {} + with engine.connect() as _: + for table_name in tables: + try: + columns = inspector.get_columns(table_name) + schema_info[table_name] = [ + { + "name": col["name"], + "type": str(col["type"]), + "nullable": col.get("nullable", True), + "default": col.get("default"), + "primary_key": col.get("primary_key", False), + } + for col in columns + ] + except Exception as e: + schema_info[table_name] = f"Error getting schema: {str(e)}" prompt_messages = [ SystemPromptMessage(content=SYSTEM_PROMPT_TEMPLATE.format(dialect=dialect)), UserPromptMessage( @@ -115,3 +128,74 @@ def _invoke(self, tool_parameters: dict[str, Any]) -> Generator[ToolInvokeMessag stream=False, ) yield self.create_text_message(response.message.content) + + def _get_clickhouse_schema(self, config: dict) -> dict: + """获取 ClickHouse/MyScale 表结构信息""" + schema_info = {} + try: + import clickhouse_connect + + client = clickhouse_connect.get_client(**config) + + try: + # 获取所有表名 + tables_query = "SELECT name FROM system.tables WHERE database = currentDatabase()" + result = client.query(tables_query) + tables_list = [row[0] for row in result.result_rows] + + for table_name in tables_list: + try: + # 获取表结构信息 + columns_query = f""" + SELECT + name, + type, + is_in_primary_key, + is_in_sorting_key, + comment + FROM system.columns + WHERE database = currentDatabase() AND table = '{table_name}' + ORDER BY position + """ + columns_result = client.query(columns_query) + + schema_info[table_name] = [] + for row in columns_result.result_rows: + schema_info[table_name].append({ + "name": row[0], + "type": str(row[1]), + "nullable": "Nullable" in str(row[1]), + "default": None, + "primary_key": row[2] if len(row) > 2 else False, + }) + + except Exception as e: + schema_info[table_name] = f"Error getting table schema: {str(e)}" + + finally: + client.close() + + except ImportError: + raise ValueError("ClickHouse driver (clickhouse-connect) is not installed") + except Exception as e: + # 如果连接失败,返回一个示例表结构 + return { + "example_table": [ + { + "name": "id", + "type": "UInt64", + "nullable": False, + "default": None, + "primary_key": True, + }, + { + "name": "name", + "type": "String", + "nullable": True, + "default": None, + "primary_key": False, + } + ] + } + + return schema_info