Adaptive Nearest Neighbor Machine Translation

kNN-MT, recently proposed by Khandelwal et al. (2020a), successfully combines pre-trained neural machine translation (NMT) model with token-level k-nearest-neighbor (kNN) retrieval to improve the translation accuracy. However, the traditional kNN algorithm used in kNN-MT simply retrieves a same number of nearest neighbors for each target token, which may cause prediction errors when the retrieved neighbors include noises. In this paper, we propose Adaptive kNN-MT to dynamically determine the number of k for each target token. We achieve this by introducing a light-weight Meta-k Network, which can be efficiently trained with only a few training samples. On four benchmark machine translation datasets, we demonstrate that the proposed method is able to effectively filter out the noises in retrieval results and significantly outperforms the vanilla kNN-MT model. Even more noteworthy is that the Meta-k Network learned on one domain could be directly applied to other domains and obtain consistent improvements, illustrating the generality of our method. Our implementation is open-sourced at https://github.com/zhengxxn/adaptive-knn-mt.

kNN-MT, recently proposed in (Khandelwal et al., 2020a), equips a pre-trained NMT model with a kNN classifier over a datastore of cached context representations and corresponding target tokens, providing a simple yet effective strategy to utilize cached contextual information in inference. However, the hyper-parameter k is fixed for all cases, which raises some potential problems. Intuitively, the retrieved neighbors may include noises when the target token is relatively hard to determine (e.g., relevant context is not enough in the datastore). And empirically, we find that the translation quality is very sensitive to the choice of k, results in the poor robustness and generalization performance.
To tackle this problem, we propose Adaptive kNN-MT that determines the choice of k regarding each target token adaptively. Specifically, instead of utilizing a fixed k, we consider a set of possible k that are smaller than an upper bound K. Then, given the retrieval results of the current target token, we propose a light-weight Meta-k Network to estimate the importance of all possible k-Nearest Neighbor results, based on which they are aggregated to obtain the final decision of the model. In this way, our method dynamically evaluate and utilize the neighbor information conditioned on different target tokens, therefore improve the translation performance of the model.
We conduct experiments on multi-domain machine translation datasets. Across four domains, our approach can achieve 1.44∼2.97 BLEU score improvements over the vanilla kNN-MT on average when K ≥ 4. The introduced light-weight Meta-k Network only requires thousands of parameters and can be easily trained with a few training samples. In addition, we find that the Meta-k Net- 2 Background: kNN-MT In this section, we will briefly introduce the background of kNN-MT, which includes two steps: creating a datastore and making predictions depends on it.
Datastore Creation. The datastore consists of a set of key-value pairs. Formally, given a bilingual sentence pair in the training set (x, y) ∈ (X , Y), a pre-trained autoregressive NMT decoder translates the t-th target token y t based on the translation context (x, y <t ). Denote the hidden representations of translation contexts as f (x, y <t ), then the datastore is constructed by taking f (x, y <t ) as keys and y t as values, Therefore, the datastore can be created through a single forward pass over the training set (X , Y).
Prediction. While inference, at each decoding step t, the kNN-MT model aims to predictŷ t given the already generated tokensŷ <t as well as the context representation f (x,ŷ <t ), which is utilized to query the datastore for k nearest neighbors w.r.t the l 2 distance. Denote the retrieved neighbors as , 2, ..., k}}, their distribu-tion over the vocabulary is computed as: where T is the temperature and d(·, ·) indicates the l 2 distance. The final probability when predicting y t is calculated as the interpolation of two distributions with a hyper-parameter λ: p(y t |x,ŷ <t ) = λ p kNN (y t |x,ŷ <t ) where p NMT indicates the vanilla NMT prediction.

Adaptive kNN-MT
The vanilla kNN-MT method utilizes a fixed number of translation contexts for every target token, which fails to exclude noises contained in retrieved neighbors when there are not enough relevant items in the datastore. We show an example with k = 32 in Figure 1. The correct prediction spreadsheet has been retrieved as top candidates. However, the model will finally predict table instead because it appears more frequently in the datastore than the correct prediction. A naive way to filter the noises is to use a small k, but this will also cause over-fitting problems for other cases. In fact, the optimal choice of k varies when utilizing different datastores in vanilla kNN-MT, leading to poor robustness and generalizability of the method, which is empirically discussed in Section 4.2.
To tackle this problem, we propose a dynamic method that allows each untranslated token to utilize different numbers of neighbors. Specifically, we consider a set of possible ks that are smaller than an upper bound K, and introduce a lightweight Meta-k Network to estimate the importance of utilizing different ks. Practically, we consider the powers of 2 as the choices of k for simplicity, as well as k = 0 which indicates ignoring kNN and only utilizing the NMT model, i.e., k ∈ S where Then the Meta-k Network evaluates the probability of different kNN results by taking retrieved neighbors as inputs.
Concretely, at the t-th decoding step, we first retrieve K neighbors N t from the datastore, and for each neighbor (h i , v i ), we calculate its distance from the current context representation as distances and c = (c 1 , ..., c K ) as counts of values for all retrieved neighbors, we then concatenate them as the input features to the Meta-k Network. The reasons of doing so are two-fold. Intuitively, the distance of each neighbor is the most direct evidence when evaluating their importance. In addition, the value distribution of retrieved results is also crucial for making the decision, i.e., if the values of each retrieved results are distinct, then the kNN predictions are less credible and we should depend more on NMT predictions.
We construct the Meta-k Network f Meta (·) as two feed-forward Networks with non-linearity between them. Given [d; c] as input, the probability of applying each kNN results is computed as: (3) Prediction. Instead of introducing the hyperparameter λ as Equation (2), we aggregate the NMT model and different kNN predictions with the output of the Meta-k Network to obtain the final prediction: where p k i NN indicates the k i Nearest Neighbor prediction results calculated as Equation (1).
Training. We fix the pre-trained NMT model and only optimize the Meta-k Network by minimizing the cross entropy loss following Equation (4), which could be very efficient by only utilizing hundreds of training samples.

Experimental Setup
We evaluate the proposed model in domain adaptation machine translation tasks, in which a pretrained general-domain NMT model is used to translate domain-specific sentences with kNN searching over an in-domain datastore. This is the most appealing application of kNN-MT as it could achieve comparable results with an indomain NMT model but without training on any in-domain data. We denote the proposed model as Adaptive kNN-MT (A) and compare it with two baselines. One of that is vanilla kNN-MT (V) and the other is uniform kNN-MT (U) where we set equal confidence for each kNN prediction.
Datasets and Evaluation Metric. We use the same multi-domain dataset as the baseline (Khandelwal et al., 2020a), and consider domains including IT, Medical, Koran, and Law in our experiments. The sentence statistics of datasets are illustrated in Table 1. The Moses toolkit 1 is used to tokenize the sentences and split the words into subword units (Sennrich et al., 2016) with the bpecodes provided by . We use Sacre-BLEU 2 to measure all results with case-sensitive detokenized BLEU (Papineni et al., 2002).      directly use the dev set (about 2k sents) to train the Meta-k Network for about 5k steps. We use Adam (Kingma and Ba, 2015) to optimize our model, the learning rate is set to 3e-4 and batch size is set to 32 sentences.

Main Results
The experimental results are listed in Table 2. We can observe that the proposed Adaptive kNN-MT significantly outperforms the vanilla kNN-MT on all domains, illustrating the benefits of dynamically determining and utilizing the neighbor information for each target token. In addition, the performance of the vanilla model is sensitive to the choice of K, while our proposed model is more robust with smaller variance. More specifically, our model achieves better results when choosing larger number of neighbors, while the vanilla model suffers from the performance degradation when K = 32, indicating that the proposed Meta-k Network is able to effectively evaluate and filter the noise in retrieved neighbors, while a fixed K cannot. We also compare our proposed method with another naive baseline, uniform kNN-MT, where we set equal confidence for each kNN prediction and make it close to the vanilla kNN-MT with small k. It further demonstrates that our method could really learn something useful but not bias smaller k.
Generality. To demonstrate the generality of our method, we directly utilize the Meta-k Network trained on the IT domain to evaluate other domains. For example, we use the Meta-k Network trained on IT domain and medical datastore to evaluate the performance on medical test set. For comparison, we collect the in-domain results from Table 2. We set K = 32 for both settings. As shown in Table 3, the Meta-k Network trained on the IT domain achieves comparable performance on all other domains which re-train the Meta-k Network with in-domain dataset. These results also indicate that the mapping from our designed feature to the confidence of retrieved neighbors is common across different domains.
Robustness. We also evaluate the robustness of our method in the domain-mismatch setting, where we consider a scenario that the user inputs an outof-domain sentence (e.g. IT domain) to a domainspecific translation system (e.g. medical domain) to evaluate the robustness of different methods. Specifically, in IT ⇒ Medical setting, we firstly use medical dev set and datastore to tune hyperparameter for vanilla kNN-MT or train the Meta-k Network for Adaptive kNN-MT, and then use IT test set to test the model with medical datastore. We set K = 32 in this experiment. As shown in Table 4, the retrieved results are highly noisy so that the vanilla kNN-MT encounters drastic performance degradation. In contrast, our method could effectively filter out noises and therefore prevent performance degradation as much as possible.
Case Study. Table 5 shows a translation example selected from the test set in Medical domain with Source Wenn eine gleichzeitige Behandlung mit Vitamin K Antagonisten erforderlich ist, müssen die Angaben in Abschnitt 4.5 beachtet werden.
Reference therapy with vitamin K antagonist should be administered in accordance with the information of Section 4.5.
Base NMT If a simultaneous treatment with vitamin K antagonists is required, the information in section 4.5 must be observed.
kNN-MT If concomitant treatment with vitamin K antagonists is required, please refer to section 4.5.
Adaptive kNN-MT When required, concomitant therapy with vitamin K antagonist should be administered in accordance with the information of Section 4.5.   K = 32. We can observe that the Meta-k Network could determine the choice of k for each target token respectively, based on which Adaptive kNN-MT leverages in-domain datastore better to achieve proper word selection and language style.
Analysis. Finally, we study the effect of two designed features, number of training sentences and the hidden size of the proposed Meta-k Network. We conduct these ablation study on IT domain with K = 8. All experimental results are summarized in Table 6 and Figure 2. It's obvious that both of the two features contribute significantly to the excellent performance of our model, in which the distance feature is more important. And surprisingly, our model could outperforms the vanilla kNN-MT with only 100 training sentences, or with a hidden size of 8 that only contains around 0.6k parameters, showing the efficiency of our model.

Conclusion and Future Works
In this paper, we propose Adaptive kNN-MT model to dynamically determine the utilization of retrieved neighbors for each target token, by introducing a light-weight Meta-k Network. In the experiments, on the domain adaptation machine trans-lation tasks, we demonstrate that our model is able to effectively filter the noises in retrieved neighbors and significantly outperform the vanilla kNN-MT baseline. In addition, the superiority of our method on generality and robustness is also verified. In the future, we plan to extend our method to other tasks like Language Modeling, Question Answering, etc, which can also benefit from utilizing kNN searching (Khandelwal et al., 2020b;Kassner and Schütze, 2020).

A.1 Datastore Creation
We first use numpy array to save the key-value pairs over training sets as datastore. Then, faiss is used to build index for each datastore to carry out fast nearest neighbor search. We utilize faiss to learn 4k cluster centroids for each domain, and search 32 clusters for each target token in decoding. The size of datastore (count of target tokens), and hard disk space of datastore as well as faiss index are shown in Table 7.

A.2 Hyper-Parameter Tuning for kNN-MT
The performance of vanilla kNN-MT is highly related to the choice of hyper-parameter, i.e. k, T and λ. We fix T as 10 for IT, Medical, Law, and 100 for Koran in all experiments. Then, we tuned k and λ for each domain when using kNN-MT and the optimal choice for each domain are shown in Table 9. The performance of kNN-MT is unstable with different hyper-parameters while our Adaptive kNN-MT avoids this problem.

A.3 Decoding Time
We compare the decoding time on IT test set of NMT, kNN-MT (our replicated) and Adaptive kNN-MT condition on different batch size. In decoding, the beam size is set to 4 with length penalty 0.6. The results are summarized in Table 8.