65 lines
2.4 KiB
Python
65 lines
2.4 KiB
Python
from models.ChatRequest import ChatRequest
|
|
from models.ChatWithReportRequest import ChatWithReportRequest
|
|
from models.BaseResponse import BaseResponse
|
|
import uuid
|
|
import db.postgres as pg
|
|
import db.sqlserver as sqlserver
|
|
import uuid
|
|
import threading
|
|
from fastapi import APIRouter, Depends
|
|
from uuid import UUID
|
|
from config.security import get_user_id_from_token
|
|
reportRouter = APIRouter()
|
|
from llm.chatLLM import get_chat_response
|
|
from llm.titleChain import get_title
|
|
from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response
|
|
from models.SaveReportRequest import SaveReportRequest
|
|
from agent.dbAgent import get_db_agent_reply
|
|
|
|
# 报表列表
|
|
@reportRouter.get("/reports")
|
|
def getReports(user_id: UUID = Depends(get_user_id_from_token)):
|
|
if not user_id:
|
|
return {"error": "userId is required"}
|
|
return BaseResponse(data=pg.get_reports(user_id))
|
|
|
|
# 报表详情
|
|
@reportRouter.get("/report")
|
|
def getReports(reportId:str):
|
|
return BaseResponse(data={
|
|
"sql":pg.getSQL(reportId),
|
|
"tableData":sqlserver.executeSQL(pg.getSQL(reportId)),
|
|
})
|
|
# 保存报表
|
|
@reportRouter.post("/saveReport")
|
|
def saveReport(report:SaveReportRequest,user_id: UUID = Depends(get_user_id_from_token)):
|
|
sql = pg.getSQL(report.reportId)
|
|
# 生成描述
|
|
title = get_sql_description_response(sql = sql)
|
|
res = pg.maked_report(report_id=report.reportId,title=title.content)
|
|
return BaseResponse(data=res)
|
|
|
|
# 使用langgraph
|
|
@reportRouter.post("/chatWithReport")
|
|
def chat(req: ChatWithReportRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
|
# 获取reportId
|
|
if not req.reportId:
|
|
# 新报表
|
|
sql = ""
|
|
else:
|
|
# 基于之前的报表
|
|
sql = pg.getSQL(req.reportId)
|
|
result = get_db_agent_reply(aiId=req.aiId, userInput = req.userInput,tenant_id=req.companyId, sql = sql)
|
|
sqlRes = result.get("sql", "")
|
|
newReportId = str(uuid.uuid4())
|
|
pg.save_report(id = newReportId, user_id = user_id, title="尚未收藏", sql=sqlRes)
|
|
|
|
if sqlRes.strip() != "":
|
|
tableData = sqlserver.executeSQL(sqlRes)
|
|
else:
|
|
tableData = []
|
|
return BaseResponse(data={"content":result["reply"],"sql":sqlRes,"reportId": newReportId, "tableData": tableData})
|
|
@reportRouter.get("/companyList")
|
|
def companyList(user_id: UUID = Depends(get_user_id_from_token)):
|
|
return BaseResponse(data=sqlserver.get_company_list(user_id))
|