import torch def train_model(model, dataloader, criterion, optimizer, device, num_epochs=10): model.train() train_losses, train_accs = [], [] for epoch in range(num_epochs): running_loss = 0.0 correct = 0 total = 0 for images, labels in dataloader: images = images.to(device) labels = labels.to(device) outputs = model(images) loss = criterion(outputs, labels) # 反向传播 optimizer.zero_grad() loss.backward() optimizer.step() # 统计数据 running_loss += loss.item() epoch_loss = running_loss / len(dataloader) train_losses.append(epoch_loss) print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}") def predict(model, test_loader, device): model.eval() with torch.no_grad(): for images, labels in test_loader: images = images.to(device) labels = labels.to(device) outputs = model(images)