first commit
This commit is contained in:
commit
b74f83079f
17
001FCN.py
Normal file
17
001FCN.py
Normal file
@ -0,0 +1,17 @@
|
||||
import torch.nn as nn
|
||||
|
||||
class FisrFCN(nn.Module):
|
||||
def __init__(self, in_c, out_c) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.fc1 = nn.Linear(in_c, 128)
|
||||
self.fc2 = nn.Linear(128, 64)
|
||||
self.fc3 = nn.Linear(64, out_c)
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(self.fc1(x))
|
||||
x = self.relu(self.fc2(x))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
80
002自定义数据集.py
Normal file
80
002自定义数据集.py
Normal file
@ -0,0 +1,80 @@
|
||||
from itertools import cycle
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
# from PIL import Image
|
||||
import numpy as np
|
||||
import cv2
|
||||
import os
|
||||
|
||||
|
||||
|
||||
class CustomImageDataset(Dataset):
|
||||
def __init__(self, root_dir, transform=None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.root_dir = root_dir
|
||||
self.transform = transform
|
||||
|
||||
self.image_path = self.root_dir+"images"
|
||||
self.label_path = self.root_dir+"lables"
|
||||
|
||||
self.all_images = [os.path.join(self.image_path, t) for t in os.listdir(self.image_path)]
|
||||
self.all_lables = [os.path.join(self.image_path, t) for t in os.listdir(self.label_path)]
|
||||
|
||||
# 一定要实现 len方法
|
||||
def __len__(self):
|
||||
return len(self.all_images)
|
||||
|
||||
# 实现获取图片方法
|
||||
def __getitem__(self, index):
|
||||
image_path = self.all_images[index]
|
||||
lable_path = self.all_lables[index]
|
||||
|
||||
image = cv2.imread(image_path)
|
||||
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
||||
image_h, image_w = image.shape[:2]
|
||||
|
||||
with open(lable_path, "r", encoding="utf-8") as f:
|
||||
lines = f.readlines()
|
||||
|
||||
label = list()
|
||||
for line in lines:
|
||||
data = line.strip().split()
|
||||
|
||||
class_id = int(data[0])
|
||||
cx, cy, w, h = map(float, data[1:])
|
||||
|
||||
# 转换为绝对坐标
|
||||
x_center = cx * image_w
|
||||
y_center = cy * image_h
|
||||
box_w = w * image_w
|
||||
box_h = h * image_h
|
||||
|
||||
# 计算边界框坐标
|
||||
x_min = max(0, x_center - box_w/2)
|
||||
y_min = max(0, y_center - box_h/2)
|
||||
x_max = min(image_w, x_center + box_w/2)
|
||||
y_max = min(image_h, y_center + box_h/2)
|
||||
|
||||
label.append([class_id, x_min, y_min, x_max, y_max])
|
||||
label = np.array(label)
|
||||
|
||||
return image, label
|
||||
|
||||
if __name__ == "__main__":
|
||||
train_dataset = CustomImageDataset("")
|
||||
test_dataset = CustomImageDataset("")
|
||||
|
||||
train_loader = DataLoader(
|
||||
train_dataset,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
num_workers=2,
|
||||
pin_memory=True # 加速数据传输到GPU
|
||||
)
|
||||
|
||||
test_loader = DataLoader(
|
||||
test_dataset,
|
||||
batch_size=32,
|
||||
shuffle=True,
|
||||
num_workers=2
|
||||
)
|
||||
43
003训练模型预测模型.py
Normal file
43
003训练模型预测模型.py
Normal file
@ -0,0 +1,43 @@
|
||||
import torch
|
||||
|
||||
|
||||
|
||||
def train_model(model, dataloader, criterion, optimizer, device, num_epochs=10):
|
||||
model.train()
|
||||
train_losses, train_accs = [], []
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
running_loss = 0.0
|
||||
correct = 0
|
||||
total = 0
|
||||
for images, labels in dataloader:
|
||||
images = images.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, labels)
|
||||
|
||||
# 反向传播
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
# 统计数据
|
||||
running_loss += loss.item()
|
||||
|
||||
epoch_loss = running_loss / len(dataloader)
|
||||
train_losses.append(epoch_loss)
|
||||
|
||||
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
|
||||
|
||||
|
||||
def predict(model, test_loader, device):
|
||||
model.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
for images, labels in test_loader:
|
||||
images = images.to(device)
|
||||
labels = labels.to(device)
|
||||
|
||||
outputs = model(images)
|
||||
|
||||
328
yolo_common.py
Normal file
328
yolo_common.py
Normal file
@ -0,0 +1,328 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.utils.data import Dataset, DataLoader
|
||||
from torchvision import transforms
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as patches
|
||||
from tqdm import tqdm
|
||||
|
||||
# 配置参数
|
||||
class Config:
|
||||
# 数据集参数
|
||||
data_dir = 'path/to/your/dataset' # 替换为你的数据集路径
|
||||
image_dir = os.path.join(data_dir, 'images')
|
||||
label_dir = os.path.join(data_dir, 'labels')
|
||||
class_names = ['class1', 'class2', 'class3'] # 替换为你的类别名称
|
||||
num_classes = len(class_names)
|
||||
|
||||
# 模型参数
|
||||
grid_size = 7 # 特征图网格大小
|
||||
input_size = 224 # 输入图像尺寸
|
||||
|
||||
# 训练参数
|
||||
batch_size = 32
|
||||
epochs = 50
|
||||
lr = 0.001
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
checkpoint_dir = 'checkpoints'
|
||||
os.makedirs(checkpoint_dir, exist_ok=True)
|
||||
|
||||
class YOLODataset(Dataset):
|
||||
def __init__(self, img_dir, label_dir, img_size=224, transform=None) -> None:
|
||||
self.img_dir = img_dir
|
||||
self.label_dir = label_dir
|
||||
self.img_size = img_size
|
||||
self.transform = transform
|
||||
self.img_files = [os.path.join(self.img_dir, t) for t in os.listdir(self.img_dir)]
|
||||
self.label_files = [os.path.join(self.label_dir, t) for t in os.listdir(self.label_dir)]
|
||||
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_files)
|
||||
|
||||
def __getitem__(self, index):
|
||||
|
||||
img_file = self.img_files[index]
|
||||
lable_file = self.label_files[index]
|
||||
|
||||
img = cv2.imread(img_file)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
orig_h, orig_w = img.shape[:2]
|
||||
|
||||
label = []
|
||||
with open(lable_file, 'r', encoding="utf-8") as f:
|
||||
for line in f.readlines():
|
||||
class_id, x_center, y_center, width, height = map(float, line.split())
|
||||
label.append([class_id, x_center, y_center, width, height])
|
||||
|
||||
# 图像预处理
|
||||
img = cv2.resize(img, (self.img_size, self.img_size))
|
||||
img = img.astype(np.float32) / 255.0
|
||||
|
||||
# 转换标签格式
|
||||
target = self._create_target_grid(label, orig_w, orig_h)
|
||||
|
||||
if self.transform:
|
||||
img = self.transform(img)
|
||||
|
||||
return torch.tensor(img).permute(2, 0, 1), torch.tensor(target)
|
||||
|
||||
def _create_target_grid(self, boxes, orig_w, orig_h):
|
||||
"""创建网格化目标张量 [S, S, C+5]"""
|
||||
S = Config.grid_size
|
||||
target = torch.zeros((S, S, Config.num_classes + 5))
|
||||
cell_size = 1.0 / S
|
||||
|
||||
for box in boxes:
|
||||
class_id, x_center, y_center, width, height = box
|
||||
|
||||
# 计算所属网格位置
|
||||
grid_x = int(x_center // cell_size)
|
||||
grid_y = int(y_center // cell_size)
|
||||
|
||||
# 计算相对网格的坐标
|
||||
x_offset = (x_center - grid_x * cell_size) / cell_size
|
||||
y_offset = (y_center - grid_y * cell_size) / cell_size
|
||||
|
||||
# 归一化宽高
|
||||
width = width * S
|
||||
height = height * S
|
||||
|
||||
# 目标张量: [class_one_hot + x, y, w, h, confidence]
|
||||
class_one_hot = torch.zeros(Config.num_classes)
|
||||
class_one_hot[int(class_id)] = 1
|
||||
|
||||
target[grid_y, grid_x, :Config.num_classes] = class_one_hot
|
||||
target[grid_y, grid_x, Config.num_classes:Config.num_classes+4] = torch.tensor([x_offset, y_offset, width, height])
|
||||
target[grid_y, grid_x, -1] = 1 # 有目标置信度
|
||||
|
||||
return target
|
||||
|
||||
|
||||
# YOLO模型架构 (简化版)
|
||||
class TinyYOLO(nn.Module):
|
||||
def __init__(self):
|
||||
super(TinyYOLO, self).__init__()
|
||||
self.S = Config.grid_size
|
||||
self.C = Config.num_classes
|
||||
|
||||
# 特征提取主干
|
||||
self.features = nn.Sequential(
|
||||
nn.Conv2d(3, 16, 3, padding=1),
|
||||
nn.BatchNorm2d(16),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
nn.Conv2d(16, 32, 3, padding=1),
|
||||
nn.BatchNorm2d(32),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
nn.Conv2d(32, 64, 3, padding=1),
|
||||
nn.BatchNorm2d(64),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
nn.Conv2d(64, 128, 3, padding=1),
|
||||
nn.BatchNorm2d(128),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
nn.Conv2d(128, 256, 3, padding=1),
|
||||
nn.BatchNorm2d(256),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.MaxPool2d(2, 2),
|
||||
|
||||
nn.Conv2d(256, 512, 3, padding=1),
|
||||
nn.BatchNorm2d(512),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.MaxPool2d(2, 2),
|
||||
)
|
||||
|
||||
# 检测头
|
||||
self.detector = nn.Sequential(
|
||||
nn.Flatten(),
|
||||
nn.Linear(512 * (self.S//64)**2, 1024),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Dropout(0.5),
|
||||
nn.Linear(1024, self.S * self.S * (self.C + 5)),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.detector(x)
|
||||
x = x.view(-1, self.S, self.S, self.C + 5)
|
||||
return x
|
||||
|
||||
|
||||
# 训练函数
|
||||
def train(model, dataloader, criterion, optimizer, epoch, device):
|
||||
model.train()
|
||||
running_loss = 0.0
|
||||
progress_bar = tqdm(dataloader, desc=f'Epoch {epoch+1}/{Config.epochs}', leave=False)
|
||||
|
||||
for images, targets in progress_bar:
|
||||
images = images.to(device).float()
|
||||
targets = targets.to(device).float()
|
||||
|
||||
optimizer.zero_grad()
|
||||
outputs = model(images)
|
||||
loss = criterion(outputs, targets)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
running_loss += loss.item() * images.size(0)
|
||||
progress_bar.set_postfix(loss=loss.item())
|
||||
|
||||
epoch_loss = running_loss / len(dataloader.dataset)
|
||||
return epoch_loss
|
||||
|
||||
# 保存检查点
|
||||
def save_checkpoint(model, optimizer, epoch, loss, path):
|
||||
torch.save({
|
||||
'epoch': epoch,
|
||||
'model_state_dict': model.state_dict(),
|
||||
'optimizer_state_dict': optimizer.state_dict(),
|
||||
'loss': loss,
|
||||
}, path)
|
||||
|
||||
# 可视化预测结果
|
||||
def visualize_prediction(image, prediction, threshold=0.5):
|
||||
"""可视化单张图像的预测结果"""
|
||||
S = Config.grid_size
|
||||
cell_size = 1.0 / S
|
||||
img_size = image.shape[1]
|
||||
fig, ax = plt.subplots(1)
|
||||
ax.imshow(image.permute(1, 2, 0).cpu().numpy())
|
||||
|
||||
for i in range(S):
|
||||
for j in range(S):
|
||||
confidence = prediction[0, i, j, -1].item()
|
||||
if confidence < threshold:
|
||||
continue
|
||||
|
||||
# 获取边界框参数
|
||||
x, y, w, h = prediction[0, i, j, Config.num_classes:Config.num_classes+4]
|
||||
x = (j + x.item()) * cell_size * img_size
|
||||
y = (i + y.item()) * cell_size * img_size
|
||||
w = w.item() * img_size / S
|
||||
h = h.item() * img_size / S
|
||||
|
||||
# 获取类别
|
||||
class_probs = prediction[0, i, j, :Config.num_classes]
|
||||
class_id = torch.argmax(class_probs).item()
|
||||
class_name = Config.class_names[class_id]
|
||||
|
||||
# 绘制边界框
|
||||
rect = patches.Rectangle(
|
||||
(x - w/2, y - h/2), w, h,
|
||||
linewidth=2, edgecolor='r', facecolor='none'
|
||||
)
|
||||
ax.add_patch(rect)
|
||||
plt.text(x - w/2, y - h/2, f'{class_name} {confidence:.2f}',
|
||||
color='white', fontsize=8, bbox=dict(facecolor='red', alpha=0.5))
|
||||
|
||||
plt.show()
|
||||
|
||||
|
||||
# 自定义损失函数
|
||||
class YOLOLoss(nn.Module):
|
||||
def __init__(self):
|
||||
super(YOLOLoss, self).__init__()
|
||||
self.mse = nn.MSELoss(reduction='sum')
|
||||
self.bce = nn.BCEWithLogitsLoss(reduction='sum')
|
||||
self.lambda_coord = 5
|
||||
self.lambda_noobj = 0.5
|
||||
|
||||
def forward(self, preds, targets):
|
||||
# 识别包含目标的网格
|
||||
obj_mask = targets[..., -1] == 1
|
||||
noobj_mask = targets[..., -1] == 0
|
||||
|
||||
# 目标置信度损失
|
||||
obj_loss = self.bce(preds[..., -1][obj_mask], targets[..., -1][obj_mask])
|
||||
noobj_loss = self.bce(preds[..., -1][noobj_mask], targets[..., -1][noobj_mask])
|
||||
confidence_loss = obj_loss + self.lambda_noobj * noobj_loss
|
||||
|
||||
# 定位损失
|
||||
obj_preds = preds[obj_mask]
|
||||
obj_targets = targets[obj_mask]
|
||||
|
||||
# 边界框中心损失
|
||||
center_loss = self.mse(torch.sigmoid(obj_preds[..., :2]), obj_targets[..., :2])
|
||||
|
||||
# 边界框尺寸损失
|
||||
wh_loss = self.mse(torch.sqrt(torch.abs(obj_preds[..., 2:4])),
|
||||
torch.sqrt(torch.abs(obj_targets[..., 2:4])))
|
||||
|
||||
coord_loss = center_loss + wh_loss
|
||||
|
||||
# 分类损失
|
||||
class_loss = self.bce(obj_preds[..., :Config.num_classes],
|
||||
obj_targets[..., :Config.num_classes])
|
||||
|
||||
total_loss = (
|
||||
self.lambda_coord * coord_loss +
|
||||
confidence_loss +
|
||||
class_loss
|
||||
) / preds.shape[0]
|
||||
|
||||
return total_loss
|
||||
|
||||
# 主函数
|
||||
def main():
|
||||
config = Config()
|
||||
|
||||
# 准备数据
|
||||
transform = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
dataset = YOLODataset(
|
||||
img_dir=config.image_dir,
|
||||
label_dir=config.label_dir,
|
||||
img_size=config.input_size,
|
||||
transform=transform
|
||||
)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=config.batch_size, shuffle=True, num_workers=4
|
||||
)
|
||||
|
||||
# 初始化模型、损失函数和优化器
|
||||
model = TinyYOLO().to(config.device)
|
||||
criterion = YOLOLoss()
|
||||
optimizer = optim.Adam(model.parameters(), lr=config.lr)
|
||||
|
||||
# 训练循环
|
||||
best_loss = float('inf')
|
||||
for epoch in range(config.epochs):
|
||||
train_loss = train(model, dataloader, criterion, optimizer, epoch, config.device)
|
||||
print(f'Epoch [{epoch+1}/{config.epochs}] Loss: {train_loss:.4f}')
|
||||
|
||||
# 保存最佳模型
|
||||
if train_loss < best_loss:
|
||||
best_loss = train_loss
|
||||
checkpoint_path = os.path.join(
|
||||
config.checkpoint_dir, f'best_model_epoch{epoch+1}_loss{train_loss:.4f}.pth'
|
||||
)
|
||||
save_checkpoint(model, optimizer, epoch, train_loss, checkpoint_path)
|
||||
|
||||
print('训练完成!保存最终模型...')
|
||||
save_checkpoint(model, optimizer, config.epochs, train_loss,
|
||||
os.path.join(config.checkpoint_dir, 'final_model.pth'))
|
||||
|
||||
# 可视化预测
|
||||
model.eval()
|
||||
sample_image, sample_target = dataset[0]
|
||||
with torch.no_grad():
|
||||
prediction = model(sample_image.unsqueeze(0).to(config.device))
|
||||
visualize_prediction(sample_image, prediction.cpu())
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Loading…
Reference in New Issue
Block a user