EBERT: Efficient BERT Inference with Dynamic Structured Pruning

Pruning has been demonstrated as an effective way of reducing computational complexity for deep networks, especially CNNs for computer vision tasks. In this paper, we investigate the opportunity to accelerate the inference of large-scale pre-trained language model via pruning. We propose EBERT, a dynamic structured pruning algorithm for efﬁcient BERT inference. Unlike previous methods that randomly prune the model weights for static inference, EBERT dynamically determines and prunes the unimportant heads in multi-head self-attention layers and the unimportant structured computations in feed-forward network for each input sample at run-time. Experimental results show that our proposed EBERT out-performs other state-of-the-art methods on different tasks.


Introduction
In the last few years, transformer-based (Vaswani et al., 2017) large-scale pre-trained language models, such as BERT (Devlin et al., 2019), RoBERTa , and GPT-3 (Brown et al., 2020), have achieved state-of-the-art results on many NLP tasks, including language understanding, question answering, and reading comprehension. Most recently, researchers also successfully applied transformer-based models to computer vision tasks, achieving comparable or superior performance compared to traditional convolutional networks. For example, Carion et al. (2020) propose detection transformer (DETR) for object detection, Dosovitskiy et al. (2021) design a transformerbased model, namely Vision Transformer (ViT), for image classification. However, due to the notable computational complexity and memory footprint, it is difficult for these models to deploy on hardware platforms under moderate computing and resource budget. Therefore, how to reduce model complexity to enable efficient inference for largescale pre-trained language models is a critical issue.
Pruning is a commonly used technique for network compression, which has been widely explored to reduce computation and storage requirements of convolutional neural networks for computer vision tasks (Han et al., 2015(Han et al., , 2016Li et al., 2016). However, can transformer-based models benefit from pruning? Michel et al. (2019) observe that a large percentage of attention heads can be removed with negligible performance drop, which indicates that the importance of different heads in same layer is different.  propose a simple, deterministic first-order weight pruning method which can prune lots of parameters with minimal accuracy loss. Although these methods are able to reduce the memory footprint, they cannot achieve real performance gain on general-purpose hardware, such as GPGPU, due to the unstructured sparsity after pruning.
Adaptive inference strategy is also proposed to accelerate the inference of BERT. It is based on two observations: 1) the input samples usually have different levels of difficulty. For a given model, it may over-calculate the simple samples while fail in complex samples ; 2) similar to convolutional neural networks, the lower and higher layers of transformer extract different information, and features provided by the intermediate layers may be enough for some samples . FastBERT  and DeeBERT  are two state-of-the-art adaptive inference models for compressing BERT. Both of them insert extra classification layers between each layer of the network. During inference, each input sample only goes through part of model when the outputs of extra classifiers meet predefined criteria like entropy and uncertainty. Because the number of executed layers is reduced, real speedup can be achieved. However, skipping all the computations of the remaining layers may be harmful to the accuracy.
In this paper, we propose EBERT, a hardwarefriendly, simple yet effective algorithm that incorporates structured pruning with adaptive inference for efficient BERT inference. Specifically, EBERT inserts predictors for self-attention sub-layer and feed-forward sub-layer in each transformer block, as illustrated in Figure 1. During inference, the predictors dynamically determine which heads of self-attention layers and channels of feed-forward network can be pruned according to current input. Once a head or a channel is pruned, the corresponding computations and memory cost can be completely avoided. Compared with static pruning methods that permanently prune some parameters, it can avoid prune important parameters for current input samples which will cause large performance drop. To the best of our knowledge, it's the first time to apply dynamically structured pruning to BERT. Experimental results on different benchmark demonstrate that the proposed EBERT can achieve better trade-off between computation reduction and accuracy.

Related Work
Adaptive inference. As different input samples usually have different levels of difficulty, using fixed-size model to process all samples may be non-optimal in terms of computational efficiency. Therefore, the main goal of adaptive inference is to adaptively skip part of the computations according to each input sample to reduce complexity. Fast-BERT  adds student classifiers to the output of each transformer block and use self-distillation strategy to improve performance. The model architecture of DeeBERT ) is similar to FastBERT, but it use entropy of output to decide whether to exit at early stages. PABEE  proposes a novel earlyexit criterion that dynamically stops forward computing when the output of internal classifiers keep unchanged for a pre-defined number of steps.
Pruning. Pruning is an intuitively simple yet effective technique for model compression, which removes unimportant computations based on certain criterion. Michel et al. (2019) observe that a large percentage of attention heads can be removed with negligible performance loss and propose a greedy pruning algorithm. Compressing BERT (Gordon et al., 2020) explores the effect of unstructured weight pruning with different levels of pruning and different training stages. McCarley et al. (2019) investigate the relationship between structured pruning and task-specific distillation. SNIP  proposes a structured pruning method to penalize an entire residual module in Transformer model toward an identity mapping.
Distillation. Knowledge Distillation (Hinton et al., 2015) is an effective technique to get light models from heavy models without sacrificing too much performance. DistilBERT (Sanh et al., 2019) leverages knowledge distillation at pre-training phase to get a lighter pre-trained model, then directly fine-tunes on downstream tasks. BERT-PKD (Sun et al., 2019) proposes an incremental knowledge extraction process. Apart from learning from the final output of teacher model, student model also patiently learns from intermediate layers. Tiny-BERT (Jiao et al., 2020) performs distillation at both the pre-training and task-specific fine-tuning phase. Data augmentation is also used to improve the accuracy of student model.

Methods
In this section, we will first introduce the architecture of EBERT. As shown in Figure 1, it can be divided into BERT branch and predictor branch. Then we will describe the training and inference in details. (b) Predictor_f in FFN. Figure 2: The details of predictors in MHA and FFN layer. Here we assume that n = 4, h = 4 and d i = 6. Shadow area means that computations can be skipped.

BERT Branch
The architecture of BERT consists of three parts: the embedding layer, multi-layer bidirectional Transformer encoders and the task-specific classification layer. Given an input sentence S = [s 0 , s 1 , ..., s n ] with length n, where s 0 is usually a special classification token [CLS], the embedding layer will transform it to a sequence of vector representations: The Transformer encoder contains two sublayers: multi-head self-attention (MHA) layer and position-wise fully connected feed-forward network (FFN), where i = 1, 2, ..., L and Z 0 = E. LN is the Layer Normalization operation. The final component of BERT is a task-specific classification layer. It accepts the representation to [CLS] token as input to generate final results, as:

Predictor Branch
In order to prune unimportant heads and channels for individual input sentence, we add predictors for MHA and FFN in each layer, respectively. The predictor consists of two feed-forward layer, one batch normalization layer and a ReLU activation layer, as depicted in Figure 2. The output t of the second feed-forward layer will be transformed to a 0-1 mask by a function f (·): where x = Z[0, :]. It means that the input of predictor is only [CLS] representation. This choice is based on two reasons. 1) Overhead. Although using the whole representation of input sentence may improve the performance of predictors, the amount of computations increases linearly with the sentence length n. When n is large, the computational overhead of predictors can not be ignored.
2) Representation ability. Because the final hidden state to [CLS] token in the last transformer block is used in task-specific classifier to generate classification results, we assume that [CLS] repre-sentation encodes most of the useful information of the sentence. Note that the representation to [CLS] token in the first MHA is independent with the input sentence, so we use average pooling of MHA as input. Intuitively, t represents the probability of heads or channels being selected. In order to train the model end-to-end with back propagation, Gumbel-Softmax trick (Jang et al., 2017;Maddison et al., 2016) is adopted in our model. Given class probabilities π 1 , π 2 , ...π n , discrete samples z can be drawn as: where g i is a sample drawn from a Gumbel distribution. Gumbel-Softmax trick replaces arg max operation with a softmax function, which is a continuous differentiable approximation to arg max: As the value of mask m is binary (0 for prune and 1 for preserve), we can simplify the Gumbel-Softmax formulation (Verelst and Tuytelaars, 2020). For the output t[i] ∈ (−∞, ∞), we can convert it to probabilities π 1 and π 2 by using a sigmoid function σ: Substituting (7) into (6), we can get: As y 1 < y 2 means the head or channel will be pruned, the final formulation is:

Training
The entire training process can be divided into three stages: fine-tune the BERT branch, joint train both branches, and re-train the BERT branch.
Fine-tuning. In the first stage, only BERT branch is fine-tuned on downstream tasks with loss L task . The training strategy is the same as BERT in (Devlin et al., 2019).
Joint Training. In this stage, we jointly train the pre-trained BERT branch and randomly initialized predictor branch to make the average ratio of remaining Floating-point operations (FLOPs) reach a target value C t ∈ [0, 1]. In order to achieve this goal, we add a loss to minimize the difference between real computational cost of the whole network and C t : Where F o is the FLOPs of original network, and F c is the average FLOPs of current model in a mini-batch.
In addition to the FLOPs constraint, we also add extra loss function to control the sparsity of each MHA and FFN, as in (11). The purpose is to avoid high sparsity of some layers that is harmful to the accuracy of the model. where where λ 1 and λ 2 control the magnitude of task and sparsity loss, respectively.
Re-training. As different input samples usually activate different parts of heads, the total update of a particular head is less than that of regular training process. As a result, the heads are probably not trained sufficiently. So do the channels in FFNs. Therefore, in this stage, we freeze the parameters of predictors and only re-train the BERT branch.

Inference
The computation flow during inference is shown in Figure 2. Given an input sequence, the predictor generates a mask by using the representation to [CLS] token. For MHA, heads with mask '0' will not be executed. For FFN, as matrix-matrix multiplication can be transformed to multiple matrixvector multiplications, we only need to complete part of computations where vector's mask is not zero.  Note that the exponential operation in (8) is typically expensive on hardware. Fortunately, this formulation can be simplified during inference by removing Gumbel noise. f (·) now can be rewritten as:  (Rajpurkar et al., 2016) and SQuAD2.0 (Rajpurkar et al., 2018), both of which are largescale reading comprehension datasets. SQuAD1.1 consists of more than 100k questions, and the answer to each question is a segment of text from the corresponding reading passage. SQuAD2.0 is more difficult as it contains over 50k unanswerable questions. We mainly report Exact Match (EM) and F1 scores. Implementation details. We apply the proposed methods to both BERT-base and RoBERTa-base, and implement them with the HuggingFace Transformers Library . The detailed setting of BERT and predictors is shown in Table 1. Figure 3 shows the ratio of FLOPs and parameters of each operation in one encoder. We can find that the extra cost of the predictors is very small. All experiments are completed on a single Nvidia GeForce RTX2080Ti GPU.  For the GLUE benchmark, we set batch size to 32, learning rate to 3e-5, training epochs to 3 while other hyperparameters are kept unchanged from the library for all downstream tasks at backbone fine-tune stage. During joint training, we use λ 1 = 4, λ 2 = 20 for BERT while λ 1 = 2, λ 2 = 10 for RoBERTa. The learning rate for predictors' parameters is 0.02 and 0.01, respectively. The hyperparameters in the third stage is the same as the first stage.
For SQuAD1.1 and SQuAD2.0, the batch size is 12, learning rate is 3e-5 and training epoch is 2. Other settings are consistent with those for BERT on GLUE benchmark.
Baseline. In order to evaluate the effectiveness of EBERT, we implement a Top-k version of BERT that f (·) is as (14). We keep the sparsity of each layer the same, so the value of k can be decided by C t . What's more, for a certain k, the sparsity is a fixed value, so no extra loss need to be added. The training objective is just L task . The training methods is the same as EBERT with Gumbel-Softmax.
For convenience, in the following sections we will use the subscript t to represent Top-k version and use subscript g for Gumbel-Softmax version.

Results on the GLUE benchmark
The main results of our proposed method on the development set of GLUE benchmark are shown in Figure 4. For BERT-base, the results of Gumbel-Softmax is always better than Top-k with the same or even smaller ratio of remaining FLOPs on four tasks. For example, when remaining 50% FLOPs, EBERT g only drops 0.6% on QQP task, while EBERT t drops 1.8%. On the MNLI task, EBERT g 's accuracy with 77% remaining FLOPs is higher than the accuracy of EBERT t with 81% remaining FLOPs. Figure 4(b) shows the performance of ER-oBERTa, and we can find the similar result, e.g. with 50% remaining FLOPs, the performance of ERoBERTa g is 2.3% higher than ERoBERTa t on the MNLI task. This proves the generality of our proposed method to different model.

Results on the SQuAD benchmark
To further demonstrate the generality of our method, we conduct experiments on the SQuAD v1.1 and v2.0 benchmark, which are reading comprehension task that the model need to predict the answer text span in the text for a given question. The results are shown in Figure 5. Similar to the observation in Figure 4, our approach achieves consistent improvement on each ratio of remaining FLOPs compared with the Top-k version. For instance, with 50% remaining FLOPs, EBERT g improves the EM and F1 score by 2.8% and 2.4% on SQuAD v1.1, respectively. On SQuAD v2.0, the improvement of EM and F1 score is 3.3% and 3.4%.

Comparison with Other Methods
We compare our proposed EBERT with other stateof-the-art compression methods. For distillation methods, we compare with DistilBERT (Sanh et al., 2019), BERT-PKD (Sun et al., 2019) and BERT-of-Theseus . For pruning, we compare with SNIP . We also compare with other two dynamic methods: DeeBERT  and PABEE . We do not compare with FastBERT  as they don't report results on the GLUE and SQuAD benchmark.
Note that other works don't report the FLOPs. However, as all of these methods try to reduce computational cost by reducing the number of layers dynamically or statically, it is reasonable to get FLOPs from speedup ratio or compression ratio under the assumption that the FLOPs is proportional to the execution time for a specific layer. For example, as the DistilBERT-6L only has half number of layers of BERT-base, we assume the ratio of remaining FLOPs is 50%.   The training process of EBERT contains three stages: fine-tuning, joint training and re-training. The purpose of re-training is to make each head and channel sufficiently trained. To evaluate the efficacy of this stage, we conduct experiments with RoBERTa on two tasks. Results are shown in Figure 6, we can see that the performance improvement is obvious. With 50% remaining FLOPs, the performance of the model is improved from 84.4% to 85.0% on MNLI and 92.2% to 92.8% on SST-2, respectively. The average performance improvement on MNLI and SST-2 is 0.4% and 0.8%, respectively. Comparing these two results, we find that the improvement is more obvious on small datasets. The reason for this phenomenon is that the parameters of the model are updated more fre-

Remaining FLOPs%
Matched Dev-Acc quently on large datasets, which makes the training of the model more sufficient at the joint training stage. As a result, re-training can be skipped for large datasets to make trade-offs.

Mask Distribution
Like in , we investigate the distribution of the learned masks. Although EBERT can dynamically generate mask for each head and channel for different samples, some masks may be constant of all time, which means that these masks are input-independent. Figure 8 is the layer-wise visualization of mask distribution in MHA and FFN on SST-2 task for masks that are 1) always one (on), 2) always zero (off), and 3) input-dependent. We can see that a large subset of the masks are inputdependent for both heads and channels, which indicates that our model learns to predict the im-  portance of heads and channels for different input samples. For head, the proportion of masks that are input-dependent is higher in the shallow layers. For channel, the 2nd, 5th, 8th and 11th layer have higher proportion of input-dependent masks than other layers.

Layer Distribution
In Section 3.3, we add two extra loss L M and L F to prevent some layers from being too sparse. We conduct experiments on SST-2 task with RoBERTa to verify the effectiveness of these constraints. Figure 7 shows the average number of non-pruned heads of MHA and non-pruned channels of FFN with different ratio of remaining FLOPs. We can see that the number in each layer is quite close, which indicates the average amount of calculations is similar. More importantly, this value is near the target C t . For example, when remaining 80% FLOPs, the number of non-pruned heads is around 9, which is exactly 80% of the number of heads in one MHA. Similarly, the number of non-pruned heads are around 4 and 5 when remaining 40% FLOPs. This phenomenon proves that L M and L F do limit the sparsity of each layer.

Conclusion and Future Works
In this paper, we propose a novel pruning method for efficient BERT inference, which is called EBERT. With the help of predictor branch, EBERT can dynamically prune unimportant heads in MHA and unimportant channels in FFN for each input sample at run-time. Compared with other compression methods, experiments on GLUE and SQuAD benchmarks demonstrate that EBERT can achieve better accuracy-efficiency trade-off.
As we talk about in Section 4.1, the performance of our method on small dataset has large variance. Similar observations also have been mentioned in other works (e.g. SNIP). In order to improve the generality of our method, it would be interesting to find out the exact reason and find the corresponding solution.