APP下载

MobileNet network optimization based on convolutional block attention module

2022-05-05ZHAOShuxuMENShiyaoYUANLin

ZHAO Shuxu, MEN Shiyao, YUAN Lin

(School of Electronics and Information Engineering, Lanzhou Jiaotong University, Lanzhou 730070, China)

Abstract: Deep learning technology is widely used in computer vision. Generally, a large amount of data is used to train the model weights in deep learning, so as to obtain a model with higher accuracy. However, massive data and complex model structures require more calculating resources. Since people generally can only carry and use mobile and portable devices in application scenarios, neural networks have limitations in terms of calculating resources, size and power consumption. Therefore, the efficient lightweight model MobileNet is used as the basic network in this study for optimization. First, the accuracy of the MobileNet model is improved by adding methods such as the convolutional block attention module (CBAM) and expansion convolution. Then, the MobileNet model is compressed by using pruning and weight quantization algorithms based on weight size. Afterwards, methods such as Python crawlers and data augmentation are employed to create a garbage classification data set. Based on the above model optimization strategy, the garbage classification mobile terminal application is deployed on mobile phones and raspberry pies, realizing completing the garbage classification task more conveniently.

Key words: MobileNet; convolutional block attention module (CBAM); model pruning and quantization; edge machine learning

0 Introduction

Artificial intelligence has developed rapidly in many engineering fields through deep learning and machine learning in recent years. Deep learning performs well as an effective way in the computer vision field to achieve intelligence. It can accurately solve the problems of image classification, target detection and image segmentation by using a large amount of data to train the model[1]. However, the model becomes more and more complex in order to obtain better accuracy, resulting in higher calculating requirements.

Therefore, the lightweight network model has become a new direction for deep learning[2]. Currently, the effective model method is mainly constructed by two ideas.

It is useful to optimize the existing excellent models, i.e. model compression. The classic model compression algorithms are listed below.

1) Weight pruning is a method of pruning a trained model, which is currently the most used method in model compression. It is usually to find an effective means to judge the importance of parameters, and cut unimportant connections or filters to reduce model redundancy. Finally, retrain the model to restore some accuracy after pruning[3].

2) Weight sharing and quantization is another efficient method to compress the model. Weight sharing and quantization can further compress the pruned network. They compress the model by reducing the number of bits needed to represent each weight, limiting the number of effective weights and storing multiple connections that share the same weight. The commonly used weight sharing method is K-means. After that, the weight parameters are reencoded and saved according to the communication principle.

3) Huffman coding as an optimal prefix code generally is used for lossless data compression, adopts variable-length code words to encode source symbols. Code according to the probability of each symbol appearing, representing more common symbols with fewer bits.

The second idea is to build a lightweight networkmodel directly by designing the convolution operation and network structure. Google proposed a MobileNet network model based on depthwise separable convolution. The essence of convolution operation was the sparse expression with less redundant information[4]. Comparing with existing networks for image classification problems such as Inception v3 and VGG16, this network exhibited the characteristics of small size, few parameters and fast calculation speed. Simultaneously, Google provided a seamless and convenient mobile solution called tensor flow lite for the MobileNet. Besides, the effectiveness of MoblieNet was demonstrated by several experimental data[4].

Therefore, the MobileNet is used as the basic network for optimization research in this paper and applied to real applications. The performance of the optimized model on the CIFAR10 standard data set was presented in this paper. Additionally, the Android application of garbage classification was achieved based on the optimization strategy.

1 Model optimization

1.1 Basic network

MobileNet is an efficient and lightweight neural networkmodel proposed by Google for mobile and embedded devices[4]. In the MobileNet network, a new convolution called depthwise separable is proposed to build the lightweight network model[4]. This convolution method separates the standard convolution into two new convolution methods to extract the plane features and channel features of the data respectively, as illustrated in Fig.1.

Fig.1 Depthwise separable convolution

Meanwhile, the size of the model and the number of parameters can be adjusted with two simple parameters, making a balance between accuracy and efficiency. MobileNet significantly reduces the number of network parameters and computations compared to other advanced models on the ImageNet standard data set. Besides, the accuracy of MobileNet is not decreased dramatically compared with the traditional excellent network[4]. As presented in Table 1, the networks’ classification accuracy (such as VGG) on the ImageNet dataset is basically the same.

Table 1 MobileNet comparison to popular models

This conclusion in another set of experiments is also gotten. The fine-grained classification task on the Stanford Dogs dataset is provided in Table 2. The classification accuracy of MobileNet is close to that of Inception V3 when the calculation and model size tremendously decrease[4].

Table 2 MobileNet for stanford dogs

To sum up, the MobileNet network is a valid lightweight model that can ensure the accuracy of classification in the task, allowing it to deploy on the edge terminals for machine learning inference. Therefore, the calculating resource and communication delays of the entire system can be reduced.

1.2 Solutions to overfitting

Neural networks often have overfitting problems. To prevent the overfitting during training, a L2 regularization strategy is added to the depthwise separable convolution structure. This strategy is arranged after depthwise convolution and before pointwise convolution, as illustrated in Fig.2.

L2 regularization is a classic method to prevent overfitting in deep learning by reducing the number of coefficients to limit the complexity of the model. The specific process is to make some elements in vector zero or limit the number of non-zero elements. Therefore, a penalty termΩ(ω) is added to the original loss function to restrict the model complexity[5]. Its mathematical expression is

L(ω;X,y)=L1(ω;X,y)+λΩ(ω)

(1)

and it satisfies the condition of

(2)

Fig.2 L2 regularization in network structure

For the model weight coefficientω, the solution is to minimize the loss function, that is

(3)

whereXrepresents the training data that is input, andydenotes the output result.

The complexity of the model can only be limited by adjusting the value of the constantCto satisfy condition

(4)

It should be ensured that the sum of squares of allωdoes not exceed the parameterC. Generally, the goal is to minimize the training sample errorL1while following the condition that the sum ofωsquares is less thanC[6].

1.3 Optimization of feature extraction

After solving the problem of overfitting, the ability to extracting data features needs to be improved. To avoid increasing the number of layers and complexity of the model structure as much as possible, the expansion convolution and CBAM are added to the feature extraction structure instead.

As illustrated in Fig.3, the expansion convolution is added to the network structure. The convolution generally used in neural networks is based on consecutive adjacent pixels. The expansion convolution is different. The distribution of its operators is sparser, and the operation parameters remain unchanged, as presented in Fig.4.

Fig.3 Expansion convolution in network structure

Fig.4 Expansion convolution

The expansion convolution can expand the receptive field of the convolution kernel and obtain more context information between image pixels without increasing the model parameters. In computer vision, it is possible to understand the context of the picture better.

After adding expansion convolution, the CBAM is added at the end of the block of each depthwise separable convolution, as exhibited in Fig.5.

Fig.5 CBAM module in network structure

In this way, the feature extraction capability of the depthwise separable convolutional block significantly improves at the cost of adding a few numbers of operations.

CBAM is a simple and effective attention module for feedforward convolutional neural networks[7]. After inputting the feature map, the module infers the attention map based on two independent dimensions of channel and space. Then, the attention map is multiplied by the input feature map to perform adaptive feature refinement, as exhibited in Fig.6. From the perspective of computing resources, CBAM is a lightweight application module that can be integrated seamlessly into any CNN network that does not require high calculating resources. Besides, the entire CNN network is employed to directly train the model[7]. Experimental results on ImageNet-1K, MS COCO and VOC 2007 datasets demonstrate that the modified module improves the classification and detection performance of different models, fully illustrating the effectiveness of the module.

Fig.6 CBAM module

1.4 Optimization strategies for training

1.4.1 Cross-entropy loss function

After the model is optimized, the training of the model is needed to start thinking about. The loss function is the most basic and critical element to be considered in model training. The cross-entropy loss function is the loss function commonly used in training neural networks. The cross-entropy describes the degree of tightness between the two sets of probability distributions, real values and putative values, as expressed in

H(p,q)=-∑p(x)logq(x),

(5)

wherep(x) denotes the probability distribution of the true classification situation, andq(x) represents the probability distribution of the inferred classification[8]. The smaller the value of the cross-entropy, the closer the distribution of the two probabilities[8].

Cross entropy is used as a loss function for classification problems in machine learning[9]. The loss function for all sample data is

(6)

whereNrepresents the number of samples;ydenotes the sample label, andprefers to the predicted probability. Suppose that there are two separate probability distributionsp(x) andq(x) for the same random variableX, KL divergence can be used to measure the difference between the two probability distributions. In the machine learning training network, the input data and labels are often determined. Therefore, the actual probability distributionp(x) is also determined, and the information entropy is constant. Since the value of KL divergence represents the difference between the true probability distributionp(x) and the predicted probability distributionq(x), the smaller the value, the better the anticipated result. Therefore, the KL divergence needs to be minimized. The cross-entropy is equal to KL divergence plus a constant (information entropy), and it is easier for the formula to calculate compared to the KL divergence. Consequently, the cross-entropy loss function is often used to calculate the loss in deep learning.

1.4.2 Adam algorithm

On the other hand, the optimizer is used to update and calculate the network parameters that affect the model training and model output to approximate or reach the optimal value, thereby minimizing the loss function. Therefore, the optimization function is as important as the loss function in neural networks. The Adam algorithm is adopted to replace the SGD algorithm when the classification results are iteratively optimized. Adam is a popular algorithm in deep learning because it can quickly obtain good results. Experimental results indicate that Adam performs well in practice and is better than other stochastic optimization methods (Fig.7). In the original paper, experiments verifies that the convergence of the algorithm meets the expectations of theoretical analysis.

Thus, the logistic regression algorithm in MNIST can be optimizedby using the Adam algorithm. Adam can also be employed when using the multilayer perceptron algorithm on MNIST and training convolutional neural networks on the CIFAR-10 image recognition dataset. To sum up, Adam can effectively solve practical deep learning problems demonstrated by the experiment on the large models and data sets[10].

Fig.7 Adam on MINIST

1.5 Model compression

Based on the above strategy, the accuracy of the model can be increased significantly. In the next step, a model compression algorithm called pruning is used to process the model to limit the size and parameters of the model. The method of pruning the trained model is currently the most used in model compression. As presented in Fig.8, the model weights are trained up. They are then pruned, and the model is retrained.

Fig.8 Pruning method

Weight pruning refers to eliminating unnecessary values in weight tensor. First,sort the neural network parameters according to the size of the weight. Then, make some parameters to zero as needed to eliminate the low-weight connections among the layers of the neural network[11], that will bring the following advantages to the model.

1) Compress. The sparse tensor (model weight) can be obtained by compressing only non-zero values and their corresponding coordinates.

2) Speed. Sparse tensor (model weight) can skip other unnecessary calculations during model inference.

Through the experiments, the model can be compressed almost several times on the MINIST dataset, as illustrated in Table 3. A CNN network is built with three layers of convolutional structure, and the MNIST is adopted to train the model. The MNIST handwritten digital data set comes from the National Institute of Standards and Technology of the United States, which is one of the famous public data sets. Generally, people use this data set as an introductory case for deep learning. The digital pictures drawn by 250 persons of different occupations in the data set contains 70 000 samples, in which 60 000 are the training dataset, and 10 000 are the test dataset. Each training element is a 28×28 pixel handwritten digital picture, and each one represents every digit from 0 to 9. Besides, the model is trained for 12 epochs by using the MINIST dataset, and the eventual accuracy is 0.992. After pruning the model (sparse is 0.6), the model size is compressed by nearly 4-5 times while the accuracy of the model is not reduced.

Table 3 Pruning on MINIST

Moreover, the model size can be further compressed with pruning combined other optimization techniques such as quantization. Quantization converts weights from 32-bit to 8-bit, leading to reducing the size of the model several times. Its performance on MINIST is verified through experiments, as presented in Table 4.

Table 4 Weight quantization on MINIST

2 Experiments

In this part, the effects of the original model and the optimized model are compared by using the standard data set CIFAR10.

2.1 CIFAR-10 dataset

Cifar-10 is a computer vision data set for universal object recognition collected by Alex Krizhevsky and Ilya Sutskever, which contains 60 000 32×32 RGB color pictures with a total of 10 categories where 50 000 sheets are the training set and 10 000 sheets are the test dataset[12], respectively, as shown in Fig.9.

Fig.9 CIFAR10 data

2.2 Experimental results

2.2.1 MobileNet original model

First, the CIFAR10 dataset is used to directly train the original MobileNet model. Before training, it is needed to set the relevant parameters. Therefore, set the class as 10, the batch size as 64, the epoch as 300, and the iterations as 782, respectively. Then, set DROPOUT to 0.2, and weight decay to 1e-4. After the training, the accuracy of the test data set remains at about 0.8, as illustrated in Fig.10 (after smoothing).

Fig.10 Accuracy of original model on test datasets

2.2.2 Result of L2 regularization optimization

To prevent overfitting, the L2 regularization strategy is added in the pointwise convolution, and Adam algorithm and cross-entropy loss function are used to train the model. Simultaneously, the accuracy of the model is finally improved to 0.85 by using he_normal to replace the default weight initialization strategy glorot_uniform[13-14], as exhibited in Fig.11 (after smoothing).

Fig.11 Accuracy of optimized model on test datasets

2.2.3 Optimization results of feature extraction

After solving the impact of overfitting, CBAM module is added after each depthwise separable convolution to improve the ability of the model to extract features. Besides, the standard convolution of the first layer is changed to expansion convolution to obtain a more extensive receiving range and more abstract features. As presented in Table 5, the model accuracy finally steadily increases to 0.90. The model size comparison is exhibited in Fig.12 (after smoothing).

Fig.12 Optimized model accuracy on test datasets using CBAM

Table 5 Model comparison of CBAM

2.2.4 Model compression

After obtaining the higher accurate model weights, the model is compressed. First, sort the model weights according to the weight magnitude. Then, prune the model weight layer by layer and set 0.6 as the pruning sparsity. Next, train the pruned model for another 50 times to achieve the aim of restoring accuracy. After pruning and retraining, the model weight is changed from float 32-bit to int 8-bit data to further explore the possibility of model compression. However, the accuracy of the model will be significantly reduced. Because after pruning the model, all the remaining weights have immense influence on the inference results. Therefore, these weights cannot be simply quantified before finding a more effective quantization strategy.

Through the above processing, the size of the model is compressed nearly 4-5 times, and the test results indicate that the accuracy of the model is not significantly reduced. More details are provided in Table 6.

Table 6 Comparison of quantization and pruning

2.3 Experimental summary

According to the proposed optimization strategy, the accuracy of the model is improved from 0.8 to 0.85 by using the L2 regularization and Adam algorithm. Then the accuracy of the model is improved to 0.9 by CBAM and expansion convolution. Finally, the size of the model is effectively compressed through pruning and weight quantization. The model reaches a balance between accuracy and size. More details are provided in Table 7.

Table 7 Comparison of models

3 Application of garbage classification

After using the standard dataset to illustrate the effect of the optimization strategy, the garbage classification problem is taken as an example to illuminate the application value of model optimization. First, collect the garbage data that frequently appears in public places as a data set. Then, train and compress the model according to the model optimization strategy. Finally, convert the training model to the corresponding format, and deploy the application on the Android system.

3.1 Experimental environment

The model training environment and the mobile phone deployment environment are presented in Tables 8 and 9.

Table 8 Model training environment

Table 9 Mobile deployment environment

3.2 Data preprocessing

Since there is no standard dataset for garbage classification, it is necessary to create a training dataset forit. According to the regulations released by Shanghai, domestic garbage is classified into four basic standard categories. They are recyclable garbage, hazardous garbage, wet garbage and dry garbage. The data preprocessing is described as follows.

1) According to the four types of garbage classification in the regulations. Python is used to crawl on the Internet to obtain each garbage sample image. Besides, pictures is also collected into the dataset in real life.

2) Delete unreasonable data. Invalid data set images with too much interference information should be deleted, as illustrated in Fig.13.

Fig.13 Invalid data and interfering data

3) Divide each type of data into the train data set and test data set at a ratio of 7∶3.

4) Data enhancement. Flip or rotate existing data to get more data. Thus, the neural network has a better generalization ability[15]. As presented in Figs.14 and 15, the samples are randomly cropped and flipped horizontally and vertically.

Fig.14 Random cropping of data

Fig.15 Horizontal and vertical flipping of data

In summary, the garbage classification data set contains 1 000 pictures in each class after the processing of the above four steps.

3.3 Results and discussion

3.3.1 Model training

First, the trained and optimized model is tested on the PC side. As illustrated in Figs.16-18, tissue and batteries are correctly identified as dry garbage (dry score=0.999 98) and hazardous garbage (hazard score=0.997 801), respectively. Besides, the framework tool tensor board is used to display the visual evaluation data of the classification results on the test set, as exhibited in Figs.19 and 20. Moreover, the classification accuracy on the test data set is maintain at 93%-95%, and the cross-entropy loss function is stable at the range of 0.15-0.2.

Fig.16 Test data: napkin and lead battery

Fig.17 Napkin classification result

Fig.18 Lead battery classification result

Fig.19 Accuracy on test dataset

Fig.20 Cross entropy loss function result

3.3.2 Android mobile deployment

Python API is used to convert the model obtained on the PC into a T flite format file. Then, put this file into the Android project. Finally, it is deployed to the mobile phone through Android studio. The classification of empty bottles and waste food are presented in Fig.21, respectively.

Referring to the published papers on MobileNet network[16-17], the experiment completes the task basically. However, the model cannot fully learn all the features of garbage images due to the size of the data set and the limitation of sample collection. This may result in insufficient classification accuracy and poor robustness.

Fig.21 Classification results on phone

4 Conclusions

The lightweight model MobileNet is used as the research object. Its feature extraction ability and the efficiency of model parameters are optimized. The results demonstrate that the accuracy of the model on the standard data set CIFAR10 is improved, and the model size is effectively reduced. Besides, a garbage image classification application is deployed on Android phones. Furthermore, more data can be obtained by building a big data collection platform for garbage classification, contributing to an improvement of the robustness of the model.