Files
2025-12-31 17:49:17 +08:00

67 lines
2.4 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from langchain_community.agent_toolkits import create_sql_agent
from langchain_core.prompts import PromptTemplate
from config.llm import llm
from config.ssDb import ssDBLC
# ______________________________________________________________SQL描述_____________________________________________________________________
sqlDescriptionPrompt = PromptTemplate(
input_variables=["sql"],
template="""
你是一个SQL专家,精通SQLServer数据库。请把一下SQL查询语句用通俗易懂的中文进行总结。
SQL语句:{sql}
有以下要求:
1. 不要任何解释
2. 不能有标点符号
3. 不能有markdown语法
4. 要用业务语言描述,不能有专业语句例如SQL表名等
请生成你认为合适的标题,:
""",
)
sqlDescriptionChain = sqlDescriptionPrompt | llm
def get_sql_description_response(sql: str) -> str:
return sqlDescriptionChain.invoke({"sql": sql})
# ______________________________________________________________第一次生成SQL_____________________________________________________________________
sqlPrompt = PromptTemplate(
input_variables=["userInput"],
template="""
你是一个SQL专家,精通SQLServer数据库。
请根据用户的需求,生成相应的SQL查询语句。
只需要返回SQL语句,不要任何解释。
用户需求:{userInput}
请生成SQL语句:
""",
)
sqlChain = sqlPrompt | llm
agent = create_sql_agent(llm=llm, db=ssDBLC, agent_type="tool-calling", verbose=True)
# def get_chat_sql_response2( userInput: str) -> str:
# return sqlChain.invoke({
# "userInput": userInput
# })
def get_chat_sql_response(userInput: str) -> str:
return agent.invoke({"input": userInput})["output"]
# ______________________________________________________________改进SQL_____________________________________________________________________
sqlImprovePrompt = PromptTemplate(
input_variables=["userInput", "sql"],
template="""
你是一个SQL专家,精通SQLServer数据库。
请根据用户的需求,改进已有的SQL查询语句。
只需要返回改进后的SQL语句,不要任何解释。
已有SQL{sql}
用户需求:{userInput}
""",
)
sqlImproveChain = sqlImprovePrompt | llm
def get_chat_sql_improve_response(userInput: str) -> str:
return sqlImproveChain.invoke({"userInput": userInput})