From b9b8d30ebf94e7968d281bda2ac05f3a206d76e9 Mon Sep 17 00:00:00 2001 From: BBIT-Kai <2911862937@qq.com> Date: Mon, 29 Dec 2025 16:30:36 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E5=96=84=E7=89=A7=E5=AE=89=E4=BA=91?= =?UTF-8?q?=E5=93=A8-=E5=90=8E=E7=AB=AF?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bbit_ai/app/agent/vehicleImageAgent.py | 79 +++++++ bbit_ai/app/app.py | 11 +- bbit_ai/app/config/emqx.py | 143 ++++++++++++ bbit_ai/app/config/minIO.py | 2 +- bbit_ai/app/config/rabbitMQ.py | 4 +- bbit_ai/app/config/redis.py | 25 ++- bbit_ai/app/config/security.py | 39 +++- bbit_ai/app/db/postgres/iot.py | 211 ++++++++++++++++++ bbit_ai/app/db/postgres/sentinel.py | 53 +++++ bbit_ai/app/db/postgres/system.py | 48 +++- bbit_ai/app/db/postgres/ws_manager.py | 41 +++- bbit_ai/app/models/IotDeviceCommandRequest.py | 8 + bbit_ai/app/models/MqttTopic.py | 107 +++++++++ bbit_ai/app/models/SentinelRecordRequest.py | 10 + bbit_ai/app/routers/Iot.py | 143 +++++++++++- bbit_ai/app/routers/Public.py | 12 + bbit_ai/app/routers/RabbitMQ.py | 11 +- bbit_ai/app/routers/Service.py | 4 +- bbit_ai/app/routers/System.py | 3 + bbit_ai/app/routers/Vision.py | 2 +- bbit_ai/app/routers/WS.py | 32 ++- bbit_ai/app/service/RabbitMQ.py | 98 ++++++++ bbit_ai/app/service/vision.py | 29 +++ 23 files changed, 1074 insertions(+), 41 deletions(-) create mode 100644 bbit_ai/app/agent/vehicleImageAgent.py create mode 100644 bbit_ai/app/config/emqx.py create mode 100644 bbit_ai/app/models/IotDeviceCommandRequest.py create mode 100644 bbit_ai/app/models/MqttTopic.py create mode 100644 bbit_ai/app/models/SentinelRecordRequest.py create mode 100644 bbit_ai/app/service/RabbitMQ.py diff --git a/bbit_ai/app/agent/vehicleImageAgent.py b/bbit_ai/app/agent/vehicleImageAgent.py new file mode 100644 index 0000000..8fbaa69 --- /dev/null +++ b/bbit_ai/app/agent/vehicleImageAgent.py @@ -0,0 +1,79 @@ +import json +import re +from typing import TypedDict + +from langchain_core.messages import HumanMessage +from langgraph.graph import StateGraph, END + +from config.llm import llmVision + + +# -------- 定义状态 -------- +class State(TypedDict): + image_url: str # 图像 + content: str # 最终内容 + + +def send_analyze(state: State, prompt_text: str): + messages = [ + HumanMessage( + content=[ + {"type": "text", "text": prompt_text}, + {"type": "image_url", "image_url": {"url": state["image_url"]}}, + ] + ) + ] + return llmVision.invoke(messages).content + + +def analysis(state: State): + state["content"] = send_analyze( + state, + """ +提示词示例 +你是一个图像分析助手。现在给你一张车的侧身照片,请你从图中分析车上运输的牲畜种类。 + +要求: +1. 牲畜种类可能是:牛、羊、猪、鸡、鸭、鹅。 +2. 如果图中无法判断牲畜类型,请在备注字段 remark 中写明“无法识别”或你观察到的情况。 +3. 不允许输出多余文字,直接返回 JSON。 + +JSON 示例格式: + { + "livestock_type": "<牲畜种类>", // 如果能识别就填牛/羊/猪/鸡/鸭/鹅 + "remark": "<备注>" // 如果无法识别,写明原因;否则可留空 + } +请确保输出的 JSON 可以被严格解析。 +""", + ) + return state + + +# ------------------------------------------------------------------------ 构建有向图 -------- +workflow = StateGraph(State) +# 必须先从 START 指向 analysis +workflow.add_node("analysis", analysis) +workflow.set_entry_point("analysis") +workflow.add_edge("analysis", END) +graph = workflow.compile() + + +# 执行函数 + + +async def get_vehicle_response(image_url: str): + final_state = graph.invoke( + { + "image_url": image_url, + } + ) + # 去掉 ```json 和 ``` 包裹 + content_str = re.sub(r"^```json\s*|\s*```$", "", final_state["content"].strip()) + # 把 JSON 字符串转为字典 + try: + content_dict = json.loads(content_str) + except json.JSONDecodeError: + print("JSON解析失败") + content_dict = {} + + return content_dict diff --git a/bbit_ai/app/app.py b/bbit_ai/app/app.py index 8e0ff00..538ad5c 100644 --- a/bbit_ai/app/app.py +++ b/bbit_ai/app/app.py @@ -4,6 +4,7 @@ from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from uvicorn import Config, Server +from config.emqx import mqtt_client_async from config.yolo import YOLOSingleton from routers.Bot import botRouter from routers.Chat import chatRouter @@ -18,6 +19,7 @@ from routers.Service import serviceRouter from routers.System import systemRouter from routers.Vision import visionRouter from routers.WS import iot_ws_router +from service.RabbitMQ import sentinel_pull_analysis_async async def ai_lab(): @@ -63,6 +65,11 @@ async def main(): YOLOSingleton.init_model() # 主干AI实验室FastAPI服务 task_api = asyncio.create_task(ai_lab()) + # RabbitMQ服务 + task_mq = asyncio.create_task(sentinel_pull_analysis_async()) + # 等 HTTP 服务启动后再启动 MQTT + task_mqtt = asyncio.create_task(mqtt_client_async()) + await asyncio.gather(task_api, task_mq, task_mqtt) # MCP服务-ailab # endpoint_url_ai_lab = "wss://ai.ronsunny.cn:8090/aimcp/mcp_endpoint/mcp/?token=TsSP9lBq6Oa1WMkachHoS2TtNt4GKV/Gli24pk5Rjpk%3D" @@ -73,11 +80,7 @@ async def main(): # endpoint_url_ql = "wss://ai.ronsunny.cn:8090/aimcp/mcp_endpoint/mcp/?token=8ZmCzp7FzsbxwHOg2%2FvBQkxrC3QWJiI%2B4iTfouExinjcT8ZgLwQfFUtgcMInI7St" # task_mcp2 = asyncio.create_task(init_mcp_server(endpoint_url_ql)) - # RabbitMQ服务 - # task_mq = asyncio.create_task(mq_pull_analysis_async()) - # await asyncio.gather(task_api, task_mcp1, task_mcp2, task_mq) - await asyncio.gather(task_api) if __name__ == "__main__": diff --git a/bbit_ai/app/config/emqx.py b/bbit_ai/app/config/emqx.py new file mode 100644 index 0000000..f9efcdc --- /dev/null +++ b/bbit_ai/app/config/emqx.py @@ -0,0 +1,143 @@ +import json +import os +import socket +import ssl +import sys +import uuid + +from aiomqtt import Client + +from config.redis import redis_client +from models.MqttTopic import MqttTopic + +# ================= 配置区域 ================= +MQTT_BROKER = "ai.ronsunny.cn" +MQTT_PORT = 8093 +MQTT_PASSWORD = "123456" +TLS_CONTEXT = ssl.create_default_context() + +# 默认连接后要订阅的 topic 配置 +DEFAULT_SUBSCRIPTIONS = [ + MqttTopic.from_parts( + project=None, + domain="status", + device_type="edge", + device_id=None, + resource="info", + ) +] +# =========================================== + +DEVICE_ID = None +MQTT_CLIENT: Client | None = None # 全局客户端 + +# Windows 平台下切换到 SelectorEventLoop +if sys.platform.lower() == "win32" or os.name.lower() == "nt": + from asyncio import set_event_loop_policy, WindowsSelectorEventLoopPolicy + + set_event_loop_policy(WindowsSelectorEventLoopPolicy()) + + +def get_device_id_simple(): + try: + with open("/etc/machine-id") as f: + mid = f.read().strip() + if mid: + return mid + except Exception: + pass + hostname = socket.gethostname() + mac = uuid.getnode() + mac_str = ":".join(f"{(mac >> ele) & 0xff:02x}" for ele in range(40, -1, -8)) + return f"{hostname}|{mac_str}" + + +# todo 这里需要订阅状态信息 设备发送信息 这里回复 vue前端发送指令 后端发送指令 设备接收指令 +# ------------------ MQTT 封装 ------------------ + + +async def mqtt_publish( + project: str, + domain: str, + device_type: str, + device_id: str, + resource: str, + payload: str, + qos: int = 1, +): + """发布消息(使用全局客户端)""" + if not MQTT_CLIENT: + raise RuntimeError("MQTT client is not initialized") + topic = f"{project}/{domain}/{device_type}/{device_id}/{resource}" + await MQTT_CLIENT.publish(topic, payload, qos=qos) + print(f"Published to {topic}: {payload}") + + +async def mqtt_publish_multiple( + targets: list[dict], resource: str, payload: str, qos: int = 1 +): + """群发消息""" + for target in targets: + await mqtt_publish( + domain=target["domain"], + device_type=target["device_type"], + device_id=target["device_id"], + resource=resource, + payload=payload, + qos=qos, + ) + + +async def _mqtt_handle_messages(): + """后台循环处理消息""" + if not MQTT_CLIENT: + return + async for message in MQTT_CLIENT.messages: + topic = MqttTopic(message.topic) + print("收到消息:" + str(topic)) + + # 处理基础状态信息 + if topic.domain == "status" and topic.resource == "info": + payload = json.loads(message.payload.decode()) + redis_client.set_device_info(topic.device_id, payload) + + +async def mqtt_client_async(): + global DEVICE_ID, MQTT_CLIENT + DEVICE_ID = get_device_id_simple() + print("服务端EMQX账号:", DEVICE_ID) + async with Client( + MQTT_BROKER, + port=MQTT_PORT, + username=DEVICE_ID, + password=MQTT_PASSWORD, + tls_context=TLS_CONTEXT, + identifier=DEVICE_ID, + ) as client: + MQTT_CLIENT = client # 保存全局客户端 + print("MQTT client connected") + + # 订阅默认 topic + for topic in DEFAULT_SUBSCRIPTIONS: + await MQTT_CLIENT.subscribe(topic.to_topic()) + print(f"Subscribed to default topic: {topic.to_topic()}") + + # 启动消息处理循环 + await _mqtt_handle_messages() + + +# ------------------ 示例主程序 ------------------ + + +# async def main(): +# await mqtt_client_async() +# +# # 示例:发布消息 +# await mqtt_publish("status", "edge", DEVICE_ID, "heartbeat", '{"alive":true}') +# +# # 示例:群发 +# targets = [ +# {"domain": "cmd", "device_type": "edge", "device_id": "edge01"}, +# {"domain": "cmd", "device_type": "edge", "device_id": "edge02"}, +# ] +# await mqtt_publish_multiple(targets, "restart", '{"action":"restart"}') diff --git a/bbit_ai/app/config/minIO.py b/bbit_ai/app/config/minIO.py index 16e61d6..3b12aa7 100644 --- a/bbit_ai/app/config/minIO.py +++ b/bbit_ai/app/config/minIO.py @@ -22,7 +22,7 @@ def push_file(bucket_name, object_name, file_bytes, contents, content_type): ) -def get_upload_token(user_id, bucket_name, object_name, xpires=timedelta(minutes=15)): +def get_upload_token(bucket_name, object_name, xpires=timedelta(minutes=15)): return minio_client.presigned_put_object( bucket_name=bucket_name, object_name=object_name, expires=xpires ) diff --git a/bbit_ai/app/config/rabbitMQ.py b/bbit_ai/app/config/rabbitMQ.py index 51c911d..33fa827 100644 --- a/bbit_ai/app/config/rabbitMQ.py +++ b/bbit_ai/app/config/rabbitMQ.py @@ -1,7 +1,9 @@ from utils.GlobalVariable import LOCAL_IP RABBIT_HOST = LOCAL_IP -RABBIT_VHOST = "bbit_ai" RABBIT_USER = "ai_lab" RABBIT_PASSWORD = "123456" QUEUE_NAME = "analysis_queue" +RABBIT_VHOST = "bbit_ai" + +SENTINEL_VHOST = "sentinel" diff --git a/bbit_ai/app/config/redis.py b/bbit_ai/app/config/redis.py index 1a5cc19..11fa8d3 100644 --- a/bbit_ai/app/config/redis.py +++ b/bbit_ai/app/config/redis.py @@ -1,8 +1,10 @@ import redis -class RedisClient: +# ---------------- Redis Client ---------------- + +class RedisClient: def __init__(self, config_path="config.yaml"): self.redis = redis.Redis( "10.10.12.101", @@ -22,3 +24,24 @@ class RedisClient: def is_device_online(self, device_id: str) -> bool: key = f"device:online:{device_id}" return self.redis.exists(key) == 1 + + def set_device_info(self, device_id: str, info: dict): + """ + 存储完整设备信息到 redis hash + 将 bool 转为 int + """ + key = f"device:info:{device_id}" + + # 转换 bool 为 int + sanitized_info = { + k: (int(v) if isinstance(v, bool) else v) for k, v in info.items() + } + + self.redis.hmset(key, sanitized_info) + + def get_device_info(self, device_id: str) -> dict: + key = f"device:info:{device_id}" + return self.redis.hgetall(key) + + +redis_client = RedisClient() diff --git a/bbit_ai/app/config/security.py b/bbit_ai/app/config/security.py index 1770c43..32c4f1c 100644 --- a/bbit_ai/app/config/security.py +++ b/bbit_ai/app/config/security.py @@ -1,13 +1,16 @@ -import jwt -from jwt import PyJWTError from uuid import UUID -from fastapi import Header, HTTPException, Depends -JWT_SECRET = "secret_jwt" +import jwt +from fastapi import Header, HTTPException +from jwt import PyJWTError +from starlette.websockets import WebSocketDisconnect + +JWT_SECRET = "secret_jwt" JWT_ALGORITHM = "HS256" JWT_AUDIENCE = "snowflake-ink" JWT_ISSUER = "https://snowflake.ink/" + def get_user_id_from_token(token: str = Header(..., alias="Authorization")) -> UUID: """ 从 Authorization 头解析 token,并返回 user_id @@ -24,7 +27,7 @@ def get_user_id_from_token(token: str = Header(..., alias="Authorization")) -> U JWT_SECRET, algorithms=[JWT_ALGORITHM], audience=JWT_AUDIENCE, - issuer=JWT_ISSUER + issuer=JWT_ISSUER, ) except PyJWTError: raise HTTPException(status_code=401, detail="Token is missing or invalid") @@ -36,4 +39,28 @@ def get_user_id_from_token(token: str = Header(..., alias="Authorization")) -> U if not user_id: raise HTTPException(status_code=401, detail="User ID not found in token") - return UUID(user_id) \ No newline at end of file + return UUID(user_id) + + +def get_user_id_from_token_from_ws(token: str) -> UUID: + if token.startswith("Bearer "): + token = token[7:] + try: + payload = jwt.decode( + token, + JWT_SECRET, + algorithms=[JWT_ALGORITHM], + audience=JWT_AUDIENCE, + issuer=JWT_ISSUER, + ) + except PyJWTError: + raise WebSocketDisconnect() # token 无效就断开 + + if payload.get("token_type") != "access_token": + raise WebSocketDisconnect() + + user_id = payload.get("user_id") + if not user_id: + raise WebSocketDisconnect() + + return UUID(user_id) diff --git a/bbit_ai/app/db/postgres/iot.py b/bbit_ai/app/db/postgres/iot.py index 2d723ed..55cbeca 100644 --- a/bbit_ai/app/db/postgres/iot.py +++ b/bbit_ai/app/db/postgres/iot.py @@ -1,5 +1,6 @@ from hashlib import sha256 +from config.minIO import get_temp_url from config.pgDb import pg_pool from utils.MyUtils import format_datetime, is_valid_uuid @@ -201,3 +202,213 @@ def delete_device_db(id: str) -> int: cursor.execute("DELETE FROM iot_users WHERE id=%s;", (id,)) conn.commit() return cursor.rowcount + + +def delete_update_db(id: str) -> int: + with pg_pool.getConn() as conn: + with conn.cursor() as cursor: + cursor.execute( + "DELETE FROM iot_update WHERE id = %s;", + (id,), + ) + conn.commit() + return cursor.rowcount + + +def get_update_list_db_page( + page: int, + page_size: int, + id=None, + code=None, + dept_id=None, + startTime=None, + endTime=None, +): + offset = (page - 1) * page_size + + conditions = [] + params = [] + + if id is not None: + conditions.append("u.id::text LIKE %s") + params.append(f"%{id}%") + + # ---- 版本 / 升级代码 ---- + if code is not None: + conditions.append("u.code = %s") + params.append(code) + + # ---- 部门 ---- + if dept_id and is_valid_uuid(dept_id): + conditions.append("u.dept_id = %s") + params.append(dept_id) + + # ---- 时间过滤 ---- + if startTime: + conditions.append("u.created_at >= %s") + params.append(startTime) + + if endTime: + conditions.append("u.created_at <= %s") + params.append(endTime) + + where_clause = " WHERE " + " AND ".join(conditions) if conditions else "" + + with pg_pool.getConn() as conn: + with conn.cursor() as cursor: + + # ---- 总数 ---- + count_sql = f""" + SELECT COUNT(*) + FROM iot_update u + {where_clause}; + """ + cursor.execute(count_sql, params) + total = cursor.fetchone()[0] + + # ---- 列表 ---- + list_sql = f""" + SELECT + u.id, + u.code, + u.dept_id, + sd.name AS dept_name, + u.remark, + u.oss, + u.size, + u.created_at + FROM iot_update u + LEFT JOIN sys_dept sd ON u.dept_id = sd.id + {where_clause} + ORDER BY u.created_at DESC + LIMIT %s OFFSET %s; + """ + + cursor.execute(list_sql, params + [page_size, offset]) + rows = cursor.fetchall() + + result = [] + for r in rows: + ( + update_id, + code, + dept_id, + dept_name, + remark, + oss, + size, + created_at, + ) = r + + result.append( + { + "id": update_id, + "code": code, + "dept_id": dept_id, + "dept_name": dept_name, + "remark": remark, + "oss_url": oss, + "size": size, + "created_at": format_datetime(created_at), + } + ) + + return result, total + + +def insert_update(data: dict) -> str: + with pg_pool.getConn() as conn: + with conn.cursor() as cursor: + cursor.execute( + """ + INSERT INTO iot_update + (code, dept_id, remark, oss, size) + VALUES + (%s, %s, %s, %s, %s) + RETURNING id; + """, + ( + data.get("code"), + data.get("dept_id"), + data.get("remark"), + data.get("uploadId"), + data.get("size"), + ), + ) + update_id = cursor.fetchone()[0] + conn.commit() + return update_id + + +def get_update_package(device_id: str | None = None): + """ + 根据设备 ID 获取所属组织最新版本的更新包信息 + 返回示例: + { + "version": 1001, + "url": "https://xxx", + "notes": "更新内容描述" + } + """ + if not device_id: + return None + + sql_get_dept = """ + SELECT dept_id + FROM iot_users + WHERE name = %s + LIMIT 1 + """ + + sql_get_package = """ + SELECT code, oss, remark + FROM iot_update + WHERE dept_id = %s + ORDER BY code DESC + LIMIT 1 + """ + + with pg_pool.getConn() as conn: + with conn.cursor() as cursor: + # 1. 查询设备所属组织 + cursor.execute(sql_get_dept, (device_id,)) + row = cursor.fetchone() + if not row: + return None + + dept_id = row[0] + + # 2. 查询该组织最新更新包 + cursor.execute(sql_get_package, (dept_id,)) + row = cursor.fetchone() + if not row: + return None + + version, oss_path, content = row + return { + "version": version, + "url": get_temp_url("iot-update", oss_path), + "notes": content, + } + + +def getMaxCodeByDeptId(dept_id: str | None = None) -> int: + """ + 根据组织ID获取 iot_update_package 最大 code,并在结果上加 1 + 返回整数,如果没有记录则返回 1 + """ + if not dept_id: + return 0 # dept_id 为空直接返回初始版本号 1 + + sql = """ + SELECT MAX(code) + FROM iot_update + WHERE dept_id = %s + """ + + with pg_pool.getConn() as conn: + with conn.cursor() as cursor: + cursor.execute(sql, (dept_id,)) + row = cursor.fetchone() + max_code = row[0] if row and row[0] is not None else 0 + return max_code diff --git a/bbit_ai/app/db/postgres/sentinel.py b/bbit_ai/app/db/postgres/sentinel.py index 67e7f19..5074463 100644 --- a/bbit_ai/app/db/postgres/sentinel.py +++ b/bbit_ai/app/db/postgres/sentinel.py @@ -1,5 +1,6 @@ from config.minIO import get_temp_url_dict from config.pgDb import pg_pool +from models.SentinelRecordRequest import SentinelRecordRequest from utils.MyUtils import format_datetime @@ -259,3 +260,55 @@ def delete_sentinel_record_db(id: str) -> int: cursor.execute("DELETE FROM sentinel_records WHERE id=%s;", (id,)) conn.commit() return cursor.rowcount + + +def saveSentinelRecord(data: SentinelRecordRequest) -> str: + sql = """ + INSERT INTO sentinel_records ( + license_plate, + license_plate_image, + vehicle_type, + vehicle_image + ) + VALUES (%s, %s, %s, %s) + RETURNING id; + """ + + with pg_pool.getConn() as conn: + with conn.cursor() as cursor: + cursor.execute( + sql, + ( + data.LicensePlate, + data.LicensePlateImage, + data.VehicleType, + data.VehicleImage, + ), + ) + new_id = cursor.fetchone()[0] + conn.commit() + return str(new_id) + + +def update_sentinel_record( + id: str, livestock_type: str, remark: str, dept_id: str +) -> bool: + """ + 根据 id 更新 sentinel_records 表中的 livestock_type 和 dept_id + """ + sql = """ + UPDATE sentinel_records + SET livestock_type = %s, + remark = %s, + dept_id = %s, + updated_at = now() + WHERE id = %s + RETURNING id; + """ + + with pg_pool.getConn() as conn: + with conn.cursor() as cursor: + cursor.execute(sql, (livestock_type, remark, dept_id, id)) + record = cursor.fetchone() + conn.commit() + return record is not None diff --git a/bbit_ai/app/db/postgres/system.py b/bbit_ai/app/db/postgres/system.py index 5901580..cdf0624 100644 --- a/bbit_ai/app/db/postgres/system.py +++ b/bbit_ai/app/db/postgres/system.py @@ -942,11 +942,55 @@ def get_dept_ids_by_user_id(user_id: UUID) -> list: return dept_ids -def get_dept_id_by_user_id(user_id: UUID) -> list: - # 第一步:通过 user_id 查找其所属的 dept_id +def get_dept_id_by_user_id(user_id: str) -> str: + # 通过 user_id 查找其所属的 dept_id with pg_pool.getConn() as conn: with conn.cursor() as cursor: cursor.execute("SELECT dept_id FROM users WHERE id = %s", (user_id,)) dept_id = cursor.fetchone() dept_id = dept_id[0] + return str(dept_id) + + +def get_dept_id_by_iot_user_name(user_id: UUID) -> str: + # 通过 iot_user_id 查找其所属的 dept_id + with pg_pool.getConn() as conn: + with conn.cursor() as cursor: + cursor.execute("SELECT dept_id FROM iot_users WHERE name = %s", (user_id,)) + dept_id = cursor.fetchone() + dept_id = dept_id[0] return dept_id + + +from typing import List + + +def get_dept_ids_by_dept_id(dept_id: str) -> List[str]: + """ + 获取当前部门 ID 以及其所有父部门 ID(递归向上) + 返回顺序:从当前部门一直到最顶层父部门 + """ + with pg_pool.getConn() as conn: + with conn.cursor() as cursor: + cursor.execute( + """ + WITH RECURSIVE dept_tree AS ( + -- 起点:当前部门 + SELECT id, parent_id + FROM sys_dept + WHERE id = %s + + UNION ALL + + -- 向上递归找父部门 + SELECT d.id, d.parent_id + FROM sys_dept d + INNER JOIN dept_tree dt ON d.id = dt.parent_id + ) + SELECT id FROM dept_tree; + """, + (dept_id,), + ) + + rows = cursor.fetchall() + return [str(row[0]) for row in rows] diff --git a/bbit_ai/app/db/postgres/ws_manager.py b/bbit_ai/app/db/postgres/ws_manager.py index 9d5bc5c..16d80b8 100644 --- a/bbit_ai/app/db/postgres/ws_manager.py +++ b/bbit_ai/app/db/postgres/ws_manager.py @@ -1,25 +1,50 @@ import asyncio from typing import List +from uuid import UUID from fastapi import WebSocket class ConnectionManager: def __init__(self): - self.active_connections: List[WebSocket] = [] + self.active_connections: List[dict] = [] # 保存 websocket 和用户信息 self.lock = asyncio.Lock() - async def connect(self, websocket: WebSocket): + # proj_id:0:在线状态 1:畜牧车辆进入 + async def connect( + self, websocket: WebSocket, user_id: UUID, dept_id: str, proj_id: int + ): await websocket.accept() async with self.lock: - self.active_connections.append(websocket) + self.active_connections.append( + { + "ws": websocket, + "user_id": user_id, + "dept_id": dept_id, + "proj_id": proj_id, + } + ) async def disconnect(self, websocket: WebSocket): async with self.lock: - if websocket in self.active_connections: - self.active_connections.remove(websocket) + self.active_connections = [ + conn for conn in self.active_connections if conn["ws"] != websocket + ] - async def broadcast(self, message: dict): + async def noticeOnlineStatus(self, message: dict): async with self.lock: - for ws in self.active_connections: - await ws.send_json(message) + for conn in self.active_connections: + if conn["proj_id"] == 0: + await conn["ws"].send_json(message) + + async def noticeSentinel( + self, message: dict, target_departments: List[UUID] = None + ): + """ + target_departments: 指定哪些部门能收到消息 + """ + async with self.lock: + for conn in self.active_connections: + if target_departments: + if conn["proj_id"] == 1 and conn["dept_id"] in target_departments: + await conn["ws"].send_json(message) diff --git a/bbit_ai/app/models/IotDeviceCommandRequest.py b/bbit_ai/app/models/IotDeviceCommandRequest.py new file mode 100644 index 0000000..c2d438f --- /dev/null +++ b/bbit_ai/app/models/IotDeviceCommandRequest.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class IotDeviceCommandRequest(BaseModel): + id: str | None = None + command: str | None = None + project: str | None = None + device_type: str | None = None diff --git a/bbit_ai/app/models/MqttTopic.py b/bbit_ai/app/models/MqttTopic.py new file mode 100644 index 0000000..1e30c35 --- /dev/null +++ b/bbit_ai/app/models/MqttTopic.py @@ -0,0 +1,107 @@ +from typing import Optional + + +class MqttTopic: + """ + 封装 MQTT topic,根据规则: + project/domain/deviceType/deviceId/resource + """ + + LEVELS = 5 + + def __init__(self, topic: str): + self.raw = str(topic) + parts = self.raw.split("/") + + # 不足的层级用 None 补齐,避免属性缺失 + parts += [None] * (self.LEVELS - len(parts)) + + self.project: Optional[str] = parts[0] + self.domain: Optional[str] = parts[1] + self.device_type: Optional[str] = parts[2] + self.device_id: Optional[str] = parts[3] + self.resource: Optional[str] = parts[4] + + @classmethod + def from_parts( + cls, + project: Optional[str] = None, + domain: Optional[str] = None, + device_type: Optional[str] = None, + device_id: Optional[str] = None, + resource: Optional[str] = None, + ) -> "MqttTopic": + """ + 通过结构化参数构造 topic + None -> '+' + """ + + def _v(v: Optional[str]) -> str: + return "+" if v is None else str(v) + + topic = "/".join( + map( + _v, + [ + project, + domain, + device_type, + device_id, + resource, + ], + ) + ) + return cls(topic) + + def to_topic(self) -> str: + """ + 根据当前字段生成 topic(允许 '+') + """ + + def _v(v: Optional[str]) -> str: + return "+" if v is None else v + + return "/".join( + map( + _v, + [ + self.project, + self.domain, + self.device_type, + self.device_id, + self.resource, + ], + ) + ) + + def build(self) -> str: + """ + 生成严格 topic(不允许 None / '+') + 用于 publish 场景 + """ + parts = [ + self.project, + self.domain, + self.device_type, + self.device_id, + self.resource, + ] + + if any(p in (None, "+") for p in parts): + raise ValueError( + f"Cannot build strict topic, wildcard exists: {self.to_topic()}" + ) + + return "/".join(parts) + + def is_wildcard(self) -> bool: + return "+" in self.to_topic() or "#" in self.to_topic() + + def __repr__(self): + return f"" + + def is_status(self) -> bool: + return self.domain == "status" + + def is_cmd(self) -> bool: + return self.domain == "cmd" diff --git a/bbit_ai/app/models/SentinelRecordRequest.py b/bbit_ai/app/models/SentinelRecordRequest.py new file mode 100644 index 0000000..8cc941b --- /dev/null +++ b/bbit_ai/app/models/SentinelRecordRequest.py @@ -0,0 +1,10 @@ +from pydantic import BaseModel + + +class SentinelRecordRequest(BaseModel): + Id: str | None = None + DeviceId: str + LicensePlate: str | None = None + LicensePlateImage: str | None = None + VehicleType: str | None = None + VehicleImage: str | None = None diff --git a/bbit_ai/app/routers/Iot.py b/bbit_ai/app/routers/Iot.py index e7d5142..b553226 100644 --- a/bbit_ai/app/routers/Iot.py +++ b/bbit_ai/app/routers/Iot.py @@ -1,17 +1,19 @@ +import uuid from uuid import UUID from fastapi import APIRouter from fastapi import Depends -from config.redis import RedisClient +from config.emqx import mqtt_publish +from config.minIO import get_upload_token +from config.redis import redis_client from db.postgres.iot import * from models.BaseResponse import BaseResponse from models.EMQXWebhook import EMQXWebhook +from models.IotDeviceCommandRequest import IotDeviceCommandRequest from routers.WS import ws_manager iot_router = APIRouter() -redis_client = RedisClient() - from config.security import get_user_id_from_token # -------------------- 设备接口 -------------------- @@ -25,14 +27,14 @@ async def emqx_webhook(data: EMQXWebhook): if event == "client.connected": redis_client.set_online(device_id) - await ws_manager.broadcast({"deviceId": device_id, "online": True}) + await ws_manager.noticeOnlineStatus({"deviceId": device_id, "online": True}) print(f"[ONLINE] {device_id}") elif event == "client.disconnected": redis_client.set_offline(device_id) - await ws_manager.broadcast({"deviceId": device_id, "online": False}) + await ws_manager.noticeOnlineStatus({"deviceId": device_id, "online": False}) print(f"[OFFLINE] {device_id}") @@ -68,6 +70,19 @@ async def get_device_list( device_id = d["name"] # 账号 d["online"] = redis_client.is_device_online(device_id) == 1 + info_json = redis_client.get_device_info(device_id) + d["version"] = info_json.get("version", "") + d["ip"] = info_json.get("ip", "") + d["hostname"] = info_json.get("hostname", "") + d["mac"] = info_json.get("mac", "") + d["os"] = info_json.get("os", "") + d["cpu"] = info_json.get("cpu", "") + d["memory_total"] = info_json.get("memory_total", "") + d["disk_total"] = info_json.get("disk_total", "") + d["last_seen"] = info_json.get("last_seen", "") + d["project"] = info_json.get("project", "") + d["device_type"] = info_json.get("deviceType", "") + return BaseResponse(data={"list": devices, "total": total}) @@ -121,3 +136,121 @@ async def delete_device( if deleted == 0: return BaseResponse(status=False, message="设备不存在", data=None) return BaseResponse(data=True) + + +@iot_router.get("/common/update/list") +async def get_update_list( + page: int = 1, + pageSize: int = 10, + id: str | None = None, + code: str | None = None, + dept_id: str | None = None, + startTime: str | None = None, + endTime: str | None = None, + user_id: UUID = Depends(get_user_id_from_token), +): + if not user_id: + return {"error": "userId is required"} + if code == "" or code is None: + code = None + else: + code = int(code) + + updates, total = get_update_list_db_page( + page, pageSize, id, code, dept_id, startTime, endTime + ) + + return BaseResponse(data={"list": updates, "total": total}) + + +@iot_router.post("/common/update") +async def create_update(data: dict, user_id: UUID = Depends(get_user_id_from_token)): + if not user_id: + return {"error": "userId is required"} + + dept_id = data.get("dept_id") + if not dept_id: + return {"error": "dept_id is required"} + + # 前端传来的版本号 + try: + new_code = int(data.get("code", 0)) + except (TypeError, ValueError): + return BaseResponse( + status=False, + message="无效的版本号", + data=None, + ) + + # 获取该组织当前最大版本号 + max_code = getMaxCodeByDeptId(dept_id) + + if new_code <= max_code: + return BaseResponse( + status=False, + message=f"新版本号必须大于当前最大版本号 {max_code}", + data=None, + ) + + # 插入数据库 + new_id = insert_update(data) + return BaseResponse(data={"id": new_id}) + + +@iot_router.delete("/common/update/{id}") +async def delete_update( + id: str, + user_id: UUID = Depends(get_user_id_from_token), +): + if not user_id: + return {"error": "userId is required"} + + deleted = delete_update_db(id) + if deleted == 0: + return BaseResponse(status=False, message="更新记录不存在", data=None) + + return BaseResponse(data=True) + + +@iot_router.get("/common/update/getUploadUrl") +def getUploadUrl( + user_id: UUID = Depends(get_user_id_from_token), +): + # 生成唯一文件名,避免覆盖 + object_name = f"{uuid.uuid4()}" + return BaseResponse( + data={ + "uploadUrl": get_upload_token("iot-update", object_name), + "id": object_name, + } + ) + + +@iot_router.get("/common/update/getMaxCodeByDeptId") +def updateGetMaxCodeByDeptId( + user_id: UUID = Depends(get_user_id_from_token), + dept_id: str | None = None, +): + # 生成唯一文件名,避免覆盖 + return BaseResponse(data=getMaxCodeByDeptId(dept_id)) + + +@iot_router.get("/common/update/check") +def getUploadUrl( + deviceID: str | None = None, +): + # 生成唯一文件名,避免覆盖 + return BaseResponse(data=get_update_package(deviceID)) + + +@iot_router.post("/common/device/command") +async def command( + data: IotDeviceCommandRequest, user_id: UUID = Depends(get_user_id_from_token) +): + if not user_id: + return {"error": "userId is required"} + + await mqtt_publish( + data.project, "cmd", data.device_type, data.id, data.command, "{}" + ) + return BaseResponse(data=None) diff --git a/bbit_ai/app/routers/Public.py b/bbit_ai/app/routers/Public.py index 81606ef..d39cb36 100644 --- a/bbit_ai/app/routers/Public.py +++ b/bbit_ai/app/routers/Public.py @@ -3,9 +3,12 @@ import base64 from fastapi import APIRouter from config.app import F8_SERVER_USER_ID +from db.postgres.sentinel import saveSentinelRecord from models.BaseResponse import BaseResponse from models.F8ImageRequest import F8ImageRequest from models.F8ImageRequestV2 import F8ImageRequestV2 +from models.SentinelRecordRequest import SentinelRecordRequest +from service.RabbitMQ import sentinel_new_analysis from service.vision import ( process_ticket_image, process_license_image, @@ -78,3 +81,12 @@ async def recognize_silkworm_cocoon(data: F8ImageRequest): return BaseResponse(data=json_data) except Exception as e: return BaseResponse(status=False, message=f"解析失败: {str(e)}", data=None) + + +@publicRouter.post("/sentinel-record-analytics") +async def delete_sentinel_record(data: SentinelRecordRequest): + # 保存部分数据到数据库 + data.Id = saveSentinelRecord(data) + # 发送请求给RabbitMQ + res = await sentinel_new_analysis(data) + return BaseResponse(data=res) diff --git a/bbit_ai/app/routers/RabbitMQ.py b/bbit_ai/app/routers/RabbitMQ.py index 4e2fc02..7cad8ca 100644 --- a/bbit_ai/app/routers/RabbitMQ.py +++ b/bbit_ai/app/routers/RabbitMQ.py @@ -1,12 +1,9 @@ from fastapi import APIRouter -from models.AnalysisRequest import AnalysisRequest -from service.Analyze import mq_new_analysis - rqRouter = APIRouter() -@rqRouter.post("/analyze") -def send_analysis_request(req: AnalysisRequest): - mq_new_analysis(req) - return {"status": "queued"} +# @rqRouter.post("/analyze") +# def send_analysis_request(req: AnalysisRequest): +# mq_new_analysis(req) +# return {"status": "queued"} diff --git a/bbit_ai/app/routers/Service.py b/bbit_ai/app/routers/Service.py index 58b368f..9b9bcc2 100644 --- a/bbit_ai/app/routers/Service.py +++ b/bbit_ai/app/routers/Service.py @@ -17,7 +17,7 @@ serviceRouter = APIRouter() # 对话列表 @serviceRouter.get("/sessionsForService") -def getSessions(user_id: UUID = Depends(get_user_id_from_token)): +async def getSessions(user_id: UUID = Depends(get_user_id_from_token)): if not user_id: return {"error": "userId is required"} return BaseResponse(data=pg.get_sessions(user_id, "service")) @@ -25,7 +25,7 @@ def getSessions(user_id: UUID = Depends(get_user_id_from_token)): # 对话 @serviceRouter.post("/chatForService") -def chat(req: ChatRequest, user_id: UUID = Depends(get_user_id_from_token)): +async def chat(req: ChatRequest, user_id: UUID = Depends(get_user_id_from_token)): if not user_id: return {"error": "userId is required"} if not req.aiId: diff --git a/bbit_ai/app/routers/System.py b/bbit_ai/app/routers/System.py index 6c5dc57..582aa09 100644 --- a/bbit_ai/app/routers/System.py +++ b/bbit_ai/app/routers/System.py @@ -113,6 +113,9 @@ async def menu_list(plat_id: int, user_id: UUID = Depends(get_user_id_from_token m["createTime"] = format_datetime(m.get("created_at")) m["updateTime"] = format_datetime(m.get("updated_at")) m["children"] = [] + # 删除created_at updated_at + m.pop("createTime", None) + m.pop("updateTime", None) # 5. 构建菜单树 tree = build_menu_tree(menus) diff --git a/bbit_ai/app/routers/Vision.py b/bbit_ai/app/routers/Vision.py index 0b6ea7e..f5231f2 100644 --- a/bbit_ai/app/routers/Vision.py +++ b/bbit_ai/app/routers/Vision.py @@ -162,7 +162,7 @@ def getIVASCUploadToken( ): # 生成唯一文件名,避免覆盖 object_name = f"raw/{uuid.uuid4()}" - return BaseResponse(data=get_upload_token(user_id, "video-sca", object_name)) + return BaseResponse(data=get_upload_token("video-sca", object_name)) @visionRouter.get("/getScVideoList") diff --git a/bbit_ai/app/routers/WS.py b/bbit_ai/app/routers/WS.py index e06cbdc..c5eb472 100644 --- a/bbit_ai/app/routers/WS.py +++ b/bbit_ai/app/routers/WS.py @@ -1,6 +1,8 @@ -from fastapi import APIRouter +from fastapi import APIRouter, Query from starlette.websockets import WebSocket, WebSocketDisconnect +from config.security import get_user_id_from_token_from_ws +from db.postgres import get_dept_id_by_user_id from db.postgres.ws_manager import ConnectionManager ws_manager = ConnectionManager() @@ -10,8 +12,13 @@ iot_ws_router = APIRouter() @iot_ws_router.websocket("/device-status") -async def websocket_device_status(websocket: WebSocket): - await ws_manager.connect(websocket) +async def websocket_device_status( + websocket: WebSocket, + token: str = Query(...), +): + user_id = get_user_id_from_token_from_ws(token) + dept_id = get_dept_id_by_user_id(user_id) # 查数据库或缓存 + await ws_manager.connect(websocket, user_id, dept_id, 0) print("[WS] client connected") try: @@ -21,3 +28,22 @@ async def websocket_device_status(websocket: WebSocket): except WebSocketDisconnect: await ws_manager.disconnect(websocket) print("[WS] client disconnected") + + +@iot_ws_router.websocket("/sentinel_record") +async def websocket_sentinel_record( + websocket: WebSocket, + token: str = Query(...), +): + user_id = get_user_id_from_token_from_ws(token) + dept_id = get_dept_id_by_user_id(user_id) # 查数据库或缓存 + print("user_id:", user_id) + print("dept_id:", dept_id) + print("已接入") + await ws_manager.connect(websocket, user_id, dept_id, 1) + + try: + while True: + await websocket.receive_text() + except WebSocketDisconnect: + await ws_manager.disconnect(websocket) diff --git a/bbit_ai/app/service/RabbitMQ.py b/bbit_ai/app/service/RabbitMQ.py new file mode 100644 index 0000000..69b0345 --- /dev/null +++ b/bbit_ai/app/service/RabbitMQ.py @@ -0,0 +1,98 @@ +# consumer.py +import asyncio +import json + +import aio_pika + +from config.rabbitMQ import * +from models.AnalysisRequest import AnalysisRequest +from models.SentinelRecordRequest import SentinelRecordRequest +from service.vision import process_vehicle_animal_image + + +async def mq_new_analysis_test(req: dict): + """将分析请求发送到 RabbitMQ 队列(异步版)""" + connection = await aio_pika.connect_robust( + f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{RABBIT_VHOST}" + ) + + async with connection: + channel = await connection.channel() + # 声明队列,确保队列存在 + queue = await channel.declare_queue(QUEUE_NAME, durable=True) + + message_body = json.dumps(req) + message = aio_pika.Message( + body=message_body.encode(), + delivery_mode=aio_pika.DeliveryMode.PERSISTENT, # 持久化 + ) + + await channel.default_exchange.publish(message, routing_key=QUEUE_NAME) + + +async def mq_pull_analysis_async_test(): + """ + 从队列拉取分析任务并处理 + process_func: 一个函数,接收 AnalysisRequest 对象处理分析逻辑 + """ + connection = await aio_pika.connect_robust( + f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{RABBIT_VHOST}" + ) + async with connection: + queue_name = QUEUE_NAME + channel = await connection.channel() + await channel.set_qos(prefetch_count=1) + queue = await channel.declare_queue(queue_name, durable=True) + + async with queue.iterator() as queue_iter: + async for message in queue_iter: + async with message.process(): + data = json.loads(message.body) + req = AnalysisRequest(**data) + print(f"收到任务: {req}") + await asyncio.sleep(5) # 模拟处理 + print(f"完成任务: {req}") + + +async def sentinel_new_analysis(req: SentinelRecordRequest): + """将分析请求发送到 RabbitMQ 队列(异步版)""" + connection = await aio_pika.connect_robust( + f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{SENTINEL_VHOST}" + ) + + async with connection: + channel = await connection.channel() + # 声明队列,确保队列存在 + queue = await channel.declare_queue(QUEUE_NAME, durable=True) + + message_body = json.dumps(req.model_dump()) + message = aio_pika.Message( + body=message_body.encode(), + delivery_mode=aio_pika.DeliveryMode.PERSISTENT, # 持久化 + ) + + await channel.default_exchange.publish(message, routing_key=QUEUE_NAME) + + +async def sentinel_pull_analysis_async(): + """ + 从队列拉取分析任务并处理 + process_func: 一个函数,接收 AnalysisRequest 对象处理分析逻辑 + """ + connection = await aio_pika.connect_robust( + f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{SENTINEL_VHOST}" + ) + async with connection: + queue_name = QUEUE_NAME + channel = await connection.channel() + await channel.set_qos(prefetch_count=1) + queue = await channel.declare_queue(queue_name, durable=True) + + async with queue.iterator() as queue_iter: + async for message in queue_iter: + async with message.process(): + data = json.loads(message.body) + req = SentinelRecordRequest(**data) + print(f"收到任务: {req}") + await process_vehicle_animal_image(req) # 处理 + print(f"完成任务: {req}") diff --git a/bbit_ai/app/service/vision.py b/bbit_ai/app/service/vision.py index 7ae5eb6..e03099c 100644 --- a/bbit_ai/app/service/vision.py +++ b/bbit_ai/app/service/vision.py @@ -4,10 +4,15 @@ from uuid import UUID import config.minIO as minIO import db.postgres as pg from agent.licenseImageAgent import get_license_response +from agent.vehicleImageAgent import get_vehicle_response from config.minIO import minio_client from config.yolo import YOLOSingleton +from db.postgres import get_dept_id_by_iot_user_name, get_dept_ids_by_dept_id +from db.postgres.sentinel import update_sentinel_record from llm.ticketLLM import * from llm.ticketLLMv2 import get_ticket_response_v2 +from models.SentinelRecordRequest import SentinelRecordRequest +from routers.WS import ws_manager def process_ticket_image( @@ -178,3 +183,27 @@ def process_silkworm_cocoon_image( "postprocess_time_ms": speed_json.get("postprocess"), "details": results_json.get("class_counts"), } + + +async def process_vehicle_animal_image( + data: SentinelRecordRequest, +): + # 通过设备id获得组织id + dept_id = get_dept_id_by_iot_user_name(data.DeviceId) + # 得到动物类型 + oss_url = minIO.get_temp_url("sentinel", "vehicle_image/" + data.VehicleImage) + analysis_result = await get_vehicle_response(oss_url) + livestock_type = analysis_result.get("livestock_type", "") + remark = analysis_result.get("remark", "") + + available_departments = get_dept_ids_by_dept_id(dept_id) + + await ws_manager.noticeSentinel( + { + "content": f"载有{livestock_type}的车辆即将进入关卡,请准备检查", + "type": "vehicle_alert", + }, + available_departments, + ) + # 保存到数据库 + return update_sentinel_record(data.Id, livestock_type, remark, dept_id)