Efficient Cluster-Based k-Nearest-Neighbor Machine Translation

k-Nearest-Neighbor Machine Translation (kNN-MT) has been recently proposed as a non-parametric solution for domain adaptation in neural machine translation (NMT). It aims to alleviate the performance degradation of advanced MT systems in translating out-of-domain sentences by coordinating with an additional token-level feature-based retrieval module constructed from in-domain data. Previous studies (Khandelwal et al., 2021; Zheng et al., 2021) have already demonstrated that non-parametric NMT is even superior to models fine-tuned on out-of-domain data. In spite of this success, kNN retrieval is at the expense of high latency, in particular for large datastores. To make it practical, in this paper, we explore a more efficient kNN-MT and propose to use clustering to improve the retrieval efficiency. Concretely, we first propose a cluster-based Compact Network for feature reduction in a contrastive learning manner to compress context features into 90+% lower dimensional vectors. We then suggest a cluster-based pruning solution to filter out 10% 40% redundant nodes in large datastores while retaining translation quality. Our proposed methods achieve better or comparable performance while reducing up to 57% inference latency against the advanced non-parametric MT model on several machine translation benchmarks. Experimental results indicate that the proposed methods maintain the most useful information of the original datastore and the Compact Network shows good generalization on unseen domains. Codes are available at https://github.com/tjunlp-lab/PCKMT.


Introduction
Recently, non-parametric approaches (Khandelwal et al., 2021;Zheng et al., 2021a,b;Jiang et al., 2021) have been successfully applied to neural * Equal contribution. † Corresponding author. machine translation (NMT) for domain adaptation with retrieval pipelines. Given an advanced MT model, they generally involve two steps: • It builds a cached memory, usually called datastore, in advance by extracting the context representations of the penultimate layer of the given NMT model corresponding to each target token from in-domain data.
• At inference, it retrieves the k nearest neighbors of the context representation for each generated token from the constructed datastore and then integrates external kNN translation probabilities derived from these retrievals to adjust the translation.
The accessibility of any provided datastore during translation makes them interpretable. Meanwhile, the reliability of these approaches gives the credit to the datastore quality. In spite of significant translation improvements, analyses on the datastore behavior have not been fully explored yet. We empirically observe that the construction of datastore is not optimal for retrieval from two aspects: retrieval latency and semantic distribution.
Retrieval Latency. As shown in Table 1, we compare both translation performance and speed between a pre-trained NMT model (Ng et al., 2019) with 270M parameters and the adaptive kNN-MT (Zheng et al., 2021a) system originated from the former on the same hardware (a P100-16GB GPU with 18 cores Intel Xeon Gold 6240 CPU @ 2.60GHz), where the later is the most advanced retrieval-based NMT model so far. 1 It indicates that the heavy computation of retrieval within a datastore causes increased latency and makes it less practical in real-time scenarios. To address this problem, we propose an efficient pruning strategy to decrease the datastore redundancy so as to deal with the trade-off between the speed and the quality.
Semantic Distribution. For robust token-totoken retrieval, tokens with similar context are expected to be distributed close to each other to form separable and compact semantic clusters, otherwise semantic noise may hurt the retrieval effectiveness.
To explore the potential of k-nearest retrieval, we visualize the feature distribution of a datastore built on the IT-domain corpus (Koehn and Knowles, 2017) in Figure 1. For the datastore constructed in the traditional way, we have 2 important findings. One is that the majority tokens are distributed in the overlapped area regardless of frequency. The other is that even the overall distribution shows a clustering effect, only a few small clusters are correctly classified with respect to frequency. Intuitively, these findings will directly and negatively affect the distance-based retrieval.
Moreover, as (Zhang et al., 2021) suggest, the dimension is highly related to retrieval speed. Preliminary studies on kNN-LM (He et al., 2021) indicate that traditional feature reduction algorithms could only maintain the original performance until the context feature dimension is reduced to a minimum required size (e.g., for feature dimension 1024, PCA requires at least 512). For NMT model, it is still challenging to reduce the feature dimen-1 The speed comparison is based on the implementation released at https://github.com/zhengxxn/adaptive-knn-mt sion to its 10% (e.g., from 1024 to <100). To tackle this problem, we design a cluster-based training strategy where an external light-weight feature reduction network is learnt in a contrastive training manner to maximize the margin between context semantic clusters. In our experiments, we can even cut out 93.75% of the original feature size.
In summary, our main contributions are twofold: • We propose a cluster-based Compact Network to reduce the dimension of the semantic representations and improve the translation performance by making different tokens separable to refine the retrieval results.
• We further propose a cluster-based pruning strategy by filtering redundant representations in the datastore so that our proposed methods could significantly decrease the translation latency during inference.
Experiments on multi-domain machine translation benchmarks indicate that our proposed methods are superior to existing retrieval-based machine translation systems in terms of both speed and quality.

Related Work and Background
In this section, we will briefly introduce the background of the adaptive kNN-MT (Zheng et al., 2021a). Adaptive kNN-MT is derived from kNN-MT (Khandelwal et al., 2021) by inserting a lightweight Meta-k Network that fuses kNN retrievals with various k to alleviate the possible noise induced by a single k. Formally, it is formulated as two steps: target-side datastore creation and Metak Network predictions. Target-side Datastore Creation. The datastore constists of a set of key-value pairs. Given a bilingual sentence pair (s, t) in a corpus (S, T ), a pretrained general domain NMT model autoregressively extracts the context representation h i of the i-th target token conditioned on both source and target context (s, t <i ), denoted as h i = f (s, t <i ). The datastore is finally constructed by taking h i as keys and t i as values: Meta-k Network Prediction. Meta-k Network (f β ) is a two-layer feed-forward network followed by a non-linear activation function. Based on the Figure 2: The diagram of the proposed approach. C-*("#") denotes the *th cluster of token "#". First, the clusterbased Compact Network is used to reduce the key's dimensionality of the original datastore and a new datastore is reconstructed. Then the cluster-based pruning is applied to reduce the datastore size. constructed datastore, it considers a set of different ks that are smaller than an upper bound K. The standard setting for k is Q = {0} ∪ {k r ∈ N | log 2 k r ∈ N, k r ≤ K}. K nearest neighbors of the current context queryĥ i from the datastore are first retrieved at the i-th decoding step. Then the square of l 2 distance fromĥ i to each neighbor (h j , v j ) is denoted as d j = h j ,ĥ i 2 . And the number of distinct values in top j neighbors are denoted as c j . The normalized weights of each available k are computed as: where f β denotes the Meta-k Network. For k r ∈ Q, the word prediction probability over the vocabulary w.r.t each neighbor is computed via the Gaussian kernal function: where T denotes the temperature hyper-parameter. The ultimate prediction probability is a weighted ensemble: Note that a validation set is usually required to study the Meta-k Network before predicting on test sets. During training, only the parameters of the Meta-k Network need to update.

Our Approach
As shown in Figure 2, our proposed approach focuses on datastore reconstruction from the perspectives of feature compression and size pruning by utilizing cluster-based signals.

Cluster-Based Feature Compression
From Figure 1, we observe that spatially close context representations may have noisy and different semantics. During inference, it may lead to unreliable neighbors for retrieval-based NMT (see examples in Appendix D "Case Analysis") due to the entanglements from these noisy context space. We hypothesize that the reasons may be three-fold. First, the pre-trained NMT model on general domain lacks target domain-specific knowledge. Second, the high dimensional semantic space is too sparse and may contain some noisy underlying components. Third, the likelihood-maximization objective from the logits by dot-production enforces the alignment of vector directions, which is inconsistent with the spatially close expectation for the sake of both direction and length.
To address these issues, we propose a one-plusone (f α +f θ ) Compact Network on top of the pretrained NMT model. The first "one" module is to transform the coarse-grained semantics of the pre-trained NMT into the fine-grained semantic clusters. The second "one" module is used to calculate our designed loss function.
To obtain coarse-grained semantic clusters, we first follow the method described in "Target-side Datastore Creation" of Section 2 to create the in-domain datastore. For context representations (keys) with the same target token (value), we conduct target-side clustering for the representations, shown as the left clusters in Figure 3. We denote the resulted clusters from the same value as the cluster family for the corresponding target token. Due to the distance-based clustering, it is guaranteed that clusters within each cluster family are not overlapped at all. However, different cluster families will have large overlapped space according to Figure 1. Therefore, our main purpose is to construct a transform that can make the cluster families separable as well.
The proposed light-weight Compact Network in Figure 3 is desired to fulfill above purpose and compress the feature dimension. The first two-layer perceptron is applied for representation compression: f α (·) = FFN 2 (σ (FFN 1 (·))), where σ(·) denotes the Sigmoid function. The last layer f θ is attached for transferring the compressed representations into classification logits where the output dimension depends on the number of designed categories. Note that the f θ layer is discarded at inference.
In order to obtain the separable cluster families after f α , we are motivated to consider several candidate contrastive regularizations to train the Compact Network.
Triplet Noise-Contrastive Estimation (NCE). For each cluster in one particular cluster family, two semantic representations are randomly sampled, one as the pivot example v * and the other as the positive example v + . From the cluster in a different cluster family, another semantic representation is randomly selected as the negative example v − . Then we conduct NCE (Gutmann and Hyvärinen, 2010) with binary classification on {pivot, positive} and {pivot, negative} to predict which pair belongs to the same cluster.
where the output dimension of f θ is 1.
Triplet Distance Ranking. This is similar to the Triplet NCE. The differences are that (1) we remove the f θ layer and (2) the objective is modified as a ranking loss by minimizing the l 2 distance between the pivot and positive examples as well as maximizing the distance between the pivot and negative ones: Word Prediction Loss. To compensate the loss of linguistic information that NCE may ignore, the traditional word prediction NMT loss is also used to train the Compact Network. In this scenario, the output dimension of f θ is the vocabulary size of the corresponding target language.
In addition, we find that dynamic pivot selection leads to unstable training as the compressed representations are forced to update toward various directions. For each cluster, we modify the dynamic pivot as a static pivot, by fixing it as the centroid. After the training converges, we can construct a new feature-compressed datastore with the output of f α , which is used for query retrieval during the kNN-MT inference.

Cluster-Based Pruning
Apart from feature reduction, the number of keyvalue pairs in the compressed datastore is crucial for the translation latency as well, hence redundant tokens are encouraged to be pruned. In literature, phrase-level pruning strategies have proved efficient for statistical machine translation (SMT) (Ling et al., 2012;Zens et al., 2012). Each record in the phrase table reflects a similar semantic unit, hence one could prune parts of the records that share similar statistics, e.g., translation quality, translation cost, etc.
Enlightened by SMT, we propose an efficient pruning strategy based on n-gram metrics on the original semantic representation space. Intuitively, the entry of a key-value pair in the datastore is redundant if there are other key-value pairs (with the same value) holding for that the difference of their perplexity (PPL) values is smaller than a given threshold (an example is represented in Figure 4).
To make it concrete, we decrible the translation cost as follows. For a given n-gram phrase (t i−n+1 , t i−n+2 , ..., t i ) in the translation with the corresponding token-level translation probability Figure 4: An example of redundant bigram "a man" with similar translation costs. "X" denotes that the node with similar PPL will be randomly deleted in pruning.
p(t j |s, t <j ) ∀j ∈ {i, i − 1, ..., i − n + 1}, we measure the translation cost of its last token (desired value in datastore) as the perplexity (PPL) of the n-gram phrase. However, when n is fixed, n-gram phrases are not always meaningful because some translations are independent of its previous targetside context (Ling et al., 2012). Hence we do not directly adopt the naive PPL as a stable translation cost but truncate it in a heuristic way. We search for the minimal PPL among all consecutive subsequences ending with that last token. Formally, given a bilingual sentence pair (s, t), we define the translation cost for each target token t i : Then we can add the translation cost into the feature-compressed datastore.
For the augmented datastore described above, we only apply propagation-based clustering (Ester et al., 1996;Zhang et al., 1996) upon the translation cost c t i to get cost-similar groups, and partition the semantic representations in accordance to these groups. To get pruned datastore, we adopt uniform sampling on each group and collect them into a small key-value paired datastore. This algorithm is summarized in Algorithm 1.
In brief, our efficient cluster-based k-nearest neighbor machine translation can be concluded into the following steps.
• We adopt the original datastore to train Compact Network while the parameters of NMT Algorithm 1 Cluster-Based Pruning Input: The expected pruning rate r. The translation cost threshold . A preprocessed datastore ((K, C), V).

Output:
A new pruned datastore (K new , V new ). 1. Greedy Clustering On Translation Costs.
are frozen.
• We adopt the validation set to train the Meta-k Network while the parameters of NMT and Compact Network are fixed.
• We reconstruct the feature-compressed datastore and prune it into a small datastore using our proposed n-gram pruning algorithm that will be eventually used for testing.

Experiments
We carried out a series of experiments to evaluate the proposed non-parametric NMT against the previous advanced counterpart on several translation benchmarks.

Datasets
We followed (Zheng et al., 2021a) to conduct all experiments on five widely used machine translation benchmarks of unique domains, including IT, Koran, Medical, Law and Subtitles. The first four domains were also used in (Zheng et al., 2021a) while the last Subtitles dataset contains a large number of target tokens, which is hence suitable to explore our pruning strategy. The statistics of these datasets are shown in Table 2. We tokenized sentences using Moses 2 and split words into subword  units (Sennrich et al., 2016) with the bpe-codes provided by (Ng et al., 2019). We applied the product quantizer with the inverted file system based on Faiss 3 to quantize the datastores and conduct retrieval. The hyper-parameters of Faiss are provided in Appendix B.

Clustering Algorithm Selection
The determination of clustering algorithms depends on computation complexity and clustering effectiveness.
• As semantic clusters in a large datastore are vague and it is hard to determine the prior quantity of clusters existing in a large datastore, clustering algorithms that hold a static cluster quantity in advance (e.g., k-Means (Hartigan and Wong, 1979)) are not fit for dataset partitioning.
• Besides, clustering complexity is not tolerant in practice when it increases up to O(N 2 ) (e.g., Affinity Propagation (Frey and Dueck, 2007)) since N is usually extremely large for a high-quality datastore.
We eventually chose two classical clustering algorithms from candidates for exploration in our experiments: DBSCAN (Ester et al., 1996) and Birch (Zhang et al., 1996). DBSCAN was applied for clustering datastore with 100M-nodes while BIRCH was applied for clustering datastore with 100M+ nodes for the sake of computation-andquality trade-off. In our experiments, We adopted the scikit-learn clustering implements. 4

Baselines
We adopted the following models as our baselines.
• Base NMT. This is the winner model (Vaswani et al., 2017) of WMT'19 German-English News translation task 5 provided by (Ng et al., 2019), which is also used in (Zheng  Table 3: The BLEU performance comparison of the feature reduction methods on the IT domain. All retrieval k is set to 4. DR, NCE and WP denote the distance ranking, noise-contrastive estimation and word prediction objectives, respectively. CL denotes that all the tokens are clustered and then the triplets are selected based on these clusters. [DY] denotes that the pivot is dynamically selected while [ST] denotes static pivot selection. et al., 2021a). It is a Transformer model (Vaswani et al., 2017) with hidden size 1024.
This is the benchmark model of our work.
In our modifications, as expected to reduce the dimension to <10% of its original size, we did greedy searching in [16,32,64,128] to obtain the optimal 64 as f α 's output dimension on the IT domain validation set and then used this setting in all experiments. The detailed dimension related analysis can be found in Appendix A. Similarly we used grid search and selected bigram in the clustering-based pruning algorithm.

Evaluation
All experiments were conducted on a P100-16GB GPU with 18 cores Intel(R) Xeon(R) Gold 6240 CPU @ 2.60GHz except for the experiments in  Subsection 4.5.2 where we used 2 GPU cards to load a larger datastore. All translation results were evaluated in case-sensitive detokenized BLEU with SacreBLEU (Post, 2018).

Results
For simplicity, we refer to the base NMT model equipped with the proposed Compact Network as CKMT and further equipped with the pruned datastore as PCKMT in this section.

Performance of the Compact Network
On the IT domain, we first evaluated the compact layer settings mentioned in Section 3, as well as two traditional feature reduction algorithms: Principal Component Analysis (PCA) used in (He et al., 2021) and Singular Value Decomposition (SVD). We applied the PCA solution to learn feature-wise linear projection while the SVD solution to learn matrix-wise projection that decomposes the weight (W ) of the last layer of the base NMT model into three matrices: Then f α can be replaced by an FFN layer with the weight S 1024 * 64 U 64 * 64 but without bias.
As shown in Table 3, the best CKMT solution is equipped with the Compact Network trained using NCE+CL+DR. It outperforms the adaptive kNN-MT by 0.74 BLEU. Being consistent with (He et al., 2021), we find that it is difficult to use the 1024-to-64 feature-wise PCA to maintain the translation performance with such a low dimension. Basically, the distance ranking loss causes serious performance degradation. We assume that the distance minimization restraint is too strict to optimize a small datastore since both the direction and the length of a semantic vector have already been optimized. Though the word prediction (WP) can recover semantic information, its f θ has too many parameters  to be optimized on the limited IT domain datastet compared with NCE alone. Besides, we attribute the improvement obtained by the clustering (CL) to the introduced semantic disambiguation. Finally, the static pivot selection (ST) achieves an improvement of 0.46 BLEU against the dynamic method. We refer to the best setting [ST] CKMT+NCE+CL as CKMT*, and report the results against the adaptive kNN-MT on various domains in Table 4. CKMT* gains an average improvement of 0.70 BLEU over the adaptive kNN-MT which indicates that our proposed Compact Network refines the retrieval for machine translation.
The Compact Network Training with Limited Data. It is unclear how much data are adequate at training-stage I. Hence, we gradually reduce the number of key-value pairs in the datastore to train the Compact Network as shown in Table 5. As the number decreases, the performance degrades slowly. When we use only 40% of the datastore for training, CKMT still outperforms the adaptive kNN-MT. It indicates that our proposed Compact Network is efficient and requires a small amount of key-value pairs to compress the semantic representations with contrastive loss.
Cross Domain Generalization. Is there a general Compact Network that is capable to generalize to different domains? If so, we will save the cost to train an unique Compact Network for various target domains. To explore this, we trained the Compact Network in a general domain with the large-scale Wikimatrix Corpus (Schwenk et al., 2021) and evaluated its behavior on various target domains. As the last row of Table 4 shows, it is interesting that the general CKMT* drops only 0.39 BLEU compared with 4 domain-specific datastores, and it still outperforms the adaptive kNN-MT by 0.31 BLEU. Overall speaking, the Compact Network generalizes well across different domains.  Table 6: Translation BLEU Results on 4 different domains with 10% pruning rate. k was set to 4. Note that CKMT* in the first row used the full datastore.

Performance of Pruning Methods
We tested our language-wise PPL-based pruning methods with several pruning strategies as follows.
• Spatially Pruning by Distance (SP). It is a naive pruning strategy using distance-wise solution by cutting off nodes with low probability according to the distance from each node to its cluster center.
• Low Translation Probability Pruning (LTP). Tokens translated with low probabilities tend to have poor translation quality, and will be pruned for datastore stability.
• High Translation Probability Pruning (HTP). As the kNN probabilities are beneficial for hart-to-translate words that NMT cannot handle, it would be more encouraged to restore the tokens wrongly translated by the base NMT. In this sense, tokens paired with high confidence will be pruned.
• Random Pruning (RP). We also perform the random pruning strategy alone for the targetside clusters, as the step 2 introduced in Algorithm 1.
The results on 4 different domains are shown in Table 6. Since the datastore size remains the same (10% pruned) for all pruning methods in Table 6, there is no much retrieval speed difference among these methods. Our cluster based pruning strategy generally achieves the smallest degradation. Though other strategies obtain impressive 6 results on a few domains (e.g., 10% pruned CKMT*+HTP outperforms non-pruned CKMT* by 0.18 BLEU on the Koran test set) since previous studies (i.e, (He et al., 2021)) our cluster-based pruning strategy performs the most stably on average. Note that the random pruning strategy is simple yet effective, which coincides with (He et al., 2021).
However, we find that the in-domain data of the tested domains have limited redundancy since the average frequency of bigrams is too low (e.g., more than 0.4M unique bigrams were collected from the 3.6M IT domain datastore, on average each bigrams only has no more than 9 occurrences in the datastore). Therefore, even 10% pruning rate can lead to about 1 BLEU loss in Table 6. We leave reducing the datastore with low n-gram redundancy to our future work.
To further explore the potential of the pruning methods on large datastore, we conducted pruning experiments on Subtitles domain containing 154M keys. We tested the random pruning strategy as well because it is the second competitive pruning strategy. As Figure 5 illustrates, the proposed PCKMT*+Ours with pruning rate 30% can even outperform non-pruned CKMT*. As the pruning rate increases, PCKMT*+Ours generally outperforms PCMKT*+RP for the same k. The performance of PCKMT*+RP drops seriously (more than 1 BLEU point) when the pruning rate ≥ 50%, but PCKMT*+Ours sees a clear drop until the pruning rate ≥ 70%. When the pruning rate increases to 80+%, PCKMT*+RP even performs worse than the base NMT, but PCKMT*+Ours still outperforms it by a large margin. These results suggest that the proposed cluster-based pruning algorithm is effective for datastore reduction.    In Table 7, we further evaluated the computation cost of CKMT* with the same BLEU performance as the adaptive kNN-MT. With the same k and the batch size, PCKMT* achieves 27%~57% less speed latency compared with the adaptive kNN-MT. In addition, we compared our optimally performed model with baselines in Table 8. PCKMT (k=8) equipped with pruning rate 30% has the optimal performance, which obtains an improvement of 0.36 BLEU and 1.56x translation speed over the adaptive kNN-MT.
Cluster Visualization. We visualize the IT domain datastore in Figure 6 to verify our assumption that our Compact Network maps the original se-mantic representations to a separable distribution with less overlaps. Tokens represented by purple dots become more distinguishable with our method.

Conclusion
In this paper, we propose a cluster-based Compact Network for feature reduction in a contrastive learning manner to reduce 90+% context feature dimension, and suggest a cluster-based pruning strategy to prune 10%~40% redundant keys in datastore while translation quality remains unchanged. Our proposed methods achieve better or comparable performance while reducing up to 57% inference latency against the advanced non-parametric MT model on several benchmarks. For future work, it is promising to design effective feature reduction algorithms and pruning strategies based on more linguistic and cross-lingual information.  The output dimension of the first FFN in f α was empirically set as 4 times of the output dimension of the whole f α . We then conducted greedy search on the IT domain validation set to obtain the optimal output dimension of f α in our Compact Network. As shown in Table 9, 64d was the optimal setting superior to the adaptive kNN-MT.

B Hyper-parameters of Faiss
We followed the default implementation setting of (Zheng et al., 2021a). To be concrete, we adopted the FP16 precision to store keys. The number of partition-based quantization centroids was set to 1024 while the number of selected invested lists at query time in the cell-probe method 7 was set to 32. The size of per quantized vector in bytes was set to 64 except for CKMT with 16d/32d compact feature dimension in Table 9 because the output size of the quantized vectors must be smaller than the size of the input features for quantization.  We compared the number of overall parameters of different systems in Table 10. It can be seen that our optimal CKMT* only requires 0.1% more parameters than the adaptive kNN-MT while it significantly decreases the latency. Hence CKMT* achieves an important speed-quality trade-off.