Skip to content

Commit fdb4b7f

Browse files
authored
Merge pull request #205 from aws-samples/204-support-clickhouse-as-data-source
support clickhouse db as datasource
2 parents bd9b8bf + f752da4 commit fdb4b7f

File tree

8 files changed

+31
-40
lines changed

8 files changed

+31
-40
lines changed

application/api/enum.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class ErrorEnum(Enum):
88
NOT_SUPPORTED = {1001: "Your query statement is currently not supported by the system"}
99
INVAILD_BEDROCK_MODEL_ID = {1002: f"Invalid bedrock model id.Vaild ids:{BEDROCK_MODEL_IDS}"}
1010
INVAILD_SESSION_ID = {1003: f"Invalid session id."}
11+
PROFILE_NOT_FOUND = {1004: "Profile name not found."}
1112
UNKNOWN_ERROR = {9999: "Unknown error."}
1213

1314
def get_code(self):

application/api/service.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ def ask(question: Question) -> Answer:
142142
log_info = ""
143143

144144
all_profiles = ProfileManagement.get_all_profiles_with_info()
145+
if selected_profile not in all_profiles:
146+
raise BizException(ErrorEnum.PROFILE_NOT_FOUND)
145147
database_profile = all_profiles[selected_profile]
146148

147149
current_nlq_chain = NLQChain(selected_profile)

application/nlq/data_access/database.py

Lines changed: 15 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ class RelationDatabase():
1313
'mysql': 'mysql+pymysql',
1414
'postgresql': 'postgresql+psycopg2',
1515
'redshift': 'postgresql+psycopg2',
16-
'starrocks': 'starrocks'
16+
'starrocks': 'starrocks',
17+
'clickhouse': 'clickhouse',
1718
# Add more mappings here for other databases
1819
}
1920

@@ -42,43 +43,20 @@ def test_connection(cls, db_type, user, password, host, port, db_name) -> bool:
4243

4344
@classmethod
4445
def get_all_schema_names_by_connection(cls, connection: ConnectConfigEntity):
45-
schemas = []
46-
if connection.db_type == 'postgresql':
47-
db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host,
48-
connection.db_port, connection.db_name)
49-
engine = db.create_engine(db_url)
50-
# with engine.connect() as conn:
51-
# query = text("""
52-
# SELECT nspname AS schema_name
53-
# FROM pg_catalog.pg_namespace
54-
# WHERE nspname !~ '^pg_' AND nspname <> 'information_schema' AND nspname <> 'public'
55-
# AND has_schema_privilege(nspname, 'USAGE');
56-
# """)
57-
#
58-
# # Executing the query
59-
# result = conn.execute(query)
60-
# schemas = [row['schema_name'] for row in result.mappings()]
61-
# print(schemas)
62-
inspector = sqlalchemy.inspect(engine)
63-
schemas = inspector.get_schema_names()
64-
elif connection.db_type == 'redshift':
65-
db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host,
66-
connection.db_port, connection.db_name)
67-
engine = db.create_engine(db_url)
68-
inspector = inspect(engine)
46+
db_type = connection.db_type
47+
db_url = cls.get_db_url(db_type, connection.db_user, connection.db_pwd, connection.db_host, connection.db_port,
48+
connection.db_name)
49+
engine = db.create_engine(db_url)
50+
inspector = inspect(engine)
51+
52+
if db_type == 'postgresql':
53+
schemas = [schema for schema in inspector.get_schema_names() if
54+
schema not in ('pg_catalog', 'information_schema', 'public')]
55+
elif db_type in ('redshift', 'mysql', 'starrocks', 'clickhouse'):
6956
schemas = inspector.get_schema_names()
70-
elif connection.db_type == 'mysql':
71-
db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host,
72-
connection.db_port, connection.db_name)
73-
engine = db.create_engine(db_url)
74-
database_connect = sqlalchemy.inspect(engine)
75-
schemas = database_connect.get_schema_names()
76-
elif connection.db_type == 'starrocks':
77-
db_url = cls.get_db_url(connection.db_type, connection.db_user, connection.db_pwd, connection.db_host,
78-
connection.db_port, connection.db_name)
79-
engine = db.create_engine(db_url)
80-
database_connect = sqlalchemy.inspect(engine)
81-
schemas = database_connect.get_schema_names()
57+
else:
58+
raise ValueError("Unsupported database type")
59+
8260
return schemas
8361

8462
@classmethod

application/pages/2_🪙_Data_Connection_Management.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
'postgresql': 'PostgreSQL',
1414
'redshift': 'Redshift',
1515
'starrocks': 'StarRocks',
16+
'clickhouse': 'Clickhouse',
1617
}
1718

1819

application/requirements-api.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,5 @@ langchain-core~=0.1.30
1515
sqlparse~=0.4.2
1616
pandas==2.0.3
1717
openpyxl
18-
starrocks==1.0.6
18+
starrocks==1.0.6
19+
clickhouse-sqlalchemy==0.2.6

application/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,4 +14,5 @@ sqlparse~=0.4.2
1414
debugpy
1515
pandas==2.0.3
1616
openpyxl
17-
starrocks==1.0.6
17+
starrocks==1.0.6
18+
clickhouse-sqlalchemy==0.2.6

application/utils/prompt.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,11 @@
2222
question.When generating SQL, do not add double quotes or single quotes around table names. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per StarRocks SQL.
2323
Never query for all columns from a table.""".format(top_k=TOP_K)
2424

25+
CLICKHOUSE_DIALECT_PROMPT_CLAUDE3="""
26+
You are a data analysis expert and proficient in Clickhouse. Given an input question, first create a syntactically correct Clickhouse query to run, then look at the results of the query and return the answer to the input question.
27+
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per ClickHouse. You can order the results to return the most informative data in the database.
28+
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
29+
Pay attention to use today() function to get the current date, if the question involves "today". Pay attention to adapted to the table field type. Please follow the clickhouse syntax or function case specifications.If the field alias contains Chinese characters, please use double quotes to Wrap it.""".format(top_k=TOP_K)
2530

2631
AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3 = """You are a Amazon Redshift expert. Given an input question, first create a syntactically correct Redshift query to run, then look at the results of the query and return the answer to the input
2732
question.When generating SQL, do not add double quotes or single quotes around table names. Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL.

application/utils/prompts/generate_prompt.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \
2-
DEFAULT_DIALECT_PROMPT, AGENT_COT_EXAMPLE, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3, STARROCKS_DIALECT_PROMPT_CLAUDE3
2+
DEFAULT_DIALECT_PROMPT, AGENT_COT_EXAMPLE, AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3, STARROCKS_DIALECT_PROMPT_CLAUDE3, CLICKHOUSE_DIALECT_PROMPT_CLAUDE3
33
from utils.prompts import guidance_prompt
44
from utils.prompts import table_prompt
55
import logging
@@ -1909,6 +1909,8 @@ def generate_llm_prompt(ddl, hints, prompt_map, search_box, sql_examples=None, n
19091909
dialect_prompt = AWS_REDSHIFT_DIALECT_PROMPT_CLAUDE3
19101910
elif dialect == 'starrocks':
19111911
dialect_prompt = STARROCKS_DIALECT_PROMPT_CLAUDE3
1912+
elif dialect == 'clickhouse':
1913+
dialect_prompt = CLICKHOUSE_DIALECT_PROMPT_CLAUDE3
19121914
else:
19131915
dialect_prompt = DEFAULT_DIALECT_PROMPT
19141916

0 commit comments

Comments
 (0)