更新python后端
This commit is contained in:
@@ -0,0 +1,76 @@
|
||||
|
||||
from models.ChatRequest import ChatRequest
|
||||
from models.ChatWithReportRequest import ChatWithReportRequest
|
||||
from models.BaseResponse import BaseResponse
|
||||
import uuid
|
||||
import db.postgres as pg
|
||||
import db.sqlserver as sqlserver
|
||||
import uuid
|
||||
import threading
|
||||
from fastapi import APIRouter, Depends
|
||||
from uuid import UUID
|
||||
from config.security import get_user_id_from_token
|
||||
botRouter = APIRouter()
|
||||
from llm.chatLLM import get_chat_response
|
||||
from llm.titleChain import get_title
|
||||
from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response
|
||||
from models.SaveReportRequest import SaveReportRequest
|
||||
from models.AIProfilesRequest import AIProfilesRequest
|
||||
import json
|
||||
|
||||
@botRouter.get("/aiListForService")
|
||||
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_all_ai_bot(user_id,"service"))
|
||||
@botRouter.get("/aiListForReport")
|
||||
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_all_ai_bot(user_id,"report"))
|
||||
|
||||
@botRouter.get("/aiListForBot")
|
||||
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_all_ai_bot(user_id,"bot"))
|
||||
|
||||
# 保存智能体
|
||||
@botRouter.post("/saveBot")
|
||||
def saveReportBot(bot: AIProfilesRequest,user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
print(bot)
|
||||
ai_personality = {
|
||||
"名字":bot.name,
|
||||
"性格":bot.role,
|
||||
"业务":bot.service,
|
||||
}
|
||||
ai_personality_json = json.dumps(ai_personality, ensure_ascii=False)
|
||||
available_report_tables_json = json.dumps(bot.available_report_tables, ensure_ascii=False)
|
||||
available_kn_bases_json = json.dumps(bot.available_kn_bases, ensure_ascii=False)
|
||||
|
||||
if bot.id:
|
||||
pg.update_bot(
|
||||
id = bot.id,
|
||||
title =bot.title,
|
||||
description = bot.description,
|
||||
welcome_words = bot.welcome_words,
|
||||
ai_personality = ai_personality_json,
|
||||
available_kn_bases = available_kn_bases_json,
|
||||
available_report_tables = available_report_tables_json,
|
||||
available_module = bot.available_module,
|
||||
user_id = user_id
|
||||
)
|
||||
else:
|
||||
pg.insert_bot(
|
||||
title =bot.title,
|
||||
description = bot.description,
|
||||
welcome_words = bot.welcome_words,
|
||||
ai_personality = ai_personality_json,
|
||||
available_kn_bases = available_kn_bases_json,
|
||||
available_module = bot.available_module,
|
||||
available_report_tables = available_report_tables_json,
|
||||
user_id = user_id
|
||||
)
|
||||
return BaseResponse(data= None)
|
||||
+27
-29
@@ -1,63 +1,61 @@
|
||||
from models.ChatRequest import ChatRequest
|
||||
from models.ChatWithReportRequest import ChatWithReportRequest
|
||||
from models.BaseResponse import BaseResponse
|
||||
import uuid
|
||||
import db.postgres as db
|
||||
import db.postgres as pg
|
||||
import db.sqlserver as sqlserver
|
||||
import uuid
|
||||
import threading
|
||||
from fastapi import APIRouter, Depends
|
||||
from uuid import UUID
|
||||
from config.security import get_user_id_from_token
|
||||
router = APIRouter()
|
||||
chatRouter = APIRouter()
|
||||
from agent.dataAgent import get_graph_output
|
||||
from llm.chatLLM import get_chat_response
|
||||
from llm.titleChain import get_title
|
||||
from llm.dataLLM import get_graph_output
|
||||
|
||||
# def async_db_task(func, *args, **kwargs):
|
||||
# """将数据库操作放到后台线程执行"""
|
||||
# threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
|
||||
|
||||
@router.get("/history")
|
||||
from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response
|
||||
from models.SaveReportRequest import SaveReportRequest
|
||||
from models.AIProfilesRequest import AIProfilesRequest
|
||||
import json
|
||||
# 对话历史记录
|
||||
@chatRouter.get("/history")
|
||||
def getHistory(sessionId: str):
|
||||
return BaseResponse(data=db.get_history(sessionId))
|
||||
return BaseResponse(data=pg.get_history(sessionId))
|
||||
|
||||
@router.get("/aiList")
|
||||
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=db.get_all_ai(user_id))
|
||||
|
||||
@router.get("/sessions")
|
||||
# 对话列表
|
||||
@chatRouter.get("/sessionsForBot")
|
||||
def getSessions(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=db.get_sessions(user_id))
|
||||
return BaseResponse(data=pg.get_sessions(user_id,'bot'))
|
||||
|
||||
@router.post("/chat")
|
||||
@chatRouter.post("/chatForBot")
|
||||
def chat(req: ChatRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
if not req.aiId:
|
||||
return {"error": "aiId is required"}
|
||||
sessionName = get_title(req.userInput)
|
||||
session_name = get_title(req.userInput)
|
||||
# 如果没有 sessionId 就新建
|
||||
if not req.sessionId:
|
||||
isNewSession = True
|
||||
is_new_session = True
|
||||
req.sessionId = str(uuid.uuid4())
|
||||
db.insert_session(user_id,req.aiId, req.sessionId, sessionName)
|
||||
pg.insert_session(user_id,req.aiId, req.sessionId, session_name,"bot")
|
||||
else:
|
||||
isNewSession = False
|
||||
db.update_session_updated_at(req.sessionId)
|
||||
is_new_session = False
|
||||
pg.update_session_updated_at(req.sessionId)
|
||||
|
||||
# 插入用户消息
|
||||
db.insert_message(req.sessionId, False, req.userInput)
|
||||
pg.insert_message(req.sessionId, False, req.userInput)
|
||||
|
||||
# 调用 LLM
|
||||
if req.aiId == "9d157dd1-921b-c768-5b90-3e903b50f6f9":
|
||||
# 数据专家AI
|
||||
answer = get_graph_output(aiRole=db.get_ai_personality(req.aiId),history=db.get_history(req.sessionId), userInput= req.userInput)
|
||||
answer = get_graph_output(aiRole=pg.get_ai_personality(req.aiId),history=pg.get_history(req.sessionId), userInput= req.userInput)
|
||||
else:
|
||||
answer = get_chat_response(aiRole=db.get_ai_personality(req.aiId),history=db.get_history(req.sessionId), userInput= req.userInput).content
|
||||
answer = get_chat_response(aiRole=pg.get_ai_personality(req.aiId),history=pg.get_history(req.sessionId), userInput= req.userInput).content
|
||||
# 插入 AI 回复
|
||||
db.insert_message(req.sessionId, True, answer)
|
||||
pg.insert_message(req.sessionId, True, answer)
|
||||
|
||||
return BaseResponse(data={"session_name":session_name,"isNewSession":is_new_session,"content":answer,"sessionId": req.sessionId})
|
||||
|
||||
return BaseResponse(data={"sessionName":sessionName,"isNewSession":isNewSession,"content":answer,"sessionId": req.sessionId})
|
||||
@@ -0,0 +1,54 @@
|
||||
from models.ChatRequest import ChatRequest
|
||||
from models.ChatWithReportRequest import ChatWithReportRequest
|
||||
from models.BaseResponse import BaseResponse
|
||||
import uuid
|
||||
import db.postgres as pg
|
||||
import db.sqlserver as sqlserver
|
||||
import uuid
|
||||
import threading
|
||||
from fastapi import APIRouter, Depends
|
||||
from uuid import UUID
|
||||
from config.security import get_user_id_from_token
|
||||
|
||||
reportDataRouter = APIRouter()
|
||||
from llm.chatLLM import get_chat_response
|
||||
from llm.titleChain import get_title
|
||||
from llm.sqlLLM import get_sql_description_response, get_chat_sql_response, get_chat_sql_improve_response
|
||||
from models.SaveReportRequest import SaveReportRequest
|
||||
from models.ReportTableAddRequest import ReportTableAddRequest
|
||||
from models.ReportFieldAddRequest import ReportFieldAddRequest
|
||||
|
||||
|
||||
# 获取表格列表
|
||||
@reportDataRouter.get("/tableList")
|
||||
def tableList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_available_tables())
|
||||
|
||||
|
||||
# 获取字段列表
|
||||
@reportDataRouter.get("/fieldList")
|
||||
def fieldList(tableId: str, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
if not tableId:
|
||||
return {"error": "tableId is required"}
|
||||
return BaseResponse(data=pg.get_fields_by_table_id(tableId))
|
||||
|
||||
|
||||
# 新增表
|
||||
@reportDataRouter.post("/addTable")
|
||||
def addTable(data: ReportTableAddRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.add_table(data.name, data.description, user_id))
|
||||
|
||||
|
||||
# 新增字段
|
||||
@reportDataRouter.post("/addField")
|
||||
def addField(data: ReportFieldAddRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(
|
||||
data=pg.add_field(data.name, data.type, data.description, data.is_active, data.table_id, user_id))
|
||||
@@ -0,0 +1,48 @@
|
||||
from models.ChatRequest import ChatRequest
|
||||
from models.ChatWithReportRequest import ChatWithReportRequest
|
||||
from models.BaseResponse import BaseResponse
|
||||
import uuid
|
||||
import db.postgres as pg
|
||||
import db.sqlserver as sqlserver
|
||||
import db.milvus as milvus
|
||||
import uuid
|
||||
import threading
|
||||
from fastapi import APIRouter, Depends
|
||||
from uuid import UUID
|
||||
from config.security import get_user_id_from_token
|
||||
knowledgeRouter = APIRouter()
|
||||
from llm.chatLLM import get_chat_response
|
||||
from llm.titleChain import get_title
|
||||
from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response
|
||||
from models.SaveReportRequest import SaveReportRequest
|
||||
from models.KnowledgeAddRequest import KnowledgeAddRequest
|
||||
from models.KnowledgeBaseAddRequest import KnowledgeBaseAddRequest
|
||||
# 获取知识库列表
|
||||
@knowledgeRouter.get("/knowledgeBaseListForService")
|
||||
def knowledge_base_list(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_available_knowledge_bases('service'))
|
||||
|
||||
# 新增知识库
|
||||
@knowledgeRouter.post("/addKnowledgeBase")
|
||||
def add_knowledge_base(data: KnowledgeBaseAddRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.add_knowledge_base(data.name, data.description, user_id))
|
||||
|
||||
# 获取知识列表
|
||||
@knowledgeRouter.get("/knowledgeList")
|
||||
def knowledge_list(knowledgeBaseId: str, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
if not knowledgeBaseId:
|
||||
return {"error": "knowledgeBaseId is required"}
|
||||
return BaseResponse(data=milvus.get_knowledge_by_base_id(knowledgeBaseId))
|
||||
|
||||
# 新增知识
|
||||
@knowledgeRouter.post("/addKnowledge")
|
||||
def add_knowledge(data: KnowledgeAddRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=milvus.add_knowledge(data.text, data.is_active, data.knowledge_base_id, user_id))
|
||||
@@ -0,0 +1,64 @@
|
||||
from models.ChatRequest import ChatRequest
|
||||
from models.ChatWithReportRequest import ChatWithReportRequest
|
||||
from models.BaseResponse import BaseResponse
|
||||
import uuid
|
||||
import db.postgres as pg
|
||||
import db.sqlserver as sqlserver
|
||||
import uuid
|
||||
import threading
|
||||
from fastapi import APIRouter, Depends
|
||||
from uuid import UUID
|
||||
from config.security import get_user_id_from_token
|
||||
reportRouter = APIRouter()
|
||||
from llm.chatLLM import get_chat_response
|
||||
from llm.titleChain import get_title
|
||||
from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response
|
||||
from models.SaveReportRequest import SaveReportRequest
|
||||
from agent.dbAgent import get_db_agent_reply
|
||||
|
||||
# 报表列表
|
||||
@reportRouter.get("/reports")
|
||||
def getReports(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_reports(user_id))
|
||||
|
||||
# 报表详情
|
||||
@reportRouter.get("/report")
|
||||
def getReports(reportId:str):
|
||||
return BaseResponse(data={
|
||||
"sql":pg.getSQL(reportId),
|
||||
"tableData":sqlserver.executeSQL(pg.getSQL(reportId)),
|
||||
})
|
||||
# 保存报表
|
||||
@reportRouter.post("/saveReport")
|
||||
def saveReport(report:SaveReportRequest,user_id: UUID = Depends(get_user_id_from_token)):
|
||||
sql = pg.getSQL(report.reportId)
|
||||
# 生成描述
|
||||
title = get_sql_description_response(sql = sql)
|
||||
res = pg.maked_report(report_id=report.reportId,title=title.content)
|
||||
return BaseResponse(data=res)
|
||||
|
||||
# 使用langgraph
|
||||
@reportRouter.post("/chatWithReport")
|
||||
def chat(req: ChatWithReportRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
# 获取reportId
|
||||
if not req.reportId:
|
||||
# 新报表
|
||||
sql = ""
|
||||
else:
|
||||
# 基于之前的报表
|
||||
sql = pg.getSQL(req.reportId)
|
||||
result = get_db_agent_reply(aiId=req.aiId, userInput = req.userInput,tenant_id=req.companyId, sql = sql)
|
||||
sqlRes = result.get("sql", "")
|
||||
newReportId = str(uuid.uuid4())
|
||||
pg.save_report(id = newReportId, user_id = user_id, title="尚未收藏", sql=sqlRes)
|
||||
|
||||
if sqlRes.strip() != "":
|
||||
tableData = sqlserver.executeSQL(sqlRes)
|
||||
else:
|
||||
tableData = []
|
||||
return BaseResponse(data={"content":result["reply"],"sql":sqlRes,"reportId": newReportId, "tableData": tableData})
|
||||
@reportRouter.get("/companyList")
|
||||
def companyList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
return BaseResponse(data=sqlserver.get_company_list(user_id))
|
||||
@@ -0,0 +1,47 @@
|
||||
from models.ChatRequest import ChatRequest
|
||||
from models.BaseResponse import BaseResponse
|
||||
import uuid
|
||||
import db.postgres as pg
|
||||
import uuid
|
||||
from fastapi import APIRouter, Depends
|
||||
from uuid import UUID
|
||||
from config.security import get_user_id_from_token
|
||||
serviceRouter = APIRouter()
|
||||
from llm.titleChain import get_title
|
||||
from agent.serviceAgent import get_service_agent_reply
|
||||
from llm.memLLM import take_memory
|
||||
import utils.MyUtils as utils
|
||||
# 对话列表
|
||||
@serviceRouter.get("/sessionsForService")
|
||||
def getSessions(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_sessions(user_id,'service'))
|
||||
|
||||
# 对话
|
||||
@serviceRouter.post("/chatForService")
|
||||
def chat(req: ChatRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
if not req.aiId:
|
||||
return {"error": "aiId is required"}
|
||||
sessionName = get_title(req.userInput)
|
||||
# 如果没有 sessionId 就新建
|
||||
if not req.sessionId:
|
||||
isNewSession = True
|
||||
req.sessionId = str(uuid.uuid4())
|
||||
pg.insert_session(user_id,req.aiId, req.sessionId, sessionName, "service")
|
||||
else:
|
||||
isNewSession = False
|
||||
pg.update_session_updated_at(req.sessionId)
|
||||
|
||||
# 插入用户消息
|
||||
pg.insert_message(req.sessionId, False, req.userInput)
|
||||
|
||||
answer = get_service_agent_reply(aiId=req.aiId,history=pg.get_history_with_time(req.sessionId,6), userInput= req.userInput,kn_bases=pg.get_ai_available_kn_bases(req.aiId))
|
||||
# 插入 AI 回复
|
||||
pg.insert_message(req.sessionId, True, answer)
|
||||
|
||||
# 异步执行:记忆判断
|
||||
utils.async_db_task(take_memory,req.aiId,req.sessionId,user_id,)
|
||||
return BaseResponse(data={"sessionName":sessionName,"isNewSession":isNewSession,"content":answer,"sessionId": req.sessionId})
|
||||
Reference in New Issue
Block a user