后端更新

This commit is contained in:
BBIT-Kai
2026-03-26 17:48:20 +08:00
parent 4c2bcd7dce
commit 0c2859b0db
22 changed files with 1336 additions and 213 deletions
+8 -9
View File
@@ -28,20 +28,21 @@ def send_analyze(state: State, prompt_text: str):
def analysis(state: State):
state["content"] = send_analyze(
state,
state, # todo
"""
提示词示例
你是一个图像分析助手。现在给你一张车的侧身照片,请你从图中分析车上运输的牲畜种类。
你是一个图像分析助手。现在给你一张道路图片,你需要观察远离你的第二根车道上,画面中心的车辆,请你从分析该车辆。
要求:
1. 牲畜种类可能是:牛、羊、猪、鸡、鸭、鹅。
2. 如果图中无法判断牲畜类型,请在备注字段 remark 中写明“无法识别”或你观察到的情况。
1. have_animal 字段中填写 true
2. livestock_type 字段中填写 “货物种类”,例如 牛、羊、猪、鸡、鸭、鹅、钢管、土渣等任何你观察到的
3. remark 字段 需要你简短的描述该车辆车身情况,例如什么颜色的车身,你需要完整的尽你所能形容一下。
3. 不允许输出多余文字,直接返回 JSON。
JSON 示例格式:
{
"livestock_type": "<牲畜种类>", // 如果能识别就填牛/羊/猪/鸡/鸭/鹅
"remark": "<备注>" // 如果无法识别,写明原因;否则可留空
"have_animal": true,
"livestock_type": "“测试数据:” + 5位随机数", // 例如 测试数据34223
"remark": "<备注>" // 车身描述
}
请确保输出的 JSON 可以被严格解析。
""",
@@ -59,8 +60,6 @@ graph = workflow.compile()
# 执行函数
async def get_vehicle_response(image_url: str):
final_state = graph.invoke(
{
View File
+433
View File
@@ -0,0 +1,433 @@
import os
from dataclasses import dataclass
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
from ai.plate.plate_recognition.double_plate_split_merge import get_split_merge
from ai.plate.plate_recognition.plate_rec import get_plate_result
@dataclass
class PlateResult:
plate: str
color: str
def collect_files(root_path):
file_list = []
for root, _, files in os.walk(root_path):
for name in files:
file_list.append(os.path.join(root, name))
return sorted(file_list)
def four_point_transform(image, pts):
rect = pts.astype(np.float32)
(tl, tr, br, bl) = rect
width_a = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
width_b = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
max_width = max(int(width_a), int(width_b))
height_a = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
height_b = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
max_height = max(int(height_a), int(height_b))
dst = np.array(
[
[0, 0],
[max_width - 1, 0],
[max_width - 1, max_height - 1],
[0, max_height - 1],
],
dtype=np.float32,
)
matrix = cv2.getPerspectiveTransform(rect, dst)
return cv2.warpPerspective(image, matrix, (max_width, max_height))
def load_model(weights, device):
model = YOLO(weights)
model.to(device)
return model
def det_rec_plate(img_ori, detect_model, plate_rec_model, device, conf=0.3, iou=0.5):
result_list = []
results = detect_model(img_ori, conf=conf, iou=iou, verbose=False)
for result in results:
boxes = result.boxes
keypoints = result.keypoints
if len(boxes) == 0 or keypoints is None:
continue
kpts_xy = keypoints.xy
num_det = min(len(boxes), len(kpts_xy))
for idx in range(num_det):
box = boxes.xyxy[idx].cpu().numpy()
det_conf = float(boxes.conf[idx])
plate_type = int(boxes.cls[idx])
landmarks = kpts_xy[idx].cpu().numpy().astype(np.int64)
roi_img = four_point_transform(img_ori, landmarks)
if plate_type == 1:
roi_img = get_split_merge(roi_img)
plate_number, _, plate_color, color_conf = get_plate_result(
roi_img, device, plate_rec_model, is_color=True
)
result_list.append(
{
"plate_no": plate_number,
"plate_color": plate_color,
"rect": [int(v) for v in box],
"detect_conf": det_conf,
"landmarks": landmarks.tolist(),
"roi_height": roi_img.shape[0],
"color_conf": color_conf,
"plate_type": plate_type, # 0: 单层, 1: 双层
}
)
return result_list
def _clamp(value, low, high):
return max(low, min(high, value))
def _normalize_rect(rect):
if not rect or len(rect) < 4:
return None
x1 = int(round(float(rect[0])))
y1 = int(round(float(rect[1])))
x2 = int(round(float(rect[2])))
y2 = int(round(float(rect[3])))
left = min(x1, x2)
top = min(y1, y2)
right = max(x1, x2)
bottom = max(y1, y2)
if right <= left or bottom <= top:
return None
return [left, top, right, bottom]
def _get_plate_theme(plate_color):
theme_map = {
"蓝色": {
"bg": (72, 33, 6),
"border": (250, 165, 96),
"text": (254, 242, 224),
"glow": (246, 130, 59),
},
"黄色": {
"bg": (0, 49, 74),
"border": (21, 204, 250),
"text": (195, 249, 254),
"glow": (8, 179, 234),
},
"绿色": {
"bg": (27, 53, 4),
"border": (128, 222, 74),
"text": (231, 252, 220),
"glow": (94, 197, 34),
},
"白色": {
"bg": (68, 50, 38),
"border": (240, 232, 226),
"text": (250, 250, 248),
"glow": (184, 163, 148),
},
"黑色": {
"bg": (20, 12, 8),
"border": (184, 163, 148),
"text": (240, 232, 226),
"glow": (139, 116, 100),
},
}
return theme_map.get(
plate_color,
{
"bg": (43, 24, 8),
"border": (248, 189, 56),
"text": (255, 248, 232),
"glow": (200, 140, 32),
},
)
def _draw_alpha_rect(img, x1, y1, x2, y2, color, alpha=0.75):
h, w = img.shape[:2]
x1 = _clamp(x1, 0, w - 1)
y1 = _clamp(y1, 0, h - 1)
x2 = _clamp(x2, 0, w)
y2 = _clamp(y2, 0, h)
if x2 <= x1 or y2 <= y1:
return
roi = img[y1:y2, x1:x2]
overlay = np.full_like(roi, color, dtype=np.uint8)
cv2.addWeighted(overlay, alpha, roi, 1 - alpha, 0, roi)
def _measure_text(text, text_size=16):
try:
font = ImageFont.truetype(get_font_file_path(), text_size, encoding="utf-8")
left, top, right, bottom = font.getbbox(text)
return right - left, bottom - top
except Exception:
print("字体文件加载失败")
size, baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.45, 1)
return size[0], size[1] + baseline
def _draw_glow_border(img, x1, y1, x2, y2, border_color, glow_color):
h, w = img.shape[:2]
x1 = _clamp(x1, 0, w - 1)
y1 = _clamp(y1, 0, h - 1)
x2 = _clamp(x2, 0, w - 1)
y2 = _clamp(y2, 0, h - 1)
if x2 <= x1 or y2 <= y1:
return
glow_layer = np.zeros_like(img)
cv2.rectangle(glow_layer, (x1, y1), (x2, y2), glow_color, 2)
glow_layer = cv2.GaussianBlur(glow_layer, (0, 0), sigmaX=1.6, sigmaY=1.6)
cv2.addWeighted(glow_layer, 0.45, img, 1.0, 0, img)
cv2.rectangle(img, (x1, y1), (x2, y2), border_color, 1)
def _draw_tech_box(img, x1, y1, x2, y2, border_color, glow_color, track_id=None):
h, w = img.shape[:2]
x1 = _clamp(x1, 0, w - 1)
y1 = _clamp(y1, 0, h - 1)
x2 = _clamp(x2, 0, w - 1)
y2 = _clamp(y2, 0, h - 1)
if x2 <= x1 or y2 <= y1:
return
bw = x2 - x1
bh = y2 - y1
diag = float(np.hypot(bw, bh))
base_thick = _clamp(int(round(diag / 70.0)), 2, 5)
glow_sigma = _clamp(diag / 55.0, 1.2, 3.6)
glow_layer = np.zeros_like(img)
cv2.rectangle(glow_layer, (x1, y1), (x2, y2), glow_color, max(1, base_thick - 1))
glow_layer = cv2.GaussianBlur(
glow_layer, (0, 0), sigmaX=glow_sigma, sigmaY=glow_sigma
)
cv2.addWeighted(glow_layer, 0.48, img, 1.0, 0, img)
cv2.rectangle(img, (x1, y1), (x2, y2), border_color, max(1, base_thick - 1))
corner_len = _clamp(int(round(min(bw, bh) * 0.28)), 8, 20)
t = base_thick
cv2.line(img, (x1, y1), (x1 + corner_len, y1), border_color, t)
cv2.line(img, (x1, y1), (x1, y1 + corner_len), border_color, t)
cv2.line(img, (x2, y1), (x2 - corner_len, y1), border_color, t)
cv2.line(img, (x2, y1), (x2, y1 + corner_len), border_color, t)
cv2.line(img, (x1, y2), (x1 + corner_len, y2), border_color, t)
cv2.line(img, (x1, y2), (x1, y2 - corner_len), border_color, t)
cv2.line(img, (x2, y2), (x2 - corner_len, y2), border_color, t)
cv2.line(img, (x2, y2), (x2, y2 - corner_len), border_color, t)
if track_id is not None:
badge = "T%02d" % track_id
badge_w_txt, badge_h_txt = _measure_text(badge, text_size=12)
pad_x = 5
pad_y = 3
badge_w = badge_w_txt + pad_x * 2
badge_h = badge_h_txt + pad_y * 2
bx2 = _clamp(x2, badge_w + 2, w - 2)
by1 = _clamp(y1 - badge_h - 2, 2, h - badge_h - 2)
bx1 = bx2 - badge_w
by2 = by1 + badge_h
_draw_alpha_rect(img, bx1, by1, bx2, by2, (18, 18, 18), alpha=0.65)
cv2.rectangle(img, (bx1, by1), (bx2, by2), border_color, 1)
text_x = bx1 + pad_x
text_y = by1 + max(1, int((badge_h - badge_h_txt) / 2))
img[:] = cv2ImgAddText(img, badge, text_x, text_y, border_color, 12)
def cv2ImgAddText(img, text, left, top, textColor=(0, 255, 0), textSize=20):
if isinstance(img, np.ndarray): # 判断是否OpenCV图片类型
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img)
fontText = ImageFont.truetype(get_font_file_path(), textSize, encoding="utf-8")
draw.text((left, top), text, textColor, font=fontText)
return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
def get_font_file_path():
# win
# font_path = r"C:\Users\BBIT\Desktop\yolo26-plate-main\fonts\platech.ttf"
# linux
font_path = r"/app/models/sentinel/platech.ttf"
return font_path
def _draw_tech_landmark(img, x, y, border_color, glow_color):
h, w = img.shape[:2]
x = _clamp(int(round(x)), 0, w - 1)
y = _clamp(int(round(y)), 0, h - 1)
glow_layer = np.zeros_like(img)
cv2.circle(glow_layer, (x, y), 5, glow_color, -1)
glow_layer = cv2.GaussianBlur(glow_layer, (0, 0), sigmaX=1.2, sigmaY=1.2)
cv2.addWeighted(glow_layer, 0.5, img, 1.0, 0, img)
cv2.circle(img, (x, y), 2, border_color, -1)
cv2.circle(img, (x, y), 4, border_color, 1)
def _plate_width_from_landmarks(landmarks, fallback_width):
if not landmarks or len(landmarks) < 4:
return float(fallback_width)
try:
p0 = np.array(landmarks[0], dtype=np.float32)
p1 = np.array(landmarks[1], dtype=np.float32)
p2 = np.array(landmarks[2], dtype=np.float32)
p3 = np.array(landmarks[3], dtype=np.float32)
top_w = float(np.linalg.norm(p1 - p0))
bottom_w = float(np.linalg.norm(p2 - p3))
width = (top_w + bottom_w) / 2.0
if np.isfinite(width) and width > 1:
return width
except Exception:
pass
return float(fallback_width)
def _plate_height_from_landmarks(landmarks, fallback_height):
if not landmarks or len(landmarks) < 4:
return float(fallback_height)
try:
p0 = np.array(landmarks[0], dtype=np.float32)
p1 = np.array(landmarks[1], dtype=np.float32)
p2 = np.array(landmarks[2], dtype=np.float32)
p3 = np.array(landmarks[3], dtype=np.float32)
left_h = float(np.linalg.norm(p3 - p0))
right_h = float(np.linalg.norm(p2 - p1))
height = (left_h + right_h) / 2.0
if np.isfinite(height) and height > 1:
return height
except Exception:
pass
return float(fallback_height)
def _fit_font_size(text, max_w, max_h, min_size=10, max_size=24):
if max_w <= 0 or max_h <= 0:
return min_size
for size in range(max_size, min_size - 1, -1):
tw, th = _measure_text(text, text_size=size)
if tw <= max_w and th <= max_h:
return size
return min_size
def draw_result(orgimg, result_list):
landmark_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
result_str = []
img_h, img_w = orgimg.shape[:2]
for idx, result in enumerate(result_list, start=1):
raw_rect = _normalize_rect(result.get("rect"))
if raw_rect is None:
continue
x1, y1, x2, y2 = raw_rect
w = x2 - x1
h = y2 - y1
padding_w = int(round(0.05 * w))
padding_h = int(round(0.11 * h))
rx1 = _clamp(x1 - padding_w, 0, img_w - 1)
ry1 = _clamp(y1 - padding_h, 0, img_h - 1)
rx2 = _clamp(x2 + padding_w, 0, img_w - 1)
ry2 = _clamp(y2 + padding_h, 0, img_h - 1)
landmarks = result.get("landmarks", [])
plate_no = result.get("plate_no", "")
plate_color = result.get("plate_color", "")
if result.get("plate_type", 0) == 1:
result_p = "%s %s双层" % (plate_no, plate_color)
else:
result_p = "%s %s" % (plate_no, plate_color)
result_str.append(result_p)
theme = _get_plate_theme(plate_color)
for i in range(min(4, len(landmarks))):
point = landmarks[i]
if len(point) < 2:
continue
point_color = landmark_colors[i]
_draw_tech_landmark(orgimg, point[0], point[1], point_color, point_color)
_draw_tech_box(
orgimg, rx1, ry1, rx2, ry2, theme["border"], theme["glow"], track_id=idx
)
label = "%s | %s" % (plate_no, plate_color)
plate_w = _plate_width_from_landmarks(landmarks, rx2 - rx1)
plate_h = _plate_height_from_landmarks(landmarks, ry2 - ry1)
pre_card_h = _clamp(int(round(plate_h)), 24, min(110, img_h - 4))
pre_pad_y = _clamp(int(round(pre_card_h * 0.16)), 3, 10)
pre_inner_h = max(8, pre_card_h - pre_pad_y * 2)
pre_max_font = _clamp(int(round(pre_card_h * 0.72)), 14, 44)
pre_min_font = _clamp(int(round(pre_card_h * 0.42)), 10, pre_max_font)
pre_font_size = _fit_font_size(
label, 4096, pre_inner_h, min_size=pre_min_font, max_size=pre_max_font
)
pre_text_w, _ = _measure_text(label, text_size=pre_font_size)
min_w_by_text = pre_text_w + 20
base_w_by_plate = int(round(plate_w * 1.05))
card_w = max(90, base_w_by_plate, min_w_by_text)
card_w = min(card_w, img_w - 8)
card_h = pre_card_h
card_pad_x = _clamp(int(round(card_w * 0.08)), 8, 18)
card_pad_y = _clamp(int(round(card_h * 0.16)), 3, 10)
card_x = int(rx1 + (rx2 - rx1 - card_w) / 2)
card_x = _clamp(card_x, 4, max(4, img_w - card_w - 4))
card_y = ry1 - card_h - 2
if card_y < 2:
card_y = _clamp(ry1 + 2, 2, max(2, img_h - card_h - 2))
_draw_alpha_rect(
orgimg,
card_x,
card_y,
card_x + card_w,
card_y + card_h,
theme["bg"],
alpha=0.78,
)
_draw_glow_border(
orgimg,
card_x,
card_y,
card_x + card_w,
card_y + card_h,
theme["border"],
theme["glow"],
)
inner_w = max(8, card_w - card_pad_x * 2)
inner_h = max(8, card_h - card_pad_y * 2)
dynamic_max_font = _clamp(int(round(card_h * 0.72)), 14, 44)
dynamic_min_font = _clamp(int(round(card_h * 0.42)), 10, dynamic_max_font)
font_size = _fit_font_size(
label,
inner_w,
inner_h,
min_size=dynamic_min_font,
max_size=dynamic_max_font,
)
text_w, text_h = _measure_text(label, text_size=font_size)
text_x = card_x + max(card_pad_x, int((card_w - text_w) / 2))
text_y = card_y + max(card_pad_y - 1, int((card_h - text_h) / 2))
orgimg = cv2ImgAddText(orgimg, label, text_x, text_y, theme["text"], font_size)
return orgimg, result_str
+98
View File
@@ -0,0 +1,98 @@
import os
from dataclasses import dataclass
import cv2
import torch
from ultralytics import YOLO
from ai.plate.detect_rec_plate import det_rec_plate, draw_result
from ai.plate.plate_recognition.plate_rec import init_model
from config.yolo import logger
@dataclass
class PlateInfo:
plate_no: str
plate_color: str
result_img_path: str
class PlateRecognizer:
_instance = None
_ready = False
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(PlateRecognizer, cls).__new__(cls)
return cls._instance
def __init__(
self,
detect_model_path,
rec_model_path,
output_dir="result",
device="cuda" if torch.cuda.is_available() else "cpu",
):
if hasattr(self, "_initialized") and self._initialized:
return
self.device = torch.device(device)
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
try:
# 模型加载
self.detect_model = YOLO(detect_model_path)
self.detect_model.to(self.device)
self.detect_model.eval()
self.plate_rec_model = init_model(
self.device, rec_model_path, is_color=True
)
self._initialized = True
self._ready = True
except Exception as e:
self.detect_model = None
self._ready = False
logger.error(f"❌车牌 YOLO 模型加载失败: {e}")
def analyze_image(self, image_path):
img = cv2.imread(image_path)
if not self._ready or img is None:
return []
# 检测与识别
result_list = det_rec_plate(
img, self.detect_model, self.plate_rec_model, self.device
)
# 可视化 & 保存
vis_img, _ = draw_result(img, result_list)
result_img_path = os.path.join(self.output_dir, os.path.basename(image_path))
cv2.imwrite(result_img_path, vis_img)
# 构建返回结果
results = [
PlateInfo(
plate_no=r["plate_no"],
plate_color=r["plate_color"],
result_img_path=result_img_path,
)
for r in result_list
]
return results
# 初始化一次
recognizer = PlateRecognizer(
detect_model_path="/app/models/sentinel/yolo26s-plate-detect.pt",
rec_model_path="/app/models/sentinel/plate_rec_color.pth",
output_dir="result",
)
# base_dir = Path(r"C:\Users\BBIT\Desktop\yolo26-plate-main")
#
# recognizer = PlateRecognizer(
# detect_model_path=str(base_dir / "weights" / "yolo26s-plate-detect.pt"),
# rec_model_path=str(base_dir / "weights" / "plate_rec_color.pth"),
# output_dir="result",
# )
@@ -0,0 +1,15 @@
import os
import cv2
import numpy as np
def get_split_merge(img):
h,w,c = img.shape
img_upper = img[0:int(5/12*h),:]
img_lower = img[int(1/3*h):,:]
img_upper = cv2.resize(img_upper,(img_lower.shape[1],img_lower.shape[0]))
new_img = np.hstack((img_upper,img_lower))
return new_img
if __name__=="__main__":
img = cv2.imread("double_plate/tmp8078.png")
new_img =get_split_merge(img)
cv2.imwrite("double_plate/new.jpg",new_img)
@@ -0,0 +1,203 @@
import torch.nn as nn
import torch
class myNet_ocr(nn.Module):
def __init__(self,cfg=None,num_classes=78,export=False):
super(myNet_ocr, self).__init__()
if cfg is None:
cfg =[32,32,64,64,'M',128,128,'M',196,196,'M',256,256]
# cfg =[32,32,'M',64,64,'M',128,128,'M',256,256]
self.feature = self.make_layers(cfg, True)
self.export = export
# self.classifier = nn.Linear(cfg[-1], num_classes)
# self.loc = nn.MaxPool2d((2, 2), (5, 1), (0, 1),ceil_mode=True)
# self.loc = nn.AvgPool2d((2, 2), (5, 2), (0, 1),ceil_mode=False)
self.loc = nn.MaxPool2d((5, 2), (1, 1),(0,1),ceil_mode=False)
self.newCnn=nn.Conv2d(cfg[-1],num_classes,1,1)
# self.newBn=nn.BatchNorm2d(num_classes)
def make_layers(self, cfg, batch_norm=False):
layers = []
in_channels = 3
for i in range(len(cfg)):
if i == 0:
conv2d =nn.Conv2d(in_channels, cfg[i], kernel_size=5,stride =1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = cfg[i]
else :
if cfg[i] == 'M':
layers += [nn.MaxPool2d(kernel_size=3, stride=2,ceil_mode=True)]
else:
conv2d = nn.Conv2d(in_channels, cfg[i], kernel_size=3, padding=(1,1),stride =1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = cfg[i]
return nn.Sequential(*layers)
def forward(self, x):
x = self.feature(x)
x=self.loc(x)
x=self.newCnn(x)
# x=self.newBn(x)
if self.export:
conv = x.squeeze(2) # b *512 * width
conv = conv.transpose(2,1) # [w, b, c]
# conv =conv.argmax(dim=2)
return conv
else:
b, c, h, w = x.size()
assert h == 1, "the height of conv must be 1"
conv = x.squeeze(2) # b *512 * width
conv = conv.permute(2, 0, 1) # [w, b, c]
# output = F.log_softmax(self.rnn(conv), dim=2)
output = torch.softmax(conv, dim=2)
return output
myCfg = [32,'M',64,'M',96,'M',128,'M',256]
class myNet(nn.Module):
def __init__(self,cfg=None,num_classes=3):
super(myNet, self).__init__()
if cfg is None:
cfg = myCfg
self.feature = self.make_layers(cfg, True)
self.classifier = nn.Linear(cfg[-1], num_classes)
def make_layers(self, cfg, batch_norm=False):
layers = []
in_channels = 3
for i in range(len(cfg)):
if i == 0:
conv2d =nn.Conv2d(in_channels, cfg[i], kernel_size=5,stride =1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = cfg[i]
else :
if cfg[i] == 'M':
layers += [nn.MaxPool2d(kernel_size=3, stride=2,ceil_mode=True)]
else:
conv2d = nn.Conv2d(in_channels, cfg[i], kernel_size=3, padding=1,stride =1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = cfg[i]
return nn.Sequential(*layers)
def forward(self, x):
x = self.feature(x)
x = nn.AvgPool2d(kernel_size=3, stride=1)(x)
x = x.view(x.size(0), -1)
y = self.classifier(x)
return y
class MyNet_color(nn.Module):
def __init__(self, class_num=6):
super(MyNet_color, self).__init__()
self.class_num = class_num
self.backbone = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(5, 5), stride=(1, 1)), # 0
torch.nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2)),
nn.Dropout(0),
nn.Flatten(),
nn.Linear(480, 64),
nn.Dropout(0),
nn.ReLU(),
nn.Linear(64, class_num),
nn.Dropout(0),
nn.Softmax(1)
)
def forward(self, x):
logits = self.backbone(x)
return logits
class myNet_ocr_color(nn.Module):
def __init__(self,cfg=None,num_classes=78,export=False,color_num=None):
super(myNet_ocr_color, self).__init__()
if cfg is None:
cfg =[32,32,64,64,'M',128,128,'M',196,196,'M',256,256]
# cfg =[32,32,'M',64,64,'M',128,128,'M',256,256]
self.feature = self.make_layers(cfg, True)
self.export = export
self.color_num=color_num
self.conv_out_num=12 #颜色第一个卷积层输出通道12
if self.color_num:
self.conv1=nn.Conv2d(cfg[-1],self.conv_out_num,kernel_size=3,stride=2)
self.bn1=nn.BatchNorm2d(self.conv_out_num)
self.relu1=nn.ReLU(inplace=True)
self.gap =nn.AdaptiveAvgPool2d(output_size=1)
self.color_classifier=nn.Conv2d(self.conv_out_num,self.color_num,kernel_size=1,stride=1)
self.color_bn = nn.BatchNorm2d(self.color_num)
self.flatten = nn.Flatten()
self.loc = nn.MaxPool2d((5, 2), (1, 1),(0,1),ceil_mode=False)
self.newCnn=nn.Conv2d(cfg[-1],num_classes,1,1)
# self.newBn=nn.BatchNorm2d(num_classes)
def make_layers(self, cfg, batch_norm=False):
layers = []
in_channels = 3
for i in range(len(cfg)):
if i == 0:
conv2d =nn.Conv2d(in_channels, cfg[i], kernel_size=5,stride =1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = cfg[i]
else :
if cfg[i] == 'M':
layers += [nn.MaxPool2d(kernel_size=3, stride=2,ceil_mode=True)]
else:
conv2d = nn.Conv2d(in_channels, cfg[i], kernel_size=3, padding=(1,1),stride =1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = cfg[i]
return nn.Sequential(*layers)
def forward(self, x):
x = self.feature(x)
if self.color_num:
x_color=self.conv1(x)
x_color=self.bn1(x_color)
x_color =self.relu1(x_color)
x_color = self.color_classifier(x_color)
x_color = self.color_bn(x_color)
x_color =self.gap(x_color)
x_color = self.flatten(x_color)
x=self.loc(x)
x=self.newCnn(x)
if self.export:
conv = x.squeeze(2) # b *512 * width
conv = conv.transpose(2,1) # [w, b, c]
if self.color_num:
return conv,x_color
return conv
else:
b, c, h, w = x.size()
assert h == 1, "the height of conv must be 1"
conv = x.squeeze(2) # b *512 * width
conv = conv.permute(2, 0, 1) # [w, b, c]
output = F.log_softmax(conv, dim=2)
if self.color_num:
return output,x_color
return output
if __name__ == '__main__':
x = torch.randn(1,3,48,216)
model = myNet_ocr(num_classes=78,export=True)
out = model(x)
print(out.shape)
@@ -0,0 +1,134 @@
import os
import time
import cv2
import numpy as np
import torch
from ai.plate.plate_recognition.plateNet import myNet_ocr_color
def cv_imread(path): # 可以读取中文路径的图片
img = cv2.imdecode(np.fromfile(path, dtype=np.uint8), -1)
return img
def allFilePath(rootPath, allFIleList):
fileList = os.listdir(rootPath)
for temp in fileList:
if os.path.isfile(os.path.join(rootPath, temp)):
if temp.endswith(".jpg") or temp.endswith(".png") or temp.endswith(".JPG"):
allFIleList.append(os.path.join(rootPath, temp))
else:
allFilePath(os.path.join(rootPath, temp), allFIleList)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
color = ["黑色", "蓝色", "绿色", "白色", "黄色"]
plateName = r"#京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新学警港澳挂使领民航危0123456789ABCDEFGHJKLMNPQRSTUVWXYZ险品"
mean_value, std_value = (0.588, 0.193)
def decodePlate(preds):
pre = 0
newPreds = []
index = []
for i in range(len(preds)):
if preds[i] != 0 and preds[i] != pre:
newPreds.append(preds[i])
index.append(i)
pre = preds[i]
return newPreds, index
def image_processing(img, device):
img = cv2.resize(img, (168, 48))
img = np.reshape(img, (48, 168, 3))
# normalize
img = img.astype(np.float32)
img = (img / 255.0 - mean_value) / std_value
img = img.transpose([2, 0, 1])
img = torch.from_numpy(img)
img = img.to(device)
img = img.view(1, *img.size())
return img
def get_plate_result(img, device, model, is_color=False):
input = image_processing(img, device)
if is_color: # 是否识别颜色
preds, color_preds = model(input)
color_preds = torch.softmax(color_preds, dim=-1)
color_conf, color_index = torch.max(color_preds, dim=-1)
color_conf = color_conf.item()
else:
preds = model(input)
preds = torch.softmax(preds, dim=-1)
prob, index = preds.max(dim=-1)
index = index.view(-1).detach().cpu().numpy()
prob = prob.view(-1).detach().cpu().numpy()
# preds=preds.view(-1).detach().cpu().numpy()
newPreds, new_index = decodePlate(index)
prob = prob[new_index]
plate = ""
for i in newPreds:
plate += plateName[i]
# if not (plate[0] in plateName[1:44] ):
# return ""
if is_color:
return (
plate,
prob,
color[color_index],
color_conf,
) # 返回车牌号以及每个字符的概率,以及颜色,和颜色的概率
else:
return plate, prob
def init_model(device, model_path, is_color=False):
# print( print(sys.path))
# model_path ="plate_recognition/model/checkpoint_61_acc_0.9715.pth"
check_point = torch.load(model_path, map_location=device)
model_state = check_point["state_dict"]
cfg = check_point["cfg"]
color_classes = 0
if is_color:
color_classes = 5 # 颜色类别数
model = myNet_ocr_color(
num_classes=len(plateName), export=True, cfg=cfg, color_num=color_classes
)
model.load_state_dict(model_state, strict=False)
model.to(device)
model.eval()
return model
# model = init_model(device)
if __name__ == "__main__":
model_path = r"weights/plate_rec_color.pth"
image_path = "images/tmp2424.png"
testPath = r"/mnt/Gpan/Mydata/pytorchPorject/CRNN/crnn_plate_recognition/images"
fileList = []
allFilePath(testPath, fileList)
# result = get_plate_result(image_path,device)
# print(result)
is_color = False
model = init_model(device, model_path, is_color=is_color)
right = 0
begin = time.time()
for imge_path in fileList:
img = cv2.imread(imge_path)
if is_color:
plate, _, plate_color, _ = get_plate_result(
img, device, model, is_color=is_color
)
print(plate)
else:
plate, _ = get_plate_result(img, device, model, is_color=is_color)
print(plate, imge_path)
+18 -7
View File
@@ -1,4 +1,5 @@
import asyncio
from contextlib import asynccontextmanager
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
@@ -21,12 +22,25 @@ from routers.Service import serviceRouter
from routers.System import systemRouter
from routers.Vision import visionRouter
from routers.WS import iot_ws_router
from service.RabbitMQ import sentinel_pull_analysis_async
from service.RabbitMQ import (
mq_client,
)
@asynccontextmanager
async def lifespan(app: FastAPI):
# 应用启动时初始化 MQ
await mq_client.init()
# 启动消费者
await mq_client.start_all_consumer()
yield
# 应用关闭时关闭 MQ 连接
if mq_client._connection:
await mq_client._connection.close()
async def ai_lab():
app = FastAPI(title="BBIT_AI")
app = FastAPI(title="BBIT_AI", lifespan=lifespan)
origins = [
"http://localhost:8091", # Vite dev 默认端口
"https://ai.ronsunny.cn:8090",
@@ -71,16 +85,13 @@ async def ai_lab():
async def main():
# 初始化模型
YOLOSingleton.init_model()
# 主干AI实验室FastAPI服务
task_api = asyncio.create_task(ai_lab())
# RabbitMQ服务
task_mq = asyncio.create_task(sentinel_pull_analysis_async())
# 等 HTTP 服务启动后再启动 MQTT
task_mqtt = asyncio.create_task(mqtt_client_runner())
await asyncio.gather(task_api, task_mq, task_mqtt)
await asyncio.gather(task_api, task_mqtt)
# MCP服务-ailab
# endpoint_url_ai_lab = "wss://ai.ronsunny.cn:8090/aimcp/mcp_endpoint/mcp/?token=TsSP9lBq6Oa1WMkachHoS2TtNt4GKV/Gli24pk5Rjpk%3D"
+28 -2
View File
@@ -1,6 +1,7 @@
from datetime import timedelta
from minio import Minio
from minio.commonconfig import CopySource
# MinIO 客户端初始化
minio_client = Minio(
@@ -28,7 +29,7 @@ def get_upload_token(bucket_name, object_name, xpires=timedelta(minutes=15)):
)
def get_temp_url(bucket_name, object_name):
def get_temp_url(bucket_name, object_name, seconds: float = 3600):
# 如果 object_name 为 None 或空字符串,则返回默认图片
if not object_name or not bucket_name:
bucket_name = "system"
@@ -36,7 +37,7 @@ def get_temp_url(bucket_name, object_name):
# 使用 presigned_get_object 获取临时 URL
return minio_client.presigned_get_object(
bucket_name, object_name, expires=timedelta(seconds=3600)
bucket_name, object_name, expires=timedelta(seconds=seconds)
)
@@ -51,3 +52,28 @@ def get_temp_url_dict(bucket_name, object_dict, object_name):
return minio_client.presigned_get_object(
bucket_name, object_dict + "/" + object_name, expires=timedelta(seconds=3600)
)
# 移动文件(实际上是 copy + delete
def move_file(bucket_name, source_object_name, target_object_name):
"""
bucket_name: bucket名称
source_object_name: 原对象路径,例如 folder1/test.jpg
target_object_name: 目标对象路径,例如 folder2/test.jpg
"""
# 复制到新位置
minio_client.copy_object(
bucket_name, target_object_name, CopySource(bucket_name, source_object_name)
)
# 删除原文件
minio_client.remove_object(bucket_name, source_object_name)
# 删除文件
def delete_file(bucket_name, object_name):
"""
bucket_name: bucket名称
object_name: 文件路径,例如 folder/test.jpg
"""
minio_client.remove_object(bucket_name, object_name)
+13 -1
View File
@@ -3,7 +3,19 @@ from utils.GlobalVariable import LOCAL_IP
RABBIT_HOST = LOCAL_IP
RABBIT_USER = "ai_lab"
RABBIT_PASSWORD = "123456"
QUEUE_NAME = "analysis_queue"
RABBIT_VHOST = "bbit_ai"
QUEUE_NAME = "analysis_queue"
SENTINEL_VHOST = "sentinel"
SENTINEL_ANALYSIS_ALL_QUEUE_NAME = "sentinel.analysis_all_queue"
SENTINEL_ANALYSIS_SIDE_QUEUE_NAME = "sentinel.analysis_side_queue"
SENTINEL_ANALYSIS_FRONT_QUEUE_NAME = "sentinel.analysis_front_queue"
SENTINEL_FRONT_REQUEST_QUEUE = "sentinel.front_pic"
def get_sentinel_front_queue_name(device_id):
return f"{SENTINEL_FRONT_REQUEST_QUEUE}.{device_id}"
+1 -1
View File
@@ -99,7 +99,7 @@ class YOLOSingleton:
except Exception as e:
cls._model = None
cls._ready = False
logger.error(f"❌ YOLO 模型加载失败: {e}")
logger.error(f"蚕茧 YOLO 模型加载失败: {e}")
@classmethod
def detect(cls, img_bytes: bytes):
+25 -13
View File
@@ -35,34 +35,46 @@ import time
def reset_all_exchange_status():
"""将所有记录 is_finished 置为 False并随机 position(以当前时间作为随机种子)"""
"""将所有记录 is_finished 置为 False
且 gift_code == 2 的记录强制排在最后,其余随机排序
"""
with pg_pool.getConn() as conn:
with conn.cursor() as cur:
# 获取总记录数
cur.execute("SELECT id FROM annual_meeting_exchange")
ids = [row[0] for row in cur.fetchall()]
# 取出 id 和 gift_code
cur.execute("SELECT id, gift_code FROM annual_meeting_exchange")
rows = cur.fetchall()
# 用当前时间戳作为随机种子
seed = int(time.time() * 1000) # 毫秒级
# 分组
normal_ids = [r[0] for r in rows if r[1] != 2]
tail_ids = [r[0] for r in rows if r[1] == 2]
# 随机种子
seed = int(time.time() * 1000)
random.seed(seed)
# 生成随机顺序的 position
positions = list(range(1, len(ids) + 1))
random.shuffle(positions)
# 只打乱非 gift_code == 2 的部分
random.shuffle(normal_ids)
# 更新每条记录
for record_id, pos in zip(ids, positions):
# 合并顺序:普通在前,gift_code==2 在后
ordered_ids = normal_ids + tail_ids
# 依次更新 sort
for idx, record_id in enumerate(ordered_ids, start=1):
cur.execute(
"""
UPDATE annual_meeting_exchange
SET is_finished = FALSE, sort = %s
WHERE id = %s
""",
(pos, record_id),
(idx, record_id),
)
conn.commit()
return {"updated_count": len(ids), "seed_used": seed}
return {
"updated_count": len(ordered_ids),
"seed_used": seed,
"tail_count": len(tail_ids),
}
def reset_user_status(target_user_id: str):
+38 -37
View File
@@ -1,6 +1,5 @@
from config.minIO import get_temp_url_dict
from config.pgDb import pg_pool
from models.SentinelRecordRequest import SentinelRecordRequest
from utils.MyUtils import format_datetime
@@ -90,6 +89,7 @@ def get_sentinel_record_list_db_page(
r.vehicle_image,
r.livestock_type,
r.livestock_source,
r.license_plate_color,
r.is_inspected,
r.dept_id,
sd.name AS dept_name,
@@ -116,6 +116,7 @@ def get_sentinel_record_list_db_page(
vehicle_image,
livestock_type,
livestock_source,
license_plate_color,
is_inspected,
dept_id,
dept_name,
@@ -130,13 +131,14 @@ def get_sentinel_record_list_db_page(
"license_plate": license_plate,
"vehicle_type": vehicle_type,
"license_plate_image": get_temp_url_dict(
"sentinel", "license_plate", license_plate_image
"sentinel", "vehicle_image_front", license_plate_image
),
"vehicle_image": get_temp_url_dict(
"sentinel", "vehicle_image", vehicle_image
"sentinel", "vehicle_image_side", vehicle_image
),
"livestock_type": livestock_type,
"livestock_source": livestock_source,
"license_plate_color": license_plate_color,
"is_inspected": 1 if is_inspected else 0,
"dept_id": dept_id,
"dept_name": dept_name,
@@ -161,6 +163,7 @@ def get_sentinel_record_by_id(record_id):
r.license_plate_image,
r.vehicle_image,
r.livestock_type,
r.license_plate_color,
r.livestock_source,
r.is_inspected,
r.dept_id,
@@ -188,6 +191,7 @@ def get_sentinel_record_by_id(record_id):
license_plate_image,
vehicle_image,
livestock_type,
license_plate_color,
livestock_source,
is_inspected,
dept_id,
@@ -202,12 +206,13 @@ def get_sentinel_record_by_id(record_id):
"license_plate": license_plate,
"vehicle_type": vehicle_type,
"license_plate_image": get_temp_url_dict(
"sentinel", "license_plate", license_plate_image
"sentinel", "vehicle_image_front", license_plate_image
),
"vehicle_image": get_temp_url_dict(
"sentinel", "vehicle_image", vehicle_image
"sentinel", "vehicle_image_side", vehicle_image
),
"livestock_type": livestock_type,
"license_plate_color": license_plate_color,
"livestock_source": livestock_source,
"is_inspected": 1 if is_inspected else 0,
"dept_id": str(dept_id),
@@ -333,15 +338,30 @@ def delete_sentinel_record_db(id: str) -> int:
return cursor.rowcount
def saveSentinelRecord(data: SentinelRecordRequest) -> str:
def saveSentinelRecord(
id: str,
vehicle_type: str,
vehicle_image: str,
livestock_type: str,
remark: str,
dept_id: str,
license_plate: str,
license_plate_image: str,
license_plate_color: int = 0,
) -> str:
sql = """
INSERT INTO sentinel_records (
id,
vehicle_type,
vehicle_image,
livestock_type,
remark,
dept_id,
license_plate,
license_plate_image,
vehicle_type,
vehicle_image
license_plate_color
)
VALUES (%s, %s, %s, %s)
VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s)
RETURNING id;
"""
@@ -350,36 +370,17 @@ def saveSentinelRecord(data: SentinelRecordRequest) -> str:
cursor.execute(
sql,
(
data.LicensePlate,
data.LicensePlateImage,
data.VehicleType,
data.VehicleImage,
id,
vehicle_type,
vehicle_image,
livestock_type,
remark,
dept_id,
license_plate,
license_plate_image,
license_plate_color,
),
)
new_id = cursor.fetchone()[0]
conn.commit()
return str(new_id)
def update_sentinel_record(
id: str, livestock_type: str, remark: str, dept_id: str
) -> bool:
"""
根据 id 更新 sentinel_records 表中的 livestock_type 和 dept_id
"""
sql = """
UPDATE sentinel_records
SET livestock_type = %s,
remark = %s,
dept_id = %s,
updated_at = now()
WHERE id = %s
RETURNING id;
"""
with pg_pool.getConn() as conn:
with conn.cursor() as cursor:
cursor.execute(sql, (livestock_type, remark, dept_id, id))
record = cursor.fetchone()
conn.commit()
return record is not None
@@ -0,0 +1,7 @@
from pydantic import BaseModel
class SentinelRecordFrontRequest(BaseModel):
Id: str | None = None
DeviceId: str
VehicleImage: str | None = None
+2 -3
View File
@@ -4,7 +4,6 @@ from pydantic import BaseModel
class SentinelRecordRequest(BaseModel):
Id: str | None = None
DeviceId: str
LicensePlate: str | None = None
LicensePlateImage: str | None = None
VehicleType: str | None = None
VehicleImage: str | None = None
vehicleImageFront: str | None = None
vehicleImageSide: str | None = None
@@ -0,0 +1,8 @@
from pydantic import BaseModel
class SentinelRecordSideRequest(BaseModel):
Id: str | None = None
DeviceId: str
VehicleType: str | None = None
VehicleImage: str | None = None
Binary file not shown.

After

Width:  |  Height:  |  Size: 116 KiB

+7 -7
View File
@@ -1,14 +1,16 @@
import asyncio
import base64
from fastapi import APIRouter
from config.app import F8_SERVER_USER_ID
from db.postgres.sentinel import saveSentinelRecord
from models.BaseResponse import BaseResponse
from models.F8ImageRequest import F8ImageRequest
from models.F8ImageRequestV2 import F8ImageRequestV2
from models.SentinelRecordRequest import SentinelRecordRequest
from service.RabbitMQ import sentinel_new_analysis
from service.RabbitMQ import (
mq_client,
)
from service.vision import (
process_ticket_image,
process_license_image,
@@ -85,8 +87,6 @@ async def recognize_silkworm_cocoon(data: F8ImageRequest):
@publicRouter.post("/sentinel-record-analytics")
async def delete_sentinel_record(data: SentinelRecordRequest):
# 保存部分数据到数据库
data.Id = saveSentinelRecord(data)
# 发送请求给RabbitMQ
res = await sentinel_new_analysis(data)
return BaseResponse(data=res)
# 发送全盘分析请求给RabbitMQ
asyncio.create_task(mq_client.send_all_analysis(data))
return BaseResponse(data="submitted")
+56 -19
View File
@@ -107,22 +107,51 @@ async def get_sentinel_monitor_promotional_list(
):
if not user_id:
return {"error": "userId is required"}
# 图片过期时间:7天
expiration_time = 60 * 60 * 24 * 7
return BaseResponse(
data=[
{
"id": 1,
"remark": "人员公示及岗位职责",
"url": get_temp_url("sentinel", "promotional/promotional (2).jpg"),
"url": get_temp_url(
"sentinel", "promotional/promotional (2).jpg", expiration_time
),
},
{
"id": 2,
"remark": "入川动物监督检查工作流程图",
"url": get_temp_url("sentinel", "promotional/promotional (1).jpg"),
"url": get_temp_url(
"sentinel", "promotional/promotional (1).jpg", expiration_time
),
},
{
"id": 3,
"remark": "四川省人民政府关于设立川动物运输指定通道的通告",
"url": get_temp_url("sentinel", "promotional/promotional (3).jpg"),
"remark": "四川省人民政府关于设立川动物运输指定通道的通告",
"url": get_temp_url(
"sentinel", "promotional/promotional (3).jpg", expiration_time
),
},
{
"id": 4,
"remark": "四川省动物卫生监督检查站工作程序",
"url": get_temp_url(
"sentinel", "promotional/promotional (4).jpg", expiration_time
),
},
{
"id": 5,
"remark": "四川省动物卫生监督检查站无害化处理制度",
"url": get_temp_url(
"sentinel", "promotional/promotional (5).jpg", expiration_time
),
},
{
"id": 6,
"remark": "四川省动物卫生监督检查站工作人员行为规范",
"url": get_temp_url(
"sentinel", "promotional/promotional (6).jpg", expiration_time
),
},
]
)
@@ -153,13 +182,18 @@ async def get_sentinel_monitor_list(
)
url = "https://open.ys7.com/api/lapp/v2/live/address/get"
# device_serials = ["BG2493625"]
device_serials = ["BG2493625", "GH3713250", "GH3714496", "GH3714497"]
device_serials = {
# "GG9175589": [1, 2, 3, 4],
"GG9175589": [1, 3],
"GG9175555": [1, 2],
}
video_expire_time = 25 * 24 * 60 * 60 # 25 天
res = []
for device_serial in device_serials:
live_key = f"ys7:live:{device_serial}"
for device_serial, channels in device_serials.items():
for channelNo in channels:
live_key = f"ys7:live:{device_serial}:{channelNo}"
cached_live = redis_client.get_value(live_key)
if cached_live:
@@ -169,24 +203,27 @@ async def get_sentinel_monitor_list(
payload = {
"accessToken": access_token,
"deviceSerial": device_serial,
"protocol": 4, # 流播放协议,1-ezopen、2-hls、3-rtmp、4-flv,默认为1
"expireTime": video_expire_time, # 25天
"channelNo": channelNo,
"protocol": 4,
"expireTime": video_expire_time,
"supportH265": 0,
"quality": 2,
}
result = await http_client.post(url, data=payload)
video_id = result["data"]["id"]
video_url = result["data"]["url"]
# 存到 Redis,自动序列化为 JSON,过期 25天
video_data = result.get("data")
if not video_data:
continue # 或者记录异常日志
video_id = video_data["id"]
video_url = video_data["url"]
redis_client.set_value(
live_key,
{"id": video_id, "url": video_url},
expire=video_expire_time,
)
res.append(
{
"id": video_id,
"url": video_url,
}
)
res.append({"id": video_id, "url": video_url})
return BaseResponse(data=res)
+80 -74
View File
@@ -1,97 +1,103 @@
# consumer.py
import asyncio
import json
import traceback
import aio_pika
from config.rabbitMQ import *
from models.AnalysisRequest import AnalysisRequest
from models.SentinelRecordRequest import SentinelRecordRequest
from service.vision import process_vehicle_animal_image
from service.vision import (
process_all_vehicle_animal_image,
)
async def mq_new_analysis_test(req: dict):
"""将分析请求发送到 RabbitMQ 队列(异步版)"""
connection = await aio_pika.connect_robust(
f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{RABBIT_VHOST}"
class MQClient:
"""RabbitMQ 单例客户端,支持生产和消费"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super().__new__(cls)
return cls._instance
def __init__(self):
self._connection = None
self._channel = None
self._consumer_tasks = []
# ---------------- 连接初始化 ----------------
async def init(self, prefetch_count: int = 10):
"""启动时初始化连接和通道"""
if self._connection is None:
self._connection = await aio_pika.connect_robust(
f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{SENTINEL_VHOST}"
)
self._channel = await self._connection.channel()
await self._channel.set_qos(prefetch_count=prefetch_count)
async with connection:
channel = await connection.channel()
# 声明队列,确保队列存在
queue = await channel.declare_queue(QUEUE_NAME, durable=True)
message_body = json.dumps(req)
# ---------------- 发布消息 ----------------
async def publish(self, queue_name: str, message_body: str):
"""向指定队列发送消息"""
if self._channel is None:
raise RuntimeError("MQClient 未初始化")
# 队列幂等声明
queue = await self._channel.declare_queue(queue_name, durable=True)
message = aio_pika.Message(
body=message_body.encode(),
delivery_mode=aio_pika.DeliveryMode.PERSISTENT, # 持久化
body=message_body.encode(), delivery_mode=aio_pika.DeliveryMode.PERSISTENT
)
await self._channel.default_exchange.publish(message, routing_key=queue_name)
async def send_all_analysis(self, req: SentinelRecordRequest):
await self.publish(
SENTINEL_ANALYSIS_ALL_QUEUE_NAME, json.dumps(req.model_dump())
)
await channel.default_exchange.publish(message, routing_key=QUEUE_NAME)
async def mq_pull_analysis_async_test():
# ---------------- 消费消息 ----------------
async def consume_queue(self, queue_name: str, process_func):
"""
从队列拉取分析任务并处理
process_func: 一个函数,接收 AnalysisRequest 对象处理分析逻辑
持续消费队列
process_func: async function 接收 dict 或 Request 对象
"""
connection = await aio_pika.connect_robust(
f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{RABBIT_VHOST}"
)
async with connection:
queue_name = QUEUE_NAME
channel = await connection.channel()
await channel.set_qos(prefetch_count=1)
queue = await channel.declare_queue(queue_name, durable=True)
if self._channel is None:
raise RuntimeError("MQClient 未初始化")
queue = await self._channel.declare_queue(queue_name, durable=True)
async with queue.iterator() as queue_iter:
async for message in queue_iter:
async with message.process():
data = json.loads(message.body)
req = AnalysisRequest(**data)
print(f"收到任务: {req}")
await asyncio.sleep(5) # 模拟处理
print(f"完成任务: {req}")
try:
body = message.body.decode()
data = json.loads(body)
await process_func(data)
except Exception as e:
print(f"[MQ Consume Error] {e}")
traceback.print_exc()
async def sentinel_new_analysis(req: SentinelRecordRequest):
"""将分析请求发送到 RabbitMQ 队列(异步版)"""
connection = await aio_pika.connect_robust(
f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{SENTINEL_VHOST}"
)
async with connection:
channel = await connection.channel()
# 声明队列,确保队列存在
queue = await channel.declare_queue(QUEUE_NAME, durable=True)
message_body = json.dumps(req.model_dump())
message = aio_pika.Message(
body=message_body.encode(),
delivery_mode=aio_pika.DeliveryMode.PERSISTENT, # 持久化
)
await channel.default_exchange.publish(message, routing_key=QUEUE_NAME)
async def sentinel_pull_analysis_async():
"""
从队列拉取分析任务并处理
process_func: 一个函数,接收 AnalysisRequest 对象处理分析逻辑
"""
connection = await aio_pika.connect_robust(
f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{SENTINEL_VHOST}"
)
async with connection:
queue_name = QUEUE_NAME
channel = await connection.channel()
await channel.set_qos(prefetch_count=1)
queue = await channel.declare_queue(queue_name, durable=True)
async with queue.iterator() as queue_iter:
async for message in queue_iter:
async with message.process():
data = json.loads(message.body)
# ---------------- 启动全局分析消费者 ----------------
async def start_all_consumer(self):
async def _process(data: dict):
req = SentinelRecordRequest(**data)
await process_vehicle_animal_image(req) # 处理
print(f"完成任务: {req}")
await process_all_vehicle_animal_image(req)
print(f"完成全局分析任务: {req}")
task = asyncio.create_task(
self.consume_queue(SENTINEL_ANALYSIS_ALL_QUEUE_NAME, _process)
)
self._consumer_tasks.append(task)
# ---------------- 关闭连接 ----------------
async def close(self):
for task in self._consumer_tasks:
task.cancel()
if self._channel:
await self._channel.close()
if self._connection:
await self._connection.close()
# ---------------- 全局单例 ----------------
mq_client = MQClient()
+77 -9
View File
@@ -1,21 +1,27 @@
import uuid
from pathlib import Path
from uuid import UUID
import config.minIO as minIO
import db.postgres as pg
from agent.licenseImageAgent import get_license_response
from agent.vehicleImageAgent import get_vehicle_response
from ai.plate.my_plate import recognizer
from config.minIO import minio_client, get_temp_url
from config.yolo import YOLOSingleton
from db.postgres import (
get_dept_id_by_iot_user_name,
get_dept_ids_by_dept_id,
)
from db.postgres.sentinel import update_sentinel_record, get_sentinel_record_by_id
from db.postgres.sentinel import (
get_sentinel_record_by_id,
saveSentinelRecord,
)
from llm.ticketLLM import *
from llm.ticketLLMv2 import get_ticket_response_v2
from models.SentinelRecordRequest import SentinelRecordRequest
from routers.WS import ws_manager
from utils import validate_plate
def process_ticket_image(
@@ -189,26 +195,88 @@ def process_silkworm_cocoon_image(
}
async def process_vehicle_animal_image(
# 处理车牌照片
async def process_all_vehicle_animal_image(
data: SentinelRecordRequest,
):
# 通过设备id获得组织id
dept_id = get_dept_id_by_iot_user_name(data.DeviceId)
oss_url = minIO.get_temp_url(
"sentinel", "vehicle_image_side/" + data.vehicleImageSide
)
# 得到动物类型
oss_url = minIO.get_temp_url("sentinel", "vehicle_image/" + data.VehicleImage)
# LLM得到车身信息
analysis_result = await get_vehicle_response(oss_url)
livestock_type = analysis_result.get("livestock_type", "")
remark = analysis_result.get("remark", "")
have_animal = analysis_result.get("have_animal", False)
if not have_animal:
minIO.delete_file("sentinel", "vehicle_image_side/" + data.vehicleImageSide)
minIO.delete_file("sentinel", "vehicle_image_front/" + data.vehicleImageFront)
else:
# 通过设备id获得组织id
dept_id = get_dept_id_by_iot_user_name(data.DeviceId)
# 处理汽车正面照--------------------------
# 从OSS下载
oss_url = minIO.get_temp_url(
"sentinel", "vehicle_image_front/" + data.vehicleImageFront
)
# 获取系统临时目录(自动兼容 Windows / Linux
tmp_dir = Path(tempfile.gettempdir())
tmp_dir.mkdir(parents=True, exist_ok=True)
tmp_path = tmp_dir / data.vehicleImageFront
# 下载图片到 tmp_path
response = requests.get(oss_url, stream=True)
if response.status_code == 200:
with open(tmp_path, "wb") as f:
for chunk in response.iter_content(chunk_size=8192):
f.write(chunk)
else:
raise Exception(f"下载失败: {oss_url}, status_code={response.status_code}")
# 调用识别
results = recognizer.analyze_image(str(tmp_path))
# 车牌号一次识别
license_plate = results[0].plate_no if results else ""
# 车牌号二次校准
license_plate = validate_plate(license_plate)
license_plate_color_str = results[0].plate_color if results else ""
color_map = {
"蓝色": 0,
"黄色": 1,
"绿色": 2,
"黑色": 3,
"白色": 4,
}
license_plate_color = color_map.get(license_plate_color_str, 0)
license_plate_image = data.vehicleImageFront
# 保存到数据库
update_sentinel_record(data.Id, livestock_type, remark, dept_id)
saveSentinelRecord(
data.Id,
data.VehicleType,
data.vehicleImageSide,
livestock_type,
remark,
dept_id,
license_plate,
license_plate_image,
license_plate_color,
)
# 识别完成后删除临时文件
os.remove(tmp_path)
# 可以通知的部门ids
available_departments = get_dept_ids_by_dept_id(dept_id)
# 通知控制界面
await ws_manager.noticeSentinel(
{
"content": f"载有{livestock_type}车辆即将进入关卡,请准备检查",
"content": f"车辆即将进入关卡,请准备检查",
"license_plate": license_plate,
"type": "vehicle_alert",
},
available_departments,
+54
View File
@@ -100,3 +100,57 @@ def get_memory_total():
def get_disk_total():
return psutil.disk_usage("/").total
def translate_vehicle_type(vehicle_type: str) -> str:
mapping = {
"coupe": "双门跑车",
"largevehicle": "大型车辆",
"sedan": "三厢轿车",
"suv": "运动型多用途车",
"truck": "卡车",
"van": "面包车",
}
return mapping.get(vehicle_type.lower(), vehicle_type)
import re
def validate_plate(plate: str) -> str:
if not plate:
return "车速过快,无法识别"
plate = plate.strip().upper()
# 普通车牌(燃油)
normal_pattern = r"^[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼][A-Z][A-Z0-9]{5}$"
# 新能源车
new_energy_pattern = r"^[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼][A-Z][A-Z0-9]{6}$"
# 武警车
wj_pattern = (
r"^WJ[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼]?\d{5}$"
)
# 外交车
diplomatic_pattern = r"^(使|领)\d{5}$"
# 港澳车辆
hk_macau_pattern = r"^粤Z[A-Z0-9]{4}(港|澳)$"
patterns = [
normal_pattern,
new_energy_pattern,
wj_pattern,
diplomatic_pattern,
hk_macau_pattern,
]
for p in patterns:
if re.match(p, plate):
return plate
return "车速过快,无法识别"