升级新库

This commit is contained in:
BBIT-Kai
2025-12-31 17:49:17 +08:00
parent d6c7f209c7
commit 6136554562
14 changed files with 355 additions and 356 deletions
+3 -30
View File
@@ -1,10 +1,10 @@
# 使用官方 Python 镜像 FROM ubuntu:22.04
FROM python:3.10-slim
WORKDIR /app WORKDIR /app
RUN apt-get update && \ RUN apt-get update && \
apt-get install -y --no-install-recommends \ apt-get install -y --no-install-recommends \
ca-certificates \
libpq5 \ libpq5 \
unixodbc \ unixodbc \
curl \ curl \
@@ -20,31 +20,4 @@ RUN apt-get update && \
ACCEPT_EULA=Y apt-get install -y msodbcsql18 && \ ACCEPT_EULA=Y apt-get install -y msodbcsql18 && \
rm -rf /var/lib/apt/lists/* rm -rf /var/lib/apt/lists/*
COPY app/requirements.txt . COPY app/ /app
# 安装 Python 依赖
RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
RUN python -m pip uninstall -y opencv-python
RUN python -m pip install opencv-python-headless
# 复制并解压 JRE
COPY docker/OpenJDK17U-jre_x64_linux_hotspot_17.0.16_8.tar.gz /opt/
RUN tar -xzf /opt/OpenJDK17U-jre_x64_linux_hotspot_17.0.16_8.tar.gz -C /opt/ && \
rm /opt/OpenJDK17U-jre_x64_linux_hotspot_17.0.16_8.tar.gz
# 配置 Java 环境
ENV JAVA_HOME=/opt/jdk-17.0.16+8-jre
ENV PATH="$JAVA_HOME/bin:$PATH"
# 复制项目代码
COPY app/ .
# 复制 pyzxing 的 jar 文件到默认路径
COPY docker/javase-3.4.1-SNAPSHOT-jar-with-dependencies.jar /root/.local/pyzxing/javase-3.4.1-SNAPSHOT-jar-with-dependencies.jar
EXPOSE 13011
# 启动命令(使用 uvicorn 启动 FastAPI
CMD ["python", "app.py"]
+43 -35
View File
@@ -1,23 +1,15 @@
from langchain.prompts import PromptTemplate
from config.llm import llm
from config.ssDb import ssDBLC
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_community.agent_toolkits import create_sql_agent
from langchain.prompts import PromptTemplate
from config.llm import llm
from config.ssDb import ssDBLC
from typing import Annotated
from langgraph.graph.message import add_messages
import os import os
from langchain_tavily import TavilySearch
from langgraph.prebuilt import ToolNode, tools_condition
from llm.chatLLM import get_chat_response
from typing import TypedDict from typing import TypedDict
from langchain_community.agent_toolkits import create_sql_agent
from langchain_core.prompts import PromptTemplate
from langchain_tavily import TavilySearch
from langgraph.graph import START
from langgraph.graph import StateGraph, END from langgraph.graph import StateGraph, END
from config.llm import llm
from config.ssDb import ssDBLC
from llm.chatLLM import get_chat_response
from llm.summarizeLLM import getSummary from llm.summarizeLLM import getSummary
@@ -30,6 +22,7 @@ class State(TypedDict):
history: str # 聊天历史 history: str # 聊天历史
reply: str # 最终回复 reply: str # 最终回复
# -------- 定义节点 -------- # -------- 定义节点 --------
# ------------------------------------------------------------------------ 路径选择 -------- # ------------------------------------------------------------------------ 路径选择 --------
@@ -45,19 +38,26 @@ pathSelectPrompt = PromptTemplate(
用户最新输入: 用户最新输入:
{userStr} {userStr}
请做出你的选择: 请做出你的选择:
""" """,
) )
pathSelectChain = pathSelectPrompt | llm pathSelectChain = pathSelectPrompt | llm
def decide_source(state: State, max_retry=3): def decide_source(state: State, max_retry=3):
print("根据用户输入选择数据来源,用户输入:", state["userInput"]) print("根据用户输入选择数据来源,用户输入:", state["userInput"])
"""根据用户输入选择数据来源""" """根据用户输入选择数据来源"""
for _ in range(max_retry): for _ in range(max_retry):
choice = pathSelectChain.invoke({ choice = (
pathSelectChain.invoke(
{
"aiRole": state["aiRole"], "aiRole": state["aiRole"],
"history": state["history"], "history": state["history"],
"userStr": state["userInput"], "userStr": state["userInput"],
}).content.strip().lower() }
)
.content.strip()
.lower()
)
if choice in ["web", "db", "chat"]: if choice in ["web", "db", "chat"]:
state["source"] = choice state["source"] = choice
break break
@@ -72,36 +72,45 @@ def decide_source(state: State, max_retry=3):
os.environ["TAVILY_API_KEY"] = "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" os.environ["TAVILY_API_KEY"] = "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"
tool = TavilySearch(max_results=2) tool = TavilySearch(max_results=2)
def fetch_web(state: State): def fetch_web(state: State):
result = tool.invoke(state["userInput"]) result = tool.invoke(state["userInput"])
state["infomation"] = result.get("content") or result state["infomation"] = result.get("content") or result
print("调用了联网工具,结果是:", state["infomation"]) print("调用了联网工具,结果是:", state["infomation"])
return state return state
# ------------------------------------------------------------------------ 数据库查询 -------- # ------------------------------------------------------------------------ 数据库查询 --------
agent = create_sql_agent( agent = create_sql_agent(llm=llm, db=ssDBLC, agent_type="tool-calling", verbose=True)
llm=llm,
db=ssDBLC,
agent_type="tool-calling",
verbose=True
)
def fetch_db(state: State): def fetch_db(state: State):
state["infomation"] = agent.invoke({"input": state["userInput"]})["output"] state["infomation"] = agent.invoke({"input": state["userInput"]})["output"]
print("调用了数据库工具,结果是:", state["infomation"]) print("调用了数据库工具,结果是:", state["infomation"])
return state return state
# ------------------------------------------------------------------------ 整理结果 -------- # ------------------------------------------------------------------------ 整理结果 --------
def summarize_ai(state: State): def summarize_ai(state: State):
"""AI 总结输出""" """AI 总结输出"""
state["reply"] = getSummary(aiRole=state["aiRole"], history=state["history"], userInput= state["userInput"], infomation= state["infomation"]) state["reply"] = getSummary(
aiRole=state["aiRole"],
history=state["history"],
userInput=state["userInput"],
infomation=state["infomation"],
)
return state return state
# ------------------------------------------------------------------------ 普通聊天 -------- # ------------------------------------------------------------------------ 普通聊天 --------
def chat(state: State): def chat(state: State):
state["reply"] = get_chat_response(aiRole=state["aiRole"],history=state["history"], userInput= state["userInput"]).content state["reply"] = get_chat_response(
aiRole=state["aiRole"], history=state["history"], userInput=state["userInput"]
).content
print("直接回复") print("直接回复")
return state return state
# ------------------------------------------------------------------------ 构建有向图 -------- # ------------------------------------------------------------------------ 构建有向图 --------
workflow = StateGraph(State) workflow = StateGraph(State)
workflow.add_node("decide", decide_source) workflow.add_node("decide", decide_source)
@@ -119,21 +128,20 @@ workflow.add_edge("fetch_db", "summarize")
workflow.add_conditional_edges( workflow.add_conditional_edges(
"decide", "decide",
lambda state: state["source"], # 返回 state["source"] 的值 lambda state: state["source"], # 返回 state["source"] 的值
{ {"web": "fetch_web", "chat": "chat", "db": "fetch_db"},
"web": "fetch_web",
"chat": "chat",
"db": "fetch_db"
}
) )
workflow.add_edge("summarize", END) workflow.add_edge("summarize", END)
workflow.add_edge("chat", END) workflow.add_edge("chat", END)
graph = workflow.compile() graph = workflow.compile()
# 执行函数 # 执行函数
def get_graph_output(aiRole: str, history: str, userInput: str) -> str: def get_graph_output(aiRole: str, history: str, userInput: str) -> str:
final_state = graph.invoke({ final_state = graph.invoke(
{
"aiRole": aiRole, "aiRole": aiRole,
"history": history, "history": history,
"userInput": userInput, "userInput": userInput,
}) }
)
return final_state["reply"] return final_state["reply"]
+57 -62
View File
@@ -1,31 +1,12 @@
from typing import Literal
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from config.llm import llm, llmThink
from langgraph.graph import StateGraph, END
from langchain.prompts import PromptTemplate
from config.llm import llm
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_community.agent_toolkits import create_sql_agent
from langchain.prompts import PromptTemplate
from config.llm import llm
from typing import Annotated
from langgraph.graph.message import add_messages
import os
from langchain_tavily import TavilySearch
from langgraph.prebuilt import ToolNode, tools_condition
from llm.chatLLM import get_chat_response
from typing import TypedDict from typing import TypedDict
from langchain_core.prompts import PromptTemplate
from langgraph.graph import StateGraph, END from langgraph.graph import StateGraph, END
from llm.summarizeLLM import getSummary
import db.postgres as pgdb import db.postgres as pgdb
import db.sqlserver as sqlserver import db.sqlserver as sqlserver
from config.llm import llm
from config.llm import llmThink
# -------- 定义状态 -------- # -------- 定义状态 --------
@@ -69,7 +50,7 @@ gen_sql_prompt = PromptTemplate(
请直接输出完整可执行的 SQL 语句,不要任何其他文字或格式化,例如反引号或 ```sql。 请直接输出完整可执行的 SQL 语句,不要任何其他文字或格式化,例如反引号或 ```sql。
""" """,
) )
sqlChain = gen_sql_prompt | llm sqlChain = gen_sql_prompt | llm
@@ -89,32 +70,33 @@ fix_prompt = PromptTemplate(
# 输出要求 # 输出要求
只返回修正后的 SQL 语句,不包含任何额外的解释或说明。 只返回修正后的 SQL 语句,不包含任何额外的解释或说明。
""" """,
) )
fixSQLChain = fix_prompt | llm fixSQLChain = fix_prompt | llm
def sql(state: State): def sql(state: State):
if state["isFirstGenSQL"]: if state["isFirstGenSQL"]:
state['sql'] = sql_1(state) state["sql"] = sql_1(state)
else: else:
state['sql'] = sql_2(state) state["sql"] = sql_2(state)
for attempt in range(2): for attempt in range(2):
try: try:
# 执行 SQL # 执行 SQL
result = sqlserver.executeSQL(state['sql']) result = sqlserver.executeSQL(state["sql"])
state['sql_result'] = result state["sql_result"] = result
# print("SQL 执行成功,结果:", result) # print("SQL 执行成功,结果:", result)
break break
except Exception as e: except Exception as e:
error_msg = str(e) error_msg = str(e)
print(f"SQL 执行出错: {error_msg}") print(f"SQL 执行出错: {error_msg}")
# 调用 LLM 修正 SQL # 调用 LLM 修正 SQL
state['sql'] = fixSQLChain.invoke({ state["sql"] = fixSQLChain.invoke(
"sql": state['sql'], {
"sql": state["sql"],
"error_msg": error_msg, "error_msg": error_msg,
"table_info": state['table_info'], "table_info": state["table_info"],
"tenant_id": state['tenant_id'] "tenant_id": state["tenant_id"],
} }
).content ).content
# print(f"LLM 生成修正 SQL: {state['sql']}") # print(f"LLM 生成修正 SQL: {state['sql']}")
@@ -124,11 +106,13 @@ def sql(state: State):
def sql_1(state: State): def sql_1(state: State):
return sqlChain.invoke({ return sqlChain.invoke(
"table_info": state['table_info'], {
"table_info": state["table_info"],
"userInput": state["userInput"], "userInput": state["userInput"],
"tenant_id": state['tenant_id'] "tenant_id": state["tenant_id"],
}).content }
).content
improve_sql_prompt = PromptTemplate( improve_sql_prompt = PromptTemplate(
@@ -151,18 +135,20 @@ improve_sql_prompt = PromptTemplate(
7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。 7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。
8. 查询的SQL字段要用别名,取名参考描述。 8. 查询的SQL字段要用别名,取名参考描述。
9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量加上WHERE tenantid = {tenant_id} 9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量加上WHERE tenantid = {tenant_id}
""" """,
) )
improveSqlChain = improve_sql_prompt | llm improveSqlChain = improve_sql_prompt | llm
def sql_2(state: State): def sql_2(state: State):
return improveSqlChain.invoke({ return improveSqlChain.invoke(
"sql": state['sql'], {
"table_info": state['table_info'], "sql": state["sql"],
"table_info": state["table_info"],
"userInput": state["userInput"], "userInput": state["userInput"],
"tenant_id": state['tenant_id'] "tenant_id": state["tenant_id"],
}).content }
).content
# ------------------------------------------------------------------------ 路径选择 -------- # ------------------------------------------------------------------------ 路径选择 --------
@@ -199,7 +185,7 @@ chat → 无法直接生成 SQL,需要进一步解释或澄清。
回答内容仅限于db或者chat,请勿输出其他内容。 回答内容仅限于db或者chat,请勿输出其他内容。
你的回复: 你的回复:
""" """,
) )
pathSelectChain = pathSelectPrompt | llmThink pathSelectChain = pathSelectPrompt | llmThink
@@ -207,12 +193,18 @@ pathSelectChain = pathSelectPrompt | llmThink
def decide_source(state: State, max_retry=3): def decide_source(state: State, max_retry=3):
"""根据用户输入选择数据来源""" """根据用户输入选择数据来源"""
for _ in range(max_retry): for _ in range(max_retry):
choice = pathSelectChain.invoke({ choice = (
pathSelectChain.invoke(
{
"userInput": state["userInput"], "userInput": state["userInput"],
"table_info": state["table_info"], "table_info": state["table_info"],
"ai_service": state["ai_service"], "ai_service": state["ai_service"],
"sql": state["sql"] "sql": state["sql"],
}).content.strip().lower() }
)
.content.strip()
.lower()
)
print("根据用户输入选择数据来源,路径是:", choice) print("根据用户输入选择数据来源,路径是:", choice)
if choice in ["db", "chat"]: if choice in ["db", "chat"]:
state["path"] = choice state["path"] = choice
@@ -242,17 +234,16 @@ noChatPrompt = PromptTemplate(
3. 引导用户提出与你业务相关的问题。 3. 引导用户提出与你业务相关的问题。
4. 使用礼貌和友好的语气。 4. 使用礼貌和友好的语气。
你的回答: 你的回答:
""" """,
) )
noChatChain = noChatPrompt | llm noChatChain = noChatPrompt | llm
def chat(state: State): def chat(state: State):
state["reply"] = noChatChain.invoke({ state["reply"] = noChatChain.invoke(
"userInput": state["userInput"], {"userInput": state["userInput"], "ai_service": state["ai_service"]}
"ai_service": state["ai_service"] ).content
}).content
print("直接回复") print("直接回复")
return state return state
@@ -291,19 +282,21 @@ summarizePrompt = PromptTemplate(
2. 提供进一步可选的查询示例,基于当前的数据库表结构,引导用户提出更具体需求。 2. 提供进一步可选的查询示例,基于当前的数据库表结构,引导用户提出更具体需求。
你的回复: 你的回复:
""" """,
) )
summarizeChain = summarizePrompt | llm summarizeChain = summarizePrompt | llm
def summarize_ai(state: State): def summarize_ai(state: State):
"""AI 总结输出""" """AI 总结输出"""
state["reply"] = summarizeChain.invoke({ state["reply"] = summarizeChain.invoke(
{
"ai_role": state["ai_role"], "ai_role": state["ai_role"],
"sql": state['sql'], "sql": state["sql"],
"userInput": state['userInput'], "userInput": state["userInput"],
"table_info": state['table_info'], "table_info": state["table_info"],
}).content }
).content
return state return state
@@ -322,7 +315,7 @@ workflow.add_conditional_edges(
{ {
"db": "sql_1", "db": "sql_1",
"chat": "chat", "chat": "chat",
} },
) )
workflow.add_edge("summarize", END) workflow.add_edge("summarize", END)
workflow.add_edge("chat", END) workflow.add_edge("chat", END)
@@ -334,7 +327,8 @@ def get_db_agent_reply(aiId: str, userInput: str, tenant_id: str, sql: str = "")
json = pgdb.get_ai_personality(aiId) json = pgdb.get_ai_personality(aiId)
ai_service = json["业务"] ai_service = json["业务"]
ai_role = json["性格"] ai_role = json["性格"]
final_state = graph.invoke({ final_state = graph.invoke(
{
"ai_service": ai_service, "ai_service": ai_service,
"ai_role": ai_role, "ai_role": ai_role,
"table_info": pgdb.get_available_tables_str(aiId), "table_info": pgdb.get_available_tables_str(aiId),
@@ -342,5 +336,6 @@ def get_db_agent_reply(aiId: str, userInput: str, tenant_id: str, sql: str = "")
"userInput": userInput, "userInput": userInput,
"sql": sql, "sql": sql,
"isFirstGenSQL": sql == "", "isFirstGenSQL": sql == "",
}) }
)
return final_state return final_state
+64 -52
View File
@@ -1,34 +1,13 @@
from typing import List
from typing import Literal
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from config.llm import llm,llmThink
from langgraph.graph import StateGraph, END
from langchain.prompts import PromptTemplate
from config.llm import llm
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_community.agent_toolkits import create_sql_agent
from langchain.prompts import PromptTemplate
from config.llm import llm
from typing import Annotated
from langgraph.graph.message import add_messages
import os
from langchain_tavily import TavilySearch
from langgraph.prebuilt import ToolNode, tools_condition
from llm.chatLLM import get_chat_response
from typing import TypedDict from typing import TypedDict
from langchain_core.prompts import PromptTemplate
from langgraph.graph import StateGraph, END from langgraph.graph import StateGraph, END
from llm.summarizeLLM import getSummary
import db.postgres as pgdb
import db.sqlserver as sqlserver
from typing import List, Dict
import db.milvus as milvus import db.milvus as milvus
import db.postgres as pgdb
from config.llm import llm
from config.llm import llmThink
# -------- 定义状态 -------- # -------- 定义状态 --------
@@ -48,6 +27,7 @@ class State(TypedDict):
userInput: str # 用户输入 userInput: str # 用户输入
reply: str # 最终回复 reply: str # 最终回复
# -------- 定义节点 -------- # -------- 定义节点 --------
# ------------------------------------------------------------------------ 向量数据库查询 -------- # ------------------------------------------------------------------------ 向量数据库查询 --------
@@ -65,23 +45,28 @@ gen_sql_prompt = PromptTemplate(
4. 确保关键词之间相互独立,不包含其他关键词。 4. 确保关键词之间相互独立,不包含其他关键词。
关键词之间用空格分隔。 关键词之间用空格分隔。
你的回答是: 你的回答是:
""" """,
) )
sqlChain = gen_sql_prompt | llm sqlChain = gen_sql_prompt | llm
def db_search(state: State): def db_search(state: State):
key_words = sqlChain.invoke({ key_words = sqlChain.invoke(
"userInput": state['userInput'], {
}).content "userInput": state["userInput"],
}
).content
print("关键词是:", key_words) print("关键词是:", key_words)
knowledge = milvus.get_knowledge_by_key_words(key_words, state['kn_bases']) knowledge = milvus.get_knowledge_by_key_words(key_words, state["kn_bases"])
print("知识库内容是:", knowledge) print("知识库内容是:", knowledge)
state["knowledge"] = knowledge state["knowledge"] = knowledge
ai_ids = [state['ai_id']] ai_ids = [state["ai_id"]]
memory = milvus.get_memory_by_key_words(key_words, ai_ids) memory = milvus.get_memory_by_key_words(key_words, ai_ids)
print("记忆是:", memory) print("记忆是:", memory)
state["memory"] = memory state["memory"] = memory
return state return state
# ------------------------------------------------------------------------ 意图分析 -------- # ------------------------------------------------------------------------ 意图分析 --------
pathSelectPrompt = PromptTemplate( pathSelectPrompt = PromptTemplate(
@@ -103,17 +88,25 @@ pathSelectPrompt = PromptTemplate(
判断规则如下: 判断规则如下:
如果用户最新回复与你的负责工作相关,需要去查知识库,输出“kn”;如果不相关,则输出“chat”,不要包含任何标点符号以及空格。 如果用户最新回复与你的负责工作相关,需要去查知识库,输出“kn”;如果不相关,则输出“chat”,不要包含任何标点符号以及空格。
你生成的结果: 你生成的结果:
""" """,
) )
pathSelectChain = pathSelectPrompt | llmThink pathSelectChain = pathSelectPrompt | llmThink
def decide_source(state: State, max_retry=3): def decide_source(state: State, max_retry=3):
"""根据用户输入选择数据来源""" """根据用户输入选择数据来源"""
for _ in range(max_retry): for _ in range(max_retry):
choice = pathSelectChain.invoke({ choice = (
pathSelectChain.invoke(
{
"userInput": state["userInput"], "userInput": state["userInput"],
"ai_service": state["ai_service"], "ai_service": state["ai_service"],
"history": state["history"], "history": state["history"],
}).content.strip().lower() }
)
.content.strip()
.lower()
)
print("根据用户输入选择数据来源,路径是:", choice) print("根据用户输入选择数据来源,路径是:", choice)
if choice in ["kn", "chat"]: if choice in ["kn", "chat"]:
state["path"] = choice state["path"] = choice
@@ -123,6 +116,7 @@ def decide_source(state: State, max_retry=3):
state["path"] = "chat" state["path"] = "chat"
return state return state
# ------------------------------------------------------------------------ !普通聊天 -------- # ------------------------------------------------------------------------ !普通聊天 --------
noChatPrompt = PromptTemplate( noChatPrompt = PromptTemplate(
input_variables=["ai_name", "ai_service", "ai_role", "history"], input_variables=["ai_name", "ai_service", "ai_role", "history"],
@@ -140,21 +134,26 @@ noChatPrompt = PromptTemplate(
4. 回复要简洁明了,避免冗长和复杂的表述。 4. 回复要简洁明了,避免冗长和复杂的表述。
你的回答: 你的回答:
""" """,
) )
noChatChain = noChatPrompt | llm noChatChain = noChatPrompt | llm
def chat(state: State): def chat(state: State):
state["reply"] = noChatChain.invoke({ state["reply"] = noChatChain.invoke(
{
"ai_name": state["ai_name"], "ai_name": state["ai_name"],
"ai_service": state["ai_service"], "ai_service": state["ai_service"],
"ai_role": state["ai_role"], "ai_role": state["ai_role"],
"history": state["history"], "history": state["history"],
"userStr": state["userInput"] "userStr": state["userInput"],
}).content }
).content
print("直接回复") print("直接回复")
return state return state
# ------------------------------------------------------------------------ 整理结果 -------- # ------------------------------------------------------------------------ 整理结果 --------
summarizePrompt = PromptTemplate( summarizePrompt = PromptTemplate(
@@ -177,32 +176,40 @@ summarizePrompt = PromptTemplate(
2. 回复的语气要结合你的性格特点。 2. 回复的语气要结合你的性格特点。
3. 确保回复内容清晰、简洁、有针对性。 3. 确保回复内容清晰、简洁、有针对性。
请生成你的回复: 请生成你的回复:
""" """,
) )
summarizeChain = summarizePrompt | llm summarizeChain = summarizePrompt | llm
def summarize_ai(state: State): def summarize_ai(state: State):
"""AI 总结输出""" """AI 总结输出"""
mem = state['memory'] mem = state["memory"]
if mem != "": if mem != "":
memStr = """ memStr = (
"""
这是给你参考的相关历史记忆: 这是给你参考的相关历史记忆:
<memory> <memory>
%s %s
</memory> </memory>
""" % mem # 这里用 % 把 mem 填进去 """
% mem
) # 这里用 % 把 mem 填进去
else: else:
memStr = "没有记忆内容" memStr = "没有记忆内容"
print("历史记录是:", state["history"]) print("历史记录是:", state["history"])
state["reply"] = summarizeChain.invoke({ state["reply"] = summarizeChain.invoke(
{
"ai_role": state["ai_role"], "ai_role": state["ai_role"],
"ai_name": state["ai_name"], "ai_name": state["ai_name"],
"history": state["history"], "history": state["history"],
"ai_service":state['ai_service'], "ai_service": state["ai_service"],
"knowledge": state["knowledge"], "knowledge": state["knowledge"],
"memory": memStr, "memory": memStr,
}).content }
).content
return state return state
# ------------------------------------------------------------------------ 构建有向图 -------- # ------------------------------------------------------------------------ 构建有向图 --------
workflow = StateGraph(State) workflow = StateGraph(State)
workflow.add_node("decide", decide_source) workflow.add_node("decide", decide_source)
@@ -217,15 +224,18 @@ workflow.add_conditional_edges(
{ {
"kn": "db_search", "kn": "db_search",
"chat": "chat", "chat": "chat",
} },
) )
workflow.add_edge("db_search", "summarize") workflow.add_edge("db_search", "summarize")
workflow.add_edge("summarize", END) workflow.add_edge("summarize", END)
workflow.add_edge("chat", END) workflow.add_edge("chat", END)
graph = workflow.compile() graph = workflow.compile()
# 执行函数 # 执行函数
def get_service_agent_reply(aiId:str, userInput: str,history:str, kn_bases:List[str]) : def get_service_agent_reply(
aiId: str, userInput: str, history: str, kn_bases: List[str]
):
json = pgdb.get_ai_personality(aiId) json = pgdb.get_ai_personality(aiId)
ai_service = json["业务"] ai_service = json["业务"]
ai_role = json["性格"] ai_role = json["性格"]
@@ -233,7 +243,8 @@ def get_service_agent_reply(aiId:str, userInput: str,history:str, kn_bases:List[
print("AI Name:", ai_name) print("AI Name:", ai_name)
print("AI Service:", ai_service) print("AI Service:", ai_service)
final_state = graph.invoke({ final_state = graph.invoke(
{
"ai_service": ai_service, "ai_service": ai_service,
"ai_role": ai_role, "ai_role": ai_role,
"ai_name": ai_name, "ai_name": ai_name,
@@ -242,5 +253,6 @@ def get_service_agent_reply(aiId:str, userInput: str,history:str, kn_bases:List[
"table_info": pgdb.get_available_tables_str(aiId), "table_info": pgdb.get_available_tables_str(aiId),
"userInput": userInput, "userInput": userInput,
"ai_id": aiId, "ai_id": aiId,
}) }
)
return final_state["reply"] return final_state["reply"]
+4 -7
View File
@@ -40,17 +40,14 @@ if sys.platform.lower() == "win32" or os.name.lower() == "nt":
def get_device_id_simple(): def get_device_id_simple():
try: hostname = os.getenv("HOST_NAME")
with open("/etc/machine-id") as f: if not hostname:
mid = f.read().strip()
if mid:
return mid
except Exception:
pass
hostname = socket.gethostname() hostname = socket.gethostname()
mac = uuid.getnode() mac = uuid.getnode()
mac_str = ":".join(f"{(mac >> ele) & 0xff:02x}" for ele in range(40, -1, -8)) mac_str = ":".join(f"{(mac >> ele) & 0xff:02x}" for ele in range(40, -1, -8))
return f"{hostname}|{mac_str}" return f"{hostname}|{mac_str}"
else:
return hostname
# todo 这里需要订阅状态信息 设备发送信息 这里回复 vue前端发送指令 后端发送指令 设备接收指令 # todo 这里需要订阅状态信息 设备发送信息 这里回复 vue前端发送指令 后端发送指令 设备接收指令
+13 -14
View File
@@ -1,9 +1,10 @@
from config.milvus import knVectorstore,memVectorstore
from langchain.schema import Document
from datetime import datetime from datetime import datetime
from typing import List from typing import List
from typing import List, Dict, Any from langchain_core.documents import Document
from config.milvus import knVectorstore, memVectorstore
def get_knowledge_by_key_words(key_words: str, kn_ids: List[str]) -> str: def get_knowledge_by_key_words(key_words: str, kn_ids: List[str]) -> str:
""" """
@@ -17,9 +18,7 @@ def get_knowledge_by_key_words(key_words: str, kn_ids: List[str]) -> str:
return "未找到相关的知识。" return "未找到相关的知识。"
result = knVectorstore.similarity_search( result = knVectorstore.similarity_search(
query=key_words, query=key_words, k=3, expr=expr # 可调节返回条数
k=3, # 可调节返回条数
expr=expr
) )
# 整理成字符串 # 整理成字符串
@@ -48,9 +47,7 @@ def get_memory_by_key_words(key_words: str, ai_ids: List[str]) -> str:
expr = "" # 不限制 kn_id todo 实际上应该不反悔任何内容 expr = "" # 不限制 kn_id todo 实际上应该不反悔任何内容
result = memVectorstore.similarity_search( result = memVectorstore.similarity_search(
query=key_words, query=key_words, k=5, expr=expr # 可调节返回条数
k=5, # 可调节返回条数
expr=expr
) )
# 整理成字符串 # 整理成字符串
@@ -64,12 +61,12 @@ def get_memory_by_key_words(key_words: str, ai_ids: List[str]) -> str:
# 拼成一个大字符串,用换行隔开 # 拼成一个大字符串,用换行隔开
combined_text = "\n\n".join(doc_texts) combined_text = "\n\n".join(doc_texts)
return combined_text return combined_text
def get_knowledge_by_base_id(base_id: str): def get_knowledge_by_base_id(base_id: str):
expr = f'kn_id == "{base_id}"' # base_id 会被替换 expr = f'kn_id == "{base_id}"' # base_id 会被替换
result = knVectorstore.similarity_search( result = knVectorstore.similarity_search(
query="", # 如果只想用过滤条件,可以传空字符串 query="", k=100, expr=expr # 如果只想用过滤条件,可以传空字符串
k=100,
expr=expr
) )
return [ return [
{ {
@@ -80,6 +77,7 @@ def get_knowledge_by_base_id(base_id: str):
for doc in result for doc in result
] ]
def add_knowledge(text: str, is_active: bool, base_id: str, user_id: str): def add_knowledge(text: str, is_active: bool, base_id: str, user_id: str):
docs = [ docs = [
Document( Document(
@@ -89,11 +87,12 @@ def add_knowledge(text: str, is_active: bool, base_id: str, user_id: str):
"created_by": str(user_id), "created_by": str(user_id),
"created_at": datetime.now().isoformat(), "created_at": datetime.now().isoformat(),
"is_active": is_active, "is_active": is_active,
} },
) )
] ]
return knVectorstore.add_documents(docs) return knVectorstore.add_documents(docs)
def add_memory(ai_id: str, mem: str, user_id: str, is_active: bool): def add_memory(ai_id: str, mem: str, user_id: str, is_active: bool):
docs = [ docs = [
Document( Document(
@@ -103,7 +102,7 @@ def add_memory(ai_id:str,mem: str, user_id: str,is_active: bool):
"created_by": str(user_id), "created_by": str(user_id),
"created_at": datetime.now().isoformat(), "created_at": datetime.now().isoformat(),
"is_active": is_active, "is_active": is_active,
} },
) )
] ]
return memVectorstore.add_documents(docs) return memVectorstore.add_documents(docs)
+6 -7
View File
@@ -1,6 +1,6 @@
from langchain_core.prompts import PromptTemplate
from config.llm import llm from config.llm import llm
from langchain.prompts import PromptTemplate
chatPrompt = PromptTemplate( chatPrompt = PromptTemplate(
input_variables=["aiRole", "history", "userInput"], input_variables=["aiRole", "history", "userInput"],
@@ -15,13 +15,12 @@ chatPrompt = PromptTemplate(
{userInput} {userInput}
最后,请注意,不要编造数据,不知道就说不知道,现在,请生成你的回复: 最后,请注意,不要编造数据,不知道就说不知道,现在,请生成你的回复:
""" """,
) )
chatChain = chatPrompt | llm chatChain = chatPrompt | llm
def get_chat_response(aiRole: str, history: str, userInput: str) -> str: def get_chat_response(aiRole: str, history: str, userInput: str) -> str:
return chatChain.invoke({ return chatChain.invoke(
"aiRole": aiRole, {"aiRole": aiRole, "history": history, "userInput": userInput}
"history": history, )
"userInput": userInput
})
+23 -9
View File
@@ -1,8 +1,8 @@
from langchain.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from config.llm import llm,llmThink
import db.milvus as milvus import db.milvus as milvus
import db.postgres as pg import db.postgres as pg
import json from config.llm import llmThink
memPathPrompt = PromptTemplate( memPathPrompt = PromptTemplate(
input_variables=["ai_role", "CHAT_RECORD"], input_variables=["ai_role", "CHAT_RECORD"],
@@ -31,7 +31,7 @@ no:用户最新回复价值有限或几乎不会在未来业务中使用。
回复不要带任何标点符号以及空格、换行符。 回复不要带任何标点符号以及空格、换行符。
请给出你的判断结果: 请给出你的判断结果:
""" """,
) )
memPathChain = memPathPrompt | llmThink memPathChain = memPathPrompt | llmThink
memPrompt = PromptTemplate( memPrompt = PromptTemplate(
@@ -48,9 +48,11 @@ memPrompt = PromptTemplate(
4. 总结内容应包含时间,并确保时间是准确的。 4. 总结内容应包含时间,并确保时间是准确的。
5. 你需要针对你的业务场景{ai_role},展开对用户最后回复的总结。 5. 你需要针对你的业务场景{ai_role},展开对用户最后回复的总结。
请生成你的总结,以用户、时间开头: 请生成你的总结,以用户、时间开头:
""" """,
) )
memChain = memPrompt | llmThink memChain = memPrompt | llmThink
def take_memory(ai_id: str, sessionId: str, user_id: str, max_retry=3): def take_memory(ai_id: str, sessionId: str, user_id: str, max_retry=3):
"""根据用户输入选择数据来源""" """根据用户输入选择数据来源"""
history = pg.get_history_with_time(sessionId, 10) history = pg.get_history_with_time(sessionId, 10)
@@ -66,17 +68,29 @@ def take_memory(ai_id:str,sessionId: str,user_id:str, max_retry=3):
else: else:
ai_service = json["业务"] ai_service = json["业务"]
print("获取的描述是:", ai_service) print("获取的描述是:", ai_service)
choice = memPathChain.invoke({ choice = (
memPathChain.invoke(
{
"ai_role": ai_service, "ai_role": ai_service,
"CHAT_RECORD": history, "CHAT_RECORD": history,
}).content.strip().lower() }
)
.content.strip()
.lower()
)
print("记忆判断器判断的结果是:", choice) print("记忆判断器判断的结果是:", choice)
if choice == "yes": if choice == "yes":
# 对对话进行总结 # 对对话进行总结
memory = memChain.invoke({ memory = (
memChain.invoke(
{
"CHAT_RECORD": history, "CHAT_RECORD": history,
"ai_role": ai_service, "ai_role": ai_service,
}).content.strip().lower() }
)
.content.strip()
.lower()
)
print("记忆生成结果是:", memory) print("记忆生成结果是:", memory)
milvus.add_memory(mem=memory, user_id=user_id, is_active=True, ai_id=ai_id) milvus.add_memory(mem=memory, user_id=user_id, is_active=True, ai_id=ai_id)
return return
+14 -17
View File
@@ -1,8 +1,8 @@
from langchain_community.agent_toolkits import create_sql_agent
from langchain_core.prompts import PromptTemplate
from config.llm import llm from config.llm import llm
from langchain.prompts import PromptTemplate
from config.ssDb import ssDBLC from config.ssDb import ssDBLC
from langchain_community.agent_toolkits import create_sql_agent
# ______________________________________________________________SQL描述_____________________________________________________________________ # ______________________________________________________________SQL描述_____________________________________________________________________
sqlDescriptionPrompt = PromptTemplate( sqlDescriptionPrompt = PromptTemplate(
@@ -16,14 +16,14 @@ sqlDescriptionPrompt = PromptTemplate(
3. 不能有markdown语法 3. 不能有markdown语法
4. 要用业务语言描述,不能有专业语句例如SQL表名等 4. 要用业务语言描述,不能有专业语句例如SQL表名等
请生成你认为合适的标题,: 请生成你认为合适的标题,:
""" """,
) )
sqlDescriptionChain = sqlDescriptionPrompt | llm sqlDescriptionChain = sqlDescriptionPrompt | llm
def get_sql_description_response(sql: str) -> str: def get_sql_description_response(sql: str) -> str:
return sqlDescriptionChain.invoke({ return sqlDescriptionChain.invoke({"sql": sql})
"sql": sql
})
# ______________________________________________________________第一次生成SQL_____________________________________________________________________ # ______________________________________________________________第一次生成SQL_____________________________________________________________________
sqlPrompt = PromptTemplate( sqlPrompt = PromptTemplate(
@@ -34,15 +34,12 @@ sqlPrompt = PromptTemplate(
只需要返回SQL语句,不要任何解释。 只需要返回SQL语句,不要任何解释。
用户需求:{userInput} 用户需求:{userInput}
请生成SQL语句: 请生成SQL语句:
""" """,
) )
sqlChain = sqlPrompt | llm sqlChain = sqlPrompt | llm
agent = create_sql_agent( agent = create_sql_agent(llm=llm, db=ssDBLC, agent_type="tool-calling", verbose=True)
llm=llm,
db=ssDBLC,
agent_type="tool-calling",
verbose=True
)
# def get_chat_sql_response2( userInput: str) -> str: # def get_chat_sql_response2( userInput: str) -> str:
# return sqlChain.invoke({ # return sqlChain.invoke({
# "userInput": userInput # "userInput": userInput
@@ -50,6 +47,7 @@ agent = create_sql_agent(
def get_chat_sql_response(userInput: str) -> str: def get_chat_sql_response(userInput: str) -> str:
return agent.invoke({"input": userInput})["output"] return agent.invoke({"input": userInput})["output"]
# ______________________________________________________________改进SQL_____________________________________________________________________ # ______________________________________________________________改进SQL_____________________________________________________________________
sqlImprovePrompt = PromptTemplate( sqlImprovePrompt = PromptTemplate(
input_variables=["userInput", "sql"], input_variables=["userInput", "sql"],
@@ -59,11 +57,10 @@ sqlImprovePrompt = PromptTemplate(
只需要返回改进后的SQL语句,不要任何解释。 只需要返回改进后的SQL语句,不要任何解释。
已有SQL{sql} 已有SQL{sql}
用户需求:{userInput} 用户需求:{userInput}
""" """,
) )
sqlImproveChain = sqlImprovePrompt | llm sqlImproveChain = sqlImprovePrompt | llm
def get_chat_sql_improve_response(userInput: str) -> str: def get_chat_sql_improve_response(userInput: str) -> str:
return sqlImproveChain.invoke({ return sqlImproveChain.invoke({"userInput": userInput})
"userInput": userInput
})
+8 -5
View File
@@ -1,5 +1,5 @@
from langchain_core.prompts import PromptTemplate
from langchain.prompts import PromptTemplate
from config.llm import llm from config.llm import llm
summarizePrompt = PromptTemplate( summarizePrompt = PromptTemplate(
@@ -21,14 +21,17 @@ summarizePrompt = PromptTemplate(
{infomation} {infomation}
··· ···
如果参考内容明显有问题,你要请用户重新描述问题,现在请生成你的回复: 如果参考内容明显有问题,你要请用户重新描述问题,现在请生成你的回复:
""" """,
) )
summarizeChain = summarizePrompt | llm summarizeChain = summarizePrompt | llm
def getSummary(aiRole: str, history: str, userInput: str, infomation: str) -> str: def getSummary(aiRole: str, history: str, userInput: str, infomation: str) -> str:
return summarizeChain.invoke({ return summarizeChain.invoke(
{
"aiRole": aiRole, "aiRole": aiRole,
"history": history, "history": history,
"userStr": userInput, "userStr": userInput,
"infomation": infomation "infomation": infomation,
}).content }
).content
+1 -1
View File
@@ -1,7 +1,7 @@
import json import json
import re import re
from langchain.schema import HumanMessage from langchain_core.messages import HumanMessage
from config.llm import * from config.llm import *
+1 -1
View File
@@ -1,7 +1,7 @@
import json import json
import re import re
from langchain.schema import HumanMessage from langchain_core.messages import HumanMessage
from config.llm import * from config.llm import *
from llm.ticketLLM import decode_barcode from llm.ticketLLM import decode_barcode
+3 -2
View File
@@ -1,5 +1,5 @@
from langchain_core.prompts import PromptTemplate
from langchain.prompts import PromptTemplate
from config.llm import llm from config.llm import llm
titlePrompt = PromptTemplate( titlePrompt = PromptTemplate(
@@ -12,9 +12,10 @@ titlePrompt = PromptTemplate(
4. 保持自然、易懂、专业或有趣(可根据场景调整风格)。 4. 保持自然、易懂、专业或有趣(可根据场景调整风格)。
5. 不能出现标点符号。 5. 不能出现标点符号。
用户原话:"{userStr}" 用户原话:"{userStr}"
""" """,
) )
titleChain = titlePrompt | llm titleChain = titlePrompt | llm
def get_title(userInput: str): def get_title(userInput: str):
return titleChain.invoke({"userStr": userInput}).content return titleChain.invoke({"userStr": userInput}).content
+1
View File
@@ -23,6 +23,7 @@ python-multipart==0.0.20
aio_pika==9.5.7 aio_pika==9.5.7
ultralytics==8.3.227 ultralytics==8.3.227
redis==7.1.0 redis==7.1.0
aiomqtt==2.4.0
# MCP服务 # MCP服务
python-dotenv>=1.0.0 python-dotenv>=1.0.0
websockets>=11.0.3 websockets>=11.0.3