更新python后端
This commit is contained in:
@@ -0,0 +1,139 @@
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm
|
||||
from config.ssDb import ssDBLC
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
from langchain_community.agent_toolkits import create_sql_agent
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm
|
||||
from config.ssDb import ssDBLC
|
||||
from typing import Annotated
|
||||
from langgraph.graph.message import add_messages
|
||||
import os
|
||||
from langchain_tavily import TavilySearch
|
||||
from langgraph.prebuilt import ToolNode, tools_condition
|
||||
from llm.chatLLM import get_chat_response
|
||||
from typing import TypedDict
|
||||
from langgraph.graph import StateGraph, END
|
||||
from llm.summarizeLLM import getSummary
|
||||
|
||||
|
||||
# -------- 定义状态 --------
|
||||
class State(TypedDict):
|
||||
userInput: str # 用户输入
|
||||
source: str # 选择的数据来源:web 或 db 或 chat
|
||||
infomation: str # 查询到的内容
|
||||
aiRole: str # AI 角色
|
||||
history: str # 聊天历史
|
||||
reply: str # 最终回复
|
||||
|
||||
# -------- 定义节点 --------
|
||||
# ------------------------------------------------------------------------ 路径选择 --------
|
||||
|
||||
pathSelectPrompt = PromptTemplate(
|
||||
input_variables=["aiRole", "history", "userStr", "infomation"],
|
||||
template = """
|
||||
你是主干信息科技有限公司的业务员,是一家蚕桑服务公司,现在需要根据用户输入来判断应该使用哪种方式来回答用户的问题。
|
||||
你有三种选择:
|
||||
1. 如果用户的问题涉及最新的信息,比如新闻、事件、天气等涉及时间的内容时,请选择 "web
|
||||
2. 如果用户的问题涉及具体的蚕桑业务(例如询问农户、订单、订种、租户)的数据库查询需求,请选择 "db"
|
||||
3. 如果用户的问题是一般性的聊天或咨询,请选择 "chat"
|
||||
请只返回 "web"、"db" 或 "chat" 之一,且不要添加任何其他解释。
|
||||
用户最新输入:
|
||||
{userStr}
|
||||
请做出你的选择:
|
||||
"""
|
||||
)
|
||||
pathSelectChain = pathSelectPrompt | llm
|
||||
|
||||
def decide_source(state: State, max_retry=3):
|
||||
print("根据用户输入选择数据来源,用户输入:", state["userInput"])
|
||||
"""根据用户输入选择数据来源"""
|
||||
for _ in range(max_retry):
|
||||
choice = pathSelectChain.invoke({
|
||||
"aiRole": state["aiRole"],
|
||||
"history": state["history"],
|
||||
"userStr": state["userInput"],
|
||||
}).content.strip().lower()
|
||||
if choice in ["web", "db", "chat"]:
|
||||
state["source"] = choice
|
||||
break
|
||||
else:
|
||||
# 如果连续 max_retry 次都不合法,默认走 chat
|
||||
state["source"] = "chat"
|
||||
print("选择的数据来源是:", state["source"])
|
||||
return state
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------ 上网查询 --------
|
||||
os.environ["TAVILY_API_KEY"] = "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"
|
||||
tool = TavilySearch(max_results=2)
|
||||
|
||||
def fetch_web(state: State):
|
||||
result = tool.invoke(state["userInput"])
|
||||
state["infomation"] = result.get("content") or result
|
||||
print("调用了联网工具,结果是:", state["infomation"])
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ 数据库查询 --------
|
||||
agent = create_sql_agent(
|
||||
llm=llm,
|
||||
db=ssDBLC,
|
||||
agent_type="tool-calling",
|
||||
verbose=True
|
||||
)
|
||||
def fetch_db(state: State):
|
||||
state["infomation"] = agent.invoke({"input": state["userInput"]})["output"]
|
||||
print("调用了数据库工具,结果是:", state["infomation"])
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ 整理结果 --------
|
||||
def summarize_ai(state: State):
|
||||
"""AI 总结输出"""
|
||||
state["reply"] = getSummary(aiRole=state["aiRole"], history=state["history"], userInput= state["userInput"], infomation= state["infomation"])
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ 普通聊天 --------
|
||||
def chat(state: State):
|
||||
state["reply"] = get_chat_response(aiRole=state["aiRole"],history=state["history"], userInput= state["userInput"]).content
|
||||
print("直接回复")
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ 构建有向图 --------
|
||||
workflow = StateGraph(State)
|
||||
workflow.add_node("decide", decide_source)
|
||||
workflow.add_node("fetch_web", fetch_web)
|
||||
workflow.add_node("fetch_db", fetch_db)
|
||||
workflow.add_node("chat", chat)
|
||||
workflow.add_node("summarize", summarize_ai)
|
||||
workflow.set_entry_point("decide")
|
||||
|
||||
# 两条路径最后都汇合到 summarize
|
||||
workflow.add_edge(START, "decide")
|
||||
workflow.add_edge("fetch_web", "summarize")
|
||||
workflow.add_edge("fetch_db", "summarize")
|
||||
# 条件边:根据 source 决定走向
|
||||
workflow.add_conditional_edges(
|
||||
"decide",
|
||||
lambda state: state["source"], # 返回 state["source"] 的值
|
||||
{
|
||||
"web": "fetch_web",
|
||||
"chat": "chat",
|
||||
"db": "fetch_db"
|
||||
}
|
||||
)
|
||||
workflow.add_edge("summarize", END)
|
||||
workflow.add_edge("chat", END)
|
||||
graph = workflow.compile()
|
||||
|
||||
# 执行函数
|
||||
def get_graph_output(aiRole:str,history: str, userInput: str) -> str:
|
||||
final_state = graph.invoke({
|
||||
"aiRole":aiRole,
|
||||
"history": history,
|
||||
"userInput": userInput,
|
||||
})
|
||||
return final_state["reply"]
|
||||
@@ -0,0 +1,346 @@
|
||||
from typing import Literal
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph import END, START, MessagesState, StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from config.llm import llm, llmThink
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
from langchain_community.agent_toolkits import create_sql_agent
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm
|
||||
from typing import Annotated
|
||||
from langgraph.graph.message import add_messages
|
||||
import os
|
||||
from langchain_tavily import TavilySearch
|
||||
from langgraph.prebuilt import ToolNode, tools_condition
|
||||
from llm.chatLLM import get_chat_response
|
||||
from typing import TypedDict
|
||||
from langgraph.graph import StateGraph, END
|
||||
from llm.summarizeLLM import getSummary
|
||||
import db.postgres as pgdb
|
||||
import db.sqlserver as sqlserver
|
||||
|
||||
|
||||
# -------- 定义状态 --------
|
||||
class State(TypedDict):
|
||||
userInput: str # 用户输入
|
||||
path: str # 开始聊天选择的路径
|
||||
table_info: str # 可用表信息
|
||||
isFirstGenSQL: bool # 是否第一次生成SQL
|
||||
sql: str # 当前操作的SQL
|
||||
|
||||
ai_service: str # AI 角色 业务
|
||||
ai_role: str # AI 角色 性格特点
|
||||
tenant_id: str # 租户ID
|
||||
|
||||
history: str # 聊天历史
|
||||
reply: str # 最终回复
|
||||
|
||||
|
||||
# -------- 定义节点 --------
|
||||
# ------------------------------------------------------------------------ 数据库查询 --------
|
||||
|
||||
gen_sql_prompt = PromptTemplate(
|
||||
input_variables=["table_info", "userInput"],
|
||||
template="""
|
||||
# 角色
|
||||
你是一个企业 SQL Server 数据库 SQL 生成助手,负责根据用户问题生成相应的 SQL 语句。
|
||||
|
||||
# 已知信息
|
||||
可访问的表和字段:{table_info}
|
||||
|
||||
# 任务要求
|
||||
1. 根据用户提出的问题“{userInput}”生成 SQL 语句。
|
||||
2. 只能使用已知的表和字段。
|
||||
3. 输出完整可执行的 SQL 语句,不包含多余文字。
|
||||
4. 若 SQL 语句返回列表数据,需限制返回数量,最大为 15 条,使用 SQL Server 语法(TOP 15 或 OFFSET FETCH)。
|
||||
5. 若 SQL 是聚合查询(如 COUNT、SUM 等),无需限制行数。
|
||||
6. 在生成 SQL 时,如果需要根据身份证计算年龄,请使用 SQL Server 标准日期格式 SUBSTRING(idcard, 7, 8) 和 CONVERT(..., 112),不要使用拼接 / 或非标准日期格式。
|
||||
7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。
|
||||
8. 查询的SQL字段要用别名,取名参考描述。
|
||||
9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量限制租户id = {tenant_id}。
|
||||
|
||||
|
||||
请直接输出完整可执行的 SQL 语句,不要任何其他文字或格式化,例如反引号或 ```sql。
|
||||
"""
|
||||
)
|
||||
sqlChain = gen_sql_prompt | llm
|
||||
|
||||
fix_prompt = PromptTemplate(
|
||||
input_variables=["sql", "error_msg", "table_info", "tenant_id"],
|
||||
template="""
|
||||
# 系统角色
|
||||
你是一位专业的 SQL Server的SQL语句纠错专家,擅长识别 SQL 语句中的语法错误和字段引用错误,并能对其进行修正。
|
||||
|
||||
# 任务
|
||||
根据提供的原始 SQL 语句、执行报错信息以及可用表和字段信息,修正 SQL 语句,确保其语法正确且引用的字段存在。
|
||||
|
||||
# 输入信息
|
||||
- 原始 SQL: {sql}
|
||||
- 执行报错: {error_msg}
|
||||
- 可用表和字段: {table_info}
|
||||
|
||||
# 输出要求
|
||||
只返回修正后的 SQL 语句,不包含任何额外的解释或说明。
|
||||
"""
|
||||
)
|
||||
fixSQLChain = fix_prompt | llm
|
||||
|
||||
|
||||
def sql(state: State):
|
||||
if state["isFirstGenSQL"]:
|
||||
state['sql'] = sql_1(state)
|
||||
else:
|
||||
state['sql'] = sql_2(state)
|
||||
for attempt in range(2):
|
||||
try:
|
||||
# 执行 SQL
|
||||
result = sqlserver.executeSQL(state['sql'])
|
||||
state['sql_result'] = result
|
||||
# print("SQL 执行成功,结果:", result)
|
||||
break
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"SQL 执行出错: {error_msg}")
|
||||
# 调用 LLM 修正 SQL
|
||||
state['sql'] = fixSQLChain.invoke({
|
||||
"sql": state['sql'],
|
||||
"error_msg": error_msg,
|
||||
"table_info": state['table_info'],
|
||||
"tenant_id": state['tenant_id']
|
||||
}
|
||||
).content
|
||||
# print(f"LLM 生成修正 SQL: {state['sql']}")
|
||||
else:
|
||||
raise RuntimeError(f"SQL 多次纠错失败,最后 SQL: {state['sql']}")
|
||||
return state
|
||||
|
||||
|
||||
def sql_1(state: State):
|
||||
return sqlChain.invoke({
|
||||
"table_info": state['table_info'],
|
||||
"userInput": state["userInput"],
|
||||
"tenant_id": state['tenant_id']
|
||||
}).content
|
||||
|
||||
|
||||
improve_sql_prompt = PromptTemplate(
|
||||
input_variables=["table_info", "userInput", "tenant_id"],
|
||||
template="""
|
||||
# 角色
|
||||
你是一个企业 SQL Server 数据库 SQL 生成助手,负责根据用户问题改进相应的 SQL 语句。
|
||||
|
||||
# 已知信息
|
||||
当前 SQL 语句: {sql}
|
||||
可访问的表和字段: {table_info}
|
||||
|
||||
# 任务要求
|
||||
1. 根据用户提出的问题“{userInput}”以及当前的 SQL 语句进行改进。
|
||||
2. 只能使用已知的表和字段。
|
||||
3. 输出完整可执行的 SQL 语句,不包含多余文字。
|
||||
4. 若 SQL 语句返回列表数据,需限制返回数量,最大为 15 条,使用 SQL Server 语法(TOP 15 或 OFFSET FETCH)。
|
||||
5. 若 SQL 是聚合查询(如 COUNT、SUM 等),无需限制行数。
|
||||
6. 在生成 SQL 时,如果需要根据身份证计算年龄,请使用 SQL Server 标准日期格式 SUBSTRING(idcard, 7, 8) 和 CONVERT(..., 112),不要使用拼接 / 或非标准日期格式。
|
||||
7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。
|
||||
8. 查询的SQL字段要用别名,取名参考描述。
|
||||
9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量加上WHERE tenantid = {tenant_id}。
|
||||
"""
|
||||
)
|
||||
improveSqlChain = improve_sql_prompt | llm
|
||||
|
||||
|
||||
def sql_2(state: State):
|
||||
return improveSqlChain.invoke({
|
||||
"sql": state['sql'],
|
||||
"table_info": state['table_info'],
|
||||
"userInput": state["userInput"],
|
||||
"tenant_id": state['tenant_id']
|
||||
}).content
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------ 路径选择 --------
|
||||
|
||||
pathSelectPrompt = PromptTemplate(
|
||||
input_variables=["userInput", "table_info", "sql"],
|
||||
template="""
|
||||
你的任务是:
|
||||
|
||||
1. 根据用户输入的问题和已知的表结构数据,判断是否能够生成准确的 SQL 查询。
|
||||
|
||||
2. 首先仔细阅读以下表结构数据:
|
||||
<table_info>
|
||||
{table_info}
|
||||
</table_info>
|
||||
|
||||
2. 然后仔细阅读用户输入的问题:
|
||||
<userInput>
|
||||
{userInput}
|
||||
</userInput>
|
||||
|
||||
3. 请严格遵循以下规则:
|
||||
|
||||
只有在能够完全、明确、直接根据表结构生成正确 SQL 时,输出 db。
|
||||
|
||||
参考表结构或字段描述中出现的关键词:如果用户问题中出现的关键字段或概念在表结构中找不到明确对应关系,或者问题逻辑无法直接映射到表结构,输出 chat。
|
||||
|
||||
不允许假设额外存在的表、字段或数据,也不允许基于常识或推测生成 SQL。
|
||||
|
||||
输出必须严格二选一:
|
||||
|
||||
db → 可以直接生成 SQL。
|
||||
chat → 无法直接生成 SQL,需要进一步解释或澄清。
|
||||
|
||||
回答内容仅限于db或者chat,请勿输出其他内容。
|
||||
你的回复:
|
||||
"""
|
||||
)
|
||||
pathSelectChain = pathSelectPrompt | llmThink
|
||||
|
||||
|
||||
def decide_source(state: State, max_retry=3):
|
||||
"""根据用户输入选择数据来源"""
|
||||
for _ in range(max_retry):
|
||||
choice = pathSelectChain.invoke({
|
||||
"userInput": state["userInput"],
|
||||
"table_info": state["table_info"],
|
||||
"ai_service": state["ai_service"],
|
||||
"sql": state["sql"]
|
||||
}).content.strip().lower()
|
||||
print("根据用户输入选择数据来源,路径是:", choice)
|
||||
if choice in ["db", "chat"]:
|
||||
state["path"] = choice
|
||||
break
|
||||
else:
|
||||
# 如果连续 max_retry 次都不合法,默认走 chat
|
||||
state["path"] = "chat"
|
||||
return state
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------ !普通聊天 --------
|
||||
noChatPrompt = PromptTemplate(
|
||||
input_variables=["userInput", "ai_service"],
|
||||
template="""
|
||||
你的任务是回复用户,告知用户你目前无法处理他们的回复,因为你的业务是特定领域的服务。请仔细阅读以下信息,并按照指示进行回复。
|
||||
用户的回复:
|
||||
<userInput>
|
||||
{userInput}
|
||||
</userInput>
|
||||
你的业务:
|
||||
<ai_service>
|
||||
{ai_service}
|
||||
</ai_service>
|
||||
在回复时,请遵循以下指南:
|
||||
1. 明确告知用户你无法处理当前回复。
|
||||
2. 提及你的业务是{ai_service}。
|
||||
3. 引导用户提出与你业务相关的问题。
|
||||
4. 使用礼貌和友好的语气。
|
||||
你的回答:
|
||||
"""
|
||||
)
|
||||
|
||||
noChatChain = noChatPrompt | llm
|
||||
|
||||
|
||||
def chat(state: State):
|
||||
state["reply"] = noChatChain.invoke({
|
||||
"userInput": state["userInput"],
|
||||
"ai_service": state["ai_service"]
|
||||
}).content
|
||||
print("直接回复")
|
||||
return state
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------ 整理结果 --------
|
||||
|
||||
summarizePrompt = PromptTemplate(
|
||||
input_variables=["ai_role", "history", "userStr", "table_info"],
|
||||
template="""
|
||||
你是主干信息研发的AI助手,你的性格特点为:
|
||||
<ai_role>
|
||||
{ai_role}
|
||||
</ai_role>
|
||||
用户之前的提问为:
|
||||
<userInput>
|
||||
{userInput}
|
||||
</userInput>
|
||||
当前生成的SQL语句为:
|
||||
<sql>
|
||||
{sql}
|
||||
</sql>
|
||||
当前支持的数据库表与字段信息如下:
|
||||
<table_info>
|
||||
{table_info}
|
||||
</table_info>
|
||||
|
||||
你的核心任务是根据用户之前的提问和当前生成的SQL语句,引导用户理解当前SQL的含义,并询问是否需要修改或完善,同时提供进一步可选的查询示例,引导用户提出更具体的需求。
|
||||
|
||||
交流要求如下:
|
||||
- 先明确SQL的用途,再提出引导性问题。
|
||||
- 回答要简洁、易理解。
|
||||
- 回复内容不要出现SQL语句,不要对SQL进行解释,只需说,查询结果已生成,然后引导用户进一步提问。
|
||||
|
||||
任务流程如下:
|
||||
1. 询问用户是否需要对当前查询内容进行修改或完善。
|
||||
2. 提供进一步可选的查询示例,基于当前的数据库表结构,引导用户提出更具体需求。
|
||||
|
||||
你的回复:
|
||||
"""
|
||||
)
|
||||
summarizeChain = summarizePrompt | llm
|
||||
|
||||
|
||||
def summarize_ai(state: State):
|
||||
"""AI 总结输出"""
|
||||
state["reply"] = summarizeChain.invoke({
|
||||
"ai_role": state["ai_role"],
|
||||
"sql": state['sql'],
|
||||
"userInput": state['userInput'],
|
||||
"table_info": state['table_info'],
|
||||
}).content
|
||||
return state
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------ 构建有向图 --------
|
||||
workflow = StateGraph(State)
|
||||
workflow.add_node("decide", decide_source)
|
||||
workflow.add_node("sql_1", sql)
|
||||
workflow.add_node("chat", chat)
|
||||
workflow.add_node("summarize", summarize_ai)
|
||||
workflow.set_entry_point("decide")
|
||||
workflow.add_edge("sql_1", "summarize")
|
||||
# 条件边:根据 path 决定走向
|
||||
workflow.add_conditional_edges(
|
||||
"decide",
|
||||
lambda state: state["path"], # 返回 state["path"] 的值
|
||||
{
|
||||
"db": "sql_1",
|
||||
"chat": "chat",
|
||||
}
|
||||
)
|
||||
workflow.add_edge("summarize", END)
|
||||
workflow.add_edge("chat", END)
|
||||
graph = workflow.compile()
|
||||
|
||||
|
||||
# 执行函数
|
||||
def get_db_agent_reply(aiId: str, userInput: str, tenant_id: str, sql: str = "") -> str:
|
||||
json = pgdb.get_ai_personality(aiId)
|
||||
ai_service = json["业务"]
|
||||
ai_role = json["性格"]
|
||||
final_state = graph.invoke({
|
||||
"ai_service": ai_service,
|
||||
"ai_role": ai_role,
|
||||
"table_info": pgdb.get_available_tables_str(aiId),
|
||||
"tenant_id": tenant_id,
|
||||
"userInput": userInput,
|
||||
"sql": sql,
|
||||
"isFirstGenSQL": sql == "",
|
||||
})
|
||||
return final_state
|
||||
@@ -0,0 +1,246 @@
|
||||
|
||||
from typing import Literal
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph import END, START, MessagesState, StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from config.llm import llm,llmThink
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
from langchain_community.agent_toolkits import create_sql_agent
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm
|
||||
from typing import Annotated
|
||||
from langgraph.graph.message import add_messages
|
||||
import os
|
||||
from langchain_tavily import TavilySearch
|
||||
from langgraph.prebuilt import ToolNode, tools_condition
|
||||
from llm.chatLLM import get_chat_response
|
||||
from typing import TypedDict
|
||||
from langgraph.graph import StateGraph, END
|
||||
from llm.summarizeLLM import getSummary
|
||||
import db.postgres as pgdb
|
||||
import db.sqlserver as sqlserver
|
||||
from typing import List, Dict
|
||||
import db.milvus as milvus
|
||||
|
||||
|
||||
# -------- 定义状态 --------
|
||||
class State(TypedDict):
|
||||
path: str # 开始聊天选择的路径
|
||||
|
||||
memory:str # 记忆
|
||||
knowledge: str # 知识库内容
|
||||
history: str # 聊天历史
|
||||
|
||||
ai_id : str # AI id
|
||||
ai_name:str # AI 名称
|
||||
ai_service: str # AI 角色 业务
|
||||
ai_role: str # AI 角色 性格特点
|
||||
kn_bases: List[str] # AI 所使用的知识库
|
||||
|
||||
userInput: str # 用户输入
|
||||
reply: str # 最终回复
|
||||
|
||||
# -------- 定义节点 --------
|
||||
# ------------------------------------------------------------------------ 向量数据库查询 --------
|
||||
|
||||
gen_sql_prompt = PromptTemplate(
|
||||
input_variables=["userInput"],
|
||||
template = """你的任务是对用户输入进行意图分析,并将其分解成方便进行知识向量数据库搜索的关键词。
|
||||
以下是用户的输入:
|
||||
<用户输入>
|
||||
{userInput}
|
||||
</用户输入>
|
||||
在提取关键词时,请遵循以下方法和要求:
|
||||
1. 去除输入中的停用词(如“的”“是”“在”等)、语气词和无实际意义的符号。
|
||||
2. 识别输入中的核心概念、实体和关键动作。
|
||||
3. 尽量使用简洁、通用的词汇作为关键词。
|
||||
4. 确保关键词之间相互独立,不包含其他关键词。
|
||||
关键词之间用空格分隔。
|
||||
你的回答是:
|
||||
"""
|
||||
)
|
||||
sqlChain = gen_sql_prompt | llm
|
||||
def db_search(state: State):
|
||||
key_words = sqlChain.invoke({
|
||||
"userInput": state['userInput'],
|
||||
}).content
|
||||
print("关键词是:", key_words)
|
||||
knowledge = milvus.get_knowledge_by_key_words(key_words, state['kn_bases'])
|
||||
print("知识库内容是:", knowledge)
|
||||
state["knowledge"] = knowledge
|
||||
ai_ids = [state['ai_id']]
|
||||
memory = milvus.get_memory_by_key_words(key_words, ai_ids)
|
||||
print("记忆是:", memory)
|
||||
state["memory"] = memory
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ 意图分析 --------
|
||||
|
||||
pathSelectPrompt = PromptTemplate(
|
||||
input_variables=[ "userInput","ai_service","history"],
|
||||
template = """
|
||||
你是一个意图分类器,负责判断用户提问是否与你的工作相关,进而确定是否需要去查知识库。
|
||||
以下是你负责的工作内容:
|
||||
<ai_service>
|
||||
{ai_service}
|
||||
</ai_service>
|
||||
这是你们的对话历史:
|
||||
<history>
|
||||
{history}
|
||||
</history>
|
||||
用户最新回复是:
|
||||
<userInput>
|
||||
{userInput}
|
||||
</userInput>
|
||||
判断规则如下:
|
||||
如果用户最新回复与你的负责工作相关,需要去查知识库,输出“kn”;如果不相关,则输出“chat”,不要包含任何标点符号以及空格。
|
||||
你生成的结果:
|
||||
"""
|
||||
)
|
||||
pathSelectChain = pathSelectPrompt | llmThink
|
||||
def decide_source(state: State, max_retry=3):
|
||||
"""根据用户输入选择数据来源"""
|
||||
for _ in range(max_retry):
|
||||
choice = pathSelectChain.invoke({
|
||||
"userInput": state["userInput"],
|
||||
"ai_service": state["ai_service"],
|
||||
"history": state["history"],
|
||||
}).content.strip().lower()
|
||||
print("根据用户输入选择数据来源,路径是:", choice)
|
||||
if choice in ["kn", "chat"]:
|
||||
state["path"] = choice
|
||||
break
|
||||
else:
|
||||
# 如果连续 max_retry 次都不合法,默认走 chat
|
||||
state["path"] = "chat"
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ !普通聊天 --------
|
||||
noChatPrompt = PromptTemplate(
|
||||
input_variables=[ "ai_name", "ai_service", "ai_role", "history"],
|
||||
template = """
|
||||
你的名字是:{ai_name},你负责的业务是{ai_service},你具有{ai_role}的性格特点。
|
||||
|
||||
这是你和用户的对话历史
|
||||
<history>
|
||||
{history}
|
||||
</history>
|
||||
在回复用户时,请遵循以下指南:
|
||||
1. 回复要与AI角色业务相关,体现AI的专业能力。
|
||||
2. 回复内容的语气和风格要符合AI角色性格特点。
|
||||
3. 参考聊天历史,使回复具有连贯性和针对性。
|
||||
4. 回复要简洁明了,避免冗长和复杂的表述。
|
||||
|
||||
你的回答:
|
||||
"""
|
||||
)
|
||||
|
||||
noChatChain = noChatPrompt | llm
|
||||
def chat(state: State):
|
||||
state["reply"] = noChatChain.invoke({
|
||||
"ai_name": state["ai_name"],
|
||||
"ai_service": state["ai_service"],
|
||||
"ai_role": state["ai_role"],
|
||||
"history": state["history"],
|
||||
"userStr": state["userInput"]
|
||||
}).content
|
||||
print("直接回复")
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ 整理结果 --------
|
||||
|
||||
summarizePrompt = PromptTemplate(
|
||||
input_variables=["ai_name", "ai_service", "ai_role", "history", "knowledge"],
|
||||
template = """
|
||||
你的任务是基于给定的AI名称、AI角色业务、AI角色性格特点和聊天历史来回复用户。请仔细阅读以下信息,并按照指示进行回复。
|
||||
你的名字是:{ai_name},你负责的业务是{ai_service},你具有{ai_role}的性格特点。
|
||||
|
||||
这是你和用户的对话历史
|
||||
<history>
|
||||
{history}
|
||||
</history>
|
||||
这是给你参考的知识库:
|
||||
<knowledge>
|
||||
{knowledge}
|
||||
</knowledge>
|
||||
{memory}
|
||||
在回复时,请遵循以下指南:
|
||||
1. 回复内容要与你负责的业务相关。
|
||||
2. 回复的语气要结合你的性格特点。
|
||||
3. 确保回复内容清晰、简洁、有针对性。
|
||||
请生成你的回复:
|
||||
"""
|
||||
)
|
||||
summarizeChain = summarizePrompt | llm
|
||||
def summarize_ai(state: State):
|
||||
"""AI 总结输出"""
|
||||
mem = state['memory']
|
||||
if mem != "":
|
||||
memStr = """
|
||||
这是给你参考的相关历史记忆:
|
||||
<memory>
|
||||
%s
|
||||
</memory>
|
||||
""" % mem # 这里用 % 把 mem 填进去
|
||||
else:
|
||||
memStr = "没有记忆内容"
|
||||
print("历史记录是:" ,state["history"])
|
||||
state["reply"] = summarizeChain.invoke({
|
||||
"ai_role":state["ai_role"],
|
||||
"ai_name":state["ai_name"],
|
||||
"history":state["history"],
|
||||
"ai_service":state['ai_service'],
|
||||
"knowledge": state["knowledge"],
|
||||
"memory": memStr,
|
||||
}).content
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ 构建有向图 --------
|
||||
workflow = StateGraph(State)
|
||||
workflow.add_node("decide", decide_source)
|
||||
workflow.add_node("db_search", db_search)
|
||||
workflow.add_node("chat", chat)
|
||||
workflow.add_node("summarize", summarize_ai)
|
||||
workflow.set_entry_point("decide")
|
||||
# 条件边:根据 path 决定走向
|
||||
workflow.add_conditional_edges(
|
||||
"decide",
|
||||
lambda state: state["path"], # 返回 state["path"] 的值
|
||||
{
|
||||
"kn": "db_search",
|
||||
"chat": "chat",
|
||||
}
|
||||
)
|
||||
workflow.add_edge("db_search", "summarize")
|
||||
workflow.add_edge("summarize", END)
|
||||
workflow.add_edge("chat", END)
|
||||
graph = workflow.compile()
|
||||
|
||||
# 执行函数
|
||||
def get_service_agent_reply(aiId:str, userInput: str,history:str, kn_bases:List[str]) :
|
||||
json = pgdb.get_ai_personality(aiId)
|
||||
ai_service = json["业务"]
|
||||
ai_role = json["性格"]
|
||||
ai_name = json["名字"]
|
||||
print("AI Name:", ai_name)
|
||||
print("AI Service:", ai_service)
|
||||
|
||||
final_state = graph.invoke({
|
||||
"ai_service":ai_service,
|
||||
"ai_role":ai_role,
|
||||
"ai_name":ai_name,
|
||||
"history":history,
|
||||
"kn_bases":kn_bases,
|
||||
"table_info": pgdb.get_available_tables_str(aiId),
|
||||
"userInput": userInput,
|
||||
"ai_id": aiId,
|
||||
})
|
||||
return final_state["reply"]
|
||||
Reference in New Issue
Block a user