diff --git a/predict.py b/predict.py index 2075b78..03fb737 100644 --- a/predict.py +++ b/predict.py @@ -5,61 +5,34 @@ from keras.layers import Lambda from keras import backend as K import os import random -from PIL import Image -# 要预测的图片 -# image_path = "./sample/1691156257961.jpg" +""" + 孪生网络 对比模型 +""" +resize_height, resize_width, channel = 52, 52, 3 +weight = "./best.h5" + +output = Lambda(lambda x: K.abs(x[0] - x[1])) +model = load_model(weight, custom_objects={'output': output}) + + image_path = os.listdir("./data") # 随机选取一张图片 # inp = input('请输入图片名称:') image_path = "./data/" + random.choice(image_path) # image_path = './sample/' + inp -# print(image_path) - -""" - YOLOv3 分割模型 -""" weight = "./yolov3-tiny_17000.weights" cfg = "./yolov3-tiny.cfg" - +img = cv2.imread(image_path) # 加载模型 net = cv2.dnn.readNet(weight, cfg) -""" - 孪生网络 对比模型 -""" -resize_height, resize_width,channel = 52,52,3 - -# 自定义的损失和精度 -output = Lambda(lambda x: K.abs(x[0] - x[1])) -def contrastive_loss(y_true, y_pred): - margin = 1 - return K.mean(y_true * K.square(y_pred) + (1 - y_true) * K.square(K.maximum(margin - y_pred, 0))) - -def binary_accuracy(y_true, y_pred): - return K.mean(K.equal(y_true, K.cast(y_pred < 0.5, y_true.dtype))) -weight = "./best.h5" -# 加载模型 -# model = load_model(weight, custom_objects={'contrastive_loss': contrastive_loss, 'binary_accuracy': binary_accuracy}) -model = load_model(weight, custom_objects={'output': output}) - -classes = ["text"] - - -img = cv2.imread(image_path) - -cv2.namedWindow('display') - - -""" - YOLO 分割出内容 -""" height, width, channels = img.shape -blob = cv2.dnn.blobFromImage(img, 0.00392, (416, 416), (0, 0, 0), True, crop=False) +blob = cv2.dnn.blobFromImage(img, 0.00392, (416, 416), (0, 0, 0), True, crop=False) # 预处理 + net.setInput(blob) outs = net.forward(net.getUnconnectedOutLayersNames()) class_ids = [] confidences = [] boxes = [] - for out in outs: for detection in out: scores = detection[5:] @@ -90,93 +63,40 @@ new_boxes = [] for i in range(len(boxes)): if i in indexes: x, y, w, h = boxes[i] - + new_boxes.append([x, y, w, h]) -total = len(new_boxes) +up_img = sorted(new_boxes, key=lambda x_: x_[1])[0:len(new_boxes)//2] # 按照y排列 取出上面的 +up_img = sorted(up_img, key=lambda x_: x_[0]) # 按照x排序 -print('>>> 检测出:', total,'个字符') +location_up = {} +for i, j in enumerate(up_img): + location_up[i+1] = [img[j[1]:j[1]+j[3], j[0]:j[0]+j[2]].astype('float64') / 255.0, j] -top_total = total / 2 -# 如果取出来有小数 -if top_total % 1 != 0: - print('>>> YOLO分割有误,顶部字符数量与点选数量不匹配') - exit() - -top_total = int(top_total) +down_img = sorted(new_boxes, key=lambda x_: x_[1])[len(new_boxes)//2:] # 取出下面的 +down_img = sorted(down_img, key=lambda x_: x_[0]) # 按照x排序 -# 取出w最大的(需要点选的) -w_max_boxes = sorted(new_boxes, key=lambda x: x[2], reverse=True)[:top_total] +location_down = {} +for i, j in enumerate(down_img): + # location[i+1] = j + location_down[i+1] = [img[j[1]:j[1]+j[3], j[0]:j[0]+j[2]].astype('float64') / 255.0, j] -# 取出剩下的(需要对比的) -w_min_box = sorted(new_boxes, key=lambda x: x[2], reverse=True)[top_total:] +new_list = [] +for down_i, down_img_ in location_down.items(): + # 先是读取下面的图 + temp = [] + for up_i, up_img_ in location_up.items(): + down = np.expand_dims(cv2.resize(down_img_[0], (52, 52)), axis=0) + up = np.expand_dims(cv2.resize(up_img_[0], (52, 52)), axis=0) + predict = model.predict([down, up]) + temp.append(predict[0][0]) + temp_ = temp.index(max(temp)) + new_list.append([down_img_[1], temp_]) -# 按照从左到右排序w_min_box -w_min_box = sorted(w_min_box, key=lambda x: x[0]) - - -w_max_image = [] -w_min_image = [] - -# 分割出具体图像 -for i in range(top_total): - x, y, w, h = w_max_boxes[i] - p = cv2.resize(img[y:y+h, x:x+w], (resize_height, resize_width)) - - # cv2.imwrite('./1/1_{}.jpg'.format(i), p) - - w_max_image.append(p) - -for i in range(len(w_min_box)): - x, y, w, h = w_min_box[i] - p = cv2.resize(img[y:y+h, x:x+w], (resize_height, resize_width)) - - # cv2.imwrite('./1/2_{}.jpg'.format(i), p) - - w_min_image.append(p) - -# print(w_max_boxes) - -w_max_image_np = np.array(w_max_image) / 255 -w_min_image_np = np.array(w_min_image) / 255 - -select_index = [] - -# 开始挨个对比,取出最相似的 -for i in range(len(w_max_image_np)): - print('>>> 开始对比第', i+1, '个字符') - cv2.imwrite('./1/1_{}.jpg'.format(i), w_max_image_np[i] * 255) - - left_x = w_max_image_np[i] - num_index = 0 - cache_rate = 0 - for k in range(len(w_min_image_np)): - left_y = w_min_image_np[k] - predict = model.predict([left_x.reshape(1, resize_height, resize_width, channel), left_y.reshape(1, resize_height, resize_width, channel)]) - rate = predict[0][0] - if rate > cache_rate: - cv2.imwrite('./1/2_{}.jpg'.format(i), w_min_image_np[k] * 255) - num_index = k - - - - cache_rate = rate - - - select_index.append(num_index) - -print('>>> 对比完成,结果为:', select_index) - -location = [] - -for i in range(len(select_index)): - x, y, w, h = w_max_boxes[select_index[i]] - cv2.rectangle(img, (x, y), (x+w, y+h), color, thickness) - cv2.putText(img, str(i+1), (x, y+h), font, 1, color, thickness) - # 转换为图片坐标中心点 - location.append([x+w/2, y+h/2]) - -print('>>> 位置坐标:', location) - -cv2.imshow('display', img) -cv2.waitKey(0) \ No newline at end of file +for (box, pos) in new_list: + x, y, w, h = box[0], box[1], box[2], box[3] + x1, y1, x2, y2 = x, y, x + w, y + h + cv2.rectangle(img, (x1, y1), (x2, y2), (0, 0, 255)) + cv2.putText(img, str(pos+1), (x1, y1), cv2.FONT_HERSHEY_PLAIN, 1, color, thickness) +cv2.imshow('1', img) +cv2.waitKey(0)