更新python后端
This commit is contained in:
Generated
+5
@@ -0,0 +1,5 @@
|
||||
# Default ignored files
|
||||
/shelf/
|
||||
/workspace.xml
|
||||
# Editor-based HTTP Client requests
|
||||
/httpRequests/
|
||||
Generated
+6
@@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="AI Toolkit Settings">
|
||||
<option name="importsOfInterestPresent" value="true" />
|
||||
</component>
|
||||
</project>
|
||||
Generated
+12
@@ -0,0 +1,12 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="bbit_ai_lab" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
<component name="PyDocumentationSettings">
|
||||
<option name="format" value="PLAIN" />
|
||||
<option name="myDocStringFormat" value="Plain" />
|
||||
</component>
|
||||
</module>
|
||||
Generated
+8
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<module type="PYTHON_MODULE" version="4">
|
||||
<component name="NewModuleRootManager">
|
||||
<content url="file://$MODULE_DIR$" />
|
||||
<orderEntry type="jdk" jdkName="Python 3.12 (bbit_ai)" jdkType="Python SDK" />
|
||||
<orderEntry type="sourceFolder" forTests="false" />
|
||||
</component>
|
||||
</module>
|
||||
@@ -0,0 +1,6 @@
|
||||
<component name="InspectionProjectProfileManager">
|
||||
<settings>
|
||||
<option name="USE_PROJECT_PROFILE" value="false" />
|
||||
<version value="1.0" />
|
||||
</settings>
|
||||
</component>
|
||||
Generated
+7
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.12" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="bbit_ai_lab" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
Generated
+8
@@ -0,0 +1,8 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="ProjectModuleManager">
|
||||
<modules>
|
||||
<module fileurl="file://$PROJECT_DIR$/.idea/app.iml" filepath="$PROJECT_DIR$/.idea/app.iml" />
|
||||
</modules>
|
||||
</component>
|
||||
</project>
|
||||
Generated
+7
@@ -0,0 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="VcsDirectoryMappings">
|
||||
<mapping directory="$PROJECT_DIR$/../.." vcs="Git" />
|
||||
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
|
||||
</component>
|
||||
</project>
|
||||
+1
-2
@@ -30,8 +30,7 @@ RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple -r re
|
||||
# 复制项目代码
|
||||
COPY app/ .
|
||||
|
||||
# 对外暴露端口(FastAPI 默认 8000)
|
||||
EXPOSE 8000
|
||||
EXPOSE 13011
|
||||
|
||||
# 启动命令(使用 uvicorn 启动 FastAPI)
|
||||
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "13011"]
|
||||
|
||||
@@ -0,0 +1,139 @@
|
||||
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm
|
||||
from config.ssDb import ssDBLC
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
from langchain_community.agent_toolkits import create_sql_agent
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm
|
||||
from config.ssDb import ssDBLC
|
||||
from typing import Annotated
|
||||
from langgraph.graph.message import add_messages
|
||||
import os
|
||||
from langchain_tavily import TavilySearch
|
||||
from langgraph.prebuilt import ToolNode, tools_condition
|
||||
from llm.chatLLM import get_chat_response
|
||||
from typing import TypedDict
|
||||
from langgraph.graph import StateGraph, END
|
||||
from llm.summarizeLLM import getSummary
|
||||
|
||||
|
||||
# -------- 定义状态 --------
|
||||
class State(TypedDict):
|
||||
userInput: str # 用户输入
|
||||
source: str # 选择的数据来源:web 或 db 或 chat
|
||||
infomation: str # 查询到的内容
|
||||
aiRole: str # AI 角色
|
||||
history: str # 聊天历史
|
||||
reply: str # 最终回复
|
||||
|
||||
# -------- 定义节点 --------
|
||||
# ------------------------------------------------------------------------ 路径选择 --------
|
||||
|
||||
pathSelectPrompt = PromptTemplate(
|
||||
input_variables=["aiRole", "history", "userStr", "infomation"],
|
||||
template = """
|
||||
你是主干信息科技有限公司的业务员,是一家蚕桑服务公司,现在需要根据用户输入来判断应该使用哪种方式来回答用户的问题。
|
||||
你有三种选择:
|
||||
1. 如果用户的问题涉及最新的信息,比如新闻、事件、天气等涉及时间的内容时,请选择 "web
|
||||
2. 如果用户的问题涉及具体的蚕桑业务(例如询问农户、订单、订种、租户)的数据库查询需求,请选择 "db"
|
||||
3. 如果用户的问题是一般性的聊天或咨询,请选择 "chat"
|
||||
请只返回 "web"、"db" 或 "chat" 之一,且不要添加任何其他解释。
|
||||
用户最新输入:
|
||||
{userStr}
|
||||
请做出你的选择:
|
||||
"""
|
||||
)
|
||||
pathSelectChain = pathSelectPrompt | llm
|
||||
|
||||
def decide_source(state: State, max_retry=3):
|
||||
print("根据用户输入选择数据来源,用户输入:", state["userInput"])
|
||||
"""根据用户输入选择数据来源"""
|
||||
for _ in range(max_retry):
|
||||
choice = pathSelectChain.invoke({
|
||||
"aiRole": state["aiRole"],
|
||||
"history": state["history"],
|
||||
"userStr": state["userInput"],
|
||||
}).content.strip().lower()
|
||||
if choice in ["web", "db", "chat"]:
|
||||
state["source"] = choice
|
||||
break
|
||||
else:
|
||||
# 如果连续 max_retry 次都不合法,默认走 chat
|
||||
state["source"] = "chat"
|
||||
print("选择的数据来源是:", state["source"])
|
||||
return state
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------ 上网查询 --------
|
||||
os.environ["TAVILY_API_KEY"] = "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"
|
||||
tool = TavilySearch(max_results=2)
|
||||
|
||||
def fetch_web(state: State):
|
||||
result = tool.invoke(state["userInput"])
|
||||
state["infomation"] = result.get("content") or result
|
||||
print("调用了联网工具,结果是:", state["infomation"])
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ 数据库查询 --------
|
||||
agent = create_sql_agent(
|
||||
llm=llm,
|
||||
db=ssDBLC,
|
||||
agent_type="tool-calling",
|
||||
verbose=True
|
||||
)
|
||||
def fetch_db(state: State):
|
||||
state["infomation"] = agent.invoke({"input": state["userInput"]})["output"]
|
||||
print("调用了数据库工具,结果是:", state["infomation"])
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ 整理结果 --------
|
||||
def summarize_ai(state: State):
|
||||
"""AI 总结输出"""
|
||||
state["reply"] = getSummary(aiRole=state["aiRole"], history=state["history"], userInput= state["userInput"], infomation= state["infomation"])
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ 普通聊天 --------
|
||||
def chat(state: State):
|
||||
state["reply"] = get_chat_response(aiRole=state["aiRole"],history=state["history"], userInput= state["userInput"]).content
|
||||
print("直接回复")
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ 构建有向图 --------
|
||||
workflow = StateGraph(State)
|
||||
workflow.add_node("decide", decide_source)
|
||||
workflow.add_node("fetch_web", fetch_web)
|
||||
workflow.add_node("fetch_db", fetch_db)
|
||||
workflow.add_node("chat", chat)
|
||||
workflow.add_node("summarize", summarize_ai)
|
||||
workflow.set_entry_point("decide")
|
||||
|
||||
# 两条路径最后都汇合到 summarize
|
||||
workflow.add_edge(START, "decide")
|
||||
workflow.add_edge("fetch_web", "summarize")
|
||||
workflow.add_edge("fetch_db", "summarize")
|
||||
# 条件边:根据 source 决定走向
|
||||
workflow.add_conditional_edges(
|
||||
"decide",
|
||||
lambda state: state["source"], # 返回 state["source"] 的值
|
||||
{
|
||||
"web": "fetch_web",
|
||||
"chat": "chat",
|
||||
"db": "fetch_db"
|
||||
}
|
||||
)
|
||||
workflow.add_edge("summarize", END)
|
||||
workflow.add_edge("chat", END)
|
||||
graph = workflow.compile()
|
||||
|
||||
# 执行函数
|
||||
def get_graph_output(aiRole:str,history: str, userInput: str) -> str:
|
||||
final_state = graph.invoke({
|
||||
"aiRole":aiRole,
|
||||
"history": history,
|
||||
"userInput": userInput,
|
||||
})
|
||||
return final_state["reply"]
|
||||
@@ -0,0 +1,346 @@
|
||||
from typing import Literal
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph import END, START, MessagesState, StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from config.llm import llm, llmThink
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
from langchain_community.agent_toolkits import create_sql_agent
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm
|
||||
from typing import Annotated
|
||||
from langgraph.graph.message import add_messages
|
||||
import os
|
||||
from langchain_tavily import TavilySearch
|
||||
from langgraph.prebuilt import ToolNode, tools_condition
|
||||
from llm.chatLLM import get_chat_response
|
||||
from typing import TypedDict
|
||||
from langgraph.graph import StateGraph, END
|
||||
from llm.summarizeLLM import getSummary
|
||||
import db.postgres as pgdb
|
||||
import db.sqlserver as sqlserver
|
||||
|
||||
|
||||
# -------- 定义状态 --------
|
||||
class State(TypedDict):
|
||||
userInput: str # 用户输入
|
||||
path: str # 开始聊天选择的路径
|
||||
table_info: str # 可用表信息
|
||||
isFirstGenSQL: bool # 是否第一次生成SQL
|
||||
sql: str # 当前操作的SQL
|
||||
|
||||
ai_service: str # AI 角色 业务
|
||||
ai_role: str # AI 角色 性格特点
|
||||
tenant_id: str # 租户ID
|
||||
|
||||
history: str # 聊天历史
|
||||
reply: str # 最终回复
|
||||
|
||||
|
||||
# -------- 定义节点 --------
|
||||
# ------------------------------------------------------------------------ 数据库查询 --------
|
||||
|
||||
gen_sql_prompt = PromptTemplate(
|
||||
input_variables=["table_info", "userInput"],
|
||||
template="""
|
||||
# 角色
|
||||
你是一个企业 SQL Server 数据库 SQL 生成助手,负责根据用户问题生成相应的 SQL 语句。
|
||||
|
||||
# 已知信息
|
||||
可访问的表和字段:{table_info}
|
||||
|
||||
# 任务要求
|
||||
1. 根据用户提出的问题“{userInput}”生成 SQL 语句。
|
||||
2. 只能使用已知的表和字段。
|
||||
3. 输出完整可执行的 SQL 语句,不包含多余文字。
|
||||
4. 若 SQL 语句返回列表数据,需限制返回数量,最大为 15 条,使用 SQL Server 语法(TOP 15 或 OFFSET FETCH)。
|
||||
5. 若 SQL 是聚合查询(如 COUNT、SUM 等),无需限制行数。
|
||||
6. 在生成 SQL 时,如果需要根据身份证计算年龄,请使用 SQL Server 标准日期格式 SUBSTRING(idcard, 7, 8) 和 CONVERT(..., 112),不要使用拼接 / 或非标准日期格式。
|
||||
7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。
|
||||
8. 查询的SQL字段要用别名,取名参考描述。
|
||||
9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量限制租户id = {tenant_id}。
|
||||
|
||||
|
||||
请直接输出完整可执行的 SQL 语句,不要任何其他文字或格式化,例如反引号或 ```sql。
|
||||
"""
|
||||
)
|
||||
sqlChain = gen_sql_prompt | llm
|
||||
|
||||
fix_prompt = PromptTemplate(
|
||||
input_variables=["sql", "error_msg", "table_info", "tenant_id"],
|
||||
template="""
|
||||
# 系统角色
|
||||
你是一位专业的 SQL Server的SQL语句纠错专家,擅长识别 SQL 语句中的语法错误和字段引用错误,并能对其进行修正。
|
||||
|
||||
# 任务
|
||||
根据提供的原始 SQL 语句、执行报错信息以及可用表和字段信息,修正 SQL 语句,确保其语法正确且引用的字段存在。
|
||||
|
||||
# 输入信息
|
||||
- 原始 SQL: {sql}
|
||||
- 执行报错: {error_msg}
|
||||
- 可用表和字段: {table_info}
|
||||
|
||||
# 输出要求
|
||||
只返回修正后的 SQL 语句,不包含任何额外的解释或说明。
|
||||
"""
|
||||
)
|
||||
fixSQLChain = fix_prompt | llm
|
||||
|
||||
|
||||
def sql(state: State):
|
||||
if state["isFirstGenSQL"]:
|
||||
state['sql'] = sql_1(state)
|
||||
else:
|
||||
state['sql'] = sql_2(state)
|
||||
for attempt in range(2):
|
||||
try:
|
||||
# 执行 SQL
|
||||
result = sqlserver.executeSQL(state['sql'])
|
||||
state['sql_result'] = result
|
||||
# print("SQL 执行成功,结果:", result)
|
||||
break
|
||||
except Exception as e:
|
||||
error_msg = str(e)
|
||||
print(f"SQL 执行出错: {error_msg}")
|
||||
# 调用 LLM 修正 SQL
|
||||
state['sql'] = fixSQLChain.invoke({
|
||||
"sql": state['sql'],
|
||||
"error_msg": error_msg,
|
||||
"table_info": state['table_info'],
|
||||
"tenant_id": state['tenant_id']
|
||||
}
|
||||
).content
|
||||
# print(f"LLM 生成修正 SQL: {state['sql']}")
|
||||
else:
|
||||
raise RuntimeError(f"SQL 多次纠错失败,最后 SQL: {state['sql']}")
|
||||
return state
|
||||
|
||||
|
||||
def sql_1(state: State):
|
||||
return sqlChain.invoke({
|
||||
"table_info": state['table_info'],
|
||||
"userInput": state["userInput"],
|
||||
"tenant_id": state['tenant_id']
|
||||
}).content
|
||||
|
||||
|
||||
improve_sql_prompt = PromptTemplate(
|
||||
input_variables=["table_info", "userInput", "tenant_id"],
|
||||
template="""
|
||||
# 角色
|
||||
你是一个企业 SQL Server 数据库 SQL 生成助手,负责根据用户问题改进相应的 SQL 语句。
|
||||
|
||||
# 已知信息
|
||||
当前 SQL 语句: {sql}
|
||||
可访问的表和字段: {table_info}
|
||||
|
||||
# 任务要求
|
||||
1. 根据用户提出的问题“{userInput}”以及当前的 SQL 语句进行改进。
|
||||
2. 只能使用已知的表和字段。
|
||||
3. 输出完整可执行的 SQL 语句,不包含多余文字。
|
||||
4. 若 SQL 语句返回列表数据,需限制返回数量,最大为 15 条,使用 SQL Server 语法(TOP 15 或 OFFSET FETCH)。
|
||||
5. 若 SQL 是聚合查询(如 COUNT、SUM 等),无需限制行数。
|
||||
6. 在生成 SQL 时,如果需要根据身份证计算年龄,请使用 SQL Server 标准日期格式 SUBSTRING(idcard, 7, 8) 和 CONVERT(..., 112),不要使用拼接 / 或非标准日期格式。
|
||||
7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。
|
||||
8. 查询的SQL字段要用别名,取名参考描述。
|
||||
9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量加上WHERE tenantid = {tenant_id}。
|
||||
"""
|
||||
)
|
||||
improveSqlChain = improve_sql_prompt | llm
|
||||
|
||||
|
||||
def sql_2(state: State):
|
||||
return improveSqlChain.invoke({
|
||||
"sql": state['sql'],
|
||||
"table_info": state['table_info'],
|
||||
"userInput": state["userInput"],
|
||||
"tenant_id": state['tenant_id']
|
||||
}).content
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------ 路径选择 --------
|
||||
|
||||
pathSelectPrompt = PromptTemplate(
|
||||
input_variables=["userInput", "table_info", "sql"],
|
||||
template="""
|
||||
你的任务是:
|
||||
|
||||
1. 根据用户输入的问题和已知的表结构数据,判断是否能够生成准确的 SQL 查询。
|
||||
|
||||
2. 首先仔细阅读以下表结构数据:
|
||||
<table_info>
|
||||
{table_info}
|
||||
</table_info>
|
||||
|
||||
2. 然后仔细阅读用户输入的问题:
|
||||
<userInput>
|
||||
{userInput}
|
||||
</userInput>
|
||||
|
||||
3. 请严格遵循以下规则:
|
||||
|
||||
只有在能够完全、明确、直接根据表结构生成正确 SQL 时,输出 db。
|
||||
|
||||
参考表结构或字段描述中出现的关键词:如果用户问题中出现的关键字段或概念在表结构中找不到明确对应关系,或者问题逻辑无法直接映射到表结构,输出 chat。
|
||||
|
||||
不允许假设额外存在的表、字段或数据,也不允许基于常识或推测生成 SQL。
|
||||
|
||||
输出必须严格二选一:
|
||||
|
||||
db → 可以直接生成 SQL。
|
||||
chat → 无法直接生成 SQL,需要进一步解释或澄清。
|
||||
|
||||
回答内容仅限于db或者chat,请勿输出其他内容。
|
||||
你的回复:
|
||||
"""
|
||||
)
|
||||
pathSelectChain = pathSelectPrompt | llmThink
|
||||
|
||||
|
||||
def decide_source(state: State, max_retry=3):
|
||||
"""根据用户输入选择数据来源"""
|
||||
for _ in range(max_retry):
|
||||
choice = pathSelectChain.invoke({
|
||||
"userInput": state["userInput"],
|
||||
"table_info": state["table_info"],
|
||||
"ai_service": state["ai_service"],
|
||||
"sql": state["sql"]
|
||||
}).content.strip().lower()
|
||||
print("根据用户输入选择数据来源,路径是:", choice)
|
||||
if choice in ["db", "chat"]:
|
||||
state["path"] = choice
|
||||
break
|
||||
else:
|
||||
# 如果连续 max_retry 次都不合法,默认走 chat
|
||||
state["path"] = "chat"
|
||||
return state
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------ !普通聊天 --------
|
||||
noChatPrompt = PromptTemplate(
|
||||
input_variables=["userInput", "ai_service"],
|
||||
template="""
|
||||
你的任务是回复用户,告知用户你目前无法处理他们的回复,因为你的业务是特定领域的服务。请仔细阅读以下信息,并按照指示进行回复。
|
||||
用户的回复:
|
||||
<userInput>
|
||||
{userInput}
|
||||
</userInput>
|
||||
你的业务:
|
||||
<ai_service>
|
||||
{ai_service}
|
||||
</ai_service>
|
||||
在回复时,请遵循以下指南:
|
||||
1. 明确告知用户你无法处理当前回复。
|
||||
2. 提及你的业务是{ai_service}。
|
||||
3. 引导用户提出与你业务相关的问题。
|
||||
4. 使用礼貌和友好的语气。
|
||||
你的回答:
|
||||
"""
|
||||
)
|
||||
|
||||
noChatChain = noChatPrompt | llm
|
||||
|
||||
|
||||
def chat(state: State):
|
||||
state["reply"] = noChatChain.invoke({
|
||||
"userInput": state["userInput"],
|
||||
"ai_service": state["ai_service"]
|
||||
}).content
|
||||
print("直接回复")
|
||||
return state
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------ 整理结果 --------
|
||||
|
||||
summarizePrompt = PromptTemplate(
|
||||
input_variables=["ai_role", "history", "userStr", "table_info"],
|
||||
template="""
|
||||
你是主干信息研发的AI助手,你的性格特点为:
|
||||
<ai_role>
|
||||
{ai_role}
|
||||
</ai_role>
|
||||
用户之前的提问为:
|
||||
<userInput>
|
||||
{userInput}
|
||||
</userInput>
|
||||
当前生成的SQL语句为:
|
||||
<sql>
|
||||
{sql}
|
||||
</sql>
|
||||
当前支持的数据库表与字段信息如下:
|
||||
<table_info>
|
||||
{table_info}
|
||||
</table_info>
|
||||
|
||||
你的核心任务是根据用户之前的提问和当前生成的SQL语句,引导用户理解当前SQL的含义,并询问是否需要修改或完善,同时提供进一步可选的查询示例,引导用户提出更具体的需求。
|
||||
|
||||
交流要求如下:
|
||||
- 先明确SQL的用途,再提出引导性问题。
|
||||
- 回答要简洁、易理解。
|
||||
- 回复内容不要出现SQL语句,不要对SQL进行解释,只需说,查询结果已生成,然后引导用户进一步提问。
|
||||
|
||||
任务流程如下:
|
||||
1. 询问用户是否需要对当前查询内容进行修改或完善。
|
||||
2. 提供进一步可选的查询示例,基于当前的数据库表结构,引导用户提出更具体需求。
|
||||
|
||||
你的回复:
|
||||
"""
|
||||
)
|
||||
summarizeChain = summarizePrompt | llm
|
||||
|
||||
|
||||
def summarize_ai(state: State):
|
||||
"""AI 总结输出"""
|
||||
state["reply"] = summarizeChain.invoke({
|
||||
"ai_role": state["ai_role"],
|
||||
"sql": state['sql'],
|
||||
"userInput": state['userInput'],
|
||||
"table_info": state['table_info'],
|
||||
}).content
|
||||
return state
|
||||
|
||||
|
||||
# ------------------------------------------------------------------------ 构建有向图 --------
|
||||
workflow = StateGraph(State)
|
||||
workflow.add_node("decide", decide_source)
|
||||
workflow.add_node("sql_1", sql)
|
||||
workflow.add_node("chat", chat)
|
||||
workflow.add_node("summarize", summarize_ai)
|
||||
workflow.set_entry_point("decide")
|
||||
workflow.add_edge("sql_1", "summarize")
|
||||
# 条件边:根据 path 决定走向
|
||||
workflow.add_conditional_edges(
|
||||
"decide",
|
||||
lambda state: state["path"], # 返回 state["path"] 的值
|
||||
{
|
||||
"db": "sql_1",
|
||||
"chat": "chat",
|
||||
}
|
||||
)
|
||||
workflow.add_edge("summarize", END)
|
||||
workflow.add_edge("chat", END)
|
||||
graph = workflow.compile()
|
||||
|
||||
|
||||
# 执行函数
|
||||
def get_db_agent_reply(aiId: str, userInput: str, tenant_id: str, sql: str = "") -> str:
|
||||
json = pgdb.get_ai_personality(aiId)
|
||||
ai_service = json["业务"]
|
||||
ai_role = json["性格"]
|
||||
final_state = graph.invoke({
|
||||
"ai_service": ai_service,
|
||||
"ai_role": ai_role,
|
||||
"table_info": pgdb.get_available_tables_str(aiId),
|
||||
"tenant_id": tenant_id,
|
||||
"userInput": userInput,
|
||||
"sql": sql,
|
||||
"isFirstGenSQL": sql == "",
|
||||
})
|
||||
return final_state
|
||||
@@ -0,0 +1,246 @@
|
||||
|
||||
from typing import Literal
|
||||
from langchain_core.messages import AIMessage
|
||||
from langchain_core.runnables import RunnableConfig
|
||||
from langgraph.graph import END, START, MessagesState, StateGraph
|
||||
from langgraph.prebuilt import ToolNode
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from config.llm import llm,llmThink
|
||||
from langgraph.graph import StateGraph, END
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm
|
||||
from typing import Annotated
|
||||
from typing_extensions import TypedDict
|
||||
from langgraph.graph import StateGraph, START, END
|
||||
from langgraph.graph.message import add_messages
|
||||
from langchain_community.agent_toolkits import create_sql_agent
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm
|
||||
from typing import Annotated
|
||||
from langgraph.graph.message import add_messages
|
||||
import os
|
||||
from langchain_tavily import TavilySearch
|
||||
from langgraph.prebuilt import ToolNode, tools_condition
|
||||
from llm.chatLLM import get_chat_response
|
||||
from typing import TypedDict
|
||||
from langgraph.graph import StateGraph, END
|
||||
from llm.summarizeLLM import getSummary
|
||||
import db.postgres as pgdb
|
||||
import db.sqlserver as sqlserver
|
||||
from typing import List, Dict
|
||||
import db.milvus as milvus
|
||||
|
||||
|
||||
# -------- 定义状态 --------
|
||||
class State(TypedDict):
|
||||
path: str # 开始聊天选择的路径
|
||||
|
||||
memory:str # 记忆
|
||||
knowledge: str # 知识库内容
|
||||
history: str # 聊天历史
|
||||
|
||||
ai_id : str # AI id
|
||||
ai_name:str # AI 名称
|
||||
ai_service: str # AI 角色 业务
|
||||
ai_role: str # AI 角色 性格特点
|
||||
kn_bases: List[str] # AI 所使用的知识库
|
||||
|
||||
userInput: str # 用户输入
|
||||
reply: str # 最终回复
|
||||
|
||||
# -------- 定义节点 --------
|
||||
# ------------------------------------------------------------------------ 向量数据库查询 --------
|
||||
|
||||
gen_sql_prompt = PromptTemplate(
|
||||
input_variables=["userInput"],
|
||||
template = """你的任务是对用户输入进行意图分析,并将其分解成方便进行知识向量数据库搜索的关键词。
|
||||
以下是用户的输入:
|
||||
<用户输入>
|
||||
{userInput}
|
||||
</用户输入>
|
||||
在提取关键词时,请遵循以下方法和要求:
|
||||
1. 去除输入中的停用词(如“的”“是”“在”等)、语气词和无实际意义的符号。
|
||||
2. 识别输入中的核心概念、实体和关键动作。
|
||||
3. 尽量使用简洁、通用的词汇作为关键词。
|
||||
4. 确保关键词之间相互独立,不包含其他关键词。
|
||||
关键词之间用空格分隔。
|
||||
你的回答是:
|
||||
"""
|
||||
)
|
||||
sqlChain = gen_sql_prompt | llm
|
||||
def db_search(state: State):
|
||||
key_words = sqlChain.invoke({
|
||||
"userInput": state['userInput'],
|
||||
}).content
|
||||
print("关键词是:", key_words)
|
||||
knowledge = milvus.get_knowledge_by_key_words(key_words, state['kn_bases'])
|
||||
print("知识库内容是:", knowledge)
|
||||
state["knowledge"] = knowledge
|
||||
ai_ids = [state['ai_id']]
|
||||
memory = milvus.get_memory_by_key_words(key_words, ai_ids)
|
||||
print("记忆是:", memory)
|
||||
state["memory"] = memory
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ 意图分析 --------
|
||||
|
||||
pathSelectPrompt = PromptTemplate(
|
||||
input_variables=[ "userInput","ai_service","history"],
|
||||
template = """
|
||||
你是一个意图分类器,负责判断用户提问是否与你的工作相关,进而确定是否需要去查知识库。
|
||||
以下是你负责的工作内容:
|
||||
<ai_service>
|
||||
{ai_service}
|
||||
</ai_service>
|
||||
这是你们的对话历史:
|
||||
<history>
|
||||
{history}
|
||||
</history>
|
||||
用户最新回复是:
|
||||
<userInput>
|
||||
{userInput}
|
||||
</userInput>
|
||||
判断规则如下:
|
||||
如果用户最新回复与你的负责工作相关,需要去查知识库,输出“kn”;如果不相关,则输出“chat”,不要包含任何标点符号以及空格。
|
||||
你生成的结果:
|
||||
"""
|
||||
)
|
||||
pathSelectChain = pathSelectPrompt | llmThink
|
||||
def decide_source(state: State, max_retry=3):
|
||||
"""根据用户输入选择数据来源"""
|
||||
for _ in range(max_retry):
|
||||
choice = pathSelectChain.invoke({
|
||||
"userInput": state["userInput"],
|
||||
"ai_service": state["ai_service"],
|
||||
"history": state["history"],
|
||||
}).content.strip().lower()
|
||||
print("根据用户输入选择数据来源,路径是:", choice)
|
||||
if choice in ["kn", "chat"]:
|
||||
state["path"] = choice
|
||||
break
|
||||
else:
|
||||
# 如果连续 max_retry 次都不合法,默认走 chat
|
||||
state["path"] = "chat"
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ !普通聊天 --------
|
||||
noChatPrompt = PromptTemplate(
|
||||
input_variables=[ "ai_name", "ai_service", "ai_role", "history"],
|
||||
template = """
|
||||
你的名字是:{ai_name},你负责的业务是{ai_service},你具有{ai_role}的性格特点。
|
||||
|
||||
这是你和用户的对话历史
|
||||
<history>
|
||||
{history}
|
||||
</history>
|
||||
在回复用户时,请遵循以下指南:
|
||||
1. 回复要与AI角色业务相关,体现AI的专业能力。
|
||||
2. 回复内容的语气和风格要符合AI角色性格特点。
|
||||
3. 参考聊天历史,使回复具有连贯性和针对性。
|
||||
4. 回复要简洁明了,避免冗长和复杂的表述。
|
||||
|
||||
你的回答:
|
||||
"""
|
||||
)
|
||||
|
||||
noChatChain = noChatPrompt | llm
|
||||
def chat(state: State):
|
||||
state["reply"] = noChatChain.invoke({
|
||||
"ai_name": state["ai_name"],
|
||||
"ai_service": state["ai_service"],
|
||||
"ai_role": state["ai_role"],
|
||||
"history": state["history"],
|
||||
"userStr": state["userInput"]
|
||||
}).content
|
||||
print("直接回复")
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ 整理结果 --------
|
||||
|
||||
summarizePrompt = PromptTemplate(
|
||||
input_variables=["ai_name", "ai_service", "ai_role", "history", "knowledge"],
|
||||
template = """
|
||||
你的任务是基于给定的AI名称、AI角色业务、AI角色性格特点和聊天历史来回复用户。请仔细阅读以下信息,并按照指示进行回复。
|
||||
你的名字是:{ai_name},你负责的业务是{ai_service},你具有{ai_role}的性格特点。
|
||||
|
||||
这是你和用户的对话历史
|
||||
<history>
|
||||
{history}
|
||||
</history>
|
||||
这是给你参考的知识库:
|
||||
<knowledge>
|
||||
{knowledge}
|
||||
</knowledge>
|
||||
{memory}
|
||||
在回复时,请遵循以下指南:
|
||||
1. 回复内容要与你负责的业务相关。
|
||||
2. 回复的语气要结合你的性格特点。
|
||||
3. 确保回复内容清晰、简洁、有针对性。
|
||||
请生成你的回复:
|
||||
"""
|
||||
)
|
||||
summarizeChain = summarizePrompt | llm
|
||||
def summarize_ai(state: State):
|
||||
"""AI 总结输出"""
|
||||
mem = state['memory']
|
||||
if mem != "":
|
||||
memStr = """
|
||||
这是给你参考的相关历史记忆:
|
||||
<memory>
|
||||
%s
|
||||
</memory>
|
||||
""" % mem # 这里用 % 把 mem 填进去
|
||||
else:
|
||||
memStr = "没有记忆内容"
|
||||
print("历史记录是:" ,state["history"])
|
||||
state["reply"] = summarizeChain.invoke({
|
||||
"ai_role":state["ai_role"],
|
||||
"ai_name":state["ai_name"],
|
||||
"history":state["history"],
|
||||
"ai_service":state['ai_service'],
|
||||
"knowledge": state["knowledge"],
|
||||
"memory": memStr,
|
||||
}).content
|
||||
return state
|
||||
|
||||
# ------------------------------------------------------------------------ 构建有向图 --------
|
||||
workflow = StateGraph(State)
|
||||
workflow.add_node("decide", decide_source)
|
||||
workflow.add_node("db_search", db_search)
|
||||
workflow.add_node("chat", chat)
|
||||
workflow.add_node("summarize", summarize_ai)
|
||||
workflow.set_entry_point("decide")
|
||||
# 条件边:根据 path 决定走向
|
||||
workflow.add_conditional_edges(
|
||||
"decide",
|
||||
lambda state: state["path"], # 返回 state["path"] 的值
|
||||
{
|
||||
"kn": "db_search",
|
||||
"chat": "chat",
|
||||
}
|
||||
)
|
||||
workflow.add_edge("db_search", "summarize")
|
||||
workflow.add_edge("summarize", END)
|
||||
workflow.add_edge("chat", END)
|
||||
graph = workflow.compile()
|
||||
|
||||
# 执行函数
|
||||
def get_service_agent_reply(aiId:str, userInput: str,history:str, kn_bases:List[str]) :
|
||||
json = pgdb.get_ai_personality(aiId)
|
||||
ai_service = json["业务"]
|
||||
ai_role = json["性格"]
|
||||
ai_name = json["名字"]
|
||||
print("AI Name:", ai_name)
|
||||
print("AI Service:", ai_service)
|
||||
|
||||
final_state = graph.invoke({
|
||||
"ai_service":ai_service,
|
||||
"ai_role":ai_role,
|
||||
"ai_name":ai_name,
|
||||
"history":history,
|
||||
"kn_bases":kn_bases,
|
||||
"table_info": pgdb.get_available_tables_str(aiId),
|
||||
"userInput": userInput,
|
||||
"ai_id": aiId,
|
||||
})
|
||||
return final_state["reply"]
|
||||
+13
-3
@@ -1,10 +1,15 @@
|
||||
from fastapi import FastAPI
|
||||
from routers.Chat import router
|
||||
from routers.Chat import chatRouter
|
||||
from routers.Report import reportRouter
|
||||
from routers.Datasource import reportDataRouter
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from routers.Knowledge import knowledgeRouter
|
||||
from routers.Service import serviceRouter
|
||||
from routers.Bot import botRouter
|
||||
app = FastAPI(title="BBIT_AI")
|
||||
|
||||
origins = [
|
||||
"http://localhost:5173", # Vite dev 默认端口
|
||||
"http://localhost:8090", # Vite dev 默认端口
|
||||
"http://127.0.0.1:5173",
|
||||
"http://s1.ronsunny.cn:8089",
|
||||
"*" # ⚠️ 生产环境不要用
|
||||
@@ -17,4 +22,9 @@ app.add_middleware(
|
||||
allow_methods=["*"], # 必须包含 OPTIONS、GET 等
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.include_router(router, prefix="/api/llm", tags=["chat"])
|
||||
app.include_router(chatRouter, prefix="/api/llm", tags=["chat"])
|
||||
app.include_router(reportRouter, prefix="/api/llm", tags=["chat"])
|
||||
app.include_router(knowledgeRouter, prefix="/api/llm", tags=["chat"])
|
||||
app.include_router(reportDataRouter, prefix="/api/llm", tags=["chat"])
|
||||
app.include_router(serviceRouter, prefix="/api/llm", tags=["chat"])
|
||||
app.include_router(botRouter, prefix="/api/llm", tags=["chat"])
|
||||
@@ -1,6 +1,34 @@
|
||||
|
||||
from langchain_community.chat_models.tongyi import ChatTongyi
|
||||
from utils.Tools import all_tools
|
||||
from langchain.chat_models import init_chat_model
|
||||
from langchain_openai import ChatOpenAI
|
||||
from openai import OpenAI
|
||||
import os
|
||||
from langchain_openai import OpenAIEmbeddings
|
||||
# 通义千文Key
|
||||
tongyiKey = "sk-9464b2498c184982a9fe9d2c2e725ab5"
|
||||
# DeepSeekKey
|
||||
deepseekKey = "sk-6129a200ae294b9f86553505191fa477"
|
||||
|
||||
llm = ChatTongyi(streaming=False, api_key="sk-fb46eefb6b404382a0a5325202e923a6")
|
||||
llm = ChatTongyi(streaming=False, api_key=tongyiKey)
|
||||
llm_with_tools = llm.bind_tools(all_tools)
|
||||
|
||||
llmThink = ChatOpenAI(
|
||||
api_key=tongyiKey,
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
model="qwen-max",
|
||||
stream = False
|
||||
)
|
||||
from langchain_community.embeddings import DashScopeEmbeddings
|
||||
embeddings = DashScopeEmbeddings(
|
||||
model="text-embedding-v3",
|
||||
dashscope_api_key= tongyiKey,
|
||||
)
|
||||
|
||||
# from langchain_deepseek import ChatDeepSeek
|
||||
# llm = ChatDeepSeek(
|
||||
# model="deepseek-reasoner",
|
||||
# api_key=deepseekKey,
|
||||
# api_base="https://api.deepseek.com"
|
||||
# )
|
||||
|
||||
@@ -0,0 +1,35 @@
|
||||
from langchain_milvus import BM25BuiltInFunction, Milvus
|
||||
from config.llm import embeddings
|
||||
|
||||
URI = "http://10.10.10.9:19530"
|
||||
|
||||
knVectorstore = Milvus(
|
||||
embedding_function=embeddings,
|
||||
connection_args={"uri": URI, "token": "root:Milvus", "db_name": "bbit_ai_lab"},
|
||||
collection_name="knowledge",
|
||||
index_params={"index_type": "FLAT", "metric_type": "L2"},
|
||||
consistency_level="Strong",
|
||||
auto_id=True,
|
||||
|
||||
primary_field = "id",
|
||||
text_field="text",
|
||||
vector_field="vector",
|
||||
partition_key_field = "kn_id",
|
||||
enable_dynamic_field = True,
|
||||
drop_old=False, # set to True if seeking to drop the collection with that name if it exists
|
||||
)
|
||||
memVectorstore = Milvus(
|
||||
embedding_function=embeddings,
|
||||
connection_args={"uri": URI, "token": "root:Milvus", "db_name": "bbit_ai_lab"},
|
||||
collection_name="memory",
|
||||
index_params={"index_type": "FLAT", "metric_type": "L2"},
|
||||
consistency_level="Strong",
|
||||
auto_id=True,
|
||||
|
||||
primary_field = "id",
|
||||
text_field="text",
|
||||
vector_field="vector",
|
||||
partition_key_field = "ai_id",
|
||||
enable_dynamic_field = True,
|
||||
drop_old=False, # set to True if seeking to drop the collection with that name if it exists
|
||||
)
|
||||
+64
-17
@@ -1,22 +1,69 @@
|
||||
|
||||
import psycopg
|
||||
from langchain_postgres import PostgresChatMessageHistory
|
||||
from psycopg_pool import ConnectionPool
|
||||
import logging
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional
|
||||
import psycopg
|
||||
from psycopg_pool import ConnectionPool
|
||||
|
||||
# conn = psycopg.connect("postgresql://postgres:123456@10.10.10.9/ktor2")
|
||||
database_name = "ai_chat_history"
|
||||
pool = ConnectionPool("postgresql://postgres:123456@10.10.10.9/ktor2")
|
||||
logger = logging.getLogger("PGPool")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
@contextmanager
|
||||
def getConn():
|
||||
with pool.connection() as temp:
|
||||
temp.autocommit = True # 如果你想所有连接默认 autocommit
|
||||
yield temp # 把 conn 暴露给外部使用
|
||||
class PGPool:
|
||||
"""
|
||||
PostgreSQL 连接池封装
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
uri: str,
|
||||
min_size: int = 1,
|
||||
max_size: int = 20,
|
||||
max_idle: int = 30,
|
||||
max_lifetime: int = 300,
|
||||
timeout: int = 10,
|
||||
check: bool = False,
|
||||
):
|
||||
"""
|
||||
:param uri: PostgreSQL 连接 URI
|
||||
"""
|
||||
self.uri = uri
|
||||
self.pool = ConnectionPool(
|
||||
self.uri,
|
||||
min_size=min_size,
|
||||
max_size=max_size,
|
||||
max_idle=max_idle,
|
||||
max_lifetime=max_lifetime,
|
||||
timeout=timeout,
|
||||
check=check,
|
||||
)
|
||||
|
||||
def init():
|
||||
with getConn() as connection:
|
||||
PostgresChatMessageHistory.create_tables(connection, database_name)
|
||||
|
||||
init()
|
||||
@contextmanager
|
||||
def getConn(self, retries: int = 2, delay: float = 1.0):
|
||||
"""
|
||||
获取数据库连接,带重试机制,自动健康检查。
|
||||
使用方式:
|
||||
with pg_pool.get_conn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(...)
|
||||
"""
|
||||
attempt = 0
|
||||
while attempt <= retries:
|
||||
try:
|
||||
with self.pool.connection() as conn:
|
||||
conn.autocommit = True
|
||||
yield conn
|
||||
return
|
||||
except psycopg.OperationalError as e:
|
||||
logger.warning(f"数据库连接异常: {e}. 尝试重试 ({attempt+1}/{retries})")
|
||||
self.pool.check() # 丢掉坏连接,重新建
|
||||
attempt += 1
|
||||
time.sleep(delay)
|
||||
except Exception as e:
|
||||
logger.error(f"SQL执行异常: {e}")
|
||||
raise
|
||||
raise psycopg.OperationalError("无法获取数据库连接,多次重试失败")
|
||||
pg_pool = PGPool(
|
||||
uri="postgresql://postgres:123456@10.10.10.9/ktor2",
|
||||
min_size=1,
|
||||
max_size=20,
|
||||
)
|
||||
@@ -2,11 +2,81 @@ from langchain_community.utilities import SQLDatabase
|
||||
from langchain_community.agent_toolkits import create_sql_agent
|
||||
from langchain_community.chat_models.tongyi import ChatTongyi
|
||||
# SQLAlchemy URI
|
||||
uri = "mssql+pyodbc://f8_db_test:APN^QPr!K9@122.114.58.23/f8_db_test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes"
|
||||
uri = "mssql+pyodbc://f8_db_test:APN^QPr!K9@122.114.58.23/f8_db_test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes&Encrypt=no"
|
||||
|
||||
# 建立数据库对象
|
||||
ssDB = SQLDatabase.from_uri(
|
||||
ssDBLC = SQLDatabase.from_uri(
|
||||
uri,
|
||||
include_tables=["NONGHU_INFO"], # 不加 schema 前缀试试
|
||||
include_tables=["NONGHU_INFO","NONGHU_BLACKLIST",],
|
||||
schema="dbo" # 显式指定 schema
|
||||
)
|
||||
)
|
||||
import logging
|
||||
import time
|
||||
from contextlib import contextmanager
|
||||
from sqlalchemy import create_engine, text
|
||||
from sqlalchemy.exc import OperationalError as SQLOperationalError
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
logger = logging.getLogger("MSSQLPool")
|
||||
logger.setLevel(logging.INFO)
|
||||
|
||||
class MSSQLPool:
|
||||
"""
|
||||
SQL Server 连接池封装
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user: str,
|
||||
password: str,
|
||||
host: str,
|
||||
database: str,
|
||||
driver: str = "ODBC Driver 18 for SQL Server",
|
||||
encrypt: str = "no",
|
||||
trust_server_certificate: str = "yes",
|
||||
min_size: int = 1,
|
||||
max_size: int = 10,
|
||||
pool_timeout: int = 10,
|
||||
):
|
||||
self.user = user
|
||||
self.password = quote_plus(password) # 处理特殊字符
|
||||
self.host = host
|
||||
self.database = database
|
||||
self.driver = driver
|
||||
self.encrypt = encrypt
|
||||
self.trust_server_certificate = trust_server_certificate
|
||||
|
||||
self.engine = create_engine(
|
||||
f"mssql+pyodbc://{self.user}:{self.password}@{self.host}/{self.database}"
|
||||
f"?driver={self.driver}&Encrypt={self.encrypt}&TrustServerCertificate={self.trust_server_certificate}",
|
||||
pool_size=min_size,
|
||||
max_overflow=max_size - min_size,
|
||||
pool_timeout=pool_timeout,
|
||||
)
|
||||
|
||||
@contextmanager
|
||||
def getConn(self, retries: int = 2, delay: float = 1.0):
|
||||
"""
|
||||
获取连接,带重试
|
||||
"""
|
||||
attempt = 0
|
||||
while attempt <= retries:
|
||||
try:
|
||||
with self.engine.connect() as conn:
|
||||
yield conn
|
||||
return
|
||||
except SQLOperationalError as e:
|
||||
logger.warning(f"数据库连接异常: {e}. 尝试重试 ({attempt+1}/{retries})")
|
||||
attempt += 1
|
||||
time.sleep(delay)
|
||||
except Exception as e:
|
||||
logger.error(f"SQL执行异常: {e}")
|
||||
raise
|
||||
raise SQLOperationalError("无法获取数据库连接,多次重试失败")
|
||||
|
||||
mssql_pool = MSSQLPool(
|
||||
user="f8_db_test",
|
||||
password="APN^QPr!K9",
|
||||
host="122.114.58.23",
|
||||
database="f8_db_test",
|
||||
)
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
from config.milvus import knVectorstore,memVectorstore
|
||||
from langchain.schema import Document
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from typing import List, Dict, Any
|
||||
|
||||
def get_knowledge_by_key_words(key_words: str, kn_ids: List[str]) -> str:
|
||||
"""
|
||||
根据关键词和 kn_ids 列表,在知识库中检索相关内容,并返回整理后的文本字符串
|
||||
"""
|
||||
# 构建过滤表达式:只查 kn_ids 范围内的
|
||||
if kn_ids:
|
||||
ids_expr = " or ".join([f'kn_id == "{kid}"' for kid in kn_ids])
|
||||
expr = f"({ids_expr})"
|
||||
else:
|
||||
return "未找到相关的知识。"
|
||||
|
||||
result = knVectorstore.similarity_search(
|
||||
query=key_words,
|
||||
k=3, # 可调节返回条数
|
||||
expr=expr
|
||||
)
|
||||
|
||||
# 整理成字符串
|
||||
doc_texts = []
|
||||
for idx, doc in enumerate(result, start=1):
|
||||
text = doc.page_content.strip()
|
||||
if text:
|
||||
# 可以加个编号,便于LLM区分
|
||||
doc_texts.append(f"[文档{idx}]: {text}")
|
||||
|
||||
# 拼成一个大字符串,用换行隔开
|
||||
combined_text = "\n\n".join(doc_texts)
|
||||
return combined_text
|
||||
|
||||
|
||||
def get_memory_by_key_words(key_words: str, ai_ids: List[str]) -> str:
|
||||
print("ai_id是:" , ai_ids)
|
||||
"""
|
||||
根据关键词和 ai_ids 列表,在知识库中检索相关内容,并返回整理后的文本字符串
|
||||
"""
|
||||
# 构建过滤表达式:只查 kn_ids 范围内的
|
||||
if ai_ids:
|
||||
ids_expr = " or ".join([f'ai_id == "{kid}"' for kid in ai_ids])
|
||||
expr = f"({ids_expr})"
|
||||
else:
|
||||
expr = "" # 不限制 kn_id todo 实际上应该不反悔任何内容
|
||||
|
||||
result = memVectorstore.similarity_search(
|
||||
query=key_words,
|
||||
k=5, # 可调节返回条数
|
||||
expr=expr
|
||||
)
|
||||
|
||||
# 整理成字符串
|
||||
doc_texts = []
|
||||
for idx, doc in enumerate(result, start=1):
|
||||
text = doc.page_content.strip()
|
||||
if text:
|
||||
# 可以加个编号,便于LLM区分
|
||||
doc_texts.append(f"[记忆{idx}]: {text}")
|
||||
|
||||
# 拼成一个大字符串,用换行隔开
|
||||
combined_text = "\n\n".join(doc_texts)
|
||||
return combined_text
|
||||
def get_knowledge_by_base_id(base_id: str):
|
||||
expr = f'kn_id == "{base_id}"' # base_id 会被替换
|
||||
result = knVectorstore.similarity_search(
|
||||
query="", # 如果只想用过滤条件,可以传空字符串
|
||||
k=100,
|
||||
expr=expr
|
||||
)
|
||||
return [
|
||||
{
|
||||
"id": str(doc.metadata["id"]),
|
||||
"text": doc.page_content,
|
||||
"is_active": doc.metadata["is_active"],
|
||||
}
|
||||
for doc in result
|
||||
]
|
||||
|
||||
def add_knowledge(text: str, is_active: bool, base_id: str, user_id: str):
|
||||
docs = [
|
||||
Document(
|
||||
page_content=text,
|
||||
metadata={
|
||||
"kn_id": str(base_id),
|
||||
"created_by": str(user_id),
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"is_active": is_active,
|
||||
}
|
||||
)
|
||||
]
|
||||
return knVectorstore.add_documents(docs)
|
||||
|
||||
def add_memory(ai_id:str,mem: str, user_id: str,is_active: bool):
|
||||
docs = [
|
||||
Document(
|
||||
page_content=mem,
|
||||
metadata={
|
||||
"ai_id": str(ai_id),
|
||||
"created_by": str(user_id),
|
||||
"created_at": datetime.now().isoformat(),
|
||||
"is_active": is_active,
|
||||
}
|
||||
)
|
||||
]
|
||||
return memVectorstore.add_documents(docs)
|
||||
+324
-30
@@ -1,11 +1,16 @@
|
||||
import psycopg
|
||||
from langchain_postgres import PostgresChatMessageHistory
|
||||
from config.pgDb import database_name,getConn
|
||||
from config.pgDb import pg_pool
|
||||
from config.ssDb import mssql_pool
|
||||
from typing import List, Dict
|
||||
import json
|
||||
|
||||
# ————————————————————————————————————————————————————AI角色———————————————————————————————
|
||||
|
||||
database_name = "ai_chat_history"
|
||||
|
||||
|
||||
def get_ai_personality(ai_id: str):
|
||||
with getConn() as conn:
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT ai_personality FROM ai_chat_profiles WHERE id = %s",
|
||||
@@ -16,9 +21,33 @@ def get_ai_personality(ai_id: str):
|
||||
return row[0]
|
||||
else:
|
||||
return "你是一个乐于助人的AI助手,请保持中文简洁回答用户。"
|
||||
|
||||
def get_all_ai(user_id: str) -> List[Dict]:
|
||||
with getConn() as conn:
|
||||
|
||||
|
||||
def get_description(ai_id: str):
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT description FROM ai_chat_profiles WHERE id = %s",
|
||||
(ai_id,)
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
return row[0]
|
||||
else:
|
||||
return "你是一个乐于助人的AI助手,请保持中文简洁回答用户。"
|
||||
|
||||
|
||||
def get_ai_available_kn_bases(ai_id: str) -> List[str]:
|
||||
with pg_pool.getConn() as conn:
|
||||
result = conn.execute(
|
||||
"SELECT available_kn_bases FROM ai_chat_profiles WHERE id = %s",
|
||||
(ai_id,)
|
||||
)
|
||||
return result.fetchone()[0]
|
||||
|
||||
|
||||
def get_all_ai_bot(user_id: str, module: str) -> List[Dict]:
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
# 查询用户角色
|
||||
cur.execute(
|
||||
@@ -31,27 +60,43 @@ def get_all_ai(user_id: str) -> List[Dict]:
|
||||
user_roles = role_row[0]
|
||||
|
||||
# 查询 AI 角色 JSON 字段包含用户角色
|
||||
|
||||
cur.execute(
|
||||
"""
|
||||
SELECT id, name, welcome_words
|
||||
SELECT id, title, description, welcome_words, ai_personality, available_report_tables, available_kn_bases
|
||||
FROM ai_chat_profiles
|
||||
WHERE availabel_roles::jsonb ?| %s
|
||||
WHERE available_module = %s
|
||||
AND is_active = TRUE
|
||||
AND available_roles::jsonb ?| %s
|
||||
""",
|
||||
(user_roles,) # user_roles 是 list,比如 ["a", "b", "c"]
|
||||
(module, user_roles)
|
||||
)
|
||||
rows = cur.fetchall()
|
||||
return [
|
||||
{
|
||||
"id": row[0],
|
||||
"name": row[1],
|
||||
"welcomeWords": row[2],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
result = []
|
||||
for row in rows:
|
||||
# row 索引对应 SELECT 字段顺序
|
||||
id_, title, description, welcome_words, ai_personality, available_report_tables, available_kn_bases = row
|
||||
|
||||
# 解析 JSON
|
||||
roles_json = ai_personality if ai_personality else {}
|
||||
result.append({
|
||||
"id": id_,
|
||||
"title": title,
|
||||
"description": description,
|
||||
"welcome_words": welcome_words,
|
||||
"name": roles_json.get("名字", ""),
|
||||
"role": roles_json.get("性格", ""),
|
||||
"service": roles_json.get("业务", ""),
|
||||
"available_report_tables": available_report_tables,
|
||||
"available_kn_bases": available_kn_bases
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ————————————————————————————————————————————————————消息———————————————————————————————
|
||||
def insert_message(session_id: str, isAI: bool, content: str):
|
||||
with getConn() as conn:
|
||||
with pg_pool.getConn() as conn:
|
||||
history = PostgresChatMessageHistory(
|
||||
database_name,
|
||||
session_id,
|
||||
@@ -62,34 +107,56 @@ def insert_message(session_id: str, isAI: bool, content: str):
|
||||
else:
|
||||
history.add_user_message(content)
|
||||
|
||||
|
||||
def get_history(session_id: str):
|
||||
with getConn() as conn:
|
||||
simplified = []
|
||||
with pg_pool.getConn() as conn:
|
||||
history = PostgresChatMessageHistory(
|
||||
database_name,
|
||||
session_id,
|
||||
sync_connection=conn
|
||||
)
|
||||
simplified = []
|
||||
for msg in history.messages:
|
||||
simplified.append({
|
||||
"type": msg.type,
|
||||
"content": msg.content
|
||||
})
|
||||
return simplified
|
||||
|
||||
|
||||
def get_history_with_time(session_id: str, number: int):
|
||||
simplified = []
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
f"SELECT message, created_at FROM ai_chat_history WHERE session_id = '{session_id}' ORDER BY created_at DESC LIMIT {number}")
|
||||
rows = cur.fetchall()
|
||||
simplified = []
|
||||
|
||||
for row in rows:
|
||||
msg_dict = row[0]
|
||||
simplified.append({
|
||||
"type": msg_dict.get("type"),
|
||||
"created_at": row[1].isoformat(),
|
||||
"content": msg_dict.get("data", {}).get("content")
|
||||
})
|
||||
|
||||
return simplified
|
||||
|
||||
|
||||
# ————————————————————————————————————————————————————会话———————————————————————————————
|
||||
def insert_session(user_id: str,ai_id:str, session_id: str,session_title: str):
|
||||
with getConn() as coon:
|
||||
def insert_session(user_id: str, ai_id: str, session_id: str, session_title: str, available_module):
|
||||
with pg_pool.getConn() as coon:
|
||||
with coon.cursor() as cur:
|
||||
cur.execute(
|
||||
"INSERT INTO ai_chat_sessions (id ,user_id, ai_id, title, created_at, updated_at) VALUES (%s, %s, %s, %s, NOW(), NOW())",
|
||||
(session_id, user_id, ai_id, session_title )
|
||||
"INSERT INTO ai_chat_sessions (id ,user_id, ai_id, title, available_module, created_at, updated_at) VALUES (%s, %s, %s, %s,%s, NOW(), NOW())",
|
||||
(session_id, user_id, ai_id, session_title, available_module)
|
||||
)
|
||||
coon.commit()
|
||||
|
||||
|
||||
def update_session_updated_at(session_id: str):
|
||||
with getConn() as conn:
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"UPDATE ai_chat_sessions SET updated_at = NOW() WHERE id = %s",
|
||||
@@ -97,13 +164,18 @@ def update_session_updated_at(session_id: str):
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
def get_sessions(user_id: str):
|
||||
with getConn() as conn:
|
||||
|
||||
def get_sessions(user_id: str, available_module: str):
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT id, title, updated_at FROM ai_chat_sessions WHERE user_id = %s ORDER BY updated_at DESC",
|
||||
(user_id,)
|
||||
"SELECT id, title, updated_at "
|
||||
"FROM ai_chat_sessions "
|
||||
"WHERE user_id = %s AND available_module = %s "
|
||||
"ORDER BY updated_at DESC",
|
||||
(user_id, available_module)
|
||||
)
|
||||
|
||||
sessions = cur.fetchall()
|
||||
return [
|
||||
{
|
||||
@@ -112,4 +184,226 @@ def get_sessions(user_id: str):
|
||||
"updated_at": row[2]
|
||||
}
|
||||
for row in sessions
|
||||
]
|
||||
]
|
||||
|
||||
|
||||
# ————————————————————————————————————————————————————报表———————————————————————————————
|
||||
def get_reports(user_id: str):
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT id, title FROM ai_reports WHERE created_by = %s AND is_masked = TRUE ORDER BY created_at DESC",
|
||||
(user_id,)
|
||||
)
|
||||
reports = cur.fetchall()
|
||||
return [
|
||||
{
|
||||
"id": row[0],
|
||||
"title": row[1]
|
||||
}
|
||||
for row in reports
|
||||
]
|
||||
|
||||
|
||||
def save_report(id: str, user_id: str, title: str, sql: str):
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"INSERT INTO ai_reports (id, title, sql, created_at, created_by , is_masked) VALUES (%s, %s, %s, NOW(), %s, FALSE) RETURNING id",
|
||||
(id, title, sql, user_id)
|
||||
)
|
||||
report_id = cur.fetchone()[0]
|
||||
conn.commit()
|
||||
return report_id
|
||||
|
||||
|
||||
def maked_report(report_id: str, title: str):
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"UPDATE ai_reports SET title = %s, is_masked = TRUE WHERE id = %s",
|
||||
(title, report_id)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def getSQL(reportId: str):
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
cur.execute(
|
||||
"SELECT sql FROM ai_reports WHERE id = %s",
|
||||
(reportId,)
|
||||
)
|
||||
row = cur.fetchone()
|
||||
if row:
|
||||
return row[0]
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
def get_available_tables_str(aiId: str):
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cur:
|
||||
# 1. 先取 AI 可用的数据库表
|
||||
cur.execute(
|
||||
"SELECT available_report_tables FROM ai_chat_profiles WHERE id = %s",
|
||||
(aiId,)
|
||||
)
|
||||
role_row = cur.fetchone()
|
||||
if not role_row:
|
||||
return "无数据库表可用"
|
||||
available_tables = role_row[0] # 假设是列表
|
||||
|
||||
if not available_tables:
|
||||
return "无数据库表可用"
|
||||
|
||||
# 2. 构造 IN 查询占位符
|
||||
placeholders = ','.join(['%s'] * len(available_tables))
|
||||
sql_query = f"""
|
||||
SELECT id, name, description
|
||||
FROM ai_reports_tables
|
||||
WHERE id IN ({placeholders}) AND is_active = TRUE
|
||||
"""
|
||||
cur.execute(sql_query, available_tables)
|
||||
tableIds = cur.fetchall()
|
||||
|
||||
# 3. 查询这些表的字段
|
||||
result = ""
|
||||
for table in tableIds:
|
||||
cur.execute(
|
||||
"SELECT name, type, description FROM ai_reports_fields WHERE table_id = %s AND is_active = TRUE",
|
||||
(table[0],)
|
||||
)
|
||||
columns = cur.fetchall()
|
||||
result += f"{table[1]}:{table[2]}\n"
|
||||
result += "字段名,数据类型,描述\n"
|
||||
for column in columns:
|
||||
result += f"{column[0]},{column[1]}, {column[2]}\n"
|
||||
result += "\n"
|
||||
return result
|
||||
|
||||
|
||||
# -------------------报表数据源------------------
|
||||
# 获取表
|
||||
def get_available_tables():
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, name, description,is_active FROM ai_reports_tables",
|
||||
)
|
||||
return [{"id": row[0], "name": row[1], "description": row[2], "is_active": row[3]} for row in
|
||||
cursor.fetchall()]
|
||||
|
||||
|
||||
# 新增表
|
||||
def add_table(name, description, user_id):
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO ai_reports_tables (name, description, create_by)
|
||||
VALUES (%s, %s, %s)
|
||||
RETURNING id
|
||||
""",
|
||||
(name, description, user_id)
|
||||
)
|
||||
new_id = cursor.fetchone()[0] # 取返回的 id
|
||||
return new_id
|
||||
|
||||
|
||||
# 获取字段
|
||||
def get_fields_by_table_id(table_id):
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"SELECT id, name, type, description, is_active FROM ai_reports_fields WHERE table_id = %s",
|
||||
(table_id,),
|
||||
)
|
||||
return [{"id": row[0], "name": row[1], "type": row[2], "description": row[3], "is_active": row[4]} for row
|
||||
in cursor.fetchall()]
|
||||
|
||||
|
||||
# 新增字段
|
||||
def add_field(name, type, description, is_active, table_id, user_id):
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"INSERT INTO ai_reports_fields (name,type,description, is_active, create_by, table_id) VALUES (%s, %s, %s, %s, %s, %s) RETURNING id",
|
||||
(name, type, description, is_active, user_id, table_id)
|
||||
)
|
||||
new_id = cursor.fetchone()[0] # 取返回的 id
|
||||
return new_id
|
||||
|
||||
|
||||
# 新增报表智能体
|
||||
def insert_bot(title: str, description: str, welcome_words: str, ai_personality: str, available_module: str,
|
||||
available_report_tables: str, available_kn_bases: str, user_id: str):
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
available_roles = json.dumps(['user'])
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO ai_chat_profiles
|
||||
(available_module,available_roles, title, description, welcome_words, ai_personality, available_report_tables, available_kn_bases, created_by, created_at)
|
||||
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, now())
|
||||
RETURNING id
|
||||
""",
|
||||
(available_module, available_roles, title, description, welcome_words, ai_personality,
|
||||
available_report_tables, available_kn_bases, user_id)
|
||||
)
|
||||
report_id = cursor.fetchone()[0]
|
||||
return report_id
|
||||
|
||||
|
||||
# 更新报表智能体
|
||||
def update_bot(id: str, title: str, description: str, welcome_words: str, ai_personality: str, available_module: str,
|
||||
available_report_tables: str, available_kn_bases: str, user_id: str):
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute("""
|
||||
UPDATE ai_chat_profiles
|
||||
SET title = %s,
|
||||
description = %s,
|
||||
ai_personality = %s,
|
||||
welcome_words = %s,
|
||||
available_report_tables = %s,
|
||||
available_kn_bases = %s,
|
||||
available_module = %s,
|
||||
updated_at = NOW(),
|
||||
updated_by = %s
|
||||
WHERE id = %s
|
||||
""",
|
||||
(title, description, ai_personality, welcome_words, available_report_tables,
|
||||
available_kn_bases, available_module, user_id, id)
|
||||
)
|
||||
|
||||
|
||||
# ————————————————————————————————————————————————————知识库———————————————————————————————
|
||||
def get_available_knowledge_bases(available_module: str):
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
SELECT id, name, description, is_active
|
||||
FROM ai_knowledge
|
||||
WHERE available_module::jsonb @> %s::jsonb
|
||||
""",
|
||||
(f'["{available_module}"]',)
|
||||
)
|
||||
|
||||
return [{"id": row[0], "name": row[1], "description": row[2], "is_active": row[3]} for row in
|
||||
cursor.fetchall()]
|
||||
|
||||
|
||||
def add_knowledge_base(name, description, user_id):
|
||||
with pg_pool.getConn() as conn:
|
||||
with conn.cursor() as cursor:
|
||||
cursor.execute(
|
||||
"""
|
||||
INSERT INTO ai_knowledge (name, description, created_by, created_at)
|
||||
VALUES (%s, %s, %s, now())
|
||||
RETURNING id
|
||||
""",
|
||||
(name, description, user_id)
|
||||
)
|
||||
new_id = cursor.fetchone()[0] # 取返回的 id
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy import text
|
||||
from config.pgDb import pg_pool
|
||||
from config.ssDb import mssql_pool
|
||||
from sqlalchemy import text
|
||||
|
||||
def executeSQL(sql: str):
|
||||
"""
|
||||
执行 SQL 并返回结果列表,每行是 dict
|
||||
"""
|
||||
with mssql_pool.getConn() as conn:
|
||||
result = conn.execute(text(sql))
|
||||
# SQLAlchemy 2.x 返回 Row 对象,转成 dict
|
||||
return [dict(row._mapping) for row in result]
|
||||
|
||||
def get_company_list(user_id: str):
|
||||
# 1️⃣ 从 PostgreSQL 获取 tenant_id
|
||||
with pg_pool.getConn() as pg_conn:
|
||||
with pg_conn.cursor() as cur:
|
||||
cur.execute("SELECT bbit_tenant_id FROM users WHERE id = %s", (user_id,))
|
||||
row = cur.fetchone()
|
||||
tenant_id = row[0] if row else None
|
||||
|
||||
# 2️⃣ 从 SQL Server 查询租户信息
|
||||
if tenant_id:
|
||||
query = text("SELECT Id, Name FROM dbo.POC_TENANTS WHERE Id = :tenant_id")
|
||||
params = {"tenant_id": tenant_id}
|
||||
else:
|
||||
query = text("SELECT Id, Name FROM dbo.POC_TENANTS")
|
||||
params = {}
|
||||
|
||||
with mssql_pool.getConn() as mssql_conn:
|
||||
result = mssql_conn.execute(query, params)
|
||||
return [{"id": str(row[0]), "name": row[1]} for row in result.fetchall()]
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from langchain.prompts import PromptTemplate
|
||||
chatPrompt = PromptTemplate(
|
||||
input_variables=["aiRole", "history", "userInput"],
|
||||
template = """
|
||||
你是一个人,用户画像为:{aiRole}。
|
||||
你的用户画像为:{aiRole}。
|
||||
你需要基于你的角色性格,使用中文回答用户。
|
||||
|
||||
聊天历史:
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.llm import llm,llmThink
|
||||
import db.milvus as milvus
|
||||
import db.postgres as pg
|
||||
import json
|
||||
|
||||
memPathPrompt = PromptTemplate(
|
||||
input_variables=["ai_role", "CHAT_RECORD"],
|
||||
template = """
|
||||
你是一个记忆筛选器,负责判断最近对话的信息中,用户的回复内容是否对业务具有长期价值或潜在价值,或者可以帮助形成用户画像。
|
||||
首先,请仔细阅读以下关于你业务的描述:
|
||||
<ai_role>
|
||||
{ai_role}
|
||||
</ai_role>
|
||||
现在,请仔细阅读以下你与用户的聊天记录:
|
||||
<聊天记录>
|
||||
{CHAT_RECORD}
|
||||
</聊天记录>
|
||||
|
||||
请仔细考虑以下标准:
|
||||
|
||||
1. 长期价值:用户最新回复信息是否能为你的业务提供知识积累、经验总结、数据支持。
|
||||
|
||||
2. 相关性:用户最新回复是否与业务核心需求、目标、流程或潜在业务场景相关。
|
||||
|
||||
3. 潜在可用性:用户最新回复是否可能在未来的业务场景中被重复使用、参考或触发进一步操作。
|
||||
|
||||
你需要根据以上标准给出判断并得出"yes"或"no"
|
||||
yes:用户最新回复具有直接或潜在长期价值,值得保留。
|
||||
no:用户最新回复价值有限或几乎不会在未来业务中使用。
|
||||
|
||||
回复不要带任何标点符号以及空格、换行符。
|
||||
请给出你的判断结果:
|
||||
"""
|
||||
)
|
||||
memPathChain = memPathPrompt | llmThink
|
||||
memPrompt = PromptTemplate(
|
||||
input_variables=["CHAT_RECORD"],
|
||||
template = """
|
||||
你的任务是对给定的聊天记录进行关键信息的记忆总结。请仔细阅读以下聊天记录,并按照要求进行总结:
|
||||
<聊天记录>
|
||||
{CHAT_RECORD}
|
||||
</聊天记录>
|
||||
在总结时,请遵循以下指南:
|
||||
1. 提取聊天记录中的用户所说的关键信息,包括主要话题、重要观点、达成的共识或决定等。
|
||||
2. 用简洁明了的语言进行总结有价值的信息,避免冗长和复杂的表述。
|
||||
3. 确保总结内容准确反映聊天记录中用户的核心内容,并尽可能简短。
|
||||
4. 总结内容应包含时间,并确保时间是准确的。
|
||||
5. 你需要针对你的业务场景{ai_role},展开对用户最后回复的总结。
|
||||
请生成你的总结,以用户、时间开头:
|
||||
"""
|
||||
)
|
||||
memChain = memPrompt | llmThink
|
||||
def take_memory(ai_id:str,sessionId: str,user_id:str, max_retry=3):
|
||||
"""根据用户输入选择数据来源"""
|
||||
history = pg.get_history_with_time(sessionId,10)
|
||||
print("获取的历史记录:",history)
|
||||
ai_service = pg.get_description(ai_id)
|
||||
if ai_service == "":
|
||||
# AI描述没有描述,则取业务字段
|
||||
json = pg.get_ai_personality(ai_id)
|
||||
if json.get("业务", "") == "":
|
||||
# AI没有任何描述,无法对记忆价值进行判断
|
||||
print("AI没有任何描述,无法对记忆价值进行判断")
|
||||
return
|
||||
else:
|
||||
ai_service = json["业务"]
|
||||
print("获取的描述是:", ai_service)
|
||||
choice = memPathChain.invoke({
|
||||
"ai_role": ai_service,
|
||||
"CHAT_RECORD": history,
|
||||
}).content.strip().lower()
|
||||
print("记忆判断器判断的结果是:", choice)
|
||||
if choice == "yes":
|
||||
# 对对话进行总结
|
||||
memory = memChain.invoke({
|
||||
"CHAT_RECORD": history,
|
||||
"ai_role": ai_service,
|
||||
}).content.strip().lower()
|
||||
print("记忆生成结果是:", memory)
|
||||
milvus.add_memory(mem = memory,user_id = user_id, is_active = True, ai_id = ai_id)
|
||||
return
|
||||
@@ -0,0 +1,69 @@
|
||||
|
||||
from config.llm import llm
|
||||
from langchain.prompts import PromptTemplate
|
||||
from config.ssDb import ssDBLC
|
||||
from langchain_community.agent_toolkits import create_sql_agent
|
||||
|
||||
#______________________________________________________________SQL描述_____________________________________________________________________
|
||||
sqlDescriptionPrompt = PromptTemplate(
|
||||
input_variables=["sql"],
|
||||
template = """
|
||||
你是一个SQL专家,精通SQLServer数据库。请把一下SQL查询语句用通俗易懂的中文进行总结。
|
||||
SQL语句:{sql}
|
||||
有以下要求:
|
||||
1. 不要任何解释
|
||||
2. 不能有标点符号
|
||||
3. 不能有markdown语法
|
||||
4. 要用业务语言描述,不能有专业语句例如SQL表名等
|
||||
请生成你认为合适的标题,:
|
||||
"""
|
||||
)
|
||||
sqlDescriptionChain = sqlDescriptionPrompt | llm
|
||||
|
||||
def get_sql_description_response( sql: str) -> str:
|
||||
return sqlDescriptionChain.invoke({
|
||||
"sql": sql
|
||||
})
|
||||
|
||||
#______________________________________________________________第一次生成SQL_____________________________________________________________________
|
||||
sqlPrompt = PromptTemplate(
|
||||
input_variables=["userInput"],
|
||||
template = """
|
||||
你是一个SQL专家,精通SQLServer数据库。
|
||||
请根据用户的需求,生成相应的SQL查询语句。
|
||||
只需要返回SQL语句,不要任何解释。
|
||||
用户需求:{userInput}
|
||||
请生成SQL语句:
|
||||
"""
|
||||
)
|
||||
sqlChain = sqlPrompt | llm
|
||||
agent = create_sql_agent(
|
||||
llm=llm,
|
||||
db=ssDBLC,
|
||||
agent_type="tool-calling",
|
||||
verbose=True
|
||||
)
|
||||
# def get_chat_sql_response2( userInput: str) -> str:
|
||||
# return sqlChain.invoke({
|
||||
# "userInput": userInput
|
||||
# })
|
||||
def get_chat_sql_response( userInput: str) -> str:
|
||||
return agent.invoke({"input": userInput})["output"]
|
||||
|
||||
#______________________________________________________________改进SQL_____________________________________________________________________
|
||||
sqlImprovePrompt = PromptTemplate(
|
||||
input_variables=["userInput", "sql"],
|
||||
template = """
|
||||
你是一个SQL专家,精通SQLServer数据库。
|
||||
请根据用户的需求,改进已有的SQL查询语句。
|
||||
只需要返回改进后的SQL语句,不要任何解释。
|
||||
已有SQL:{sql}
|
||||
用户需求:{userInput}
|
||||
"""
|
||||
)
|
||||
sqlImproveChain = sqlImprovePrompt | llm
|
||||
|
||||
def get_chat_sql_improve_response( userInput: str) -> str:
|
||||
return sqlImproveChain.invoke({
|
||||
"userInput": userInput
|
||||
})
|
||||
@@ -10,6 +10,7 @@ titlePrompt = PromptTemplate(
|
||||
2. 直接概括本次对话的核心内容。
|
||||
3. 避免使用笼统或无意义的词语,如“讨论”、“聊天”等。
|
||||
4. 保持自然、易懂、专业或有趣(可根据场景调整风格)。
|
||||
5. 不能出现标点符号。
|
||||
用户原话:"{userStr}"
|
||||
"""
|
||||
)
|
||||
|
||||
@@ -0,0 +1,13 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class AIProfilesRequest(BaseModel):
|
||||
id: str | None = None
|
||||
name: str
|
||||
available_kn_bases:list[str]
|
||||
available_report_tables:list[str]
|
||||
description: str
|
||||
role: str
|
||||
service: str
|
||||
welcome_words: str
|
||||
title: str
|
||||
available_module: str
|
||||
@@ -1,11 +1,12 @@
|
||||
from fastapi import FastAPI
|
||||
from pydantic import BaseModel
|
||||
from typing import Generic, TypeVar, Optional, List
|
||||
from pydantic.generics import GenericModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
# 定义通用响应结构
|
||||
class BaseResponse(GenericModel, Generic[T]):
|
||||
class BaseResponse(BaseModel, Generic[T]):
|
||||
status: bool = True
|
||||
message: str = "操作成功"
|
||||
data: Optional[T] = None
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ChatWithReportRequest(BaseModel):
|
||||
aiId: str
|
||||
companyId: str
|
||||
reportId: str | None = None
|
||||
userInput: str
|
||||
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
class KnowledgeAddRequest(BaseModel):
|
||||
text: str
|
||||
is_active: Optional[bool] = True
|
||||
knowledge_base_id: str
|
||||
@@ -0,0 +1,6 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
|
||||
class KnowledgeBaseAddRequest(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
@@ -0,0 +1,8 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ReportFieldAddRequest(BaseModel):
|
||||
name: str
|
||||
type: str
|
||||
description: str
|
||||
is_active: bool
|
||||
table_id: str
|
||||
@@ -0,0 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ReportTableAddRequest(BaseModel):
|
||||
name: str
|
||||
description: str | None = None
|
||||
@@ -0,0 +1,4 @@
|
||||
from pydantic import BaseModel
|
||||
|
||||
class SaveReportRequest(BaseModel):
|
||||
reportId: str | None = None
|
||||
@@ -0,0 +1,76 @@
|
||||
|
||||
from models.ChatRequest import ChatRequest
|
||||
from models.ChatWithReportRequest import ChatWithReportRequest
|
||||
from models.BaseResponse import BaseResponse
|
||||
import uuid
|
||||
import db.postgres as pg
|
||||
import db.sqlserver as sqlserver
|
||||
import uuid
|
||||
import threading
|
||||
from fastapi import APIRouter, Depends
|
||||
from uuid import UUID
|
||||
from config.security import get_user_id_from_token
|
||||
botRouter = APIRouter()
|
||||
from llm.chatLLM import get_chat_response
|
||||
from llm.titleChain import get_title
|
||||
from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response
|
||||
from models.SaveReportRequest import SaveReportRequest
|
||||
from models.AIProfilesRequest import AIProfilesRequest
|
||||
import json
|
||||
|
||||
@botRouter.get("/aiListForService")
|
||||
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_all_ai_bot(user_id,"service"))
|
||||
@botRouter.get("/aiListForReport")
|
||||
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_all_ai_bot(user_id,"report"))
|
||||
|
||||
@botRouter.get("/aiListForBot")
|
||||
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_all_ai_bot(user_id,"bot"))
|
||||
|
||||
# 保存智能体
|
||||
@botRouter.post("/saveBot")
|
||||
def saveReportBot(bot: AIProfilesRequest,user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
print(bot)
|
||||
ai_personality = {
|
||||
"名字":bot.name,
|
||||
"性格":bot.role,
|
||||
"业务":bot.service,
|
||||
}
|
||||
ai_personality_json = json.dumps(ai_personality, ensure_ascii=False)
|
||||
available_report_tables_json = json.dumps(bot.available_report_tables, ensure_ascii=False)
|
||||
available_kn_bases_json = json.dumps(bot.available_kn_bases, ensure_ascii=False)
|
||||
|
||||
if bot.id:
|
||||
pg.update_bot(
|
||||
id = bot.id,
|
||||
title =bot.title,
|
||||
description = bot.description,
|
||||
welcome_words = bot.welcome_words,
|
||||
ai_personality = ai_personality_json,
|
||||
available_kn_bases = available_kn_bases_json,
|
||||
available_report_tables = available_report_tables_json,
|
||||
available_module = bot.available_module,
|
||||
user_id = user_id
|
||||
)
|
||||
else:
|
||||
pg.insert_bot(
|
||||
title =bot.title,
|
||||
description = bot.description,
|
||||
welcome_words = bot.welcome_words,
|
||||
ai_personality = ai_personality_json,
|
||||
available_kn_bases = available_kn_bases_json,
|
||||
available_module = bot.available_module,
|
||||
available_report_tables = available_report_tables_json,
|
||||
user_id = user_id
|
||||
)
|
||||
return BaseResponse(data= None)
|
||||
+27
-29
@@ -1,63 +1,61 @@
|
||||
from models.ChatRequest import ChatRequest
|
||||
from models.ChatWithReportRequest import ChatWithReportRequest
|
||||
from models.BaseResponse import BaseResponse
|
||||
import uuid
|
||||
import db.postgres as db
|
||||
import db.postgres as pg
|
||||
import db.sqlserver as sqlserver
|
||||
import uuid
|
||||
import threading
|
||||
from fastapi import APIRouter, Depends
|
||||
from uuid import UUID
|
||||
from config.security import get_user_id_from_token
|
||||
router = APIRouter()
|
||||
chatRouter = APIRouter()
|
||||
from agent.dataAgent import get_graph_output
|
||||
from llm.chatLLM import get_chat_response
|
||||
from llm.titleChain import get_title
|
||||
from llm.dataLLM import get_graph_output
|
||||
|
||||
# def async_db_task(func, *args, **kwargs):
|
||||
# """将数据库操作放到后台线程执行"""
|
||||
# threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
|
||||
|
||||
@router.get("/history")
|
||||
from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response
|
||||
from models.SaveReportRequest import SaveReportRequest
|
||||
from models.AIProfilesRequest import AIProfilesRequest
|
||||
import json
|
||||
# 对话历史记录
|
||||
@chatRouter.get("/history")
|
||||
def getHistory(sessionId: str):
|
||||
return BaseResponse(data=db.get_history(sessionId))
|
||||
return BaseResponse(data=pg.get_history(sessionId))
|
||||
|
||||
@router.get("/aiList")
|
||||
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=db.get_all_ai(user_id))
|
||||
|
||||
@router.get("/sessions")
|
||||
# 对话列表
|
||||
@chatRouter.get("/sessionsForBot")
|
||||
def getSessions(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=db.get_sessions(user_id))
|
||||
return BaseResponse(data=pg.get_sessions(user_id,'bot'))
|
||||
|
||||
@router.post("/chat")
|
||||
@chatRouter.post("/chatForBot")
|
||||
def chat(req: ChatRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
if not req.aiId:
|
||||
return {"error": "aiId is required"}
|
||||
sessionName = get_title(req.userInput)
|
||||
session_name = get_title(req.userInput)
|
||||
# 如果没有 sessionId 就新建
|
||||
if not req.sessionId:
|
||||
isNewSession = True
|
||||
is_new_session = True
|
||||
req.sessionId = str(uuid.uuid4())
|
||||
db.insert_session(user_id,req.aiId, req.sessionId, sessionName)
|
||||
pg.insert_session(user_id,req.aiId, req.sessionId, session_name,"bot")
|
||||
else:
|
||||
isNewSession = False
|
||||
db.update_session_updated_at(req.sessionId)
|
||||
is_new_session = False
|
||||
pg.update_session_updated_at(req.sessionId)
|
||||
|
||||
# 插入用户消息
|
||||
db.insert_message(req.sessionId, False, req.userInput)
|
||||
pg.insert_message(req.sessionId, False, req.userInput)
|
||||
|
||||
# 调用 LLM
|
||||
if req.aiId == "9d157dd1-921b-c768-5b90-3e903b50f6f9":
|
||||
# 数据专家AI
|
||||
answer = get_graph_output(aiRole=db.get_ai_personality(req.aiId),history=db.get_history(req.sessionId), userInput= req.userInput)
|
||||
answer = get_graph_output(aiRole=pg.get_ai_personality(req.aiId),history=pg.get_history(req.sessionId), userInput= req.userInput)
|
||||
else:
|
||||
answer = get_chat_response(aiRole=db.get_ai_personality(req.aiId),history=db.get_history(req.sessionId), userInput= req.userInput).content
|
||||
answer = get_chat_response(aiRole=pg.get_ai_personality(req.aiId),history=pg.get_history(req.sessionId), userInput= req.userInput).content
|
||||
# 插入 AI 回复
|
||||
db.insert_message(req.sessionId, True, answer)
|
||||
pg.insert_message(req.sessionId, True, answer)
|
||||
|
||||
return BaseResponse(data={"session_name":session_name,"isNewSession":is_new_session,"content":answer,"sessionId": req.sessionId})
|
||||
|
||||
return BaseResponse(data={"sessionName":sessionName,"isNewSession":isNewSession,"content":answer,"sessionId": req.sessionId})
|
||||
@@ -0,0 +1,54 @@
|
||||
from models.ChatRequest import ChatRequest
|
||||
from models.ChatWithReportRequest import ChatWithReportRequest
|
||||
from models.BaseResponse import BaseResponse
|
||||
import uuid
|
||||
import db.postgres as pg
|
||||
import db.sqlserver as sqlserver
|
||||
import uuid
|
||||
import threading
|
||||
from fastapi import APIRouter, Depends
|
||||
from uuid import UUID
|
||||
from config.security import get_user_id_from_token
|
||||
|
||||
reportDataRouter = APIRouter()
|
||||
from llm.chatLLM import get_chat_response
|
||||
from llm.titleChain import get_title
|
||||
from llm.sqlLLM import get_sql_description_response, get_chat_sql_response, get_chat_sql_improve_response
|
||||
from models.SaveReportRequest import SaveReportRequest
|
||||
from models.ReportTableAddRequest import ReportTableAddRequest
|
||||
from models.ReportFieldAddRequest import ReportFieldAddRequest
|
||||
|
||||
|
||||
# 获取表格列表
|
||||
@reportDataRouter.get("/tableList")
|
||||
def tableList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_available_tables())
|
||||
|
||||
|
||||
# 获取字段列表
|
||||
@reportDataRouter.get("/fieldList")
|
||||
def fieldList(tableId: str, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
if not tableId:
|
||||
return {"error": "tableId is required"}
|
||||
return BaseResponse(data=pg.get_fields_by_table_id(tableId))
|
||||
|
||||
|
||||
# 新增表
|
||||
@reportDataRouter.post("/addTable")
|
||||
def addTable(data: ReportTableAddRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.add_table(data.name, data.description, user_id))
|
||||
|
||||
|
||||
# 新增字段
|
||||
@reportDataRouter.post("/addField")
|
||||
def addField(data: ReportFieldAddRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(
|
||||
data=pg.add_field(data.name, data.type, data.description, data.is_active, data.table_id, user_id))
|
||||
@@ -0,0 +1,48 @@
|
||||
from models.ChatRequest import ChatRequest
|
||||
from models.ChatWithReportRequest import ChatWithReportRequest
|
||||
from models.BaseResponse import BaseResponse
|
||||
import uuid
|
||||
import db.postgres as pg
|
||||
import db.sqlserver as sqlserver
|
||||
import db.milvus as milvus
|
||||
import uuid
|
||||
import threading
|
||||
from fastapi import APIRouter, Depends
|
||||
from uuid import UUID
|
||||
from config.security import get_user_id_from_token
|
||||
knowledgeRouter = APIRouter()
|
||||
from llm.chatLLM import get_chat_response
|
||||
from llm.titleChain import get_title
|
||||
from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response
|
||||
from models.SaveReportRequest import SaveReportRequest
|
||||
from models.KnowledgeAddRequest import KnowledgeAddRequest
|
||||
from models.KnowledgeBaseAddRequest import KnowledgeBaseAddRequest
|
||||
# 获取知识库列表
|
||||
@knowledgeRouter.get("/knowledgeBaseListForService")
|
||||
def knowledge_base_list(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_available_knowledge_bases('service'))
|
||||
|
||||
# 新增知识库
|
||||
@knowledgeRouter.post("/addKnowledgeBase")
|
||||
def add_knowledge_base(data: KnowledgeBaseAddRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.add_knowledge_base(data.name, data.description, user_id))
|
||||
|
||||
# 获取知识列表
|
||||
@knowledgeRouter.get("/knowledgeList")
|
||||
def knowledge_list(knowledgeBaseId: str, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
if not knowledgeBaseId:
|
||||
return {"error": "knowledgeBaseId is required"}
|
||||
return BaseResponse(data=milvus.get_knowledge_by_base_id(knowledgeBaseId))
|
||||
|
||||
# 新增知识
|
||||
@knowledgeRouter.post("/addKnowledge")
|
||||
def add_knowledge(data: KnowledgeAddRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=milvus.add_knowledge(data.text, data.is_active, data.knowledge_base_id, user_id))
|
||||
@@ -0,0 +1,64 @@
|
||||
from models.ChatRequest import ChatRequest
|
||||
from models.ChatWithReportRequest import ChatWithReportRequest
|
||||
from models.BaseResponse import BaseResponse
|
||||
import uuid
|
||||
import db.postgres as pg
|
||||
import db.sqlserver as sqlserver
|
||||
import uuid
|
||||
import threading
|
||||
from fastapi import APIRouter, Depends
|
||||
from uuid import UUID
|
||||
from config.security import get_user_id_from_token
|
||||
reportRouter = APIRouter()
|
||||
from llm.chatLLM import get_chat_response
|
||||
from llm.titleChain import get_title
|
||||
from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response
|
||||
from models.SaveReportRequest import SaveReportRequest
|
||||
from agent.dbAgent import get_db_agent_reply
|
||||
|
||||
# 报表列表
|
||||
@reportRouter.get("/reports")
|
||||
def getReports(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_reports(user_id))
|
||||
|
||||
# 报表详情
|
||||
@reportRouter.get("/report")
|
||||
def getReports(reportId:str):
|
||||
return BaseResponse(data={
|
||||
"sql":pg.getSQL(reportId),
|
||||
"tableData":sqlserver.executeSQL(pg.getSQL(reportId)),
|
||||
})
|
||||
# 保存报表
|
||||
@reportRouter.post("/saveReport")
|
||||
def saveReport(report:SaveReportRequest,user_id: UUID = Depends(get_user_id_from_token)):
|
||||
sql = pg.getSQL(report.reportId)
|
||||
# 生成描述
|
||||
title = get_sql_description_response(sql = sql)
|
||||
res = pg.maked_report(report_id=report.reportId,title=title.content)
|
||||
return BaseResponse(data=res)
|
||||
|
||||
# 使用langgraph
|
||||
@reportRouter.post("/chatWithReport")
|
||||
def chat(req: ChatWithReportRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
# 获取reportId
|
||||
if not req.reportId:
|
||||
# 新报表
|
||||
sql = ""
|
||||
else:
|
||||
# 基于之前的报表
|
||||
sql = pg.getSQL(req.reportId)
|
||||
result = get_db_agent_reply(aiId=req.aiId, userInput = req.userInput,tenant_id=req.companyId, sql = sql)
|
||||
sqlRes = result.get("sql", "")
|
||||
newReportId = str(uuid.uuid4())
|
||||
pg.save_report(id = newReportId, user_id = user_id, title="尚未收藏", sql=sqlRes)
|
||||
|
||||
if sqlRes.strip() != "":
|
||||
tableData = sqlserver.executeSQL(sqlRes)
|
||||
else:
|
||||
tableData = []
|
||||
return BaseResponse(data={"content":result["reply"],"sql":sqlRes,"reportId": newReportId, "tableData": tableData})
|
||||
@reportRouter.get("/companyList")
|
||||
def companyList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
return BaseResponse(data=sqlserver.get_company_list(user_id))
|
||||
@@ -0,0 +1,47 @@
|
||||
from models.ChatRequest import ChatRequest
|
||||
from models.BaseResponse import BaseResponse
|
||||
import uuid
|
||||
import db.postgres as pg
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends
|
||||
from uuid import UUID
|
||||
from config.security import get_user_id_from_token
|
||||
serviceRouter = APIRouter()
|
||||
from llm.titleChain import get_title
|
||||
from agent.serviceAgent import get_service_agent_reply
|
||||
from llm.memLLM import take_memory
|
||||
import utils.MyUtils as utils
|
||||
# 对话列表
|
||||
@serviceRouter.get("/sessionsForService")
|
||||
def getSessions(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_sessions(user_id,'service'))
|
||||
|
||||
# 对话
|
||||
@serviceRouter.post("/chatForService")
|
||||
def chat(req: ChatRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
if not req.aiId:
|
||||
return {"error": "aiId is required"}
|
||||
sessionName = get_title(req.userInput)
|
||||
# 如果没有 sessionId 就新建
|
||||
if not req.sessionId:
|
||||
isNewSession = True
|
||||
req.sessionId = str(uuid.uuid4())
|
||||
pg.insert_session(user_id,req.aiId, req.sessionId, sessionName, "service")
|
||||
else:
|
||||
isNewSession = False
|
||||
pg.update_session_updated_at(req.sessionId)
|
||||
|
||||
# 插入用户消息
|
||||
pg.insert_message(req.sessionId, False, req.userInput)
|
||||
|
||||
answer = get_service_agent_reply(aiId=req.aiId,history=pg.get_history_with_time(req.sessionId,6), userInput= req.userInput,kn_bases=pg.get_ai_available_kn_bases(req.aiId))
|
||||
# 插入 AI 回复
|
||||
pg.insert_message(req.sessionId, True, answer)
|
||||
|
||||
# 异步执行:记忆判断
|
||||
utils.async_db_task(take_memory,req.aiId,req.sessionId,user_id,)
|
||||
return BaseResponse(data={"sessionName":sessionName,"isNewSession":isNewSession,"content":answer,"sessionId": req.sessionId})
|
||||
@@ -0,0 +1,6 @@
|
||||
import threading
|
||||
|
||||
# 后台操作
|
||||
def async_db_task(func, *args, **kwargs):
|
||||
"""将数据库操作放到后台线程执行"""
|
||||
threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
|
||||
@@ -1,10 +1,13 @@
|
||||
fastapi==0.116.1
|
||||
langchain==0.3.27
|
||||
langchain_community==0.3.29
|
||||
langchain_milvus==0.2.1
|
||||
langchain_core==0.3.75
|
||||
langchain_postgres==0.0.15
|
||||
langchain_tavily==0.2.11
|
||||
langgraph==0.6.6
|
||||
langchain_openai===0.3.32
|
||||
langchain-milvus===0.2.1
|
||||
psycopg==3.2.9
|
||||
psycopg_pool==3.2.6
|
||||
pydantic==2.11.7
|
||||
|
||||
@@ -0,0 +1,18 @@
|
||||
import dspy
|
||||
lm = dspy.LM("openai/deepseek-chat", api_key="sk-6129a200ae294b9f86553505191fa477", api_base="https://api.deepseek.com")
|
||||
dspy.configure(lm=lm)
|
||||
|
||||
# print(lm("Say this is a test!", temperature=0.7)) # => ['This is a test!']
|
||||
# print(lm(messages=[{"role": "user", "content": "Say this is a test!"}])) # => ['This is a test!']
|
||||
|
||||
from typing import Literal
|
||||
|
||||
class Classify(dspy.Signature):
|
||||
"""Classify sentiment of a given sentence."""
|
||||
|
||||
sentence: str = dspy.InputField()
|
||||
sentiment: Literal["positive", "negative", "neutral"] = dspy.OutputField()
|
||||
confidence: float = dspy.OutputField()
|
||||
|
||||
classify = dspy.Predict(Classify)
|
||||
print(classify(sentence="This book was super fun to read, though not the last chapter."))
|
||||
@@ -0,0 +1,126 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "d029ad67",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_milvus import BM25BuiltInFunction, Milvus\n",
|
||||
"from typing import List\n",
|
||||
"URI = \"http://10.10.10.9:19530\"\n",
|
||||
"tongyiKey = \"sk-9464b2498c184982a9fe9d2c2e725ab5\"\n",
|
||||
"from langchain_community.embeddings import DashScopeEmbeddings\n",
|
||||
"embeddings = DashScopeEmbeddings(\n",
|
||||
" model=\"text-embedding-v3\",\n",
|
||||
" dashscope_api_key= tongyiKey, \n",
|
||||
")\n",
|
||||
"memVectorstore = Milvus(\n",
|
||||
" embedding_function=embeddings,\n",
|
||||
" connection_args={\"uri\": URI, \"token\": \"root:Milvus\", \"db_name\": \"bbit_ai_lab\"},\n",
|
||||
" collection_name=\"memory\",\n",
|
||||
" index_params={\"index_type\": \"FLAT\", \"metric_type\": \"L2\"},\n",
|
||||
" consistency_level=\"Strong\",\n",
|
||||
" auto_id=True,\n",
|
||||
"\n",
|
||||
" primary_field = \"id\",\n",
|
||||
" text_field=\"text\",\n",
|
||||
" vector_field=\"vector\",\n",
|
||||
" partition_key_field = \"ai_id\",\n",
|
||||
" enable_dynamic_field = True,\n",
|
||||
" drop_old=False, # set to True if seeking to drop the collection with that name if it exists\n",
|
||||
")\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "a480053b",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"def get_memory_by_key_words(key_words: str, ai_ids: List[str]) -> str:\n",
|
||||
" print(\"ai_id是:\" , ai_ids)\n",
|
||||
" \"\"\"\n",
|
||||
" 根据关键词和 ai_ids 列表,在知识库中检索相关内容,并返回整理后的文本字符串\n",
|
||||
" \"\"\"\n",
|
||||
" # 构建过滤表达式:只查 kn_ids 范围内的\n",
|
||||
" if ai_ids:\n",
|
||||
" ids_expr = \" or \".join([f'ai_id == \"{kid}\"' for kid in ai_ids])\n",
|
||||
" expr = f\"({ids_expr})\"\n",
|
||||
" else:\n",
|
||||
" expr = \"\" # 不限制 kn_id todo 实际上应该不反悔任何内容\n",
|
||||
" \n",
|
||||
" result = knVectorstore.similarity_search(\n",
|
||||
" query=key_words,\n",
|
||||
" k=5, # 可调节返回条数\n",
|
||||
" expr=expr\n",
|
||||
" )\n",
|
||||
" \n",
|
||||
" # 整理成字符串\n",
|
||||
" doc_texts = []\n",
|
||||
" for idx, doc in enumerate(result, start=1):\n",
|
||||
" text = doc.page_content.strip()\n",
|
||||
" if text:\n",
|
||||
" # 可以加个编号,便于LLM区分\n",
|
||||
" doc_texts.append(f\"[记忆{idx}]: {text}\")\n",
|
||||
" \n",
|
||||
" # 拼成一个大字符串,用换行隔开\n",
|
||||
" combined_text = \"\\n\\n\".join(doc_texts)\n",
|
||||
" return combined_text"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "36759de5",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"ai_id是: ['3730f279-8b56-46ec-bde9-8a9e6c27f021']\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"ename": "NameError",
|
||||
"evalue": "name 'knVectorstore' is not defined",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mget_memory_by_key_words\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m共育室 部署 地方\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m3730f279-8b56-46ec-bde9-8a9e6c27f021\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
|
||||
"Cell \u001b[0;32mIn[2], line 13\u001b[0m, in \u001b[0;36mget_memory_by_key_words\u001b[0;34m(key_words, ai_ids)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 11\u001b[0m expr \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;66;03m# 不限制 kn_id todo 实际上应该不反悔任何内容\u001b[39;00m\n\u001b[0;32m---> 13\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mknVectorstore\u001b[49m\u001b[38;5;241m.\u001b[39msimilarity_search(\n\u001b[1;32m 14\u001b[0m query\u001b[38;5;241m=\u001b[39mkey_words,\n\u001b[1;32m 15\u001b[0m k\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m, \u001b[38;5;66;03m# 可调节返回条数\u001b[39;00m\n\u001b[1;32m 16\u001b[0m expr\u001b[38;5;241m=\u001b[39mexpr\n\u001b[1;32m 17\u001b[0m )\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# 整理成字符串\u001b[39;00m\n\u001b[1;32m 20\u001b[0m doc_texts \u001b[38;5;241m=\u001b[39m []\n",
|
||||
"\u001b[0;31mNameError\u001b[0m: name 'knVectorstore' is not defined"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"get_memory_by_key_words(\"共育室 部署 地方\",[\"3730f279-8b56-46ec-bde9-8a9e6c27f021\"])"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "lang",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,114 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 21,
|
||||
"id": "d029ad67",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"[460823023525530114, 460823023525530115]"
|
||||
]
|
||||
},
|
||||
"execution_count": 21,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_milvus import BM25BuiltInFunction, Milvus\n",
|
||||
"URI = \"http://10.10.10.9:19530\"\n",
|
||||
"tongyiKey = \"sk-9464b2498c184982a9fe9d2c2e725ab5\"\n",
|
||||
"from langchain_community.embeddings import DashScopeEmbeddings\n",
|
||||
"embeddings = DashScopeEmbeddings(\n",
|
||||
" model=\"text-embedding-v3\",\n",
|
||||
" dashscope_api_key= tongyiKey, \n",
|
||||
")\n",
|
||||
"vectorstore = Milvus(\n",
|
||||
" embedding_function=embeddings,\n",
|
||||
" connection_args={\"uri\": URI, \"token\": \"root:Milvus\", \"db_name\": \"bbit_ai_lab\"},\n",
|
||||
" collection_name=\"knowledge\",\n",
|
||||
" index_params={\"index_type\": \"FLAT\", \"metric_type\": \"L2\"},\n",
|
||||
" consistency_level=\"Strong\",\n",
|
||||
" auto_id=True,\n",
|
||||
"\n",
|
||||
" primary_field = \"id\",\n",
|
||||
" text_field=\"text\",\n",
|
||||
" vector_field=\"vector\",\n",
|
||||
" partition_key_field = \"kn_id\",\n",
|
||||
" enable_dynamic_field = True,\n",
|
||||
" drop_old=False, # set to True if seeking to drop the collection with that name if it exists\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"from langchain.schema import Document\n",
|
||||
"\n",
|
||||
"docs = [\n",
|
||||
" Document(\n",
|
||||
" page_content=\"这是第一条文本\",\n",
|
||||
" metadata={\n",
|
||||
" \"kn_id\": \"8ecd1179-4194-4b80-bc39-5addc678df4b\",\n",
|
||||
" \"is_active\": True,\n",
|
||||
" }\n",
|
||||
" ),\n",
|
||||
" Document(\n",
|
||||
" page_content=\"这是第二条文本\",\n",
|
||||
" metadata={\n",
|
||||
" \"kn_id\": \"8ecd1179-4194-4b80-bc39-5addc678df4b\",\n",
|
||||
" \"is_active\": True,\n",
|
||||
" }\n",
|
||||
" )\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"vectorstore.add_documents(docs)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "a480053b",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"*这是第一条文本 [{'kn_id': '8ecd1179-4194-4b80-bc39-5addc678df4b', 'id': 460823023525530108, 'is_active': True}]\n",
|
||||
"*这是第一条文本 [{'kn_id': '8ecd1179-4194-4b80-bc39-5addc678df4b', 'id': 460823023525530110, 'is_active': True}]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"results = vectorstore.similarity_search(\n",
|
||||
" \"\",\n",
|
||||
" k=2,\n",
|
||||
" expr='kn_id == \"8ecd1179-4194-4b80-bc39-5addc678df4b\"',\n",
|
||||
")\n",
|
||||
"for res in results:\n",
|
||||
" print(f\"*{res.page_content} [{res.metadata}]\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "lang",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
@@ -0,0 +1,260 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 11,
|
||||
"id": "dfb008fd",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"from openai import OpenAI\n",
|
||||
"from glob import glob\n",
|
||||
"from pymilvus import MilvusClient\n",
|
||||
"from tqdm import tqdm"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "eaa97ad1",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"client = OpenAI(\n",
|
||||
" api_key= \"sk-9464b2498c184982a9fe9d2c2e725ab5\", # 如果您没有配置环境变量,请在此处用您的API Key进行替换\n",
|
||||
" base_url=\"https://dashscope.aliyuncs.com/compatible-mode/v1\" # 百炼服务的base_url\n",
|
||||
")\n",
|
||||
"def emb_text(text):\n",
|
||||
" return client.embeddings.create(\n",
|
||||
" model=\"text-embedding-v4\",\n",
|
||||
" input=text,\n",
|
||||
" dimensions=1024, # 指定向量维度(仅 text-embedding-v3及 text-embedding-v4支持该参数)\n",
|
||||
" encoding_format=\"float\"\n",
|
||||
" ).data[0].embedding"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "9df315ea",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"1024\n",
|
||||
"[-0.017507297918200493, 0.02571254037320614, 0.02589302882552147, -0.02639283984899521, -0.013571279123425484, -0.0032158030662685633, -0.006428135093301535, 0.02458796463906765, -0.059366535395383835, 0.13083963096141815]\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 测试\n",
|
||||
"test_embedding = emb_text(\"This is a test\")\n",
|
||||
"embedding_dim = len(test_embedding)\n",
|
||||
"print(embedding_dim)\n",
|
||||
"print(test_embedding[:10])\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "95d0a121",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Milvus数据库配置\n",
|
||||
"milvus_client = MilvusClient(uri=\"http://10.10.10.9:19530\")\n",
|
||||
"collection_name = \"my_rag_collection\"\n",
|
||||
"embedding_dim = 1024\n",
|
||||
"\n",
|
||||
"if milvus_client.has_collection(collection_name):\n",
|
||||
" milvus_client.drop_collection(collection_name)\n",
|
||||
"milvus_client.create_collection(\n",
|
||||
" collection_name=collection_name,\n",
|
||||
" dimension=embedding_dim,\n",
|
||||
" metric_type=\"IP\", # Inner product distance\n",
|
||||
" consistency_level=\"Bounded\", # Supported values are (`\"Strong\"`, `\"Session\"`, `\"Bounded\"`, `\"Eventually\"`). See https://milvus.io/docs/consistency.md#Consistency-Level for more details.\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "e09edfec",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Creating embeddings: 100%|██████████| 72/72 [00:11<00:00, 6.46it/s]\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"{'insert_count': 72, 'ids': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], 'cost': 0}"
|
||||
]
|
||||
},
|
||||
"execution_count": 13,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"# 从文件中插入数据\n",
|
||||
"text_lines = []\n",
|
||||
"for file_path in glob(\"milvus_docs/en/faq/*.md\", recursive=True):\n",
|
||||
" with open(file_path, \"r\") as file:\n",
|
||||
" file_text = file.read()\n",
|
||||
"\n",
|
||||
" text_lines += file_text.split(\"# \")\n",
|
||||
"\n",
|
||||
"data = []\n",
|
||||
"\n",
|
||||
"for i, line in enumerate(tqdm(text_lines, desc=\"Creating embeddings\")):\n",
|
||||
" data.append({\"id\": i, \"vector\": emb_text(line), \"text\": line})\n",
|
||||
"\n",
|
||||
"milvus_client.insert(collection_name=collection_name, data=data)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 25,
|
||||
"id": "f3007553",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Milvus 是一个开源的向量数据库,主要用于高效地存储、管理和检索大规模的向量数据。它广泛应用于机器学习、推荐系统、图像识别等需要处理高维数据的场景。\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"question = \"milvus是什么,用中文回答\"\n",
|
||||
"search_res = milvus_client.search(\n",
|
||||
" collection_name=collection_name,\n",
|
||||
" data=[\n",
|
||||
" emb_text(question)\n",
|
||||
" ], # Use the `emb_text` function to convert the question to an embedding vector\n",
|
||||
" limit=3, # Return top 3 results\n",
|
||||
" search_params={\"metric_type\": \"IP\", \"params\": {}}, # Inner product distance\n",
|
||||
" output_fields=[\"text\"], # Return the text field\n",
|
||||
")\n",
|
||||
"import json\n",
|
||||
"# 获取答案\n",
|
||||
"retrieved_lines_with_distances = [\n",
|
||||
" (res[\"entity\"][\"text\"], res[\"distance\"]) for res in search_res[0]\n",
|
||||
"]\n",
|
||||
"context = \"\\n\".join(\n",
|
||||
" [line_with_distance[0] for line_with_distance in retrieved_lines_with_distances]\n",
|
||||
")\n",
|
||||
"SYSTEM_PROMPT = \"\"\"\n",
|
||||
"Human: You are an AI assistant. You are able to find answers to the questions from the contextual passage snippets provided.\n",
|
||||
"\"\"\"\n",
|
||||
"USER_PROMPT = f\"\"\"\n",
|
||||
"Use the following pieces of information enclosed in <context> tags to provide an answer to the question enclosed in <question> tags.\n",
|
||||
"<context>\n",
|
||||
"{context}\n",
|
||||
"</context>\n",
|
||||
"<question>\n",
|
||||
"{question}\n",
|
||||
"</question>\n",
|
||||
"\"\"\"\n",
|
||||
"response = client.chat.completions.create(\n",
|
||||
" model='qwen-turbo',\n",
|
||||
" messages=[\n",
|
||||
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
|
||||
" {\"role\": \"user\", \"content\": USER_PROMPT},\n",
|
||||
" ],\n",
|
||||
")\n",
|
||||
"print(response.choices[0].message.content)\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "077922d1",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"2025-09-15 15:12:53,649 [ERROR][handler]: RPC error: [drop_database], <MilvusException: (code=65535, message=can not drop default database)>, <Time:{'RPC start': '2025-09-15 15:12:53.638539', 'RPC error': '2025-09-15 15:12:53.649605'}> (decorators.py:140)\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Database 'default' already exists.\n",
|
||||
"Collection 'my_rag_collection' has been dropped.\n",
|
||||
"Collection 'bbit_ai_lab_knowledge' has been dropped.\n",
|
||||
"An error occurred: <MilvusException: (code=65535, message=can not drop default database)>\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from pymilvus import Collection, MilvusException, connections, db, utility\n",
|
||||
"\n",
|
||||
"conn = connections.connect(host=\"10.10.10.9\", port=19530)\n",
|
||||
"\n",
|
||||
"# Check if the database exists\n",
|
||||
"db_name = \"default\"\n",
|
||||
"\n",
|
||||
"try:\n",
|
||||
" existing_databases = db.list_database()\n",
|
||||
" if db_name in existing_databases:\n",
|
||||
" print(f\"Database '{db_name}' already exists.\")\n",
|
||||
"\n",
|
||||
" # Use the database context\n",
|
||||
" db.using_database(db_name)\n",
|
||||
"\n",
|
||||
" # Drop all collections in the database\n",
|
||||
" collections = utility.list_collections()\n",
|
||||
" for collection_name in collections:\n",
|
||||
" collection = Collection(name=collection_name)\n",
|
||||
" collection.drop()\n",
|
||||
" print(f\"Collection '{collection_name}' has been dropped.\")\n",
|
||||
"\n",
|
||||
" db.drop_database(db_name)\n",
|
||||
" print(f\"Database '{db_name}' has been deleted.\")\n",
|
||||
" else:\n",
|
||||
" print(f\"Database '{db_name}' does not exist.\")\n",
|
||||
" database = db.create_database(db_name)\n",
|
||||
" print(f\"Database '{db_name}' created successfully.\")\n",
|
||||
"except MilvusException as e:\n",
|
||||
" print(f\"An error occurred: {e}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "lang",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.18"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
||||
Reference in New Issue
Block a user