Skip to content

Commit 0e5037e

Browse files
authored
Merge pull request #2 from aws-samples/dev
support claude3 inference
2 parents a0601e8 + 17a5f76 commit 0e5037e

File tree

3 files changed

+87
-129
lines changed

3 files changed

+87
-129
lines changed

application/pages/1_🌍_Natural_Language_Querying.py

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from nlq.business.nlq_chain import NLQChain
1111
from nlq.business.profile import ProfileManagement
1212
from utils.database import get_db_url_dialect
13-
from utils.llm import claude_to_sql, create_vector_embedding_with_bedrock, retrieve_results_from_opensearch, \
13+
from utils.llm import claude3_to_sql, create_vector_embedding_with_bedrock, retrieve_results_from_opensearch, \
1414
upload_results_to_opensearch
1515

1616
def sample_question_clicked(sample):
@@ -128,10 +128,8 @@ def main():
128128
if 'nlq_chain' not in st.session_state:
129129
st.session_state['nlq_chain'] = None
130130

131-
bedrock_model_ids = ['anthropic.claude-v2:1', 'anthropic.claude-3-sonnet-20240229-v1:0',
132-
'anthropic.claude-v2', 'anthropic.claude-v1',
133-
'meta.llama2-70b-chat-v1', 'mistral.mistral-7b-instruct-v0:2',
134-
'mistral.mixtral-8x7b-instruct-v0:1']
131+
bedrock_model_ids = ['anthropic.claude-3-sonnet-20240229-v1:0', 'anthropic.claude-3-haiku-20240307-v1:0',
132+
'anthropic.claude-v2:1']
135133

136134
with st.sidebar:
137135
st.title('Setting')
@@ -147,14 +145,6 @@ def main():
147145
st.session_state['option'] = st.selectbox("Choose your option", ["Text2SQL"])
148146
model_type = st.selectbox("Choose your model", bedrock_model_ids)
149147
model_provider = None
150-
# Commented because if only for specific customer's demo
151-
# model_provider = st.text_input("Model Provider", value='replicate')
152-
# model_type = st.selectbox("Model", ['CodeLlama', 'DeepSeek'])
153-
# model_type_def = {
154-
# 'DeepSeek': 'kcaverly/deepseek-coder-33b-instruct-gguf:ea964345066a8868e43aca432f314822660b72e29cab6b4b904b779014fe58fd',
155-
# 'CodeLlama': 'meta/codellama-34b-instruct:eeb928567781f4e90d2aba57a51baef235de53f907c214a4ab42adabf5bb9736',
156-
# }
157-
# model_type = model_type_def[model_type]
158148

159149
use_rag = st.checkbox("Using RAG from Q/A Embedding", True)
160150
visualize_results = st.checkbox("Visualize Results", True)
@@ -276,13 +266,13 @@ def main():
276266
conn_name = database_profile['conn_name']
277267
db_url = ConnectionManagement.get_db_url_by_name(conn_name)
278268
database_profile['db_url'] = db_url
279-
response = claude_to_sql(database_profile['tables_info'],
280-
database_profile['hints'],
281-
search_box,
282-
model_id=model_type,
283-
examples=retrieve_result,
284-
dialect=get_db_url_dialect(database_profile['db_url']),
285-
model_provider=model_provider)
269+
response = claude3_to_sql(database_profile['tables_info'],
270+
database_profile['hints'],
271+
search_box,
272+
model_id=model_type,
273+
examples=retrieve_result,
274+
dialect=get_db_url_dialect(database_profile['db_url']),
275+
model_provider=model_provider)
286276

287277
logger.info(f'got llm response: {response}')
288278
current_nlq_chain.set_generated_sql_response(response)

application/utils/llm.py

Lines changed: 60 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from botocore.config import Config
66
from opensearchpy import OpenSearch
77
from utils import opensearch
8+
from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \
9+
DEFAULT_DIALECT_PROMPT
810
import os
911
from loguru import logger
1012

@@ -32,144 +34,93 @@ def get_bedrock_client():
3234
return bedrock
3335

3436

35-
def sqlcoder(SQLCODER_API_ENDPOINT, payload):
36-
headers = {
37-
'Content-Type': 'application/json'
38-
}
39-
response = requests.post(SQLCODER_API_ENDPOINT, headers=headers, data=json.dumps(payload))
40-
return response
41-
42-
43-
def invoke_model(payload, model_id):
44-
body = json.dumps(payload)
45-
46-
accept = 'application/json'
47-
contentType = 'application/json'
37+
def invoke_model_claude3(model_id, system_prompt, messages, max_tokens):
38+
body = json.dumps(
39+
{
40+
"anthropic_version": "bedrock-2023-05-31",
41+
"max_tokens": max_tokens,
42+
"system": system_prompt,
43+
"messages": messages,
44+
"temperature": 0.01
45+
}
46+
)
4847

49-
response = get_bedrock_client().invoke_model(body=body, modelId=model_id, accept=accept, contentType=contentType)
48+
response = get_bedrock_client().invoke_model(body=body, modelId=model_id)
5049
response_body = json.loads(response.get('body').read())
51-
# logger.info(f'{response_body=}')
52-
if 'anthropic.claude' in model_id:
53-
result_key = 'completion'
54-
elif 'meta.llama2' in model_id:
55-
result_key = 'generation'
56-
return response_body[result_key]
50+
51+
return response_body
5752

5853

5954
def claude_select_table():
6055
pass
6156

6257

63-
DEFAULT_DIALECT_PROMPT = '''You are a data analyst who writes SQL statements.'''
64-
65-
TOP_K = 100
66-
POSTGRES_DIALECT_PROMPT = """You are a PostgreSQL expert. Given an input question, first create a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer to the input question.
67-
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. You can order the results to return the most informative data in the database.
68-
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
69-
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
70-
Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today".""".format(top_k=TOP_K)
71-
72-
MYSQL_DIALECT_PROMPT = """You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run.
73-
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database.
74-
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
75-
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
76-
Pay attention to use CURDATE() function to get the current date, if the question involves "today".""".format(top_k=TOP_K)
77-
78-
def claude_to_sql(ddl, hints, search_box, examples=None, model_id='anthropic.claude-v2:1', dialect='mysql', model_provider=None):
58+
def generate_prompt(ddl, hints, search_box, examples=None, model_id=None, dialect='mysql'):
7959
long_string = ""
8060
for table_name, table_data in ddl.items():
81-
8261
ddl_string = table_data["col_a"] if 'col_a' in table_data else table_data["ddl"]
83-
long_string += "-- {}表:{}\n".format(table_name, table_data["tbl_a"] if 'tbl_a' in table_data else table_data["description"])
62+
long_string += "-- {}表:{}\n".format(table_name, table_data["tbl_a"] if 'tbl_a' in table_data else table_data[
63+
"description"])
8464
long_string += ddl_string
8565
long_string += "\n"
8666

8767
ddl = long_string
8868

8969
logger.info(f'{dialect=}')
9070
if dialect == 'postgresql':
91-
dialect_prompt = POSTGRES_DIALECT_PROMPT
71+
dialect_prompt = POSTGRES_DIALECT_PROMPT_CLAUDE3
9272
elif dialect == 'mysql':
93-
dialect_prompt = MYSQL_DIALECT_PROMPT
73+
dialect_prompt = MYSQL_DIALECT_PROMPT_CLAUDE3
9474
elif dialect == 'redshift':
95-
dialect_prompt = '''You are a Amazon Redshift expert. Given an input question, first create a syntactically correct Redshift query to run, then look at the results of the query and return the answer to the input question.'''
75+
dialect_prompt = '''You are a Amazon Redshift expert. Given an input question, first create a syntactically
76+
correct Redshift query to run, then look at the results of the query and return the answer to the input
77+
question.'''
9678
else:
9779
dialect_prompt = DEFAULT_DIALECT_PROMPT
9880

9981
example_prompt = ""
10082
if not examples:
101-
prompt = '''Human:
102-
{dialect_prompt}
103-
Here is DDL of the database you are working on:
104-
```sql
105-
{ddl}
106-
```
107-
Please do not perform any modifications to SQL tables.
108-
Absolutely do not output any columns, tables, or other information that is not mentioned in the database. Ensure that the program runs without errors.
109-
Here are some hints:
110-
{hints}
111-
You need to answer the question: "{question}" in SQL. Please give the SQL statement that can answer the question. Aside from giving the SQL answer, concisely explain yourself after giving the answer in same language as the question.
112-
Assistant:'''.format(dialect_prompt=dialect_prompt, ddl=ddl, hints=hints, question=search_box)
113-
logger.info(f'{prompt=}')
114-
claude_prompt = prompt
83+
claude_prompt = '''Here is DDL of the database you are working on:
84+
```
85+
{ddl}
86+
```
87+
Here are some hints: {hints}
88+
You need to answer the question: "{question}" in SQL.
89+
'''.format(ddl=ddl, hints=hints, question=search_box)
11590
else:
11691
# assemble examples into a string
11792
for item in examples:
11893
example_prompt += "Q: " + item['_source']['text'] + "\n"
11994
example_prompt += "A: ```sql\n" + item['_source']['sql'] + "```\n"
12095

121-
claude_prompt = '''Human:
122-
{dialect_prompt}
123-
Here is DDL of the database you are working on:
124-
```sql
125-
{ddl}
126-
```
127-
Please do not perform any modifications to SQL tables.
128-
Absolutely do not output any columns, tables, or other information that is not mentioned in the database. Ensure that the program runs without errors.
129-
Here are some hints:
130-
{hints}
131-
DO NOT use window function in another function's argument.
132-
Also, here are some examples of generating SQL using natural language:
133-
{examples}
134-
Now, you need to answer the question: "{question}" in SQL. Please give the SQL statement that can answer the question. Aside from giving the SQL answer, concisely explain yourself after giving the answer in same language as the question.
135-
Assistant:'''.format(dialect_prompt=dialect_prompt, ddl=ddl, hints=hints, examples=example_prompt, question=search_box)
136-
137-
llama_prompt = '''[INST]{dialect_prompt}[/INST]
138-
Here is DDL of the database you are working on:
139-
```sql
140-
{ddl}
141-
```
142-
Please do not perform any modifications to SQL tables.
143-
Absolutely do not output any columns, tables, or other information that is not mentioned in the database. Ensure that the program runs without errors.
144-
Here are some hints:
145-
{hints}
146-
Also, here are some examples of generating SQL using natural language:
147-
{examples}
148-
Now, you need to answer the question: "{question}" in SQL.
149-
Please give the SQL statement that can answer the question.
150-
SQL Query:'''.format(dialect_prompt=dialect_prompt, ddl=ddl, hints=hints, examples=example_prompt, question=search_box)
151-
152-
if model_provider == 'replicate':
153-
response = None
154-
# from utils.opensource_llm import invoke_model_replicate
155-
# response = invoke_model_replicate(model_id, ddl, hints, dialect_prompt, example_prompt, search_box)
156-
else:
157-
payload = {
158-
# "prompt": prompt,
159-
# "max_tokens_to_sample": 1024,
160-
"temperature": 0.01,
161-
"top_p": 0.9,
162-
}
163-
if 'anthropic.claude' in model_id:
164-
payload['max_tokens_to_sample'] = 1024
165-
payload['prompt'] = claude_prompt
166-
elif 'meta.llama2' in model_id:
167-
payload['max_gen_len'] = 1024
168-
payload['prompt'] = llama_prompt
169-
logger.info(f"{payload['prompt']=}")
170-
print(payload['prompt'])
171-
response = invoke_model(payload, model_id=model_id)
172-
return response
96+
claude_prompt = '''Here is DDL of the database you are working on:
97+
```
98+
{ddl}
99+
```
100+
Here are some hints: {hints}
101+
Also, here are some examples of generating SQL using natural language:
102+
{examples}
103+
Now, you need to answer the question: "{question}" in SQL.
104+
'''.format(ddl=ddl, hints=hints, examples=example_prompt, question=search_box)
105+
106+
return claude_prompt, dialect_prompt
107+
108+
109+
@logger.catch
110+
def claude3_to_sql(ddl, hints, search_box, examples=None, model_id=None, dialect='mysql', model_provider=None):
111+
user_prompt, system_prompt = generate_prompt(ddl, hints, search_box, examples, model_id, dialect=dialect)
112+
113+
max_tokens = 2048
114+
115+
# Prompt with user turn only.
116+
user_message = {"role": "user", "content": user_prompt}
117+
messages = [user_message]
118+
logger.info(f'{system_prompt=}')
119+
logger.info(f'{messages=}')
120+
response = invoke_model_claude3(model_id, system_prompt, messages, max_tokens)
121+
final_response = response.get("content")[0].get("text")
122+
123+
return final_response
173124

174125

175126
def create_vector_embedding_with_bedrock(text, index_name):
@@ -210,9 +161,9 @@ def retrieve_results_from_opensearch(index_name, region_name, domain, opensearch
210161
"query": {
211162
"bool": {
212163
"filter": {
213-
"match_phrase": {
164+
"match_phrase": {
214165
"profile": profile_name
215-
}
166+
}
216167
},
217168
"must": [
218169
{

application/utils/prompt.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
DEFAULT_DIALECT_PROMPT = '''You are a data analyst who writes SQL statements.'''
2+
3+
TOP_K = 100
4+
5+
POSTGRES_DIALECT_PROMPT_CLAUDE3 = """Given an input question, first create a syntactically correct PostgreSQL query to run.
6+
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL.
7+
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers.
8+
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
9+
Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today". Aside from giving the SQL answer, concisely explain yourself after giving the answer
10+
in the same language as the question.""".format(top_k=TOP_K)
11+
12+
MYSQL_DIALECT_PROMPT_CLAUDE3 = """Given an input question, create a syntactically correct MySQL query to run.
13+
Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL.
14+
Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers.
15+
Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table.
16+
Pay attention to use CURDATE() function to get the current date, if the question involves "today". Aside from giving the SQL answer, concisely explain yourself after giving the answer
17+
in the same language as the question.""".format(top_k=TOP_K)

0 commit comments

Comments
 (0)