from itertools import cycle from torch.utils.data import Dataset, DataLoader # from PIL import Image import numpy as np import cv2 import os class CustomImageDataset(Dataset): def __init__(self, root_dir, transform=None) -> None: super().__init__() self.root_dir = root_dir self.transform = transform self.image_path = self.root_dir+"images" self.label_path = self.root_dir+"lables" self.all_images = [os.path.join(self.image_path, t) for t in os.listdir(self.image_path)] self.all_lables = [os.path.join(self.image_path, t) for t in os.listdir(self.label_path)] # 一定要实现 len方法 def __len__(self): return len(self.all_images) # 实现获取图片方法 def __getitem__(self, index): image_path = self.all_images[index] lable_path = self.all_lables[index] image = cv2.imread(image_path) image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) image_h, image_w = image.shape[:2] with open(lable_path, "r", encoding="utf-8") as f: lines = f.readlines() label = list() for line in lines: data = line.strip().split() class_id = int(data[0]) cx, cy, w, h = map(float, data[1:]) # 转换为绝对坐标 x_center = cx * image_w y_center = cy * image_h box_w = w * image_w box_h = h * image_h # 计算边界框坐标 x_min = max(0, x_center - box_w/2) y_min = max(0, y_center - box_h/2) x_max = min(image_w, x_center + box_w/2) y_max = min(image_h, y_center + box_h/2) label.append([class_id, x_min, y_min, x_max, y_max]) label = np.array(label) return image, label if __name__ == "__main__": train_dataset = CustomImageDataset("") test_dataset = CustomImageDataset("") train_loader = DataLoader( train_dataset, batch_size=32, shuffle=True, num_workers=2, pin_memory=True # 加速数据传输到GPU ) test_loader = DataLoader( test_dataset, batch_size=32, shuffle=True, num_workers=2 )