246 lines
8.2 KiB
Python
246 lines
8.2 KiB
Python
|
|
from typing import Literal
|
|
from langchain_core.messages import AIMessage
|
|
from langchain_core.runnables import RunnableConfig
|
|
from langgraph.graph import END, START, MessagesState, StateGraph
|
|
from langgraph.prebuilt import ToolNode
|
|
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|
from config.llm import llm,llmThink
|
|
from langgraph.graph import StateGraph, END
|
|
from langchain.prompts import PromptTemplate
|
|
from config.llm import llm
|
|
from typing import Annotated
|
|
from typing_extensions import TypedDict
|
|
from langgraph.graph import StateGraph, START, END
|
|
from langgraph.graph.message import add_messages
|
|
from langchain_community.agent_toolkits import create_sql_agent
|
|
from langchain.prompts import PromptTemplate
|
|
from config.llm import llm
|
|
from typing import Annotated
|
|
from langgraph.graph.message import add_messages
|
|
import os
|
|
from langchain_tavily import TavilySearch
|
|
from langgraph.prebuilt import ToolNode, tools_condition
|
|
from llm.chatLLM import get_chat_response
|
|
from typing import TypedDict
|
|
from langgraph.graph import StateGraph, END
|
|
from llm.summarizeLLM import getSummary
|
|
import db.postgres as pgdb
|
|
import db.sqlserver as sqlserver
|
|
from typing import List, Dict
|
|
import db.milvus as milvus
|
|
|
|
|
|
# -------- 定义状态 --------
|
|
class State(TypedDict):
|
|
path: str # 开始聊天选择的路径
|
|
|
|
memory:str # 记忆
|
|
knowledge: str # 知识库内容
|
|
history: str # 聊天历史
|
|
|
|
ai_id : str # AI id
|
|
ai_name:str # AI 名称
|
|
ai_service: str # AI 角色 业务
|
|
ai_role: str # AI 角色 性格特点
|
|
kn_bases: List[str] # AI 所使用的知识库
|
|
|
|
userInput: str # 用户输入
|
|
reply: str # 最终回复
|
|
|
|
# -------- 定义节点 --------
|
|
# ------------------------------------------------------------------------ 向量数据库查询 --------
|
|
|
|
gen_sql_prompt = PromptTemplate(
|
|
input_variables=["userInput"],
|
|
template = """你的任务是对用户输入进行意图分析,并将其分解成方便进行知识向量数据库搜索的关键词。
|
|
以下是用户的输入:
|
|
<用户输入>
|
|
{userInput}
|
|
</用户输入>
|
|
在提取关键词时,请遵循以下方法和要求:
|
|
1. 去除输入中的停用词(如“的”“是”“在”等)、语气词和无实际意义的符号。
|
|
2. 识别输入中的核心概念、实体和关键动作。
|
|
3. 尽量使用简洁、通用的词汇作为关键词。
|
|
4. 确保关键词之间相互独立,不包含其他关键词。
|
|
关键词之间用空格分隔。
|
|
你的回答是:
|
|
"""
|
|
)
|
|
sqlChain = gen_sql_prompt | llm
|
|
def db_search(state: State):
|
|
key_words = sqlChain.invoke({
|
|
"userInput": state['userInput'],
|
|
}).content
|
|
print("关键词是:", key_words)
|
|
knowledge = milvus.get_knowledge_by_key_words(key_words, state['kn_bases'])
|
|
print("知识库内容是:", knowledge)
|
|
state["knowledge"] = knowledge
|
|
ai_ids = [state['ai_id']]
|
|
memory = milvus.get_memory_by_key_words(key_words, ai_ids)
|
|
print("记忆是:", memory)
|
|
state["memory"] = memory
|
|
return state
|
|
|
|
# ------------------------------------------------------------------------ 意图分析 --------
|
|
|
|
pathSelectPrompt = PromptTemplate(
|
|
input_variables=[ "userInput","ai_service","history"],
|
|
template = """
|
|
你是一个意图分类器,负责判断用户提问是否与你的工作相关,进而确定是否需要去查知识库。
|
|
以下是你负责的工作内容:
|
|
<ai_service>
|
|
{ai_service}
|
|
</ai_service>
|
|
这是你们的对话历史:
|
|
<history>
|
|
{history}
|
|
</history>
|
|
用户最新回复是:
|
|
<userInput>
|
|
{userInput}
|
|
</userInput>
|
|
判断规则如下:
|
|
如果用户最新回复与你的负责工作相关,需要去查知识库,输出“kn”;如果不相关,则输出“chat”,不要包含任何标点符号以及空格。
|
|
你生成的结果:
|
|
"""
|
|
)
|
|
pathSelectChain = pathSelectPrompt | llmThink
|
|
def decide_source(state: State, max_retry=3):
|
|
"""根据用户输入选择数据来源"""
|
|
for _ in range(max_retry):
|
|
choice = pathSelectChain.invoke({
|
|
"userInput": state["userInput"],
|
|
"ai_service": state["ai_service"],
|
|
"history": state["history"],
|
|
}).content.strip().lower()
|
|
print("根据用户输入选择数据来源,路径是:", choice)
|
|
if choice in ["kn", "chat"]:
|
|
state["path"] = choice
|
|
break
|
|
else:
|
|
# 如果连续 max_retry 次都不合法,默认走 chat
|
|
state["path"] = "chat"
|
|
return state
|
|
|
|
# ------------------------------------------------------------------------ !普通聊天 --------
|
|
noChatPrompt = PromptTemplate(
|
|
input_variables=[ "ai_name", "ai_service", "ai_role", "history"],
|
|
template = """
|
|
你的名字是:{ai_name},你负责的业务是{ai_service},你具有{ai_role}的性格特点。
|
|
|
|
这是你和用户的对话历史
|
|
<history>
|
|
{history}
|
|
</history>
|
|
在回复用户时,请遵循以下指南:
|
|
1. 回复要与AI角色业务相关,体现AI的专业能力。
|
|
2. 回复内容的语气和风格要符合AI角色性格特点。
|
|
3. 参考聊天历史,使回复具有连贯性和针对性。
|
|
4. 回复要简洁明了,避免冗长和复杂的表述。
|
|
|
|
你的回答:
|
|
"""
|
|
)
|
|
|
|
noChatChain = noChatPrompt | llm
|
|
def chat(state: State):
|
|
state["reply"] = noChatChain.invoke({
|
|
"ai_name": state["ai_name"],
|
|
"ai_service": state["ai_service"],
|
|
"ai_role": state["ai_role"],
|
|
"history": state["history"],
|
|
"userStr": state["userInput"]
|
|
}).content
|
|
print("直接回复")
|
|
return state
|
|
|
|
# ------------------------------------------------------------------------ 整理结果 --------
|
|
|
|
summarizePrompt = PromptTemplate(
|
|
input_variables=["ai_name", "ai_service", "ai_role", "history", "knowledge"],
|
|
template = """
|
|
你的任务是基于给定的AI名称、AI角色业务、AI角色性格特点和聊天历史来回复用户。请仔细阅读以下信息,并按照指示进行回复。
|
|
你的名字是:{ai_name},你负责的业务是{ai_service},你具有{ai_role}的性格特点。
|
|
|
|
这是你和用户的对话历史
|
|
<history>
|
|
{history}
|
|
</history>
|
|
这是给你参考的知识库:
|
|
<knowledge>
|
|
{knowledge}
|
|
</knowledge>
|
|
{memory}
|
|
在回复时,请遵循以下指南:
|
|
1. 回复内容要与你负责的业务相关。
|
|
2. 回复的语气要结合你的性格特点。
|
|
3. 确保回复内容清晰、简洁、有针对性。
|
|
请生成你的回复:
|
|
"""
|
|
)
|
|
summarizeChain = summarizePrompt | llm
|
|
def summarize_ai(state: State):
|
|
"""AI 总结输出"""
|
|
mem = state['memory']
|
|
if mem != "":
|
|
memStr = """
|
|
这是给你参考的相关历史记忆:
|
|
<memory>
|
|
%s
|
|
</memory>
|
|
""" % mem # 这里用 % 把 mem 填进去
|
|
else:
|
|
memStr = "没有记忆内容"
|
|
print("历史记录是:" ,state["history"])
|
|
state["reply"] = summarizeChain.invoke({
|
|
"ai_role":state["ai_role"],
|
|
"ai_name":state["ai_name"],
|
|
"history":state["history"],
|
|
"ai_service":state['ai_service'],
|
|
"knowledge": state["knowledge"],
|
|
"memory": memStr,
|
|
}).content
|
|
return state
|
|
|
|
# ------------------------------------------------------------------------ 构建有向图 --------
|
|
workflow = StateGraph(State)
|
|
workflow.add_node("decide", decide_source)
|
|
workflow.add_node("db_search", db_search)
|
|
workflow.add_node("chat", chat)
|
|
workflow.add_node("summarize", summarize_ai)
|
|
workflow.set_entry_point("decide")
|
|
# 条件边:根据 path 决定走向
|
|
workflow.add_conditional_edges(
|
|
"decide",
|
|
lambda state: state["path"], # 返回 state["path"] 的值
|
|
{
|
|
"kn": "db_search",
|
|
"chat": "chat",
|
|
}
|
|
)
|
|
workflow.add_edge("db_search", "summarize")
|
|
workflow.add_edge("summarize", END)
|
|
workflow.add_edge("chat", END)
|
|
graph = workflow.compile()
|
|
|
|
# 执行函数
|
|
def get_service_agent_reply(aiId:str, userInput: str,history:str, kn_bases:List[str]) :
|
|
json = pgdb.get_ai_personality(aiId)
|
|
ai_service = json["业务"]
|
|
ai_role = json["性格"]
|
|
ai_name = json["名字"]
|
|
print("AI Name:", ai_name)
|
|
print("AI Service:", ai_service)
|
|
|
|
final_state = graph.invoke({
|
|
"ai_service":ai_service,
|
|
"ai_role":ai_role,
|
|
"ai_name":ai_name,
|
|
"history":history,
|
|
"kn_bases":kn_bases,
|
|
"table_info": pgdb.get_available_tables_str(aiId),
|
|
"userInput": userInput,
|
|
"ai_id": aiId,
|
|
})
|
|
return final_state["reply"] |