APP下载

SqueezeNet和动态网络手术的脱机手写汉字识别

2021-03-21周於川谭钦红奚川龙

小型微型计算机系统 2021年3期
关键词:剪枝手写准确率

周於川,谭钦红,奚川龙

(重庆邮电大学 通信与信息工程学院,重庆 400065)

1 引 言

脱机手写汉字识别在50年以来的研究中,在票据自动识别、手写汉字录入、自动化教学办公等领域具有重要意义;相较于于打印字,人们有着风格迥异的字体,更难于提取和识别特征.传统方法中,MQDF[1]和DLQDF[2]在CASIA[3]数据集上有好的效果,达到近93%的准确率,但传统方法逐渐达到瓶颈.基于卷积神经网络(CNN)的模型在模式识别方面效果很好,用于脱机手写汉字识别中的CNN模型也有许多,基于CNN的Fujitsu[4]模型在测试集ICDAR-2013以94.77%准确率获得冠军;轮换训练松弛卷积神经网络(ATR-CNN)[5]最新模型达到3.94%的错误率进一步缩小机器和人眼识别差距;HCCR-Gabor-GoogLeNet(HEGL)[6]在Fujitsu基础上修改后准确率达到96.58%,HEGL在损失部分速度和存储情况下达到96.74%的准确率;基于ResNet的倾斜校正网络[7]更是达到了98.4%准确率.

尽管基于CNN模型的手写汉字识别在准确率上取得了很大提升,但是其运算资源、功耗和存储空间要求大、参数多、训练复杂、难于进行分布式训练;很难将相应模型部署于硬件资源有限的ARM板和FPGA等嵌入式平台中.本文为实现有限资源条件下的手写汉字识别,在保证模型预测性能良好的情况下,尽可能减小模型的体量.

压缩CNN模型体积常用方法[8]有5类,分别是网络剪枝、参数共享、量化、网络蒸馏和紧凑网络设计,都可以得到明显的压缩效果.其中紧凑网络改进了网络参数量和计算量较多的卷积,Iandola等提出SqueezeNet[9],旷视科技提出的ShuffleNet[10],谷歌团队的MobileNet[11],以及Francois等提出的Xception[12]都是在卷积设计上做了相关工作.

其中SqueezeNet将FireModule引入AlexNet卷积模型,在保证准确率较好的情况下,将模型压缩50倍,并成功应用于嵌入式平台.本文对脱机手写汉字模型压缩进行研究,将SqueezeNet模型修改后,加入动态网络手术(Dynamic Network Surgery)[13]对模型参数压缩,包含裁剪和修复,压缩参数同时保证模型准确率.

2 SqueezeNet卷积神经网络模型

SqueezeNet是基于AlexNet的卷积神经网络模型,设计更少参数的CNN模型,减少大量参数的同时,仍然拥有接近AlexNet网络的准确率.

SqueezeNet核心在于FireModule,小卷积核代替部分大卷积核,当分别用5*5和3*3卷积核对5*5*1图像进行卷积,前者产生25个参数,25次计算,后者会产生18个参数,90次计算,但计算机读取内存的速度远慢于乘法计算,参数量少的小卷积核卷积速度更快,故文中采用1*1替代部分3*3会加快卷积速度,剩余3*3卷积核保证收敛速度.

如图1所示,分为Squeeze层和Expand层,其中Squeeze层是S个1*1卷积核的卷积层,Expand层是e1个1*1和e2个3*3卷积核的卷积层,其激活层都是ReLU.其中FireMoudle输入特征图大小为H*W*M,输出特征图大小为H*W*(e1+e2),变化的仅是维数,并未改变其分辨率.首先H*W*M的特征图经过Squeeze层,得到S个特征图,S均是小于M,起到压缩效果;在Expand层,H*W*S分别用e1个1*1卷积核和e2个3*3卷积核进行卷积,并将两部分卷积结果进行合并,得到H*W*(e1+e2)大小的输出结果,所选取e1+e2的值要求大于M,因此FireMoudle增加了输入的维数.其中S、e1、e2是可调参数,都是代表卷积核个数,也反映输出特征图的维数,文中取e1=e2=4S.

此外,模块中采用下采样操作来保证卷积层具有更大的激活函数,在有限网络参数的条件下保证模型精度.

3 改进SqueezeNet模型

3.1 网络结构

如图2所示,SqueezeNet网络结构设计思想与传统卷积神经网络结构类似,通过堆叠卷积操作来实现,只是SqueezeNet堆叠的是FireMoudle.

本文改进的SqueezeNet模型,在原模型的基础上进行3个部分的改进:1)将最大池化层加入下层FireMoudle层进行融合,改善小卷积核的过拟合问题,这个过程中保证最大池化

图2 原始和改进的SqueezeNet网络结构设计Fig.2 Original and improved SqueezeNet structure design

层特征图和融合的FireMoudle特征图大小匹配;2)针对FireMoudle层的特征图参数,采用动态压缩网络手术算法动态连接修剪、降低网络复杂度;3)采用L2范数约束[14]的Softmax代替原先的Softmax进行分类,通过正则化来实现更好的约束效果.模型参数见表1.

表1 改进SqueezeNet模型参数Table 1 Improved SqueezeNet module parameters

3.2 动态网络手术

常用的模型参数裁剪算法是通过阈值来删除不重要的参数来压缩CNN模型,但是参数重要性往往伴随着网络性能而变化,也就导致两个常见问题:1)有可能将重要的参数删除,降低模型精度;2)时间很长,收敛过慢.动态网络手术压缩模型,对参数进行调整,其流程采取剪枝和拼接结合、训练与压缩同步的方式,在减少大量参数的同时保证精度.此模型包含两部分,即剪枝和拼接,如图3所示.其中,剪枝是压缩网络模型;拼接是为了弥补在剪枝不正确而造成的精度损失,对不正确的剪枝进行恢复拼接.不仅提高学习效率,而且更好接近压缩极限.对于问题2,通过两个方式来加快训练速度:1)降低参数的删除概率,提高收敛速度;2)将FireMoudle和卷积层分开进行参数裁剪.

图3 动态网络手术策略Fig.3 Dynamic network surgery strategy

式(1)表明网络的损失函数:

(1)

L(Wk⊙Tk)是网络损失函数,⊙代表是矩阵哈达玛乘积;hk(w)是分类函数,判定重要就为1,否则为0;Tk是0-1矩阵,表明网络的连接状态,是否被剪枝.I代表矩阵Wk中的元素.

分类函数hk(w)如式(2)所示,参数的重要性以权值绝对值为基础,设置ak,bk2个阈值,其中bk=ak+Tk.

(2)

Wk和Tk确定后,通过式(3)来更新Wk的值,其中β为正向学习效率.式(3)不仅更新重要的参数,而且更新已被认定为不重要或对减少损失函数无效的参数,即对Tk中已被定为0的参数依然进行更新.

(3)

算法中剪枝和拼接是不断循环的过程,通过不断更改连接的权重Wk和Tk的值来实现,直到迭代次数iter达到预设值.动态网络算法步骤如表2所示.

3.3 融合算法

最大池化层和下层的FireMoudle层相融合[15],不仅改善小卷积核的过拟合问题;而且底层特征分辨率更高含有更多位置、细节信息,但噪声很多,而高层特征分辨率低,但是对细节感知能力差;将高层特征和底层特征进行融合会提高对小目标(手写汉字中的点)的检测效果;前层学习的特征映射可以被后层访问,整个网络公用一部分特性,使模型更紧凑.

融合方法[16]体现于图2,将池化层得到的特征图和后面

表2 动态网络手术算法步骤Table 2 Dynamic network surgery algorithm procedure FireMoudle得到的特征图进行融合,得到新的特征图,算法如式(4)所示,将池化层提取得到的特征图和其后FireMoudle得到的特征图进行融合,得到新的特征图.

(4)

其中n,i,j分别代表新特征图个数,池化层所提取的特征图个数和FireMoudle处理后的特征图个数.

3.4 L2约束的Softmax分类

Softmax对于给定的测试输入,通过假设函数针对估算出每一个类别概率值并归一化处理,得到类别的归一化概率值,如式(5)所示;在模式识别任务中,可以有效分离多个类别并且容易实现;但是也有明显缺点:1)如果类别过多,那么会出现匹配问题;2)受限于最大化条件概率的处理方式,其更适用于高质量图像,不适用于困难罕见图像.

图4 Softmax和L2-Softmax在mnist数据集特征 分布情况对比Fig.4 Comparison of Softmax and L2-Softmax feature distributions in the mnist dataset

当限制最后的隐藏层输出为2时,实现特征可视化,得到图4从左到右为Softmax和L2-Softmax在mnist数据集上得到的特征分布情况,L2-Softmax准确率要高于Softmax.

由于本文手写汉字识别类别较多,故本文采用L2范数约束的Softmax进行分类,加上范数约束条件后,同一类别图像在归一化特征空间更接近彼此,不同类别图像距离更远,给样本平均化的关注,可以很好地处理到质量较差的样本.

(5)

式(6)为L2-Softmax类别概率值归一化处理,其中f(xi)是规模为M的一张输入图像,yi表示第i个目标的类别描述,只有一个元素为1,f(xi)是最后的全连接层之前的d维特征描述量,C是类别的数量,W和b分别代表网络中可训练的权重和偏差.

(6)

在网络中实现L2约束如图5,Softmax直接对Softmax损失进行归一化处理得到概率值,而L2-Softmax对Softmax输出前引入L2格式化层和Scale层.其中L2格式化层将输入的特征x归一化为单位向量;Scale层根据给定参数α,将单位向量缩放到固定的半径,鉴于将参数a同其他网络参数同时训练所得值过大,本文直接将a固定为较小常量,效果更好.

图5 Softmax与L2-Softmax网络Fig.5 Softmax and L2-Softmax network

4 实 验

本文选择CASIA-HWDB1.1数据集作为模型训练集,其中的汉字更多变,更难识别,包含3755类汉字;ICDAR-2013竞赛数据集作为测试集.实验环境:操作系统是Ubuntu 18.04,CPU是Intel Core i7-8700K CPU@3.70GHzX12,GPU是NVIDIA GTX1080TI 16G,RAM是DDR4 3200 16G,采用深度学习框架是Tensorflow 1.4.0和Keras 2.1.0,基于python3.6.3.

4.1 脱机手写汉字集预处理

数据集中汉字字迹深浅不一,对识别准确率有影响,对图像进行增强对比度操作.

(7)

式(7)中Imax,Imin分别为原图像的最大、最小灰度像素值,I(x,y)为原图像像素值,D(x,y)为目标图像像素值.

图片尺寸过大会增加网络负担,过小会降低识别性能,通过最近邻插值法将汉字图像归一化为56×56大小.

结合梯度特征可以提高手写汉字识别的有效性和准确率[17].从0,π/4,π/2,3π/4,π,5π/4,3π/2,7π/4这8个方向提取手写汉字特征,可以涵盖汉字的横、竖、撇、捺等笔画.通过sobel算子得到水平和垂直方向的梯度,再根据平行四边形分解原 则得到八个方向的特征图,最后进行叠加得到平均梯度图像.

图6左上角为原始图像“的”,中间为图像增强处理后的图像,右上角为梯度图像叠加后的平均梯度图像,后面8幅图像为对应方向的梯度图像.

图6 对“的”预处理Fig.6 Pretreatment of “的”

4.2 在CASIA-HWDB1.1数据集中实验

本数据集中每个汉字大概含有300个样本,共计1121749个汉字,分为训练集和验证集;训练集中每个汉字包含250个样本,验证集中每个汉字包含50个样本.测试集ICDAR-2013依然采用32×32的输入尺寸.

表3 超参数设置Table 3 Improved SqueezeNet module parameters

对卷积神经网络超参数进行设置,如表3所示,其中FireMoudle中设置压缩比为0.5,3×3的filter个数占总个数比例为0.25.

4.3 实验结果与分析

表4中展示几种典型方法在ICDAR-2013数据集上的识别效果,MQDF-HIT和MQDF-THU是通过提取灰度化后字符图像的特征向量后,采用级联的MQDF分类器分类.CNN-Fujitsu作为ICDAR-2013汉字识别的冠军模型,根据4个CNN模型投票来产生最终输出结果.ATR-CNN采用松弛卷积神经网络识别手写汉字,即通过改变传统卷积中一个特征图内共享卷积核策略.HEG是通过10个改进后的GoogLeNet的投票结果来产生最终结果.

由表可得,本文所提模型的识别准确率和模型体积量都比以MQDF为代表的传统手写汉字识别更有优势;比卷积神经网络AlexNet、CNN-Fujitsu和ATR-CNN有小幅度准确率提升并降低了模型体积;其准确率仅低于Skew Correction ResNet的98.4%,但有更小的体积.

对比SqueezeNet模型直接剪枝后的结果,虽然模型体积变得很小,同时准确率降低较多;本文中采用的动态网络手术来剪枝并拼接被误删重要的参数,虽模型体积相对直接剪枝更大一些,达到了3.2MB,但准确率得到显著提高,达到96.03%;最后对输入图片进行增强和梯度提取后,在预处理输入的基础下,得到模型的准确率提高到96.32%.

表4 改进SqueezeNet模型参数Table 4 Improved SqueezeNet module parameters

注:SN表示SqueezeNet,SCR表示Skew-correction-ResNet,DNS表示动态网络手术,HEG表示HCCR-Ensemble-GoogLeNet

5 总 结

本文提出的模型是在卷积神经网络SqueezeNet的基础上,引入动态网络手术降低参数输入,加快训练和收敛并在保证精度的情况进行合理剪枝,进一步减少参数,将模型压缩到了3.2MB,采用L2约束Softmax分类函数加速汉字分类的收敛,速度得到提升,达到很好的效果.但是模型的准确率对比最新的模型偏低2.37%,后面会加入HWDB1.0训练集,提高准确率;模型在计算机上已经有较好的压缩和识别效果,故下一步将本模型部署在ARM板或FPGA等硬件资源限制有限的平台,实现对脱机手写汉字的识别,并评估准确率和速率等性能.

猜你喜欢

剪枝手写准确率
基于梯度追踪的结构化剪枝算法
基于YOLOv4模型剪枝的番茄缺陷在线检测
工业场景下基于秩信息对YOLOv4的剪枝
我手写我心
乳腺超声检查诊断乳腺肿瘤的特异度及准确率分析
多层螺旋CT技术诊断急性阑尾炎的效果及准确率分析
不同序列磁共振成像诊断脊柱损伤的临床准确率比较探讨
颈椎病患者使用X线平片和CT影像诊断的临床准确率比照观察
我手写我意
剪枝