仿生人AI服务端
This commit is contained in:
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,148 @@
|
||||
import os
|
||||
import yaml
|
||||
from collections.abc import Mapping
|
||||
from config.manage_api_client import init_service, get_server_config, get_agent_models
|
||||
|
||||
|
||||
def get_project_dir():
|
||||
"""获取项目根目录"""
|
||||
return os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + "/"
|
||||
|
||||
|
||||
def read_config(config_path):
|
||||
with open(config_path, "r", encoding="utf-8") as file:
|
||||
config = yaml.safe_load(file)
|
||||
return config
|
||||
|
||||
|
||||
def load_config():
|
||||
"""加载配置文件"""
|
||||
from core.utils.cache.manager import cache_manager, CacheType
|
||||
|
||||
# 检查缓存
|
||||
cached_config = cache_manager.get(CacheType.CONFIG, "main_config")
|
||||
if cached_config is not None:
|
||||
return cached_config
|
||||
|
||||
default_config_path = get_project_dir() + "config.yaml"
|
||||
custom_config_path = get_project_dir() + "data/.config.yaml"
|
||||
|
||||
# 加载默认配置
|
||||
default_config = read_config(default_config_path)
|
||||
custom_config = read_config(custom_config_path)
|
||||
|
||||
if custom_config.get("manager-api", {}).get("url"):
|
||||
config = get_config_from_api(custom_config)
|
||||
else:
|
||||
# 合并配置
|
||||
config = merge_configs(default_config, custom_config)
|
||||
# 初始化目录
|
||||
ensure_directories(config)
|
||||
|
||||
# 缓存配置
|
||||
cache_manager.set(CacheType.CONFIG, "main_config", config)
|
||||
return config
|
||||
|
||||
|
||||
def get_config_from_api(config):
|
||||
"""从Java API获取配置"""
|
||||
# 初始化API客户端
|
||||
init_service(config)
|
||||
|
||||
# 获取服务器配置
|
||||
config_data = get_server_config()
|
||||
if config_data is None:
|
||||
raise Exception("Failed to fetch server config from API")
|
||||
|
||||
config_data["read_config_from_api"] = True
|
||||
config_data["manager-api"] = {
|
||||
"url": config["manager-api"].get("url", ""),
|
||||
"secret": config["manager-api"].get("secret", ""),
|
||||
}
|
||||
# server的配置以本地为准
|
||||
if config.get("server"):
|
||||
config_data["server"] = {
|
||||
"ip": config["server"].get("ip", ""),
|
||||
"port": config["server"].get("port", ""),
|
||||
"http_port": config["server"].get("http_port", ""),
|
||||
"vision_explain": config["server"].get("vision_explain", ""),
|
||||
"auth_key": config["server"].get("auth_key", ""),
|
||||
}
|
||||
return config_data
|
||||
|
||||
|
||||
def get_private_config_from_api(config, device_id, client_id):
|
||||
"""从Java API获取私有配置"""
|
||||
return get_agent_models(device_id, client_id, config["selected_module"])
|
||||
|
||||
|
||||
def ensure_directories(config):
|
||||
"""确保所有配置路径存在"""
|
||||
dirs_to_create = set()
|
||||
project_dir = get_project_dir() # 获取项目根目录
|
||||
# 日志文件目录
|
||||
log_dir = config.get("log", {}).get("log_dir", "tmp")
|
||||
dirs_to_create.add(os.path.join(project_dir, log_dir))
|
||||
|
||||
# ASR/TTS模块输出目录
|
||||
for module in ["ASR", "TTS"]:
|
||||
if config.get(module) is None:
|
||||
continue
|
||||
for provider in config.get(module, {}).values():
|
||||
output_dir = provider.get("output_dir", "")
|
||||
if output_dir:
|
||||
dirs_to_create.add(output_dir)
|
||||
|
||||
# 根据selected_module创建模型目录
|
||||
selected_modules = config.get("selected_module", {})
|
||||
for module_type in ["ASR", "LLM", "TTS"]:
|
||||
selected_provider = selected_modules.get(module_type)
|
||||
if not selected_provider:
|
||||
continue
|
||||
if config.get(module) is None:
|
||||
continue
|
||||
if config.get(selected_provider) is None:
|
||||
continue
|
||||
provider_config = config.get(module_type, {}).get(selected_provider, {})
|
||||
output_dir = provider_config.get("output_dir")
|
||||
if output_dir:
|
||||
full_model_dir = os.path.join(project_dir, output_dir)
|
||||
dirs_to_create.add(full_model_dir)
|
||||
|
||||
# 统一创建目录(保留原data目录创建)
|
||||
for dir_path in dirs_to_create:
|
||||
try:
|
||||
os.makedirs(dir_path, exist_ok=True)
|
||||
except PermissionError:
|
||||
print(f"警告:无法创建目录 {dir_path},请检查写入权限")
|
||||
|
||||
|
||||
def merge_configs(default_config, custom_config):
|
||||
"""
|
||||
递归合并配置,custom_config优先级更高
|
||||
|
||||
Args:
|
||||
default_config: 默认配置
|
||||
custom_config: 用户自定义配置
|
||||
|
||||
Returns:
|
||||
合并后的配置
|
||||
"""
|
||||
if not isinstance(default_config, Mapping) or not isinstance(
|
||||
custom_config, Mapping
|
||||
):
|
||||
return custom_config
|
||||
|
||||
merged = dict(default_config)
|
||||
|
||||
for key, value in custom_config.items():
|
||||
if (
|
||||
key in merged
|
||||
and isinstance(merged[key], Mapping)
|
||||
and isinstance(value, Mapping)
|
||||
):
|
||||
merged[key] = merge_configs(merged[key], value)
|
||||
else:
|
||||
merged[key] = value
|
||||
|
||||
return merged
|
||||
@@ -0,0 +1,114 @@
|
||||
import os
|
||||
import sys
|
||||
from loguru import logger
|
||||
from config.config_loader import load_config
|
||||
from config.settings import check_config_file
|
||||
from datetime import datetime
|
||||
|
||||
SERVER_VERSION = "0.8.5"
|
||||
_logger_initialized = False
|
||||
|
||||
|
||||
def get_module_abbreviation(module_name, module_dict):
|
||||
"""获取模块名称的缩写,如果为空则返回00
|
||||
如果名称中包含下划线,则返回下划线后面的前两个字符
|
||||
"""
|
||||
module_value = module_dict.get(module_name, "")
|
||||
if not module_value:
|
||||
return "00"
|
||||
if "_" in module_value:
|
||||
parts = module_value.split("_")
|
||||
return parts[-1][:2] if parts[-1] else "00"
|
||||
return module_value[:2]
|
||||
|
||||
|
||||
def build_module_string(selected_module):
|
||||
"""构建模块字符串"""
|
||||
return (
|
||||
get_module_abbreviation("VAD", selected_module)
|
||||
+ get_module_abbreviation("ASR", selected_module)
|
||||
+ get_module_abbreviation("LLM", selected_module)
|
||||
+ get_module_abbreviation("TTS", selected_module)
|
||||
+ get_module_abbreviation("Memory", selected_module)
|
||||
+ get_module_abbreviation("Intent", selected_module)
|
||||
+ get_module_abbreviation("VLLM", selected_module)
|
||||
)
|
||||
|
||||
|
||||
def formatter(record):
|
||||
"""为没有 tag 的日志添加默认值,并处理动态模块字符串"""
|
||||
record["extra"].setdefault("tag", record["name"])
|
||||
# 如果没有设置 selected_module,使用默认值
|
||||
record["extra"].setdefault("selected_module", "00000000000000")
|
||||
# 将 selected_module 从 extra 提取到顶级,以支持 {selected_module} 格式
|
||||
record["selected_module"] = record["extra"]["selected_module"]
|
||||
return record["message"]
|
||||
|
||||
|
||||
def setup_logging():
|
||||
check_config_file()
|
||||
"""从配置文件中读取日志配置,并设置日志输出格式和级别"""
|
||||
config = load_config()
|
||||
log_config = config["log"]
|
||||
global _logger_initialized
|
||||
|
||||
# 第一次初始化时配置日志
|
||||
if not _logger_initialized:
|
||||
# 使用默认的模块字符串进行初始化
|
||||
logger.configure(
|
||||
extra={
|
||||
"selected_module": log_config.get("selected_module", "00000000000000"),
|
||||
}
|
||||
)
|
||||
|
||||
log_format = log_config.get(
|
||||
"log_format",
|
||||
"<green>{time:YYMMDD HH:mm:ss}</green>[{version}_{extra[selected_module]}][<light-blue>{extra[tag]}</light-blue>]-<level>{level}</level>-<light-green>{message}</light-green>",
|
||||
)
|
||||
log_format_file = log_config.get(
|
||||
"log_format_file",
|
||||
"{time:YYYY-MM-DD HH:mm:ss} - {version}_{extra[selected_module]} - {name} - {level} - {extra[tag]} - {message}",
|
||||
)
|
||||
log_format = log_format.replace("{version}", SERVER_VERSION)
|
||||
log_format_file = log_format_file.replace("{version}", SERVER_VERSION)
|
||||
|
||||
log_level = log_config.get("log_level", "INFO")
|
||||
log_dir = log_config.get("log_dir", "tmp")
|
||||
log_file = log_config.get("log_file", "server.log")
|
||||
data_dir = log_config.get("data_dir", "data")
|
||||
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
|
||||
# 配置日志输出
|
||||
logger.remove()
|
||||
|
||||
# 输出到控制台
|
||||
logger.add(sys.stdout, format=log_format, level=log_level, filter=formatter)
|
||||
|
||||
# 输出到文件 - 统一目录,按大小轮转
|
||||
# 日志文件完整路径
|
||||
log_file_path = os.path.join(log_dir, log_file)
|
||||
|
||||
# 添加日志处理器
|
||||
logger.add(
|
||||
log_file_path,
|
||||
format=log_format_file,
|
||||
level=log_level,
|
||||
filter=formatter,
|
||||
rotation="10 MB", # 每个文件最大10MB
|
||||
retention="30 days", # 保留30天
|
||||
compression=None,
|
||||
encoding="utf-8",
|
||||
enqueue=True, # 异步安全
|
||||
backtrace=True,
|
||||
diagnose=True,
|
||||
)
|
||||
_logger_initialized = True # 标记为已初始化
|
||||
|
||||
return logger
|
||||
|
||||
|
||||
def create_connection_logger(selected_module_str):
|
||||
"""为连接创建独立的日志器,绑定特定的模块字符串"""
|
||||
return logger.bind(selected_module=selected_module_str)
|
||||
@@ -0,0 +1,193 @@
|
||||
import os
|
||||
import time
|
||||
import base64
|
||||
from typing import Optional, Dict
|
||||
|
||||
import httpx
|
||||
|
||||
TAG = __name__
|
||||
|
||||
|
||||
class DeviceNotFoundException(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DeviceBindException(Exception):
|
||||
def __init__(self, bind_code):
|
||||
self.bind_code = bind_code
|
||||
super().__init__(f"设备绑定异常,绑定码: {bind_code}")
|
||||
|
||||
|
||||
class ManageApiClient:
|
||||
_instance = None
|
||||
_client = None
|
||||
_secret = None
|
||||
|
||||
def __new__(cls, config):
|
||||
"""单例模式确保全局唯一实例,并支持传入配置参数"""
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._init_client(config)
|
||||
return cls._instance
|
||||
|
||||
@classmethod
|
||||
def _init_client(cls, config):
|
||||
"""初始化持久化连接池"""
|
||||
cls.config = config.get("manager-api")
|
||||
|
||||
if not cls.config:
|
||||
raise Exception("manager-api配置错误")
|
||||
|
||||
if not cls.config.get("url") or not cls.config.get("secret"):
|
||||
raise Exception("manager-api的url或secret配置错误")
|
||||
|
||||
if "你" in cls.config.get("secret"):
|
||||
raise Exception("请先配置manager-api的secret")
|
||||
|
||||
cls._secret = cls.config.get("secret")
|
||||
cls.max_retries = cls.config.get("max_retries", 6) # 最大重试次数
|
||||
cls.retry_delay = cls.config.get("retry_delay", 10) # 初始重试延迟(秒)
|
||||
# NOTE(goody): 2025/4/16 http相关资源统一管理,后续可以增加线程池或者超时
|
||||
# 后续也可以统一配置apiToken之类的走通用的Auth
|
||||
cls._client = httpx.Client(
|
||||
base_url=cls.config.get("url"),
|
||||
headers={
|
||||
"User-Agent": f"PythonClient/2.0 (PID:{os.getpid()})",
|
||||
"Accept": "application/json",
|
||||
"Authorization": "Bearer " + cls._secret,
|
||||
},
|
||||
timeout=cls.config.get("timeout", 30), # 默认超时时间30秒
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _request(cls, method: str, endpoint: str, **kwargs) -> Dict:
|
||||
"""发送单次HTTP请求并处理响应"""
|
||||
endpoint = endpoint.lstrip("/")
|
||||
response = cls._client.request(method, endpoint, **kwargs)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json()
|
||||
|
||||
# 处理API返回的业务错误
|
||||
if result.get("code") == 10041:
|
||||
raise DeviceNotFoundException(result.get("msg"))
|
||||
elif result.get("code") == 10042:
|
||||
raise DeviceBindException(result.get("msg"))
|
||||
elif result.get("code") != 0:
|
||||
raise Exception(f"API返回错误: {result.get('msg', '未知错误')}")
|
||||
|
||||
# 返回成功数据
|
||||
return result.get("data") if result.get("code") == 0 else None
|
||||
|
||||
@classmethod
|
||||
def _should_retry(cls, exception: Exception) -> bool:
|
||||
"""判断异常是否应该重试"""
|
||||
# 网络连接相关错误
|
||||
if isinstance(
|
||||
exception, (httpx.ConnectError, httpx.TimeoutException, httpx.NetworkError)
|
||||
):
|
||||
return True
|
||||
|
||||
# HTTP状态码错误
|
||||
if isinstance(exception, httpx.HTTPStatusError):
|
||||
status_code = exception.response.status_code
|
||||
return status_code in [408, 429, 500, 502, 503, 504]
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _execute_request(cls, method: str, endpoint: str, **kwargs) -> Dict:
|
||||
"""带重试机制的请求执行器"""
|
||||
retry_count = 0
|
||||
|
||||
while retry_count <= cls.max_retries:
|
||||
try:
|
||||
# 执行请求
|
||||
return cls._request(method, endpoint, **kwargs)
|
||||
except Exception as e:
|
||||
# 判断是否应该重试
|
||||
if retry_count < cls.max_retries and cls._should_retry(e):
|
||||
retry_count += 1
|
||||
print(
|
||||
f"{method} {endpoint} 请求失败,将在 {cls.retry_delay:.1f} 秒后进行第 {retry_count} 次重试"
|
||||
)
|
||||
time.sleep(cls.retry_delay)
|
||||
continue
|
||||
else:
|
||||
# 不重试,直接抛出异常
|
||||
raise
|
||||
|
||||
@classmethod
|
||||
def safe_close(cls):
|
||||
"""安全关闭连接池"""
|
||||
if cls._client:
|
||||
cls._client.close()
|
||||
cls._instance = None
|
||||
|
||||
|
||||
def get_server_config() -> Optional[Dict]:
|
||||
"""获取服务器基础配置"""
|
||||
return ManageApiClient._instance._execute_request("POST", "/config/server-base")
|
||||
|
||||
|
||||
def get_agent_models(
|
||||
mac_address: str, client_id: str, selected_module: Dict
|
||||
) -> Optional[Dict]:
|
||||
"""获取代理模型配置"""
|
||||
return ManageApiClient._instance._execute_request(
|
||||
"POST",
|
||||
"/config/agent-models",
|
||||
json={
|
||||
"macAddress": mac_address,
|
||||
"clientId": client_id,
|
||||
"selectedModule": selected_module,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def save_mem_local_short(mac_address: str, short_momery: str) -> Optional[Dict]:
|
||||
try:
|
||||
return ManageApiClient._instance._execute_request(
|
||||
"PUT",
|
||||
f"/agent/saveMemory/" + mac_address,
|
||||
json={
|
||||
"summaryMemory": short_momery,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"存储短期记忆到服务器失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def report(
|
||||
mac_address: str, session_id: str, chat_type: int, content: str, audio, report_time
|
||||
) -> Optional[Dict]:
|
||||
"""带熔断的业务方法示例"""
|
||||
if not content or not ManageApiClient._instance:
|
||||
return None
|
||||
try:
|
||||
return ManageApiClient._instance._execute_request(
|
||||
"POST",
|
||||
f"/agent/chat-history/report",
|
||||
json={
|
||||
"macAddress": mac_address,
|
||||
"sessionId": session_id,
|
||||
"chatType": chat_type,
|
||||
"content": content,
|
||||
"reportTime": report_time,
|
||||
"audioBase64": (
|
||||
base64.b64encode(audio).decode("utf-8") if audio else None
|
||||
),
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"TTS上报失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def init_service(config):
|
||||
ManageApiClient(config)
|
||||
|
||||
|
||||
def manage_api_http_safe_close():
|
||||
ManageApiClient.safe_close()
|
||||
@@ -0,0 +1,33 @@
|
||||
import os
|
||||
from config.config_loader import read_config, get_project_dir, load_config
|
||||
|
||||
|
||||
default_config_file = "config.yaml"
|
||||
config_file_valid = False
|
||||
|
||||
|
||||
def check_config_file():
|
||||
global config_file_valid
|
||||
if config_file_valid:
|
||||
return
|
||||
"""
|
||||
简化的配置检查,仅提示用户配置文件的使用情况
|
||||
"""
|
||||
custom_config_file = get_project_dir() + "data/." + default_config_file
|
||||
if not os.path.exists(custom_config_file):
|
||||
raise FileNotFoundError(
|
||||
"找不到data/.config.yaml文件,请按教程确认该配置文件是否存在"
|
||||
)
|
||||
|
||||
# 检查是否从API读取配置
|
||||
config = load_config()
|
||||
if config.get("read_config_from_api", False):
|
||||
print("从API读取配置")
|
||||
old_config_origin = read_config(custom_config_file)
|
||||
if old_config_origin.get("selected_module") is not None:
|
||||
error_msg = "您的配置文件好像既包含智控台的配置又包含本地配置:\n"
|
||||
error_msg += "\n建议您:\n"
|
||||
error_msg += "1、将根目录的config_from_api.yaml文件复制到data下,重命名为.config.yaml\n"
|
||||
error_msg += "2、按教程配置好接口地址和密钥\n"
|
||||
raise ValueError(error_msg)
|
||||
config_file_valid = True
|
||||
Reference in New Issue
Block a user