from typing import Literal from langchain_core.messages import AIMessage from langchain_core.runnables import RunnableConfig from langgraph.graph import END, START, MessagesState, StateGraph from langgraph.prebuilt import ToolNode from langchain_community.agent_toolkits import SQLDatabaseToolkit from config.llm import llm, llmThink from langgraph.graph import StateGraph, END from langchain.prompts import PromptTemplate from config.llm import llm from typing import Annotated from typing_extensions import TypedDict from langgraph.graph import StateGraph, START, END from langgraph.graph.message import add_messages from langchain_community.agent_toolkits import create_sql_agent from langchain.prompts import PromptTemplate from config.llm import llm from typing import Annotated from langgraph.graph.message import add_messages import os from langchain_tavily import TavilySearch from langgraph.prebuilt import ToolNode, tools_condition from llm.chatLLM import get_chat_response from typing import TypedDict from langgraph.graph import StateGraph, END from llm.summarizeLLM import getSummary import db.postgres as pgdb import db.sqlserver as sqlserver # -------- 定义状态 -------- class State(TypedDict): userInput: str # 用户输入 path: str # 开始聊天选择的路径 table_info: str # 可用表信息 isFirstGenSQL: bool # 是否第一次生成SQL sql: str # 当前操作的SQL ai_service: str # AI 角色 业务 ai_role: str # AI 角色 性格特点 tenant_id: str # 租户ID history: str # 聊天历史 reply: str # 最终回复 # -------- 定义节点 -------- # ------------------------------------------------------------------------ 数据库查询 -------- gen_sql_prompt = PromptTemplate( input_variables=["table_info", "userInput"], template=""" # 角色 你是一个企业 SQL Server 数据库 SQL 生成助手,负责根据用户问题生成相应的 SQL 语句。 # 已知信息 可访问的表和字段:{table_info} # 任务要求 1. 根据用户提出的问题“{userInput}”生成 SQL 语句。 2. 只能使用已知的表和字段。 3. 输出完整可执行的 SQL 语句,不包含多余文字。 4. 若 SQL 语句返回列表数据,需限制返回数量,最大为 15 条,使用 SQL Server 语法(TOP 15 或 OFFSET FETCH)。 5. 若 SQL 是聚合查询(如 COUNT、SUM 等),无需限制行数。 6. 在生成 SQL 时,如果需要根据身份证计算年龄,请使用 SQL Server 标准日期格式 SUBSTRING(idcard, 7, 8) 和 CONVERT(..., 112),不要使用拼接 / 或非标准日期格式。 7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。 8. 查询的SQL字段要用别名,取名参考描述。 9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量限制租户id = {tenant_id}。 请直接输出完整可执行的 SQL 语句,不要任何其他文字或格式化,例如反引号或 ```sql。 """ ) sqlChain = gen_sql_prompt | llm fix_prompt = PromptTemplate( input_variables=["sql", "error_msg", "table_info", "tenant_id"], template=""" # 系统角色 你是一位专业的 SQL Server的SQL语句纠错专家,擅长识别 SQL 语句中的语法错误和字段引用错误,并能对其进行修正。 # 任务 根据提供的原始 SQL 语句、执行报错信息以及可用表和字段信息,修正 SQL 语句,确保其语法正确且引用的字段存在。 # 输入信息 - 原始 SQL: {sql} - 执行报错: {error_msg} - 可用表和字段: {table_info} # 输出要求 只返回修正后的 SQL 语句,不包含任何额外的解释或说明。 """ ) fixSQLChain = fix_prompt | llm def sql(state: State): if state["isFirstGenSQL"]: state['sql'] = sql_1(state) else: state['sql'] = sql_2(state) for attempt in range(2): try: # 执行 SQL result = sqlserver.executeSQL(state['sql']) state['sql_result'] = result # print("SQL 执行成功,结果:", result) break except Exception as e: error_msg = str(e) print(f"SQL 执行出错: {error_msg}") # 调用 LLM 修正 SQL state['sql'] = fixSQLChain.invoke({ "sql": state['sql'], "error_msg": error_msg, "table_info": state['table_info'], "tenant_id": state['tenant_id'] } ).content # print(f"LLM 生成修正 SQL: {state['sql']}") else: raise RuntimeError(f"SQL 多次纠错失败,最后 SQL: {state['sql']}") return state def sql_1(state: State): return sqlChain.invoke({ "table_info": state['table_info'], "userInput": state["userInput"], "tenant_id": state['tenant_id'] }).content improve_sql_prompt = PromptTemplate( input_variables=["table_info", "userInput", "tenant_id"], template=""" # 角色 你是一个企业 SQL Server 数据库 SQL 生成助手,负责根据用户问题改进相应的 SQL 语句。 # 已知信息 当前 SQL 语句: {sql} 可访问的表和字段: {table_info} # 任务要求 1. 根据用户提出的问题“{userInput}”以及当前的 SQL 语句进行改进。 2. 只能使用已知的表和字段。 3. 输出完整可执行的 SQL 语句,不包含多余文字。 4. 若 SQL 语句返回列表数据,需限制返回数量,最大为 15 条,使用 SQL Server 语法(TOP 15 或 OFFSET FETCH)。 5. 若 SQL 是聚合查询(如 COUNT、SUM 等),无需限制行数。 6. 在生成 SQL 时,如果需要根据身份证计算年龄,请使用 SQL Server 标准日期格式 SUBSTRING(idcard, 7, 8) 和 CONVERT(..., 112),不要使用拼接 / 或非标准日期格式。 7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。 8. 查询的SQL字段要用别名,取名参考描述。 9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量加上WHERE tenantid = {tenant_id}。 """ ) improveSqlChain = improve_sql_prompt | llm def sql_2(state: State): return improveSqlChain.invoke({ "sql": state['sql'], "table_info": state['table_info'], "userInput": state["userInput"], "tenant_id": state['tenant_id'] }).content # ------------------------------------------------------------------------ 路径选择 -------- pathSelectPrompt = PromptTemplate( input_variables=["userInput", "table_info", "sql"], template=""" 你的任务是: 1. 根据用户输入的问题和已知的表结构数据,判断是否能够生成准确的 SQL 查询。 2. 首先仔细阅读以下表结构数据: {table_info} 2. 然后仔细阅读用户输入的问题: {userInput} 3. 请严格遵循以下规则: 只有在能够完全、明确、直接根据表结构生成正确 SQL 时,输出 db。 参考表结构或字段描述中出现的关键词:如果用户问题中出现的关键字段或概念在表结构中找不到明确对应关系,或者问题逻辑无法直接映射到表结构,输出 chat。 不允许假设额外存在的表、字段或数据,也不允许基于常识或推测生成 SQL。 输出必须严格二选一: db → 可以直接生成 SQL。 chat → 无法直接生成 SQL,需要进一步解释或澄清。 回答内容仅限于db或者chat,请勿输出其他内容。 你的回复: """ ) pathSelectChain = pathSelectPrompt | llmThink def decide_source(state: State, max_retry=3): """根据用户输入选择数据来源""" for _ in range(max_retry): choice = pathSelectChain.invoke({ "userInput": state["userInput"], "table_info": state["table_info"], "ai_service": state["ai_service"], "sql": state["sql"] }).content.strip().lower() print("根据用户输入选择数据来源,路径是:", choice) if choice in ["db", "chat"]: state["path"] = choice break else: # 如果连续 max_retry 次都不合法,默认走 chat state["path"] = "chat" return state # ------------------------------------------------------------------------ !普通聊天 -------- noChatPrompt = PromptTemplate( input_variables=["userInput", "ai_service"], template=""" 你的任务是回复用户,告知用户你目前无法处理他们的回复,因为你的业务是特定领域的服务。请仔细阅读以下信息,并按照指示进行回复。 用户的回复: {userInput} 你的业务: {ai_service} 在回复时,请遵循以下指南: 1. 明确告知用户你无法处理当前回复。 2. 提及你的业务是{ai_service}。 3. 引导用户提出与你业务相关的问题。 4. 使用礼貌和友好的语气。 你的回答: """ ) noChatChain = noChatPrompt | llm def chat(state: State): state["reply"] = noChatChain.invoke({ "userInput": state["userInput"], "ai_service": state["ai_service"] }).content print("直接回复") return state # ------------------------------------------------------------------------ 整理结果 -------- summarizePrompt = PromptTemplate( input_variables=["ai_role", "history", "userStr", "table_info"], template=""" 你是主干信息研发的AI助手,你的性格特点为: {ai_role} 用户之前的提问为: {userInput} 当前生成的SQL语句为: {sql} 当前支持的数据库表与字段信息如下: {table_info} 你的核心任务是根据用户之前的提问和当前生成的SQL语句,引导用户理解当前SQL的含义,并询问是否需要修改或完善,同时提供进一步可选的查询示例,引导用户提出更具体的需求。 交流要求如下: - 先明确SQL的用途,再提出引导性问题。 - 回答要简洁、易理解。 - 回复内容不要出现SQL语句,不要对SQL进行解释,只需说,查询结果已生成,然后引导用户进一步提问。 任务流程如下: 1. 询问用户是否需要对当前查询内容进行修改或完善。 2. 提供进一步可选的查询示例,基于当前的数据库表结构,引导用户提出更具体需求。 你的回复: """ ) summarizeChain = summarizePrompt | llm def summarize_ai(state: State): """AI 总结输出""" state["reply"] = summarizeChain.invoke({ "ai_role": state["ai_role"], "sql": state['sql'], "userInput": state['userInput'], "table_info": state['table_info'], }).content return state # ------------------------------------------------------------------------ 构建有向图 -------- workflow = StateGraph(State) workflow.add_node("decide", decide_source) workflow.add_node("sql_1", sql) workflow.add_node("chat", chat) workflow.add_node("summarize", summarize_ai) workflow.set_entry_point("decide") workflow.add_edge("sql_1", "summarize") # 条件边:根据 path 决定走向 workflow.add_conditional_edges( "decide", lambda state: state["path"], # 返回 state["path"] 的值 { "db": "sql_1", "chat": "chat", } ) workflow.add_edge("summarize", END) workflow.add_edge("chat", END) graph = workflow.compile() # 执行函数 def get_db_agent_reply(aiId: str, userInput: str, tenant_id: str, sql: str = "") -> str: json = pgdb.get_ai_personality(aiId) ai_service = json["业务"] ai_role = json["性格"] final_state = graph.invoke({ "ai_service": ai_service, "ai_role": ai_role, "table_info": pgdb.get_available_tables_str(aiId), "tenant_id": tenant_id, "userInput": userInput, "sql": sql, "isFirstGenSQL": sql == "", }) return final_state