更新python后端
This commit is contained in:
@@ -0,0 +1,64 @@
|
||||
from models.ChatRequest import ChatRequest
|
||||
from models.ChatWithReportRequest import ChatWithReportRequest
|
||||
from models.BaseResponse import BaseResponse
|
||||
import uuid
|
||||
import db.postgres as pg
|
||||
import db.sqlserver as sqlserver
|
||||
import uuid
|
||||
import threading
|
||||
from fastapi import APIRouter, Depends
|
||||
from uuid import UUID
|
||||
from config.security import get_user_id_from_token
|
||||
reportRouter = APIRouter()
|
||||
from llm.chatLLM import get_chat_response
|
||||
from llm.titleChain import get_title
|
||||
from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response
|
||||
from models.SaveReportRequest import SaveReportRequest
|
||||
from agent.dbAgent import get_db_agent_reply
|
||||
|
||||
# 报表列表
|
||||
@reportRouter.get("/reports")
|
||||
def getReports(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=pg.get_reports(user_id))
|
||||
|
||||
# 报表详情
|
||||
@reportRouter.get("/report")
|
||||
def getReports(reportId:str):
|
||||
return BaseResponse(data={
|
||||
"sql":pg.getSQL(reportId),
|
||||
"tableData":sqlserver.executeSQL(pg.getSQL(reportId)),
|
||||
})
|
||||
# 保存报表
|
||||
@reportRouter.post("/saveReport")
|
||||
def saveReport(report:SaveReportRequest,user_id: UUID = Depends(get_user_id_from_token)):
|
||||
sql = pg.getSQL(report.reportId)
|
||||
# 生成描述
|
||||
title = get_sql_description_response(sql = sql)
|
||||
res = pg.maked_report(report_id=report.reportId,title=title.content)
|
||||
return BaseResponse(data=res)
|
||||
|
||||
# 使用langgraph
|
||||
@reportRouter.post("/chatWithReport")
|
||||
def chat(req: ChatWithReportRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
# 获取reportId
|
||||
if not req.reportId:
|
||||
# 新报表
|
||||
sql = ""
|
||||
else:
|
||||
# 基于之前的报表
|
||||
sql = pg.getSQL(req.reportId)
|
||||
result = get_db_agent_reply(aiId=req.aiId, userInput = req.userInput,tenant_id=req.companyId, sql = sql)
|
||||
sqlRes = result.get("sql", "")
|
||||
newReportId = str(uuid.uuid4())
|
||||
pg.save_report(id = newReportId, user_id = user_id, title="尚未收藏", sql=sqlRes)
|
||||
|
||||
if sqlRes.strip() != "":
|
||||
tableData = sqlserver.executeSQL(sqlRes)
|
||||
else:
|
||||
tableData = []
|
||||
return BaseResponse(data={"content":result["reply"],"sql":sqlRes,"reportId": newReportId, "tableData": tableData})
|
||||
@reportRouter.get("/companyList")
|
||||
def companyList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
return BaseResponse(data=sqlserver.get_company_list(user_id))
|
||||
Reference in New Issue
Block a user