221 lines
6.2 KiB
Python
221 lines
6.2 KiB
Python
import logging
|
||
|
||
import cv2
|
||
import numpy as np
|
||
import torch
|
||
from ultralytics import YOLO
|
||
import os
|
||
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
import numpy as np
|
||
import cv2
|
||
import io
|
||
|
||
# 设置工作目录
|
||
os.chdir("/home/bbit/mine/yolo/")
|
||
from PIL import Image, ImageDraw, ImageFont
|
||
import cv2
|
||
import numpy as np
|
||
|
||
def draw_annotations(img_bgr, boxes, labels, confidences=None,
|
||
font_path="/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc"):
|
||
"""
|
||
绘制带中文、置信度、不同颜色的标注框
|
||
:param img_bgr: np.ndarray, BGR 图像
|
||
:param boxes: list of xyxy
|
||
:param labels: list of str
|
||
:param confidences: list of float, 可选,用于显示置信度
|
||
:param font_path: 字体路径
|
||
:return: np.ndarray, BGR 标注图
|
||
"""
|
||
h, w, _ = img_bgr.shape
|
||
line_width = max(2, int(w / 400))
|
||
font_size = max(20, int(w / 40))
|
||
|
||
img = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
|
||
pil_img = Image.fromarray(img)
|
||
draw = ImageDraw.Draw(pil_img)
|
||
font = ImageFont.truetype(font_path, font_size)
|
||
|
||
# 定义类别颜色映射
|
||
color_map = {
|
||
"正茧": "green",
|
||
"双宫茧": "grey",
|
||
"黄斑茧": "red",
|
||
"毛茧": "white",
|
||
"蛆壳茧": "purple"
|
||
}
|
||
|
||
for idx, (b, label) in enumerate(zip(boxes, labels), start=1):
|
||
conf = confidences[idx-1] if confidences else None
|
||
display_text = f"{label}#{idx}"
|
||
if conf is not None:
|
||
display_text += f" {conf:.2f}"
|
||
|
||
x1, y1, x2, y2 = map(int, b)
|
||
box_color = color_map.get(label, "red")
|
||
|
||
# 画框
|
||
draw.rectangle([x1, y1, x2, y2], outline=box_color, width=line_width)
|
||
|
||
# 文本边界
|
||
bbox = draw.textbbox((0,0), display_text, font=font)
|
||
text_width = bbox[2] - bbox[0]
|
||
text_height = bbox[3] - bbox[1]
|
||
|
||
# 文本背景
|
||
draw.rectangle([x1, y1 - text_height - 4, x1 + text_width + 4, y1], fill="black")
|
||
draw.text((x1 + 2, y1 - text_height - 2), display_text, fill="white", font=font)
|
||
|
||
annotated = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
|
||
return annotated
|
||
|
||
|
||
|
||
|
||
logger = logging.getLogger("yolo_service")
|
||
|
||
|
||
class YOLOSingleton:
|
||
_model = None
|
||
_ready = False
|
||
|
||
@classmethod
|
||
def instance(cls):
|
||
"""获取模型单例"""
|
||
if cls._model is None:
|
||
logger.warning("模型尚未初始化")
|
||
return cls._model
|
||
|
||
@classmethod
|
||
def is_ready(cls):
|
||
return cls._ready
|
||
|
||
@classmethod
|
||
def init_model(cls):
|
||
"""初始化模型,失败时不抛出异常"""
|
||
try:
|
||
cls._model = YOLO("runs/detect/train13/weights/best.pt")
|
||
cls._model.to("cuda" if torch.cuda.is_available() else "cpu")
|
||
cls._ready = True
|
||
logger.info("✅ YOLO 模型加载完成")
|
||
except Exception as e:
|
||
cls._model = None
|
||
cls._ready = False
|
||
logger.error(f"❌ YOLO 模型加载失败: {e}")
|
||
|
||
|
||
@classmethod
|
||
def detect(cls, img_bytes: bytes):
|
||
if not cls._ready or cls._model is None:
|
||
raise RuntimeError("模型未加载或不可用")
|
||
|
||
label_map = {
|
||
"normal": "正茧",
|
||
"double_pupa": "双宫茧",
|
||
"spot": "黄斑茧",
|
||
"hairy": "毛茧",
|
||
"maggot_shell": "蛆壳茧"
|
||
}
|
||
|
||
img = cv2.imdecode(np.frombuffer(img_bytes, np.uint8), cv2.IMREAD_COLOR)
|
||
results = cls._model(img, conf=0.45)
|
||
r = results[0]
|
||
boxes = r.boxes
|
||
names = cls._model.names
|
||
|
||
total = len(boxes)
|
||
class_counts = {}
|
||
confidences = []
|
||
box_list = []
|
||
label_list = []
|
||
|
||
for idx, b in enumerate(boxes):
|
||
cls_id = int(b.cls)
|
||
conf = float(b.conf)
|
||
label_en = names.get(cls_id, str(cls_id))
|
||
label_cn = label_map.get(label_en, label_en)
|
||
class_counts[label_cn] = class_counts.get(label_cn, 0) + 1
|
||
confidences.append(conf)
|
||
box_list.append(b.xyxy[0])
|
||
label_list.append(label_cn)
|
||
|
||
# 用 PIL 绘制中文
|
||
annotated = draw_annotations(img, box_list, label_list, confidences)
|
||
|
||
_, buffer = cv2.imencode(".jpg", annotated)
|
||
img_bytes_out = buffer.tobytes()
|
||
|
||
result_json = {
|
||
"total_objects": total,
|
||
"class_counts": class_counts,
|
||
"min_confidence": round(min(confidences), 4) if confidences else 0.0,
|
||
"max_confidence": round(max(confidences), 4) if confidences else 0.0,
|
||
"avg_confidence": (
|
||
round(sum(confidences) / len(confidences), 4) if confidences else 0.0
|
||
),
|
||
"speed_ms": r.speed, # 直接来自 YOLO
|
||
}
|
||
|
||
return img_bytes_out, result_json
|
||
|
||
|
||
|
||
import logging
|
||
from pathlib import Path
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s [%(levelname)s] %(message)s"
|
||
)
|
||
logger = logging.getLogger("yolo_test")
|
||
import os
|
||
|
||
@staticmethod
|
||
def save_image(img_bytes: bytes, save_path: str):
|
||
"""
|
||
将检测后的字节流保存为图片文件
|
||
|
||
:param img_bytes: 图片的二进制字节(通常来自 detect() 返回)
|
||
:param save_path: 保存路径(例如 '/home/bbit/mine/yolo/output/result.jpg')
|
||
"""
|
||
# 确保目录存在
|
||
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
||
|
||
# 写入文件
|
||
with open(save_path, "wb") as f:
|
||
f.write(img_bytes)
|
||
|
||
def main():
|
||
# 初始化模型
|
||
YOLOSingleton.init_model()
|
||
|
||
if not YOLOSingleton.is_ready():
|
||
logger.error("模型未初始化成功,退出")
|
||
return
|
||
|
||
# 测试图片路径
|
||
img_path = Path("/home/bbit/mine/yolo/valid/raw_1.jpg")
|
||
if not img_path.exists():
|
||
logger.error(f"图片不存在: {img_path}")
|
||
return
|
||
|
||
# 读取图片为字节
|
||
with open(img_path, "rb") as f:
|
||
img_bytes = f.read()
|
||
|
||
# 执行检测
|
||
try:
|
||
img_bytes_out,result = YOLOSingleton.detect(img_bytes)
|
||
logger.info("检测结果:")
|
||
import json
|
||
print(json.dumps(result, indent=4))
|
||
save_image(img_bytes_out,'/home/bbit/mine/yolo/output/result.jpg')
|
||
except Exception as e:
|
||
logger.error(f"检测失败: {e}")
|
||
|
||
if __name__ == "__main__":
|
||
main()
|