本文基于d2l项目内容整理,介绍深度学习中模型参数的保存与加载技术,包括最佳实践和常见问题的解决方案。

1. 模型参数读写的重要性

以 PyTorch 为例,可以使用torch.load()torch.save()函数读写包括张量在内的任意 Python 对象,如:列表、字典、模型的结构与参数、优化器状态等。

为什么需要保存模型参数?

运行耗时较长的训练时,最佳的做法是定期保存模型参数,防止因意外而丢失全部计算结果。这对于深度学习项目的稳定性和可持续性至关重要。


2. 模型参数的保存

2.1 保存模型参数

基本保存方法

使用Module.state_dict()方法可获取模型Module当前状态 (state) 的全部参数字典,包含了权重和偏置:

1
2
3
4
5
6
7
8
9
10
11
12
import torch
import torch.nn as nn

# 定义一个简单的模型
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 1)
)

# 保存模型参数
torch.save(model.state_dict(), 'model.pth')

文件扩展名约定:

.pth是 PyTorch 序列化文件扩展名的约定。这个扩展名帮助识别文件类型,但不是强制要求。

保存完整模型(不推荐)

1
2
# 直接保存完整模型(不推荐)
torch.save(model, 'complete_model.pth')

为什么不推荐保存完整模型?

由于模型与代码间可能的依赖关系变化,直接保存完整的模型不是推荐的做法。这种方式可能导致:

  • 代码结构变化时无法正确加载
  • 跨版本兼容性问题
  • 文件体积更大
  • 安全性风险

2.2 保存优化器状态

可以同样使用state_dict()方法保存当前的优化器状态:

1
2
3
4
5
6
7
import torch.optim as optim

# 定义优化器
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 保存优化器状态
torch.save(optimizer.state_dict(), 'optimizer.pth')

2.3 联合保存(推荐)

基本联合保存:

由于torch.save()函数支持保存任意 Python 对象,可用字典组织优化器和参数等状态后联合保存:

1
2
3
4
5
6
7
8
# 联合保存多个状态
state = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': 10,
'loss': 0.1
}
torch.save(state, 'checkpoint.pth')

完整检查点保存:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
import time

def save_checkpoint(model, optimizer, epoch, loss, filepath):
"""保存训练检查点"""
checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'loss': loss,
'model_architecture': str(model),
'timestamp': torch.tensor(time.time())
}
torch.save(checkpoint, filepath)
print(f"检查点已保存到: {filepath}")

检查点 (Checkpoint) 的概念:

Checkpoint表示模型当前状态的快照。一般包括模型和优化器的状态字典、训练的轮次 (epoch)、学习率等其他超参数。这是深度学习训练中的标准做法。


3. 模型参数的加载

3.1 设备兼容性处理

设备兼容性注意事项:

若需要将来自 GPU 的模型加载到 CPU 上,应在load()函数中指定map_location参数:

1
2
3
4
5
# 将GPU模型加载到CPU
model = torch.load('model.pth', map_location=torch.device('cpu'))

# 或者使用字符串形式
model = torch.load('model.pth', map_location='cpu')

3.2 加载模型参数

标准加载方法

使用Module.load_state_dict()方法可从文件对象加载状态字典(反序列化):

1
2
3
4
5
6
7
8
9
10
# 创建模型实例(结构必须与保存时一致)
model = nn.Sequential(
nn.Linear(10, 5),
nn.ReLU(),
nn.Linear(5, 1)
)

# 加载模型参数
model.load_state_dict(torch.load('model.pth'))
model.eval() # 设置为评估模式

重要提醒:

在加载模型参数时,通常应确保模型结构与保存时一致。如果结构不匹配,会导致加载失败。

加载完整模型

1
2
# 加载完整模型(如果之前保存的是完整模型)
model = torch.load('complete_model.pth')

3.3 加载优化器状态

1
2
3
4
5
# 创建优化器实例
optimizer = optim.Adam(model.parameters(), lr=0.001)

# 加载优化器状态
optimizer.load_state_dict(torch.load('optimizer.pth'))

3.4 联合加载

基本联合加载:

1
2
3
4
5
6
7
8
9
10
11
12
# 加载检查点
checkpoint = torch.load('checkpoint.pth')

# 恢复模型和优化器状态
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

# 恢复训练信息
epoch = checkpoint['epoch']
loss = checkpoint['loss']

print(f"恢复到第 {epoch} 轮训练,损失值:{loss}")

安全加载(带错误处理):

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
def load_checkpoint(model, optimizer, filepath):
"""安全加载训练检查点"""
try:
checkpoint = torch.load(filepath, map_location='cpu')

# 检查必要的键是否存在
required_keys = ['model_state_dict', 'optimizer_state_dict', 'epoch', 'loss']
for key in required_keys:
if key not in checkpoint:
raise KeyError(f"检查点文件缺少必要的键: {key}")

# 加载状态
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

epoch = checkpoint['epoch']
loss = checkpoint['loss']

print(f"成功加载检查点: epoch={epoch}, loss={loss:.4f}")
return epoch, loss

except FileNotFoundError:
print(f"检查点文件不存在: {filepath}")
return 0, float('inf')
except Exception as e:
print(f"加载检查点时出错: {e}")
return 0, float('inf')

4. 高级应用

4.1 训练过程中的自动保存

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import os
import time

class CheckpointManager:
def __init__(self, save_dir='checkpoints', max_keep=5):
self.save_dir = save_dir
self.max_keep = max_keep
os.makedirs(save_dir, exist_ok=True)

def save(self, model, optimizer, epoch, loss, metrics=None):
"""保存检查点"""
timestamp = int(time.time())
filename = f'checkpoint_epoch_{epoch}_{timestamp}.pth'
filepath = os.path.join(self.save_dir, filename)

checkpoint = {
'model_state_dict': model.state_dict(),
'optimizer_state_dict': optimizer.state_dict(),
'epoch': epoch,
'loss': loss,
'timestamp': timestamp
}

if metrics:
checkpoint['metrics'] = metrics

torch.save(checkpoint, filepath)
self._cleanup_old_checkpoints()
return filepath

def _cleanup_old_checkpoints(self):
"""清理旧的检查点文件"""
checkpoints = []
for filename in os.listdir(self.save_dir):
if filename.startswith('checkpoint_') and filename.endswith('.pth'):
filepath = os.path.join(self.save_dir, filename)
checkpoints.append((filepath, os.path.getctime(filepath)))

# 按创建时间排序,保留最新的几个
checkpoints.sort(key=lambda x: x[1], reverse=True)
for filepath, _ in checkpoints[self.max_keep:]:
os.remove(filepath)
print(f"删除旧检查点: {filepath}")

# 使用示例
checkpoint_manager = CheckpointManager()

# 在训练循环中
for epoch in range(num_epochs):
# ... 训练代码 ...

if epoch % 10 == 0: # 每10个epoch保存一次
checkpoint_manager.save(model, optimizer, epoch, train_loss)

4.2 模型版本管理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
class ModelVersionManager:
def __init__(self, base_path='models'):
self.base_path = base_path
os.makedirs(base_path, exist_ok=True)

def save_model(self, model, version, metadata=None):
"""保存模型版本"""
version_dir = os.path.join(self.base_path, f'v{version}')
os.makedirs(version_dir, exist_ok=True)

# 保存模型参数
model_path = os.path.join(version_dir, 'model.pth')
torch.save(model.state_dict(), model_path)

# 保存元数据
if metadata:
metadata_path = os.path.join(version_dir, 'metadata.json')
import json
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)

print(f"模型版本 v{version} 已保存到 {version_dir}")

def load_model(self, model, version):
"""加载指定版本的模型"""
model_path = os.path.join(self.base_path, f'v{version}', 'model.pth')
if os.path.exists(model_path):
model.load_state_dict(torch.load(model_path, map_location='cpu'))
print(f"成功加载模型版本 v{version}")
return True
else:
print(f"模型版本 v{version} 不存在")
return False

# 使用示例
version_manager = ModelVersionManager()

# 保存模型版本
metadata = {
'accuracy': 0.95,
'loss': 0.05,
'training_time': '2 hours',
'dataset': 'CIFAR-10'
}
version_manager.save_model(model, '1.0', metadata)

4.3 跨平台兼容性

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
def save_model_portable(model, filepath):
"""保存可跨平台使用的模型"""
# 确保模型在CPU上
model_cpu = model.cpu()

# 保存时指定协议版本以确保兼容性
torch.save(model_cpu.state_dict(), filepath, _use_new_zipfile_serialization=False)

print(f"模型已保存为跨平台兼容格式: {filepath}")

def load_model_safe(model, filepath):
"""安全加载模型,处理各种异常情况"""
try:
# 尝试加载
state_dict = torch.load(filepath, map_location='cpu')

# 检查键的兼容性
model_keys = set(model.state_dict().keys())
loaded_keys = set(state_dict.keys())

missing_keys = model_keys - loaded_keys
unexpected_keys = loaded_keys - model_keys

if missing_keys:
print(f"警告: 模型中缺少以下键: {missing_keys}")
if unexpected_keys:
print(f"警告: 加载的状态字典中有未预期的键: {unexpected_keys}")

# 加载兼容的参数
model.load_state_dict(state_dict, strict=False)
print("模型加载成功")
return True

except Exception as e:
print(f"模型加载失败: {e}")
return False

5. 常见问题与解决方案

5.1 常见错误及解决方法

设备不匹配错误:

1
2
3
4
5
6
# 错误示例
model.load_state_dict(torch.load('model.pth')) # 可能出现CUDA/CPU不匹配

# 正确做法
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.load_state_dict(torch.load('model.pth', map_location=device))

模型结构不匹配:

1
2
3
4
5
6
7
8
9
# 部分加载,忽略不匹配的层
def load_partial_state_dict(model, state_dict):
model_dict = model.state_dict()
# 过滤掉不匹配的参数
filtered_dict = {k: v for k, v in state_dict.items()
if k in model_dict and v.size() == model_dict[k].size()}
model_dict.update(filtered_dict)
model.load_state_dict(model_dict)
print(f"成功加载 {len(filtered_dict)} 个参数")

内存不足问题:

1
2
3
4
5
6
7
8
9
10
# 对于大模型,可以分块加载
def load_large_model(model, filepath):
checkpoint = torch.load(filepath, map_location='cpu')

# 逐层加载以节省内存
for name, param in model.named_parameters():
if name in checkpoint:
param.data.copy_(checkpoint[name])

print("大模型加载完成")

5.2 性能优化建议

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# 使用压缩保存大模型
def save_compressed_model(model, filepath):
"""保存压缩的模型文件"""
import gzip
import pickle

with gzip.open(filepath, 'wb') as f:
pickle.dump(model.state_dict(), f)

print(f"压缩模型已保存: {filepath}")

def load_compressed_model(model, filepath):
"""加载压缩的模型文件"""
import gzip
import pickle

with gzip.open(filepath, 'rb') as f:
state_dict = pickle.load(f)

model.load_state_dict(state_dict)
print(f"压缩模型已加载: {filepath}")

5.3 实际应用示例

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
def training_with_checkpoints():
"""带检查点的完整训练示例"""
# 模型和优化器初始化
model = nn.Sequential(
nn.Linear(784, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

# 检查点管理器
checkpoint_manager = CheckpointManager()

# 尝试加载之前的检查点
start_epoch = 0
best_loss = float('inf')

checkpoint_files = [f for f in os.listdir('checkpoints') if f.endswith('.pth')]
if checkpoint_files:
latest_checkpoint = max(checkpoint_files, key=lambda x: os.path.getctime(os.path.join('checkpoints', x)))
start_epoch, best_loss = load_checkpoint(model, optimizer, os.path.join('checkpoints', latest_checkpoint))

# 训练循环
num_epochs = 100
for epoch in range(start_epoch, num_epochs):
# 训练代码...
train_loss = 0.5 # 假设的训练损失

# 定期保存检查点
if epoch % 10 == 0 or train_loss < best_loss:
checkpoint_manager.save(model, optimizer, epoch, train_loss)
if train_loss < best_loss:
best_loss = train_loss
# 保存最佳模型
torch.save(model.state_dict(), 'best_model.pth')

print(f"Epoch {epoch}, Loss: {train_loss:.4f}")

# 运行训练
if __name__ == '__main__':
training_with_checkpoints()

总结

模型参数的读写是深度学习项目中的关键技能:

  1. 保存策略:推荐保存state_dict而非完整模型,使用检查点管理训练状态
  2. 加载技巧:注意设备兼容性,处理结构不匹配问题
  3. 最佳实践:实现自动保存、版本管理和错误处理
  4. 性能优化:考虑压缩存储和内存效率

掌握这些技术能够确保训练过程的稳定性和模型的可重现性,是深度学习工程化的重要组成部分。在实际项目中,建议建立完善的模型管理流程,包括版本控制、自动备份和灾难恢复机制。