DQ-BART: Efficient Sequence-to-Sequence Model via Joint Distillation and Quantization

Large-scale pre-trained sequence-to-sequence models like BART and T5 achieve state-of-the-art performance on many generative NLP tasks. However, such models pose a great challenge in resource-constrained scenarios owing to their large memory requirements and high latency. To alleviate this issue, we propose to jointly distill and quantize the model, where knowledge is transferred from the full-precision teacher model to the quantized and distilled low-precision student model. Empirical analyses show that, despite the challenging nature of generative tasks, we were able to achieve a 16.5x model footprint compression ratio with little performance drop relative to the full-precision counterparts on multiple summarization and QA datasets. We further pushed the limit of compression ratio to 27.7x and presented the performance-efficiency trade-off for generative tasks using pre-trained models. To the best of our knowledge, this is the first work aiming to effectively distill and quantize sequence-to-sequence pre-trained models for language generation tasks.


Introduction
Pretrained sequence-to-sequence (seq2seq) models such as BART  and T5 (Raffel et al., 2020;Xue et al., 2021) have shown great success in various natural language processing (NLP) tasks, such as text summarization (Nallapati et al., 2016;See et al., 2017;Narayan et al., 2018), machine translation, question answering (Fan et al., 2019) and information extraction (Zhou et al., 2021). However, such large-scale pre-trained language models come with hundreds of millions of parameters: Lewis et al. † Work done during an internship at AWS AI Labs. § Equal contribution.
(2020) trained a BART model with 400M parameters, while Raffel et al. (2020) pushed the limit to 11 billion parameters in T5.
The continual growth in model sizes leads to significant demand in both computation and memory resources during inference, and poses a huge challenge on deployment, especially in real-time and/or resource-constrained scenarios. This motivates researchers to compress large pre-trained models to be smaller and faster while retaining strong performance. Among existing compression approaches such as weight-sharing (Dehghani et al., 2019;Lan et al., 2020), low-rank approximation (Ma et al., 2019;Lan et al., 2020), and pruning (Michel et al., 2019), quantization approaches have received attention recently since they reduce model footprints using lower bits for the weight values without changing the carefully-designed model architecture. Most prior work on transformer quantization focused on BERT-based transformers (Zhang et al., 2020;Zafrir et al., 2019;Bai et al., 2021). However, efficient quantization on the encoder-decoder transformers is insufficiently studied. Prato et al. (2020) achieve 8-bit quantization for a seq2seq transformer without significant loss of performance but low-bit quantization proved to be difficult for this model (4-bit performance in Table 2 in their work) due to the accumulation of quantization errors in seq2seq models. Moreover, their work did not target quantizing large-scale pre-trained language models, nor could it be applied to other NLP tasks besides machine translation. Meanwhile, model distillation which transfers knowledge from a large teacher model to a smaller student model has been widely investigated for BERT compression (Sanh et al., 2019;Jiao et al., 2020).
Recently, Shleifer and Rush (2020) applied "shrink and fine-tune" distillation method on BART for text summarization, yet their work focuses more on the methodology for distilling text summarization only. Besides, their work did not yield a sig-nificant model footprint reduction, one of the most challenging issues in the deployment of large models in resource-constrained scenarios.
In this work, we try to address the challenge of building a more efficient seq2seq model by answering two research questions: first, how well does the quantized seq2seq model perform on various tasks? Second, how do we combine quantization and distillation to push the limit of compressing the seq2seq model without significant performance losses in challenging tasks like summarization and question answering? To this end, we proposed a joint distillation and quantization framework, which efficiently transfers the knowledge from a full-precision teacher seq2seq model to its student with fewer layers and ultra-low bits for encoding its parameters. Experimental results on BART show that the proposed models reduce the model footprint by 16.5x while preserving competitive performances on multiple language generation benchmarks, and further illustrate the performanceefficiency trade-off of compressing seq2seq models up to 27.7x smaller. To the best of our knowledge, this is the first work aiming to effectively distill and quantize seq2seq pre-trained models for language generation tasks.

Distilling and Quantizing BART
In this section, we consider two directions for reducing the size of our generative language model: quantization ( §2.1) and distillation ( §2.2). We apply distillation-aware training ( §2.3) to train a quantized and distilled low-precision model as a student model to emulate the full-precision teacher model.

Quantization
Quantization refers to the operation of mapping a real (high-precision) number to its low-precision counterpart in order to achieve model footprint reduction. There has been extensive study on applying quantization to training neural networks. Different quantization schemes include, e.g., linear quantization (e.g., Hubara et al., 2016Hubara et al., , 2017Jacob et al., 2018), non-linear quantization (Li and Sa, 2019), approximation-based quantization method (Lin et al., 2016), and loss-aware quantization (Hou and Kwok, 2018). In our work, we used the approximation-based method with linear quantization following Zhang et al. (2020).
Quantizing BART We applied quantization to the weights of all the hidden layers and most of the embeddings. Following previous work (Zhang et al., 2020), we did not quantize positional embeddings and quantized activations only to 8 bits.
Weight Quantization We dive into the mathematical details of how to quantize the weights in BART models. Let us denote w t ∈ R nt as the vector obtained by stacking all the columns of the full-precision weight matrix W t that we wish to quantize at iteration t. By quantizing w t , we are looking for a scaling factor (also known as quantization step) α t and a low-precision number b t , to replace full precision weight w t with α t b t . When quantizing with more than 2 bits, we are applying the commonly used symmetric linear quantization, with where th = 2 n b −1 − 1 and n b is the number of bits we use for quantization. Then b t can be obtained by b t = round(w t /α t ). When quantizing with 2 bits, we use the approximation based TWN method (Li et al., 2016). The mathematical details are provided in Appendix A.

Distillation
The second task we consider is knowledge distillation, where we train a smaller student model to mimic the behavior of a larger teacher model; specifically, we want to reproduce the output logits, attentions, and hidden states of the teacher model. Following Shleifer and Rush (2020), we initialize the student model by copying the weights from maximally spaced layers of the teacher model, e.g., when initializing a 3-layer student encoder (decoder) from a 6-layer teacher encoder (decoder), we copy the 0th, 3th and 5th layers from the teacher to the student. When copying only 1 layer, we choose the last instead of the first, which has been shown empirically to yield better performance. Different than Shleifer and Rush (2020) who only distill the decoder, we distill both the encoder and the decoder. After initialization, we fine-tune the student model with the combined objective of task loss and distillation loss, i.e. L data + L dist , with where the RHS are MSE losses measuring the difference between the student and teacher with regard to output logits, attention scores (including  We abbreviate the number of bits for weights, word embedding and activations as "W-E-A (#bits)", followed by the number of encoder and decoder layers as "E-D (#layers)". We use the rouge-{1,2,L} as evaluation metrics (Lin, 2004). We found that distillation-aware quantized models achieves comparable or even better performance compared with the full precision models, and combining quantization and distillation, e.g., from "2-2-8 6-6" to "2-2-8 6-3", gives us a further boost in model footprint compression ratio without significant sacrifice in performance. See §3.2 for details. encoder attention, decoder attention and cross attention), and hidden states (including all encoder and decoder layers). 1 We include the details of the loss in Appendix B for completeness.

Distillation-aware quantization
To fine-tune our quantized and distilled model, we use the technique of distillation-aware quantization with a teacher-student architecture from (Zhang et al., 2020) 2 . We treat the quantized and distilled low-precision model as a student model trained to emulate the full precision model, which in this case is the teacher model. Meanwhile, we also keep the full-precision distilled counterpart of the student model for parameter update. At each iteration, we first quantize the full precision student model to get its quantized version, then do the forward pass with the low-precision student model and get the task loss as well as the distillation losses discussed in §2.2. Finally, we use these losses to update the parameters in the full-precision student model.

Experimental Setup
We followed the standard splits of these datasets. The statistics could be found in Appendix C. For ELI5, we reproduced the author's implementation to train a dense retriever that retrieves 10 supporting documents from Wikipedia for each question. Additional details could be found in Appendix D. As our target is achieving efficient seq2seq generative models, we used base-sized BART for summarization and question answering tasks. For machine translation, we used mBART-large due to the lack of pretrained base-sized multilingual BART models. We reused existing models 3 , and finetuned our own models on end tasks when no open-sourced model is available. We trained our quantized-only models for 10 epochs and distilled-and-quantized models for 20 epochs. We used a batch size of 128, a learning rate of 3 × 10 −5 with 5% linear warmup, and selected the best model based on rouge-L scores on the development set. We set generative hyperparameters following previous work . All experiments were performed on A100 GPUs.  Table 1. Green dots are for quantization only, and purple dots are for distillation + quantization. We found that the performance degradation is minimal as the compression ratio grows, especially before 20x.

DQ-BART Results and Discussions
We summarized the main results in Table 1 and visualized the performance on text summarization on the CNN/DailyMail dataset in Figure 1. Additional visualizations are in Appendix E. We found that: 1. Direct quantization performs poorly in generation tasks. The rouge-L score drops ∼50-75% relatively compared with the baseline. 2. The performance of 8-bit distillation-aware quantized models ("8-8-8 6-6") achieves comparable or even better performance compared with the full precision models across all tasks, signaling that 8-bit is not too challenging for generative models like BART, similar to the findings for BERT (Zhang et al., 2020). 3. We were able to achieve a 13.6x model size compression ratio when using 2-bit quantization with the trade-off of slight performance drop for summarization tasks and even no performance drop for the long-form QA task. 4. Combining quantization and distillation gives us a further boost in model compression ratio without significant further sacrifice in performance. For example, when using 2-bit quantization, by cutting the layers of the decoder in half (from "2-2-8 6-6" to "2-2-8 6-3"), we only saw < 0.5 rouge-L performance drop across all tasks while getting another 2.9x compression. 5. When pushing the compression rate to the limit ("2-2-8 1-1"), we were able to achieve a 27.7x compression ratio while still preserving reasonable performance. We observed a rouge-L drop of 5 Thus, for certain tasks a large model compression ratio would not lead to a significant performance drop while for others the drop could be huge, suggesting that the specific compression ratio to use should be decided on a task-by-task basis with the trade-off of performance and efficiency in mind.

DQ-mBART for Translation
We further extend our study to see how distillation and quantization work for mBART , a deeper multilingual model. We experimented mBART-large on WMT English-Romanian translation task (Bojar et al., 2016). The results are in Table 2.  We found that distillation-aware quantization yields reasonably good performance, similar to the findings in DQ-BART (Table 1). However, the performance drops substantially when performing 2-bit quantization with distillation, possibly due to the accumulation of the distillation/quantization error becoming more significant with deeper models and the challenging nature of machine translation.
Future work may explore how to improve the performance of joint distillation and quantization for deep models under a low-bit setting.  We want to understand how much gain there is when doing joint distillation and quantization compared with distillation-only method (Shleifer and Rush, 2020). To do so, we trained distillation-only models and compared them with DQ-BART with a similar size. From Table 3, we found that joint distillation and quantization performs much better across all tasks, signaling the huge gain with joint distillation and quantization. Additional ablation study on "Shrink and Finetune" could be found in Appendix F.

Conclusion
Transformer-based pre-trained seq2seq language models like BART have greatly advanced the state of the art in a range of NLP tasks. Yet, these extremely large-scale models pose a challenge in resource-constrained scenarios. To alleviate this issue, we proposed DQ-BART, a jointly distilled and quantized BART model. Empirical results show that, despite the difficult nature of language generation tasks, we achieve a 16.5x model footprint compression ratio with little performance drop on three generative benchmarks, and further present the performance-efficiency trade-off for seq2seq models up to a 27.7x compression ratio. Additionally, we studied distillation and quantization for mBART on a machine translation task, and highlighted the challenge of joint low-bit quantization with distillation for deeper models on cross-lingual tasks. To the best of our knowledge, our method is the first to apply joint quantization and distillation on pretrained language models, and this is the first work aiming to effectively distill and quantize seq2seq pretrained models for language generation tasks. We hope this work could open doors for developing and applying efficient seq2seq language models. We leave additional compression methods like attention head pruning (Michel et al., 2019) and sequence-level distillation (Kim and Rush, 2016), and the measurement of latency improvements in various settings for future work. Our code is available at https://www.github.com/ amazon-research/dq-bart/.

A Details of TWN Quantization
When quantizing using 2 bits (which is also know as ternarization), following Zhang et al. (2020), we apply the TWN method (Li et al., 2016). To quantize w, we are looking for scaling factor α > 0 and b ∈ {−1, 0, 1} n such that w ∼ αb where n is the dimension of w. To minimize the quantization error, we have the following optimization problem: Denote ∆ as a threshold and I ∆ (x) be a function such that , then according to Hou and Kwok (2018), the solution to the previous optimization problem can be reached at where ⊙ is element-wise multiplication and || · || 1 is the l 1 -norm. To approximate this result, we set ∆ * = 0.7||w|| 1 /dim(w) then compute α * and b * accordingly.

B Details of Distillation Losses
The distillation losses is defined as the following: In this section we'll go through each part of the losses. We denote ϕ enc (·), ϕ dec (·) as the functions that map the index of an encoder/decoder layer of the student model to the index of the teacher model layer that it is trained to emulate, the details of which is discussed in §2.2, and we use l S enc , l S dec to denote the number of encoder layers and decoder layers of the student model. To illustrate, if l S enc = 3, l S dec = 2, we would have: ϕ enc (0, 1, 2) = 0, 3, 5, ϕ dec (0, 1) = 0, 5 For simplicity, we use superscript · S , · T to distinguish counterparts from the student model and teacher model respectively. Next, we will explain the definition of each part of the distillation losses.
Firstly, L logits is the Mean Squared Error (MSE) between the output logits of the student model and that of the teacher model, i.e.
L logits = M SE(logits S , logits T ) Secondly, L att is the attention distillation loss, which is the sum of distillation losses of encoder attentions (EA), decoder attentions (DA), and cross attention (CA), i.e.
with the subscripts i, ϕ(i) specifying the indices of the layers. Finally, L hid is the distillation loss between all the hidden states between student layers and teacher layers, which include encoder hidden states (EHS) and decoder hidden states (DHS):

D ELI5 Additional Details
In this section, we present additional details for the ELI5 dataset.

D.1 Dense Retriever
We were not able to find a public version of supporting documents for ELI5, and thus followed the author's implementation 4 to train a dense retriever 4 https://yjernite.github.io/lfqa.html that retrieves support documents from Wikipedia. Our trained retriever achieves a similar performance compared with the one reported in the author's implementation (recall: ours 0.3273, reported 0.3247).

D.2 Evaluating ELI5 Results
We use the ROUGE-SCORE package 5 to calculate rouge scores through the paper. However, as the author of ELI5 pointed out 4 , the original rouge implementation used in ELI5 and BART papers performs additional normalization. For consistency, we also reported results for ELI5 using the same ROUGE-SCORE package, which differs from the one used in ELI5/BART. Here we compared the performance of our trained ELI5 baseline model with the public one using the rouge implementation used in ELI5/BART papers.  Results in Table 5 shows that the performance of our base-size model is close to the one with largesize reported in . This signals that our baseline model for ELI5 is well-trained.

F Comparisons on "Shrink and Finetune"
We benchmarked the performance of three randomly picked models with the "Shrink and Finetune" schema proposed in Shleifer and Rush (2020). We ran the models using the same hyperparameter settings we used in this paper. The results are shown in Table 6. We found that when using distillation losses between the teacher and the student, the performance are slightly better than the "Shrink and Finetune" method under our setting. This signals that having guidance in weighting is important for a quantized and distilled model to learn well.  Table 6: Performance comparison between the loss used in this paper and the "shrink and finetune" loss from (Shleifer and Rush, 2020).