深度学习系列之第七课卷积神经网络_CNN_调整学习率
2026/4/3 13:49:30 网站建设 项目流程

目录

简介

一、调整学习率

1.有序调整学习率

1.1StepLR(等间隔调整学习率)

1.2MultiStepLR(多间隔调整学习率)

1.3 ExponentialLR (指数衰减调整学习率)

1.4CosineAnnealing (余弦退火函数调整学习率)

2.自适应调整

2.1ReduceLROnPlateau (根据指标调整学习率)

3.自定义调整

3.1LambdaLR (自定义调整学习率)

二、代码分析

1. 导入必要的库

2. 数据预处理部分

3. 自定义数据集类

4. 数据加载器

5. 设备配置

6. 定义 CNN 模型

7. 训练函数

8. 测试函数

9. 训练配置和执行

简介

之前我们对数据进行增强、有保存和使用最佳模型,今天我们再对模型进行最后的优化,就是调整我们的学习率,在这之前我们一直使用的是固定的学习率来训练模型。

深度学习系列之第五课卷积神经网络_CNN_如何训练自己的数据集(暨食物分类案例)

[深度学习之第六课卷积神经网络 (CNN)如何保存和使用最优模型][_CNN 1]

一、调整学习率

Pytorch学习率调整策略通过 torch.optim.lr_sheduler 接口实现。并提供3种调整方法:

(1)有序调整:等间隔调整(Step),多间隔调整(MultiStep),指数衰减(Exponential),余弦退火(CosineAnnealing);

(2)自适应调整:依训练状况伺机而变,通过监测某个指标的变化情况(loss、accuracy),当该指标不怎么变化时,就是调整学习率的时机(ReduceLROnPlateau); (

(3)自定义调整:通过自定义关于epoch的lambda函数调整学习率(LambdaLR)。

1.有序调整学习率

1.1StepLR(等间隔调整学习率)
torch.optim.lr_scheduler.StepLR(optimizer,step_size,gamma=0.1)

参数:

optimizer: 神经网络训练中使用的优化器,如optimizer=torch.optim.Adam(…)

step_size(int): 学习率下降间隔数,单位是epoch,而不是iteration.

gamma(float):学习率调整倍数,默认为0.1 每训练step_size个epoch,学习率调整为lr=lr*gamma.

1.2MultiStepLR(多间隔调整学习率)
torch.optim.lr_shceduler.MultiStepLR(optimizer,milestones,gamma=0.1)

参数:

milestone(list): 一个列表参数,表示多个学习率需要调整的epoch值,如milestones=[10, 30, 80].

1.3 ExponentialLR (指数衰减调整学习率)
torch.optim.lr_scheduler.ExponentialLR(optimizer,gamma)

参数:

gamma(float):学习率调整倍数的底数,指数为epoch,初始值我lr, 倍数为

1.4CosineAnnealing (余弦退火函数调整学习率)
torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max,eta_min=0)

参数:

Tmax(int):学习率下降到最小值时的epoch数,即当epoch=T_max时,学习率下降到余弦函数最小值,当epoch>T_max时,学习率将增大;

etamin: 学习率调整的最小值,即epoch=Tmax时,lrmin=etamin, 默认为0.

2.自适应调整

2.1ReduceLROnPlateau (根据指标调整学习率)
torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.1,patience=10,verbose=False,threshold=0.0001,threshold_mode='rel',cooldown=0,min_lr=0,eps=1e-08)

3.自定义调整

3.1LambdaLR (自定义调整学习率)
torch.optim.lr_scheduler.LambdaLR(optimizer,lr_lambda)

参数:

lr_lambda(function or list): 自定义计算学习率调整倍数的函数,通常时epoch的函数,当有多个参数组时,设为list.

二、代码分析

1. 导入必要的库
importtorchfromtorch.utils.dataimportDataLoader,DatasetfromPILimportImagefromtorchvisionimporttra

需要专业的网站建设服务?

联系我们获取免费的网站建设咨询和方案报价,让我们帮助您实现业务目标

立即咨询