43 lines
1.1 KiB
Python
43 lines
1.1 KiB
Python
import torch
|
|
|
|
|
|
|
|
def train_model(model, dataloader, criterion, optimizer, device, num_epochs=10):
|
|
model.train()
|
|
train_losses, train_accs = [], []
|
|
|
|
for epoch in range(num_epochs):
|
|
running_loss = 0.0
|
|
correct = 0
|
|
total = 0
|
|
for images, labels in dataloader:
|
|
images = images.to(device)
|
|
labels = labels.to(device)
|
|
|
|
outputs = model(images)
|
|
loss = criterion(outputs, labels)
|
|
|
|
# 反向传播
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
|
|
# 统计数据
|
|
running_loss += loss.item()
|
|
|
|
epoch_loss = running_loss / len(dataloader)
|
|
train_losses.append(epoch_loss)
|
|
|
|
print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")
|
|
|
|
|
|
def predict(model, test_loader, device):
|
|
model.eval()
|
|
|
|
with torch.no_grad():
|
|
for images, labels in test_loader:
|
|
images = images.to(device)
|
|
labels = labels.to(device)
|
|
|
|
outputs = model(images)
|
|
|