后端更新

This commit is contained in:
BBIT-Kai
2026-03-26 17:48:20 +08:00
parent 4c2bcd7dce
commit 0c2859b0db
22 changed files with 1336 additions and 213 deletions
View File
+433
View File
@@ -0,0 +1,433 @@
import os
from dataclasses import dataclass
import cv2
import numpy as np
from PIL import Image, ImageDraw, ImageFont
from ultralytics import YOLO
from ai.plate.plate_recognition.double_plate_split_merge import get_split_merge
from ai.plate.plate_recognition.plate_rec import get_plate_result
@dataclass
class PlateResult:
plate: str
color: str
def collect_files(root_path):
file_list = []
for root, _, files in os.walk(root_path):
for name in files:
file_list.append(os.path.join(root, name))
return sorted(file_list)
def four_point_transform(image, pts):
rect = pts.astype(np.float32)
(tl, tr, br, bl) = rect
width_a = np.sqrt(((br[0] - bl[0]) ** 2) + ((br[1] - bl[1]) ** 2))
width_b = np.sqrt(((tr[0] - tl[0]) ** 2) + ((tr[1] - tl[1]) ** 2))
max_width = max(int(width_a), int(width_b))
height_a = np.sqrt(((tr[0] - br[0]) ** 2) + ((tr[1] - br[1]) ** 2))
height_b = np.sqrt(((tl[0] - bl[0]) ** 2) + ((tl[1] - bl[1]) ** 2))
max_height = max(int(height_a), int(height_b))
dst = np.array(
[
[0, 0],
[max_width - 1, 0],
[max_width - 1, max_height - 1],
[0, max_height - 1],
],
dtype=np.float32,
)
matrix = cv2.getPerspectiveTransform(rect, dst)
return cv2.warpPerspective(image, matrix, (max_width, max_height))
def load_model(weights, device):
model = YOLO(weights)
model.to(device)
return model
def det_rec_plate(img_ori, detect_model, plate_rec_model, device, conf=0.3, iou=0.5):
result_list = []
results = detect_model(img_ori, conf=conf, iou=iou, verbose=False)
for result in results:
boxes = result.boxes
keypoints = result.keypoints
if len(boxes) == 0 or keypoints is None:
continue
kpts_xy = keypoints.xy
num_det = min(len(boxes), len(kpts_xy))
for idx in range(num_det):
box = boxes.xyxy[idx].cpu().numpy()
det_conf = float(boxes.conf[idx])
plate_type = int(boxes.cls[idx])
landmarks = kpts_xy[idx].cpu().numpy().astype(np.int64)
roi_img = four_point_transform(img_ori, landmarks)
if plate_type == 1:
roi_img = get_split_merge(roi_img)
plate_number, _, plate_color, color_conf = get_plate_result(
roi_img, device, plate_rec_model, is_color=True
)
result_list.append(
{
"plate_no": plate_number,
"plate_color": plate_color,
"rect": [int(v) for v in box],
"detect_conf": det_conf,
"landmarks": landmarks.tolist(),
"roi_height": roi_img.shape[0],
"color_conf": color_conf,
"plate_type": plate_type, # 0: 单层, 1: 双层
}
)
return result_list
def _clamp(value, low, high):
return max(low, min(high, value))
def _normalize_rect(rect):
if not rect or len(rect) < 4:
return None
x1 = int(round(float(rect[0])))
y1 = int(round(float(rect[1])))
x2 = int(round(float(rect[2])))
y2 = int(round(float(rect[3])))
left = min(x1, x2)
top = min(y1, y2)
right = max(x1, x2)
bottom = max(y1, y2)
if right <= left or bottom <= top:
return None
return [left, top, right, bottom]
def _get_plate_theme(plate_color):
theme_map = {
"蓝色": {
"bg": (72, 33, 6),
"border": (250, 165, 96),
"text": (254, 242, 224),
"glow": (246, 130, 59),
},
"黄色": {
"bg": (0, 49, 74),
"border": (21, 204, 250),
"text": (195, 249, 254),
"glow": (8, 179, 234),
},
"绿色": {
"bg": (27, 53, 4),
"border": (128, 222, 74),
"text": (231, 252, 220),
"glow": (94, 197, 34),
},
"白色": {
"bg": (68, 50, 38),
"border": (240, 232, 226),
"text": (250, 250, 248),
"glow": (184, 163, 148),
},
"黑色": {
"bg": (20, 12, 8),
"border": (184, 163, 148),
"text": (240, 232, 226),
"glow": (139, 116, 100),
},
}
return theme_map.get(
plate_color,
{
"bg": (43, 24, 8),
"border": (248, 189, 56),
"text": (255, 248, 232),
"glow": (200, 140, 32),
},
)
def _draw_alpha_rect(img, x1, y1, x2, y2, color, alpha=0.75):
h, w = img.shape[:2]
x1 = _clamp(x1, 0, w - 1)
y1 = _clamp(y1, 0, h - 1)
x2 = _clamp(x2, 0, w)
y2 = _clamp(y2, 0, h)
if x2 <= x1 or y2 <= y1:
return
roi = img[y1:y2, x1:x2]
overlay = np.full_like(roi, color, dtype=np.uint8)
cv2.addWeighted(overlay, alpha, roi, 1 - alpha, 0, roi)
def _measure_text(text, text_size=16):
try:
font = ImageFont.truetype(get_font_file_path(), text_size, encoding="utf-8")
left, top, right, bottom = font.getbbox(text)
return right - left, bottom - top
except Exception:
print("字体文件加载失败")
size, baseline = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.45, 1)
return size[0], size[1] + baseline
def _draw_glow_border(img, x1, y1, x2, y2, border_color, glow_color):
h, w = img.shape[:2]
x1 = _clamp(x1, 0, w - 1)
y1 = _clamp(y1, 0, h - 1)
x2 = _clamp(x2, 0, w - 1)
y2 = _clamp(y2, 0, h - 1)
if x2 <= x1 or y2 <= y1:
return
glow_layer = np.zeros_like(img)
cv2.rectangle(glow_layer, (x1, y1), (x2, y2), glow_color, 2)
glow_layer = cv2.GaussianBlur(glow_layer, (0, 0), sigmaX=1.6, sigmaY=1.6)
cv2.addWeighted(glow_layer, 0.45, img, 1.0, 0, img)
cv2.rectangle(img, (x1, y1), (x2, y2), border_color, 1)
def _draw_tech_box(img, x1, y1, x2, y2, border_color, glow_color, track_id=None):
h, w = img.shape[:2]
x1 = _clamp(x1, 0, w - 1)
y1 = _clamp(y1, 0, h - 1)
x2 = _clamp(x2, 0, w - 1)
y2 = _clamp(y2, 0, h - 1)
if x2 <= x1 or y2 <= y1:
return
bw = x2 - x1
bh = y2 - y1
diag = float(np.hypot(bw, bh))
base_thick = _clamp(int(round(diag / 70.0)), 2, 5)
glow_sigma = _clamp(diag / 55.0, 1.2, 3.6)
glow_layer = np.zeros_like(img)
cv2.rectangle(glow_layer, (x1, y1), (x2, y2), glow_color, max(1, base_thick - 1))
glow_layer = cv2.GaussianBlur(
glow_layer, (0, 0), sigmaX=glow_sigma, sigmaY=glow_sigma
)
cv2.addWeighted(glow_layer, 0.48, img, 1.0, 0, img)
cv2.rectangle(img, (x1, y1), (x2, y2), border_color, max(1, base_thick - 1))
corner_len = _clamp(int(round(min(bw, bh) * 0.28)), 8, 20)
t = base_thick
cv2.line(img, (x1, y1), (x1 + corner_len, y1), border_color, t)
cv2.line(img, (x1, y1), (x1, y1 + corner_len), border_color, t)
cv2.line(img, (x2, y1), (x2 - corner_len, y1), border_color, t)
cv2.line(img, (x2, y1), (x2, y1 + corner_len), border_color, t)
cv2.line(img, (x1, y2), (x1 + corner_len, y2), border_color, t)
cv2.line(img, (x1, y2), (x1, y2 - corner_len), border_color, t)
cv2.line(img, (x2, y2), (x2 - corner_len, y2), border_color, t)
cv2.line(img, (x2, y2), (x2, y2 - corner_len), border_color, t)
if track_id is not None:
badge = "T%02d" % track_id
badge_w_txt, badge_h_txt = _measure_text(badge, text_size=12)
pad_x = 5
pad_y = 3
badge_w = badge_w_txt + pad_x * 2
badge_h = badge_h_txt + pad_y * 2
bx2 = _clamp(x2, badge_w + 2, w - 2)
by1 = _clamp(y1 - badge_h - 2, 2, h - badge_h - 2)
bx1 = bx2 - badge_w
by2 = by1 + badge_h
_draw_alpha_rect(img, bx1, by1, bx2, by2, (18, 18, 18), alpha=0.65)
cv2.rectangle(img, (bx1, by1), (bx2, by2), border_color, 1)
text_x = bx1 + pad_x
text_y = by1 + max(1, int((badge_h - badge_h_txt) / 2))
img[:] = cv2ImgAddText(img, badge, text_x, text_y, border_color, 12)
def cv2ImgAddText(img, text, left, top, textColor=(0, 255, 0), textSize=20):
if isinstance(img, np.ndarray): # 判断是否OpenCV图片类型
img = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
draw = ImageDraw.Draw(img)
fontText = ImageFont.truetype(get_font_file_path(), textSize, encoding="utf-8")
draw.text((left, top), text, textColor, font=fontText)
return cv2.cvtColor(np.asarray(img), cv2.COLOR_RGB2BGR)
def get_font_file_path():
# win
# font_path = r"C:\Users\BBIT\Desktop\yolo26-plate-main\fonts\platech.ttf"
# linux
font_path = r"/app/models/sentinel/platech.ttf"
return font_path
def _draw_tech_landmark(img, x, y, border_color, glow_color):
h, w = img.shape[:2]
x = _clamp(int(round(x)), 0, w - 1)
y = _clamp(int(round(y)), 0, h - 1)
glow_layer = np.zeros_like(img)
cv2.circle(glow_layer, (x, y), 5, glow_color, -1)
glow_layer = cv2.GaussianBlur(glow_layer, (0, 0), sigmaX=1.2, sigmaY=1.2)
cv2.addWeighted(glow_layer, 0.5, img, 1.0, 0, img)
cv2.circle(img, (x, y), 2, border_color, -1)
cv2.circle(img, (x, y), 4, border_color, 1)
def _plate_width_from_landmarks(landmarks, fallback_width):
if not landmarks or len(landmarks) < 4:
return float(fallback_width)
try:
p0 = np.array(landmarks[0], dtype=np.float32)
p1 = np.array(landmarks[1], dtype=np.float32)
p2 = np.array(landmarks[2], dtype=np.float32)
p3 = np.array(landmarks[3], dtype=np.float32)
top_w = float(np.linalg.norm(p1 - p0))
bottom_w = float(np.linalg.norm(p2 - p3))
width = (top_w + bottom_w) / 2.0
if np.isfinite(width) and width > 1:
return width
except Exception:
pass
return float(fallback_width)
def _plate_height_from_landmarks(landmarks, fallback_height):
if not landmarks or len(landmarks) < 4:
return float(fallback_height)
try:
p0 = np.array(landmarks[0], dtype=np.float32)
p1 = np.array(landmarks[1], dtype=np.float32)
p2 = np.array(landmarks[2], dtype=np.float32)
p3 = np.array(landmarks[3], dtype=np.float32)
left_h = float(np.linalg.norm(p3 - p0))
right_h = float(np.linalg.norm(p2 - p1))
height = (left_h + right_h) / 2.0
if np.isfinite(height) and height > 1:
return height
except Exception:
pass
return float(fallback_height)
def _fit_font_size(text, max_w, max_h, min_size=10, max_size=24):
if max_w <= 0 or max_h <= 0:
return min_size
for size in range(max_size, min_size - 1, -1):
tw, th = _measure_text(text, text_size=size)
if tw <= max_w and th <= max_h:
return size
return min_size
def draw_result(orgimg, result_list):
landmark_colors = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0)]
result_str = []
img_h, img_w = orgimg.shape[:2]
for idx, result in enumerate(result_list, start=1):
raw_rect = _normalize_rect(result.get("rect"))
if raw_rect is None:
continue
x1, y1, x2, y2 = raw_rect
w = x2 - x1
h = y2 - y1
padding_w = int(round(0.05 * w))
padding_h = int(round(0.11 * h))
rx1 = _clamp(x1 - padding_w, 0, img_w - 1)
ry1 = _clamp(y1 - padding_h, 0, img_h - 1)
rx2 = _clamp(x2 + padding_w, 0, img_w - 1)
ry2 = _clamp(y2 + padding_h, 0, img_h - 1)
landmarks = result.get("landmarks", [])
plate_no = result.get("plate_no", "")
plate_color = result.get("plate_color", "")
if result.get("plate_type", 0) == 1:
result_p = "%s %s双层" % (plate_no, plate_color)
else:
result_p = "%s %s" % (plate_no, plate_color)
result_str.append(result_p)
theme = _get_plate_theme(plate_color)
for i in range(min(4, len(landmarks))):
point = landmarks[i]
if len(point) < 2:
continue
point_color = landmark_colors[i]
_draw_tech_landmark(orgimg, point[0], point[1], point_color, point_color)
_draw_tech_box(
orgimg, rx1, ry1, rx2, ry2, theme["border"], theme["glow"], track_id=idx
)
label = "%s | %s" % (plate_no, plate_color)
plate_w = _plate_width_from_landmarks(landmarks, rx2 - rx1)
plate_h = _plate_height_from_landmarks(landmarks, ry2 - ry1)
pre_card_h = _clamp(int(round(plate_h)), 24, min(110, img_h - 4))
pre_pad_y = _clamp(int(round(pre_card_h * 0.16)), 3, 10)
pre_inner_h = max(8, pre_card_h - pre_pad_y * 2)
pre_max_font = _clamp(int(round(pre_card_h * 0.72)), 14, 44)
pre_min_font = _clamp(int(round(pre_card_h * 0.42)), 10, pre_max_font)
pre_font_size = _fit_font_size(
label, 4096, pre_inner_h, min_size=pre_min_font, max_size=pre_max_font
)
pre_text_w, _ = _measure_text(label, text_size=pre_font_size)
min_w_by_text = pre_text_w + 20
base_w_by_plate = int(round(plate_w * 1.05))
card_w = max(90, base_w_by_plate, min_w_by_text)
card_w = min(card_w, img_w - 8)
card_h = pre_card_h
card_pad_x = _clamp(int(round(card_w * 0.08)), 8, 18)
card_pad_y = _clamp(int(round(card_h * 0.16)), 3, 10)
card_x = int(rx1 + (rx2 - rx1 - card_w) / 2)
card_x = _clamp(card_x, 4, max(4, img_w - card_w - 4))
card_y = ry1 - card_h - 2
if card_y < 2:
card_y = _clamp(ry1 + 2, 2, max(2, img_h - card_h - 2))
_draw_alpha_rect(
orgimg,
card_x,
card_y,
card_x + card_w,
card_y + card_h,
theme["bg"],
alpha=0.78,
)
_draw_glow_border(
orgimg,
card_x,
card_y,
card_x + card_w,
card_y + card_h,
theme["border"],
theme["glow"],
)
inner_w = max(8, card_w - card_pad_x * 2)
inner_h = max(8, card_h - card_pad_y * 2)
dynamic_max_font = _clamp(int(round(card_h * 0.72)), 14, 44)
dynamic_min_font = _clamp(int(round(card_h * 0.42)), 10, dynamic_max_font)
font_size = _fit_font_size(
label,
inner_w,
inner_h,
min_size=dynamic_min_font,
max_size=dynamic_max_font,
)
text_w, text_h = _measure_text(label, text_size=font_size)
text_x = card_x + max(card_pad_x, int((card_w - text_w) / 2))
text_y = card_y + max(card_pad_y - 1, int((card_h - text_h) / 2))
orgimg = cv2ImgAddText(orgimg, label, text_x, text_y, theme["text"], font_size)
return orgimg, result_str
+98
View File
@@ -0,0 +1,98 @@
import os
from dataclasses import dataclass
import cv2
import torch
from ultralytics import YOLO
from ai.plate.detect_rec_plate import det_rec_plate, draw_result
from ai.plate.plate_recognition.plate_rec import init_model
from config.yolo import logger
@dataclass
class PlateInfo:
plate_no: str
plate_color: str
result_img_path: str
class PlateRecognizer:
_instance = None
_ready = False
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(PlateRecognizer, cls).__new__(cls)
return cls._instance
def __init__(
self,
detect_model_path,
rec_model_path,
output_dir="result",
device="cuda" if torch.cuda.is_available() else "cpu",
):
if hasattr(self, "_initialized") and self._initialized:
return
self.device = torch.device(device)
self.output_dir = output_dir
os.makedirs(self.output_dir, exist_ok=True)
try:
# 模型加载
self.detect_model = YOLO(detect_model_path)
self.detect_model.to(self.device)
self.detect_model.eval()
self.plate_rec_model = init_model(
self.device, rec_model_path, is_color=True
)
self._initialized = True
self._ready = True
except Exception as e:
self.detect_model = None
self._ready = False
logger.error(f"❌车牌 YOLO 模型加载失败: {e}")
def analyze_image(self, image_path):
img = cv2.imread(image_path)
if not self._ready or img is None:
return []
# 检测与识别
result_list = det_rec_plate(
img, self.detect_model, self.plate_rec_model, self.device
)
# 可视化 & 保存
vis_img, _ = draw_result(img, result_list)
result_img_path = os.path.join(self.output_dir, os.path.basename(image_path))
cv2.imwrite(result_img_path, vis_img)
# 构建返回结果
results = [
PlateInfo(
plate_no=r["plate_no"],
plate_color=r["plate_color"],
result_img_path=result_img_path,
)
for r in result_list
]
return results
# 初始化一次
recognizer = PlateRecognizer(
detect_model_path="/app/models/sentinel/yolo26s-plate-detect.pt",
rec_model_path="/app/models/sentinel/plate_rec_color.pth",
output_dir="result",
)
# base_dir = Path(r"C:\Users\BBIT\Desktop\yolo26-plate-main")
#
# recognizer = PlateRecognizer(
# detect_model_path=str(base_dir / "weights" / "yolo26s-plate-detect.pt"),
# rec_model_path=str(base_dir / "weights" / "plate_rec_color.pth"),
# output_dir="result",
# )
@@ -0,0 +1,15 @@
import os
import cv2
import numpy as np
def get_split_merge(img):
h,w,c = img.shape
img_upper = img[0:int(5/12*h),:]
img_lower = img[int(1/3*h):,:]
img_upper = cv2.resize(img_upper,(img_lower.shape[1],img_lower.shape[0]))
new_img = np.hstack((img_upper,img_lower))
return new_img
if __name__=="__main__":
img = cv2.imread("double_plate/tmp8078.png")
new_img =get_split_merge(img)
cv2.imwrite("double_plate/new.jpg",new_img)
@@ -0,0 +1,203 @@
import torch.nn as nn
import torch
class myNet_ocr(nn.Module):
def __init__(self,cfg=None,num_classes=78,export=False):
super(myNet_ocr, self).__init__()
if cfg is None:
cfg =[32,32,64,64,'M',128,128,'M',196,196,'M',256,256]
# cfg =[32,32,'M',64,64,'M',128,128,'M',256,256]
self.feature = self.make_layers(cfg, True)
self.export = export
# self.classifier = nn.Linear(cfg[-1], num_classes)
# self.loc = nn.MaxPool2d((2, 2), (5, 1), (0, 1),ceil_mode=True)
# self.loc = nn.AvgPool2d((2, 2), (5, 2), (0, 1),ceil_mode=False)
self.loc = nn.MaxPool2d((5, 2), (1, 1),(0,1),ceil_mode=False)
self.newCnn=nn.Conv2d(cfg[-1],num_classes,1,1)
# self.newBn=nn.BatchNorm2d(num_classes)
def make_layers(self, cfg, batch_norm=False):
layers = []
in_channels = 3
for i in range(len(cfg)):
if i == 0:
conv2d =nn.Conv2d(in_channels, cfg[i], kernel_size=5,stride =1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = cfg[i]
else :
if cfg[i] == 'M':
layers += [nn.MaxPool2d(kernel_size=3, stride=2,ceil_mode=True)]
else:
conv2d = nn.Conv2d(in_channels, cfg[i], kernel_size=3, padding=(1,1),stride =1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = cfg[i]
return nn.Sequential(*layers)
def forward(self, x):
x = self.feature(x)
x=self.loc(x)
x=self.newCnn(x)
# x=self.newBn(x)
if self.export:
conv = x.squeeze(2) # b *512 * width
conv = conv.transpose(2,1) # [w, b, c]
# conv =conv.argmax(dim=2)
return conv
else:
b, c, h, w = x.size()
assert h == 1, "the height of conv must be 1"
conv = x.squeeze(2) # b *512 * width
conv = conv.permute(2, 0, 1) # [w, b, c]
# output = F.log_softmax(self.rnn(conv), dim=2)
output = torch.softmax(conv, dim=2)
return output
myCfg = [32,'M',64,'M',96,'M',128,'M',256]
class myNet(nn.Module):
def __init__(self,cfg=None,num_classes=3):
super(myNet, self).__init__()
if cfg is None:
cfg = myCfg
self.feature = self.make_layers(cfg, True)
self.classifier = nn.Linear(cfg[-1], num_classes)
def make_layers(self, cfg, batch_norm=False):
layers = []
in_channels = 3
for i in range(len(cfg)):
if i == 0:
conv2d =nn.Conv2d(in_channels, cfg[i], kernel_size=5,stride =1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = cfg[i]
else :
if cfg[i] == 'M':
layers += [nn.MaxPool2d(kernel_size=3, stride=2,ceil_mode=True)]
else:
conv2d = nn.Conv2d(in_channels, cfg[i], kernel_size=3, padding=1,stride =1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = cfg[i]
return nn.Sequential(*layers)
def forward(self, x):
x = self.feature(x)
x = nn.AvgPool2d(kernel_size=3, stride=1)(x)
x = x.view(x.size(0), -1)
y = self.classifier(x)
return y
class MyNet_color(nn.Module):
def __init__(self, class_num=6):
super(MyNet_color, self).__init__()
self.class_num = class_num
self.backbone = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=16, kernel_size=(5, 5), stride=(1, 1)), # 0
torch.nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(kernel_size=(2, 2)),
nn.Dropout(0),
nn.Flatten(),
nn.Linear(480, 64),
nn.Dropout(0),
nn.ReLU(),
nn.Linear(64, class_num),
nn.Dropout(0),
nn.Softmax(1)
)
def forward(self, x):
logits = self.backbone(x)
return logits
class myNet_ocr_color(nn.Module):
def __init__(self,cfg=None,num_classes=78,export=False,color_num=None):
super(myNet_ocr_color, self).__init__()
if cfg is None:
cfg =[32,32,64,64,'M',128,128,'M',196,196,'M',256,256]
# cfg =[32,32,'M',64,64,'M',128,128,'M',256,256]
self.feature = self.make_layers(cfg, True)
self.export = export
self.color_num=color_num
self.conv_out_num=12 #颜色第一个卷积层输出通道12
if self.color_num:
self.conv1=nn.Conv2d(cfg[-1],self.conv_out_num,kernel_size=3,stride=2)
self.bn1=nn.BatchNorm2d(self.conv_out_num)
self.relu1=nn.ReLU(inplace=True)
self.gap =nn.AdaptiveAvgPool2d(output_size=1)
self.color_classifier=nn.Conv2d(self.conv_out_num,self.color_num,kernel_size=1,stride=1)
self.color_bn = nn.BatchNorm2d(self.color_num)
self.flatten = nn.Flatten()
self.loc = nn.MaxPool2d((5, 2), (1, 1),(0,1),ceil_mode=False)
self.newCnn=nn.Conv2d(cfg[-1],num_classes,1,1)
# self.newBn=nn.BatchNorm2d(num_classes)
def make_layers(self, cfg, batch_norm=False):
layers = []
in_channels = 3
for i in range(len(cfg)):
if i == 0:
conv2d =nn.Conv2d(in_channels, cfg[i], kernel_size=5,stride =1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = cfg[i]
else :
if cfg[i] == 'M':
layers += [nn.MaxPool2d(kernel_size=3, stride=2,ceil_mode=True)]
else:
conv2d = nn.Conv2d(in_channels, cfg[i], kernel_size=3, padding=(1,1),stride =1)
if batch_norm:
layers += [conv2d, nn.BatchNorm2d(cfg[i]), nn.ReLU(inplace=True)]
else:
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = cfg[i]
return nn.Sequential(*layers)
def forward(self, x):
x = self.feature(x)
if self.color_num:
x_color=self.conv1(x)
x_color=self.bn1(x_color)
x_color =self.relu1(x_color)
x_color = self.color_classifier(x_color)
x_color = self.color_bn(x_color)
x_color =self.gap(x_color)
x_color = self.flatten(x_color)
x=self.loc(x)
x=self.newCnn(x)
if self.export:
conv = x.squeeze(2) # b *512 * width
conv = conv.transpose(2,1) # [w, b, c]
if self.color_num:
return conv,x_color
return conv
else:
b, c, h, w = x.size()
assert h == 1, "the height of conv must be 1"
conv = x.squeeze(2) # b *512 * width
conv = conv.permute(2, 0, 1) # [w, b, c]
output = F.log_softmax(conv, dim=2)
if self.color_num:
return output,x_color
return output
if __name__ == '__main__':
x = torch.randn(1,3,48,216)
model = myNet_ocr(num_classes=78,export=True)
out = model(x)
print(out.shape)
@@ -0,0 +1,134 @@
import os
import time
import cv2
import numpy as np
import torch
from ai.plate.plate_recognition.plateNet import myNet_ocr_color
def cv_imread(path): # 可以读取中文路径的图片
img = cv2.imdecode(np.fromfile(path, dtype=np.uint8), -1)
return img
def allFilePath(rootPath, allFIleList):
fileList = os.listdir(rootPath)
for temp in fileList:
if os.path.isfile(os.path.join(rootPath, temp)):
if temp.endswith(".jpg") or temp.endswith(".png") or temp.endswith(".JPG"):
allFIleList.append(os.path.join(rootPath, temp))
else:
allFilePath(os.path.join(rootPath, temp), allFIleList)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
color = ["黑色", "蓝色", "绿色", "白色", "黄色"]
plateName = r"#京沪津渝冀晋蒙辽吉黑苏浙皖闽赣鲁豫鄂湘粤桂琼川贵云藏陕甘青宁新学警港澳挂使领民航危0123456789ABCDEFGHJKLMNPQRSTUVWXYZ险品"
mean_value, std_value = (0.588, 0.193)
def decodePlate(preds):
pre = 0
newPreds = []
index = []
for i in range(len(preds)):
if preds[i] != 0 and preds[i] != pre:
newPreds.append(preds[i])
index.append(i)
pre = preds[i]
return newPreds, index
def image_processing(img, device):
img = cv2.resize(img, (168, 48))
img = np.reshape(img, (48, 168, 3))
# normalize
img = img.astype(np.float32)
img = (img / 255.0 - mean_value) / std_value
img = img.transpose([2, 0, 1])
img = torch.from_numpy(img)
img = img.to(device)
img = img.view(1, *img.size())
return img
def get_plate_result(img, device, model, is_color=False):
input = image_processing(img, device)
if is_color: # 是否识别颜色
preds, color_preds = model(input)
color_preds = torch.softmax(color_preds, dim=-1)
color_conf, color_index = torch.max(color_preds, dim=-1)
color_conf = color_conf.item()
else:
preds = model(input)
preds = torch.softmax(preds, dim=-1)
prob, index = preds.max(dim=-1)
index = index.view(-1).detach().cpu().numpy()
prob = prob.view(-1).detach().cpu().numpy()
# preds=preds.view(-1).detach().cpu().numpy()
newPreds, new_index = decodePlate(index)
prob = prob[new_index]
plate = ""
for i in newPreds:
plate += plateName[i]
# if not (plate[0] in plateName[1:44] ):
# return ""
if is_color:
return (
plate,
prob,
color[color_index],
color_conf,
) # 返回车牌号以及每个字符的概率,以及颜色,和颜色的概率
else:
return plate, prob
def init_model(device, model_path, is_color=False):
# print( print(sys.path))
# model_path ="plate_recognition/model/checkpoint_61_acc_0.9715.pth"
check_point = torch.load(model_path, map_location=device)
model_state = check_point["state_dict"]
cfg = check_point["cfg"]
color_classes = 0
if is_color:
color_classes = 5 # 颜色类别数
model = myNet_ocr_color(
num_classes=len(plateName), export=True, cfg=cfg, color_num=color_classes
)
model.load_state_dict(model_state, strict=False)
model.to(device)
model.eval()
return model
# model = init_model(device)
if __name__ == "__main__":
model_path = r"weights/plate_rec_color.pth"
image_path = "images/tmp2424.png"
testPath = r"/mnt/Gpan/Mydata/pytorchPorject/CRNN/crnn_plate_recognition/images"
fileList = []
allFilePath(testPath, fileList)
# result = get_plate_result(image_path,device)
# print(result)
is_color = False
model = init_model(device, model_path, is_color=is_color)
right = 0
begin = time.time()
for imge_path in fileList:
img = cv2.imread(imge_path)
if is_color:
plate, _, plate_color, _ = get_plate_result(
img, device, model, is_color=is_color
)
print(plate)
else:
plate, _ = get_plate_result(img, device, model, is_color=is_color)
print(plate, imge_path)