|
5 | 5 | from botocore.config import Config
|
6 | 6 | from opensearchpy import OpenSearch
|
7 | 7 | from utils import opensearch
|
| 8 | +from utils.prompt import POSTGRES_DIALECT_PROMPT_CLAUDE3, MYSQL_DIALECT_PROMPT_CLAUDE3, \ |
| 9 | + DEFAULT_DIALECT_PROMPT |
8 | 10 | import os
|
9 | 11 | from loguru import logger
|
10 | 12 |
|
@@ -32,144 +34,93 @@ def get_bedrock_client():
|
32 | 34 | return bedrock
|
33 | 35 |
|
34 | 36 |
|
35 |
| -def sqlcoder(SQLCODER_API_ENDPOINT, payload): |
36 |
| - headers = { |
37 |
| - 'Content-Type': 'application/json' |
38 |
| - } |
39 |
| - response = requests.post(SQLCODER_API_ENDPOINT, headers=headers, data=json.dumps(payload)) |
40 |
| - return response |
41 |
| - |
42 |
| - |
43 |
| -def invoke_model(payload, model_id): |
44 |
| - body = json.dumps(payload) |
45 |
| - |
46 |
| - accept = 'application/json' |
47 |
| - contentType = 'application/json' |
| 37 | +def invoke_model_claude3(model_id, system_prompt, messages, max_tokens): |
| 38 | + body = json.dumps( |
| 39 | + { |
| 40 | + "anthropic_version": "bedrock-2023-05-31", |
| 41 | + "max_tokens": max_tokens, |
| 42 | + "system": system_prompt, |
| 43 | + "messages": messages, |
| 44 | + "temperature": 0.01 |
| 45 | + } |
| 46 | + ) |
48 | 47 |
|
49 |
| - response = get_bedrock_client().invoke_model(body=body, modelId=model_id, accept=accept, contentType=contentType) |
| 48 | + response = get_bedrock_client().invoke_model(body=body, modelId=model_id) |
50 | 49 | response_body = json.loads(response.get('body').read())
|
51 |
| - # logger.info(f'{response_body=}') |
52 |
| - if 'anthropic.claude' in model_id: |
53 |
| - result_key = 'completion' |
54 |
| - elif 'meta.llama2' in model_id: |
55 |
| - result_key = 'generation' |
56 |
| - return response_body[result_key] |
| 50 | + |
| 51 | + return response_body |
57 | 52 |
|
58 | 53 |
|
59 | 54 | def claude_select_table():
|
60 | 55 | pass
|
61 | 56 |
|
62 | 57 |
|
63 |
| -DEFAULT_DIALECT_PROMPT = '''You are a data analyst who writes SQL statements.''' |
64 |
| - |
65 |
| -TOP_K = 100 |
66 |
| -POSTGRES_DIALECT_PROMPT = """You are a PostgreSQL expert. Given an input question, first create a syntactically correct PostgreSQL query to run, then look at the results of the query and return the answer to the input question. |
67 |
| -Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per PostgreSQL. You can order the results to return the most informative data in the database. |
68 |
| -Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in double quotes (") to denote them as delimited identifiers. |
69 |
| -Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. |
70 |
| -Pay attention to use CURRENT_DATE function to get the current date, if the question involves "today".""".format(top_k=TOP_K) |
71 |
| - |
72 |
| -MYSQL_DIALECT_PROMPT = """You are a MySQL expert. Given an input question, create a syntactically correct MySQL query to run. |
73 |
| -Unless the user specifies in the question a specific number of examples to obtain, query for at most {top_k} results using the LIMIT clause as per MySQL. You can order the results to return the most informative data in the database. |
74 |
| -Never query for all columns from a table. You must query only the columns that are needed to answer the question. Wrap each column name in backticks (`) to denote them as delimited identifiers. |
75 |
| -Pay attention to use only the column names you can see in the tables below. Be careful to not query for columns that do not exist. Also, pay attention to which column is in which table. |
76 |
| -Pay attention to use CURDATE() function to get the current date, if the question involves "today".""".format(top_k=TOP_K) |
77 |
| - |
78 |
| -def claude_to_sql(ddl, hints, search_box, examples=None, model_id='anthropic.claude-v2:1', dialect='mysql', model_provider=None): |
| 58 | +def generate_prompt(ddl, hints, search_box, examples=None, model_id=None, dialect='mysql'): |
79 | 59 | long_string = ""
|
80 | 60 | for table_name, table_data in ddl.items():
|
81 |
| - |
82 | 61 | ddl_string = table_data["col_a"] if 'col_a' in table_data else table_data["ddl"]
|
83 |
| - long_string += "-- {}表:{}\n".format(table_name, table_data["tbl_a"] if 'tbl_a' in table_data else table_data["description"]) |
| 62 | + long_string += "-- {}表:{}\n".format(table_name, table_data["tbl_a"] if 'tbl_a' in table_data else table_data[ |
| 63 | + "description"]) |
84 | 64 | long_string += ddl_string
|
85 | 65 | long_string += "\n"
|
86 | 66 |
|
87 | 67 | ddl = long_string
|
88 | 68 |
|
89 | 69 | logger.info(f'{dialect=}')
|
90 | 70 | if dialect == 'postgresql':
|
91 |
| - dialect_prompt = POSTGRES_DIALECT_PROMPT |
| 71 | + dialect_prompt = POSTGRES_DIALECT_PROMPT_CLAUDE3 |
92 | 72 | elif dialect == 'mysql':
|
93 |
| - dialect_prompt = MYSQL_DIALECT_PROMPT |
| 73 | + dialect_prompt = MYSQL_DIALECT_PROMPT_CLAUDE3 |
94 | 74 | elif dialect == 'redshift':
|
95 |
| - dialect_prompt = '''You are a Amazon Redshift expert. Given an input question, first create a syntactically correct Redshift query to run, then look at the results of the query and return the answer to the input question.''' |
| 75 | + dialect_prompt = '''You are a Amazon Redshift expert. Given an input question, first create a syntactically |
| 76 | + correct Redshift query to run, then look at the results of the query and return the answer to the input |
| 77 | + question.''' |
96 | 78 | else:
|
97 | 79 | dialect_prompt = DEFAULT_DIALECT_PROMPT
|
98 | 80 |
|
99 | 81 | example_prompt = ""
|
100 | 82 | if not examples:
|
101 |
| - prompt = '''Human: |
102 |
| -{dialect_prompt} |
103 |
| -Here is DDL of the database you are working on: |
104 |
| -```sql |
105 |
| -{ddl} |
106 |
| -``` |
107 |
| -Please do not perform any modifications to SQL tables. |
108 |
| -Absolutely do not output any columns, tables, or other information that is not mentioned in the database. Ensure that the program runs without errors. |
109 |
| -Here are some hints: |
110 |
| -{hints} |
111 |
| -You need to answer the question: "{question}" in SQL. Please give the SQL statement that can answer the question. Aside from giving the SQL answer, concisely explain yourself after giving the answer in same language as the question. |
112 |
| -Assistant:'''.format(dialect_prompt=dialect_prompt, ddl=ddl, hints=hints, question=search_box) |
113 |
| - logger.info(f'{prompt=}') |
114 |
| - claude_prompt = prompt |
| 83 | + claude_prompt = '''Here is DDL of the database you are working on: |
| 84 | + ``` |
| 85 | + {ddl} |
| 86 | + ``` |
| 87 | + Here are some hints: {hints} |
| 88 | + You need to answer the question: "{question}" in SQL. |
| 89 | + '''.format(ddl=ddl, hints=hints, question=search_box) |
115 | 90 | else:
|
116 | 91 | # assemble examples into a string
|
117 | 92 | for item in examples:
|
118 | 93 | example_prompt += "Q: " + item['_source']['text'] + "\n"
|
119 | 94 | example_prompt += "A: ```sql\n" + item['_source']['sql'] + "```\n"
|
120 | 95 |
|
121 |
| - claude_prompt = '''Human: |
122 |
| -{dialect_prompt} |
123 |
| -Here is DDL of the database you are working on: |
124 |
| -```sql |
125 |
| -{ddl} |
126 |
| -``` |
127 |
| -Please do not perform any modifications to SQL tables. |
128 |
| -Absolutely do not output any columns, tables, or other information that is not mentioned in the database. Ensure that the program runs without errors. |
129 |
| -Here are some hints: |
130 |
| -{hints} |
131 |
| -DO NOT use window function in another function's argument. |
132 |
| -Also, here are some examples of generating SQL using natural language: |
133 |
| -{examples} |
134 |
| -Now, you need to answer the question: "{question}" in SQL. Please give the SQL statement that can answer the question. Aside from giving the SQL answer, concisely explain yourself after giving the answer in same language as the question. |
135 |
| -Assistant:'''.format(dialect_prompt=dialect_prompt, ddl=ddl, hints=hints, examples=example_prompt, question=search_box) |
136 |
| - |
137 |
| - llama_prompt = '''[INST]{dialect_prompt}[/INST] |
138 |
| -Here is DDL of the database you are working on: |
139 |
| -```sql |
140 |
| -{ddl} |
141 |
| -``` |
142 |
| -Please do not perform any modifications to SQL tables. |
143 |
| -Absolutely do not output any columns, tables, or other information that is not mentioned in the database. Ensure that the program runs without errors. |
144 |
| -Here are some hints: |
145 |
| -{hints} |
146 |
| -Also, here are some examples of generating SQL using natural language: |
147 |
| -{examples} |
148 |
| -Now, you need to answer the question: "{question}" in SQL. |
149 |
| -Please give the SQL statement that can answer the question. |
150 |
| -SQL Query:'''.format(dialect_prompt=dialect_prompt, ddl=ddl, hints=hints, examples=example_prompt, question=search_box) |
151 |
| - |
152 |
| - if model_provider == 'replicate': |
153 |
| - response = None |
154 |
| - # from utils.opensource_llm import invoke_model_replicate |
155 |
| - # response = invoke_model_replicate(model_id, ddl, hints, dialect_prompt, example_prompt, search_box) |
156 |
| - else: |
157 |
| - payload = { |
158 |
| - # "prompt": prompt, |
159 |
| - # "max_tokens_to_sample": 1024, |
160 |
| - "temperature": 0.01, |
161 |
| - "top_p": 0.9, |
162 |
| - } |
163 |
| - if 'anthropic.claude' in model_id: |
164 |
| - payload['max_tokens_to_sample'] = 1024 |
165 |
| - payload['prompt'] = claude_prompt |
166 |
| - elif 'meta.llama2' in model_id: |
167 |
| - payload['max_gen_len'] = 1024 |
168 |
| - payload['prompt'] = llama_prompt |
169 |
| - logger.info(f"{payload['prompt']=}") |
170 |
| - print(payload['prompt']) |
171 |
| - response = invoke_model(payload, model_id=model_id) |
172 |
| - return response |
| 96 | + claude_prompt = '''Here is DDL of the database you are working on: |
| 97 | + ``` |
| 98 | + {ddl} |
| 99 | + ``` |
| 100 | + Here are some hints: {hints} |
| 101 | + Also, here are some examples of generating SQL using natural language: |
| 102 | + {examples} |
| 103 | + Now, you need to answer the question: "{question}" in SQL. |
| 104 | + '''.format(ddl=ddl, hints=hints, examples=example_prompt, question=search_box) |
| 105 | + |
| 106 | + return claude_prompt, dialect_prompt |
| 107 | + |
| 108 | + |
| 109 | +@logger.catch |
| 110 | +def claude3_to_sql(ddl, hints, search_box, examples=None, model_id=None, dialect='mysql', model_provider=None): |
| 111 | + user_prompt, system_prompt = generate_prompt(ddl, hints, search_box, examples, model_id, dialect=dialect) |
| 112 | + |
| 113 | + max_tokens = 2048 |
| 114 | + |
| 115 | + # Prompt with user turn only. |
| 116 | + user_message = {"role": "user", "content": user_prompt} |
| 117 | + messages = [user_message] |
| 118 | + logger.info(f'{system_prompt=}') |
| 119 | + logger.info(f'{messages=}') |
| 120 | + response = invoke_model_claude3(model_id, system_prompt, messages, max_tokens) |
| 121 | + final_response = response.get("content")[0].get("text") |
| 122 | + |
| 123 | + return final_response |
173 | 124 |
|
174 | 125 |
|
175 | 126 | def create_vector_embedding_with_bedrock(text, index_name):
|
@@ -210,9 +161,9 @@ def retrieve_results_from_opensearch(index_name, region_name, domain, opensearch
|
210 | 161 | "query": {
|
211 | 162 | "bool": {
|
212 | 163 | "filter": {
|
213 |
| - "match_phrase": { |
| 164 | + "match_phrase": { |
214 | 165 | "profile": profile_name
|
215 |
| - } |
| 166 | + } |
216 | 167 | },
|
217 | 168 | "must": [
|
218 | 169 | {
|
|
0 commit comments