本文基于d2l项目内容整理,介绍深度学习中模型参数的保存与加载技术,包括最佳实践和常见问题的解决方案。
1. 模型参数读写的重要性 以 PyTorch 为例,可以使用torch.load()
和torch.save()
函数读写包括张量在内的任意 Python 对象,如:列表、字典、模型的结构与参数、优化器状态等。
为什么需要保存模型参数?
运行耗时较长的训练时,最佳的做法是定期保存模型参数,防止因意外而丢失全部计算结果。这对于深度学习项目的稳定性和可持续性至关重要。
2. 模型参数的保存 2.1 保存模型参数 基本保存方法 使用Module.state_dict()
方法可获取模型Module当前状态 (state) 的全部参数字典,包含了权重和偏置:
1 2 3 4 5 6 7 8 9 10 11 12 import torchimport torch.nn as nnmodel = nn.Sequential( nn.Linear(10 , 5 ), nn.ReLU(), nn.Linear(5 , 1 ) ) torch.save(model.state_dict(), 'model.pth' )
文件扩展名约定:
.pth
是 PyTorch 序列化文件扩展名的约定。这个扩展名帮助识别文件类型,但不是强制要求。
保存完整模型(不推荐) 1 2 torch.save(model, 'complete_model.pth' )
为什么不推荐保存完整模型?
由于模型与代码间可能的依赖关系变化,直接保存完整的模型不是推荐的做法。这种方式可能导致:
代码结构变化时无法正确加载
跨版本兼容性问题
文件体积更大
安全性风险
2.2 保存优化器状态 可以同样使用state_dict()
方法保存当前的优化器状态:
1 2 3 4 5 6 7 import torch.optim as optimoptimizer = optim.Adam(model.parameters(), lr=0.001 ) torch.save(optimizer.state_dict(), 'optimizer.pth' )
2.3 联合保存(推荐) 基本联合保存:
由于torch.save()
函数支持保存任意 Python 对象,可用字典组织优化器和参数等状态后联合保存:
1 2 3 4 5 6 7 8 state = { 'model_state_dict' : model.state_dict(), 'optimizer_state_dict' : optimizer.state_dict(), 'epoch' : 10 , 'loss' : 0.1 } torch.save(state, 'checkpoint.pth' )
完整检查点保存:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 import timedef save_checkpoint (model, optimizer, epoch, loss, filepath ): """保存训练检查点""" checkpoint = { 'model_state_dict' : model.state_dict(), 'optimizer_state_dict' : optimizer.state_dict(), 'epoch' : epoch, 'loss' : loss, 'model_architecture' : str (model), 'timestamp' : torch.tensor(time.time()) } torch.save(checkpoint, filepath) print (f"检查点已保存到: {filepath} " )
检查点 (Checkpoint) 的概念:
Checkpoint表示模型当前状态的快照。一般包括模型和优化器的状态字典、训练的轮次 (epoch)、学习率等其他超参数。这是深度学习训练中的标准做法。
3. 模型参数的加载 3.1 设备兼容性处理 设备兼容性注意事项:
若需要将来自 GPU 的模型加载到 CPU 上,应在load()
函数中指定map_location
参数:
1 2 3 4 5 model = torch.load('model.pth' , map_location=torch.device('cpu' )) model = torch.load('model.pth' , map_location='cpu' )
3.2 加载模型参数 标准加载方法 使用Module.load_state_dict()
方法可从文件对象加载状态字典(反序列化):
1 2 3 4 5 6 7 8 9 10 model = nn.Sequential( nn.Linear(10 , 5 ), nn.ReLU(), nn.Linear(5 , 1 ) ) model.load_state_dict(torch.load('model.pth' )) model.eval ()
重要提醒:
在加载模型参数时,通常应确保模型结构与保存时一致。如果结构不匹配,会导致加载失败。
加载完整模型 1 2 model = torch.load('complete_model.pth' )
3.3 加载优化器状态 1 2 3 4 5 optimizer = optim.Adam(model.parameters(), lr=0.001 ) optimizer.load_state_dict(torch.load('optimizer.pth' ))
3.4 联合加载 基本联合加载:
1 2 3 4 5 6 7 8 9 10 11 12 checkpoint = torch.load('checkpoint.pth' ) model.load_state_dict(checkpoint['model_state_dict' ]) optimizer.load_state_dict(checkpoint['optimizer_state_dict' ]) epoch = checkpoint['epoch' ] loss = checkpoint['loss' ] print (f"恢复到第 {epoch} 轮训练,损失值:{loss} " )
安全加载(带错误处理):
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 def load_checkpoint (model, optimizer, filepath ): """安全加载训练检查点""" try : checkpoint = torch.load(filepath, map_location='cpu' ) required_keys = ['model_state_dict' , 'optimizer_state_dict' , 'epoch' , 'loss' ] for key in required_keys: if key not in checkpoint: raise KeyError(f"检查点文件缺少必要的键: {key} " ) model.load_state_dict(checkpoint['model_state_dict' ]) optimizer.load_state_dict(checkpoint['optimizer_state_dict' ]) epoch = checkpoint['epoch' ] loss = checkpoint['loss' ] print (f"成功加载检查点: epoch={epoch} , loss={loss:.4 f} " ) return epoch, loss except FileNotFoundError: print (f"检查点文件不存在: {filepath} " ) return 0 , float ('inf' ) except Exception as e: print (f"加载检查点时出错: {e} " ) return 0 , float ('inf' )
4. 高级应用 4.1 训练过程中的自动保存 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 import osimport timeclass CheckpointManager : def __init__ (self, save_dir='checkpoints' , max_keep=5 ): self .save_dir = save_dir self .max_keep = max_keep os.makedirs(save_dir, exist_ok=True ) def save (self, model, optimizer, epoch, loss, metrics=None ): """保存检查点""" timestamp = int (time.time()) filename = f'checkpoint_epoch_{epoch} _{timestamp} .pth' filepath = os.path.join(self .save_dir, filename) checkpoint = { 'model_state_dict' : model.state_dict(), 'optimizer_state_dict' : optimizer.state_dict(), 'epoch' : epoch, 'loss' : loss, 'timestamp' : timestamp } if metrics: checkpoint['metrics' ] = metrics torch.save(checkpoint, filepath) self ._cleanup_old_checkpoints() return filepath def _cleanup_old_checkpoints (self ): """清理旧的检查点文件""" checkpoints = [] for filename in os.listdir(self .save_dir): if filename.startswith('checkpoint_' ) and filename.endswith('.pth' ): filepath = os.path.join(self .save_dir, filename) checkpoints.append((filepath, os.path.getctime(filepath))) checkpoints.sort(key=lambda x: x[1 ], reverse=True ) for filepath, _ in checkpoints[self .max_keep:]: os.remove(filepath) print (f"删除旧检查点: {filepath} " ) checkpoint_manager = CheckpointManager() for epoch in range (num_epochs): if epoch % 10 == 0 : checkpoint_manager.save(model, optimizer, epoch, train_loss)
4.2 模型版本管理 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 class ModelVersionManager : def __init__ (self, base_path='models' ): self .base_path = base_path os.makedirs(base_path, exist_ok=True ) def save_model (self, model, version, metadata=None ): """保存模型版本""" version_dir = os.path.join(self .base_path, f'v{version} ' ) os.makedirs(version_dir, exist_ok=True ) model_path = os.path.join(version_dir, 'model.pth' ) torch.save(model.state_dict(), model_path) if metadata: metadata_path = os.path.join(version_dir, 'metadata.json' ) import json with open (metadata_path, 'w' ) as f: json.dump(metadata, f, indent=2 ) print (f"模型版本 v{version} 已保存到 {version_dir} " ) def load_model (self, model, version ): """加载指定版本的模型""" model_path = os.path.join(self .base_path, f'v{version} ' , 'model.pth' ) if os.path.exists(model_path): model.load_state_dict(torch.load(model_path, map_location='cpu' )) print (f"成功加载模型版本 v{version} " ) return True else : print (f"模型版本 v{version} 不存在" ) return False version_manager = ModelVersionManager() metadata = { 'accuracy' : 0.95 , 'loss' : 0.05 , 'training_time' : '2 hours' , 'dataset' : 'CIFAR-10' } version_manager.save_model(model, '1.0' , metadata)
4.3 跨平台兼容性 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 def save_model_portable (model, filepath ): """保存可跨平台使用的模型""" model_cpu = model.cpu() torch.save(model_cpu.state_dict(), filepath, _use_new_zipfile_serialization=False ) print (f"模型已保存为跨平台兼容格式: {filepath} " ) def load_model_safe (model, filepath ): """安全加载模型,处理各种异常情况""" try : state_dict = torch.load(filepath, map_location='cpu' ) model_keys = set (model.state_dict().keys()) loaded_keys = set (state_dict.keys()) missing_keys = model_keys - loaded_keys unexpected_keys = loaded_keys - model_keys if missing_keys: print (f"警告: 模型中缺少以下键: {missing_keys} " ) if unexpected_keys: print (f"警告: 加载的状态字典中有未预期的键: {unexpected_keys} " ) model.load_state_dict(state_dict, strict=False ) print ("模型加载成功" ) return True except Exception as e: print (f"模型加载失败: {e} " ) return False
5. 常见问题与解决方案 5.1 常见错误及解决方法 设备不匹配错误:
1 2 3 4 5 6 model.load_state_dict(torch.load('model.pth' )) device = torch.device('cuda' if torch.cuda.is_available() else 'cpu' ) model.load_state_dict(torch.load('model.pth' , map_location=device))
模型结构不匹配:
1 2 3 4 5 6 7 8 9 def load_partial_state_dict (model, state_dict ): model_dict = model.state_dict() filtered_dict = {k: v for k, v in state_dict.items() if k in model_dict and v.size() == model_dict[k].size()} model_dict.update(filtered_dict) model.load_state_dict(model_dict) print (f"成功加载 {len (filtered_dict)} 个参数" )
内存不足问题:
1 2 3 4 5 6 7 8 9 10 def load_large_model (model, filepath ): checkpoint = torch.load(filepath, map_location='cpu' ) for name, param in model.named_parameters(): if name in checkpoint: param.data.copy_(checkpoint[name]) print ("大模型加载完成" )
5.2 性能优化建议 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 def save_compressed_model (model, filepath ): """保存压缩的模型文件""" import gzip import pickle with gzip.open (filepath, 'wb' ) as f: pickle.dump(model.state_dict(), f) print (f"压缩模型已保存: {filepath} " ) def load_compressed_model (model, filepath ): """加载压缩的模型文件""" import gzip import pickle with gzip.open (filepath, 'rb' ) as f: state_dict = pickle.load(f) model.load_state_dict(state_dict) print (f"压缩模型已加载: {filepath} " )
5.3 实际应用示例 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 def training_with_checkpoints (): """带检查点的完整训练示例""" model = nn.Sequential( nn.Linear(784 , 128 ), nn.ReLU(), nn.Linear(128 , 10 ) ) optimizer = optim.Adam(model.parameters(), lr=0.001 ) criterion = nn.CrossEntropyLoss() checkpoint_manager = CheckpointManager() start_epoch = 0 best_loss = float ('inf' ) checkpoint_files = [f for f in os.listdir('checkpoints' ) if f.endswith('.pth' )] if checkpoint_files: latest_checkpoint = max (checkpoint_files, key=lambda x: os.path.getctime(os.path.join('checkpoints' , x))) start_epoch, best_loss = load_checkpoint(model, optimizer, os.path.join('checkpoints' , latest_checkpoint)) num_epochs = 100 for epoch in range (start_epoch, num_epochs): train_loss = 0.5 if epoch % 10 == 0 or train_loss < best_loss: checkpoint_manager.save(model, optimizer, epoch, train_loss) if train_loss < best_loss: best_loss = train_loss torch.save(model.state_dict(), 'best_model.pth' ) print (f"Epoch {epoch} , Loss: {train_loss:.4 f} " ) if __name__ == '__main__' : training_with_checkpoints()
总结
模型参数的读写是深度学习项目中的关键技能:
保存策略 :推荐保存state_dict而非完整模型,使用检查点管理训练状态
加载技巧 :注意设备兼容性,处理结构不匹配问题
最佳实践 :实现自动保存、版本管理和错误处理
性能优化 :考虑压缩存储和内存效率
掌握这些技术能够确保训练过程的稳定性和模型的可重现性,是深度学习工程化的重要组成部分。在实际项目中,建议建立完善的模型管理流程,包括版本控制、自动备份和灾难恢复机制。