from langchain_community.agent_toolkits import create_sql_agent from langchain_core.prompts import PromptTemplate from config.llm import llm from config.ssDb import ssDBLC # ______________________________________________________________SQL描述_____________________________________________________________________ sqlDescriptionPrompt = PromptTemplate( input_variables=["sql"], template=""" 你是一个SQL专家,精通SQLServer数据库。请把一下SQL查询语句用通俗易懂的中文进行总结。 SQL语句:{sql} 有以下要求: 1. 不要任何解释 2. 不能有标点符号 3. 不能有markdown语法 4. 要用业务语言描述,不能有专业语句例如SQL表名等 请生成你认为合适的标题,: """, ) sqlDescriptionChain = sqlDescriptionPrompt | llm def get_sql_description_response(sql: str) -> str: return sqlDescriptionChain.invoke({"sql": sql}) # ______________________________________________________________第一次生成SQL_____________________________________________________________________ sqlPrompt = PromptTemplate( input_variables=["userInput"], template=""" 你是一个SQL专家,精通SQLServer数据库。 请根据用户的需求,生成相应的SQL查询语句。 只需要返回SQL语句,不要任何解释。 用户需求:{userInput} 请生成SQL语句: """, ) sqlChain = sqlPrompt | llm agent = create_sql_agent(llm=llm, db=ssDBLC, agent_type="tool-calling", verbose=True) # def get_chat_sql_response2( userInput: str) -> str: # return sqlChain.invoke({ # "userInput": userInput # }) def get_chat_sql_response(userInput: str) -> str: return agent.invoke({"input": userInput})["output"] # ______________________________________________________________改进SQL_____________________________________________________________________ sqlImprovePrompt = PromptTemplate( input_variables=["userInput", "sql"], template=""" 你是一个SQL专家,精通SQLServer数据库。 请根据用户的需求,改进已有的SQL查询语句。 只需要返回改进后的SQL语句,不要任何解释。 已有SQL:{sql} 用户需求:{userInput} """, ) sqlImproveChain = sqlImprovePrompt | llm def get_chat_sql_improve_response(userInput: str) -> str: return sqlImproveChain.invoke({"userInput": userInput})