mirror of
https://github.com/luzhisheng/js_reverse.git
synced 2025-04-19 09:24:46 +08:00
110 lines
4.4 KiB
Python
110 lines
4.4 KiB
Python
import os
|
||
import xml.etree.ElementTree as ET
|
||
import logging
|
||
import shutil
|
||
|
||
# 提取图像并将其分类为训练集与训练测试集 :train validation
|
||
Images = os.listdir('JPEGImages')
|
||
Images = [i for i in Images if i.split('.')[-1] == 'png']
|
||
print('提取到有效jpg图片共{}张'.format(len(Images)))
|
||
# 按照分配率将图片分类 分类率:train/validation 可以自己修改,可以不改,看心情
|
||
distribution_rate = 0.9
|
||
|
||
# ------------------------------------------------------------
|
||
# ↓↓标注的类别,很重要很重要,这里必须要改,按照自己的类别去改|
|
||
|
||
classes = ['缺口']
|
||
|
||
# ↑↑标注的类别,很重要很重要,这里必须要改,按照自己的类别去改|
|
||
# ------------------------------------------------------------
|
||
|
||
|
||
# 正式移动图片到指定目录:.images 下, 并且生成训练索引 train.txt and val.txt 这一步会清空这两个文本的内容
|
||
# 正式移动图片到指定目录:.images 下, 并且生成训练索引 train.txt and val.txt 这一步会清空这两个文本的内容
|
||
train = Images[0: int(distribution_rate * len(Images))]
|
||
validation = Images[int(distribution_rate * len(Images)):]
|
||
if train == 0 or validation == 0:
|
||
raise FileExistsError('没有找到训练集的图片或测试集图片,请检查目录')
|
||
|
||
# 获取绝对路径。为了好看 把 \ 处理成 /
|
||
ab_path = os.path.dirname(os.path.abspath(__file__)).replace('\\', '/')
|
||
print(ab_path)
|
||
with open('train.txt', 'w', encoding='utf-8') as f:
|
||
for i in train:
|
||
f.write(ab_path + '/images/train/' + i + '\n')
|
||
shutil.copy('JPEGImages/' + i, 'images/train')
|
||
|
||
with open('val.txt', 'w', encoding='utf-8') as f:
|
||
for i in validation:
|
||
f.write(ab_path + '/images/val/' + i + '\n')
|
||
shutil.copy('JPEGImages/' + i, 'images/val')
|
||
|
||
print('图片移动/复制完成,训练索引 train.txt and val.txt 生成完毕')
|
||
|
||
# 预检测 xml与图片的对应关系,这里要求严格一一对应
|
||
|
||
|
||
xml_file = os.listdir('Annotations')
|
||
xml_file = [i for i in xml_file if i.split('.')[-1] == 'xml']
|
||
xml_file_check = [i.split('.')[0] + '.xml' for i in Images if i.split('.')[-1] == 'png']
|
||
if xml_file_check != xml_file:
|
||
raise FileExistsError('Annotations 中xml文件与JPEGImages图片不对应,请仔细检测!')
|
||
|
||
|
||
# 下面将 xml文件标注提取并生成label
|
||
def convert(size, box):
|
||
dw = 1. / (size[0])
|
||
dh = 1. / (size[1])
|
||
x = (box[0] + box[1]) / 2.0 - 1
|
||
y = (box[2] + box[3]) / 2.0 - 1
|
||
w = box[1] - box[0]
|
||
h = box[3] - box[2]
|
||
x = x * dw
|
||
w = w * dw
|
||
y = y * dh
|
||
h = h * dh
|
||
return x, y, w, h
|
||
|
||
|
||
def write_labels(xml_file_path, write_to_file_path):
|
||
with open(xml_file_path, 'r', encoding='utf-8') as f:
|
||
tree = ET.parse(f)
|
||
root = tree.getroot()
|
||
size = root.find('size')
|
||
w = int(size.find('width').text)
|
||
h = int(size.find('height').text)
|
||
with open(write_to_file_path, 'w', encoding='utf-8') as f2:
|
||
for obj in root.iter('object'):
|
||
xml_name = obj.find('name').text
|
||
if xml_name not in classes:
|
||
logging.warning('正在检索该对象不存在设定classes,应该引起重视')
|
||
continue
|
||
cls_id = classes.index(xml_name)
|
||
xmlbox = obj.find('bndbox')
|
||
b = (float(xmlbox.find('xmin').text), float(xmlbox.find('xmax').text), float(xmlbox.find('ymin').text),
|
||
float(xmlbox.find('ymax').text))
|
||
b1, b2, b3, b4 = b
|
||
# 标注越界修正
|
||
if b2 > w:
|
||
b2 = w
|
||
if b4 > h:
|
||
b4 = h
|
||
b = (b1, b2, b3, b4)
|
||
bb = convert((w, h), b)
|
||
write_message = str(cls_id) + " " + " ".join([str(a) for a in bb]) + '\n'
|
||
f2.write(write_message)
|
||
if not write_message:
|
||
logging.warning(
|
||
'未在标注图片的xml文件中取得分类内容,此警告应引起重视,可能意味着分类参数不匹配。classes错误')
|
||
|
||
|
||
for i in train:
|
||
write_labels('Annotations/' + i.split('.')[0] + '.xml', 'labels/train/{}'.format(i.split('.')[0] + '.txt'))
|
||
|
||
for i in validation:
|
||
write_labels('Annotations/' + i.split('.')[0] + '.xml', 'labels/val/{}'.format(i.split('.')[0] + '.txt'))
|
||
|
||
# 最后一步 在当前目录下生成索引
|
||
|
||
print('finish work')
|