GAML-BERT: Improving BERT Early Exiting by Gradient Aligned Mutual Learning

In this work, we propose a novel framework, Gradient Aligned Mutual Learning BERT (GAML-BERT), for improving the early exiting of BERT. GAML-BERT’s contributions are two-fold. We conduct a set of pilot experiments, which shows that mutual knowledge distillation between a shallow exit and a deep exit leads to better performances for both. From this observation, we use mutual learning to improve BERT’s early exiting performances, that is, we ask each exit of a multi-exit BERT to distill knowledge from each other. Second, we propose GA, a novel training method that aligns the gradients from knowledge distillation to cross-entropy losses. Extensive experiments are conducted on the GLUE benchmark, which shows that our GAML-BERT can significantly outperform the state-of-the-art (SOTA) BERT early exiting methods.


Introduction
Since BERT (Devlin et al., 2018), the pre-trained language models (PLMs) are dominating the field of natural language processing (NLP). The recent years have witnessed the rise of many PLMs, such as GPT (Radford et al., 2019), XLNet (Yang et al., 2019), and ALBERT (Lan et al., 2020), and so forth. These BERT-style models achieved considerable improvements in many Natural Language Processing (NLP) tasks by pre-training on the unlabeled corpus and fine-tuning on labeled tasks, such as text classification natural language inference (NLI), sequence labeling. However, PLMs are notorious for being gigantic and slow in both training and inference. Their significant inference latencies pose great challenges to deployment in real-time applications, such as chat-bots and search engines.
In addition, previous literature (Fan et al., 2020;Michel et al., 2019; find that * Corresponding author. Email: xlwang@cs.ecnu.edu.cn. large PLMs with dozens of Transformer layers are over-parameterized and suffer from the "overthinking" problem (Kaya et al., 2019). That is, for many input samples, their shallow representations at a shallow layer are enough to make a correct classification. Moreover, the final layer's representations may be too overfitting to generalize. 1 The overthinking problem leads to not only poor generalization but also wasted computation.
Addressing the above two issues, a branch of literature focuses on making PLMs' inference more efficient via network pruning (Zhu and Gupta, 2018;Fan et al., 2020;Michel et al., 2019;Zhu et al., 2021;Zhu, 2021c,a), knowledge distillation (Sun et al., 2019b;Sanh et al., 2019;Jiao et al., 2020), weight quantization Bai et al., 2020;Kim et al., 2021) and adaptive inference . The adaptive inference has drawn much attention. The adaptive inference is designated to process simple examples with only shallow layers of BERT and predict more difficult queries with deeper layers, thus significantly speeding up the inference time on average while maintaining high accuracy. The speed-up ratio can be controlled with certain hyper-parameters to handle drastic changes in request traffic. What is more, it can address the over-thinking problem and improve the model's generalization ability.
Early exiting is one of the most important adaptive inference methods (Bolukbasi et al., 2017). As depicted in Figure 1, it implements adaptive inference by installing an early exit, i.e., an intermediate prediction layer, at each layer of BERT and early exiting "easy" samples to speed up inference. At the training stage, all the exits are jointly optimized with BERT's parameters. At the inference stage, there are two different settings. First, in budgeted exiting mode, the model makes a prediction with a fixed exit for all queries. This mode deals with heavy traffic by assigning a shallower exit for prediction. The other one is dynamic exiting mode. That is, some strategies for early exiting is designed to decide whether to exit at each layer given the currently obtained predictions (from previous and current layers) (Teerapittayanon et al., 2016;Kaya et al., 2019;. In this mode, different samples can exit at different depths. Knowledge distillation (KD) (Hinton et al., 2015) is of essential importance for improving early exiting performances. The traditional belief of KD is that the teacher model (usually a stronger model) teaches a lower-capacity student model "dark knowledge" through providing soft targets. As an application of this traditional belief, recent studies by Phuong and Lampert (2019);  improve the training procedure by incorporate KD losses, which encourages the early exits to mimic the output distributions of the last exit. Yuan et al. (2019) challenge the common belief of KD by revealing that knowledge distillation is actually a learned label smoothing (LS) regularization (Szegedy et al., 2016), and a weaker teacher can also enhance a stronger student's performance via knowledge distillation. Zhu (2021b) shows that asking all the exits to learn from one another (mutual learning) is beneficial for early exiting. Sun et al. (2019a) introduce dense pairwise knowledge matching operations at certain intermediate layers of deep convolutional networks during training, which are demonstrated to be beneficial for the whole network's generalization. However, Zhu (2021b) propose this mutual learning framework intuitively, and does not fully explore the underlying motivations. Sun et al. (2019a) focuses on improving the whole network, and does not investigate how these mutual learning (or knowledge matching) affects each layer's exiting performances.
In this work, we first conduct a series of exploratory experiments called pairwise mutual learning (PML). PML selects two exits of BERT and consider the finetuning of these exits with or without adding different knowledge distillation settings. Our experiments show the following three approaches can imrpove the shallower exit's performances: (1) adding supervisions to deeper exits; (2) KD from deeper exits; (3) further asking the deeper exits to learn from this shallow exits. In addition, via the above approaches, the deeper exit's perfor- mance also improves. Our experimental findings is consistent with and complement the conclusions of Szegedy et al. (2016) and Sun et al. (2019a). And thus we adopt the mutual learning (ML) framework to enhance the training of early exits. ML asks all the exits to learn from one another (depicted by Figure 1), thus fully releasing knowledge transfer and regularization capabilities of KD.
Further, by analyzing the directions of gradients from cross-entropy loss and distillation loss (denoted as g CE and g KD ), we find that g KD is often in a conflicting direction with g CE . We hypothesize that g KD⊥CE , g KD 's part that is orthogonal with g CE , will extract the model from moving toward optimum. Thus we propose a novel optimization mechanism called Gradient Alignment (GA). As depicted by 2, GA will project g KD onto the direction of g CE . We propose two versions of GA, GA-soft, and GA-hard, which only differ in how to deal with g KD when the angle of the two gradients is larger than 90 • .
We will call our framework Gradient Aligned Mutual Learning for BERT (GAML-BERT). Extensive experiments are conducted on the GLUE benchmark (Wang et al., 2018) and show that GAML-BERT outperforms existing SOTA BERT early exiting methods, sometimes by a large margin. Deeper analysis and ablation studies result in the following main takeaways: (a) knowledge distillation among the exits can improve their performances, especially for the shallow ones; (b) our gradient alignment method can improve the training procedure and thus improve the model's generalization performances.  Figure 2: Two scenarios of our proposed gradient alignment method. Note that when the gradients' angle is larger than 90 • , g P KD is opposite to g CE .
Our contributions are summarized as follows: • We conduct exploratory experiments to demonstrate that the mutual knowledge distillation between two exits of different depth are benefical for both .
• We propose a novel gradient alignment method for better optimization.

Preliminaries
In this section, we introduce the necessary background for BERT early exiting. Throughout this work, we consider the case of multi-class classification with samples {(x, y), x ∈ X , y ∈ Y, i = 1, 2, ..., N }, e.g., sentences, and the number of classes is K.

Backbone models
In this work, we adopt BERT as backbone models. BERT is a multi-layer Transformer (Vaswani et al., 2017) network, which is pre-trained in a selfsupervised manner on a large corpus.

Early-exiting Architecture
As depicted in Figure 1, early exiting architectures are networks with exits 2 at each transformer layer. With M exits, M classifiers f m (x; θ m ) : X → ∆ K (m = 1, 2, ..., M ) are designated at M layers of BERT, each of which maps its input to the probability simplex ∆ K , i.e., the set of probability distributions over the K classes. All the parameters of the transformer layers and exits are denoted as Θ.

Training
At the training stage, all the exits are jointly optimized with a summed loss function. Following Huang et al. (2017) and , the loss function is the weighted average of the crossentropy (CE) losses given by where L CE m = L CE m (y, f m (x; θ m )) denotes the cross-entropy loss of the m-th exit. Note that the weight m corresponds to the relative inference cost of exit m.

Inference
At inference, the multi-exit BERT can operate in two different modes, depending on whether the computational budget to classify an example is known or not.
Budgeted Exiting. If the computational budget is known, we can directly appoint a suitable exit of BERT, f M (x; θ M ), to predict all queries.
Dynamic Exiting. Under this mode, after receiving a query input x, the model starts to predict on the classifiers f 1 (x; θ 1 ), f 2 (x; θ 2 ), ..., in turn in a forward pass, reusing computation where possible. It will continue to do so until it receives a signal to stop early at an exit M < M , or arrives at the last exit M . At this point, it will output the final predictions based on the current and previous predictions. Note that under this early exit setting, different samples might exit at different layers. 3

Pilot experiments
In this section, we examine the effects of mutual learning among exits by conducting a series of pilot experiments called pairwise mutual learning (PML). In the PML experiments, we select two exits (i, j) (i < j, i.e., exit i is shallower than exit j). We consider the following settings: • Directly finetuning exit x (x = i, j). In this setting, we reveal supervision signals to exit x and finetune it among with the BERT parameters.
• Finetuning exit i and j jointly. That is, we sum up the losses of exit i and j during training.
• Finetuning exit i and j jointly, and asking exit i to learn from exit j.
• Finetuning exit i and j jointly, and asking exit j to learn from exit i.
• Finetuning exit i and j jointly, and asking the two exits to learn from each other.
We conduct the above PML experiments on Co-LA and SST-2 datasets in the GLUE benchmark (Wang et al., 2018). We select two exit pairs, (1, 12) and (6, 12). The performance metrics follow GLUE. Detailed experimental settings are reported in the Appendix. Table 1 reports the results of our pilot experiments. From the results we can see that:

Result analysis
• Exit 12 benefis from KD from exit 6, demonstrating that the last exit, as a strong student, still obtain performance improvements when it receives knowledge distillation from a much weaker teacher. This observation is consistent with Yuan et al. (2019).
• When a lower exit (1 or 6) is finetuned jointly with exit 12 (with no KD), both exits' performances will improve. This observation is consistent with Sun et al. (2019a). Intuitively, letting the intermediate layers to receive supervision signals can improve the lower layers' representation capabilities, thus helping the last exit. For lower layers, receiving the top layers' gradient signal is benefical for lower layers's optimization, thus the lower exit's performance can also be significantly boosted.
• It is normal for lower exits to improve significantly when it receives KD signals from exit 12 since it receives superior knowledge from the latter (Hinton et al., 2015). However, we can see that the lower exits' performances further improve when we introduce mutual KD between the exit pair. Mutual learning not only improve the last exit (consistent with Sun et al. (2019a)) but also the lower exits. We believe the low exits' extra performance gains are from: (a) a better top layer, thus gradient signals are better; (b) mutual learning drives the behaviors of the exit pair to be more similar, which is like a regularization that help to improve the generalization performances.

Mutual learning
In light of the above analysis, and following Zhu (2021b) and Sun et al. (2019a), we adopt the mutual learning framework (Figure 1) to explore the potentials of early exits. That is, all the exiting classifiers learn from one another. The loss terms from this fully mutual learning framework are added to the cross-entropy losses in Eq. 1, and the loss objective becomes where L KD is given by The ML framework is different from Fast-BERT  in two aspects. First, Fast-BERT employs a two-stage learning mechanism, where the optimization with KD is separated from optimization with the cross-entropy loss, and the BERT backbone is frozen during the optimization with KD losses. Second, FastBERT only asks the lower exits to distill knowledge from the last layer. In the ML framework, each layer receives the regularizations from all other exits, thus fully exploiting the regularization potentials of knowledge distillation. In this work, we run FastBERT with the codes of , and experimental results will demonstrate that the ML framework outperforms FastBERT.

Gradient alignment
The ML training mechanism can stabilize the training process and lead to better optimization by implementing rich regularizations over all the exits. However, during experiments, we still observe that the weight of the KD loss α has a large impact on the model performances, and the better performances are achieved by setting α to be relatively small (e.g., 0.1 or 0.2). Our observations are consistent with the experimental observations of Yang et al. (2020); Sun et al. (2019b). Intuitively, it seems that the KD objective is conflicting with the CE loss to a certain degree.
To visualize the interactions between crossentropy loss and KD loss, we separately compute the gradients derived from the two loss objectives, g KD = ∆ Θ L KD and g CE = ∆ Θ L CE . Their an-gle is given by During the BERT model fine-tuning on SST-2 with the mutual learning training objective, we calculate the angles of the gradients g KD and g CE on each training step. The distribution of the gradient angles are plotted in Figure 3. We can see that about half of the optimization steps γ is larger than 90 • , meaning that they have conflicting directions for optimization. This observation motivates us to think that we may obtain better convergences if we can somehow align the two gradients. Thus, the above observation naturally leads to the following hypothesis: Hypothesis 1 (H1): Dropping off the part of KD's gradient that conflicts with CE's gradient and only keeps the part aligned with the latter can improve the trained model's performances.
Thus, we propose a novel optimization method, gradient alignment (GA), to align KD's gradient g KD with CE's gradient g CE . GA is depicted in Figure 2. When we project g KD on g CE , the projected vector is given by Denote the final modified gradient as g GA . We propose two versions of GA, as follows. GA-soft When the angle θ is larger than 90 • , g P KD is also added to g CE . Thus g GA is given by In GA-soft, when the angle θ is larger than 90 • , g P KD is in the opposite direction with g CE , thus might slow down or reverse this gradient descent direction.
GA-hard When the angle θ is larger than 90 • , g P KD is not added to g CE . Thus g GA is given by In GA-hard, when the angle θ is larger than 90 • , we discard g P KD . Our proposed method, GA, is intuitively sound. In GA, the gradient descent direction strictly follows g CE , and we discard the part of g KD that is orthogonal to g CE , thus eliminating the conflicting signals from g KD . What is more, the projected gradient g P KD can help to adjust the pace of gradient descent. When the angle θ is smaller than 90 • , the projected gradient g P KD is added to g CE . In this scenario, the gradients have similar directions. Thus the optimizer is quite sure of the optimization direction, and it should move with a larger step. When the angle θ is larger than 90 • , g P KD is in the opposite direction with g CE . On the one hand, g P KD can be seen as a regularization to g CE , and stops g CE from local optimum or jumping away from optimum. On the other hand, g P KD might slow down the convergences. Thus, we will leave the selection between GA-soft and GA-hard as a hyper-parameter.

Datasets
We evaluate our proposed approach to the classification tasks on the GLUE benchmark. We only exclude the STS-B task since it is a regression task, and we exclude the WNLI task following previous work (Devlin et al., 2018;.

Backbone models
Backbone models. All of the experiments are built upon the Google BERT (Devlin et al., 2019). We ensure fair comparison by setting the hyperparameters related to the PLM backbones the same as HuggingFace Transformers (Wolf et al., 2020).

Baseline methods
We compare with the previous BERT early exiting methods and compare other methods that speed up BERT inference.
Directly reducing layers. We experiment with directly utilizing the first 6 layers of the original (AL)BERT with a single output layer on the top, denoted by (AL)BERT-xL (x = 6). This baseline serves as a lower bound for performance matrics since it does not employ any additional technique.
Static model compression approaches. For knowledge distillation, we include DistillBERT (Sanh et al., 2019) and BERT-PKD (Sun et al., 2019b). 4 For model parameter pruning, we include the results of LayerDrop (Fan et al., 2020) and attention head pruning (Michel et al., 2019) on ALBERT. For module replacing, we include BERT-of-Theseus .
Early exiting approaches. We compare our method with the previous state-of-the-art BERT early exiting approaches, under both budgeted exiting mode and dynamic exiting mode. For dynamic exiting mode, we compare with: (a) entropy-based method DeeBERT; (b) score-based method Shallow-deep; and (c) patience-based exiting method PABEE; (d) FastBERT when it adopts the PABEE's exiting strategy. For budgeted exiting mode, we compare with: (a) BERT with Multiexits fine-tuned with a loss objective given by Equation 1, which DeeBERT and PABEE adopt; (b) FastBERT.

Experimental settings
We implement our GAML-BERT and other baseline methods based on HuggingFace's Transformers (Wolf et al., 2020). We conduct our experiments on a single Nvidia V100 16GB GPU.
Training. We add a linear output layer after each intermediate layer of the pre-trained BERT/ALBERT model as the internal classifier. The hyperparameter tuning is done in a crossvalidation fashion on the training set so that the dev set of GLUE tasks remains blind for model generalization. We perform grid search over batch sizes {16, 32, 128}, and learning rates {1e-5, 2e-5, 3e-5, 5e-5} for model parameters Θ, and warm-up steps of {0.8, 1.0, 1.2} times the number of steps in an epoch, and values of weight α (from Eq. 2) {0.1, 0.2, 0.3, 0.5, 0.8}. We will adopt the Adam optimizer. We apply an early stopping mechanism with patience 5 and evaluate the model on the valid set (from cross-validation) after each epoch. Moreover, we define the dev performance of our early exiting architecture as the average performance of all the exits. We will select the model checkpoint with the best average performance in cross-validation.
Dynamic exiting mode inference. Following prior work , dynamic exiting mode inference is on a per-instance basis, i.e., the batch size for inference is set to 1. We believe this setting mimics the common latency-sensitive production scenario when processing individual requests of different difficulties from different users. We adjust the hyper-parameters for each dynamic exiting method such that the speed-up ratio is between 1.80x to 2.1x.  Budgeted exiting mode inference. In this setting, a multi-exit model is forced to output its prediction with a given exit. The results under this mode will be mainly reported in figures depicting the relation between the depth of the exit and the performance scores. Table 2 reports the main results on GLUE with BERT as the backbone model under the dynamic exiting inference mode. From Table 2, we can see that our full model GAML-BERTs, especially with GA-soft, outperforms all previous methods to improve inference efficiency while maintaining good performances, demonstrating the proposed GAML-BERT framework's effectiveness. Note that Table 2 shows that GAML-BERT with GAsoft outperforms GA-hard consistently and by a clear margin on CoLA, MRPC, RTE. Thus, we will refer to GAML-BERT with GA-soft as our GAML-BERT model.

Overall Comparison
Although our work mainly works on NLP tasks, we also show that one can easily apply our GAML-BERT framework to image classification tasks in the Appendix.

Analysis
We now analyze more deeply the main takeaways from Table 2 and our experiments.
Our GAML-BERT method can improve the performances of early exiting, especially on shallow exits. To demonstrate our method's effectiveness and how it improves the shallow exits, we conduct budgeted exiting inference on each task and plot the relationship between the layer depth and the performance score in Figure 4. On CoLA, except that the first four exits have zero scores, all the exits of GAML-BERT outperform multi-exit BERT trained with Eq 1. Similar observations can be made on SST-2. Note that the performance margins on the shallow exits are more significant than those on the deep exits, showing that our model is effective in improving the shallow early exits' performances.
The ML strategies are beneficial. Table 2 reveals that the ML training objective provides the (a) CoLA (b) SST-2 Figure 4: The depth-score curves for different early exiting strategies. The x-axis is the depth of the exit (or the number of layers before entering this exit), the y-axis is the performance metrics following GLUE (Wang et al., 2018). We also add the performance of BERT-base for comparison. best performances on GLUE in terms of knowledge distillation strategies. ML-BERT consistently outperforms FastBERT, and GAML-BERT outperforms all the baseline models, especially FastBERT trained with our gradient alignment method. These experimental results validate our findings in section 3 that mutual learning can help to improve the early exits to the greatest extend. The ML method imposes rich regularizations over all exits, thus improvig the performances of early exiting.
Our GA algorithm brings performance gains. From Table 2, we can see that our full model GAML-BERT, consistently outperforms the ML-BERT, sometimes with a large margin. In addition, we also combine FastBERT with our GA method, denoted by FastBERT-GA. 5 We can see 5 Note that in FastBERT-GA, the KD loss terms are added to the CE loss terms, and the fine-tuning is done in a single stage, which is different from FastBERT's two-stage procedure.  that FastBERT-GA also consistently outperforms FastBERT. These ablation results empirically prove that the hypothesis H1 is true and demonstrate the effectiveness of our GA method.
Our GA algorithm makes the model less sensitive to α. In this group of experiments, we show how changes in α (in Eq. 2) affect the performances of early exiting architectures, with or without our GA method. Table 3 reports the performance comparisons on the CoLA task. We can clearly see that GAML-BERT's performance changes are much more minor than those of ML-BERT, under different values of α. In addition, GAML-BERT outperforms ML-BERT under all values of α, showing that GA can effectively stabilize the training with KD and leads to better optimization.
How our GAML-BERT method affects the training time costs. Table 4 presents the training time costs for GAML-BERT compared with the original BERT and PABEE. Note that we adopt an early stopping mechanism for training. Thus the training time costs are measured by the average number of steps till early stopping. Firstly, although exits introduce extra parameters and extra time for training, early exiting architectures actually can reduce the training time. Intuitively, additional loss objectives can be regarded as additional parameter updating steps for lower layers, thus speeding up the model convergence. Secondly, ML-BERT converges earlier than PABEE, demonstrating the regularization functionality of KD. Third, GA further accelerates the convergences of ML-BERT. Intuitively, GA eliminates the conflicting factors of KD, and thus leading to faster convergences.
How our GAML-BERT method performs under different dynamic exiting strategies. Note that the main experimental results in Table 2 are obtained by adopting the PABEE's patience-based dynamic exiting strategies. However, our GAML-BERT model is off-the-shelf since it can be easily adapted to other exiting strategies. In Table 5

Conclusion
In this work, we propose GAML-BERT, a novel framework for improving PLMs' early exiting. Our contributions are three-fold. Firstly, we conduct a series of exploratory experiments, which shows that mutual knowledge distillation between a pair of exits can boost the performances of both. Following this observation, it is natural to apply mutual learning (ML), that is, asking all the exits to learn from each other, to enhance the performances of BERT early exiting. Second, we propose GA, a novel training method that aligns the gradients from knowledge distillation to cross-entropy losses. Experiments on the GLUE benchmark datasets show that our framework can improve PLMs' early exiting performances, especially under high latency requirements. In addition, our framework is off-theshelf and can be adapted to various early exiting strategies. A The over-thinking problem for BERT In Figure 5, we demonstrate the over-thinking problem of BERT on SST-2 and CoLA. To obtain the performance of BERT-base's layer i, we insert a classifier at layer i and fine-tune BERT-base on the train set. We can see that the last layer does not obtain the best performances.

B A review of early exiting strategies
There are mainly three dynamic exiting strategies for BERT dynamic exiting. BranchyNet (Teerapittayanon et al., 2016), FastBERT , and DeeBERT  calculated the entropy of the prediction probability distribution as a proxy for the confidence of exiting classifiers to enable dynamic early exiting. Shallow-Deep Nets (Kaya et al., 2019) and RightTool (Schwartz et al., 2020) leveraged the softmax scores of predictions of exiting classifiers. If the score of a particular class is dominant and large enough, the model will exit. Recently, PABEE  propose a patience-based dynamic exiting strategy analogous to early stopping model training. That is, if the exits' predictions remain unchanged for a predefined number of times (patience), the model will stop inference and exit. PABEE achieves SOTAs results for BERT early exiting.
In this work, we mainly adopt the PABEE's patience-based early exiting strategy. However, in ablation studies, we will show that our GAML-BERT framework can improve the inference performance of other exiting strategies.
C Hyperparameter settings C.1 Hyperparameters for each task in the pilot experiments Table 6 reports the important hyper-parameters of BERT for each task in the pilot experiments.
C.2 Hyper-parameters for each task in the main experiments Table 7 reports the important hyper-parameters of GAML-BERT for each task. Note that our hyperparameter search was done on the training set with cross-validation so that the GLUE benchmarks' dev set information was not revealed during training.

D GAML-BERT are effective for image classification
To demonstrate the effectiveness of GAML-BERT on the image classification task, we follow the experimental settings in Shallow-Deep (Kaya et al., 2019). We conduct experiments on two image classification datasets, CIFAR-10 and CIFAR-100 (Krizhevsky, 2009). The ResNet-56 model (He et al., 2016) serves as the backbone, and we compare GAML-BERT with PABEE, DBT from Phuong and Lampert (2019). We place an exiting classifier at every two convolutional layers. We set the batch size to 128 and use an SGD optimizer with a learning rate of 0.1. Table 8 reports the results. GAML-BERT outperforms the original ResNet-56 on both tasks even when it provides 1.3x speed-up. Besides, it outperforms PABEE and DBT. To obtain the performance of BERT-base's layer i, we insert a classifier at layer i and fine-tune BERT-base on the train set. The metric is MCC for CoLA, and ACC for SST-2. For CoLA, the highest score is obtained by layer 11. For SST-2, the highest score is obtained by layer 9, and deeper layers have lower performance scores.