仿生人AI服务端
This commit is contained in:
@@ -0,0 +1,24 @@
|
||||
import importlib
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import wave
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional, Tuple, List
|
||||
from core.providers.asr.base import ASRProviderBase
|
||||
from config.logger import setup_logging
|
||||
|
||||
TAG = __name__
|
||||
logger = setup_logging()
|
||||
|
||||
def create_instance(class_name: str, *args, **kwargs) -> ASRProviderBase:
|
||||
"""工厂方法创建ASR实例"""
|
||||
if os.path.exists(os.path.join('core', 'providers', 'asr', f'{class_name}.py')):
|
||||
lib_name = f'core.providers.asr.{class_name}'
|
||||
if lib_name not in sys.modules:
|
||||
sys.modules[lib_name] = importlib.import_module(f'{lib_name}')
|
||||
return sys.modules[lib_name].ASRProvider(*args, **kwargs)
|
||||
|
||||
raise ValueError(f"不支持的ASR类型: {class_name},请检查该配置的type是否设置正确")
|
||||
@@ -0,0 +1,126 @@
|
||||
import jwt
|
||||
import time
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Tuple, Optional
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
from cryptography.hazmat.primitives import padding
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
import base64
|
||||
|
||||
|
||||
class AuthToken:
|
||||
def __init__(self, secret_key: str):
|
||||
self.secret_key = secret_key.encode() # 转换为字节
|
||||
# 从密钥派生固定长度的加密密钥 (32字节 for AES-256)
|
||||
self.encryption_key = self._derive_key(32)
|
||||
|
||||
def _derive_key(self, length: int) -> bytes:
|
||||
"""派生固定长度的密钥"""
|
||||
from cryptography.hazmat.primitives import hashes
|
||||
from cryptography.hazmat.primitives.kdf.pbkdf2 import PBKDF2HMAC
|
||||
|
||||
# 使用固定盐值(实际生产环境应使用随机盐)
|
||||
salt = b"fixed_salt_placeholder" # 生产环境应改为随机生成
|
||||
kdf = PBKDF2HMAC(
|
||||
algorithm=hashes.SHA256(),
|
||||
length=length,
|
||||
salt=salt,
|
||||
iterations=100000,
|
||||
backend=default_backend(),
|
||||
)
|
||||
return kdf.derive(self.secret_key)
|
||||
|
||||
def _encrypt_payload(self, payload: dict) -> str:
|
||||
"""使用AES-GCM加密整个payload"""
|
||||
# 将payload转换为JSON字符串
|
||||
payload_json = json.dumps(payload)
|
||||
|
||||
# 生成随机IV
|
||||
iv = os.urandom(12)
|
||||
# 创建加密器
|
||||
cipher = Cipher(
|
||||
algorithms.AES(self.encryption_key),
|
||||
modes.GCM(iv),
|
||||
backend=default_backend(),
|
||||
)
|
||||
encryptor = cipher.encryptor()
|
||||
|
||||
# 加密并生成标签
|
||||
ciphertext = encryptor.update(payload_json.encode()) + encryptor.finalize()
|
||||
tag = encryptor.tag
|
||||
|
||||
# 组合 IV + 密文 + 标签
|
||||
encrypted_data = iv + ciphertext + tag
|
||||
return base64.urlsafe_b64encode(encrypted_data).decode()
|
||||
|
||||
def _decrypt_payload(self, encrypted_data: str) -> dict:
|
||||
"""解密AES-GCM加密的payload"""
|
||||
# 解码Base64
|
||||
data = base64.urlsafe_b64decode(encrypted_data.encode())
|
||||
# 拆分组件
|
||||
iv = data[:12]
|
||||
tag = data[-16:]
|
||||
ciphertext = data[12:-16]
|
||||
|
||||
# 创建解密器
|
||||
cipher = Cipher(
|
||||
algorithms.AES(self.encryption_key),
|
||||
modes.GCM(iv, tag),
|
||||
backend=default_backend(),
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
|
||||
# 解密
|
||||
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
|
||||
return json.loads(plaintext.decode())
|
||||
|
||||
def generate_token(self, device_id: str) -> str:
|
||||
"""
|
||||
生成JWT token
|
||||
:param device_id: 设备ID
|
||||
:return: JWT token字符串
|
||||
"""
|
||||
# 设置过期时间为1小时后
|
||||
expire_time = datetime.now(timezone.utc) + timedelta(hours=1)
|
||||
|
||||
# 创建原始payload
|
||||
payload = {"device_id": device_id, "exp": expire_time.timestamp()}
|
||||
|
||||
# 加密整个payload
|
||||
encrypted_payload = self._encrypt_payload(payload)
|
||||
|
||||
# 创建外层payload,包含加密数据
|
||||
outer_payload = {"data": encrypted_payload}
|
||||
|
||||
# 使用JWT进行编码
|
||||
token = jwt.encode(outer_payload, self.secret_key, algorithm="HS256")
|
||||
return token
|
||||
|
||||
def verify_token(self, token: str) -> Tuple[bool, Optional[str]]:
|
||||
"""
|
||||
验证token
|
||||
:param token: JWT token字符串
|
||||
:return: (是否有效, 设备ID)
|
||||
"""
|
||||
try:
|
||||
# 先验证外层JWT(签名和过期时间)
|
||||
outer_payload = jwt.decode(token, self.secret_key, algorithms=["HS256"])
|
||||
|
||||
# 解密内层payload
|
||||
inner_payload = self._decrypt_payload(outer_payload["data"])
|
||||
|
||||
# 再次检查过期时间(双重验证)
|
||||
if inner_payload["exp"] < time.time():
|
||||
return False, None
|
||||
|
||||
return True, inner_payload["device_id"]
|
||||
|
||||
except jwt.InvalidTokenError:
|
||||
return False, None
|
||||
except json.JSONDecodeError:
|
||||
return False, None
|
||||
except Exception as e: # 捕获其他可能的错误
|
||||
print(f"Token verification failed: {str(e)}")
|
||||
return False, None
|
||||
+62
@@ -0,0 +1,62 @@
|
||||
"""
|
||||
缓存配置管理
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Dict, Any, Optional
|
||||
from dataclasses import dataclass
|
||||
from .strategies import CacheStrategy
|
||||
|
||||
|
||||
class CacheType(Enum):
|
||||
"""缓存类型枚举"""
|
||||
|
||||
LOCATION = "location"
|
||||
WEATHER = "weather"
|
||||
LUNAR = "lunar"
|
||||
INTENT = "intent"
|
||||
IP_INFO = "ip_info"
|
||||
CONFIG = "config"
|
||||
DEVICE_PROMPT = "device_prompt"
|
||||
VOICEPRINT_HEALTH = "voiceprint_health" # 声纹识别健康检查
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheConfig:
|
||||
"""缓存配置类"""
|
||||
|
||||
strategy: CacheStrategy = CacheStrategy.TTL
|
||||
ttl: Optional[float] = 300 # 默认5分钟
|
||||
max_size: Optional[int] = 1000 # 默认最大1000条
|
||||
cleanup_interval: float = 60 # 清理间隔(秒)
|
||||
|
||||
@classmethod
|
||||
def for_type(cls, cache_type: CacheType) -> "CacheConfig":
|
||||
"""根据缓存类型返回预设配置"""
|
||||
configs = {
|
||||
CacheType.LOCATION: cls(
|
||||
strategy=CacheStrategy.TTL, ttl=None, max_size=1000 # 手动失效
|
||||
),
|
||||
CacheType.IP_INFO: cls(
|
||||
strategy=CacheStrategy.TTL, ttl=86400, max_size=1000 # 24小时
|
||||
),
|
||||
CacheType.WEATHER: cls(
|
||||
strategy=CacheStrategy.TTL, ttl=28800, max_size=1000 # 8小时
|
||||
),
|
||||
CacheType.LUNAR: cls(
|
||||
strategy=CacheStrategy.TTL, ttl=2592000, max_size=365 # 30天过期
|
||||
),
|
||||
CacheType.INTENT: cls(
|
||||
strategy=CacheStrategy.TTL_LRU, ttl=600, max_size=1000 # 10分钟
|
||||
),
|
||||
CacheType.CONFIG: cls(
|
||||
strategy=CacheStrategy.FIXED_SIZE, ttl=None, max_size=20 # 手动失效
|
||||
),
|
||||
CacheType.DEVICE_PROMPT: cls(
|
||||
strategy=CacheStrategy.TTL, ttl=None, max_size=1000 # 手动失效
|
||||
),
|
||||
CacheType.VOICEPRINT_HEALTH: cls(
|
||||
strategy=CacheStrategy.TTL, ttl=600, max_size=100 # 10分钟过期
|
||||
),
|
||||
}
|
||||
return configs.get(cache_type, cls())
|
||||
+216
@@ -0,0 +1,216 @@
|
||||
"""
|
||||
全局缓存管理器
|
||||
"""
|
||||
|
||||
import time
|
||||
import threading
|
||||
from typing import Any, Optional, Dict
|
||||
from collections import OrderedDict
|
||||
from .strategies import CacheStrategy, CacheEntry
|
||||
from .config import CacheConfig, CacheType
|
||||
|
||||
|
||||
class GlobalCacheManager:
|
||||
"""全局缓存管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._logger = None
|
||||
self._caches: Dict[str, Dict[str, CacheEntry]] = {}
|
||||
self._configs: Dict[str, CacheConfig] = {}
|
||||
self._locks: Dict[str, threading.RLock] = {}
|
||||
self._global_lock = threading.RLock()
|
||||
self._last_cleanup = time.time()
|
||||
self._stats = {"hits": 0, "misses": 0, "evictions": 0, "cleanups": 0}
|
||||
|
||||
@property
|
||||
def logger(self):
|
||||
"""延迟初始化 logger 以避免循环导入"""
|
||||
if self._logger is None:
|
||||
from config.logger import setup_logging
|
||||
|
||||
self._logger = setup_logging()
|
||||
return self._logger
|
||||
|
||||
def _get_cache_name(self, cache_type: CacheType, namespace: str = "") -> str:
|
||||
"""生成缓存名称"""
|
||||
if namespace:
|
||||
return f"{cache_type.value}:{namespace}"
|
||||
return cache_type.value
|
||||
|
||||
def _get_or_create_cache(
|
||||
self, cache_name: str, config: CacheConfig
|
||||
) -> Dict[str, CacheEntry]:
|
||||
"""获取或创建缓存空间"""
|
||||
with self._global_lock:
|
||||
if cache_name not in self._caches:
|
||||
self._caches[cache_name] = (
|
||||
OrderedDict()
|
||||
if config.strategy in [CacheStrategy.LRU, CacheStrategy.TTL_LRU]
|
||||
else {}
|
||||
)
|
||||
self._configs[cache_name] = config
|
||||
self._locks[cache_name] = threading.RLock()
|
||||
return self._caches[cache_name]
|
||||
|
||||
def set(
|
||||
self,
|
||||
cache_type: CacheType,
|
||||
key: str,
|
||||
value: Any,
|
||||
ttl: Optional[float] = None,
|
||||
namespace: str = "",
|
||||
) -> None:
|
||||
"""设置缓存值"""
|
||||
cache_name = self._get_cache_name(cache_type, namespace)
|
||||
config = self._configs.get(cache_name) or CacheConfig.for_type(cache_type)
|
||||
cache = self._get_or_create_cache(cache_name, config)
|
||||
|
||||
# 使用配置的TTL或传入的TTL
|
||||
effective_ttl = ttl if ttl is not None else config.ttl
|
||||
|
||||
with self._locks[cache_name]:
|
||||
# 创建缓存条目
|
||||
entry = CacheEntry(value=value, timestamp=time.time(), ttl=effective_ttl)
|
||||
|
||||
# 处理不同策略
|
||||
if config.strategy in [CacheStrategy.LRU, CacheStrategy.TTL_LRU]:
|
||||
# LRU策略:如果已存在则移动到末尾
|
||||
if key in cache:
|
||||
del cache[key]
|
||||
cache[key] = entry
|
||||
|
||||
# 检查大小限制
|
||||
if config.max_size and len(cache) > config.max_size:
|
||||
# 移除最旧的条目
|
||||
oldest_key = next(iter(cache))
|
||||
del cache[oldest_key]
|
||||
self._stats["evictions"] += 1
|
||||
|
||||
else:
|
||||
cache[key] = entry
|
||||
|
||||
# 检查大小限制
|
||||
if config.max_size and len(cache) > config.max_size:
|
||||
# 简单策略:随机移除一个条目
|
||||
victim_key = next(iter(cache))
|
||||
del cache[victim_key]
|
||||
self._stats["evictions"] += 1
|
||||
|
||||
# 定期清理过期条目
|
||||
self._maybe_cleanup(cache_name)
|
||||
|
||||
def get(
|
||||
self, cache_type: CacheType, key: str, namespace: str = ""
|
||||
) -> Optional[Any]:
|
||||
"""获取缓存值"""
|
||||
cache_name = self._get_cache_name(cache_type, namespace)
|
||||
|
||||
if cache_name not in self._caches:
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
cache = self._caches[cache_name]
|
||||
config = self._configs[cache_name]
|
||||
|
||||
with self._locks[cache_name]:
|
||||
if key not in cache:
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
entry = cache[key]
|
||||
|
||||
# 检查过期
|
||||
if entry.is_expired():
|
||||
del cache[key]
|
||||
self._stats["misses"] += 1
|
||||
return None
|
||||
|
||||
# 更新访问信息
|
||||
entry.touch()
|
||||
|
||||
# LRU策略:移动到末尾
|
||||
if config.strategy in [CacheStrategy.LRU, CacheStrategy.TTL_LRU]:
|
||||
del cache[key]
|
||||
cache[key] = entry
|
||||
|
||||
self._stats["hits"] += 1
|
||||
return entry.value
|
||||
|
||||
def delete(self, cache_type: CacheType, key: str, namespace: str = "") -> bool:
|
||||
"""删除缓存条目"""
|
||||
cache_name = self._get_cache_name(cache_type, namespace)
|
||||
|
||||
if cache_name not in self._caches:
|
||||
return False
|
||||
|
||||
cache = self._caches[cache_name]
|
||||
|
||||
with self._locks[cache_name]:
|
||||
if key in cache:
|
||||
del cache[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def clear(self, cache_type: CacheType, namespace: str = "") -> None:
|
||||
"""清空指定缓存"""
|
||||
cache_name = self._get_cache_name(cache_type, namespace)
|
||||
|
||||
if cache_name not in self._caches:
|
||||
return
|
||||
|
||||
with self._locks[cache_name]:
|
||||
self._caches[cache_name].clear()
|
||||
|
||||
def invalidate_pattern(
|
||||
self, cache_type: CacheType, pattern: str, namespace: str = ""
|
||||
) -> int:
|
||||
"""按模式失效缓存条目"""
|
||||
cache_name = self._get_cache_name(cache_type, namespace)
|
||||
|
||||
if cache_name not in self._caches:
|
||||
return 0
|
||||
|
||||
cache = self._caches[cache_name]
|
||||
deleted_count = 0
|
||||
|
||||
with self._locks[cache_name]:
|
||||
keys_to_delete = [key for key in cache.keys() if pattern in key]
|
||||
for key in keys_to_delete:
|
||||
del cache[key]
|
||||
deleted_count += 1
|
||||
|
||||
return deleted_count
|
||||
|
||||
def _cleanup_expired(self, cache_name: str) -> int:
|
||||
"""清理过期条目"""
|
||||
if cache_name not in self._caches:
|
||||
return 0
|
||||
|
||||
cache = self._caches[cache_name]
|
||||
deleted_count = 0
|
||||
|
||||
with self._locks[cache_name]:
|
||||
expired_keys = [key for key, entry in cache.items() if entry.is_expired()]
|
||||
for key in expired_keys:
|
||||
del cache[key]
|
||||
deleted_count += 1
|
||||
|
||||
return deleted_count
|
||||
|
||||
def _maybe_cleanup(self, cache_name: str):
|
||||
"""定期清理检查"""
|
||||
config = self._configs.get(cache_name)
|
||||
if not config:
|
||||
return
|
||||
|
||||
now = time.time()
|
||||
if now - self._last_cleanup > config.cleanup_interval:
|
||||
self._last_cleanup = now
|
||||
deleted = self._cleanup_expired(cache_name)
|
||||
if deleted > 0:
|
||||
self._stats["cleanups"] += 1
|
||||
self.logger.debug(f"清理缓存 {cache_name}: 删除 {deleted} 个过期条目")
|
||||
|
||||
|
||||
# 创建全局缓存管理器实例
|
||||
cache_manager = GlobalCacheManager()
|
||||
+43
@@ -0,0 +1,43 @@
|
||||
"""
|
||||
缓存策略和数据结构定义
|
||||
"""
|
||||
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
class CacheStrategy(Enum):
|
||||
"""缓存策略枚举"""
|
||||
|
||||
TTL = "ttl" # 基于时间过期
|
||||
LRU = "lru" # 最近最少使用
|
||||
FIXED_SIZE = "fixed_size" # 固定大小
|
||||
TTL_LRU = "ttl_lru" # TTL + LRU混合策略
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheEntry:
|
||||
"""缓存条目数据结构"""
|
||||
|
||||
value: Any
|
||||
timestamp: float
|
||||
ttl: Optional[float] = None # 生存时间(秒)
|
||||
access_count: int = 0
|
||||
last_access: float = None
|
||||
|
||||
def __post_init__(self):
|
||||
if self.last_access is None:
|
||||
self.last_access = self.timestamp
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
"""检查是否过期"""
|
||||
if self.ttl is None:
|
||||
return False
|
||||
return time.time() - self.timestamp > self.ttl
|
||||
|
||||
def touch(self):
|
||||
"""更新访问时间和计数"""
|
||||
self.last_access = time.time()
|
||||
self.access_count += 1
|
||||
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
时间工具模块
|
||||
提供统一的时间获取功能
|
||||
"""
|
||||
|
||||
import cnlunar
|
||||
from datetime import datetime
|
||||
|
||||
WEEKDAY_MAP = {
|
||||
"Monday": "星期一",
|
||||
"Tuesday": "星期二",
|
||||
"Wednesday": "星期三",
|
||||
"Thursday": "星期四",
|
||||
"Friday": "星期五",
|
||||
"Saturday": "星期六",
|
||||
"Sunday": "星期日",
|
||||
}
|
||||
|
||||
|
||||
def get_current_time() -> str:
|
||||
"""
|
||||
获取当前时间字符串 (格式: HH:MM)
|
||||
"""
|
||||
return datetime.now().strftime("%H:%M")
|
||||
|
||||
|
||||
def get_current_date() -> str:
|
||||
"""
|
||||
获取今天日期字符串 (格式: YYYY-MM-DD)
|
||||
"""
|
||||
return datetime.now().strftime("%Y-%m-%d")
|
||||
|
||||
|
||||
def get_current_weekday() -> str:
|
||||
"""
|
||||
获取今天星期几
|
||||
"""
|
||||
now = datetime.now()
|
||||
return WEEKDAY_MAP[now.strftime("%A")]
|
||||
|
||||
|
||||
def get_current_lunar_date() -> str:
|
||||
"""
|
||||
获取农历日期字符串
|
||||
"""
|
||||
try:
|
||||
now = datetime.now()
|
||||
today_lunar = cnlunar.Lunar(now, godType="8char")
|
||||
return "%s年%s%s" % (
|
||||
today_lunar.lunarYearCn,
|
||||
today_lunar.lunarMonthCn[:-1],
|
||||
today_lunar.lunarDayCn,
|
||||
)
|
||||
except Exception:
|
||||
return "农历获取失败"
|
||||
|
||||
|
||||
def get_current_time_info() -> tuple:
|
||||
"""
|
||||
获取当前时间信息
|
||||
返回: (当前时间字符串, 今天日期, 今天星期, 农历日期)
|
||||
"""
|
||||
current_time = get_current_time()
|
||||
today_date = get_current_date()
|
||||
today_weekday = get_current_weekday()
|
||||
lunar_date = get_current_lunar_date()
|
||||
|
||||
return current_time, today_date, today_weekday, lunar_date
|
||||
@@ -0,0 +1,118 @@
|
||||
import uuid
|
||||
import re
|
||||
from typing import List, Dict
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class Message:
|
||||
def __init__(
|
||||
self,
|
||||
role: str,
|
||||
content: str = None,
|
||||
uniq_id: str = None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
):
|
||||
self.uniq_id = uniq_id if uniq_id is not None else str(uuid.uuid4())
|
||||
self.role = role
|
||||
self.content = content
|
||||
self.tool_calls = tool_calls
|
||||
self.tool_call_id = tool_call_id
|
||||
|
||||
|
||||
class Dialogue:
|
||||
def __init__(self):
|
||||
self.dialogue: List[Message] = []
|
||||
# 获取当前时间
|
||||
self.current_time = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
|
||||
def put(self, message: Message):
|
||||
self.dialogue.append(message)
|
||||
|
||||
def getMessages(self, m, dialogue):
|
||||
if m.tool_calls is not None:
|
||||
dialogue.append({"role": m.role, "tool_calls": m.tool_calls})
|
||||
elif m.role == "tool":
|
||||
dialogue.append(
|
||||
{
|
||||
"role": m.role,
|
||||
"tool_call_id": (
|
||||
str(uuid.uuid4()) if m.tool_call_id is None else m.tool_call_id
|
||||
),
|
||||
"content": m.content,
|
||||
}
|
||||
)
|
||||
else:
|
||||
dialogue.append({"role": m.role, "content": m.content})
|
||||
|
||||
def get_llm_dialogue(self) -> List[Dict[str, str]]:
|
||||
# 直接调用get_llm_dialogue_with_memory,传入None作为memory_str
|
||||
# 这样确保说话人功能在所有调用路径下都生效
|
||||
return self.get_llm_dialogue_with_memory(None, None)
|
||||
|
||||
def update_system_message(self, new_content: str):
|
||||
"""更新或添加系统消息"""
|
||||
# 查找第一个系统消息
|
||||
system_msg = next((msg for msg in self.dialogue if msg.role == "system"), None)
|
||||
if system_msg:
|
||||
system_msg.content = new_content
|
||||
else:
|
||||
self.put(Message(role="system", content=new_content))
|
||||
|
||||
def get_llm_dialogue_with_memory(
|
||||
self, memory_str: str = None, voiceprint_config: dict = None
|
||||
) -> List[Dict[str, str]]:
|
||||
# 构建对话
|
||||
dialogue = []
|
||||
|
||||
# 添加系统提示和记忆
|
||||
system_message = next(
|
||||
(msg for msg in self.dialogue if msg.role == "system"), None
|
||||
)
|
||||
|
||||
if system_message:
|
||||
# 基础系统提示
|
||||
enhanced_system_prompt = system_message.content
|
||||
# 替换时间占位符
|
||||
enhanced_system_prompt = enhanced_system_prompt.replace(
|
||||
"{{current_time}}", datetime.now().strftime("%H:%M")
|
||||
)
|
||||
|
||||
# 添加说话人个性化描述
|
||||
try:
|
||||
speakers = voiceprint_config.get("speakers", [])
|
||||
if speakers:
|
||||
enhanced_system_prompt += "\n\n<speakers_info>"
|
||||
for speaker_str in speakers:
|
||||
try:
|
||||
parts = speaker_str.split(",", 2)
|
||||
if len(parts) >= 2:
|
||||
name = parts[1].strip()
|
||||
# 如果描述为空,则为""
|
||||
description = (
|
||||
parts[2].strip() if len(parts) >= 3 else ""
|
||||
)
|
||||
enhanced_system_prompt += f"\n- {name}:{description}"
|
||||
except:
|
||||
pass
|
||||
enhanced_system_prompt += "\n\n</speakers_info>"
|
||||
except:
|
||||
# 配置读取失败时忽略错误,不影响其他功能
|
||||
pass
|
||||
|
||||
# 使用正则表达式匹配 <memory> 标签,不管中间有什么内容
|
||||
if memory_str is not None:
|
||||
enhanced_system_prompt = re.sub(
|
||||
r"<memory>.*?</memory>",
|
||||
f"<memory>\n{memory_str}\n</memory>",
|
||||
enhanced_system_prompt,
|
||||
flags=re.DOTALL,
|
||||
)
|
||||
dialogue.append({"role": "system", "content": enhanced_system_prompt})
|
||||
|
||||
# 添加用户和助手的对话
|
||||
for m in self.dialogue:
|
||||
if m.role != "system": # 跳过原始的系统消息
|
||||
self.getMessages(m, dialogue)
|
||||
|
||||
return dialogue
|
||||
@@ -0,0 +1,17 @@
|
||||
import os
|
||||
import sys
|
||||
from config.logger import setup_logging
|
||||
import importlib
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
def create_instance(class_name, *args, **kwargs):
|
||||
# 创建intent实例
|
||||
if os.path.exists(os.path.join('core', 'providers', 'intent', class_name, f'{class_name}.py')):
|
||||
lib_name = f'core.providers.intent.{class_name}.{class_name}'
|
||||
if lib_name not in sys.modules:
|
||||
sys.modules[lib_name] = importlib.import_module(f'{lib_name}')
|
||||
return sys.modules[lib_name].IntentProvider(*args, **kwargs)
|
||||
|
||||
raise ValueError(f"不支持的intent类型: {class_name},请检查该配置的type是否设置正确")
|
||||
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, "..", ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from config.logger import setup_logging
|
||||
import importlib
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
def create_instance(class_name, *args, **kwargs):
|
||||
# 创建LLM实例
|
||||
if os.path.exists(os.path.join('core', 'providers', 'llm', class_name, f'{class_name}.py')):
|
||||
lib_name = f'core.providers.llm.{class_name}.{class_name}'
|
||||
if lib_name not in sys.modules:
|
||||
sys.modules[lib_name] = importlib.import_module(f'{lib_name}')
|
||||
return sys.modules[lib_name].LLMProvider(*args, **kwargs)
|
||||
|
||||
raise ValueError(f"不支持的LLM类型: {class_name},请检查该配置的type是否设置正确")
|
||||
@@ -0,0 +1,18 @@
|
||||
import os
|
||||
import sys
|
||||
import importlib
|
||||
from config.logger import setup_logging
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
def create_instance(class_name, *args, **kwargs):
|
||||
if os.path.exists(
|
||||
os.path.join("core", "providers", "memory", class_name, f"{class_name}.py")
|
||||
):
|
||||
lib_name = f"core.providers.memory.{class_name}.{class_name}"
|
||||
if lib_name not in sys.modules:
|
||||
sys.modules[lib_name] = importlib.import_module(f"{lib_name}")
|
||||
return sys.modules[lib_name].MemoryProvider(*args, **kwargs)
|
||||
|
||||
raise ValueError(f"不支持的记忆服务类型: {class_name}")
|
||||
@@ -0,0 +1,151 @@
|
||||
from typing import Dict, Any
|
||||
from config.logger import setup_logging
|
||||
from core.utils import tts, llm, intent, memory, vad, asr
|
||||
|
||||
TAG = __name__
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
def initialize_modules(
|
||||
logger,
|
||||
config: Dict[str, Any],
|
||||
init_vad=False,
|
||||
init_asr=False,
|
||||
init_llm=False,
|
||||
init_tts=False,
|
||||
init_memory=False,
|
||||
init_intent=False,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
初始化所有模块组件
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: 包含所有初始化后的模块的字典
|
||||
"""
|
||||
modules = {}
|
||||
|
||||
# 初始化TTS模块
|
||||
if init_tts:
|
||||
select_tts_module = config["selected_module"]["TTS"]
|
||||
modules["tts"] = initialize_tts(config)
|
||||
logger.bind(tag=TAG).info(f"初始化组件: tts成功 {select_tts_module}")
|
||||
|
||||
# 初始化LLM模块
|
||||
if init_llm:
|
||||
select_llm_module = config["selected_module"]["LLM"]
|
||||
llm_type = (
|
||||
select_llm_module
|
||||
if "type" not in config["LLM"][select_llm_module]
|
||||
else config["LLM"][select_llm_module]["type"]
|
||||
)
|
||||
modules["llm"] = llm.create_instance(
|
||||
llm_type,
|
||||
config["LLM"][select_llm_module],
|
||||
)
|
||||
logger.bind(tag=TAG).info(f"初始化组件: llm成功 {select_llm_module}")
|
||||
|
||||
# 初始化Intent模块
|
||||
if init_intent:
|
||||
select_intent_module = config["selected_module"]["Intent"]
|
||||
intent_type = (
|
||||
select_intent_module
|
||||
if "type" not in config["Intent"][select_intent_module]
|
||||
else config["Intent"][select_intent_module]["type"]
|
||||
)
|
||||
modules["intent"] = intent.create_instance(
|
||||
intent_type,
|
||||
config["Intent"][select_intent_module],
|
||||
)
|
||||
logger.bind(tag=TAG).info(f"初始化组件: intent成功 {select_intent_module}")
|
||||
|
||||
# 初始化Memory模块
|
||||
if init_memory:
|
||||
select_memory_module = config["selected_module"]["Memory"]
|
||||
memory_type = (
|
||||
select_memory_module
|
||||
if "type" not in config["Memory"][select_memory_module]
|
||||
else config["Memory"][select_memory_module]["type"]
|
||||
)
|
||||
modules["memory"] = memory.create_instance(
|
||||
memory_type,
|
||||
config["Memory"][select_memory_module],
|
||||
config.get("summaryMemory", None),
|
||||
)
|
||||
logger.bind(tag=TAG).info(f"初始化组件: memory成功 {select_memory_module}")
|
||||
|
||||
# 初始化VAD模块
|
||||
if init_vad:
|
||||
select_vad_module = config["selected_module"]["VAD"]
|
||||
vad_type = (
|
||||
select_vad_module
|
||||
if "type" not in config["VAD"][select_vad_module]
|
||||
else config["VAD"][select_vad_module]["type"]
|
||||
)
|
||||
modules["vad"] = vad.create_instance(
|
||||
vad_type,
|
||||
config["VAD"][select_vad_module],
|
||||
)
|
||||
logger.bind(tag=TAG).info(f"初始化组件: vad成功 {select_vad_module}")
|
||||
|
||||
# 初始化ASR模块
|
||||
if init_asr:
|
||||
select_asr_module = config["selected_module"]["ASR"]
|
||||
modules["asr"] = initialize_asr(config)
|
||||
logger.bind(tag=TAG).info(f"初始化组件: asr成功 {select_asr_module}")
|
||||
return modules
|
||||
|
||||
|
||||
def initialize_tts(config):
|
||||
select_tts_module = config["selected_module"]["TTS"]
|
||||
tts_type = (
|
||||
select_tts_module
|
||||
if "type" not in config["TTS"][select_tts_module]
|
||||
else config["TTS"][select_tts_module]["type"]
|
||||
)
|
||||
new_tts = tts.create_instance(
|
||||
tts_type,
|
||||
config["TTS"][select_tts_module],
|
||||
str(config.get("delete_audio", True)).lower() in ("true", "1", "yes"),
|
||||
)
|
||||
return new_tts
|
||||
|
||||
|
||||
def initialize_asr(config):
|
||||
select_asr_module = config["selected_module"]["ASR"]
|
||||
asr_type = (
|
||||
select_asr_module
|
||||
if "type" not in config["ASR"][select_asr_module]
|
||||
else config["ASR"][select_asr_module]["type"]
|
||||
)
|
||||
new_asr = asr.create_instance(
|
||||
asr_type,
|
||||
config["ASR"][select_asr_module],
|
||||
str(config.get("delete_audio", True)).lower() in ("true", "1", "yes"),
|
||||
)
|
||||
logger.bind(tag=TAG).info("ASR模块初始化完成")
|
||||
return new_asr
|
||||
|
||||
|
||||
def initialize_voiceprint(asr_instance, config):
|
||||
"""初始化声纹识别功能"""
|
||||
voiceprint_config = config.get("voiceprint")
|
||||
if not voiceprint_config:
|
||||
return False
|
||||
|
||||
# 应用配置
|
||||
if not voiceprint_config.get("url") or not voiceprint_config.get("speakers"):
|
||||
logger.bind(tag=TAG).warning("声纹识别配置不完整")
|
||||
return False
|
||||
|
||||
try:
|
||||
asr_instance.init_voiceprint(voiceprint_config)
|
||||
logger.bind(tag=TAG).info("ASR模块声纹识别功能已动态启用")
|
||||
logger.bind(tag=TAG).info(f"配置说话人数量: {len(voiceprint_config['speakers'])}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.bind(tag=TAG).error(f"动态初始化声纹识别功能失败: {str(e)}")
|
||||
return False
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
"""
|
||||
Opus编码工具类
|
||||
将PCM音频数据编码为Opus格式
|
||||
"""
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
import numpy as np
|
||||
from opuslib_next import Encoder
|
||||
from opuslib_next import constants
|
||||
from typing import Optional, Callable, Any
|
||||
|
||||
class OpusEncoderUtils:
|
||||
"""PCM到Opus的编码器"""
|
||||
|
||||
def __init__(self, sample_rate: int, channels: int, frame_size_ms: int):
|
||||
"""
|
||||
初始化Opus编码器
|
||||
|
||||
Args:
|
||||
sample_rate: 采样率 (Hz)
|
||||
channels: 通道数 (1=单声道, 2=立体声)
|
||||
frame_size_ms: 帧大小 (毫秒)
|
||||
"""
|
||||
self.sample_rate = sample_rate
|
||||
self.channels = channels
|
||||
self.frame_size_ms = frame_size_ms
|
||||
# 计算每帧样本数 = 采样率 * 帧大小(毫秒) / 1000
|
||||
self.frame_size = (sample_rate * frame_size_ms) // 1000
|
||||
# 总帧大小 = 每帧样本数 * 通道数
|
||||
self.total_frame_size = self.frame_size * channels
|
||||
|
||||
# 比特率和复杂度设置
|
||||
self.bitrate = 24000 # bps
|
||||
self.complexity = 10 # 最高质量
|
||||
|
||||
# 缓冲区初始化为空
|
||||
self.buffer = np.array([], dtype=np.int16)
|
||||
|
||||
try:
|
||||
# 创建Opus编码器
|
||||
self.encoder = Encoder(
|
||||
sample_rate, channels, constants.APPLICATION_AUDIO # 音频优化模式
|
||||
)
|
||||
self.encoder.bitrate = self.bitrate
|
||||
self.encoder.complexity = self.complexity
|
||||
self.encoder.signal = constants.SIGNAL_VOICE # 语音信号优化
|
||||
except Exception as e:
|
||||
logging.error(f"初始化Opus编码器失败: {e}")
|
||||
raise RuntimeError("初始化失败") from e
|
||||
|
||||
def reset_state(self):
|
||||
"""重置编码器状态"""
|
||||
self.encoder.reset_state()
|
||||
self.buffer = np.array([], dtype=np.int16)
|
||||
|
||||
def encode_pcm_to_opus_stream(self, pcm_data: bytes, end_of_stream: bool, callback: Callable[[Any], Any]):
|
||||
"""
|
||||
将PCM数据编码为Opus格式,以流式方式进行处理
|
||||
|
||||
Args:
|
||||
pcm_data: PCM字节数据
|
||||
end_of_stream: 是否为流的结束,
|
||||
callback: opus处理方法
|
||||
|
||||
Returns:
|
||||
Opus数据包列表
|
||||
"""
|
||||
# 将字节数据转换为short数组
|
||||
new_samples = self._convert_bytes_to_shorts(pcm_data)
|
||||
|
||||
# 校验PCM数据
|
||||
self._validate_pcm_data(new_samples)
|
||||
|
||||
# 将新数据追加到缓冲区
|
||||
self.buffer = np.append(self.buffer, new_samples)
|
||||
|
||||
offset = 0
|
||||
|
||||
# 处理所有完整帧
|
||||
while offset <= len(self.buffer) - self.total_frame_size:
|
||||
frame = self.buffer[offset : offset + self.total_frame_size]
|
||||
output = self._encode(frame)
|
||||
if output:
|
||||
callback(output)
|
||||
offset += self.total_frame_size
|
||||
|
||||
# 保留未处理的样本
|
||||
self.buffer = self.buffer[offset:]
|
||||
|
||||
# 流结束时处理剩余数据
|
||||
if end_of_stream and len(self.buffer) > 0:
|
||||
# 创建最后一帧并用0填充
|
||||
last_frame = np.zeros(self.total_frame_size, dtype=np.int16)
|
||||
last_frame[: len(self.buffer)] = self.buffer
|
||||
|
||||
output = self._encode(last_frame)
|
||||
if output:
|
||||
callback(output)
|
||||
self.buffer = np.array([], dtype=np.int16)
|
||||
|
||||
def _encode(self, frame: np.ndarray) -> Optional[bytes]:
|
||||
"""编码一帧音频数据"""
|
||||
try:
|
||||
# 将numpy数组转换为bytes
|
||||
frame_bytes = frame.tobytes()
|
||||
# opuslib要求输入字节数必须是channels*2的倍数
|
||||
encoded = self.encoder.encode(frame_bytes, self.frame_size)
|
||||
return encoded
|
||||
except Exception as e:
|
||||
logging.error(f"Opus编码失败: {e}")
|
||||
traceback.print_exc()
|
||||
return None
|
||||
|
||||
def _convert_bytes_to_shorts(self, bytes_data: bytes) -> np.ndarray:
|
||||
"""将字节数组转换为short数组 (16位PCM)"""
|
||||
# 假设输入是小端字节序的16位PCM
|
||||
return np.frombuffer(bytes_data, dtype=np.int16)
|
||||
|
||||
def _validate_pcm_data(self, pcm_shorts: np.ndarray) -> None:
|
||||
"""验证PCM数据是否有效"""
|
||||
# 16位PCM数据范围是 -32768 到 32767
|
||||
if np.any((pcm_shorts < -32768) | (pcm_shorts > 32767)):
|
||||
invalid_samples = pcm_shorts[(pcm_shorts < -32768) | (pcm_shorts > 32767)]
|
||||
logging.warning(f"发现无效PCM样本: {invalid_samples[:5]}...")
|
||||
# 在实际应用中可以选择裁剪而不是抛出异常
|
||||
# np.clip(pcm_shorts, -32768, 32767, out=pcm_shorts)
|
||||
|
||||
def close(self):
|
||||
"""关闭编码器并释放资源"""
|
||||
# opuslib没有明确的关闭方法,Python的垃圾回收会处理
|
||||
pass
|
||||
@@ -0,0 +1,50 @@
|
||||
import datetime
|
||||
from typing import Dict, Tuple
|
||||
|
||||
# 全局字典,用于存储每个设备的每日输出字数
|
||||
_device_daily_output: Dict[Tuple[str, datetime.date], int] = {}
|
||||
# 记录最后一次检查的日期
|
||||
_last_check_date: datetime.date = None
|
||||
|
||||
|
||||
def reset_device_output():
|
||||
"""
|
||||
重置所有设备的每日输出字数
|
||||
每天0点调用此函数
|
||||
"""
|
||||
_device_daily_output.clear()
|
||||
|
||||
|
||||
def get_device_output(device_id: str) -> int:
|
||||
"""
|
||||
获取设备当日的输出字数
|
||||
"""
|
||||
current_date = datetime.datetime.now().date()
|
||||
return _device_daily_output.get((device_id, current_date), 0)
|
||||
|
||||
|
||||
def add_device_output(device_id: str, char_count: int):
|
||||
"""
|
||||
增加设备的输出字数
|
||||
"""
|
||||
current_date = datetime.datetime.now().date()
|
||||
global _last_check_date
|
||||
|
||||
# 如果是第一次调用或者日期发生变化,清空计数器
|
||||
if _last_check_date is None or _last_check_date != current_date:
|
||||
_device_daily_output.clear()
|
||||
_last_check_date = current_date
|
||||
|
||||
current_count = _device_daily_output.get((device_id, current_date), 0)
|
||||
_device_daily_output[(device_id, current_date)] = current_count + char_count
|
||||
|
||||
|
||||
def check_device_output_limit(device_id: str, max_output_size: int) -> bool:
|
||||
"""
|
||||
检查设备是否超过输出限制
|
||||
:return: True 如果超过限制,False 如果未超过
|
||||
"""
|
||||
if not device_id:
|
||||
return False
|
||||
current_output = get_device_output(device_id)
|
||||
return current_output >= max_output_size
|
||||
@@ -0,0 +1,59 @@
|
||||
import struct
|
||||
|
||||
def decode_opus_from_file(input_file):
|
||||
"""
|
||||
从p3文件中解码 Opus 数据,并返回一个 Opus 数据包的列表以及总时长。
|
||||
"""
|
||||
opus_datas = []
|
||||
total_frames = 0
|
||||
sample_rate = 16000 # 文件采样率
|
||||
frame_duration_ms = 60 # 帧时长
|
||||
frame_size = int(sample_rate * frame_duration_ms / 1000)
|
||||
|
||||
with open(input_file, 'rb') as f:
|
||||
while True:
|
||||
# 读取头部(4字节):[1字节类型,1字节保留,2字节长度]
|
||||
header = f.read(4)
|
||||
if not header:
|
||||
break
|
||||
|
||||
# 解包头部信息
|
||||
_, _, data_len = struct.unpack('>BBH', header)
|
||||
|
||||
# 根据头部指定的长度读取 Opus 数据
|
||||
opus_data = f.read(data_len)
|
||||
if len(opus_data) != data_len:
|
||||
raise ValueError(f"Data length({len(opus_data)}) mismatch({data_len}) in the file.")
|
||||
|
||||
opus_datas.append(opus_data)
|
||||
total_frames += 1
|
||||
|
||||
# 计算总时长
|
||||
total_duration = (total_frames * frame_duration_ms) / 1000.0
|
||||
return opus_datas, total_duration
|
||||
|
||||
def decode_opus_from_bytes(input_bytes):
|
||||
"""
|
||||
从p3二进制数据中解码 Opus 数据,并返回一个 Opus 数据包的列表以及总时长。
|
||||
"""
|
||||
import io
|
||||
opus_datas = []
|
||||
total_frames = 0
|
||||
sample_rate = 16000 # 文件采样率
|
||||
frame_duration_ms = 60 # 帧时长
|
||||
frame_size = int(sample_rate * frame_duration_ms / 1000)
|
||||
|
||||
f = io.BytesIO(input_bytes)
|
||||
while True:
|
||||
header = f.read(4)
|
||||
if not header:
|
||||
break
|
||||
_, _, data_len = struct.unpack('>BBH', header)
|
||||
opus_data = f.read(data_len)
|
||||
if len(opus_data) != data_len:
|
||||
raise ValueError(f"Data length({len(opus_data)}) mismatch({data_len}) in the bytes.")
|
||||
opus_datas.append(opus_data)
|
||||
total_frames += 1
|
||||
|
||||
total_duration = (total_frames * frame_duration_ms) / 1000.0
|
||||
return opus_datas, total_duration
|
||||
@@ -0,0 +1,241 @@
|
||||
"""
|
||||
系统提示词管理器模块
|
||||
负责管理和更新系统提示词,包括快速初始化和异步增强功能
|
||||
"""
|
||||
|
||||
import os
|
||||
import cnlunar
|
||||
from typing import Dict, Any
|
||||
from config.logger import setup_logging
|
||||
from jinja2 import Template
|
||||
|
||||
TAG = __name__
|
||||
|
||||
WEEKDAY_MAP = {
|
||||
"Monday": "星期一",
|
||||
"Tuesday": "星期二",
|
||||
"Wednesday": "星期三",
|
||||
"Thursday": "星期四",
|
||||
"Friday": "星期五",
|
||||
"Saturday": "星期六",
|
||||
"Sunday": "星期日",
|
||||
}
|
||||
|
||||
EMOJI_List = [
|
||||
"😶",
|
||||
"🙂",
|
||||
"😆",
|
||||
"😂",
|
||||
"😔",
|
||||
"😠",
|
||||
"😭",
|
||||
"😍",
|
||||
"😳",
|
||||
"😲",
|
||||
"😱",
|
||||
"🤔",
|
||||
"😉",
|
||||
"😎",
|
||||
"😌",
|
||||
"🤤",
|
||||
"😘",
|
||||
"😏",
|
||||
"😴",
|
||||
"😜",
|
||||
"🙄",
|
||||
]
|
||||
|
||||
|
||||
class PromptManager:
|
||||
"""系统提示词管理器,负责管理和更新系统提示词"""
|
||||
|
||||
def __init__(self, config: Dict[str, Any], logger=None):
|
||||
self.config = config
|
||||
self.logger = logger or setup_logging()
|
||||
self.base_prompt_template = None
|
||||
self.last_update_time = 0
|
||||
|
||||
# 导入全局缓存管理器
|
||||
from core.utils.cache.manager import cache_manager, CacheType
|
||||
|
||||
self.cache_manager = cache_manager
|
||||
self.CacheType = CacheType
|
||||
|
||||
self._load_base_template()
|
||||
|
||||
def _load_base_template(self):
|
||||
"""加载基础提示词模板"""
|
||||
try:
|
||||
template_path = self.config.get("prompt_template", "agent-base-prompt.txt")
|
||||
cache_key = f"prompt_template:{template_path}"
|
||||
|
||||
# 先从缓存获取
|
||||
cached_template = self.cache_manager.get(self.CacheType.CONFIG, cache_key)
|
||||
if cached_template is not None:
|
||||
self.base_prompt_template = cached_template
|
||||
self.logger.bind(tag=TAG).debug("从缓存加载基础提示词模板")
|
||||
return
|
||||
|
||||
# 缓存未命中,从文件读取
|
||||
if os.path.exists(template_path):
|
||||
with open(template_path, "r", encoding="utf-8") as f:
|
||||
template_content = f.read()
|
||||
|
||||
# 存入缓存(CONFIG类型默认不自动过期,需要手动失效)
|
||||
self.cache_manager.set(
|
||||
self.CacheType.CONFIG, cache_key, template_content
|
||||
)
|
||||
self.base_prompt_template = template_content
|
||||
self.logger.bind(tag=TAG).debug("成功加载基础提示词模板并缓存")
|
||||
else:
|
||||
self.logger.bind(tag=TAG).warning(f"未找到{template_path}文件")
|
||||
except Exception as e:
|
||||
self.logger.bind(tag=TAG).error(f"加载提示词模板失败: {e}")
|
||||
|
||||
def get_quick_prompt(self, user_prompt: str, device_id: str = None) -> str:
|
||||
"""快速获取系统提示词(使用用户配置)"""
|
||||
device_cache_key = f"device_prompt:{device_id}"
|
||||
cached_device_prompt = self.cache_manager.get(
|
||||
self.CacheType.DEVICE_PROMPT, device_cache_key
|
||||
)
|
||||
if cached_device_prompt is not None:
|
||||
self.logger.bind(tag=TAG).debug(f"使用设备 {device_id} 的缓存提示词")
|
||||
return cached_device_prompt
|
||||
else:
|
||||
self.logger.bind(tag=TAG).debug(
|
||||
f"设备 {device_id} 无缓存提示词,使用传入的提示词"
|
||||
)
|
||||
|
||||
# 使用传入的提示词并缓存(如果有设备ID)
|
||||
if device_id:
|
||||
device_cache_key = f"device_prompt:{device_id}"
|
||||
self.cache_manager.set(self.CacheType.CONFIG, device_cache_key, user_prompt)
|
||||
self.logger.bind(tag=TAG).debug(f"设备 {device_id} 的提示词已缓存")
|
||||
|
||||
self.logger.bind(tag=TAG).info(f"使用快速提示词: {user_prompt[:50]}...")
|
||||
return user_prompt
|
||||
|
||||
def _get_current_time_info(self) -> tuple:
|
||||
"""获取当前时间信息"""
|
||||
from .current_time import get_current_date, get_current_weekday, get_current_lunar_date
|
||||
|
||||
today_date = get_current_date()
|
||||
today_weekday = get_current_weekday()
|
||||
lunar_date = get_current_lunar_date() + "\n"
|
||||
|
||||
return today_date, today_weekday, lunar_date
|
||||
|
||||
def _get_location_info(self, client_ip: str) -> str:
|
||||
"""获取位置信息"""
|
||||
try:
|
||||
# 先从缓存获取
|
||||
cached_location = self.cache_manager.get(self.CacheType.LOCATION, client_ip)
|
||||
if cached_location is not None:
|
||||
return cached_location
|
||||
|
||||
# 缓存未命中,调用API获取
|
||||
from core.utils.util import get_ip_info
|
||||
|
||||
ip_info = get_ip_info(client_ip, self.logger)
|
||||
city = ip_info.get("city", "未知位置")
|
||||
location = f"{city}"
|
||||
|
||||
# 存入缓存
|
||||
self.cache_manager.set(self.CacheType.LOCATION, client_ip, location)
|
||||
return location
|
||||
except Exception as e:
|
||||
self.logger.bind(tag=TAG).error(f"获取位置信息失败: {e}")
|
||||
return "未知位置"
|
||||
|
||||
def _get_weather_info(self, conn, location: str) -> str:
|
||||
"""获取天气信息"""
|
||||
try:
|
||||
# 先从缓存获取
|
||||
cached_weather = self.cache_manager.get(self.CacheType.WEATHER, location)
|
||||
if cached_weather is not None:
|
||||
return cached_weather
|
||||
|
||||
# 缓存未命中,调用get_weather函数获取
|
||||
from plugins_func.functions.get_weather import get_weather
|
||||
from plugins_func.register import ActionResponse
|
||||
|
||||
# 调用get_weather函数
|
||||
result = get_weather(conn, location=location, lang="zh_CN")
|
||||
if isinstance(result, ActionResponse):
|
||||
weather_report = result.result
|
||||
self.cache_manager.set(self.CacheType.WEATHER, location, weather_report)
|
||||
return weather_report
|
||||
return "天气信息获取失败"
|
||||
|
||||
except Exception as e:
|
||||
self.logger.bind(tag=TAG).error(f"获取天气信息失败: {e}")
|
||||
return "天气信息获取失败"
|
||||
|
||||
def update_context_info(self, conn, client_ip: str):
|
||||
"""同步更新上下文信息"""
|
||||
try:
|
||||
# 获取位置信息(使用全局缓存)
|
||||
local_address = self._get_location_info(client_ip)
|
||||
# 获取天气信息(使用全局缓存)
|
||||
self._get_weather_info(conn, local_address)
|
||||
self.logger.bind(tag=TAG).info(f"上下文信息更新完成")
|
||||
|
||||
except Exception as e:
|
||||
self.logger.bind(tag=TAG).error(f"更新上下文信息失败: {e}")
|
||||
|
||||
def build_enhanced_prompt(
|
||||
self, user_prompt: str, device_id: str, client_ip: str = None, *args, **kwargs
|
||||
) -> str:
|
||||
"""构建增强的系统提示词"""
|
||||
if not self.base_prompt_template:
|
||||
return user_prompt
|
||||
|
||||
try:
|
||||
# 获取最新的时间信息(不缓存)
|
||||
today_date, today_weekday, lunar_date = (
|
||||
self._get_current_time_info()
|
||||
)
|
||||
|
||||
# 获取缓存的上下文信息
|
||||
local_address = ""
|
||||
weather_info = ""
|
||||
|
||||
if client_ip:
|
||||
# 获取位置信息(从全局缓存)
|
||||
local_address = (
|
||||
self.cache_manager.get(self.CacheType.LOCATION, client_ip) or ""
|
||||
)
|
||||
|
||||
# 获取天气信息(从全局缓存)
|
||||
if local_address:
|
||||
weather_info = (
|
||||
self.cache_manager.get(self.CacheType.WEATHER, local_address)
|
||||
or ""
|
||||
)
|
||||
|
||||
# 替换模板变量
|
||||
template = Template(self.base_prompt_template)
|
||||
enhanced_prompt = template.render(
|
||||
base_prompt=user_prompt,
|
||||
current_time="{{current_time}}",
|
||||
today_date=today_date,
|
||||
today_weekday=today_weekday,
|
||||
lunar_date=lunar_date,
|
||||
local_address=local_address,
|
||||
weather_info=weather_info,
|
||||
emojiList=EMOJI_List,
|
||||
device_id=device_id,
|
||||
*args, **kwargs
|
||||
)
|
||||
device_cache_key = f"device_prompt:{device_id}"
|
||||
self.cache_manager.set(
|
||||
self.CacheType.DEVICE_PROMPT, device_cache_key, enhanced_prompt
|
||||
)
|
||||
self.logger.bind(tag=TAG).info(
|
||||
f"构建增强提示词成功,长度: {len(enhanced_prompt)}"
|
||||
)
|
||||
return enhanced_prompt
|
||||
|
||||
except Exception as e:
|
||||
self.logger.bind(tag=TAG).error(f"构建增强提示词失败: {e}")
|
||||
return user_prompt
|
||||
@@ -0,0 +1,113 @@
|
||||
import json
|
||||
|
||||
TAG = __name__
|
||||
EMOJI_MAP = {
|
||||
"😂": "laughing",
|
||||
"😭": "crying",
|
||||
"😠": "angry",
|
||||
"😔": "sad",
|
||||
"😍": "loving",
|
||||
"😲": "surprised",
|
||||
"😱": "shocked",
|
||||
"🤔": "thinking",
|
||||
"😌": "relaxed",
|
||||
"😴": "sleepy",
|
||||
"😜": "silly",
|
||||
"🙄": "confused",
|
||||
"😶": "neutral",
|
||||
"🙂": "happy",
|
||||
"😆": "laughing",
|
||||
"😳": "embarrassed",
|
||||
"😉": "winking",
|
||||
"😎": "cool",
|
||||
"🤤": "delicious",
|
||||
"😘": "kissy",
|
||||
"😏": "confident",
|
||||
}
|
||||
EMOJI_RANGES = [
|
||||
(0x1F600, 0x1F64F),
|
||||
(0x1F300, 0x1F5FF),
|
||||
(0x1F680, 0x1F6FF),
|
||||
(0x1F900, 0x1F9FF),
|
||||
(0x1FA70, 0x1FAFF),
|
||||
(0x2600, 0x26FF),
|
||||
(0x2700, 0x27BF),
|
||||
]
|
||||
|
||||
|
||||
def get_string_no_punctuation_or_emoji(s):
|
||||
"""去除字符串首尾的空格、标点符号和表情符号"""
|
||||
chars = list(s)
|
||||
# 处理开头的字符
|
||||
start = 0
|
||||
while start < len(chars) and is_punctuation_or_emoji(chars[start]):
|
||||
start += 1
|
||||
# 处理结尾的字符
|
||||
end = len(chars) - 1
|
||||
while end >= start and is_punctuation_or_emoji(chars[end]):
|
||||
end -= 1
|
||||
return "".join(chars[start : end + 1])
|
||||
|
||||
|
||||
def is_punctuation_or_emoji(char):
|
||||
"""检查字符是否为空格、指定标点或表情符号"""
|
||||
# 定义需要去除的中英文标点(包括全角/半角)
|
||||
punctuation_set = {
|
||||
",",
|
||||
",", # 中文逗号 + 英文逗号
|
||||
"。",
|
||||
".", # 中文句号 + 英文句号
|
||||
"!",
|
||||
"!", # 中文感叹号 + 英文感叹号
|
||||
"“",
|
||||
"”",
|
||||
'"', # 中文双引号 + 英文引号
|
||||
":",
|
||||
":", # 中文冒号 + 英文冒号
|
||||
"-",
|
||||
"-", # 英文连字符 + 中文全角横线
|
||||
"、", # 中文顿号
|
||||
"[",
|
||||
"]", # 方括号
|
||||
"【",
|
||||
"】", # 中文方括号
|
||||
}
|
||||
if char.isspace() or char in punctuation_set:
|
||||
return True
|
||||
return is_emoji(char)
|
||||
|
||||
|
||||
async def get_emotion(conn, text):
|
||||
"""获取文本内的情绪消息"""
|
||||
emoji = "🙂"
|
||||
emotion = "happy"
|
||||
for char in text:
|
||||
if char in EMOJI_MAP:
|
||||
emoji = char
|
||||
emotion = EMOJI_MAP[char]
|
||||
break
|
||||
try:
|
||||
await conn.websocket.send(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "llm",
|
||||
"text": emoji,
|
||||
"emotion": emotion,
|
||||
"session_id": conn.session_id,
|
||||
}
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
conn.logger.bind(tag=TAG).warning(f"发送情绪表情失败,错误:{e}")
|
||||
return
|
||||
|
||||
|
||||
def is_emoji(char):
|
||||
"""检查字符是否为emoji表情"""
|
||||
code_point = ord(char)
|
||||
return any(start <= code_point <= end for start, end in EMOJI_RANGES)
|
||||
|
||||
|
||||
def check_emoji(text):
|
||||
"""去除文本中的所有emoji表情"""
|
||||
return ''.join(char for char in text if not is_emoji(char) and char != "\n")
|
||||
@@ -0,0 +1,138 @@
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
from config.logger import setup_logging
|
||||
import importlib
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
punctuation_set = {
|
||||
",",
|
||||
",", # 中文逗号 + 英文逗号
|
||||
"。",
|
||||
".", # 中文句号 + 英文句号
|
||||
"!",
|
||||
"!", # 中文感叹号 + 英文感叹号
|
||||
"“",
|
||||
"”",
|
||||
'"', # 中文双引号 + 英文引号
|
||||
":",
|
||||
":", # 中文冒号 + 英文冒号
|
||||
"-",
|
||||
"-", # 英文连字符 + 中文全角横线
|
||||
"、", # 中文顿号
|
||||
"[",
|
||||
"]", # 方括号
|
||||
"【",
|
||||
"】", # 中文方括号
|
||||
"~", # 波浪号
|
||||
}
|
||||
|
||||
def create_instance(class_name, *args, **kwargs):
|
||||
# 创建TTS实例
|
||||
if os.path.exists(os.path.join('core', 'providers', 'tts', f'{class_name}.py')):
|
||||
lib_name = f'core.providers.tts.{class_name}'
|
||||
if lib_name not in sys.modules:
|
||||
sys.modules[lib_name] = importlib.import_module(f'{lib_name}')
|
||||
return sys.modules[lib_name].TTSProvider(*args, **kwargs)
|
||||
|
||||
raise ValueError(f"不支持的TTS类型: {class_name},请检查该配置的type是否设置正确")
|
||||
|
||||
|
||||
class MarkdownCleaner:
|
||||
"""
|
||||
封装 Markdown 清理逻辑:直接用 MarkdownCleaner.clean_markdown(text) 即可
|
||||
"""
|
||||
# 公式字符
|
||||
NORMAL_FORMULA_CHARS = re.compile(r'[a-zA-Z\\^_{}\+\-\(\)\[\]=]')
|
||||
|
||||
@staticmethod
|
||||
def _replace_inline_dollar(m: re.Match) -> str:
|
||||
"""
|
||||
只要捕获到完整的 "$...$":
|
||||
- 如果内部有典型公式字符 => 去掉两侧 $
|
||||
- 否则 (纯数字/货币等) => 保留 "$...$"
|
||||
"""
|
||||
content = m.group(1)
|
||||
if MarkdownCleaner.NORMAL_FORMULA_CHARS.search(content):
|
||||
return content
|
||||
else:
|
||||
return m.group(0)
|
||||
|
||||
@staticmethod
|
||||
def _replace_table_block(match: re.Match) -> str:
|
||||
"""
|
||||
当匹配到一个整段表格块时,回调该函数。
|
||||
"""
|
||||
block_text = match.group('table_block')
|
||||
lines = block_text.strip('\n').split('\n')
|
||||
|
||||
parsed_table = []
|
||||
for line in lines:
|
||||
line_stripped = line.strip()
|
||||
if re.match(r'^\|\s*[-:]+\s*(\|\s*[-:]+\s*)+\|?$', line_stripped):
|
||||
continue
|
||||
columns = [col.strip() for col in line_stripped.split('|') if col.strip() != '']
|
||||
if columns:
|
||||
parsed_table.append(columns)
|
||||
|
||||
if not parsed_table:
|
||||
return ""
|
||||
|
||||
headers = parsed_table[0]
|
||||
data_rows = parsed_table[1:] if len(parsed_table) > 1 else []
|
||||
|
||||
lines_for_tts = []
|
||||
if len(parsed_table) == 1:
|
||||
# 只有一行
|
||||
only_line_str = ", ".join(parsed_table[0])
|
||||
lines_for_tts.append(f"单行表格:{only_line_str}")
|
||||
else:
|
||||
lines_for_tts.append(f"表头是:{', '.join(headers)}")
|
||||
for i, row in enumerate(data_rows, start=1):
|
||||
row_str_list = []
|
||||
for col_index, cell_val in enumerate(row):
|
||||
if col_index < len(headers):
|
||||
row_str_list.append(f"{headers[col_index]} = {cell_val}")
|
||||
else:
|
||||
row_str_list.append(cell_val)
|
||||
lines_for_tts.append(f"第 {i} 行:{', '.join(row_str_list)}")
|
||||
|
||||
return "\n".join(lines_for_tts) + "\n"
|
||||
|
||||
# 预编译所有正则表达式(按执行频率排序)
|
||||
# 这里要把 replace_xxx 的静态方法放在最前定义,以便在列表里能正确引用它们。
|
||||
REGEXES = [
|
||||
(re.compile(r'```.*?```', re.DOTALL), ''), # 代码块
|
||||
(re.compile(r'^#+\s*', re.MULTILINE), ''), # 标题
|
||||
(re.compile(r'(\*\*|__)(.*?)\1'), r'\2'), # 粗体
|
||||
(re.compile(r'(\*|_)(?=\S)(.*?)(?<=\S)\1'), r'\2'), # 斜体
|
||||
(re.compile(r'!\[.*?\]\(.*?\)'), ''), # 图片
|
||||
(re.compile(r'\[(.*?)\]\(.*?\)'), r'\1'), # 链接
|
||||
(re.compile(r'^\s*>+\s*', re.MULTILINE), ''), # 引用
|
||||
(
|
||||
re.compile(r'(?P<table_block>(?:^[^\n]*\|[^\n]*\n)+)', re.MULTILINE),
|
||||
_replace_table_block
|
||||
),
|
||||
(re.compile(r'^\s*[*+-]\s*', re.MULTILINE), '- '), # 列表
|
||||
(re.compile(r'\$\$.*?\$\$', re.DOTALL), ''), # 块级公式
|
||||
(
|
||||
re.compile(r'(?<![A-Za-z0-9])\$([^\n$]+)\$(?![A-Za-z0-9])'),
|
||||
_replace_inline_dollar
|
||||
),
|
||||
(re.compile(r'\n{2,}'), '\n'), # 多余空行
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def clean_markdown(text: str) -> str:
|
||||
"""
|
||||
主入口方法:依序执行所有正则,移除或替换 Markdown 元素
|
||||
"""
|
||||
# 检查文本是否全为英文和基本标点符号
|
||||
if text and all((c.isascii() or c.isspace() or c in punctuation_set) for c in text):
|
||||
# 保留原始空格,直接返回
|
||||
return text
|
||||
|
||||
for regex, replacement in MarkdownCleaner.REGEXES:
|
||||
text = regex.sub(replacement, text)
|
||||
return text.strip()
|
||||
@@ -0,0 +1,542 @@
|
||||
import re
|
||||
import os
|
||||
import json
|
||||
import copy
|
||||
import wave
|
||||
import socket
|
||||
import requests
|
||||
import subprocess
|
||||
import numpy as np
|
||||
import opuslib_next
|
||||
from io import BytesIO
|
||||
from core.utils import p3
|
||||
from pydub import AudioSegment
|
||||
from typing import Callable, Any
|
||||
|
||||
TAG = __name__
|
||||
emoji_map = {
|
||||
"neutral": "😶",
|
||||
"happy": "🙂",
|
||||
"laughing": "😆",
|
||||
"funny": "😂",
|
||||
"sad": "😔",
|
||||
"angry": "😠",
|
||||
"crying": "😭",
|
||||
"loving": "😍",
|
||||
"embarrassed": "😳",
|
||||
"surprised": "😲",
|
||||
"shocked": "😱",
|
||||
"thinking": "🤔",
|
||||
"winking": "😉",
|
||||
"cool": "😎",
|
||||
"relaxed": "😌",
|
||||
"delicious": "🤤",
|
||||
"kissy": "😘",
|
||||
"confident": "😏",
|
||||
"sleepy": "😴",
|
||||
"silly": "😜",
|
||||
"confused": "🙄",
|
||||
}
|
||||
|
||||
|
||||
def get_local_ip():
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
# Connect to Google's DNS servers
|
||||
s.connect(("8.8.8.8", 80))
|
||||
local_ip = s.getsockname()[0]
|
||||
s.close()
|
||||
return local_ip
|
||||
except Exception as e:
|
||||
return "127.0.0.1"
|
||||
|
||||
|
||||
def is_private_ip(ip_addr):
|
||||
"""
|
||||
Check if an IP address is a private IP address (compatible with IPv4 and IPv6).
|
||||
|
||||
@param {string} ip_addr - The IP address to check.
|
||||
@return {bool} True if the IP address is private, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Validate IPv4 or IPv6 address format
|
||||
if not re.match(
|
||||
r"^(\d{1,3}\.){3}\d{1,3}$|^([0-9a-fA-F]{1,4}:){7}[0-9a-fA-F]{1,4}$", ip_addr
|
||||
):
|
||||
return False # Invalid IP address format
|
||||
|
||||
# IPv4 private address ranges
|
||||
if "." in ip_addr: # IPv4 address
|
||||
ip_parts = list(map(int, ip_addr.split(".")))
|
||||
if ip_parts[0] == 10:
|
||||
return True # 10.0.0.0/8 range
|
||||
elif ip_parts[0] == 172 and 16 <= ip_parts[1] <= 31:
|
||||
return True # 172.16.0.0/12 range
|
||||
elif ip_parts[0] == 192 and ip_parts[1] == 168:
|
||||
return True # 192.168.0.0/16 range
|
||||
elif ip_addr == "127.0.0.1":
|
||||
return True # Loopback address
|
||||
elif ip_parts[0] == 169 and ip_parts[1] == 254:
|
||||
return True # Link-local address 169.254.0.0/16
|
||||
else:
|
||||
return False # Not a private IPv4 address
|
||||
else: # IPv6 address
|
||||
ip_addr = ip_addr.lower()
|
||||
if ip_addr.startswith("fc00:") or ip_addr.startswith("fd00:"):
|
||||
return True # Unique Local Addresses (FC00::/7)
|
||||
elif ip_addr == "::1":
|
||||
return True # Loopback address
|
||||
elif ip_addr.startswith("fe80:"):
|
||||
return True # Link-local unicast addresses (FE80::/10)
|
||||
else:
|
||||
return False # Not a private IPv6 address
|
||||
|
||||
except (ValueError, IndexError):
|
||||
return False # IP address format error or insufficient segments
|
||||
|
||||
|
||||
def get_ip_info(ip_addr, logger):
|
||||
try:
|
||||
# 导入全局缓存管理器
|
||||
from core.utils.cache.manager import cache_manager, CacheType
|
||||
|
||||
# 先从缓存获取
|
||||
cached_ip_info = cache_manager.get(CacheType.IP_INFO, ip_addr)
|
||||
if cached_ip_info is not None:
|
||||
return cached_ip_info
|
||||
|
||||
# 缓存未命中,调用API
|
||||
if is_private_ip(ip_addr):
|
||||
ip_addr = ""
|
||||
url = f"https://whois.pconline.com.cn/ipJson.jsp?json=true&ip={ip_addr}"
|
||||
resp = requests.get(url).json()
|
||||
ip_info = {"city": resp.get("city")}
|
||||
|
||||
# 存入缓存
|
||||
cache_manager.set(CacheType.IP_INFO, ip_addr, ip_info)
|
||||
return ip_info
|
||||
except Exception as e:
|
||||
logger.bind(tag=TAG).error(f"Error getting client ip info: {e}")
|
||||
return {}
|
||||
|
||||
|
||||
def write_json_file(file_path, data):
|
||||
"""将数据写入 JSON 文件"""
|
||||
with open(file_path, "w", encoding="utf-8") as file:
|
||||
json.dump(data, file, ensure_ascii=False, indent=4)
|
||||
|
||||
|
||||
def remove_punctuation_and_length(text):
|
||||
# 全角符号和半角符号的Unicode范围
|
||||
full_width_punctuations = (
|
||||
"!"#$%&'()*+,-。/:;<=>?@[\]^_`{|}~"
|
||||
)
|
||||
half_width_punctuations = r'!"#$%&\'()*+,-./:;<=>?@[\]^_`{|}~'
|
||||
space = " " # 半角空格
|
||||
full_width_space = " " # 全角空格
|
||||
|
||||
# 去除全角和半角符号以及空格
|
||||
result = "".join(
|
||||
[
|
||||
char
|
||||
for char in text
|
||||
if char not in full_width_punctuations
|
||||
and char not in half_width_punctuations
|
||||
and char not in space
|
||||
and char not in full_width_space
|
||||
]
|
||||
)
|
||||
|
||||
if result == "Yeah":
|
||||
return 0, ""
|
||||
return len(result), result
|
||||
|
||||
|
||||
def check_model_key(modelType, modelKey):
|
||||
if "你" in modelKey:
|
||||
return f"配置错误: {modelType} 的 API key 未设置,当前值为: {modelKey}"
|
||||
return None
|
||||
|
||||
|
||||
def parse_string_to_list(value, separator=";"):
|
||||
"""
|
||||
将输入值转换为列表
|
||||
Args:
|
||||
value: 输入值,可以是 None、字符串或列表
|
||||
separator: 分隔符,默认为分号
|
||||
Returns:
|
||||
list: 处理后的列表
|
||||
"""
|
||||
if value is None or value == "":
|
||||
return []
|
||||
elif isinstance(value, str):
|
||||
return [item.strip() for item in value.split(separator) if item.strip()]
|
||||
elif isinstance(value, list):
|
||||
return value
|
||||
return []
|
||||
|
||||
|
||||
def check_ffmpeg_installed() -> bool:
|
||||
"""
|
||||
检查当前环境中是否已正确安装并可执行 ffmpeg。
|
||||
|
||||
Returns:
|
||||
bool: 如果 ffmpeg 正常可用,返回 True;否则抛出 ValueError 异常。
|
||||
|
||||
Raises:
|
||||
ValueError: 当检测到 ffmpeg 未安装或依赖缺失时,抛出详细的提示信息。
|
||||
"""
|
||||
try:
|
||||
# 尝试执行 ffmpeg 命令
|
||||
result = subprocess.run(
|
||||
["ffmpeg", "-version"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
text=True,
|
||||
check=True, # 非零退出码会触发 CalledProcessError
|
||||
)
|
||||
|
||||
output = (result.stdout + result.stderr).lower()
|
||||
if "ffmpeg version" in output:
|
||||
return True
|
||||
|
||||
# 如果未检测到版本信息,也视为异常情况
|
||||
raise ValueError("未检测到有效的 ffmpeg 版本输出。")
|
||||
|
||||
except (subprocess.CalledProcessError, FileNotFoundError) as e:
|
||||
# 提取错误输出
|
||||
stderr_output = ""
|
||||
if isinstance(e, subprocess.CalledProcessError):
|
||||
stderr_output = (e.stderr or "").strip()
|
||||
else:
|
||||
stderr_output = str(e).strip()
|
||||
|
||||
# 构建基础错误提示
|
||||
error_msg = [
|
||||
"❌ 检测到 ffmpeg 无法正常运行。\n",
|
||||
"建议您:",
|
||||
"1. 确认已正确激活 conda 环境;",
|
||||
"2. 查阅项目安装文档,了解如何在 conda 环境中安装 ffmpeg。\n",
|
||||
]
|
||||
|
||||
# 🎯 针对具体错误信息提供额外提示
|
||||
if "libiconv.so.2" in stderr_output:
|
||||
error_msg.append("⚠️ 发现缺少依赖库:libiconv.so.2")
|
||||
error_msg.append("解决方法:在当前 conda 环境中执行:")
|
||||
error_msg.append(" conda install -c conda-forge libiconv\n")
|
||||
elif "no such file or directory" in stderr_output and "ffmpeg" in stderr_output.lower():
|
||||
error_msg.append("⚠️ 系统未找到 ffmpeg 可执行文件。")
|
||||
error_msg.append("解决方法:在当前 conda 环境中执行:")
|
||||
error_msg.append(" conda install -c conda-forge ffmpeg\n")
|
||||
else:
|
||||
error_msg.append("错误详情:")
|
||||
error_msg.append(stderr_output or "未知错误。")
|
||||
|
||||
# 抛出详细异常信息
|
||||
raise ValueError("\n".join(error_msg)) from e
|
||||
|
||||
|
||||
def extract_json_from_string(input_string):
|
||||
"""提取字符串中的 JSON 部分"""
|
||||
pattern = r"(\{.*\})"
|
||||
match = re.search(pattern, input_string, re.DOTALL) # 添加 re.DOTALL
|
||||
if match:
|
||||
return match.group(1) # 返回提取的 JSON 字符串
|
||||
return None
|
||||
|
||||
|
||||
def audio_to_data_stream(audio_file_path, is_opus=True, callback: Callable[[Any], Any]=None) -> None:
|
||||
# 获取文件后缀名
|
||||
file_type = os.path.splitext(audio_file_path)[1]
|
||||
if file_type:
|
||||
file_type = file_type.lstrip(".")
|
||||
# 读取音频文件,-nostdin 参数:不要从标准输入读取数据,否则FFmpeg会阻塞
|
||||
audio = AudioSegment.from_file(
|
||||
audio_file_path, format=file_type, parameters=["-nostdin"]
|
||||
)
|
||||
|
||||
# 转换为单声道/16kHz采样率/16位小端编码(确保与编码器匹配)
|
||||
audio = audio.set_channels(1).set_frame_rate(16000).set_sample_width(2)
|
||||
|
||||
# 获取原始PCM数据(16位小端)
|
||||
raw_data = audio.raw_data
|
||||
pcm_to_data_stream(raw_data, is_opus, callback)
|
||||
|
||||
def audio_to_data(audio_file_path: str, is_opus: bool = True) -> list[bytes]:
|
||||
"""
|
||||
将音频文件转换为Opus/PCM编码的帧列表
|
||||
Args:
|
||||
audio_file_path: 音频文件路径
|
||||
is_opus: 是否进行Opus编码
|
||||
"""
|
||||
# 获取文件后缀名
|
||||
file_type = os.path.splitext(audio_file_path)[1]
|
||||
if file_type:
|
||||
file_type = file_type.lstrip(".")
|
||||
# 读取音频文件,-nostdin 参数:不要从标准输入读取数据,否则FFmpeg会阻塞
|
||||
audio = AudioSegment.from_file(
|
||||
audio_file_path, format=file_type, parameters=["-nostdin"]
|
||||
)
|
||||
|
||||
# 转换为单声道/16kHz采样率/16位小端编码(确保与编码器匹配)
|
||||
audio = audio.set_channels(1).set_frame_rate(16000).set_sample_width(2)
|
||||
|
||||
# 获取原始PCM数据(16位小端)
|
||||
raw_data = audio.raw_data
|
||||
|
||||
# 初始化Opus编码器
|
||||
encoder = opuslib_next.Encoder(16000, 1, opuslib_next.APPLICATION_AUDIO)
|
||||
|
||||
# 编码参数
|
||||
frame_duration = 60 # 60ms per frame
|
||||
frame_size = int(16000 * frame_duration / 1000) # 960 samples/frame
|
||||
|
||||
datas = []
|
||||
# 按帧处理所有音频数据(包括最后一帧可能补零)
|
||||
for i in range(0, len(raw_data), frame_size * 2): # 16bit=2bytes/sample
|
||||
# 获取当前帧的二进制数据
|
||||
chunk = raw_data[i : i + frame_size * 2]
|
||||
|
||||
# 如果最后一帧不足,补零
|
||||
if len(chunk) < frame_size * 2:
|
||||
chunk += b"\x00" * (frame_size * 2 - len(chunk))
|
||||
|
||||
if is_opus:
|
||||
# 转换为numpy数组处理
|
||||
np_frame = np.frombuffer(chunk, dtype=np.int16)
|
||||
# 编码Opus数据
|
||||
frame_data = encoder.encode(np_frame.tobytes(), frame_size)
|
||||
else:
|
||||
frame_data = chunk if isinstance(chunk, bytes) else bytes(chunk)
|
||||
|
||||
datas.append(frame_data)
|
||||
|
||||
return datas
|
||||
|
||||
def audio_bytes_to_data_stream(audio_bytes, file_type, is_opus, callback: Callable[[Any], Any]) -> None:
|
||||
"""
|
||||
直接用音频二进制数据转为opus/pcm数据,支持wav、mp3、p3
|
||||
"""
|
||||
if file_type == "p3":
|
||||
# 直接用p3解码
|
||||
return p3.decode_opus_from_bytes_stream(audio_bytes, callback)
|
||||
else:
|
||||
# 其他格式用pydub
|
||||
audio = AudioSegment.from_file(
|
||||
BytesIO(audio_bytes), format=file_type, parameters=["-nostdin"]
|
||||
)
|
||||
audio = audio.set_channels(1).set_frame_rate(16000).set_sample_width(2)
|
||||
raw_data = audio.raw_data
|
||||
pcm_to_data_stream(raw_data, is_opus, callback)
|
||||
|
||||
|
||||
def pcm_to_data_stream(raw_data, is_opus=True, callback: Callable[[Any], Any] = None):
|
||||
# 初始化Opus编码器
|
||||
encoder = opuslib_next.Encoder(16000, 1, opuslib_next.APPLICATION_AUDIO)
|
||||
|
||||
# 编码参数
|
||||
frame_duration = 60 # 60ms per frame
|
||||
frame_size = int(16000 * frame_duration / 1000) # 960 samples/frame
|
||||
|
||||
# 按帧处理所有音频数据(包括最后一帧可能补零)
|
||||
for i in range(0, len(raw_data), frame_size * 2): # 16bit=2bytes/sample
|
||||
# 获取当前帧的二进制数据
|
||||
chunk = raw_data[i : i + frame_size * 2]
|
||||
|
||||
# 如果最后一帧不足,补零
|
||||
if len(chunk) < frame_size * 2:
|
||||
chunk += b"\x00" * (frame_size * 2 - len(chunk))
|
||||
|
||||
if is_opus:
|
||||
# 转换为numpy数组处理
|
||||
np_frame = np.frombuffer(chunk, dtype=np.int16)
|
||||
# 编码Opus数据
|
||||
frame_data = encoder.encode(np_frame.tobytes(), frame_size)
|
||||
callback(frame_data)
|
||||
else:
|
||||
frame_data = chunk if isinstance(chunk, bytes) else bytes(chunk)
|
||||
callback(frame_data)
|
||||
|
||||
def opus_datas_to_wav_bytes(opus_datas, sample_rate=16000, channels=1):
|
||||
"""
|
||||
将opus帧列表解码为wav字节流
|
||||
"""
|
||||
decoder = opuslib_next.Decoder(sample_rate, channels)
|
||||
pcm_datas = []
|
||||
|
||||
frame_duration = 60 # ms
|
||||
frame_size = int(sample_rate * frame_duration / 1000) # 960
|
||||
|
||||
for opus_frame in opus_datas:
|
||||
# 解码为PCM(返回bytes,2字节/采样点)
|
||||
pcm = decoder.decode(opus_frame, frame_size)
|
||||
pcm_datas.append(pcm)
|
||||
|
||||
pcm_bytes = b"".join(pcm_datas)
|
||||
|
||||
# 写入wav字节流
|
||||
wav_buffer = BytesIO()
|
||||
with wave.open(wav_buffer, "wb") as wf:
|
||||
wf.setnchannels(channels)
|
||||
wf.setsampwidth(2) # 16bit
|
||||
wf.setframerate(sample_rate)
|
||||
wf.writeframes(pcm_bytes)
|
||||
return wav_buffer.getvalue()
|
||||
|
||||
def check_vad_update(before_config, new_config):
|
||||
if (
|
||||
new_config.get("selected_module") is None
|
||||
or new_config["selected_module"].get("VAD") is None
|
||||
):
|
||||
return False
|
||||
update_vad = False
|
||||
current_vad_module = before_config["selected_module"]["VAD"]
|
||||
new_vad_module = new_config["selected_module"]["VAD"]
|
||||
current_vad_type = (
|
||||
current_vad_module
|
||||
if "type" not in before_config["VAD"][current_vad_module]
|
||||
else before_config["VAD"][current_vad_module]["type"]
|
||||
)
|
||||
new_vad_type = (
|
||||
new_vad_module
|
||||
if "type" not in new_config["VAD"][new_vad_module]
|
||||
else new_config["VAD"][new_vad_module]["type"]
|
||||
)
|
||||
update_vad = current_vad_type != new_vad_type
|
||||
return update_vad
|
||||
|
||||
|
||||
def check_asr_update(before_config, new_config):
|
||||
if (
|
||||
new_config.get("selected_module") is None
|
||||
or new_config["selected_module"].get("ASR") is None
|
||||
):
|
||||
return False
|
||||
update_asr = False
|
||||
current_asr_module = before_config["selected_module"]["ASR"]
|
||||
new_asr_module = new_config["selected_module"]["ASR"]
|
||||
current_asr_type = (
|
||||
current_asr_module
|
||||
if "type" not in before_config["ASR"][current_asr_module]
|
||||
else before_config["ASR"][current_asr_module]["type"]
|
||||
)
|
||||
new_asr_type = (
|
||||
new_asr_module
|
||||
if "type" not in new_config["ASR"][new_asr_module]
|
||||
else new_config["ASR"][new_asr_module]["type"]
|
||||
)
|
||||
update_asr = current_asr_type != new_asr_type
|
||||
return update_asr
|
||||
|
||||
|
||||
def filter_sensitive_info(config: dict) -> dict:
|
||||
"""
|
||||
过滤配置中的敏感信息
|
||||
Args:
|
||||
config: 原始配置字典
|
||||
Returns:
|
||||
过滤后的配置字典
|
||||
"""
|
||||
sensitive_keys = [
|
||||
"api_key",
|
||||
"personal_access_token",
|
||||
"access_token",
|
||||
"token",
|
||||
"secret",
|
||||
"access_key_secret",
|
||||
"secret_key",
|
||||
]
|
||||
|
||||
def _filter_dict(d: dict) -> dict:
|
||||
filtered = {}
|
||||
for k, v in d.items():
|
||||
if any(sensitive in k.lower() for sensitive in sensitive_keys):
|
||||
filtered[k] = "***"
|
||||
elif isinstance(v, dict):
|
||||
filtered[k] = _filter_dict(v)
|
||||
elif isinstance(v, list):
|
||||
filtered[k] = [_filter_dict(i) if isinstance(i, dict) else i for i in v]
|
||||
else:
|
||||
filtered[k] = v
|
||||
return filtered
|
||||
|
||||
return _filter_dict(copy.deepcopy(config))
|
||||
|
||||
|
||||
def get_vision_url(config: dict) -> str:
|
||||
"""获取 vision URL
|
||||
|
||||
Args:
|
||||
config: 配置字典
|
||||
|
||||
Returns:
|
||||
str: vision URL
|
||||
"""
|
||||
server_config = config["server"]
|
||||
vision_explain = server_config.get("vision_explain", "")
|
||||
if "你的" in vision_explain:
|
||||
local_ip = get_local_ip()
|
||||
port = int(server_config.get("http_port", 8003))
|
||||
vision_explain = f"http://{local_ip}:{port}/mcp/vision/explain"
|
||||
return vision_explain
|
||||
|
||||
|
||||
def is_valid_image_file(file_data: bytes) -> bool:
|
||||
"""
|
||||
检查文件数据是否为有效的图片格式
|
||||
|
||||
Args:
|
||||
file_data: 文件的二进制数据
|
||||
|
||||
Returns:
|
||||
bool: 如果是有效的图片格式返回True,否则返回False
|
||||
"""
|
||||
# 常见图片格式的魔数(文件头)
|
||||
image_signatures = {
|
||||
b"\xff\xd8\xff": "JPEG",
|
||||
b"\x89PNG\r\n\x1a\n": "PNG",
|
||||
b"GIF87a": "GIF",
|
||||
b"GIF89a": "GIF",
|
||||
b"BM": "BMP",
|
||||
b"II*\x00": "TIFF",
|
||||
b"MM\x00*": "TIFF",
|
||||
b"RIFF": "WEBP",
|
||||
}
|
||||
|
||||
# 检查文件头是否匹配任何已知的图片格式
|
||||
for signature in image_signatures:
|
||||
if file_data.startswith(signature):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def sanitize_tool_name(name: str) -> str:
|
||||
"""Sanitize tool names for OpenAI compatibility."""
|
||||
# 支持中文、英文字母、数字、下划线和连字符
|
||||
return re.sub(r"[^a-zA-Z0-9_\-\u4e00-\u9fff]", "_", name)
|
||||
|
||||
|
||||
def validate_mcp_endpoint(mcp_endpoint: str) -> bool:
|
||||
"""
|
||||
校验MCP接入点格式
|
||||
|
||||
Args:
|
||||
mcp_endpoint: MCP接入点字符串
|
||||
|
||||
Returns:
|
||||
bool: 是否有效
|
||||
"""
|
||||
# 1. 检查是否以ws开头
|
||||
if not mcp_endpoint.startswith("ws"):
|
||||
return False
|
||||
|
||||
# 2. 检查是否包含key、call字样
|
||||
if "key" in mcp_endpoint.lower() or "call" in mcp_endpoint.lower():
|
||||
return False
|
||||
|
||||
# 3. 检查是否包含/mcp/字样
|
||||
if "/mcp/" not in mcp_endpoint:
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -0,0 +1,19 @@
|
||||
import importlib
|
||||
import os
|
||||
import sys
|
||||
from core.providers.vad.base import VADProviderBase
|
||||
from config.logger import setup_logging
|
||||
|
||||
TAG = __name__
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
def create_instance(class_name: str, *args, **kwargs) -> VADProviderBase:
|
||||
"""工厂方法创建VAD实例"""
|
||||
if os.path.exists(os.path.join("core", "providers", "vad", f"{class_name}.py")):
|
||||
lib_name = f"core.providers.vad.{class_name}"
|
||||
if lib_name not in sys.modules:
|
||||
sys.modules[lib_name] = importlib.import_module(f"{lib_name}")
|
||||
return sys.modules[lib_name].VADProvider(*args, **kwargs)
|
||||
|
||||
raise ValueError(f"不支持的VAD类型: {class_name},请检查该配置的type是否设置正确")
|
||||
@@ -0,0 +1,23 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
# 添加项目根目录到Python路径
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
project_root = os.path.abspath(os.path.join(current_dir, "..", ".."))
|
||||
sys.path.insert(0, project_root)
|
||||
|
||||
from config.logger import setup_logging
|
||||
import importlib
|
||||
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
def create_instance(class_name, *args, **kwargs):
|
||||
# 创建LLM实例
|
||||
if os.path.exists(os.path.join("core", "providers", "vllm", f"{class_name}.py")):
|
||||
lib_name = f"core.providers.vllm.{class_name}"
|
||||
if lib_name not in sys.modules:
|
||||
sys.modules[lib_name] = importlib.import_module(f"{lib_name}")
|
||||
return sys.modules[lib_name].VLLMProvider(*args, **kwargs)
|
||||
|
||||
raise ValueError(f"不支持的VLLM类型: {class_name},请检查该配置的type是否设置正确")
|
||||
@@ -0,0 +1,198 @@
|
||||
import asyncio
|
||||
import time
|
||||
import aiohttp
|
||||
import requests
|
||||
from urllib.parse import urlparse, parse_qs
|
||||
from typing import Optional, Dict
|
||||
from config.logger import setup_logging
|
||||
from core.utils.cache.manager import cache_manager
|
||||
from core.utils.cache.config import CacheType
|
||||
|
||||
TAG = __name__
|
||||
logger = setup_logging()
|
||||
|
||||
|
||||
class VoiceprintProvider:
|
||||
"""声纹识别服务提供者"""
|
||||
|
||||
def __init__(self, config: dict):
|
||||
self.original_url = config.get("url", "")
|
||||
self.speakers = config.get("speakers", [])
|
||||
self.speaker_map = self._parse_speakers()
|
||||
# 声纹识别相似度阈值,默认0.4
|
||||
self.similarity_threshold = float(config.get("similarity_threshold", 0.4))
|
||||
|
||||
# 解析API地址和密钥
|
||||
self.api_url = None
|
||||
self.api_key = None
|
||||
self.speaker_ids = []
|
||||
|
||||
if not self.original_url:
|
||||
logger.bind(tag=TAG).warning("声纹识别URL未配置,声纹识别将被禁用")
|
||||
self.enabled = False
|
||||
else:
|
||||
# 解析URL和key
|
||||
parsed_url = urlparse(self.original_url)
|
||||
base_url = f"{parsed_url.scheme}://{parsed_url.netloc}"
|
||||
|
||||
# 从查询参数中提取key
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
self.api_key = query_params.get('key', [''])[0]
|
||||
|
||||
if not self.api_key:
|
||||
logger.bind(tag=TAG).error("URL中未找到key参数,声纹识别将被禁用")
|
||||
self.enabled = False
|
||||
else:
|
||||
# 构造identify接口地址
|
||||
self.api_url = f"{base_url}/voiceprint/identify"
|
||||
|
||||
# 提取speaker_ids
|
||||
for speaker_str in self.speakers:
|
||||
try:
|
||||
parts = speaker_str.split(",", 2)
|
||||
if len(parts) >= 1:
|
||||
speaker_id = parts[0].strip()
|
||||
self.speaker_ids.append(speaker_id)
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
# 检查是否有有效的说话人配置
|
||||
if not self.speaker_ids:
|
||||
logger.bind(tag=TAG).warning("未配置有效的说话人,声纹识别将被禁用")
|
||||
self.enabled = False
|
||||
else:
|
||||
# 进行健康检查,验证服务器是否可用
|
||||
if self._check_server_health():
|
||||
self.enabled = True
|
||||
logger.bind(tag=TAG).info(f"声纹识别已启用: API={self.api_url}, 说话人={len(self.speaker_ids)}个, 相似度阈值={self.similarity_threshold}")
|
||||
else:
|
||||
self.enabled = False
|
||||
logger.bind(tag=TAG).warning(f"声纹识别服务器不可用,声纹识别已禁用: {self.api_url}")
|
||||
|
||||
def _parse_speakers(self) -> Dict[str, Dict[str, str]]:
|
||||
"""解析说话人配置"""
|
||||
speaker_map = {}
|
||||
for speaker_str in self.speakers:
|
||||
try:
|
||||
parts = speaker_str.split(",", 2)
|
||||
if len(parts) >= 3:
|
||||
speaker_id, name, description = parts[0].strip(), parts[1].strip(), parts[2].strip()
|
||||
speaker_map[speaker_id] = {
|
||||
"name": name,
|
||||
"description": description
|
||||
}
|
||||
except Exception as e:
|
||||
logger.bind(tag=TAG).warning(f"解析说话人配置失败: {speaker_str}, 错误: {e}")
|
||||
return speaker_map
|
||||
|
||||
def _check_server_health(self) -> bool:
|
||||
"""检查声纹识别服务器健康状态"""
|
||||
if not self.api_url or not self.api_key:
|
||||
return False
|
||||
|
||||
cache_key = f"{self.api_url}:{self.api_key}"
|
||||
|
||||
# 检查缓存
|
||||
cached_result = cache_manager.get(CacheType.VOICEPRINT_HEALTH, cache_key)
|
||||
if cached_result is not None:
|
||||
logger.bind(tag=TAG).debug(f"使用缓存的健康状态: {cached_result}")
|
||||
return cached_result
|
||||
|
||||
# 缓存过期或不存在
|
||||
logger.bind(tag=TAG).info("执行声纹服务器健康检查")
|
||||
|
||||
try:
|
||||
# 健康检查URL
|
||||
parsed_url = urlparse(self.api_url)
|
||||
health_url = f"{parsed_url.scheme}://{parsed_url.netloc}/voiceprint/health?key={self.api_key}"
|
||||
|
||||
# 发送健康检查请求
|
||||
response = requests.get(health_url, timeout=3)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if result.get("status") == "healthy":
|
||||
logger.bind(tag=TAG).info("声纹识别服务器健康检查通过")
|
||||
is_healthy = True
|
||||
else:
|
||||
logger.bind(tag=TAG).warning(f"声纹识别服务器状态异常: {result}")
|
||||
is_healthy = False
|
||||
else:
|
||||
logger.bind(tag=TAG).warning(f"声纹识别服务器健康检查失败: HTTP {response.status_code}")
|
||||
is_healthy = False
|
||||
|
||||
except requests.exceptions.ConnectTimeout:
|
||||
logger.bind(tag=TAG).warning("声纹识别服务器连接超时")
|
||||
is_healthy = False
|
||||
except requests.exceptions.ConnectionError:
|
||||
logger.bind(tag=TAG).warning("声纹识别服务器连接被拒绝")
|
||||
is_healthy = False
|
||||
except Exception as e:
|
||||
logger.bind(tag=TAG).warning(f"声纹识别服务器健康检查异常: {e}")
|
||||
is_healthy = False
|
||||
|
||||
# 使用全局缓存管理器缓存结果
|
||||
cache_manager.set(CacheType.VOICEPRINT_HEALTH, cache_key, is_healthy)
|
||||
logger.bind(tag=TAG).info(f"健康检查结果已缓存: {is_healthy}")
|
||||
|
||||
return is_healthy
|
||||
|
||||
async def identify_speaker(self, audio_data: bytes, session_id: str) -> Optional[str]:
|
||||
"""识别说话人"""
|
||||
if not self.enabled or not self.api_url or not self.api_key:
|
||||
logger.bind(tag=TAG).debug("声纹识别功能已禁用或未配置,跳过识别")
|
||||
return None
|
||||
|
||||
try:
|
||||
api_start_time = time.monotonic()
|
||||
|
||||
# 准备请求头
|
||||
headers = {
|
||||
'Authorization': f'Bearer {self.api_key}',
|
||||
'Accept': 'application/json'
|
||||
}
|
||||
|
||||
# 准备multipart/form-data数据
|
||||
data = aiohttp.FormData()
|
||||
data.add_field('speaker_ids', ','.join(self.speaker_ids))
|
||||
data.add_field('file', audio_data, filename='audio.wav', content_type='audio/wav')
|
||||
|
||||
timeout = aiohttp.ClientTimeout(total=10)
|
||||
|
||||
# 网络请求
|
||||
async with aiohttp.ClientSession(timeout=timeout) as session:
|
||||
async with session.post(self.api_url, headers=headers, data=data) as response:
|
||||
|
||||
if response.status == 200:
|
||||
result = await response.json()
|
||||
speaker_id = result.get("speaker_id")
|
||||
score = result.get("score", 0)
|
||||
total_elapsed_time = time.monotonic() - api_start_time
|
||||
|
||||
logger.bind(tag=TAG).info(f"声纹识别耗时: {total_elapsed_time:.3f}s")
|
||||
|
||||
# 相似度阈值检查
|
||||
if score < self.similarity_threshold:
|
||||
logger.bind(tag=TAG).warning(f"声纹识别相似度{score:.3f}低于阈值{self.similarity_threshold}")
|
||||
return "未知说话人"
|
||||
|
||||
if speaker_id and speaker_id in self.speaker_map:
|
||||
result_name = self.speaker_map[speaker_id]["name"]
|
||||
logger.bind(tag=TAG).info(f"声纹识别成功: {result_name} (相似度: {score:.3f})")
|
||||
return result_name
|
||||
else:
|
||||
logger.bind(tag=TAG).warning(f"未识别的说话人ID: {speaker_id}")
|
||||
return "未知说话人"
|
||||
else:
|
||||
logger.bind(tag=TAG).error(f"声纹识别API错误: HTTP {response.status}")
|
||||
return None
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
elapsed = time.monotonic() - api_start_time
|
||||
logger.bind(tag=TAG).error(f"声纹识别超时: {elapsed:.3f}s")
|
||||
return None
|
||||
except Exception as e:
|
||||
elapsed = time.monotonic() - api_start_time
|
||||
logger.bind(tag=TAG).error(f"声纹识别失败: {e}")
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
import os
|
||||
import re
|
||||
import yaml
|
||||
import time
|
||||
import hashlib
|
||||
import portalocker
|
||||
from typing import Dict
|
||||
|
||||
|
||||
class FileLock:
|
||||
def __init__(self, file, timeout=5):
|
||||
self.file = file
|
||||
self.timeout = timeout
|
||||
self.start_time = None
|
||||
|
||||
def __enter__(self):
|
||||
self.start_time = time.time()
|
||||
while True:
|
||||
try:
|
||||
portalocker.lock(self.file, portalocker.LOCK_EX | portalocker.LOCK_NB)
|
||||
return self.file
|
||||
except portalocker.LockException:
|
||||
if time.time() - self.start_time > self.timeout:
|
||||
raise TimeoutError("获取文件锁超时")
|
||||
time.sleep(0.1)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
portalocker.unlock(self.file)
|
||||
|
||||
|
||||
class WakeupWordsConfig:
|
||||
def __init__(self):
|
||||
self.config_file = "data/.wakeup_words.yaml"
|
||||
self.assets_dir = "config/assets/wakeup_words"
|
||||
self._ensure_directories()
|
||||
self._config_cache = None
|
||||
self._last_load_time = 0
|
||||
self._cache_ttl = 1 # 缓存有效期(秒)
|
||||
self._lock_timeout = 5 # 文件锁超时时间(秒)
|
||||
|
||||
def _ensure_directories(self):
|
||||
"""确保必要的目录存在"""
|
||||
os.makedirs(os.path.dirname(self.config_file), exist_ok=True)
|
||||
os.makedirs(self.assets_dir, exist_ok=True)
|
||||
|
||||
def _load_config(self) -> Dict:
|
||||
"""加载配置文件,使用缓存机制"""
|
||||
current_time = time.time()
|
||||
|
||||
# 如果缓存有效,直接返回缓存
|
||||
if (
|
||||
self._config_cache is not None
|
||||
and current_time - self._last_load_time < self._cache_ttl
|
||||
):
|
||||
return self._config_cache
|
||||
|
||||
try:
|
||||
with open(self.config_file, "a+", encoding="utf-8") as f:
|
||||
with FileLock(f, timeout=self._lock_timeout):
|
||||
f.seek(0)
|
||||
content = f.read()
|
||||
config = yaml.safe_load(content) if content else {}
|
||||
self._config_cache = config
|
||||
self._last_load_time = current_time
|
||||
return config
|
||||
except (TimeoutError, IOError) as e:
|
||||
print(f"加载配置文件失败: {e}")
|
||||
return {}
|
||||
except Exception as e:
|
||||
print(f"加载配置文件时发生未知错误: {e}")
|
||||
return {}
|
||||
|
||||
def _save_config(self, config: Dict):
|
||||
"""保存配置到文件,使用文件锁保护"""
|
||||
try:
|
||||
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||
with FileLock(f, timeout=self._lock_timeout):
|
||||
yaml.dump(config, f, allow_unicode=True)
|
||||
self._config_cache = config
|
||||
self._last_load_time = time.time()
|
||||
except (TimeoutError, IOError) as e:
|
||||
print(f"保存配置文件失败: {e}")
|
||||
raise
|
||||
except Exception as e:
|
||||
print(f"保存配置文件时发生未知错误: {e}")
|
||||
raise
|
||||
|
||||
def get_wakeup_response(self, voice: str) -> Dict:
|
||||
voice = hashlib.md5(voice.encode()).hexdigest()
|
||||
"""获取唤醒词回复配置"""
|
||||
config = self._load_config()
|
||||
|
||||
if not config or voice not in config:
|
||||
return None
|
||||
|
||||
# 检查文件大小
|
||||
file_path = config[voice]["file_path"]
|
||||
if not os.path.exists(file_path) or os.stat(file_path).st_size < (15 * 1024):
|
||||
return None
|
||||
|
||||
return config[voice]
|
||||
|
||||
def update_wakeup_response(self, voice: str, file_path: str, text: str):
|
||||
"""更新唤醒词回复配置"""
|
||||
try:
|
||||
# 过滤表情符号
|
||||
filtered_text = re.sub(r'[\U0001F600-\U0001F64F\U0001F900-\U0001F9FF]', '', text)
|
||||
|
||||
config = self._load_config()
|
||||
voice_hash = hashlib.md5(voice.encode()).hexdigest()
|
||||
config[voice_hash] = {
|
||||
"voice": voice,
|
||||
"file_path": file_path,
|
||||
"time": time.time(),
|
||||
"text": filtered_text,
|
||||
}
|
||||
self._save_config(config)
|
||||
except Exception as e:
|
||||
print(f"更新唤醒词回复配置失败: {e}")
|
||||
raise
|
||||
|
||||
def generate_file_path(self, voice: str) -> str:
|
||||
"""生成音频文件路径,使用voice的哈希值作为文件名"""
|
||||
try:
|
||||
# 生成voice的哈希值
|
||||
voice_hash = hashlib.md5(voice.encode()).hexdigest()
|
||||
file_path = os.path.join(self.assets_dir, f"{voice_hash}.wav")
|
||||
|
||||
# 如果文件已存在,先删除
|
||||
if os.path.exists(file_path):
|
||||
try:
|
||||
os.remove(file_path)
|
||||
except Exception as e:
|
||||
print(f"删除已存在的音频文件失败: {e}")
|
||||
raise
|
||||
|
||||
return file_path
|
||||
except Exception as e:
|
||||
print(f"生成音频文件路径失败: {e}")
|
||||
raise
|
||||
Reference in New Issue
Block a user