One Teacher is Enough? Pre-trained Language Model Distillation from Multiple Teachers

Pre-trained language models (PLMs) achieve great success in NLP. However, their huge model sizes hinder their applications in many practical systems. Knowledge distillation is a popular technique to compress PLMs, which learns a small student model from a large teacher PLM. However, the knowledge learned from a single teacher may be limited and even biased, resulting in low-quality student model. In this paper, we propose a multi-teacher knowledge distillation framework named MT-BERT for pre-trained language model compression, which can train high-quality student model from multiple teacher PLMs. In MT-BERT we design a multi-teacher co-finetuning method to jointly finetune multiple teacher PLMs in downstream tasks with shared pooling and prediction layers to align their output space for better collaborative teaching. In addition, we propose a multi-teacher hidden loss and a multi-teacher distillation loss to transfer the useful knowledge in both hidden states and soft labels from multiple teacher PLMs to the student model. Experiments on three benchmark datasets validate the effectiveness of MT-BERT in compressing PLMs.


Introduction
Pre-trained language models (PLMs) such as BERT and RoBERTa have achieved notable success in various NLP tasks (Devlin et al., 2019;. However, many PLMs have a huge model size and computational complexity, making it difficult to deploy them to low-latency and high-concurrence online systems or devices with limited computational resources Wu et al., 2021).
Knowledge distillation is a widely used technique for compressing large-scale pre-trained language models (Sun et al., 2019;. For example, Sanh et al. (2019) proposed Distil-BERT to compress BERT by transferring knowledge from the soft labels predicted by the teacher model to student model with a distillation loss.  proposed TinyBERT, which aligns the hidden states and the attention heatmaps between student and teacher models. These methods usually learn the student model from a single teacher model (Gou et al., 2020). However, the knowledge and supervision provided by a single teacher model may be insufficient to learn an accurate student model, and the student model may also inherit the bias in the teacher model (Bhardwaj et al., 2020). Fortunately, many different large PLMs such as BERT (Devlin et al., 2019), RoBERTa  and UniLM (Dong et al., 2019) are off-theshelf. These PLMs may encode complementary knowledge because they usually have different configurations and are trained on different corpus with different self-supervision tasks (Qiu et al., 2020). Thus, incorporating multiple pre-trained language models into knowledge distillation has the potential to learn better student models.
In this paper, we present a multi-teacher knowledge distillation method named MT-BERT for pretrained language model compression. 1 In MT-BERT, we propose a multi-teacher co-finetuning framework to jointly finetune multiple teacher models with a shared pooling and prediction module to align their output hidden states for better collaborative student teaching. In addition, we propose a multi-teacher hidden loss and a multi-teacher distillation loss to transfer the useful knowledge in both hidden states and soft labels from multiple teacher models to student model. Experiments on three benchmark datasets show MT-BERT can effectively improve the quality of student models for PLM compression and outperform many singleteacher knowledge distillation methods.

MT-BERT
Next, we introduce the details of our multi-teacher knowledge distillation method MT-BERT for pretrained language model compression. 2 We first introduce the multi-teacher co-finetuning framework to jointly finetune multiple teacher models in downstream tasks, and then introduce the multi-teacher distillation framework to collaboratively teach the student with multiple teachers.

Multi-Teacher Co-Finetuning
Researchers have found that distilling the knowledge in the hidden states of a teacher model is important for effective student teaching (Sun et al., 2019;. However, since different teacher PLMs are separately pre-trained with different settings, finetuning them independently may lead to some inconsistency in their feature space, which is not optimal for transferring knowledge in the hidden states of multiple teachers. Thus, we design a multi-teacher co-finetuning framework to obtain some uniformity among the hidden states output by the last layer of different teacher models for better collaborative student teaching, as shown in Fig. 1. Assume there are N teacher models, and denote the hidden states output by the top layer of the i-th teacher as H i . We use a shared pooling 3 layer to summarize each hidden matrix H i into a unified text embedding, and then use a shared dense layer to convert it into a soft probability vector y i . Finally, we jointly optimize the summation of the task-specific losses of all teacher models, i.e., N i=1 CE(y, y i ), where CE(·, ·) stands for the cross-entropy loss and y is the ground-truth label. Since the pooling and prediction layers are shared among different teachers, the feature space of the output hidden states from different teacher PLMs can be aligned, which can help them collaborate better for student teaching.

Multi-Teacher Knowledge Distillation
Next, we introduce our proposed multi-teacher knowledge distillation framework, which is shown in Fig. 2. Two loss functions are used for knowledge distillation, i.e., a multi-teacher hidden loss and a multi-teacher distillation loss.
The multi-teacher hidden loss aims to transfer knowledge in the hidden states of multiple teachers.
2 Codes available at https://github.com/wuch15/MT-BERT 3 In MT-BERT we use attentive pooling because it performs better than average pooling and "[CLS]" token embedding. Assume there are N teacher PLMs, and each of them has T × K Transformer layers. They collaboratively teach a student model with K layers, and each layer in the student model corresponds to T layers in teacher PLMs. 4 Denote the hidden states output by the j-th layer of the student model as H s j , and the corresponding hidden states output by the (T × j)-th layer of the i-th teacher model as H i T j . Following (Sun et al., 2019), we apply the mean squared error (MSE) to the hidden states of corresponding layers in the student and teacher models to encourage the student model to have similar functions with teacher models. The multi-teacher hidden loss L M T −Hid is formulated as follows: where W ij is a learnable transformation matrix. The multi-teacher distillation loss aims to transfer the knowledge in the soft labels output by multiple teachers to student. The predictions of different teachers on the same sample may have different correctness and confidence. Thus, it may be suboptimal to simply ensemble (Fukuda et al., 2017; or choose (Yuan et al., 2020) soft labels without the help of task labels. Since in taskspecific knowledge distillation the labels of training samples are available, we propose a distillation loss weighting method to assign different weights to different samples. The weights are based on the loss inferred from the predictions of corresponding teacher against the gold labels. More specifically, the multi-teacher distillation loss L M T −Dis is formulated as follows: 4 Here we assume that all teacher models have the same number of layers. We will explore to generalize MT-BERT to scenarios where teacher models have different architectures in our future work. where t is the temperature coefficient. In this way, if a teacher's prediction on a certain sample is more close to the ground-truth label, its corresponding distillation loss will gain higher weight.
Following (Tang et al., 2019;Lu et al., 2020), we also incorporate gold labels to compute the taskspecific loss L T ask based on the predictions of the student model, i.e., L T ask = CE(y, y s ). The final loss function L for learning the student model is a summation of the multi-teacher hidden loss, multiteacher distillation loss and the task-specific loss, which is formulated as follows:

Datasets and Experimental Settings
We conduct experiments on three benchmark datasets with different sizes. The first one is SST-2 (Socher et al., 2013), which is a benchmark for text sentiment classification. The second one is RTE (Bentivogli et al., 2009), which is a widely used dataset for natural language inference. The third one is the MIND dataset (Wu et al., 2020c), which is a large-scale public English news dataset. 5 We perform the news topic classification task on this dataset. The detailed statistics of the three datasets are shown in Table 1.
In our experiments, we use the pre-trained 12layer BERT, RoBERTa and UniLM    and a 4-layer student models respectively. We use the token embeddings and the first 4 or 6 Transformer layers of UniLM to initialize the parameters of the student model. The pooling layer is implemented by an attention network (Yang et al., 2016;Wu et al., 2020a). The temperature coefficient t is set to 1. The attention query dimension in the attentive pooling layer is 200. The optimizer we use is Adam (Bengio and LeCun, 2015). The teacher model learning rate is 2e-6 while the student model learning rate is 5e-6. The batch size is 64. Following , we report the accuracy score on the SST-2 and RTE datasets. In addition, since the news topics in the MIND dataset are highly imbalanced, following (Wu et al., 2020b) we report both accuracy and macro-F1 scores. Each experiment is independently repeated 5 times and the average scores are reported.

Performance Evaluation
We compare the performance of MT-BERT with two groups of baselines. The first group includes the 12-layer version of the teacher models, i.e., BERT (Devlin et al., 2019), RoBERTa  and UniLM . The second group includes the 6-layer and 4-layer student     Table 2. 7 Referring to this table, we find MT-BERT can consistently outperform all the single-teacher knowledge distillation methods compared here. This is because the knowledge provided by a single teacher model may be insufficient, and incorporating the complementary knowledge encoded in multiple teacher models can help learn better student model. In addition, 7 We take the original reported results of baseline methods on the SST-2 and RTE datasets, and we run their codes to obtain their results on the MIND dataset.

Teachers
SST-2 (Acc.)  compared with the teacher models, MT-BERT has much fewer parameters and its performance is comparable or even better than these teacher models. It shows that MT-BERT can effectively inherit the knowledge of multiple teacher models even if the model size is significantly compressed.

RTE
We also compare MT-BERT with several multiteacher knowledge distillation methods proposed in the computer vision field that ensemble the outputs of different teachers for student teaching (You et al., 2017;. The results are shown in Fig. 3. We find our MT-BERT performs better than these ensemble-based multi-teacher knowledge distillation methods. This is because these methods do not consider the correctness of the teacher model predictions on a specific sample and cannot transfer useful knowledge encoded in the intermediate layers, which may not be optimal for collaborative knowledge distillation from multiple teachers.

Effectiveness of Multiple Teachers
Next, we study the effectiveness of using multiple teacher PLMs for knowledge distillation. We compare the performance of the 6-layer student model distilled from different combinations of teacher models. The results are summarized in Table 3. It shows that using multiple teacher PLMs can achieve better performance than using a single one. This is because different teacher models can encode complementary knowledge and combining them together can provide better supervision for student model. In addition, combining all three teacher PLMs can further improve the performance of student model, which validates the effectiveness of MT-BERT in distilling knowledge from multiple teacher models.

Ablation Study
We study the effectiveness of the two important techniques in MT-BERT, i.e., the multi-teacher co-finetuning framework and the distillation loss  weighting method. We compare MT-BERT and its variants with one of these modules removed, as shown in Fig. 4. The student model has 6 layers. We find the multi-teacher co-finetuning framework is very important. This is because the hidden states of different teacher models can be in very different spaces, and jointly finetuning multiple teachers with shared pooling and prediction layers can align their output hidden spaces for better collaborative student teaching. In addition, the distillation loss weighting method is also useful. This is because the predictions of different teachers on the same sample may have different correctness, and focusing on the more reliable predictions is helpful for distilling accurate student models.
We also verify the effectiveness of different loss functions in MT-BERT, which is shown in Fig. 5. We find the task loss is very important. It is because in our experiments the corpus for task-specific distillation are not large and the direct supervision from task labels is useful. In addition, the distillation loss is also important. It indicates that transferring the knowledge in soft labels plays a critical role in knowledge distillation. Moreover, the hidden loss is also helpful. It shows that hidden states of different teacher models can provide useful knowledge for student model learning.

Conclusion
In this paper, we propose a multi-teacher knowledge distillation method named MT-BERT for pretrained language model compression, which can learn small but strong student model from multiple teacher PLMs in a collaborative way. We propose a multi-teacher co-finetuning framework to align the output hidden states of multiple teacher models for better collaborative student teaching. In addition, we design a multi-teacher hidden loss and a multi-teacher distillation loss to transfer the useful knowledge in both hidden states and prediction of multiple teacher models to student model. The extensive experiments on three benchmark datasets show that MT-BERT can effectively improve the performance of pre-trained language model compression, and can outperform many single-teacher knowledge distillation methods.