LeeBERT: Learned Early Exit for BERT with cross-level optimization

Pre-trained language models like BERT are performant in a wide range of natural language tasks. However, they are resource exhaustive and computationally expensive for industrial scenarios. Thus, early exits are adopted at each layer of BERT to perform adaptive computation by predicting easier samples with the first few layers to speed up the inference. In this work, to improve efficiency without performance drop, we propose a novel training scheme called Learned Early Exit for BERT (LeeBERT). First, we ask each exit to learn from each other, rather than learning only from the last layer. Second, the weights of different loss terms are learned, thus balancing off different objectives. We formulate the optimization of LeeBERT as a bi-level optimization problem, and we propose a novel cross-level optimization (CLO) algorithm to improve the optimization results. Experiments on the GLUE benchmark show that our proposed methods improve the performance of the state-of-the-art (SOTA) early exit methods for pre-trained models.


Introduction
The last couple of years have witnessed the rise of pre-trained language models (PLMs), such as BERT (Devlin et al., 2018), GPT (Radford et al., 2019), XLNet (Yang et al., 2019), and ALBERT (Lan et al., 2020), etc. By pre-training on the unlabeled corpus and fine-tuning on labeled ones, BERT-like models achieved considerable improvements in many Natural Language Processing (NLP) tasks, such as text classification and natural language inference (NLI), sequence labeling, etc.
However, these PLMs suffer from two problems. The first problem is efficiency. The state-of-the-art (SOTAs) achievements of these models usually rely * Contact: 52205901018@stu.ecnu.edu.cn. on very deep model architectures accompanied by high computational demands, impairs their practicalities. Like general search engines or online medical consultation services, industrial settings process generally millions of requests per minute. What makes efficiency more critical is that the traffic of online services varies drastically with time. For example, during the flu season, the search requests of Dingxiangyuan 1 are ten times more than usual. And the number of claims during the holidays is five to ten times more than that of the workdays for online shopping. Many servers need to be deployed to enable BERT in industrial settings, which is unbearable for many companies.
Second, previous literature (Fan et al., 2020;Michel et al., 2019; pointed out that large PLMs with dozens of stacked Transformer layers are over-parameterized and could suffer from the "overthinking" problem (Kaya et al., 2019). That is, for many input samples, their shallow representations at a shallow layer are enough to make a correct classification. In contrast, the final layer's representations may be overfitting or distracted by irrelevant features that do not generalize. The overthinking problem leads to not only poor generalization but also wasted computation.
To address these issues, both the industry and academia have devoted themselves to accelerating PLMs at inference time. Standard methods include direct network pruning (Zhu and Gupta, 2018;Fan et al., 2020;Michel et al., 2019), knowledge distillation (Sun et al., 2019;Sanh et al., 2019;Jiao et al., 2020), weight quantization (Zhang et al., 2020;Bai et al., 2020;Kim et al., 2021) and adaptive inference Geng et al., 2021;. Among them, adaptive inference has attracted much attention. Given that real-world data is usually com-posed of easy samples and difficult samples, adaptive inference aims to deal with simple examples with only a small part of a PLM, thus speeding up inference time on average. The speed-up ratio can be controlled with certain hyper-parameters to cope with drastic changes in request traffic. What's more, it can address the over-thinking problem and improve the model's generalization ability.
Early exiting is one of the most crucial adaptive inference methods (Bolukbasi et al., 2017). It implements adaptive inference by installing exits, or intermediate prediction layer, at each layer of BERT and exiting "easy" samples at exits of the shallow layers to speed up inference (Figure 1). Strategies for early exiting are designed (Teerapittayanon et al., 2016;Kaya et al., 2019;, which decides when to exit given the current obtained predictions (from previous and current layers).
Early exiting architectures' training procedure is essentially a multi-objective problem since each exit is trying to improve its performance. Different objectives from different classifiers may conflict and interfere with one-another (Phuong and Lampert, 2019;Yu et al., 2020). Thus they incorporate distillation loss to improve the training procedure by encouraging early exits to mimic the output distributions of the last exit. The motivation is that the last exit has the maximum network capacity and should be more accurate than the earlier exits. In their work, only the last exit can act as a teacher exit. Besides, the multiple objectives are uniformly weighted.
In this work, we propose a novel training mechanism called Learned Early Exiting for BERT (Lee-BERT). Our contributions are three folded. First, instead of learning from the last exit, LeeBERT asks each exit to learn from each other. The motivation is that different layers extract features of varying granularity. Thus they have different perspectives of the sentence. Distilling knowledge from each other improves the expressiveness of lower exits and alleviates the overfittng of the later exits. Second, to achieve the optimal trade-offs between different loss terms, their weights are treated as parameters and are learned along with model parameters. The optimization of the learnable weights and model parameters is formulated as a bi-level optimization problem, optimized with gradient descent. Built upon previous literature , we propose a novel cross-level optimization (CLO) algorithm to solve the bilevel optimization better.
Extensive experiments are conducted on the GLUE benchmark (Wang et al., 2018), and show that LeeBERT outperforms existing SOTA BERT early exiting methods, sometimes by a large margin. Ablation study shows that: (1) knowledge distillation among all the exits can improve their performances, especially for the shallow ones; (2) our novel CLO algorithm is useful in learning more suitable weights and brings performance gains.
Our contributions are integrated into our Lee-BERT framework, which can be summarized as follows: • We propose a novel training method for early exiting PLMs to ask each exit to learn from each other.
• We propose to find the optimal trade-off of different loss terms by assigning learnable weights.
• We propose a novel cross-level optimization (CLO) algorithm to learn the loss term weights better.

Preliminaries
In this section, we introduce the necessary background for BERT early exiting. Throughout this work, we consider the case of multi-class classification with samples {(x n , y n ), x n ∈ X , y n ∈ Y, i = 1, 2, ..., N }, e.g., sentences, and the number of classes is K.

Backbone models
In this work, we adopt BERT and ALBERT as backbone models. BERT is a multi-layer Transformer (Vaswani et al., 2017) network, which is pre-trained in a self-supervised manner on a large corpus. ALBERT is more lightweight than BERT since it shares parameters across different layers, and the embedding matrix is factorized.

Early exiting architecture
As depicted in Figure 1, early exiting architectures are networks with exits at different transformer layers. With M exits, M classifiers p m : X → ∆ K (m = 1, 2, ..., M ) are designated at M layers of BERT, each of which maps its input to the probability simplex ∆ K , i.e., the set of probability distributions over the K classes. Previous literature (Phuong and Lampert, 2019; think Figure 1: The training procedure of LeeBERT, which differs from the previous literature in two aspects. First, we let exits learn from each other, instead of only asking shallow exits to learn from the deepest exit. Second, the importance of each distillation loss term are retained along with the learning of model parameters.
of p 1 , ..., p M as being ordered from least to most expressive. However, in terms of generalization ability, due to the over-thinking problem, later layers may not be superior to shallow layers.
In principle, the classifiers may or may not share weights and computation, but in the most interesting and practically useful case, they share both.

Early exiting strategies
There are mainly three early exiting strategies for BERT early exiting. BranchyNet (Teerapittayanon et al., 2016), FastBERT  and Dee-BERT  calculated the entropy of the prediction probability distribution as a proxy for the confidence of exiting classifiers to enable early exiting. Shallow-Deep Nets (Kaya et al., 2019) and RightTool (Schwartz et al., 2020) leveraged the softmax scores of predictions of exiting classifiers, that is, if the score of a particular class is dominant and large enough, the model will exit. Recently, PABEE (Zhou et al., 2020) propose a patience based exiting strategy analogous to early stopping model training, that is, if the exits' predictions remain unchanged for a pre-defined number of times (patience), the model will stop inference and exit. PABEE achieves SOTAs results for BERT early exiting.
In this work, we mainly adopt the PABEE's patience based early exiting strategy. However, in ablation studies, we will show that our LeeBERT framework can improve the inference performance of other exiting strategies.

Our LeeBERT framework
In this section, we introduce the proposed Lee-BERT framework. First, we present our distillation based loss design, and then we elaborate on how to optimize with learnable weights. Our main contribution is a novel training mechanism for BERT early exiting, which extends  and Phuong and Lampert (2019) via mutual distillation and learned weights.

Classification loss
When receiving an input sample (x n , y n ), each exit will calculate the cross-entropy loss based on its predicted, and all the exits are simultaneously optimized with a summed loss, i.e., Note that the above objective directly assumes uniform weights for all M loss terms.

Distillation loss
To introduce our contribution, we first remind the reader of the classical distillation framework as introduced in Hinton et al. (2015): assume we want a probabilistic classifier s (student) to learn from another classifier t (teacher). This can be achieved by minimizing the (temperature-scaled) cross-entropy between their prediction distributions, (2) where τ ∈ R + is the distillation temperature, and is the distribution obtained from the distribution t(x) by temperature-scaling, and [t 1/τ (x)] k is defined analogously. The temperature parameter allows controlling the softness of the teachers' predictions: the higher the temperature, the more suppressed is the difference between the largest and the smallest value of the probability vector. The temperature scaling allows compensating for the over-confidence of the network's outputs, i.e., they put too much probability mass on the top predicted class and too little on the others. The factor τ 2 in Eq 2 ensures that the temperature scaling does not negatively affect the gradient magnitude.
Returning to the early exiting architecture, we follow the same strategy as classical distillation but use exits of different layers both as students and teachers. For any exit m, let T (m) ⊂ 1, ..., M (which could be empty) be the set of teacher exits it is meant to learn from. Then we define the overall distillation loss as (4) Previous work (Phuong and Lampert, 2019; considers using only the last exit as as the teacher and all exits learn from it. The usual belief is that deeper exits have more network capacity and more accurate than the early exits. However, the over-thinking phenomenon reveals that later exits may not be superior to earlier ones. The more shallow exit may provide different perspectives in semantic understanding of the input sentences. Thus, to fully learn from available information, later exits can benefit from learning from early exits. With this motivation, we consider two settings: Learn from Later Exits (LLE). In this setting, early exits learn from all its later exits.
Learn from All Exits (LAE). In this setting, an exit learns from all other exits.

Weighted loss
Previous work considers uniform weights for the distillation loss terms or classification loss term, which does not effectively take the trade-off among multiple objectives. First, from the perspective of knowledge distillation, intuitively, later exits should place little weights on the very early exits since they have less to offer. And all exits should place higher importance on exits that are performant and not overfitting. Second, different loss objectives are usually competing, which may hurt the final results.
To address these issues, we propose to assign a set of learnable weights to our loss objective, which are updated via gradient descent along with the model parameters. We give weight w i for each classification loss term and w m,t for the distillation loss term coming from exit m learning from exit t, and the overall loss objective becomes Note that Ω = {w i , w m,t } can be understood as a set of learnable training hyper-parameter.

Single vs. Bi-level optimization
Assume we have two datasets D 1 and D 2 , which usually are both subsets of the training set D tr . D 1 can be equal to D 2 . For a given set of Ω = {w i , w m,t }, the optimal solution Θ * (Ω) of network parameters Θ are derived from D 1 , and the optimal Ω * are determined on D 2 . We denote the loss on dataset D as L D (Θ, Ω), a function of two sets of parameters for convenience. Then the optimization problem becomes Though the above bi-level optimization can accurately describe our problem, it is generally difficult to solve. One heuristic simplification of the above equation is to let D 1 = D 2 = D tr , and the optimization problem in Eq 16 reduces to the single-level optimization (SLO), which can be solved directly by stochastic gradient descent. This reduced formulation treats the learnable weights Ω just as a part of the model parameters. Despite its efficiency, compared with Θ, the number of parameters in Ω is almost neglectable. Thus optimization will need to fit Θ well for gradient descent, resulting in inadequate solutions of Ω.
The most widely adopted optimization algorithm for Eq 16 is the bi-level optimization (BLO) algorithm , which asks D 1 and D 2 to be a random split of D tr . 2 And the gradient descent is done following: that is, updating the parameters in an interleaving fashion: one-step gradient descent of Θ on D 1 followed by one step gradient descent of Ω on D 2 . Note that Θ * (ω) in Eq 16 is not satisfied in BLO due to first-order approximation, leading gradient updates of ω into wrong directions, collapsing the bi-level optimization.

Cross-level optimization
We now propose our cross-level optimization algorithm. The gradient descent updating of Θ and Ω follows The above equation is the core of our CLO algorithm, which we will refer to as CLO-v1, which are derived and demonstrated in detail in the Appendix. We can see that our cross-level optimization's core idea is to draw gradient information from both splits of the training set, thus making the updating of Ω more reliable.
Note that updating Ω requires its gradients on both the D 1 set and D 2 set. Thus its computation complexity is higher than the BLO algorithm. We propose a more efficient version of cross-level optimization (CLO-v2), which can also be found in the Appendix. We divide the training procedure into groups, each group containing C steps, Θ is updated solely on the training set for C − 1 steps, and updated following Eq 9 for the remaining one step. We will call the hyper-parameter C as the crosslevel cycle length. CLO-v2 is more efficient than CLO-v1, and our experiments show that CLO-v2 works well and is comparable with CLO-v1.

Tasks and Datasets
We evaluate our proposed approach to the classification tasks on GLUE benchmark. We only exclude the STS-B task since it is a regression task, and we exclude the WNLI task following previous work (Devlin et al., 2018;Jiao et al., 2020;.

Backbone models
Backbone models. All of the experiments are built upon the Google BERT, ALBERT. We ensure fair comparison by setting the hyper-parameters related to the PLM backbones the same with HuggingFace Transformers (Wolf et al., 2020).

Baseline methods
We compare with the previous BERT early exiting methods and compare other methods that speed up BERT inference.
Directly reducing layers. We experiment with directly utilizing the first 6 and 9 layers of the original (AL)BERT with a single output layer on the top, denoted by (AL)BERT-6L and (AL)BERT-9L, respectively. These two baselines serve as a lower bound for performance metrics since it does not employ any technique.
Static model compression approaches. Input-adaptive inference. This category includes entropy-based method DeeBERT, scorebased method Shallow-deep, and patience-based exiting method PABEE as our baselines. We also  Table 1: Experimental results of models with ALBERT backbone on the development set and GLUE test set. If not specified, LeeBERT and its variants (e.g., LeeBERT-LLE) are optimized using CLO-v2. The mean performance scores of 5 runs are reported. The speed-up ratio is averaged across 7 tasks. Best performances are bolded, "*" indicates the performance gains are statistically significant.
include the results of FastBERT when it adopts the PABEE's exiting strategy.

Experimental settings
We implement LeeBERT on the base of Hugging-Face's Transformers. We conduct our experiments on a single Nvidia V100 16GB GPU.
Training. We add a linear output layer after each intermediate layer of the pre-trained BERT/ALBERT model as the internal classifier. The hyperparameter tuning is done in a crossvalidation fashion on the training set so that the dev set information of GLUE tasks are not revealed. We perform grid search over batch sizes of 16, 32, 128, and learning rates of {1e-5, 2e-5, 3e-5, 5e-5} for model parameters Θ, and learning rates of {1e-5, 1e-4, 1e-3, 5e-3} for learnable weights Ω. The cross-level cycle length C will be selected from 2, 4, 8. We will adopt the Adam optimizer. At each epoch, the training set is randomly split into D 1 and D 2 with a ratio 5 : 5. We apply an early stopping mechanism with patience 5 and evaluate the model on dev set at each epoch end. And we define the dev performance of our early exiting architecture as the average performance of all the exits. We will select the model with the best average performance in cross validation.
We set CLO-v2 as the main optimization algorithm of LeeBERT, and LAE as the main distillation strategy. 4 To demonstrate LeeBERT's ditillation objectives are beneficial, we train LeeBERT with the LLE strategy (LeeBERT-LLE). We also let the loss term weights in FastBERT to be learnable and train with our CLO-v2 algorithm, i.e., FastBERT-CLO-v2.
To compare our LeeBERT's CLO optimization procedure with baselines, we also train LeeBERT with (1) single level algorithm (LeeBERT-SLO); (2) bi-level algorithm (LeeBERT-BLO). To compare CLO-v1 and CLO-v2, we also train the Lee-BERT with CLO-v1, i.e., LeeBERT-CLO-v1. Besides, we also include LeeBERT with randomly assigned discrete weights (LeeBERT-rand) and uniform weights (LeeBERT-uniform) as baselines, which will serve to demonstrate that our optimization procedure is beneficial. The discrete weights are randomly selected from {1, 2, ..., 50}, and are normalized so that the loss terms at each exit have weights summed to 1.
Inference. Following prior work, inference with early exiting is on a per-instance basis, i.e., the batch size for inference is set to 1. We believe this setting mimics the common latency-sensitive production scenario when processing individual requests from different users. We report the mean performance over 5 runs with different random seeds. For DeeBERT and Shallow-deep, we set the threshold for entropy or score, such that the speedup ratio is between 1.80x to 2.1x. For FastBERT and our LeeBERT, we mainly adopt the PABEE's patience based exiting strategy, and we compare the results when the patience is set at 4. How the patience parameter affects the inference efficiency is also investigated for PABEE, FastBERT, and LeeBERT. Table 1 reports the main results on GLUE with ALBERT as the backbone model. ALBERT is parameter and memory-efficient due to its cross-layer parameter sharing strategy, however, it still has high inference latency. From Table 1 we can see that our approach outperforms all compared methods to improve inference efficiency while maintaining good performances, demonstrating the proposed LeeBERT framework's effectiveness. Note that our system can effectively enhance the original AL-BERT and PABEE by a relatively large margin when speeding-up inference by 1.97x. We also conduct experiments on the BERT backbone with the MNLI, MRPC, and SST-2 tasks, which can be found in the Appendix. To give more insight into how early exits perform under different efficiency settings, we illustrate how the patience parameter affect the average number of inference layers (which is directly related to speed-up ratios) (Figure 2), and prediction performances (Figure 3). We also show that one can easily apply our LeeBERT framework to image classification tasks in the Appendix.

Analysis
We now analyze more deeply the main take-aways from Table 1 and our experiments.
Our LeeBERT can speed up inference. Figure 2 shows that on the MRPC task, with the same patience parameter, LeeBERT usually goes through fewer layers (on average) than PABEE and Fast- BERT, showing the LeeBERT can improve the efficiency of PLMs' early exiting.
Our knowledge distillation strategies are beneficial. Table 1 reveals that our LAE setting provides the best overall performances on GLUE in terms of distillation strategies. LeeBERT outperforms FastBERT-CLO-v2 on all tasks and exceeds LeeBERT-LLE on 6 of the seven tasks, and the scores on QNLI the results are comparable. This result proves that exits learning from each other are generally beneficial.
Our CLO algorithm brings performance gains. As a sanity check, LeeBERT-rand performs worse than all optimized LeeBERT models. Table  1 also shows that the SLO and BLO algorithms perform worse than our CLO. And we can see that CLO-v1 and CLO-v2 have comparable results. CLO-v1 seems to have slight advantages on tasks with few samples, but the performance gaps seem to be marginal. Since CLO-v2 is more efficient, we will use CLO-v2 as our main optimization algorithm.
The patience-score curves are different for different PLMs. Figures 3(a) and 3(b) show that differnt PLMs have quite different patience-score curves. For ALBERT, early exiting with PABEE's strategy can improve upon the ALBERT-base finetuning, and the best performance is obtained with patience 6. With patience 6, the average number of inference layers is 8.11. This phenomenon shows that ALBERT base may suffer from the overthinking problem. With the help of our distillation strategy and CLO optimization, the performance gain is considerable. Note that: (a) Without distilla- tion, shallow exits' performances are significantly worse, and our distillation can help these exits to improve; (b) with LeeBERT, the performances of the later exits are comparable to the earlier ones, since the over-thinking problem is alleviated by distillation. However, the patience-score curve for BERT is quite monotonic, suggesting that overthinking problem is less severe. Note that BERT's shallow exits are significantly worse than that of ALBERT, and with LeeBERT, the shallow exits' performances are improved. Training time costs. Table 2 presents the parameter numbers and time costs of training for Lee-BERT compared with the original (AL)BERT, and PABEE, FastBERT. We can see that although exits need extra time for training, early exiting architectures actually can reduce the training time. Intuitively, additional loss objectives can be regarded as additional parameter updating steps for lower layers, thus speeding up the model convergence.
LeeBERT-CLO-v1 requires a longer time for training. Notably, our LeeBERT's time costs are comparable with PABEE and FastBERT, even though it has more complicated gradient updating steps.
Working with different exiting strategies. Recall that our results are mainly obtained by adopting the PABEE's patience based exiting strategies. However, our LeeBERT framework is quite offthe-shelf, and can be integrated with many other exiting strategies. Our framework can work under different exiting strategies. 5 When using entropybased strategy, LeeBERT outperforms DeeBERT

Conclusion and discussions
In this work, we propose a new framework for improving PLMs' early exiting. Our main contributions lie in two aspects. First, we argue that exits should learn and distill knowledge from each other during training. Second, we propose that early exiting networks' training objectives be weighted differently, where the weights are learnable. The learnable weights are optimized with the cross-level optimization we propose. Experiments on the GLUE benchmark datasets show that our framework can improve PLMs' early exiting performances, especially under high latency requirements. Our framework is easy to implement and can be adapted to various early exiting strategies. We want to explore novel exiting strategies that better guarantee exiting performances in the future. A Derivation of our cross-level optimization algorithm.
Using the Lagrangian multiplier method, the Lagrangian function is To solve this Lagrangian function, the gradient descent updating of Θ and Ω becomes Now we formally illustrate the CLO-v1 algorithm, which is in Algorithm 1. We also officially give the CLO-v2 algorithm in Algorithm 2. B Hyper-parameters for each tasks Table 3 reports the important hyper-parameters of LeeBERT for each task. Note that our hyperparameter search was done on the training set with cross-validation so that the GLUE benchmarks' dev set information was not revealed during training.

C Results with BERT backbone
We conduct experiments with the BERT backbone on three representative tasks of GLUE, MNLI, M-RPC, and SST-2. The results are reported in Table  5. The results show that our LeeBERT framework works well with different types of PLMs.

D Patience-performance curves on sst-2
We also provide the patience-performance curves ( Figure 4) on the SST-2 task, with ALBERT and BERT backbones.

E Working with different exiting strategies
Our results are mainly obtained by adopting the PABEE's patience based exiting strategies. Now we demonstrate that LeeBERT can work with other exiting strategies. Table 4 shows that LeeBERT can help improve DeeBERT with its entropy-based exiting method and outperforms Shallow-deep with its max-prediction-based approach.

F LeeBERT are effective for image classification
To demonstrate the effectiveness of LeeBERT on the image classification task, we follow the experimental settings in Shallow-Deep (Kaya et al., 2019). We conduct experiments on two image classification datasets, CIFAR-10 and CIFAR-100 (Krizhevsky, 2009 (2019). After every two convolutional layers, an exiting classifier is added. We set the batch size to 128 and use SGD optimizer with learning rate of 0.1. We set the cross level sycle to be 4, and learning rate of the learnable weights Ω are 0.01. Table 6 reports the results. LeeBERT outperforms the full ResNet-56 on both tasks even when it provides 1.3x speed-up. Besides, it outperforms PABEE and DBT.