diff --git a/bbit_ai/app/agent/vehicleImageAgent.py b/bbit_ai/app/agent/vehicleImageAgent.py index 8fbaa69..6647154 100644 --- a/bbit_ai/app/agent/vehicleImageAgent.py +++ b/bbit_ai/app/agent/vehicleImageAgent.py @@ -28,20 +28,21 @@ def send_analyze(state: State, prompt_text: str): def analysis(state: State): state["content"] = send_analyze( - state, + state, # todo """ -提示词示例 -你是一个图像分析助手。现在给你一张车的侧身照片,请你从图中分析车上运输的牲畜种类。 +你是一个图像分析助手。现在给你一张道路图片,你需要观察远离你的第二根车道上,画面中心的车辆,请你从分析该车辆。 要求: -1. 牲畜种类可能是:牛、羊、猪、鸡、鸭、鹅。 -2. 如果图中无法判断牲畜类型,请在备注字段 remark 中写明“无法识别”或你观察到的情况。 +1. have_animal 字段中填写 true +2. livestock_type 字段中填写 “货物种类”,例如 牛、羊、猪、鸡、鸭、鹅、钢管、土渣等任何你观察到的 +3. remark 字段 需要你简短的描述该车辆车身情况,例如什么颜色的车身,你需要完整的尽你所能形容一下。 3. 不允许输出多余文字,直接返回 JSON。 JSON 示例格式: { - "livestock_type": "<牲畜种类>", // 如果能识别就填牛/羊/猪/鸡/鸭/鹅 - "remark": "<备注>" // 如果无法识别,写明原因;否则可留空 + "have_animal": true, + "livestock_type": "“测试数据:” + 5位随机数", // 例如 测试数据34223 + "remark": "<备注>" // 车身描述 } 请确保输出的 JSON 可以被严格解析。 """, @@ -59,8 +60,6 @@ graph = workflow.compile() # 执行函数 - - async def get_vehicle_response(image_url: str): final_state = graph.invoke( { diff --git a/bbit_ai/app/ai/plate/__init__.py b/bbit_ai/app/ai/plate/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/bbit_ai/app/ai/plate/detect_rec_plate.py b/bbit_ai/app/ai/plate/detect_rec_plate.py new file mode 100644 index 0000000..1c17fc7 --- /dev/null +++ b/bbit_ai/app/ai/plate/detect_rec_plate.py @@ -0,0 +1,433 @@ +import os +from dataclasses import dataclass + +import cv2 +import numpy as np +from PIL import Image, ImageDraw, ImageFont +from ultralytics import YOLO + +from ai.plate.plate_recognition.double_plate_split_merge import get_split_merge +from ai.plate.plate_recognition.plate_rec import get_plate_result + + +@dataclass +class PlateResult: + plate: str + color: str + + +def collect_files(root_path): + file_list = [] + for root, _, files in os.walk(root_path): + for name in files: + file_list.append(os.path.join(root, name)) + return sorted(file_list) + + +def four_point_transform(image, pts): + rect = pts.astype(np.float32) + (tl, tr, br, bl) = rect + + width_a = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2)) + width_b = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2)) + max_width = max(int(width_a), int(width_b)) + + height_a = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2)) + height_b = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2)) + max_height = max(int(height_a), int(height_b)) + + dst = np.array( + [ + [0, 0], + [max_width - 1, 0], + [max_width - 1, max_height - 1], + [0, max_height - 1], + ], + dtype=np.float32, + ) + matrix = cv2.getPerspectiveTransform(rect, dst) + return cv2.warpPerspective(image, matrix, (max_width, max_height)) + + +def load_model(weights, device): + model = YOLO(weights) + model.to(device) + return model + + +def det_rec_plate(img_ori, detect_model, plate_rec_model, device, conf=0.3, iou=0.5): + result_list = [] + results = detect_model(img_ori, conf=conf, iou=iou, verbose=False) + + for result in results: + boxes = result.boxes + keypoints = result.keypoints + if len(boxes) == 0 or keypoints is None: + continue + + kpts_xy = keypoints.xy + num_det = min(len(boxes), len(kpts_xy)) + for idx in range(num_det): + box = boxes.xyxy[idx].cpu().numpy() + det_conf = float(boxes.conf[idx]) + plate_type = int(boxes.cls[idx]) + + landmarks = kpts_xy[idx].cpu().numpy().astype(np.int64) + roi_img = four_point_transform(img_ori, landmarks) + if plate_type == 1: + roi_img = get_split_merge(roi_img) + + plate_number, _, plate_color, color_conf = get_plate_result( + roi_img, device, plate_rec_model, is_color=True + ) + + result_list.append( + { + "plate_no": plate_number, + "plate_color": plate_color, + "rect": [int(v) for v in box], + "detect_conf": det_conf, + "landmarks": landmarks.tolist(), + "roi_height": roi_img.shape[0], + "color_conf": color_conf, + "plate_type": plate_type, # 0: 单层, 1: 双层 + } + ) + return result_list + + +def _clamp(value, low, high): + return max(low, min(high, value)) + + +def _normalize_rect(rect): + if not rect or len(rect) < 4: + return None + x1 = int(round(float(rect[0]))) + y1 = int(round(float(rect[1]))) + x2 = int(round(float(rect[2]))) + y2 = int(round(float(rect[3]))) + left = min(x1, x2) + top = min(y1, y2) + right = max(x1, x2) + bottom = max(y1, y2) + if right <= left or bottom <= top: + return None + return [left, top, right, bottom] + + +def _get_plate_theme(plate_color): + theme_map = { + "蓝色": { + "bg": (72, 33, 6), + "border": (250, 165, 96), + "text": (254, 242, 224), + "glow": (246, 130, 59), + }, + "黄色": { + "bg": (0, 49, 74), + "border": (21, 204, 250), + "text": (195, 249, 254), + "glow": (8, 179, 234), + }, + "绿色": { + "bg": (27, 53, 4), + "border": (128, 222, 74), + "text": (231, 252, 220), + "glow": (94, 197, 34), + }, + "白色": { + "bg": (68, 50, 38), + "border": (240, 232, 226), + "text": (250, 250, 248), + "glow": (184, 163, 148), + }, + "黑色": { + "bg": (20, 12, 8), + "border": (184, 163, 148), + "text": (240, 232, 226), + "glow": (139, 116, 100), + }, + } + return theme_map.get( + plate_color, + { + "bg": (43, 24, 8), + "border": (248, 189, 56), + "text": (255, 248, 232), + "glow": (200, 140, 32), + }, + ) + + +def _draw_alpha_rect(img, x1, y1, x2, y2, color, alpha=0.75): + h, w = img.shape[:2] + x1 = _clamp(x1, 0, w - 1) + y1 = _clamp(y1, 0, h - 1) + x2 = _clamp(x2, 0, w) + y2 = _clamp(y2, 0, h) + if x2 <= x1 or y2 <= y1: + return + roi = img[y1:y2, x1:x2] + overlay = np.full_like(roi, color, dtype=np.uint8) + cv2.addWeighted(overlay, alpha, roi, 1 - alpha, 0, roi) + + +def _measure_text(text, text_size=16): + try: + font = ImageFont.truetype(get_font_file_path(), text_size, encoding="utf-8") + left, top, right, bottom = font.getbbox(text) + return right - left, bottom - top + except Exception: + print("字体文件加载失败") + size, baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.45, 1) + return size[0], size[1] + baseline + + +def _draw_glow_border(img, x1, y1, x2, y2, border_color, glow_color): + h, w = img.shape[:2] + x1 = _clamp(x1, 0, w - 1) + y1 = _clamp(y1, 0, h - 1) + x2 = _clamp(x2, 0, w - 1) + y2 = _clamp(y2, 0, h - 1) + if x2 <= x1 or y2 <= y1: + return + glow_layer = np.zeros_like(img) + cv2.rectangle(glow_layer, (x1, y1), (x2, y2), glow_color, 2) + glow_layer = cv2.GaussianBlur(glow_layer, (0, 0), sigmaX=1.6, sigmaY=1.6) + cv2.addWeighted(glow_layer, 0.45, img, 1.0, 0, img) + cv2.rectangle(img, (x1, y1), (x2, y2), border_color, 1) + + +def _draw_tech_box(img, x1, y1, x2, y2, border_color, glow_color, track_id=None): + h, w = img.shape[:2] + x1 = _clamp(x1, 0, w - 1) + y1 = _clamp(y1, 0, h - 1) + x2 = _clamp(x2, 0, w - 1) + y2 = _clamp(y2, 0, h - 1) + if x2 <= x1 or y2 <= y1: + return + + bw = x2 - x1 + bh = y2 - y1 + diag = float(np.hypot(bw, bh)) + base_thick = _clamp(int(round(diag / 70.0)), 2, 5) + glow_sigma = _clamp(diag / 55.0, 1.2, 3.6) + + glow_layer = np.zeros_like(img) + cv2.rectangle(glow_layer, (x1, y1), (x2, y2), glow_color, max(1, base_thick - 1)) + glow_layer = cv2.GaussianBlur( + glow_layer, (0, 0), sigmaX=glow_sigma, sigmaY=glow_sigma + ) + cv2.addWeighted(glow_layer, 0.48, img, 1.0, 0, img) + cv2.rectangle(img, (x1, y1), (x2, y2), border_color, max(1, base_thick - 1)) + + corner_len = _clamp(int(round(min(bw, bh) * 0.28)), 8, 20) + t = base_thick + cv2.line(img, (x1, y1), (x1 + corner_len, y1), border_color, t) + cv2.line(img, (x1, y1), (x1, y1 + corner_len), border_color, t) + cv2.line(img, (x2, y1), (x2 - corner_len, y1), border_color, t) + cv2.line(img, (x2, y1), (x2, y1 + corner_len), border_color, t) + cv2.line(img, (x1, y2), (x1 + corner_len, y2), border_color, t) + cv2.line(img, (x1, y2), (x1, y2 - corner_len), border_color, t) + cv2.line(img, (x2, y2), (x2 - corner_len, y2), border_color, t) + cv2.line(img, (x2, y2), (x2, y2 - corner_len), border_color, t) + + if track_id is not None: + badge = "T%02d" % track_id + badge_w_txt, badge_h_txt = _measure_text(badge, text_size=12) + pad_x = 5 + pad_y = 3 + badge_w = badge_w_txt + pad_x * 2 + badge_h = badge_h_txt + pad_y * 2 + bx2 = _clamp(x2, badge_w + 2, w - 2) + by1 = _clamp(y1 - badge_h - 2, 2, h - badge_h - 2) + bx1 = bx2 - badge_w + by2 = by1 + badge_h + _draw_alpha_rect(img, bx1, by1, bx2, by2, (18, 18, 18), alpha=0.65) + cv2.rectangle(img, (bx1, by1), (bx2, by2), border_color, 1) + text_x = bx1 + pad_x + text_y = by1 + max(1, int((badge_h - badge_h_txt) / 2)) + img[:] = cv2ImgAddText(img, badge, text_x, text_y, border_color, 12) + + +def cv2ImgAddText(img, text, left, top, textColor=(0, 255, 0), textSize=20): + if isinstance(img, np.ndarray): # 判断是否OpenCV图片类型 + img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(img) + fontText = ImageFont.truetype(get_font_file_path(), textSize, encoding="utf-8") + draw.text((left, top), text, textColor, font=fontText) + return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR) + + +def get_font_file_path(): + # win + # font_path = r"C:\Users\BBIT\Desktop\yolo26-plate-main\fonts\platech.ttf" + # linux + font_path = r"/app/models/sentinel/platech.ttf" + + return font_path + + +def _draw_tech_landmark(img, x, y, border_color, glow_color): + h, w = img.shape[:2] + x = _clamp(int(round(x)), 0, w - 1) + y = _clamp(int(round(y)), 0, h - 1) + glow_layer = np.zeros_like(img) + cv2.circle(glow_layer, (x, y), 5, glow_color, -1) + glow_layer = cv2.GaussianBlur(glow_layer, (0, 0), sigmaX=1.2, sigmaY=1.2) + cv2.addWeighted(glow_layer, 0.5, img, 1.0, 0, img) + cv2.circle(img, (x, y), 2, border_color, -1) + cv2.circle(img, (x, y), 4, border_color, 1) + + +def _plate_width_from_landmarks(landmarks, fallback_width): + if not landmarks or len(landmarks) < 4: + return float(fallback_width) + try: + p0 = np.array(landmarks[0], dtype=np.float32) + p1 = np.array(landmarks[1], dtype=np.float32) + p2 = np.array(landmarks[2], dtype=np.float32) + p3 = np.array(landmarks[3], dtype=np.float32) + top_w = float(np.linalg.norm(p1 - p0)) + bottom_w = float(np.linalg.norm(p2 - p3)) + width = (top_w + bottom_w) / 2.0 + if np.isfinite(width) and width > 1: + return width + except Exception: + pass + return float(fallback_width) + + +def _plate_height_from_landmarks(landmarks, fallback_height): + if not landmarks or len(landmarks) < 4: + return float(fallback_height) + try: + p0 = np.array(landmarks[0], dtype=np.float32) + p1 = np.array(landmarks[1], dtype=np.float32) + p2 = np.array(landmarks[2], dtype=np.float32) + p3 = np.array(landmarks[3], dtype=np.float32) + left_h = float(np.linalg.norm(p3 - p0)) + right_h = float(np.linalg.norm(p2 - p1)) + height = (left_h + right_h) / 2.0 + if np.isfinite(height) and height > 1: + return height + except Exception: + pass + return float(fallback_height) + + +def _fit_font_size(text, max_w, max_h, min_size=10, max_size=24): + if max_w <= 0 or max_h <= 0: + return min_size + for size in range(max_size, min_size - 1, -1): + tw, th = _measure_text(text, text_size=size) + if tw <= max_w and th <= max_h: + return size + return min_size + + +def draw_result(orgimg, result_list): + landmark_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)] + result_str = [] + img_h, img_w = orgimg.shape[:2] + for idx, result in enumerate(result_list, start=1): + raw_rect = _normalize_rect(result.get("rect")) + if raw_rect is None: + continue + + x1, y1, x2, y2 = raw_rect + w = x2 - x1 + h = y2 - y1 + padding_w = int(round(0.05 * w)) + padding_h = int(round(0.11 * h)) + rx1 = _clamp(x1 - padding_w, 0, img_w - 1) + ry1 = _clamp(y1 - padding_h, 0, img_h - 1) + rx2 = _clamp(x2 + padding_w, 0, img_w - 1) + ry2 = _clamp(y2 + padding_h, 0, img_h - 1) + + landmarks = result.get("landmarks", []) + plate_no = result.get("plate_no", "") + plate_color = result.get("plate_color", "") + if result.get("plate_type", 0) == 1: + result_p = "%s %s双层" % (plate_no, plate_color) + else: + result_p = "%s %s" % (plate_no, plate_color) + result_str.append(result_p) + + theme = _get_plate_theme(plate_color) + for i in range(min(4, len(landmarks))): + point = landmarks[i] + if len(point) < 2: + continue + point_color = landmark_colors[i] + _draw_tech_landmark(orgimg, point[0], point[1], point_color, point_color) + + _draw_tech_box( + orgimg, rx1, ry1, rx2, ry2, theme["border"], theme["glow"], track_id=idx + ) + + label = "%s | %s" % (plate_no, plate_color) + plate_w = _plate_width_from_landmarks(landmarks, rx2 - rx1) + plate_h = _plate_height_from_landmarks(landmarks, ry2 - ry1) + pre_card_h = _clamp(int(round(plate_h)), 24, min(110, img_h - 4)) + pre_pad_y = _clamp(int(round(pre_card_h * 0.16)), 3, 10) + pre_inner_h = max(8, pre_card_h - pre_pad_y * 2) + pre_max_font = _clamp(int(round(pre_card_h * 0.72)), 14, 44) + pre_min_font = _clamp(int(round(pre_card_h * 0.42)), 10, pre_max_font) + pre_font_size = _fit_font_size( + label, 4096, pre_inner_h, min_size=pre_min_font, max_size=pre_max_font + ) + pre_text_w, _ = _measure_text(label, text_size=pre_font_size) + + min_w_by_text = pre_text_w + 20 + base_w_by_plate = int(round(plate_w * 1.05)) + card_w = max(90, base_w_by_plate, min_w_by_text) + card_w = min(card_w, img_w - 8) + + card_h = pre_card_h + card_pad_x = _clamp(int(round(card_w * 0.08)), 8, 18) + card_pad_y = _clamp(int(round(card_h * 0.16)), 3, 10) + + card_x = int(rx1 + (rx2 - rx1 - card_w) / 2) + card_x = _clamp(card_x, 4, max(4, img_w - card_w - 4)) + card_y = ry1 - card_h - 2 + if card_y < 2: + card_y = _clamp(ry1 + 2, 2, max(2, img_h - card_h - 2)) + + _draw_alpha_rect( + orgimg, + card_x, + card_y, + card_x + card_w, + card_y + card_h, + theme["bg"], + alpha=0.78, + ) + _draw_glow_border( + orgimg, + card_x, + card_y, + card_x + card_w, + card_y + card_h, + theme["border"], + theme["glow"], + ) + + inner_w = max(8, card_w - card_pad_x * 2) + inner_h = max(8, card_h - card_pad_y * 2) + dynamic_max_font = _clamp(int(round(card_h * 0.72)), 14, 44) + dynamic_min_font = _clamp(int(round(card_h * 0.42)), 10, dynamic_max_font) + font_size = _fit_font_size( + label, + inner_w, + inner_h, + min_size=dynamic_min_font, + max_size=dynamic_max_font, + ) + text_w, text_h = _measure_text(label, text_size=font_size) + text_x = card_x + max(card_pad_x, int((card_w - text_w) / 2)) + text_y = card_y + max(card_pad_y - 1, int((card_h - text_h) / 2)) + orgimg = cv2ImgAddText(orgimg, label, text_x, text_y, theme["text"], font_size) + + return orgimg, result_str diff --git a/bbit_ai/app/ai/plate/my_plate.py b/bbit_ai/app/ai/plate/my_plate.py new file mode 100644 index 0000000..d318f6d --- /dev/null +++ b/bbit_ai/app/ai/plate/my_plate.py @@ -0,0 +1,98 @@ +import os +from dataclasses import dataclass + +import cv2 +import torch +from ultralytics import YOLO + +from ai.plate.detect_rec_plate import det_rec_plate, draw_result +from ai.plate.plate_recognition.plate_rec import init_model +from config.yolo import logger + + +@dataclass +class PlateInfo: + plate_no: str + plate_color: str + result_img_path: str + + +class PlateRecognizer: + _instance = None + _ready = False + + def __new__(cls, *args, **kwargs): + if cls._instance is None: + cls._instance = super(PlateRecognizer, cls).__new__(cls) + return cls._instance + + def __init__( + self, + detect_model_path, + rec_model_path, + output_dir="result", + device="cuda" if torch.cuda.is_available() else "cpu", + ): + if hasattr(self, "_initialized") and self._initialized: + return + self.device = torch.device(device) + self.output_dir = output_dir + os.makedirs(self.output_dir, exist_ok=True) + + try: + # 模型加载 + self.detect_model = YOLO(detect_model_path) + self.detect_model.to(self.device) + self.detect_model.eval() + self.plate_rec_model = init_model( + self.device, rec_model_path, is_color=True + ) + self._initialized = True + self._ready = True + except Exception as e: + self.detect_model = None + self._ready = False + logger.error(f"❌车牌 YOLO 模型加载失败: {e}") + + def analyze_image(self, image_path): + img = cv2.imread(image_path) + if not self._ready or img is None: + return [] + + # 检测与识别 + result_list = det_rec_plate( + img, self.detect_model, self.plate_rec_model, self.device + ) + + # 可视化 & 保存 + vis_img, _ = draw_result(img, result_list) + result_img_path = os.path.join(self.output_dir, os.path.basename(image_path)) + cv2.imwrite(result_img_path, vis_img) + + # 构建返回结果 + results = [ + PlateInfo( + plate_no=r["plate_no"], + plate_color=r["plate_color"], + result_img_path=result_img_path, + ) + for r in result_list + ] + return results + + +# 初始化一次 +recognizer = PlateRecognizer( + detect_model_path="/app/models/sentinel/yolo26s-plate-detect.pt", + rec_model_path="/app/models/sentinel/plate_rec_color.pth", + output_dir="result", +) + + +# base_dir = Path(r"C:\Users\BBIT\Desktop\yolo26-plate-main") +# +# recognizer = PlateRecognizer( +# detect_model_path=str(base_dir / "weights" / "yolo26s-plate-detect.pt"), +# rec_model_path=str(base_dir / "weights" / "plate_rec_color.pth"), +# output_dir="result", +# ) diff --git a/bbit_ai/app/ai/plate/plate_recognition/double_plate_split_merge.py b/bbit_ai/app/ai/plate/plate_recognition/double_plate_split_merge.py new file mode 100644 index 0000000..24c6537 --- /dev/null +++ b/bbit_ai/app/ai/plate/plate_recognition/double_plate_split_merge.py @@ -0,0 +1,15 @@ +import os +import cv2 +import numpy as np +def get_split_merge(img): + h,w,c = img.shape + img_upper = img[0:int(5/12*h),:] + img_lower = img[int(1/3*h):,:] + img_upper = cv2.resize(img_upper,(img_lower.shape[1],img_lower.shape[0])) + new_img = np.hstack((img_upper,img_lower)) + return new_img + +if __name__=="__main__": + img = cv2.imread("double_plate/tmp8078.png") + new_img =get_split_merge(img) + cv2.imwrite("double_plate/new.jpg",new_img) diff --git a/bbit_ai/app/ai/plate/plate_recognition/plateNet.py b/bbit_ai/app/ai/plate/plate_recognition/plateNet.py new file mode 100644 index 0000000..ce9a982 --- /dev/null +++ b/bbit_ai/app/ai/plate/plate_recognition/plateNet.py @@ -0,0 +1,203 @@ +import torch.nn as nn +import torch + + +class myNet_ocr(nn.Module): + def __init__(self,cfg=None,num_classes=78,export=False): + super(myNet_ocr, self).__init__() + if cfg is None: + cfg =[32,32,64,64,'M',128,128,'M',196,196,'M',256,256] + # cfg =[32,32,'M',64,64,'M',128,128,'M',256,256] + self.feature = self.make_layers(cfg, True) + self.export = export + # self.classifier = nn.Linear(cfg[-1], num_classes) + # self.loc = nn.MaxPool2d((2, 2), (5, 1), (0, 1),ceil_mode=True) + # self.loc = nn.AvgPool2d((2, 2), (5, 2), (0, 1),ceil_mode=False) + self.loc = nn.MaxPool2d((5, 2), (1, 1),(0,1),ceil_mode=False) + self.newCnn=nn.Conv2d(cfg[-1],num_classes,1,1) + # self.newBn=nn.BatchNorm2d(num_classes) + def make_layers(self, cfg, batch_norm=False): + layers = [] + in_channels = 3 + for i in range(len(cfg)): + if i == 0: + conv2d =nn.Conv2d(in_channels, cfg[i], kernel_size=5,stride =1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = cfg[i] + else : + if cfg[i] == 'M': + layers += [nn.MaxPool2d(kernel_size=3, stride=2,ceil_mode=True)] + else: + conv2d = nn.Conv2d(in_channels, cfg[i], kernel_size=3, padding=(1,1),stride =1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = cfg[i] + return nn.Sequential(*layers) + + def forward(self, x): + x = self.feature(x) + x=self.loc(x) + x=self.newCnn(x) + # x=self.newBn(x) + if self.export: + conv = x.squeeze(2) # b *512 * width + conv = conv.transpose(2,1) # [w, b, c] + # conv =conv.argmax(dim=2) + return conv + else: + b, c, h, w = x.size() + assert h == 1, "the height of conv must be 1" + conv = x.squeeze(2) # b *512 * width + conv = conv.permute(2, 0, 1) # [w, b, c] + # output = F.log_softmax(self.rnn(conv), dim=2) + output = torch.softmax(conv, dim=2) + return output + +myCfg = [32,'M',64,'M',96,'M',128,'M',256] +class myNet(nn.Module): + def __init__(self,cfg=None,num_classes=3): + super(myNet, self).__init__() + if cfg is None: + cfg = myCfg + self.feature = self.make_layers(cfg, True) + self.classifier = nn.Linear(cfg[-1], num_classes) + def make_layers(self, cfg, batch_norm=False): + layers = [] + in_channels = 3 + for i in range(len(cfg)): + if i == 0: + conv2d =nn.Conv2d(in_channels, cfg[i], kernel_size=5,stride =1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = cfg[i] + else : + if cfg[i] == 'M': + layers += [nn.MaxPool2d(kernel_size=3, stride=2,ceil_mode=True)] + else: + conv2d = nn.Conv2d(in_channels, cfg[i], kernel_size=3, padding=1,stride =1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = cfg[i] + return nn.Sequential(*layers) + + def forward(self, x): + x = self.feature(x) + x = nn.AvgPool2d(kernel_size=3, stride=1)(x) + x = x.view(x.size(0), -1) + y = self.classifier(x) + return y + + +class MyNet_color(nn.Module): + def __init__(self, class_num=6): + super(MyNet_color, self).__init__() + self.class_num = class_num + self.backbone = nn.Sequential( + nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(5, 5), stride=(1, 1)), # 0 + torch.nn.BatchNorm2d(16), + nn.ReLU(), + nn.MaxPool2d(kernel_size=(2, 2)), + nn.Dropout(0), + nn.Flatten(), + nn.Linear(480, 64), + nn.Dropout(0), + nn.ReLU(), + nn.Linear(64, class_num), + nn.Dropout(0), + nn.Softmax(1) + ) + + def forward(self, x): + logits = self.backbone(x) + + return logits + + +class myNet_ocr_color(nn.Module): + def __init__(self,cfg=None,num_classes=78,export=False,color_num=None): + super(myNet_ocr_color, self).__init__() + if cfg is None: + cfg =[32,32,64,64,'M',128,128,'M',196,196,'M',256,256] + # cfg =[32,32,'M',64,64,'M',128,128,'M',256,256] + self.feature = self.make_layers(cfg, True) + self.export = export + self.color_num=color_num + self.conv_out_num=12 #颜色第一个卷积层输出通道12 + if self.color_num: + self.conv1=nn.Conv2d(cfg[-1],self.conv_out_num,kernel_size=3,stride=2) + self.bn1=nn.BatchNorm2d(self.conv_out_num) + self.relu1=nn.ReLU(inplace=True) + self.gap =nn.AdaptiveAvgPool2d(output_size=1) + self.color_classifier=nn.Conv2d(self.conv_out_num,self.color_num,kernel_size=1,stride=1) + self.color_bn = nn.BatchNorm2d(self.color_num) + self.flatten = nn.Flatten() + self.loc = nn.MaxPool2d((5, 2), (1, 1),(0,1),ceil_mode=False) + self.newCnn=nn.Conv2d(cfg[-1],num_classes,1,1) + # self.newBn=nn.BatchNorm2d(num_classes) + def make_layers(self, cfg, batch_norm=False): + layers = [] + in_channels = 3 + for i in range(len(cfg)): + if i == 0: + conv2d =nn.Conv2d(in_channels, cfg[i], kernel_size=5,stride =1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = cfg[i] + else : + if cfg[i] == 'M': + layers += [nn.MaxPool2d(kernel_size=3, stride=2,ceil_mode=True)] + else: + conv2d = nn.Conv2d(in_channels, cfg[i], kernel_size=3, padding=(1,1),stride =1) + if batch_norm: + layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)] + else: + layers += [conv2d, nn.ReLU(inplace=True)] + in_channels = cfg[i] + return nn.Sequential(*layers) + + def forward(self, x): + x = self.feature(x) + if self.color_num: + x_color=self.conv1(x) + x_color=self.bn1(x_color) + x_color =self.relu1(x_color) + x_color = self.color_classifier(x_color) + x_color = self.color_bn(x_color) + x_color =self.gap(x_color) + x_color = self.flatten(x_color) + x=self.loc(x) + x=self.newCnn(x) + + if self.export: + conv = x.squeeze(2) # b *512 * width + conv = conv.transpose(2,1) # [w, b, c] + if self.color_num: + return conv,x_color + return conv + else: + b, c, h, w = x.size() + assert h == 1, "the height of conv must be 1" + conv = x.squeeze(2) # b *512 * width + conv = conv.permute(2, 0, 1) # [w, b, c] + output = F.log_softmax(conv, dim=2) + if self.color_num: + return output,x_color + return output + + +if __name__ == '__main__': + x = torch.randn(1,3,48,216) + model = myNet_ocr(num_classes=78,export=True) + out = model(x) + print(out.shape) \ No newline at end of file diff --git a/bbit_ai/app/ai/plate/plate_recognition/plate_rec.py b/bbit_ai/app/ai/plate/plate_recognition/plate_rec.py new file mode 100644 index 0000000..d64f9ac --- /dev/null +++ b/bbit_ai/app/ai/plate/plate_recognition/plate_rec.py @@ -0,0 +1,134 @@ +import os +import time + +import cv2 +import numpy as np +import torch + +from ai.plate.plate_recognition.plateNet import myNet_ocr_color + + +def cv_imread(path): # 可以读取中文路径的图片 + img = cv2.imdecode(np.fromfile(path, dtype=np.uint8), -1) + return img + + +def allFilePath(rootPath, allFIleList): + fileList = os.listdir(rootPath) + for temp in fileList: + if os.path.isfile(os.path.join(rootPath, temp)): + if temp.endswith(".jpg") or temp.endswith(".png") or temp.endswith(".JPG"): + allFIleList.append(os.path.join(rootPath, temp)) + else: + allFilePath(os.path.join(rootPath, temp), allFIleList) + + +device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") +color = ["黑色", "蓝色", "绿色", "白色", "黄色"] +plateName = r"#京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新学警港澳挂使领民航危0123456789ABCDEFGHJKLMNPQRSTUVWXYZ险品" +mean_value, std_value = (0.588, 0.193) + + +def decodePlate(preds): + pre = 0 + newPreds = [] + index = [] + for i in range(len(preds)): + if preds[i] != 0 and preds[i] != pre: + newPreds.append(preds[i]) + index.append(i) + pre = preds[i] + return newPreds, index + + +def image_processing(img, device): + img = cv2.resize(img, (168, 48)) + img = np.reshape(img, (48, 168, 3)) + + # normalize + img = img.astype(np.float32) + img = (img / 255.0 - mean_value) / std_value + img = img.transpose([2, 0, 1]) + img = torch.from_numpy(img) + + img = img.to(device) + img = img.view(1, *img.size()) + return img + + +def get_plate_result(img, device, model, is_color=False): + input = image_processing(img, device) + if is_color: # 是否识别颜色 + preds, color_preds = model(input) + color_preds = torch.softmax(color_preds, dim=-1) + color_conf, color_index = torch.max(color_preds, dim=-1) + color_conf = color_conf.item() + else: + preds = model(input) + preds = torch.softmax(preds, dim=-1) + prob, index = preds.max(dim=-1) + index = index.view(-1).detach().cpu().numpy() + prob = prob.view(-1).detach().cpu().numpy() + + # preds=preds.view(-1).detach().cpu().numpy() + newPreds, new_index = decodePlate(index) + prob = prob[new_index] + plate = "" + for i in newPreds: + plate += plateName[i] + # if not (plate[0] in plateName[1:44] ): + # return "" + if is_color: + return ( + plate, + prob, + color[color_index], + color_conf, + ) # 返回车牌号以及每个字符的概率,以及颜色,和颜色的概率 + else: + return plate, prob + + +def init_model(device, model_path, is_color=False): + # print( print(sys.path)) + # model_path ="plate_recognition/model/checkpoint_61_acc_0.9715.pth" + check_point = torch.load(model_path, map_location=device) + model_state = check_point["state_dict"] + cfg = check_point["cfg"] + color_classes = 0 + if is_color: + color_classes = 5 # 颜色类别数 + model = myNet_ocr_color( + num_classes=len(plateName), export=True, cfg=cfg, color_num=color_classes + ) + + model.load_state_dict(model_state, strict=False) + model.to(device) + model.eval() + return model + + +# model = init_model(device) +if __name__ == "__main__": + model_path = r"weights/plate_rec_color.pth" + image_path = "images/tmp2424.png" + testPath = r"/mnt/Gpan/Mydata/pytorchPorject/CRNN/crnn_plate_recognition/images" + fileList = [] + allFilePath(testPath, fileList) + # result = get_plate_result(image_path,device) + # print(result) + is_color = False + model = init_model(device, model_path, is_color=is_color) + right = 0 + begin = time.time() + + for imge_path in fileList: + img = cv2.imread(imge_path) + if is_color: + plate, _, plate_color, _ = get_plate_result( + img, device, model, is_color=is_color + ) + print(plate) + else: + plate, _ = get_plate_result(img, device, model, is_color=is_color) + print(plate, imge_path) diff --git a/bbit_ai/app/app.py b/bbit_ai/app/app.py index a63b018..f934c73 100644 --- a/bbit_ai/app/app.py +++ b/bbit_ai/app/app.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import asynccontextmanager from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware @@ -21,12 +22,25 @@ from routers.Service import serviceRouter from routers.System import systemRouter from routers.Vision import visionRouter from routers.WS import iot_ws_router -from service.RabbitMQ import sentinel_pull_analysis_async +from service.RabbitMQ import ( + mq_client, +) + + +@asynccontextmanager +async def lifespan(app: FastAPI): + # 应用启动时初始化 MQ + await mq_client.init() + # 启动消费者 + await mq_client.start_all_consumer() + yield + # 应用关闭时关闭 MQ 连接 + if mq_client._connection: + await mq_client._connection.close() async def ai_lab(): - app = FastAPI(title="BBIT_AI") - + app = FastAPI(title="BBIT_AI", lifespan=lifespan) origins = [ "http://localhost:8091", # Vite dev 默认端口 "https://ai.ronsunny.cn:8090", @@ -71,16 +85,13 @@ async def ai_lab(): async def main(): - # 初始化模型 YOLOSingleton.init_model() # 主干AI实验室FastAPI服务 task_api = asyncio.create_task(ai_lab()) - # RabbitMQ服务 - task_mq = asyncio.create_task(sentinel_pull_analysis_async()) # 等 HTTP 服务启动后再启动 MQTT task_mqtt = asyncio.create_task(mqtt_client_runner()) - await asyncio.gather(task_api, task_mq, task_mqtt) + await asyncio.gather(task_api, task_mqtt) # MCP服务-ailab # endpoint_url_ai_lab = "wss://ai.ronsunny.cn:8090/aimcp/mcp_endpoint/mcp/?token=TsSP9lBq6Oa1WMkachHoS2TtNt4GKV/Gli24pk5Rjpk%3D" diff --git a/bbit_ai/app/config/minIO.py b/bbit_ai/app/config/minIO.py index 3b12aa7..4ed84d9 100644 --- a/bbit_ai/app/config/minIO.py +++ b/bbit_ai/app/config/minIO.py @@ -1,6 +1,7 @@ from datetime import timedelta from minio import Minio +from minio.commonconfig import CopySource # MinIO 客户端初始化 minio_client = Minio( @@ -28,7 +29,7 @@ def get_upload_token(bucket_name, object_name, xpires=timedelta(minutes=15)): ) -def get_temp_url(bucket_name, object_name): +def get_temp_url(bucket_name, object_name, seconds: float = 3600): # 如果 object_name 为 None 或空字符串,则返回默认图片 if not object_name or not bucket_name: bucket_name = "system" @@ -36,7 +37,7 @@ def get_temp_url(bucket_name, object_name): # 使用 presigned_get_object 获取临时 URL return minio_client.presigned_get_object( - bucket_name, object_name, expires=timedelta(seconds=3600) + bucket_name, object_name, expires=timedelta(seconds=seconds) ) @@ -51,3 +52,28 @@ def get_temp_url_dict(bucket_name, object_dict, object_name): return minio_client.presigned_get_object( bucket_name, object_dict + "/" + object_name, expires=timedelta(seconds=3600) ) + + +# 移动文件(实际上是 copy + delete) +def move_file(bucket_name, source_object_name, target_object_name): + """ + bucket_name: bucket名称 + source_object_name: 原对象路径,例如 folder1/test.jpg + target_object_name: 目标对象路径,例如 folder2/test.jpg + """ + # 复制到新位置 + minio_client.copy_object( + bucket_name, target_object_name, CopySource(bucket_name, source_object_name) + ) + + # 删除原文件 + minio_client.remove_object(bucket_name, source_object_name) + + +# 删除文件 +def delete_file(bucket_name, object_name): + """ + bucket_name: bucket名称 + object_name: 文件路径,例如 folder/test.jpg + """ + minio_client.remove_object(bucket_name, object_name) diff --git a/bbit_ai/app/config/rabbitMQ.py b/bbit_ai/app/config/rabbitMQ.py index 33fa827..1cd7712 100644 --- a/bbit_ai/app/config/rabbitMQ.py +++ b/bbit_ai/app/config/rabbitMQ.py @@ -3,7 +3,19 @@ from utils.GlobalVariable import LOCAL_IP RABBIT_HOST = LOCAL_IP RABBIT_USER = "ai_lab" RABBIT_PASSWORD = "123456" -QUEUE_NAME = "analysis_queue" + + RABBIT_VHOST = "bbit_ai" +QUEUE_NAME = "analysis_queue" + SENTINEL_VHOST = "sentinel" +SENTINEL_ANALYSIS_ALL_QUEUE_NAME = "sentinel.analysis_all_queue" +SENTINEL_ANALYSIS_SIDE_QUEUE_NAME = "sentinel.analysis_side_queue" +SENTINEL_ANALYSIS_FRONT_QUEUE_NAME = "sentinel.analysis_front_queue" + +SENTINEL_FRONT_REQUEST_QUEUE = "sentinel.front_pic" + + +def get_sentinel_front_queue_name(device_id): + return f"{SENTINEL_FRONT_REQUEST_QUEUE}.{device_id}" diff --git a/bbit_ai/app/config/yolo.py b/bbit_ai/app/config/yolo.py index 742cd1f..ce2c4cb 100644 --- a/bbit_ai/app/config/yolo.py +++ b/bbit_ai/app/config/yolo.py @@ -99,7 +99,7 @@ class YOLOSingleton: except Exception as e: cls._model = None cls._ready = False - logger.error(f"❌ YOLO 模型加载失败: {e}") + logger.error(f"❌蚕茧 YOLO 模型加载失败: {e}") @classmethod def detect(cls, img_bytes: bytes): diff --git a/bbit_ai/app/db/postgres/annual_meeting.py b/bbit_ai/app/db/postgres/annual_meeting.py index b42c75e..0c97631 100644 --- a/bbit_ai/app/db/postgres/annual_meeting.py +++ b/bbit_ai/app/db/postgres/annual_meeting.py @@ -35,34 +35,46 @@ import time def reset_all_exchange_status(): - """将所有记录 is_finished 置为 False,并随机 position(以当前时间作为随机种子)""" + """将所有记录 is_finished 置为 False, + 且 gift_code == 2 的记录强制排在最后,其余随机排序 + """ with pg_pool.getConn() as conn: with conn.cursor() as cur: - # 获取总记录数 - cur.execute("SELECT id FROM annual_meeting_exchange") - ids = [row[0] for row in cur.fetchall()] + # 取出 id 和 gift_code + cur.execute("SELECT id, gift_code FROM annual_meeting_exchange") + rows = cur.fetchall() - # 用当前时间戳作为随机种子 - seed = int(time.time() * 1000) # 毫秒级 + # 分组 + normal_ids = [r[0] for r in rows if r[1] != 2] + tail_ids = [r[0] for r in rows if r[1] == 2] + + # 随机种子 + seed = int(time.time() * 1000) random.seed(seed) - # 生成随机顺序的 position - positions = list(range(1, len(ids) + 1)) - random.shuffle(positions) + # 只打乱非 gift_code == 2 的部分 + random.shuffle(normal_ids) - # 更新每条记录 - for record_id, pos in zip(ids, positions): + # 合并顺序:普通在前,gift_code==2 在后 + ordered_ids = normal_ids + tail_ids + + # 依次更新 sort + for idx, record_id in enumerate(ordered_ids, start=1): cur.execute( """ UPDATE annual_meeting_exchange SET is_finished = FALSE, sort = %s WHERE id = %s """, - (pos, record_id), + (idx, record_id), ) conn.commit() - return {"updated_count": len(ids), "seed_used": seed} + return { + "updated_count": len(ordered_ids), + "seed_used": seed, + "tail_count": len(tail_ids), + } def reset_user_status(target_user_id: str): diff --git a/bbit_ai/app/db/postgres/sentinel.py b/bbit_ai/app/db/postgres/sentinel.py index 1bb4d30..f0bd53d 100644 --- a/bbit_ai/app/db/postgres/sentinel.py +++ b/bbit_ai/app/db/postgres/sentinel.py @@ -1,6 +1,5 @@ from config.minIO import get_temp_url_dict from config.pgDb import pg_pool -from models.SentinelRecordRequest import SentinelRecordRequest from utils.MyUtils import format_datetime @@ -90,6 +89,7 @@ def get_sentinel_record_list_db_page( r.vehicle_image, r.livestock_type, r.livestock_source, + r.license_plate_color, r.is_inspected, r.dept_id, sd.name AS dept_name, @@ -116,6 +116,7 @@ def get_sentinel_record_list_db_page( vehicle_image, livestock_type, livestock_source, + license_plate_color, is_inspected, dept_id, dept_name, @@ -130,13 +131,14 @@ def get_sentinel_record_list_db_page( "license_plate": license_plate, "vehicle_type": vehicle_type, "license_plate_image": get_temp_url_dict( - "sentinel", "license_plate", license_plate_image + "sentinel", "vehicle_image_front", license_plate_image ), "vehicle_image": get_temp_url_dict( - "sentinel", "vehicle_image", vehicle_image + "sentinel", "vehicle_image_side", vehicle_image ), "livestock_type": livestock_type, "livestock_source": livestock_source, + "license_plate_color": license_plate_color, "is_inspected": 1 if is_inspected else 0, "dept_id": dept_id, "dept_name": dept_name, @@ -161,6 +163,7 @@ def get_sentinel_record_by_id(record_id): r.license_plate_image, r.vehicle_image, r.livestock_type, + r.license_plate_color, r.livestock_source, r.is_inspected, r.dept_id, @@ -188,6 +191,7 @@ def get_sentinel_record_by_id(record_id): license_plate_image, vehicle_image, livestock_type, + license_plate_color, livestock_source, is_inspected, dept_id, @@ -202,12 +206,13 @@ def get_sentinel_record_by_id(record_id): "license_plate": license_plate, "vehicle_type": vehicle_type, "license_plate_image": get_temp_url_dict( - "sentinel", "license_plate", license_plate_image + "sentinel", "vehicle_image_front", license_plate_image ), "vehicle_image": get_temp_url_dict( - "sentinel", "vehicle_image", vehicle_image + "sentinel", "vehicle_image_side", vehicle_image ), "livestock_type": livestock_type, + "license_plate_color": license_plate_color, "livestock_source": livestock_source, "is_inspected": 1 if is_inspected else 0, "dept_id": str(dept_id), @@ -333,15 +338,30 @@ def delete_sentinel_record_db(id: str) -> int: return cursor.rowcount -def saveSentinelRecord(data: SentinelRecordRequest) -> str: +def saveSentinelRecord( + id: str, + vehicle_type: str, + vehicle_image: str, + livestock_type: str, + remark: str, + dept_id: str, + license_plate: str, + license_plate_image: str, + license_plate_color: int = 0, +) -> str: sql = """ INSERT INTO sentinel_records ( + id, + vehicle_type, + vehicle_image, + livestock_type, + remark, + dept_id, license_plate, license_plate_image, - vehicle_type, - vehicle_image + license_plate_color ) - VALUES (%s, %s, %s, %s) + VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s) RETURNING id; """ @@ -350,36 +370,17 @@ def saveSentinelRecord(data: SentinelRecordRequest) -> str: cursor.execute( sql, ( - data.LicensePlate, - data.LicensePlateImage, - data.VehicleType, - data.VehicleImage, + id, + vehicle_type, + vehicle_image, + livestock_type, + remark, + dept_id, + license_plate, + license_plate_image, + license_plate_color, ), ) new_id = cursor.fetchone()[0] conn.commit() return str(new_id) - - -def update_sentinel_record( - id: str, livestock_type: str, remark: str, dept_id: str -) -> bool: - """ - 根据 id 更新 sentinel_records 表中的 livestock_type 和 dept_id - """ - sql = """ - UPDATE sentinel_records - SET livestock_type = %s, - remark = %s, - dept_id = %s, - updated_at = now() - WHERE id = %s - RETURNING id; - """ - - with pg_pool.getConn() as conn: - with conn.cursor() as cursor: - cursor.execute(sql, (livestock_type, remark, dept_id, id)) - record = cursor.fetchone() - conn.commit() - return record is not None diff --git a/bbit_ai/app/models/SentinelRecordFrontRequest.py b/bbit_ai/app/models/SentinelRecordFrontRequest.py new file mode 100644 index 0000000..2aec87a --- /dev/null +++ b/bbit_ai/app/models/SentinelRecordFrontRequest.py @@ -0,0 +1,7 @@ +from pydantic import BaseModel + + +class SentinelRecordFrontRequest(BaseModel): + Id: str | None = None + DeviceId: str + VehicleImage: str | None = None diff --git a/bbit_ai/app/models/SentinelRecordRequest.py b/bbit_ai/app/models/SentinelRecordRequest.py index 8cc941b..5186306 100644 --- a/bbit_ai/app/models/SentinelRecordRequest.py +++ b/bbit_ai/app/models/SentinelRecordRequest.py @@ -4,7 +4,6 @@ from pydantic import BaseModel class SentinelRecordRequest(BaseModel): Id: str | None = None DeviceId: str - LicensePlate: str | None = None - LicensePlateImage: str | None = None VehicleType: str | None = None - VehicleImage: str | None = None + vehicleImageFront: str | None = None + vehicleImageSide: str | None = None diff --git a/bbit_ai/app/models/SentinelRecordSideRequest.py b/bbit_ai/app/models/SentinelRecordSideRequest.py new file mode 100644 index 0000000..2ce65bb --- /dev/null +++ b/bbit_ai/app/models/SentinelRecordSideRequest.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel + + +class SentinelRecordSideRequest(BaseModel): + Id: str | None = None + DeviceId: str + VehicleType: str | None = None + VehicleImage: str | None = None diff --git a/bbit_ai/app/result/e707cc77bd1a4342b63e30321556a184.jpg b/bbit_ai/app/result/e707cc77bd1a4342b63e30321556a184.jpg new file mode 100644 index 0000000..78a9e6b Binary files /dev/null and b/bbit_ai/app/result/e707cc77bd1a4342b63e30321556a184.jpg differ diff --git a/bbit_ai/app/routers/Public.py b/bbit_ai/app/routers/Public.py index d39cb36..3cb8dc6 100644 --- a/bbit_ai/app/routers/Public.py +++ b/bbit_ai/app/routers/Public.py @@ -1,14 +1,16 @@ +import asyncio import base64 from fastapi import APIRouter from config.app import F8_SERVER_USER_ID -from db.postgres.sentinel import saveSentinelRecord from models.BaseResponse import BaseResponse from models.F8ImageRequest import F8ImageRequest from models.F8ImageRequestV2 import F8ImageRequestV2 from models.SentinelRecordRequest import SentinelRecordRequest -from service.RabbitMQ import sentinel_new_analysis +from service.RabbitMQ import ( + mq_client, +) from service.vision import ( process_ticket_image, process_license_image, @@ -85,8 +87,6 @@ async def recognize_silkworm_cocoon(data: F8ImageRequest): @publicRouter.post("/sentinel-record-analytics") async def delete_sentinel_record(data: SentinelRecordRequest): - # 保存部分数据到数据库 - data.Id = saveSentinelRecord(data) - # 发送请求给RabbitMQ - res = await sentinel_new_analysis(data) - return BaseResponse(data=res) + # 发送全盘分析请求给RabbitMQ + asyncio.create_task(mq_client.send_all_analysis(data)) + return BaseResponse(data="submitted") diff --git a/bbit_ai/app/routers/Sentinel.py b/bbit_ai/app/routers/Sentinel.py index 3777ecb..7f9cee3 100644 --- a/bbit_ai/app/routers/Sentinel.py +++ b/bbit_ai/app/routers/Sentinel.py @@ -107,22 +107,51 @@ async def get_sentinel_monitor_promotional_list( ): if not user_id: return {"error": "userId is required"} + # 图片过期时间:7天 + expiration_time = 60 * 60 * 24 * 7 return BaseResponse( data=[ { "id": 1, "remark": "人员公示及岗位职责", - "url": get_temp_url("sentinel", "promotional/promotional (2).jpg"), + "url": get_temp_url( + "sentinel", "promotional/promotional (2).jpg", expiration_time + ), }, { "id": 2, "remark": "入川动物监督检查工作流程图", - "url": get_temp_url("sentinel", "promotional/promotional (1).jpg"), + "url": get_temp_url( + "sentinel", "promotional/promotional (1).jpg", expiration_time + ), }, { "id": 3, - "remark": "四川省人民政府关于设立人川动物运输指定通道的通告", - "url": get_temp_url("sentinel", "promotional/promotional (3).jpg"), + "remark": "四川省人民政府关于设立入川动物运输指定通道的通告", + "url": get_temp_url( + "sentinel", "promotional/promotional (3).jpg", expiration_time + ), + }, + { + "id": 4, + "remark": "四川省动物卫生监督检查站工作程序", + "url": get_temp_url( + "sentinel", "promotional/promotional (4).jpg", expiration_time + ), + }, + { + "id": 5, + "remark": "四川省动物卫生监督检查站无害化处理制度", + "url": get_temp_url( + "sentinel", "promotional/promotional (5).jpg", expiration_time + ), + }, + { + "id": 6, + "remark": "四川省动物卫生监督检查站工作人员行为规范", + "url": get_temp_url( + "sentinel", "promotional/promotional (6).jpg", expiration_time + ), }, ] ) @@ -153,40 +182,48 @@ async def get_sentinel_monitor_list( ) url = "https://open.ys7.com/api/lapp/v2/live/address/get" - # device_serials = ["BG2493625"] - device_serials = ["BG2493625", "GH3713250", "GH3714496", "GH3714497"] + + device_serials = { + # "GG9175589": [1, 2, 3, 4], + "GG9175589": [1, 3], + "GG9175555": [1, 2], + } video_expire_time = 25 * 24 * 60 * 60 # 25 天 res = [] - for device_serial in device_serials: - live_key = f"ys7:live:{device_serial}" - cached_live = redis_client.get_value(live_key) + for device_serial, channels in device_serials.items(): + for channelNo in channels: + live_key = f"ys7:live:{device_serial}:{channelNo}" + cached_live = redis_client.get_value(live_key) + + if cached_live: + video_id = cached_live.get("id") + video_url = cached_live.get("url") + else: + payload = { + "accessToken": access_token, + "deviceSerial": device_serial, + "channelNo": channelNo, + "protocol": 4, + "expireTime": video_expire_time, + "supportH265": 0, + "quality": 2, + } + result = await http_client.post(url, data=payload) + + video_data = result.get("data") + if not video_data: + continue # 或者记录异常日志 + + video_id = video_data["id"] + video_url = video_data["url"] + + redis_client.set_value( + live_key, + {"id": video_id, "url": video_url}, + expire=video_expire_time, + ) + + res.append({"id": video_id, "url": video_url}) - if cached_live: - video_id = cached_live.get("id") - video_url = cached_live.get("url") - else: - payload = { - "accessToken": access_token, - "deviceSerial": device_serial, - "protocol": 4, # 流播放协议,1-ezopen、2-hls、3-rtmp、4-flv,默认为1 - "expireTime": video_expire_time, # 25天 - "supportH265": 0, - "quality": 2, - } - result = await http_client.post(url, data=payload) - video_id = result["data"]["id"] - video_url = result["data"]["url"] - # 存到 Redis,自动序列化为 JSON,过期 25天 - redis_client.set_value( - live_key, - {"id": video_id, "url": video_url}, - expire=video_expire_time, - ) - res.append( - { - "id": video_id, - "url": video_url, - } - ) return BaseResponse(data=res) diff --git a/bbit_ai/app/service/RabbitMQ.py b/bbit_ai/app/service/RabbitMQ.py index 826f956..b1a69f8 100644 --- a/bbit_ai/app/service/RabbitMQ.py +++ b/bbit_ai/app/service/RabbitMQ.py @@ -1,97 +1,103 @@ # consumer.py + import asyncio import json +import traceback import aio_pika from config.rabbitMQ import * -from models.AnalysisRequest import AnalysisRequest from models.SentinelRecordRequest import SentinelRecordRequest -from service.vision import process_vehicle_animal_image +from service.vision import ( + process_all_vehicle_animal_image, +) -async def mq_new_analysis_test(req: dict): - """将分析请求发送到 RabbitMQ 队列(异步版)""" - connection = await aio_pika.connect_robust( - f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{RABBIT_VHOST}" - ) +class MQClient: + """RabbitMQ 单例客户端,支持生产和消费""" - async with connection: - channel = await connection.channel() - # 声明队列,确保队列存在 - queue = await channel.declare_queue(QUEUE_NAME, durable=True) + _instance = None - message_body = json.dumps(req) + def __new__(cls): + if cls._instance is None: + cls._instance = super().__new__(cls) + return cls._instance + + def __init__(self): + self._connection = None + self._channel = None + self._consumer_tasks = [] + + # ---------------- 连接初始化 ---------------- + async def init(self, prefetch_count: int = 10): + """启动时初始化连接和通道""" + if self._connection is None: + self._connection = await aio_pika.connect_robust( + f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{SENTINEL_VHOST}" + ) + self._channel = await self._connection.channel() + await self._channel.set_qos(prefetch_count=prefetch_count) + + # ---------------- 发布消息 ---------------- + async def publish(self, queue_name: str, message_body: str): + """向指定队列发送消息""" + if self._channel is None: + raise RuntimeError("MQClient 未初始化") + # 队列幂等声明 + queue = await self._channel.declare_queue(queue_name, durable=True) message = aio_pika.Message( - body=message_body.encode(), - delivery_mode=aio_pika.DeliveryMode.PERSISTENT, # 持久化 + body=message_body.encode(), delivery_mode=aio_pika.DeliveryMode.PERSISTENT + ) + await self._channel.default_exchange.publish(message, routing_key=queue_name) + + async def send_all_analysis(self, req: SentinelRecordRequest): + await self.publish( + SENTINEL_ANALYSIS_ALL_QUEUE_NAME, json.dumps(req.model_dump()) ) - await channel.default_exchange.publish(message, routing_key=QUEUE_NAME) + # ---------------- 消费消息 ---------------- + async def consume_queue(self, queue_name: str, process_func): + """ + 持续消费队列 + process_func: async function 接收 dict 或 Request 对象 + """ + if self._channel is None: + raise RuntimeError("MQClient 未初始化") - -async def mq_pull_analysis_async_test(): - """ - 从队列拉取分析任务并处理 - process_func: 一个函数,接收 AnalysisRequest 对象处理分析逻辑 - """ - connection = await aio_pika.connect_robust( - f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{RABBIT_VHOST}" - ) - async with connection: - queue_name = QUEUE_NAME - channel = await connection.channel() - await channel.set_qos(prefetch_count=1) - queue = await channel.declare_queue(queue_name, durable=True) + queue = await self._channel.declare_queue(queue_name, durable=True) async with queue.iterator() as queue_iter: async for message in queue_iter: async with message.process(): - data = json.loads(message.body) - req = AnalysisRequest(**data) - print(f"收到任务: {req}") - await asyncio.sleep(5) # 模拟处理 - print(f"完成任务: {req}") + try: + body = message.body.decode() + data = json.loads(body) + await process_func(data) + except Exception as e: + print(f"[MQ Consume Error] {e}") + traceback.print_exc() + # ---------------- 启动全局分析消费者 ---------------- + async def start_all_consumer(self): + async def _process(data: dict): + req = SentinelRecordRequest(**data) + await process_all_vehicle_animal_image(req) + print(f"完成全局分析任务: {req}") -async def sentinel_new_analysis(req: SentinelRecordRequest): - """将分析请求发送到 RabbitMQ 队列(异步版)""" - connection = await aio_pika.connect_robust( - f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{SENTINEL_VHOST}" - ) - - async with connection: - channel = await connection.channel() - # 声明队列,确保队列存在 - queue = await channel.declare_queue(QUEUE_NAME, durable=True) - - message_body = json.dumps(req.model_dump()) - message = aio_pika.Message( - body=message_body.encode(), - delivery_mode=aio_pika.DeliveryMode.PERSISTENT, # 持久化 + task = asyncio.create_task( + self.consume_queue(SENTINEL_ANALYSIS_ALL_QUEUE_NAME, _process) ) + self._consumer_tasks.append(task) - await channel.default_exchange.publish(message, routing_key=QUEUE_NAME) + # ---------------- 关闭连接 ---------------- + async def close(self): + for task in self._consumer_tasks: + task.cancel() + if self._channel: + await self._channel.close() + if self._connection: + await self._connection.close() -async def sentinel_pull_analysis_async(): - """ - 从队列拉取分析任务并处理 - process_func: 一个函数,接收 AnalysisRequest 对象处理分析逻辑 - """ - connection = await aio_pika.connect_robust( - f"amqp://{RABBIT_USER}:{RABBIT_PASSWORD}@{RABBIT_HOST}/{SENTINEL_VHOST}" - ) - async with connection: - queue_name = QUEUE_NAME - channel = await connection.channel() - await channel.set_qos(prefetch_count=1) - queue = await channel.declare_queue(queue_name, durable=True) - - async with queue.iterator() as queue_iter: - async for message in queue_iter: - async with message.process(): - data = json.loads(message.body) - req = SentinelRecordRequest(**data) - await process_vehicle_animal_image(req) # 处理 - print(f"完成任务: {req}") +# ---------------- 全局单例 ---------------- +mq_client = MQClient() diff --git a/bbit_ai/app/service/vision.py b/bbit_ai/app/service/vision.py index 2f9f08c..6e03ba3 100644 --- a/bbit_ai/app/service/vision.py +++ b/bbit_ai/app/service/vision.py @@ -1,21 +1,27 @@ import uuid +from pathlib import Path from uuid import UUID import config.minIO as minIO import db.postgres as pg from agent.licenseImageAgent import get_license_response from agent.vehicleImageAgent import get_vehicle_response +from ai.plate.my_plate import recognizer from config.minIO import minio_client, get_temp_url from config.yolo import YOLOSingleton from db.postgres import ( get_dept_id_by_iot_user_name, get_dept_ids_by_dept_id, ) -from db.postgres.sentinel import update_sentinel_record, get_sentinel_record_by_id +from db.postgres.sentinel import ( + get_sentinel_record_by_id, + saveSentinelRecord, +) from llm.ticketLLM import * from llm.ticketLLMv2 import get_ticket_response_v2 from models.SentinelRecordRequest import SentinelRecordRequest from routers.WS import ws_manager +from utils import validate_plate def process_ticket_image( @@ -189,35 +195,97 @@ def process_silkworm_cocoon_image( } -async def process_vehicle_animal_image( +# 处理车牌照片 +async def process_all_vehicle_animal_image( data: SentinelRecordRequest, ): - # 通过设备id获得组织id - dept_id = get_dept_id_by_iot_user_name(data.DeviceId) + oss_url = minIO.get_temp_url( + "sentinel", "vehicle_image_side/" + data.vehicleImageSide + ) - # 得到动物类型 - oss_url = minIO.get_temp_url("sentinel", "vehicle_image/" + data.VehicleImage) + # LLM得到车身信息 analysis_result = await get_vehicle_response(oss_url) livestock_type = analysis_result.get("livestock_type", "") remark = analysis_result.get("remark", "") - # 保存到数据库 - update_sentinel_record(data.Id, livestock_type, remark, dept_id) - # 可以通知的部门ids - available_departments = get_dept_ids_by_dept_id(dept_id) + have_animal = analysis_result.get("have_animal", False) - # 通知控制界面 - await ws_manager.noticeSentinel( - { - "content": f"载有{livestock_type}的车辆即将进入关卡,请准备检查", - "type": "vehicle_alert", - }, - available_departments, - ) - # 通知大屏界面 - await ws_manager.noticeSentinelMonitorStatus( - { - "content": get_sentinel_record_by_id(data.Id), - "type": "vehicle_alert", - }, - available_departments, - ) + if not have_animal: + minIO.delete_file("sentinel", "vehicle_image_side/" + data.vehicleImageSide) + minIO.delete_file("sentinel", "vehicle_image_front/" + data.vehicleImageFront) + else: + # 通过设备id获得组织id + dept_id = get_dept_id_by_iot_user_name(data.DeviceId) + + # 处理汽车正面照-------------------------- + + # 从OSS下载 + oss_url = minIO.get_temp_url( + "sentinel", "vehicle_image_front/" + data.vehicleImageFront + ) + + # 获取系统临时目录(自动兼容 Windows / Linux) + tmp_dir = Path(tempfile.gettempdir()) + tmp_dir.mkdir(parents=True, exist_ok=True) + + tmp_path = tmp_dir / data.vehicleImageFront + + # 下载图片到 tmp_path + response = requests.get(oss_url, stream=True) + if response.status_code == 200: + with open(tmp_path, "wb") as f: + for chunk in response.iter_content(chunk_size=8192): + f.write(chunk) + else: + raise Exception(f"下载失败: {oss_url}, status_code={response.status_code}") + + # 调用识别 + results = recognizer.analyze_image(str(tmp_path)) + # 车牌号一次识别 + license_plate = results[0].plate_no if results else "" + # 车牌号二次校准 + license_plate = validate_plate(license_plate) + license_plate_color_str = results[0].plate_color if results else "" + color_map = { + "蓝色": 0, + "黄色": 1, + "绿色": 2, + "黑色": 3, + "白色": 4, + } + + license_plate_color = color_map.get(license_plate_color_str, 0) + license_plate_image = data.vehicleImageFront + # 保存到数据库 + saveSentinelRecord( + data.Id, + data.VehicleType, + data.vehicleImageSide, + livestock_type, + remark, + dept_id, + license_plate, + license_plate_image, + license_plate_color, + ) + # 识别完成后删除临时文件 + os.remove(tmp_path) + + # 可以通知的部门ids + available_departments = get_dept_ids_by_dept_id(dept_id) + # 通知控制界面 + await ws_manager.noticeSentinel( + { + "content": f"车辆即将进入关卡,请准备检查", + "license_plate": license_plate, + "type": "vehicle_alert", + }, + available_departments, + ) + # 通知大屏界面 + await ws_manager.noticeSentinelMonitorStatus( + { + "content": get_sentinel_record_by_id(data.Id), + "type": "vehicle_alert", + }, + available_departments, + ) diff --git a/bbit_ai/app/utils/MyUtils.py b/bbit_ai/app/utils/MyUtils.py index 2f62940..acb009d 100644 --- a/bbit_ai/app/utils/MyUtils.py +++ b/bbit_ai/app/utils/MyUtils.py @@ -100,3 +100,57 @@ def get_memory_total(): def get_disk_total(): return psutil.disk_usage("/").total + + +def translate_vehicle_type(vehicle_type: str) -> str: + mapping = { + "coupe": "双门跑车", + "largevehicle": "大型车辆", + "sedan": "三厢轿车", + "suv": "运动型多用途车", + "truck": "卡车", + "van": "面包车", + } + + return mapping.get(vehicle_type.lower(), vehicle_type) + + +import re + + +def validate_plate(plate: str) -> str: + if not plate: + return "车速过快,无法识别" + + plate = plate.strip().upper() + + # 普通车牌(燃油) + normal_pattern = r"^[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼][A-Z][A-Z0-9]{5}$" + + # 新能源车 + new_energy_pattern = r"^[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼][A-Z][A-Z0-9]{6}$" + + # 武警车 + wj_pattern = ( + r"^WJ[京津沪渝冀豫云辽黑湘皖鲁新苏浙赣鄂桂甘晋蒙陕吉闽贵粤青藏川宁琼]?\d{5}$" + ) + + # 外交车 + diplomatic_pattern = r"^(使|领)\d{5}$" + + # 港澳车辆 + hk_macau_pattern = r"^粤Z[A-Z0-9]{4}(港|澳)$" + + patterns = [ + normal_pattern, + new_energy_pattern, + wj_pattern, + diplomatic_pattern, + hk_macau_pattern, + ] + + for p in patterns: + if re.match(p, plate): + return plate + + return "车速过快,无法识别"