前言

在英伟达深度学习的教程中,我们学习了对美国手语数据集的模型训练,我们需要用到保存已经写的模型,然后加载它,进行图像的预测。

正文

首先,要知道2个概念:
保存模型 = 保存权重
加载模型 = 恢复权重

其次,保存并加载模型,有两种方法:
① 保存 model,保存内容有 结构+权重+路径
② 保存 model.state_dict(),保存内容只有 权重

第一种保存加载模型方式:保存model(不推荐)

首先,创建一个save.py文件,用来定义模型和保存模型
我们定义了一个名为 BaiQing类卷积块

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
class BaiQing(nn.Module):
def __init__(self, in_ch, out_ch, dropout_p):
kernel_size = 3
super().__init__()

self.model = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size, stride=1, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Dropout(dropout_p),
nn.MaxPool2d(2, stride=2)
)

def forward(self, x):
return self.model(x)

要注意的是,该类初始化了3个量,后续会提到
然后,我们定义一个基于 BaiQing卷积块基础模型base_model

1
2
3
4
5
6
7
8
9
10
11
12
13
IMG_CHS = 1     # 图片通道数(Channels)
n_classes = 25 # 分类类别数量
flattened_img_size = 75 * 3 * 3
base_model = nn.Sequential(
BaiQing(IMG_CHS, 25, 0),
BaiQing(25, 50, 0.2),
BaiQing(50, 75, 0.2),
nn.Flatten(),
nn.Linear(flattened_img_size, 512),
nn.Dropout(0.3),
nn.ReLU(),
nn.Linear(512, n_classes)
)

没错,用同一个类 BaiQing 创建了 3 个不同的层实例,请注意,后续会提到

重点来了,第一种保存模型方法,直接保存model

1
torch.save(base_model, 'model.pth', weights_only=False)

我们将 基础模型base_model 直接保存,命名为 model.pth 文件
然后,我们创建一个 load.py 文件用来调用模型
开始调用 model.pth 之前,因为base_model中有用到BaiQing类,所以需要现在新的文件内再次定义,或者直接导包

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class BaiQing(nn.Module):
def __init__(self, in_ch, out_ch, dropout_p):
kernel_size = 3
super().__init__()

self.model = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size, stride=1, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Dropout(dropout_p),
nn.MaxPool2d(2, stride=2)
)

def forward(self, x):
return self.model(x)

model = torch.load("model.pth")

于是我们就加载了模型了

‼️ 但是,这种方法不推荐:虽然看起来非常方便,但是问题很多

  1. 如果改了文件名或者类名,就直接无法运行
  2. 更换conda环境,PyTorch版本不对,也可能加载失败
  3. 模型的结构更改需要回到save文件,不可控,像薛定谔的猫
  4. 写起来比较省事,可以用于临时代码实验

第二种保存加载模型方式:保存model.state_dict()(推荐,权威)

我们同样创建 save.py文件load.py文件
save.py文件 里创建好 卷积块基础模型

使用第二种保存模型方法,只保存权重,保存model.state_dict()

1
torch.save(model.state_dict(), "model.pth")

然后进入 load.py文件
由于model.pth 内现在只保存了模型的权重,所以我们需要把模型实例化:

1
2
3
4
5
6
# 模型实例化
model = BaiQing()
# 载入模型
model.load_state_dict(torch.load("model.pth", map_location=device, weights_only=False))
# 模型评估模式
model.eval()

这样是正常的保存和载入的步骤
但是,我们很快会发现:

1
2
3
4
5
6
7
8
class BaiQing(nn.Module):
def __init__(self, in_ch, out_ch, dropout_p):
·
·
·
model = BaiQing()
·
·

没错,我们的卷积块,初始化了3个变量,而实例化的时候,没有传入任何参数,这会导致报错:

1
TypeError: __init__() missing 3 required positional arguments: 'in_ch', 'out_ch', and 'dropout_p'

因为我们在save.py文件中的模型,BaiQing类初始化了3个量,所以我们不能直接在load.py文件中,直接传入一个量的参数,进行加载模型
所以需要再load.py文件中添加一个步骤:再次定义一个 base_model基础模型

完整代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
class BaiQing(nn.Module):
def __init__(self, in_ch, out_ch, dropout_p):
kernel_size = 3
super().__init__()

self.model = nn.Sequential(
nn.Conv2d(in_ch, out_ch, kernel_size, stride=1, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(),
nn.Dropout(dropout_p),
nn.MaxPool2d(2, stride=2)
)

def forward(self, x):
return self.model(x)

IMG_CHS = 1 # 图片通道数(Channels)
n_classes = 25 # 分类类别数量

base_model = nn.Sequential(
BaiQing(IMG_CHS, 25, 0),
BaiQing(25, 50, 0.2),
BaiQing(50, 75, 0.2),
nn.Flatten(),
nn.Linear(75 * 3 * 3, 512),
nn.Dropout(0.3),
nn.ReLU(),
nn.Linear(512, n_classes)
)

state_dict = torch.load('model.pth', map_location=device, weights_only=False)
base_model.load_state_dict(state_dict)

首先定义类 BaiQing卷积块 和 基础模型base_model
再把权重加载到 state_dict 字典中
最后,利用 load_state_dict(state_dict) 方法,成功的在load.py文件中,把模型加载到base_model内提供预测

总结

两种方法保存和加载模型,前者便捷,但不安全;后者麻烦,但更权威。
强调后者:只保存权重,所以需要实例化;若卷积块存在多个参数,且被模型多次初始化,则需要重新定义一个模型,单独初始化,再实例化模型,加载权重

AI给了我一个非常生动形象的比喻:
保存 model 就是把整个人冷冻起来,包括:

  • 他的大脑记忆 (权重)
  • 身体结构 (模型结构)
  • 身份证地址 (类定义的路径)
  • 当时的环境 (代码文件名、模块位置)

加载 model 时,解释把这整个人解冻:
虽然人原样复活,但是如果:

  • 换了房间 (代码路径改变)
  • 换了国家 (PyTorch/Conda版本环境错了)
  • 改了名字 (类名/文件名变了)

则解冻失败,直接报错

再来反观保存 model.state_dict()
只保存大脑里的记忆:

  • 包括:技能、经验、记忆、神经链接强度(权重)
  • 不包括:他穿的衣服,站在哪,谁生的

加载 state_dict()
相当于先造了一个身体(state_dict)
再把记忆(权重)灌进去
只要:

  • 身体结构一致
  • 大脑接口一致

则加载成功

映射到上文我写的代码,不难发现:
身体骨架 就是 模型结构(代码)
大脑记忆 就是 state_dict()
冷冻整个人 就是 torch.save(model)