328 lines
11 KiB
Python
328 lines
11 KiB
Python
import os
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.optim as optim
|
|
from torch.utils.data import Dataset, DataLoader
|
|
from torchvision import transforms
|
|
import cv2
|
|
import matplotlib.pyplot as plt
|
|
import matplotlib.patches as patches
|
|
from tqdm import tqdm
|
|
|
|
# 配置参数
|
|
class Config:
|
|
# 数据集参数
|
|
data_dir = 'path/to/your/dataset' # 替换为你的数据集路径
|
|
image_dir = os.path.join(data_dir, 'images')
|
|
label_dir = os.path.join(data_dir, 'labels')
|
|
class_names = ['class1', 'class2', 'class3'] # 替换为你的类别名称
|
|
num_classes = len(class_names)
|
|
|
|
# 模型参数
|
|
grid_size = 7 # 特征图网格大小
|
|
input_size = 224 # 输入图像尺寸
|
|
|
|
# 训练参数
|
|
batch_size = 32
|
|
epochs = 50
|
|
lr = 0.001
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
checkpoint_dir = 'checkpoints'
|
|
os.makedirs(checkpoint_dir, exist_ok=True)
|
|
|
|
class YOLODataset(Dataset):
|
|
def __init__(self, img_dir, label_dir, img_size=224, transform=None) -> None:
|
|
self.img_dir = img_dir
|
|
self.label_dir = label_dir
|
|
self.img_size = img_size
|
|
self.transform = transform
|
|
self.img_files = [os.path.join(self.img_dir, t) for t in os.listdir(self.img_dir)]
|
|
self.label_files = [os.path.join(self.label_dir, t) for t in os.listdir(self.label_dir)]
|
|
|
|
|
|
def __len__(self):
|
|
return len(self.img_files)
|
|
|
|
def __getitem__(self, index):
|
|
|
|
img_file = self.img_files[index]
|
|
lable_file = self.label_files[index]
|
|
|
|
img = cv2.imread(img_file)
|
|
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
|
|
|
orig_h, orig_w = img.shape[:2]
|
|
|
|
label = []
|
|
with open(lable_file, 'r', encoding="utf-8") as f:
|
|
for line in f.readlines():
|
|
class_id, x_center, y_center, width, height = map(float, line.split())
|
|
label.append([class_id, x_center, y_center, width, height])
|
|
|
|
# 图像预处理
|
|
img = cv2.resize(img, (self.img_size, self.img_size))
|
|
img = img.astype(np.float32) / 255.0
|
|
|
|
# 转换标签格式
|
|
target = self._create_target_grid(label, orig_w, orig_h)
|
|
|
|
if self.transform:
|
|
img = self.transform(img)
|
|
|
|
return torch.tensor(img).permute(2, 0, 1), torch.tensor(target)
|
|
|
|
def _create_target_grid(self, boxes, orig_w, orig_h):
|
|
"""创建网格化目标张量 [S, S, C+5]"""
|
|
S = Config.grid_size
|
|
target = torch.zeros((S, S, Config.num_classes + 5))
|
|
cell_size = 1.0 / S
|
|
|
|
for box in boxes:
|
|
class_id, x_center, y_center, width, height = box
|
|
|
|
# 计算所属网格位置
|
|
grid_x = int(x_center // cell_size)
|
|
grid_y = int(y_center // cell_size)
|
|
|
|
# 计算相对网格的坐标
|
|
x_offset = (x_center - grid_x * cell_size) / cell_size
|
|
y_offset = (y_center - grid_y * cell_size) / cell_size
|
|
|
|
# 归一化宽高
|
|
width = width * S
|
|
height = height * S
|
|
|
|
# 目标张量: [class_one_hot + x, y, w, h, confidence]
|
|
class_one_hot = torch.zeros(Config.num_classes)
|
|
class_one_hot[int(class_id)] = 1
|
|
|
|
target[grid_y, grid_x, :Config.num_classes] = class_one_hot
|
|
target[grid_y, grid_x, Config.num_classes:Config.num_classes+4] = torch.tensor([x_offset, y_offset, width, height])
|
|
target[grid_y, grid_x, -1] = 1 # 有目标置信度
|
|
|
|
return target
|
|
|
|
|
|
# YOLO模型架构 (简化版)
|
|
class TinyYOLO(nn.Module):
|
|
def __init__(self):
|
|
super(TinyYOLO, self).__init__()
|
|
self.S = Config.grid_size
|
|
self.C = Config.num_classes
|
|
|
|
# 特征提取主干
|
|
self.features = nn.Sequential(
|
|
nn.Conv2d(3, 16, 3, padding=1),
|
|
nn.BatchNorm2d(16),
|
|
nn.LeakyReLU(0.1),
|
|
nn.MaxPool2d(2, 2),
|
|
|
|
nn.Conv2d(16, 32, 3, padding=1),
|
|
nn.BatchNorm2d(32),
|
|
nn.LeakyReLU(0.1),
|
|
nn.MaxPool2d(2, 2),
|
|
|
|
nn.Conv2d(32, 64, 3, padding=1),
|
|
nn.BatchNorm2d(64),
|
|
nn.LeakyReLU(0.1),
|
|
nn.MaxPool2d(2, 2),
|
|
|
|
nn.Conv2d(64, 128, 3, padding=1),
|
|
nn.BatchNorm2d(128),
|
|
nn.LeakyReLU(0.1),
|
|
nn.MaxPool2d(2, 2),
|
|
|
|
nn.Conv2d(128, 256, 3, padding=1),
|
|
nn.BatchNorm2d(256),
|
|
nn.LeakyReLU(0.1),
|
|
nn.MaxPool2d(2, 2),
|
|
|
|
nn.Conv2d(256, 512, 3, padding=1),
|
|
nn.BatchNorm2d(512),
|
|
nn.LeakyReLU(0.1),
|
|
nn.MaxPool2d(2, 2),
|
|
)
|
|
|
|
# 检测头
|
|
self.detector = nn.Sequential(
|
|
nn.Flatten(),
|
|
nn.Linear(512 * (self.S//64)**2, 1024),
|
|
nn.LeakyReLU(0.1),
|
|
nn.Dropout(0.5),
|
|
nn.Linear(1024, self.S * self.S * (self.C + 5)),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x = self.features(x)
|
|
x = self.detector(x)
|
|
x = x.view(-1, self.S, self.S, self.C + 5)
|
|
return x
|
|
|
|
|
|
# 训练函数
|
|
def train(model, dataloader, criterion, optimizer, epoch, device):
|
|
model.train()
|
|
running_loss = 0.0
|
|
progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{Config.epochs}', leave=False)
|
|
|
|
for images, targets in progress_bar:
|
|
images = images.to(device).float()
|
|
targets = targets.to(device).float()
|
|
|
|
optimizer.zero_grad()
|
|
outputs = model(images)
|
|
loss = criterion(outputs, targets)
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
running_loss += loss.item() * images.size(0)
|
|
progress_bar.set_postfix(loss=loss.item())
|
|
|
|
epoch_loss = running_loss / len(dataloader.dataset)
|
|
return epoch_loss
|
|
|
|
# 保存检查点
|
|
def save_checkpoint(model, optimizer, epoch, loss, path):
|
|
torch.save({
|
|
'epoch': epoch,
|
|
'model_state_dict': model.state_dict(),
|
|
'optimizer_state_dict': optimizer.state_dict(),
|
|
'loss': loss,
|
|
}, path)
|
|
|
|
# 可视化预测结果
|
|
def visualize_prediction(image, prediction, threshold=0.5):
|
|
"""可视化单张图像的预测结果"""
|
|
S = Config.grid_size
|
|
cell_size = 1.0 / S
|
|
img_size = image.shape[1]
|
|
fig, ax = plt.subplots(1)
|
|
ax.imshow(image.permute(1, 2, 0).cpu().numpy())
|
|
|
|
for i in range(S):
|
|
for j in range(S):
|
|
confidence = prediction[0, i, j, -1].item()
|
|
if confidence < threshold:
|
|
continue
|
|
|
|
# 获取边界框参数
|
|
x, y, w, h = prediction[0, i, j, Config.num_classes:Config.num_classes+4]
|
|
x = (j + x.item()) * cell_size * img_size
|
|
y = (i + y.item()) * cell_size * img_size
|
|
w = w.item() * img_size / S
|
|
h = h.item() * img_size / S
|
|
|
|
# 获取类别
|
|
class_probs = prediction[0, i, j, :Config.num_classes]
|
|
class_id = torch.argmax(class_probs).item()
|
|
class_name = Config.class_names[class_id]
|
|
|
|
# 绘制边界框
|
|
rect = patches.Rectangle(
|
|
(x - w/2, y - h/2), w, h,
|
|
linewidth=2, edgecolor='r', facecolor='none'
|
|
)
|
|
ax.add_patch(rect)
|
|
plt.text(x - w/2, y - h/2, f'{class_name} {confidence:.2f}',
|
|
color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
|
|
|
|
plt.show()
|
|
|
|
|
|
# 自定义损失函数
|
|
class YOLOLoss(nn.Module):
|
|
def __init__(self):
|
|
super(YOLOLoss, self).__init__()
|
|
self.mse = nn.MSELoss(reduction='sum')
|
|
self.bce = nn.BCEWithLogitsLoss(reduction='sum')
|
|
self.lambda_coord = 5
|
|
self.lambda_noobj = 0.5
|
|
|
|
def forward(self, preds, targets):
|
|
# 识别包含目标的网格
|
|
obj_mask = targets[..., -1] == 1
|
|
noobj_mask = targets[..., -1] == 0
|
|
|
|
# 目标置信度损失
|
|
obj_loss = self.bce(preds[..., -1][obj_mask], targets[..., -1][obj_mask])
|
|
noobj_loss = self.bce(preds[..., -1][noobj_mask], targets[..., -1][noobj_mask])
|
|
confidence_loss = obj_loss + self.lambda_noobj * noobj_loss
|
|
|
|
# 定位损失
|
|
obj_preds = preds[obj_mask]
|
|
obj_targets = targets[obj_mask]
|
|
|
|
# 边界框中心损失
|
|
center_loss = self.mse(torch.sigmoid(obj_preds[..., :2]), obj_targets[..., :2])
|
|
|
|
# 边界框尺寸损失
|
|
wh_loss = self.mse(torch.sqrt(torch.abs(obj_preds[..., 2:4])),
|
|
torch.sqrt(torch.abs(obj_targets[..., 2:4])))
|
|
|
|
coord_loss = center_loss + wh_loss
|
|
|
|
# 分类损失
|
|
class_loss = self.bce(obj_preds[..., :Config.num_classes],
|
|
obj_targets[..., :Config.num_classes])
|
|
|
|
total_loss = (
|
|
self.lambda_coord * coord_loss +
|
|
confidence_loss +
|
|
class_loss
|
|
) / preds.shape[0]
|
|
|
|
return total_loss
|
|
|
|
# 主函数
|
|
def main():
|
|
config = Config()
|
|
|
|
# 准备数据
|
|
transform = transforms.Compose([
|
|
transforms.ToTensor(),
|
|
])
|
|
|
|
dataset = YOLODataset(
|
|
img_dir=config.image_dir,
|
|
label_dir=config.label_dir,
|
|
img_size=config.input_size,
|
|
transform=transform
|
|
)
|
|
|
|
dataloader = DataLoader(
|
|
dataset, batch_size=config.batch_size, shuffle=True, num_workers=4
|
|
)
|
|
|
|
# 初始化模型、损失函数和优化器
|
|
model = TinyYOLO().to(config.device)
|
|
criterion = YOLOLoss()
|
|
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
|
|
|
# 训练循环
|
|
best_loss = float('inf')
|
|
for epoch in range(config.epochs):
|
|
train_loss = train(model, dataloader, criterion, optimizer, epoch, config.device)
|
|
print(f'Epoch [{epoch+1}/{config.epochs}] Loss: {train_loss:.4f}')
|
|
|
|
# 保存最佳模型
|
|
if train_loss < best_loss:
|
|
best_loss = train_loss
|
|
checkpoint_path = os.path.join(
|
|
config.checkpoint_dir, f'best_model_epoch{epoch+1}_loss{train_loss:.4f}.pth'
|
|
)
|
|
save_checkpoint(model, optimizer, epoch, train_loss, checkpoint_path)
|
|
|
|
print('训练完成!保存最终模型...')
|
|
save_checkpoint(model, optimizer, config.epochs, train_loss,
|
|
os.path.join(config.checkpoint_dir, 'final_model.pth'))
|
|
|
|
# 可视化预测
|
|
model.eval()
|
|
sample_image, sample_target = dataset[0]
|
|
with torch.no_grad():
|
|
prediction = model(sample_image.unsqueeze(0).to(config.device))
|
|
visualize_prediction(sample_image, prediction.cpu())
|
|
|
|
if __name__ == "__main__":
|
|
main() |