from typing import List from typing import TypedDict from langchain_core.prompts import PromptTemplate from langgraph.graph import StateGraph, END import db.milvus as milvus import db.postgres as pgdb from config.llm import llm from config.llm import llmThink # -------- 定义状态 -------- class State(TypedDict): path: str # 开始聊天选择的路径 memory: str # 记忆 knowledge: str # 知识库内容 history: str # 聊天历史 ai_id: str # AI id ai_name: str # AI 名称 ai_service: str # AI 角色 业务 ai_role: str # AI 角色 性格特点 kn_bases: List[str] # AI 所使用的知识库 userInput: str # 用户输入 reply: str # 最终回复 # -------- 定义节点 -------- # ------------------------------------------------------------------------ 向量数据库查询 -------- gen_sql_prompt = PromptTemplate( input_variables=["userInput"], template="""你的任务是对用户输入进行意图分析,并将其分解成方便进行知识向量数据库搜索的关键词。 以下是用户的输入: <用户输入> {userInput} 在提取关键词时,请遵循以下方法和要求: 1. 去除输入中的停用词(如“的”“是”“在”等)、语气词和无实际意义的符号。 2. 识别输入中的核心概念、实体和关键动作。 3. 尽量使用简洁、通用的词汇作为关键词。 4. 确保关键词之间相互独立,不包含其他关键词。 关键词之间用空格分隔。 你的回答是: """, ) sqlChain = gen_sql_prompt | llm def db_search(state: State): key_words = sqlChain.invoke( { "userInput": state["userInput"], } ).content print("关键词是:", key_words) knowledge = milvus.get_knowledge_by_key_words(key_words, state["kn_bases"]) print("知识库内容是:", knowledge) state["knowledge"] = knowledge ai_ids = [state["ai_id"]] memory = milvus.get_memory_by_key_words(key_words, ai_ids) print("记忆是:", memory) state["memory"] = memory return state # ------------------------------------------------------------------------ 意图分析 -------- pathSelectPrompt = PromptTemplate( input_variables=["userInput", "ai_service", "history"], template=""" 你是一个意图分类器,负责判断用户提问是否与你的工作相关,进而确定是否需要去查知识库。 以下是你负责的工作内容: {ai_service} 这是你们的对话历史: {history} 用户最新回复是: {userInput} 判断规则如下: 如果用户最新回复与你的负责工作相关,需要去查知识库,输出“kn”;如果不相关,则输出“chat”,不要包含任何标点符号以及空格。 你生成的结果: """, ) pathSelectChain = pathSelectPrompt | llmThink def decide_source(state: State, max_retry=3): """根据用户输入选择数据来源""" for _ in range(max_retry): choice = ( pathSelectChain.invoke( { "userInput": state["userInput"], "ai_service": state["ai_service"], "history": state["history"], } ) .content.strip() .lower() ) print("根据用户输入选择数据来源,路径是:", choice) if choice in ["kn", "chat"]: state["path"] = choice break else: # 如果连续 max_retry 次都不合法,默认走 chat state["path"] = "chat" return state # ------------------------------------------------------------------------ !普通聊天 -------- noChatPrompt = PromptTemplate( input_variables=["ai_name", "ai_service", "ai_role", "history"], template=""" 你的名字是:{ai_name},你负责的业务是{ai_service},你具有{ai_role}的性格特点。 这是你和用户的对话历史 {history} 在回复用户时,请遵循以下指南: 1. 回复要与AI角色业务相关,体现AI的专业能力。 2. 回复内容的语气和风格要符合AI角色性格特点。 3. 参考聊天历史,使回复具有连贯性和针对性。 4. 回复要简洁明了,避免冗长和复杂的表述。 你的回答: """, ) noChatChain = noChatPrompt | llm def chat(state: State): state["reply"] = noChatChain.invoke( { "ai_name": state["ai_name"], "ai_service": state["ai_service"], "ai_role": state["ai_role"], "history": state["history"], "userStr": state["userInput"], } ).content print("直接回复") return state # ------------------------------------------------------------------------ 整理结果 -------- summarizePrompt = PromptTemplate( input_variables=["ai_name", "ai_service", "ai_role", "history", "knowledge"], template=""" 你的任务是基于给定的AI名称、AI角色业务、AI角色性格特点和聊天历史来回复用户。请仔细阅读以下信息,并按照指示进行回复。 你的名字是:{ai_name},你负责的业务是{ai_service},你具有{ai_role}的性格特点。 这是你和用户的对话历史 {history} 这是给你参考的知识库: {knowledge} {memory} 在回复时,请遵循以下指南: 1. 回复内容要与你负责的业务相关。 2. 回复的语气要结合你的性格特点。 3. 确保回复内容清晰、简洁、有针对性。 请生成你的回复: """, ) summarizeChain = summarizePrompt | llm def summarize_ai(state: State): """AI 总结输出""" mem = state["memory"] if mem != "": memStr = ( """ 这是给你参考的相关历史记忆: %s """ % mem ) # 这里用 % 把 mem 填进去 else: memStr = "没有记忆内容" print("历史记录是:", state["history"]) state["reply"] = summarizeChain.invoke( { "ai_role": state["ai_role"], "ai_name": state["ai_name"], "history": state["history"], "ai_service": state["ai_service"], "knowledge": state["knowledge"], "memory": memStr, } ).content return state # ------------------------------------------------------------------------ 构建有向图 -------- workflow = StateGraph(State) workflow.add_node("decide", decide_source) workflow.add_node("db_search", db_search) workflow.add_node("chat", chat) workflow.add_node("summarize", summarize_ai) workflow.set_entry_point("decide") # 条件边:根据 path 决定走向 workflow.add_conditional_edges( "decide", lambda state: state["path"], # 返回 state["path"] 的值 { "kn": "db_search", "chat": "chat", }, ) workflow.add_edge("db_search", "summarize") workflow.add_edge("summarize", END) workflow.add_edge("chat", END) graph = workflow.compile() # 执行函数 def get_service_agent_reply( aiId: str, userInput: str, history: str, kn_bases: List[str] ): json = pgdb.get_ai_personality(aiId) ai_service = json["业务"] ai_role = json["性格"] ai_name = json["名字"] print("AI Name:", ai_name) print("AI Service:", ai_service) final_state = graph.invoke( { "ai_service": ai_service, "ai_role": ai_role, "ai_name": ai_name, "history": history, "kn_bases": kn_bases, "table_info": pgdb.get_available_tables_str(aiId), "userInput": userInput, "ai_id": aiId, } ) return final_state["reply"]