Fast Nearest Neighbor Machine Translation

Though nearest neighbor Machine Translation (kNN-MT) (CITATION) has proved to introduce significant performance boosts over standard neural MT systems, it is prohibitively slow since it uses the entire reference corpus as the datastore for the nearest neighbor search. This means each step for each beam in the beam search has to search over the entire reference corpus. kNN-MT is thus two-orders slower than vanilla MT models, making it hard to be applied to real-world applications, especially online services. In this work, we propose Fast kNN-MT to address this issue. Fast kNN-MT constructs a significantly smaller datastore for the nearest neighbor search: for each word in a source sentence, Fast kNN-MT first selects its nearest token-level neighbors, which is limited to tokens that are the same as the query token. Then at each decoding step, in contrast to using the entire corpus as the datastore, the search space is limited to target tokens corresponding to the previously selected reference source tokens. This strategy avoids search through the whole datastore for nearest neighbors and drastically improves decoding efficiency. Without loss of performance, Fast kNN-MT is two-orders faster than kNN-MT, and is only two times slower than the standard NMT model. Fast kNN-MT enables the practical use of kNN-MT systems in real-world MT applications. The code is available at https://github.com/ShannonAI/fast-knn-nmt.


Introduction
Machine translation (MT) is a fundamental task in natural language processing (Brown et al., 1993;Och and Ney, 2003), and the prevalence of deep neural networks has spurred a diverse array of neural machine translation (NMT) models to improve translation quality (Sutskever et al., 2014;Bahdanau et al., 2014;Vaswani et al., 2017). The recently proposed k nearest neighbor (kNN) MT model (Khandelwal et al., 2020) has proved to introduce significant performance boosts over standard neural MT systems. The basic idea behind kNN-MT is that at each decoding step, the model is allowed to refer to reference target tokens with similar translation contexts in a large datastore of cached examples. The corresponding reference target tokens provide important insights on the translation token likely to appear next.
One notable limitation of kNN-MT is that it is prohibitively slow: it uses the entire reference corpus as the datastore for the nearest neighbor search. This means each step for each beam in the beam search has to search over the entire reference corpus. kNN-MT is thus two-orders slower than vanilla MT models. The original paper of kNN-MT (Khandelwal et al., 2020) suggests using fewer searching clusters, smaller beams and smaller datastores for generation speedup, but to achieve satisfactory results, carefully tuning on these factors under different tasks and datasets is still required according to analyses in (Khandelwal et al., 2020). The computational overhead introduced by kNN-MT makes it hard to be deployed on real-world online services, which usually require both model performance and runtime efficiency.
In this work, we propose a fast version of kNN-MT -Fast kNN-MT, to tackle the aforementioned issues. Fast kNN-MT constructs a significantly smaller datastore for the nearest neighbor search: for each word in a source sentence, Fast kNN-MT first selects its nearest token-level neighbors, which is limited to tokens of the same token type. Then at each decoding step, in contrast to consulting the entire corpus for nearest neighbor search, the datastore for the currently decoding token is limited within the tokens of reference targets corresponding to the previously selected reference source tokens, as shown in Figure 1. The chain of mappings from the target token to the source token, then to its nearest source reference tokens, and last to cor-responding target reference tokens, can be obtained using FastAlign (Dyer et al., 2013).
Fast kNN-MT provides several important advantages against vanilla kNN-MT in terms of speedup: (1) the datastore in the KNN search is limited to target tokens corresponding to previously selected reference source tokens, instead of the entire corpus. This significantly improves decoding efficiency; (2) for source nearest neighbor retrieval, we propose to restrict the reference sources tokens that are the same as the query token, which further improves nearest-neighbor search efficiency. Without loss of performance, Fast kNN-MT is two-orders faster than kNN-MT, and is only two times slower than standard MT model. Under the settings of bilingual translation and domain adaptation, Fast kNN-MT achieves comparable results to kNN-MT, leading to a SacreBLEU score of 39.3 on WMT'19 De-En, 41.7 on WMT'14 En-Fr, and an average score of 41.4 on the domain adaptation task.

Related Work
Neural Machine Translation Neural machine translation systems (Vaswani et al., 2017;Gehring et al., 2017;Meng et al., 2019) are often implemented by the sequence-to-sequence framework (Sutskever et al., 2014) and enhanced with the attention mechanism (Bahdanau et al., 2014;Luong et al., 2015) which associates the current decoding token to the most semantically related part in the source side. At decoding time, beam search and its variants are used to find the optimal sequence (Sutskever et al., 2014;Li and Jurafsky, 2016). The development of self-attention (Vaswani et al., 2017) and pretraining (Devlin et al., 2018; has greatly motivated a line of works for more expressive MT systems. These works include incorporating pretrained models (Zhu et al., 2020;Guo et al., 2020), designing lightweight model structures (Kasai et al., 2020;Lioutas and Guo, 2020;Tay et al., 2020;Kasai et al., 2021;Peng et al., 2021), handling multiple languages (Aharoni et al., 2019;Arivazhagan et al., 2019; and mitigating structural issues in Transformers (Wang et al., 2019;Nguyen and Salazar, 2019;Xiong et al., 2020b) for more robust and efficient NMT systems.

Retrieval-Augmented Models
Retrieving and integrating auxiliary sentences has shown effectiveness in improving robustness and expressiveness for NMT systems.  up-weighted the output tokens by collecting from the retrieved target sentences n-grams that align with the words in the source sentence, and ) similarly retrieved n-grams but incorporated the information using gated attention (Cao and Xiong, 2018). (Tu et al., 2018) updated and stored the hidden representations of recent translation history in cache for access when new tokens are generated, so that the model can dynamically adapt to different contexts. (Gu et al., 2018) leveraged an off-the-shelf search engine to retrieve a small subset of sentence pairs from the training set and then perform translation given the source sentence along with the retrieved pairs. Farajian et al., 2017) proposed to retrieve similar sentences from the training set for the purpose of adapting the model to different input sentences. (Bulté and Tezcan, 2019;Jitao et al., 2020) used fuzzy matches to retrieve similar sentence pairs from translation memories and augmented the source sentence with the retrieved pairs. Our work is motivated by kNN-MT (Khandelwal et al., 2020) and target improving the efficiency of kNN retrieval while achieving comparable translation performances.
Apart from machine translation, other NLP tasks have also benefited from retrieval-augmented models, such as language modeling (Khandelwal et al., 2019), question answering (Guu et al., 2020;Lewis et al., 2020b,a;Xiong et al., 2020a) and dialog generation (Weston et al., 2018;Thulke et al., 2021). Most of these works perform retrieval at the sentence level and treat the extracted sentences as additional input for model generation, whereas fast kNN-MT retrieves the most relevant tokens in the source side and fixes the probability distribution using the aligned target tokens at each decoding step.

Background: kNN-MT
Given an input sequence of tokens x = {x 1 , · · · , x n } of length n, an MT model translates it into a target sentence in another language y = {y 1 , · · · , y m } of length m. A common practice to produce each token y i on the target side is to obtain a probability distribution over the vocabulary p MT (y i |x, y 1:i−1 ) from the decoder and use beam search for generation. The combination of the complete source sentence and prefix of the target sentence (x, y 1:i−1 ) is called trans-lation context. kNN-MT interpolates this probability distribution with a multinomial distribution p kNN (y i |x, y 1:i−1 ) derived from the k nearest neighbors of the current translation context (x, y 1:i−1 ) from a large scale datastore S: p(y i |x, y 1:i−1 ) = λp kNN (y i |x, y 1:i−1 ) More specifically, kNN-MT first constructs the datastore S using key-value pairs, where the key is the high-dimensional vector of the translation context produced by a trained MT model f (x, y 1:i−1 ), and the value is the corresponding gold target token The context-target pairs may come from any parallel corpus. Then, using the dense representation of the current translation context as query q = f (x in , y 1:i−1 ) and L 2 distance as measure, kNN-MT searches through the entire datastore S to retrieve k nearest translation contexts along with the corresponding target tokens N = {k j , v j } k j=1 . Last, the retrieved set is transformed to a probability distribution by normalizing and aggregating the negative L 2 distances, −d, using the softmax operator with temperature T , which can be expressed as follows: Integrating Eq.(2) into Eq.(1) gives the final probability of generating token y i for time step i. Note that the above kNN search-interpolating process is applied to each decoding step of each beam, and each iteration needs to run on the full datastore S. This gives a total time complexity of O(|S|Bm), where B is the beam size and m is the target length. In order for faster nearest neighbor search, kNN-MT leverages FAISS , an toolkit for efficient similarity search and clustering of dense vectors. slow when the size of the datastore S or the beam size B is large. We propose strategies to address this issue. The same as vanilla kNN-MT, fast kNN-MT system is built upon a separately trained MT encoder-decoder model. To get a better illustration of how Fast kNN-MT works, we give a toy illustration in Figure 1. We use the capitalized characters to denote source tokens and lower-cased letters to denote target tokens. Given the training set, which is: in the toy example, an encoder-decoder model is trained. Next, we wish to translate a source string {B,C,E} at test time.

Datastore Creation On the Source Side
Given a pretrained encoder-decoder model, and the training corpus, we first obtain representations for all source tokens and target tokens of the training set, which are the last layer outputs from the encoder-decoder model. In the toy example, representations for source tokens {A,B,C,D} in the first training example ({A,B,C,D}, {b,c,d,a}) are respectively h 11 , h 12 , h 13 , h 14 , and for target tokens {b,c,d,a} are respectively z 11 , z 12 , z 13 , z 14 . Given a test example to translate, which is {B,C,E} in the example, we also obtain the representation for each of its constituent token, denoted by h B , h C , h E . Next, we select nearest neighbor tokens for each source token, i.e., {B, C, E}. The nearest neighbor tokens are first limited to source tokens of the same token type as the query token. For token B, tokens of the same token type are x 12 , x 21 , x 32 , x 41 . Similarly, for the token C in the test example, tokens of the same type are {x 13 , x 22 }; for the token E, tokens of the same type are {x 34 , x 43 , x 52 }. One issue that stands out is that, for common words such as "the", there can be tens of millions of the same type tokens in the training corpus. We thus need to further limit the number of nearest neighbors. Let c denote the hyper-parameter that controls the number of nearest neighbors for each token on the source side, which is set to 2 in the toy example. We rank all candidates based on the distance between the  Figure 1: Caching source and target tokens (left, blue): Given a trained NMT model f and the training corpus D train , we obtain representations for all source tokens h and target tokens z in the training set, which are the last layer outputs from f . Datastore construction (right, green): Given a test example to translate, which is {B,C,E} in the example, we first navigate each source token to the tokens of the same type in the cache, e.g., x 12 , x 21 , x 32 and x 41 are identified for token B. Then, the top c nearest neighbors for each source token are preserved according to the distance between the source token representation and candidate token representations, e.g., x 12 , x 21 are selected for token B. Last, the selected source tokens are aligned to their target tokens using FastAlign (Dyer et al., 2013). For token B, the aligned target tokens are y 11 , y 24 . The collection of all aligned target tokens (along with their representations) constitutes the datastore for the current input {B,C,E}. source token representation (e.g., h B , h C , h E ) and candidate token representations, and select the top c. Suppose that in the toy example, x 12 , x 21 are selected for token B, x 13 , x 23 are selected for token C, x 34 , x 52 are selected for token E. The concatenation of selected candidates for all source tokens constitute the datastore on the source side, which is D source = {x 12 , x 21 , x 13 , x 23 , x 34 , x 52 } in the toy example. The datastore creation for source tokens (e.g., {B, C, E}) can be run in parallel.

Datastore Creation On the Target Side
For decoding, the model needs to refer to reference target tokens rather than source tokens. We thus need to transform D source to a list of target tokens. We use FastAlign (Dyer et al., 2013) toolkit to achieve this goal. FastAlign maps source tokens to target tokens based on the IBM model (Och and Ney, 2003). Source tokens in D source that do not have correspondence on the target side are abandoned. Output target tokens from FastAlign form the datastore on the target side, denoted by D target . In the toy example, x 12 , x 21 , x 13 , x 23 , x 34 , x 52 are respectively mapped to D target = {y 11 , y 24 , y 12 , y 21 , y 35 , y 52 }. The size of D target is c × n, where n is the source length.
In practice, we first iterate over all examples in the training set, extracting all the source token representations and all the target token representations. Then, we build a separate token-specific cache D v for each v in vocabulary, which consists of (key, value) pairs where the key is the high-dimensional representation h and the value is a binary tuple containing the corresponding aligned target token along with its representation z. Then we could map each source token of a given input sentence to its corresponding cache D v , and build the target-side datastore following the steps in Section 4.1 and Section 4.2. The process of caching source and target tokens is present in Algorithm 1.

Decoding
At the decoding time, the datastore for each decoding step is all limited to D target , within which kNN search is performed. Since tokens in D target are not all related to the current decoding, nearest neighbor search is performed to select the top k candidates from D target for each decoding step. For the nearest neighbor search here, we use the current representation h at the decoding time to query target representation z for target tokens in D target . The selected nearest neighbors and their representations are used to compute the final word generation probability based on Eq.(1) and Eq.(2).

Quantization
Although the prohibitive computational cost issue of kNN-MT has been addressed, the intensive memory for datastore remains a problem, as we wish to cache all source and target representations of the entire training set. Additionally, frequently accessing Terabytes of data is also extremely timeintensive. To address this issue, we propose to use product quantization (PQ) (Jegou et al., 2010) to Algorithm 1: Constructing Datastore for a Test Input x.
Input :All sentence-pairs in training set: (x (1) , y (1) ), ..., (x (N ) , y (N ) ), vocabulary V NMT encoder fe, NMT decoder f d , word alignment for each sentence-pair (A (1) , ..., A (N ) ) Input for test: Dv ← ∅ initialize the (key, value) datastore for each word in the vocabulary end % Caching Source and Target Tokens: computing representations for each target word Therefore each x is mapped to its nearest codeword in the Cartesian product space C = C 1 × ... × C M .
If each subspace codebook C m has n codewords, then the Cartetian product space C could represent n m D-dimensional codewords with only n × m d-dimensional vectors, thus significantly eased the memory issue. With the quantization technique, we are able to compress each token representation to 128-bytes. For example, for the WMT19 En-De dataset, the memory size is reduced from 3.5TB to 108GB.

kNN Retrieval Details
In practice, executing exact nearest neighbor search over millions or even billions of tokens could be time-consuming. Hence we use FAISS  for fast approximate nearest neighbor search. All token representations are quantized to 128-bytes. Recall that we build a token-specific datastore D v for each v in vocabulary. We do brute force search for tokens whose frequency n v is lower than 30000. For those tokens whose frequency is larger than 30000, the keys are stored in clusters to speed up search. The number of clusters for token v is set to min(4 × √ n v , n v /30). To learn the cluster centroids, we use at most 5M keys for each token v. During inference, we query the datastore for k = 512 neighbors through searching 32 nearest clusters.

Discussions on Comparisons to Vanilla kNN-MT
The speedup of Fast kNN-MT lies in the following three aspects: (1) For nearest neighbor retrieval on the source side, we first restrict the reference tokens that are the same as the query token. This strategy significantly narrows down the search space to roughly |S|/mid(F) times, where |S| denotes the number of tokens in the corpus, and mid(F ) denotes the medium word frequency in the corpus.
(2) The nearest neighbor search for all source tokens on the source side can be run in parallel, which is also a key speedup over kNN-MT. For vanilla kNN-MT, kNN search is performed on the target side and has to be auto-regressive: the representation for the current decoding step, which is used for the kNN search over the entire corpus, relies on previously generated tokens. Therefore, the kNN search for the current step has to wait for the finish of kNN searches for all previous generation steps.
(3) On the target side, the datastore in the kNN search is limited to target representations corresponding to selected reference source tokens.
Though the nearest neighbor search in the decoding process is auto-regressive and thus cannot be run in parallel, the running cost is fairly low: recall that the size of D target is c × n. Across all settings, the largest value of c is set to 512. The size of D target is roughly 15k. Performing nearest neighbor searches among 15k candidates is relatively cheap for NMT, and is actually cheaper than the softmax operation for word prediction, where the vocabulary size is usually around 50k. The combination of all these aspects leads to Fast kNN-MT two orders of magnitude faster than vanilla kNN-MT.

Bilingual Machine Translation
We conduct experiments on two bilingual machine translation datasets: WMT'14 English-French and WMT'19 German-English. To create the datastore, we follow  to apply language identification filtering, keeping only sentence pairs with correct languages on both sides. We also remove sentences longer than 250 tokens and sentence pairs with a source/target ratio exceeding 1.5. For all datasets, we use the standard Transformer-base model provided by FairSeq  library. 3 The model has 6 encoder layers and 6 decoder layers. The dimensionality of word representations is 1024, the number of multiattention heads is 16, and the inner dimensionality of feedforward layers is 8192. Particularly, following (Khandelwal et al., 2020), the model for WMT'19 German-English has also been trained on over 10 billion tokens of extra backtranslation data as well as fine-tuned on newstest test sets from previous years. We report the SacreBLEU scores (Post, 2018) for comparison. 4 Table 1 shows our results on the two NMT datasets. The proposed Fast kNN-MT model is able to achieve slightly better results to the vanilla kNN-MT model on WMT'19 German-English, and competitive results on WMT'14 English-French, with less kNN search cost.

Domain Adaptation
We also measure the effectiveness of the proposed Fast kNN-MT model on the domain adaptation task. We use the multi-domain datasets which are originally provided in (Koehn and Knowles, 2017) and further cleaned by (Aharoni and Goldberg, 2020). These datasets include German-English parallel data for train/validation/test sets in five domains: Medical, Law, IT, Koran and Subtitles. We use the trained German-English model introduced in Section 5.1 as our base model, and further build domain-specific datastores to evaluate the performance of Fast kNN-MT on each domain. Table  2  Following (Khandelwal et al., 2020), we also carry out experiments under the out-of-domain and multi-domain settings and report the results on Table 2. "+ WMT19' datastore" shows the results for retrieving neighbors from 770M tokens of WMT'19 data that the model has been trained on, and "+ all-domain datastore" shows the results where the model is trained on the multi-domain datastore from all six settings. The BLEU improvement is much smaller on the out-of-domain setup compared to the in-domain setup, illustrating that the proposed framework relies on in-domain data to retrieve valuable contexts. For the multi-domain setup, the performance for all six domains generally remains the same and only a small drop of the average score is witnessed. This shows that the Fast kNN-MT framework is robust to a massive amount of out-of-domain data and is able to retrieve the context-related information from in-domain data.

Analysis
Examples To visualize the effectiveness of the proposed Fast kNN-MT model, we randomly choose an example from the test set of the Law domain. Table 3   Fur@@ ther@@ more , two Community producers in Greece who took part in the previous investigation ce@@ ased their activity .
Cer@@ tain establi@@ sh@@ ments have ce@@ ased their activities . Table 3: A test sentence pair from the Law domain. We show the original sentence pair for test (the first row), the nearest-neighbor tokens on the source side along with the sentences that retrieved tokens reside in (the second column), and the aligned target tokens extracted from FastAlign, along with sentences in which target tokens reside in (the third column). The retrieved tokens are in red. nearest neighbor tokens on the source side, and the corresponding target tokens. The first figure in Figure 2 demonstrates the similarity heatmap between the gold target tokens and the selected target neighbors. We can see that the retrieved target nearest tokens are highly correlated with the ground-truth target tokens, exhibiting the ability of Fast kNN-MT to accurately extract nearest reference tokens at each decoding step.
The Effect of the Number of neighbors per token on the source side We queried the datastore for nearest c neighbors for each source token. Intu-itively, the larger the c is, the more likely the model could recall the nearest neighbors on the target side. The second figure in Figure2 verifies this point: the model performance increases drastically when c increases from 8 to 64, and then continues increasing as c is up to 512.
The Effect of the Number of neighbors per token on the target side Fast kNN-MT selects top k nearest neighbors at each decoding step for computing the probability p kNN in Eq.(2). The third figure in Figure 2 shows that the model performance first increases and then decreases when we continue  enlarging the value of k, with c fixed at 512, which is consistent with the observation in (Khandelwal et al., 2020). This is because that using neighbors that are too far away from the ground-truth target token adds noise to the model prediction, and thus hurts the performance. kNN-MT is two order of magnitude slower than base MT and Fast kNN-MT regarding the decoding speed. This is because Fast kNN-MT substantially restricts the search space during decoding, whereas vanilla kNN-MT has to execute kNN search over the entire datastore at each decoding step.

Similarity function
We have tried different similarity functions when retrieving c nearest neighbors on source side and computing the kNN distribution. These functions include cosine similarity, inner product and L 2 distance, the SacreBLEU scores for which are respectively 39.2, 39.1 and 38.8 on WMT'19 German-English, showing that cosine similarity is a better measure for representation distance than L 2 distance and inner product.
Effect of quantization Due to the memory issue, we applied quantization to compress the high-5 k plays a minor role to the overall time complexity because each search on the target side is performed within a total amount of cn tokens, which is negligible compared to the time cost spent on the source side.  dimensional representation of each token in the training set. We investigate how quantization would affect model performances. As shown in Table 4, quantization has minor side effects in terms of BLEU scores, and when we use full precision instead of quantization, the average BLEU score only increases 0.1, which suggests that computing similarity using compressed vectors is a viable trade-off between memory usage and model performance.

Conclusion
In this work, we propose a fast version of kNN-MT -Fast kNN-MT -to address the runtime complexity issue of the vanilla kNN-MT. During decoding, Fast kNN-MT constructs a significantly smaller datastore for the nearest neighbor search: for each word in a source sentence, Fast kNN-MT selects its nearest tokens from a large-scale cache. The selected tokens are the same as the query token. Then at each decoding step, in contrast to using the entire datastore, the search space is limited to target tokens corresponding to the previously selected reference source tokens. Experiments demonstrate that this strategy drastically improves decoding efficiency while maintaining model performances compared to vanilla kNN-MT under different settings including bilingual machine translation and domain adaptation. Comprehensive ablation studies are performed to understand the behavior of each component in Fast kNN-MT. In future work, we plan to further improve the efficiency of Fast kNN-MT by applying clustering techniques to build the datastore.