仿生人AI服务端
This commit is contained in:
@@ -0,0 +1,354 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
import concurrent.futures
|
||||
from typing import Dict, Optional
|
||||
import aiohttp
|
||||
from tabulate import tabulate
|
||||
from core.utils.asr import create_instance as create_stt_instance
|
||||
|
||||
# 设置全局日志级别为WARNING,抑制INFO级别日志
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
description = "语音识别模型性能测试"
|
||||
|
||||
class ASRPerformanceTester:
|
||||
def __init__(self):
|
||||
self.config = self._load_config_from_data_dir()
|
||||
self.test_wav_list = self._load_test_wav_files()
|
||||
self.results = {"stt": {}}
|
||||
|
||||
# 调试日志
|
||||
print(f"[DEBUG] 加载的ASR配置: {self.config.get('ASR', {})}")
|
||||
print(f"[DEBUG] 音频文件数量: {len(self.test_wav_list)}")
|
||||
|
||||
def _load_config_from_data_dir(self) -> Dict:
|
||||
"""从 data 目录加载所有 .config.yaml 文件的配置"""
|
||||
config = {"ASR": {}}
|
||||
data_dir = os.path.join(os.getcwd(), "data")
|
||||
print(f"[DEBUG] 扫描配置文件目录: {data_dir}")
|
||||
|
||||
for root, _, files in os.walk(data_dir):
|
||||
for file in files:
|
||||
if file.endswith(".config.yaml"):
|
||||
file_path = os.path.join(root, file)
|
||||
try:
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
import yaml
|
||||
file_config = yaml.safe_load(f)
|
||||
# 兼容大小写的 ASR/asr 配置
|
||||
asr_config = file_config.get("ASR") or file_config.get("asr")
|
||||
if asr_config:
|
||||
config["ASR"].update(asr_config)
|
||||
print(f"[DEBUG] 从 {file_path} 加载 ASR 配置成功")
|
||||
except Exception as e:
|
||||
print(f" 加载配置文件 {file_path} 失败: {str(e)}")
|
||||
return config
|
||||
|
||||
def _load_test_wav_files(self) -> list:
|
||||
"""加载测试用的音频文件(添加路径调试)"""
|
||||
wav_root = os.path.join(os.getcwd(), "config", "assets")
|
||||
print(f"[DEBUG] 音频文件目录: {wav_root}")
|
||||
test_wav_list = []
|
||||
|
||||
if os.path.exists(wav_root):
|
||||
file_list = os.listdir(wav_root)
|
||||
print(f"[DEBUG] 找到音频文件: {file_list}")
|
||||
for file_name in file_list:
|
||||
file_path = os.path.join(wav_root, file_name)
|
||||
if os.path.getsize(file_path) > 300 * 1024: # 300KB
|
||||
with open(file_path, "rb") as f:
|
||||
test_wav_list.append(f.read())
|
||||
else:
|
||||
print(f" 目录不存在: {wav_root}")
|
||||
return test_wav_list
|
||||
|
||||
async def _test_single_audio(self, stt_name: str, stt, audio_data: bytes) -> Optional[float]:
|
||||
"""测试单个音频文件的性能"""
|
||||
try:
|
||||
start_time = time.time()
|
||||
text, _ = await stt.speech_to_text([audio_data], "1", stt.audio_format)
|
||||
if text is None:
|
||||
return None
|
||||
|
||||
duration = time.time() - start_time
|
||||
|
||||
# 检测0.000s的异常时间
|
||||
if abs(duration) < 0.001: # 小于1毫秒视为异常
|
||||
print(f"{stt_name} 检测到异常时间: {duration:.6f}s (视为错误)")
|
||||
return None
|
||||
|
||||
return duration
|
||||
except Exception as e:
|
||||
error_msg = str(e).lower()
|
||||
if "502" in error_msg or "bad gateway" in error_msg:
|
||||
print(f"{stt_name} 遇到502错误")
|
||||
return None
|
||||
return None
|
||||
|
||||
async def _test_stt_with_timeout(self, stt_name: str, config: Dict) -> Dict:
|
||||
"""异步测试单个STT性能,带超时控制"""
|
||||
try:
|
||||
# 检查配置有效性
|
||||
token_fields = ["access_token", "api_key", "token"]
|
||||
if any(
|
||||
field in config
|
||||
and str(config[field]).lower() in ["你的", "placeholder", "none", "null", ""]
|
||||
for field in token_fields
|
||||
):
|
||||
print(f" STT {stt_name} 未配置有效access_token/api_key,已跳过")
|
||||
return {
|
||||
"name": stt_name,
|
||||
"type": "stt",
|
||||
"errors": 1,
|
||||
"error_type": "配置错误"
|
||||
}
|
||||
|
||||
module_type = config.get("type", stt_name)
|
||||
stt = create_stt_instance(module_type, config, delete_audio_file=True)
|
||||
stt.audio_format = "pcm"
|
||||
|
||||
print(f" 测试 STT: {stt_name}")
|
||||
|
||||
# 使用线程池和超时控制
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
# 测试第一个音频文件作为连通性检查
|
||||
try:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
lambda: asyncio.run(self._test_single_audio(stt_name, stt, self.test_wav_list[0]))
|
||||
)
|
||||
first_result = await asyncio.wait_for(
|
||||
asyncio.wrap_future(future), timeout=10.0
|
||||
)
|
||||
|
||||
if first_result is None:
|
||||
print(f" {stt_name} 连接失败")
|
||||
return {
|
||||
"name": stt_name,
|
||||
"type": "stt",
|
||||
"errors": 1,
|
||||
"error_type": "网络错误"
|
||||
}
|
||||
except asyncio.TimeoutError:
|
||||
print(f" {stt_name} 连接超时(10秒),跳过")
|
||||
return {
|
||||
"name": stt_name,
|
||||
"type": "stt",
|
||||
"errors": 1,
|
||||
"error_type": "超时连接"
|
||||
}
|
||||
except Exception as e:
|
||||
error_msg = str(e).lower()
|
||||
if "502" in error_msg or "bad gateway" in error_msg:
|
||||
print(f" {stt_name} 遇到502错误,跳过")
|
||||
return {
|
||||
"name": stt_name,
|
||||
"type": "stt",
|
||||
"errors": 1,
|
||||
"error_type": "502网络错误"
|
||||
}
|
||||
print(f" {stt_name} 连接异常: {str(e)}")
|
||||
return {
|
||||
"name": stt_name,
|
||||
"type": "stt",
|
||||
"errors": 1,
|
||||
"error_type": "网络错误"
|
||||
}
|
||||
|
||||
# 全量测试,带超时控制
|
||||
total_time = 0
|
||||
valid_tests = 0
|
||||
test_count = len(self.test_wav_list)
|
||||
|
||||
for i, audio_data in enumerate(self.test_wav_list, 1):
|
||||
try:
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
future = executor.submit(
|
||||
lambda: asyncio.run(self._test_single_audio(stt_name, stt, audio_data))
|
||||
)
|
||||
duration = await asyncio.wait_for(
|
||||
asyncio.wrap_future(future), timeout=10.0
|
||||
)
|
||||
|
||||
if duration is not None and duration > 0.001:
|
||||
total_time += duration
|
||||
valid_tests += 1
|
||||
print(f" {stt_name} [{i}/{test_count}] 耗时: {duration:.2f}s")
|
||||
else:
|
||||
print(f" {stt_name} [{i}/{test_count}] 测试失败(含0.000s异常)")
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
print(f" {stt_name} [{i}/{test_count}] 超时(10秒),跳过")
|
||||
continue
|
||||
except Exception as e:
|
||||
error_msg = str(e).lower()
|
||||
if "502" in error_msg or "bad gateway" in error_msg:
|
||||
print(f" {stt_name} [{i}/{test_count}] 502错误,跳过")
|
||||
return {
|
||||
"name": stt_name,
|
||||
"type": "stt",
|
||||
"errors": 1,
|
||||
"error_type": "502网络错误"
|
||||
}
|
||||
print(f" {stt_name} [{i}/{test_count}] 异常: {str(e)}")
|
||||
continue
|
||||
# 检查有效测试数量
|
||||
if valid_tests < test_count * 0.3: # 至少30%成功率
|
||||
print(f" {stt_name} 成功测试过少({valid_tests}/{test_count}),可能网络不稳定")
|
||||
return {
|
||||
"name": stt_name,
|
||||
"type": "stt",
|
||||
"errors": 1,
|
||||
"error_type": "网络错误"
|
||||
}
|
||||
|
||||
if valid_tests == 0:
|
||||
return {
|
||||
"name": stt_name,
|
||||
"type": "stt",
|
||||
"errors": 1,
|
||||
"error_type": "网络错误"
|
||||
}
|
||||
|
||||
avg_time = total_time / valid_tests
|
||||
return {
|
||||
"name": stt_name,
|
||||
"type": "stt",
|
||||
"avg_time": avg_time,
|
||||
"success_rate": f"{valid_tests}/{test_count}",
|
||||
"errors": 0,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = str(e).lower()
|
||||
if "502" in error_msg or "bad gateway" in error_msg:
|
||||
error_type = "502网络错误"
|
||||
elif "timeout" in error_msg:
|
||||
error_type = "超时连接"
|
||||
else:
|
||||
error_type = "网络错误"
|
||||
print(f"⚠️ {stt_name} 测试失败: {str(e)}")
|
||||
return {
|
||||
"name": stt_name,
|
||||
"type": "stt",
|
||||
"errors": 1,
|
||||
"error_type": error_type
|
||||
}
|
||||
|
||||
def _print_results(self):
|
||||
"""打印测试结果,按响应时间排序"""
|
||||
print("\n" + "=" * 50)
|
||||
print("ASR 性能测试结果")
|
||||
print("=" * 50)
|
||||
|
||||
if not self.results.get("stt"):
|
||||
print("没有可用的测试结果")
|
||||
return
|
||||
|
||||
headers = ["模型名称", "平均耗时(s)", "成功率", "状态"]
|
||||
table_data = []
|
||||
|
||||
# 收集所有数据并分类
|
||||
valid_results = []
|
||||
error_results = []
|
||||
|
||||
for name, data in self.results["stt"].items():
|
||||
if data["errors"] == 0:
|
||||
# 正常结果
|
||||
avg_time = f"{data['avg_time']:.3f}"
|
||||
success_rate = data.get("success_rate", "N/A")
|
||||
status = "✅ 正常"
|
||||
|
||||
# 保存用于排序的值
|
||||
sort_key = data["avg_time"]
|
||||
|
||||
valid_results.append({
|
||||
"name": name,
|
||||
"avg_time": avg_time,
|
||||
"success_rate": success_rate,
|
||||
"status": status,
|
||||
"sort_key": sort_key,
|
||||
})
|
||||
else:
|
||||
# 错误结果
|
||||
avg_time = "-"
|
||||
success_rate = "0/N"
|
||||
|
||||
# 获取具体错误类型
|
||||
error_type = data.get("error_type", "网络错误")
|
||||
status = f"❌ {error_type}"
|
||||
|
||||
error_results.append([name, avg_time, success_rate, status])
|
||||
|
||||
# 按响应时间升序排序(从快到慢)
|
||||
valid_results.sort(key=lambda x: x["sort_key"])
|
||||
|
||||
# 将排序后的有效结果转换为表格数据
|
||||
for result in valid_results:
|
||||
table_data.append([
|
||||
result["name"],
|
||||
result["avg_time"],
|
||||
result["success_rate"],
|
||||
result["status"],
|
||||
])
|
||||
|
||||
# 将错误结果添加到表格数据末尾
|
||||
table_data.extend(error_results)
|
||||
|
||||
print(tabulate(table_data, headers=headers, tablefmt="grid"))
|
||||
print("\n测试说明:")
|
||||
print("- 超时控制:单个音频最大等待时间为10秒")
|
||||
print("- 错误处理:自动跳过502错误、超时和网络异常的模型")
|
||||
print("- 成功率:成功识别的音频数量/总测试音频数量")
|
||||
print("- 排序规则:按平均耗时从快到慢排序,错误模型排最后")
|
||||
print("\n测试完成!")
|
||||
|
||||
async def run(self):
|
||||
"""执行全量异步测试"""
|
||||
print("开始筛选可用ASR模块...")
|
||||
if not self.config.get("ASR"):
|
||||
print("配置中未找到 ASR 模块")
|
||||
return
|
||||
|
||||
all_tasks = []
|
||||
for stt_name, config in self.config["ASR"].items():
|
||||
# 检查配置有效性
|
||||
token_fields = ["access_token", "api_key", "token"]
|
||||
if any(
|
||||
field in config
|
||||
and str(config[field]).lower() in ["你的", "placeholder", "none", "null", ""]
|
||||
for field in token_fields
|
||||
):
|
||||
print(f"ASR {stt_name} 未配置有效access_token/api_key,已跳过")
|
||||
continue
|
||||
|
||||
print(f"添加 ASR 测试任务: {stt_name}")
|
||||
all_tasks.append(self._test_stt_with_timeout(stt_name, config))
|
||||
|
||||
if not all_tasks:
|
||||
print("没有可用的ASR模块进行测试。")
|
||||
return
|
||||
|
||||
print(f"\n找到 {len(all_tasks)} 个可用ASR模块")
|
||||
print("\n开始并发测试所有ASR模块...")
|
||||
all_results = await asyncio.gather(*all_tasks, return_exceptions=True)
|
||||
|
||||
# 处理结果
|
||||
for result in all_results:
|
||||
if isinstance(result, dict) and result.get("type") == "stt":
|
||||
self.results["stt"][result["name"]] = result
|
||||
|
||||
# 打印结果
|
||||
self._print_results()
|
||||
|
||||
|
||||
async def main():
|
||||
tester = ASRPerformanceTester()
|
||||
await tester.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,544 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import statistics
|
||||
import time
|
||||
import concurrent.futures
|
||||
from typing import Dict, Optional
|
||||
import yaml
|
||||
import aiohttp
|
||||
from tabulate import tabulate
|
||||
from core.utils.llm import create_instance as create_llm_instance
|
||||
from config.settings import load_config
|
||||
|
||||
# 设置全局日志级别为 WARNING,抑制 INFO 级别日志
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
description = "大语言模型性能测试"
|
||||
|
||||
|
||||
class LLMPerformanceTester:
|
||||
def __init__(self):
|
||||
self.config = load_config()
|
||||
# 使用更符合智能体场景的测试内容,包含系统提示词
|
||||
self.system_prompt = self._load_system_prompt()
|
||||
self.test_sentences = self.config.get("module_test", {}).get(
|
||||
"test_sentences",
|
||||
[
|
||||
"你好,我今天心情不太好,能安慰一下我吗?",
|
||||
"帮我查一下明天的天气如何?",
|
||||
"我想听一个有趣的故事,你能给我讲一个吗?",
|
||||
"现在几点了?今天是星期几?",
|
||||
"我想设置一个明天早上8点的闹钟提醒我开会",
|
||||
],
|
||||
)
|
||||
self.results = {}
|
||||
|
||||
def _load_system_prompt(self) -> str:
|
||||
"""加载系统提示词"""
|
||||
try:
|
||||
prompt_file = os.path.join(
|
||||
os.path.dirname(os.path.dirname(__file__)), self.config.get("prompt_template", "agent-base-prompt.txt")
|
||||
)
|
||||
with open(prompt_file, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
# 替换模板变量为测试值
|
||||
content = content.replace(
|
||||
"{{base_prompt}}", "你是小智,一个聪明可爱的AI助手"
|
||||
)
|
||||
content = content.replace(
|
||||
"{{emojiList}}", "😀,😃,😄,😁,😊,😍,🤔,😮,😱,😢,😭,😴,😵,🤗,🙄"
|
||||
)
|
||||
content = content.replace("{{current_time}}", "2024年8月17日 12:30:45")
|
||||
content = content.replace("{{today_date}}", "2024年8月17日")
|
||||
content = content.replace("{{today_weekday}}", "星期六")
|
||||
content = content.replace("{{lunar_date}}", "甲辰年七月十四")
|
||||
content = content.replace("{{local_address}}", "北京市")
|
||||
content = content.replace("{{weather_info}}", "今天晴,25-32℃")
|
||||
return content
|
||||
except Exception as e:
|
||||
print(f"无法加载系统提示词文件: {e}")
|
||||
return "你是小智,一个聪明可爱的AI助手。请用温暖友善的语气回复用户。"
|
||||
|
||||
def _collect_response_sync(self, llm, messages, llm_name, sentence_start):
|
||||
"""同步收集响应数据的辅助方法"""
|
||||
chunks = []
|
||||
first_token_received = False
|
||||
first_token_time = None
|
||||
|
||||
try:
|
||||
response_generator = llm.response("perf_test", messages)
|
||||
chunk_count = 0
|
||||
for chunk in response_generator:
|
||||
chunk_count += 1
|
||||
# 每处理一定数量的chunk就检查一下是否应该中断
|
||||
if chunk_count % 10 == 0:
|
||||
# 通过检查当前线程是否被标记为中断来提前退出
|
||||
import threading
|
||||
|
||||
if (
|
||||
threading.current_thread().ident
|
||||
!= threading.main_thread().ident
|
||||
):
|
||||
# 如果不是主线程,检查是否应该停止
|
||||
pass
|
||||
|
||||
# 检查chunk是否包含错误信息
|
||||
chunk_str = str(chunk)
|
||||
if (
|
||||
"异常" in chunk_str
|
||||
or "错误" in chunk_str
|
||||
or "502" in chunk_str.lower()
|
||||
):
|
||||
error_msg = chunk_str.lower()
|
||||
print(f"{llm_name} 响应包含错误信息: {error_msg}")
|
||||
# 抛出一个包含错误信息的异常
|
||||
raise Exception(chunk_str)
|
||||
|
||||
if not first_token_received and chunk.strip() != "":
|
||||
first_token_time = time.time() - sentence_start
|
||||
first_token_received = True
|
||||
print(f"{llm_name} 首个 Token: {first_token_time:.3f}s")
|
||||
chunks.append(chunk)
|
||||
except Exception as e:
|
||||
# 更详细的错误信息
|
||||
error_msg = str(e).lower()
|
||||
print(f"{llm_name} 响应收集异常: {error_msg}")
|
||||
# 对于502错误或网络错误,直接抛出异常让上层处理
|
||||
if (
|
||||
"502" in error_msg
|
||||
or "bad gateway" in error_msg
|
||||
or "error code: 502" in error_msg
|
||||
or "异常" in str(e)
|
||||
or "错误" in str(e)
|
||||
):
|
||||
raise e
|
||||
# 对于其他错误,可以返回部分结果
|
||||
return chunks, first_token_time
|
||||
|
||||
return chunks, first_token_time
|
||||
|
||||
async def _check_ollama_service(self, base_url: str, model_name: str) -> bool:
|
||||
"""异步检查 Ollama 服务状态"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
try:
|
||||
async with session.get(f"{base_url}/api/version") as response:
|
||||
if response.status != 200:
|
||||
print(f"Ollama 服务未启动或无法访问: {base_url}")
|
||||
return False
|
||||
async with session.get(f"{base_url}/api/tags") as response:
|
||||
if response.status == 200:
|
||||
data = await response.json()
|
||||
models = data.get("models", [])
|
||||
if not any(model["name"] == model_name for model in models):
|
||||
print(
|
||||
f"Ollama 模型 {model_name} 未找到,请先使用 `ollama pull {model_name}` 下载"
|
||||
)
|
||||
return False
|
||||
else:
|
||||
print("无法获取 Ollama 模型列表")
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"无法连接到 Ollama 服务: {str(e)}")
|
||||
return False
|
||||
|
||||
async def _test_single_sentence(
|
||||
self, llm_name: str, llm, sentence: str
|
||||
) -> Optional[Dict]:
|
||||
"""测试单个句子的性能"""
|
||||
try:
|
||||
print(f"{llm_name} 开始测试: {sentence[:20]}...")
|
||||
sentence_start = time.time()
|
||||
first_token_received = False
|
||||
first_token_time = None
|
||||
|
||||
# 构建包含系统提示词的消息
|
||||
messages = [
|
||||
{"role": "system", "content": self.system_prompt},
|
||||
{"role": "user", "content": sentence},
|
||||
]
|
||||
|
||||
# 使用asyncio.wait_for进行超时控制
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# 创建响应收集任务
|
||||
future = executor.submit(
|
||||
self._collect_response_sync,
|
||||
llm,
|
||||
messages,
|
||||
llm_name,
|
||||
sentence_start,
|
||||
)
|
||||
|
||||
# 使用asyncio.wait_for实现超时控制
|
||||
try:
|
||||
response_chunks, first_token_time = await asyncio.wait_for(
|
||||
asyncio.wrap_future(future), timeout=10.0
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
print(f"{llm_name} 测试超时(10秒),跳过")
|
||||
# 强制取消future
|
||||
future.cancel()
|
||||
# 等待一小段时间确保线程池任务能够响应取消
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
asyncio.wrap_future(future), timeout=1.0
|
||||
)
|
||||
except (
|
||||
asyncio.TimeoutError,
|
||||
concurrent.futures.CancelledError,
|
||||
Exception,
|
||||
):
|
||||
# 忽略所有异常,确保程序继续执行
|
||||
pass
|
||||
return None
|
||||
|
||||
except Exception as timeout_error:
|
||||
print(f"{llm_name} 处理异常: {timeout_error}")
|
||||
return None
|
||||
|
||||
response_time = time.time() - sentence_start
|
||||
print(f"{llm_name} 完成响应: {response_time:.3f}s")
|
||||
|
||||
return {
|
||||
"name": llm_name,
|
||||
"type": "llm",
|
||||
"first_token_time": first_token_time,
|
||||
"response_time": response_time,
|
||||
}
|
||||
except Exception as e:
|
||||
error_msg = str(e).lower()
|
||||
# 检查是否为502错误或网络错误
|
||||
if (
|
||||
"502" in error_msg
|
||||
or "bad gateway" in error_msg
|
||||
or "error code: 502" in error_msg
|
||||
):
|
||||
print(f"{llm_name} 遇到502错误,跳过测试")
|
||||
return {
|
||||
"name": llm_name,
|
||||
"type": "llm",
|
||||
"errors": 1,
|
||||
"error_type": "502网络错误",
|
||||
}
|
||||
print(f"{llm_name} 句子测试失败: {str(e)}")
|
||||
return None
|
||||
|
||||
async def _test_llm(self, llm_name: str, config: Dict) -> Dict:
|
||||
"""异步测试单个 LLM 性能"""
|
||||
try:
|
||||
# 对于 Ollama,跳过 api_key 检查并进行特殊处理
|
||||
if llm_name == "Ollama":
|
||||
base_url = config.get("base_url", "http://localhost:11434")
|
||||
model_name = config.get("model_name")
|
||||
if not model_name:
|
||||
print("Ollama 未配置 model_name")
|
||||
return {
|
||||
"name": llm_name,
|
||||
"type": "llm",
|
||||
"errors": 1,
|
||||
"error_type": "网络错误",
|
||||
}
|
||||
|
||||
if not await self._check_ollama_service(base_url, model_name):
|
||||
return {
|
||||
"name": llm_name,
|
||||
"type": "llm",
|
||||
"errors": 1,
|
||||
"error_type": "网络错误",
|
||||
}
|
||||
else:
|
||||
if "api_key" in config and any(
|
||||
x in config["api_key"] for x in ["你的", "placeholder", "sk-xxx"]
|
||||
):
|
||||
print(f"跳过未配置的 LLM: {llm_name}")
|
||||
return {
|
||||
"name": llm_name,
|
||||
"type": "llm",
|
||||
"errors": 1,
|
||||
"error_type": "配置错误",
|
||||
}
|
||||
|
||||
# 获取实际类型(兼容旧配置)
|
||||
module_type = config.get("type", llm_name)
|
||||
llm = create_llm_instance(module_type, config)
|
||||
|
||||
# 统一使用 UTF-8 编码
|
||||
test_sentences = [
|
||||
s.encode("utf-8").decode("utf-8") for s in self.test_sentences
|
||||
]
|
||||
|
||||
# 创建所有句子的测试任务
|
||||
sentence_tasks = []
|
||||
for sentence in test_sentences:
|
||||
sentence_tasks.append(
|
||||
self._test_single_sentence(llm_name, llm, sentence)
|
||||
)
|
||||
|
||||
# 并发执行所有句子测试,并处理可能的异常
|
||||
sentence_results = await asyncio.gather(
|
||||
*sentence_tasks, return_exceptions=True
|
||||
)
|
||||
|
||||
# 处理结果,过滤掉异常和None值
|
||||
valid_results = []
|
||||
for result in sentence_results:
|
||||
if isinstance(result, dict) and result is not None:
|
||||
valid_results.append(result)
|
||||
elif isinstance(result, Exception):
|
||||
error_msg = str(result).lower()
|
||||
if "502" in error_msg or "bad gateway" in error_msg:
|
||||
print(f"{llm_name} 遇到502错误,跳过该句子测试")
|
||||
return {
|
||||
"name": llm_name,
|
||||
"type": "llm",
|
||||
"errors": 1,
|
||||
"error_type": "502网络错误",
|
||||
}
|
||||
else:
|
||||
print(f"{llm_name} 句子测试异常: {result}")
|
||||
|
||||
if not valid_results:
|
||||
print(f"{llm_name} 无有效数据,可能遇到网络问题或配置错误")
|
||||
return {
|
||||
"name": llm_name,
|
||||
"type": "llm",
|
||||
"errors": 1,
|
||||
"error_type": "网络错误",
|
||||
}
|
||||
|
||||
# 检查有效结果数量,如果太少则认为测试失败
|
||||
if len(valid_results) < len(test_sentences) * 0.3: # 至少要有30%的成功率
|
||||
print(
|
||||
f"{llm_name} 成功测试句子过少({len(valid_results)}/{len(test_sentences)}),可能网络不稳定或接口有问题"
|
||||
)
|
||||
return {
|
||||
"name": llm_name,
|
||||
"type": "llm",
|
||||
"errors": 1,
|
||||
"error_type": "网络错误",
|
||||
}
|
||||
|
||||
first_token_times = [
|
||||
r["first_token_time"]
|
||||
for r in valid_results
|
||||
if r.get("first_token_time")
|
||||
]
|
||||
response_times = [r["response_time"] for r in valid_results]
|
||||
|
||||
# 过滤异常数据(超出3个标准差的数据)
|
||||
if len(response_times) > 1:
|
||||
mean = statistics.mean(response_times)
|
||||
stdev = statistics.stdev(response_times)
|
||||
filtered_times = [t for t in response_times if t <= mean + 3 * stdev]
|
||||
else:
|
||||
filtered_times = response_times
|
||||
|
||||
return {
|
||||
"name": llm_name,
|
||||
"type": "llm",
|
||||
"avg_response": sum(response_times) / len(response_times),
|
||||
"avg_first_token": (
|
||||
sum(first_token_times) / len(first_token_times)
|
||||
if first_token_times
|
||||
else 0
|
||||
),
|
||||
"success_rate": f"{len(valid_results)}/{len(test_sentences)}",
|
||||
"errors": 0,
|
||||
}
|
||||
except Exception as e:
|
||||
error_msg = str(e).lower()
|
||||
if "502" in error_msg or "bad gateway" in error_msg:
|
||||
print(f"LLM {llm_name} 遇到502错误,跳过测试")
|
||||
else:
|
||||
print(f"LLM {llm_name} 测试失败: {str(e)}")
|
||||
error_type = "网络错误"
|
||||
if "timeout" in str(e).lower():
|
||||
error_type = "超时连接"
|
||||
return {
|
||||
"name": llm_name,
|
||||
"type": "llm",
|
||||
"errors": 1,
|
||||
"error_type": error_type,
|
||||
}
|
||||
|
||||
def _print_results(self):
|
||||
"""打印测试结果"""
|
||||
print("\n" + "=" * 50)
|
||||
print("LLM 性能测试结果")
|
||||
print("=" * 50)
|
||||
|
||||
if not self.results:
|
||||
print("没有可用的测试结果")
|
||||
return
|
||||
|
||||
headers = ["模型名称", "平均响应时间(s)", "首Token时间(s)", "成功率", "状态"]
|
||||
table_data = []
|
||||
|
||||
# 收集所有数据并分类
|
||||
valid_results = []
|
||||
error_results = []
|
||||
|
||||
for name, data in self.results.items():
|
||||
if data["errors"] == 0:
|
||||
# 正常结果
|
||||
avg_response = f"{data['avg_response']:.3f}"
|
||||
avg_first_token = (
|
||||
f"{data['avg_first_token']:.3f}"
|
||||
if data["avg_first_token"] > 0
|
||||
else "-"
|
||||
)
|
||||
success_rate = data.get("success_rate", "N/A")
|
||||
status = "✅ 正常"
|
||||
|
||||
# 保存用于排序的值
|
||||
first_token_value = (
|
||||
data["avg_first_token"]
|
||||
if data["avg_first_token"] > 0
|
||||
else float("inf")
|
||||
)
|
||||
|
||||
valid_results.append(
|
||||
{
|
||||
"name": name,
|
||||
"avg_response": avg_response,
|
||||
"avg_first_token": avg_first_token,
|
||||
"success_rate": success_rate,
|
||||
"status": status,
|
||||
"sort_key": first_token_value,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# 错误结果
|
||||
avg_response = "-"
|
||||
avg_first_token = "-"
|
||||
success_rate = "0/5"
|
||||
|
||||
# 获取具体错误类型
|
||||
error_type = data.get("error_type", "网络错误")
|
||||
status = f"❌ {error_type}"
|
||||
|
||||
error_results.append(
|
||||
[name, avg_response, avg_first_token, success_rate, status]
|
||||
)
|
||||
|
||||
# 按首Token时间升序排序
|
||||
valid_results.sort(key=lambda x: x["sort_key"])
|
||||
|
||||
# 将排序后的有效结果转换为表格数据
|
||||
for result in valid_results:
|
||||
table_data.append(
|
||||
[
|
||||
result["name"],
|
||||
result["avg_response"],
|
||||
result["avg_first_token"],
|
||||
result["success_rate"],
|
||||
result["status"],
|
||||
]
|
||||
)
|
||||
|
||||
# 将错误结果添加到表格数据末尾
|
||||
table_data.extend(error_results)
|
||||
|
||||
print(tabulate(table_data, headers=headers, tablefmt="grid"))
|
||||
print("\n测试说明:")
|
||||
print("- 测试内容:包含完整系统提示词的智能体对话场景")
|
||||
print("- 超时控制:单个请求最大等待时间为10秒")
|
||||
print("- 错误处理:自动跳过502错误和网络异常的模型")
|
||||
print("- 成功率:成功响应的句子数量/总测试句子数量")
|
||||
print("\n测试完成!")
|
||||
|
||||
async def run(self):
|
||||
"""执行全量异步测试"""
|
||||
print("开始筛选可用 LLM 模块...")
|
||||
|
||||
# 创建所有测试任务
|
||||
all_tasks = []
|
||||
|
||||
# LLM 测试任务
|
||||
if self.config.get("LLM") is not None:
|
||||
for llm_name, config in self.config.get("LLM", {}).items():
|
||||
# 检查配置有效性
|
||||
if llm_name == "CozeLLM":
|
||||
if any(x in config.get("bot_id", "") for x in ["你的"]) or any(
|
||||
x in config.get("user_id", "") for x in ["你的"]
|
||||
):
|
||||
print(f"LLM {llm_name} 未配置 bot_id/user_id,已跳过")
|
||||
continue
|
||||
elif "api_key" in config and any(
|
||||
x in config["api_key"] for x in ["你的", "placeholder", "sk-xxx"]
|
||||
):
|
||||
print(f"LLM {llm_name} 未配置 api_key,已跳过")
|
||||
continue
|
||||
|
||||
# 对于 Ollama,先检查服务状态
|
||||
if llm_name == "Ollama":
|
||||
base_url = config.get("base_url", "http://localhost:11434")
|
||||
model_name = config.get("model_name")
|
||||
if not model_name:
|
||||
print("Ollama 未配置 model_name")
|
||||
continue
|
||||
|
||||
if not await self._check_ollama_service(base_url, model_name):
|
||||
continue
|
||||
|
||||
print(f"添加 LLM 测试任务: {llm_name}")
|
||||
all_tasks.append(self._test_llm(llm_name, config))
|
||||
|
||||
print(f"\n找到 {len(all_tasks)} 个可用 LLM 模块")
|
||||
print("\n开始并发测试所有模块...\n")
|
||||
|
||||
# 并发执行所有测试任务,但为每个任务设置独立超时
|
||||
async def test_with_timeout(task, timeout=30):
|
||||
"""为每个测试任务添加超时保护"""
|
||||
try:
|
||||
return await asyncio.wait_for(task, timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
print(f"测试任务超时({timeout}秒),跳过")
|
||||
return {
|
||||
"name": "Unknown",
|
||||
"type": "llm",
|
||||
"errors": 1,
|
||||
"error_type": "超时连接",
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"测试任务异常: {str(e)}")
|
||||
return {
|
||||
"name": "Unknown",
|
||||
"type": "llm",
|
||||
"errors": 1,
|
||||
"error_type": "网络错误",
|
||||
}
|
||||
|
||||
# 为每个任务包装超时保护
|
||||
protected_tasks = [test_with_timeout(task) for task in all_tasks]
|
||||
|
||||
# 并发执行所有测试任务
|
||||
all_results = await asyncio.gather(*protected_tasks, return_exceptions=True)
|
||||
|
||||
# 处理结果
|
||||
for result in all_results:
|
||||
if isinstance(result, dict):
|
||||
if result.get("errors") == 0:
|
||||
self.results[result["name"]] = result
|
||||
else:
|
||||
# 即使有错误也记录,用于显示失败状态
|
||||
if result.get("name") != "Unknown":
|
||||
self.results[result["name"]] = result
|
||||
elif isinstance(result, Exception):
|
||||
print(f"测试结果处理异常: {str(result)}")
|
||||
|
||||
# 打印结果
|
||||
print("\n生成测试报告...")
|
||||
self._print_results()
|
||||
|
||||
|
||||
async def main():
|
||||
tester = LLMPerformanceTester()
|
||||
await tester.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,473 @@
|
||||
import asyncio
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
import os
|
||||
import websockets
|
||||
import gzip
|
||||
import random
|
||||
from urllib import parse
|
||||
from tabulate import tabulate
|
||||
from config.settings import load_config
|
||||
import tempfile
|
||||
import wave
|
||||
import hmac
|
||||
import base64
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from wsgiref.handlers import format_date_time
|
||||
from time import mktime
|
||||
description = "流式ASR首词延迟测试"
|
||||
try:
|
||||
import dashscope
|
||||
except ImportError:
|
||||
dashscope = None
|
||||
|
||||
class BaseASRTester:
|
||||
def __init__(self, config_key: str):
|
||||
self.config = load_config()
|
||||
self.config_key = config_key
|
||||
self.asr_config = self.config.get("ASR", {}).get(config_key, {})
|
||||
self.test_audio_files = self._load_test_audio_files()
|
||||
self.results = []
|
||||
|
||||
def _load_test_audio_files(self):
|
||||
audio_root = os.path.join(os.getcwd(), "config", "assets")
|
||||
test_files = []
|
||||
if os.path.exists(audio_root):
|
||||
for file_name in os.listdir(audio_root):
|
||||
if file_name.endswith(('.wav', '.pcm')):
|
||||
file_path = os.path.join(audio_root, file_name)
|
||||
with open(file_path, 'rb') as f:
|
||||
test_files.append({
|
||||
'data': f.read(),
|
||||
'path': file_path,
|
||||
'name': file_name
|
||||
})
|
||||
return test_files
|
||||
|
||||
async def test(self, test_count=5):
|
||||
raise NotImplementedError
|
||||
|
||||
def _calculate_result(self, service_name, latencies, test_count):
|
||||
valid_latencies = [l for l in latencies if l > 0]
|
||||
if valid_latencies:
|
||||
avg_latency = sum(valid_latencies) / len(valid_latencies)
|
||||
status = f"成功({len(valid_latencies)}/{test_count}次有效)"
|
||||
else:
|
||||
avg_latency = 0
|
||||
status = "失败: 所有测试均失败"
|
||||
return {"name": service_name, "latency": avg_latency, "status": status}
|
||||
|
||||
|
||||
class DoubaoStreamASRTester(BaseASRTester):
|
||||
def __init__(self):
|
||||
super().__init__("DoubaoStreamASR")
|
||||
|
||||
def _generate_header(self):
|
||||
header = bytearray()
|
||||
header.append((0x01 << 4) | 0x01)
|
||||
header.append((0x01 << 4) | 0x00)
|
||||
header.append((0x01 << 4) | 0x01)
|
||||
header.append(0x00)
|
||||
return header
|
||||
|
||||
def _generate_audio_default_header(self):
|
||||
return self._generate_header()
|
||||
|
||||
def _parse_response(self, res: bytes) -> dict:
|
||||
try:
|
||||
if len(res) < 4:
|
||||
return {"error": "响应数据长度不足"}
|
||||
header = res[:4]
|
||||
message_type = header[1] >> 4
|
||||
if message_type == 0x0F:
|
||||
code = int.from_bytes(res[4:8], "big", signed=False)
|
||||
msg_length = int.from_bytes(res[8:12], "big", signed=False)
|
||||
error_msg = json.loads(res[12:].decode("utf-8"))
|
||||
return {
|
||||
"code": code,
|
||||
"msg_length": msg_length,
|
||||
"payload_msg": error_msg
|
||||
}
|
||||
try:
|
||||
json_data = res[12:].decode("utf-8")
|
||||
return {"payload_msg": json.loads(json_data)}
|
||||
except (UnicodeDecodeError, json.JSONDecodeError):
|
||||
return {"error": "JSON解析失败"}
|
||||
except Exception:
|
||||
return {"error": "解析响应失败"}
|
||||
|
||||
async def test(self, test_count=5):
|
||||
if not self.test_audio_files:
|
||||
return {"name": "豆包流式ASR", "latency": 0, "status": "失败: 未找到测试音频"}
|
||||
if not self.asr_config:
|
||||
return {"name": "豆包流式ASR", "latency": 0, "status": "失败: 未配置"}
|
||||
|
||||
latencies = []
|
||||
for i in range(test_count):
|
||||
try:
|
||||
ws_url = "wss://openspeech.bytedance.com/api/v3/sauc/bigmodel"
|
||||
appid = self.asr_config["appid"]
|
||||
access_token = self.asr_config["access_token"]
|
||||
uid = self.asr_config.get("uid", "streaming_asr_service")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
headers = {
|
||||
"X-Api-App-Key": appid,
|
||||
"X-Api-Access-Key": access_token,
|
||||
"X-Api-Resource-Id": "volc.bigasr.sauc.duration",
|
||||
"X-Api-Connect-Id": str(uuid.uuid4())
|
||||
}
|
||||
|
||||
async with websockets.connect(
|
||||
ws_url,
|
||||
additional_headers=headers,
|
||||
max_size=1000000000,
|
||||
ping_interval=None,
|
||||
ping_timeout=None,
|
||||
close_timeout=10
|
||||
) as ws:
|
||||
request_params = {
|
||||
"app": {"appid": appid, "token": access_token},
|
||||
"user": {"uid": uid},
|
||||
"request": {
|
||||
"reqid": str(uuid.uuid4()),
|
||||
"workflow": "audio_in,resample,partition,vad,fe,decode,itn,nlu_punctuate",
|
||||
"show_utterances": True,
|
||||
"result_type": "single",
|
||||
"sequence": 1
|
||||
},
|
||||
"audio": {
|
||||
"format": "pcm",
|
||||
"codec": "pcm",
|
||||
"rate": 16000,
|
||||
"language": "zh-CN",
|
||||
"bits": 16,
|
||||
"channel": 1,
|
||||
"sample_rate": 16000
|
||||
}
|
||||
}
|
||||
|
||||
payload_bytes = str.encode(json.dumps(request_params))
|
||||
payload_bytes = gzip.compress(payload_bytes)
|
||||
full_client_request = self._generate_header()
|
||||
full_client_request.extend((len(payload_bytes)).to_bytes(4, "big"))
|
||||
full_client_request.extend(payload_bytes)
|
||||
await ws.send(full_client_request)
|
||||
|
||||
init_res = await ws.recv()
|
||||
result = self._parse_response(init_res)
|
||||
if "code" in result and result["code"] != 1000:
|
||||
raise Exception(f"初始化失败: {result.get('payload_msg', {}).get('error', '未知错误')}")
|
||||
|
||||
audio_data = self.test_audio_files[0]['data']
|
||||
if audio_data.startswith(b'RIFF'):
|
||||
audio_data = audio_data[44:]
|
||||
|
||||
payload = gzip.compress(audio_data)
|
||||
audio_request = bytearray(self._generate_audio_default_header())
|
||||
audio_request.extend(len(payload).to_bytes(4, "big"))
|
||||
audio_request.extend(payload)
|
||||
await ws.send(audio_request)
|
||||
|
||||
first_chunk = await ws.recv()
|
||||
latency = time.time() - start_time
|
||||
latencies.append(latency)
|
||||
await ws.close()
|
||||
|
||||
except Exception as e:
|
||||
print(f"[豆包ASR] 第{i+1}次测试失败: {str(e)}")
|
||||
latencies.append(0)
|
||||
|
||||
return self._calculate_result("豆包流式ASR", latencies, test_count)
|
||||
|
||||
|
||||
class QwenASRFlashTester(BaseASRTester):
|
||||
def __init__(self):
|
||||
super().__init__("Qwen3ASRFlash")
|
||||
|
||||
async def _test_single(self, audio_file_info):
|
||||
start_time = time.time()
|
||||
temp_file_path = None
|
||||
|
||||
try:
|
||||
audio_data = audio_file_info['data']
|
||||
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f:
|
||||
temp_file_path = f.name
|
||||
|
||||
with wave.open(temp_file_path, 'wb') as wav_file:
|
||||
wav_file.setnchannels(1)
|
||||
wav_file.setsampwidth(2)
|
||||
wav_file.setframerate(16000)
|
||||
wav_file.writeframes(audio_data)
|
||||
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"audio": temp_file_path}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
api_key = self.asr_config.get("api_key") or os.getenv("DASHSCOPE_API_KEY")
|
||||
if not api_key:
|
||||
raise ValueError("未配置 api_key")
|
||||
|
||||
if dashscope is None:
|
||||
raise RuntimeError("未安装 dashscope 库")
|
||||
|
||||
dashscope.api_key = api_key
|
||||
|
||||
response = dashscope.MultiModalConversation.call(
|
||||
model="qwen3-asr-flash",
|
||||
messages=messages,
|
||||
result_format="message",
|
||||
asr_options={"enable_lid": True, "enable_itn": False},
|
||||
stream=True
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
latency = time.time() - start_time
|
||||
return latency
|
||||
|
||||
raise Exception("流式结束,未收到任何响应")
|
||||
|
||||
except Exception as e:
|
||||
raise Exception(f"通义ASR流式失败: {str(e)}")
|
||||
|
||||
finally:
|
||||
if temp_file_path and os.path.exists(temp_file_path):
|
||||
try:
|
||||
os.unlink(temp_file_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
async def test(self, test_count=5):
|
||||
if not self.test_audio_files:
|
||||
return {"name": "通义千问ASR", "latency": 0, "status": "失败: 未找到测试音频"}
|
||||
if not self.asr_config and not os.getenv("DASHSCOPE_API_KEY"):
|
||||
return {"name": "通义千问ASR", "latency": 0, "status": "失败: 未配置 api_key"}
|
||||
|
||||
latencies = []
|
||||
for i in range(test_count):
|
||||
try:
|
||||
# print(f"\n[通义ASR] 开始第 {i+1} 次测试...")
|
||||
latency = await self._test_single(self.test_audio_files[0])
|
||||
latencies.append(latency)
|
||||
# print(f"[通义ASR] 第{i+1}次成功 延迟: {latency:.3f}s")
|
||||
except Exception as e:
|
||||
# print(f"[通义ASR] 第{i+1}次测试失败: {str(e)}")
|
||||
latencies.append(0)
|
||||
|
||||
return self._calculate_result("通义千问ASR", latencies, test_count)
|
||||
|
||||
|
||||
class XunfeiStreamASRTester(BaseASRTester):
|
||||
def __init__(self):
|
||||
super().__init__("XunfeiStreamASR")
|
||||
|
||||
def _create_url(self):
|
||||
"""生成讯飞ASR认证URL"""
|
||||
url = 'ws://iat.cn-huabei-1.xf-yun.com/v1'
|
||||
# 生成RFC1123格式的时间戳
|
||||
now = datetime.now()
|
||||
date = format_date_time(mktime(now.timetuple()))
|
||||
|
||||
# 拼接字符串
|
||||
signature_origin = "host: " + "iat.cn-huabei-1.xf-yun.com" + "\n"
|
||||
signature_origin += "date: " + date + "\n"
|
||||
signature_origin += "GET " + "/v1 " + "HTTP/1.1"
|
||||
|
||||
# 进行hmac-sha256进行加密
|
||||
signature_sha = hmac.new(self.asr_config["api_secret"].encode('utf-8'), signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256).digest()
|
||||
signature_sha = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
authorization_origin = "api_key=\"%s\", algorithm=\"%s\", headers=\"%s\", signature=\"%s\"" % (
|
||||
self.asr_config["api_key"], "hmac-sha256", "host date request-line", signature_sha)
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
# 将请求的鉴权参数组合为字典
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": "iat.cn-huabei-1.xf-yun.com"
|
||||
}
|
||||
|
||||
# 拼接鉴权参数,生成url
|
||||
url = url + '?' + parse.urlencode(v)
|
||||
return url
|
||||
|
||||
async def test(self, test_count=5):
|
||||
if not self.test_audio_files:
|
||||
return {"name": "讯飞流式ASR", "latency": 0, "status": "失败: 未找到测试音频"}
|
||||
if not self.asr_config:
|
||||
return {"name": "讯飞流式ASR", "latency": 0, "status": "失败: 未配置"}
|
||||
|
||||
# 检查必要的配置参数
|
||||
required_keys = ["app_id", "api_key", "api_secret"]
|
||||
for key in required_keys:
|
||||
if key not in self.asr_config:
|
||||
return {"name": "讯飞流式ASR", "latency": 0, "status": f"失败: 缺少配置项 {key}"}
|
||||
|
||||
latencies = []
|
||||
STATUS_FIRST_FRAME = 0
|
||||
|
||||
for i in range(test_count):
|
||||
try:
|
||||
# 生成认证URL
|
||||
ws_url = self._create_url()
|
||||
|
||||
# 获取音频数据
|
||||
audio_data = self.test_audio_files[0]['data']
|
||||
if audio_data.startswith(b'RIFF'):
|
||||
audio_data = audio_data[44:] # 跳过WAV文件头
|
||||
|
||||
# 识别参数
|
||||
iat_params = {
|
||||
"domain": self.asr_config.get("domain", "slm"),
|
||||
"language": self.asr_config.get("language", "zh_cn"),
|
||||
"accent": self.asr_config.get("accent", "mandarin"),
|
||||
"dwa": self.asr_config.get("dwa", "wpgs"),
|
||||
"result": {
|
||||
"encoding": "utf8",
|
||||
"compress": "raw",
|
||||
"format": "plain"
|
||||
}
|
||||
}
|
||||
|
||||
# 准备首帧数据
|
||||
first_frame_data = {
|
||||
"header": {
|
||||
"status": STATUS_FIRST_FRAME,
|
||||
"app_id": self.asr_config["app_id"]
|
||||
},
|
||||
"parameter": {
|
||||
"iat": iat_params
|
||||
},
|
||||
"payload": {
|
||||
"audio": {
|
||||
"audio": base64.b64encode(audio_data[:960]).decode('utf-8'),
|
||||
"sample_rate": 16000,
|
||||
"encoding": "raw"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
# 启动连接并测量时间
|
||||
start_time = time.time()
|
||||
|
||||
async with websockets.connect(
|
||||
ws_url,
|
||||
max_size=1000000000,
|
||||
ping_interval=None,
|
||||
ping_timeout=None,
|
||||
close_timeout=30,
|
||||
) as ws:
|
||||
# 发送首帧数据
|
||||
await ws.send(json.dumps(first_frame_data, ensure_ascii=False))
|
||||
print(f"[讯飞ASR] 第{i+1}次测试:已发送首帧,等待响应...")
|
||||
|
||||
# 直接等待第一个响应并计算延迟
|
||||
# 参考豆包和通义千问的实现方式,简化逻辑
|
||||
response_received = False
|
||||
while not response_received:
|
||||
try:
|
||||
# 设置较大的超时时间
|
||||
response = await asyncio.wait_for(ws.recv(), timeout=30.0)
|
||||
|
||||
# 收到响应立即计算延迟,不管内容是什么
|
||||
# 这样可以准确测量首包到达时间
|
||||
latency = time.time() - start_time
|
||||
latencies.append(latency)
|
||||
response_received = True
|
||||
|
||||
print(f"[讯飞ASR] 第{i+1}次测试:收到首包响应,延迟: {latency:.3f}s")
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
print(f"[讯飞ASR] 第{i+1}次测试:响应超时")
|
||||
raise Exception("获取响应超时")
|
||||
except Exception as e:
|
||||
print(f"[讯飞ASR] 第{i+1}次测试失败: {str(e)}")
|
||||
latencies.append(0)
|
||||
|
||||
return self._calculate_result("讯飞流式ASR", latencies, test_count)
|
||||
|
||||
class ASRPerformanceSuite:
|
||||
def __init__(self):
|
||||
self.testers = []
|
||||
self.results = []
|
||||
|
||||
def register_tester(self, tester_class):
|
||||
try:
|
||||
tester = tester_class()
|
||||
self.testers.append(tester)
|
||||
print(f"已注册测试器: {tester.config_key}")
|
||||
except Exception as e:
|
||||
name_map = {
|
||||
"DoubaoStreamASRTester": "豆包流式ASR",
|
||||
"QwenASRFlashTester": "通义千问ASR",
|
||||
"XunfeiStreamASRTester": "讯飞流式ASR"
|
||||
}
|
||||
name = name_map.get(tester_class.__name__, tester_class.__name__)
|
||||
print(f"跳过 {name}: {str(e)}")
|
||||
|
||||
def _print_results(self, test_count):
|
||||
if not self.results:
|
||||
print("没有有效的ASR测试结果")
|
||||
return
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("流式ASR首词响应时间测试结果")
|
||||
print(f"{'='*60}")
|
||||
print(f"测试次数: 每个ASR服务测试 {test_count} 次")
|
||||
|
||||
success_results = sorted(
|
||||
[r for r in self.results if "成功" in r["status"]],
|
||||
key=lambda x: x["latency"]
|
||||
)
|
||||
failed_results = [r for r in self.results if "成功" not in r["status"]]
|
||||
|
||||
table_data = [
|
||||
[r["name"], f"{r['latency']:.3f}s" if r['latency'] > 0 else "N/A", r["status"]]
|
||||
for r in success_results + failed_results
|
||||
]
|
||||
|
||||
print(tabulate(table_data, headers=["ASR服务", "首词延迟", "状态"], tablefmt="grid"))
|
||||
print("\n测试说明:")
|
||||
print("- 测量从发送请求到接收第一个有效识别文本的时间")
|
||||
print("- 超时控制: DashScope 默认超时,豆包 WebSocket 超时10秒")
|
||||
print("- 排序规则: 成功的按延迟升序,失败的排在后面")
|
||||
|
||||
async def run(self, test_count=5):
|
||||
print(f"开始流式ASR首词响应时间测试...")
|
||||
print(f"每个ASR服务测试次数: {test_count}次\n")
|
||||
|
||||
self.results = []
|
||||
for tester in self.testers:
|
||||
print(f"\n--- 测试 {tester.config_key} ---")
|
||||
result = await tester.test(test_count)
|
||||
self.results.append(result)
|
||||
|
||||
self._print_results(test_count)
|
||||
|
||||
|
||||
async def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="流式ASR首词响应时间测试工具")
|
||||
parser.add_argument("--count", type=int, default=5, help="测试次数")
|
||||
args = parser.parse_args()
|
||||
|
||||
suite = ASRPerformanceSuite()
|
||||
suite.register_tester(DoubaoStreamASRTester)
|
||||
suite.register_tester(QwenASRFlashTester)
|
||||
suite.register_tester(XunfeiStreamASRTester)
|
||||
|
||||
await suite.run(args.count)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,536 @@
|
||||
import asyncio
|
||||
import time
|
||||
import json
|
||||
import uuid
|
||||
import aiohttp
|
||||
import websockets
|
||||
import hmac
|
||||
import base64
|
||||
import hashlib
|
||||
import asyncio
|
||||
from urllib.parse import urlparse, urlencode
|
||||
from tabulate import tabulate
|
||||
from config.settings import load_config
|
||||
|
||||
description = "流式TTS语音合成首词耗时测试"
|
||||
class StreamTTSPerformanceTester:
|
||||
def __init__(self):
|
||||
self.config = load_config()
|
||||
self.test_texts = [
|
||||
"你好,这是一句话。"
|
||||
]
|
||||
self.results = []
|
||||
|
||||
async def test_aliyun_tts(self, text=None, test_count=5):
|
||||
"""测试阿里云流式TTS首词延迟(测试多次取平均)"""
|
||||
text = text or self.test_texts[0]
|
||||
latencies = []
|
||||
|
||||
for i in range(test_count):
|
||||
try:
|
||||
tts_config = self.config["TTS"]["AliyunStreamTTS"]
|
||||
appkey = tts_config["appkey"]
|
||||
token = tts_config["token"]
|
||||
voice = tts_config["voice"]
|
||||
host = tts_config["host"]
|
||||
ws_url = f"wss://{host}/ws/v1"
|
||||
|
||||
start_time = time.time()
|
||||
async with websockets.connect(ws_url, extra_headers={"X-NLS-Token": token}) as ws:
|
||||
task_id = str(uuid.uuid4())
|
||||
message_id = str(uuid.uuid4())
|
||||
|
||||
start_request = {
|
||||
"header": {
|
||||
"message_id": message_id,
|
||||
"task_id": task_id,
|
||||
"namespace": "FlowingSpeechSynthesizer",
|
||||
"name": "StartSynthesis",
|
||||
"appkey": appkey,
|
||||
},
|
||||
"payload": {
|
||||
"voice": voice,
|
||||
"format": "pcm",
|
||||
"sample_rate": 16000,
|
||||
"volume": 50,
|
||||
"speech_rate": 0,
|
||||
"pitch_rate": 0,
|
||||
}
|
||||
}
|
||||
await ws.send(json.dumps(start_request))
|
||||
|
||||
start_response = json.loads(await ws.recv())
|
||||
if start_response["header"]["name"] != "SynthesisStarted":
|
||||
raise Exception("启动合成失败")
|
||||
|
||||
run_request = {
|
||||
"header": {
|
||||
"message_id": str(uuid.uuid4()),
|
||||
"task_id": task_id,
|
||||
"namespace": "FlowingSpeechSynthesizer",
|
||||
"name": "RunSynthesis",
|
||||
"appkey": appkey,
|
||||
},
|
||||
"payload": {"text": text}
|
||||
}
|
||||
await ws.send(json.dumps(run_request))
|
||||
|
||||
while True:
|
||||
response = await ws.recv()
|
||||
if isinstance(response, bytes):
|
||||
latency = time.time() - start_time
|
||||
latencies.append(latency)
|
||||
break
|
||||
elif isinstance(response, str):
|
||||
data = json.loads(response)
|
||||
if data["header"]["name"] == "TaskFailed":
|
||||
raise Exception(f"合成失败: {data['payload']['error_info']}")
|
||||
|
||||
except Exception as e:
|
||||
latencies.append(0)
|
||||
|
||||
return self._calculate_result("阿里云TTS", latencies, test_count)
|
||||
|
||||
async def test_doubao_tts(self, text=None, test_count=5):
|
||||
"""测试火山引擎流式TTS首词延迟(测试多次取平均)"""
|
||||
text = text or self.test_texts[0]
|
||||
latencies = []
|
||||
|
||||
for i in range(test_count):
|
||||
try:
|
||||
tts_config = self.config["TTS"]["HuoshanDoubleStreamTTS"]
|
||||
ws_url = tts_config["ws_url"]
|
||||
app_id = tts_config["appid"]
|
||||
access_token = tts_config["access_token"]
|
||||
resource_id = tts_config["resource_id"]
|
||||
speaker = tts_config["speaker"]
|
||||
|
||||
start_time = time.time()
|
||||
ws_header = {
|
||||
"X-Api-App-Key": app_id,
|
||||
"X-Api-Access-Key": access_token,
|
||||
"X-Api-Resource-Id": resource_id,
|
||||
"X-Api-Connect-Id": str(uuid.uuid4()),
|
||||
}
|
||||
async with websockets.connect(ws_url, additional_headers=ws_header, max_size=1000000000) as ws:
|
||||
session_id = uuid.uuid4().hex
|
||||
|
||||
# 发送会话启动请求
|
||||
header = bytes([
|
||||
(0b0001 << 4) | 0b0001,
|
||||
0b0001 << 4 | 0b100,
|
||||
0b0001 << 4 | 0b0000,
|
||||
0
|
||||
])
|
||||
optional = bytearray()
|
||||
optional.extend((1).to_bytes(4, "big", signed=True))
|
||||
session_id_bytes = session_id.encode()
|
||||
optional.extend(len(session_id_bytes).to_bytes(4, "big", signed=True))
|
||||
optional.extend(session_id_bytes)
|
||||
payload = json.dumps({"speaker": speaker}).encode()
|
||||
await ws.send(header + optional + len(payload).to_bytes(4, "big", signed=True) + payload)
|
||||
|
||||
# 发送文本
|
||||
header = bytes([
|
||||
(0b0001 << 4) | 0b0001,
|
||||
0b0001 << 4 | 0b100,
|
||||
0b0001 << 4 | 0b0000,
|
||||
0
|
||||
])
|
||||
optional = bytearray()
|
||||
optional.extend((200).to_bytes(4, "big", signed=True))
|
||||
session_id_bytes = session_id.encode()
|
||||
optional.extend(len(session_id_bytes).to_bytes(4, "big", signed=True))
|
||||
optional.extend(session_id_bytes)
|
||||
payload = json.dumps({"text": text, "speaker": speaker}).encode()
|
||||
await ws.send(header + optional + len(payload).to_bytes(4, "big", signed=True) + payload)
|
||||
|
||||
first_chunk = await ws.recv()
|
||||
latency = time.time() - start_time
|
||||
latencies.append(latency)
|
||||
|
||||
except Exception as e:
|
||||
latencies.append(0)
|
||||
|
||||
return self._calculate_result("火山引擎TTS", latencies, test_count)
|
||||
|
||||
async def test_paddlespeech_tts(self, text=None, test_count=5):
|
||||
"""测试PaddleSpeech流式TTS首词延迟(测试多次取平均)"""
|
||||
text = text or self.test_texts[0]
|
||||
latencies = []
|
||||
|
||||
for i in range(test_count):
|
||||
try:
|
||||
tts_config = self.config["TTS"]["PaddleSpeechTTS"]
|
||||
tts_url = tts_config["url"]
|
||||
spk_id = tts_config["spk_id"]
|
||||
speed = tts_config["speed"]
|
||||
volume = tts_config["volume"]
|
||||
|
||||
start_time = time.time()
|
||||
async with websockets.connect(tts_url) as ws:
|
||||
# 发送开始请求
|
||||
await ws.send(json.dumps({
|
||||
"task": "tts",
|
||||
"signal": "start"
|
||||
}))
|
||||
|
||||
start_response = json.loads(await ws.recv())
|
||||
if start_response.get("status") != 0:
|
||||
raise Exception("连接失败")
|
||||
|
||||
# 发送文本数据
|
||||
await ws.send(json.dumps({
|
||||
"text": text,
|
||||
"spk_id": spk_id,
|
||||
"speed": speed,
|
||||
"volume": volume
|
||||
}))
|
||||
|
||||
# 接收第一个数据块
|
||||
first_chunk = await ws.recv()
|
||||
latency = time.time() - start_time
|
||||
latencies.append(latency)
|
||||
|
||||
# 发送结束请求
|
||||
end_request = {
|
||||
"task": "tts",
|
||||
"signal": "end"
|
||||
}
|
||||
await ws.send(json.dumps(end_request))
|
||||
|
||||
# 确保连接正常关闭
|
||||
try:
|
||||
await ws.recv()
|
||||
except websockets.exceptions.ConnectionClosedOK:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
latencies.append(0)
|
||||
|
||||
return self._calculate_result("PaddleSpeechTTS", latencies, test_count)
|
||||
|
||||
async def test_indexstream_tts(self, text=None, test_count=5):
|
||||
"""测试IndexStream流式TTS首词延迟(测试多次取平均)"""
|
||||
text = text or self.test_texts[0]
|
||||
latencies = []
|
||||
|
||||
for i in range(test_count):
|
||||
try:
|
||||
tts_config = self.config["TTS"]["IndexStreamTTS"]
|
||||
api_url = tts_config.get("api_url")
|
||||
voice = tts_config.get("voice")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
payload = {"text": text, "character": voice}
|
||||
async with session.post(api_url, json=payload, timeout=10) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"请求失败: {resp.status}, {await resp.text()}")
|
||||
|
||||
async for chunk in resp.content.iter_any():
|
||||
data = chunk[0] if isinstance(chunk, (list, tuple)) else chunk
|
||||
if not data:
|
||||
continue
|
||||
|
||||
latency = time.time() - start_time
|
||||
latencies.append(latency)
|
||||
resp.close()
|
||||
break
|
||||
else:
|
||||
latencies.append(0)
|
||||
|
||||
except Exception as e:
|
||||
latencies.append(0)
|
||||
|
||||
return self._calculate_result("IndexStreamTTS", latencies, test_count)
|
||||
|
||||
async def test_linkerai_tts(self, text=None, test_count=5):
|
||||
"""测试Linkerai流式TTS首词延迟(测试多次取平均)"""
|
||||
text = text or self.test_texts[0]
|
||||
latencies = []
|
||||
|
||||
for i in range(test_count):
|
||||
try:
|
||||
tts_config = self.config["TTS"]["LinkeraiTTS"]
|
||||
api_url = tts_config["api_url"]
|
||||
access_token = tts_config["access_token"]
|
||||
voice = tts_config["voice"]
|
||||
|
||||
start_time = time.time()
|
||||
async with aiohttp.ClientSession() as session:
|
||||
params = {
|
||||
"tts_text": text,
|
||||
"spk_id": voice,
|
||||
"frame_durition": 60,
|
||||
"stream": "true",
|
||||
"target_sr": 16000,
|
||||
"audio_format": "pcm",
|
||||
"instruct_text": "请生成一段自然流畅的语音",
|
||||
}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with session.get(api_url, params=params, headers=headers, timeout=10) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"请求失败: {resp.status}, {await resp.text()}")
|
||||
|
||||
# 接收第一个数据块
|
||||
async for _ in resp.content.iter_any():
|
||||
latency = time.time() - start_time
|
||||
latencies.append(latency)
|
||||
break
|
||||
else:
|
||||
latencies.append(0)
|
||||
|
||||
except Exception as e:
|
||||
latencies.append(0)
|
||||
|
||||
return self._calculate_result("LinkeraiTTS", latencies, test_count)
|
||||
|
||||
async def test_xunfei_tts(self, text=None, test_count=5):
|
||||
"""测试讯飞流式TTS首词延迟(测试多次取平均)"""
|
||||
text = text or self.test_texts[0]
|
||||
latencies = []
|
||||
|
||||
for i in range(test_count):
|
||||
try:
|
||||
# 修正配置节点名称,与配置文件中的XunFeiTTS匹配
|
||||
tts_config = self.config["TTS"]["XunFeiTTS"]
|
||||
app_id = tts_config["app_id"]
|
||||
api_key = tts_config["api_key"]
|
||||
api_secret = tts_config["api_secret"]
|
||||
api_url = tts_config.get("api_url", "wss://cbm01.cn-huabei-1.xf-yun.com/v1/private/mcd9m97e6")
|
||||
voice = tts_config.get("voice", "x5_lingxiaoxuan_flow")
|
||||
|
||||
# 生成认证URL
|
||||
auth_url = self._create_xunfei_auth_url(api_key, api_secret, api_url)
|
||||
|
||||
async with websockets.connect(
|
||||
auth_url,
|
||||
ping_interval=30,
|
||||
ping_timeout=10,
|
||||
close_timeout=10,
|
||||
max_size=1000000000
|
||||
) as ws:
|
||||
# 构造请求
|
||||
request = self._build_xunfei_request(app_id, text, voice)
|
||||
# 发送请求后立即计时,确保准确测量从发送文本到接收首块的时间
|
||||
await ws.send(json.dumps(request))
|
||||
start_time = time.time()
|
||||
|
||||
# 等待第一个音频数据块
|
||||
first_audio_received = False
|
||||
while not first_audio_received:
|
||||
msg = await asyncio.wait_for(ws.recv(), timeout=10)
|
||||
data = json.loads(msg)
|
||||
header = data.get("header", {})
|
||||
code = header.get("code")
|
||||
|
||||
if code != 0:
|
||||
message = header.get("message", "未知错误")
|
||||
raise Exception(f"合成失败: {code} - {message}")
|
||||
|
||||
payload = data.get("payload", {})
|
||||
audio_payload = payload.get("audio", {})
|
||||
|
||||
if audio_payload:
|
||||
status = audio_payload.get("status", 0)
|
||||
audio_data = audio_payload.get("audio", "")
|
||||
if status == 1 and audio_data:
|
||||
# 收到第一个音频数据块
|
||||
latency = time.time() - start_time
|
||||
latencies.append(latency)
|
||||
first_audio_received = True
|
||||
break
|
||||
except Exception as e:
|
||||
latencies.append(0)
|
||||
|
||||
return self._calculate_result("讯飞TTS", latencies, test_count)
|
||||
|
||||
def _create_xunfei_auth_url(self, api_key, api_secret, api_url):
|
||||
"""生成讯飞WebSocket认证URL"""
|
||||
parsed_url = urlparse(api_url)
|
||||
host = parsed_url.netloc
|
||||
path = parsed_url.path
|
||||
|
||||
# 获取UTC时间,讯飞要求使用RFC1123格式
|
||||
now = time.gmtime()
|
||||
date = time.strftime('%a, %d %b %Y %H:%M:%S GMT', now)
|
||||
|
||||
# 构造签名字符串
|
||||
signature_origin = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1"
|
||||
|
||||
# 计算签名
|
||||
signature_sha = hmac.new(
|
||||
api_secret.encode('utf-8'),
|
||||
signature_origin.encode('utf-8'),
|
||||
digestmod=hashlib.sha256
|
||||
).digest()
|
||||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||||
|
||||
# 构造authorization
|
||||
authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||||
|
||||
# 构造最终的WebSocket URL
|
||||
v = {
|
||||
"authorization": authorization,
|
||||
"date": date,
|
||||
"host": host
|
||||
}
|
||||
url = api_url + '?' + urlencode(v)
|
||||
return url
|
||||
|
||||
def _build_xunfei_request(self, app_id, text, voice):
|
||||
"""构建讯飞TTS请求结构"""
|
||||
return {
|
||||
"header": {
|
||||
"app_id": app_id,
|
||||
"status": 2,
|
||||
},
|
||||
"parameter": {
|
||||
"oral": {
|
||||
"oral_level": "mid",
|
||||
"spark_assist": 1,
|
||||
"stop_split": 0,
|
||||
"remain": 0
|
||||
},
|
||||
"tts": {
|
||||
"vcn": voice,
|
||||
"speed": 50,
|
||||
"volume": 50,
|
||||
"pitch": 50,
|
||||
"bgs": 0,
|
||||
"reg": 0,
|
||||
"rdn": 0,
|
||||
"rhy": 0,
|
||||
"audio": {
|
||||
"encoding": "raw",
|
||||
"sample_rate": 24000,
|
||||
"channels": 1,
|
||||
"bit_depth": 16,
|
||||
"frame_size": 0
|
||||
}
|
||||
}
|
||||
},
|
||||
"payload": {
|
||||
"text": {
|
||||
"encoding": "utf8",
|
||||
"compress": "raw",
|
||||
"format": "plain",
|
||||
"status": 2,
|
||||
"seq": 1,
|
||||
"text": base64.b64encode(text.encode('utf-8')).decode('utf-8')
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
def _calculate_result(self, service_name, latencies, test_count):
|
||||
"""计算测试结果"""
|
||||
valid_latencies = [l for l in latencies if l > 0]
|
||||
if valid_latencies:
|
||||
avg_latency = sum(valid_latencies) / len(valid_latencies)
|
||||
status = f"成功({len(valid_latencies)}/{test_count}次有效)"
|
||||
else:
|
||||
avg_latency = 0
|
||||
status = "失败: 所有测试均失败"
|
||||
return {"name": service_name, "latency": avg_latency, "status": status}
|
||||
|
||||
def _print_results(self, test_text, test_count):
|
||||
"""打印测试结果"""
|
||||
if not self.results:
|
||||
print("没有有效的TTS测试结果")
|
||||
return
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("流式TTS首词延迟测试结果")
|
||||
print(f"{'='*60}")
|
||||
print(f"测试文本: {test_text}")
|
||||
print(f"测试次数: 每个TTS服务测试 {test_count} 次")
|
||||
|
||||
# 排序结果:成功优先,按延迟升序
|
||||
success_results = sorted(
|
||||
[r for r in self.results if "成功" in r["status"]],
|
||||
key=lambda x: x["latency"]
|
||||
)
|
||||
failed_results = [r for r in self.results if "成功" not in r["status"]]
|
||||
|
||||
table_data = [
|
||||
[r["name"], f"{r['latency']:.3f}", r["status"]]
|
||||
for r in success_results + failed_results
|
||||
]
|
||||
|
||||
print(tabulate(table_data, headers=["TTS服务", "首词延迟(秒)", "状态"], tablefmt="grid"))
|
||||
print("\n测试说明:测量从发送请求到接收第一个音频数据块的时间,取多次测试平均值")
|
||||
print("- 超时控制: 单个请求最大等待时间为10秒")
|
||||
print("- 错误处理: 无法连接和超时的列为网络错误")
|
||||
print("- 排序规则: 按平均耗时从快到慢排序")
|
||||
|
||||
|
||||
async def run(self, test_text=None, test_count=5):
|
||||
"""执行测试
|
||||
|
||||
Args:
|
||||
test_text: 要测试的文本,如果为None则使用默认文本
|
||||
test_count: 每个TTS服务的测试次数
|
||||
"""
|
||||
test_text = test_text or self.test_texts[0]
|
||||
print(f"开始流式TTS首词延迟测试...")
|
||||
print(f"测试文本: {test_text}")
|
||||
print(f"每个TTS服务测试次数: {test_count}次")
|
||||
|
||||
if not self.config.get("TTS"):
|
||||
print("配置文件中未找到TTS配置")
|
||||
return
|
||||
|
||||
# 测试每种TTS服务
|
||||
self.results = []
|
||||
|
||||
# 测试阿里云TTS
|
||||
result = await self.test_aliyun_tts(test_text, test_count)
|
||||
self.results.append(result)
|
||||
|
||||
# 测试火山引擎TTS
|
||||
result = await self.test_doubao_tts(test_text, test_count)
|
||||
self.results.append(result)
|
||||
|
||||
# 测试PaddleSpeech TTS
|
||||
result = await self.test_paddlespeech_tts(test_text, test_count)
|
||||
self.results.append(result)
|
||||
|
||||
# 测试Linkerai TTS
|
||||
result = await self.test_linkerai_tts(test_text, test_count)
|
||||
self.results.append(result)
|
||||
|
||||
# 测试IndexStreamTTS
|
||||
result = await self.test_indexstream_tts(test_text, test_count)
|
||||
self.results.append(result)
|
||||
|
||||
# 测试讯飞TTS
|
||||
if self.config.get("TTS", {}).get("XunFeiTTS"):
|
||||
result = await self.test_xunfei_tts(test_text, test_count)
|
||||
self.results.append(result)
|
||||
|
||||
# 打印结果
|
||||
self._print_results(test_text, test_count)
|
||||
|
||||
|
||||
async def main():
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="流式TTS首词延迟测试工具")
|
||||
parser.add_argument("--text", help="要测试的文本内容")
|
||||
parser.add_argument("--count", type=int, default=5, help="每个TTS服务的测试次数")
|
||||
|
||||
args = parser.parse_args()
|
||||
await StreamTTSPerformanceTester().run(args.text, args.count)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import asyncio
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,183 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Dict
|
||||
import yaml
|
||||
from tabulate import tabulate
|
||||
|
||||
# 确保从 core.utils.tts 导入 create_tts_instance
|
||||
from core.utils.tts import create_instance as create_tts_instance
|
||||
from config.settings import load_config
|
||||
|
||||
# 设置全局日志级别为 WARNING
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
description = "非流式语音合成性能测试"
|
||||
|
||||
|
||||
class TTSPerformanceTester:
|
||||
def __init__(self):
|
||||
self.config = load_config()
|
||||
self.test_sentences = self.config.get("module_test", {}).get(
|
||||
"test_sentences",
|
||||
[
|
||||
"永和九年,岁在癸丑,暮春之初;",
|
||||
"夫人之相与,俯仰一世,或取诸怀抱,悟言一室之内;或因寄所托,放浪形骸之外。虽趣舍万殊,静躁不同,",
|
||||
"每览昔人兴感之由,若合一契,未尝不临文嗟悼,不能喻之于怀。固知一死生为虚诞,齐彭殇为妄作。",
|
||||
],
|
||||
)
|
||||
self.results = {}
|
||||
|
||||
async def _test_tts(self, tts_name: str, config: Dict) -> Dict:
|
||||
"""测试单个TTS模块的性能"""
|
||||
try:
|
||||
token_fields = ["access_token", "api_key", "token"]
|
||||
if any(
|
||||
field in config
|
||||
and any(x in config[field] for x in ["你的", "placeholder"])
|
||||
for field in token_fields
|
||||
):
|
||||
print(f"TTS {tts_name} 未配置access_token/api_key,已跳过")
|
||||
return {"name": tts_name, "errors": 1}
|
||||
|
||||
module_type = config.get("type", tts_name)
|
||||
tts = create_tts_instance(module_type, config, delete_audio_file=True)
|
||||
|
||||
print(f"测试 TTS: {tts_name}")
|
||||
|
||||
# 连接测试
|
||||
tmp_file = tts.generate_filename()
|
||||
await tts.text_to_speak("连接测试", tmp_file)
|
||||
|
||||
if not tmp_file or not os.path.exists(tmp_file):
|
||||
print(f"{tts_name} 连接失败")
|
||||
return {"name": tts_name, "errors": 1}
|
||||
|
||||
total_time = 0
|
||||
test_count = len(self.test_sentences[:3])
|
||||
|
||||
for i, sentence in enumerate(self.test_sentences[:2], 1):
|
||||
start = time.time()
|
||||
tmp_file = tts.generate_filename()
|
||||
await tts.text_to_speak(sentence, tmp_file)
|
||||
duration = time.time() - start
|
||||
total_time += duration
|
||||
|
||||
if tmp_file and os.path.exists(tmp_file):
|
||||
print(f"{tts_name} [{i}/{test_count}] 测试成功")
|
||||
else:
|
||||
print(f"{tts_name} [{i}/{test_count}] 测试失败")
|
||||
return {"name": tts_name, "errors": 1}
|
||||
|
||||
return {
|
||||
"name": tts_name,
|
||||
"avg_time": total_time / test_count,
|
||||
"errors": 0,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"{tts_name} 测试失败: {str(e)}")
|
||||
return {"name": tts_name, "errors": 1}
|
||||
|
||||
def _print_results(self):
|
||||
"""打印测试结果"""
|
||||
if not self.results:
|
||||
print("没有有效的TTS测试结果")
|
||||
return
|
||||
|
||||
headers = ["TTS模块", "平均耗时(秒)", "测试句子数", "状态"]
|
||||
table_data = []
|
||||
|
||||
# 收集所有数据并分类
|
||||
valid_results = []
|
||||
error_results = []
|
||||
|
||||
for name, data in self.results.items():
|
||||
if data["errors"] == 0:
|
||||
# 正常结果
|
||||
avg_time = f"{data['avg_time']:.3f}"
|
||||
test_count = len(self.test_sentences[:3])
|
||||
status = "✅ 正常"
|
||||
|
||||
# 保存用于排序的值
|
||||
valid_results.append({
|
||||
"name": name,
|
||||
"avg_time": avg_time,
|
||||
"test_count": test_count,
|
||||
"status": status,
|
||||
"sort_key": data['avg_time']
|
||||
})
|
||||
else:
|
||||
# 错误结果
|
||||
avg_time = "-"
|
||||
test_count = "0/3"
|
||||
|
||||
# 默认错误类型为网络错误
|
||||
error_type = "网络错误"
|
||||
status = f"❌ {error_type}"
|
||||
|
||||
error_results.append([name, avg_time, test_count, status])
|
||||
|
||||
# 按平均耗时升序排序
|
||||
valid_results.sort(key=lambda x: x["sort_key"])
|
||||
|
||||
# 将排序后的有效结果转换为表格数据
|
||||
for result in valid_results:
|
||||
table_data.append([
|
||||
result["name"],
|
||||
result["avg_time"],
|
||||
result["test_count"],
|
||||
result["status"]
|
||||
])
|
||||
|
||||
# 将错误结果添加到表格数据末尾
|
||||
table_data.extend(error_results)
|
||||
|
||||
print("\nTTS性能测试结果:")
|
||||
print(
|
||||
tabulate(
|
||||
table_data,
|
||||
headers=headers,
|
||||
tablefmt="grid",
|
||||
colalign=("left", "right", "right", "left"),
|
||||
)
|
||||
)
|
||||
print("\n测试说明:")
|
||||
print("- 超时控制: 单个请求最大等待时间为10秒")
|
||||
print("- 错误处理: 无法连接和超时的列为网络错误")
|
||||
print("- 排序规则: 按平均耗时从快到慢排序")
|
||||
|
||||
async def run(self):
|
||||
"""执行测试"""
|
||||
print("开始TTS性能测试...")
|
||||
|
||||
if not self.config.get("TTS"):
|
||||
print("配置文件中未找到TTS配置")
|
||||
return
|
||||
|
||||
# 遍历所有TTS配置
|
||||
tasks = []
|
||||
for tts_name, config in self.config.get("TTS", {}).items():
|
||||
tasks.append(self._test_tts(tts_name, config))
|
||||
|
||||
# 并发执行测试
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
# 保存所有结果,包括错误
|
||||
for result in results:
|
||||
self.results[result["name"]] = result
|
||||
|
||||
# 打印结果
|
||||
self._print_results()
|
||||
|
||||
|
||||
# 为了performance_tester.py的调用需求
|
||||
async def main():
|
||||
tester = TTSPerformanceTester()
|
||||
await tester.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tester = TTSPerformanceTester()
|
||||
asyncio.run(tester.run())
|
||||
@@ -0,0 +1,192 @@
|
||||
import time
|
||||
import asyncio
|
||||
import logging
|
||||
import statistics
|
||||
import base64
|
||||
from typing import Dict
|
||||
from tabulate import tabulate
|
||||
from core.utils.vllm import create_instance
|
||||
from config.settings import load_config
|
||||
|
||||
# 设置全局日志级别为WARNING,抑制INFO级别日志
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
|
||||
description = "视觉识别模型性能测试"
|
||||
|
||||
|
||||
class AsyncVisionPerformanceTester:
|
||||
def __init__(self):
|
||||
self.config = load_config()
|
||||
|
||||
self.test_images = [
|
||||
"../../docs/images/demo1.png",
|
||||
"../../docs/images/demo2.png",
|
||||
]
|
||||
self.test_questions = [
|
||||
"这张图片里有什么?",
|
||||
"请详细描述这张图片的内容",
|
||||
]
|
||||
|
||||
# 加载测试图片
|
||||
self.results = {"vllm": {}}
|
||||
|
||||
async def _test_vllm(self, vllm_name: str, config: Dict) -> Dict:
|
||||
"""异步测试单个视觉大模型性能"""
|
||||
try:
|
||||
# 检查API密钥配置
|
||||
if "api_key" in config and any(
|
||||
x in config["api_key"] for x in ["你的", "placeholder", "sk-xxx"]
|
||||
):
|
||||
print(f"⏭️ VLLM {vllm_name} 未配置api_key,已跳过")
|
||||
return {"name": vllm_name, "type": "vllm", "errors": 1}
|
||||
|
||||
# 获取实际类型(兼容旧配置)
|
||||
module_type = config.get("type", vllm_name)
|
||||
vllm = create_instance(module_type, config)
|
||||
|
||||
print(f"🖼️ 测试 VLLM: {vllm_name}")
|
||||
|
||||
# 创建所有测试任务
|
||||
test_tasks = []
|
||||
for question in self.test_questions:
|
||||
for image in self.test_images:
|
||||
test_tasks.append(
|
||||
self._test_single_vision(vllm_name, vllm, question, image)
|
||||
)
|
||||
|
||||
# 并发执行所有测试
|
||||
test_results = await asyncio.gather(*test_tasks)
|
||||
|
||||
# 处理结果
|
||||
valid_results = [r for r in test_results if r is not None]
|
||||
if not valid_results:
|
||||
print(f"⚠️ {vllm_name} 无有效数据,可能配置错误")
|
||||
return {"name": vllm_name, "type": "vllm", "errors": 1}
|
||||
|
||||
response_times = [r["response_time"] for r in valid_results]
|
||||
|
||||
# 过滤异常数据
|
||||
mean = statistics.mean(response_times)
|
||||
stdev = statistics.stdev(response_times) if len(response_times) > 1 else 0
|
||||
filtered_times = [t for t in response_times if t <= mean + 3 * stdev]
|
||||
|
||||
if len(filtered_times) < len(test_tasks) * 0.5:
|
||||
print(f"⚠️ {vllm_name} 有效数据不足,可能网络不稳定")
|
||||
return {"name": vllm_name, "type": "vllm", "errors": 1}
|
||||
|
||||
return {
|
||||
"name": vllm_name,
|
||||
"type": "vllm",
|
||||
"avg_response": sum(response_times) / len(response_times),
|
||||
"std_response": (
|
||||
statistics.stdev(response_times) if len(response_times) > 1 else 0
|
||||
),
|
||||
"errors": 0,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(f"⚠️ VLLM {vllm_name} 测试失败: {str(e)}")
|
||||
return {"name": vllm_name, "type": "vllm", "errors": 1}
|
||||
|
||||
async def _test_single_vision(
|
||||
self, vllm_name: str, vllm, question: str, image: str
|
||||
) -> Dict:
|
||||
"""测试单个视觉问题的性能"""
|
||||
try:
|
||||
print(f"📝 {vllm_name} 开始测试: {question[:20]}...")
|
||||
start_time = time.time()
|
||||
|
||||
# 读取图片并转换为base64
|
||||
with open(image, "rb") as image_file:
|
||||
image_data = image_file.read()
|
||||
image_base64 = base64.b64encode(image_data).decode("utf-8")
|
||||
|
||||
# 直接获取响应
|
||||
response = vllm.response(question, image_base64)
|
||||
response_time = time.time() - start_time
|
||||
print(f"✓ {vllm_name} 完成响应: {response_time:.3f}s")
|
||||
|
||||
return {
|
||||
"name": vllm_name,
|
||||
"type": "vllm",
|
||||
"response_time": response_time,
|
||||
}
|
||||
except Exception as e:
|
||||
print(f"⚠️ {vllm_name} 测试失败: {str(e)}")
|
||||
return None
|
||||
|
||||
def _print_results(self):
|
||||
"""打印测试结果"""
|
||||
vllm_table = []
|
||||
for name, data in self.results["vllm"].items():
|
||||
if data["errors"] == 0:
|
||||
stability = data["std_response"] / data["avg_response"]
|
||||
vllm_table.append(
|
||||
[
|
||||
name,
|
||||
f"{data['avg_response']:.3f}秒",
|
||||
f"{stability:.3f}",
|
||||
]
|
||||
)
|
||||
|
||||
if vllm_table:
|
||||
print("\n视觉大模型性能排行:\n")
|
||||
print(
|
||||
tabulate(
|
||||
vllm_table,
|
||||
headers=["模型名称", "响应耗时", "稳定性"],
|
||||
tablefmt="github",
|
||||
colalign=("left", "right", "right"),
|
||||
disable_numparse=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
print("\n⚠️ 没有可用的视觉大模型进行测试。")
|
||||
|
||||
async def run(self):
|
||||
"""执行全量异步测试"""
|
||||
print("🔍 开始筛选可用视觉大模型...")
|
||||
|
||||
if not self.test_images:
|
||||
print(f"\n⚠️ {self.image_root} 路径下没有图片文件,无法进行测试")
|
||||
return
|
||||
|
||||
# 创建所有测试任务
|
||||
all_tasks = []
|
||||
|
||||
# VLLM测试任务
|
||||
if self.config.get("VLLM") is not None:
|
||||
for vllm_name, config in self.config.get("VLLM", {}).items():
|
||||
if "api_key" in config and any(
|
||||
x in config["api_key"] for x in ["你的", "placeholder", "sk-xxx"]
|
||||
):
|
||||
print(f"⏭️ VLLM {vllm_name} 未配置api_key,已跳过")
|
||||
continue
|
||||
print(f"🖼️ 添加VLLM测试任务: {vllm_name}")
|
||||
all_tasks.append(self._test_vllm(vllm_name, config))
|
||||
|
||||
print(f"\n✅ 找到 {len(all_tasks)} 个可用视觉大模型")
|
||||
print(f"✅ 使用 {len(self.test_images)} 张测试图片")
|
||||
print(f"✅ 使用 {len(self.test_questions)} 个测试问题")
|
||||
print("\n⏳ 开始并发测试所有模型...\n")
|
||||
|
||||
# 并发执行所有测试任务
|
||||
all_results = await asyncio.gather(*all_tasks, return_exceptions=True)
|
||||
|
||||
# 处理结果
|
||||
for result in all_results:
|
||||
if isinstance(result, dict) and result["errors"] == 0:
|
||||
self.results["vllm"][result["name"]] = result
|
||||
|
||||
# 打印结果
|
||||
print("\n📊 生成测试报告...")
|
||||
self._print_results()
|
||||
|
||||
|
||||
async def main():
|
||||
tester = AsyncVisionPerformanceTester()
|
||||
await tester.run()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user