主干Ai实验室后端项目
This commit is contained in:
@@ -0,0 +1,37 @@
|
|||||||
|
# 使用官方 Python 镜像
|
||||||
|
FROM python:3.10-slim
|
||||||
|
|
||||||
|
# 设置工作目录
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# 复制依赖文件
|
||||||
|
COPY requirements.txt .
|
||||||
|
# 更新系统源,安装 PostgreSQL 和 ODBC 依赖,以及微软 SQL Server 驱动
|
||||||
|
# 安装基础依赖和 Microsoft ODBC 驱动依赖
|
||||||
|
RUN apt-get update && \
|
||||||
|
apt-get install -y --no-install-recommends \
|
||||||
|
libpq5 \
|
||||||
|
unixodbc \
|
||||||
|
curl \
|
||||||
|
gnupg \
|
||||||
|
apt-transport-https \
|
||||||
|
lsb-release && \
|
||||||
|
# 导入微软 GPG key(使用 keyrings 方式)
|
||||||
|
curl https://packages.microsoft.com/keys/microsoft.asc | gpg --dearmor > /usr/share/keyrings/microsoft.gpg && \
|
||||||
|
echo "deb [arch=amd64 signed-by=/usr/share/keyrings/microsoft.gpg] https://packages.microsoft.com/ubuntu/22.04/prod jammy main" > /etc/apt/sources.list.d/mssql-release.list && \
|
||||||
|
apt-get update && \
|
||||||
|
ACCEPT_EULA=Y apt-get install -y msodbcsql18 && \
|
||||||
|
rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
|
||||||
|
# 安装 Python 依赖
|
||||||
|
RUN pip install --no-cache-dir -i https://pypi.tuna.tsinghua.edu.cn/simple -r requirements.txt
|
||||||
|
|
||||||
|
# 复制项目代码
|
||||||
|
COPY app/ .
|
||||||
|
|
||||||
|
# 对外暴露端口(FastAPI 默认 8000)
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# 启动命令(使用 uvicorn 启动 FastAPI)
|
||||||
|
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "13011"]
|
||||||
@@ -0,0 +1,20 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from routers.Chat import router
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
app = FastAPI(title="BBIT_AI")
|
||||||
|
|
||||||
|
origins = [
|
||||||
|
"http://localhost:5173", # Vite dev 默认端口
|
||||||
|
"http://127.0.0.1:5173",
|
||||||
|
"http://s1.ronsunny.cn:8089",
|
||||||
|
"*" # ⚠️ 生产环境不要用
|
||||||
|
]
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=origins,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"], # 必须包含 OPTIONS、GET 等
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
app.include_router(router, prefix="/api/llm", tags=["chat"])
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
|
||||||
|
from langchain_community.chat_models.tongyi import ChatTongyi
|
||||||
|
from utils.Tools import all_tools
|
||||||
|
|
||||||
|
llm = ChatTongyi(streaming=False, api_key="sk-fb46eefb6b404382a0a5325202e923a6")
|
||||||
|
llm_with_tools = llm.bind_tools(all_tools)
|
||||||
@@ -0,0 +1,22 @@
|
|||||||
|
|
||||||
|
import psycopg
|
||||||
|
from langchain_postgres import PostgresChatMessageHistory
|
||||||
|
from psycopg_pool import ConnectionPool
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
# conn = psycopg.connect("postgresql://postgres:123456@10.10.10.9/ktor2")
|
||||||
|
database_name = "ai_chat_history"
|
||||||
|
pool = ConnectionPool("postgresql://postgres:123456@10.10.10.9/ktor2")
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def getConn():
|
||||||
|
with pool.connection() as temp:
|
||||||
|
temp.autocommit = True # 如果你想所有连接默认 autocommit
|
||||||
|
yield temp # 把 conn 暴露给外部使用
|
||||||
|
|
||||||
|
|
||||||
|
def init():
|
||||||
|
with getConn() as connection:
|
||||||
|
PostgresChatMessageHistory.create_tables(connection, database_name)
|
||||||
|
|
||||||
|
init()
|
||||||
@@ -0,0 +1,39 @@
|
|||||||
|
import jwt
|
||||||
|
from jwt import PyJWTError
|
||||||
|
from uuid import UUID
|
||||||
|
from fastapi import Header, HTTPException, Depends
|
||||||
|
|
||||||
|
JWT_SECRET = "secret_jwt"
|
||||||
|
JWT_ALGORITHM = "HS256"
|
||||||
|
JWT_AUDIENCE = "snowflake-ink"
|
||||||
|
JWT_ISSUER = "https://snowflake.ink/"
|
||||||
|
|
||||||
|
def get_user_id_from_token(token: str = Header(..., alias="Authorization")) -> UUID:
|
||||||
|
"""
|
||||||
|
从 Authorization 头解析 token,并返回 user_id
|
||||||
|
假设前端传 Authorization: Bearer <token>
|
||||||
|
"""
|
||||||
|
if token.startswith("Bearer "):
|
||||||
|
token = token[7:]
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token format")
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
JWT_SECRET,
|
||||||
|
algorithms=[JWT_ALGORITHM],
|
||||||
|
audience=JWT_AUDIENCE,
|
||||||
|
issuer=JWT_ISSUER
|
||||||
|
)
|
||||||
|
except PyJWTError:
|
||||||
|
raise HTTPException(status_code=401, detail="Token is missing or invalid")
|
||||||
|
|
||||||
|
if payload.get("token_type") != "access_token":
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid token type")
|
||||||
|
|
||||||
|
user_id = payload.get("user_id")
|
||||||
|
if not user_id:
|
||||||
|
raise HTTPException(status_code=401, detail="User ID not found in token")
|
||||||
|
|
||||||
|
return UUID(user_id)
|
||||||
@@ -0,0 +1,12 @@
|
|||||||
|
from langchain_community.utilities import SQLDatabase
|
||||||
|
from langchain_community.agent_toolkits import create_sql_agent
|
||||||
|
from langchain_community.chat_models.tongyi import ChatTongyi
|
||||||
|
# SQLAlchemy URI
|
||||||
|
uri = "mssql+pyodbc://f8_db_test:APN^QPr!K9@122.114.58.23/f8_db_test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes"
|
||||||
|
|
||||||
|
# 建立数据库对象
|
||||||
|
ssDB = SQLDatabase.from_uri(
|
||||||
|
uri,
|
||||||
|
include_tables=["NONGHU_INFO"], # 不加 schema 前缀试试
|
||||||
|
schema="dbo" # 显式指定 schema
|
||||||
|
)
|
||||||
@@ -0,0 +1,115 @@
|
|||||||
|
import psycopg
|
||||||
|
from langchain_postgres import PostgresChatMessageHistory
|
||||||
|
from config.pgDb import database_name,getConn
|
||||||
|
from typing import List, Dict
|
||||||
|
# ————————————————————————————————————————————————————AI角色———————————————————————————————
|
||||||
|
|
||||||
|
def get_ai_personality(ai_id: str):
|
||||||
|
with getConn() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"SELECT ai_personality FROM ai_chat_profiles WHERE id = %s",
|
||||||
|
(ai_id,)
|
||||||
|
)
|
||||||
|
row = cur.fetchone()
|
||||||
|
if row:
|
||||||
|
return row[0]
|
||||||
|
else:
|
||||||
|
return "你是一个乐于助人的AI助手,请保持中文简洁回答用户。"
|
||||||
|
|
||||||
|
def get_all_ai(user_id: str) -> List[Dict]:
|
||||||
|
with getConn() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
# 查询用户角色
|
||||||
|
cur.execute(
|
||||||
|
"SELECT roles FROM users WHERE id = %s",
|
||||||
|
(user_id,)
|
||||||
|
)
|
||||||
|
role_row = cur.fetchone()
|
||||||
|
if not role_row:
|
||||||
|
return [] # 用户不存在
|
||||||
|
user_roles = role_row[0]
|
||||||
|
|
||||||
|
# 查询 AI 角色 JSON 字段包含用户角色
|
||||||
|
cur.execute(
|
||||||
|
"""
|
||||||
|
SELECT id, name, welcome_words
|
||||||
|
FROM ai_chat_profiles
|
||||||
|
WHERE availabel_roles::jsonb ?| %s
|
||||||
|
""",
|
||||||
|
(user_roles,) # user_roles 是 list,比如 ["a", "b", "c"]
|
||||||
|
)
|
||||||
|
rows = cur.fetchall()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": row[0],
|
||||||
|
"name": row[1],
|
||||||
|
"welcomeWords": row[2],
|
||||||
|
}
|
||||||
|
for row in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
# ————————————————————————————————————————————————————消息———————————————————————————————
|
||||||
|
def insert_message(session_id: str, isAI: bool, content: str):
|
||||||
|
with getConn() as conn:
|
||||||
|
history = PostgresChatMessageHistory(
|
||||||
|
database_name,
|
||||||
|
session_id,
|
||||||
|
sync_connection=conn
|
||||||
|
)
|
||||||
|
if isAI:
|
||||||
|
history.add_ai_message(content)
|
||||||
|
else:
|
||||||
|
history.add_user_message(content)
|
||||||
|
|
||||||
|
def get_history(session_id: str):
|
||||||
|
with getConn() as conn:
|
||||||
|
history = PostgresChatMessageHistory(
|
||||||
|
database_name,
|
||||||
|
session_id,
|
||||||
|
sync_connection=conn
|
||||||
|
)
|
||||||
|
simplified = []
|
||||||
|
for msg in history.messages:
|
||||||
|
simplified.append({
|
||||||
|
"type": msg.type,
|
||||||
|
"content": msg.content
|
||||||
|
})
|
||||||
|
|
||||||
|
return simplified
|
||||||
|
|
||||||
|
# ————————————————————————————————————————————————————会话———————————————————————————————
|
||||||
|
def insert_session(user_id: str,ai_id:str, session_id: str,session_title: str):
|
||||||
|
with getConn() as coon:
|
||||||
|
with coon.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"INSERT INTO ai_chat_sessions (id ,user_id, ai_id, title, created_at, updated_at) VALUES (%s, %s, %s, %s, NOW(), NOW())",
|
||||||
|
(session_id, user_id, ai_id, session_title )
|
||||||
|
)
|
||||||
|
coon.commit()
|
||||||
|
|
||||||
|
def update_session_updated_at(session_id: str):
|
||||||
|
with getConn() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"UPDATE ai_chat_sessions SET updated_at = NOW() WHERE id = %s",
|
||||||
|
(session_id,)
|
||||||
|
)
|
||||||
|
conn.commit()
|
||||||
|
|
||||||
|
def get_sessions(user_id: str):
|
||||||
|
with getConn() as conn:
|
||||||
|
with conn.cursor() as cur:
|
||||||
|
cur.execute(
|
||||||
|
"SELECT id, title, updated_at FROM ai_chat_sessions WHERE user_id = %s ORDER BY updated_at DESC",
|
||||||
|
(user_id,)
|
||||||
|
)
|
||||||
|
sessions = cur.fetchall()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": row[0],
|
||||||
|
"title": row[1],
|
||||||
|
"updated_at": row[2]
|
||||||
|
}
|
||||||
|
for row in sessions
|
||||||
|
]
|
||||||
@@ -0,0 +1,27 @@
|
|||||||
|
|
||||||
|
from config.llm import llm
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
|
||||||
|
chatPrompt = PromptTemplate(
|
||||||
|
input_variables=["aiRole", "history", "userInput"],
|
||||||
|
template = """
|
||||||
|
你是一个人,用户画像为:{aiRole}。
|
||||||
|
你需要基于你的角色性格,使用中文回答用户。
|
||||||
|
|
||||||
|
聊天历史:
|
||||||
|
{history}
|
||||||
|
|
||||||
|
用户最新输入:
|
||||||
|
{userInput}
|
||||||
|
|
||||||
|
最后,请注意,不要编造数据,不知道就说不知道,现在,请生成你的回复:
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
chatChain = chatPrompt | llm
|
||||||
|
|
||||||
|
def get_chat_response(aiRole: str,history: str, userInput: str) -> str:
|
||||||
|
return chatChain.invoke({
|
||||||
|
"aiRole": aiRole,
|
||||||
|
"history": history,
|
||||||
|
"userInput": userInput
|
||||||
|
})
|
||||||
@@ -0,0 +1,139 @@
|
|||||||
|
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from config.llm import llm
|
||||||
|
from config.ssDb import ssDB
|
||||||
|
from typing import Annotated
|
||||||
|
from typing_extensions import TypedDict
|
||||||
|
from langgraph.graph import StateGraph, START, END
|
||||||
|
from langgraph.graph.message import add_messages
|
||||||
|
from langchain_community.agent_toolkits import create_sql_agent
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from config.llm import llm
|
||||||
|
from config.ssDb import ssDB
|
||||||
|
from typing import Annotated
|
||||||
|
from langgraph.graph.message import add_messages
|
||||||
|
import os
|
||||||
|
from langchain_tavily import TavilySearch
|
||||||
|
from langgraph.prebuilt import ToolNode, tools_condition
|
||||||
|
from llm.chatLLM import get_chat_response
|
||||||
|
from typing import TypedDict
|
||||||
|
from langgraph.graph import StateGraph, END
|
||||||
|
from llm.summarizeLLM import getSummary
|
||||||
|
|
||||||
|
|
||||||
|
# -------- 定义状态 --------
|
||||||
|
class State(TypedDict):
|
||||||
|
userInput: str # 用户输入
|
||||||
|
source: str # 选择的数据来源:web 或 db 或 chat
|
||||||
|
infomation: str # 查询到的内容
|
||||||
|
aiRole: str # AI 角色
|
||||||
|
history: str # 聊天历史
|
||||||
|
reply: str # 最终回复
|
||||||
|
|
||||||
|
# -------- 定义节点 --------
|
||||||
|
# ------------------------------------------------------------------------ 路径选择 --------
|
||||||
|
|
||||||
|
pathSelectPrompt = PromptTemplate(
|
||||||
|
input_variables=["aiRole", "history", "userStr", "infomation"],
|
||||||
|
template = """
|
||||||
|
你是主干信息科技有限公司的业务员,是一家蚕桑服务公司,现在需要根据用户输入来判断应该使用哪种方式来回答用户的问题。
|
||||||
|
你有三种选择:
|
||||||
|
1. 如果用户的问题涉及最新的信息,比如新闻、事件、天气等涉及时间的内容时,请选择 "web
|
||||||
|
2. 如果用户的问题涉及具体的蚕桑业务(例如询问农户、订单、订种、租户)的数据库查询需求,请选择 "db"
|
||||||
|
3. 如果用户的问题是一般性的聊天或咨询,请选择 "chat"
|
||||||
|
请只返回 "web"、"db" 或 "chat" 之一,且不要添加任何其他解释。
|
||||||
|
用户最新输入:
|
||||||
|
{userStr}
|
||||||
|
请做出你的选择:
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
pathSelectChain = pathSelectPrompt | llm
|
||||||
|
|
||||||
|
def decide_source(state: State, max_retry=3):
|
||||||
|
print("根据用户输入选择数据来源,用户输入:", state["userInput"])
|
||||||
|
"""根据用户输入选择数据来源"""
|
||||||
|
for _ in range(max_retry):
|
||||||
|
choice = pathSelectChain.invoke({
|
||||||
|
"aiRole": state["aiRole"],
|
||||||
|
"history": state["history"],
|
||||||
|
"userStr": state["userInput"],
|
||||||
|
}).content.strip().lower()
|
||||||
|
if choice in ["web", "db", "chat"]:
|
||||||
|
state["source"] = choice
|
||||||
|
break
|
||||||
|
else:
|
||||||
|
# 如果连续 max_retry 次都不合法,默认走 chat
|
||||||
|
state["source"] = "chat"
|
||||||
|
print("选择的数据来源是:", state["source"])
|
||||||
|
return state
|
||||||
|
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------ 上网查询 --------
|
||||||
|
os.environ["TAVILY_API_KEY"] = "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"
|
||||||
|
tool = TavilySearch(max_results=2)
|
||||||
|
|
||||||
|
def fetch_web(state: State):
|
||||||
|
result = tool.invoke(state["userInput"])
|
||||||
|
state["infomation"] = result.get("content") or result
|
||||||
|
print("调用了联网工具,结果是:", state["infomation"])
|
||||||
|
return state
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------ 数据库查询 --------
|
||||||
|
agent = create_sql_agent(
|
||||||
|
llm=llm,
|
||||||
|
db=ssDB,
|
||||||
|
agent_type="tool-calling",
|
||||||
|
verbose=True
|
||||||
|
)
|
||||||
|
def fetch_db(state: State):
|
||||||
|
state["infomation"] = agent.invoke({"input": state["userInput"]})["output"]
|
||||||
|
print("调用了数据库工具,结果是:", state["infomation"])
|
||||||
|
return state
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------ 整理结果 --------
|
||||||
|
def summarize_ai(state: State):
|
||||||
|
"""AI 总结输出"""
|
||||||
|
state["reply"] = getSummary(aiRole=state["aiRole"], history=state["history"], userInput= state["userInput"], infomation= state["infomation"])
|
||||||
|
return state
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------ 普通聊天 --------
|
||||||
|
def chat(state: State):
|
||||||
|
state["reply"] = get_chat_response(aiRole=state["aiRole"],history=state["history"], userInput= state["userInput"]).content
|
||||||
|
print("直接回复")
|
||||||
|
return state
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------------ 构建有向图 --------
|
||||||
|
workflow = StateGraph(State)
|
||||||
|
workflow.add_node("decide", decide_source)
|
||||||
|
workflow.add_node("fetch_web", fetch_web)
|
||||||
|
workflow.add_node("fetch_db", fetch_db)
|
||||||
|
workflow.add_node("chat", chat)
|
||||||
|
workflow.add_node("summarize", summarize_ai)
|
||||||
|
workflow.set_entry_point("decide")
|
||||||
|
|
||||||
|
# 两条路径最后都汇合到 summarize
|
||||||
|
workflow.add_edge(START, "decide")
|
||||||
|
workflow.add_edge("fetch_web", "summarize")
|
||||||
|
workflow.add_edge("fetch_db", "summarize")
|
||||||
|
# 条件边:根据 source 决定走向
|
||||||
|
workflow.add_conditional_edges(
|
||||||
|
"decide",
|
||||||
|
lambda state: state["source"], # 返回 state["source"] 的值
|
||||||
|
{
|
||||||
|
"web": "fetch_web",
|
||||||
|
"chat": "chat",
|
||||||
|
"db": "fetch_db"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
workflow.add_edge("summarize", END)
|
||||||
|
workflow.add_edge("chat", END)
|
||||||
|
graph = workflow.compile()
|
||||||
|
|
||||||
|
# 执行函数
|
||||||
|
def get_graph_output(aiRole:str,history: str, userInput: str) -> str:
|
||||||
|
final_state = graph.invoke({
|
||||||
|
"aiRole":aiRole,
|
||||||
|
"history": history,
|
||||||
|
"userInput": userInput,
|
||||||
|
})
|
||||||
|
return final_state["reply"]
|
||||||
@@ -0,0 +1,34 @@
|
|||||||
|
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from config.llm import llm
|
||||||
|
|
||||||
|
summarizePrompt = PromptTemplate(
|
||||||
|
input_variables=["aiRole", "history", "userStr", "infomation"],
|
||||||
|
template = """
|
||||||
|
你是一个主干信息研发的 AI 助手,用户画像为:{aiRole}。
|
||||||
|
请基于你的角色性格,保持中文简洁回答的,根据下方提示回答用户。
|
||||||
|
|
||||||
|
聊天历史:
|
||||||
|
···
|
||||||
|
{history}
|
||||||
|
···
|
||||||
|
用户最新输入:
|
||||||
|
···
|
||||||
|
{userStr}
|
||||||
|
···
|
||||||
|
给你的参考内容:
|
||||||
|
···
|
||||||
|
{infomation}
|
||||||
|
···
|
||||||
|
如果参考内容明显有问题,你要请用户重新描述问题,现在请生成你的回复:
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
summarizeChain = summarizePrompt | llm
|
||||||
|
|
||||||
|
def getSummary(aiRole: str, history: str, userInput: str, infomation: str) -> str:
|
||||||
|
return summarizeChain.invoke({
|
||||||
|
"aiRole":aiRole,
|
||||||
|
"history": history,
|
||||||
|
"userStr": userInput,
|
||||||
|
"infomation": infomation
|
||||||
|
}).content
|
||||||
@@ -0,0 +1,19 @@
|
|||||||
|
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
from config.llm import llm
|
||||||
|
|
||||||
|
titlePrompt = PromptTemplate(
|
||||||
|
input_variables=["userStr"],
|
||||||
|
template = """
|
||||||
|
请将用户的这句话总结成一个简短、精准的对话标题,要求:
|
||||||
|
1. 不超过10个字(可根据需要调整长度)。
|
||||||
|
2. 直接概括本次对话的核心内容。
|
||||||
|
3. 避免使用笼统或无意义的词语,如“讨论”、“聊天”等。
|
||||||
|
4. 保持自然、易懂、专业或有趣(可根据场景调整风格)。
|
||||||
|
用户原话:"{userStr}"
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
titleChain = titlePrompt | llm
|
||||||
|
|
||||||
|
def get_title(userInput: str):
|
||||||
|
return titleChain.invoke({"userStr": userInput}).content
|
||||||
@@ -0,0 +1,11 @@
|
|||||||
|
from fastapi import FastAPI
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Generic, TypeVar, Optional, List
|
||||||
|
from pydantic.generics import GenericModel
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
# 定义通用响应结构
|
||||||
|
class BaseResponse(GenericModel, Generic[T]):
|
||||||
|
status: bool = True
|
||||||
|
message: str = "操作成功"
|
||||||
|
data: Optional[T] = None
|
||||||
@@ -0,0 +1,6 @@
|
|||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
aiId: str
|
||||||
|
sessionId: str | None = None
|
||||||
|
userInput: str
|
||||||
@@ -0,0 +1,63 @@
|
|||||||
|
from models.ChatRequest import ChatRequest
|
||||||
|
from models.BaseResponse import BaseResponse
|
||||||
|
import uuid
|
||||||
|
import db.postgres as db
|
||||||
|
import uuid
|
||||||
|
import threading
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from uuid import UUID
|
||||||
|
from config.security import get_user_id_from_token
|
||||||
|
router = APIRouter()
|
||||||
|
from llm.chatLLM import get_chat_response
|
||||||
|
from llm.titleChain import get_title
|
||||||
|
from llm.dataLLM import get_graph_output
|
||||||
|
|
||||||
|
# def async_db_task(func, *args, **kwargs):
|
||||||
|
# """将数据库操作放到后台线程执行"""
|
||||||
|
# threading.Thread(target=func, args=args, kwargs=kwargs, daemon=True).start()
|
||||||
|
|
||||||
|
@router.get("/history")
|
||||||
|
def getHistory(sessionId: str):
|
||||||
|
return BaseResponse(data=db.get_history(sessionId))
|
||||||
|
|
||||||
|
@router.get("/aiList")
|
||||||
|
def getAiList(user_id: UUID = Depends(get_user_id_from_token)):
|
||||||
|
if not user_id:
|
||||||
|
return {"error": "userId is required"}
|
||||||
|
return BaseResponse(data=db.get_all_ai(user_id))
|
||||||
|
|
||||||
|
@router.get("/sessions")
|
||||||
|
def getSessions(user_id: UUID = Depends(get_user_id_from_token)):
|
||||||
|
if not user_id:
|
||||||
|
return {"error": "userId is required"}
|
||||||
|
return BaseResponse(data=db.get_sessions(user_id))
|
||||||
|
|
||||||
|
@router.post("/chat")
|
||||||
|
def chat(req: ChatRequest, user_id: UUID = Depends(get_user_id_from_token)):
|
||||||
|
if not user_id:
|
||||||
|
return {"error": "userId is required"}
|
||||||
|
if not req.aiId:
|
||||||
|
return {"error": "aiId is required"}
|
||||||
|
sessionName = get_title(req.userInput)
|
||||||
|
# 如果没有 sessionId 就新建
|
||||||
|
if not req.sessionId:
|
||||||
|
isNewSession = True
|
||||||
|
req.sessionId = str(uuid.uuid4())
|
||||||
|
db.insert_session(user_id,req.aiId, req.sessionId, sessionName)
|
||||||
|
else:
|
||||||
|
isNewSession = False
|
||||||
|
db.update_session_updated_at(req.sessionId)
|
||||||
|
|
||||||
|
# 插入用户消息
|
||||||
|
db.insert_message(req.sessionId, False, req.userInput)
|
||||||
|
|
||||||
|
# 调用 LLM
|
||||||
|
if req.aiId == "9d157dd1-921b-c768-5b90-3e903b50f6f9":
|
||||||
|
# 数据专家AI
|
||||||
|
answer = get_graph_output(aiRole=db.get_ai_personality(req.aiId),history=db.get_history(req.sessionId), userInput= req.userInput)
|
||||||
|
else:
|
||||||
|
answer = get_chat_response(aiRole=db.get_ai_personality(req.aiId),history=db.get_history(req.sessionId), userInput= req.userInput).content
|
||||||
|
# 插入 AI 回复
|
||||||
|
db.insert_message(req.sessionId, True, answer)
|
||||||
|
|
||||||
|
return BaseResponse(data={"sessionName":sessionName,"isNewSession":isNewSession,"content":answer,"sessionId": req.sessionId})
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def add(a: int, b: int) -> int:
|
||||||
|
"""Adds a and b."""
|
||||||
|
return a + b
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
def multiply(a: int, b: int) -> int:
|
||||||
|
"""Multiplies a and b."""
|
||||||
|
return a * b
|
||||||
|
|
||||||
|
all_tools = [add, multiply]
|
||||||
@@ -0,0 +1,15 @@
|
|||||||
|
fastapi==0.116.1
|
||||||
|
langchain==0.3.27
|
||||||
|
langchain_community==0.3.29
|
||||||
|
langchain_core==0.3.75
|
||||||
|
langchain_postgres==0.0.15
|
||||||
|
langchain_tavily==0.2.11
|
||||||
|
langgraph==0.6.6
|
||||||
|
psycopg==3.2.9
|
||||||
|
psycopg_pool==3.2.6
|
||||||
|
pydantic==2.11.7
|
||||||
|
PyJWT==2.10.1
|
||||||
|
typing_extensions==4.15.0
|
||||||
|
uvicorn[standard]
|
||||||
|
pyodbc==5.2.0
|
||||||
|
dashscope==1.24.2
|
||||||
@@ -0,0 +1,7 @@
|
|||||||
|
from sqlalchemy import create_engine, text # ✅ 注意这里导入 text
|
||||||
|
uri = "mssql+pyodbc://f8_db_test:APN^QPr!K9@122.114.58.23/f8_db_test?driver=ODBC+Driver+18+for+SQL+Server&TrustServerCertificate=yes"
|
||||||
|
|
||||||
|
engine = create_engine(uri)
|
||||||
|
with engine.connect() as conn:
|
||||||
|
result = conn.execute(text("SELECT count(*) FROM dbo.NONGHU_INFO"))
|
||||||
|
print(result.scalar())
|
||||||
@@ -0,0 +1,41 @@
|
|||||||
|
from langchain_community.chat_models.tongyi import ChatTongyi
|
||||||
|
from langchain_postgres import PostgresChatMessageHistory
|
||||||
|
from langchain.chains import LLMChain
|
||||||
|
from langchain.prompts import PromptTemplate
|
||||||
|
import uuid
|
||||||
|
import psycopg
|
||||||
|
from utils.tools import all_tools
|
||||||
|
# ------------------ 配置 PostgreSQL 聊天记录 ------------------
|
||||||
|
conn = sync_connection=psycopg.connect("postgresql://postgres:123456@10.10.10.9/ktor2")
|
||||||
|
database_name = "ai_chat_history"
|
||||||
|
PostgresChatMessageHistory.create_tables(conn, database_name)
|
||||||
|
|
||||||
|
history = PostgresChatMessageHistory(
|
||||||
|
database_name,
|
||||||
|
str(uuid.uuid4()), # session_id
|
||||||
|
sync_connection=conn
|
||||||
|
)
|
||||||
|
# ------------------ 配置 LLM ------------------
|
||||||
|
prompt = PromptTemplate(
|
||||||
|
input_variables=["question"],
|
||||||
|
template="""
|
||||||
|
请基于上下文,保持中文简洁回答用户:
|
||||||
|
上下文:{history},
|
||||||
|
用户:{userStr}
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
llm = ChatTongyi(streaming=False, api_key="sk-fb46eefb6b404382a0a5325202e923a6")
|
||||||
|
llm_with_tools = llm.bind_tools(all_tools)
|
||||||
|
# 建立链
|
||||||
|
chain = prompt | llm_with_tools
|
||||||
|
# chain = prompt | llm
|
||||||
|
|
||||||
|
# ------------------ 循环聊天 ------------------
|
||||||
|
while True:
|
||||||
|
userStr = input("用户: ")
|
||||||
|
history.add_user_message(userStr)
|
||||||
|
answer = chain.invoke({"history":history.messages,"userStr": userStr})
|
||||||
|
print("AI:", answer.content)
|
||||||
|
history.add_ai_message(answer)
|
||||||
|
|
||||||
|
|
||||||
Reference in New Issue
Block a user