更新python后端

This commit is contained in:
BBIT-Kai
2025-09-18 17:18:18 +08:00
parent 2fc209e6e6
commit de6a350da8
45 changed files with 2524 additions and 89 deletions
+5
View File
@@ -0,0 +1,5 @@
# Default ignored files
/shelf/
/workspace.xml
# Editor-based HTTP Client requests
/httpRequests/
+6
View File
@@ -0,0 +1,6 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="AI Toolkit Settings">
<option name="importsOfInterestPresent" value="true" />
</component>
</project>
+12
View File
@@ -0,0 +1,12 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="bbit_ai_lab" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
<component name="PyDocumentationSettings">
<option name="format" value="PLAIN" />
<option name="myDocStringFormat" value="Plain" />
</component>
</module>
+8
View File
@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<module type="PYTHON_MODULE" version="4">
<component name="NewModuleRootManager">
<content url="file://$MODULE_DIR$" />
<orderEntry type="jdk" jdkName="Python 3.12 (bbit_ai)" jdkType="Python SDK" />
<orderEntry type="sourceFolder" forTests="false" />
</component>
</module>
+6
View File
@@ -0,0 +1,6 @@
<component name="InspectionProjectProfileManager">
<settings>
<option name="USE_PROJECT_PROFILE" value="false" />
<version value="1.0" />
</settings>
</component>
+7
View File
@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="Black">
<option name="sdkName" value="Python 3.12" />
</component>
<component name="ProjectRootManager" version="2" project-jdk-name="bbit_ai_lab" project-jdk-type="Python SDK" />
</project>
+8
View File
@@ -0,0 +1,8 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="ProjectModuleManager">
<modules>
<module fileurl="file://$PROJECT_DIR$/.idea/app.iml" filepath="$PROJECT_DIR$/.idea/app.iml" />
</modules>
</component>
</project>
+7
View File
@@ -0,0 +1,7 @@
<?xml version="1.0" encoding="UTF-8"?>
<project version="4">
<component name="VcsDirectoryMappings">
<mapping directory="$PROJECT_DIR$/../.." vcs="Git" />
<mapping directory="$PROJECT_DIR$/.." vcs="Git" />
</component>
</project>
+1 -2
View File
@@ -30,8 +30,7 @@ RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple -r re
# 复制项目代码
COPY app/ .
# 对外暴露端口(FastAPI 默认 8000
EXPOSE 8000
EXPOSE 13011
# 启动命令(使用 uvicorn 启动 FastAPI
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "13011"]
View File
+139
View File
@@ -0,0 +1,139 @@
from langchain.prompts import PromptTemplate
from config.llm import llm
from config.ssDb import ssDBLC
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_community.agent_toolkits import create_sql_agent
from langchain.prompts import PromptTemplate
from config.llm import llm
from config.ssDb import ssDBLC
from typing import Annotated
from langgraph.graph.message import add_messages
import os
from langchain_tavily import TavilySearch
from langgraph.prebuilt import ToolNode, tools_condition
from llm.chatLLM import get_chat_response
from typing import TypedDict
from langgraph.graph import StateGraph, END
from llm.summarizeLLM import getSummary
# -------- 定义状态 --------
class State(TypedDict):
userInput: str # 用户输入
source: str # 选择的数据来源:web 或 db 或 chat
infomation: str # 查询到的内容
aiRole: str # AI 角色
history: str # 聊天历史
reply: str # 最终回复
# -------- 定义节点 --------
# ------------------------------------------------------------------------ 路径选择 --------
pathSelectPrompt = PromptTemplate(
input_variables=["aiRole", "history", "userStr", "infomation"],
template = """
你是主干信息科技有限公司的业务员,是一家蚕桑服务公司,现在需要根据用户输入来判断应该使用哪种方式来回答用户的问题。
你有三种选择:
1. 如果用户的问题涉及最新的信息,比如新闻、事件、天气等涉及时间的内容时,请选择 "web
2. 如果用户的问题涉及具体的蚕桑业务(例如询问农户、订单、订种、租户)的数据库查询需求,请选择 "db"
3. 如果用户的问题是一般性的聊天或咨询,请选择 "chat"
请只返回 "web""db""chat" 之一,且不要添加任何其他解释。
用户最新输入:
{userStr}
请做出你的选择:
"""
)
pathSelectChain = pathSelectPrompt | llm
def decide_source(state: State, max_retry=3):
print("根据用户输入选择数据来源,用户输入:", state["userInput"])
"""根据用户输入选择数据来源"""
for _ in range(max_retry):
choice = pathSelectChain.invoke({
"aiRole": state["aiRole"],
"history": state["history"],
"userStr": state["userInput"],
}).content.strip().lower()
if choice in ["web", "db", "chat"]:
state["source"] = choice
break
else:
# 如果连续 max_retry 次都不合法,默认走 chat
state["source"] = "chat"
print("选择的数据来源是:", state["source"])
return state
# ------------------------------------------------------------------------ 上网查询 --------
os.environ["TAVILY_API_KEY"] = "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"
tool = TavilySearch(max_results=2)
def fetch_web(state: State):
result = tool.invoke(state["userInput"])
state["infomation"] = result.get("content") or result
print("调用了联网工具,结果是:", state["infomation"])
return state
# ------------------------------------------------------------------------ 数据库查询 --------
agent = create_sql_agent(
llm=llm,
db=ssDBLC,
agent_type="tool-calling",
verbose=True
)
def fetch_db(state: State):
state["infomation"] = agent.invoke({"input": state["userInput"]})["output"]
print("调用了数据库工具,结果是:", state["infomation"])
return state
# ------------------------------------------------------------------------ 整理结果 --------
def summarize_ai(state: State):
"""AI 总结输出"""
state["reply"] = getSummary(aiRole=state["aiRole"], history=state["history"], userInput= state["userInput"], infomation= state["infomation"])
return state
# ------------------------------------------------------------------------ 普通聊天 --------
def chat(state: State):
state["reply"] = get_chat_response(aiRole=state["aiRole"],history=state["history"], userInput= state["userInput"]).content
print("直接回复")
return state
# ------------------------------------------------------------------------ 构建有向图 --------
workflow = StateGraph(State)
workflow.add_node("decide", decide_source)
workflow.add_node("fetch_web", fetch_web)
workflow.add_node("fetch_db", fetch_db)
workflow.add_node("chat", chat)
workflow.add_node("summarize", summarize_ai)
workflow.set_entry_point("decide")
# 两条路径最后都汇合到 summarize
workflow.add_edge(START, "decide")
workflow.add_edge("fetch_web", "summarize")
workflow.add_edge("fetch_db", "summarize")
# 条件边:根据 source 决定走向
workflow.add_conditional_edges(
"decide",
lambda state: state["source"], # 返回 state["source"] 的值
{
"web": "fetch_web",
"chat": "chat",
"db": "fetch_db"
}
)
workflow.add_edge("summarize", END)
workflow.add_edge("chat", END)
graph = workflow.compile()
# 执行函数
def get_graph_output(aiRole:str,history: str, userInput: str) -> str:
final_state = graph.invoke({
"aiRole":aiRole,
"history": history,
"userInput": userInput,
})
return final_state["reply"]
+346
View File
@@ -0,0 +1,346 @@
from typing import Literal
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from config.llm import llm, llmThink
from langgraph.graph import StateGraph, END
from langchain.prompts import PromptTemplate
from config.llm import llm
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_community.agent_toolkits import create_sql_agent
from langchain.prompts import PromptTemplate
from config.llm import llm
from typing import Annotated
from langgraph.graph.message import add_messages
import os
from langchain_tavily import TavilySearch
from langgraph.prebuilt import ToolNode, tools_condition
from llm.chatLLM import get_chat_response
from typing import TypedDict
from langgraph.graph import StateGraph, END
from llm.summarizeLLM import getSummary
import db.postgres as pgdb
import db.sqlserver as sqlserver
# -------- 定义状态 --------
class State(TypedDict):
userInput: str # 用户输入
path: str # 开始聊天选择的路径
table_info: str # 可用表信息
isFirstGenSQL: bool # 是否第一次生成SQL
sql: str # 当前操作的SQL
ai_service: str # AI 角色 业务
ai_role: str # AI 角色 性格特点
tenant_id: str # 租户ID
history: str # 聊天历史
reply: str # 最终回复
# -------- 定义节点 --------
# ------------------------------------------------------------------------ 数据库查询 --------
gen_sql_prompt = PromptTemplate(
input_variables=["table_info", "userInput"],
template="""
# 角色
你是一个企业 SQL Server 数据库 SQL 生成助手,负责根据用户问题生成相应的 SQL 语句。
# 已知信息
可访问的表和字段:{table_info}
# 任务要求
1. 根据用户提出的问题“{userInput}”生成 SQL 语句。
2. 只能使用已知的表和字段。
3. 输出完整可执行的 SQL 语句,不包含多余文字。
4. 若 SQL 语句返回列表数据,需限制返回数量,最大为 15 条,使用 SQL Server 语法(TOP 15 或 OFFSET FETCH)。
5. 若 SQL 是聚合查询(如 COUNT、SUM 等),无需限制行数。
6. 在生成 SQL 时,如果需要根据身份证计算年龄,请使用 SQL Server 标准日期格式 SUBSTRING(idcard, 7, 8) 和 CONVERT(..., 112),不要使用拼接 / 或非标准日期格式。
7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。
8. 查询的SQL字段要用别名,取名参考描述。
9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量限制租户id = {tenant_id}
请直接输出完整可执行的 SQL 语句,不要任何其他文字或格式化,例如反引号或 ```sql。
"""
)
sqlChain = gen_sql_prompt | llm
fix_prompt = PromptTemplate(
input_variables=["sql", "error_msg", "table_info", "tenant_id"],
template="""
# 系统角色
你是一位专业的 SQL Server的SQL语句纠错专家,擅长识别 SQL 语句中的语法错误和字段引用错误,并能对其进行修正。
# 任务
根据提供的原始 SQL 语句、执行报错信息以及可用表和字段信息,修正 SQL 语句,确保其语法正确且引用的字段存在。
# 输入信息
- 原始 SQL: {sql}
- 执行报错: {error_msg}
- 可用表和字段: {table_info}
# 输出要求
只返回修正后的 SQL 语句,不包含任何额外的解释或说明。
"""
)
fixSQLChain = fix_prompt | llm
def sql(state: State):
if state["isFirstGenSQL"]:
state['sql'] = sql_1(state)
else:
state['sql'] = sql_2(state)
for attempt in range(2):
try:
# 执行 SQL
result = sqlserver.executeSQL(state['sql'])
state['sql_result'] = result
# print("SQL 执行成功,结果:", result)
break
except Exception as e:
error_msg = str(e)
print(f"SQL 执行出错: {error_msg}")
# 调用 LLM 修正 SQL
state['sql'] = fixSQLChain.invoke({
"sql": state['sql'],
"error_msg": error_msg,
"table_info": state['table_info'],
"tenant_id": state['tenant_id']
}
).content
# print(f"LLM 生成修正 SQL: {state['sql']}")
else:
raise RuntimeError(f"SQL 多次纠错失败,最后 SQL: {state['sql']}")
return state
def sql_1(state: State):
return sqlChain.invoke({
"table_info": state['table_info'],
"userInput": state["userInput"],
"tenant_id": state['tenant_id']
}).content
improve_sql_prompt = PromptTemplate(
input_variables=["table_info", "userInput", "tenant_id"],
template="""
# 角色
你是一个企业 SQL Server 数据库 SQL 生成助手,负责根据用户问题改进相应的 SQL 语句。
# 已知信息
当前 SQL 语句: {sql}
可访问的表和字段: {table_info}
# 任务要求
1. 根据用户提出的问题“{userInput}”以及当前的 SQL 语句进行改进。
2. 只能使用已知的表和字段。
3. 输出完整可执行的 SQL 语句,不包含多余文字。
4. 若 SQL 语句返回列表数据,需限制返回数量,最大为 15 条,使用 SQL Server 语法(TOP 15 或 OFFSET FETCH)。
5. 若 SQL 是聚合查询(如 COUNT、SUM 等),无需限制行数。
6. 在生成 SQL 时,如果需要根据身份证计算年龄,请使用 SQL Server 标准日期格式 SUBSTRING(idcard, 7, 8) 和 CONVERT(..., 112),不要使用拼接 / 或非标准日期格式。
7. 通常来说,不查询对用户来说意义不大的字段,比如主键、外键、id等。
8. 查询的SQL字段要用别名,取名参考描述。
9. 一般情况下,如果能限制租户Id(通常为tenantid 字段),则尽量加上WHERE tenantid = {tenant_id}
"""
)
improveSqlChain = improve_sql_prompt | llm
def sql_2(state: State):
return improveSqlChain.invoke({
"sql": state['sql'],
"table_info": state['table_info'],
"userInput": state["userInput"],
"tenant_id": state['tenant_id']
}).content
# ------------------------------------------------------------------------ 路径选择 --------
pathSelectPrompt = PromptTemplate(
input_variables=["userInput", "table_info", "sql"],
template="""
你的任务是:
1. 根据用户输入的问题和已知的表结构数据,判断是否能够生成准确的 SQL 查询。
2. 首先仔细阅读以下表结构数据:
<table_info>
{table_info}
</table_info>
2. 然后仔细阅读用户输入的问题:
<userInput>
{userInput}
</userInput>
3. 请严格遵循以下规则:
只有在能够完全、明确、直接根据表结构生成正确 SQL 时,输出 db。
参考表结构或字段描述中出现的关键词:如果用户问题中出现的关键字段或概念在表结构中找不到明确对应关系,或者问题逻辑无法直接映射到表结构,输出 chat。
不允许假设额外存在的表、字段或数据,也不允许基于常识或推测生成 SQL。
输出必须严格二选一:
db → 可以直接生成 SQL。
chat → 无法直接生成 SQL,需要进一步解释或澄清。
回答内容仅限于db或者chat,请勿输出其他内容。
你的回复:
"""
)
pathSelectChain = pathSelectPrompt | llmThink
def decide_source(state: State, max_retry=3):
"""根据用户输入选择数据来源"""
for _ in range(max_retry):
choice = pathSelectChain.invoke({
"userInput": state["userInput"],
"table_info": state["table_info"],
"ai_service": state["ai_service"],
"sql": state["sql"]
}).content.strip().lower()
print("根据用户输入选择数据来源,路径是:", choice)
if choice in ["db", "chat"]:
state["path"] = choice
break
else:
# 如果连续 max_retry 次都不合法,默认走 chat
state["path"] = "chat"
return state
# ------------------------------------------------------------------------ !普通聊天 --------
noChatPrompt = PromptTemplate(
input_variables=["userInput", "ai_service"],
template="""
你的任务是回复用户,告知用户你目前无法处理他们的回复,因为你的业务是特定领域的服务。请仔细阅读以下信息,并按照指示进行回复。
用户的回复:
<userInput>
{userInput}
</userInput>
你的业务:
<ai_service>
{ai_service}
</ai_service>
在回复时,请遵循以下指南:
1. 明确告知用户你无法处理当前回复。
2. 提及你的业务是{ai_service}
3. 引导用户提出与你业务相关的问题。
4. 使用礼貌和友好的语气。
你的回答:
"""
)
noChatChain = noChatPrompt | llm
def chat(state: State):
state["reply"] = noChatChain.invoke({
"userInput": state["userInput"],
"ai_service": state["ai_service"]
}).content
print("直接回复")
return state
# ------------------------------------------------------------------------ 整理结果 --------
summarizePrompt = PromptTemplate(
input_variables=["ai_role", "history", "userStr", "table_info"],
template="""
你是主干信息研发的AI助手,你的性格特点为:
<ai_role>
{ai_role}
</ai_role>
用户之前的提问为:
<userInput>
{userInput}
</userInput>
当前生成的SQL语句为:
<sql>
{sql}
</sql>
当前支持的数据库表与字段信息如下:
<table_info>
{table_info}
</table_info>
你的核心任务是根据用户之前的提问和当前生成的SQL语句,引导用户理解当前SQL的含义,并询问是否需要修改或完善,同时提供进一步可选的查询示例,引导用户提出更具体的需求。
交流要求如下:
- 先明确SQL的用途,再提出引导性问题。
- 回答要简洁、易理解。
- 回复内容不要出现SQL语句,不要对SQL进行解释,只需说,查询结果已生成,然后引导用户进一步提问。
任务流程如下:
1. 询问用户是否需要对当前查询内容进行修改或完善。
2. 提供进一步可选的查询示例,基于当前的数据库表结构,引导用户提出更具体需求。
你的回复:
"""
)
summarizeChain = summarizePrompt | llm
def summarize_ai(state: State):
"""AI 总结输出"""
state["reply"] = summarizeChain.invoke({
"ai_role": state["ai_role"],
"sql": state['sql'],
"userInput": state['userInput'],
"table_info": state['table_info'],
}).content
return state
# ------------------------------------------------------------------------ 构建有向图 --------
workflow = StateGraph(State)
workflow.add_node("decide", decide_source)
workflow.add_node("sql_1", sql)
workflow.add_node("chat", chat)
workflow.add_node("summarize", summarize_ai)
workflow.set_entry_point("decide")
workflow.add_edge("sql_1", "summarize")
# 条件边:根据 path 决定走向
workflow.add_conditional_edges(
"decide",
lambda state: state["path"], # 返回 state["path"] 的值
{
"db": "sql_1",
"chat": "chat",
}
)
workflow.add_edge("summarize", END)
workflow.add_edge("chat", END)
graph = workflow.compile()
# 执行函数
def get_db_agent_reply(aiId: str, userInput: str, tenant_id: str, sql: str = "") -> str:
json = pgdb.get_ai_personality(aiId)
ai_service = json["业务"]
ai_role = json["性格"]
final_state = graph.invoke({
"ai_service": ai_service,
"ai_role": ai_role,
"table_info": pgdb.get_available_tables_str(aiId),
"tenant_id": tenant_id,
"userInput": userInput,
"sql": sql,
"isFirstGenSQL": sql == "",
})
return final_state
+246
View File
@@ -0,0 +1,246 @@
from typing import Literal
from langchain_core.messages import AIMessage
from langchain_core.runnables import RunnableConfig
from langgraph.graph import END, START, MessagesState, StateGraph
from langgraph.prebuilt import ToolNode
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from config.llm import llm,llmThink
from langgraph.graph import StateGraph, END
from langchain.prompts import PromptTemplate
from config.llm import llm
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_community.agent_toolkits import create_sql_agent
from langchain.prompts import PromptTemplate
from config.llm import llm
from typing import Annotated
from langgraph.graph.message import add_messages
import os
from langchain_tavily import TavilySearch
from langgraph.prebuilt import ToolNode, tools_condition
from llm.chatLLM import get_chat_response
from typing import TypedDict
from langgraph.graph import StateGraph, END
from llm.summarizeLLM import getSummary
import db.postgres as pgdb
import db.sqlserver as sqlserver
from typing import List, Dict
import db.milvus as milvus
# -------- 定义状态 --------
class State(TypedDict):
path: str # 开始聊天选择的路径
memory:str # 记忆
knowledge: str # 知识库内容
history: str # 聊天历史
ai_id : str # AI id
ai_name:str # AI 名称
ai_service: str # AI 角色 业务
ai_role: str # AI 角色 性格特点
kn_bases: List[str] # AI 所使用的知识库
userInput: str # 用户输入
reply: str # 最终回复
# -------- 定义节点 --------
# ------------------------------------------------------------------------ 向量数据库查询 --------
gen_sql_prompt = PromptTemplate(
input_variables=["userInput"],
template = """你的任务是对用户输入进行意图分析,并将其分解成方便进行知识向量数据库搜索的关键词。
以下是用户的输入:
<用户输入>
{userInput}
</用户输入>
在提取关键词时,请遵循以下方法和要求:
1. 去除输入中的停用词(如“的”“是”“在”等)、语气词和无实际意义的符号。
2. 识别输入中的核心概念、实体和关键动作。
3. 尽量使用简洁、通用的词汇作为关键词。
4. 确保关键词之间相互独立,不包含其他关键词。
关键词之间用空格分隔。
你的回答是:
"""
)
sqlChain = gen_sql_prompt | llm
def db_search(state: State):
key_words = sqlChain.invoke({
"userInput": state['userInput'],
}).content
print("关键词是:", key_words)
knowledge = milvus.get_knowledge_by_key_words(key_words, state['kn_bases'])
print("知识库内容是:", knowledge)
state["knowledge"] = knowledge
ai_ids = [state['ai_id']]
memory = milvus.get_memory_by_key_words(key_words, ai_ids)
print("记忆是:", memory)
state["memory"] = memory
return state
# ------------------------------------------------------------------------ 意图分析 --------
pathSelectPrompt = PromptTemplate(
input_variables=[ "userInput","ai_service","history"],
template = """
你是一个意图分类器,负责判断用户提问是否与你的工作相关,进而确定是否需要去查知识库。
以下是你负责的工作内容:
<ai_service>
{ai_service}
</ai_service>
这是你们的对话历史:
<history>
{history}
</history>
用户最新回复是:
<userInput>
{userInput}
</userInput>
判断规则如下:
如果用户最新回复与你的负责工作相关,需要去查知识库,输出“kn”;如果不相关,则输出“chat”,不要包含任何标点符号以及空格。
你生成的结果:
"""
)
pathSelectChain = pathSelectPrompt | llmThink
def decide_source(state: State, max_retry=3):
"""根据用户输入选择数据来源"""
for _ in range(max_retry):
choice = pathSelectChain.invoke({
"userInput": state["userInput"],
"ai_service": state["ai_service"],
"history": state["history"],
}).content.strip().lower()
print("根据用户输入选择数据来源,路径是:", choice)
if choice in ["kn", "chat"]:
state["path"] = choice
break
else:
# 如果连续 max_retry 次都不合法,默认走 chat
state["path"] = "chat"
return state
# ------------------------------------------------------------------------ !普通聊天 --------
noChatPrompt = PromptTemplate(
input_variables=[ "ai_name", "ai_service", "ai_role", "history"],
template = """
你的名字是:{ai_name},你负责的业务是{ai_service},你具有{ai_role}的性格特点。
这是你和用户的对话历史
<history>
{history}
</history>
在回复用户时,请遵循以下指南:
1. 回复要与AI角色业务相关,体现AI的专业能力。
2. 回复内容的语气和风格要符合AI角色性格特点。
3. 参考聊天历史,使回复具有连贯性和针对性。
4. 回复要简洁明了,避免冗长和复杂的表述。
你的回答:
"""
)
noChatChain = noChatPrompt | llm
def chat(state: State):
state["reply"] = noChatChain.invoke({
"ai_name": state["ai_name"],
"ai_service": state["ai_service"],
"ai_role": state["ai_role"],
"history": state["history"],
"userStr": state["userInput"]
}).content
print("直接回复")
return state
# ------------------------------------------------------------------------ 整理结果 --------
summarizePrompt = PromptTemplate(
input_variables=["ai_name", "ai_service", "ai_role", "history", "knowledge"],
template = """
你的任务是基于给定的AI名称、AI角色业务、AI角色性格特点和聊天历史来回复用户。请仔细阅读以下信息,并按照指示进行回复。
你的名字是:{ai_name},你负责的业务是{ai_service},你具有{ai_role}的性格特点。
这是你和用户的对话历史
<history>
{history}
</history>
这是给你参考的知识库:
<knowledge>
{knowledge}
</knowledge>
{memory}
在回复时,请遵循以下指南:
1. 回复内容要与你负责的业务相关。
2. 回复的语气要结合你的性格特点。
3. 确保回复内容清晰、简洁、有针对性。
请生成你的回复:
"""
)
summarizeChain = summarizePrompt | llm
def summarize_ai(state: State):
"""AI 总结输出"""
mem = state['memory']
if mem != "":
memStr = """
这是给你参考的相关历史记忆:
<memory>
%s
</memory>
""" % mem # 这里用 % 把 mem 填进去
else:
memStr = "没有记忆内容"
print("历史记录是:" ,state["history"])
state["reply"] = summarizeChain.invoke({
"ai_role":state["ai_role"],
"ai_name":state["ai_name"],
"history":state["history"],
"ai_service":state['ai_service'],
"knowledge": state["knowledge"],
"memory": memStr,
}).content
return state
# ------------------------------------------------------------------------ 构建有向图 --------
workflow = StateGraph(State)
workflow.add_node("decide", decide_source)
workflow.add_node("db_search", db_search)
workflow.add_node("chat", chat)
workflow.add_node("summarize", summarize_ai)
workflow.set_entry_point("decide")
# 条件边:根据 path 决定走向
workflow.add_conditional_edges(
"decide",
lambda state: state["path"], # 返回 state["path"] 的值
{
"kn": "db_search",
"chat": "chat",
}
)
workflow.add_edge("db_search", "summarize")
workflow.add_edge("summarize", END)
workflow.add_edge("chat", END)
graph = workflow.compile()
# 执行函数
def get_service_agent_reply(aiId:str, userInput: str,history:str, kn_bases:List[str]) :
json = pgdb.get_ai_personality(aiId)
ai_service = json["业务"]
ai_role = json["性格"]
ai_name = json["名字"]
print("AI Name:", ai_name)
print("AI Service:", ai_service)
final_state = graph.invoke({
"ai_service":ai_service,
"ai_role":ai_role,
"ai_name":ai_name,
"history":history,
"kn_bases":kn_bases,
"table_info": pgdb.get_available_tables_str(aiId),
"userInput": userInput,
"ai_id": aiId,
})
return final_state["reply"]
+13 -3
View File
@@ -1,10 +1,15 @@
from fastapi import FastAPI
from routers.Chat import router
from routers.Chat import chatRouter
from routers.Report import reportRouter
from routers.Datasource import reportDataRouter
from fastapi.middleware.cors import CORSMiddleware
from routers.Knowledge import knowledgeRouter
from routers.Service import serviceRouter
from routers.Bot import botRouter
app = FastAPI(title="BBIT_AI")
origins = [
"http://localhost:5173", # Vite dev 默认端口
"http://localhost:8090", # Vite dev 默认端口
"http://127.0.0.1:5173",
"http://s1.ronsunny.cn:8089",
"*" # ⚠️ 生产环境不要用
@@ -17,4 +22,9 @@ app.add_middleware(
allow_methods=["*"], # 必须包含 OPTIONS、GET 等
allow_headers=["*"],
)
app.include_router(router, prefix="/api/llm", tags=["chat"])
app.include_router(chatRouter, prefix="/api/llm", tags=["chat"])
app.include_router(reportRouter, prefix="/api/llm", tags=["chat"])
app.include_router(knowledgeRouter, prefix="/api/llm", tags=["chat"])
app.include_router(reportDataRouter, prefix="/api/llm", tags=["chat"])
app.include_router(serviceRouter, prefix="/api/llm", tags=["chat"])
app.include_router(botRouter, prefix="/api/llm", tags=["chat"])
+29 -1
View File
@@ -1,6 +1,34 @@
from langchain_community.chat_models.tongyi import ChatTongyi
from utils.Tools import all_tools
from langchain.chat_models import init_chat_model
from langchain_openai import ChatOpenAI
from openai import OpenAI
import os
from langchain_openai import OpenAIEmbeddings
# 通义千文Key
tongyiKey = "sk-9464b2498c184982a9fe9d2c2e725ab5"
# DeepSeekKey
deepseekKey = "sk-6129a200ae294b9f86553505191fa477"
llm = ChatTongyi(streaming=False, api_key="sk-fb46eefb6b404382a0a5325202e923a6")
llm = ChatTongyi(streaming=False, api_key=tongyiKey)
llm_with_tools = llm.bind_tools(all_tools)
llmThink = ChatOpenAI(
api_key=tongyiKey,
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
model="qwen-max",
stream = False
)
from langchain_community.embeddings import DashScopeEmbeddings
embeddings = DashScopeEmbeddings(
model="text-embedding-v3",
dashscope_api_key= tongyiKey,
)
# from langchain_deepseek import ChatDeepSeek
# llm = ChatDeepSeek(
# model="deepseek-reasoner",
# api_key=deepseekKey,
# api_base="https://api.deepseek.com"
# )
+35
View File
@@ -0,0 +1,35 @@
from langchain_milvus import BM25BuiltInFunction, Milvus
from config.llm import embeddings
URI = "http://10.10.10.9:19530"
knVectorstore = Milvus(
embedding_function=embeddings,
connection_args={"uri": URI, "token": "root:Milvus", "db_name": "bbit_ai_lab"},
collection_name="knowledge",
index_params={"index_type": "FLAT", "metric_type": "L2"},
consistency_level="Strong",
auto_id=True,
primary_field = "id",
text_field="text",
vector_field="vector",
partition_key_field = "kn_id",
enable_dynamic_field = True,
drop_old=False, # set to True if seeking to drop the collection with that name if it exists
)
memVectorstore = Milvus(
embedding_function=embeddings,
connection_args={"uri": URI, "token": "root:Milvus", "db_name": "bbit_ai_lab"},
collection_name="memory",
index_params={"index_type": "FLAT", "metric_type": "L2"},
consistency_level="Strong",
auto_id=True,
primary_field = "id",
text_field="text",
vector_field="vector",
partition_key_field = "ai_id",
enable_dynamic_field = True,
drop_old=False, # set to True if seeking to drop the collection with that name if it exists
)
+64 -17
View File
@@ -1,22 +1,69 @@
import psycopg
from langchain_postgres import PostgresChatMessageHistory
from psycopg_pool import ConnectionPool
import logging
import time
from contextlib import contextmanager
from typing import Optional
import psycopg
from psycopg_pool import ConnectionPool
# conn = psycopg.connect("postgresql://postgres:123456@10.10.10.9/ktor2")
database_name = "ai_chat_history"
pool = ConnectionPool("postgresql://postgres:123456@10.10.10.9/ktor2")
logger = logging.getLogger("PGPool")
logger.setLevel(logging.INFO)
@contextmanager
def getConn():
with pool.connection() as temp:
temp.autocommit = True # 如果你想所有连接默认 autocommit
yield temp # 把 conn 暴露给外部使用
class PGPool:
"""
PostgreSQL 连接池封装
"""
def __init__(
self,
uri: str,
min_size: int = 1,
max_size: int = 20,
max_idle: int = 30,
max_lifetime: int = 300,
timeout: int = 10,
check: bool = False,
):
"""
:param uri: PostgreSQL 连接 URI
"""
self.uri = uri
self.pool = ConnectionPool(
self.uri,
min_size=min_size,
max_size=max_size,
max_idle=max_idle,
max_lifetime=max_lifetime,
timeout=timeout,
check=check,
)
def init():
with getConn() as connection:
PostgresChatMessageHistory.create_tables(connection, database_name)
init()
@contextmanager
def getConn(self, retries: int = 2, delay: float = 1.0):
"""
获取数据库连接,带重试机制,自动健康检查。
使用方式:
with pg_pool.get_conn() as conn:
with conn.cursor() as cur:
cur.execute(...)
"""
attempt = 0
while attempt <= retries:
try:
with self.pool.connection() as conn:
conn.autocommit = True
yield conn
return
except psycopg.OperationalError as e:
logger.warning(f"数据库连接异常: {e}. 尝试重试 ({attempt+1}/{retries})")
self.pool.check() # 丢掉坏连接,重新建
attempt += 1
time.sleep(delay)
except Exception as e:
logger.error(f"SQL执行异常: {e}")
raise
raise psycopg.OperationalError("无法获取数据库连接,多次重试失败")
pg_pool = PGPool(
uri="postgresql://postgres:123456@10.10.10.9/ktor2",
min_size=1,
max_size=20,
)
+73 -3
View File
@@ -2,11 +2,81 @@ from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.chat_models.tongyi import ChatTongyi
# SQLAlchemy URI
uri = "mssql+pyodbc://f8_db_test:APN^QPr!K9@122.114.58.23/f8_db_test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes"
uri = "mssql+pyodbc://f8_db_test:APN^QPr!K9@122.114.58.23/f8_db_test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes&Encrypt=no"
# 建立数据库对象
ssDB = SQLDatabase.from_uri(
ssDBLC = SQLDatabase.from_uri(
uri,
include_tables=["NONGHU_INFO"], # 不加 schema 前缀试试
include_tables=["NONGHU_INFO","NONGHU_BLACKLIST",],
schema="dbo" # 显式指定 schema
)
import logging
import time
from contextlib import contextmanager
from sqlalchemy import create_engine, text
from sqlalchemy.exc import OperationalError as SQLOperationalError
from urllib.parse import quote_plus
logger = logging.getLogger("MSSQLPool")
logger.setLevel(logging.INFO)
class MSSQLPool:
"""
SQL Server 连接池封装
"""
def __init__(
self,
user: str,
password: str,
host: str,
database: str,
driver: str = "ODBC Driver 18 for SQL Server",
encrypt: str = "no",
trust_server_certificate: str = "yes",
min_size: int = 1,
max_size: int = 10,
pool_timeout: int = 10,
):
self.user = user
self.password = quote_plus(password) # 处理特殊字符
self.host = host
self.database = database
self.driver = driver
self.encrypt = encrypt
self.trust_server_certificate = trust_server_certificate
self.engine = create_engine(
f"mssql+pyodbc://{self.user}:{self.password}@{self.host}/{self.database}"
f"?driver={self.driver}&Encrypt={self.encrypt}&TrustServerCertificate={self.trust_server_certificate}",
pool_size=min_size,
max_overflow=max_size - min_size,
pool_timeout=pool_timeout,
)
@contextmanager
def getConn(self, retries: int = 2, delay: float = 1.0):
"""
获取连接,带重试
"""
attempt = 0
while attempt <= retries:
try:
with self.engine.connect() as conn:
yield conn
return
except SQLOperationalError as e:
logger.warning(f"数据库连接异常: {e}. 尝试重试 ({attempt+1}/{retries})")
attempt += 1
time.sleep(delay)
except Exception as e:
logger.error(f"SQL执行异常: {e}")
raise
raise SQLOperationalError("无法获取数据库连接,多次重试失败")
mssql_pool = MSSQLPool(
user="f8_db_test",
password="APN^QPr!K9",
host="122.114.58.23",
database="f8_db_test",
)
+109
View File
@@ -0,0 +1,109 @@
from config.milvus import knVectorstore,memVectorstore
from langchain.schema import Document
from datetime import datetime
from typing import List
from typing import List, Dict, Any
def get_knowledge_by_key_words(key_words: str, kn_ids: List[str]) -> str:
"""
根据关键词和 kn_ids 列表,在知识库中检索相关内容,并返回整理后的文本字符串
"""
# 构建过滤表达式:只查 kn_ids 范围内的
if kn_ids:
ids_expr = " or ".join([f'kn_id == "{kid}"' for kid in kn_ids])
expr = f"({ids_expr})"
else:
return "未找到相关的知识。"
result = knVectorstore.similarity_search(
query=key_words,
k=3, # 可调节返回条数
expr=expr
)
# 整理成字符串
doc_texts = []
for idx, doc in enumerate(result, start=1):
text = doc.page_content.strip()
if text:
# 可以加个编号,便于LLM区分
doc_texts.append(f"[文档{idx}]: {text}")
# 拼成一个大字符串,用换行隔开
combined_text = "\n\n".join(doc_texts)
return combined_text
def get_memory_by_key_words(key_words: str, ai_ids: List[str]) -> str:
print("ai_id是:" , ai_ids)
"""
根据关键词和 ai_ids 列表,在知识库中检索相关内容,并返回整理后的文本字符串
"""
# 构建过滤表达式:只查 kn_ids 范围内的
if ai_ids:
ids_expr = " or ".join([f'ai_id == "{kid}"' for kid in ai_ids])
expr = f"({ids_expr})"
else:
expr = "" # 不限制 kn_id todo 实际上应该不反悔任何内容
result = memVectorstore.similarity_search(
query=key_words,
k=5, # 可调节返回条数
expr=expr
)
# 整理成字符串
doc_texts = []
for idx, doc in enumerate(result, start=1):
text = doc.page_content.strip()
if text:
# 可以加个编号,便于LLM区分
doc_texts.append(f"[记忆{idx}]: {text}")
# 拼成一个大字符串,用换行隔开
combined_text = "\n\n".join(doc_texts)
return combined_text
def get_knowledge_by_base_id(base_id: str):
expr = f'kn_id == "{base_id}"' # base_id 会被替换
result = knVectorstore.similarity_search(
query="", # 如果只想用过滤条件,可以传空字符串
k=100,
expr=expr
)
return [
{
"id": str(doc.metadata["id"]),
"text": doc.page_content,
"is_active": doc.metadata["is_active"],
}
for doc in result
]
def add_knowledge(text: str, is_active: bool, base_id: str, user_id: str):
docs = [
Document(
page_content=text,
metadata={
"kn_id": str(base_id),
"created_by": str(user_id),
"created_at": datetime.now().isoformat(),
"is_active": is_active,
}
)
]
return knVectorstore.add_documents(docs)
def add_memory(ai_id:str,mem: str, user_id: str,is_active: bool):
docs = [
Document(
page_content=mem,
metadata={
"ai_id": str(ai_id),
"created_by": str(user_id),
"created_at": datetime.now().isoformat(),
"is_active": is_active,
}
)
]
return memVectorstore.add_documents(docs)
+322 -28
View File
@@ -1,11 +1,16 @@
import psycopg
from langchain_postgres import PostgresChatMessageHistory
from config.pgDb import database_name,getConn
from config.pgDb import pg_pool
from config.ssDb import mssql_pool
from typing import List, Dict
import json
# ————————————————————————————————————————————————————AI角色———————————————————————————————
database_name = "ai_chat_history"
def get_ai_personality(ai_id: str):
with getConn() as conn:
with pg_pool.getConn() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT ai_personality FROM ai_chat_profiles WHERE id = %s",
@@ -17,8 +22,32 @@ def get_ai_personality(ai_id: str):
else:
return "你是一个乐于助人的AI助手,请保持中文简洁回答用户。"
def get_all_ai(user_id: str) -> List[Dict]:
with getConn() as conn:
def get_description(ai_id: str):
with pg_pool.getConn() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT description FROM ai_chat_profiles WHERE id = %s",
(ai_id,)
)
row = cur.fetchone()
if row:
return row[0]
else:
return "你是一个乐于助人的AI助手,请保持中文简洁回答用户。"
def get_ai_available_kn_bases(ai_id: str) -> List[str]:
with pg_pool.getConn() as conn:
result = conn.execute(
"SELECT available_kn_bases FROM ai_chat_profiles WHERE id = %s",
(ai_id,)
)
return result.fetchone()[0]
def get_all_ai_bot(user_id: str, module: str) -> List[Dict]:
with pg_pool.getConn() as conn:
with conn.cursor() as cur:
# 查询用户角色
cur.execute(
@@ -31,27 +60,43 @@ def get_all_ai(user_id: str) -> List[Dict]:
user_roles = role_row[0]
# 查询 AI 角色 JSON 字段包含用户角色
cur.execute(
"""
SELECT id, name, welcome_words
SELECT id, title, description, welcome_words, ai_personality, available_report_tables, available_kn_bases
FROM ai_chat_profiles
WHERE availabel_roles::jsonb ?| %s
WHERE available_module = %s
AND is_active = TRUE
AND available_roles::jsonb ?| %s
""",
(user_roles,) # user_roles 是 list,比如 ["a", "b", "c"]
(module, user_roles)
)
rows = cur.fetchall()
return [
{
"id": row[0],
"name": row[1],
"welcomeWords": row[2],
}
for row in rows
]
result = []
for row in rows:
# row 索引对应 SELECT 字段顺序
id_, title, description, welcome_words, ai_personality, available_report_tables, available_kn_bases = row
# 解析 JSON
roles_json = ai_personality if ai_personality else {}
result.append({
"id": id_,
"title": title,
"description": description,
"welcome_words": welcome_words,
"name": roles_json.get("名字", ""),
"role": roles_json.get("性格", ""),
"service": roles_json.get("业务", ""),
"available_report_tables": available_report_tables,
"available_kn_bases": available_kn_bases
})
return result
# ————————————————————————————————————————————————————消息———————————————————————————————
def insert_message(session_id: str, isAI: bool, content: str):
with getConn() as conn:
with pg_pool.getConn() as conn:
history = PostgresChatMessageHistory(
database_name,
session_id,
@@ -62,34 +107,56 @@ def insert_message(session_id: str, isAI: bool, content: str):
else:
history.add_user_message(content)
def get_history(session_id: str):
with getConn() as conn:
simplified = []
with pg_pool.getConn() as conn:
history = PostgresChatMessageHistory(
database_name,
session_id,
sync_connection=conn
)
simplified = []
for msg in history.messages:
simplified.append({
"type": msg.type,
"content": msg.content
})
return simplified
def get_history_with_time(session_id: str, number: int):
simplified = []
with pg_pool.getConn() as conn:
with conn.cursor() as cur:
cur.execute(
f"SELECT message, created_at FROM ai_chat_history WHERE session_id = '{session_id}' ORDER BY created_at DESC LIMIT {number}")
rows = cur.fetchall()
simplified = []
for row in rows:
msg_dict = row[0]
simplified.append({
"type": msg_dict.get("type"),
"created_at": row[1].isoformat(),
"content": msg_dict.get("data", {}).get("content")
})
return simplified
# ————————————————————————————————————————————————————会话———————————————————————————————
def insert_session(user_id: str,ai_id:str, session_id: str,session_title: str):
with getConn() as coon:
def insert_session(user_id: str, ai_id: str, session_id: str, session_title: str, available_module):
with pg_pool.getConn() as coon:
with coon.cursor() as cur:
cur.execute(
"INSERT INTO ai_chat_sessions (id ,user_id, ai_id, title, created_at, updated_at) VALUES (%s, %s, %s, %s, NOW(), NOW())",
(session_id, user_id, ai_id, session_title )
"INSERT INTO ai_chat_sessions (id ,user_id, ai_id, title, available_module, created_at, updated_at) VALUES (%s, %s, %s, %s,%s, NOW(), NOW())",
(session_id, user_id, ai_id, session_title, available_module)
)
coon.commit()
def update_session_updated_at(session_id: str):
with getConn() as conn:
with pg_pool.getConn() as conn:
with conn.cursor() as cur:
cur.execute(
"UPDATE ai_chat_sessions SET updated_at = NOW() WHERE id = %s",
@@ -97,13 +164,18 @@ def update_session_updated_at(session_id: str):
)
conn.commit()
def get_sessions(user_id: str):
with getConn() as conn:
def get_sessions(user_id: str, available_module: str):
with pg_pool.getConn() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT id, title, updated_at FROM ai_chat_sessions WHERE user_id = %s ORDER BY updated_at DESC",
(user_id,)
"SELECT id, title, updated_at "
"FROM ai_chat_sessions "
"WHERE user_id = %s AND available_module = %s "
"ORDER BY updated_at DESC",
(user_id, available_module)
)
sessions = cur.fetchall()
return [
{
@@ -113,3 +185,225 @@ def get_sessions(user_id: str):
}
for row in sessions
]
# ————————————————————————————————————————————————————报表———————————————————————————————
def get_reports(user_id: str):
with pg_pool.getConn() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT id, title FROM ai_reports WHERE created_by = %s AND is_masked = TRUE ORDER BY created_at DESC",
(user_id,)
)
reports = cur.fetchall()
return [
{
"id": row[0],
"title": row[1]
}
for row in reports
]
def save_report(id: str, user_id: str, title: str, sql: str):
with pg_pool.getConn() as conn:
with conn.cursor() as cur:
cur.execute(
"INSERT INTO ai_reports (id, title, sql, created_at, created_by , is_masked) VALUES (%s, %s, %s, NOW(), %s, FALSE) RETURNING id",
(id, title, sql, user_id)
)
report_id = cur.fetchone()[0]
conn.commit()
return report_id
def maked_report(report_id: str, title: str):
with pg_pool.getConn() as conn:
with conn.cursor() as cur:
cur.execute(
"UPDATE ai_reports SET title = %s, is_masked = TRUE WHERE id = %s",
(title, report_id)
)
conn.commit()
def getSQL(reportId: str):
with pg_pool.getConn() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT sql FROM ai_reports WHERE id = %s",
(reportId,)
)
row = cur.fetchone()
if row:
return row[0]
else:
return ""
def get_available_tables_str(aiId: str):
with pg_pool.getConn() as conn:
with conn.cursor() as cur:
# 1. 先取 AI 可用的数据库表
cur.execute(
"SELECT available_report_tables FROM ai_chat_profiles WHERE id = %s",
(aiId,)
)
role_row = cur.fetchone()
if not role_row:
return "无数据库表可用"
available_tables = role_row[0] # 假设是列表
if not available_tables:
return "无数据库表可用"
# 2. 构造 IN 查询占位符
placeholders = ','.join(['%s'] * len(available_tables))
sql_query = f"""
SELECT id, name, description
FROM ai_reports_tables
WHERE id IN ({placeholders}) AND is_active = TRUE
"""
cur.execute(sql_query, available_tables)
tableIds = cur.fetchall()
# 3. 查询这些表的字段
result = ""
for table in tableIds:
cur.execute(
"SELECT name, type, description FROM ai_reports_fields WHERE table_id = %s AND is_active = TRUE",
(table[0],)
)
columns = cur.fetchall()
result += f"{table[1]}{table[2]}\n"
result += "字段名,数据类型,描述\n"
for column in columns:
result += f"{column[0]},{column[1]}, {column[2]}\n"
result += "\n"
return result
# -------------------报表数据源------------------
# 获取表
def get_available_tables():
with pg_pool.getConn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"SELECT id, name, description,is_active FROM ai_reports_tables",
)
return [{"id": row[0], "name": row[1], "description": row[2], "is_active": row[3]} for row in
cursor.fetchall()]
# 新增表
def add_table(name, description, user_id):
with pg_pool.getConn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
INSERT INTO ai_reports_tables (name, description, create_by)
VALUES (%s, %s, %s)
RETURNING id
""",
(name, description, user_id)
)
new_id = cursor.fetchone()[0] # 取返回的 id
return new_id
# 获取字段
def get_fields_by_table_id(table_id):
with pg_pool.getConn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"SELECT id, name, type, description, is_active FROM ai_reports_fields WHERE table_id = %s",
(table_id,),
)
return [{"id": row[0], "name": row[1], "type": row[2], "description": row[3], "is_active": row[4]} for row
in cursor.fetchall()]
# 新增字段
def add_field(name, type, description, is_active, table_id, user_id):
with pg_pool.getConn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"INSERT INTO ai_reports_fields (name,type,description, is_active, create_by, table_id) VALUES (%s, %s, %s, %s, %s, %s) RETURNING id",
(name, type, description, is_active, user_id, table_id)
)
new_id = cursor.fetchone()[0] # 取返回的 id
return new_id
# 新增报表智能体
def insert_bot(title: str, description: str, welcome_words: str, ai_personality: str, available_module: str,
available_report_tables: str, available_kn_bases: str, user_id: str):
with pg_pool.getConn() as conn:
with conn.cursor() as cursor:
available_roles = json.dumps(['user'])
cursor.execute(
"""
INSERT INTO ai_chat_profiles
(available_module,available_roles, title, description, welcome_words, ai_personality, available_report_tables, available_kn_bases, created_by, created_at)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, now())
RETURNING id
""",
(available_module, available_roles, title, description, welcome_words, ai_personality,
available_report_tables, available_kn_bases, user_id)
)
report_id = cursor.fetchone()[0]
return report_id
# 更新报表智能体
def update_bot(id: str, title: str, description: str, welcome_words: str, ai_personality: str, available_module: str,
available_report_tables: str, available_kn_bases: str, user_id: str):
with pg_pool.getConn() as conn:
with conn.cursor() as cursor:
cursor.execute("""
UPDATE ai_chat_profiles
SET title = %s,
description = %s,
ai_personality = %s,
welcome_words = %s,
available_report_tables = %s,
available_kn_bases = %s,
available_module = %s,
updated_at = NOW(),
updated_by = %s
WHERE id = %s
""",
(title, description, ai_personality, welcome_words, available_report_tables,
available_kn_bases, available_module, user_id, id)
)
# ————————————————————————————————————————————————————知识库———————————————————————————————
def get_available_knowledge_bases(available_module: str):
with pg_pool.getConn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
SELECT id, name, description, is_active
FROM ai_knowledge
WHERE available_module::jsonb @> %s::jsonb
""",
(f'["{available_module}"]',)
)
return [{"id": row[0], "name": row[1], "description": row[2], "is_active": row[3]} for row in
cursor.fetchall()]
def add_knowledge_base(name, description, user_id):
with pg_pool.getConn() as conn:
with conn.cursor() as cursor:
cursor.execute(
"""
INSERT INTO ai_knowledge (name, description, created_by, created_at)
VALUES (%s, %s, %s, now())
RETURNING id
""",
(name, description, user_id)
)
new_id = cursor.fetchone()[0] # 取返回的 id
+37
View File
@@ -0,0 +1,37 @@
from sqlalchemy import text
from sqlalchemy import text
from config.pgDb import pg_pool
from config.ssDb import mssql_pool
from sqlalchemy import text
def executeSQL(sql: str):
"""
执行 SQL 并返回结果列表,每行是 dict
"""
with mssql_pool.getConn() as conn:
result = conn.execute(text(sql))
# SQLAlchemy 2.x 返回 Row 对象,转成 dict
return [dict(row._mapping) for row in result]
def get_company_list(user_id: str):
# 1️⃣ 从 PostgreSQL 获取 tenant_id
with pg_pool.getConn() as pg_conn:
with pg_conn.cursor() as cur:
cur.execute("SELECT bbit_tenant_id FROM users WHERE id = %s", (user_id,))
row = cur.fetchone()
tenant_id = row[0] if row else None
# 2️⃣ 从 SQL Server 查询租户信息
if tenant_id:
query = text("SELECT Id, Name FROM dbo.POC_TENANTS WHERE Id = :tenant_id")
params = {"tenant_id": tenant_id}
else:
query = text("SELECT Id, Name FROM dbo.POC_TENANTS")
params = {}
with mssql_pool.getConn() as mssql_conn:
result = mssql_conn.execute(query, params)
return [{"id": str(row[0]), "name": row[1]} for row in result.fetchall()]
+1 -1
View File
@@ -5,7 +5,7 @@ from langchain.prompts import PromptTemplate
chatPrompt = PromptTemplate(
input_variables=["aiRole", "history", "userInput"],
template = """
是一个人,用户画像为:{aiRole}
用户画像为:{aiRole}
你需要基于你的角色性格,使用中文回答用户。
聊天历史:
+82
View File
@@ -0,0 +1,82 @@
from langchain.prompts import PromptTemplate
from config.llm import llm,llmThink
import db.milvus as milvus
import db.postgres as pg
import json
memPathPrompt = PromptTemplate(
input_variables=["ai_role", "CHAT_RECORD"],
template = """
你是一个记忆筛选器,负责判断最近对话的信息中,用户的回复内容是否对业务具有长期价值或潜在价值,或者可以帮助形成用户画像。
首先,请仔细阅读以下关于你业务的描述:
<ai_role>
{ai_role}
</ai_role>
现在,请仔细阅读以下你与用户的聊天记录:
<聊天记录>
{CHAT_RECORD}
</聊天记录>
请仔细考虑以下标准:
1. 长期价值:用户最新回复信息是否能为你的业务提供知识积累、经验总结、数据支持。
2. 相关性:用户最新回复是否与业务核心需求、目标、流程或潜在业务场景相关。
3. 潜在可用性:用户最新回复是否可能在未来的业务场景中被重复使用、参考或触发进一步操作。
你需要根据以上标准给出判断并得出"yes""no"
yes:用户最新回复具有直接或潜在长期价值,值得保留。
no:用户最新回复价值有限或几乎不会在未来业务中使用。
回复不要带任何标点符号以及空格、换行符。
请给出你的判断结果:
"""
)
memPathChain = memPathPrompt | llmThink
memPrompt = PromptTemplate(
input_variables=["CHAT_RECORD"],
template = """
你的任务是对给定的聊天记录进行关键信息的记忆总结。请仔细阅读以下聊天记录,并按照要求进行总结:
<聊天记录>
{CHAT_RECORD}
</聊天记录>
在总结时,请遵循以下指南:
1. 提取聊天记录中的用户所说的关键信息,包括主要话题、重要观点、达成的共识或决定等。
2. 用简洁明了的语言进行总结有价值的信息,避免冗长和复杂的表述。
3. 确保总结内容准确反映聊天记录中用户的核心内容,并尽可能简短。
4. 总结内容应包含时间,并确保时间是准确的。
5. 你需要针对你的业务场景{ai_role},展开对用户最后回复的总结。
请生成你的总结,以用户、时间开头:
"""
)
memChain = memPrompt | llmThink
def take_memory(ai_id:str,sessionId: str,user_id:str, max_retry=3):
"""根据用户输入选择数据来源"""
history = pg.get_history_with_time(sessionId,10)
print("获取的历史记录:",history)
ai_service = pg.get_description(ai_id)
if ai_service == "":
# AI描述没有描述,则取业务字段
json = pg.get_ai_personality(ai_id)
if json.get("业务", "") == "":
# AI没有任何描述,无法对记忆价值进行判断
print("AI没有任何描述,无法对记忆价值进行判断")
return
else:
ai_service = json["业务"]
print("获取的描述是:", ai_service)
choice = memPathChain.invoke({
"ai_role": ai_service,
"CHAT_RECORD": history,
}).content.strip().lower()
print("记忆判断器判断的结果是:", choice)
if choice == "yes":
# 对对话进行总结
memory = memChain.invoke({
"CHAT_RECORD": history,
"ai_role": ai_service,
}).content.strip().lower()
print("记忆生成结果是:", memory)
milvus.add_memory(mem = memory,user_id = user_id, is_active = True, ai_id = ai_id)
return
+69
View File
@@ -0,0 +1,69 @@
from config.llm import llm
from langchain.prompts import PromptTemplate
from config.ssDb import ssDBLC
from langchain_community.agent_toolkits import create_sql_agent
#______________________________________________________________SQL描述_____________________________________________________________________
sqlDescriptionPrompt = PromptTemplate(
input_variables=["sql"],
template = """
你是一个SQL专家,精通SQLServer数据库。请把一下SQL查询语句用通俗易懂的中文进行总结。
SQL语句:{sql}
有以下要求:
1. 不要任何解释
2. 不能有标点符号
3. 不能有markdown语法
4. 要用业务语言描述,不能有专业语句例如SQL表名等
请生成你认为合适的标题,:
"""
)
sqlDescriptionChain = sqlDescriptionPrompt | llm
def get_sql_description_response( sql: str) -> str:
return sqlDescriptionChain.invoke({
"sql": sql
})
#______________________________________________________________第一次生成SQL_____________________________________________________________________
sqlPrompt = PromptTemplate(
input_variables=["userInput"],
template = """
你是一个SQL专家,精通SQLServer数据库。
请根据用户的需求,生成相应的SQL查询语句。
只需要返回SQL语句,不要任何解释。
用户需求:{userInput}
请生成SQL语句:
"""
)
sqlChain = sqlPrompt | llm
agent = create_sql_agent(
llm=llm,
db=ssDBLC,
agent_type="tool-calling",
verbose=True
)
# def get_chat_sql_response2( userInput: str) -> str:
# return sqlChain.invoke({
# "userInput": userInput
# })
def get_chat_sql_response( userInput: str) -> str:
return agent.invoke({"input": userInput})["output"]
#______________________________________________________________改进SQL_____________________________________________________________________
sqlImprovePrompt = PromptTemplate(
input_variables=["userInput", "sql"],
template = """
你是一个SQL专家,精通SQLServer数据库。
请根据用户的需求,改进已有的SQL查询语句。
只需要返回改进后的SQL语句,不要任何解释。
已有SQL{sql}
用户需求:{userInput}
"""
)
sqlImproveChain = sqlImprovePrompt | llm
def get_chat_sql_improve_response( userInput: str) -> str:
return sqlImproveChain.invoke({
"userInput": userInput
})
+1
View File
@@ -10,6 +10,7 @@ titlePrompt = PromptTemplate(
2. 直接概括本次对话的核心内容。
3. 避免使用笼统或无意义的词语,如“讨论”、“聊天”等。
4. 保持自然、易懂、专业或有趣(可根据场景调整风格)。
5. 不能出现标点符号。
用户原话:"{userStr}"
"""
)
+13
View File
@@ -0,0 +1,13 @@
from pydantic import BaseModel
class AIProfilesRequest(BaseModel):
id: str | None = None
name: str
available_kn_bases:list[str]
available_report_tables:list[str]
description: str
role: str
service: str
welcome_words: str
title: str
available_module: str
+3 -2
View File
@@ -1,11 +1,12 @@
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Generic, TypeVar, Optional, List
from pydantic.generics import GenericModel
from pydantic import BaseModel
T = TypeVar("T")
# 定义通用响应结构
class BaseResponse(GenericModel, Generic[T]):
class BaseResponse(BaseModel, Generic[T]):
status: bool = True
message: str = "操作成功"
data: Optional[T] = None
@@ -0,0 +1,7 @@
from pydantic import BaseModel
class ChatWithReportRequest(BaseModel):
aiId: str
companyId: str
reportId: str | None = None
userInput: str
@@ -0,0 +1,6 @@
from pydantic import BaseModel
from typing import Optional
class KnowledgeAddRequest(BaseModel):
text: str
is_active: Optional[bool] = True
knowledge_base_id: str
@@ -0,0 +1,6 @@
from pydantic import BaseModel
from typing import Optional
class KnowledgeBaseAddRequest(BaseModel):
name: str
description: Optional[str] = None
@@ -0,0 +1,8 @@
from pydantic import BaseModel
class ReportFieldAddRequest(BaseModel):
name: str
type: str
description: str
is_active: bool
table_id: str
@@ -0,0 +1,5 @@
from pydantic import BaseModel
class ReportTableAddRequest(BaseModel):
name: str
description: str | None = None
+4
View File
@@ -0,0 +1,4 @@
from pydantic import BaseModel
class SaveReportRequest(BaseModel):
reportId: str | None = None
+76
View File
@@ -0,0 +1,76 @@
from models.ChatRequest import ChatRequest
from models.ChatWithReportRequest import ChatWithReportRequest
from models.BaseResponse import BaseResponse
import uuid
import db.postgres as pg
import db.sqlserver as sqlserver
import uuid
import threading
from fastapi import APIRouter, Depends
from uuid import UUID
from config.security import get_user_id_from_token
botRouter = APIRouter()
from llm.chatLLM import get_chat_response
from llm.titleChain import get_title
from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response
from models.SaveReportRequest import SaveReportRequest
from models.AIProfilesRequest import AIProfilesRequest
import json
@botRouter.get("/aiListForService")
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(data=pg.get_all_ai_bot(user_id,"service"))
@botRouter.get("/aiListForReport")
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(data=pg.get_all_ai_bot(user_id,"report"))
@botRouter.get("/aiListForBot")
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(data=pg.get_all_ai_bot(user_id,"bot"))
# 保存智能体
@botRouter.post("/saveBot")
def saveReportBot(bot: AIProfilesRequest,user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
print(bot)
ai_personality = {
"名字":bot.name,
"性格":bot.role,
"业务":bot.service,
}
ai_personality_json = json.dumps(ai_personality, ensure_ascii=False)
available_report_tables_json = json.dumps(bot.available_report_tables, ensure_ascii=False)
available_kn_bases_json = json.dumps(bot.available_kn_bases, ensure_ascii=False)
if bot.id:
pg.update_bot(
id = bot.id,
title =bot.title,
description = bot.description,
welcome_words = bot.welcome_words,
ai_personality = ai_personality_json,
available_kn_bases = available_kn_bases_json,
available_report_tables = available_report_tables_json,
available_module = bot.available_module,
user_id = user_id
)
else:
pg.insert_bot(
title =bot.title,
description = bot.description,
welcome_words = bot.welcome_words,
ai_personality = ai_personality_json,
available_kn_bases = available_kn_bases_json,
available_module = bot.available_module,
available_report_tables = available_report_tables_json,
user_id = user_id
)
return BaseResponse(data= None)
+27 -29
View File
@@ -1,63 +1,61 @@
from models.ChatRequest import ChatRequest
from models.ChatWithReportRequest import ChatWithReportRequest
from models.BaseResponse import BaseResponse
import uuid
import db.postgres as db
import db.postgres as pg
import db.sqlserver as sqlserver
import uuid
import threading
from fastapi import APIRouter, Depends
from uuid import UUID
from config.security import get_user_id_from_token
router = APIRouter()
chatRouter = APIRouter()
from agent.dataAgent import get_graph_output
from llm.chatLLM import get_chat_response
from llm.titleChain import get_title
from llm.dataLLM import get_graph_output
# def async_db_task(func, *args, **kwargs):
# """将数据库操作放到后台线程执行"""
# threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
@router.get("/history")
from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response
from models.SaveReportRequest import SaveReportRequest
from models.AIProfilesRequest import AIProfilesRequest
import json
# 对话历史记录
@chatRouter.get("/history")
def getHistory(sessionId: str):
return BaseResponse(data=db.get_history(sessionId))
return BaseResponse(data=pg.get_history(sessionId))
@router.get("/aiList")
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(data=db.get_all_ai(user_id))
@router.get("/sessions")
# 对话列表
@chatRouter.get("/sessionsForBot")
def getSessions(user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(data=db.get_sessions(user_id))
return BaseResponse(data=pg.get_sessions(user_id,'bot'))
@router.post("/chat")
@chatRouter.post("/chatForBot")
def chat(req: ChatRequest, user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
if not req.aiId:
return {"error": "aiId is required"}
sessionName = get_title(req.userInput)
session_name = get_title(req.userInput)
# 如果没有 sessionId 就新建
if not req.sessionId:
isNewSession = True
is_new_session = True
req.sessionId = str(uuid.uuid4())
db.insert_session(user_id,req.aiId, req.sessionId, sessionName)
pg.insert_session(user_id,req.aiId, req.sessionId, session_name,"bot")
else:
isNewSession = False
db.update_session_updated_at(req.sessionId)
is_new_session = False
pg.update_session_updated_at(req.sessionId)
# 插入用户消息
db.insert_message(req.sessionId, False, req.userInput)
pg.insert_message(req.sessionId, False, req.userInput)
# 调用 LLM
if req.aiId == "9d157dd1-921b-c768-5b90-3e903b50f6f9":
# 数据专家AI
answer = get_graph_output(aiRole=db.get_ai_personality(req.aiId),history=db.get_history(req.sessionId), userInput= req.userInput)
answer = get_graph_output(aiRole=pg.get_ai_personality(req.aiId),history=pg.get_history(req.sessionId), userInput= req.userInput)
else:
answer = get_chat_response(aiRole=db.get_ai_personality(req.aiId),history=db.get_history(req.sessionId), userInput= req.userInput).content
answer = get_chat_response(aiRole=pg.get_ai_personality(req.aiId),history=pg.get_history(req.sessionId), userInput= req.userInput).content
# 插入 AI 回复
db.insert_message(req.sessionId, True, answer)
pg.insert_message(req.sessionId, True, answer)
return BaseResponse(data={"session_name":session_name,"isNewSession":is_new_session,"content":answer,"sessionId": req.sessionId})
return BaseResponse(data={"sessionName":sessionName,"isNewSession":isNewSession,"content":answer,"sessionId": req.sessionId})
+54
View File
@@ -0,0 +1,54 @@
from models.ChatRequest import ChatRequest
from models.ChatWithReportRequest import ChatWithReportRequest
from models.BaseResponse import BaseResponse
import uuid
import db.postgres as pg
import db.sqlserver as sqlserver
import uuid
import threading
from fastapi import APIRouter, Depends
from uuid import UUID
from config.security import get_user_id_from_token
reportDataRouter = APIRouter()
from llm.chatLLM import get_chat_response
from llm.titleChain import get_title
from llm.sqlLLM import get_sql_description_response, get_chat_sql_response, get_chat_sql_improve_response
from models.SaveReportRequest import SaveReportRequest
from models.ReportTableAddRequest import ReportTableAddRequest
from models.ReportFieldAddRequest import ReportFieldAddRequest
# 获取表格列表
@reportDataRouter.get("/tableList")
def tableList(user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(data=pg.get_available_tables())
# 获取字段列表
@reportDataRouter.get("/fieldList")
def fieldList(tableId: str, user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
if not tableId:
return {"error": "tableId is required"}
return BaseResponse(data=pg.get_fields_by_table_id(tableId))
# 新增表
@reportDataRouter.post("/addTable")
def addTable(data: ReportTableAddRequest, user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(data=pg.add_table(data.name, data.description, user_id))
# 新增字段
@reportDataRouter.post("/addField")
def addField(data: ReportFieldAddRequest, user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(
data=pg.add_field(data.name, data.type, data.description, data.is_active, data.table_id, user_id))
+48
View File
@@ -0,0 +1,48 @@
from models.ChatRequest import ChatRequest
from models.ChatWithReportRequest import ChatWithReportRequest
from models.BaseResponse import BaseResponse
import uuid
import db.postgres as pg
import db.sqlserver as sqlserver
import db.milvus as milvus
import uuid
import threading
from fastapi import APIRouter, Depends
from uuid import UUID
from config.security import get_user_id_from_token
knowledgeRouter = APIRouter()
from llm.chatLLM import get_chat_response
from llm.titleChain import get_title
from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response
from models.SaveReportRequest import SaveReportRequest
from models.KnowledgeAddRequest import KnowledgeAddRequest
from models.KnowledgeBaseAddRequest import KnowledgeBaseAddRequest
# 获取知识库列表
@knowledgeRouter.get("/knowledgeBaseListForService")
def knowledge_base_list(user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(data=pg.get_available_knowledge_bases('service'))
# 新增知识库
@knowledgeRouter.post("/addKnowledgeBase")
def add_knowledge_base(data: KnowledgeBaseAddRequest, user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(data=pg.add_knowledge_base(data.name, data.description, user_id))
# 获取知识列表
@knowledgeRouter.get("/knowledgeList")
def knowledge_list(knowledgeBaseId: str, user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
if not knowledgeBaseId:
return {"error": "knowledgeBaseId is required"}
return BaseResponse(data=milvus.get_knowledge_by_base_id(knowledgeBaseId))
# 新增知识
@knowledgeRouter.post("/addKnowledge")
def add_knowledge(data: KnowledgeAddRequest, user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(data=milvus.add_knowledge(data.text, data.is_active, data.knowledge_base_id, user_id))
+64
View File
@@ -0,0 +1,64 @@
from models.ChatRequest import ChatRequest
from models.ChatWithReportRequest import ChatWithReportRequest
from models.BaseResponse import BaseResponse
import uuid
import db.postgres as pg
import db.sqlserver as sqlserver
import uuid
import threading
from fastapi import APIRouter, Depends
from uuid import UUID
from config.security import get_user_id_from_token
reportRouter = APIRouter()
from llm.chatLLM import get_chat_response
from llm.titleChain import get_title
from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response
from models.SaveReportRequest import SaveReportRequest
from agent.dbAgent import get_db_agent_reply
# 报表列表
@reportRouter.get("/reports")
def getReports(user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(data=pg.get_reports(user_id))
# 报表详情
@reportRouter.get("/report")
def getReports(reportId:str):
return BaseResponse(data={
"sql":pg.getSQL(reportId),
"tableData":sqlserver.executeSQL(pg.getSQL(reportId)),
})
# 保存报表
@reportRouter.post("/saveReport")
def saveReport(report:SaveReportRequest,user_id: UUID = Depends(get_user_id_from_token)):
sql = pg.getSQL(report.reportId)
# 生成描述
title = get_sql_description_response(sql = sql)
res = pg.maked_report(report_id=report.reportId,title=title.content)
return BaseResponse(data=res)
# 使用langgraph
@reportRouter.post("/chatWithReport")
def chat(req: ChatWithReportRequest, user_id: UUID = Depends(get_user_id_from_token)):
# 获取reportId
if not req.reportId:
# 新报表
sql = ""
else:
# 基于之前的报表
sql = pg.getSQL(req.reportId)
result = get_db_agent_reply(aiId=req.aiId, userInput = req.userInput,tenant_id=req.companyId, sql = sql)
sqlRes = result.get("sql", "")
newReportId = str(uuid.uuid4())
pg.save_report(id = newReportId, user_id = user_id, title="尚未收藏", sql=sqlRes)
if sqlRes.strip() != "":
tableData = sqlserver.executeSQL(sqlRes)
else:
tableData = []
return BaseResponse(data={"content":result["reply"],"sql":sqlRes,"reportId": newReportId, "tableData": tableData})
@reportRouter.get("/companyList")
def companyList(user_id: UUID = Depends(get_user_id_from_token)):
return BaseResponse(data=sqlserver.get_company_list(user_id))
+47
View File
@@ -0,0 +1,47 @@
from models.ChatRequest import ChatRequest
from models.BaseResponse import BaseResponse
import uuid
import db.postgres as pg
import uuid
from fastapi import APIRouter, Depends
from uuid import UUID
from config.security import get_user_id_from_token
serviceRouter = APIRouter()
from llm.titleChain import get_title
from agent.serviceAgent import get_service_agent_reply
from llm.memLLM import take_memory
import utils.MyUtils as utils
# 对话列表
@serviceRouter.get("/sessionsForService")
def getSessions(user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(data=pg.get_sessions(user_id,'service'))
# 对话
@serviceRouter.post("/chatForService")
def chat(req: ChatRequest, user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
if not req.aiId:
return {"error": "aiId is required"}
sessionName = get_title(req.userInput)
# 如果没有 sessionId 就新建
if not req.sessionId:
isNewSession = True
req.sessionId = str(uuid.uuid4())
pg.insert_session(user_id,req.aiId, req.sessionId, sessionName, "service")
else:
isNewSession = False
pg.update_session_updated_at(req.sessionId)
# 插入用户消息
pg.insert_message(req.sessionId, False, req.userInput)
answer = get_service_agent_reply(aiId=req.aiId,history=pg.get_history_with_time(req.sessionId,6), userInput= req.userInput,kn_bases=pg.get_ai_available_kn_bases(req.aiId))
# 插入 AI 回复
pg.insert_message(req.sessionId, True, answer)
# 异步执行:记忆判断
utils.async_db_task(take_memory,req.aiId,req.sessionId,user_id,)
return BaseResponse(data={"sessionName":sessionName,"isNewSession":isNewSession,"content":answer,"sessionId": req.sessionId})
+6
View File
@@ -0,0 +1,6 @@
import threading
# 后台操作
def async_db_task(func, *args, **kwargs):
"""将数据库操作放到后台线程执行"""
threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
+3
View File
@@ -1,10 +1,13 @@
fastapi==0.116.1
langchain==0.3.27
langchain_community==0.3.29
langchain_milvus==0.2.1
langchain_core==0.3.75
langchain_postgres==0.0.15
langchain_tavily==0.2.11
langgraph==0.6.6
langchain_openai===0.3.32
langchain-milvus===0.2.1
psycopg==3.2.9
psycopg_pool==3.2.6
pydantic==2.11.7
+18
View File
@@ -0,0 +1,18 @@
import dspy
lm = dspy.LM("openai/deepseek-chat", api_key="sk-6129a200ae294b9f86553505191fa477", api_base="https://api.deepseek.com")
dspy.configure(lm=lm)
# print(lm("Say this is a test!", temperature=0.7)) # => ['This is a test!']
# print(lm(messages=[{"role": "user", "content": "Say this is a test!"}])) # => ['This is a test!']
from typing import Literal
class Classify(dspy.Signature):
"""Classify sentiment of a given sentence."""
sentence: str = dspy.InputField()
sentiment: Literal["positive", "negative", "neutral"] = dspy.OutputField()
confidence: float = dspy.OutputField()
classify = dspy.Predict(Classify)
print(classify(sentence="This book was super fun to read, though not the last chapter."))
+126
View File
@@ -0,0 +1,126 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d029ad67",
"metadata": {},
"outputs": [],
"source": [
"from langchain_milvus import BM25BuiltInFunction, Milvus\n",
"from typing import List\n",
"URI = \"http://10.10.10.9:19530\"\n",
"tongyiKey = \"sk-9464b2498c184982a9fe9d2c2e725ab5\"\n",
"from langchain_community.embeddings import DashScopeEmbeddings\n",
"embeddings = DashScopeEmbeddings(\n",
" model=\"text-embedding-v3\",\n",
" dashscope_api_key= tongyiKey, \n",
")\n",
"memVectorstore = Milvus(\n",
" embedding_function=embeddings,\n",
" connection_args={\"uri\": URI, \"token\": \"root:Milvus\", \"db_name\": \"bbit_ai_lab\"},\n",
" collection_name=\"memory\",\n",
" index_params={\"index_type\": \"FLAT\", \"metric_type\": \"L2\"},\n",
" consistency_level=\"Strong\",\n",
" auto_id=True,\n",
"\n",
" primary_field = \"id\",\n",
" text_field=\"text\",\n",
" vector_field=\"vector\",\n",
" partition_key_field = \"ai_id\",\n",
" enable_dynamic_field = True,\n",
" drop_old=False, # set to True if seeking to drop the collection with that name if it exists\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "a480053b",
"metadata": {},
"outputs": [],
"source": [
"def get_memory_by_key_words(key_words: str, ai_ids: List[str]) -> str:\n",
" print(\"ai_id是:\" , ai_ids)\n",
" \"\"\"\n",
" 根据关键词和 ai_ids 列表,在知识库中检索相关内容,并返回整理后的文本字符串\n",
" \"\"\"\n",
" # 构建过滤表达式:只查 kn_ids 范围内的\n",
" if ai_ids:\n",
" ids_expr = \" or \".join([f'ai_id == \"{kid}\"' for kid in ai_ids])\n",
" expr = f\"({ids_expr})\"\n",
" else:\n",
" expr = \"\" # 不限制 kn_id todo 实际上应该不反悔任何内容\n",
" \n",
" result = knVectorstore.similarity_search(\n",
" query=key_words,\n",
" k=5, # 可调节返回条数\n",
" expr=expr\n",
" )\n",
" \n",
" # 整理成字符串\n",
" doc_texts = []\n",
" for idx, doc in enumerate(result, start=1):\n",
" text = doc.page_content.strip()\n",
" if text:\n",
" # 可以加个编号,便于LLM区分\n",
" doc_texts.append(f\"[记忆{idx}]: {text}\")\n",
" \n",
" # 拼成一个大字符串,用换行隔开\n",
" combined_text = \"\\n\\n\".join(doc_texts)\n",
" return combined_text"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "36759de5",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"ai_id是: ['3730f279-8b56-46ec-bde9-8a9e6c27f021']\n"
]
},
{
"ename": "NameError",
"evalue": "name 'knVectorstore' is not defined",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[3], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mget_memory_by_key_words\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m共育室 部署 地方\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43m3730f279-8b56-46ec-bde9-8a9e6c27f021\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
"Cell \u001b[0;32mIn[2], line 13\u001b[0m, in \u001b[0;36mget_memory_by_key_words\u001b[0;34m(key_words, ai_ids)\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 11\u001b[0m expr \u001b[38;5;241m=\u001b[39m \u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m\"\u001b[39m \u001b[38;5;66;03m# 不限制 kn_id todo 实际上应该不反悔任何内容\u001b[39;00m\n\u001b[0;32m---> 13\u001b[0m result \u001b[38;5;241m=\u001b[39m \u001b[43mknVectorstore\u001b[49m\u001b[38;5;241m.\u001b[39msimilarity_search(\n\u001b[1;32m 14\u001b[0m query\u001b[38;5;241m=\u001b[39mkey_words,\n\u001b[1;32m 15\u001b[0m k\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m5\u001b[39m, \u001b[38;5;66;03m# 可调节返回条数\u001b[39;00m\n\u001b[1;32m 16\u001b[0m expr\u001b[38;5;241m=\u001b[39mexpr\n\u001b[1;32m 17\u001b[0m )\n\u001b[1;32m 19\u001b[0m \u001b[38;5;66;03m# 整理成字符串\u001b[39;00m\n\u001b[1;32m 20\u001b[0m doc_texts \u001b[38;5;241m=\u001b[39m []\n",
"\u001b[0;31mNameError\u001b[0m: name 'knVectorstore' is not defined"
]
}
],
"source": [
"get_memory_by_key_words(\"共育室 部署 地方\",[\"3730f279-8b56-46ec-bde9-8a9e6c27f021\"])"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "lang",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+114
View File
@@ -0,0 +1,114 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 21,
"id": "d029ad67",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"[460823023525530114, 460823023525530115]"
]
},
"execution_count": 21,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_milvus import BM25BuiltInFunction, Milvus\n",
"URI = \"http://10.10.10.9:19530\"\n",
"tongyiKey = \"sk-9464b2498c184982a9fe9d2c2e725ab5\"\n",
"from langchain_community.embeddings import DashScopeEmbeddings\n",
"embeddings = DashScopeEmbeddings(\n",
" model=\"text-embedding-v3\",\n",
" dashscope_api_key= tongyiKey, \n",
")\n",
"vectorstore = Milvus(\n",
" embedding_function=embeddings,\n",
" connection_args={\"uri\": URI, \"token\": \"root:Milvus\", \"db_name\": \"bbit_ai_lab\"},\n",
" collection_name=\"knowledge\",\n",
" index_params={\"index_type\": \"FLAT\", \"metric_type\": \"L2\"},\n",
" consistency_level=\"Strong\",\n",
" auto_id=True,\n",
"\n",
" primary_field = \"id\",\n",
" text_field=\"text\",\n",
" vector_field=\"vector\",\n",
" partition_key_field = \"kn_id\",\n",
" enable_dynamic_field = True,\n",
" drop_old=False, # set to True if seeking to drop the collection with that name if it exists\n",
")\n",
"\n",
"from langchain.schema import Document\n",
"\n",
"docs = [\n",
" Document(\n",
" page_content=\"这是第一条文本\",\n",
" metadata={\n",
" \"kn_id\": \"8ecd1179-4194-4b80-bc39-5addc678df4b\",\n",
" \"is_active\": True,\n",
" }\n",
" ),\n",
" Document(\n",
" page_content=\"这是第二条文本\",\n",
" metadata={\n",
" \"kn_id\": \"8ecd1179-4194-4b80-bc39-5addc678df4b\",\n",
" \"is_active\": True,\n",
" }\n",
" )\n",
"]\n",
"\n",
"vectorstore.add_documents(docs)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "a480053b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"*这是第一条文本 [{'kn_id': '8ecd1179-4194-4b80-bc39-5addc678df4b', 'id': 460823023525530108, 'is_active': True}]\n",
"*这是第一条文本 [{'kn_id': '8ecd1179-4194-4b80-bc39-5addc678df4b', 'id': 460823023525530110, 'is_active': True}]\n"
]
}
],
"source": [
"results = vectorstore.similarity_search(\n",
" \"\",\n",
" k=2,\n",
" expr='kn_id == \"8ecd1179-4194-4b80-bc39-5addc678df4b\"',\n",
")\n",
"for res in results:\n",
" print(f\"*{res.page_content} [{res.metadata}]\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "lang",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
+260
View File
@@ -0,0 +1,260 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 11,
"id": "dfb008fd",
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from openai import OpenAI\n",
"from glob import glob\n",
"from pymilvus import MilvusClient\n",
"from tqdm import tqdm"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "eaa97ad1",
"metadata": {},
"outputs": [],
"source": [
"client = OpenAI(\n",
" api_key= \"sk-9464b2498c184982a9fe9d2c2e725ab5\", # 如果您没有配置环境变量,请在此处用您的API Key进行替换\n",
" base_url=\"https://dashscope.aliyuncs.com/compatible-mode/v1\" # 百炼服务的base_url\n",
")\n",
"def emb_text(text):\n",
" return client.embeddings.create(\n",
" model=\"text-embedding-v4\",\n",
" input=text,\n",
" dimensions=1024, # 指定向量维度(仅 text-embedding-v3及 text-embedding-v4支持该参数)\n",
" encoding_format=\"float\"\n",
" ).data[0].embedding"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "9df315ea",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"1024\n",
"[-0.017507297918200493, 0.02571254037320614, 0.02589302882552147, -0.02639283984899521, -0.013571279123425484, -0.0032158030662685633, -0.006428135093301535, 0.02458796463906765, -0.059366535395383835, 0.13083963096141815]\n"
]
}
],
"source": [
"# 测试\n",
"test_embedding = emb_text(\"This is a test\")\n",
"embedding_dim = len(test_embedding)\n",
"print(embedding_dim)\n",
"print(test_embedding[:10])\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95d0a121",
"metadata": {},
"outputs": [],
"source": [
"# Milvus数据库配置\n",
"milvus_client = MilvusClient(uri=\"http://10.10.10.9:19530\")\n",
"collection_name = \"my_rag_collection\"\n",
"embedding_dim = 1024\n",
"\n",
"if milvus_client.has_collection(collection_name):\n",
" milvus_client.drop_collection(collection_name)\n",
"milvus_client.create_collection(\n",
" collection_name=collection_name,\n",
" dimension=embedding_dim,\n",
" metric_type=\"IP\", # Inner product distance\n",
" consistency_level=\"Bounded\", # Supported values are (`\"Strong\"`, `\"Session\"`, `\"Bounded\"`, `\"Eventually\"`). See https://milvus.io/docs/consistency.md#Consistency-Level for more details.\n",
")"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "e09edfec",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Creating embeddings: 100%|██████████| 72/72 [00:11<00:00, 6.46it/s]\n"
]
},
{
"data": {
"text/plain": [
"{'insert_count': 72, 'ids': [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71], 'cost': 0}"
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# 从文件中插入数据\n",
"text_lines = []\n",
"for file_path in glob(\"milvus_docs/en/faq/*.md\", recursive=True):\n",
" with open(file_path, \"r\") as file:\n",
" file_text = file.read()\n",
"\n",
" text_lines += file_text.split(\"# \")\n",
"\n",
"data = []\n",
"\n",
"for i, line in enumerate(tqdm(text_lines, desc=\"Creating embeddings\")):\n",
" data.append({\"id\": i, \"vector\": emb_text(line), \"text\": line})\n",
"\n",
"milvus_client.insert(collection_name=collection_name, data=data)"
]
},
{
"cell_type": "code",
"execution_count": 25,
"id": "f3007553",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Milvus 是一个开源的向量数据库,主要用于高效地存储、管理和检索大规模的向量数据。它广泛应用于机器学习、推荐系统、图像识别等需要处理高维数据的场景。\n"
]
}
],
"source": [
"question = \"milvus是什么,用中文回答\"\n",
"search_res = milvus_client.search(\n",
" collection_name=collection_name,\n",
" data=[\n",
" emb_text(question)\n",
" ], # Use the `emb_text` function to convert the question to an embedding vector\n",
" limit=3, # Return top 3 results\n",
" search_params={\"metric_type\": \"IP\", \"params\": {}}, # Inner product distance\n",
" output_fields=[\"text\"], # Return the text field\n",
")\n",
"import json\n",
"# 获取答案\n",
"retrieved_lines_with_distances = [\n",
" (res[\"entity\"][\"text\"], res[\"distance\"]) for res in search_res[0]\n",
"]\n",
"context = \"\\n\".join(\n",
" [line_with_distance[0] for line_with_distance in retrieved_lines_with_distances]\n",
")\n",
"SYSTEM_PROMPT = \"\"\"\n",
"Human: You are an AI assistant. You are able to find answers to the questions from the contextual passage snippets provided.\n",
"\"\"\"\n",
"USER_PROMPT = f\"\"\"\n",
"Use the following pieces of information enclosed in <context> tags to provide an answer to the question enclosed in <question> tags.\n",
"<context>\n",
"{context}\n",
"</context>\n",
"<question>\n",
"{question}\n",
"</question>\n",
"\"\"\"\n",
"response = client.chat.completions.create(\n",
" model='qwen-turbo',\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": SYSTEM_PROMPT},\n",
" {\"role\": \"user\", \"content\": USER_PROMPT},\n",
" ],\n",
")\n",
"print(response.choices[0].message.content)\n",
"\n",
"\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "077922d1",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"2025-09-15 15:12:53,649 [ERROR][handler]: RPC error: [drop_database], <MilvusException: (code=65535, message=can not drop default database)>, <Time:{'RPC start': '2025-09-15 15:12:53.638539', 'RPC error': '2025-09-15 15:12:53.649605'}> (decorators.py:140)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Database 'default' already exists.\n",
"Collection 'my_rag_collection' has been dropped.\n",
"Collection 'bbit_ai_lab_knowledge' has been dropped.\n",
"An error occurred: <MilvusException: (code=65535, message=can not drop default database)>\n"
]
}
],
"source": [
"from pymilvus import Collection, MilvusException, connections, db, utility\n",
"\n",
"conn = connections.connect(host=\"10.10.10.9\", port=19530)\n",
"\n",
"# Check if the database exists\n",
"db_name = \"default\"\n",
"\n",
"try:\n",
" existing_databases = db.list_database()\n",
" if db_name in existing_databases:\n",
" print(f\"Database '{db_name}' already exists.\")\n",
"\n",
" # Use the database context\n",
" db.using_database(db_name)\n",
"\n",
" # Drop all collections in the database\n",
" collections = utility.list_collections()\n",
" for collection_name in collections:\n",
" collection = Collection(name=collection_name)\n",
" collection.drop()\n",
" print(f\"Collection '{collection_name}' has been dropped.\")\n",
"\n",
" db.drop_database(db_name)\n",
" print(f\"Database '{db_name}' has been deleted.\")\n",
" else:\n",
" print(f\"Database '{db_name}' does not exist.\")\n",
" database = db.create_database(db_name)\n",
" print(f\"Database '{db_name}' created successfully.\")\n",
"except MilvusException as e:\n",
" print(f\"An error occurred: {e}\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "lang",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.18"
}
},
"nbformat": 4,
"nbformat_minor": 5
}