基于RNN 模型与LSTM 模型的机器作诗研究*
2021-10-08武丽芬严学勇
武丽芬,严学勇,赵 吉
(1.晋中学院 数学系,山西 晋中 030619;2.中国联通晋中分公司,山西 晋中 030601;3.晋中学院 信息技术与工程系,山西 晋中 030619)
诗歌是人类文学皇冠上的明珠,让机器自动生成诗歌,在人工智能领域极具挑战性,传统的诗歌生成方法主要有Word Salada(词语沙拉)、基于模板和模式的方法、基于遗传算法的方法、基于摘要生成的方法、基于统计机器翻译的方法。这些传统方法非常依赖诗词领域专业知识,需要专家设计大量的人工规则,对生成诗词的格律和质量进行约束,同时迁移能力很差。随着深度学习技术的发展,诗歌生成的研究进入了一个崭新阶段,本文将唐代诗歌作为研究对象,对比了两种递归神经网络模型在唐诗写作方面的效果。
1 模型概述
1.1 循环神经网络(RNN)概述及其结构
循环神经网络(RNN)具有其特有的记忆功能,是除卷积神经网络之外深度学习中使用频率最高的一种网络结构。在传统神经网络中,各个网络层之间是全连接的,各层的各个神经元是独立无连接的,而RNN 网络则能把序列化的、与时间相关的输入数据间的关系表达出来。类似于在一句唐诗中,某个词的出现必定与前几个甚至前几句诗中的词有关,完美状态下这样的循环神经网络可以记忆任意长度的序列化信息,其结构示意图如图1 所示,Xt为输入数据,A 为记忆单元,X0输入到A 后得到的中间信息,再把中间信息输给X1,以此类推。
图1 RNN 结构展开示意图
1.2 长短期记忆网络(LSTM)概述及其结构
LSTM 本质是一种特殊的RNN,其优势在于不仅可以像RNN 一样记忆先前输入的数据特征,还可以选择性遗忘一些对当前模型预测结果作用不大的特征信息。LSTM 是由Hochreiter &Schmidhuber 提出的,之后Alex Graves 对其进行了改进和推广,在处理很多与时序性相关的问题上,LSTM 模型均被证明具备突出的预测能力。RNN 训练参数是采用梯度下降法计算曲线的局部最小值,当略微越过最低点时,在梯度相反方向上会出现较大的梯度,这就会使得出现反向远离要求的局部极小值,进入重复寻求最小值的过程[1]。随着计算次数的增加,梯度很快减小,再次接近这个最小值需要运算很长时间,出现梯度消失现象[2],而LSTM 网络很好地解决了RNN 网络这一缺陷,其结构示意图如图2 所示。
图2 LSTM 示意图
LSTM 是具有循环神经网络模块的链式形式[3],在LSTM 结构中存在3 个门单元,分别是输入门,忘记门,输出门[1]。如当前层输入数据xt进入和上一层输出的中间信息ht-1组合之后通过激活函数(sigmoid)得到一个ft值,这个ft值全部在0 到1 的范围内,ft内将会和上一层输出的Ct-1作乘法操作组合在一起,相当于进行一个有选择的丢弃。Ct是LSTM 网络当中永久维护更新的一个参数,由于每一阶段需要保留和丢弃的信息不一致,故每一阶段都需要得出当前阶段的Ct值。上述信息组合会得到一个it,it为当前层要保留的信息,it与Ct组合将控制参数Ct进行更新得到新的Ct,记为nCt,接着对本层Ct进行最终更新,更新方法为Ct=ft*Ct-1+it*nCt。更新后需要根据输入ht-1和xt来判断输出的状态特征,输入经过sigmoid 层得到判断条件,然后经过tanh 层得到一个-1 到1 之间的向量,该向量与输出门得到的判断条件相乘得到最终该LSTM 单元的输出。最终输出ht=ot*tanh(Ct),其中ot=σ(W0[ht-1,xt]+b0,为当前层输入数据xt和上一层输出的中间数据ht-1组合后得到的尚未有选择保留和遗忘的数据。
2 数据采集与预处理
2.1 数据获取
采用Python 网络爬虫技术在“全唐诗网”(https://www.gushiwen.org)提取目标数据,调用re.compile()方法提前编译提取的正则表达式,避免每次提取时重新编译导致效率降低,利用map()方法进行映射操作,将各自对应位置的元素用“:”连接起来,map 方法返回由全部唐诗字符串构成的列表。
2.2 数据预处理
对爬取到的数据进行预处理,舍弃字符串中包含的“-”“*”“_”等特殊符号,并通过原诗查询予以校正。把唐诗语料库中出现的汉字组成的词汇表导出到映射表,其键为汉字,值为该汉字映射的数值,由映射值构成唐诗数据。
3 模型实现
利用TensorFlow 的tf.contrib.rnn.BasicRNNCell()方法定义一个循环神经网络单元,将该单元堆叠为两层的神经网络。根据输出数据是否为空判断初始化神经网络的状态,利用tf.nn.embedding_lookup()构造输入矩阵,得到神经网络的输出,利用tf.nn.bias_add()将偏置项加到输出矩阵和权重矩阵的乘积上,得到网络最后输出。
3.1 模型训练
采用one-hot 编码转化,利用tf.nn.softmax_cross_entropy_with_logits()传入真实值和预测值(即神经网络的最后输出)计算交叉熵损失,这里得出的交叉熵损失为全部汉字的交叉熵损失,维度较高,因此利用tf.reduce_mean()将损失降维并求均值。再利用tf.train.AdamOptimizer(learning_rate).minimize(loss)进行梯度下降最小化损失。最后将损失均值,最小化损失等op 返回供后续模型调用或训练使用。
3.2 模型写诗
利用tf.nn.softmax()方法将神经网络最后输出转化为各词出现的概率以及神经网络当前状态。
3.3 模型训练与调用
创建名为inference 的模块,在该模块下创建名为tang_poems.py 的编码文件,封装一个名为run_training 的方法调用数据预处理得到的唐诗向量数据、汉字映射表及词汇表。
在预测值汉字转化方法中,需要传入神经网络的最终预测值以及词汇表,根据汉字使用频度,定义是一个长度为6030 的数组,数组元素代表对应词汇表中各汉字出现的概率。由于神经网络训练完成后权重矩阵一定,意味着当输入某个汉字时,神经网络预测出的下一个汉字是固定的。也就是说多次传入某个汉字调用模型写诗,模型所作的诗句是一定的。这显然与作诗初衷不符。若让机器能够创作出不同诗句,则必须在词预测上加入一些随机性的因素,每次预测不选择出现概率最高的汉字,而是将概率映射到一个区间上,在区间上随机采样,输出概率较大的汉字对应区间较大,被采样的概率也较大,但程序也有小概率会选择到其他字。
在实现上则需要将概率表进行升级,对概率表元素重新赋值,使其第i 个元素变为原概率表前i 个元素的和,生成一个服从均匀分布在[0,1)之间的随机值,寻找该随机值在新概率表中的位置,将该位置作为本次预测的结果,恰好实现随机又不完全随机的目的。例如原概率表为[0.1,0.2,0.18,0.25,0.21,0.06],升级后的新概率表为[0.1,0.3,0.48,0.73,0.94,1.0],生成的随机值为0.4,那么其在新概率表中索引到的下标为2,也就是根据随机值0.4 索引到概率区间[0,0.48]上,选取到该区间内的最后一个概率值,而这个概率值又是随机的(原概率表中各汉字对应的概率排列并无规律可言),即本质是在[0,0.48]这个概率区间上随机进行采样。
4 测试
4.1 交叉熵损失观察
系统本质上属于分类问题,因此采用交叉熵构造损失函数。交叉熵公式为:L=-y*logP+-(1-y)*log(1-P)。其中y 为神经网络最后一层的输出值,P 为该值经过softmax 激活后的概率值,概率值越大说明预测的事件越稳定,即熵值越低,而-logP 函数图像恰好拟合了熵值和概率值之间的映射关系,并且每个样本预测值为分类号0或1,0 表示预测结果不是该汉字,1 表示预测结果是该汉字。那么当y 为0 或1 时,L 中总会有一部分为0,另一部分恰好为交叉熵值,则可以一步计算出预测结果为0或1 的样本的交叉熵损失值。
4.2 两种模型对比测试
调用RNN 模型和LSTM 模型,各随机生成20 句唐诗,在Robo 3t 下观察RNN 与LSTM 各自作诗水平,RNN 模型平均交叉熵损失观察结果为6.53426236 429408;LSTM模型平均交叉熵损失观察结果为4.26054176463658。
4.3 LSTM 模型的写诗水平明显优于RNN 模型
(1)在交叉熵损失方面。LSTM 的平均交叉熵损失和最低交叉熵损失都显著优于RNN,RNN 的最低交叉熵损失均出现在第一代训练批次中,且在训练后期的损失值反而高于其平均损失,说明RNN 的损失函数在经过前几代训练后得到收敛。反观LSTM 的最低损失多数出现在训练后期,说明其训练过程更加有效,随着训练迭代次数不断增加能够持续降低损失值。
(2)在作诗水平方面。RNN 所作诗句明显不符合作诗预期,大多为一个字或三到四个字,并未连成完整诗句。而LSTM 所作诗句不仅对仗工整,且与唐诗相似度极高。测试结果有力证明在作诗方面LSTM 模型优于RNN模型。
5 结束语
通过模型测试循环神经网络(RNN)与长短期记忆网络(LSTM)在写诗方面的效果比较,得出LIST 模型明显优于RNN 模型。可见LSTM 模型善于建立时间序列数据间的非线性关系,适用于词句的预测。RNN 神经网络的训练过程采用BPTT 算法,随着递归层数的增加,计算梯度下降时,会出现梯度消失和梯度爆炸问题,严重影响模型效果,而LSTM 则从神经单元的结构上对此进行了优化[1],大大减小了其训练过程中梯度消失和梯度爆炸的问题。RNN 只能够处理短期依赖问题,而LSTM 得益于其特有的神经元内部结构使之既能够处理短期依赖问题,又能够处理长期依赖问题。在处理时序性信息时效果明显优于RNN,在自然语言处理应用方面比RNN 更高效。