80 lines
2.3 KiB
Python
80 lines
2.3 KiB
Python
from itertools import cycle
|
|
from torch.utils.data import Dataset, DataLoader
|
|
# from PIL import Image
|
|
import numpy as np
|
|
import cv2
|
|
import os
|
|
|
|
|
|
|
|
class CustomImageDataset(Dataset):
|
|
def __init__(self, root_dir, transform=None) -> None:
|
|
super().__init__()
|
|
|
|
self.root_dir = root_dir
|
|
self.transform = transform
|
|
|
|
self.image_path = self.root_dir+"images"
|
|
self.label_path = self.root_dir+"lables"
|
|
|
|
self.all_images = [os.path.join(self.image_path, t) for t in os.listdir(self.image_path)]
|
|
self.all_lables = [os.path.join(self.image_path, t) for t in os.listdir(self.label_path)]
|
|
|
|
# 一定要实现 len方法
|
|
def __len__(self):
|
|
return len(self.all_images)
|
|
|
|
# 实现获取图片方法
|
|
def __getitem__(self, index):
|
|
image_path = self.all_images[index]
|
|
lable_path = self.all_lables[index]
|
|
|
|
image = cv2.imread(image_path)
|
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
|
image_h, image_w = image.shape[:2]
|
|
|
|
with open(lable_path, "r", encoding="utf-8") as f:
|
|
lines = f.readlines()
|
|
|
|
label = list()
|
|
for line in lines:
|
|
data = line.strip().split()
|
|
|
|
class_id = int(data[0])
|
|
cx, cy, w, h = map(float, data[1:])
|
|
|
|
# 转换为绝对坐标
|
|
x_center = cx * image_w
|
|
y_center = cy * image_h
|
|
box_w = w * image_w
|
|
box_h = h * image_h
|
|
|
|
# 计算边界框坐标
|
|
x_min = max(0, x_center - box_w/2)
|
|
y_min = max(0, y_center - box_h/2)
|
|
x_max = min(image_w, x_center + box_w/2)
|
|
y_max = min(image_h, y_center + box_h/2)
|
|
|
|
label.append([class_id, x_min, y_min, x_max, y_max])
|
|
label = np.array(label)
|
|
|
|
return image, label
|
|
|
|
if __name__ == "__main__":
|
|
train_dataset = CustomImageDataset("")
|
|
test_dataset = CustomImageDataset("")
|
|
|
|
train_loader = DataLoader(
|
|
train_dataset,
|
|
batch_size=32,
|
|
shuffle=True,
|
|
num_workers=2,
|
|
pin_memory=True # 加速数据传输到GPU
|
|
)
|
|
|
|
test_loader = DataLoader(
|
|
test_dataset,
|
|
batch_size=32,
|
|
shuffle=True,
|
|
num_workers=2
|
|
) |