完善牧安云哨-后端
This commit is contained in:
@@ -0,0 +1,143 @@
|
||||
import json
|
||||
import os
|
||||
import socket
|
||||
import ssl
|
||||
import sys
|
||||
import uuid
|
||||
|
||||
from aiomqtt import Client
|
||||
|
||||
from config.redis import redis_client
|
||||
from models.MqttTopic import MqttTopic
|
||||
|
||||
# ================= 配置区域 =================
|
||||
MQTT_BROKER = "ai.ronsunny.cn"
|
||||
MQTT_PORT = 8093
|
||||
MQTT_PASSWORD = "123456"
|
||||
TLS_CONTEXT = ssl.create_default_context()
|
||||
|
||||
# 默认连接后要订阅的 topic 配置
|
||||
DEFAULT_SUBSCRIPTIONS = [
|
||||
MqttTopic.from_parts(
|
||||
project=None,
|
||||
domain="status",
|
||||
device_type="edge",
|
||||
device_id=None,
|
||||
resource="info",
|
||||
)
|
||||
]
|
||||
# ===========================================
|
||||
|
||||
DEVICE_ID = None
|
||||
MQTT_CLIENT: Client | None = None # 全局客户端
|
||||
|
||||
# Windows 平台下切换到 SelectorEventLoop
|
||||
if sys.platform.lower() == "win32" or os.name.lower() == "nt":
|
||||
from asyncio import set_event_loop_policy, WindowsSelectorEventLoopPolicy
|
||||
|
||||
set_event_loop_policy(WindowsSelectorEventLoopPolicy())
|
||||
|
||||
|
||||
def get_device_id_simple():
|
||||
try:
|
||||
with open("/etc/machine-id") as f:
|
||||
mid = f.read().strip()
|
||||
if mid:
|
||||
return mid
|
||||
except Exception:
|
||||
pass
|
||||
hostname = socket.gethostname()
|
||||
mac = uuid.getnode()
|
||||
mac_str = ":".join(f"{(mac >> ele) & 0xff:02x}" for ele in range(40, -1, -8))
|
||||
return f"{hostname}|{mac_str}"
|
||||
|
||||
|
||||
# todo 这里需要订阅状态信息 设备发送信息 这里回复 vue前端发送指令 后端发送指令 设备接收指令
|
||||
# ------------------ MQTT 封装 ------------------
|
||||
|
||||
|
||||
async def mqtt_publish(
|
||||
project: str,
|
||||
domain: str,
|
||||
device_type: str,
|
||||
device_id: str,
|
||||
resource: str,
|
||||
payload: str,
|
||||
qos: int = 1,
|
||||
):
|
||||
"""发布消息(使用全局客户端)"""
|
||||
if not MQTT_CLIENT:
|
||||
raise RuntimeError("MQTT client is not initialized")
|
||||
topic = f"{project}/{domain}/{device_type}/{device_id}/{resource}"
|
||||
await MQTT_CLIENT.publish(topic, payload, qos=qos)
|
||||
print(f"Published to {topic}: {payload}")
|
||||
|
||||
|
||||
async def mqtt_publish_multiple(
|
||||
targets: list[dict], resource: str, payload: str, qos: int = 1
|
||||
):
|
||||
"""群发消息"""
|
||||
for target in targets:
|
||||
await mqtt_publish(
|
||||
domain=target["domain"],
|
||||
device_type=target["device_type"],
|
||||
device_id=target["device_id"],
|
||||
resource=resource,
|
||||
payload=payload,
|
||||
qos=qos,
|
||||
)
|
||||
|
||||
|
||||
async def _mqtt_handle_messages():
|
||||
"""后台循环处理消息"""
|
||||
if not MQTT_CLIENT:
|
||||
return
|
||||
async for message in MQTT_CLIENT.messages:
|
||||
topic = MqttTopic(message.topic)
|
||||
print("收到消息:" + str(topic))
|
||||
|
||||
# 处理基础状态信息
|
||||
if topic.domain == "status" and topic.resource == "info":
|
||||
payload = json.loads(message.payload.decode())
|
||||
redis_client.set_device_info(topic.device_id, payload)
|
||||
|
||||
|
||||
async def mqtt_client_async():
|
||||
global DEVICE_ID, MQTT_CLIENT
|
||||
DEVICE_ID = get_device_id_simple()
|
||||
print("服务端EMQX账号:", DEVICE_ID)
|
||||
async with Client(
|
||||
MQTT_BROKER,
|
||||
port=MQTT_PORT,
|
||||
username=DEVICE_ID,
|
||||
password=MQTT_PASSWORD,
|
||||
tls_context=TLS_CONTEXT,
|
||||
identifier=DEVICE_ID,
|
||||
) as client:
|
||||
MQTT_CLIENT = client # 保存全局客户端
|
||||
print("MQTT client connected")
|
||||
|
||||
# 订阅默认 topic
|
||||
for topic in DEFAULT_SUBSCRIPTIONS:
|
||||
await MQTT_CLIENT.subscribe(topic.to_topic())
|
||||
print(f"Subscribed to default topic: {topic.to_topic()}")
|
||||
|
||||
# 启动消息处理循环
|
||||
await _mqtt_handle_messages()
|
||||
|
||||
|
||||
# ------------------ 示例主程序 ------------------
|
||||
|
||||
|
||||
# async def main():
|
||||
# await mqtt_client_async()
|
||||
#
|
||||
# # 示例:发布消息
|
||||
# await mqtt_publish("status", "edge", DEVICE_ID, "heartbeat", '{"alive":true}')
|
||||
#
|
||||
# # 示例:群发
|
||||
# targets = [
|
||||
# {"domain": "cmd", "device_type": "edge", "device_id": "edge01"},
|
||||
# {"domain": "cmd", "device_type": "edge", "device_id": "edge02"},
|
||||
# ]
|
||||
# await mqtt_publish_multiple(targets, "restart", '{"action":"restart"}')
|
||||
@@ -22,7 +22,7 @@ def push_file(bucket_name, object_name, file_bytes, contents, content_type):
|
||||
)
|
||||
|
||||
|
||||
def get_upload_token(user_id, bucket_name, object_name, xpires=timedelta(minutes=15)):
|
||||
def get_upload_token(bucket_name, object_name, xpires=timedelta(minutes=15)):
|
||||
return minio_client.presigned_put_object(
|
||||
bucket_name=bucket_name, object_name=object_name, expires=xpires
|
||||
)
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from utils.GlobalVariable import LOCAL_IP
|
||||
|
||||
RABBIT_HOST = LOCAL_IP
|
||||
RABBIT_VHOST = "bbit_ai"
|
||||
RABBIT_USER = "ai_lab"
|
||||
RABBIT_PASSWORD = "123456"
|
||||
QUEUE_NAME = "analysis_queue"
|
||||
RABBIT_VHOST = "bbit_ai"
|
||||
|
||||
SENTINEL_VHOST = "sentinel"
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import redis
|
||||
|
||||
|
||||
class RedisClient:
|
||||
# ---------------- Redis Client ----------------
|
||||
|
||||
|
||||
class RedisClient:
|
||||
def __init__(self, config_path="config.yaml"):
|
||||
self.redis = redis.Redis(
|
||||
"10.10.12.101",
|
||||
@@ -22,3 +24,24 @@ class RedisClient:
|
||||
def is_device_online(self, device_id: str) -> bool:
|
||||
key = f"device:online:{device_id}"
|
||||
return self.redis.exists(key) == 1
|
||||
|
||||
def set_device_info(self, device_id: str, info: dict):
|
||||
"""
|
||||
存储完整设备信息到 redis hash
|
||||
将 bool 转为 int
|
||||
"""
|
||||
key = f"device:info:{device_id}"
|
||||
|
||||
# 转换 bool 为 int
|
||||
sanitized_info = {
|
||||
k: (int(v) if isinstance(v, bool) else v) for k, v in info.items()
|
||||
}
|
||||
|
||||
self.redis.hmset(key, sanitized_info)
|
||||
|
||||
def get_device_info(self, device_id: str) -> dict:
|
||||
key = f"device:info:{device_id}"
|
||||
return self.redis.hgetall(key)
|
||||
|
||||
|
||||
redis_client = RedisClient()
|
||||
|
||||
@@ -1,13 +1,16 @@
|
||||
import jwt
|
||||
from jwt import PyJWTError
|
||||
from uuid import UUID
|
||||
from fastapi import Header, HTTPException, Depends
|
||||
|
||||
JWT_SECRET = "secret_jwt"
|
||||
import jwt
|
||||
from fastapi import Header, HTTPException
|
||||
from jwt import PyJWTError
|
||||
from starlette.websockets import WebSocketDisconnect
|
||||
|
||||
JWT_SECRET = "secret_jwt"
|
||||
JWT_ALGORITHM = "HS256"
|
||||
JWT_AUDIENCE = "snowflake-ink"
|
||||
JWT_ISSUER = "https://snowflake.ink/"
|
||||
|
||||
|
||||
def get_user_id_from_token(token: str = Header(..., alias="Authorization")) -> UUID:
|
||||
"""
|
||||
从 Authorization 头解析 token,并返回 user_id
|
||||
@@ -24,7 +27,7 @@ def get_user_id_from_token(token: str = Header(..., alias="Authorization")) -> U
|
||||
JWT_SECRET,
|
||||
algorithms=[JWT_ALGORITHM],
|
||||
audience=JWT_AUDIENCE,
|
||||
issuer=JWT_ISSUER
|
||||
issuer=JWT_ISSUER,
|
||||
)
|
||||
except PyJWTError:
|
||||
raise HTTPException(status_code=401, detail="Token is missing or invalid")
|
||||
@@ -36,4 +39,28 @@ def get_user_id_from_token(token: str = Header(..., alias="Authorization")) -> U
|
||||
if not user_id:
|
||||
raise HTTPException(status_code=401, detail="User ID not found in token")
|
||||
|
||||
return UUID(user_id)
|
||||
return UUID(user_id)
|
||||
|
||||
|
||||
def get_user_id_from_token_from_ws(token: str) -> UUID:
|
||||
if token.startswith("Bearer "):
|
||||
token = token[7:]
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
JWT_SECRET,
|
||||
algorithms=[JWT_ALGORITHM],
|
||||
audience=JWT_AUDIENCE,
|
||||
issuer=JWT_ISSUER,
|
||||
)
|
||||
except PyJWTError:
|
||||
raise WebSocketDisconnect() # token 无效就断开
|
||||
|
||||
if payload.get("token_type") != "access_token":
|
||||
raise WebSocketDisconnect()
|
||||
|
||||
user_id = payload.get("user_id")
|
||||
if not user_id:
|
||||
raise WebSocketDisconnect()
|
||||
|
||||
return UUID(user_id)
|
||||
|
||||
Reference in New Issue
Block a user