升级新库
This commit is contained in:
@@ -1,31 +1,12 @@
|
||||
from typing import Literal
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph import END, START, MessagesState, StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from config.llm import llm, llmThink
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
from langchain_community.agent_toolkits import create_sql_agent
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm
|
||||
from typing import Annotated
|
||||
from langgraph.graph.message import add_messages
|
||||
import os
|
||||
from langchain_tavily import TavilySearch
|
||||
from langgraph.prebuilt import ToolNode, tools_condition
|
||||
from llm.chatLLM import get_chat_response
|
||||
from typing import TypedDict
|
||||
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langgraph.graph import StateGraph, END
|
||||
from llm.summarizeLLM import getSummary
|
||||
|
||||
import db.postgres as pgdb
|
||||
import db.sqlserver as sqlserver
|
||||
from config.llm import llm
|
||||
from config.llm import llmThink
|
||||
|
||||
|
||||
# -------- 定义状态 --------
|
||||
@@ -69,7 +50,7 @@ gen_sql_prompt = PromptTemplate(
|
||||
|
||||
|
||||
请直接输出完整可执行的 SQL 语句,不要任何其他文字或格式化,例如反引号或 ```sql。
|
||||
"""
|
||||
""",
|
||||
)
|
||||
sqlChain = gen_sql_prompt | llm
|
||||
|
||||
@@ -89,33 +70,34 @@ fix_prompt = PromptTemplate(
|
||||
|
||||
# 输出要求
|
||||
只返回修正后的 SQL 语句,不包含任何额外的解释或说明。
|
||||
"""
|
||||
""",
|
||||
)
|
||||
fixSQLChain = fix_prompt | llm
|
||||
|
||||
|
||||
def sql(state: State):
|
||||
if state["isFirstGenSQL"]:
|
||||
state['sql'] = sql_1(state)
|
||||
state["sql"] = sql_1(state)
|
||||
else:
|
||||
state['sql'] = sql_2(state)
|
||||
state["sql"] = sql_2(state)
|
||||
for attempt in range(2):
|
||||
try:
|
||||
# 执行 SQL
|
||||
result = sqlserver.executeSQL(state['sql'])
|
||||
state['sql_result'] = result
|
||||
result = sqlserver.executeSQL(state["sql"])
|
||||
state["sql_result"] = result
|
||||
# print("SQL 执行成功,结果:", result)
|
||||
break
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"SQL 执行出错: {error_msg}")
|
||||
# 调用 LLM 修正 SQL
|
||||
state['sql'] = fixSQLChain.invoke({
|
||||
"sql": state['sql'],
|
||||
"error_msg": error_msg,
|
||||
"table_info": state['table_info'],
|
||||
"tenant_id": state['tenant_id']
|
||||
}
|
||||
state["sql"] = fixSQLChain.invoke(
|
||||
{
|
||||
"sql": state["sql"],
|
||||
"error_msg": error_msg,
|
||||
"table_info": state["table_info"],
|
||||
"tenant_id": state["tenant_id"],
|
||||
}
|
||||
).content
|
||||
# print(f"LLM 生成修正 SQL: {state['sql']}")
|
||||
else:
|
||||
@@ -124,11 +106,13 @@ def sql(state: State):
|
||||
|
||||
|
||||
def sql_1(state: State):
|
||||
return sqlChain.invoke({
|
||||
"table_info": state['table_info'],
|
||||
"userInput": state["userInput"],
|
||||
"tenant_id": state['tenant_id']
|
||||
}).content
|
||||
return sqlChain.invoke(
|
||||
{
|
||||
"table_info": state["table_info"],
|
||||
"userInput": state["userInput"],
|
||||
"tenant_id": state["tenant_id"],
|
||||
}
|
||||
).content
|
||||
|
||||
|
||||
improve_sql_prompt = PromptTemplate(
|
||||
@@ -151,18 +135,20 @@ improve_sql_prompt = PromptTemplate(
|
||||
7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。
|
||||
8. 查询的SQL字段要用别名,取名参考描述。
|
||||
9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量加上WHERE tenantid = {tenant_id}。
|
||||
"""
|
||||
""",
|
||||
)
|
||||
improveSqlChain = improve_sql_prompt | llm
|
||||
|
||||
|
||||
def sql_2(state: State):
|
||||
return improveSqlChain.invoke({
|
||||
"sql": state['sql'],
|
||||
"table_info": state['table_info'],
|
||||
"userInput": state["userInput"],
|
||||
"tenant_id": state['tenant_id']
|
||||
}).content
|
||||
return improveSqlChain.invoke(
|
||||
{
|
||||
"sql": state["sql"],
|
||||
"table_info": state["table_info"],
|
||||
"userInput": state["userInput"],
|
||||
"tenant_id": state["tenant_id"],
|
||||
}
|
||||
).content
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------ 路径选择 --------
|
||||
@@ -199,7 +185,7 @@ chat → 无法直接生成 SQL,需要进一步解释或澄清。
|
||||
|
||||
回答内容仅限于db或者chat,请勿输出其他内容。
|
||||
你的回复:
|
||||
"""
|
||||
""",
|
||||
)
|
||||
pathSelectChain = pathSelectPrompt | llmThink
|
||||
|
||||
@@ -207,12 +193,18 @@ pathSelectChain = pathSelectPrompt | llmThink
|
||||
def decide_source(state: State, max_retry=3):
|
||||
"""根据用户输入选择数据来源"""
|
||||
for _ in range(max_retry):
|
||||
choice = pathSelectChain.invoke({
|
||||
"userInput": state["userInput"],
|
||||
"table_info": state["table_info"],
|
||||
"ai_service": state["ai_service"],
|
||||
"sql": state["sql"]
|
||||
}).content.strip().lower()
|
||||
choice = (
|
||||
pathSelectChain.invoke(
|
||||
{
|
||||
"userInput": state["userInput"],
|
||||
"table_info": state["table_info"],
|
||||
"ai_service": state["ai_service"],
|
||||
"sql": state["sql"],
|
||||
}
|
||||
)
|
||||
.content.strip()
|
||||
.lower()
|
||||
)
|
||||
print("根据用户输入选择数据来源,路径是:", choice)
|
||||
if choice in ["db", "chat"]:
|
||||
state["path"] = choice
|
||||
@@ -242,17 +234,16 @@ noChatPrompt = PromptTemplate(
|
||||
3. 引导用户提出与你业务相关的问题。
|
||||
4. 使用礼貌和友好的语气。
|
||||
你的回答:
|
||||
"""
|
||||
""",
|
||||
)
|
||||
|
||||
noChatChain = noChatPrompt | llm
|
||||
|
||||
|
||||
def chat(state: State):
|
||||
state["reply"] = noChatChain.invoke({
|
||||
"userInput": state["userInput"],
|
||||
"ai_service": state["ai_service"]
|
||||
}).content
|
||||
state["reply"] = noChatChain.invoke(
|
||||
{"userInput": state["userInput"], "ai_service": state["ai_service"]}
|
||||
).content
|
||||
print("直接回复")
|
||||
return state
|
||||
|
||||
@@ -291,19 +282,21 @@ summarizePrompt = PromptTemplate(
|
||||
2. 提供进一步可选的查询示例,基于当前的数据库表结构,引导用户提出更具体需求。
|
||||
|
||||
你的回复:
|
||||
"""
|
||||
""",
|
||||
)
|
||||
summarizeChain = summarizePrompt | llm
|
||||
|
||||
|
||||
def summarize_ai(state: State):
|
||||
"""AI 总结输出"""
|
||||
state["reply"] = summarizeChain.invoke({
|
||||
"ai_role": state["ai_role"],
|
||||
"sql": state['sql'],
|
||||
"userInput": state['userInput'],
|
||||
"table_info": state['table_info'],
|
||||
}).content
|
||||
state["reply"] = summarizeChain.invoke(
|
||||
{
|
||||
"ai_role": state["ai_role"],
|
||||
"sql": state["sql"],
|
||||
"userInput": state["userInput"],
|
||||
"table_info": state["table_info"],
|
||||
}
|
||||
).content
|
||||
return state
|
||||
|
||||
|
||||
@@ -322,7 +315,7 @@ workflow.add_conditional_edges(
|
||||
{
|
||||
"db": "sql_1",
|
||||
"chat": "chat",
|
||||
}
|
||||
},
|
||||
)
|
||||
workflow.add_edge("summarize", END)
|
||||
workflow.add_edge("chat", END)
|
||||
@@ -334,13 +327,15 @@ def get_db_agent_reply(aiId: str, userInput: str, tenant_id: str, sql: str = "")
|
||||
json = pgdb.get_ai_personality(aiId)
|
||||
ai_service = json["业务"]
|
||||
ai_role = json["性格"]
|
||||
final_state = graph.invoke({
|
||||
"ai_service": ai_service,
|
||||
"ai_role": ai_role,
|
||||
"table_info": pgdb.get_available_tables_str(aiId),
|
||||
"tenant_id": tenant_id,
|
||||
"userInput": userInput,
|
||||
"sql": sql,
|
||||
"isFirstGenSQL": sql == "",
|
||||
})
|
||||
final_state = graph.invoke(
|
||||
{
|
||||
"ai_service": ai_service,
|
||||
"ai_role": ai_role,
|
||||
"table_info": pgdb.get_available_tables_str(aiId),
|
||||
"tenant_id": tenant_id,
|
||||
"userInput": userInput,
|
||||
"sql": sql,
|
||||
"isFirstGenSQL": sql == "",
|
||||
}
|
||||
)
|
||||
return final_state
|
||||
|
||||
Reference in New Issue
Block a user