Skip to content

Commit 5d41a9a

Browse files
authored
Merge pull request #107 from #61
Feature/61 질문 재정의 노드 강화
2 parents 002548e + 5e97671 commit 5d41a9a

File tree

9 files changed

+299
-6
lines changed

9 files changed

+299
-6
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,4 +10,4 @@ test_lhm/
1010
.cursorignore
1111
.vscode
1212
table_info_db
13-
ko_reranker_local
13+
ko_reranker_local

interface/lang2sql.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,10 @@
1111

1212
from llm_utils.connect_db import ConnectDB
1313
from llm_utils.graph import builder
14+
from llm_utils.enriched_graph import builder as enriched_builder
1415
from llm_utils.display_chart import DisplayChart
1516
from llm_utils.llm_response_parser import LLMResponseParser
1617

17-
1818
DEFAULT_QUERY = "고객 데이터를 기반으로 유니크한 유저 수를 카운트하는 쿼리"
1919
SIDEBAR_OPTIONS = {
2020
"show_total_token_usage": "Show Total Token Usage",
@@ -77,7 +77,10 @@ def execute_query(
7777

7878
graph = st.session_state.get("graph")
7979
if graph is None:
80-
graph = builder.compile()
80+
graph_builder = (
81+
enriched_builder if st.session_state.get("use_enriched") else builder
82+
)
83+
graph = graph_builder.compile()
8184
st.session_state["graph"] = graph
8285

8386
res = graph.invoke(
@@ -198,14 +201,29 @@ def should_show(_key: str) -> bool:
198201

199202
st.title("Lang2SQL")
200203

204+
# 워크플로우 선택(UI)
205+
use_enriched = st.sidebar.checkbox(
206+
"프로파일 추출 & 컨텍스트 보강 워크플로우 사용", value=False
207+
)
208+
201209
# 세션 상태 초기화
202-
if "graph" not in st.session_state:
203-
st.session_state["graph"] = builder.compile()
210+
if (
211+
"graph" not in st.session_state
212+
or st.session_state.get("use_enriched") != use_enriched
213+
):
214+
graph_builder = enriched_builder if use_enriched else builder
215+
st.session_state["graph"] = graph_builder.compile()
216+
217+
# 프로파일 추출 & 컨텍스트 보강 그래프
218+
st.session_state["use_enriched"] = use_enriched
204219
st.info("Lang2SQL이 성공적으로 시작되었습니다.")
205220

206221
# 새로고침 버튼 추가
207222
if st.sidebar.button("Lang2SQL 새로고침"):
208-
st.session_state["graph"] = builder.compile()
223+
graph_builder = (
224+
enriched_builder if st.session_state.get("use_enriched") else builder
225+
)
226+
st.session_state["graph"] = graph_builder.compile()
209227
st.sidebar.success("Lang2SQL이 성공적으로 새로고침되었습니다.")
210228

211229
user_query = st.text_area(

llm_utils/chains.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,14 @@
44
MessagesPlaceholder,
55
SystemMessagePromptTemplate,
66
)
7+
from pydantic import BaseModel, Field
78

89
from .llm_factory import get_llm
910

1011
from dotenv import load_dotenv
1112
from prompt.template_loader import get_prompt_template
1213

14+
1315
env_path = os.path.join(os.getcwd(), ".env")
1416

1517
if os.path.exists(env_path):
@@ -20,6 +22,16 @@
2022
llm = get_llm()
2123

2224

25+
class QuestionProfile(BaseModel):
26+
is_timeseries: bool = Field(description="시계열 분석 필요 여부")
27+
is_aggregation: bool = Field(description="집계 함수 필요 여부")
28+
has_filter: bool = Field(description="조건 필터 필요 여부")
29+
is_grouped: bool = Field(description="그룹화 필요 여부")
30+
has_ranking: bool = Field(description="정렬/순위 필요 여부")
31+
has_temporal_comparison: bool = Field(description="기간 비교 포함 여부")
32+
intent_type: str = Field(description="질문의 주요 의도 유형")
33+
34+
2335
def create_query_refiner_chain(llm):
2436
prompt = get_prompt_template("query_refiner_prompt")
2537
tool_choice_prompt = ChatPromptTemplate.from_messages(
@@ -72,8 +84,66 @@ def create_query_maker_chain(llm):
7284
return query_maker_prompt | llm
7385

7486

87+
def create_query_refiner_with_profile_chain(llm):
88+
prompt = get_prompt_template("query_refiner_prompt")
89+
90+
tool_choice_prompt = ChatPromptTemplate.from_messages(
91+
[
92+
SystemMessagePromptTemplate.from_template(prompt),
93+
MessagesPlaceholder(variable_name="user_input"),
94+
SystemMessagePromptTemplate.from_template(
95+
"다음은 사용자의 실제 사용 가능한 테이블 및 컬럼 정보입니다:"
96+
),
97+
MessagesPlaceholder(variable_name="searched_tables"),
98+
# 프로파일 정보 입력
99+
SystemMessagePromptTemplate.from_template(
100+
"다음은 사용자의 질문을 분석한 프로파일 정보입니다."
101+
),
102+
MessagesPlaceholder("profile_prompt"),
103+
SystemMessagePromptTemplate.from_template(
104+
"""
105+
위 사용자의 입력과 위 조건을 바탕으로
106+
분석 관점에서 **충분히 답변 가능한 형태**로
107+
"구체화된 질문"을 작성하세요.
108+
""",
109+
),
110+
]
111+
)
112+
113+
return tool_choice_prompt | llm
114+
115+
116+
def create_query_enrichment_chain(llm):
117+
prompt = get_prompt_template("query_enrichment_prompt")
118+
119+
enrichment_prompt = ChatPromptTemplate.from_messages(
120+
[
121+
SystemMessagePromptTemplate.from_template(prompt),
122+
]
123+
)
124+
125+
chain = enrichment_prompt | llm
126+
return chain
127+
128+
129+
def create_profile_extraction_chain(llm):
130+
prompt = get_prompt_template("profile_extraction_prompt")
131+
132+
profile_prompt = ChatPromptTemplate.from_messages(
133+
[
134+
SystemMessagePromptTemplate.from_template(prompt),
135+
]
136+
)
137+
138+
chain = profile_prompt | llm.with_structured_output(QuestionProfile)
139+
return chain
140+
141+
75142
query_refiner_chain = create_query_refiner_chain(llm)
76143
query_maker_chain = create_query_maker_chain(llm)
144+
profile_extraction_chain = create_profile_extraction_chain(llm)
145+
query_refiner_with_profile_chain = create_query_refiner_with_profile_chain(llm)
146+
query_enrichment_chain = create_query_enrichment_chain(llm)
77147

78148
if __name__ == "__main__":
79149
query_refiner_chain.invoke()

llm_utils/enriched_graph.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import json
2+
3+
from langgraph.graph import StateGraph, END
4+
from llm_utils.graph import (
5+
QueryMakerState,
6+
GET_TABLE_INFO,
7+
PROFILE_EXTRACTION,
8+
QUERY_REFINER,
9+
CONTEXT_ENRICHMENT,
10+
QUERY_MAKER,
11+
get_table_info_node,
12+
profile_extraction_node,
13+
query_refiner_with_profile_node,
14+
context_enrichment_node,
15+
query_maker_node,
16+
)
17+
18+
"""
19+
기본 워크플로우에 '프로파일 추출(PROFILE_EXTRACTION)'과 '컨텍스트 보강(CONTEXT_ENRICHMENT)'를
20+
추가한 확장된 그래프입니다.
21+
"""
22+
23+
# StateGraph 생성 및 구성
24+
builder = StateGraph(QueryMakerState)
25+
builder.set_entry_point(GET_TABLE_INFO)
26+
27+
# 노드 추가
28+
builder.add_node(GET_TABLE_INFO, get_table_info_node)
29+
builder.add_node(QUERY_REFINER, query_refiner_with_profile_node)
30+
builder.add_node(PROFILE_EXTRACTION, profile_extraction_node)
31+
builder.add_node(CONTEXT_ENRICHMENT, context_enrichment_node)
32+
builder.add_node(QUERY_MAKER, query_maker_node)
33+
34+
# 기본 엣지 설정
35+
builder.add_edge(GET_TABLE_INFO, PROFILE_EXTRACTION)
36+
builder.add_edge(PROFILE_EXTRACTION, QUERY_REFINER)
37+
builder.add_edge(QUERY_REFINER, CONTEXT_ENRICHMENT)
38+
builder.add_edge(CONTEXT_ENRICHMENT, QUERY_MAKER)
39+
40+
# QUERY_MAKER 노드 후 종료
41+
builder.add_edge(QUERY_MAKER, END)

llm_utils/graph.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,17 +11,23 @@
1111
from llm_utils.chains import (
1212
query_refiner_chain,
1313
query_maker_chain,
14+
query_refiner_with_profile_chain,
15+
profile_extraction_chain,
16+
query_enrichment_chain,
1417
)
1518

1619
from llm_utils.tools import get_info_from_db
1720
from llm_utils.retrieval import search_tables
21+
from llm_utils.utils import profile_to_text
1822

1923
# 노드 식별자 정의
2024
QUERY_REFINER = "query_refiner"
2125
GET_TABLE_INFO = "get_table_info"
2226
TOOL = "tool"
2327
TABLE_FILTER = "table_filter"
2428
QUERY_MAKER = "query_maker"
29+
PROFILE_EXTRACTION = "profile_extraction"
30+
CONTEXT_ENRICHMENT = "context_enrichment"
2531

2632

2733
# 상태 타입 정의 (추가 상태 정보와 메시지들을 포함)
@@ -31,12 +37,38 @@ class QueryMakerState(TypedDict):
3137
searched_tables: dict[str, dict[str, str]]
3238
best_practice_query: str
3339
refined_input: str
40+
question_profile: dict
3441
generated_query: str
3542
retriever_name: str
3643
top_n: int
3744
device: str
3845

3946

47+
# 노드 함수: PROFILE_EXTRACTION 노드
48+
def profile_extraction_node(state: QueryMakerState):
49+
"""
50+
자연어 쿼리로부터 질문 유형(PROFILE)을 추출하는 노드입니다.
51+
52+
이 노드는 주어진 자연어 쿼리에서 질문의 특성을 분석하여, 해당 질문이 시계열 분석, 집계 함수 사용, 조건 필터 필요 여부,
53+
그룹화, 정렬/순위, 기간 비교 등 다양한 특성을 갖는지 여부를 추출합니다.
54+
55+
추출된 정보는 `QuestionProfile` 모델에 맞춰 저장됩니다. `QuestionProfile` 모델의 필드는 다음과 같습니다:
56+
- `is_timeseries`: 시계열 분석 필요 여부
57+
- `is_aggregation`: 집계 함수 필요 여부
58+
- `has_filter`: 조건 필터 필요 여부
59+
- `is_grouped`: 그룹화 필요 여부
60+
- `has_ranking`: 정렬/순위 필요 여부
61+
- `has_temporal_comparison`: 기간 비교 포함 여부
62+
- `intent_type`: 질문의 주요 의도 유형
63+
64+
"""
65+
result = profile_extraction_chain.invoke({"question": state["messages"][0].content})
66+
67+
state["question_profile"] = result
68+
print("profile_extraction_node : ", result)
69+
return state
70+
71+
4072
# 노드 함수: QUERY_REFINER 노드
4173
def query_refiner_node(state: QueryMakerState):
4274
res = query_refiner_chain.invoke(
@@ -52,6 +84,80 @@ def query_refiner_node(state: QueryMakerState):
5284
return state
5385

5486

87+
# 노드 함수: QUERY_REFINER 노드
88+
def query_refiner_with_profile_node(state: QueryMakerState):
89+
"""
90+
자연어 쿼리로부터 질문 유형(PROFILE)을 사용해 자연어 질의를 확장하는 노드입니다.
91+
92+
"""
93+
94+
profile_bullets = profile_to_text(state["question_profile"])
95+
res = query_refiner_with_profile_chain.invoke(
96+
input={
97+
"user_input": [state["messages"][0].content],
98+
"user_database_env": [state["user_database_env"]],
99+
"best_practice_query": [state["best_practice_query"]],
100+
"searched_tables": [json.dumps(state["searched_tables"])],
101+
"profile_prompt": [profile_bullets],
102+
}
103+
)
104+
state["messages"].append(res)
105+
state["refined_input"] = res
106+
107+
print("refined_input before context enrichment : ", res.content)
108+
return state
109+
110+
111+
# 노드 함수: CONTEXT_ENRICHMENT 노드
112+
def context_enrichment_node(state: QueryMakerState):
113+
"""
114+
주어진 질문과 관련된 메타데이터를 기반으로 질문을 풍부하게 만드는 노드입니다.
115+
116+
이 함수는 `refined_question`, `profiles`, `related_tables` 정보를 이용하여 자연어 질문을 보강합니다.
117+
보강 과정에서는 질문의 의도를 유지하면서, 추가적인 세부 정보를 제공하거나 잘못된 용어를 수정합니다.
118+
119+
주요 작업:
120+
- 주어진 질문의 메타데이터 (`question_profile` 및 `searched_tables`)를 활용하여, 질문을 수정하거나 추가 정보를 삽입합니다.
121+
- 질문이 시계열 분석 또는 집계 함수 관련인 경우, 이를 명시적으로 강조합니다 (예: "지난 30일 동안").
122+
- 자연어에서 실제 열 이름 또는 값으로 잘못 매칭된 용어를 수정합니다 (예: ‘미국’ → ‘USA’).
123+
- 보강된 질문을 출력합니다.
124+
125+
Args:
126+
state (QueryMakerState): 쿼리와 관련된 상태 정보를 담고 있는 객체.
127+
상태 객체는 `refined_input`, `question_profile`, `searched_tables` 등의 정보를 포함합니다.
128+
129+
Returns:
130+
QueryMakerState: 보강된 질문이 포함된 상태 객체.
131+
132+
Example:
133+
Given the refined question "What are the total sales in the last month?",
134+
the function would enrich it with additional information such as:
135+
- Ensuring the time period is specified correctly.
136+
- Correcting any column names if necessary.
137+
- Returning the enriched version of the question.
138+
"""
139+
140+
searched_tables = state["searched_tables"]
141+
searched_tables_json = json.dumps(searched_tables, ensure_ascii=False, indent=2)
142+
143+
question_profile = state["question_profile"].model_dump()
144+
question_profile_json = json.dumps(question_profile, ensure_ascii=False, indent=2)
145+
146+
enriched_text = query_enrichment_chain.invoke(
147+
input={
148+
"refined_question": state["refined_input"],
149+
"profiles": question_profile_json,
150+
"related_tables": searched_tables_json,
151+
}
152+
)
153+
154+
state["refined_input"] = enriched_text
155+
state["messages"].append(enriched_text)
156+
print("After context enrichment : ", enriched_text.content)
157+
158+
return state
159+
160+
55161
def get_table_info_node(state: QueryMakerState):
56162
# retriever_name과 top_n을 이용하여 검색 수행
57163
documents_dict = search_tables(

llm_utils/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
def profile_to_text(profile_obj) -> str:
2+
mapping = {
3+
"is_timeseries": "• 시계열 분석 필요",
4+
"is_aggregation": "• 집계 함수 필요",
5+
"has_filter": "• WHERE 조건 필요",
6+
"is_grouped": "• GROUP BY 필요",
7+
"has_ranking": "• 정렬/순위 필요",
8+
"has_temporal_comparison": "• 기간 비교 필요",
9+
}
10+
bullets = [
11+
text for field, text in mapping.items() if getattr(profile_obj, field, False)
12+
]
13+
intent = getattr(profile_obj, "intent_type", None)
14+
if intent:
15+
bullets.append(f"• 의도 유형 → {intent}")
16+
17+
return "\n".join(bullets)
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Role
2+
3+
You are an assistant that analyzes a user question and extracts the following profiles as JSON:
4+
- is_timeseries (boolean)
5+
- is_aggregation (boolean)
6+
- has_filter (boolean)
7+
- is_grouped (boolean)
8+
- has_ranking (boolean)
9+
- has_temporal_comparison (boolean)
10+
- intent_type (one of: trend, lookup, comparison, distribution)
11+
12+
# Input
13+
14+
Question:
15+
{question}
16+
17+
# Output Example
18+
19+
The output must be a valid JSON matching the QuestionProfile schema.

0 commit comments

Comments
 (0)