后端新增《蚕茧识别V2》模块
This commit is contained in:
@@ -0,0 +1,156 @@
|
||||
import logging
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image, ImageDraw, ImageFont
|
||||
from ultralytics import YOLO
|
||||
|
||||
from utils import MyUtils
|
||||
|
||||
|
||||
def draw_annotations(
|
||||
img_bgr,
|
||||
boxes,
|
||||
labels,
|
||||
confidences=None,
|
||||
font_path="/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc",
|
||||
):
|
||||
"""
|
||||
绘制带中文、置信度、不同颜色的标注框
|
||||
:param img_bgr: np.ndarray, BGR 图像
|
||||
:param boxes: list of xyxy
|
||||
:param labels: list of str
|
||||
:param confidences: list of float, 可选,用于显示置信度
|
||||
:param font_path: 字体路径
|
||||
:return: np.ndarray, BGR 标注图
|
||||
"""
|
||||
h, w, _ = img_bgr.shape
|
||||
line_width = max(2, int(w / 400))
|
||||
font_size = max(20, int(w / 40))
|
||||
|
||||
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||||
pil_img = Image.fromarray(img)
|
||||
draw = ImageDraw.Draw(pil_img)
|
||||
font = ImageFont.truetype(font_path, font_size)
|
||||
|
||||
# 定义类别颜色映射
|
||||
color_map = {
|
||||
"正茧": "green",
|
||||
"双宫茧": "grey",
|
||||
"黄斑茧": "red",
|
||||
"毛茧": "white",
|
||||
"蛆壳茧": "purple",
|
||||
}
|
||||
|
||||
for idx, (b, label) in enumerate(zip(boxes, labels), start=1):
|
||||
conf = confidences[idx - 1] if confidences else None
|
||||
display_text = f"{label}#{idx}"
|
||||
if conf is not None:
|
||||
display_text += f" {conf:.2f}"
|
||||
|
||||
x1, y1, x2, y2 = map(int, b)
|
||||
box_color = color_map.get(label, "red")
|
||||
|
||||
# 画框
|
||||
draw.rectangle([x1, y1, x2, y2], outline=box_color, width=line_width)
|
||||
|
||||
# 文本边界
|
||||
bbox = draw.textbbox((0, 0), display_text, font=font)
|
||||
text_width = bbox[2] - bbox[0]
|
||||
text_height = bbox[3] - bbox[1]
|
||||
|
||||
# 文本背景
|
||||
draw.rectangle(
|
||||
[x1, y1 - text_height - 4, x1 + text_width + 4, y1], fill="black"
|
||||
)
|
||||
draw.text((x1 + 2, y1 - text_height - 2), display_text, fill="white", font=font)
|
||||
|
||||
annotated = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||||
return annotated
|
||||
|
||||
|
||||
logger = logging.getLogger("yolo_service")
|
||||
|
||||
|
||||
class YOLOSingleton:
|
||||
_model = None
|
||||
_ready = False
|
||||
|
||||
@classmethod
|
||||
def instance(cls):
|
||||
"""获取模型单例"""
|
||||
if cls._model is None:
|
||||
logger.warning("模型尚未初始化")
|
||||
return cls._model
|
||||
|
||||
@classmethod
|
||||
def is_ready(cls):
|
||||
return cls._ready
|
||||
|
||||
@classmethod
|
||||
def init_model(cls):
|
||||
"""初始化模型,失败时不抛出异常"""
|
||||
try:
|
||||
cls._model = YOLO("/app/models/yolo/yolo_silkworm_cocoon_detect_v1.pt")
|
||||
cls._model.to("cuda" if torch.cuda.is_available() else "cpu")
|
||||
cls._ready = True
|
||||
logger.info("✅ YOLO 模型加载完成")
|
||||
except Exception as e:
|
||||
cls._model = None
|
||||
cls._ready = False
|
||||
logger.error(f"❌ YOLO 模型加载失败: {e}")
|
||||
|
||||
@classmethod
|
||||
def detect(cls, img_bytes: bytes):
|
||||
if not cls._ready or cls._model is None:
|
||||
raise RuntimeError("模型未加载或不可用")
|
||||
|
||||
label_map = {
|
||||
"normal": "正茧",
|
||||
"double_pupa": "双宫茧",
|
||||
"spot": "黄斑茧",
|
||||
"hairy": "毛茧",
|
||||
"maggot_shell": "蛆壳茧",
|
||||
}
|
||||
|
||||
img = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR)
|
||||
results = cls._model(img, conf=0.45)
|
||||
r = results[0]
|
||||
boxes = r.boxes
|
||||
names = cls._model.names
|
||||
|
||||
total = len(boxes)
|
||||
class_counts = {}
|
||||
confidences = []
|
||||
box_list = []
|
||||
label_list = []
|
||||
|
||||
for idx, b in enumerate(boxes):
|
||||
cls_id = int(b.cls)
|
||||
conf = float(b.conf)
|
||||
label_en = names.get(cls_id, str(cls_id))
|
||||
label_cn = label_map.get(label_en, label_en)
|
||||
class_counts[label_cn] = class_counts.get(label_cn, 0) + 1
|
||||
confidences.append(conf)
|
||||
box_list.append(b.xyxy[0])
|
||||
label_list.append(label_cn)
|
||||
|
||||
# 用 PIL 绘制中文
|
||||
annotated = draw_annotations(img, box_list, label_list, confidences)
|
||||
|
||||
_, buffer = cv2.imencode(".jpg", annotated)
|
||||
img_bytes_out = buffer.tobytes()
|
||||
|
||||
result_json = {
|
||||
"total_objects": total,
|
||||
"class_counts": class_counts,
|
||||
"min_confidence": MyUtils.safe_round(min(confidences), 4),
|
||||
"max_confidence": MyUtils.safe_round(max(confidences), 4),
|
||||
"avg_confidence": (
|
||||
MyUtils.safe_round(sum(confidences) / len(confidences), 4)
|
||||
),
|
||||
"speed_ms": r.speed, # 直接来自 YOLO
|
||||
}
|
||||
|
||||
return img_bytes_out, result_json
|
||||
Reference in New Issue
Block a user