104 lines
3.5 KiB
Python
104 lines
3.5 KiB
Python
# consumer.py
|
|
|
|
import asyncio
|
|
import json
|
|
import traceback
|
|
|
|
import aio_pika
|
|
|
|
from config.rabbitMQ import *
|
|
from models.SentinelRecordRequest import SentinelRecordRequest
|
|
from service.vision import (
|
|
process_all_vehicle_animal_image,
|
|
)
|
|
|
|
|
|
class MQClient:
|
|
"""RabbitMQ 单例客户端,支持生产和消费"""
|
|
|
|
_instance = None
|
|
|
|
def __new__(cls):
|
|
if cls._instance is None:
|
|
cls._instance = super().__new__(cls)
|
|
return cls._instance
|
|
|
|
def __init__(self):
|
|
self._connection = None
|
|
self._channel = None
|
|
self._consumer_tasks = []
|
|
|
|
# ---------------- 连接初始化 ----------------
|
|
async def init(self, prefetch_count: int = 10):
|
|
"""启动时初始化连接和通道"""
|
|
if self._connection is None:
|
|
self._connection = await aio_pika.connect_robust(
|
|
f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{SENTINEL_VHOST}"
|
|
)
|
|
self._channel = await self._connection.channel()
|
|
await self._channel.set_qos(prefetch_count=prefetch_count)
|
|
|
|
# ---------------- 发布消息 ----------------
|
|
async def publish(self, queue_name: str, message_body: str):
|
|
"""向指定队列发送消息"""
|
|
if self._channel is None:
|
|
raise RuntimeError("MQClient 未初始化")
|
|
# 队列幂等声明
|
|
queue = await self._channel.declare_queue(queue_name, durable=True)
|
|
message = aio_pika.Message(
|
|
body=message_body.encode(), delivery_mode=aio_pika.DeliveryMode.PERSISTENT
|
|
)
|
|
await self._channel.default_exchange.publish(message, routing_key=queue_name)
|
|
|
|
async def send_all_analysis(self, req: SentinelRecordRequest):
|
|
await self.publish(
|
|
SENTINEL_ANALYSIS_ALL_QUEUE_NAME, json.dumps(req.model_dump())
|
|
)
|
|
|
|
# ---------------- 消费消息 ----------------
|
|
async def consume_queue(self, queue_name: str, process_func):
|
|
"""
|
|
持续消费队列
|
|
process_func: async function 接收 dict 或 Request 对象
|
|
"""
|
|
if self._channel is None:
|
|
raise RuntimeError("MQClient 未初始化")
|
|
|
|
queue = await self._channel.declare_queue(queue_name, durable=True)
|
|
|
|
async with queue.iterator() as queue_iter:
|
|
async for message in queue_iter:
|
|
async with message.process():
|
|
try:
|
|
body = message.body.decode()
|
|
data = json.loads(body)
|
|
await process_func(data)
|
|
except Exception as e:
|
|
print(f"[MQ Consume Error] {e}")
|
|
traceback.print_exc()
|
|
|
|
# ---------------- 启动全局分析消费者 ----------------
|
|
async def start_all_consumer(self):
|
|
async def _process(data: dict):
|
|
req = SentinelRecordRequest(**data)
|
|
await process_all_vehicle_animal_image(req)
|
|
print(f"完成全局分析任务: {req}")
|
|
|
|
task = asyncio.create_task(
|
|
self.consume_queue(SENTINEL_ANALYSIS_ALL_QUEUE_NAME, _process)
|
|
)
|
|
self._consumer_tasks.append(task)
|
|
|
|
# ---------------- 关闭连接 ----------------
|
|
async def close(self):
|
|
for task in self._consumer_tasks:
|
|
task.cancel()
|
|
if self._channel:
|
|
await self._channel.close()
|
|
if self._connection:
|
|
await self._connection.close()
|
|
|
|
|
|
# ---------------- 全局单例 ----------------
|
|
mq_client = MQClient()
|