139 lines
5.3 KiB
Python
139 lines
5.3 KiB
Python
|
|
from langchain.prompts import PromptTemplate
|
|
from config.llm import llm
|
|
from config.ssDb import ssDB
|
|
from typing import Annotated
|
|
from typing_extensions import TypedDict
|
|
from langgraph.graph import StateGraph, START, END
|
|
from langgraph.graph.message import add_messages
|
|
from langchain_community.agent_toolkits import create_sql_agent
|
|
from langchain.prompts import PromptTemplate
|
|
from config.llm import llm
|
|
from config.ssDb import ssDB
|
|
from typing import Annotated
|
|
from langgraph.graph.message import add_messages
|
|
import os
|
|
from langchain_tavily import TavilySearch
|
|
from langgraph.prebuilt import ToolNode, tools_condition
|
|
from llm.chatLLM import get_chat_response
|
|
from typing import TypedDict
|
|
from langgraph.graph import StateGraph, END
|
|
from llm.summarizeLLM import getSummary
|
|
|
|
|
|
# -------- 定义状态 --------
|
|
class State(TypedDict):
|
|
userInput: str # 用户输入
|
|
source: str # 选择的数据来源:web 或 db 或 chat
|
|
infomation: str # 查询到的内容
|
|
aiRole: str # AI 角色
|
|
history: str # 聊天历史
|
|
reply: str # 最终回复
|
|
|
|
# -------- 定义节点 --------
|
|
# ------------------------------------------------------------------------ 路径选择 --------
|
|
|
|
pathSelectPrompt = PromptTemplate(
|
|
input_variables=["aiRole", "history", "userStr", "infomation"],
|
|
template = """
|
|
你是主干信息科技有限公司的业务员,是一家蚕桑服务公司,现在需要根据用户输入来判断应该使用哪种方式来回答用户的问题。
|
|
你有三种选择:
|
|
1. 如果用户的问题涉及最新的信息,比如新闻、事件、天气等涉及时间的内容时,请选择 "web
|
|
2. 如果用户的问题涉及具体的蚕桑业务(例如询问农户、订单、订种、租户)的数据库查询需求,请选择 "db"
|
|
3. 如果用户的问题是一般性的聊天或咨询,请选择 "chat"
|
|
请只返回 "web"、"db" 或 "chat" 之一,且不要添加任何其他解释。
|
|
用户最新输入:
|
|
{userStr}
|
|
请做出你的选择:
|
|
"""
|
|
)
|
|
pathSelectChain = pathSelectPrompt | llm
|
|
|
|
def decide_source(state: State, max_retry=3):
|
|
print("根据用户输入选择数据来源,用户输入:", state["userInput"])
|
|
"""根据用户输入选择数据来源"""
|
|
for _ in range(max_retry):
|
|
choice = pathSelectChain.invoke({
|
|
"aiRole": state["aiRole"],
|
|
"history": state["history"],
|
|
"userStr": state["userInput"],
|
|
}).content.strip().lower()
|
|
if choice in ["web", "db", "chat"]:
|
|
state["source"] = choice
|
|
break
|
|
else:
|
|
# 如果连续 max_retry 次都不合法,默认走 chat
|
|
state["source"] = "chat"
|
|
print("选择的数据来源是:", state["source"])
|
|
return state
|
|
|
|
|
|
# ------------------------------------------------------------------------ 上网查询 --------
|
|
os.environ["TAVILY_API_KEY"] = "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"
|
|
tool = TavilySearch(max_results=2)
|
|
|
|
def fetch_web(state: State):
|
|
result = tool.invoke(state["userInput"])
|
|
state["infomation"] = result.get("content") or result
|
|
print("调用了联网工具,结果是:", state["infomation"])
|
|
return state
|
|
|
|
# ------------------------------------------------------------------------ 数据库查询 --------
|
|
agent = create_sql_agent(
|
|
llm=llm,
|
|
db=ssDB,
|
|
agent_type="tool-calling",
|
|
verbose=True
|
|
)
|
|
def fetch_db(state: State):
|
|
state["infomation"] = agent.invoke({"input": state["userInput"]})["output"]
|
|
print("调用了数据库工具,结果是:", state["infomation"])
|
|
return state
|
|
|
|
# ------------------------------------------------------------------------ 整理结果 --------
|
|
def summarize_ai(state: State):
|
|
"""AI 总结输出"""
|
|
state["reply"] = getSummary(aiRole=state["aiRole"], history=state["history"], userInput= state["userInput"], infomation= state["infomation"])
|
|
return state
|
|
|
|
# ------------------------------------------------------------------------ 普通聊天 --------
|
|
def chat(state: State):
|
|
state["reply"] = get_chat_response(aiRole=state["aiRole"],history=state["history"], userInput= state["userInput"]).content
|
|
print("直接回复")
|
|
return state
|
|
|
|
# ------------------------------------------------------------------------ 构建有向图 --------
|
|
workflow = StateGraph(State)
|
|
workflow.add_node("decide", decide_source)
|
|
workflow.add_node("fetch_web", fetch_web)
|
|
workflow.add_node("fetch_db", fetch_db)
|
|
workflow.add_node("chat", chat)
|
|
workflow.add_node("summarize", summarize_ai)
|
|
workflow.set_entry_point("decide")
|
|
|
|
# 两条路径最后都汇合到 summarize
|
|
workflow.add_edge(START, "decide")
|
|
workflow.add_edge("fetch_web", "summarize")
|
|
workflow.add_edge("fetch_db", "summarize")
|
|
# 条件边:根据 source 决定走向
|
|
workflow.add_conditional_edges(
|
|
"decide",
|
|
lambda state: state["source"], # 返回 state["source"] 的值
|
|
{
|
|
"web": "fetch_web",
|
|
"chat": "chat",
|
|
"db": "fetch_db"
|
|
}
|
|
)
|
|
workflow.add_edge("summarize", END)
|
|
workflow.add_edge("chat", END)
|
|
graph = workflow.compile()
|
|
|
|
# 执行函数
|
|
def get_graph_output(aiRole:str,history: str, userInput: str) -> str:
|
|
final_state = graph.invoke({
|
|
"aiRole":aiRole,
|
|
"history": history,
|
|
"userInput": userInput,
|
|
})
|
|
return final_state["reply"] |