Cross Attention Augmented Transducer Networks for Simultaneous Translation

This paper proposes a novel architecture, Cross Attention Augmented Transducer (CAAT), for simultaneous translation. The framework aims to jointly optimize the policy and translation models. To effectively consider all possible READ-WRITE simultaneous translation action paths, we adapt the online automatic speech recognition (ASR) model, RNN-T, but remove the strong monotonic constraint, which is critical for the translation task to consider reordering. To make CAAT work, we introduce a novel latency loss whose expectation can be optimized by a forward-backward algorithm. We implement CAAT with Transformer while the general CAAT architecture can also be implemented with other attention-based encoder-decoder frameworks. Experiments on both speech-to-text (S2T) and text-to-text (T2T) simultaneous translation tasks show that CAAT achieves significantly better latency-quality trade-offs compared to the state-of-the-art simultaneous translation approaches.


Introduction
Simultaneous translation, which starts to translate input sentences before they are finished, is of importance to many real-life applications such as teleconference systems and time-sensitive spoken document analysis and conversion. While a substantial progress has been made on offline machine translation (Wu et al., 2016;Vaswani et al., 2017;Hassan et al., 2018), more research on simultaneous translation is yet highly desirable. Central to the task is performing high-quality low-latency translation, which involves the key challenges of developing optimal policies for the READ-WRITE action paths as well as generating high-quality target sequences based only on partial source sequences.
This paper aims to optimize the policy and translation model jointly, by expanding target sequences with blank symbols for READ actions. The loss function can be defined as negative loglikelihood (NLL) of marginal distribution through all expanded paths. A similar problem in automatic speech recognition (ASR) has been tackled with RNN-T (Recurrent Neural Network Transducer) (Graves, 2012) by an efficient forward-backward algorithm. However, RNN-T is trained based on the monotonic alignment between source and target sequences, which is not suitable for simultaneous translation, as it cannot properly consider reordering. On the other hand, the forward-backward algorithm is not available for attention-based encoderdecoder (Bahdanau et al., 2015) architectures, including Transformer (Vaswani et al., 2017), due to the deep coupling between source contexts and target history contexts.
To solve this problem, we separate the cross attention mechanism from target history representation in attention-based encoder-decoder, which can also be viewed as RNN-T with the joiner being augmented by cross attention mechanism, resulting in Cross Attention Augmented Transducer (CAAT). However, cross attention mechanism removes the alignment constraint in RNN-T which originally encourages an appropriate latency. To ensure latency under control, jointly minimizing a latency loss is required. Both the NLL loss and latency loss can be efficiently optimized by a forward-backward algorithm.
The main contributions of this paper are threefold: (1) We propose a novel architecture, Cross Attention Augmented Transducer, which jointly optimizes the policy and translation model by considering all possible READ-WRITE simultaneous translation action paths. (2) We introduce a novel latency loss whose expectation can be optimized by a forward-backward algorithm. Training with this latency loss ensures the latency of CAAT simultaneous translation model to be under control. (3) The proposed model achieves significantly better latency-quality trade-offs compared to the state-ofthe-art simultaneous translation approaches.

Related Work
Recent work on simultaneous translation falls into two categories. The first category uses a fixed policy for the READ/WRITE actions. Cho and Esipova (2016) propose simultaneous translation with the wait-if-* policy for an offline model. Ma et al. (2019) propose a wait-k policy for both the training and inference period. The second category includes models with a flexible policy learned and/or adaptive to current context. Gu et al. (2017) introduce an agent trained by reinforcement learning from the interaction with a pre-trained offline neural machine translation model. Zheng et al. (2019a) train the agent by supervise learning with label sequences generated via the rank of golden target words given partial input. A special subcategory of flexible policy jointly optimize policy and translation by monotonic attention customized to translation model, e.g., Monotonic Infinite Lookback (MILk) attention (Arivazhagan et al., 2019) on RNN encoder-decoder (Bahdanau et al., 2015) and Monotonic Multihead Attention (MMA) (Ma et al., 2020c) on Transformer (Vaswani et al., 2017).
End-to-end speech-to-text (S2T) simultaneous translation has been investigated in (Ma et al., 2020b,d;Ren et al., 2020), among which Ma et al. (2020b) adapt latency metrics from T2T simultaneous translation to S2T simultaneous translation, and experiment with both the fixed and flexible policy. Ma et al. (2020d) study the effect of speech block processing on S2T simultaneous translation. Ren et al. (2020) experiment with the wait-k policy based on a source language CTC segmenter.
In our work, we optimize the marginal distribution of all expanded paths motivated by RNN-T (Graves, 2012). Unlike RNN-T, the CAAT model removes the monotonic constraint, which is critical for considering reordering in machine translation tasks. The optimization of our latency loss is motivated by Sequence Discriminative Training in ASR (Povey, 2005).

Notations and formulation
Let x and y denote the source sequence and target sequence, and f and g the encoder and decoder function, respectively. For simultaneous translation, let a j denotes the length of source sequence processed when deciding the target y j . The policy of simultaneous translation is denoted as an action sequence p ∈ {R, W } |x|+|y| where R denotes the READ action and W the WRITE action. If the READ action is replaced with a blank symbol ∅, the policy can also be represented by the expanded target sequenceŷ ∈ (V ∪ {∅}) |x|+|y| , where V is the vocabulary of the target language. Note that removing all ∅ inŷ results in the original target sequence y. The mapping from y to sets of all possible expansionŷ is denoted as H(x, y).
As shown in Figure 1(b), to calculate P (ŷ k |h i k , y <j k ), RNN-T divides decoder into predictor and joiner, where the predictor, denoted f pred , produces target history representation (Eq. (2)), and the joiner products output probability Pr(y|i, j) by joint representations from predictor and encoder (Eq. (3)).
Though named as RNN Transducer, other sequence processing architectures work well as the encoder or predictor, e.g., Transformer (Zhang et al., 2020;Yeh et al., 2019). Online decoding is natural for RNN-T if the encoder works with streaming input, which makes RNN-T widely adopted in both the online and offline ASR tasks.
One drawback of RNN-T is that it is based on a monotonic alignment between the input and output sequence, making it unsuitable for sequenceto-sequence tasks with reordering, e.g., machine translation. The goal of simultaneous translation is to achieve high translation quality and low latency. A natural loss function hence measures the NLL loss of marginal conditional distribution and expectation of latency metric through all possible expanded paths: where Pr(ŷ|y, x) = Pr(ŷ |x) ,ŷ ∈ H(x, y) is one of the expanded paths of the target sequence y, and l(ŷ) is the latency loss for patĥ y.
As the total number of expanded paths is exponential with regard to |x| + |y|, computing the marginal probability ŷ∈H(x,y) Pr(ŷ|x) is non-trivial. RNN-T solves this with a forwardbackward algorithm (Graves, 2012), which inherently requires paths in the graph to be mergeable. That is, the representations of the same location in different paths should be identical. Conventional attention-based encoder-decoder architectures as  Figure 1(a), however, do not satisfy this requirement. Take Figure 2 as an example, the decoder hidden states for the red pathŷ 1 and the blue pathŷ 2 are described below (we denotes s n i as the representation of the i-th decoder step in the expanded pathŷ n ) : The decoder states at output step 2 with different history paths, s 1 2 and s 2 2 , are not identical. This is due to the coupling of source and previous target representation by the attention mechanism in the decoder. The same problem exists in Transformer, from the coupling of self-attention and encoderdecoder cross attention in each block.
To solve this, we separate the cross attention mechanism from the target history representation, which is similar to the joiner and predictor in RNN-T. The novel architecture, as shown in Figure 1(c), can be viewed as an extended version of RNN-T with the joiner augmented by cross attention mechanism, and is named as Cross Attention Augmented Transducer (CAAT). Different from RNN-T, the joiner in CAAT is a complex architecture with attention mechanisms as in Eq. (6): Note that s i,j is independent of previous nodes s i ,j in pathŷ, and the same location from different paths in Figure 2 produces the same state representation. By analyzing the diffusion of the output probability through the lattice in Figure 2, we can find that Pr(y|x) is equal to the sum of probabilities over any top-right to bottom-left diagonal nodes. Defining the forward variable α(i, j) as the probability of outputting y [1:j] during x [1:i] , and the backward variable β(i, j) as the probability of outputting y [j+1:|y|] during x [i:|x|] , we can draw the marginal likelihood Pr(y|x) as : where 1 ≤ m ≤ |x| + |y|. The detailed derivation of NLL loss of CAAT can be found in Appendix A.1.
The proposed CAAT can be implemented with a variety of attention-based encoder-decoder frameworks. In this paper, we implemented CAAT with Transformer, by dividing Transformer's decoder into the predictor and joiner module. As shown in Figure 3, the predictor and joiner share the same number of transformer blocks as the conventional transformer decoder, but there are no crossattention blocks in the predictor module and no self-attention blocks in the joiner.

Multi-Step Decision
The CAAT architecture gains the ability of handling source-target reordering at the cost of an expensive joiner. The complexity of joiner is O(|x| · |y|) during training. For RNN-T, the joiner is efficient because only softmax operates at O(|x| · |y|). But for CAAT, joiner takes up half of the parameters of decoder, which means the complexity of CAAT is about |x| 4 times higher than the conventional encoder-decoder framework during training.
RNN-T needs to ensure the output timing of y j is the corresponding source frame a j = align(x, y j ). However, based on attention mechanism, CAAT only needs to ensure output timing to be after the corresponding position (a j ≥ align(x, y j )). Therefore, it is no longer necessary to make decision each encoder frame; the decision step size d > 1 is appropriate for CAAT, which reduces the complexity of the joiner from O (|x| · |y|) to O |x|·|y| d . Besides, the decision step size is also an effective way to adjust latency-quality trade-off.

Latency Loss
CAAT relaxes the restriction of output timing by attention mechanism, which means all source step i ≥ align(x, y j ) should be appropriate for output y j , including the offline path (∀j : a j = |x|).
To avoid the CAAT model bypassing online policy by choosing the offline path, the latency loss L latency (x, y) as defined in Eq. (4) is required.
Motivated by Sequence Criterion Training in ASR (Povey, 2005), we optimize the latency loss with the forward-backward algorithm. To calculate the expectation of latency loss through all pathsŷ, mergeable is also a requirement to the latency loss definition, which means the latency loss through pathŷ may be defined as l(ŷ) = |x|+|y| k=1 l(ŷ k ) and l(ŷ k ) is independent of l(ŷ k =k ). However, both Average Lagging (Ma et al., 2019) and Differentiable Average Lagging (Arivazhagan et al., 2019) do not meet this requirement. We hence introduce a novel latency function as follows: where denote the number of READ and WRITE actions beforeŷ k , respectively. The maximization operation is used to avoid encouraging over-aggressive decision paths. This latency definition is not rigorous enough to be an evaluation metric for the under-estimation after source ended, as analyzed in (Arivazhagan et al., 2019), but it can still be used as a loss function.
By defining the forward latency variable α lat (i, j) as the expectation of latency of outputting y [1:j] during x [1,2,··· ,i] , and the backward latency variable beta lat (i, j) as the expectation of latency of outputting y [j+1:|y|] during decision steps x [i,··· ,|x|] , the latency loss can be drawn as: where 1 ≤ m ≤ |x| + |y|. The detailed derivation of latency loss of CAAT can be found in Appendix A.2.

Offline Auxiliary Loss
We add the negative log-likelihood loss of the offline translation path as an auxiliary loss to CAAT model training for two reasons. First, we hope the CAAT model falls back to offline translation in the worst case; second, the CAAT translation is carried out in accordance with offline translation when a source sentence finishes. The final loss function for CAAT training is defined as follows: where λ latency and λ of f line are the scaling factors corresponding to L latency and L of f line , respectively. And we set λ latency = λ of f line = 1.0 if not specified.

Streaming Encoder
Unidirectional Transformer encoder (Arivazhagan et al., 2019; Ma et al., 2020c) is not effective for speech data processing, because of the close relatedness to the right context for speech feature x i . Block processing (Dong et al., 2019;Wu et al., 2020) is introduced for online ASR, but it lacks direct observation to infinite left context. We process the streaming encoder for speech data by block processing with the right context and infinite left context. First, input representations h is divided into overlapped blocks with block shift step m and block size m + r. Each block consists of two parts, the main context m n = h m * n+1 , · · · , h m * (n+1) and the right context r n = h (n+1) * m+1 , · · · , h (n+1) * m+r . The query, key, and value of block b n in self-attention can be described as follows: By reorganizing the input sequence and designed self-attention mask, training is effective by reusing conventional transformer encoder layers. And unidirectional transformer can be regarded as a special case of our method with {m = 1, r = 0}. Note that the look-ahead window size in our method is fixed, which enables us to increase transformer layers without increasing latency. We set the main context size and right context size to 8 and 4, respectively, for our experiments on speech-to-text simultaneous translation, and conventional unidirectional transformer encoder {m = 1, r = 0} for experiments on text-to-text simultaneous translation.

Inference of CAAT Simultaneous Translation
The online inference for CAAT is adapted from beam search for RNN-T (Graves, 2012), and the changes are as follows 2 : (1) We only merge paths between decision steps, as the cost of the joiner of CAAT is significantly more expensive than that of RNN-T.
(2) We extract common prefix of existing hypotheses as determined target output at each decision time step.
(3) Different beam sizes are introduced for intra-decision (b 1 ) and inter-decision (b 2 ) pruning, to ensure timely determination of outputs. b 1 and b 2 are set to be 5 and 1, respectively, if not otherwise specified.  (Vaswani et al., 2017). Since the variance of the length of speech frames is more significant than that of text length, we use both cosine positional embedding (Vaswani et al., 2017) and relative positional attention (Shaw et al., 2018) for speech encoder, and only cosine positional embedding for the decoder. Detailed hyper-parameters of our models can be found in Appendix C.1.
Training and Inference Training speech translation models is often regarded to be more difficult than training text machine translation or ASR models. We use two methods to improve the performance and stability of model training.
The first is to pre-train encoder with ASR task (Ma et al., 2020b), and the second is to leverage sequence-level knowledge distillation with text machine translation model (Ren et al., 2020). Training CAAT models require significantly larger GPU memory than that used in conventional Transformer due to the spatial complexity O( |x||y| d ) of the joiner module; we solve this by splitting hidden states into small pieces before sending them into the joiner and recombining them during backpropagation.
Our implementation is based on the Fairseq library (Ott et al., 2019); the NLL and latency loss for CAAT are implemented based on warp-rnnt 4 .
Evaluation We evaluate our models with SimulEval (Ma et al., 2020a). Translation quality is measured by detokenized case-sensitive BLEU (Papineni et al., 2002); latency is measured with the adapted version of word-level Average Lagging (AL) (Ma et al., 2020a).

Results
We compare CAAT to the current state-of-the-art model in speech-to-text simultaneous translation (Ma et al., 2020b), which uses wait-k with a fixed pre-decision step size of 320ms. All our simultaneous speech translation models, both wait-k and CAAT are trained with encoder pretrained on ASR task and sequence-level knowledge distillation with text translation model. Two inference methods are used for wait-k, conventional beam search only on target tail (when source finishes) and speculative beam search (SBS) (Zheng et al., 2019b), both with a beam size of 5; the forecast steps in SBS is set to be 2. For CAAT we set the intra-decision beam size b 1 = 5 and inter-decision beam size b 2 = 1 as described in Sec. 4.3. The latency-quality curves of CAAT are produced by varying decision step size d ∈ {8, 16, 32, 48, 64, 80, +∞}, and wait-k by varying k ∈ {1, 2, 4, 6, 8, 10, 12, +∞}. The AL-BLEU curves on the MuST-C EN→DE and EN→ES test sets are shown in Figure 4. 5 From the figure we can observe that: (1) In general CAAT significantly outperforms wait-k (with and without SBS) in both the EN→DE and EN→ES task. Especially in the low-latency region (AL < 1000ms) (Ansari et al., 2020), CAAT outperforms wait-k with SBS by more than 3 BLEU points.
(2) The Offline models of CAAT and wait-k obtain similar BLEU, suggesting that the adapted architecture of CAAT performs comparably with conventional Transformer in an offline scenario. (3) With the same wait step k, SBS can produce lower latency. This is due to the word-level latency metrics we used requires an additional token to ensure complete word submitted, which can be offset by the forward exploration in SBS.

Ablation Study
Effectiveness of Streaming Encoder The performance of our offline models with full-sentence encoder compared to the state-of-the-art offline speech translation systems (Wang et al., 2020;Inaguma et al., 2020) are demonstrated in Table 1. We also show the ablation analyses on sequencelevel knowledge distillation with text translation model (KD) and pretrain encoder with ASR task (Pretrain). We further compare offline translation models with streaming encoders to those with the conventional full-sentence encoder. As shown in Table 1, the performance of the translation model with a unidirectional encoder drops 2-3 BLEU points compared to that with a full-sentence encoder, and the gap is gradually narrowed by the increase of main block size m and introduction of right context. Considering the effect on latency, we choose {m = 8, r = 4}.   Effectiveness of λ latency and λ of f line The effectiveness of λ of f line is demonstrated in Table 3. Furthermore, as shown in Figure 5, though λ latency may affect the trade-off between translation quality and latency, varying λ latency is not as effective as varying the decision step size d, and we found the model training will be unstable when λ latency ≥ 2.0.

Effectiveness of Beam Search
The effectiveness of the intra-decision beam size b 1 and interdecision beam size b 2 on simulation translation performance is shown in Table 4. We can find that beam search in one decision step brings an improvement of about 0.7 BLEU over the greedy search. And if we allow multiple hypothesizes between decision steps we may get another 0.5 BLEU improvement at the cost of latency (AL increases from 1114.9 to 2433.5). However, this may be useful in the scenarios where revision is allowed (Arivazhagan et al., 2020), e.g., simultaneous translation for subtitle.  Case Study We perform case study to demonstrate the advantages of CAAT model over wait-k with SBS, we compare wait-k k = 2 with CAAT d = 32 for they have similar AL latency. As shown in Figure 7, wait-k generates meaningless translation by 'predict' in the place of pauses and changes in speech rate, while CAAT does not suffer from this problem. As a result, CAAT outperforms waitk with SBS.

Text-to-Text Simultaneous Translation
We further performed experiments on the text-totext simultaneous translation task. Experiments are carried out on the WMT15 German-English (DE→EN) dataset with newstest2013 as the validation set and newstest2015 as the test set. We strictly follow the same settings of (Arivazhagan et al.  8 We can see that CAAT outperforms wait-k and wait-k with SBS, but the gap is narrowing compared to that of S2T simultaneous translation in Figure 4. Considering the case analyze in Sec. 5.1.3, we believe that flexible policy is more important for speech translation because of the speech rate changing.

Conclusions
This paper proposes Cross Attention Augmented Transducer (CAAT), a novel simultaneous translation model that jointly optimizes policy and translation model by considering all possible READ-WRITE action paths. Crucial to the model is a

A.1 Derivation of CAAT NLL loss
Given the encoder representation h n , where 1 ≤ n ≤ |x|, the predictor vector h pred j , where 0 ≤ j ≤ J and J = |y|. and decision step size d ≥ 1. The maximum decision step is I = |x| d , and the output logits at decision step i, target position j should be s(i,j) is a vector of |V |+1 dimension corresponding to V and blank symbol ∅. s(k, i, j) denotes the k-th dimension of s(i, j). The conditional output distribution can be yielded as : To simplify notation, define y(i, j) := Pr(y j+1 |i, j) Define the forward variable α(i, j) as the probability of outputting y [1:j] during decision steps [1, 2, · · · , i]. The forward variables for all 1 ≤ i ≤ I and 0 ≤ j ≤ |y| can be calculated recursively using with initial condition α(1, 0) = 1. The total output sequence probability is equal to the forward variable at the terminal node: Define the backward variable β(i, j) as the probability of outputting y [j+1:J] during decision steps [i, · · · , I]. Then: with initial condition β(I, J) = ∅(I, J). Pr(y|x) is equal to the sum of α(i, j)β(i, j) over any topright to bottom-left diagonal through the nodes. That is, ∀m : 1 ≤ m ≤ I + J From Eqs. 18, 20 and 21, we can draw the derivation of loss function L = − log Pr(y|x) as

A.2 Derivation of CAAT Latency Loss
To calculate the marginal expectation in Eq. 23, we define forward latency variable α lat (n, j) as the expectation latency of outputting y [1:j] during decision steps [1, 2, · · · , i], and backward latency variable beta lat (i, j) as the expectation latency of outputting y [j+1:J] during decision steps [i, i + 1, · · · , i]. Here we denote l(n, j) as the latency function for output y j at source position n.
The forward latency variables can be calculated recursively using α lat (i, j) = p 1 (i, j) · f 1 (i, j) with initial condition α lat (1, 0) = 0. Where For backward latency variables with initial condition β lat (I, J) = 0. Where To simplify notation, define the latency expectation of all paths go through grid (n, j) as The expectation latency for all pathsŷ ∈ H(x, y) is equal to the expectation through diagonal nodes. That is, ∀m : 1 ≤ m ≤ N + J: And the latency loss L latency (x, y) =ĉ. From Eqs. 24, 26, 28 and 29, it follows that:

B Beam Search Algorithm for CAAT
The pseudo code of beam search algorithm for CAAT is described in Algorithm 1.

C Hyper-parameters C.1 Hyper-parameters on Speech-to-Text Simultaneous Translation
Our experiments on speech-to-text simultaneous translation are based on Transformer. For speech processing two 2D convolution blocks are introduced before the stacked Transformer encoder layers. Each convolution block consists of a 3-by-3 convolution layer with 64 channels and stride size as 2, and a ReLU activation function. Input speech features are downsampled 4 times by convolution blocks and flattened to 1D sequence as input to transformer layers. Cosine positional embedding is added to speech representations after convolutions, and relative positional attention is employed for encoder self-attention. 6. The dropout ratio is set to 0.3. Our CAAT model shares the same hyper-parameters with the conventional Transformer model, except the feedforward hidden sizes of predictor and joiner are set to 1024 to ensure the number of the total parameters is identical. All speech translation models were trained with Adam optimizer with an initial learning rate of 5 × 10 −4 and invert_sqrt scheduler. Each model was trained with 2× V-100 GPU with 32GB video memory, using a batch-size of 20000 frames; update-frequency is set to be 8.

C.2 Hyper-parameters on Text-to-Text Simultaneous Translation
Our experiments on text-to-text simultaneous translation are based on Transformer_Base. That is, the hidden size, feed-forward hidden size, number of heads, number of encoder and decoder layers are set to 512, 2048, 8, 6, and 6. The dropout ratio is set to 0.3. The MMA-IL model is trained with architecture transformer_monotonic 9 , except the noise variance is set to 2. Our CAAT model shares the same parameters number with the Transformer model by setting the feed-forward hidden size of predictor and joiner to 1024 (half of Transformer decoder feed-forward hidden size). All text translation models were trained with Adam optimizer with initial learning rate 5 × 10 −4 and invert_sqrt scheduler. Each model was trained with 2× V-100 GPU with 32GB video memory, with batch-size 4096 frames, and update-frequency is set to 8.

D Expanded Results
We also evaluate our work with the latency metrics Average Proportion (AP) and Differentiable Average Lagging (DAL). The full-size version of translation quality against latency (AL, AP, and DAL) curves on the MuST-C EN→DE and EN→ES speech-to-text simultaneous translation tasks are shown in Figure 8. And the quality-latency curves on the WMT15 DE→EN text-to-text translation task are shown in Figure 9. We also provide a complete table of results in Tables 5, 6 and 7 .