mirror of
https://github.com/wlkjyh/dianxuan.git
synced 2025-04-06 08:14:32 +08:00
修复预测bug
This commit is contained in:
parent
7ab0bcabf6
commit
837e6085eb
166
predict.py
166
predict.py
@ -5,61 +5,34 @@ from keras.layers import Lambda
|
||||
from keras import backend as K
|
||||
import os
|
||||
import random
|
||||
from PIL import Image
|
||||
# 要预测的图片
|
||||
# image_path = "./sample/1691156257961.jpg"
|
||||
"""
|
||||
孪生网络 对比模型
|
||||
"""
|
||||
resize_height, resize_width, channel = 52, 52, 3
|
||||
weight = "./best.h5"
|
||||
|
||||
output = Lambda(lambda x: K.abs(x[0] - x[1]))
|
||||
model = load_model(weight, custom_objects={'output': output})
|
||||
|
||||
|
||||
image_path = os.listdir("./data")
|
||||
# 随机选取一张图片
|
||||
# inp = input('请输入图片名称:')
|
||||
image_path = "./data/" + random.choice(image_path)
|
||||
# image_path = './sample/' + inp
|
||||
# print(image_path)
|
||||
|
||||
"""
|
||||
YOLOv3 分割模型
|
||||
"""
|
||||
weight = "./yolov3-tiny_17000.weights"
|
||||
cfg = "./yolov3-tiny.cfg"
|
||||
|
||||
img = cv2.imread(image_path)
|
||||
# 加载模型
|
||||
net = cv2.dnn.readNet(weight, cfg)
|
||||
"""
|
||||
孪生网络 对比模型
|
||||
"""
|
||||
resize_height, resize_width,channel = 52,52,3
|
||||
|
||||
# 自定义的损失和精度
|
||||
output = Lambda(lambda x: K.abs(x[0] - x[1]))
|
||||
def contrastive_loss(y_true, y_pred):
|
||||
margin = 1
|
||||
return K.mean(y_true * K.square(y_pred) + (1 - y_true) * K.square(K.maximum(margin - y_pred, 0)))
|
||||
|
||||
def binary_accuracy(y_true, y_pred):
|
||||
return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype)))
|
||||
weight = "./best.h5"
|
||||
# 加载模型
|
||||
# model = load_model(weight, custom_objects={'contrastive_loss': contrastive_loss, 'binary_accuracy': binary_accuracy})
|
||||
model = load_model(weight, custom_objects={'output': output})
|
||||
|
||||
classes = ["text"]
|
||||
|
||||
|
||||
img = cv2.imread(image_path)
|
||||
|
||||
cv2.namedWindow('display')
|
||||
|
||||
|
||||
"""
|
||||
YOLO 分割出内容
|
||||
"""
|
||||
height, width, channels = img.shape
|
||||
blob = cv2.dnn.blobFromImage(img, 0.00392, (416, 416), (0, 0, 0), True, crop=False)
|
||||
blob = cv2.dnn.blobFromImage(img, 0.00392, (416, 416), (0, 0, 0), True, crop=False) # 预处理
|
||||
|
||||
net.setInput(blob)
|
||||
outs = net.forward(net.getUnconnectedOutLayersNames())
|
||||
class_ids = []
|
||||
confidences = []
|
||||
boxes = []
|
||||
|
||||
for out in outs:
|
||||
for detection in out:
|
||||
scores = detection[5:]
|
||||
@ -90,93 +63,40 @@ new_boxes = []
|
||||
for i in range(len(boxes)):
|
||||
if i in indexes:
|
||||
x, y, w, h = boxes[i]
|
||||
|
||||
|
||||
new_boxes.append([x, y, w, h])
|
||||
|
||||
total = len(new_boxes)
|
||||
up_img = sorted(new_boxes, key=lambda x_: x_[1])[0:len(new_boxes)//2] # 按照y排列 取出上面的
|
||||
up_img = sorted(up_img, key=lambda x_: x_[0]) # 按照x排序
|
||||
|
||||
print('>>> 检测出:', total,'个字符')
|
||||
location_up = {}
|
||||
for i, j in enumerate(up_img):
|
||||
location_up[i+1] = [img[j[1]:j[1]+j[3], j[0]:j[0]+j[2]].astype('float64') / 255.0, j]
|
||||
|
||||
top_total = total / 2
|
||||
# 如果取出来有小数
|
||||
if top_total % 1 != 0:
|
||||
print('>>> YOLO分割有误,顶部字符数量与点选数量不匹配')
|
||||
exit()
|
||||
|
||||
top_total = int(top_total)
|
||||
down_img = sorted(new_boxes, key=lambda x_: x_[1])[len(new_boxes)//2:] # 取出下面的
|
||||
down_img = sorted(down_img, key=lambda x_: x_[0]) # 按照x排序
|
||||
|
||||
# 取出w最大的(需要点选的)
|
||||
w_max_boxes = sorted(new_boxes, key=lambda x: x[2], reverse=True)[:top_total]
|
||||
location_down = {}
|
||||
for i, j in enumerate(down_img):
|
||||
# location[i+1] = j
|
||||
location_down[i+1] = [img[j[1]:j[1]+j[3], j[0]:j[0]+j[2]].astype('float64') / 255.0, j]
|
||||
|
||||
# 取出剩下的(需要对比的)
|
||||
w_min_box = sorted(new_boxes, key=lambda x: x[2], reverse=True)[top_total:]
|
||||
new_list = []
|
||||
for down_i, down_img_ in location_down.items():
|
||||
# 先是读取下面的图
|
||||
temp = []
|
||||
for up_i, up_img_ in location_up.items():
|
||||
down = np.expand_dims(cv2.resize(down_img_[0], (52, 52)), axis=0)
|
||||
up = np.expand_dims(cv2.resize(up_img_[0], (52, 52)), axis=0)
|
||||
predict = model.predict([down, up])
|
||||
temp.append(predict[0][0])
|
||||
temp_ = temp.index(max(temp))
|
||||
new_list.append([down_img_[1], temp_])
|
||||
|
||||
# 按照从左到右排序w_min_box
|
||||
w_min_box = sorted(w_min_box, key=lambda x: x[0])
|
||||
|
||||
|
||||
w_max_image = []
|
||||
w_min_image = []
|
||||
|
||||
# 分割出具体图像
|
||||
for i in range(top_total):
|
||||
x, y, w, h = w_max_boxes[i]
|
||||
p = cv2.resize(img[y:y+h, x:x+w], (resize_height, resize_width))
|
||||
|
||||
# cv2.imwrite('./1/1_{}.jpg'.format(i), p)
|
||||
|
||||
w_max_image.append(p)
|
||||
|
||||
for i in range(len(w_min_box)):
|
||||
x, y, w, h = w_min_box[i]
|
||||
p = cv2.resize(img[y:y+h, x:x+w], (resize_height, resize_width))
|
||||
|
||||
# cv2.imwrite('./1/2_{}.jpg'.format(i), p)
|
||||
|
||||
w_min_image.append(p)
|
||||
|
||||
# print(w_max_boxes)
|
||||
|
||||
w_max_image_np = np.array(w_max_image) / 255
|
||||
w_min_image_np = np.array(w_min_image) / 255
|
||||
|
||||
select_index = []
|
||||
|
||||
# 开始挨个对比,取出最相似的
|
||||
for i in range(len(w_max_image_np)):
|
||||
print('>>> 开始对比第', i+1, '个字符')
|
||||
cv2.imwrite('./1/1_{}.jpg'.format(i), w_max_image_np[i] * 255)
|
||||
|
||||
left_x = w_max_image_np[i]
|
||||
num_index = 0
|
||||
cache_rate = 0
|
||||
for k in range(len(w_min_image_np)):
|
||||
left_y = w_min_image_np[k]
|
||||
predict = model.predict([left_x.reshape(1, resize_height, resize_width, channel), left_y.reshape(1, resize_height, resize_width, channel)])
|
||||
rate = predict[0][0]
|
||||
if rate > cache_rate:
|
||||
cv2.imwrite('./1/2_{}.jpg'.format(i), w_min_image_np[k] * 255)
|
||||
num_index = k
|
||||
|
||||
|
||||
|
||||
cache_rate = rate
|
||||
|
||||
|
||||
select_index.append(num_index)
|
||||
|
||||
print('>>> 对比完成,结果为:', select_index)
|
||||
|
||||
location = []
|
||||
|
||||
for i in range(len(select_index)):
|
||||
x, y, w, h = w_max_boxes[select_index[i]]
|
||||
cv2.rectangle(img, (x, y), (x+w, y+h), color, thickness)
|
||||
cv2.putText(img, str(i+1), (x, y+h), font, 1, color, thickness)
|
||||
# 转换为图片坐标中心点
|
||||
location.append([x+w/2, y+h/2])
|
||||
|
||||
print('>>> 位置坐标:', location)
|
||||
|
||||
cv2.imshow('display', img)
|
||||
cv2.waitKey(0)
|
||||
for (box, pos) in new_list:
|
||||
x, y, w, h = box[0], box[1], box[2], box[3]
|
||||
x1, y1, x2, y2 = x, y, x + w, y + h
|
||||
cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255))
|
||||
cv2.putText(img, str(pos+1), (x1, y1), cv2.FONT_HERSHEY_PLAIN, 1, color, thickness)
|
||||
cv2.imshow('1', img)
|
||||
cv2.waitKey(0)
|
||||
|
Loading…
Reference in New Issue
Block a user