升级新库

This commit is contained in:
BBIT-Kai
2025-12-31 17:49:17 +08:00
parent d6c7f209c7
commit 6136554562
14 changed files with 355 additions and 356 deletions
+72 -77
View File
@@ -1,31 +1,12 @@
from typing import Literal
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from config.llm import llm, llmThink
from langgraph.graph import StateGraph, END
from langchain.prompts import PromptTemplate
from config.llm import llm
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_community.agent_toolkits import create_sql_agent
from langchain.prompts import PromptTemplate
from config.llm import llm
from typing import Annotated
from langgraph.graph.message import add_messages
import os
from langchain_tavily import TavilySearch
from langgraph.prebuilt import ToolNode, tools_condition
from llm.chatLLM import get_chat_response
from typing import TypedDict
from langchain_core.prompts import PromptTemplate
from langgraph.graph import StateGraph, END
from llm.summarizeLLM import getSummary
import db.postgres as pgdb
import db.sqlserver as sqlserver
from config.llm import llm
from config.llm import llmThink
# -------- 定义状态 --------
@@ -69,7 +50,7 @@ gen_sql_prompt = PromptTemplate(
请直接输出完整可执行的 SQL 语句,不要任何其他文字或格式化,例如反引号或 ```sql。
"""
""",
)
sqlChain = gen_sql_prompt | llm
@@ -89,33 +70,34 @@ fix_prompt = PromptTemplate(
# 输出要求
只返回修正后的 SQL 语句,不包含任何额外的解释或说明。
"""
""",
)
fixSQLChain = fix_prompt | llm
def sql(state: State):
if state["isFirstGenSQL"]:
state['sql'] = sql_1(state)
state["sql"] = sql_1(state)
else:
state['sql'] = sql_2(state)
state["sql"] = sql_2(state)
for attempt in range(2):
try:
# 执行 SQL
result = sqlserver.executeSQL(state['sql'])
state['sql_result'] = result
result = sqlserver.executeSQL(state["sql"])
state["sql_result"] = result
# print("SQL 执行成功,结果:", result)
break
except Exception as e:
error_msg = str(e)
print(f"SQL 执行出错: {error_msg}")
# 调用 LLM 修正 SQL
state['sql'] = fixSQLChain.invoke({
"sql": state['sql'],
"error_msg": error_msg,
"table_info": state['table_info'],
"tenant_id": state['tenant_id']
}
state["sql"] = fixSQLChain.invoke(
{
"sql": state["sql"],
"error_msg": error_msg,
"table_info": state["table_info"],
"tenant_id": state["tenant_id"],
}
).content
# print(f"LLM 生成修正 SQL: {state['sql']}")
else:
@@ -124,11 +106,13 @@ def sql(state: State):
def sql_1(state: State):
return sqlChain.invoke({
"table_info": state['table_info'],
"userInput": state["userInput"],
"tenant_id": state['tenant_id']
}).content
return sqlChain.invoke(
{
"table_info": state["table_info"],
"userInput": state["userInput"],
"tenant_id": state["tenant_id"],
}
).content
improve_sql_prompt = PromptTemplate(
@@ -151,18 +135,20 @@ improve_sql_prompt = PromptTemplate(
7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。
8. 查询的SQL字段要用别名,取名参考描述。
9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量加上WHERE tenantid = {tenant_id}
"""
""",
)
improveSqlChain = improve_sql_prompt | llm
def sql_2(state: State):
return improveSqlChain.invoke({
"sql": state['sql'],
"table_info": state['table_info'],
"userInput": state["userInput"],
"tenant_id": state['tenant_id']
}).content
return improveSqlChain.invoke(
{
"sql": state["sql"],
"table_info": state["table_info"],
"userInput": state["userInput"],
"tenant_id": state["tenant_id"],
}
).content
# ------------------------------------------------------------------------ 路径选择 --------
@@ -199,7 +185,7 @@ chat → 无法直接生成 SQL,需要进一步解释或澄清。
回答内容仅限于db或者chat,请勿输出其他内容。
你的回复:
"""
""",
)
pathSelectChain = pathSelectPrompt | llmThink
@@ -207,12 +193,18 @@ pathSelectChain = pathSelectPrompt | llmThink
def decide_source(state: State, max_retry=3):
"""根据用户输入选择数据来源"""
for _ in range(max_retry):
choice = pathSelectChain.invoke({
"userInput": state["userInput"],
"table_info": state["table_info"],
"ai_service": state["ai_service"],
"sql": state["sql"]
}).content.strip().lower()
choice = (
pathSelectChain.invoke(
{
"userInput": state["userInput"],
"table_info": state["table_info"],
"ai_service": state["ai_service"],
"sql": state["sql"],
}
)
.content.strip()
.lower()
)
print("根据用户输入选择数据来源,路径是:", choice)
if choice in ["db", "chat"]:
state["path"] = choice
@@ -242,17 +234,16 @@ noChatPrompt = PromptTemplate(
3. 引导用户提出与你业务相关的问题。
4. 使用礼貌和友好的语气。
你的回答:
"""
""",
)
noChatChain = noChatPrompt | llm
def chat(state: State):
state["reply"] = noChatChain.invoke({
"userInput": state["userInput"],
"ai_service": state["ai_service"]
}).content
state["reply"] = noChatChain.invoke(
{"userInput": state["userInput"], "ai_service": state["ai_service"]}
).content
print("直接回复")
return state
@@ -291,19 +282,21 @@ summarizePrompt = PromptTemplate(
2. 提供进一步可选的查询示例,基于当前的数据库表结构,引导用户提出更具体需求。
你的回复:
"""
""",
)
summarizeChain = summarizePrompt | llm
def summarize_ai(state: State):
"""AI 总结输出"""
state["reply"] = summarizeChain.invoke({
"ai_role": state["ai_role"],
"sql": state['sql'],
"userInput": state['userInput'],
"table_info": state['table_info'],
}).content
state["reply"] = summarizeChain.invoke(
{
"ai_role": state["ai_role"],
"sql": state["sql"],
"userInput": state["userInput"],
"table_info": state["table_info"],
}
).content
return state
@@ -322,7 +315,7 @@ workflow.add_conditional_edges(
{
"db": "sql_1",
"chat": "chat",
}
},
)
workflow.add_edge("summarize", END)
workflow.add_edge("chat", END)
@@ -334,13 +327,15 @@ def get_db_agent_reply(aiId: str, userInput: str, tenant_id: str, sql: str = "")
json = pgdb.get_ai_personality(aiId)
ai_service = json["业务"]
ai_role = json["性格"]
final_state = graph.invoke({
"ai_service": ai_service,
"ai_role": ai_role,
"table_info": pgdb.get_available_tables_str(aiId),
"tenant_id": tenant_id,
"userInput": userInput,
"sql": sql,
"isFirstGenSQL": sql == "",
})
final_state = graph.invoke(
{
"ai_service": ai_service,
"ai_role": ai_role,
"table_info": pgdb.get_available_tables_str(aiId),
"tenant_id": tenant_id,
"userInput": userInput,
"sql": sql,
"isFirstGenSQL": sql == "",
}
)
return final_state