仿生人AI服务端
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
import json
|
||||
|
||||
TAG = __name__
|
||||
|
||||
|
||||
async def handleAbortMessage(conn):
|
||||
conn.logger.bind(tag=TAG).info("Abort message received")
|
||||
# 设置成打断状态,会自动打断llm、tts任务
|
||||
conn.client_abort = True
|
||||
conn.clear_queues()
|
||||
# 打断客户端说话状态
|
||||
await conn.websocket.send(
|
||||
json.dumps({"type": "tts", "state": "stop", "session_id": conn.session_id})
|
||||
)
|
||||
conn.clearSpeakStatus()
|
||||
conn.logger.bind(tag=TAG).info("Abort message received-end")
|
||||
@@ -0,0 +1,153 @@
|
||||
import time
|
||||
import json
|
||||
import random
|
||||
import asyncio
|
||||
from core.utils.dialogue import Message
|
||||
from core.utils.util import audio_to_data
|
||||
from core.providers.tts.dto.dto import SentenceType
|
||||
from core.utils.wakeup_word import WakeupWordsConfig
|
||||
from core.handle.sendAudioHandle import sendAudioMessage, send_tts_message
|
||||
from core.utils.util import remove_punctuation_and_length, opus_datas_to_wav_bytes
|
||||
from core.providers.tools.device_mcp import (
|
||||
MCPClient,
|
||||
send_mcp_initialize_message,
|
||||
send_mcp_tools_list_request,
|
||||
)
|
||||
|
||||
TAG = __name__
|
||||
|
||||
WAKEUP_CONFIG = {
|
||||
"refresh_time": 10,
|
||||
"responses": [
|
||||
"我一直都在呢,您请说。",
|
||||
"在的呢,请随时吩咐我。",
|
||||
"来啦来啦,请告诉我吧。",
|
||||
"您请说,我正听着。",
|
||||
"请您讲话,我准备好了。",
|
||||
"请您说出指令吧。",
|
||||
"我认真听着呢,请讲。",
|
||||
"请问您需要什么帮助?",
|
||||
"我在这里,等候您的指令。",
|
||||
],
|
||||
}
|
||||
|
||||
# 创建全局的唤醒词配置管理器
|
||||
wakeup_words_config = WakeupWordsConfig()
|
||||
|
||||
# 用于防止并发调用wakeupWordsResponse的锁
|
||||
_wakeup_response_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def handleHelloMessage(conn, msg_json):
|
||||
"""处理hello消息"""
|
||||
audio_params = msg_json.get("audio_params")
|
||||
if audio_params:
|
||||
format = audio_params.get("format")
|
||||
conn.logger.bind(tag=TAG).info(f"客户端音频格式: {format}")
|
||||
conn.audio_format = format
|
||||
conn.welcome_msg["audio_params"] = audio_params
|
||||
features = msg_json.get("features")
|
||||
if features:
|
||||
conn.logger.bind(tag=TAG).info(f"客户端特性: {features}")
|
||||
conn.features = features
|
||||
if features.get("mcp"):
|
||||
conn.logger.bind(tag=TAG).info("客户端支持MCP")
|
||||
conn.mcp_client = MCPClient()
|
||||
# 发送初始化
|
||||
asyncio.create_task(send_mcp_initialize_message(conn))
|
||||
# 发送mcp消息,获取tools列表
|
||||
asyncio.create_task(send_mcp_tools_list_request(conn))
|
||||
|
||||
await conn.websocket.send(json.dumps(conn.welcome_msg))
|
||||
|
||||
|
||||
async def checkWakeupWords(conn, text):
|
||||
enable_wakeup_words_response_cache = conn.config[
|
||||
"enable_wakeup_words_response_cache"
|
||||
]
|
||||
|
||||
# 等待tts初始化,最多等待3秒
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < 3:
|
||||
if conn.tts:
|
||||
break
|
||||
await asyncio.sleep(0.1)
|
||||
else:
|
||||
return False
|
||||
|
||||
if not enable_wakeup_words_response_cache:
|
||||
return False
|
||||
|
||||
_, filtered_text = remove_punctuation_and_length(text)
|
||||
if filtered_text not in conn.config.get("wakeup_words"):
|
||||
return False
|
||||
|
||||
conn.just_woken_up = True
|
||||
await send_tts_message(conn, "start")
|
||||
|
||||
# 获取当前音色
|
||||
voice = getattr(conn.tts, "voice", "default")
|
||||
if not voice:
|
||||
voice = "default"
|
||||
|
||||
# 获取唤醒词回复配置
|
||||
response = wakeup_words_config.get_wakeup_response(voice)
|
||||
if not response or not response.get("file_path"):
|
||||
response = {
|
||||
"voice": "default",
|
||||
"file_path": "config/assets/wakeup_words_short.wav",
|
||||
"time": 0,
|
||||
"text": "我在这里哦!",
|
||||
}
|
||||
|
||||
# 获取音频数据
|
||||
opus_packets = audio_to_data(response.get("file_path"))
|
||||
# 播放唤醒词回复
|
||||
conn.client_abort = False
|
||||
|
||||
conn.logger.bind(tag=TAG).info(f"播放唤醒词回复: {response.get('text')}")
|
||||
await sendAudioMessage(conn, SentenceType.FIRST, opus_packets, response.get("text"))
|
||||
await sendAudioMessage(conn, SentenceType.LAST, [], None)
|
||||
|
||||
# 补充对话
|
||||
conn.dialogue.put(Message(role="assistant", content=response.get("text")))
|
||||
|
||||
# 检查是否需要更新唤醒词回复
|
||||
if time.time() - response.get("time", 0) > WAKEUP_CONFIG["refresh_time"]:
|
||||
if not _wakeup_response_lock.locked():
|
||||
asyncio.create_task(wakeupWordsResponse(conn))
|
||||
return True
|
||||
|
||||
|
||||
async def wakeupWordsResponse(conn):
|
||||
if not conn.tts:
|
||||
return
|
||||
|
||||
try:
|
||||
# 尝试获取锁,如果获取不到就返回
|
||||
if not await _wakeup_response_lock.acquire():
|
||||
return
|
||||
|
||||
# 从预定义回复列表中随机选择一个回复
|
||||
result = random.choice(WAKEUP_CONFIG["responses"])
|
||||
if not result or len(result) == 0:
|
||||
return
|
||||
|
||||
# 生成TTS音频
|
||||
tts_result = await asyncio.to_thread(conn.tts.to_tts, result)
|
||||
if not tts_result:
|
||||
return
|
||||
|
||||
# 获取当前音色
|
||||
voice = getattr(conn.tts, "voice", "default")
|
||||
|
||||
wav_bytes = opus_datas_to_wav_bytes(tts_result, sample_rate=16000)
|
||||
file_path = wakeup_words_config.generate_file_path(voice)
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(wav_bytes)
|
||||
# 更新配置
|
||||
wakeup_words_config.update_wakeup_response(voice, file_path, result)
|
||||
finally:
|
||||
# 确保在任何情况下都释放锁
|
||||
if _wakeup_response_lock.locked():
|
||||
_wakeup_response_lock.release()
|
||||
@@ -0,0 +1,207 @@
|
||||
import json
|
||||
import uuid
|
||||
import asyncio
|
||||
from core.utils.dialogue import Message
|
||||
from core.providers.tts.dto.dto import ContentType
|
||||
from core.handle.helloHandle import checkWakeupWords
|
||||
from plugins_func.register import Action, ActionResponse
|
||||
from core.handle.sendAudioHandle import send_stt_message
|
||||
from core.utils.util import remove_punctuation_and_length
|
||||
from core.providers.tts.dto.dto import TTSMessageDTO, SentenceType
|
||||
|
||||
TAG = __name__
|
||||
|
||||
|
||||
async def handle_user_intent(conn, text):
|
||||
# 预处理输入文本,处理可能的JSON格式
|
||||
try:
|
||||
if text.strip().startswith('{') and text.strip().endswith('}'):
|
||||
parsed_data = json.loads(text)
|
||||
if isinstance(parsed_data, dict) and "content" in parsed_data:
|
||||
text = parsed_data["content"] # 提取content用于意图分析
|
||||
conn.current_speaker = parsed_data.get("speaker") # 保留说话人信息
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
# 检查是否有明确的退出命令
|
||||
_, filtered_text = remove_punctuation_and_length(text)
|
||||
if await check_direct_exit(conn, filtered_text):
|
||||
return True
|
||||
|
||||
# 检查是否是唤醒词
|
||||
if await checkWakeupWords(conn, filtered_text):
|
||||
return True
|
||||
|
||||
if conn.intent_type == "function_call":
|
||||
# 使用支持function calling的聊天方法,不再进行意图分析
|
||||
return False
|
||||
# 使用LLM进行意图分析
|
||||
intent_result = await analyze_intent_with_llm(conn, text)
|
||||
if not intent_result:
|
||||
return False
|
||||
# 会话开始时生成sentence_id
|
||||
conn.sentence_id = str(uuid.uuid4().hex)
|
||||
# 处理各种意图
|
||||
return await process_intent_result(conn, intent_result, text)
|
||||
|
||||
|
||||
async def check_direct_exit(conn, text):
|
||||
"""检查是否有明确的退出命令"""
|
||||
_, text = remove_punctuation_and_length(text)
|
||||
cmd_exit = conn.cmd_exit
|
||||
for cmd in cmd_exit:
|
||||
if text == cmd:
|
||||
conn.logger.bind(tag=TAG).info(f"识别到明确的退出命令: {text}")
|
||||
await send_stt_message(conn, text)
|
||||
await conn.close()
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def analyze_intent_with_llm(conn, text):
|
||||
"""使用LLM分析用户意图"""
|
||||
if not hasattr(conn, "intent") or not conn.intent:
|
||||
conn.logger.bind(tag=TAG).warning("意图识别服务未初始化")
|
||||
return None
|
||||
|
||||
# 对话历史记录
|
||||
dialogue = conn.dialogue
|
||||
try:
|
||||
intent_result = await conn.intent.detect_intent(conn, dialogue.dialogue, text)
|
||||
return intent_result
|
||||
except Exception as e:
|
||||
conn.logger.bind(tag=TAG).error(f"意图识别失败: {str(e)}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def process_intent_result(conn, intent_result, original_text):
|
||||
"""处理意图识别结果"""
|
||||
try:
|
||||
# 尝试将结果解析为JSON
|
||||
intent_data = json.loads(intent_result)
|
||||
|
||||
# 检查是否有function_call
|
||||
if "function_call" in intent_data:
|
||||
# 直接从意图识别获取了function_call
|
||||
conn.logger.bind(tag=TAG).debug(
|
||||
f"检测到function_call格式的意图结果: {intent_data['function_call']['name']}"
|
||||
)
|
||||
function_name = intent_data["function_call"]["name"]
|
||||
if function_name == "continue_chat":
|
||||
return False
|
||||
|
||||
if function_name == "result_for_context":
|
||||
await send_stt_message(conn, original_text)
|
||||
conn.client_abort = False
|
||||
|
||||
def process_context_result():
|
||||
conn.dialogue.put(Message(role="user", content=original_text))
|
||||
|
||||
from core.utils.current_time import get_current_time_info
|
||||
|
||||
current_time, today_date, today_weekday, lunar_date = get_current_time_info()
|
||||
|
||||
# 构建带上下文的基础提示
|
||||
context_prompt = f"""当前时间:{current_time}
|
||||
今天日期:{today_date} ({today_weekday})
|
||||
今天农历:{lunar_date}
|
||||
|
||||
请根据以上信息回答用户的问题:{original_text}"""
|
||||
|
||||
response = conn.intent.replyResult(context_prompt, original_text)
|
||||
speak_txt(conn, response)
|
||||
|
||||
conn.executor.submit(process_context_result)
|
||||
return True
|
||||
|
||||
function_args = {}
|
||||
if "arguments" in intent_data["function_call"]:
|
||||
function_args = intent_data["function_call"]["arguments"]
|
||||
if function_args is None:
|
||||
function_args = {}
|
||||
# 确保参数是字符串格式的JSON
|
||||
if isinstance(function_args, dict):
|
||||
function_args = json.dumps(function_args)
|
||||
|
||||
function_call_data = {
|
||||
"name": function_name,
|
||||
"id": str(uuid.uuid4().hex),
|
||||
"arguments": function_args,
|
||||
}
|
||||
|
||||
await send_stt_message(conn, original_text)
|
||||
conn.client_abort = False
|
||||
|
||||
# 使用executor执行函数调用和结果处理
|
||||
def process_function_call():
|
||||
conn.dialogue.put(Message(role="user", content=original_text))
|
||||
|
||||
# 使用统一工具处理器处理所有工具调用
|
||||
try:
|
||||
result = asyncio.run_coroutine_threadsafe(
|
||||
conn.func_handler.handle_llm_function_call(
|
||||
conn, function_call_data
|
||||
),
|
||||
conn.loop,
|
||||
).result()
|
||||
except Exception as e:
|
||||
conn.logger.bind(tag=TAG).error(f"工具调用失败: {e}")
|
||||
result = ActionResponse(
|
||||
action=Action.ERROR, result=str(e), response=str(e)
|
||||
)
|
||||
|
||||
if result:
|
||||
if result.action == Action.RESPONSE: # 直接回复前端
|
||||
text = result.response
|
||||
if text is not None:
|
||||
speak_txt(conn, text)
|
||||
elif result.action == Action.REQLLM: # 调用函数后再请求llm生成回复
|
||||
text = result.result
|
||||
conn.dialogue.put(Message(role="tool", content=text))
|
||||
llm_result = conn.intent.replyResult(text, original_text)
|
||||
if llm_result is None:
|
||||
llm_result = text
|
||||
speak_txt(conn, llm_result)
|
||||
elif (
|
||||
result.action == Action.NOTFOUND
|
||||
or result.action == Action.ERROR
|
||||
):
|
||||
text = result.result
|
||||
if text is not None:
|
||||
speak_txt(conn, text)
|
||||
elif function_name != "play_music":
|
||||
# For backward compatibility with original code
|
||||
# 获取当前最新的文本索引
|
||||
text = result.response
|
||||
if text is None:
|
||||
text = result.result
|
||||
if text is not None:
|
||||
speak_txt(conn, text)
|
||||
|
||||
# 将函数执行放在线程池中
|
||||
conn.executor.submit(process_function_call)
|
||||
return True
|
||||
return False
|
||||
except json.JSONDecodeError as e:
|
||||
conn.logger.bind(tag=TAG).error(f"处理意图结果时出错: {e}")
|
||||
return False
|
||||
|
||||
|
||||
def speak_txt(conn, text):
|
||||
conn.tts.tts_text_queue.put(
|
||||
TTSMessageDTO(
|
||||
sentence_id=conn.sentence_id,
|
||||
sentence_type=SentenceType.FIRST,
|
||||
content_type=ContentType.ACTION,
|
||||
)
|
||||
)
|
||||
conn.tts.tts_one_sentence(conn, ContentType.TEXT, content_detail=text)
|
||||
conn.tts.tts_text_queue.put(
|
||||
TTSMessageDTO(
|
||||
sentence_id=conn.sentence_id,
|
||||
sentence_type=SentenceType.LAST,
|
||||
content_type=ContentType.ACTION,
|
||||
)
|
||||
)
|
||||
conn.dialogue.put(Message(role="assistant", content=text))
|
||||
@@ -0,0 +1,166 @@
|
||||
import time
|
||||
import json
|
||||
import asyncio
|
||||
from core.utils.util import audio_to_data
|
||||
from core.handle.abortHandle import handleAbortMessage
|
||||
from core.handle.intentHandler import handle_user_intent
|
||||
from core.utils.output_counter import check_device_output_limit
|
||||
from core.handle.sendAudioHandle import send_stt_message, SentenceType
|
||||
|
||||
TAG = __name__
|
||||
|
||||
|
||||
async def handleAudioMessage(conn, audio):
|
||||
# 当前片段是否有人说话
|
||||
have_voice = conn.vad.is_vad(conn, audio)
|
||||
# 如果设备刚刚被唤醒,短暂忽略VAD检测
|
||||
if hasattr(conn, "just_woken_up") and conn.just_woken_up:
|
||||
have_voice = False
|
||||
# 设置一个短暂延迟后恢复VAD检测
|
||||
conn.asr_audio.clear()
|
||||
if not hasattr(conn, "vad_resume_task") or conn.vad_resume_task.done():
|
||||
conn.vad_resume_task = asyncio.create_task(resume_vad_detection(conn))
|
||||
return
|
||||
# manual 模式下不打断正在播放的内容
|
||||
if have_voice:
|
||||
if conn.client_is_speaking and conn.client_listen_mode != "manual":
|
||||
await handleAbortMessage(conn)
|
||||
# 设备长时间空闲检测,用于say goodbye
|
||||
await no_voice_close_connect(conn, have_voice)
|
||||
# 接收音频
|
||||
await conn.asr.receive_audio(conn, audio, have_voice)
|
||||
|
||||
|
||||
async def resume_vad_detection(conn):
|
||||
# 等待2秒后恢复VAD检测
|
||||
await asyncio.sleep(2)
|
||||
conn.just_woken_up = False
|
||||
|
||||
|
||||
async def startToChat(conn, text):
|
||||
# 检查输入是否是JSON格式(包含说话人信息)
|
||||
speaker_name = None
|
||||
actual_text = text
|
||||
|
||||
try:
|
||||
# 尝试解析JSON格式的输入
|
||||
if text.strip().startswith("{") and text.strip().endswith("}"):
|
||||
data = json.loads(text)
|
||||
if "speaker" in data and "content" in data:
|
||||
speaker_name = data["speaker"]
|
||||
actual_text = data["content"]
|
||||
conn.logger.bind(tag=TAG).info(f"解析到说话人信息: {speaker_name}")
|
||||
|
||||
# 直接使用JSON格式的文本,不解析
|
||||
actual_text = text
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# 如果解析失败,继续使用原始文本
|
||||
pass
|
||||
|
||||
# 保存说话人信息到连接对象
|
||||
if speaker_name:
|
||||
conn.current_speaker = speaker_name
|
||||
else:
|
||||
conn.current_speaker = None
|
||||
|
||||
if conn.need_bind:
|
||||
await check_bind_device(conn)
|
||||
return
|
||||
|
||||
# 如果当日的输出字数大于限定的字数
|
||||
if conn.max_output_size > 0:
|
||||
if check_device_output_limit(
|
||||
conn.headers.get("device-id"), conn.max_output_size
|
||||
):
|
||||
await max_out_size(conn)
|
||||
return
|
||||
# manual 模式下不打断正在播放的内容
|
||||
if conn.client_is_speaking and conn.client_listen_mode != "manual":
|
||||
await handleAbortMessage(conn)
|
||||
|
||||
# 首先进行意图分析,使用实际文本内容
|
||||
intent_handled = await handle_user_intent(conn, actual_text)
|
||||
|
||||
if intent_handled:
|
||||
# 如果意图已被处理,不再进行聊天
|
||||
return
|
||||
|
||||
# 意图未被处理,继续常规聊天流程,使用实际文本内容
|
||||
await send_stt_message(conn, actual_text)
|
||||
conn.executor.submit(conn.chat, actual_text)
|
||||
|
||||
|
||||
async def no_voice_close_connect(conn, have_voice):
|
||||
if have_voice:
|
||||
conn.last_activity_time = time.time() * 1000
|
||||
return
|
||||
# 只有在已经初始化过时间戳的情况下才进行超时检查
|
||||
if conn.last_activity_time > 0.0:
|
||||
no_voice_time = time.time() * 1000 - conn.last_activity_time
|
||||
close_connection_no_voice_time = int(
|
||||
conn.config.get("close_connection_no_voice_time", 120)
|
||||
)
|
||||
if (
|
||||
not conn.close_after_chat
|
||||
and no_voice_time > 1000 * close_connection_no_voice_time
|
||||
):
|
||||
conn.close_after_chat = True
|
||||
conn.client_abort = False
|
||||
end_prompt = conn.config.get("end_prompt", {})
|
||||
if end_prompt and end_prompt.get("enable", True) is False:
|
||||
conn.logger.bind(tag=TAG).info("结束对话,无需发送结束提示语")
|
||||
await conn.close()
|
||||
return
|
||||
prompt = end_prompt.get("prompt")
|
||||
if not prompt:
|
||||
prompt = "请你以```时间过得真快```未来头,用富有感情、依依不舍的话来结束这场对话吧。!"
|
||||
await startToChat(conn, prompt)
|
||||
|
||||
|
||||
async def max_out_size(conn):
|
||||
# 播放超出最大输出字数的提示
|
||||
conn.client_abort = False
|
||||
text = "不好意思,我现在有点事情要忙,明天这个时候我们再聊,约好了哦!明天不见不散,拜拜!"
|
||||
await send_stt_message(conn, text)
|
||||
file_path = "config/assets/max_output_size.wav"
|
||||
opus_packets = audio_to_data(file_path)
|
||||
conn.tts.tts_audio_queue.put((SentenceType.LAST, opus_packets, text))
|
||||
conn.close_after_chat = True
|
||||
|
||||
|
||||
async def check_bind_device(conn):
|
||||
if conn.bind_code:
|
||||
# 确保bind_code是6位数字
|
||||
if len(conn.bind_code) != 6:
|
||||
conn.logger.bind(tag=TAG).error(f"无效的绑定码格式: {conn.bind_code}")
|
||||
text = "绑定码格式错误,请检查配置。"
|
||||
await send_stt_message(conn, text)
|
||||
return
|
||||
|
||||
text = f"请登录控制面板,输入{conn.bind_code},绑定设备。"
|
||||
await send_stt_message(conn, text)
|
||||
|
||||
# 播放提示音
|
||||
music_path = "config/assets/bind_code.wav"
|
||||
opus_packets = audio_to_data(music_path)
|
||||
conn.tts.tts_audio_queue.put((SentenceType.FIRST, opus_packets, text))
|
||||
|
||||
# 逐个播放数字
|
||||
for i in range(6): # 确保只播放6位数字
|
||||
try:
|
||||
digit = conn.bind_code[i]
|
||||
num_path = f"config/assets/bind_code/{digit}.wav"
|
||||
num_packets = audio_to_data(num_path)
|
||||
conn.tts.tts_audio_queue.put((SentenceType.MIDDLE, num_packets, None))
|
||||
except Exception as e:
|
||||
conn.logger.bind(tag=TAG).error(f"播放数字音频失败: {e}")
|
||||
continue
|
||||
conn.tts.tts_audio_queue.put((SentenceType.LAST, [], None))
|
||||
else:
|
||||
# 播放未绑定提示
|
||||
conn.client_abort = False
|
||||
text = f"没有找到该设备的版本信息,请正确配置 OTA地址,然后重新编译固件。"
|
||||
await send_stt_message(conn, text)
|
||||
music_path = "config/assets/bind_not_found.wav"
|
||||
opus_packets = audio_to_data(music_path)
|
||||
conn.tts.tts_audio_queue.put((SentenceType.LAST, opus_packets, text))
|
||||
@@ -0,0 +1,149 @@
|
||||
"""
|
||||
TTS上报功能已集成到ConnectionHandler类中。
|
||||
|
||||
上报功能包括:
|
||||
1. 每个连接对象拥有自己的上报队列和处理线程
|
||||
2. 上报线程的生命周期与连接对象绑定
|
||||
3. 使用ConnectionHandler.enqueue_tts_report方法进行上报
|
||||
|
||||
具体实现请参考core/connection.py中的相关代码。
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
import opuslib_next
|
||||
|
||||
from config.manage_api_client import report as manage_report
|
||||
|
||||
TAG = __name__
|
||||
|
||||
|
||||
def report(conn, type, text, opus_data, report_time):
|
||||
"""执行聊天记录上报操作
|
||||
|
||||
Args:
|
||||
conn: 连接对象
|
||||
type: 上报类型,1为用户,2为智能体
|
||||
text: 合成文本
|
||||
opus_data: opus音频数据
|
||||
report_time: 上报时间
|
||||
"""
|
||||
try:
|
||||
if opus_data:
|
||||
audio_data = opus_to_wav(conn, opus_data)
|
||||
else:
|
||||
audio_data = None
|
||||
# 执行上报
|
||||
manage_report(
|
||||
mac_address=conn.device_id,
|
||||
session_id=conn.session_id,
|
||||
chat_type=type,
|
||||
content=text,
|
||||
audio=audio_data,
|
||||
report_time=report_time,
|
||||
)
|
||||
except Exception as e:
|
||||
conn.logger.bind(tag=TAG).error(f"聊天记录上报失败: {e}")
|
||||
|
||||
|
||||
def opus_to_wav(conn, opus_data):
|
||||
"""将Opus数据转换为WAV格式的字节流
|
||||
|
||||
Args:
|
||||
output_dir: 输出目录(保留参数以保持接口兼容)
|
||||
opus_data: opus音频数据
|
||||
|
||||
Returns:
|
||||
bytes: WAV格式的音频数据
|
||||
"""
|
||||
decoder = opuslib_next.Decoder(16000, 1) # 16kHz, 单声道
|
||||
pcm_data = []
|
||||
|
||||
for opus_packet in opus_data:
|
||||
try:
|
||||
pcm_frame = decoder.decode(opus_packet, 960) # 960 samples = 60ms
|
||||
pcm_data.append(pcm_frame)
|
||||
except opuslib_next.OpusError as e:
|
||||
conn.logger.bind(tag=TAG).error(f"Opus解码错误: {e}", exc_info=True)
|
||||
|
||||
if not pcm_data:
|
||||
raise ValueError("没有有效的PCM数据")
|
||||
|
||||
# 创建WAV文件头
|
||||
pcm_data_bytes = b"".join(pcm_data)
|
||||
num_samples = len(pcm_data_bytes) // 2 # 16-bit samples
|
||||
|
||||
# WAV文件头
|
||||
wav_header = bytearray()
|
||||
wav_header.extend(b"RIFF") # ChunkID
|
||||
wav_header.extend((36 + len(pcm_data_bytes)).to_bytes(4, "little")) # ChunkSize
|
||||
wav_header.extend(b"WAVE") # Format
|
||||
wav_header.extend(b"fmt ") # Subchunk1ID
|
||||
wav_header.extend((16).to_bytes(4, "little")) # Subchunk1Size
|
||||
wav_header.extend((1).to_bytes(2, "little")) # AudioFormat (PCM)
|
||||
wav_header.extend((1).to_bytes(2, "little")) # NumChannels
|
||||
wav_header.extend((16000).to_bytes(4, "little")) # SampleRate
|
||||
wav_header.extend((32000).to_bytes(4, "little")) # ByteRate
|
||||
wav_header.extend((2).to_bytes(2, "little")) # BlockAlign
|
||||
wav_header.extend((16).to_bytes(2, "little")) # BitsPerSample
|
||||
wav_header.extend(b"data") # Subchunk2ID
|
||||
wav_header.extend(len(pcm_data_bytes).to_bytes(4, "little")) # Subchunk2Size
|
||||
|
||||
# 返回完整的WAV数据
|
||||
return bytes(wav_header) + pcm_data_bytes
|
||||
|
||||
|
||||
def enqueue_tts_report(conn, text, opus_data):
|
||||
if not conn.read_config_from_api or conn.need_bind or not conn.report_tts_enable:
|
||||
return
|
||||
if conn.chat_history_conf == 0:
|
||||
return
|
||||
"""将TTS数据加入上报队列
|
||||
|
||||
Args:
|
||||
conn: 连接对象
|
||||
text: 合成文本
|
||||
opus_data: opus音频数据
|
||||
"""
|
||||
try:
|
||||
# 使用连接对象的队列,传入文本和二进制数据而非文件路径
|
||||
if conn.chat_history_conf == 2:
|
||||
conn.report_queue.put((2, text, opus_data, int(time.time())))
|
||||
conn.logger.bind(tag=TAG).debug(
|
||||
f"TTS数据已加入上报队列: {conn.device_id}, 音频大小: {len(opus_data)} "
|
||||
)
|
||||
else:
|
||||
conn.report_queue.put((2, text, None, int(time.time())))
|
||||
conn.logger.bind(tag=TAG).debug(
|
||||
f"TTS数据已加入上报队列: {conn.device_id}, 不上报音频"
|
||||
)
|
||||
except Exception as e:
|
||||
conn.logger.bind(tag=TAG).error(f"加入TTS上报队列失败: {text}, {e}")
|
||||
|
||||
|
||||
def enqueue_asr_report(conn, text, opus_data):
|
||||
if not conn.read_config_from_api or conn.need_bind or not conn.report_asr_enable:
|
||||
return
|
||||
if conn.chat_history_conf == 0:
|
||||
return
|
||||
"""将ASR数据加入上报队列
|
||||
|
||||
Args:
|
||||
conn: 连接对象
|
||||
text: 合成文本
|
||||
opus_data: opus音频数据
|
||||
"""
|
||||
try:
|
||||
# 使用连接对象的队列,传入文本和二进制数据而非文件路径
|
||||
if conn.chat_history_conf == 2:
|
||||
conn.report_queue.put((1, text, opus_data, int(time.time())))
|
||||
conn.logger.bind(tag=TAG).debug(
|
||||
f"ASR数据已加入上报队列: {conn.device_id}, 音频大小: {len(opus_data)} "
|
||||
)
|
||||
else:
|
||||
conn.report_queue.put((1, text, None, int(time.time())))
|
||||
conn.logger.bind(tag=TAG).debug(
|
||||
f"ASR数据已加入上报队列: {conn.device_id}, 不上报音频"
|
||||
)
|
||||
except Exception as e:
|
||||
conn.logger.bind(tag=TAG).debug(f"加入ASR上报队列失败: {text}, {e}")
|
||||
@@ -0,0 +1,253 @@
|
||||
import json
|
||||
import time
|
||||
import asyncio
|
||||
from core.utils import textUtils
|
||||
from core.utils.util import audio_to_data
|
||||
from core.providers.tts.dto.dto import SentenceType
|
||||
|
||||
TAG = __name__
|
||||
|
||||
|
||||
async def sendAudioMessage(conn, sentenceType, audios, text):
|
||||
if conn.tts.tts_audio_first_sentence:
|
||||
conn.logger.bind(tag=TAG).info(f"发送第一段语音: {text}")
|
||||
conn.tts.tts_audio_first_sentence = False
|
||||
await send_tts_message(conn, "start", None)
|
||||
|
||||
if sentenceType == SentenceType.FIRST:
|
||||
await send_tts_message(conn, "sentence_start", text)
|
||||
|
||||
await sendAudio(conn, audios)
|
||||
# 发送句子开始消息
|
||||
if sentenceType is not SentenceType.MIDDLE:
|
||||
conn.logger.bind(tag=TAG).info(f"发送音频消息: {sentenceType}, {text}")
|
||||
|
||||
# 发送结束消息(如果是最后一个文本)
|
||||
if conn.llm_finish_task and sentenceType == SentenceType.LAST:
|
||||
await send_tts_message(conn, "stop", None)
|
||||
conn.client_is_speaking = False
|
||||
if conn.close_after_chat:
|
||||
await conn.close()
|
||||
|
||||
|
||||
def calculate_timestamp_and_sequence(conn, start_time, packet_index, frame_duration=60):
|
||||
"""
|
||||
计算音频数据包的时间戳和序列号
|
||||
Args:
|
||||
conn: 连接对象
|
||||
start_time: 起始时间(性能计数器值)
|
||||
packet_index: 数据包索引
|
||||
frame_duration: 帧时长(毫秒),匹配 Opus 编码
|
||||
Returns:
|
||||
tuple: (timestamp, sequence)
|
||||
"""
|
||||
# 计算时间戳(使用播放位置计算)
|
||||
timestamp = int((start_time + packet_index * frame_duration / 1000) * 1000) % (
|
||||
2**32
|
||||
)
|
||||
|
||||
# 计算序列号
|
||||
if hasattr(conn, "audio_flow_control"):
|
||||
sequence = conn.audio_flow_control["sequence"]
|
||||
else:
|
||||
sequence = packet_index # 如果没有流控状态,直接使用索引
|
||||
|
||||
return timestamp, sequence
|
||||
|
||||
|
||||
async def _send_to_mqtt_gateway(conn, opus_packet, timestamp, sequence):
|
||||
"""
|
||||
发送带16字节头部的opus数据包给mqtt_gateway
|
||||
Args:
|
||||
conn: 连接对象
|
||||
opus_packet: opus数据包
|
||||
timestamp: 时间戳
|
||||
sequence: 序列号
|
||||
"""
|
||||
# 为opus数据包添加16字节头部
|
||||
header = bytearray(16)
|
||||
header[0] = 1 # type
|
||||
header[2:4] = len(opus_packet).to_bytes(2, "big") # payload length
|
||||
header[4:8] = sequence.to_bytes(4, "big") # sequence
|
||||
header[8:12] = timestamp.to_bytes(4, "big") # 时间戳
|
||||
header[12:16] = len(opus_packet).to_bytes(4, "big") # opus长度
|
||||
|
||||
# 发送包含头部的完整数据包
|
||||
complete_packet = bytes(header) + opus_packet
|
||||
await conn.websocket.send(complete_packet)
|
||||
|
||||
|
||||
# 播放音频
|
||||
async def sendAudio(conn, audios, frame_duration=60):
|
||||
"""
|
||||
发送单个opus包,支持流控
|
||||
Args:
|
||||
conn: 连接对象
|
||||
opus_packet: 单个opus数据包
|
||||
pre_buffer: 快速发送音频
|
||||
frame_duration: 帧时长(毫秒),匹配 Opus 编码
|
||||
"""
|
||||
if audios is None or len(audios) == 0:
|
||||
return
|
||||
|
||||
# 获取发送延迟配置
|
||||
send_delay = conn.config.get("tts_audio_send_delay", -1) / 1000.0
|
||||
|
||||
if isinstance(audios, bytes):
|
||||
if conn.client_abort:
|
||||
return
|
||||
|
||||
conn.last_activity_time = time.time() * 1000
|
||||
|
||||
# 获取或初始化流控状态
|
||||
if not hasattr(conn, "audio_flow_control"):
|
||||
conn.audio_flow_control = {
|
||||
"last_send_time": 0,
|
||||
"packet_count": 0,
|
||||
"start_time": time.perf_counter(),
|
||||
"sequence": 0, # 添加序列号
|
||||
}
|
||||
|
||||
flow_control = conn.audio_flow_control
|
||||
current_time = time.perf_counter()
|
||||
|
||||
if send_delay > 0:
|
||||
# 使用固定延迟
|
||||
await asyncio.sleep(send_delay)
|
||||
else:
|
||||
# 计算预期发送时间
|
||||
expected_time = flow_control["start_time"] + (
|
||||
flow_control["packet_count"] * frame_duration / 1000
|
||||
)
|
||||
delay = expected_time - current_time
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
else:
|
||||
# 纠正误差
|
||||
flow_control["start_time"] += abs(delay)
|
||||
|
||||
if conn.conn_from_mqtt_gateway:
|
||||
# 计算时间戳和序列号
|
||||
timestamp, sequence = calculate_timestamp_and_sequence(
|
||||
conn,
|
||||
flow_control["start_time"],
|
||||
flow_control["packet_count"],
|
||||
frame_duration,
|
||||
)
|
||||
# 调用通用函数发送带头部的数据包
|
||||
await _send_to_mqtt_gateway(conn, audios, timestamp, sequence)
|
||||
else:
|
||||
# 直接发送opus数据包,不添加头部
|
||||
await conn.websocket.send(audios)
|
||||
|
||||
# 更新流控状态
|
||||
flow_control["packet_count"] += 1
|
||||
flow_control["sequence"] += 1
|
||||
flow_control["last_send_time"] = time.perf_counter()
|
||||
else:
|
||||
# 文件型音频走普通播放
|
||||
start_time = time.perf_counter()
|
||||
play_position = 0
|
||||
|
||||
# 执行预缓冲
|
||||
pre_buffer_frames = min(3, len(audios))
|
||||
for i in range(pre_buffer_frames):
|
||||
if conn.conn_from_mqtt_gateway:
|
||||
# 计算时间戳和序列号
|
||||
timestamp, sequence = calculate_timestamp_and_sequence(
|
||||
conn, start_time, i, frame_duration
|
||||
)
|
||||
# 调用通用函数发送带头部的数据包
|
||||
await _send_to_mqtt_gateway(conn, audios[i], timestamp, sequence)
|
||||
else:
|
||||
# 直接发送预缓冲包,不添加头部
|
||||
await conn.websocket.send(audios[i])
|
||||
remaining_audios = audios[pre_buffer_frames:]
|
||||
|
||||
# 播放剩余音频帧
|
||||
for i, opus_packet in enumerate(remaining_audios):
|
||||
if conn.client_abort:
|
||||
break
|
||||
|
||||
# 重置没有声音的状态
|
||||
conn.last_activity_time = time.time() * 1000
|
||||
|
||||
if send_delay > 0:
|
||||
# 固定延迟模式
|
||||
await asyncio.sleep(send_delay)
|
||||
else:
|
||||
# 计算预期发送时间
|
||||
expected_time = start_time + (play_position / 1000)
|
||||
current_time = time.perf_counter()
|
||||
delay = expected_time - current_time
|
||||
if delay > 0:
|
||||
await asyncio.sleep(delay)
|
||||
|
||||
if conn.conn_from_mqtt_gateway:
|
||||
# 计算时间戳和序列号(使用当前的数据包索引确保连续性)
|
||||
packet_index = pre_buffer_frames + i
|
||||
timestamp, sequence = calculate_timestamp_and_sequence(
|
||||
conn, start_time, packet_index, frame_duration
|
||||
)
|
||||
# 调用通用函数发送带头部的数据包
|
||||
await _send_to_mqtt_gateway(conn, opus_packet, timestamp, sequence)
|
||||
else:
|
||||
# 直接发送opus数据包,不添加头部
|
||||
await conn.websocket.send(opus_packet)
|
||||
|
||||
play_position += frame_duration
|
||||
|
||||
|
||||
async def send_tts_message(conn, state, text=None):
|
||||
"""发送 TTS 状态消息"""
|
||||
if text is None and state == "sentence_start":
|
||||
return
|
||||
message = {"type": "tts", "state": state, "session_id": conn.session_id}
|
||||
if text is not None:
|
||||
message["text"] = textUtils.check_emoji(text)
|
||||
|
||||
# TTS播放结束
|
||||
if state == "stop":
|
||||
# 播放提示音
|
||||
tts_notify = conn.config.get("enable_stop_tts_notify", False)
|
||||
if tts_notify:
|
||||
stop_tts_notify_voice = conn.config.get(
|
||||
"stop_tts_notify_voice", "config/assets/tts_notify.mp3"
|
||||
)
|
||||
audios = audio_to_data(stop_tts_notify_voice, is_opus=True)
|
||||
await sendAudio(conn, audios)
|
||||
# 清除服务端讲话状态
|
||||
conn.clearSpeakStatus()
|
||||
|
||||
# 发送消息到客户端
|
||||
await conn.websocket.send(json.dumps(message))
|
||||
|
||||
|
||||
async def send_stt_message(conn, text):
|
||||
"""发送 STT 状态消息"""
|
||||
end_prompt_str = conn.config.get("end_prompt", {}).get("prompt")
|
||||
if end_prompt_str and end_prompt_str == text:
|
||||
await send_tts_message(conn, "start")
|
||||
return
|
||||
|
||||
# 解析JSON格式,提取实际的用户说话内容
|
||||
display_text = text
|
||||
try:
|
||||
# 尝试解析JSON格式
|
||||
if text.strip().startswith("{") and text.strip().endswith("}"):
|
||||
parsed_data = json.loads(text)
|
||||
if isinstance(parsed_data, dict) and "content" in parsed_data:
|
||||
# 如果是包含说话人信息的JSON格式,只显示content部分
|
||||
display_text = parsed_data["content"]
|
||||
# 保存说话人信息到conn对象
|
||||
if "speaker" in parsed_data:
|
||||
conn.current_speaker = parsed_data["speaker"]
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# 如果不是JSON格式,直接使用原始文本
|
||||
display_text = text
|
||||
stt_text = textUtils.get_string_no_punctuation_or_emoji(display_text)
|
||||
await conn.websocket.send(
|
||||
json.dumps({"type": "stt", "text": stt_text, "session_id": conn.session_id})
|
||||
)
|
||||
conn.client_is_speaking = True
|
||||
await send_tts_message(conn, "start")
|
||||
@@ -0,0 +1,14 @@
|
||||
from core.handle.textMessageHandlerRegistry import TextMessageHandlerRegistry
|
||||
from core.handle.textMessageProcessor import TextMessageProcessor
|
||||
|
||||
TAG = __name__
|
||||
|
||||
# 全局处理器注册表
|
||||
message_registry = TextMessageHandlerRegistry()
|
||||
|
||||
# 创建全局消息处理器实例
|
||||
message_processor = TextMessageProcessor(message_registry)
|
||||
|
||||
async def handleTextMessage(conn, message):
|
||||
"""处理文本消息"""
|
||||
await message_processor.process_message(conn, message)
|
||||
@@ -0,0 +1,16 @@
|
||||
from typing import Dict, Any
|
||||
|
||||
from core.handle.abortHandle import handleAbortMessage
|
||||
from core.handle.textMessageHandler import TextMessageHandler
|
||||
from core.handle.textMessageType import TextMessageType
|
||||
|
||||
|
||||
class AbortTextMessageHandler(TextMessageHandler):
|
||||
"""Abort消息处理器"""
|
||||
|
||||
@property
|
||||
def message_type(self) -> TextMessageType:
|
||||
return TextMessageType.ABORT
|
||||
|
||||
async def handle(self, conn, msg_json: Dict[str, Any]) -> None:
|
||||
await handleAbortMessage(conn)
|
||||
@@ -0,0 +1,16 @@
|
||||
from typing import Dict, Any
|
||||
|
||||
from core.handle.helloHandle import handleHelloMessage
|
||||
from core.handle.textMessageHandler import TextMessageHandler
|
||||
from core.handle.textMessageType import TextMessageType
|
||||
|
||||
|
||||
class HelloTextMessageHandler(TextMessageHandler):
|
||||
"""Hello消息处理器"""
|
||||
|
||||
@property
|
||||
def message_type(self) -> TextMessageType:
|
||||
return TextMessageType.HELLO
|
||||
|
||||
async def handle(self, conn, msg_json: Dict[str, Any]) -> None:
|
||||
await handleHelloMessage(conn, msg_json)
|
||||
@@ -0,0 +1,20 @@
|
||||
import asyncio
|
||||
from typing import Dict, Any
|
||||
|
||||
from core.handle.textMessageHandler import TextMessageHandler
|
||||
from core.handle.textMessageType import TextMessageType
|
||||
from core.providers.tools.device_iot import handleIotStatus, handleIotDescriptors
|
||||
|
||||
|
||||
class IotTextMessageHandler(TextMessageHandler):
|
||||
"""IOT消息处理器"""
|
||||
|
||||
@property
|
||||
def message_type(self) -> TextMessageType:
|
||||
return TextMessageType.IOT
|
||||
|
||||
async def handle(self, conn, msg_json: Dict[str, Any]) -> None:
|
||||
if "descriptors" in msg_json:
|
||||
asyncio.create_task(handleIotDescriptors(conn, msg_json["descriptors"]))
|
||||
if "states" in msg_json:
|
||||
asyncio.create_task(handleIotStatus(conn, msg_json["states"]))
|
||||
@@ -0,0 +1,63 @@
|
||||
import time
|
||||
from typing import Dict, Any
|
||||
|
||||
from core.handle.receiveAudioHandle import handleAudioMessage, startToChat
|
||||
from core.handle.reportHandle import enqueue_asr_report
|
||||
from core.handle.sendAudioHandle import send_stt_message, send_tts_message
|
||||
from core.handle.textMessageHandler import TextMessageHandler
|
||||
from core.handle.textMessageType import TextMessageType
|
||||
from core.utils.util import remove_punctuation_and_length
|
||||
|
||||
TAG = __name__
|
||||
|
||||
class ListenTextMessageHandler(TextMessageHandler):
|
||||
"""Listen消息处理器"""
|
||||
|
||||
@property
|
||||
def message_type(self) -> TextMessageType:
|
||||
return TextMessageType.LISTEN
|
||||
|
||||
async def handle(self, conn, msg_json: Dict[str, Any]) -> None:
|
||||
if "mode" in msg_json:
|
||||
conn.client_listen_mode = msg_json["mode"]
|
||||
conn.logger.bind(tag=TAG).debug(
|
||||
f"客户端拾音模式:{conn.client_listen_mode}"
|
||||
)
|
||||
if msg_json["state"] == "start":
|
||||
conn.client_have_voice = True
|
||||
conn.client_voice_stop = False
|
||||
elif msg_json["state"] == "stop":
|
||||
conn.client_have_voice = True
|
||||
conn.client_voice_stop = True
|
||||
if len(conn.asr_audio) > 0:
|
||||
await handleAudioMessage(conn, b"")
|
||||
elif msg_json["state"] == "detect":
|
||||
conn.client_have_voice = False
|
||||
conn.asr_audio.clear()
|
||||
if "text" in msg_json:
|
||||
conn.last_activity_time = time.time() * 1000
|
||||
original_text = msg_json["text"] # 保留原始文本
|
||||
filtered_len, filtered_text = remove_punctuation_and_length(
|
||||
original_text
|
||||
)
|
||||
|
||||
# 识别是否是唤醒词
|
||||
is_wakeup_words = filtered_text in conn.config.get("wakeup_words")
|
||||
# 是否开启唤醒词回复
|
||||
enable_greeting = conn.config.get("enable_greeting", True)
|
||||
|
||||
if is_wakeup_words and not enable_greeting:
|
||||
# 如果是唤醒词,且关闭了唤醒词回复,就不用回答
|
||||
await send_stt_message(conn, original_text)
|
||||
await send_tts_message(conn, "stop", None)
|
||||
conn.client_is_speaking = False
|
||||
elif is_wakeup_words:
|
||||
conn.just_woken_up = True
|
||||
# 上报纯文字数据(复用ASR上报功能,但不提供音频数据)
|
||||
enqueue_asr_report(conn, "嘿,你好呀", [])
|
||||
await startToChat(conn, "嘿,你好呀")
|
||||
else:
|
||||
# 上报纯文字数据(复用ASR上报功能,但不提供音频数据)
|
||||
enqueue_asr_report(conn, original_text, [])
|
||||
# 否则需要LLM对文字内容进行答复
|
||||
await startToChat(conn, original_text)
|
||||
@@ -0,0 +1,20 @@
|
||||
import asyncio
|
||||
from typing import Dict, Any
|
||||
|
||||
from core.handle.textMessageHandler import TextMessageHandler
|
||||
from core.handle.textMessageType import TextMessageType
|
||||
from core.providers.tools.device_mcp import handle_mcp_message
|
||||
|
||||
|
||||
class McpTextMessageHandler(TextMessageHandler):
|
||||
"""MCP消息处理器"""
|
||||
|
||||
@property
|
||||
def message_type(self) -> TextMessageType:
|
||||
return TextMessageType.MCP
|
||||
|
||||
async def handle(self, conn, msg_json: Dict[str, Any]) -> None:
|
||||
if "payload" in msg_json:
|
||||
asyncio.create_task(
|
||||
handle_mcp_message(conn, conn.mcp_client, msg_json["payload"])
|
||||
)
|
||||
@@ -0,0 +1,92 @@
|
||||
import asyncio
|
||||
import json
|
||||
from typing import Dict, Any
|
||||
|
||||
from core.handle.textMessageHandler import TextMessageHandler
|
||||
from core.handle.textMessageType import TextMessageType
|
||||
from core.providers.tools.device_mcp import handle_mcp_message
|
||||
|
||||
TAG = __name__
|
||||
|
||||
class ServerTextMessageHandler(TextMessageHandler):
|
||||
"""MCP消息处理器"""
|
||||
|
||||
@property
|
||||
def message_type(self) -> TextMessageType:
|
||||
return TextMessageType.SERVER
|
||||
|
||||
async def handle(self, conn, msg_json: Dict[str, Any]) -> None:
|
||||
# 如果配置是从API读取的,则需要验证secret
|
||||
if not conn.read_config_from_api:
|
||||
return
|
||||
# 获取post请求的secret
|
||||
post_secret = msg_json.get("content", {}).get("secret", "")
|
||||
secret = conn.config["manager-api"].get("secret", "")
|
||||
# 如果secret不匹配,则返回
|
||||
if post_secret != secret:
|
||||
await conn.websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "server",
|
||||
"status": "error",
|
||||
"message": "服务器密钥验证失败",
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
# 动态更新配置
|
||||
if msg_json["action"] == "update_config":
|
||||
try:
|
||||
# 更新WebSocketServer的配置
|
||||
if not conn.server:
|
||||
await conn.websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "server",
|
||||
"status": "error",
|
||||
"message": "无法获取服务器实例",
|
||||
"content": {"action": "update_config"},
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if not await conn.server.update_config():
|
||||
await conn.websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "server",
|
||||
"status": "error",
|
||||
"message": "更新服务器配置失败",
|
||||
"content": {"action": "update_config"},
|
||||
}
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# 发送成功响应
|
||||
await conn.websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "server",
|
||||
"status": "success",
|
||||
"message": "配置更新成功",
|
||||
"content": {"action": "update_config"},
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
conn.logger.bind(tag=TAG).error(f"更新配置失败: {str(e)}")
|
||||
await conn.websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "server",
|
||||
"status": "error",
|
||||
"message": f"更新配置失败: {str(e)}",
|
||||
"content": {"action": "update_config"},
|
||||
}
|
||||
)
|
||||
)
|
||||
# 重启服务器
|
||||
elif msg_json["action"] == "restart":
|
||||
await conn.handle_restart(msg_json)
|
||||
@@ -0,0 +1,21 @@
|
||||
from abc import abstractmethod, ABC
|
||||
from typing import Dict, Any
|
||||
|
||||
from core.handle.textMessageType import TextMessageType
|
||||
|
||||
TAG = __name__
|
||||
|
||||
|
||||
class TextMessageHandler(ABC):
|
||||
"""消息处理器抽象基类"""
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, conn, msg_json: Dict[str, Any]) -> None:
|
||||
"""处理消息的抽象方法"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def message_type(self) -> TextMessageType:
|
||||
"""返回处理的消息类型"""
|
||||
pass
|
||||
@@ -0,0 +1,45 @@
|
||||
from typing import Dict, Optional
|
||||
|
||||
from core.handle.textHandler.abortMessageHandler import AbortTextMessageHandler
|
||||
from core.handle.textHandler.helloMessageHandler import HelloTextMessageHandler
|
||||
from core.handle.textHandler.iotMessageHandler import IotTextMessageHandler
|
||||
from core.handle.textHandler.listenMessageHandler import ListenTextMessageHandler
|
||||
from core.handle.textHandler.mcpMessageHandler import McpTextMessageHandler
|
||||
from core.handle.textMessageHandler import TextMessageHandler
|
||||
from core.handle.textHandler.serverMessageHandler import ServerTextMessageHandler
|
||||
|
||||
TAG = __name__
|
||||
|
||||
|
||||
class TextMessageHandlerRegistry:
|
||||
"""消息处理器注册表"""
|
||||
|
||||
def __init__(self):
|
||||
self._handlers: Dict[str, TextMessageHandler] = {}
|
||||
self._register_default_handlers()
|
||||
|
||||
def _register_default_handlers(self) -> None:
|
||||
"""注册默认的消息处理器"""
|
||||
handlers = [
|
||||
HelloTextMessageHandler(),
|
||||
AbortTextMessageHandler(),
|
||||
ListenTextMessageHandler(),
|
||||
IotTextMessageHandler(),
|
||||
McpTextMessageHandler(),
|
||||
ServerTextMessageHandler(),
|
||||
]
|
||||
|
||||
for handler in handlers:
|
||||
self.register_handler(handler)
|
||||
|
||||
def register_handler(self, handler: TextMessageHandler) -> None:
|
||||
"""注册消息处理器"""
|
||||
self._handlers[handler.message_type.value] = handler
|
||||
|
||||
def get_handler(self, message_type: str) -> Optional[TextMessageHandler]:
|
||||
"""获取消息处理器"""
|
||||
return self._handlers.get(message_type)
|
||||
|
||||
def get_supported_types(self) -> list:
|
||||
"""获取支持的消息类型"""
|
||||
return list(self._handlers.keys())
|
||||
@@ -0,0 +1,41 @@
|
||||
import json
|
||||
|
||||
from core.handle.textMessageHandlerRegistry import TextMessageHandlerRegistry
|
||||
|
||||
TAG = __name__
|
||||
|
||||
|
||||
class TextMessageProcessor:
|
||||
"""消息处理器主类"""
|
||||
|
||||
def __init__(self, registry: TextMessageHandlerRegistry):
|
||||
self.registry = registry
|
||||
|
||||
async def process_message(self, conn, message: str) -> None:
|
||||
"""处理消息的主入口"""
|
||||
try:
|
||||
# 解析JSON消息
|
||||
msg_json = json.loads(message)
|
||||
|
||||
# 处理JSON消息
|
||||
if isinstance(msg_json, dict):
|
||||
message_type = msg_json.get("type")
|
||||
|
||||
# 记录日志
|
||||
conn.logger.bind(tag=TAG).info(f"收到{message_type}消息:{message}")
|
||||
|
||||
# 获取并执行处理器
|
||||
handler = self.registry.get_handler(message_type)
|
||||
if handler:
|
||||
await handler.handle(conn, msg_json)
|
||||
else:
|
||||
conn.logger.bind(tag=TAG).error(f"收到未知类型消息:{message}")
|
||||
# 处理纯数字消息
|
||||
elif isinstance(msg_json, int):
|
||||
conn.logger.bind(tag=TAG).info(f"收到数字消息:{message}")
|
||||
await conn.websocket.send(message)
|
||||
|
||||
except json.JSONDecodeError:
|
||||
# 非JSON消息直接转发
|
||||
conn.logger.bind(tag=TAG).error(f"解析到错误的消息:{message}")
|
||||
await conn.websocket.send(message)
|
||||
@@ -0,0 +1,11 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TextMessageType(Enum):
|
||||
"""消息类型枚举"""
|
||||
HELLO = "hello"
|
||||
ABORT = "abort"
|
||||
LISTEN = "listen"
|
||||
IOT = "iot"
|
||||
MCP = "mcp"
|
||||
SERVER = "server"
|
||||
Reference in New Issue
Block a user