仿生人AI服务端

This commit is contained in:
BBIT-Kai
2025-11-05 18:07:21 +08:00
parent 7ff894e875
commit 4c2ae9e809
190 changed files with 27776 additions and 0 deletions
+24
View File
@@ -0,0 +1,24 @@
import importlib
import logging
import os
import sys
import time
import wave
import uuid
from abc import ABC, abstractmethod
from typing import Optional, Tuple, List
from core.providers.asr.base import ASRProviderBase
from config.logger import setup_logging
TAG = __name__
logger = setup_logging()
def create_instance(class_name: str, *args, **kwargs) -> ASRProviderBase:
"""工厂方法创建ASR实例"""
if os.path.exists(os.path.join('core', 'providers', 'asr', f'{class_name}.py')):
lib_name = f'core.providers.asr.{class_name}'
if lib_name not in sys.modules:
sys.modules[lib_name] = importlib.import_module(f'{lib_name}')
return sys.modules[lib_name].ASRProvider(*args, **kwargs)
raise ValueError(f"不支持的ASR类型: {class_name},请检查该配置的type是否设置正确")
+126
View File
@@ -0,0 +1,126 @@
import jwt
import time
import json
import os
from datetime import datetime, timedelta, timezone
from typing import Tuple, Optional
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
from cryptography.hazmat.primitives import padding
from cryptography.hazmat.backends import default_backend
import base64
class AuthToken:
def __init__(self, secret_key: str):
self.secret_key = secret_key.encode() # 转换为字节
# 从密钥派生固定长度的加密密钥 (32字节 for AES-256)
self.encryption_key = self._derive_key(32)
def _derive_key(self, length: int) -> bytes:
"""派生固定长度的密钥"""
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
# 使用固定盐值(实际生产环境应使用随机盐)
salt = b"fixed_salt_placeholder" # 生产环境应改为随机生成
kdf = PBKDF2HMAC(
algorithm=hashes.SHA256(),
length=length,
salt=salt,
iterations=100000,
backend=default_backend(),
)
return kdf.derive(self.secret_key)
def _encrypt_payload(self, payload: dict) -> str:
"""使用AES-GCM加密整个payload"""
# 将payload转换为JSON字符串
payload_json = json.dumps(payload)
# 生成随机IV
iv = os.urandom(12)
# 创建加密器
cipher = Cipher(
algorithms.AES(self.encryption_key),
modes.GCM(iv),
backend=default_backend(),
)
encryptor = cipher.encryptor()
# 加密并生成标签
ciphertext = encryptor.update(payload_json.encode()) + encryptor.finalize()
tag = encryptor.tag
# 组合 IV + 密文 + 标签
encrypted_data = iv + ciphertext + tag
return base64.urlsafe_b64encode(encrypted_data).decode()
def _decrypt_payload(self, encrypted_data: str) -> dict:
"""解密AES-GCM加密的payload"""
# 解码Base64
data = base64.urlsafe_b64decode(encrypted_data.encode())
# 拆分组件
iv = data[:12]
tag = data[-16:]
ciphertext = data[12:-16]
# 创建解密器
cipher = Cipher(
algorithms.AES(self.encryption_key),
modes.GCM(iv, tag),
backend=default_backend(),
)
decryptor = cipher.decryptor()
# 解密
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
return json.loads(plaintext.decode())
def generate_token(self, device_id: str) -> str:
"""
生成JWT token
:param device_id: 设备ID
:return: JWT token字符串
"""
# 设置过期时间为1小时后
expire_time = datetime.now(timezone.utc) + timedelta(hours=1)
# 创建原始payload
payload = {"device_id": device_id, "exp": expire_time.timestamp()}
# 加密整个payload
encrypted_payload = self._encrypt_payload(payload)
# 创建外层payload,包含加密数据
outer_payload = {"data": encrypted_payload}
# 使用JWT进行编码
token = jwt.encode(outer_payload, self.secret_key, algorithm="HS256")
return token
def verify_token(self, token: str) -> Tuple[bool, Optional[str]]:
"""
验证token
:param token: JWT token字符串
:return: (是否有效, 设备ID)
"""
try:
# 先验证外层JWT(签名和过期时间)
outer_payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
# 解密内层payload
inner_payload = self._decrypt_payload(outer_payload["data"])
# 再次检查过期时间(双重验证)
if inner_payload["exp"] < time.time():
return False, None
return True, inner_payload["device_id"]
except jwt.InvalidTokenError:
return False, None
except json.JSONDecodeError:
return False, None
except Exception as e: # 捕获其他可能的错误
print(f"Token verification failed: {str(e)}")
return False, None
+62
View File
@@ -0,0 +1,62 @@
"""
缓存配置管理
"""
from enum import Enum
from typing import Dict, Any, Optional
from dataclasses import dataclass
from .strategies import CacheStrategy
class CacheType(Enum):
"""缓存类型枚举"""
LOCATION = "location"
WEATHER = "weather"
LUNAR = "lunar"
INTENT = "intent"
IP_INFO = "ip_info"
CONFIG = "config"
DEVICE_PROMPT = "device_prompt"
VOICEPRINT_HEALTH = "voiceprint_health" # 声纹识别健康检查
@dataclass
class CacheConfig:
"""缓存配置类"""
strategy: CacheStrategy = CacheStrategy.TTL
ttl: Optional[float] = 300 # 默认5分钟
max_size: Optional[int] = 1000 # 默认最大1000条
cleanup_interval: float = 60 # 清理间隔(秒)
@classmethod
def for_type(cls, cache_type: CacheType) -> "CacheConfig":
"""根据缓存类型返回预设配置"""
configs = {
CacheType.LOCATION: cls(
strategy=CacheStrategy.TTL, ttl=None, max_size=1000 # 手动失效
),
CacheType.IP_INFO: cls(
strategy=CacheStrategy.TTL, ttl=86400, max_size=1000 # 24小时
),
CacheType.WEATHER: cls(
strategy=CacheStrategy.TTL, ttl=28800, max_size=1000 # 8小时
),
CacheType.LUNAR: cls(
strategy=CacheStrategy.TTL, ttl=2592000, max_size=365 # 30天过期
),
CacheType.INTENT: cls(
strategy=CacheStrategy.TTL_LRU, ttl=600, max_size=1000 # 10分钟
),
CacheType.CONFIG: cls(
strategy=CacheStrategy.FIXED_SIZE, ttl=None, max_size=20 # 手动失效
),
CacheType.DEVICE_PROMPT: cls(
strategy=CacheStrategy.TTL, ttl=None, max_size=1000 # 手动失效
),
CacheType.VOICEPRINT_HEALTH: cls(
strategy=CacheStrategy.TTL, ttl=600, max_size=100 # 10分钟过期
),
}
return configs.get(cache_type, cls())
+216
View File
@@ -0,0 +1,216 @@
"""
全局缓存管理器
"""
import time
import threading
from typing import Any, Optional, Dict
from collections import OrderedDict
from .strategies import CacheStrategy, CacheEntry
from .config import CacheConfig, CacheType
class GlobalCacheManager:
"""全局缓存管理器"""
def __init__(self):
self._logger = None
self._caches: Dict[str, Dict[str, CacheEntry]] = {}
self._configs: Dict[str, CacheConfig] = {}
self._locks: Dict[str, threading.RLock] = {}
self._global_lock = threading.RLock()
self._last_cleanup = time.time()
self._stats = {"hits": 0, "misses": 0, "evictions": 0, "cleanups": 0}
@property
def logger(self):
"""延迟初始化 logger 以避免循环导入"""
if self._logger is None:
from config.logger import setup_logging
self._logger = setup_logging()
return self._logger
def _get_cache_name(self, cache_type: CacheType, namespace: str = "") -> str:
"""生成缓存名称"""
if namespace:
return f"{cache_type.value}:{namespace}"
return cache_type.value
def _get_or_create_cache(
self, cache_name: str, config: CacheConfig
) -> Dict[str, CacheEntry]:
"""获取或创建缓存空间"""
with self._global_lock:
if cache_name not in self._caches:
self._caches[cache_name] = (
OrderedDict()
if config.strategy in [CacheStrategy.LRU, CacheStrategy.TTL_LRU]
else {}
)
self._configs[cache_name] = config
self._locks[cache_name] = threading.RLock()
return self._caches[cache_name]
def set(
self,
cache_type: CacheType,
key: str,
value: Any,
ttl: Optional[float] = None,
namespace: str = "",
) -> None:
"""设置缓存值"""
cache_name = self._get_cache_name(cache_type, namespace)
config = self._configs.get(cache_name) or CacheConfig.for_type(cache_type)
cache = self._get_or_create_cache(cache_name, config)
# 使用配置的TTL或传入的TTL
effective_ttl = ttl if ttl is not None else config.ttl
with self._locks[cache_name]:
# 创建缓存条目
entry = CacheEntry(value=value, timestamp=time.time(), ttl=effective_ttl)
# 处理不同策略
if config.strategy in [CacheStrategy.LRU, CacheStrategy.TTL_LRU]:
# LRU策略:如果已存在则移动到末尾
if key in cache:
del cache[key]
cache[key] = entry
# 检查大小限制
if config.max_size and len(cache) > config.max_size:
# 移除最旧的条目
oldest_key = next(iter(cache))
del cache[oldest_key]
self._stats["evictions"] += 1
else:
cache[key] = entry
# 检查大小限制
if config.max_size and len(cache) > config.max_size:
# 简单策略:随机移除一个条目
victim_key = next(iter(cache))
del cache[victim_key]
self._stats["evictions"] += 1
# 定期清理过期条目
self._maybe_cleanup(cache_name)
def get(
self, cache_type: CacheType, key: str, namespace: str = ""
) -> Optional[Any]:
"""获取缓存值"""
cache_name = self._get_cache_name(cache_type, namespace)
if cache_name not in self._caches:
self._stats["misses"] += 1
return None
cache = self._caches[cache_name]
config = self._configs[cache_name]
with self._locks[cache_name]:
if key not in cache:
self._stats["misses"] += 1
return None
entry = cache[key]
# 检查过期
if entry.is_expired():
del cache[key]
self._stats["misses"] += 1
return None
# 更新访问信息
entry.touch()
# LRU策略:移动到末尾
if config.strategy in [CacheStrategy.LRU, CacheStrategy.TTL_LRU]:
del cache[key]
cache[key] = entry
self._stats["hits"] += 1
return entry.value
def delete(self, cache_type: CacheType, key: str, namespace: str = "") -> bool:
"""删除缓存条目"""
cache_name = self._get_cache_name(cache_type, namespace)
if cache_name not in self._caches:
return False
cache = self._caches[cache_name]
with self._locks[cache_name]:
if key in cache:
del cache[key]
return True
return False
def clear(self, cache_type: CacheType, namespace: str = "") -> None:
"""清空指定缓存"""
cache_name = self._get_cache_name(cache_type, namespace)
if cache_name not in self._caches:
return
with self._locks[cache_name]:
self._caches[cache_name].clear()
def invalidate_pattern(
self, cache_type: CacheType, pattern: str, namespace: str = ""
) -> int:
"""按模式失效缓存条目"""
cache_name = self._get_cache_name(cache_type, namespace)
if cache_name not in self._caches:
return 0
cache = self._caches[cache_name]
deleted_count = 0
with self._locks[cache_name]:
keys_to_delete = [key for key in cache.keys() if pattern in key]
for key in keys_to_delete:
del cache[key]
deleted_count += 1
return deleted_count
def _cleanup_expired(self, cache_name: str) -> int:
"""清理过期条目"""
if cache_name not in self._caches:
return 0
cache = self._caches[cache_name]
deleted_count = 0
with self._locks[cache_name]:
expired_keys = [key for key, entry in cache.items() if entry.is_expired()]
for key in expired_keys:
del cache[key]
deleted_count += 1
return deleted_count
def _maybe_cleanup(self, cache_name: str):
"""定期清理检查"""
config = self._configs.get(cache_name)
if not config:
return
now = time.time()
if now - self._last_cleanup > config.cleanup_interval:
self._last_cleanup = now
deleted = self._cleanup_expired(cache_name)
if deleted > 0:
self._stats["cleanups"] += 1
self.logger.debug(f"清理缓存 {cache_name}: 删除 {deleted} 个过期条目")
# 创建全局缓存管理器实例
cache_manager = GlobalCacheManager()
+43
View File
@@ -0,0 +1,43 @@
"""
缓存策略和数据结构定义
"""
import time
from enum import Enum
from typing import Any, Optional
from dataclasses import dataclass
class CacheStrategy(Enum):
"""缓存策略枚举"""
TTL = "ttl" # 基于时间过期
LRU = "lru" # 最近最少使用
FIXED_SIZE = "fixed_size" # 固定大小
TTL_LRU = "ttl_lru" # TTL + LRU混合策略
@dataclass
class CacheEntry:
"""缓存条目数据结构"""
value: Any
timestamp: float
ttl: Optional[float] = None # 生存时间(秒)
access_count: int = 0
last_access: float = None
def __post_init__(self):
if self.last_access is None:
self.last_access = self.timestamp
def is_expired(self) -> bool:
"""检查是否过期"""
if self.ttl is None:
return False
return time.time() - self.timestamp > self.ttl
def touch(self):
"""更新访问时间和计数"""
self.last_access = time.time()
self.access_count += 1
@@ -0,0 +1,68 @@
"""
时间工具模块
提供统一的时间获取功能
"""
import cnlunar
from datetime import datetime
WEEKDAY_MAP = {
"Monday": "星期一",
"Tuesday": "星期二",
"Wednesday": "星期三",
"Thursday": "星期四",
"Friday": "星期五",
"Saturday": "星期六",
"Sunday": "星期日",
}
def get_current_time() -> str:
"""
获取当前时间字符串 (格式: HH:MM)
"""
return datetime.now().strftime("%H:%M")
def get_current_date() -> str:
"""
获取今天日期字符串 (格式: YYYY-MM-DD)
"""
return datetime.now().strftime("%Y-%m-%d")
def get_current_weekday() -> str:
"""
获取今天星期几
"""
now = datetime.now()
return WEEKDAY_MAP[now.strftime("%A")]
def get_current_lunar_date() -> str:
"""
获取农历日期字符串
"""
try:
now = datetime.now()
today_lunar = cnlunar.Lunar(now, godType="8char")
return "%s%s%s" % (
today_lunar.lunarYearCn,
today_lunar.lunarMonthCn[:-1],
today_lunar.lunarDayCn,
)
except Exception:
return "农历获取失败"
def get_current_time_info() -> tuple:
"""
获取当前时间信息
返回: (当前时间字符串, 今天日期, 今天星期, 农历日期)
"""
current_time = get_current_time()
today_date = get_current_date()
today_weekday = get_current_weekday()
lunar_date = get_current_lunar_date()
return current_time, today_date, today_weekday, lunar_date
+118
View File
@@ -0,0 +1,118 @@
import uuid
import re
from typing import List, Dict
from datetime import datetime
class Message:
def __init__(
self,
role: str,
content: str = None,
uniq_id: str = None,
tool_calls=None,
tool_call_id=None,
):
self.uniq_id = uniq_id if uniq_id is not None else str(uuid.uuid4())
self.role = role
self.content = content
self.tool_calls = tool_calls
self.tool_call_id = tool_call_id
class Dialogue:
def __init__(self):
self.dialogue: List[Message] = []
# 获取当前时间
self.current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
def put(self, message: Message):
self.dialogue.append(message)
def getMessages(self, m, dialogue):
if m.tool_calls is not None:
dialogue.append({"role": m.role, "tool_calls": m.tool_calls})
elif m.role == "tool":
dialogue.append(
{
"role": m.role,
"tool_call_id": (
str(uuid.uuid4()) if m.tool_call_id is None else m.tool_call_id
),
"content": m.content,
}
)
else:
dialogue.append({"role": m.role, "content": m.content})
def get_llm_dialogue(self) -> List[Dict[str, str]]:
# 直接调用get_llm_dialogue_with_memory,传入None作为memory_str
# 这样确保说话人功能在所有调用路径下都生效
return self.get_llm_dialogue_with_memory(None, None)
def update_system_message(self, new_content: str):
"""更新或添加系统消息"""
# 查找第一个系统消息
system_msg = next((msg for msg in self.dialogue if msg.role == "system"), None)
if system_msg:
system_msg.content = new_content
else:
self.put(Message(role="system", content=new_content))
def get_llm_dialogue_with_memory(
self, memory_str: str = None, voiceprint_config: dict = None
) -> List[Dict[str, str]]:
# 构建对话
dialogue = []
# 添加系统提示和记忆
system_message = next(
(msg for msg in self.dialogue if msg.role == "system"), None
)
if system_message:
# 基础系统提示
enhanced_system_prompt = system_message.content
# 替换时间占位符
enhanced_system_prompt = enhanced_system_prompt.replace(
"{{current_time}}", datetime.now().strftime("%H:%M")
)
# 添加说话人个性化描述
try:
speakers = voiceprint_config.get("speakers", [])
if speakers:
enhanced_system_prompt += "\n\n<speakers_info>"
for speaker_str in speakers:
try:
parts = speaker_str.split(",", 2)
if len(parts) >= 2:
name = parts[1].strip()
# 如果描述为空,则为""
description = (
parts[2].strip() if len(parts) >= 3 else ""
)
enhanced_system_prompt += f"\n- {name}{description}"
except:
pass
enhanced_system_prompt += "\n\n</speakers_info>"
except:
# 配置读取失败时忽略错误,不影响其他功能
pass
# 使用正则表达式匹配 <memory> 标签,不管中间有什么内容
if memory_str is not None:
enhanced_system_prompt = re.sub(
r"<memory>.*?</memory>",
f"<memory>\n{memory_str}\n</memory>",
enhanced_system_prompt,
flags=re.DOTALL,
)
dialogue.append({"role": "system", "content": enhanced_system_prompt})
# 添加用户和助手的对话
for m in self.dialogue:
if m.role != "system": # 跳过原始的系统消息
self.getMessages(m, dialogue)
return dialogue
+17
View File
@@ -0,0 +1,17 @@
import os
import sys
from config.logger import setup_logging
import importlib
logger = setup_logging()
def create_instance(class_name, *args, **kwargs):
# 创建intent实例
if os.path.exists(os.path.join('core', 'providers', 'intent', class_name, f'{class_name}.py')):
lib_name = f'core.providers.intent.{class_name}.{class_name}'
if lib_name not in sys.modules:
sys.modules[lib_name] = importlib.import_module(f'{lib_name}')
return sys.modules[lib_name].IntentProvider(*args, **kwargs)
raise ValueError(f"不支持的intent类型: {class_name},请检查该配置的type是否设置正确")
+23
View File
@@ -0,0 +1,23 @@
import os
import sys
# 添加项目根目录到Python路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, "..", ".."))
sys.path.insert(0, project_root)
from config.logger import setup_logging
import importlib
logger = setup_logging()
def create_instance(class_name, *args, **kwargs):
# 创建LLM实例
if os.path.exists(os.path.join('core', 'providers', 'llm', class_name, f'{class_name}.py')):
lib_name = f'core.providers.llm.{class_name}.{class_name}'
if lib_name not in sys.modules:
sys.modules[lib_name] = importlib.import_module(f'{lib_name}')
return sys.modules[lib_name].LLMProvider(*args, **kwargs)
raise ValueError(f"不支持的LLM类型: {class_name},请检查该配置的type是否设置正确")
+18
View File
@@ -0,0 +1,18 @@
import os
import sys
import importlib
from config.logger import setup_logging
logger = setup_logging()
def create_instance(class_name, *args, **kwargs):
if os.path.exists(
os.path.join("core", "providers", "memory", class_name, f"{class_name}.py")
):
lib_name = f"core.providers.memory.{class_name}.{class_name}"
if lib_name not in sys.modules:
sys.modules[lib_name] = importlib.import_module(f"{lib_name}")
return sys.modules[lib_name].MemoryProvider(*args, **kwargs)
raise ValueError(f"不支持的记忆服务类型: {class_name}")
@@ -0,0 +1,151 @@
from typing import Dict, Any
from config.logger import setup_logging
from core.utils import tts, llm, intent, memory, vad, asr
TAG = __name__
logger = setup_logging()
def initialize_modules(
logger,
config: Dict[str, Any],
init_vad=False,
init_asr=False,
init_llm=False,
init_tts=False,
init_memory=False,
init_intent=False,
) -> Dict[str, Any]:
"""
初始化所有模块组件
Args:
config: 配置字典
Returns:
Dict[str, Any]: 包含所有初始化后的模块的字典
"""
modules = {}
# 初始化TTS模块
if init_tts:
select_tts_module = config["selected_module"]["TTS"]
modules["tts"] = initialize_tts(config)
logger.bind(tag=TAG).info(f"初始化组件: tts成功 {select_tts_module}")
# 初始化LLM模块
if init_llm:
select_llm_module = config["selected_module"]["LLM"]
llm_type = (
select_llm_module
if "type" not in config["LLM"][select_llm_module]
else config["LLM"][select_llm_module]["type"]
)
modules["llm"] = llm.create_instance(
llm_type,
config["LLM"][select_llm_module],
)
logger.bind(tag=TAG).info(f"初始化组件: llm成功 {select_llm_module}")
# 初始化Intent模块
if init_intent:
select_intent_module = config["selected_module"]["Intent"]
intent_type = (
select_intent_module
if "type" not in config["Intent"][select_intent_module]
else config["Intent"][select_intent_module]["type"]
)
modules["intent"] = intent.create_instance(
intent_type,
config["Intent"][select_intent_module],
)
logger.bind(tag=TAG).info(f"初始化组件: intent成功 {select_intent_module}")
# 初始化Memory模块
if init_memory:
select_memory_module = config["selected_module"]["Memory"]
memory_type = (
select_memory_module
if "type" not in config["Memory"][select_memory_module]
else config["Memory"][select_memory_module]["type"]
)
modules["memory"] = memory.create_instance(
memory_type,
config["Memory"][select_memory_module],
config.get("summaryMemory", None),
)
logger.bind(tag=TAG).info(f"初始化组件: memory成功 {select_memory_module}")
# 初始化VAD模块
if init_vad:
select_vad_module = config["selected_module"]["VAD"]
vad_type = (
select_vad_module
if "type" not in config["VAD"][select_vad_module]
else config["VAD"][select_vad_module]["type"]
)
modules["vad"] = vad.create_instance(
vad_type,
config["VAD"][select_vad_module],
)
logger.bind(tag=TAG).info(f"初始化组件: vad成功 {select_vad_module}")
# 初始化ASR模块
if init_asr:
select_asr_module = config["selected_module"]["ASR"]
modules["asr"] = initialize_asr(config)
logger.bind(tag=TAG).info(f"初始化组件: asr成功 {select_asr_module}")
return modules
def initialize_tts(config):
select_tts_module = config["selected_module"]["TTS"]
tts_type = (
select_tts_module
if "type" not in config["TTS"][select_tts_module]
else config["TTS"][select_tts_module]["type"]
)
new_tts = tts.create_instance(
tts_type,
config["TTS"][select_tts_module],
str(config.get("delete_audio", True)).lower() in ("true", "1", "yes"),
)
return new_tts
def initialize_asr(config):
select_asr_module = config["selected_module"]["ASR"]
asr_type = (
select_asr_module
if "type" not in config["ASR"][select_asr_module]
else config["ASR"][select_asr_module]["type"]
)
new_asr = asr.create_instance(
asr_type,
config["ASR"][select_asr_module],
str(config.get("delete_audio", True)).lower() in ("true", "1", "yes"),
)
logger.bind(tag=TAG).info("ASR模块初始化完成")
return new_asr
def initialize_voiceprint(asr_instance, config):
"""初始化声纹识别功能"""
voiceprint_config = config.get("voiceprint")
if not voiceprint_config:
return False
# 应用配置
if not voiceprint_config.get("url") or not voiceprint_config.get("speakers"):
logger.bind(tag=TAG).warning("声纹识别配置不完整")
return False
try:
asr_instance.init_voiceprint(voiceprint_config)
logger.bind(tag=TAG).info("ASR模块声纹识别功能已动态启用")
logger.bind(tag=TAG).info(f"配置说话人数量: {len(voiceprint_config['speakers'])}")
return True
except Exception as e:
logger.bind(tag=TAG).error(f"动态初始化声纹识别功能失败: {str(e)}")
return False
@@ -0,0 +1,132 @@
"""
Opus编码工具类
将PCM音频数据编码为Opus格式
"""
import logging
import traceback
import numpy as np
from opuslib_next import Encoder
from opuslib_next import constants
from typing import Optional, Callable, Any
class OpusEncoderUtils:
"""PCM到Opus的编码器"""
def __init__(self, sample_rate: int, channels: int, frame_size_ms: int):
"""
初始化Opus编码器
Args:
sample_rate: 采样率 (Hz)
channels: 通道数 (1=单声道, 2=立体声)
frame_size_ms: 帧大小 (毫秒)
"""
self.sample_rate = sample_rate
self.channels = channels
self.frame_size_ms = frame_size_ms
# 计算每帧样本数 = 采样率 * 帧大小(毫秒) / 1000
self.frame_size = (sample_rate * frame_size_ms) // 1000
# 总帧大小 = 每帧样本数 * 通道数
self.total_frame_size = self.frame_size * channels
# 比特率和复杂度设置
self.bitrate = 24000 # bps
self.complexity = 10 # 最高质量
# 缓冲区初始化为空
self.buffer = np.array([], dtype=np.int16)
try:
# 创建Opus编码器
self.encoder = Encoder(
sample_rate, channels, constants.APPLICATION_AUDIO # 音频优化模式
)
self.encoder.bitrate = self.bitrate
self.encoder.complexity = self.complexity
self.encoder.signal = constants.SIGNAL_VOICE # 语音信号优化
except Exception as e:
logging.error(f"初始化Opus编码器失败: {e}")
raise RuntimeError("初始化失败") from e
def reset_state(self):
"""重置编码器状态"""
self.encoder.reset_state()
self.buffer = np.array([], dtype=np.int16)
def encode_pcm_to_opus_stream(self, pcm_data: bytes, end_of_stream: bool, callback: Callable[[Any], Any]):
"""
将PCM数据编码为Opus格式,以流式方式进行处理
Args:
pcm_data: PCM字节数据
end_of_stream: 是否为流的结束,
callback: opus处理方法
Returns:
Opus数据包列表
"""
# 将字节数据转换为short数组
new_samples = self._convert_bytes_to_shorts(pcm_data)
# 校验PCM数据
self._validate_pcm_data(new_samples)
# 将新数据追加到缓冲区
self.buffer = np.append(self.buffer, new_samples)
offset = 0
# 处理所有完整帧
while offset <= len(self.buffer) - self.total_frame_size:
frame = self.buffer[offset : offset + self.total_frame_size]
output = self._encode(frame)
if output:
callback(output)
offset += self.total_frame_size
# 保留未处理的样本
self.buffer = self.buffer[offset:]
# 流结束时处理剩余数据
if end_of_stream and len(self.buffer) > 0:
# 创建最后一帧并用0填充
last_frame = np.zeros(self.total_frame_size, dtype=np.int16)
last_frame[: len(self.buffer)] = self.buffer
output = self._encode(last_frame)
if output:
callback(output)
self.buffer = np.array([], dtype=np.int16)
def _encode(self, frame: np.ndarray) -> Optional[bytes]:
"""编码一帧音频数据"""
try:
# 将numpy数组转换为bytes
frame_bytes = frame.tobytes()
# opuslib要求输入字节数必须是channels*2的倍数
encoded = self.encoder.encode(frame_bytes, self.frame_size)
return encoded
except Exception as e:
logging.error(f"Opus编码失败: {e}")
traceback.print_exc()
return None
def _convert_bytes_to_shorts(self, bytes_data: bytes) -> np.ndarray:
"""将字节数组转换为short数组 (16位PCM)"""
# 假设输入是小端字节序的16位PCM
return np.frombuffer(bytes_data, dtype=np.int16)
def _validate_pcm_data(self, pcm_shorts: np.ndarray) -> None:
"""验证PCM数据是否有效"""
# 16位PCM数据范围是 -32768 到 32767
if np.any((pcm_shorts < -32768) | (pcm_shorts > 32767)):
invalid_samples = pcm_shorts[(pcm_shorts < -32768) | (pcm_shorts > 32767)]
logging.warning(f"发现无效PCM样本: {invalid_samples[:5]}...")
# 在实际应用中可以选择裁剪而不是抛出异常
# np.clip(pcm_shorts, -32768, 32767, out=pcm_shorts)
def close(self):
"""关闭编码器并释放资源"""
# opuslib没有明确的关闭方法,Python的垃圾回收会处理
pass
@@ -0,0 +1,50 @@
import datetime
from typing import Dict, Tuple
# 全局字典,用于存储每个设备的每日输出字数
_device_daily_output: Dict[Tuple[str, datetime.date], int] = {}
# 记录最后一次检查的日期
_last_check_date: datetime.date = None
def reset_device_output():
"""
重置所有设备的每日输出字数
每天0点调用此函数
"""
_device_daily_output.clear()
def get_device_output(device_id: str) -> int:
"""
获取设备当日的输出字数
"""
current_date = datetime.datetime.now().date()
return _device_daily_output.get((device_id, current_date), 0)
def add_device_output(device_id: str, char_count: int):
"""
增加设备的输出字数
"""
current_date = datetime.datetime.now().date()
global _last_check_date
# 如果是第一次调用或者日期发生变化,清空计数器
if _last_check_date is None or _last_check_date != current_date:
_device_daily_output.clear()
_last_check_date = current_date
current_count = _device_daily_output.get((device_id, current_date), 0)
_device_daily_output[(device_id, current_date)] = current_count + char_count
def check_device_output_limit(device_id: str, max_output_size: int) -> bool:
"""
检查设备是否超过输出限制
:return: True 如果超过限制,False 如果未超过
"""
if not device_id:
return False
current_output = get_device_output(device_id)
return current_output >= max_output_size
+59
View File
@@ -0,0 +1,59 @@
import struct
def decode_opus_from_file(input_file):
"""
从p3文件中解码 Opus 数据,并返回一个 Opus 数据包的列表以及总时长。
"""
opus_datas = []
total_frames = 0
sample_rate = 16000 # 文件采样率
frame_duration_ms = 60 # 帧时长
frame_size = int(sample_rate * frame_duration_ms / 1000)
with open(input_file, 'rb') as f:
while True:
# 读取头部(4字节):[1字节类型,1字节保留,2字节长度]
header = f.read(4)
if not header:
break
# 解包头部信息
_, _, data_len = struct.unpack('>BBH', header)
# 根据头部指定的长度读取 Opus 数据
opus_data = f.read(data_len)
if len(opus_data) != data_len:
raise ValueError(f"Data length({len(opus_data)}) mismatch({data_len}) in the file.")
opus_datas.append(opus_data)
total_frames += 1
# 计算总时长
total_duration = (total_frames * frame_duration_ms) / 1000.0
return opus_datas, total_duration
def decode_opus_from_bytes(input_bytes):
"""
从p3二进制数据中解码 Opus 数据,并返回一个 Opus 数据包的列表以及总时长。
"""
import io
opus_datas = []
total_frames = 0
sample_rate = 16000 # 文件采样率
frame_duration_ms = 60 # 帧时长
frame_size = int(sample_rate * frame_duration_ms / 1000)
f = io.BytesIO(input_bytes)
while True:
header = f.read(4)
if not header:
break
_, _, data_len = struct.unpack('>BBH', header)
opus_data = f.read(data_len)
if len(opus_data) != data_len:
raise ValueError(f"Data length({len(opus_data)}) mismatch({data_len}) in the bytes.")
opus_datas.append(opus_data)
total_frames += 1
total_duration = (total_frames * frame_duration_ms) / 1000.0
return opus_datas, total_duration
@@ -0,0 +1,241 @@
"""
系统提示词管理器模块
负责管理和更新系统提示词,包括快速初始化和异步增强功能
"""
import os
import cnlunar
from typing import Dict, Any
from config.logger import setup_logging
from jinja2 import Template
TAG = __name__
WEEKDAY_MAP = {
"Monday": "星期一",
"Tuesday": "星期二",
"Wednesday": "星期三",
"Thursday": "星期四",
"Friday": "星期五",
"Saturday": "星期六",
"Sunday": "星期日",
}
EMOJI_List = [
"😶",
"🙂",
"😆",
"😂",
"😔",
"😠",
"😭",
"😍",
"😳",
"😲",
"😱",
"🤔",
"😉",
"😎",
"😌",
"🤤",
"😘",
"😏",
"😴",
"😜",
"🙄",
]
class PromptManager:
"""系统提示词管理器,负责管理和更新系统提示词"""
def __init__(self, config: Dict[str, Any], logger=None):
self.config = config
self.logger = logger or setup_logging()
self.base_prompt_template = None
self.last_update_time = 0
# 导入全局缓存管理器
from core.utils.cache.manager import cache_manager, CacheType
self.cache_manager = cache_manager
self.CacheType = CacheType
self._load_base_template()
def _load_base_template(self):
"""加载基础提示词模板"""
try:
template_path = self.config.get("prompt_template", "agent-base-prompt.txt")
cache_key = f"prompt_template:{template_path}"
# 先从缓存获取
cached_template = self.cache_manager.get(self.CacheType.CONFIG, cache_key)
if cached_template is not None:
self.base_prompt_template = cached_template
self.logger.bind(tag=TAG).debug("从缓存加载基础提示词模板")
return
# 缓存未命中,从文件读取
if os.path.exists(template_path):
with open(template_path, "r", encoding="utf-8") as f:
template_content = f.read()
# 存入缓存(CONFIG类型默认不自动过期,需要手动失效)
self.cache_manager.set(
self.CacheType.CONFIG, cache_key, template_content
)
self.base_prompt_template = template_content
self.logger.bind(tag=TAG).debug("成功加载基础提示词模板并缓存")
else:
self.logger.bind(tag=TAG).warning(f"未找到{template_path}文件")
except Exception as e:
self.logger.bind(tag=TAG).error(f"加载提示词模板失败: {e}")
def get_quick_prompt(self, user_prompt: str, device_id: str = None) -> str:
"""快速获取系统提示词(使用用户配置)"""
device_cache_key = f"device_prompt:{device_id}"
cached_device_prompt = self.cache_manager.get(
self.CacheType.DEVICE_PROMPT, device_cache_key
)
if cached_device_prompt is not None:
self.logger.bind(tag=TAG).debug(f"使用设备 {device_id} 的缓存提示词")
return cached_device_prompt
else:
self.logger.bind(tag=TAG).debug(
f"设备 {device_id} 无缓存提示词,使用传入的提示词"
)
# 使用传入的提示词并缓存(如果有设备ID)
if device_id:
device_cache_key = f"device_prompt:{device_id}"
self.cache_manager.set(self.CacheType.CONFIG, device_cache_key, user_prompt)
self.logger.bind(tag=TAG).debug(f"设备 {device_id} 的提示词已缓存")
self.logger.bind(tag=TAG).info(f"使用快速提示词: {user_prompt[:50]}...")
return user_prompt
def _get_current_time_info(self) -> tuple:
"""获取当前时间信息"""
from .current_time import get_current_date, get_current_weekday, get_current_lunar_date
today_date = get_current_date()
today_weekday = get_current_weekday()
lunar_date = get_current_lunar_date() + "\n"
return today_date, today_weekday, lunar_date
def _get_location_info(self, client_ip: str) -> str:
"""获取位置信息"""
try:
# 先从缓存获取
cached_location = self.cache_manager.get(self.CacheType.LOCATION, client_ip)
if cached_location is not None:
return cached_location
# 缓存未命中,调用API获取
from core.utils.util import get_ip_info
ip_info = get_ip_info(client_ip, self.logger)
city = ip_info.get("city", "未知位置")
location = f"{city}"
# 存入缓存
self.cache_manager.set(self.CacheType.LOCATION, client_ip, location)
return location
except Exception as e:
self.logger.bind(tag=TAG).error(f"获取位置信息失败: {e}")
return "未知位置"
def _get_weather_info(self, conn, location: str) -> str:
"""获取天气信息"""
try:
# 先从缓存获取
cached_weather = self.cache_manager.get(self.CacheType.WEATHER, location)
if cached_weather is not None:
return cached_weather
# 缓存未命中,调用get_weather函数获取
from plugins_func.functions.get_weather import get_weather
from plugins_func.register import ActionResponse
# 调用get_weather函数
result = get_weather(conn, location=location, lang="zh_CN")
if isinstance(result, ActionResponse):
weather_report = result.result
self.cache_manager.set(self.CacheType.WEATHER, location, weather_report)
return weather_report
return "天气信息获取失败"
except Exception as e:
self.logger.bind(tag=TAG).error(f"获取天气信息失败: {e}")
return "天气信息获取失败"
def update_context_info(self, conn, client_ip: str):
"""同步更新上下文信息"""
try:
# 获取位置信息(使用全局缓存)
local_address = self._get_location_info(client_ip)
# 获取天气信息(使用全局缓存)
self._get_weather_info(conn, local_address)
self.logger.bind(tag=TAG).info(f"上下文信息更新完成")
except Exception as e:
self.logger.bind(tag=TAG).error(f"更新上下文信息失败: {e}")
def build_enhanced_prompt(
self, user_prompt: str, device_id: str, client_ip: str = None, *args, **kwargs
) -> str:
"""构建增强的系统提示词"""
if not self.base_prompt_template:
return user_prompt
try:
# 获取最新的时间信息(不缓存)
today_date, today_weekday, lunar_date = (
self._get_current_time_info()
)
# 获取缓存的上下文信息
local_address = ""
weather_info = ""
if client_ip:
# 获取位置信息(从全局缓存)
local_address = (
self.cache_manager.get(self.CacheType.LOCATION, client_ip) or ""
)
# 获取天气信息(从全局缓存)
if local_address:
weather_info = (
self.cache_manager.get(self.CacheType.WEATHER, local_address)
or ""
)
# 替换模板变量
template = Template(self.base_prompt_template)
enhanced_prompt = template.render(
base_prompt=user_prompt,
current_time="{{current_time}}",
today_date=today_date,
today_weekday=today_weekday,
lunar_date=lunar_date,
local_address=local_address,
weather_info=weather_info,
emojiList=EMOJI_List,
device_id=device_id,
*args, **kwargs
)
device_cache_key = f"device_prompt:{device_id}"
self.cache_manager.set(
self.CacheType.DEVICE_PROMPT, device_cache_key, enhanced_prompt
)
self.logger.bind(tag=TAG).info(
f"构建增强提示词成功,长度: {len(enhanced_prompt)}"
)
return enhanced_prompt
except Exception as e:
self.logger.bind(tag=TAG).error(f"构建增强提示词失败: {e}")
return user_prompt
+113
View File
@@ -0,0 +1,113 @@
import json
TAG = __name__
EMOJI_MAP = {
"😂": "laughing",
"😭": "crying",
"😠": "angry",
"😔": "sad",
"😍": "loving",
"😲": "surprised",
"😱": "shocked",
"🤔": "thinking",
"😌": "relaxed",
"😴": "sleepy",
"😜": "silly",
"🙄": "confused",
"😶": "neutral",
"🙂": "happy",
"😆": "laughing",
"😳": "embarrassed",
"😉": "winking",
"😎": "cool",
"🤤": "delicious",
"😘": "kissy",
"😏": "confident",
}
EMOJI_RANGES = [
(0x1F600, 0x1F64F),
(0x1F300, 0x1F5FF),
(0x1F680, 0x1F6FF),
(0x1F900, 0x1F9FF),
(0x1FA70, 0x1FAFF),
(0x2600, 0x26FF),
(0x2700, 0x27BF),
]
def get_string_no_punctuation_or_emoji(s):
"""去除字符串首尾的空格、标点符号和表情符号"""
chars = list(s)
# 处理开头的字符
start = 0
while start < len(chars) and is_punctuation_or_emoji(chars[start]):
start += 1
# 处理结尾的字符
end = len(chars) - 1
while end >= start and is_punctuation_or_emoji(chars[end]):
end -= 1
return "".join(chars[start : end + 1])
def is_punctuation_or_emoji(char):
"""检查字符是否为空格、指定标点或表情符号"""
# 定义需要去除的中英文标点(包括全角/半角)
punctuation_set = {
"",
",", # 中文逗号 + 英文逗号
"",
".", # 中文句号 + 英文句号
"",
"!", # 中文感叹号 + 英文感叹号
"",
"",
'"', # 中文双引号 + 英文引号
"",
":", # 中文冒号 + 英文冒号
"-",
"", # 英文连字符 + 中文全角横线
"", # 中文顿号
"[",
"]", # 方括号
"",
"", # 中文方括号
}
if char.isspace() or char in punctuation_set:
return True
return is_emoji(char)
async def get_emotion(conn, text):
"""获取文本内的情绪消息"""
emoji = "🙂"
emotion = "happy"
for char in text:
if char in EMOJI_MAP:
emoji = char
emotion = EMOJI_MAP[char]
break
try:
await conn.websocket.send(
json.dumps(
{
"type": "llm",
"text": emoji,
"emotion": emotion,
"session_id": conn.session_id,
}
)
)
except Exception as e:
conn.logger.bind(tag=TAG).warning(f"发送情绪表情失败,错误:{e}")
return
def is_emoji(char):
"""检查字符是否为emoji表情"""
code_point = ord(char)
return any(start <= code_point <= end for start, end in EMOJI_RANGES)
def check_emoji(text):
"""去除文本中的所有emoji表情"""
return ''.join(char for char in text if not is_emoji(char) and char != "\n")
+138
View File
@@ -0,0 +1,138 @@
import os
import re
import sys
from config.logger import setup_logging
import importlib
logger = setup_logging()
punctuation_set = {
"",
",", # 中文逗号 + 英文逗号
"",
".", # 中文句号 + 英文句号
"",
"!", # 中文感叹号 + 英文感叹号
"",
"",
'"', # 中文双引号 + 英文引号
"",
":", # 中文冒号 + 英文冒号
"-",
"", # 英文连字符 + 中文全角横线
"", # 中文顿号
"[",
"]", # 方括号
"",
"", # 中文方括号
"~", # 波浪号
}
def create_instance(class_name, *args, **kwargs):
# 创建TTS实例
if os.path.exists(os.path.join('core', 'providers', 'tts', f'{class_name}.py')):
lib_name = f'core.providers.tts.{class_name}'
if lib_name not in sys.modules:
sys.modules[lib_name] = importlib.import_module(f'{lib_name}')
return sys.modules[lib_name].TTSProvider(*args, **kwargs)
raise ValueError(f"不支持的TTS类型: {class_name},请检查该配置的type是否设置正确")
class MarkdownCleaner:
"""
封装 Markdown 清理逻辑:直接用 MarkdownCleaner.clean_markdown(text) 即可
"""
# 公式字符
NORMAL_FORMULA_CHARS = re.compile(r'[a-zA-Z\\^_{}\+\-\(\)\[\]=]')
@staticmethod
def _replace_inline_dollar(m: re.Match) -> str:
"""
只要捕获到完整的 "$...$":
- 如果内部有典型公式字符 => 去掉两侧 $
- 否则 (纯数字/货币等) => 保留 "$...$"
"""
content = m.group(1)
if MarkdownCleaner.NORMAL_FORMULA_CHARS.search(content):
return content
else:
return m.group(0)
@staticmethod
def _replace_table_block(match: re.Match) -> str:
"""
当匹配到一个整段表格块时,回调该函数。
"""
block_text = match.group('table_block')
lines = block_text.strip('\n').split('\n')
parsed_table = []
for line in lines:
line_stripped = line.strip()
if re.match(r'^\|\s*[-:]+\s*(\|\s*[-:]+\s*)+\|?$', line_stripped):
continue
columns = [col.strip() for col in line_stripped.split('|') if col.strip() != '']
if columns:
parsed_table.append(columns)
if not parsed_table:
return ""
headers = parsed_table[0]
data_rows = parsed_table[1:] if len(parsed_table) > 1 else []
lines_for_tts = []
if len(parsed_table) == 1:
# 只有一行
only_line_str = ", ".join(parsed_table[0])
lines_for_tts.append(f"单行表格:{only_line_str}")
else:
lines_for_tts.append(f"表头是:{', '.join(headers)}")
for i, row in enumerate(data_rows, start=1):
row_str_list = []
for col_index, cell_val in enumerate(row):
if col_index < len(headers):
row_str_list.append(f"{headers[col_index]} = {cell_val}")
else:
row_str_list.append(cell_val)
lines_for_tts.append(f"{i} 行:{', '.join(row_str_list)}")
return "\n".join(lines_for_tts) + "\n"
# 预编译所有正则表达式(按执行频率排序)
# 这里要把 replace_xxx 的静态方法放在最前定义,以便在列表里能正确引用它们。
REGEXES = [
(re.compile(r'```.*?```', re.DOTALL), ''), # 代码块
(re.compile(r'^#+\s*', re.MULTILINE), ''), # 标题
(re.compile(r'(\*\*|__)(.*?)\1'), r'\2'), # 粗体
(re.compile(r'(\*|_)(?=\S)(.*?)(?<=\S)\1'), r'\2'), # 斜体
(re.compile(r'!\[.*?\]\(.*?\)'), ''), # 图片
(re.compile(r'\[(.*?)\]\(.*?\)'), r'\1'), # 链接
(re.compile(r'^\s*>+\s*', re.MULTILINE), ''), # 引用
(
re.compile(r'(?P<table_block>(?:^[^\n]*\|[^\n]*\n)+)', re.MULTILINE),
_replace_table_block
),
(re.compile(r'^\s*[*+-]\s*', re.MULTILINE), '- '), # 列表
(re.compile(r'\$\$.*?\$\$', re.DOTALL), ''), # 块级公式
(
re.compile(r'(?<![A-Za-z0-9])\$([^\n$]+)\$(?![A-Za-z0-9])'),
_replace_inline_dollar
),
(re.compile(r'\n{2,}'), '\n'), # 多余空行
]
@staticmethod
def clean_markdown(text: str) -> str:
"""
主入口方法:依序执行所有正则,移除或替换 Markdown 元素
"""
# 检查文本是否全为英文和基本标点符号
if text and all((c.isascii() or c.isspace() or c in punctuation_set) for c in text):
# 保留原始空格,直接返回
return text
for regex, replacement in MarkdownCleaner.REGEXES:
text = regex.sub(replacement, text)
return text.strip()
+542
View File
@@ -0,0 +1,542 @@
import re
import os
import json
import copy
import wave
import socket
import requests
import subprocess
import numpy as np
import opuslib_next
from io import BytesIO
from core.utils import p3
from pydub import AudioSegment
from typing import Callable, Any
TAG = __name__
emoji_map = {
"neutral": "😶",
"happy": "🙂",
"laughing": "😆",
"funny": "😂",
"sad": "😔",
"angry": "😠",
"crying": "😭",
"loving": "😍",
"embarrassed": "😳",
"surprised": "😲",
"shocked": "😱",
"thinking": "🤔",
"winking": "😉",
"cool": "😎",
"relaxed": "😌",
"delicious": "🤤",
"kissy": "😘",
"confident": "😏",
"sleepy": "😴",
"silly": "😜",
"confused": "🙄",
}
def get_local_ip():
try:
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
# Connect to Google's DNS servers
s.connect(("8.8.8.8", 80))
local_ip = s.getsockname()[0]
s.close()
return local_ip
except Exception as e:
return "127.0.0.1"
def is_private_ip(ip_addr):
"""
Check if an IP address is a private IP address (compatible with IPv4 and IPv6).
@param {string} ip_addr - The IP address to check.
@return {bool} True if the IP address is private, False otherwise.
"""
try:
# Validate IPv4 or IPv6 address format
if not re.match(
r"^(\d{1,3}\.){3}\d{1,3}$|^([0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}$", ip_addr
):
return False # Invalid IP address format
# IPv4 private address ranges
if "." in ip_addr: # IPv4 address
ip_parts = list(map(int, ip_addr.split(".")))
if ip_parts[0] == 10:
return True # 10.0.0.0/8 range
elif ip_parts[0] == 172 and 16 <= ip_parts[1] <= 31:
return True # 172.16.0.0/12 range
elif ip_parts[0] == 192 and ip_parts[1] == 168:
return True # 192.168.0.0/16 range
elif ip_addr == "127.0.0.1":
return True # Loopback address
elif ip_parts[0] == 169 and ip_parts[1] == 254:
return True # Link-local address 169.254.0.0/16
else:
return False # Not a private IPv4 address
else: # IPv6 address
ip_addr = ip_addr.lower()
if ip_addr.startswith("fc00:") or ip_addr.startswith("fd00:"):
return True # Unique Local Addresses (FC00::/7)
elif ip_addr == "::1":
return True # Loopback address
elif ip_addr.startswith("fe80:"):
return True # Link-local unicast addresses (FE80::/10)
else:
return False # Not a private IPv6 address
except (ValueError, IndexError):
return False # IP address format error or insufficient segments
def get_ip_info(ip_addr, logger):
try:
# 导入全局缓存管理器
from core.utils.cache.manager import cache_manager, CacheType
# 先从缓存获取
cached_ip_info = cache_manager.get(CacheType.IP_INFO, ip_addr)
if cached_ip_info is not None:
return cached_ip_info
# 缓存未命中,调用API
if is_private_ip(ip_addr):
ip_addr = ""
url = f"https://whois.pconline.com.cn/ipJson.jsp?json=true&ip={ip_addr}"
resp = requests.get(url).json()
ip_info = {"city": resp.get("city")}
# 存入缓存
cache_manager.set(CacheType.IP_INFO, ip_addr, ip_info)
return ip_info
except Exception as e:
logger.bind(tag=TAG).error(f"Error getting client ip info: {e}")
return {}
def write_json_file(file_path, data):
"""将数据写入 JSON 文件"""
with open(file_path, "w", encoding="utf-8") as file:
json.dump(data, file, ensure_ascii=False, indent=4)
def remove_punctuation_and_length(text):
# 全角符号和半角符号的Unicode范围
full_width_punctuations = (
"!"#$%&'()*+,-。/:;<=>?@[\]^_`{|}~"
)
half_width_punctuations = r'!"#$%&\'()*+,-./:;<=>?@[\]^_`{|}~'
space = " " # 半角空格
full_width_space = " " # 全角空格
# 去除全角和半角符号以及空格
result = "".join(
[
char
for char in text
if char not in full_width_punctuations
and char not in half_width_punctuations
and char not in space
and char not in full_width_space
]
)
if result == "Yeah":
return 0, ""
return len(result), result
def check_model_key(modelType, modelKey):
if "" in modelKey:
return f"配置错误: {modelType} 的 API key 未设置,当前值为: {modelKey}"
return None
def parse_string_to_list(value, separator=";"):
"""
将输入值转换为列表
Args:
value: 输入值,可以是 None、字符串或列表
separator: 分隔符,默认为分号
Returns:
list: 处理后的列表
"""
if value is None or value == "":
return []
elif isinstance(value, str):
return [item.strip() for item in value.split(separator) if item.strip()]
elif isinstance(value, list):
return value
return []
def check_ffmpeg_installed() -> bool:
"""
检查当前环境中是否已正确安装并可执行 ffmpeg。
Returns:
bool: 如果 ffmpeg 正常可用,返回 True;否则抛出 ValueError 异常。
Raises:
ValueError: 当检测到 ffmpeg 未安装或依赖缺失时,抛出详细的提示信息。
"""
try:
# 尝试执行 ffmpeg 命令
result = subprocess.run(
["ffmpeg", "-version"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True,
check=True, # 非零退出码会触发 CalledProcessError
)
output = (result.stdout + result.stderr).lower()
if "ffmpeg version" in output:
return True
# 如果未检测到版本信息,也视为异常情况
raise ValueError("未检测到有效的 ffmpeg 版本输出。")
except (subprocess.CalledProcessError, FileNotFoundError) as e:
# 提取错误输出
stderr_output = ""
if isinstance(e, subprocess.CalledProcessError):
stderr_output = (e.stderr or "").strip()
else:
stderr_output = str(e).strip()
# 构建基础错误提示
error_msg = [
"❌ 检测到 ffmpeg 无法正常运行。\n",
"建议您:",
"1. 确认已正确激活 conda 环境;",
"2. 查阅项目安装文档,了解如何在 conda 环境中安装 ffmpeg。\n",
]
# 🎯 针对具体错误信息提供额外提示
if "libiconv.so.2" in stderr_output:
error_msg.append("⚠️ 发现缺少依赖库:libiconv.so.2")
error_msg.append("解决方法:在当前 conda 环境中执行:")
error_msg.append(" conda install -c conda-forge libiconv\n")
elif "no such file or directory" in stderr_output and "ffmpeg" in stderr_output.lower():
error_msg.append("⚠️ 系统未找到 ffmpeg 可执行文件。")
error_msg.append("解决方法:在当前 conda 环境中执行:")
error_msg.append(" conda install -c conda-forge ffmpeg\n")
else:
error_msg.append("错误详情:")
error_msg.append(stderr_output or "未知错误。")
# 抛出详细异常信息
raise ValueError("\n".join(error_msg)) from e
def extract_json_from_string(input_string):
"""提取字符串中的 JSON 部分"""
pattern = r"(\{.*\})"
match = re.search(pattern, input_string, re.DOTALL) # 添加 re.DOTALL
if match:
return match.group(1) # 返回提取的 JSON 字符串
return None
def audio_to_data_stream(audio_file_path, is_opus=True, callback: Callable[[Any], Any]=None) -> None:
# 获取文件后缀名
file_type = os.path.splitext(audio_file_path)[1]
if file_type:
file_type = file_type.lstrip(".")
# 读取音频文件,-nostdin 参数:不要从标准输入读取数据,否则FFmpeg会阻塞
audio = AudioSegment.from_file(
audio_file_path, format=file_type, parameters=["-nostdin"]
)
# 转换为单声道/16kHz采样率/16位小端编码(确保与编码器匹配)
audio = audio.set_channels(1).set_frame_rate(16000).set_sample_width(2)
# 获取原始PCM数据(16位小端)
raw_data = audio.raw_data
pcm_to_data_stream(raw_data, is_opus, callback)
def audio_to_data(audio_file_path: str, is_opus: bool = True) -> list[bytes]:
"""
将音频文件转换为Opus/PCM编码的帧列表
Args:
audio_file_path: 音频文件路径
is_opus: 是否进行Opus编码
"""
# 获取文件后缀名
file_type = os.path.splitext(audio_file_path)[1]
if file_type:
file_type = file_type.lstrip(".")
# 读取音频文件,-nostdin 参数:不要从标准输入读取数据,否则FFmpeg会阻塞
audio = AudioSegment.from_file(
audio_file_path, format=file_type, parameters=["-nostdin"]
)
# 转换为单声道/16kHz采样率/16位小端编码(确保与编码器匹配)
audio = audio.set_channels(1).set_frame_rate(16000).set_sample_width(2)
# 获取原始PCM数据(16位小端)
raw_data = audio.raw_data
# 初始化Opus编码器
encoder = opuslib_next.Encoder(16000, 1, opuslib_next.APPLICATION_AUDIO)
# 编码参数
frame_duration = 60 # 60ms per frame
frame_size = int(16000 * frame_duration / 1000) # 960 samples/frame
datas = []
# 按帧处理所有音频数据(包括最后一帧可能补零)
for i in range(0, len(raw_data), frame_size * 2): # 16bit=2bytes/sample
# 获取当前帧的二进制数据
chunk = raw_data[i : i + frame_size * 2]
# 如果最后一帧不足,补零
if len(chunk) < frame_size * 2:
chunk += b"\x00" * (frame_size * 2 - len(chunk))
if is_opus:
# 转换为numpy数组处理
np_frame = np.frombuffer(chunk, dtype=np.int16)
# 编码Opus数据
frame_data = encoder.encode(np_frame.tobytes(), frame_size)
else:
frame_data = chunk if isinstance(chunk, bytes) else bytes(chunk)
datas.append(frame_data)
return datas
def audio_bytes_to_data_stream(audio_bytes, file_type, is_opus, callback: Callable[[Any], Any]) -> None:
"""
直接用音频二进制数据转为opus/pcm数据,支持wav、mp3、p3
"""
if file_type == "p3":
# 直接用p3解码
return p3.decode_opus_from_bytes_stream(audio_bytes, callback)
else:
# 其他格式用pydub
audio = AudioSegment.from_file(
BytesIO(audio_bytes), format=file_type, parameters=["-nostdin"]
)
audio = audio.set_channels(1).set_frame_rate(16000).set_sample_width(2)
raw_data = audio.raw_data
pcm_to_data_stream(raw_data, is_opus, callback)
def pcm_to_data_stream(raw_data, is_opus=True, callback: Callable[[Any], Any] = None):
# 初始化Opus编码器
encoder = opuslib_next.Encoder(16000, 1, opuslib_next.APPLICATION_AUDIO)
# 编码参数
frame_duration = 60 # 60ms per frame
frame_size = int(16000 * frame_duration / 1000) # 960 samples/frame
# 按帧处理所有音频数据(包括最后一帧可能补零)
for i in range(0, len(raw_data), frame_size * 2): # 16bit=2bytes/sample
# 获取当前帧的二进制数据
chunk = raw_data[i : i + frame_size * 2]
# 如果最后一帧不足,补零
if len(chunk) < frame_size * 2:
chunk += b"\x00" * (frame_size * 2 - len(chunk))
if is_opus:
# 转换为numpy数组处理
np_frame = np.frombuffer(chunk, dtype=np.int16)
# 编码Opus数据
frame_data = encoder.encode(np_frame.tobytes(), frame_size)
callback(frame_data)
else:
frame_data = chunk if isinstance(chunk, bytes) else bytes(chunk)
callback(frame_data)
def opus_datas_to_wav_bytes(opus_datas, sample_rate=16000, channels=1):
"""
将opus帧列表解码为wav字节流
"""
decoder = opuslib_next.Decoder(sample_rate, channels)
pcm_datas = []
frame_duration = 60 # ms
frame_size = int(sample_rate * frame_duration / 1000) # 960
for opus_frame in opus_datas:
# 解码为PCM(返回bytes,2字节/采样点)
pcm = decoder.decode(opus_frame, frame_size)
pcm_datas.append(pcm)
pcm_bytes = b"".join(pcm_datas)
# 写入wav字节流
wav_buffer = BytesIO()
with wave.open(wav_buffer, "wb") as wf:
wf.setnchannels(channels)
wf.setsampwidth(2) # 16bit
wf.setframerate(sample_rate)
wf.writeframes(pcm_bytes)
return wav_buffer.getvalue()
def check_vad_update(before_config, new_config):
if (
new_config.get("selected_module") is None
or new_config["selected_module"].get("VAD") is None
):
return False
update_vad = False
current_vad_module = before_config["selected_module"]["VAD"]
new_vad_module = new_config["selected_module"]["VAD"]
current_vad_type = (
current_vad_module
if "type" not in before_config["VAD"][current_vad_module]
else before_config["VAD"][current_vad_module]["type"]
)
new_vad_type = (
new_vad_module
if "type" not in new_config["VAD"][new_vad_module]
else new_config["VAD"][new_vad_module]["type"]
)
update_vad = current_vad_type != new_vad_type
return update_vad
def check_asr_update(before_config, new_config):
if (
new_config.get("selected_module") is None
or new_config["selected_module"].get("ASR") is None
):
return False
update_asr = False
current_asr_module = before_config["selected_module"]["ASR"]
new_asr_module = new_config["selected_module"]["ASR"]
current_asr_type = (
current_asr_module
if "type" not in before_config["ASR"][current_asr_module]
else before_config["ASR"][current_asr_module]["type"]
)
new_asr_type = (
new_asr_module
if "type" not in new_config["ASR"][new_asr_module]
else new_config["ASR"][new_asr_module]["type"]
)
update_asr = current_asr_type != new_asr_type
return update_asr
def filter_sensitive_info(config: dict) -> dict:
"""
过滤配置中的敏感信息
Args:
config: 原始配置字典
Returns:
过滤后的配置字典
"""
sensitive_keys = [
"api_key",
"personal_access_token",
"access_token",
"token",
"secret",
"access_key_secret",
"secret_key",
]
def _filter_dict(d: dict) -> dict:
filtered = {}
for k, v in d.items():
if any(sensitive in k.lower() for sensitive in sensitive_keys):
filtered[k] = "***"
elif isinstance(v, dict):
filtered[k] = _filter_dict(v)
elif isinstance(v, list):
filtered[k] = [_filter_dict(i) if isinstance(i, dict) else i for i in v]
else:
filtered[k] = v
return filtered
return _filter_dict(copy.deepcopy(config))
def get_vision_url(config: dict) -> str:
"""获取 vision URL
Args:
config: 配置字典
Returns:
str: vision URL
"""
server_config = config["server"]
vision_explain = server_config.get("vision_explain", "")
if "你的" in vision_explain:
local_ip = get_local_ip()
port = int(server_config.get("http_port", 8003))
vision_explain = f"http://{local_ip}:{port}/mcp/vision/explain"
return vision_explain
def is_valid_image_file(file_data: bytes) -> bool:
"""
检查文件数据是否为有效的图片格式
Args:
file_data: 文件的二进制数据
Returns:
bool: 如果是有效的图片格式返回True,否则返回False
"""
# 常见图片格式的魔数(文件头)
image_signatures = {
b"\xff\xd8\xff": "JPEG",
b"\x89PNG\r\n\x1a\n": "PNG",
b"GIF87a": "GIF",
b"GIF89a": "GIF",
b"BM": "BMP",
b"II*\x00": "TIFF",
b"MM\x00*": "TIFF",
b"RIFF": "WEBP",
}
# 检查文件头是否匹配任何已知的图片格式
for signature in image_signatures:
if file_data.startswith(signature):
return True
return False
def sanitize_tool_name(name: str) -> str:
"""Sanitize tool names for OpenAI compatibility."""
# 支持中文、英文字母、数字、下划线和连字符
return re.sub(r"[^a-zA-Z0-9_\-\u4e00-\u9fff]", "_", name)
def validate_mcp_endpoint(mcp_endpoint: str) -> bool:
"""
校验MCP接入点格式
Args:
mcp_endpoint: MCP接入点字符串
Returns:
bool: 是否有效
"""
# 1. 检查是否以ws开头
if not mcp_endpoint.startswith("ws"):
return False
# 2. 检查是否包含key、call字样
if "key" in mcp_endpoint.lower() or "call" in mcp_endpoint.lower():
return False
# 3. 检查是否包含/mcp/字样
if "/mcp/" not in mcp_endpoint:
return False
return True
+19
View File
@@ -0,0 +1,19 @@
import importlib
import os
import sys
from core.providers.vad.base import VADProviderBase
from config.logger import setup_logging
TAG = __name__
logger = setup_logging()
def create_instance(class_name: str, *args, **kwargs) -> VADProviderBase:
"""工厂方法创建VAD实例"""
if os.path.exists(os.path.join("core", "providers", "vad", f"{class_name}.py")):
lib_name = f"core.providers.vad.{class_name}"
if lib_name not in sys.modules:
sys.modules[lib_name] = importlib.import_module(f"{lib_name}")
return sys.modules[lib_name].VADProvider(*args, **kwargs)
raise ValueError(f"不支持的VAD类型: {class_name},请检查该配置的type是否设置正确")
+23
View File
@@ -0,0 +1,23 @@
import os
import sys
# 添加项目根目录到Python路径
current_dir = os.path.dirname(os.path.abspath(__file__))
project_root = os.path.abspath(os.path.join(current_dir, "..", ".."))
sys.path.insert(0, project_root)
from config.logger import setup_logging
import importlib
logger = setup_logging()
def create_instance(class_name, *args, **kwargs):
# 创建LLM实例
if os.path.exists(os.path.join("core", "providers", "vllm", f"{class_name}.py")):
lib_name = f"core.providers.vllm.{class_name}"
if lib_name not in sys.modules:
sys.modules[lib_name] = importlib.import_module(f"{lib_name}")
return sys.modules[lib_name].VLLMProvider(*args, **kwargs)
raise ValueError(f"不支持的VLLM类型: {class_name},请检查该配置的type是否设置正确")
@@ -0,0 +1,198 @@
import asyncio
import time
import aiohttp
import requests
from urllib.parse import urlparse, parse_qs
from typing import Optional, Dict
from config.logger import setup_logging
from core.utils.cache.manager import cache_manager
from core.utils.cache.config import CacheType
TAG = __name__
logger = setup_logging()
class VoiceprintProvider:
"""声纹识别服务提供者"""
def __init__(self, config: dict):
self.original_url = config.get("url", "")
self.speakers = config.get("speakers", [])
self.speaker_map = self._parse_speakers()
# 声纹识别相似度阈值,默认0.4
self.similarity_threshold = float(config.get("similarity_threshold", 0.4))
# 解析API地址和密钥
self.api_url = None
self.api_key = None
self.speaker_ids = []
if not self.original_url:
logger.bind(tag=TAG).warning("声纹识别URL未配置,声纹识别将被禁用")
self.enabled = False
else:
# 解析URL和key
parsed_url = urlparse(self.original_url)
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
# 从查询参数中提取key
query_params = parse_qs(parsed_url.query)
self.api_key = query_params.get('key', [''])[0]
if not self.api_key:
logger.bind(tag=TAG).error("URL中未找到key参数,声纹识别将被禁用")
self.enabled = False
else:
# 构造identify接口地址
self.api_url = f"{base_url}/voiceprint/identify"
# 提取speaker_ids
for speaker_str in self.speakers:
try:
parts = speaker_str.split(",", 2)
if len(parts) >= 1:
speaker_id = parts[0].strip()
self.speaker_ids.append(speaker_id)
except Exception:
continue
# 检查是否有有效的说话人配置
if not self.speaker_ids:
logger.bind(tag=TAG).warning("未配置有效的说话人,声纹识别将被禁用")
self.enabled = False
else:
# 进行健康检查,验证服务器是否可用
if self._check_server_health():
self.enabled = True
logger.bind(tag=TAG).info(f"声纹识别已启用: API={self.api_url}, 说话人={len(self.speaker_ids)}个, 相似度阈值={self.similarity_threshold}")
else:
self.enabled = False
logger.bind(tag=TAG).warning(f"声纹识别服务器不可用,声纹识别已禁用: {self.api_url}")
def _parse_speakers(self) -> Dict[str, Dict[str, str]]:
"""解析说话人配置"""
speaker_map = {}
for speaker_str in self.speakers:
try:
parts = speaker_str.split(",", 2)
if len(parts) >= 3:
speaker_id, name, description = parts[0].strip(), parts[1].strip(), parts[2].strip()
speaker_map[speaker_id] = {
"name": name,
"description": description
}
except Exception as e:
logger.bind(tag=TAG).warning(f"解析说话人配置失败: {speaker_str}, 错误: {e}")
return speaker_map
def _check_server_health(self) -> bool:
"""检查声纹识别服务器健康状态"""
if not self.api_url or not self.api_key:
return False
cache_key = f"{self.api_url}:{self.api_key}"
# 检查缓存
cached_result = cache_manager.get(CacheType.VOICEPRINT_HEALTH, cache_key)
if cached_result is not None:
logger.bind(tag=TAG).debug(f"使用缓存的健康状态: {cached_result}")
return cached_result
# 缓存过期或不存在
logger.bind(tag=TAG).info("执行声纹服务器健康检查")
try:
# 健康检查URL
parsed_url = urlparse(self.api_url)
health_url = f"{parsed_url.scheme}://{parsed_url.netloc}/voiceprint/health?key={self.api_key}"
# 发送健康检查请求
response = requests.get(health_url, timeout=3)
if response.status_code == 200:
result = response.json()
if result.get("status") == "healthy":
logger.bind(tag=TAG).info("声纹识别服务器健康检查通过")
is_healthy = True
else:
logger.bind(tag=TAG).warning(f"声纹识别服务器状态异常: {result}")
is_healthy = False
else:
logger.bind(tag=TAG).warning(f"声纹识别服务器健康检查失败: HTTP {response.status_code}")
is_healthy = False
except requests.exceptions.ConnectTimeout:
logger.bind(tag=TAG).warning("声纹识别服务器连接超时")
is_healthy = False
except requests.exceptions.ConnectionError:
logger.bind(tag=TAG).warning("声纹识别服务器连接被拒绝")
is_healthy = False
except Exception as e:
logger.bind(tag=TAG).warning(f"声纹识别服务器健康检查异常: {e}")
is_healthy = False
# 使用全局缓存管理器缓存结果
cache_manager.set(CacheType.VOICEPRINT_HEALTH, cache_key, is_healthy)
logger.bind(tag=TAG).info(f"健康检查结果已缓存: {is_healthy}")
return is_healthy
async def identify_speaker(self, audio_data: bytes, session_id: str) -> Optional[str]:
"""识别说话人"""
if not self.enabled or not self.api_url or not self.api_key:
logger.bind(tag=TAG).debug("声纹识别功能已禁用或未配置,跳过识别")
return None
try:
api_start_time = time.monotonic()
# 准备请求头
headers = {
'Authorization': f'Bearer {self.api_key}',
'Accept': 'application/json'
}
# 准备multipart/form-data数据
data = aiohttp.FormData()
data.add_field('speaker_ids', ','.join(self.speaker_ids))
data.add_field('file', audio_data, filename='audio.wav', content_type='audio/wav')
timeout = aiohttp.ClientTimeout(total=10)
# 网络请求
async with aiohttp.ClientSession(timeout=timeout) as session:
async with session.post(self.api_url, headers=headers, data=data) as response:
if response.status == 200:
result = await response.json()
speaker_id = result.get("speaker_id")
score = result.get("score", 0)
total_elapsed_time = time.monotonic() - api_start_time
logger.bind(tag=TAG).info(f"声纹识别耗时: {total_elapsed_time:.3f}s")
# 相似度阈值检查
if score < self.similarity_threshold:
logger.bind(tag=TAG).warning(f"声纹识别相似度{score:.3f}低于阈值{self.similarity_threshold}")
return "未知说话人"
if speaker_id and speaker_id in self.speaker_map:
result_name = self.speaker_map[speaker_id]["name"]
logger.bind(tag=TAG).info(f"声纹识别成功: {result_name} (相似度: {score:.3f})")
return result_name
else:
logger.bind(tag=TAG).warning(f"未识别的说话人ID: {speaker_id}")
return "未知说话人"
else:
logger.bind(tag=TAG).error(f"声纹识别API错误: HTTP {response.status}")
return None
except asyncio.TimeoutError:
elapsed = time.monotonic() - api_start_time
logger.bind(tag=TAG).error(f"声纹识别超时: {elapsed:.3f}s")
return None
except Exception as e:
elapsed = time.monotonic() - api_start_time
logger.bind(tag=TAG).error(f"声纹识别失败: {e}")
return None
+140
View File
@@ -0,0 +1,140 @@
import os
import re
import yaml
import time
import hashlib
import portalocker
from typing import Dict
class FileLock:
def __init__(self, file, timeout=5):
self.file = file
self.timeout = timeout
self.start_time = None
def __enter__(self):
self.start_time = time.time()
while True:
try:
portalocker.lock(self.file, portalocker.LOCK_EX | portalocker.LOCK_NB)
return self.file
except portalocker.LockException:
if time.time() - self.start_time > self.timeout:
raise TimeoutError("获取文件锁超时")
time.sleep(0.1)
def __exit__(self, exc_type, exc_val, exc_tb):
portalocker.unlock(self.file)
class WakeupWordsConfig:
def __init__(self):
self.config_file = "data/.wakeup_words.yaml"
self.assets_dir = "config/assets/wakeup_words"
self._ensure_directories()
self._config_cache = None
self._last_load_time = 0
self._cache_ttl = 1 # 缓存有效期(秒)
self._lock_timeout = 5 # 文件锁超时时间(秒)
def _ensure_directories(self):
"""确保必要的目录存在"""
os.makedirs(os.path.dirname(self.config_file), exist_ok=True)
os.makedirs(self.assets_dir, exist_ok=True)
def _load_config(self) -> Dict:
"""加载配置文件,使用缓存机制"""
current_time = time.time()
# 如果缓存有效,直接返回缓存
if (
self._config_cache is not None
and current_time - self._last_load_time < self._cache_ttl
):
return self._config_cache
try:
with open(self.config_file, "a+", encoding="utf-8") as f:
with FileLock(f, timeout=self._lock_timeout):
f.seek(0)
content = f.read()
config = yaml.safe_load(content) if content else {}
self._config_cache = config
self._last_load_time = current_time
return config
except (TimeoutError, IOError) as e:
print(f"加载配置文件失败: {e}")
return {}
except Exception as e:
print(f"加载配置文件时发生未知错误: {e}")
return {}
def _save_config(self, config: Dict):
"""保存配置到文件,使用文件锁保护"""
try:
with open(self.config_file, "w", encoding="utf-8") as f:
with FileLock(f, timeout=self._lock_timeout):
yaml.dump(config, f, allow_unicode=True)
self._config_cache = config
self._last_load_time = time.time()
except (TimeoutError, IOError) as e:
print(f"保存配置文件失败: {e}")
raise
except Exception as e:
print(f"保存配置文件时发生未知错误: {e}")
raise
def get_wakeup_response(self, voice: str) -> Dict:
voice = hashlib.md5(voice.encode()).hexdigest()
"""获取唤醒词回复配置"""
config = self._load_config()
if not config or voice not in config:
return None
# 检查文件大小
file_path = config[voice]["file_path"]
if not os.path.exists(file_path) or os.stat(file_path).st_size < (15 * 1024):
return None
return config[voice]
def update_wakeup_response(self, voice: str, file_path: str, text: str):
"""更新唤醒词回复配置"""
try:
# 过滤表情符号
filtered_text = re.sub(r'[\U0001F600-\U0001F64F\U0001F900-\U0001F9FF]', '', text)
config = self._load_config()
voice_hash = hashlib.md5(voice.encode()).hexdigest()
config[voice_hash] = {
"voice": voice,
"file_path": file_path,
"time": time.time(),
"text": filtered_text,
}
self._save_config(config)
except Exception as e:
print(f"更新唤醒词回复配置失败: {e}")
raise
def generate_file_path(self, voice: str) -> str:
"""生成音频文件路径,使用voice的哈希值作为文件名"""
try:
# 生成voice的哈希值
voice_hash = hashlib.md5(voice.encode()).hexdigest()
file_path = os.path.join(self.assets_dir, f"{voice_hash}.wav")
# 如果文件已存在,先删除
if os.path.exists(file_path):
try:
os.remove(file_path)
except Exception as e:
print(f"删除已存在的音频文件失败: {e}")
raise
return file_path
except Exception as e:
print(f"生成音频文件路径失败: {e}")
raise