主干Ai实验室后端项目
This commit is contained in:
@@ -0,0 +1,63 @@
|
||||
from models.ChatRequest import ChatRequest
|
||||
from models.BaseResponse import BaseResponse
|
||||
import uuid
|
||||
import db.postgres as db
|
||||
import uuid
|
||||
import threading
|
||||
from fastapi import APIRouter, Depends
|
||||
from uuid import UUID
|
||||
from config.security import get_user_id_from_token
|
||||
router = APIRouter()
|
||||
from llm.chatLLM import get_chat_response
|
||||
from llm.titleChain import get_title
|
||||
from llm.dataLLM import get_graph_output
|
||||
|
||||
# def async_db_task(func, *args, **kwargs):
|
||||
# """将数据库操作放到后台线程执行"""
|
||||
# threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
|
||||
|
||||
@router.get("/history")
|
||||
def getHistory(sessionId: str):
|
||||
return BaseResponse(data=db.get_history(sessionId))
|
||||
|
||||
@router.get("/aiList")
|
||||
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=db.get_all_ai(user_id))
|
||||
|
||||
@router.get("/sessions")
|
||||
def getSessions(user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
return BaseResponse(data=db.get_sessions(user_id))
|
||||
|
||||
@router.post("/chat")
|
||||
def chat(req: ChatRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||
if not user_id:
|
||||
return {"error": "userId is required"}
|
||||
if not req.aiId:
|
||||
return {"error": "aiId is required"}
|
||||
sessionName = get_title(req.userInput)
|
||||
# 如果没有 sessionId 就新建
|
||||
if not req.sessionId:
|
||||
isNewSession = True
|
||||
req.sessionId = str(uuid.uuid4())
|
||||
db.insert_session(user_id,req.aiId, req.sessionId, sessionName)
|
||||
else:
|
||||
isNewSession = False
|
||||
db.update_session_updated_at(req.sessionId)
|
||||
|
||||
# 插入用户消息
|
||||
db.insert_message(req.sessionId, False, req.userInput)
|
||||
|
||||
# 调用 LLM
|
||||
if req.aiId == "9d157dd1-921b-c768-5b90-3e903b50f6f9":
|
||||
# 数据专家AI
|
||||
answer = get_graph_output(aiRole=db.get_ai_personality(req.aiId),history=db.get_history(req.sessionId), userInput= req.userInput)
|
||||
else:
|
||||
answer = get_chat_response(aiRole=db.get_ai_personality(req.aiId),history=db.get_history(req.sessionId), userInput= req.userInput).content
|
||||
# 插入 AI 回复
|
||||
db.insert_message(req.sessionId, True, answer)
|
||||
|
||||
return BaseResponse(data={"sessionName":sessionName,"isNewSession":isNewSession,"content":answer,"sessionId": req.sessionId})
|
||||
Reference in New Issue
Block a user