RW-KD: Sample-wise Loss Terms Re-Weighting for Knowledge Distillation

Knowledge Distillation (KD) is extensively used in Natural Language Processing to compress the pre-training and task-specific finetuning phases of large neural language models. A student model is trained to minimize a convex combination of the prediction loss over the labels and another over the teacher output. However, most existing works either fix the interpolating weight between the two losses apriori or vary the weight using heuristics. In this work, we propose a novel samplewise loss weighting method, RW-KD. A metalearner, simultaneously trained with the student, adaptively re-weights the two losses for each sample. We demonstrate, on 7 datasets of the GLUE benchmark, that RW-KD outperforms other loss re-weighting methods for KD.


Introduction
Knowledge Distillation (Ba and Caruana, 2014;Hinton et al., 2015) has proven highly effective for compressing a large-scale NLP model (Devlin et al., 2019;Radford et al., 2019), called teacher in KD terms, into a smaller one, the student. A key factor behind KD's success is the use of teacher output as soft labels for supervising the training of the student (Müller et al., 2019;Yuan et al., 2020). The latter model is trained by jointly minimizing the losses on both hard and soft labels. The contribution of each loss term is conventionally controlled by a balancing hyperparameter.
However, recent studies suggested that hard and soft label importance is sample-wise (Tang et al., 2020;Zhou et al., 2021), and only a subset of training samples are crucial for distillation (Li et al., 2018;. For instance, teacher outputs may be of poor quality for some samples (Ghaddar et al., 2021a,b), but highly informative for others (Cho and Kang, 2020). Also, researchers have found that adjusting loss weights during training greatly benefits performance of KD (Clark et al., 2019;Mukherjee and Awadallah, 2020;Jafari et al., 2021). However, the contribution of loss terms is heuristically decayed by an annealing factor, yet another hyperparameter.
We argue that using the same weights for all training samples, referred to in our work as singleweight, prevents exploiting the full advantage of KD, because each data sample might have different optimal weights for the loss terms. We propose a meta-learning approach to learn samplewise weights of loss terms. We revisit learning to weight approaches (Ren et al., 2018;Shu et al., 2019), initially proposed for noisy sample downweighting, and adapt it for loss terms weighting in KD.
Experimental results show that our KD loss weighting scheme consistently outperforms its counterparts on 7 tasks from the GLUE benchmark (Wang et al., 2019). A fine-grained analysis of the learned weights shows that, compared to the baselines, our meta-learner explores a greater range of KD weights to find the sample-wise optimal values.
Learning to weight approaches (Ren et al., 2018;Zhang et al., 2020) were mainly proposed to learn per-sample loss weights in order to discount noisy samples thanks to an auxiliary meta-learner which re-weights training samples of the main model. Such approaches often train a meta-learner on a clean validation set, or on small-loss training samples if no clean data is available. The meta-learner architecture varies from a simple multi-layer perceptron (MLP) as in Meta-Weight-Net (Shu et al., 2019) to LSTM-based encoder as in MentorNet (Jiang et al., 2018).
The work of Jin et al. (2021) on multi-modal model compression with KD is the most similar to ours. The authors train a MLP meta learner (Shu et al., 2019), on the validation set, which assigns samplelevel weights for 3 loss terms that are calculated when text, image, and both modalities are given as input. In our work, we use a transformer-based meta learner to estimate the sample-wise optimal weights for KD with gradient similarity (see Section 3.2).

Methodology
Let T p¨q be a fine-tuned fixed teacher, and S θ p¨q the student model parameterized with θ. Given a training set of tx i , y i u| N i"1 samples where x i is a data sample and y i is the respective label, vanilla KD (Hinton et al., 2015) consists of minimizing a weighted combination of two different losses: where L CE is a cross-entropy (CE) loss on hard labels, and L KD is the Kullback-Leibler divergence (Kullback, 1997) between teacher and student logits. α P r0, 1s is a hyperparameter controlling the contribution of both losses. For simplicity, we refer to L CE py i , S θ px i qq as L CE px i q and L KD pT px i q, S θ px i qq as L KD px i q hereafter.
Reweighting KD We propose a sample-wise reweighting method for KD to learn a balance between the CE and KD loss for every training sample. The new training loss is computed as follows: Update the meta student with a GD step:

Meta Student
Teacher X X +

Meta Student
Normalization Function

Student
Estimate the weight with the negative gradient :

Stage 2
Stage 1 Figure 1: Meta-reweight Module. In Stage 1, the parameter θ of the meta student is updated to be a function of . In Stage 2, the optimal weights tλ CE , λ KD u are estimated with the negative gradient of L meta w.r.t .
Finding the optimal weights for each loss is intractable. Our solution is inspired by Koh and Liang (2017) and Ren et al. (2018). These works investigate which training samples are most responsible for the generalization performance. We follow this line of works and perturb different losses in KD training to identify which loss is more influential and informative.

Meta-reweight KD
We define our problem as a meta-learning one and use the validation set to define a meta-learning loss function. Our meta-reweight module is depicted in Figure 1.
Meta-objective. The optimal selection of λ " tλ CE i , λ KD i u| N i"1 is derived from its performance on the meta dataset of M samples 1 : where L meta is the loss computed on samples from the meta dataset. Since computing the optimal λå nd θ˚need two nested optimization loops, we adopt an online strategy to estimate λ and update θ respectively.
Meta-reweighting In order to derive the optimal weights on two different losses for each sample before updating the student, we use a meta model to compute the weight taking a gradient step on the meta loss. First, we initialize a meta-studentŜ θ p¨q with the same parameters of the student model S θ p¨q at the beginning of every iteration.
Next, we feed a mini-batch of n training samples X pnq " tx i , y i u| n i"1 to the meta-student and compute the CE and KD losses, then perturb their weights by CE i and KD i respectively for each example and calculate the weighted loss: is the collection of all perturbations. We then take a gradient step update on the current parameter θ t : where α is the step size of the gradient descent. Next, we feed a mini-batch of meta examples X pmq " tx j , y j u| m j"1 to the meta-studentŜθ t p¨q and compute the meta loss L meta pX pmq ;θ t q as: Since the parameterθ t of the meta-student becomes a function of as ∇ θtL is a function of , we can directly compute the gradient of meta loss w.r.t via the chain rule, which is implemented in practice by automatic differentiation of deep learning frameworks such as Pytorch (Paszke et al., 2019). Here we take the negative gradients as the estimation of weights: where β is a scaling factor. We then normalize the weights tu CE i , u KD i u for each training sample x i to make them positive and ensure they sum to 1, leading to: Algorithm 1: Knowledge Distillation with Meta-reweighting input :D train , D meta , S θ p¨q, T p¨q 1 S θ p¨q initialization; where δ = 1e-8 is a hyperparameter for helping training stability. In the end, we compute the final loss with locally optimal weights for the two losses for each sample in the training mini-batch and update our student model S θ p¨q.
The weight is estimated by computing gradients of meta loss w.r.t the perturbation on different losses and these gradients can indicate the sensitivity of the meta loss when we perturb each loss used for training. By using these gradients as the weight of different losses, we can adjust the impact of different losses towards better performance on the predefined meta-dataset. The detailed pseudo-code is presented in Algorithm 1.

Weight Estimation via Gradient Similarity
Next, we show the relation between the weight estimation and the gradient similarity. To save space, we omit u KD i . The weight on the CE loss of i-th example is the similarity between the gradient of the i-th example on CE loss and the average gradient of mini-batch of the meta data computed for the meta loss at time step t. The computation of Eq 7 by backpropagation can be rewritten as 2 : where J 1 is the Jacobian vector of L meta w.r.tθ which indicates the direction of decrease in loss on a mini-batch of meta data, and J 2 is the Jacobian vector of L CE of i-th sample w.r.t θ which indicates the direction of decrease of the CE loss of i-th sample. Larger weights mean that moving along the J 2 direction is likely to not only reduce the training loss, but also reduce the meta loss.

Dataset and Evaluation
We run experiments on 7 tasks from the GLUE benchmark (Wang et al., 2019): 2 single-sentence (CoLA and SST-2) and 5 sentence-pair (MRPC, RTE, QQP, QNLI, QQP, and MNLI) classification tasks. Following prior works, we report Matthews correlation on CoLA, F1 score on MRPC and QQP, and accuracy for the other tasks on their corresponding test sets.

Baselines
We compare RW-KD to 4 losses re-weighting methods: • w/o KD In this setting, the KL loss weight (α) is always set to zero.
2 Derivation can be found in Appendix A.
• Vanilla-KD Here, we select the best performing α value for each task.
Finally, we consider the recent WLS-KD (Zhou et al., 2021) dynamic re-weighting method, where α is calculated as follow: where L s ce and L t ce are loss values on the hard label for the student and teacher respectively.

Implementation
All models use a 12-layer BERT-base-uncased model (Devlin et al., 2019) as teacher, and the pretrained 6-layer distillBERT (Sanh et al., 2019) as initialization for the students. We perform hyperparameter tuning, and select best performing models using early stopping on dev sets. Table 1 shows the performances of the teacher, baselines, and our method on the GLUE test sets. First, we notice that ANL-KD fails to perform as we expected (only 0.2% gain on top of Vanilla-KD), although we extensively tested different α decay schedules.

Results
It is worth mentioning that this approach was successful in multi-task KD when the teacher and the student are of same size. Second, we observe that RW-KD outperforms single-weight weighting schemes (Vanilla and ANL), and sample-wise WLS-KD method by 1.3%,1.1% and 0.6% respectively on all tasks. We plot the weights learned by the meta-learner to better understand why RW-KD performs better. Figure 2  On one hand, we observe that the majority of WLS weights are concentrated below 0.3 and that the best α values were around 0.5 for Vanilla KD. On the other hand, we observe that our meta-learner mostly produces weights with either very high or very low values, and less frequently weights around 0.5 (e.g. CoLA and RTE). Interestingly, this suggests that for many samples, either one of the hard or soft label loss is informative for the student. Consequently, a sample-wise loss weighting method seems a key component of KD.

Conclusion
In this paper, we show the importance of samplewise loss term weighting in Knowledge Distillation and propose RW-KD a method which does this and leads to better distillation performance on 7 GLUE tasks. Future work involves combining RW-KD with state of the art KD methods that use extra loss terms such as intermediate layer similarity (Sanh et al., 2019;Jiao et al., 2020), attention matching (Sun et al., 2020;, and adversarial (Rashid et al., 2021) losses. We expect that these methods can take full advantage of RW-KD, since they use single-weight loss terms weights. In addition to KD training, we will investigate apply-ing our reweighting method to Multi-task Learning (MTL) scenarios (Caruana, 1997;Lu et al., 2019;Stickland and Murray, 2019), where learning to balance losses from different tasks is critical to benefit all tasks involved. where K is the number of parameters of the student. Sinceθ t is a function of : θ t "θ t´α ∇ θtL pX pnq ; θ t , q We continue to expand the middle part ,¨¨¨, BL meta Bθ t,K s T is the Jacobian vector of L meta w.r.tθ on a mini-batch of meta data, J 2 " r BL CE px i q Bθ t,1 ,¨¨¨, BL CE px i q Bθ t,K s T is the Jacobian vector of L CE w.r.t θ of the i-th training sample.