升级新库
This commit is contained in:
@@ -1,10 +1,10 @@
|
|||||||
# 使用官方 Python 镜像
|
FROM ubuntu:22.04
|
||||||
FROM python:3.10-slim
|
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
RUN apt-get update && \
|
RUN apt-get update && \
|
||||||
apt-get install -y --no-install-recommends \
|
apt-get install -y --no-install-recommends \
|
||||||
|
ca-certificates \
|
||||||
libpq5 \
|
libpq5 \
|
||||||
unixodbc \
|
unixodbc \
|
||||||
curl \
|
curl \
|
||||||
@@ -20,31 +20,4 @@ RUN apt-get update && \
|
|||||||
ACCEPT_EULA=Y apt-get install -y msodbcsql18 && \
|
ACCEPT_EULA=Y apt-get install -y msodbcsql18 && \
|
||||||
rm -rf /var/lib/apt/lists/*
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
COPY app/requirements.txt .
|
COPY app/ /app
|
||||||
# 安装 Python 依赖
|
|
||||||
RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
|
|
||||||
RUN python -m pip uninstall -y opencv-python
|
|
||||||
RUN python -m pip install opencv-python-headless
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# 复制并解压 JRE
|
|
||||||
COPY docker/OpenJDK17U-jre_x64_linux_hotspot_17.0.16_8.tar.gz /opt/
|
|
||||||
RUN tar -xzf /opt/OpenJDK17U-jre_x64_linux_hotspot_17.0.16_8.tar.gz -C /opt/ && \
|
|
||||||
rm /opt/OpenJDK17U-jre_x64_linux_hotspot_17.0.16_8.tar.gz
|
|
||||||
|
|
||||||
# 配置 Java 环境
|
|
||||||
ENV JAVA_HOME=/opt/jdk-17.0.16+8-jre
|
|
||||||
ENV PATH="$JAVA_HOME/bin:$PATH"
|
|
||||||
|
|
||||||
# 复制项目代码
|
|
||||||
COPY app/ .
|
|
||||||
# 复制 pyzxing 的 jar 文件到默认路径
|
|
||||||
COPY docker/javase-3.4.1-SNAPSHOT-jar-with-dependencies.jar /root/.local/pyzxing/javase-3.4.1-SNAPSHOT-jar-with-dependencies.jar
|
|
||||||
|
|
||||||
EXPOSE 13011
|
|
||||||
|
|
||||||
# 启动命令(使用 uvicorn 启动 FastAPI)
|
|
||||||
CMD ["python", "app.py"]
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,23 +1,15 @@
|
|||||||
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from config.llm import llm
|
|
||||||
from config.ssDb import ssDBLC
|
|
||||||
from typing import Annotated
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
from langgraph.graph import StateGraph, START, END
|
|
||||||
from langgraph.graph.message import add_messages
|
|
||||||
from langchain_community.agent_toolkits import create_sql_agent
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from config.llm import llm
|
|
||||||
from config.ssDb import ssDBLC
|
|
||||||
from typing import Annotated
|
|
||||||
from langgraph.graph.message import add_messages
|
|
||||||
import os
|
import os
|
||||||
from langchain_tavily import TavilySearch
|
|
||||||
from langgraph.prebuilt import ToolNode, tools_condition
|
|
||||||
from llm.chatLLM import get_chat_response
|
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
|
from langchain_community.agent_toolkits import create_sql_agent
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
from langchain_tavily import TavilySearch
|
||||||
|
from langgraph.graph import START
|
||||||
from langgraph.graph import StateGraph, END
|
from langgraph.graph import StateGraph, END
|
||||||
|
|
||||||
|
from config.llm import llm
|
||||||
|
from config.ssDb import ssDBLC
|
||||||
|
from llm.chatLLM import get_chat_response
|
||||||
from llm.summarizeLLM import getSummary
|
from llm.summarizeLLM import getSummary
|
||||||
|
|
||||||
|
|
||||||
@@ -30,6 +22,7 @@ class State(TypedDict):
|
|||||||
history: str # 聊天历史
|
history: str # 聊天历史
|
||||||
reply: str # 最终回复
|
reply: str # 最终回复
|
||||||
|
|
||||||
|
|
||||||
# -------- 定义节点 --------
|
# -------- 定义节点 --------
|
||||||
# ------------------------------------------------------------------------ 路径选择 --------
|
# ------------------------------------------------------------------------ 路径选择 --------
|
||||||
|
|
||||||
@@ -45,19 +38,26 @@ pathSelectPrompt = PromptTemplate(
|
|||||||
用户最新输入:
|
用户最新输入:
|
||||||
{userStr}
|
{userStr}
|
||||||
请做出你的选择:
|
请做出你的选择:
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
pathSelectChain = pathSelectPrompt | llm
|
pathSelectChain = pathSelectPrompt | llm
|
||||||
|
|
||||||
|
|
||||||
def decide_source(state: State, max_retry=3):
|
def decide_source(state: State, max_retry=3):
|
||||||
print("根据用户输入选择数据来源,用户输入:", state["userInput"])
|
print("根据用户输入选择数据来源,用户输入:", state["userInput"])
|
||||||
"""根据用户输入选择数据来源"""
|
"""根据用户输入选择数据来源"""
|
||||||
for _ in range(max_retry):
|
for _ in range(max_retry):
|
||||||
choice = pathSelectChain.invoke({
|
choice = (
|
||||||
|
pathSelectChain.invoke(
|
||||||
|
{
|
||||||
"aiRole": state["aiRole"],
|
"aiRole": state["aiRole"],
|
||||||
"history": state["history"],
|
"history": state["history"],
|
||||||
"userStr": state["userInput"],
|
"userStr": state["userInput"],
|
||||||
}).content.strip().lower()
|
}
|
||||||
|
)
|
||||||
|
.content.strip()
|
||||||
|
.lower()
|
||||||
|
)
|
||||||
if choice in ["web", "db", "chat"]:
|
if choice in ["web", "db", "chat"]:
|
||||||
state["source"] = choice
|
state["source"] = choice
|
||||||
break
|
break
|
||||||
@@ -72,36 +72,45 @@ def decide_source(state: State, max_retry=3):
|
|||||||
os.environ["TAVILY_API_KEY"] = "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"
|
os.environ["TAVILY_API_KEY"] = "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"
|
||||||
tool = TavilySearch(max_results=2)
|
tool = TavilySearch(max_results=2)
|
||||||
|
|
||||||
|
|
||||||
def fetch_web(state: State):
|
def fetch_web(state: State):
|
||||||
result = tool.invoke(state["userInput"])
|
result = tool.invoke(state["userInput"])
|
||||||
state["infomation"] = result.get("content") or result
|
state["infomation"] = result.get("content") or result
|
||||||
print("调用了联网工具,结果是:", state["infomation"])
|
print("调用了联网工具,结果是:", state["infomation"])
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------ 数据库查询 --------
|
# ------------------------------------------------------------------------ 数据库查询 --------
|
||||||
agent = create_sql_agent(
|
agent = create_sql_agent(llm=llm, db=ssDBLC, agent_type="tool-calling", verbose=True)
|
||||||
llm=llm,
|
|
||||||
db=ssDBLC,
|
|
||||||
agent_type="tool-calling",
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
def fetch_db(state: State):
|
def fetch_db(state: State):
|
||||||
state["infomation"] = agent.invoke({"input": state["userInput"]})["output"]
|
state["infomation"] = agent.invoke({"input": state["userInput"]})["output"]
|
||||||
print("调用了数据库工具,结果是:", state["infomation"])
|
print("调用了数据库工具,结果是:", state["infomation"])
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------ 整理结果 --------
|
# ------------------------------------------------------------------------ 整理结果 --------
|
||||||
def summarize_ai(state: State):
|
def summarize_ai(state: State):
|
||||||
"""AI 总结输出"""
|
"""AI 总结输出"""
|
||||||
state["reply"] = getSummary(aiRole=state["aiRole"], history=state["history"], userInput= state["userInput"], infomation= state["infomation"])
|
state["reply"] = getSummary(
|
||||||
|
aiRole=state["aiRole"],
|
||||||
|
history=state["history"],
|
||||||
|
userInput=state["userInput"],
|
||||||
|
infomation=state["infomation"],
|
||||||
|
)
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------ 普通聊天 --------
|
# ------------------------------------------------------------------------ 普通聊天 --------
|
||||||
def chat(state: State):
|
def chat(state: State):
|
||||||
state["reply"] = get_chat_response(aiRole=state["aiRole"],history=state["history"], userInput= state["userInput"]).content
|
state["reply"] = get_chat_response(
|
||||||
|
aiRole=state["aiRole"], history=state["history"], userInput=state["userInput"]
|
||||||
|
).content
|
||||||
print("直接回复")
|
print("直接回复")
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------ 构建有向图 --------
|
# ------------------------------------------------------------------------ 构建有向图 --------
|
||||||
workflow = StateGraph(State)
|
workflow = StateGraph(State)
|
||||||
workflow.add_node("decide", decide_source)
|
workflow.add_node("decide", decide_source)
|
||||||
@@ -119,21 +128,20 @@ workflow.add_edge("fetch_db", "summarize")
|
|||||||
workflow.add_conditional_edges(
|
workflow.add_conditional_edges(
|
||||||
"decide",
|
"decide",
|
||||||
lambda state: state["source"], # 返回 state["source"] 的值
|
lambda state: state["source"], # 返回 state["source"] 的值
|
||||||
{
|
{"web": "fetch_web", "chat": "chat", "db": "fetch_db"},
|
||||||
"web": "fetch_web",
|
|
||||||
"chat": "chat",
|
|
||||||
"db": "fetch_db"
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
workflow.add_edge("summarize", END)
|
workflow.add_edge("summarize", END)
|
||||||
workflow.add_edge("chat", END)
|
workflow.add_edge("chat", END)
|
||||||
graph = workflow.compile()
|
graph = workflow.compile()
|
||||||
|
|
||||||
|
|
||||||
# 执行函数
|
# 执行函数
|
||||||
def get_graph_output(aiRole: str, history: str, userInput: str) -> str:
|
def get_graph_output(aiRole: str, history: str, userInput: str) -> str:
|
||||||
final_state = graph.invoke({
|
final_state = graph.invoke(
|
||||||
|
{
|
||||||
"aiRole": aiRole,
|
"aiRole": aiRole,
|
||||||
"history": history,
|
"history": history,
|
||||||
"userInput": userInput,
|
"userInput": userInput,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
return final_state["reply"]
|
return final_state["reply"]
|
||||||
@@ -1,31 +1,12 @@
|
|||||||
from typing import Literal
|
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
|
||||||
from langgraph.graph import END, START, MessagesState, StateGraph
|
|
||||||
from langgraph.prebuilt import ToolNode
|
|
||||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|
||||||
from config.llm import llm, llmThink
|
|
||||||
from langgraph.graph import StateGraph, END
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from config.llm import llm
|
|
||||||
from typing import Annotated
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
from langgraph.graph import StateGraph, START, END
|
|
||||||
from langgraph.graph.message import add_messages
|
|
||||||
from langchain_community.agent_toolkits import create_sql_agent
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from config.llm import llm
|
|
||||||
from typing import Annotated
|
|
||||||
from langgraph.graph.message import add_messages
|
|
||||||
import os
|
|
||||||
from langchain_tavily import TavilySearch
|
|
||||||
from langgraph.prebuilt import ToolNode, tools_condition
|
|
||||||
from llm.chatLLM import get_chat_response
|
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langgraph.graph import StateGraph, END
|
from langgraph.graph import StateGraph, END
|
||||||
from llm.summarizeLLM import getSummary
|
|
||||||
import db.postgres as pgdb
|
import db.postgres as pgdb
|
||||||
import db.sqlserver as sqlserver
|
import db.sqlserver as sqlserver
|
||||||
|
from config.llm import llm
|
||||||
|
from config.llm import llmThink
|
||||||
|
|
||||||
|
|
||||||
# -------- 定义状态 --------
|
# -------- 定义状态 --------
|
||||||
@@ -69,7 +50,7 @@ gen_sql_prompt = PromptTemplate(
|
|||||||
|
|
||||||
|
|
||||||
请直接输出完整可执行的 SQL 语句,不要任何其他文字或格式化,例如反引号或 ```sql。
|
请直接输出完整可执行的 SQL 语句,不要任何其他文字或格式化,例如反引号或 ```sql。
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
sqlChain = gen_sql_prompt | llm
|
sqlChain = gen_sql_prompt | llm
|
||||||
|
|
||||||
@@ -89,32 +70,33 @@ fix_prompt = PromptTemplate(
|
|||||||
|
|
||||||
# 输出要求
|
# 输出要求
|
||||||
只返回修正后的 SQL 语句,不包含任何额外的解释或说明。
|
只返回修正后的 SQL 语句,不包含任何额外的解释或说明。
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
fixSQLChain = fix_prompt | llm
|
fixSQLChain = fix_prompt | llm
|
||||||
|
|
||||||
|
|
||||||
def sql(state: State):
|
def sql(state: State):
|
||||||
if state["isFirstGenSQL"]:
|
if state["isFirstGenSQL"]:
|
||||||
state['sql'] = sql_1(state)
|
state["sql"] = sql_1(state)
|
||||||
else:
|
else:
|
||||||
state['sql'] = sql_2(state)
|
state["sql"] = sql_2(state)
|
||||||
for attempt in range(2):
|
for attempt in range(2):
|
||||||
try:
|
try:
|
||||||
# 执行 SQL
|
# 执行 SQL
|
||||||
result = sqlserver.executeSQL(state['sql'])
|
result = sqlserver.executeSQL(state["sql"])
|
||||||
state['sql_result'] = result
|
state["sql_result"] = result
|
||||||
# print("SQL 执行成功,结果:", result)
|
# print("SQL 执行成功,结果:", result)
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
error_msg = str(e)
|
error_msg = str(e)
|
||||||
print(f"SQL 执行出错: {error_msg}")
|
print(f"SQL 执行出错: {error_msg}")
|
||||||
# 调用 LLM 修正 SQL
|
# 调用 LLM 修正 SQL
|
||||||
state['sql'] = fixSQLChain.invoke({
|
state["sql"] = fixSQLChain.invoke(
|
||||||
"sql": state['sql'],
|
{
|
||||||
|
"sql": state["sql"],
|
||||||
"error_msg": error_msg,
|
"error_msg": error_msg,
|
||||||
"table_info": state['table_info'],
|
"table_info": state["table_info"],
|
||||||
"tenant_id": state['tenant_id']
|
"tenant_id": state["tenant_id"],
|
||||||
}
|
}
|
||||||
).content
|
).content
|
||||||
# print(f"LLM 生成修正 SQL: {state['sql']}")
|
# print(f"LLM 生成修正 SQL: {state['sql']}")
|
||||||
@@ -124,11 +106,13 @@ def sql(state: State):
|
|||||||
|
|
||||||
|
|
||||||
def sql_1(state: State):
|
def sql_1(state: State):
|
||||||
return sqlChain.invoke({
|
return sqlChain.invoke(
|
||||||
"table_info": state['table_info'],
|
{
|
||||||
|
"table_info": state["table_info"],
|
||||||
"userInput": state["userInput"],
|
"userInput": state["userInput"],
|
||||||
"tenant_id": state['tenant_id']
|
"tenant_id": state["tenant_id"],
|
||||||
}).content
|
}
|
||||||
|
).content
|
||||||
|
|
||||||
|
|
||||||
improve_sql_prompt = PromptTemplate(
|
improve_sql_prompt = PromptTemplate(
|
||||||
@@ -151,18 +135,20 @@ improve_sql_prompt = PromptTemplate(
|
|||||||
7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。
|
7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。
|
||||||
8. 查询的SQL字段要用别名,取名参考描述。
|
8. 查询的SQL字段要用别名,取名参考描述。
|
||||||
9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量加上WHERE tenantid = {tenant_id}。
|
9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量加上WHERE tenantid = {tenant_id}。
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
improveSqlChain = improve_sql_prompt | llm
|
improveSqlChain = improve_sql_prompt | llm
|
||||||
|
|
||||||
|
|
||||||
def sql_2(state: State):
|
def sql_2(state: State):
|
||||||
return improveSqlChain.invoke({
|
return improveSqlChain.invoke(
|
||||||
"sql": state['sql'],
|
{
|
||||||
"table_info": state['table_info'],
|
"sql": state["sql"],
|
||||||
|
"table_info": state["table_info"],
|
||||||
"userInput": state["userInput"],
|
"userInput": state["userInput"],
|
||||||
"tenant_id": state['tenant_id']
|
"tenant_id": state["tenant_id"],
|
||||||
}).content
|
}
|
||||||
|
).content
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------ 路径选择 --------
|
# ------------------------------------------------------------------------ 路径选择 --------
|
||||||
@@ -199,7 +185,7 @@ chat → 无法直接生成 SQL,需要进一步解释或澄清。
|
|||||||
|
|
||||||
回答内容仅限于db或者chat,请勿输出其他内容。
|
回答内容仅限于db或者chat,请勿输出其他内容。
|
||||||
你的回复:
|
你的回复:
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
pathSelectChain = pathSelectPrompt | llmThink
|
pathSelectChain = pathSelectPrompt | llmThink
|
||||||
|
|
||||||
@@ -207,12 +193,18 @@ pathSelectChain = pathSelectPrompt | llmThink
|
|||||||
def decide_source(state: State, max_retry=3):
|
def decide_source(state: State, max_retry=3):
|
||||||
"""根据用户输入选择数据来源"""
|
"""根据用户输入选择数据来源"""
|
||||||
for _ in range(max_retry):
|
for _ in range(max_retry):
|
||||||
choice = pathSelectChain.invoke({
|
choice = (
|
||||||
|
pathSelectChain.invoke(
|
||||||
|
{
|
||||||
"userInput": state["userInput"],
|
"userInput": state["userInput"],
|
||||||
"table_info": state["table_info"],
|
"table_info": state["table_info"],
|
||||||
"ai_service": state["ai_service"],
|
"ai_service": state["ai_service"],
|
||||||
"sql": state["sql"]
|
"sql": state["sql"],
|
||||||
}).content.strip().lower()
|
}
|
||||||
|
)
|
||||||
|
.content.strip()
|
||||||
|
.lower()
|
||||||
|
)
|
||||||
print("根据用户输入选择数据来源,路径是:", choice)
|
print("根据用户输入选择数据来源,路径是:", choice)
|
||||||
if choice in ["db", "chat"]:
|
if choice in ["db", "chat"]:
|
||||||
state["path"] = choice
|
state["path"] = choice
|
||||||
@@ -242,17 +234,16 @@ noChatPrompt = PromptTemplate(
|
|||||||
3. 引导用户提出与你业务相关的问题。
|
3. 引导用户提出与你业务相关的问题。
|
||||||
4. 使用礼貌和友好的语气。
|
4. 使用礼貌和友好的语气。
|
||||||
你的回答:
|
你的回答:
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
noChatChain = noChatPrompt | llm
|
noChatChain = noChatPrompt | llm
|
||||||
|
|
||||||
|
|
||||||
def chat(state: State):
|
def chat(state: State):
|
||||||
state["reply"] = noChatChain.invoke({
|
state["reply"] = noChatChain.invoke(
|
||||||
"userInput": state["userInput"],
|
{"userInput": state["userInput"], "ai_service": state["ai_service"]}
|
||||||
"ai_service": state["ai_service"]
|
).content
|
||||||
}).content
|
|
||||||
print("直接回复")
|
print("直接回复")
|
||||||
return state
|
return state
|
||||||
|
|
||||||
@@ -291,19 +282,21 @@ summarizePrompt = PromptTemplate(
|
|||||||
2. 提供进一步可选的查询示例,基于当前的数据库表结构,引导用户提出更具体需求。
|
2. 提供进一步可选的查询示例,基于当前的数据库表结构,引导用户提出更具体需求。
|
||||||
|
|
||||||
你的回复:
|
你的回复:
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
summarizeChain = summarizePrompt | llm
|
summarizeChain = summarizePrompt | llm
|
||||||
|
|
||||||
|
|
||||||
def summarize_ai(state: State):
|
def summarize_ai(state: State):
|
||||||
"""AI 总结输出"""
|
"""AI 总结输出"""
|
||||||
state["reply"] = summarizeChain.invoke({
|
state["reply"] = summarizeChain.invoke(
|
||||||
|
{
|
||||||
"ai_role": state["ai_role"],
|
"ai_role": state["ai_role"],
|
||||||
"sql": state['sql'],
|
"sql": state["sql"],
|
||||||
"userInput": state['userInput'],
|
"userInput": state["userInput"],
|
||||||
"table_info": state['table_info'],
|
"table_info": state["table_info"],
|
||||||
}).content
|
}
|
||||||
|
).content
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
@@ -322,7 +315,7 @@ workflow.add_conditional_edges(
|
|||||||
{
|
{
|
||||||
"db": "sql_1",
|
"db": "sql_1",
|
||||||
"chat": "chat",
|
"chat": "chat",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
workflow.add_edge("summarize", END)
|
workflow.add_edge("summarize", END)
|
||||||
workflow.add_edge("chat", END)
|
workflow.add_edge("chat", END)
|
||||||
@@ -334,7 +327,8 @@ def get_db_agent_reply(aiId: str, userInput: str, tenant_id: str, sql: str = "")
|
|||||||
json = pgdb.get_ai_personality(aiId)
|
json = pgdb.get_ai_personality(aiId)
|
||||||
ai_service = json["业务"]
|
ai_service = json["业务"]
|
||||||
ai_role = json["性格"]
|
ai_role = json["性格"]
|
||||||
final_state = graph.invoke({
|
final_state = graph.invoke(
|
||||||
|
{
|
||||||
"ai_service": ai_service,
|
"ai_service": ai_service,
|
||||||
"ai_role": ai_role,
|
"ai_role": ai_role,
|
||||||
"table_info": pgdb.get_available_tables_str(aiId),
|
"table_info": pgdb.get_available_tables_str(aiId),
|
||||||
@@ -342,5 +336,6 @@ def get_db_agent_reply(aiId: str, userInput: str, tenant_id: str, sql: str = "")
|
|||||||
"userInput": userInput,
|
"userInput": userInput,
|
||||||
"sql": sql,
|
"sql": sql,
|
||||||
"isFirstGenSQL": sql == "",
|
"isFirstGenSQL": sql == "",
|
||||||
})
|
}
|
||||||
|
)
|
||||||
return final_state
|
return final_state
|
||||||
|
|||||||
@@ -1,34 +1,13 @@
|
|||||||
|
from typing import List
|
||||||
from typing import Literal
|
|
||||||
from langchain_core.messages import AIMessage
|
|
||||||
from langchain_core.runnables import RunnableConfig
|
|
||||||
from langgraph.graph import END, START, MessagesState, StateGraph
|
|
||||||
from langgraph.prebuilt import ToolNode
|
|
||||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|
||||||
from config.llm import llm,llmThink
|
|
||||||
from langgraph.graph import StateGraph, END
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from config.llm import llm
|
|
||||||
from typing import Annotated
|
|
||||||
from typing_extensions import TypedDict
|
|
||||||
from langgraph.graph import StateGraph, START, END
|
|
||||||
from langgraph.graph.message import add_messages
|
|
||||||
from langchain_community.agent_toolkits import create_sql_agent
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from config.llm import llm
|
|
||||||
from typing import Annotated
|
|
||||||
from langgraph.graph.message import add_messages
|
|
||||||
import os
|
|
||||||
from langchain_tavily import TavilySearch
|
|
||||||
from langgraph.prebuilt import ToolNode, tools_condition
|
|
||||||
from llm.chatLLM import get_chat_response
|
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
from langgraph.graph import StateGraph, END
|
from langgraph.graph import StateGraph, END
|
||||||
from llm.summarizeLLM import getSummary
|
|
||||||
import db.postgres as pgdb
|
|
||||||
import db.sqlserver as sqlserver
|
|
||||||
from typing import List, Dict
|
|
||||||
import db.milvus as milvus
|
import db.milvus as milvus
|
||||||
|
import db.postgres as pgdb
|
||||||
|
from config.llm import llm
|
||||||
|
from config.llm import llmThink
|
||||||
|
|
||||||
|
|
||||||
# -------- 定义状态 --------
|
# -------- 定义状态 --------
|
||||||
@@ -48,6 +27,7 @@ class State(TypedDict):
|
|||||||
userInput: str # 用户输入
|
userInput: str # 用户输入
|
||||||
reply: str # 最终回复
|
reply: str # 最终回复
|
||||||
|
|
||||||
|
|
||||||
# -------- 定义节点 --------
|
# -------- 定义节点 --------
|
||||||
# ------------------------------------------------------------------------ 向量数据库查询 --------
|
# ------------------------------------------------------------------------ 向量数据库查询 --------
|
||||||
|
|
||||||
@@ -65,23 +45,28 @@ gen_sql_prompt = PromptTemplate(
|
|||||||
4. 确保关键词之间相互独立,不包含其他关键词。
|
4. 确保关键词之间相互独立,不包含其他关键词。
|
||||||
关键词之间用空格分隔。
|
关键词之间用空格分隔。
|
||||||
你的回答是:
|
你的回答是:
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
sqlChain = gen_sql_prompt | llm
|
sqlChain = gen_sql_prompt | llm
|
||||||
|
|
||||||
|
|
||||||
def db_search(state: State):
|
def db_search(state: State):
|
||||||
key_words = sqlChain.invoke({
|
key_words = sqlChain.invoke(
|
||||||
"userInput": state['userInput'],
|
{
|
||||||
}).content
|
"userInput": state["userInput"],
|
||||||
|
}
|
||||||
|
).content
|
||||||
print("关键词是:", key_words)
|
print("关键词是:", key_words)
|
||||||
knowledge = milvus.get_knowledge_by_key_words(key_words, state['kn_bases'])
|
knowledge = milvus.get_knowledge_by_key_words(key_words, state["kn_bases"])
|
||||||
print("知识库内容是:", knowledge)
|
print("知识库内容是:", knowledge)
|
||||||
state["knowledge"] = knowledge
|
state["knowledge"] = knowledge
|
||||||
ai_ids = [state['ai_id']]
|
ai_ids = [state["ai_id"]]
|
||||||
memory = milvus.get_memory_by_key_words(key_words, ai_ids)
|
memory = milvus.get_memory_by_key_words(key_words, ai_ids)
|
||||||
print("记忆是:", memory)
|
print("记忆是:", memory)
|
||||||
state["memory"] = memory
|
state["memory"] = memory
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------ 意图分析 --------
|
# ------------------------------------------------------------------------ 意图分析 --------
|
||||||
|
|
||||||
pathSelectPrompt = PromptTemplate(
|
pathSelectPrompt = PromptTemplate(
|
||||||
@@ -103,17 +88,25 @@ pathSelectPrompt = PromptTemplate(
|
|||||||
判断规则如下:
|
判断规则如下:
|
||||||
如果用户最新回复与你的负责工作相关,需要去查知识库,输出“kn”;如果不相关,则输出“chat”,不要包含任何标点符号以及空格。
|
如果用户最新回复与你的负责工作相关,需要去查知识库,输出“kn”;如果不相关,则输出“chat”,不要包含任何标点符号以及空格。
|
||||||
你生成的结果:
|
你生成的结果:
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
pathSelectChain = pathSelectPrompt | llmThink
|
pathSelectChain = pathSelectPrompt | llmThink
|
||||||
|
|
||||||
|
|
||||||
def decide_source(state: State, max_retry=3):
|
def decide_source(state: State, max_retry=3):
|
||||||
"""根据用户输入选择数据来源"""
|
"""根据用户输入选择数据来源"""
|
||||||
for _ in range(max_retry):
|
for _ in range(max_retry):
|
||||||
choice = pathSelectChain.invoke({
|
choice = (
|
||||||
|
pathSelectChain.invoke(
|
||||||
|
{
|
||||||
"userInput": state["userInput"],
|
"userInput": state["userInput"],
|
||||||
"ai_service": state["ai_service"],
|
"ai_service": state["ai_service"],
|
||||||
"history": state["history"],
|
"history": state["history"],
|
||||||
}).content.strip().lower()
|
}
|
||||||
|
)
|
||||||
|
.content.strip()
|
||||||
|
.lower()
|
||||||
|
)
|
||||||
print("根据用户输入选择数据来源,路径是:", choice)
|
print("根据用户输入选择数据来源,路径是:", choice)
|
||||||
if choice in ["kn", "chat"]:
|
if choice in ["kn", "chat"]:
|
||||||
state["path"] = choice
|
state["path"] = choice
|
||||||
@@ -123,6 +116,7 @@ def decide_source(state: State, max_retry=3):
|
|||||||
state["path"] = "chat"
|
state["path"] = "chat"
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------ !普通聊天 --------
|
# ------------------------------------------------------------------------ !普通聊天 --------
|
||||||
noChatPrompt = PromptTemplate(
|
noChatPrompt = PromptTemplate(
|
||||||
input_variables=["ai_name", "ai_service", "ai_role", "history"],
|
input_variables=["ai_name", "ai_service", "ai_role", "history"],
|
||||||
@@ -140,21 +134,26 @@ noChatPrompt = PromptTemplate(
|
|||||||
4. 回复要简洁明了,避免冗长和复杂的表述。
|
4. 回复要简洁明了,避免冗长和复杂的表述。
|
||||||
|
|
||||||
你的回答:
|
你的回答:
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
noChatChain = noChatPrompt | llm
|
noChatChain = noChatPrompt | llm
|
||||||
|
|
||||||
|
|
||||||
def chat(state: State):
|
def chat(state: State):
|
||||||
state["reply"] = noChatChain.invoke({
|
state["reply"] = noChatChain.invoke(
|
||||||
|
{
|
||||||
"ai_name": state["ai_name"],
|
"ai_name": state["ai_name"],
|
||||||
"ai_service": state["ai_service"],
|
"ai_service": state["ai_service"],
|
||||||
"ai_role": state["ai_role"],
|
"ai_role": state["ai_role"],
|
||||||
"history": state["history"],
|
"history": state["history"],
|
||||||
"userStr": state["userInput"]
|
"userStr": state["userInput"],
|
||||||
}).content
|
}
|
||||||
|
).content
|
||||||
print("直接回复")
|
print("直接回复")
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------ 整理结果 --------
|
# ------------------------------------------------------------------------ 整理结果 --------
|
||||||
|
|
||||||
summarizePrompt = PromptTemplate(
|
summarizePrompt = PromptTemplate(
|
||||||
@@ -177,32 +176,40 @@ summarizePrompt = PromptTemplate(
|
|||||||
2. 回复的语气要结合你的性格特点。
|
2. 回复的语气要结合你的性格特点。
|
||||||
3. 确保回复内容清晰、简洁、有针对性。
|
3. 确保回复内容清晰、简洁、有针对性。
|
||||||
请生成你的回复:
|
请生成你的回复:
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
summarizeChain = summarizePrompt | llm
|
summarizeChain = summarizePrompt | llm
|
||||||
|
|
||||||
|
|
||||||
def summarize_ai(state: State):
|
def summarize_ai(state: State):
|
||||||
"""AI 总结输出"""
|
"""AI 总结输出"""
|
||||||
mem = state['memory']
|
mem = state["memory"]
|
||||||
if mem != "":
|
if mem != "":
|
||||||
memStr = """
|
memStr = (
|
||||||
|
"""
|
||||||
这是给你参考的相关历史记忆:
|
这是给你参考的相关历史记忆:
|
||||||
<memory>
|
<memory>
|
||||||
%s
|
%s
|
||||||
</memory>
|
</memory>
|
||||||
""" % mem # 这里用 % 把 mem 填进去
|
"""
|
||||||
|
% mem
|
||||||
|
) # 这里用 % 把 mem 填进去
|
||||||
else:
|
else:
|
||||||
memStr = "没有记忆内容"
|
memStr = "没有记忆内容"
|
||||||
print("历史记录是:", state["history"])
|
print("历史记录是:", state["history"])
|
||||||
state["reply"] = summarizeChain.invoke({
|
state["reply"] = summarizeChain.invoke(
|
||||||
|
{
|
||||||
"ai_role": state["ai_role"],
|
"ai_role": state["ai_role"],
|
||||||
"ai_name": state["ai_name"],
|
"ai_name": state["ai_name"],
|
||||||
"history": state["history"],
|
"history": state["history"],
|
||||||
"ai_service":state['ai_service'],
|
"ai_service": state["ai_service"],
|
||||||
"knowledge": state["knowledge"],
|
"knowledge": state["knowledge"],
|
||||||
"memory": memStr,
|
"memory": memStr,
|
||||||
}).content
|
}
|
||||||
|
).content
|
||||||
return state
|
return state
|
||||||
|
|
||||||
|
|
||||||
# ------------------------------------------------------------------------ 构建有向图 --------
|
# ------------------------------------------------------------------------ 构建有向图 --------
|
||||||
workflow = StateGraph(State)
|
workflow = StateGraph(State)
|
||||||
workflow.add_node("decide", decide_source)
|
workflow.add_node("decide", decide_source)
|
||||||
@@ -217,15 +224,18 @@ workflow.add_conditional_edges(
|
|||||||
{
|
{
|
||||||
"kn": "db_search",
|
"kn": "db_search",
|
||||||
"chat": "chat",
|
"chat": "chat",
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
workflow.add_edge("db_search", "summarize")
|
workflow.add_edge("db_search", "summarize")
|
||||||
workflow.add_edge("summarize", END)
|
workflow.add_edge("summarize", END)
|
||||||
workflow.add_edge("chat", END)
|
workflow.add_edge("chat", END)
|
||||||
graph = workflow.compile()
|
graph = workflow.compile()
|
||||||
|
|
||||||
|
|
||||||
# 执行函数
|
# 执行函数
|
||||||
def get_service_agent_reply(aiId:str, userInput: str,history:str, kn_bases:List[str]) :
|
def get_service_agent_reply(
|
||||||
|
aiId: str, userInput: str, history: str, kn_bases: List[str]
|
||||||
|
):
|
||||||
json = pgdb.get_ai_personality(aiId)
|
json = pgdb.get_ai_personality(aiId)
|
||||||
ai_service = json["业务"]
|
ai_service = json["业务"]
|
||||||
ai_role = json["性格"]
|
ai_role = json["性格"]
|
||||||
@@ -233,7 +243,8 @@ def get_service_agent_reply(aiId:str, userInput: str,history:str, kn_bases:List[
|
|||||||
print("AI Name:", ai_name)
|
print("AI Name:", ai_name)
|
||||||
print("AI Service:", ai_service)
|
print("AI Service:", ai_service)
|
||||||
|
|
||||||
final_state = graph.invoke({
|
final_state = graph.invoke(
|
||||||
|
{
|
||||||
"ai_service": ai_service,
|
"ai_service": ai_service,
|
||||||
"ai_role": ai_role,
|
"ai_role": ai_role,
|
||||||
"ai_name": ai_name,
|
"ai_name": ai_name,
|
||||||
@@ -242,5 +253,6 @@ def get_service_agent_reply(aiId:str, userInput: str,history:str, kn_bases:List[
|
|||||||
"table_info": pgdb.get_available_tables_str(aiId),
|
"table_info": pgdb.get_available_tables_str(aiId),
|
||||||
"userInput": userInput,
|
"userInput": userInput,
|
||||||
"ai_id": aiId,
|
"ai_id": aiId,
|
||||||
})
|
}
|
||||||
|
)
|
||||||
return final_state["reply"]
|
return final_state["reply"]
|
||||||
@@ -40,17 +40,14 @@ if sys.platform.lower() == "win32" or os.name.lower() == "nt":
|
|||||||
|
|
||||||
|
|
||||||
def get_device_id_simple():
|
def get_device_id_simple():
|
||||||
try:
|
hostname = os.getenv("HOST_NAME")
|
||||||
with open("/etc/machine-id") as f:
|
if not hostname:
|
||||||
mid = f.read().strip()
|
|
||||||
if mid:
|
|
||||||
return mid
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
hostname = socket.gethostname()
|
hostname = socket.gethostname()
|
||||||
mac = uuid.getnode()
|
mac = uuid.getnode()
|
||||||
mac_str = ":".join(f"{(mac >> ele) & 0xff:02x}" for ele in range(40, -1, -8))
|
mac_str = ":".join(f"{(mac >> ele) & 0xff:02x}" for ele in range(40, -1, -8))
|
||||||
return f"{hostname}|{mac_str}"
|
return f"{hostname}|{mac_str}"
|
||||||
|
else:
|
||||||
|
return hostname
|
||||||
|
|
||||||
|
|
||||||
# todo 这里需要订阅状态信息 设备发送信息 这里回复 vue前端发送指令 后端发送指令 设备接收指令
|
# todo 这里需要订阅状态信息 设备发送信息 这里回复 vue前端发送指令 后端发送指令 设备接收指令
|
||||||
|
|||||||
+13
-14
@@ -1,9 +1,10 @@
|
|||||||
from config.milvus import knVectorstore,memVectorstore
|
|
||||||
from langchain.schema import Document
|
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from typing import List, Dict, Any
|
from langchain_core.documents import Document
|
||||||
|
|
||||||
|
from config.milvus import knVectorstore, memVectorstore
|
||||||
|
|
||||||
|
|
||||||
def get_knowledge_by_key_words(key_words: str, kn_ids: List[str]) -> str:
|
def get_knowledge_by_key_words(key_words: str, kn_ids: List[str]) -> str:
|
||||||
"""
|
"""
|
||||||
@@ -17,9 +18,7 @@ def get_knowledge_by_key_words(key_words: str, kn_ids: List[str]) -> str:
|
|||||||
return "未找到相关的知识。"
|
return "未找到相关的知识。"
|
||||||
|
|
||||||
result = knVectorstore.similarity_search(
|
result = knVectorstore.similarity_search(
|
||||||
query=key_words,
|
query=key_words, k=3, expr=expr # 可调节返回条数
|
||||||
k=3, # 可调节返回条数
|
|
||||||
expr=expr
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 整理成字符串
|
# 整理成字符串
|
||||||
@@ -48,9 +47,7 @@ def get_memory_by_key_words(key_words: str, ai_ids: List[str]) -> str:
|
|||||||
expr = "" # 不限制 kn_id todo 实际上应该不反悔任何内容
|
expr = "" # 不限制 kn_id todo 实际上应该不反悔任何内容
|
||||||
|
|
||||||
result = memVectorstore.similarity_search(
|
result = memVectorstore.similarity_search(
|
||||||
query=key_words,
|
query=key_words, k=5, expr=expr # 可调节返回条数
|
||||||
k=5, # 可调节返回条数
|
|
||||||
expr=expr
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# 整理成字符串
|
# 整理成字符串
|
||||||
@@ -64,12 +61,12 @@ def get_memory_by_key_words(key_words: str, ai_ids: List[str]) -> str:
|
|||||||
# 拼成一个大字符串,用换行隔开
|
# 拼成一个大字符串,用换行隔开
|
||||||
combined_text = "\n\n".join(doc_texts)
|
combined_text = "\n\n".join(doc_texts)
|
||||||
return combined_text
|
return combined_text
|
||||||
|
|
||||||
|
|
||||||
def get_knowledge_by_base_id(base_id: str):
|
def get_knowledge_by_base_id(base_id: str):
|
||||||
expr = f'kn_id == "{base_id}"' # base_id 会被替换
|
expr = f'kn_id == "{base_id}"' # base_id 会被替换
|
||||||
result = knVectorstore.similarity_search(
|
result = knVectorstore.similarity_search(
|
||||||
query="", # 如果只想用过滤条件,可以传空字符串
|
query="", k=100, expr=expr # 如果只想用过滤条件,可以传空字符串
|
||||||
k=100,
|
|
||||||
expr=expr
|
|
||||||
)
|
)
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
@@ -80,6 +77,7 @@ def get_knowledge_by_base_id(base_id: str):
|
|||||||
for doc in result
|
for doc in result
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def add_knowledge(text: str, is_active: bool, base_id: str, user_id: str):
|
def add_knowledge(text: str, is_active: bool, base_id: str, user_id: str):
|
||||||
docs = [
|
docs = [
|
||||||
Document(
|
Document(
|
||||||
@@ -89,11 +87,12 @@ def add_knowledge(text: str, is_active: bool, base_id: str, user_id: str):
|
|||||||
"created_by": str(user_id),
|
"created_by": str(user_id),
|
||||||
"created_at": datetime.now().isoformat(),
|
"created_at": datetime.now().isoformat(),
|
||||||
"is_active": is_active,
|
"is_active": is_active,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
return knVectorstore.add_documents(docs)
|
return knVectorstore.add_documents(docs)
|
||||||
|
|
||||||
|
|
||||||
def add_memory(ai_id: str, mem: str, user_id: str, is_active: bool):
|
def add_memory(ai_id: str, mem: str, user_id: str, is_active: bool):
|
||||||
docs = [
|
docs = [
|
||||||
Document(
|
Document(
|
||||||
@@ -103,7 +102,7 @@ def add_memory(ai_id:str,mem: str, user_id: str,is_active: bool):
|
|||||||
"created_by": str(user_id),
|
"created_by": str(user_id),
|
||||||
"created_at": datetime.now().isoformat(),
|
"created_at": datetime.now().isoformat(),
|
||||||
"is_active": is_active,
|
"is_active": is_active,
|
||||||
}
|
},
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
return memVectorstore.add_documents(docs)
|
return memVectorstore.add_documents(docs)
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
from config.llm import llm
|
from config.llm import llm
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
|
|
||||||
chatPrompt = PromptTemplate(
|
chatPrompt = PromptTemplate(
|
||||||
input_variables=["aiRole", "history", "userInput"],
|
input_variables=["aiRole", "history", "userInput"],
|
||||||
@@ -15,13 +15,12 @@ chatPrompt = PromptTemplate(
|
|||||||
{userInput}
|
{userInput}
|
||||||
|
|
||||||
最后,请注意,不要编造数据,不知道就说不知道,现在,请生成你的回复:
|
最后,请注意,不要编造数据,不知道就说不知道,现在,请生成你的回复:
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
chatChain = chatPrompt | llm
|
chatChain = chatPrompt | llm
|
||||||
|
|
||||||
|
|
||||||
def get_chat_response(aiRole: str, history: str, userInput: str) -> str:
|
def get_chat_response(aiRole: str, history: str, userInput: str) -> str:
|
||||||
return chatChain.invoke({
|
return chatChain.invoke(
|
||||||
"aiRole": aiRole,
|
{"aiRole": aiRole, "history": history, "userInput": userInput}
|
||||||
"history": history,
|
)
|
||||||
"userInput": userInput
|
|
||||||
})
|
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
from langchain.prompts import PromptTemplate
|
from langchain_core.prompts import PromptTemplate
|
||||||
from config.llm import llm,llmThink
|
|
||||||
import db.milvus as milvus
|
import db.milvus as milvus
|
||||||
import db.postgres as pg
|
import db.postgres as pg
|
||||||
import json
|
from config.llm import llmThink
|
||||||
|
|
||||||
memPathPrompt = PromptTemplate(
|
memPathPrompt = PromptTemplate(
|
||||||
input_variables=["ai_role", "CHAT_RECORD"],
|
input_variables=["ai_role", "CHAT_RECORD"],
|
||||||
@@ -31,7 +31,7 @@ no:用户最新回复价值有限或几乎不会在未来业务中使用。
|
|||||||
|
|
||||||
回复不要带任何标点符号以及空格、换行符。
|
回复不要带任何标点符号以及空格、换行符。
|
||||||
请给出你的判断结果:
|
请给出你的判断结果:
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
memPathChain = memPathPrompt | llmThink
|
memPathChain = memPathPrompt | llmThink
|
||||||
memPrompt = PromptTemplate(
|
memPrompt = PromptTemplate(
|
||||||
@@ -48,9 +48,11 @@ memPrompt = PromptTemplate(
|
|||||||
4. 总结内容应包含时间,并确保时间是准确的。
|
4. 总结内容应包含时间,并确保时间是准确的。
|
||||||
5. 你需要针对你的业务场景{ai_role},展开对用户最后回复的总结。
|
5. 你需要针对你的业务场景{ai_role},展开对用户最后回复的总结。
|
||||||
请生成你的总结,以用户、时间开头:
|
请生成你的总结,以用户、时间开头:
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
memChain = memPrompt | llmThink
|
memChain = memPrompt | llmThink
|
||||||
|
|
||||||
|
|
||||||
def take_memory(ai_id: str, sessionId: str, user_id: str, max_retry=3):
|
def take_memory(ai_id: str, sessionId: str, user_id: str, max_retry=3):
|
||||||
"""根据用户输入选择数据来源"""
|
"""根据用户输入选择数据来源"""
|
||||||
history = pg.get_history_with_time(sessionId, 10)
|
history = pg.get_history_with_time(sessionId, 10)
|
||||||
@@ -66,17 +68,29 @@ def take_memory(ai_id:str,sessionId: str,user_id:str, max_retry=3):
|
|||||||
else:
|
else:
|
||||||
ai_service = json["业务"]
|
ai_service = json["业务"]
|
||||||
print("获取的描述是:", ai_service)
|
print("获取的描述是:", ai_service)
|
||||||
choice = memPathChain.invoke({
|
choice = (
|
||||||
|
memPathChain.invoke(
|
||||||
|
{
|
||||||
"ai_role": ai_service,
|
"ai_role": ai_service,
|
||||||
"CHAT_RECORD": history,
|
"CHAT_RECORD": history,
|
||||||
}).content.strip().lower()
|
}
|
||||||
|
)
|
||||||
|
.content.strip()
|
||||||
|
.lower()
|
||||||
|
)
|
||||||
print("记忆判断器判断的结果是:", choice)
|
print("记忆判断器判断的结果是:", choice)
|
||||||
if choice == "yes":
|
if choice == "yes":
|
||||||
# 对对话进行总结
|
# 对对话进行总结
|
||||||
memory = memChain.invoke({
|
memory = (
|
||||||
|
memChain.invoke(
|
||||||
|
{
|
||||||
"CHAT_RECORD": history,
|
"CHAT_RECORD": history,
|
||||||
"ai_role": ai_service,
|
"ai_role": ai_service,
|
||||||
}).content.strip().lower()
|
}
|
||||||
|
)
|
||||||
|
.content.strip()
|
||||||
|
.lower()
|
||||||
|
)
|
||||||
print("记忆生成结果是:", memory)
|
print("记忆生成结果是:", memory)
|
||||||
milvus.add_memory(mem=memory, user_id=user_id, is_active=True, ai_id=ai_id)
|
milvus.add_memory(mem=memory, user_id=user_id, is_active=True, ai_id=ai_id)
|
||||||
return
|
return
|
||||||
+14
-17
@@ -1,8 +1,8 @@
|
|||||||
|
from langchain_community.agent_toolkits import create_sql_agent
|
||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
from config.llm import llm
|
from config.llm import llm
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from config.ssDb import ssDBLC
|
from config.ssDb import ssDBLC
|
||||||
from langchain_community.agent_toolkits import create_sql_agent
|
|
||||||
|
|
||||||
# ______________________________________________________________SQL描述_____________________________________________________________________
|
# ______________________________________________________________SQL描述_____________________________________________________________________
|
||||||
sqlDescriptionPrompt = PromptTemplate(
|
sqlDescriptionPrompt = PromptTemplate(
|
||||||
@@ -16,14 +16,14 @@ sqlDescriptionPrompt = PromptTemplate(
|
|||||||
3. 不能有markdown语法
|
3. 不能有markdown语法
|
||||||
4. 要用业务语言描述,不能有专业语句例如SQL表名等
|
4. 要用业务语言描述,不能有专业语句例如SQL表名等
|
||||||
请生成你认为合适的标题,:
|
请生成你认为合适的标题,:
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
sqlDescriptionChain = sqlDescriptionPrompt | llm
|
sqlDescriptionChain = sqlDescriptionPrompt | llm
|
||||||
|
|
||||||
|
|
||||||
def get_sql_description_response(sql: str) -> str:
|
def get_sql_description_response(sql: str) -> str:
|
||||||
return sqlDescriptionChain.invoke({
|
return sqlDescriptionChain.invoke({"sql": sql})
|
||||||
"sql": sql
|
|
||||||
})
|
|
||||||
|
|
||||||
# ______________________________________________________________第一次生成SQL_____________________________________________________________________
|
# ______________________________________________________________第一次生成SQL_____________________________________________________________________
|
||||||
sqlPrompt = PromptTemplate(
|
sqlPrompt = PromptTemplate(
|
||||||
@@ -34,15 +34,12 @@ sqlPrompt = PromptTemplate(
|
|||||||
只需要返回SQL语句,不要任何解释。
|
只需要返回SQL语句,不要任何解释。
|
||||||
用户需求:{userInput}
|
用户需求:{userInput}
|
||||||
请生成SQL语句:
|
请生成SQL语句:
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
sqlChain = sqlPrompt | llm
|
sqlChain = sqlPrompt | llm
|
||||||
agent = create_sql_agent(
|
agent = create_sql_agent(llm=llm, db=ssDBLC, agent_type="tool-calling", verbose=True)
|
||||||
llm=llm,
|
|
||||||
db=ssDBLC,
|
|
||||||
agent_type="tool-calling",
|
|
||||||
verbose=True
|
|
||||||
)
|
|
||||||
# def get_chat_sql_response2( userInput: str) -> str:
|
# def get_chat_sql_response2( userInput: str) -> str:
|
||||||
# return sqlChain.invoke({
|
# return sqlChain.invoke({
|
||||||
# "userInput": userInput
|
# "userInput": userInput
|
||||||
@@ -50,6 +47,7 @@ agent = create_sql_agent(
|
|||||||
def get_chat_sql_response(userInput: str) -> str:
|
def get_chat_sql_response(userInput: str) -> str:
|
||||||
return agent.invoke({"input": userInput})["output"]
|
return agent.invoke({"input": userInput})["output"]
|
||||||
|
|
||||||
|
|
||||||
# ______________________________________________________________改进SQL_____________________________________________________________________
|
# ______________________________________________________________改进SQL_____________________________________________________________________
|
||||||
sqlImprovePrompt = PromptTemplate(
|
sqlImprovePrompt = PromptTemplate(
|
||||||
input_variables=["userInput", "sql"],
|
input_variables=["userInput", "sql"],
|
||||||
@@ -59,11 +57,10 @@ sqlImprovePrompt = PromptTemplate(
|
|||||||
只需要返回改进后的SQL语句,不要任何解释。
|
只需要返回改进后的SQL语句,不要任何解释。
|
||||||
已有SQL:{sql}
|
已有SQL:{sql}
|
||||||
用户需求:{userInput}
|
用户需求:{userInput}
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
sqlImproveChain = sqlImprovePrompt | llm
|
sqlImproveChain = sqlImprovePrompt | llm
|
||||||
|
|
||||||
|
|
||||||
def get_chat_sql_improve_response(userInput: str) -> str:
|
def get_chat_sql_improve_response(userInput: str) -> str:
|
||||||
return sqlImproveChain.invoke({
|
return sqlImproveChain.invoke({"userInput": userInput})
|
||||||
"userInput": userInput
|
|
||||||
})
|
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from config.llm import llm
|
from config.llm import llm
|
||||||
|
|
||||||
summarizePrompt = PromptTemplate(
|
summarizePrompt = PromptTemplate(
|
||||||
@@ -21,14 +21,17 @@ summarizePrompt = PromptTemplate(
|
|||||||
{infomation}
|
{infomation}
|
||||||
···
|
···
|
||||||
如果参考内容明显有问题,你要请用户重新描述问题,现在请生成你的回复:
|
如果参考内容明显有问题,你要请用户重新描述问题,现在请生成你的回复:
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
summarizeChain = summarizePrompt | llm
|
summarizeChain = summarizePrompt | llm
|
||||||
|
|
||||||
|
|
||||||
def getSummary(aiRole: str, history: str, userInput: str, infomation: str) -> str:
|
def getSummary(aiRole: str, history: str, userInput: str, infomation: str) -> str:
|
||||||
return summarizeChain.invoke({
|
return summarizeChain.invoke(
|
||||||
|
{
|
||||||
"aiRole": aiRole,
|
"aiRole": aiRole,
|
||||||
"history": history,
|
"history": history,
|
||||||
"userStr": userInput,
|
"userStr": userInput,
|
||||||
"infomation": infomation
|
"infomation": infomation,
|
||||||
}).content
|
}
|
||||||
|
).content
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from langchain.schema import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
from config.llm import *
|
from config.llm import *
|
||||||
|
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from langchain.schema import HumanMessage
|
from langchain_core.messages import HumanMessage
|
||||||
|
|
||||||
from config.llm import *
|
from config.llm import *
|
||||||
from llm.ticketLLM import decode_barcode
|
from llm.ticketLLM import decode_barcode
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
from langchain_core.prompts import PromptTemplate
|
||||||
|
|
||||||
from langchain.prompts import PromptTemplate
|
|
||||||
from config.llm import llm
|
from config.llm import llm
|
||||||
|
|
||||||
titlePrompt = PromptTemplate(
|
titlePrompt = PromptTemplate(
|
||||||
@@ -12,9 +12,10 @@ titlePrompt = PromptTemplate(
|
|||||||
4. 保持自然、易懂、专业或有趣(可根据场景调整风格)。
|
4. 保持自然、易懂、专业或有趣(可根据场景调整风格)。
|
||||||
5. 不能出现标点符号。
|
5. 不能出现标点符号。
|
||||||
用户原话:"{userStr}"
|
用户原话:"{userStr}"
|
||||||
"""
|
""",
|
||||||
)
|
)
|
||||||
titleChain = titlePrompt | llm
|
titleChain = titlePrompt | llm
|
||||||
|
|
||||||
|
|
||||||
def get_title(userInput: str):
|
def get_title(userInput: str):
|
||||||
return titleChain.invoke({"userStr": userInput}).content
|
return titleChain.invoke({"userStr": userInput}).content
|
||||||
@@ -23,6 +23,7 @@ python-multipart==0.0.20
|
|||||||
aio_pika==9.5.7
|
aio_pika==9.5.7
|
||||||
ultralytics==8.3.227
|
ultralytics==8.3.227
|
||||||
redis==7.1.0
|
redis==7.1.0
|
||||||
|
aiomqtt==2.4.0
|
||||||
# MCP服务
|
# MCP服务
|
||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
websockets>=11.0.3
|
websockets>=11.0.3
|
||||||
|
|||||||
Reference in New Issue
Block a user