主干Ai实验室后端项目

This commit is contained in:
BBIT-Kai
2025-09-05 09:37:47 +08:00
parent aa25f914ab
commit 4a0e79b35a
25 changed files with 628 additions and 0 deletions
+20
View File
@@ -0,0 +1,20 @@
from fastapi import FastAPI
from routers.Chat import router
from fastapi.middleware.cors import CORSMiddleware
app = FastAPI(title="BBIT_AI")
origins = [
"http://localhost:5173", # Vite dev 默认端口
"http://127.0.0.1:5173",
"http://s1.ronsunny.cn:8089",
"*" # ⚠️ 生产环境不要用
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"], # 必须包含 OPTIONS、GET 等
allow_headers=["*"],
)
app.include_router(router, prefix="/api/llm", tags=["chat"])
View File
+6
View File
@@ -0,0 +1,6 @@
from langchain_community.chat_models.tongyi import ChatTongyi
from utils.Tools import all_tools
llm = ChatTongyi(streaming=False, api_key="sk-fb46eefb6b404382a0a5325202e923a6")
llm_with_tools = llm.bind_tools(all_tools)
+22
View File
@@ -0,0 +1,22 @@
import psycopg
from langchain_postgres import PostgresChatMessageHistory
from psycopg_pool import ConnectionPool
from contextlib import contextmanager
# conn = psycopg.connect("postgresql://postgres:123456@10.10.10.9/ktor2")
database_name = "ai_chat_history"
pool = ConnectionPool("postgresql://postgres:123456@10.10.10.9/ktor2")
@contextmanager
def getConn():
with pool.connection() as temp:
temp.autocommit = True # 如果你想所有连接默认 autocommit
yield temp # 把 conn 暴露给外部使用
def init():
with getConn() as connection:
PostgresChatMessageHistory.create_tables(connection, database_name)
init()
+39
View File
@@ -0,0 +1,39 @@
import jwt
from jwt import PyJWTError
from uuid import UUID
from fastapi import Header, HTTPException, Depends
JWT_SECRET = "secret_jwt"
JWT_ALGORITHM = "HS256"
JWT_AUDIENCE = "snowflake-ink"
JWT_ISSUER = "https://snowflake.ink/"
def get_user_id_from_token(token: str = Header(..., alias="Authorization")) -> UUID:
"""
从 Authorization 头解析 token,并返回 user_id
假设前端传 Authorization: Bearer <token>
"""
if token.startswith("Bearer "):
token = token[7:]
else:
raise HTTPException(status_code=401, detail="Invalid token format")
try:
payload = jwt.decode(
token,
JWT_SECRET,
algorithms=[JWT_ALGORITHM],
audience=JWT_AUDIENCE,
issuer=JWT_ISSUER
)
except PyJWTError:
raise HTTPException(status_code=401, detail="Token is missing or invalid")
if payload.get("token_type") != "access_token":
raise HTTPException(status_code=401, detail="Invalid token type")
user_id = payload.get("user_id")
if not user_id:
raise HTTPException(status_code=401, detail="User ID not found in token")
return UUID(user_id)
+12
View File
@@ -0,0 +1,12 @@
from langchain_community.utilities import SQLDatabase
from langchain_community.agent_toolkits import create_sql_agent
from langchain_community.chat_models.tongyi import ChatTongyi
# SQLAlchemy URI
uri = "mssql+pyodbc://f8_db_test:APN^QPr!K9@122.114.58.23/f8_db_test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes"
# 建立数据库对象
ssDB = SQLDatabase.from_uri(
uri,
include_tables=["NONGHU_INFO"], # 不加 schema 前缀试试
schema="dbo" # 显式指定 schema
)
View File
+115
View File
@@ -0,0 +1,115 @@
import psycopg
from langchain_postgres import PostgresChatMessageHistory
from config.pgDb import database_name,getConn
from typing import List, Dict
# ————————————————————————————————————————————————————AI角色———————————————————————————————
def get_ai_personality(ai_id: str):
with getConn() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT ai_personality FROM ai_chat_profiles WHERE id = %s",
(ai_id,)
)
row = cur.fetchone()
if row:
return row[0]
else:
return "你是一个乐于助人的AI助手,请保持中文简洁回答用户。"
def get_all_ai(user_id: str) -> List[Dict]:
with getConn() as conn:
with conn.cursor() as cur:
# 查询用户角色
cur.execute(
"SELECT roles FROM users WHERE id = %s",
(user_id,)
)
role_row = cur.fetchone()
if not role_row:
return [] # 用户不存在
user_roles = role_row[0]
# 查询 AI 角色 JSON 字段包含用户角色
cur.execute(
"""
SELECT id, name, welcome_words
FROM ai_chat_profiles
WHERE availabel_roles::jsonb ?| %s
""",
(user_roles,) # user_roles 是 list,比如 ["a", "b", "c"]
)
rows = cur.fetchall()
return [
{
"id": row[0],
"name": row[1],
"welcomeWords": row[2],
}
for row in rows
]
# ————————————————————————————————————————————————————消息———————————————————————————————
def insert_message(session_id: str, isAI: bool, content: str):
with getConn() as conn:
history = PostgresChatMessageHistory(
database_name,
session_id,
sync_connection=conn
)
if isAI:
history.add_ai_message(content)
else:
history.add_user_message(content)
def get_history(session_id: str):
with getConn() as conn:
history = PostgresChatMessageHistory(
database_name,
session_id,
sync_connection=conn
)
simplified = []
for msg in history.messages:
simplified.append({
"type": msg.type,
"content": msg.content
})
return simplified
# ————————————————————————————————————————————————————会话———————————————————————————————
def insert_session(user_id: str,ai_id:str, session_id: str,session_title: str):
with getConn() as coon:
with coon.cursor() as cur:
cur.execute(
"INSERT INTO ai_chat_sessions (id ,user_id, ai_id, title, created_at, updated_at) VALUES (%s, %s, %s, %s, NOW(), NOW())",
(session_id, user_id, ai_id, session_title )
)
coon.commit()
def update_session_updated_at(session_id: str):
with getConn() as conn:
with conn.cursor() as cur:
cur.execute(
"UPDATE ai_chat_sessions SET updated_at = NOW() WHERE id = %s",
(session_id,)
)
conn.commit()
def get_sessions(user_id: str):
with getConn() as conn:
with conn.cursor() as cur:
cur.execute(
"SELECT id, title, updated_at FROM ai_chat_sessions WHERE user_id = %s ORDER BY updated_at DESC",
(user_id,)
)
sessions = cur.fetchall()
return [
{
"id": row[0],
"title": row[1],
"updated_at": row[2]
}
for row in sessions
]
View File
+27
View File
@@ -0,0 +1,27 @@
from config.llm import llm
from langchain.prompts import PromptTemplate
chatPrompt = PromptTemplate(
input_variables=["aiRole", "history", "userInput"],
template = """
你是一个人,用户画像为:{aiRole}
你需要基于你的角色性格,使用中文回答用户。
聊天历史:
{history}
用户最新输入:
{userInput}
最后,请注意,不要编造数据,不知道就说不知道,现在,请生成你的回复:
"""
)
chatChain = chatPrompt | llm
def get_chat_response(aiRole: str,history: str, userInput: str) -> str:
return chatChain.invoke({
"aiRole": aiRole,
"history": history,
"userInput": userInput
})
+139
View File
@@ -0,0 +1,139 @@
from langchain.prompts import PromptTemplate
from config.llm import llm
from config.ssDb import ssDB
from typing import Annotated
from typing_extensions import TypedDict
from langgraph.graph import StateGraph, START, END
from langgraph.graph.message import add_messages
from langchain_community.agent_toolkits import create_sql_agent
from langchain.prompts import PromptTemplate
from config.llm import llm
from config.ssDb import ssDB
from typing import Annotated
from langgraph.graph.message import add_messages
import os
from langchain_tavily import TavilySearch
from langgraph.prebuilt import ToolNode, tools_condition
from llm.chatLLM import get_chat_response
from typing import TypedDict
from langgraph.graph import StateGraph, END
from llm.summarizeLLM import getSummary
# -------- 定义状态 --------
class State(TypedDict):
userInput: str # 用户输入
source: str # 选择的数据来源:web 或 db 或 chat
infomation: str # 查询到的内容
aiRole: str # AI 角色
history: str # 聊天历史
reply: str # 最终回复
# -------- 定义节点 --------
# ------------------------------------------------------------------------ 路径选择 --------
pathSelectPrompt = PromptTemplate(
input_variables=["aiRole", "history", "userStr", "infomation"],
template = """
你是主干信息科技有限公司的业务员,是一家蚕桑服务公司,现在需要根据用户输入来判断应该使用哪种方式来回答用户的问题。
你有三种选择:
1. 如果用户的问题涉及最新的信息,比如新闻、事件、天气等涉及时间的内容时,请选择 "web
2. 如果用户的问题涉及具体的蚕桑业务(例如询问农户、订单、订种、租户)的数据库查询需求,请选择 "db"
3. 如果用户的问题是一般性的聊天或咨询,请选择 "chat"
请只返回 "web""db""chat" 之一,且不要添加任何其他解释。
用户最新输入:
{userStr}
请做出你的选择:
"""
)
pathSelectChain = pathSelectPrompt | llm
def decide_source(state: State, max_retry=3):
print("根据用户输入选择数据来源,用户输入:", state["userInput"])
"""根据用户输入选择数据来源"""
for _ in range(max_retry):
choice = pathSelectChain.invoke({
"aiRole": state["aiRole"],
"history": state["history"],
"userStr": state["userInput"],
}).content.strip().lower()
if choice in ["web", "db", "chat"]:
state["source"] = choice
break
else:
# 如果连续 max_retry 次都不合法,默认走 chat
state["source"] = "chat"
print("选择的数据来源是:", state["source"])
return state
# ------------------------------------------------------------------------ 上网查询 --------
os.environ["TAVILY_API_KEY"] = "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"
tool = TavilySearch(max_results=2)
def fetch_web(state: State):
result = tool.invoke(state["userInput"])
state["infomation"] = result.get("content") or result
print("调用了联网工具,结果是:", state["infomation"])
return state
# ------------------------------------------------------------------------ 数据库查询 --------
agent = create_sql_agent(
llm=llm,
db=ssDB,
agent_type="tool-calling",
verbose=True
)
def fetch_db(state: State):
state["infomation"] = agent.invoke({"input": state["userInput"]})["output"]
print("调用了数据库工具,结果是:", state["infomation"])
return state
# ------------------------------------------------------------------------ 整理结果 --------
def summarize_ai(state: State):
"""AI 总结输出"""
state["reply"] = getSummary(aiRole=state["aiRole"], history=state["history"], userInput= state["userInput"], infomation= state["infomation"])
return state
# ------------------------------------------------------------------------ 普通聊天 --------
def chat(state: State):
state["reply"] = get_chat_response(aiRole=state["aiRole"],history=state["history"], userInput= state["userInput"]).content
print("直接回复")
return state
# ------------------------------------------------------------------------ 构建有向图 --------
workflow = StateGraph(State)
workflow.add_node("decide", decide_source)
workflow.add_node("fetch_web", fetch_web)
workflow.add_node("fetch_db", fetch_db)
workflow.add_node("chat", chat)
workflow.add_node("summarize", summarize_ai)
workflow.set_entry_point("decide")
# 两条路径最后都汇合到 summarize
workflow.add_edge(START, "decide")
workflow.add_edge("fetch_web", "summarize")
workflow.add_edge("fetch_db", "summarize")
# 条件边:根据 source 决定走向
workflow.add_conditional_edges(
"decide",
lambda state: state["source"], # 返回 state["source"] 的值
{
"web": "fetch_web",
"chat": "chat",
"db": "fetch_db"
}
)
workflow.add_edge("summarize", END)
workflow.add_edge("chat", END)
graph = workflow.compile()
# 执行函数
def get_graph_output(aiRole:str,history: str, userInput: str) -> str:
final_state = graph.invoke({
"aiRole":aiRole,
"history": history,
"userInput": userInput,
})
return final_state["reply"]
+34
View File
@@ -0,0 +1,34 @@
from langchain.prompts import PromptTemplate
from config.llm import llm
summarizePrompt = PromptTemplate(
input_variables=["aiRole", "history", "userStr", "infomation"],
template = """
你是一个主干信息研发的 AI 助手,用户画像为:{aiRole}
请基于你的角色性格,保持中文简洁回答的,根据下方提示回答用户。
聊天历史:
···
{history}
···
用户最新输入:
···
{userStr}
···
给你的参考内容:
···
{infomation}
···
如果参考内容明显有问题,你要请用户重新描述问题,现在请生成你的回复:
"""
)
summarizeChain = summarizePrompt | llm
def getSummary(aiRole: str, history: str, userInput: str, infomation: str) -> str:
return summarizeChain.invoke({
"aiRole":aiRole,
"history": history,
"userStr": userInput,
"infomation": infomation
}).content
+19
View File
@@ -0,0 +1,19 @@
from langchain.prompts import PromptTemplate
from config.llm import llm
titlePrompt = PromptTemplate(
input_variables=["userStr"],
template = """
请将用户的这句话总结成一个简短、精准的对话标题,要求:
1. 不超过10个字(可根据需要调整长度)。
2. 直接概括本次对话的核心内容。
3. 避免使用笼统或无意义的词语,如“讨论”、“聊天”等。
4. 保持自然、易懂、专业或有趣(可根据场景调整风格)。
用户原话:"{userStr}"
"""
)
titleChain = titlePrompt | llm
def get_title(userInput: str):
return titleChain.invoke({"userStr": userInput}).content
+11
View File
@@ -0,0 +1,11 @@
from fastapi import FastAPI
from pydantic import BaseModel
from typing import Generic, TypeVar, Optional, List
from pydantic.generics import GenericModel
T = TypeVar("T")
# 定义通用响应结构
class BaseResponse(GenericModel, Generic[T]):
status: bool = True
message: str = "操作成功"
data: Optional[T] = None
+6
View File
@@ -0,0 +1,6 @@
from pydantic import BaseModel
class ChatRequest(BaseModel):
aiId: str
sessionId: str | None = None
userInput: str
View File
+63
View File
@@ -0,0 +1,63 @@
from models.ChatRequest import ChatRequest
from models.BaseResponse import BaseResponse
import uuid
import db.postgres as db
import uuid
import threading
from fastapi import APIRouter, Depends
from uuid import UUID
from config.security import get_user_id_from_token
router = APIRouter()
from llm.chatLLM import get_chat_response
from llm.titleChain import get_title
from llm.dataLLM import get_graph_output
# def async_db_task(func, *args, **kwargs):
# """将数据库操作放到后台线程执行"""
# threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
@router.get("/history")
def getHistory(sessionId: str):
return BaseResponse(data=db.get_history(sessionId))
@router.get("/aiList")
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(data=db.get_all_ai(user_id))
@router.get("/sessions")
def getSessions(user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
return BaseResponse(data=db.get_sessions(user_id))
@router.post("/chat")
def chat(req: ChatRequest, user_id: UUID = Depends(get_user_id_from_token)):
if not user_id:
return {"error": "userId is required"}
if not req.aiId:
return {"error": "aiId is required"}
sessionName = get_title(req.userInput)
# 如果没有 sessionId 就新建
if not req.sessionId:
isNewSession = True
req.sessionId = str(uuid.uuid4())
db.insert_session(user_id,req.aiId, req.sessionId, sessionName)
else:
isNewSession = False
db.update_session_updated_at(req.sessionId)
# 插入用户消息
db.insert_message(req.sessionId, False, req.userInput)
# 调用 LLM
if req.aiId == "9d157dd1-921b-c768-5b90-3e903b50f6f9":
# 数据专家AI
answer = get_graph_output(aiRole=db.get_ai_personality(req.aiId),history=db.get_history(req.sessionId), userInput= req.userInput)
else:
answer = get_chat_response(aiRole=db.get_ai_personality(req.aiId),history=db.get_history(req.sessionId), userInput= req.userInput).content
# 插入 AI 回复
db.insert_message(req.sessionId, True, answer)
return BaseResponse(data={"sessionName":sessionName,"isNewSession":isNewSession,"content":answer,"sessionId": req.sessionId})
View File
+15
View File
@@ -0,0 +1,15 @@
from langchain_core.tools import tool
@tool
def add(a: int, b: int) -> int:
"""Adds a and b."""
return a + b
@tool
def multiply(a: int, b: int) -> int:
"""Multiplies a and b."""
return a * b
all_tools = [add, multiply]
View File