APP下载

基于关键点估计的抓取检测算法

2022-03-02关立文孙鑫磊

计算机工程与应用 2022年4期
关键词:中心点关键点损失

关立文,孙鑫磊,杨 佩

1.清华大学 机械工程系,北京100084

2.电子科技大学 机械与电气工程学院,成都611731

虽然抓取物体对于人类而言非常简单,但是对于机器人而言,可靠地抓取任意物体仍然是非常具有难度的。解决这个问题可以促进机器人在工业领域的应用,如零件组装、分拣、装箱等,同时也能够推进服务机器人的发展,通过增强机器人与周围环境的交互,满足人类的需求。机器人抓取涉及到物体感知、路径规划以及控制。得到要抓取的对象的位置以及对应的抓取姿态对于一个成功的抓取是非常重要的。物体的几何外形是决定抓取位置的主要因素,可以通过引入机器视觉来增加抓取检测网络的泛化能力。描述一个抓取的参数主要有三个,分别是抓取点的坐标、抓取手爪张开的宽度以及抓取的旋转角度。

抓取检测算法的研究在20 世纪80 年代就开始了,但早期的研究主要是针对抓取点的检测,无法提供一个准确的抓取描述。直到2011 年Jiang 等[1]提出了一种抓取矩形框的表示方法,如图1(a)所示,抓取矩形框由一个五维向量g=(x,y,w,h,θ)来表示,其中(x,y)表示抓取点在图像中的坐标,h表示平行爪夹持器张开的宽度,w表示手爪的宽度,θ表示抓取角度。用这样的方法表示抓取,那么抓取检测就可以将一个在空间中寻找抓取姿态与抓取点的问题,转换成在一个包含待抓取目标的图像中检测抓取矩形框的问题。针对这一问题,使用深度学习在图像上学习特征的方法在抓取检测中获得了很好的效果。

Lenz等[2]率先采用深度学习方法提取特征,使用基于滑动窗口检测的框架同时使用支持向量机(support vector machine,SVM)作为分类器,预测输入图像中是否存在合适的抓取位置,这种方法在康奈尔抓取数据集[3]上达到了73.9%的准确率。但是由于采用滑动窗口的方法导致在遍历可能存在抓取时消耗大量时间。

Redmon等[4]抛弃了滑动窗口机制,将整幅图像划分成N×N个单元网络,使用AlexNet 网络[5]直接在每个单元格中回归抓取框的参数以及可行抓取的概率,取其中概率最高的作为预测结果。这种方法在相同的数据集中达到了88.0%的准确率。

Guo 等[6]将参考矩形框引入到抓取检测当中,这是一种无向锚框,如图1(b)所示,这些参考矩形是图像在每个特定大小的区域中生成的具有相同面积、不同长宽比的矩形框。在他的研究中,并没有直接通过深度学习的方法检测抓取方向,而是通过一种融合视觉感知与触觉感知的模型预测可抓取性、抓取手爪的张开宽度以及抓取的方向。

Chu等[7]利用深度学习的方法来检测抓取框的位置以及抓取方向,他们使用了与Guo文章中相同的参考矩形框,用来回归抓取矩形框,同时将抓取角度看作抓取的语义信息,将角度按照不同的区间分成不同的类。网络检测时对回归的抓取矩形框旋转预测得到的角度类别所对应的角度,如图1(c),得到一个有向抓取矩形框的检测结果。

图1 基于锚框的抓取表示Fig.1 Anchor-based grasping representation

以上几个研究都是基于锚框的抓取检测算法。基于锚框的检测算法检测速度较慢,同时锚框的设计也影响着网络的性能。另外在Guo 与Chu 的研究中对水平抓取框与抓取角度分别进行检测,忽略了抓取角度是抓取框的几何属性而非语义属性的事实,这样做会导致抓取检测的准确度下降。因此本文提出了一种更加简单高效的方法,如图2 所示,用抓取框的中心点来表示一个抓取,同时在中心点处直接预测抓取检测框的尺寸和角度。这里抓取检测问题被简化成了一个关键点检测问题,本文将图片输入到一个全卷积网络中得到一张抓取热力图,在抓取热力图中的局部峰值就对应抓取检测框中心点的位置。同时在特征图上中心点对应的位置会预测抓取框的尺寸和角度。该模型在康奈尔抓取数据集上使用GTX1080TI显卡运行,达到了97.6%的准确率,并且达到了42 frame/s的检测速度,满足检测实时性的要求。本文主要做了以下工作:

图2 基于关键点的抓取框表示Fig.2 Key-point-based grasping representation

(1)设计了一种特征融合方法B-FPN,可以通过权重融合不同阶段的特征图,减少特征的丢失。

(2)设计了一个基于关键点估计的抓取检测网络,直接在特征图上预测抓取中心点位置以及抓取尺寸与抓取角度。

(3)使用了一种新的损失函数,能够在不增加模型复杂度的情况下避免由于正负样本不均衡带来的预测准确度下降。

1 方法介绍

1.1 目标检测的类别与特点

目标检测算法主要可以分为单阶段与两阶段两种类型,目前主流的两阶段目标检测算法以R-CNN(region-based convolutional neural networks)系列为主,比较成功的有Faster R-CNN[8]、Mask R-CNN[9]等。两阶段的目标检测算法首先通过一次粗回归得到ROI(region of interest)作为候选框样本输入到卷积神经网络中,通过精回归得到对应的目标检测框。R-CNN 系列算法虽然在性能上有比较大的提升,但是由于其在训练网络时的正负样本由传统算法生成,这限制了算法的检测速度。

以R-CNN算法为代表的两阶段检测算法由于RPN结构的存在,虽然检测的精度越来越高,但是检测速度却很难达到实时检测的需求。因此,研究人员提出了基于回归的单阶段目标检测算法。以YOLO(you only look once)系列为例,YOLO算法经历了从早期的YOLO[10]到YOLOv2[11]再到后来的YOLOv3[12],算法的准确率不断提高。在YOLOv3中也引入了anchor机制,并采用特征金字塔结构增强网络对多尺度目标的检测能力。

1.2 CenterNet目标检测

CenterNet[13]目标检测算法不同于R-CNN、YOLOv3[12]、SSD[14]等基于锚框的检测算法,它利用关键点估计的思想,通过检测目标框的中心点,然后回归检测框的其他属性,比如尺寸、姿态等,如图3所示边界框的尺寸与其他对象属性是从中心的关键点特征判断出来的,中心点以彩色显示。相比基于锚框的检测算法,CenterNet 的模型是端到端的,因此它更加简单、更加准确,检测速度也更快,其与不同算法的比较如图4所示。

图3 利用边界框的中心点建模Fig.3 Modelling object as center point of bounding box

图4 不同检测方法在COCO数据集上的速度-精度曲线图Fig.4 Speed-accuracy trade-off on COCO validation for different detectors

CenterNet 以目标的中心点来表示目标的位置,然后在特征图上中心点的位置回归出目标的其他属性,这样一来就将目标检测问题转换成一个关键点估计的问题。将图像传入到一个全卷积网络中,网络会输出一个热力图,热力图中峰值点的位置就是图像中目标的中心点位置,同时特征图上每一个峰值点的位置都会预测目标的尺寸信息。整个网络采用监督学习的方式来训练,并且不需要对检测结果进行附加的后处理操作。

2 算法原理

本文算法的整体框架如图5所示,本章主要从网络结构、损失函数以及训练策略三方面对基于关键点估计的抓取检测算法进行介绍。

图5 基于关键点估计的抓取检测算法框架Fig.5 Grasping detection algorithm based on key point estimation

2.1 网络结构

在CenterNet论文中采用Resnet101[15]作为特征提取网络,在上采样阶段,先用3×3 的深度可分离卷积改变图像的通道数,然后使用转置卷积进行上采样。最后得到相当于输入图像4倍下采样大小的特征图,相比于传统目标检测算法使用16 倍下采样作为特征图,较大的特征图更适合关键点估计。但是这里用到的特征图原论文中只使用了最后4 倍下采样的特征图进行目标检测,这会导致图像的一些特征丢失。为了充分利用卷积阶段的各个特征图,本文将Resnet101网络产生的4个特征图使用特征金字塔(feature pyramid networks,FPN)[16]的特征图融合方法进行融合。

但是常规的特征金字塔融合方法是直接将各个特征图进行融合的,没有考虑到不同的特征图对最后的目标检测性能的区别。研究表明,各个阶段的特征图对于最后融合的特征图的贡献是不同的,Tan 等[17]在EfficientDet 中提出了一种对各个特征图的加权特征融合方法BiFPN。在对特征图进行融合时,对每个输入的特征图增加一个权重,这个权重是可学习的,这样网络能够在训练过程中学习到特征图融合的权重,改变各个特征图对最后目标检测性能的贡献。

本文采用了快速标准化特征融合的方法来进行特征图的权重融合,其表达式如下:

在每一个wi后接一个Relu 函数来保证wi≥0 。其中ε=0.000 1,可以避免分母为0 而导致的数值不稳定。经过标准化之后每一个权重都落在了0 到1 之间,然后对不同的特征图Ii进行加权求和,得到的O就是融合特征图。

本文中的特征融合网络B-FPN的结构如图6所示,图中的圆形表示卷积操作,虚线箭头表示上采样,高层的特征图通过上采样与低层的特征图进行融合,得到最后的特征图输出O。

图6 B-FPN结构Fig.6 B-FPN structure

其计算公式如下:

与CenterNet中的Resnet101特征提取网络相比,本文使用的Resnet101+B-FPN 特征提取网络能够通过添加权重的方式进行特征图融合,减少特征的损失。

网络改进前后的特征提取网络如图7所示。图7(a)为CenterNet 中原本的特征提取网络,图7(b)为改进后的特征提取网络。在网络的上采样阶段在每一个转置卷积前加上了一个3×3 的深度可分离卷积来改变通道数,然后使用转置卷积进行上采样(如图中32 →16 的上采样过程中,黑色虚线箭头表示深度可分离卷积,红色实线箭头表示转置卷积上采样,在16 →8、8 →4 的上采样过程中,使用一个红色虚线箭头代替两个过程)。最后得到相当于输入图像4倍下采样大小的特征图,相比于传统目标检测算法使用16 倍下采样作为特征图,较大的特征图更适合关键点估计。

图7 特征提取网络Fig.7 Feature extraction network

特征图后接4 个通道,分别为关键点检测通道、关键点偏移量预测通道、抓取框尺寸预测通道以及抓取角度的预测通道。在本文的抓取检测算法中,因为不涉及对目标类别的分类,只需要检测是否可抓取,所以在抓取可行性热力图的通道数为1。在关键点偏移量预测通道中网络会预测每个点在x与y方向上的偏移量,因此其通道数为2。在抓取框尺寸预测通道中网络会预测抓取框的尺寸信息,分别w、h,通道数为2。最后在抓取角度预测通道,网络会预测抓取框的抓取角度θ,其通道数为1。

2.2 关键点估计及损失

设I∈RH×W×3为宽为W、高为H的输入图像,网络的输出是利用关键点估计生成的热力图其中R是输出特征图的下采样倍率(即尺寸缩放比例),本文中取4。C表示输出特征图的个数,在本文的抓取检测算法中,C=1,即可抓取类别。对于Ground Truth 的关键点K,其坐标为p∈R2,经过下采样之后在特征图上的位置为。本文通过使用二维高斯核将热力标签分散到热力图中。其中σp为尺度自适应标准差,其值为卷积核大小的,本文使用大小为的高斯核,w表示标注抓取框的宽度。如图8所示为高斯热力分布图。

图8 康奈尔抓取数据集中的物体及其抓取热力图Fig.8 Objects in Cornell grasp dataset and grasping heat map

热力图在训练时的损失函数使用改进的Focal Loss[18],其表达式如下:

其中α、β为超参数,在本实验中选择2 和4,N表示一张图片中关键点的个数。不考虑权重(1-Yxyc)β,可以将上述损失函数转换成以下形式:

当Pt的值比较接近于1 时,(1-Pt)α会比较小,这样损失函数的值也会变小;当Pt的值比较小时,表示当前样本为难分样本,对应的(1-Pt)α会比较大,这样一来网络在训练过程中会更加关注难分样本的分类。(1-Yxyc)β表示负样本的权重项,在传统的Focal Loss 中,对于预测值过高的负样本,网络会用来惩罚损失函数,但是在关键点检测中,期望越接近于中心点的位置其预测值越大,因此这里使用了(1-Yxyc)β权重项,当预测位置越接近中心点,其值就越小,损失函数也会越小。而对于远离中心的预测位置,该项不起作用。

2.3 关键点偏移及损失

因为在对图像进行下采样操作时存在量化操作,这使得Ground Truth的关键点会产生偏移,所以需要对关键点的位置进行回归。本文对每一个关键点的位置进行了局部偏移的预测,对于这个偏移量,使用L1 Loss 来训练,这里只计算关键点处的偏移损失,损失函数表达式如下:

其中,p表示中心点在原图中的坐标,R为图像的缩放尺度,本文取4,为量化操作后的坐标,N为正样本的数量。

2.4 目标尺寸预测及损失

因为在抓取检测算法中,目标框的bounding box不是水平矩形,所以无法用左上右下点的坐标来表示,这里需要用4 个点的坐标来表示bounding box。设表示目标k的bounding box 的4个角点的坐标,那么其中心点的位置为:

同时也可以计算出目标的尺寸信息。根据康奈尔数据标注的特点,参考抓取框的宽为参考抓取框的高为。可以估计出目标尺寸信息:

这里的损失函数也用到了L1 Loss函数,其表达式如下:

其中,为目标k的参考抓取框尺寸,可以表示为(w,h),w表示平行爪夹持器张开的宽度,h表示手爪的宽度,N为正样本的数量。

2.5 抓取角度预测及损失

在抓取检测算法中除了对抓取检测框尺寸进行预测,还需要预测抓取框的角度,本文中的抓取角度θ表示平行爪夹持器在图像平面中的投影与图像水平方向所成的夹角,范围为(0,π)。根据康奈尔抓取数据集上数据标注的规则,前两个点所连线段的方向代表平行爪夹持器手爪张开的方向,即抓取角度,因此可以用θ=来表示。计算角度损失使用的损失函数为L1损失函数,其表达式如下:

其中,为目标k的抓取角度,N表示正样本的数量。

2.6 总损失函数

因为本文没有对目标尺寸做归一化处理,直接选用原始像素的坐标,所以会导致Lsize的值较大。为了平衡损失函数的分布,需要在各个损失函数前添加权重,其表达式如下:

参考了文献[12]中的权重设置,在实验中,使用λsize=0.1,λoff=1,λtheta=1。

2.7 网络的预测

本文使用的输入图像大小为512×512,经过4 倍下采样操作之后特征图的大小为128×128,网络会在特征图的每一个位置预测6个值,分别为特征图每个点处的关键点热力图,偏移量δx、δy,尺寸预测值w、h,抓取角度预测量θ。

在抓取热力图中将所有点与其8 邻域内的所有点的预测值做比较,如果该点的值大于或等于其他8个邻近点则保留,最后保留满足之前所有要求的前100个峰值点。设为抓取热力图中检测到的n个预测关键点的集合,。其中每一个关键点的坐标都是以整数形式(i,)给出的,因此最后生成的抓取检测框的表示形式为:

其中,(δi,)表示关键点位置偏移量预测,表示抓取框的尺寸预测,表示抓取角度的预测值。所有的预测输出都是通过点估计直接产生的,不需要进行非极大值抑制或其他后处理操作。

3 实验结果及分析

3.1 实验条件

本文实验使用的操作系统是ubuntu16.04,处理器的型号为Intel®CoreTMi7-8700K,显卡型号为NVIDIA GeForce®GTX 1080Ti,采用NVIDIA CUDA9.0 加速工具箱。

3.2 模型测试

为了对新设计的Resnet101+B-FPN 模型的性能进行测试,在Pascal VOC数据集上进行实验。Pascal VOC是一个常用的目标检测数据集,其中包含20 个类别的16 551 张训练图片以及4 962 张测试图片。本文使用IoU阈值为0.5时的mAP作为评价指标。比较CenterNet在分别使用Resnet101和Resnet101+B-FPN时网络的表现。在实验中分别采用了两种分辨率的输入384×384,512×512。两个网络采用相同的训练策略,批大小为32,初始学习率设置为1.25E-4,总共迭代70次,其中在迭代次数达到45 和60 时将学习率减小为原来的1/10。实验结果如表1所示。

表1 Pascal VOC数据集上目标检测的结果Table 1 Object detection results on Pascal VOC dataset

从表1中可以看到,加入B-FPN之后网络的mAP提高了,并且在大分辨率的图片上mAP提高得更加明显,由此可见使用了特征融合方法融合特征图之后可以增强网络的性能。在后续的抓取检测网络中,用到的特征提取网络部分为Resnet101+B-FPN。

3.3 训练

本实验在康奈尔抓取数据集上进行训练和测试,该数据集包含240 个可抓取物品的885 张图片,在每张图片中,可行的抓取被表示成抓取矩形框。在训练时将数据集按照4∶1的比例划分成训练集与测试集。

在训练集上进行算法模型的训练,512×512的图像输入到网络当中,模型输出大小为128×128。在训练过程中,对训练集中的图像进行随机翻转、随机旋转、随机缩放、随机裁剪和色彩抖动等方法来进行数据增强,算法模型的优化器选择Adam。输入数据的批量大小为16,迭代步数为140,初始学习率设定为0.001,在训练步数达到60和80时,学习率减小为原来的1/10。

训练过程中损失函数的收敛趋势如图9所示。图9(a)表示抓取检测总损失函数(Ldet)的收敛趋势,图9(b)表示关键点估计损失函数(Lhm)的收敛趋势,图9(c)表示目标尺寸大小损失函数(Lsize)的收敛趋势,图9(d)表示目标中心点偏移损失函数(Loff)的收敛趋势,图9(e)表示角度预测损失函数(Ltheta)的收敛趋势。

图9 训练阶段损失Fig.9 Loss in training stage

3.4 评价指标及测试

测试时本文采用的仍然是康奈尔抓取数据集来评估抓取检测的性能。本实验采用Zhang等人[22]提出的评价指标,当抓取检测框满足以下两个条件时认为是一个正确的预测:

(1)预测抓取框与参考抓取框之间抓取角度的差值小于30°。

(2)Jaccard相似系数大于25%,其中Jaccard相似系数计算公式如下:

其中,g表示预测抓取框所围成的区域,表示参考抓取框所围成的区域,表示两个区域相交部分的面积,表示两个区域覆盖部分的面积。

模型在康奈尔抓取数据集中的验证集上的检测结果如图10所示。第一行为康奈尔抓取数据集中的物体,第二行为预测抓取热力图,第三行为抓取框的检测结果。

图10 抓取检测结果Fig.10 Results of grasp detection

表2 是本文网络模型与其他抓取网络检测模型的准确率与检测速度的对比表。

由表2可以发现,本文的模型在准确率与检测速度上都有提升,使用了无锚框的网络结构使检测速度更加快,同时使网络具有更强的鲁棒性。

表2 抓取位置检测算法准确率对比Table 2 Accuracy for grasping detection

4 结束语

为了加快抓取检测网络的检测速度以及增强对抓取角度的检测能力,本文提出了一种新的检测方法,通过关键点估计的方法检测抓取框的中心点位置,并且预测抓取的角度、尺寸等信息。一方面CenterNet 端到端的网络结构以及无需后处理的网络特点,能够加快网络的检测速度;另一方面高分辨率的特征图的输入能够检测到更多的可行抓取框。同时,基于关键点的检测思路与抓取检测任务更加匹配,在抓取检测时从抓取点去回归抓取框,相较于基于锚框的抓取检测算法能够更好地得到抓取轮廓,获得更多的可行抓取。在康奈尔抓取数据集中,相比于基于锚框的抓取检测算法,基于关键点估计的检测模型能够在保证较高准确率的同时达到较快的检测速度。实验结果显示本文的模型在验证集上有97.6%的准确率,并且能够达到42 frame/s的检测速度。

猜你喜欢

中心点关键点损失
肉兔育肥抓好七个关键点
建筑设计中的防火技术关键点
胖胖损失了多少元
一种基于标准差的K-medoids聚类算法
Scratch 3.9更新了什么?
如何设置造型中心点?
玉米抽穗前倒伏怎么办?怎么减少损失?
菜烧好了应该尽量马上吃
寻找视觉中心点
机械能守恒定律应用的关键点