仿生人MCP接入点
This commit is contained in:
@@ -0,0 +1,5 @@
|
||||
"""
|
||||
工具模块
|
||||
"""
|
||||
|
||||
__version__ = "0.0.6"
|
||||
@@ -0,0 +1,117 @@
|
||||
import base64
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
import hashlib
|
||||
|
||||
|
||||
def pad_key(key: str) -> bytes:
|
||||
"""
|
||||
填充密钥到指定长度(16、24或32位)
|
||||
@param key: 原始密钥字符串
|
||||
@return: 填充后的密钥字节数组
|
||||
"""
|
||||
key_bytes = key.encode("utf-8")
|
||||
key_length = len(key_bytes)
|
||||
|
||||
if key_length == 16 or key_length == 24 or key_length == 32:
|
||||
return key_bytes
|
||||
|
||||
# 如果密钥长度不足,用0填充;如果超过,截取前32位
|
||||
padded_key = bytearray(32)
|
||||
padded_key[: min(key_length, 32)] = key_bytes[: min(key_length, 32)]
|
||||
return bytes(padded_key)
|
||||
|
||||
|
||||
def encrypt(key: str, plain_text: str) -> str:
|
||||
"""
|
||||
AES加密
|
||||
@param key: 密钥(16位、24位或32位)
|
||||
@param plain_text: 待加密字符串
|
||||
@return: 加密后的Base64字符串
|
||||
"""
|
||||
try:
|
||||
# 确保密钥长度为16、24或32位
|
||||
key_bytes = pad_key(key)
|
||||
cipher = AES.new(key_bytes, AES.MODE_ECB)
|
||||
|
||||
# 对明文进行PKCS7填充
|
||||
try:
|
||||
plain_bytes = plain_text.encode("utf-8")
|
||||
padded_data = pad(plain_bytes, AES.block_size)
|
||||
except Exception as e:
|
||||
raise ValueError(f"明文编码或填充失败: {str(e)}")
|
||||
|
||||
# 加密
|
||||
try:
|
||||
encrypted_bytes = cipher.encrypt(padded_data)
|
||||
except Exception as e:
|
||||
raise ValueError(f"加密失败: {str(e)}")
|
||||
|
||||
# Base64编码
|
||||
try:
|
||||
return base64.b64encode(encrypted_bytes).decode("utf-8")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Base64编码失败: {str(e)}")
|
||||
except ValueError as e:
|
||||
# 重新抛出ValueError,保持错误类型一致
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"加密过程中发生未知错误: {str(e)}")
|
||||
|
||||
|
||||
def decrypt(key: str, encrypted_text: str) -> str:
|
||||
"""
|
||||
AES解密
|
||||
@param key: 密钥(16位、24位或32位)
|
||||
@param encrypted_text: 待解密的Base64字符串
|
||||
@return: 解密后的字符串
|
||||
"""
|
||||
try:
|
||||
# 确保密钥长度为16、24或32位
|
||||
key_bytes = pad_key(key)
|
||||
cipher = AES.new(key_bytes, AES.MODE_ECB)
|
||||
|
||||
# 解码Base64
|
||||
try:
|
||||
encrypted_bytes = base64.b64decode(encrypted_text)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
# 解密
|
||||
try:
|
||||
decrypted_bytes = cipher.decrypt(encrypted_bytes)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
# 去除PKCS7填充
|
||||
try:
|
||||
unpadded_data = unpad(decrypted_bytes, AES.block_size)
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
return unpadded_data.decode("utf-8")
|
||||
except ValueError as e:
|
||||
# 重新抛出ValueError,保持错误类型一致
|
||||
return None
|
||||
except Exception as e:
|
||||
return None
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试代码
|
||||
test_key = "6a369b7f1bcf4d3e8d123ece38bb9627"
|
||||
test_text = '{"agentId": "test1"}'
|
||||
|
||||
print(f"原始文本: {test_text}")
|
||||
print(f"密钥: {test_key}")
|
||||
|
||||
# 加密
|
||||
encrypted = encrypt(test_key, test_text)
|
||||
print(f"加密结果: {encrypted}")
|
||||
|
||||
# 解密
|
||||
decrypted = decrypt(test_key, encrypted)
|
||||
print(f"解密结果: {decrypted}")
|
||||
|
||||
# 验证
|
||||
print(f"加解密一致性: {test_text == decrypted}")
|
||||
@@ -0,0 +1,131 @@
|
||||
"""
|
||||
配置管理工具
|
||||
"""
|
||||
|
||||
import os
|
||||
import configparser
|
||||
import uuid
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
class ConfigManager:
|
||||
"""配置管理器"""
|
||||
|
||||
def __init__(self, config_file: str = "data/.mcp-endpoint-server.cfg"):
|
||||
self.config_file = config_file
|
||||
self.config = configparser.ConfigParser()
|
||||
self._load_config()
|
||||
|
||||
def _load_config(self):
|
||||
"""加载配置文件"""
|
||||
if os.path.exists(self.config_file):
|
||||
self.config.read(self.config_file, encoding="utf-8")
|
||||
# 检查并生成key
|
||||
self._check_and_generate_key()
|
||||
else:
|
||||
# 如果配置文件不存在,从根目录拷贝
|
||||
self._copy_config_from_root()
|
||||
|
||||
def _copy_config_from_root(self):
|
||||
"""从根目录拷贝配置文件到data目录"""
|
||||
root_config = "mcp-endpoint-server.cfg"
|
||||
if os.path.exists(root_config):
|
||||
# 确保data目录存在
|
||||
os.makedirs(os.path.dirname(self.config_file), exist_ok=True)
|
||||
|
||||
# 拷贝配置文件
|
||||
import shutil
|
||||
|
||||
shutil.copy2(root_config, self.config_file)
|
||||
|
||||
# 重新加载配置
|
||||
self.config.read(self.config_file, encoding="utf-8")
|
||||
# 检查并生成key
|
||||
self._check_and_generate_key()
|
||||
else:
|
||||
# 如果根目录也没有配置文件,则创建默认配置
|
||||
self._create_default_config()
|
||||
|
||||
def _check_and_generate_key(self):
|
||||
"""检查key是否存在且长度足够,如果不足则生成新的"""
|
||||
try:
|
||||
current_key = self.config.get("server", "key", fallback="")
|
||||
if not current_key or len(current_key) < 32:
|
||||
# 生成32位随机密码
|
||||
new_key = self._generate_random_key()
|
||||
self.config.set("server", "key", new_key)
|
||||
|
||||
# 保存到配置文件
|
||||
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||
self.config.write(f)
|
||||
|
||||
print(f"已自动生成新的32位密钥: {new_key}")
|
||||
except Exception as e:
|
||||
print(f"检查密钥时发生错误: {e}")
|
||||
|
||||
def _generate_random_key(self) -> str:
|
||||
"""生成指定长度的随机密钥"""
|
||||
# 使用UUID生成密钥,移除连字符
|
||||
return str(uuid.uuid4()).replace("-", "")
|
||||
|
||||
def _create_default_config(self):
|
||||
"""创建默认配置"""
|
||||
self.config["server"] = {
|
||||
"host": "127.0.0.1",
|
||||
"port": "8004",
|
||||
"debug": "false",
|
||||
"log_level": "INFO",
|
||||
"key": self._generate_random_key(), # 生成默认密钥
|
||||
}
|
||||
|
||||
self.config["websocket"] = {
|
||||
"max_connections": "1000",
|
||||
"ping_interval": "30",
|
||||
"ping_timeout": "10",
|
||||
"close_timeout": "10",
|
||||
}
|
||||
|
||||
self.config["security"] = {"allowed_origins": "*", "enable_cors": "true"}
|
||||
|
||||
self.config["logging"] = {
|
||||
"log_file": "logs/mcp_server.log",
|
||||
"max_file_size": "10MB",
|
||||
"backup_count": "5",
|
||||
}
|
||||
|
||||
# 确保目录存在
|
||||
os.makedirs(os.path.dirname(self.config_file), exist_ok=True)
|
||||
|
||||
# 保存默认配置
|
||||
with open(self.config_file, "w", encoding="utf-8") as f:
|
||||
self.config.write(f)
|
||||
|
||||
def get(self, section: str, key: str, default: Optional[str] = None) -> str:
|
||||
"""获取配置值"""
|
||||
try:
|
||||
return self.config.get(section, key)
|
||||
except (configparser.NoSectionError, configparser.NoOptionError):
|
||||
return default
|
||||
|
||||
def getint(self, section: str, key: str, default: int = 0) -> int:
|
||||
"""获取整数配置值"""
|
||||
try:
|
||||
return self.config.getint(section, key)
|
||||
except (configparser.NoSectionError, configparser.NoOptionError, ValueError):
|
||||
return default
|
||||
|
||||
def getboolean(self, section: str, key: str, default: bool = False) -> bool:
|
||||
"""获取布尔配置值"""
|
||||
try:
|
||||
return self.config.getboolean(section, key)
|
||||
except (configparser.NoSectionError, configparser.NoOptionError, ValueError):
|
||||
return default
|
||||
|
||||
def reload(self):
|
||||
"""重新加载配置"""
|
||||
self._load_config()
|
||||
|
||||
|
||||
# 全局配置实例
|
||||
config = ConfigManager()
|
||||
@@ -0,0 +1,224 @@
|
||||
"""
|
||||
JSON-RPC 2.0 协议封装类
|
||||
用于统一处理JSON-RPC消息的格式化和解析
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from dataclasses import dataclass, asdict
|
||||
|
||||
|
||||
@dataclass
|
||||
class JSONRPCError:
|
||||
"""JSON-RPC错误对象"""
|
||||
|
||||
code: int
|
||||
message: str
|
||||
data: Optional[Any] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class JSONRPCRequest:
|
||||
"""JSON-RPC请求对象"""
|
||||
|
||||
method: str
|
||||
params: Optional[Union[Dict, list]] = None
|
||||
id: Optional[Union[str, int]] = None
|
||||
jsonrpc: str = "2.0"
|
||||
|
||||
|
||||
@dataclass
|
||||
class JSONRPCResponse:
|
||||
"""JSON-RPC响应对象"""
|
||||
|
||||
result: Optional[Any] = None
|
||||
error: Optional[JSONRPCError] = None
|
||||
id: Optional[Union[str, int]] = None
|
||||
jsonrpc: str = "2.0"
|
||||
|
||||
|
||||
class JSONRPCProtocol:
|
||||
"""JSON-RPC 2.0 协议封装类"""
|
||||
|
||||
# 预定义错误码
|
||||
PARSE_ERROR = -32700
|
||||
INVALID_REQUEST = -32600
|
||||
METHOD_NOT_FOUND = -32601
|
||||
INVALID_PARAMS = -32602
|
||||
INTERNAL_ERROR = -32603
|
||||
|
||||
# 自定义错误码
|
||||
TOOL_NOT_CONNECTED = -32001
|
||||
FORWARD_FAILED = -32002
|
||||
CONNECTION_ERROR = -32003
|
||||
AUTHENTICATION_ERROR = -32004
|
||||
|
||||
@staticmethod
|
||||
def create_request(
|
||||
method: str,
|
||||
params: Optional[Union[Dict, list]] = None,
|
||||
request_id: Optional[Union[str, int]] = None,
|
||||
) -> JSONRPCRequest:
|
||||
"""创建JSON-RPC请求"""
|
||||
return JSONRPCRequest(
|
||||
method=method, params=params, id=request_id, jsonrpc="2.0"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_success_response(
|
||||
result: Any, request_id: Optional[Union[str, int]] = None
|
||||
) -> JSONRPCResponse:
|
||||
"""创建成功响应"""
|
||||
return JSONRPCResponse(result=result, id=request_id, jsonrpc="2.0")
|
||||
|
||||
@staticmethod
|
||||
def create_error_response(
|
||||
error_code: int,
|
||||
error_message: str,
|
||||
error_data: Optional[Any] = None,
|
||||
request_id: Optional[Union[str, int]] = None,
|
||||
) -> JSONRPCResponse:
|
||||
"""创建错误响应"""
|
||||
error = JSONRPCError(code=error_code, message=error_message, data=error_data)
|
||||
return JSONRPCResponse(error=error, id=request_id, jsonrpc="2.0")
|
||||
|
||||
@staticmethod
|
||||
def create_notification(
|
||||
method: str, params: Optional[Union[Dict, list]] = None
|
||||
) -> JSONRPCRequest:
|
||||
"""创建通知消息(无ID的请求)"""
|
||||
return JSONRPCRequest(method=method, params=params, id=None, jsonrpc="2.0")
|
||||
|
||||
@staticmethod
|
||||
def to_dict(obj: Union[JSONRPCRequest, JSONRPCResponse]) -> Dict:
|
||||
"""将对象转换为字典"""
|
||||
return asdict(obj)
|
||||
|
||||
@staticmethod
|
||||
def to_json(
|
||||
obj: Union[JSONRPCRequest, JSONRPCResponse], ensure_ascii: bool = False
|
||||
) -> str:
|
||||
"""将对象转换为JSON字符串"""
|
||||
return json.dumps(asdict(obj), ensure_ascii=ensure_ascii)
|
||||
|
||||
@staticmethod
|
||||
def parse_request(json_str: str) -> Optional[JSONRPCRequest]:
|
||||
"""解析JSON-RPC请求"""
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
# 验证必需字段
|
||||
if "jsonrpc" not in data or data["jsonrpc"] != "2.0":
|
||||
return None
|
||||
if "method" not in data:
|
||||
return None
|
||||
|
||||
return JSONRPCRequest(
|
||||
method=data["method"],
|
||||
params=data.get("params"),
|
||||
id=data.get("id"),
|
||||
jsonrpc=data["jsonrpc"],
|
||||
)
|
||||
except (json.JSONDecodeError, KeyError, TypeError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def parse_response(json_str: str) -> Optional[JSONRPCResponse]:
|
||||
"""解析JSON-RPC响应"""
|
||||
try:
|
||||
data = json.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
return None
|
||||
|
||||
# 验证必需字段
|
||||
if "jsonrpc" not in data or data["jsonrpc"] != "2.0":
|
||||
return None
|
||||
|
||||
# 检查是否有result或error字段
|
||||
has_result = "result" in data
|
||||
has_error = "error" in data
|
||||
|
||||
if not has_result and not has_error:
|
||||
return None
|
||||
if has_result and has_error:
|
||||
return None
|
||||
|
||||
response = JSONRPCResponse(id=data.get("id"), jsonrpc=data["jsonrpc"])
|
||||
|
||||
if has_result:
|
||||
response.result = data["result"]
|
||||
else:
|
||||
error_data = data["error"]
|
||||
response.error = JSONRPCError(
|
||||
code=error_data["code"],
|
||||
message=error_data["message"],
|
||||
data=error_data.get("data"),
|
||||
)
|
||||
|
||||
return response
|
||||
except (json.JSONDecodeError, KeyError, TypeError):
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def is_valid_request(json_str: str) -> bool:
|
||||
"""验证是否为有效的JSON-RPC请求"""
|
||||
return JSONRPCProtocol.parse_request(json_str) is not None
|
||||
|
||||
@staticmethod
|
||||
def is_valid_response(json_str: str) -> bool:
|
||||
"""验证是否为有效的JSON-RPC响应"""
|
||||
return JSONRPCProtocol.parse_response(json_str) is not None
|
||||
|
||||
@staticmethod
|
||||
def is_notification(json_str: str) -> bool:
|
||||
"""检查是否为通知消息(无ID的请求)"""
|
||||
request = JSONRPCProtocol.parse_request(json_str)
|
||||
return request is not None and request.id is None
|
||||
|
||||
|
||||
def create_tool_not_connected_error(
|
||||
request_id: Optional[Union[str, int]] = None, agent_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""创建工具端未连接的错误消息"""
|
||||
error_data = (
|
||||
{"agent_id": agent_id, "details": "请求的工具端连接不存在或已断开"}
|
||||
if agent_id
|
||||
else None
|
||||
)
|
||||
|
||||
response = JSONRPCProtocol.create_error_response(
|
||||
error_code=JSONRPCProtocol.TOOL_NOT_CONNECTED,
|
||||
error_message="工具端未连接",
|
||||
error_data=error_data,
|
||||
request_id=request_id,
|
||||
)
|
||||
return JSONRPCProtocol.to_json(response, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_forward_failed_error(
|
||||
request_id: Optional[Union[str, int]] = None, agent_id: Optional[str] = None
|
||||
) -> str:
|
||||
"""创建转发失败的错误消息"""
|
||||
error_data = (
|
||||
{"agent_id": agent_id, "details": "消息转发过程中发生错误"}
|
||||
if agent_id
|
||||
else None
|
||||
)
|
||||
|
||||
response = JSONRPCProtocol.create_error_response(
|
||||
error_code=JSONRPCProtocol.FORWARD_FAILED,
|
||||
error_message="转发消息失败",
|
||||
error_data=error_data,
|
||||
request_id=request_id,
|
||||
)
|
||||
return JSONRPCProtocol.to_json(response, ensure_ascii=False)
|
||||
|
||||
|
||||
def create_authentication_error(message: str = "认证失败") -> str:
|
||||
"""创建认证错误消息"""
|
||||
response = JSONRPCProtocol.create_error_response(
|
||||
error_code=JSONRPCProtocol.AUTHENTICATION_ERROR, error_message=message
|
||||
)
|
||||
return JSONRPCProtocol.to_json(response, ensure_ascii=False)
|
||||
@@ -0,0 +1,143 @@
|
||||
"""
|
||||
日志管理工具
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from typing import Optional
|
||||
from loguru import logger
|
||||
from .config import config
|
||||
|
||||
# 版本号
|
||||
from . import __version__ as VERSION
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
"""拦截标准库日志并转发到loguru"""
|
||||
|
||||
def emit(self, record):
|
||||
# 获取对应的loguru级别
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
# 查找调用者
|
||||
frame, depth = sys._getframe(6), 6
|
||||
while frame and frame.f_code.co_filename == __file__:
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, record.getMessage()
|
||||
)
|
||||
|
||||
|
||||
class LoggerManager:
|
||||
"""日志管理器"""
|
||||
|
||||
def __init__(self):
|
||||
self._setup_logger()
|
||||
|
||||
def _setup_logger(self):
|
||||
"""设置日志器"""
|
||||
# 移除默认的处理器
|
||||
logger.remove()
|
||||
|
||||
# 自定义格式:时间[版本号][模块路径]-级别-消息
|
||||
# 不同部分使用不同颜色
|
||||
custom_format = (
|
||||
f"<green>{{time:YYMMDD HH:mm:ss}}</green>"
|
||||
f"<blue>[{VERSION}][{{name}}]</blue>"
|
||||
"<level>-{level}-</level>"
|
||||
"<green>{message}</green>"
|
||||
)
|
||||
|
||||
# 控制台处理器
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
format=custom_format,
|
||||
level=config.get("server", "log_level", "INFO"),
|
||||
colorize=True,
|
||||
backtrace=True,
|
||||
diagnose=True,
|
||||
enqueue=True,
|
||||
catch=True,
|
||||
)
|
||||
|
||||
# 文件处理器(不带颜色)
|
||||
log_file = config.get("logging", "log_file", "logs/mcp_server.log")
|
||||
if log_file:
|
||||
# 确保日志目录存在
|
||||
os.makedirs(os.path.dirname(log_file), exist_ok=True)
|
||||
|
||||
# 文件格式(不带颜色)
|
||||
file_format = f"{{time:YYYY-MM-DD HH:mm:ss}} [{VERSION}][{{name}}] {{level}} - {{message}}"
|
||||
|
||||
# 获取文件大小限制
|
||||
max_file_size = config.get("logging", "max_file_size", "10MB")
|
||||
max_bytes = self._parse_size(max_file_size)
|
||||
|
||||
# 获取备份数量
|
||||
backup_count = config.getint("logging", "backup_count", 5)
|
||||
|
||||
logger.add(
|
||||
log_file,
|
||||
format=file_format,
|
||||
level=config.get("server", "log_level", "INFO"),
|
||||
rotation=max_bytes,
|
||||
retention=backup_count,
|
||||
compression="zip",
|
||||
encoding="utf-8",
|
||||
enqueue=True,
|
||||
catch=True,
|
||||
)
|
||||
print("Logger initialized", flush=True)
|
||||
logger.info("Logger test message")
|
||||
|
||||
def _parse_size(self, size_str: str) -> int:
|
||||
"""解析文件大小字符串"""
|
||||
size_str = size_str.upper()
|
||||
if size_str.endswith("MB"):
|
||||
return int(float(size_str[:-2]) * 1024 * 1024)
|
||||
elif size_str.endswith("KB"):
|
||||
return int(float(size_str[:-2]) * 1024)
|
||||
elif size_str.endswith("B"):
|
||||
return int(size_str[:-1])
|
||||
else:
|
||||
return int(size_str)
|
||||
|
||||
def get_logger(self):
|
||||
"""获取日志器"""
|
||||
return logger
|
||||
|
||||
def reload(self):
|
||||
"""重新加载日志配置"""
|
||||
self._setup_logger()
|
||||
|
||||
def setup_uvicorn_logging(self):
|
||||
"""设置uvicorn日志拦截"""
|
||||
import logging
|
||||
|
||||
# 拦截标准库日志
|
||||
logging.basicConfig(handlers=[InterceptHandler()], level=0, force=True)
|
||||
|
||||
# 直接禁用uvicorn.access日志
|
||||
logging.getLogger("uvicorn.access").disabled = True
|
||||
logging.getLogger("uvicorn.access").propagate = False
|
||||
|
||||
# 拦截其他uvicorn日志
|
||||
for name in logging.root.manager.loggerDict.keys():
|
||||
if not name.startswith("uvicorn.access"):
|
||||
logging.getLogger(name).handlers = []
|
||||
logging.getLogger(name).propagate = True
|
||||
|
||||
|
||||
# 全局日志管理器实例
|
||||
logger_manager = LoggerManager()
|
||||
|
||||
|
||||
def get_logger(name: str = "mcp_server"):
|
||||
"""获取日志器"""
|
||||
return logger.bind(name=name)
|
||||
@@ -0,0 +1,13 @@
|
||||
import socket
|
||||
|
||||
|
||||
def get_local_ip():
|
||||
try:
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
|
||||
# Connect to Google's DNS servers
|
||||
s.connect(("8.8.8.8", 80))
|
||||
local_ip = s.getsockname()[0]
|
||||
s.close()
|
||||
return local_ip
|
||||
except Exception as e:
|
||||
return "127.0.0.1"
|
||||
Reference in New Issue
Block a user