Week0_训练框架

SyEic_L Lv4

最小PyTorch训练项目

  • dataset:
    • train_dataset
    • validation_dataset
  • dataloader
    • 如果是图片之类,需要先转成tensor,transform = transforms.ToTensor()
    • images.shape()=[batch, channel, height, width]
    • label.shape()=[batch, label]
  • model
    • 继承自nn.Module
    • __init__中需要调用super().__init__()nn.Flatten()用来将tensor转成一维向量,定义不同层
    • 定义forward(self, x),记得返回x
    • 分为model.train()model.eval()
    • 调用model(images)本质是调用model.forward(images),返回值是logits
  • loss:criterion
    • 分类任务用CrossEntropyLoss等
  • optimizer
    • 在得到loss后更新模型参数,例如Adam、SGD等
  • train loop
    • optimizer.zero_grad():清除上一batch的梯度
    • loss.backward():反向传播计算梯度
    • optimizer.step():根据lr更新参数
  • eval loop
    • 只需要计算loss、acc等,不需要更新模型参数
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.optim as optim
import torch
from pathlib import Path
import matplotlib.pyplot as plt

def get_dataloader():
transform = transforms.ToTensor()

train_dataset = datasets.FashionMNIST(
root="./datasets",
train=True,
download=True,
transform=transform,
)

train_loader = DataLoader(
train_dataset,
batch_size=128,
shuffle=True,
)

test_dataset = datasets.FashionMNIST(
root="./datasets",
train=False,
download=True,
transform=transform,
)

test_loader = DataLoader(
test_dataset,
batch_size=128,
shuffle=False,
)

return train_loader, test_loader

class MLP(nn.Module):
def __init__(self):
super().__init__()
self.flatten = nn.Flatten()
self.linear1 = nn.Linear(28*28, 256)
self.relu = nn.ReLU()
self.linear2 = nn.Linear(256, 10)

def forward(self, x):
x = self.flatten(x)
x = self.linear1(x)
x = self.relu(x)
x = self.linear2(x)
return x


def main():
train_loader, test_loader = get_dataloader()

model = MLP()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs = 5

path = Path("checkpoints")
path.mkdir(exist_ok=True)

path = Path("plots")
path.mkdir(exist_ok=True)

train_losses = []
val_losses = []

best_val_acc = 0.0
for epoch in range(epochs):
model.train()
train_loss = 0.0
train_correct = 0
train_total = 0
for images, labels in train_loader:
output = model(images)

preds = output.argmax(dim=1)
train_correct += (preds == labels).sum().item()
train_total += labels.size(0)

loss = criterion(output, labels)
train_loss += loss.item() * labels.size(0)

optimizer.zero_grad()
loss.backward()
optimizer.step()

train_loss = train_loss / train_total
train_acc = train_correct / train_total


model.eval()
val_loss = 0.0
val_correct = 0
val_total = 0
with torch.no_grad():
for images, labels in test_loader:
output = model(images)

preds = output.argmax(dim=1)
val_correct += (preds == labels).sum().item()
val_total += labels.size(0)

loss = criterion(output, labels)
val_loss += loss.item() * labels.size(0)

val_loss = val_loss / val_total
val_acc = val_correct / val_total

train_losses.append(train_loss)
val_losses.append(val_loss)

print(f"Epoch {epoch+1}/{epochs}")
print(f"Train Loss: {train_loss:.4f}\nTrain Acc: {train_acc:.4f}")
print(f"Val Loss: {val_loss:.4f}\nVal Acc: {val_acc:.4f}\n")

if val_acc > best_val_acc:
best_val_acc = val_acc
torch.save({
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"epoch": epoch+1,
"val_acc": val_acc,
}, "./checkpoints/best_model.pth")

epochs_range = range(1, epochs + 1)
plt.figure(figsize=(10, 5))
plt.plot(epochs_range, train_losses, label="Train Loss")
plt.plot(epochs_range, val_losses, label="Val Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Train and Val Loss")
plt.legend()
plt.savefig("./plots/loss_plot.png")
plt.close()


if __name__ == "__main__":
main()
  • Title: Week0_训练框架
  • Author: SyEic_L
  • Created at : 2026-04-02 17:35:12
  • Updated at : 2026-04-04 09:46:52
  • Link: https://blog.syeicl.vip/2026/04/02/Week0-训练框架/
  • License: This work is licensed under CC BY-NC-SA 4.0.
Comments
On this page
Week0_训练框架