DiMS: Distilling Multiple Steps of Iterative Non-Autoregressive Transformers for Machine Translation

,


Introduction
Neural machine translation models typically follow an autoregressive decoding strategy, generating the target sentence one token at a time. This sequential nature makes the inference process slow and dependent on the output sequence length. To address this limitation Gu et al. (2018) introduces the Non-Autoregressive Transformer (NAT). NAT generates the entire target sentence in parallel, reducing the latency by an order of magnitude. NAT can be considered as a member of a broader family of iterative non-autoregressive Transformers (iNAT) (Lee et al., 2020;Stern et al., 2019;Ghazvininejad et al., 2019) where the number of decoding steps is fixed * Equal contribution. Figure 1: DiMS training. The student is trained to match the predictions of the teacher after several iterative steps. Teacher is updated with an exponential moving average of the student. and independent of the sequence length. By tuning the number of decoding steps, one can control the trade-off between speed and quality. While iNATs can be considered as efficient alternatives to their autoregressive counterparts, Kasai et al. (2020b) shows that autoregressive models can be sped up without loss in accuracy by combining shallow decoders with deep encoders. This diminishes the computational advantage of iNATs and challenges their motivation. The focus of recent work has thus been shifted to design single-step NAT models Qian et al., 2021;Du et al., 2021).
In order to preserve the enhancements obtained by multiple decoding iterations of iNATs, we introduce Distill Multiple Steps (DiMS), a distillation algorithm applicable to a wide range of iterative models. Given a pre-trained iNAT, referred to as teacher, a student aims to replicate the behavior of multiple iterative steps of the teacher with one decoding pass. This process resembles the wellknown knowledge distillation framework (Hinton et al., 2015). However, instead of reducing the number of parameters, we aim to decrease the number The proposed distillation can be repeated iteratively, where at the end of each round the newly optimized student becomes the next teacher. While effective, iterative distillation is slow as it requires multiple rounds of training until convergence. Alternatively, we propose updating the parameters of the teacher with an exponential moving average (EMA) of the student. This gradually transfers the new knowledge learned by the student to the teacher and can be viewed as a continuous variant of iterative distillation. Figure 1 depicts the DiMS algorithm.
We demonstrate the effectiveness of our approach on several public datasets by showing that DiMS obtains substantial improvements on singlestep translation with gains of up to 7.8 BLEU points on the distilled training dataset, while the gains on raw datasets are even greater. Notably, we are able to surpass many leading NAT models designed specifically for single-step translation. We further show that EMA considerably speeds up training and converges to a comparable accuracy with iterative distillation in a fraction of epochs.

Background
In this section, we lay out a formal framework for iNATs. We use the setup of Conditional Masked Language Models (CMLM). CMLM first introduced in Ghazvininejad et al. (2019) and subsequently adopted in many iNAT models (Ghazvininejad et al., 2020b;Kasai et al., 2020a;Huang et al., 2021). The source sentence, target sentence, and target sequence length are denoted by x, y and N , respectively.

Training
Given a partially masked reference sentenceỹ and the corresponding source context x, the model is trained to reveal all the masked positions simultaneously (Ghazvininejad et al., 2019). From a probabilistic perspective, this imposes a conditional independence assumption on the predicted tokens. Formally, the training loss is: where M is a distribution over all partially masked target sentences and ξ is a function that returns the set of masked indices. The training objective above implicitly assumes access to the target sentence length. To resolve this issue, CMLM trains a parametric model, length predictor, to predict the output length.

Inference
The inference begins by creating a templateỹ (0) withÑ masked tokens, whereÑ is the output of the length predictor. At iteration t of the inference, the model predicts the translation r (t) givenỹ (t−1) and x as inputs. Depending on the number of decoding iterations S, typically a linear unmasking policy is used where at each stepÑ /S tokens with the highest probability are revealed. This process is repeated S times, resulting in a fully revealed sentence. In other words,ỹ where p θ denotes the output probability of the model. Otherwiseỹ Note that multiple length candidates can be considered (e.g.Ñ ± 1) with the average token probability as a ranking criterion. This is similar to beam Algorithm 1 DiMS Require: Data set D, pre-trained model ϕ, Hidden state loss factor λ, teacher steps n, EMA momentum µ, learning rate η θ t , θ s ← ϕ ▷ Initialize teacher and student while not converged do (x, y) ∼ D ▷ Sample datã y ∼ M(y) ▷ Sample masking p t ← I θt (x,ỹ, n) ▷ Run the teacher for n iterative steps p s ← I θs (x,ỹ, 1) ▷ Run the student for a single step ▷ Gradient based optimization of the student θ t ← (1 − µ)θ s + µθ t ▷ EMA Update of the teacher end while search in autoregressive models but applied to the output sequence length. It is referred to as length beam.

Distillation of Iterative Non-autoregressive Transformers
Increasing the number of decoding steps typically improves accuracy, but diminishes the computational advantage of iNATs. Our objective is to reduce the number of decoding steps without degrading the performance. More specifically, we want to condense the translation quality of multiple steps of a teacher into one decoding pass of a student. For instance, consider an iterative model (teacher) that uses eight decoding steps. By replicating four steps of the teacher with one decoding pass, two steps of the student would be sufficient to reach a similar performance. The standard way of knowledge distillation would have the teacher generate soft labels for all intermediate iterations, and optimize the student to track the teacher's output with fewer steps, but doing such generation on-the-fly greatly increases the training cost. This process can be moved to a pre-processing phase, at the cost of large memory requirement. We propose to use partially masked reference sentences as an approximation to the intermediate predictions of the teacher, which eliminates the need for several decoding passes or large memory capacity.
The distillation process starts by initializing the student and the teacher to the same pre-trained model with parameters ϕ i.e. θ s = θ t = ϕ where θ s and θ t denote the parameters of the student and teacher. Then, the teacher processes a partially masked sentenceỹ through n iterative steps with a linear unmasking policy. More precisely, i/n of the originally masked tokens are revealed up to step i and after the final pass, no masked token remains. This is similar to the inference procedure outlined in Section 2.2, but instead of starting from a fully masked sentence, it starts from a partially masked one. The student is optimized to match the teacher's soft labels and a temperature is used to control the smoothness of the labels. With enough capacity, the student is expected to imitate the behavior of n consecutive steps of the teacher with one decoding pass.

Training Loss
We denote the output distribution after n iterative steps on the partially masked sentenceỹ by I θ (ỹ, x, n) where θ represents the parameters of the model. The distillation loss can be described as: i∈ξ(ỹ) KL p t,i |p s,i where p t = I θt (ỹ, x, n) , p s = I θs (ỹ, x, 1) and i in subscript denotes the index in the sentence. Note that the teacher's soft labels do not come from the same decoding iteration i.e. whenever a token is revealed, the corresponding soft labels are fixed in p t . Thus, the student receives labels from various decoding steps of the teacher. Figure 2 depicts the process teacher follows to produce the labels for two iterative steps. From the student's point of view, the primary difference between DiMS and CMLM training (Section 2.1) is the use of soft labels generated by the teacher instead of the ground truth tokens.
To facilitate the distillation, we combine the KLdivergence with the Euclidean distance of the last layers' hidden states of the teacher and the student. This transfers the knowledge concealed within the hidden states that might not be discernible in soft labels. We refer to this as hidden state loss. Sim- ilar to the KL-divergence, the hidden state loss is computed over the masked indices.
To summarize, DiMS training loss has two terms: i) KL-divergence between distributions predicted by the teacher and the student. ii) The Euclidean distance between the last hidden states of two models. Denoting teacher's and student's last hidden state by e t and e s , DiMS loss can be written formally as: The hyper-parameter λ controls the contribution of hidden state loss. When the distillation is completed, the student is used for inference.

EMA Update of the Teacher
As the distillation progresses, the performance gap between multiple steps of the teacher and a singlepass of the student shrinks, making the teacher's labels less informative. Two approaches can be considered to sustain the usefulness of the teacher's labels: i) Increasing the number of teacher's iterative steps. ii) Restarting the distillation where the recently optimized student becomes the new teacher and repeating this process several times, . The former makes the training more expensive as the number of sequential steps grows, and the latter requires repeated distillation rounds leading to a longer training time.
Instead, we propose updating the teacher with the student's recently learned knowledge. As the student's single-step output approaches the teacher's multi-step, the student's multi-step performance would improve as well, and it is beneficial to use the improved student as the new teacher. However, replacing the teacher directly with the student would hurt the training stability, and can lead to a pathological solution of mapping everything to a constant vector. This degenerate solution shortcuts the L DiMS loss by setting it to a global minimum of zero. To alleviate this, we update the teacher with a slow-exponential-moving average of the student, which transfers the new knowledge learned by the student to the teacher in a controlled manner. The updated teacher now provides a better training target for the student, creating a positive feedback loop between the two models. The teacher also benefits from the ensembling effects of the EMA (Izmailov et al., 2018). Algorithm 1 outlines the steps for DiMS training with EMA.

Experimental Setup
We use Fairseq (Ott et al., 2019) for all the experiments and follow the default data splits. All models are Transformers with encoder-decoder architecture, each having 6 layers and 512-dimensional hidden states. Adam optimizer with inverse squared root learning rate scheduler is used along with mixed precision. EMA and hidden state loss are leveraged with two iterative steps of the teacher unless otherwise stated. We use early stopping based on single-step BLEU score on the validation set. The final model is the average of 5 best checkpoints. Dropout is disabled for the teacher and the student since empirical improvements are observed. We conduct experiments on both the raw and distilled dataset that is obtained from an autoregressive model (Gu et al., 2018). Training is done with 4 Tesla V100 GPUs (32 GB) and we report all the hyper-parameters in Section C of the appendix. The extra computational cost of distillation is a small fraction of original training. We report a detailed comparison in Section E of the appendix.

Main Results
Our main experiments are conducted on WMT'14 En-De and WMT'16 En-Ro datasets with two models: i) CMLM, a pivotal work in iNAT literature showing the effectiveness of conditional masked language models. ii) CMLMC, a recent work improving CMLM by incorporating a correction mechanism. The corresponding official repositories are used to train the teachers. Both models exploit a length predictor that is conditioned on the encoder's hidden states. For CMLMC models we use encoder side masking and prediction  to further boost the performance of the teacher. To make the length predictor compatible with changes in the encoder, we keep the length predictor loss during distillation. Figure 3 contrasts the single-step BLEU score of students with teachers evaluated for various number of decoding steps. DiMS considerably improves the translation quality of the single-step inference, reducing or eliminating the gap with multi-step inference. For example, on the WMT'14 De-En dataset, the single-step of CMLMC+DiMS surpasses the teacher's 4-step performance. We compared our best single-step model with strong baselines in Table 1 showing the effectiveness of our approach. DiMS outperforms all cross-entropy based models and makes cross-entropy based models competitive with their alignment based counterparts.

Results on an Alignment Based Model
To show the versatility of DiMS, we conduct experiment on alignment-based models leveraging Connectionist Temporal Classification (CTC) (Graves et al., 2006) objective. Imputer (Saharia et al., 2020) is among a few models that are both alignment based and iterative. There is no official implementation of Imputer available, therefore we implement a version ourselves (denoted with †) 1 . Table 2 summarizes the results of DiMS applied to Imputer for both directions of the WMT'14 English-German dataset. While DiMS boosts single step translation of Imputer, it still falls behind more recent alignment based models mentioned in Table 1. However, we believe if one incorporates various tricks introduced for alignment based models recently and create a better iterative model, then DiMS can be an effective tool to further enhance the single step translation. Details of Imputer training and distillation are explained in Section F of the appendix.

DiMS on Raw Dataset
The performance of the leading iNATs is at best similar to the autoregressive model used for sequence level knowledge distillation. This limits the final performance of iNATs and makes training without distillation desirable (Huang et al., 2021). Table 3 shows that DiMS improves the raw performance by a large margin even more than the corresponding distilled variant. For instance, DiMS gets more than 12 BLEU scores improvements on single-step evaluation of CMLMC. For one decoding pass, when raw variants of CMLMC are distilled with DiMS the performance is superior to training on the distilled dataset (without DiMS). This makes DiMS preferable to sequence-level knowledge distillation. Nevertheless, the best performance is obtained when the two distillation approaches are combined.

Unsupervised DiMS
In previous sections, we assume access to a parallel dataset and feed a partially masked reference sentence to both student and teacher. One can use the teacher to generate synthetic target sentences during the distillation. This relaxes the dependence on the references and enables using monolingual datasets for distillation. As usual, there is a tradeoff between computation and sample quality i.e. using more decoding passes leads to better data while increasing the computational requirements. We refer to this unsupervised distillation variant as U-DiMS. Note that unsupervised only refers to the distillation, and for training the teacher we still require access to a parallel dataset. The only dis-  tinction between U-DiMS and DiMS is the usage of synthetic data generated by the teacher and the remaining parts are untouched. We run U-DiMS on WMT'14 De-En for CMLM and CMLMC using two iterative steps to generate the synthetic samples.

Ablation Studies
We conduct all the ablation studies on CMLM over WMT'16 En-Ro as it is smaller than WMT'14 and validation set is used for evaluation.

Hidden State Loss
To investigate the effects of hidden state loss, we conduct an ablation study in this section. The first block in Table 5 includes BLEU scores for the base DiMS model with and without this term. The single-step performance of the distilled model is improved over 2 BLEU points by leveraging this loss. This supports the fact that the hidden states contain extra information that is not available in soft labels. The exact value of λ is selected based on a grid search reported in Section D of the ap- pendix.

EMA
In order to establish the computational advantages of the slow-moving average, we compare it with running the base variant for 9 iterative rounds. Figure 4 demonstrates that the EMA variant is able to match the iterative distillation with far fewer updates (almost equal to one round of the distillation).
We observed that it is essential to move the teacher toward the student slowly. For example, when µ ≤ 0.9, the collapse to a degenerate solution (explained in Section 3.2) occurs before the end of the first epoch. We plot the validation curve for various values of µ in Section B of the appendix showing the importance of the slow-moving average.

Teacher Decoding Steps
One hyper-parameter in DiMS algorithm is the number of teacher's decoding steps. In order to investigate the effect of this hyper-parameter, we set it to 2, 4, and 8 while turning EMA on and off. The two bottom blocks of Table 5 include the results of this ablation. Although running the teacher for 4 decoding steps shows superior performance without EMA, as soon as we turn it on the gap disappears. This shows that EMA can gradually improve the teacher and remove the need for several iterative steps. Thus, we find no reason to set this hyper-parameter larger than 2 as it only increases distillation's computational cost.

Analysis
We study the effect of target sentence lengths on DiMS performance. The test set is divided into five equally-sized buckets based on the target length. The BLEU scores are reported for each bucket in Figure 5. The main benefit of the iterative model is manifested by large sentences. The reason might be the fact that longer sentences require a context and modeling it becomes challenging with the conditional independence assumption in NAT. It is clear in Figure 5, that the performance is improved in every bucket. This improvement is most visible in the bucket with the highest average sentence length. This is because of the fact that the same bucket has the largest gap between the teacher's single and multi-step evaluation.
We combine the length predictor objective with ours to account for changes in the encoder's parameters. Interestingly enough, DiMS improves the performance of the length predictor as depicted in Figure 6. This shows that the encoder benefits from the distillation as well. Table 6 shows a qualitative example from the WMT'14 De-En dataset. The improvements in samples are evident by comparing the predictions of the teacher and the student with the target sentence. We provide more qualitative examples in the appendix.

Related Works
Many techniques have been proposed for iterative non-autoregressive machine translation. Earlier attempts include denoising autoencoder (Lee et al., 2018) and insertion-deletion (Stern et al., 2019;. More recently, Ghazvininejad et al. (2019) introduced the Mask-Predict improving the performance of iNATs by employing a conditional masked language model. CMLMC (Huang et al., 2021) and SMART (Ghazvininejad et al., 2020b) improve CMLM by incorporating a correction mechanism. DisCo (Kasai et al., 2020b) is another variant conditioning each token on an arbitrary subset of the other tokens. DiMS is entangled with the progress in this domain as it requires a pre-trained iterative teacher.
The position constraint in cross-entropy can make the NAT training challenging, therefore  propose aligned crossentropy (AXE), an objective that considers the best monotonic alignment between the target and the model's predictions. Du et al. (2021) relaxes the monotonic assumption and introduces Order Agnostic Cross-Entropy (OAXE). CTC (Libovickỳ and Helcl, 2018) is a similar alignment-based objective that fixes the model output length and considers various alignments leading to the same target. Imputer (Saharia et al., 2020) extends CTC to benefit from iterative refinements.
GLAT (Qian et al., 2021) shows that the optimization challenges of iNATs can be mitigated by introducing a curriculum learning focusing on sentences with only a few masked tokens in the early stages of the training and gradually increasing the masking ratio. ENGINE (Tu et al., 2020) assumes access to a pre-trained autoregressive model and optimizes a NAT model to maximize the likelihood under the probability distribution defined by the pre-trained model. Salimans and Ho (2021) applies a distillation technique similar to DiMS on generative models to decrease the number of required steps for generating high-quality images. In contrast to DiMS, the distillation is applied progressively. DiMS eliminates the need for progressive distillation by updating the teacher with EMA. Lastly, the proposed EMA has some resemblance to self-supervised learning techniques (Grill et al., 2020;Caron et al., 2021;He et al., 2020) where two models are updated, one through gradient-based optimization and the other one through EMA. Despite this similarity, the motivations are quite different. In selfsupervised learning, EMA is proposed as a technique to remove large negative sets whereas here EMA enhances the quality of the labels generated by the teacher.

Discussion
It is not completely clear why knowledge distillation works in general (Zhou et al., 2019;Huang et al., 2022a). But when it comes to DiMS, we hypothesize that the labels generated by the teacher make the task simpler for the student. In other words, it is difficult for the model to close the gap between its single step prediction and ground truth while distillation with teacher-generated labels reduces this gap. The importance of the gap between labels and the model capacity has also been observed before (Mirzadeh et al., 2020).

Conclusion
We introduce DiMS, an effective distillation algorithm that enhances the single-step translation quality of a pre-trained iterative model. This is done by replicating the model's multi-step behavior through one decoding pass. The distillation can be repeated to achieve greater gains, but this increases the training time noticeably. We show that the same benefits are obtainable by setting the teacher as a moving average of the student while keeping the training time comparable to one round of the distillation. Experiments over raw and distilled datasets on four translation tasks for supervised and unsupervised variants validate the effectiveness and versatility of DiMS. Potential directions for future works include: i) The same family of iterative models have been applied to automatic speech recognition, thus DiMS is applicable to this domain. ii) One can combine a pyramid of techniques introduced for iNATs to obtain a strong iterative model and make it computationally efficient via DiMS. iii) Large monolingual Target The antibodies hunt down any nicotine molecules in the bloodstream , neutralising them before they reached the brain , preventing a smoker from getting a nicotine hit .
Teacher The antibodies hunt the nicotine molecules molecblood neutralize them before reach brain a smoker not experience high nicotine . Student The antibodies hunt the nicotine molecules in the blood and neutralize them before they reach the brain , so a smoker does not experience a nicotine high . sets can be used to distill models with U-DiMS.

Limitations
While DiMS makes the cross-entropy based family competitive with alignment based variants, it still falls behind one some cases. Moreover, DiMS can improve the performance of models trained on raw data, but the best performance is still achieved when DiMS is applied on distilled datasets. Therefore, DiMS still depends on an auto-regressive model for the best translation quality.

B EMA Momentum Effect
We showcase the importance of the slow moving average in Figure 7. As we increase the momentum the training becomes more stable and leads to a better validation set BLEU score.

C Hyper-parameters for Distillation
We use the same hyper-parameters for all the datasets.

D Ablation on Hidden State Loss Coefficient
The importance of the hidden state loss is shown in Section 4.6.1 of the main body. We conduct an ablation study in this section to find the optimal value of λ that controls the contribution of the hidden state loss.

E Computational Cost
During the distillation we have to run teacher for two steps which adds extra computation. More  Figure 9 compares the overall time for training and distillation on De-En and En-Ro datasets and it shows that the distillation time is one order of magnitude smaller than training time. Note that teacher is being run in the evaluation mode, thus the activations maps are not kept in the memory. Therefore, the teacher can be run with a larger batch-size which further reduces the computational costs. We leave this as future works as it adds implementation complexity.

F Imputer Details
As mentioned in the main body, there is no official implementation of Imputer available online. Here, we explain the differences between our implementation and the original paper. Imputer proposes a  pre-training phase where the model is optimized merely with the CTC objective. We find it unnecessary as the model reaches a better or competitive performance without it. Imputer leverages a unified decoder rather than an encoder-decoder architecture incorporated here. For Imputer training, computing the alignment with the highest probability is necessary. This increases the training cost and (Saharia et al., 2020) proposes either a pre-processing stage or using a stale copy of the active model to manage the extra computation. We compute the best alignment on the fly as it is still computationally feasible. Similar to Imputer inference, extra care is taken to make sure consecutive tokens are not unmasked in the same step. Instead of a Bernoulli masking policy during training, we used a block masking policy. For the distillation, Imputer mainly benefits from two iterative steps and the gains are not as significant after that. Therefore, there is no incentive to use EMA.

Target
The rate of 3.1 per cent is indeed better than the previous year and is also better than in September , " however , we had hoped for more , " said Monika Felder -Bauer , acting branch manager of the Employment Agency in Sonthofen .

Teacher
Although the quota was better 3.1 better than last year and better than September , we would hoped more , " Monika -Bauer Deputy of the Labour Agency in Sonthofen . Student Although the quota at 3.1 % is better than last year and is also better than in September , " we would have hoped for more , " says Monika Felder -Bauer , deputy head of the Labour Agency in Sonthofen .

CMLM WMT'16 Ro-En
Target we must ask these people to learn the language , try to appropriate our values , to stop having one foot in europe and one in their home country , bringing the rest of the family including through marriages of convenience .
Teacher let us ask these people to learn their language , try to take values , stop longer stand in europe and with one their country home origin , bringing the rest of family , including through convenience marriages . Student let us ask these people to learn the language , try to take over values , no longer stand in europe and with one in their home country , bringing the rest of the family , including through convenience marriages .

CMLMC WMT'14 De-En
Target Edward Snowden , the US intelligence whistleblower , has declared that he is willing to travel to Berlin to give evidence to the German parliament if the US National Security Agency and its director Keith Alexander fail to provide answers about its activities .
Teacher Edward Snowden , the whistleblower of the US intelligence , has that he is to travel to Berlin and testify the German destag if the National Security Agency its director Keith Alexander not provide answers about their activities . Student Edward Snowden , the whistleblower of the US intelligence , has said that he is prepared to travel to Berlin and testify to the German destag if the American National Security Agency and its director Keith Alexander do not provide answers to their activities .

CMLMC WMT'16 Ro-En
Target during the routine control , the border policemen observed in the truck 's cab , a large travel bag , which contained personal things of many people , which is why they conducted a thorough check on the means of transport .
Teacher at specific control , border police officers a large travel of travel the cabvan , where there were things for several people , which is why carried out thorough control over the vehicle of transport .
Student at specific control , border police observed , in the cabin , a large travel getravel , where there were personal things for several people , which is why they carried out thorough control over the vehicle of transport .