42 lines
1.6 KiB
Python
42 lines
1.6 KiB
Python
'''
|
|
20250304更新
|
|
'''
|
|
from pathlib import Path
|
|
|
|
|
|
# source_classes = ['Gloves','Helmet','Non-Helmet','Person','Shoes','Vest','bare-arms']
|
|
|
|
# labels_path = '安全帽鞋头数据集/labels'
|
|
|
|
import os
|
|
|
|
# 类别名称列表
|
|
class_names = ['Gloves','Helmet','Non-Helmet','Person','Shoes','Vest','bare-arms'] # 替换为你的类别名称
|
|
target_class = ['Helmet', 'Non-Helmet', 'Shoes'] # 目标类别
|
|
|
|
def filter_labels(label_directory, output_label_directory):
|
|
if not os.path.exists(output_label_directory):
|
|
os.makedirs(output_label_directory) # 创建输出标签目录
|
|
|
|
# 遍历标签目录中的所有文件
|
|
for label_file in os.listdir(label_directory):
|
|
if label_file.endswith('.txt'): # 确保是标签文件
|
|
with open(os.path.join(label_directory, label_file), 'r') as file:
|
|
lines = file.readlines()
|
|
|
|
# 过滤标签,只保留目标类别
|
|
filtered_lines = []
|
|
for line in lines:
|
|
class_id, _, _, _, _ = map(float, line.split())
|
|
if class_names[int(class_id)] in target_class:
|
|
filtered_lines.append(line)
|
|
|
|
# 将过滤后的标签写入新的文件
|
|
with open(os.path.join(output_label_directory, label_file), 'w') as file:
|
|
file.writelines(filtered_lines)
|
|
|
|
# 示例用法
|
|
label_directory = '安全帽鞋头数据集/labels' # 替换为你的标签文件夹路径
|
|
output_label_directory = '安全帽鞋头数据集/labels_helmetShoeHead' # 替换为你希望保存过滤后标签的路径
|
|
filter_labels(label_directory, output_label_directory)
|