基于强化学习的超参数优化方法
2020-04-10陈森朋陈修云
陈森朋,吴 佳,陈修云
(电子科技大学 信息与软件工程学院,成都 610054)
1 引 言
近年来,机器学习算法已成功应用于众多领域,但同时也面临着巨大挑战.诸如随机森林(Random Forest)[1]、XGBoost[2]和支持向量机(Support Vector Machines)[3]等机器学习算法在实际应用的过程中存在繁琐的超参数优化过程.
超参数优化对机器学习算法的性能起着至关重要的作用,然而机器学习算法的性能和超参数之间的函数关系尚不明确.在实际应用中,往往通过不断调整超参数的值来提高机器学习算法的实践性能.当机器学习算法的超参数空间较大时,优化过程将非常耗时和低效.因此,超参数优化成为了机器学习算法应用中的难点之一.
针对上述问题,本文提出了一种基于强化学习的超参数优化方法(图1).该方法将超参数优化问题抽象为序列决策过程,即分步选择待优化算法的超参数,这样超参数选择过程可建模为马尔科夫决策过程(Markov Decision Process-MDP),进而采用强化学习来求解.具体的,该方法利用长短时记忆神经网络(Long Short-Term Memory Neural Network,LSTM)[4]构建一个智能体(agent)来代替算法使用者设置超参数的值;然后,agent在训练集上训练算法模型,并在验证数据集上得到该算法模型的验证集性能,并以此为奖赏信号,利用策略梯度算法(Policy Gradient)[5]优化agent的决策.
本文结构如下:第2节介绍了超参数优化问题的定义及相关工作;第3节详细描述了本文所提出的超参数优化方法以及如何减小训练方差;第4节针对两个具有代表性的机器学习算法,将本文所提出的方法与五种常用超参数优化方法进行对比,并且讨论了agent结构和数据引导池的有效性;第5节总结全文并展望未来工作.
图1 基于强化学习的超参数优化方法Fig.1 Hyperparameter optimization method basedon deep reinforcement learning
2 背景及相关工作
超参数优化问题(HPO)的通常定义为:假设一个机器学习算法M有N个超参数,第n个超参数空间为Λn,那么算法的超参数搜索空间为Λ= Λ1×Λ2×…ΛN.Mλ表示超参数为λ的算法,其中向量λ∈Λ为算法M的一个超参数组合.当给定数据集D,HPO问题的优化目标为最优的超参数组合λ*:
λ*=argminE(Dtrain,Dvalid)~DL(Mλ,Dtrain,Dvalid)
(1)
其中,Dtrain和Dvalid分别表示训练集和验证集;L(Mλ,Dtrain,Dvalid)表示算法Mλ在数据集D上的交叉验证误差,以此作为损失函数值.
近年来,具有代表性的超参数优化方法有随机搜索(Random Search)、贝叶斯优化(Bayesian Optimization),TPE(Tree-structured Parzen Estimator)以及自适应协方差矩阵进化策略(CMA-ES)算法.随机搜索方法[6]在超参数搜索空间中随机采样,执行效率高且操作简单,经过多次搜索可以获得性能较好的超参数组合.然而,随机搜索方法稳定性较差,且只有在达到或接近最优值的超参数组合的比重超过5%时,搜索效率较高.自适应协方差矩阵进化策略(CMA-ES)算法[7]是一种基于进化算法的改进算法,主要用来解决非线性、非凸的优化问题,但算法运行具有一定的随机性,优化性能不稳定.贝叶斯优化[8,9]方法使用高斯过程对代理函数进行建模,以一组超参数λ为条件对优化目标y进行建模,形成先验模型P(y|λ).虽然该方法能够达到很好优化结果,但是随着迭代次数增加,优化过程耗费大量时间.文献[10]实验证明了基于高斯过程的贝叶斯优化方法在一些标准任务上优于随机搜索方法.另一种贝叶斯优化的变体是基于序列模型的优化方法(SMAC)[11],该方法使用随机森林对代理函数进行建模.与基于高斯过程的贝叶斯优化方法类似,TPE[12]是一种基于树状结构Parzen密度估计的非标准贝叶斯优化算法,也能达到很好的优化性能.
相比于上述工作,本文的创新点主要有以下几点:
1)将超参数优化问题抽象为序列决策问题并建模为MDP,分步选择超参数,提高优化效率;
2)采用强化学习智能体(agent),并使用策略梯度算法进行训练以避免直接求解超参数优化的黑盒目标函数,从而搜索到最优超参数组合;
3)提出数据引导池技术,降低训练方差,提高方法稳定性.
3 基于强化学习的超参数优化方法
3.1 整体结构
针对超参数优化问题(HPO),本文提出了一种基于强化学习的优化方法.该方法将超参数优化问题抽象为序列决策问题(即每次决策只选择一个超参数)是基于以下原因:
1)一个复杂问题通常通过分解成多个易于求解的子问题来解决.由于一个复杂机器学习算法具有巨大的超参数空间,同时进行所有超参数的选择极具困难.
2)相反的,如果agent分步进行超参数选择,整个搜索空间可大大缩小,从而提高搜索效率.
我们将上述的序列决策过程建模为MDP,即M=(S,A,P,R):
·S表示状态集合,st∈S,st表示t时刻环境的状态,即agent的输入;
·A表示动作集合,at∈A,at表示t时刻的agent选择的动作,即超参数选择;
·P表示在当前状态s下,执行动作a后,环境转移到下一状态的概率.在HPO问题中它是未知的;
·R表示reward函数,R:S×A→R,R表示在当前状态s下执行动作a的奖励值,即为超参数配置的验证集准确度.
Agent的目标是找到一个策略π:S→A使得累积收益最大化.Agent工作流程如下:对每一次迭代,agent以概率P为算法模型选择一组超参数λ;然后在训练数据集Dtrain上训练算法模型Mλ;最后将Mλ在验证数据集Dvalid上的准确率作为奖赏值,并利用策略梯度算法[5]来更新策略.经过多次训练,agent会以更高的概率选择准确率高的超参数配置.为了确保该方法具有更好的稳定性,提出了数据引导池以减小训练方差.
3.2 详细设计
3.2.1 Agent结构设计
根据3.1节,我们将超参数优化问题看作一个序列决策问题,即每个时刻针对某个超参数进行选择,因此不同时刻优化了不同的超参数,这样可以大大减少每次决策的搜索空间.为了更加清晰的说明序列选择超参数的优势,我们将进一步分析超参数优化的搜索空间.假设一个算法具有N个待优化的超参数.一种简单的方法是将超参数优化问题看作一个多臂机问题(multi-armed bandit problem),直接在整个超参数搜索空间中选择整个超参数配置,则决策的搜索空间为:Λ=Λ1×Λ2×…ΛN(×表示笛卡尔乘积).相反,如果我们将超参数优化问题作为序列决策问题,基于前一次决策顺序的选择每一个超参数,则决策的搜索空间为:Λ=Λ1∪Λ2∪…ΛN.显然,后者能够大大缩减超参数优化问题的搜索空间,从而提高优化效率.
为了适应顺序选择超参数的方法,我们将agent设计为自循环的结构.每次循环时,我们将agent上一次的输出作为agent下一次的输入,以保持超参数优化的整体性.同时,由于超参数之间可能存在相关性,也就是每个时刻的选择可能是相互关联的.若只将超参数优化问题分步进行,而不考虑超参数之间的内部关系,超参数的优化顺序则会成为一个影响因素.基于上述特点,我们利用LSTM构造了一个强化学习agent(图2).使用LSTM网络作为agent的核心结构的主要原因在于:LSTM网络独特的内部设计能够使agent保留或遗忘超参数之间的内在联系,从而有利于超参数选择,也避免了由于超参数优化顺序而造成的影响.尽管LSTM 网络的训练比较困难,但是LSTM网络被认为是解决时序问题的最好结构.
图2 Agent结构图Fig.2 Structure of agent
图2展示了agent内部结构,图中左边部分表示agent整体结构,右边部分 (“=”右)表示按时间步展开的agent结构.Agent的核心结构由3层LSTM网络构成,且输入、输出与LSTM网络之间各有一个全连接层,该全连接层用来调整前后输入和输出的维度.在每一时刻t(t∈[1,T],T为待优化模型的超参数个数),agent选择一个超参数at,并将at的one-hot编码作为下一时刻agent的输入,也就是t+1时刻状态st+1为at.在t=1时刻,agent输入状态s1为全1向量.
通过这样的设计,agent在不同时刻只需选择对应的超参数,减小了超参数的搜索空间.同时,由于将前一时刻的输出作为下一时刻的输入,使得采用LSTM网络作为核心结构的agent能够学习超参数之间的潜在关系.
3.2.2 Agent训练
策略梯度方法[5]使用逼近器(函数)来近似表示策略,通过不断计算策略期望的总奖赏并基于梯度来更新策略参数,最终收敛于最优策略.它的优点非常明显:能够直接优化策略的期望总奖赏,并以端对端的方式直接在策略空间中搜索最优策略,省去了繁琐的中间环节.因此,本文采用策略梯度方法训练agent.
假设θ表示agent的模型参数;R表示agent在每次选择超参数组合a1:T后,与所选择的超参数组合结合的待优化模型在验证数据集上的准确率.定义期望的总奖赏值为:
J(θ)=EP(a1:T;θ)[R]
(2)
其中,P(a1:T;θ)表示表示参数为θ的agent输出超参数组合a1:T的概率.
Agent的训练目标是找到一个合理的参数θ使得期望奖赏值J(θ)最大化:
(3)
(4)
(5)
其中,T为待优化算法的超参数个数;Ri为在第i个超参数组合下模型的k-折交叉验证结果;b是基准值,即模型交叉验证结果的指数移动平均值.
3.2.3 数据引导池(Boot Pool)模块
在使用本文所提出的方法进行超参数优化时,虽然添加了基线函数b减小训练误差,但是仍存在训练方差较大的问题,造成其优化结果稳定性较差.为此,我们提出了数据引导池模块.
数据引导池是一个固定大小的存储区域,用来保存最优的K条(top-K)超参数组合及对应奖励值.在agent训练过程中,引导池中的数据会根据新的采样数据进行实时更新,并定期提供给agent进行学习.若K过大,则使得引导过强,陷入前期较差的局部最优值;若K过小,则引导力度变弱,策略更多的进行探索,从而导致训练不稳定.事实上,通过对参数K的调整来平衡策略的利用和探索.
4 实验结果及分析
在实验中,我们将随机森林和XGBoost两种算法作为超参数优化对象,使用UCI数据库中的五个标准数据集作为实验数据集(表1).为了验证本文提出方法的性能,我们将本文所提出的方法与随机搜索优化方法、基于贝叶斯的优化方法、TPE优化方法、CM-AES优化方法和SMAC优化方法进行了对比.此外,通过一系列消融实验来验证agent结构和数据引导池的有效性.
4.1 实验细节
数据集:实验中,我们选择五个大小各异的UCI数据集作为优化任务(详细信息见表1).UCI数据集是常用的、种类丰富的数据集.在实验中,每个数据集按照8:2的比例分成训练集和测试集两部分.实验在训练集下采用5-折交叉验证的方法训练待优化模型;训练完成后,使用测试集测试超参数优化方法的最终性能.
参数设置:在实验中,所有参数均是选择多个随机种子中的最优参数.针对不同的优化任务,我们设置了不同的学习率α和数据引导池大小K(详细信息见表1).基准函数的折扣系数γ设置为0.8.以-0.2与0.2之间的随机值对网络中的权重进行初始化.
搜索空间:实验中我们选择对随机森林(6个超参数)和XGBoost(10个超参数)两种分类算法进行超参数优化(详细信息见表2),随机森林和XGBoost算法的具体实现基于scikit-learn[13].选择上述两种算法进行优化主要是由于:
1)文献[14]中评估了179种机器学习分类算法在UCI数据集上的表现,实验结果表明随机森林分类算法是最优的分类器;XGBoost算法具有更多的待优化超参数,并且解决分类任务具有很大的潜力;
表1 数据集信息及对应参数设置表
Table 1 Data sets information and parameter settings
编号数据集样本量特征数K学习率UCI-1Breast Cancer569280.0007UCI-2Optdigits5,6206480.0008UCI-3Crowdsourced Mapping10,8462880.001UCI-4Letter Recognition20,00016160.001UCI-5HTRU_217,898980.001
2)两种算法均属于先进的分类算法,广泛应用在数据科学竞赛和工业界.
表2 随机森林算法和XGBoost算法的超参数搜索空间
Table 2 Hyperparameters search spaces of the random forest
and the XGBoost
算法超参数范围间隔类型Random Forestn_estimators[100,1200]100intmax_depth[3,30]3intmin_samples_split[0,100]5intmin_samples_leaf[0,100]5intmax_features[0.1,0.9]0.1floatbootstrapTrue,False-boolXGBoostmax_depth[3,25]2intlearning_rate[0.01,0.1]0.01floatn_estimators[100,1200]100intgamma[0.05,1.0]0.01floatmin_child_weight[1,9]2intsubsample[0.5,1.0]0.1floatcolsample_bytree[0.5,1.0]0.1floatcolsample_bylevel[0.5,1.0]0.1floatreg_alpha[0.1,1.0]0.1floatreg_lambda[0.01,1.0]0.01float
4.2 Agent结构的有效性
本小节中,我们将验证agent结构的有效性,即验证将超参数优化问题作为序列决策问题的正确性.实验中,我们所提出的方法简称为BP-Agent,同时也设计了对比方法BP-FC:该方法使用全连接网络(FC)作为agent的核心结构,并且直接使用全连接网络一次输出所有超参数的选择,而不是逐步选择超参数.为了满足对比实验的公平性,我们确保BP-FC方法中的全连接网络的可训练参数的数量与本文提出的方法的可训练参数量大致相等.另外,该方法也采用了引导池技术(BP)来减小训练过程的方差.为充分利用计算资源,我们在UCI-(1-4)数据集上进行对比实验,每组对比实验独立执行3次,每种优化方法每次独立运行300分钟.实验结果如图3和图4所示.图中,分别展示了本文所提出的方法(BP-Agent)和对比方法(BP-FC)在验证集上的训练过程.我们可以看出:BP-FC方法使用全连接网络直接输出所有超参数的选择,在部分任务上具有优化效果,但优化效果较差,并且优化效率低;相比于BP-Agent方法,BP-Agent方法具有更好的优化效果和稳定性,也具有更高的优化效率.因此,上述实验证明将超参数优化问题序列化.并逐步选择超参数的agent设计有利于提高优化性能.
图3 不同agent结构在四个UCI数据集上优化随机森林的性能比较图Fig.3 Performance comparison of agents with different structures for optimizing Random forests on four UCI datasets
图4 不同agent结构在四个UCI数据集上优化XGBoost的性能比较图Fig.4 Performance comparison of agents with different structures for optimizing XGBoost on four UCI datasets
4.3 数据引导池模块对优化结果的影响
为了验证数据引导池的有效性,我们设计了BP-Agent方法(含有BP模块)与Agent方法(不含有BP模块)的对比实验.我们在UCI-(1-4)数据集下对随机森林和XGBoost算法的超参数进行优化,每种优化方法在每个优化任务上独立运行5次,对比5次优化的平均性能.
图5 BP-Agent和Agent方法在四个UCI数据集上优化随机森林的性能比较图Fig.5 Performance comparison of the BP-Agent and the Agent for optimizing Random forests on four UCI datasets
实验结果以箱型图的形式展示,如图5和图6所示.通过观察可以发现:Agent方法能够达到很好优化效果(即箱型图的中位数),但是其稳定性较差(即箱型图的触须);相比于Agent方法,BP-Agent方法具有更好的优化结果,并且其稳定性较好.因此,可以得出以下结论:添加方向引导池能够把握优化方向,增强方法的稳定性.
图6 BP-Agent和Agent方法在四个UCI数据集上优化XGBoost的性能比较图Fig.6 Performance comparison of the BP-Agent and the Agent for optimizing XGBoost on four UCI datasets
4.4 对比BP-Agent方法与其他优化方法
为了进一步验证本文所提出的方法,我们将其与常用的且具有代表性的五种优化方法(随机搜索,TPE,贝叶斯优化,CM-AES,SMAC)进行对比.除此之外,我们也将对比随机森林和XGBoost两个算法默认超参数配置的性能,默认的超参数配置基于scikit-learn[13].实验在UCI-(1-5)数据集上分别优化随机森林和XGBoost两个分类算法的超参数,因此共包含10个优化任务.同样的,为充分利用计算资源,每组对比实
1https://github.com/hyperopt/hyperopt-sklearn
2https://github.com/AIworx-Labs/chocolate
3https://github.com/mlindauer/SMAC3
验独立执行3次,每种优化方法每次独立运行300分钟.随机搜索、TPE和贝叶斯优化三种方法的具体实现基于Hyperopt1,CM-AES方法的具体实现基于Chocolate2,SMAC方法的具体实现基于SMAC33.
对比指标选取的是待优化模型在测试集上的错误率(如表“Err”所示).实验结果以3次对比实验的Err平均值和方差进行展示(详细实验结果见表3),不仅能够表示待优化模型在测试集上的准确度,还能够反映优化方法的稳定性.通过观察表中实验数据,可以看出:所有的优化方法在大部分优化任务上都能得到优于默认参数性能的超参数配置.具体的,在10个优化任务中,贝叶斯优化、CM-AES和SMAC三种优化方法都达到了很好优化结果,且具有很好的稳定性,而随机搜索和TPE两种优化方法的优化性能相对较差;相比之下,BP-Agent方法在8个优化任务中分别达到了最好的优化结果和稳定性.
表3 六种超参数优化方法的性能对比表
Table 3 Performance comparison of five
HPO optimization methods
数据集优化算法随机森林ErrXGBoostErrUCI-1随机搜索0.0774±0.02120.0862±0.0198TPE0.0594±0.01490.0563±0.0101贝叶斯优化0.0507±0.00610.0477±0.0096CM-AES0.0521±0.0050.0473±0.0083SMAC0.0479±0.0160.0561±0.037BP-Agent0.0472±0.00210.0452±0.0019默认参数0.05480.0523UCI-2随机搜索0.0725±0.01800.0443±0.0098TPE0.0562±0.01870.0403±0.0078贝叶斯优化0.0553±0.00220.0419±0.0028CM-AES0.0561±0.00690.0547±0.0041SMAC0.0566±0.00310.0434±0.0047BP-Agent0.0544±0.00150.0393±0.0016默认参数0.08110.0593UCI-3随机搜索0.0187±0.01910.0169±0.0098TPE0.0186±0.01650.0179±0.0078贝叶斯优化0.0169±0.00390.0160±0.0028CM-AES0.0165±0.00570.0167±0.0017SMAC0.0171±0.01030.0154±0.0035BP-Agent0.0160±0.00470.0151±0.0016默认参数0.03700.0214UCI-4随机搜索0.0520±0.07280.0619±0.0250TPE0.1239±0.07020.0570±0.0111贝叶斯优化0.0530±0.02940.0596±0.0045CM-AES0.0473±0.00690.0588±0.0057SMAC0.0471±0.00610.0603±0.0039BP-Agent0.0499±0.00550.0564±0.0028默认参数0.10110.1293UCI-5随机搜索0.0191±0.01020.0204±0.0157TPE0.0196±0.00790.0174±0.0103贝叶斯优化0.0153±0.00810.0162±0.0076CM-AES0.0141±0.00410.0159±0.0044SMAC0.0157±0.00490.0160±0.0053BP-Agent0.0131±0.00360.0128±0.0039默认参数0.02190.0202
另外,我们对实验结果进行统计检验.假设显著性水平α=0.05,检验结果显示:在具有优势的8个优化任务中,BP-Agent的性能提升均具有显著性差异(P<0.05).
上述实验表明本文所提的BP-Agent方法能够得到更好优化结果,且具有最好的稳定性.
4.5 讨论与分析
对于超参数优化问题,当前工作主要分类三类:基础搜索方法[10]、基于采样的方法[15,16]和基于梯度的方法[17-19].虽然当前新方法层出不穷,超参数优化问题仍面临以下难点:
1)优化目标属于黑盒函数.对于给定任务,超参数选择与性能表现之间的函数无法显式表达.
2)搜索空间巨大.由于每种待优化算法都有相应的超参数空间,选择的可能性是指数级的.
3)耗费巨大的资源.当评估所选择的超参数配置时,需要进行完整的训练过程并在测试集上测试最终性能,整个优化过程耗费大量计算资源和时间.
通过实验可以看出,本文所提出的方法能够在大部分任务达到最好的优化结果,并具有很好的稳定性.我们认为主要原因在于:在超参数选择过程中,由于逐个选择超参数,因此每次选择只需针对当前超参数的搜索空间进行探索,而不需要搜索整个超参数空间,这样可以极大地提高搜索效率;同时,我们选择LSTM网络作为agent的核心结构,使agent能够在分步决策过程中学习超参数选择的内在联系;另外,训练过程中添加了数据引导池(BP)模块,在一定程度上平衡了策略的探索和利用,使得优化方法性能更加稳定.
5 结束语
随着机器学习的广泛应用,快速高效的解决超参数优化问题(HPO)越来越重要.针对超参数优化问题(HPO),本文提出了一种基于强化学习的超参数优化方法.该方法将超参数优化问题看作序列决策问题,即将复杂问题分解为多个易于求解的子问题来解决.进一步将该问题抽象为MDP,利用强化学习算法来求解该问题.具体的,以LSTM网络为核心构造agent,逐步为待优化的机器学习算法选择超参数.Agent的动作(action)为超参数的选择;agent的输入,即状态(state)为上一时刻的动作选择;待优化算法在验证数据集上的准确率作为奖赏值(reward).
为了验证所提出方法的有效性,我们选择了五个UCI数据集,分别对随机森林和XGBoost这两种算法的超参数进行优化.通过对比随机搜索、TPE、贝叶斯优化、CM-AES和SMAC五种具有代表性的超参数优化方法,我们发现本文提出的方法在优化结果和稳定性上均优于对比方法.同时,一系列消融实验验证了agent结构和数据引导池的有效性.