仿生人AI服务端
This commit is contained in:
@@ -0,0 +1,16 @@
|
||||
from aiohttp import web
|
||||
from config.logger import setup_logging
|
||||
|
||||
|
||||
class BaseHandler:
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.logger = setup_logging()
|
||||
|
||||
def _add_cors_headers(self, response):
|
||||
"""添加CORS头信息"""
|
||||
response.headers["Access-Control-Allow-Headers"] = (
|
||||
"client-id, content-type, device-id"
|
||||
)
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
@@ -0,0 +1,201 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import time
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
from core.api.base_handler import BaseHandler
|
||||
from core.auth import AuthManager
|
||||
from core.utils.util import get_local_ip
|
||||
|
||||
TAG = __name__
|
||||
|
||||
|
||||
class OTAHandler(BaseHandler):
|
||||
def __init__(self, config: dict):
|
||||
super().__init__(config)
|
||||
auth_config = config["server"].get("auth", {})
|
||||
self.auth_enable = auth_config.get("enabled", False)
|
||||
# 设备白名单
|
||||
self.allowed_devices = set(auth_config.get("allowed_devices", []))
|
||||
secret_key = config["server"]["auth_key"]
|
||||
expire_seconds = auth_config.get("expire_seconds")
|
||||
self.auth = AuthManager(secret_key=secret_key, expire_seconds=expire_seconds)
|
||||
|
||||
def generate_password_signature(self, content: str, secret_key: str) -> str:
|
||||
"""生成MQTT密码签名
|
||||
|
||||
Args:
|
||||
content: 签名内容 (clientId + '|' + username)
|
||||
secret_key: 密钥
|
||||
|
||||
Returns:
|
||||
str: Base64编码的HMAC-SHA256签名
|
||||
"""
|
||||
try:
|
||||
hmac_obj = hmac.new(
|
||||
secret_key.encode("utf-8"), content.encode("utf-8"), hashlib.sha256
|
||||
)
|
||||
signature = hmac_obj.digest()
|
||||
return base64.b64encode(signature).decode("utf-8")
|
||||
except Exception as e:
|
||||
self.logger.bind(tag=TAG).error(f"生成MQTT密码签名失败: {e}")
|
||||
return ""
|
||||
|
||||
def _get_websocket_url(self, local_ip: str, port: int) -> str:
|
||||
"""获取websocket地址
|
||||
|
||||
Args:
|
||||
local_ip: 本地IP地址
|
||||
port: 端口号
|
||||
|
||||
Returns:
|
||||
str: websocket地址
|
||||
"""
|
||||
server_config = self.config["server"]
|
||||
websocket_config = server_config.get("websocket", "")
|
||||
|
||||
if "你的" not in websocket_config:
|
||||
return websocket_config
|
||||
else:
|
||||
return f"ws://{local_ip}:{port}/xiaozhi/v1/"
|
||||
|
||||
async def handle_post(self, request):
|
||||
"""处理 OTA POST 请求"""
|
||||
try:
|
||||
data = await request.text()
|
||||
self.logger.bind(tag=TAG).debug(f"OTA请求方法: {request.method}")
|
||||
self.logger.bind(tag=TAG).debug(f"OTA请求头: {request.headers}")
|
||||
self.logger.bind(tag=TAG).debug(f"OTA请求数据: {data}")
|
||||
|
||||
device_id = request.headers.get("device-id", "")
|
||||
if device_id:
|
||||
self.logger.bind(tag=TAG).info(f"OTA请求设备ID: {device_id}")
|
||||
else:
|
||||
raise Exception("OTA请求设备ID为空")
|
||||
|
||||
client_id = request.headers.get("client-id", "")
|
||||
if client_id:
|
||||
self.logger.bind(tag=TAG).info(f"OTA请求ClientID: {client_id}")
|
||||
else:
|
||||
raise Exception("OTA请求ClientID为空")
|
||||
|
||||
data_json = json.loads(data)
|
||||
|
||||
server_config = self.config["server"]
|
||||
port = int(server_config.get("port", 8000))
|
||||
local_ip = get_local_ip()
|
||||
|
||||
return_json = {
|
||||
"server_time": {
|
||||
"timestamp": int(round(time.time() * 1000)),
|
||||
"timezone_offset": server_config.get("timezone_offset", 8) * 60,
|
||||
},
|
||||
"firmware": {
|
||||
"version": data_json["application"].get("version", "1.0.0"),
|
||||
"url": "",
|
||||
},
|
||||
}
|
||||
|
||||
mqtt_gateway_endpoint = server_config.get("mqtt_gateway")
|
||||
|
||||
if mqtt_gateway_endpoint: # 如果配置了非空字符串
|
||||
# 尝试从请求数据中获取设备型号
|
||||
device_model = "default"
|
||||
try:
|
||||
if "device" in data_json and isinstance(data_json["device"], dict):
|
||||
device_model = data_json["device"].get("model", "default")
|
||||
elif "model" in data_json:
|
||||
device_model = data_json["model"]
|
||||
group_id = f"GID_{device_model}".replace(":", "_").replace(" ", "_")
|
||||
except Exception as e:
|
||||
self.logger.bind(tag=TAG).error(f"获取设备型号失败: {e}")
|
||||
group_id = "GID_default"
|
||||
|
||||
mac_address_safe = device_id.replace(":", "_")
|
||||
mqtt_client_id = f"{group_id}@@@{mac_address_safe}@@@{mac_address_safe}"
|
||||
|
||||
# 构建用户数据
|
||||
user_data = {"ip": "unknown"}
|
||||
try:
|
||||
user_data_json = json.dumps(user_data)
|
||||
username = base64.b64encode(user_data_json.encode("utf-8")).decode(
|
||||
"utf-8"
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.bind(tag=TAG).error(f"生成用户名失败: {e}")
|
||||
username = ""
|
||||
|
||||
# 生成密码
|
||||
password = ""
|
||||
signature_key = server_config.get("mqtt_signature_key", "")
|
||||
if signature_key:
|
||||
password = self.generate_password_signature(
|
||||
mqtt_client_id + "|" + username, signature_key
|
||||
)
|
||||
if not password:
|
||||
password = "" # 签名失败则留空,由设备决定是否允许无密码
|
||||
else:
|
||||
self.logger.bind(tag=TAG).warning("缺少MQTT签名密钥,密码留空")
|
||||
|
||||
# 构建MQTT配置(直接使用 mqtt_gateway 字符串)
|
||||
return_json["mqtt"] = {
|
||||
"endpoint": mqtt_gateway_endpoint,
|
||||
"client_id": mqtt_client_id,
|
||||
"username": username,
|
||||
"password": password,
|
||||
"publish_topic": "device-server",
|
||||
"subscribe_topic": f"devices/p2p/{mac_address_safe}",
|
||||
}
|
||||
self.logger.bind(tag=TAG).info(f"为设备 {device_id} 下发MQTT网关配置")
|
||||
|
||||
else: # 未配置 mqtt_gateway,下发 WebSocket
|
||||
# 如果开启了认证,则进行认证校验
|
||||
token = ""
|
||||
if self.auth_enable:
|
||||
if self.allowed_devices:
|
||||
if device_id not in self.allowed_devices:
|
||||
token = self.auth.generate_token(client_id, device_id)
|
||||
else:
|
||||
token = self.auth.generate_token(client_id, device_id)
|
||||
return_json["websocket"] = {
|
||||
# "url": self._get_websocket_url(local_ip, port),
|
||||
"url": "wss://ai.ronsunny.cn:8090/aibotws/xiaozhi/v1/",
|
||||
"token": token,
|
||||
}
|
||||
# self.logger.bind(tag=TAG).info(
|
||||
# f"未配置MQTT网关,为设备 {device_id} 下发WebSocket配置"
|
||||
# )
|
||||
self.logger.bind(tag=TAG).info(f"{return_json}")
|
||||
|
||||
response = web.Response(
|
||||
text=json.dumps(return_json, separators=(",", ":")),
|
||||
content_type="application/json",
|
||||
)
|
||||
except Exception as e:
|
||||
return_json = {"success": False, "message": "request error."}
|
||||
response = web.Response(
|
||||
text=json.dumps(return_json, separators=(",", ":")),
|
||||
content_type="application/json",
|
||||
)
|
||||
finally:
|
||||
self._add_cors_headers(response)
|
||||
return response
|
||||
|
||||
async def handle_get(self, request):
|
||||
"""处理 OTA GET 请求"""
|
||||
try:
|
||||
server_config = self.config["server"]
|
||||
local_ip = get_local_ip()
|
||||
port = int(server_config.get("port", 8000))
|
||||
websocket_url = self._get_websocket_url(local_ip, port)
|
||||
message = f"OTA接口运行正常,向设备发送的websocket地址是:{websocket_url}"
|
||||
response = web.Response(text=message, content_type="text/plain")
|
||||
except Exception as e:
|
||||
self.logger.bind(tag=TAG).error(f"OTA GET请求异常: {e}")
|
||||
response = web.Response(text="OTA接口异常", content_type="text/plain")
|
||||
finally:
|
||||
self._add_cors_headers(response)
|
||||
return response
|
||||
@@ -0,0 +1,182 @@
|
||||
import json
|
||||
import copy
|
||||
from aiohttp import web
|
||||
from config.logger import setup_logging
|
||||
from core.utils.util import get_vision_url, is_valid_image_file
|
||||
from core.utils.vllm import create_instance
|
||||
from config.config_loader import get_private_config_from_api
|
||||
from core.utils.auth import AuthToken
|
||||
import base64
|
||||
from typing import Tuple, Optional
|
||||
from plugins_func.register import Action
|
||||
|
||||
TAG = __name__
|
||||
|
||||
# 设置最大文件大小为5MB
|
||||
MAX_FILE_SIZE = 5 * 1024 * 1024
|
||||
|
||||
|
||||
class VisionHandler:
|
||||
def __init__(self, config: dict):
|
||||
self.config = config
|
||||
self.logger = setup_logging()
|
||||
# 初始化认证工具
|
||||
self.auth = AuthToken(config["server"]["auth_key"])
|
||||
|
||||
def _create_error_response(self, message: str) -> dict:
|
||||
"""创建统一的错误响应格式"""
|
||||
return {"success": False, "message": message}
|
||||
|
||||
def _verify_auth_token(self, request) -> Tuple[bool, Optional[str]]:
|
||||
"""验证认证token"""
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return False, None
|
||||
|
||||
token = auth_header[7:] # 移除"Bearer "前缀
|
||||
return self.auth.verify_token(token)
|
||||
|
||||
async def handle_post(self, request):
|
||||
"""处理 MCP Vision POST 请求"""
|
||||
response = None # 初始化response变量
|
||||
try:
|
||||
# 验证token
|
||||
is_valid, token_device_id = self._verify_auth_token(request)
|
||||
if not is_valid:
|
||||
response = web.Response(
|
||||
text=json.dumps(
|
||||
self._create_error_response("无效的认证token或token已过期")
|
||||
),
|
||||
content_type="application/json",
|
||||
status=401,
|
||||
)
|
||||
return response
|
||||
|
||||
# 获取请求头信息
|
||||
device_id = request.headers.get("Device-Id", "")
|
||||
client_id = request.headers.get("Client-Id", "")
|
||||
if device_id != token_device_id:
|
||||
raise ValueError("设备ID与token不匹配")
|
||||
# 解析multipart/form-data请求
|
||||
reader = await request.multipart()
|
||||
|
||||
# 读取question字段
|
||||
question_field = await reader.next()
|
||||
if question_field is None:
|
||||
raise ValueError("缺少问题字段")
|
||||
question = await question_field.text()
|
||||
self.logger.bind(tag=TAG).debug(f"Question: {question}")
|
||||
|
||||
# 读取图片文件
|
||||
image_field = await reader.next()
|
||||
if image_field is None:
|
||||
raise ValueError("缺少图片文件")
|
||||
|
||||
# 读取图片数据
|
||||
image_data = await image_field.read()
|
||||
if not image_data:
|
||||
raise ValueError("图片数据为空")
|
||||
|
||||
# 检查文件大小
|
||||
if len(image_data) > MAX_FILE_SIZE:
|
||||
raise ValueError(
|
||||
f"图片大小超过限制,最大允许{MAX_FILE_SIZE/1024/1024}MB"
|
||||
)
|
||||
|
||||
# 检查文件格式
|
||||
if not is_valid_image_file(image_data):
|
||||
raise ValueError(
|
||||
"不支持的文件格式,请上传有效的图片文件(支持JPEG、PNG、GIF、BMP、TIFF、WEBP格式)"
|
||||
)
|
||||
|
||||
# 将图片转换为base64编码
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
# 如果开启了智控台,则从智控台获取模型配置
|
||||
current_config = copy.deepcopy(self.config)
|
||||
read_config_from_api = current_config.get("read_config_from_api", False)
|
||||
if read_config_from_api:
|
||||
current_config = get_private_config_from_api(
|
||||
current_config,
|
||||
device_id,
|
||||
client_id,
|
||||
)
|
||||
|
||||
select_vllm_module = current_config["selected_module"].get("VLLM")
|
||||
if not select_vllm_module:
|
||||
raise ValueError("您还未设置默认的视觉分析模块")
|
||||
|
||||
vllm_type = (
|
||||
select_vllm_module
|
||||
if "type" not in current_config["VLLM"][select_vllm_module]
|
||||
else current_config["VLLM"][select_vllm_module]["type"]
|
||||
)
|
||||
|
||||
if not vllm_type:
|
||||
raise ValueError(f"无法找到VLLM模块对应的供应器{vllm_type}")
|
||||
|
||||
vllm = create_instance(
|
||||
vllm_type, current_config["VLLM"][select_vllm_module]
|
||||
)
|
||||
|
||||
result = vllm.response(question, image_base64)
|
||||
|
||||
return_json = {
|
||||
"success": True,
|
||||
"action": Action.RESPONSE.name,
|
||||
"response": result,
|
||||
}
|
||||
|
||||
response = web.Response(
|
||||
text=json.dumps(return_json, separators=(",", ":")),
|
||||
content_type="application/json",
|
||||
)
|
||||
except ValueError as e:
|
||||
self.logger.bind(tag=TAG).error(f"MCP Vision POST请求异常: {e}")
|
||||
return_json = self._create_error_response(str(e))
|
||||
response = web.Response(
|
||||
text=json.dumps(return_json, separators=(",", ":")),
|
||||
content_type="application/json",
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.bind(tag=TAG).error(f"MCP Vision POST请求异常: {e}")
|
||||
return_json = self._create_error_response("处理请求时发生错误")
|
||||
response = web.Response(
|
||||
text=json.dumps(return_json, separators=(",", ":")),
|
||||
content_type="application/json",
|
||||
)
|
||||
finally:
|
||||
if response:
|
||||
self._add_cors_headers(response)
|
||||
return response
|
||||
|
||||
async def handle_get(self, request):
|
||||
"""处理 MCP Vision GET 请求"""
|
||||
try:
|
||||
vision_explain = get_vision_url(self.config)
|
||||
if vision_explain and len(vision_explain) > 0 and "null" != vision_explain:
|
||||
message = (
|
||||
f"MCP Vision 接口运行正常,视觉解释接口地址是:{vision_explain}"
|
||||
)
|
||||
else:
|
||||
message = "MCP Vision 接口运行不正常,请打开data目录下的.config.yaml文件,找到【server.vision_explain】,设置好地址"
|
||||
|
||||
response = web.Response(text=message, content_type="text/plain")
|
||||
except Exception as e:
|
||||
self.logger.bind(tag=TAG).error(f"MCP Vision GET请求异常: {e}")
|
||||
return_json = self._create_error_response("服务器内部错误")
|
||||
response = web.Response(
|
||||
text=json.dumps(return_json, separators=(",", ":")),
|
||||
content_type="application/json",
|
||||
)
|
||||
finally:
|
||||
self._add_cors_headers(response)
|
||||
return response
|
||||
|
||||
def _add_cors_headers(self, response):
|
||||
"""添加CORS头信息"""
|
||||
response.headers["Access-Control-Allow-Headers"] = (
|
||||
"client-id, content-type, device-id"
|
||||
)
|
||||
response.headers["Access-Control-Allow-Credentials"] = "true"
|
||||
response.headers["Access-Control-Allow-Origin"] = "*"
|
||||
Reference in New Issue
Block a user