from models.ChatRequest import ChatRequest from models.ChatWithReportRequest import ChatWithReportRequest from models.BaseResponse import BaseResponse import uuid import db.postgres as pg import db.sqlserver as sqlserver import uuid import threading from fastapi import APIRouter, Depends from uuid import UUID from config.security import get_user_id_from_token reportRouter = APIRouter() from llm.chatLLM import get_chat_response from llm.titleChain import get_title from llm.sqlLLM import get_sql_description_response,get_chat_sql_response,get_chat_sql_improve_response from models.SaveReportRequest import SaveReportRequest from agent.dbAgent import get_db_agent_reply # 报表列表 @reportRouter.get("/reports") def getReports(user_id: UUID = Depends(get_user_id_from_token)): if not user_id: return {"error": "userId is required"} return BaseResponse(data=pg.get_reports(user_id)) # 报表详情 @reportRouter.get("/report") def getReports(reportId:str): return BaseResponse(data={ "sql":pg.getSQL(reportId), "tableData":sqlserver.executeSQL(pg.getSQL(reportId)), }) # 保存报表 @reportRouter.post("/saveReport") def saveReport(report:SaveReportRequest,user_id: UUID = Depends(get_user_id_from_token)): sql = pg.getSQL(report.reportId) # 生成描述 title = get_sql_description_response(sql = sql) res = pg.maked_report(report_id=report.reportId,title=title.content) return BaseResponse(data=res) # 使用langgraph @reportRouter.post("/chatWithReport") def chat(req: ChatWithReportRequest, user_id: UUID = Depends(get_user_id_from_token)): # 获取reportId if not req.reportId: # 新报表 sql = "" else: # 基于之前的报表 sql = pg.getSQL(req.reportId) result = get_db_agent_reply(aiId=req.aiId, userInput = req.userInput,tenant_id=req.companyId, sql = sql) sqlRes = result.get("sql", "") newReportId = str(uuid.uuid4()) pg.save_report(id = newReportId, user_id = user_id, title="尚未收藏", sql=sqlRes) if sqlRes.strip() != "": tableData = sqlserver.executeSQL(sqlRes) else: tableData = [] return BaseResponse(data={"content":result["reply"],"sql":sqlRes,"reportId": newReportId, "tableData": tableData}) @reportRouter.get("/companyList") def companyList(user_id: UUID = Depends(get_user_id_from_token)): return BaseResponse(data=sqlserver.get_company_list(user_id))