Learning to Rewrite for Non-Autoregressive Neural Machine Translation

Non-autoregressive neural machine translation, which decomposes the dependence on previous target tokens from the inputs of the decoder, has achieved impressive inference speedup but at the cost of inferior accuracy. Previous works employ iterative decoding to improve the translation by applying multiple refinement iterations. However, a serious drawback is that these approaches expose the serious weakness in recognizing the erroneous translation pieces. In this paper, we propose an architecture named RewriteNAT to explicitly learn to rewrite the erroneous translation pieces. Specifically, RewriteNAT utilizes a locator module to locate the erroneous ones, which are then revised into the correct ones by a revisor module. Towards keeping the consistency of data distribution with iterative decoding, an iterative training strategy is employed to further improve the capacity of rewriting. Extensive experiments conducted on several widely-used benchmarks show that RewriteNAT can achieve better performance while significantly reducing decoding time, compared with previous iterative decoding strategies. In particular, RewriteNAT can obtain competitive results with autoregressive translation on WMT14 En-De, En-Fr and WMT16 Ro-En translation benchmarks.


Introduction
State-of-the-art neural machine translation (NMT) systems use autoregressive decoding where the decoder generates a target sentence word by word, and the generation of the latter words depends on previously generated ones (Bahdanau et al., 2015;Gehring et al., 2017;Vaswani et al., 2017). Instead of sequential decoding as in the autoregressive translation (AT), non-autoregressive neural machine translation (NAT) (Gu et al., 2018;Guo et al., 2019;Ma et al., 2019;Sun et al., 1 Our code is publicly available at https://github. com/xwgeng/RewriteNAT. Heuristic Rules [MASK] y 3 y 2 translation translation (b) RewriteNAT Figure 1: Illustration of the difference in masking words between (a) conventional masked LM-based NAT (Ghazvininejad et al., 2019) and (b) our proposed REWRITENAT. Instead of using inefficient heuristic rules which perhaps mask correct words in some case (e.g., y 1 y 4 ), REWRITENAT utilizes an additional locator module to learn to explicitly distinguish erroneous translation pieces (e.g.,ŷ 2ŷ3 ), annotated as special symbol (i.e., [MASK]).
2019; Ghazvininejad et al., 2020a;Ding et al., 2021a,b) generates the whole target sentence simultaneously. To enable parallel decoding, NAT imposes a conditional independence assumption among words in target sentences, yielding significantly faster inference speed than AT. However, since intrinsic dependencies within target sentence are omitted, NAT suffers from severe inconsistency problem , leading to inferior translation quality, especially when capturing highly multimodal distribution of target translations (Gu et al., 2018). Towards tackling above fundamental problem, iterative decoding (Lee et al., 2018;Ghazvininejad et al., 2019;Gu et al., 2019;Guo et al., 2020b;Ghazvininejad et al., 2020b) is proposed to improve NAT by repeatedly refining previously generated translation. Instead of enforcing NAT to generate accurate translation by one-pass decoding, these approaches are expected to revise incorrect translation pieces through several refinements (Xia et al., 2017;Geng et al., 2018). With the introduction of iterative decoding, NAT further boosts translation quality, bridging performance gap between NAT and AT models.
However, existing iterative NAT models expose the weakness in distinguishing the erroneous words. The dominant approach to identify the mistakes is mask-predict algorithm (Ghazvininejad et al., 2019;Guo et al., 2020b), which employs inefficient heuristic rules to roughly choose the least confident words as the erroneous. In some case, mask-predict may mistake to rewrite correct words while maintain erroneous ones, acting as noises to make a negative impact on subsequent iterations. Without explicitly classifying translated words into wrong or right, the translations decode in constant number of iterations, hindering the further improvement of inference speed. Besides, decoder inputs of prevailing iterative NAT models (Kasai et al., 2020;Guo et al., 2020b) almost come from the ground-truth during training, while target sentences generated at different refinement steps are taken as decoder inputs in inference, creating a discrepancy that can hurt performance.
In this paper, we propose an architecture named REWRITENAT, which explicitly learns to rewrite erroneous translation pieces. Specifically, we introduce a locator module to locate incorrect words within previously generated translation. The located words will be masked out and revised by the revisor module in subsequent refinement. We frame learning to rewrite, comprised of two steps: locate and revise, as an iterative training procedure, where locate and revise operations are supervised by comparing the generated translation with the groundtruth. Towards keeping the consistency with iterative decoding, iterative training is utilized to further improve the training procedure. Experimental results on several typical machine translation datasets demonstrate that REWRITENAT achieves consistent improvement over iterative decoding baselines, but with substantially less decoding time. Further analysis show that REWRITENAT prefers to generate the "easy" words at the early decoding iteration, and leaves the more complicated choice later.

Autoregressive Machine Translation
Autoregressive neural machine translation (AT) draws much attention due to its convenience and effectiveness on various machine translation tasks (Sutskever et al., 2014;Cho et al., 2014;Bahdanau et al., 2015). Given a source sentence X = {x 1 , · · · , x T } and target sentence Y = {y 1 , · · · , y T }, AT decomposes translation distribution p AT (X|Y ) into a chain of conditional probabilities in a unidirectional manner: where y <t represents the set of generated tokens before time-step t. Besides, T and T is the length of the source and the target sequence, respectively. Sine AT generates translation in an autoregressive manner, it suffers from low inference speed.

Non-Autoregressive Machine Translation
Towards alleviating this issue, NAT (Gu et al., 2018) removes sequential dependencies within target sentence, and generates target words, simultaneously. NAT models conditional probabilities p NAT (Y |X) of translation from X to Y as a product of conditionally independent per-step distributions: Since each target word y t only depends on the source sentence X, the target distributions p(y t |X) can be computed in parallel at inference time. Nevertheless, this desirable property of parallel decoding comes at the cost that the translation quality is largely sacrificed. Since the intrinsic dependencies within target sentence (y t depends y <t ) are abandoned from decoder input, NAT shows its weakness in exploiting inherent sentence structure for prediction. Hence, NAT has to figure out such target-side information by itself, merely conditioned on source-side information. In contrast, AT produces current target word, conditioned on previously generated words, which provides strong target side context information. Consequently, with less and weaker information, NAT suffers from inferior translation quality.

Architecture
As depicted in Figure 2, our proposed REWRITE-NAT literally consists of three major components: an encoder, a revisor and a locator. The encoder utilizes transformer encoder, comprisedof N e transformer blocks (Vaswani et al., 2017), to convert source sentence into the contextual representations, similar to previous work (Gu et al., 2018 Figure 2: Architecture of our proposed REWRITENAT model, which consists of three major components: an encoder, a revisor and a locator. The encoder is utilized to convert the source sentence into contextual representations. During the decoding, the revisor converts the erroneous words annotated as "[MASK]" into the correct ones, while the incorrect words within previously generated hypothesis are distinguished, by classifying the words into two classes: revise and keep. Given previously located hypothesis, M refinements, each of which utilizes a revisor and a following locator refine the hypothesis, are applied to obtain the final translation. We take an instance from English→German translation as example, where source sentence is "Thank you .". REWRITENAT applies two refinements into the initial hypothesis, merely comprised of " [MASK]". Subsequently, the decoding terminates since the locator categorizes the entire sequence into keep, meaning that any word is not required to be revised.
revisor and locator, composing into an decoder, are employed to revise and locate the incorrect words within previously generated translation, respectively. We will elaborate the revisor and locator in the following.

Revisor
Given altered translation Y r by the locator, the revisor is utilized to convert erroneous pieces into the correct, conditioned on source sentence. Particularly, it's expected to speculate about correct words in positions annotated as " [MASK]", under the context of the remaining translation. Notably, the revisor treats an input merely consisting of "[MASK]" as initial input, meaning that the whole input is required to be revised. Given the hypothesis Y r = {y r 1 , · · · , y r T }, we leverage a stack of transformer blocks (Vaswani et al., 2017;Gu et al., 2018) to generate the corresponding representations H r = {H r 1 , · · · , H r T }, with the glimpse at source representations H e : where TransformerStack r (·) represents the stack of N r transformer blocks with respect to the revisor. Subsequently, the generated representations H r with respect to special symbol " [MASK]" are fed to a classifier π r to generate the target words as follows: where W r and b r are trainable parameters, and represent weight matrix and bias vector, respectively. The generated words by π r are treated as the substitute of the incorrect words annotated as "[MASK]", yield the revised translation Y l = {y l 1 , · · · , y l T } as follows: where Y l is fed to the locator.

Locator
Given previously generated translation as input, we employ the locator to distinguish incorrect words within entire sequence, conditioned on source sentence. Using the locator, each word within translation can be categorized into two types: revise (1) and keep (0). According to resulted classification, it is required to alter previous translation into another format, which is then fed to the revisor. In details, the words annotated as "revise" are substituted by special symbol, denoted as "[MASK]", while the remaining hold. Given previously generated translation Y l = {y l 1 , · · · , y l T } to be located, a stack of transformer blocks (Vaswani et al., 2017;Gu et al., 2018) are utilized to transform input translation Y l into a sequence of hidden states H l = {h l 1 , · · · , h l T }, conditioned on source contextual representations H e : where TransformerStack l (·) represents the stack of N l transformer blocks with respect to the locator. Using induced hidden states H l as input, an additional classifier π l is employed to decide whether previously generated word y l t at step t is required to be revised, and calculated as follows: where W l and b l are trainable parameters, and represent weight matrix and bias vector, respectively. Using the classifier π l , input translation Y l can be converted into an annotation sequence L = {l 1 , · · · , l T }. Subsequently, dependent on the annotation L, the translation Y l is altered into Y r = {y r 1 , · · · , y r T } as follows: where Y r is treated as input of the revisor.

Training
Towards maintaining the consistency of data distribution with iterative decoding at inference time, iterative training strategy is utilized to train REWRITENAT to learn the ability to rewriting, as described in Algorithm 1. During training, at m-th refinement including revise and locate operations, we compare previously-generated translations (i.e., Y r m andŶ l m ) with ground-truth (i.e., Y ) to distinguish erroneous translation pieces, and construct two types of supervised signals (i.e., q(Ŷ r m ) and z(Ŷ l m )) to instruct the learning of revisor and locator modules, respectively. With the introduction of iterative training with M refinements, training objective L(θ) can be formalized as: Algorithm 1 Iterative Training to REWRITENAT 1: Input: Parallel training dataset (X , Y), revisor module π r θ , locator module π l θ , maximum refinement steps M , learning rate γ 2: repeat 3: GenerateŶ l m using π r (·|Ŷ r m , X) as Eq. 5 7: GenerateŶ r m using π l (·|Ŷ l m , X) as Eq. 8 9: Update model parameters θ ← θ + γ∇ θ L 13: until convergence where the translationsŶ r m andŶ l m are generated at m-th refinement step depending on output distributions of the revisor π l (·|Ŷ l m−1 , X) and locator π r (·|Ŷ r m , X), respectively. During training, generated translationŶ r m andŶ l m have same length with the ground-truth Y . When calculating revisor objective, we use q(Ŷ r ) as a weight vector to merely concentrate on optimizing at the incorrect words (annotated as [MASK] inŶ r ) but omit the losses with respect to correctly-generated ones: The locator target z(Ŷ l ) is a vector meaning that the positions where translationŶ l is different from ground-truth Y should be categorized into revise (1), while the remaining are mapped into keep (0):

Inference
During training REWRITENAT generates the translations with same length as the ground-truth, while in inference we apply REWRITENAT over a sequence of "[MASK]" with a length predicted by length classifier (Lee et al., 2018). When locator module classifies entire sentence into keep or the classifications of two consecutive refinements keep the same, decoding stops (a.k.a dynamic halting).   (Devlin et al., 2019), and sample weights from N (0, 0.02), set biases to zero, and set layer normalization parameters to β = 0 and γ = 1. For regularization, we use dropout (En↔De and En↔Ro: 0.3, En→Fr: 0.1, En→Zh: 0.25), 0.01 L 2 weight decay, and smoothed cross validation loss with λ = 0.1. we adopt the Adam optimizer (Kingma and Ba, 2015) using β 1 = 0.9, β 2 = 0.98, = 1e −8 . The learning rate is scheduled using inverse_sqrt with a maximum learning rate 0.0005, and 10,000 warmup steps except for TRANSFORMER which sets warmup steps as 4000. All the models are run on 8 Tesla V100 GPUs for 300,000 updates with an effective batch size of 128,000 tokens apart from En→Fr where we make 500,000 updates to account for the data size. During decoding, we use a beam size of b = 5 for autoregressive decoding, while length beam (Ghazvininejad et al., 2019) is applied to obtain the translation with respect to non-autoregressive counterpart.  Table 3: Average number of iterations ("Iters.") and performance ("BLEU") with repsect to REWRITENAT on large-scale WMT17 En→Zh and WMT14 En→Fr datasets.

Decoding Speed
As shown above, REWRITENAT can obtain substantial improvements than strong iterative NAT baselines while reducing the number of iterations. Here we compare them in terms of speedup with respect to TRANSFORMER, as depicted in Figure 3. It can be clearly observed that REWRITENAT can obtain same performance but with substantially higher speedup than iterative NAT baselines. When maximum iteration is set as 2 (i.e., T = 2), REWRITE-NAT obtains competitive result to CMLM and TRANSFORMER with b = 1 (i.e., 27.03 vs. 27.05) but with higher speedup (i.e., 7.02×). The performance of REWRITENAT benefits much from the growth of T until T = 4. Particularly, REWRITE-NAT with T = 4 achieves comparable result (27.77 vs. 27.82, 3.86×) with TRANSFORMER with b = 5. Furthermore, REWRITENAT with T = 4 outperforms the strongest SMART (i.e., 27.56 vs. 27.77) but using about half of decoding time. Afterwards, performance gain is relatively subtle but with a slight decrease of speedup due to dynamic halting.

Word Repetitions
With decoupling the sequential dependencies among target sentence, NAT shows the serious weakness in modeling highly multimodal distributions (Gu et al., 2018), often manifest as word repetitions  in generated translations. Towards evaluating the multi-modality, we follow Ghazvininejad et al. (2019) to measure the percentage of consecutive repetitive words as a proxy metric. As shown in Table 4, the proportion of repetitive words with respect to REWRITENAT is significantly lower than most relevant CMLM baseline, especially when decoding using single iteration (-6.05%). Simultaneously, REWRITE-NAT can achieve substantial performance over CMLM. These results demonstrate the superiority of REWRITENAT over CMLM in alleviating word repetitions.  Table 4: The performance ("BLEU") and percentage of repetitive words ("Reps") when decoding with a different number of iterations on WMT14 En→De test set. Notably, with respect to REWRITENAT, T denotes the max number of iterations taken during decoding.

Effect of Weight Sharing
Towards evaluating the effectiveness of weight sharing between revisor and locator modules, we conduct some experiments to make the further analysis. As shown in Table 5, the performance of REWRITENAT using sharing parameters (i.e., + w/ sharing) shows a slight decrease (i.e., 27.54 vs. 27.83) on WMT14 En→De translation task, but still surpasses the most relevant baseline (i.e., CMLM). Besides, it's observed that the proposed REWRITENAT with weight sharing can consumes less iterations taken during decoding, leading to a slightly high inference speed.

Iterations vs. Length
As described above, compared with previous iterative NATs, the number of iterations taken during decoding significantly decreases with respect to our proposed REWRITENAT. Towards exploring the impact of length, we compare the number of required iterations and the length of target sentences, as illustrated in Figure 4. It's clearly observed that REWRITENAT can properly choose the number of iterations accordingly. In general, as the length of target sentences grows, REWRITENAT also requires more iterations to produce the translation.

Analysis on Part-of-Speech
Despite proving the effectiveness, we doubt whether the number of iterations has any prefer- ence towards different Part-of-Speechs 6 . For each Part-of-Speech 7 , we calculate average percentage of required iterations to produce the words with respect to different Part-of-Speechs. As shown in Figure 5, REWRITENAT tends to generate punctuation words (i.e., PUNC) early in decoding. Subsequently, nouns are next easiest to predict. Conditioned on generated nouns, other Part-of-Speechs (e.g., CONJ, ADJ, DET, ADV, PREP), which often act as modifiers, prefers to come out in the generated translation. Finally, the most difficult for REWRITENAT is to generate verbs (i.e., VERB) and particles (i.e., PRT). These observations are consistent with easy-first generation hypothesis: early decoding iterations mostly generate words which are the easiest to predict based on input data (Emelianenko et al., 2019).

Case Study
As illustrated in Figure 6, we present a translation example to compare REWRITENAT with CMLM. The number of maximum decoding iterations is set as 10. We can observe that REWRITENAT can generate the reasonable translation with 3 decoding iterations and terminate the decoding due to the locator module, automatically. In addition, the erroneous translation pieces (e.g., "are children children") can be accurately distinguished. In contrast, strong CMLM baseline shows their weakness at tackling the incorrect ones. Consequently, 6 STANFORD CORENLP TOOLKIT (Manning et al., 2014) is utilized to annotate translation output with Part-of-Speechs. 7 PUNC-punctuation, NOUN-noun, PRT-particle, DETdeterminer, CONJ-conjunction, ADJ-adjective, ADV-adverb, PREP-preposition, VERB-verb CMLM generally spend more decoding iterations than REWRITENAT, but achieving inferior performance. These results confirm the effectiveness and efficiency of the proposed REWRITENAT.
7 Related Work Gu et al. (2018) first proposed NAT to generate the translation in parallel, boosting the inference speed. Towards mitigating the performance degradation, a series of works were proposed to strengthen the capacity of capturing the dependencies among output words, including adding a lite autoregressive module (Kaiser et al., 2018;, training with well-designed objectives (Guo et al., 2019;Libovický and Helcl, 2018;Shao et al., 2020;Ghazvininejad et al., 2020a;Du et al., 2021), modeling with latent varibles (Ma et al., 2019) and mimicking hidden states of autoregressive teacher . Despite above improvements, decoding inconsistency can still be observed in the translation. Towards eliminating the errors, iterative decoding (Xia et al., 2017;Geng et al., 2018) was proposed to employ multiple iterations to polish previously generated translation. As an early alternative, Lee et al. (2018) corrected the original non-autoregressive output by passing it multiple times through a denoising autoencoder. Instead of generating in discrete space of sentences, continuous latent variables were utilized to improve iterative refinements . Subsequently, Ghazvininejad et al. (2019) introduced mask-predict, which first generate target words non-autoregressively, and then repeat-

SOURCE
Den Kindern stehen regionale Handwerker von 11 bis 17 Uhr helfend zur Seite . CMLM 1∼8 Regional craftsmen are at their children from 11 a.m. to 5 p.m . 9 Regional craftsmen are assist their children from 11 a.m. to 5 p.m . 10 Regional craftsmen are helping the children from 11 a.m. to 5 p.m . REWRITENAT 1 Regional craftsmen are children children children from 11 a.m. to 5 p.m . 2 Regional craftsmen are assist the children from 11 a.m. to 5 p.m . 3 Regional craftsmen will assist the children from 11 a.m. to 5 p.m . Figure 6: An example from the WMT14 De→En translation that illustrates how REWRITENAT, together with CMLM generate text with iterative decoding. The translation pieces to be revised in next iteration are annotated as strikethrough, and the erroneous ones within the final translation are underlined. Notably, we remove the BPE tokens in the generated translation, leading to the unreasonable words (e.g., cra@@ fts@@ from → craftsfrom).
edly mask out and re-generate the subset of words that model is least confident about (Ghazvininejad et al., 2020b;Guo et al., 2020b). However, a serious drawback is that previous iterative NAT approaches exposes fundamental weakness in distinguishing erroneous translation pieces. Precisely, previous iterative NAT models based on mask-predict utilizes heuristic rules to consider the least confident words as the ones to be revised, but it struggles to perfectly make correct classifications simply relying on the probability distribution of the generated translation. Despite LevT (Gu et al., 2019) can alleviate the issue to some extent by adopting two basic operations (i.e.,insert and delete), a serious discrepancy in input data distribution between training and decoding exists due to the utilization of iterative strategy into decoding but not training. Towards address above issues, REWRITENAT adopts an additional locator module specialized to distinguish the erroneous translation pieces, and iterative training strategy is utilized to maintain the consistency of data distribution with iterative decoding.

Conclusion
In this work, we propose an architecture named REWRITENAT, which explicitly learns to rewrite the erroneous translation pieces, and iterative training is utilized to train this architecture. Extensive experimental results demonstrate REWRITENAT can achieve remarkable improvement over previous iterative NAT models, but with significantly less decoding iterations. The further analysis reveals that the generation orders of REWRITENAT measured by the percentage of decoding iterations are consistent with easy-first hypothesis.