升级新库

This commit is contained in:
BBIT-Kai
2025-12-31 17:49:17 +08:00
parent d6c7f209c7
commit 6136554562
14 changed files with 355 additions and 356 deletions
+59 -51
View File
@@ -1,41 +1,34 @@
from langchain.prompts import PromptTemplate
from config.llm import llm
from config.ssDb import ssDBLC
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_community.agent_toolkits import create_sql_agent
from langchain.prompts import PromptTemplate
from config.llm import llm
from config.ssDb import ssDBLC
from typing import Annotated
from langgraph.graph.message import add_messages
import os
from langchain_tavily import TavilySearch
from langgraph.prebuilt import ToolNode, tools_condition
from llm.chatLLM import get_chat_response
from typing import TypedDict
from langchain_community.agent_toolkits import create_sql_agent
from langchain_core.prompts import PromptTemplate
from langchain_tavily import TavilySearch
from langgraph.graph import START
from langgraph.graph import StateGraph, END
from config.llm import llm
from config.ssDb import ssDBLC
from llm.chatLLM import get_chat_response
from llm.summarizeLLM import getSummary
# -------- 定义状态 --------
class State(TypedDict):
userInput: str # 用户输入
source: str # 选择的数据来源:web 或 db 或 chat
infomation: str # 查询到的内容
aiRole: str # AI 角色
history: str # 聊天历史
reply: str # 最终回复
userInput: str # 用户输入
source: str # 选择的数据来源:web 或 db 或 chat
infomation: str # 查询到的内容
aiRole: str # AI 角色
history: str # 聊天历史
reply: str # 最终回复
# -------- 定义节点 --------
# ------------------------------------------------------------------------ 路径选择 --------
pathSelectPrompt = PromptTemplate(
input_variables=["aiRole", "history", "userStr", "infomation"],
template = """
template="""
你是主干信息科技有限公司的业务员,是一家蚕桑服务公司,现在需要根据用户输入来判断应该使用哪种方式来回答用户的问题。
你有三种选择:
1. 如果用户的问题涉及最新的信息,比如新闻、事件、天气等涉及时间的内容时,请选择 "web
@@ -45,19 +38,26 @@ pathSelectPrompt = PromptTemplate(
用户最新输入:
{userStr}
请做出你的选择:
"""
""",
)
pathSelectChain = pathSelectPrompt | llm
def decide_source(state: State, max_retry=3):
print("根据用户输入选择数据来源,用户输入:", state["userInput"])
"""根据用户输入选择数据来源"""
for _ in range(max_retry):
choice = pathSelectChain.invoke({
"aiRole": state["aiRole"],
"history": state["history"],
"userStr": state["userInput"],
}).content.strip().lower()
choice = (
pathSelectChain.invoke(
{
"aiRole": state["aiRole"],
"history": state["history"],
"userStr": state["userInput"],
}
)
.content.strip()
.lower()
)
if choice in ["web", "db", "chat"]:
state["source"] = choice
break
@@ -72,36 +72,45 @@ def decide_source(state: State, max_retry=3):
os.environ["TAVILY_API_KEY"] = "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"
tool = TavilySearch(max_results=2)
def fetch_web(state: State):
result = tool.invoke(state["userInput"])
state["infomation"] = result.get("content") or result
state["infomation"] = result.get("content") or result
print("调用了联网工具,结果是:", state["infomation"])
return state
# ------------------------------------------------------------------------ 数据库查询 --------
agent = create_sql_agent(
llm=llm,
db=ssDBLC,
agent_type="tool-calling",
verbose=True
)
agent = create_sql_agent(llm=llm, db=ssDBLC, agent_type="tool-calling", verbose=True)
def fetch_db(state: State):
state["infomation"] = agent.invoke({"input": state["userInput"]})["output"]
print("调用了数据库工具,结果是:", state["infomation"])
return state
# ------------------------------------------------------------------------ 整理结果 --------
def summarize_ai(state: State):
"""AI 总结输出"""
state["reply"] = getSummary(aiRole=state["aiRole"], history=state["history"], userInput= state["userInput"], infomation= state["infomation"])
state["reply"] = getSummary(
aiRole=state["aiRole"],
history=state["history"],
userInput=state["userInput"],
infomation=state["infomation"],
)
return state
# ------------------------------------------------------------------------ 普通聊天 --------
def chat(state: State):
state["reply"] = get_chat_response(aiRole=state["aiRole"],history=state["history"], userInput= state["userInput"]).content
state["reply"] = get_chat_response(
aiRole=state["aiRole"], history=state["history"], userInput=state["userInput"]
).content
print("直接回复")
return state
# ------------------------------------------------------------------------ 构建有向图 --------
workflow = StateGraph(State)
workflow.add_node("decide", decide_source)
@@ -119,21 +128,20 @@ workflow.add_edge("fetch_db", "summarize")
workflow.add_conditional_edges(
"decide",
lambda state: state["source"], # 返回 state["source"] 的值
{
"web": "fetch_web",
"chat": "chat",
"db": "fetch_db"
}
{"web": "fetch_web", "chat": "chat", "db": "fetch_db"},
)
workflow.add_edge("summarize", END)
workflow.add_edge("chat", END)
graph = workflow.compile()
# 执行函数
def get_graph_output(aiRole:str,history: str, userInput: str) -> str:
final_state = graph.invoke({
"aiRole":aiRole,
"history": history,
"userInput": userInput,
})
return final_state["reply"]
def get_graph_output(aiRole: str, history: str, userInput: str) -> str:
final_state = graph.invoke(
{
"aiRole": aiRole,
"history": history,
"userInput": userInput,
}
)
return final_state["reply"]
+72 -77
View File
@@ -1,31 +1,12 @@
from typing import Literal
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from config.llm import llm, llmThink
from langgraph.graph import StateGraph, END
from langchain.prompts import PromptTemplate
from config.llm import llm
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_community.agent_toolkits import create_sql_agent
from langchain.prompts import PromptTemplate
from config.llm import llm
from typing import Annotated
from langgraph.graph.message import add_messages
import os
from langchain_tavily import TavilySearch
from langgraph.prebuilt import ToolNode, tools_condition
from llm.chatLLM import get_chat_response
from typing import TypedDict
from langchain_core.prompts import PromptTemplate
from langgraph.graph import StateGraph, END
from llm.summarizeLLM import getSummary
import db.postgres as pgdb
import db.sqlserver as sqlserver
from config.llm import llm
from config.llm import llmThink
# -------- 定义状态 --------
@@ -69,7 +50,7 @@ gen_sql_prompt = PromptTemplate(
请直接输出完整可执行的 SQL 语句,不要任何其他文字或格式化,例如反引号或 ```sql。
"""
""",
)
sqlChain = gen_sql_prompt | llm
@@ -89,33 +70,34 @@ fix_prompt = PromptTemplate(
# 输出要求
只返回修正后的 SQL 语句,不包含任何额外的解释或说明。
"""
""",
)
fixSQLChain = fix_prompt | llm
def sql(state: State):
if state["isFirstGenSQL"]:
state['sql'] = sql_1(state)
state["sql"] = sql_1(state)
else:
state['sql'] = sql_2(state)
state["sql"] = sql_2(state)
for attempt in range(2):
try:
# 执行 SQL
result = sqlserver.executeSQL(state['sql'])
state['sql_result'] = result
result = sqlserver.executeSQL(state["sql"])
state["sql_result"] = result
# print("SQL 执行成功,结果:", result)
break
except Exception as e:
error_msg = str(e)
print(f"SQL 执行出错: {error_msg}")
# 调用 LLM 修正 SQL
state['sql'] = fixSQLChain.invoke({
"sql": state['sql'],
"error_msg": error_msg,
"table_info": state['table_info'],
"tenant_id": state['tenant_id']
}
state["sql"] = fixSQLChain.invoke(
{
"sql": state["sql"],
"error_msg": error_msg,
"table_info": state["table_info"],
"tenant_id": state["tenant_id"],
}
).content
# print(f"LLM 生成修正 SQL: {state['sql']}")
else:
@@ -124,11 +106,13 @@ def sql(state: State):
def sql_1(state: State):
return sqlChain.invoke({
"table_info": state['table_info'],
"userInput": state["userInput"],
"tenant_id": state['tenant_id']
}).content
return sqlChain.invoke(
{
"table_info": state["table_info"],
"userInput": state["userInput"],
"tenant_id": state["tenant_id"],
}
).content
improve_sql_prompt = PromptTemplate(
@@ -151,18 +135,20 @@ improve_sql_prompt = PromptTemplate(
7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。
8. 查询的SQL字段要用别名,取名参考描述。
9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量加上WHERE tenantid = {tenant_id}
"""
""",
)
improveSqlChain = improve_sql_prompt | llm
def sql_2(state: State):
return improveSqlChain.invoke({
"sql": state['sql'],
"table_info": state['table_info'],
"userInput": state["userInput"],
"tenant_id": state['tenant_id']
}).content
return improveSqlChain.invoke(
{
"sql": state["sql"],
"table_info": state["table_info"],
"userInput": state["userInput"],
"tenant_id": state["tenant_id"],
}
).content
# ------------------------------------------------------------------------ 路径选择 --------
@@ -199,7 +185,7 @@ chat → 无法直接生成 SQL,需要进一步解释或澄清。
回答内容仅限于db或者chat,请勿输出其他内容。
你的回复:
"""
""",
)
pathSelectChain = pathSelectPrompt | llmThink
@@ -207,12 +193,18 @@ pathSelectChain = pathSelectPrompt | llmThink
def decide_source(state: State, max_retry=3):
"""根据用户输入选择数据来源"""
for _ in range(max_retry):
choice = pathSelectChain.invoke({
"userInput": state["userInput"],
"table_info": state["table_info"],
"ai_service": state["ai_service"],
"sql": state["sql"]
}).content.strip().lower()
choice = (
pathSelectChain.invoke(
{
"userInput": state["userInput"],
"table_info": state["table_info"],
"ai_service": state["ai_service"],
"sql": state["sql"],
}
)
.content.strip()
.lower()
)
print("根据用户输入选择数据来源,路径是:", choice)
if choice in ["db", "chat"]:
state["path"] = choice
@@ -242,17 +234,16 @@ noChatPrompt = PromptTemplate(
3. 引导用户提出与你业务相关的问题。
4. 使用礼貌和友好的语气。
你的回答:
"""
""",
)
noChatChain = noChatPrompt | llm
def chat(state: State):
state["reply"] = noChatChain.invoke({
"userInput": state["userInput"],
"ai_service": state["ai_service"]
}).content
state["reply"] = noChatChain.invoke(
{"userInput": state["userInput"], "ai_service": state["ai_service"]}
).content
print("直接回复")
return state
@@ -291,19 +282,21 @@ summarizePrompt = PromptTemplate(
2. 提供进一步可选的查询示例,基于当前的数据库表结构,引导用户提出更具体需求。
你的回复:
"""
""",
)
summarizeChain = summarizePrompt | llm
def summarize_ai(state: State):
"""AI 总结输出"""
state["reply"] = summarizeChain.invoke({
"ai_role": state["ai_role"],
"sql": state['sql'],
"userInput": state['userInput'],
"table_info": state['table_info'],
}).content
state["reply"] = summarizeChain.invoke(
{
"ai_role": state["ai_role"],
"sql": state["sql"],
"userInput": state["userInput"],
"table_info": state["table_info"],
}
).content
return state
@@ -322,7 +315,7 @@ workflow.add_conditional_edges(
{
"db": "sql_1",
"chat": "chat",
}
},
)
workflow.add_edge("summarize", END)
workflow.add_edge("chat", END)
@@ -334,13 +327,15 @@ def get_db_agent_reply(aiId: str, userInput: str, tenant_id: str, sql: str = "")
json = pgdb.get_ai_personality(aiId)
ai_service = json["业务"]
ai_role = json["性格"]
final_state = graph.invoke({
"ai_service": ai_service,
"ai_role": ai_role,
"table_info": pgdb.get_available_tables_str(aiId),
"tenant_id": tenant_id,
"userInput": userInput,
"sql": sql,
"isFirstGenSQL": sql == "",
})
final_state = graph.invoke(
{
"ai_service": ai_service,
"ai_role": ai_role,
"table_info": pgdb.get_available_tables_str(aiId),
"tenant_id": tenant_id,
"userInput": userInput,
"sql": sql,
"isFirstGenSQL": sql == "",
}
)
return final_state
+105 -93
View File
@@ -1,59 +1,39 @@
from typing import Literal
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from config.llm import llm,llmThink
from langgraph.graph import StateGraph, END
from langchain.prompts import PromptTemplate
from config.llm import llm
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_community.agent_toolkits import create_sql_agent
from langchain.prompts import PromptTemplate
from config.llm import llm
from typing import Annotated
from langgraph.graph.message import add_messages
import os
from langchain_tavily import TavilySearch
from langgraph.prebuilt import ToolNode, tools_condition
from llm.chatLLM import get_chat_response
from typing import List
from typing import TypedDict
from langchain_core.prompts import PromptTemplate
from langgraph.graph import StateGraph, END
from llm.summarizeLLM import getSummary
import db.postgres as pgdb
import db.sqlserver as sqlserver
from typing import List, Dict
import db.milvus as milvus
import db.postgres as pgdb
from config.llm import llm
from config.llm import llmThink
# -------- 定义状态 --------
class State(TypedDict):
path: str # 开始聊天选择的路径
path: str # 开始聊天选择的路径
memory:str # 记忆
knowledge: str # 知识库内容
history: str # 聊天历史
memory: str # 记忆
knowledge: str # 知识库内容
history: str # 聊天历史
ai_id : str # AI id
ai_name:str # AI 名称
ai_service: str # AI 角色 业务
ai_role: str # AI 角色 性格特点
kn_bases: List[str] # AI 所使用的知识库
ai_id: str # AI id
ai_name: str # AI 名称
ai_service: str # AI 角色 业务
ai_role: str # AI 角色 性格特点
kn_bases: List[str] # AI 所使用的知识库
userInput: str # 用户输入
reply: str # 最终回复
userInput: str # 用户输入
reply: str # 最终回复
# -------- 定义节点 --------
# ------------------------------------------------------------------------ 向量数据库查询 --------
gen_sql_prompt = PromptTemplate(
input_variables=["userInput"],
template = """你的任务是对用户输入进行意图分析,并将其分解成方便进行知识向量数据库搜索的关键词。
template="""你的任务是对用户输入进行意图分析,并将其分解成方便进行知识向量数据库搜索的关键词。
以下是用户的输入:
<用户输入>
{userInput}
@@ -65,28 +45,33 @@ gen_sql_prompt = PromptTemplate(
4. 确保关键词之间相互独立,不包含其他关键词。
关键词之间用空格分隔。
你的回答是:
"""
""",
)
sqlChain = gen_sql_prompt | llm
def db_search(state: State):
key_words = sqlChain.invoke({
"userInput": state['userInput'],
}).content
key_words = sqlChain.invoke(
{
"userInput": state["userInput"],
}
).content
print("关键词是:", key_words)
knowledge = milvus.get_knowledge_by_key_words(key_words, state['kn_bases'])
knowledge = milvus.get_knowledge_by_key_words(key_words, state["kn_bases"])
print("知识库内容是:", knowledge)
state["knowledge"] = knowledge
ai_ids = [state['ai_id']]
ai_ids = [state["ai_id"]]
memory = milvus.get_memory_by_key_words(key_words, ai_ids)
print("记忆是:", memory)
state["memory"] = memory
return state
# ------------------------------------------------------------------------ 意图分析 --------
pathSelectPrompt = PromptTemplate(
input_variables=[ "userInput","ai_service","history"],
template = """
input_variables=["userInput", "ai_service", "history"],
template="""
你是一个意图分类器,负责判断用户提问是否与你的工作相关,进而确定是否需要去查知识库。
以下是你负责的工作内容:
<ai_service>
@@ -103,17 +88,25 @@ pathSelectPrompt = PromptTemplate(
判断规则如下:
如果用户最新回复与你的负责工作相关,需要去查知识库,输出“kn”;如果不相关,则输出“chat”,不要包含任何标点符号以及空格。
你生成的结果:
"""
""",
)
pathSelectChain = pathSelectPrompt | llmThink
def decide_source(state: State, max_retry=3):
"""根据用户输入选择数据来源"""
for _ in range(max_retry):
choice = pathSelectChain.invoke({
"userInput": state["userInput"],
"ai_service": state["ai_service"],
"history": state["history"],
}).content.strip().lower()
choice = (
pathSelectChain.invoke(
{
"userInput": state["userInput"],
"ai_service": state["ai_service"],
"history": state["history"],
}
)
.content.strip()
.lower()
)
print("根据用户输入选择数据来源,路径是:", choice)
if choice in ["kn", "chat"]:
state["path"] = choice
@@ -123,10 +116,11 @@ def decide_source(state: State, max_retry=3):
state["path"] = "chat"
return state
# ------------------------------------------------------------------------ !普通聊天 --------
noChatPrompt = PromptTemplate(
input_variables=[ "ai_name", "ai_service", "ai_role", "history"],
template = """
input_variables=["ai_name", "ai_service", "ai_role", "history"],
template="""
你的名字是:{ai_name},你负责的业务是{ai_service},你具有{ai_role}的性格特点。
这是你和用户的对话历史
@@ -140,26 +134,31 @@ noChatPrompt = PromptTemplate(
4. 回复要简洁明了,避免冗长和复杂的表述。
你的回答:
"""
""",
)
noChatChain = noChatPrompt | llm
def chat(state: State):
state["reply"] = noChatChain.invoke({
"ai_name": state["ai_name"],
"ai_service": state["ai_service"],
"ai_role": state["ai_role"],
"history": state["history"],
"userStr": state["userInput"]
}).content
state["reply"] = noChatChain.invoke(
{
"ai_name": state["ai_name"],
"ai_service": state["ai_service"],
"ai_role": state["ai_role"],
"history": state["history"],
"userStr": state["userInput"],
}
).content
print("直接回复")
return state
# ------------------------------------------------------------------------ 整理结果 --------
summarizePrompt = PromptTemplate(
input_variables=["ai_name", "ai_service", "ai_role", "history", "knowledge"],
template = """
template="""
你的任务是基于给定的AI名称、AI角色业务、AI角色性格特点和聊天历史来回复用户。请仔细阅读以下信息,并按照指示进行回复。
你的名字是:{ai_name},你负责的业务是{ai_service},你具有{ai_role}的性格特点。
@@ -177,32 +176,40 @@ summarizePrompt = PromptTemplate(
2. 回复的语气要结合你的性格特点。
3. 确保回复内容清晰、简洁、有针对性。
请生成你的回复:
"""
""",
)
summarizeChain = summarizePrompt | llm
def summarize_ai(state: State):
"""AI 总结输出"""
mem = state['memory']
mem = state["memory"]
if mem != "":
memStr = """
memStr = (
"""
这是给你参考的相关历史记忆:
<memory>
%s
</memory>
""" % mem # 这里用 % 把 mem 填进去
"""
% mem
) # 这里用 % 把 mem 填进去
else:
memStr = "没有记忆内容"
print("历史记录是:" ,state["history"])
state["reply"] = summarizeChain.invoke({
"ai_role":state["ai_role"],
"ai_name":state["ai_name"],
"history":state["history"],
"ai_service":state['ai_service'],
"knowledge": state["knowledge"],
"memory": memStr,
}).content
print("历史记录是:", state["history"])
state["reply"] = summarizeChain.invoke(
{
"ai_role": state["ai_role"],
"ai_name": state["ai_name"],
"history": state["history"],
"ai_service": state["ai_service"],
"knowledge": state["knowledge"],
"memory": memStr,
}
).content
return state
# ------------------------------------------------------------------------ 构建有向图 --------
workflow = StateGraph(State)
workflow.add_node("decide", decide_source)
@@ -217,30 +224,35 @@ workflow.add_conditional_edges(
{
"kn": "db_search",
"chat": "chat",
}
},
)
workflow.add_edge("db_search", "summarize")
workflow.add_edge("summarize", END)
workflow.add_edge("chat", END)
graph = workflow.compile()
# 执行函数
def get_service_agent_reply(aiId:str, userInput: str,history:str, kn_bases:List[str]) :
def get_service_agent_reply(
aiId: str, userInput: str, history: str, kn_bases: List[str]
):
json = pgdb.get_ai_personality(aiId)
ai_service = json["业务"]
ai_role = json["性格"]
ai_role = json["性格"]
ai_name = json["名字"]
print("AI Name:", ai_name)
print("AI Service:", ai_service)
final_state = graph.invoke({
"ai_service":ai_service,
"ai_role":ai_role,
"ai_name":ai_name,
"history":history,
"kn_bases":kn_bases,
"table_info": pgdb.get_available_tables_str(aiId),
"userInput": userInput,
"ai_id": aiId,
})
return final_state["reply"]
final_state = graph.invoke(
{
"ai_service": ai_service,
"ai_role": ai_role,
"ai_name": ai_name,
"history": history,
"kn_bases": kn_bases,
"table_info": pgdb.get_available_tables_str(aiId),
"userInput": userInput,
"ai_id": aiId,
}
)
return final_state["reply"]