PyTorch预训练模型加载实战:从.pth文件到迁移学习避坑指南
2026/4/18 4:43:41 网站建设 项目流程

1. 从零开始加载.pth文件的完整流程

第一次用PyTorch加载预训练模型时,我盯着那个.pth文件发呆了半小时——明明按照官方文档写的代码,却总是报各种奇怪的错误。后来才发现,从下载模型到加载权重,每个环节都藏着不少坑。下面我就用SqueezeNet为例,带你完整走一遍这个流程。

先说说最常见的网络下载问题。当你运行model = models.squeezenet1_1(pretrained=True)时,程序会尝试从PyTorch服务器下载模型文件。但在国内环境下,十次有九次会碰到这样的报错:

requests.exceptions.ConnectionError: ('Connection aborted.', TimeoutError(10060, '由于连接方在一段时间后没有正确答复或连接的主机没有反应,连接尝试失败。', None, 10060, None))

这时候别急着翻墙(注意:所有操作都应在合法合规前提下进行),我有更简单的解决方案。仔细观察报错信息,会发现类似这样的下载链接:

Downloading: "https://download.pytorch.org/models/squeezenet1_1-f364aa15.pth"

把这个链接复制到浏览器,如果打不开,试试去掉https://前缀,直接访问download.pytorch.org/models/squeezenet1_1-f364aa15.pth。我实测这个方法在移动宽带和电信网络下都能成功下载。

下载完成后,你可能会遇到SSL证书验证问题。这时候需要在代码开头加上:

import ssl ssl._create_default_https_context = ssl._create_unverified_context

不过要提醒的是,这只是一个临时解决方案,在生产环境中应该配置正确的证书验证方式。

2. 模型加载的两种姿势与常见陷阱

拿到.pth文件后,新手最容易犯的错误就是直接torch.load()整个文件。用这个命令加载后,一定要先用print看看内容结构:

import torch pthfile = 'squeezenet1_1-f364aa15.pth' net = torch.load(pthfile) print(type(net)) # 输出会是OrderedDict或nn.Module

如果是OrderedDict,说明只保存了权重参数;如果是nn.Module,则是完整模型结构+参数。对于官方预训练模型,通常都是前者。这时候正确的加载姿势是:

import torchvision.models as models # 先创建空模型结构 model = models.squeezenet1_1(pretrained=False) # 然后加载权重参数 model.load_state_dict(torch.load(pthfile))

这里有个隐藏的坑:如果模型结构不匹配,会报Missing key(s) in state_dict错误。我就曾经因为用了squeezenet1_0的结构加载1_1的权重,调试了半天找不到原因。

3. 迁移学习改造实战指南

现在来到最关键的迁移学习环节。假设我们要用SqueezeNet做10分类任务,通常的操作流程是:

  1. 冻结所有底层参数
  2. 替换最后的分类层
  3. 只训练新添加的层

代码看起来很简单:

import torch.nn as nn # 加载预训练模型 model = models.squeezenet1_1(pretrained=True) # 冻结所有参数 for param in model.parameters(): param.requires_grad = False # 修改分类器 model.classifier[1] = nn.Conv2d(512, 10, kernel_size=(1,1))

但运行后你可能会遇到一个诡异的错误:

RuntimeError: shape '[25, 1000]' is invalid for input of size 50

这是因为SqueezeNet内部还有个num_classes属性没改!这个坑官方文档可没提醒,是我踩了三次才发现的。完整解决方案是:

model.classifier[1] = nn.Conv2d(512, 10, kernel_size=(1,1)) model.num_classes = 10 # 这个千万别漏!

4. 参数冻结与解冻的高级技巧

在实际项目中,我们往往不需要冻结所有层。比如对于SqueezeNet,我会选择:

  • 完全冻结前3个fire模块(特征提取层)
  • 部分解冻最后2个fire模块(特征融合层)
  • 完全解冻分类器层

具体实现代码:

# 按名称选择性冻结 for name, param in model.named_parameters(): if 'features.0' in name or 'features.3' in name or 'features.6' in name: param.requires_grad = False elif 'features.9' in name or 'features.12' in name: param.requires_grad = True # 部分解冻 else: param.requires_grad = True # 完全解冻 # 查看哪些层需要更新 params_to_update = [] for name, param in model.named_parameters(): if param.requires_grad: params_to_update.append(param) print("可训练参数:", name)

这种分层冻结策略在我的花卉分类项目中,使验证准确率提升了12%。关键是要理解网络不同层的作用——前面的卷积层提取基础特征,后面的层组合高级特征。

5. 模型保存与加载的最佳实践

训练好的模型需要妥善保存。我推荐使用以下两种方式:

  1. 保存完整模型(结构+参数):
torch.save(model, 'full_model.pth')

加载时直接model = torch.load('full_model.pth')

  1. 只保存参数(推荐):
torch.save(model.state_dict(), 'params_only.pth')

加载时需要先创建结构:

model = models.squeezenet1_1(pretrained=False) model.load_state_dict(torch.load('params_only.pth'))

特别注意:如果用第一种方式保存,加载时可能因为类定义变化导致报错。有次我升级PyTorch版本后,之前保存的模型就加载失败了。所以生产环境强烈推荐第二种方式。

6. 跨设备加载的兼容性问题

当你在GPU训练后要在CPU部署,或者反过来,会遇到经典的RuntimeError: Attempting to deserialize object on CUDA device but torch.cuda.is_available() is False。解决方案是:

# GPU保存 → CPU加载 model.load_state_dict(torch.load('gpu_model.pth', map_location=torch.device('cpu'))) # CPU保存 → GPU加载 model.load_state_dict(torch.load('cpu_model.pth', map_location='cuda:0')) model = model.cuda()

还有个更智能的写法,适合不确定部署环境的情况:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model.load_state_dict(torch.load('model.pth', map_location=device))

7. 实战中的性能优化技巧

最后分享几个提升加载效率的技巧:

  1. 使用torch.save_use_new_zipfile_serialization参数可以减小文件体积:
torch.save(model.state_dict(), 'compressed.pth', _use_new_zipfile_serialization=False)
  1. 对于大型模型,可以分块加载:
from collections import OrderedDict state_dict = torch.load('huge_model.pth') new_state_dict = OrderedDict() for k, v in state_dict.items(): if k.startswith('features.0'): # 只加载特定部分 new_state_dict[k] = v model.load_state_dict(new_state_dict, strict=False)
  1. 使用torch.jit.trace可以加速模型加载:
example_input = torch.rand(1, 3, 224, 224) traced_model = torch.jit.trace(model, example_input) torch.jit.save(traced_model, 'traced_model.pt')

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询