Training Language Models with Memory Augmentation

Recent work has improved language models (LMs) remarkably by equipping them with a non-parametric memory component. However, most existing approaches only introduce mem-ories at testing time or represent them using a separately trained encoder, resulting in suboptimal training of the language model. In this work, we present TRIME, a novel yet simple training approach designed for training LMs with memory augmentation. Our approach uses a training objective that directly takes in-batch examples as accessible memory. We also present new methods for memory construction and data batching, which are used for adapting to different sets of memories—local, long-term, and external memory—at testing time. We evaluate TRIME on multiple language modeling and machine translation benchmarks and show that it is able to achieve significant improvements across all the settings. Concretely, TRIME reduces the perplexity from 18.70 to 15.37 on WIKITEXT-103, by effectively leveraging a large memory set from the training corpus. Compared to standard LM training, TRIME adds negligible computational overhead and is compatible with different neural architectures, making it a versatile solution for training memory-augmented LMs.


Introduction
Memory augmentation has become a remarkable approach to enhance language modeling performance without significantly increasing the amount of parameters and computation. By accessing memory units such as a neural cache of recent inputs (Merity et al., 2017;Grave et al., 2017b) and an external look-up table (Khandelwal et al., 2020), a memory-augmented language model (LM) enjoys increased memorization capacity and sets * TL currently works at Google Research. The collaboration was initialized before he joined Google. 1 Our code and pre-trained models are publicly available at https://github.com/princeton-nlp/TRIME.  Table 1: A comparison between our TRIME language models and previous approaches: vanilla LM, continuous cache (Grave et al., 2017b,a), kNN-LM (Khandelwal et al., 2020). M local , M long , M ext denote three types of memories (see §2.2 for more details).
new state-of-the-art records in various language modeling benchmarks. A major limitation of existing approaches, however, is that the memory units are either introduced at testing time (Grave et al., 2017b;Khandelwal et al., 2020) or taken from a separately trained model . As a consequence, they are not directly optimized during the training process, resulting in a missed opportunity to achieve even stronger results. In this paper, we pioneer and present a novel yet simple training approach TRIME (Training with In-batch Memories) 2 , that is well-suited for memory augmentation in language modeling. Our approach makes two major departures compared to standard language model training: Training objective Inspired by contrastive representation learning, we propose a training objective that directly leverages in-batch examples as accessible memory. Our training objective is closely connected to neural cache models (Grave et al., 2017b;Merity et al., 2017) and nearest-neighbor language models (Khandelwal et al., 2020), where the next-token probabilities are calculated by comparing encoder outputs against static token embeddings and memory representations. However, previous work only considers incorporating memories at testing time, while we do for both training and testing. In-batch memory construction With this training objective in mind, the key challenge is how to construct memories effectively during training while keeping it efficient. We identify three types of memories that can be leveraged at testing time and have been explored in the literature: (a) local memory denotes the words that appear in the recent past and are modeled using attention (Vaswani et al., 2017); (b) long-term memory 3 denotes longrange context from the same document but cannot be directly accessed due to the limit of input length; (c) external memory is used to store the entire training set or any additional corpus (Khandelwal et al., 2020;Borgeaud et al., 2021). To better leverage these memories at testing time, we devise new data batching strategies to improve the construction of training memories ( §4). By packing consecutive segments from the same document in one training batch, our model can access long-term memories beyond the attention context. Additionally, we pack segments from other documents that have high lexical overlap as a proxy to external memory units. Importantly, these working memories are generated on the fly during training, allowing us to back-propagate to all memory representations.
We instantiate TRIME in three models by considering different sets of training and testing memories (Table 1) and evaluate them on multiple language modeling benchmarks (Merity et al., 2017;Mahoney, 2009), and an IWSLT'14 machine translation task. We highlight our results as follows: • We first show that we can simply optimize a language model using our training objective without long-term and external memory. Without any other modifications, we demonstrate that a 247M Transformer-based model (Vaswani et al., 2017) can achieve improved perplexity from 18.70 to 17.76 on WIKITEXT-103 with negligible overhead. This model can be viewed as a simple replacement for vanilla language models.
• By training with consecutive segments in the same batch, our approach is capable of leveraging very long context at testing time-up to 15k-25k tokens on WIKITEXT-103 and ENWIK8. Our approach achieves at least competitive performance as previous works (Dai et al., 2019;Martins et al., 2022;Ji et al., 2022) that modify the Transformer architecture to incorporate memories from previous segments, yet our solution is conceptually simpler and computationally cheaper.
• Finally, we train language models by incorporating all other segments in the same batch as memories. Our model works better with a large datastore at testing time and improves over the kNN-LM model (Khandelwal et al., 2020) by reducing the test perplexity from 16.23 to 15.47 on WIKITEXT-103. We also demonstrate improvements over the the kNN-MT baseline (Khandelwal et al., 2021) on an IWSLT'14 De-En machine translation task.
In summary, we propose a simple approach TRIME for optimizing language models with memory augmentation and demonstrate consistent and significant gains in multiple experimental settings. Our approach does not modify the model architecture and only uses memories at the final prediction step, and hence adds very little computational overhead during both training and inference. As a result, it is compatible with other neural models and techniques such as recurrent networks and compressed attention (Dai et al., 2019;Rae et al., 2020). We hope that our work can encourage the research community to think about better training objectives for language models, given their significant societal impacts (Brown et al., 2020;Chowdhery et al., 2022;Zhang et al., 2022).

Language Modeling
In this paper, we mainly focus on improving language models, although we believe that our solutions may extend to most text generation tasks. Later on, we will demonstrate one application in machine translation ( §5.4). Neural language models take a sequence of tokens as context c t = x 1 , . . . , x t−1 and map it to a vector representation f θ (c t ) ∈ R d , where f θ (·) is parameterized by a neural network. The next-token probability is: where E w ∈ R d denotes the output embedding of token w ∈ V. The parameters are optimized to minimize the negative log-likelihood of ground truth x t during training.

Memory Augmentation
We consider memory as a set of context-target pairs {(c i , x i )} following Grave et al. (2017b); Khandelwal et al. (2020). These context-target pairs can be aggregated to obtain the next-token probability weighted by the similarity between hidden representations. 4 We formalize three types of contexttarget memories as follows: Local memory The local memory is simply the most recent preceding tokens in the same input. Specifically, for c t = x 1 , . . . , x t−1 , it is defined as: Grave et al. (2017b) use the local memory at testing time, denoted by the "continuous cache" model. However, it has been argued less effective for Transformer-based models because they can already learn to leverage recent tokens in the selfattention layers (Khandelwal et al., 2020). Interestingly, we show that using local memory is still beneficial if we consider it during training.
Long-term memory Long-term memory denotes long-range context from the same document, but they cannot be directly accessed by attention.
For example, if a document contains 10k tokens, only a short segment of text (e.g., 100-1k tokens) can be fed into a Transformer model because the complexity scales quadratically with input length. Formally, we divide a document into consecutive segments s (1) , . . . , s (T ) , where a segment s (i) contains L contexts s (i) = {c Previous works (Dai et al., 2019;Rae et al., 2020;Martins et al., 2022;Ji et al., 2022;Wu et al., 2022) leverage hidden representations from previous segments with modified Transformer architectures to learn long-range dependency. Our approach does not modify the model architecture and is compatible with these neural architectures. Note that continuous cache can be naturally extended to longterm memory, as we will experiment later. 5 External memory Finally, external memory assumes a large corpus D and the external memory set can be defined as: D can be simply the training corpus (as is the case in our experiments), or a domain-specific corpus when the testing domain shifts (Khandelwal et al., 2020). Note that |M ext | is usually several orders of magnitude order larger than the previous two types (e.g., 10 8 ); accessing all the memories is computationally expensive and requires approximate nearest neighbor search.

Training with In-batch Memories
In this section, we propose a new training approach TRIME for language model training. Compared to standard language model training, our training objective assumes a set of training memories We differentiate training memories from testing memories, as they are constructed on-the-fly during training and may deviate from the testing memories used during inference. Importantly, the training memories are constructed from the same training batch, which enables backpropagating the training signal to the current hidden representation as well as all the memory representations. We will discuss how to construct training memories in the next section ( §4) and only discuss the training objective in a general form. Our training objective is illustrated in Figure 1. Given a memory set M and a context c, TRIME defines the next-token probability distribution as: Here, f θ (c) is the output representation of a Transformer model and E w is the token embedding as we defined in §2.1. g θ (·) denotes the representations that can be used to compute similarity between c and all the contexts c j in the memory M train . It is possible to simply take g θ = f θ ; however, we find that taking g θ to be the input of the final feedforward layer in Transformer works better, which is consistent with the observation in Khandelwal et al. (2020). In addition, sim(·, ·) is a similarity function and we found using the scaled dot-product sim(q, k) = q·k √ d (Vaswani et al., 2017) leads to stable training and better performance in our preliminary experiments.
This training objective can be viewed as a contrastive loss (Hadsell et al., 2006): for a contexttarget pair (c, w * ), the goal is to align the query representation f θ (c) (and g θ (c)) with the static token representation E w * , and contextualized representations that share the same next token i.e., g θ (c j ) for x j = w * . Our objective handles rare words nicely-if w * does not appear in the training memory, the objective will fall back to aligning f θ (c) with only the word embedding E w * . Similar to the vanilla training loss (Eq. 1), our TRIME loss is optimized to minimize the negative log-likelihood of next token w * and all the parameters θ and E w are updated during training.
Our training objective is inspired by the success of contrastive learning in dense retrieval (Karpukhin et al., 2020)-As we will show in §6, it can help improve retrieving contexts that share the same next token effectively when the set of testing memories is large. Our objective is also closely connected to the objective used in Grave et al. (2017b); Khandelwal et al. (2020), which linearly interpolates two distributions: the standard language modeling objective and a distribution defined by cache or an external datastore, e.g., Our work differs from previous works most in that we use this objective as a training (and testing) objective, while they only used it at testing time-the key is how to construct training memories that we will elaborate next. 6 6 Grave et al. (2017b) described a "global normalization"

Adaption to Different Memories
Inference We are interested in incorporating the three types of memories defined in §2.2 and their combinations at testing time. The testing objective is basically the same as the training objective (Eq. 5) except that we take testing memories as a combination of M local , M long and M ext and we tune a temperature term τ to adjust the weight of the memory component. See Appendix A for details about the testing objective.
Notation Throughout this section, we use L to denote segment length, B to denote the total number of segments used in the one training batch, and m to denote the number of consecutive segments from each document in the batch. Correspondingly, each batch will contain b ≈ B m different documents. L, B and m are hyper-parameters that we will choose for training, and will vary as we consider different memories during inference.
A key challenge is that the testing memories can be very large (e.g., M long ∼ 10 4 and M ext ∼ 10 8 in our experiments) and it is computationally infeasible to keep training memories the same as testing memories. In the following, we will discuss three ways of constructing training memories and data batching, aiming to reduce the discrepancy between training and testing. Along the way, we will also present three major model instantiations: TRIMELM, TRIMELM long , TRIMELM ext (Table 1), which combine the training strategies and different sets of testing memories.

Local Memory
M local only considers all the previous tokens in the same segment. It is straightforward that we can simply use M train = M local . As shown in Fig. 2(a), we basically do not need to make any modifications compared to standard language model training. All we need is to replace the training objective of Eq. 1 by our objective in Eq. 5, by incorporating (c j , x j ), ∀j < i in the memory during both training and testing. The computational overhead is also negligible compared to running neural encoders on the segment x 1 , . . . , x L itself. We denote this model as TRIMELM, which can be viewed as a lightweight variant in the paper, which is similar to our objective. However, they only used it at testing time and only considered short-term contexts in calculating the distribution. Another earlier work (Merity et al., 2017) trained a pointer network component with a learned gating component for the interpolation-we attempted training with a similar objective earlier and found it to perform worse than our current objective. replacement for vanilla language models. As we will show in the experiments, simply incorporating local memory provides a notable gain on multiple LM benchmarks, showing the effectiveness of training with memories explicitly.

Long-term Memory
In order to enable long-term memory augmentation, we pack multiple consecutive segments from the same document in a training batch (i.e., m > 1). For a specific context-target pair (c, w) in the training batch, its accessible memory M train includes tokens from previous segments as well as the preceding tokens in the same segment. Figure 2(b) illustrates the training batch construction and the training memory for a given token (and its context). Note that at testing time, we can use a much longer context-we simply enumerate the number of segments used in M eval and choose the optimum based on the development set. We denote this model as TRIMELM long . It shares a similar motivation with many previous works which aim to leverage memory from previous segments through attention recurrence (Dai et al., 2019;Ji et al., 2022), or memory compression (Rae et al., 2020;Martins et al., 2022;Wu et al., 2022). However, our solution deviates significantly from previous approaches. First, previous works need to store the hidden representations (of every layer) from previous segments and modify the self-attention layers to incorporate them. Our approach does not modify the architecture and only uses the outputs from the last layer. In addition, previous works use stale memory representations and do not back-propagate gradients to the representations of previous segments, whereas our batching method enables gradient propagation to the memory and previous segments. 7 As we will show in the experiments, our approach is competitive with previous works while being conceptually simpler and computationally cheaper.

External Memory
Finally, we consider external memory M ext . Since M ext contains the context-target pairs in a large corpus such as the entire training set, we need to retrieve top-K pairs from M ext measured by sim(g θ (c), g θ (c j )) through (approximate) similarity search (more details are given in §5.2).
Since the retrieved contexts at testing time are expected to be similar to the query context, we propose a simple heuristic for constructing training memories M train by packing segments that have large lexical overlap into the same batch using BM25 (Robertson and Zaragoza, 2009). Specifically, we start with a single segment and repeatedly add segments with highest BM25 scores into the same batch. A high BM25 score indicates that two segments have high lexical overlap and can serve as a good proxy to nearest neighbors in the external memory, which improves our model predictions at testing time. Figure 2(c) illustrates our method. M train contains all tokens from other segments as well as the previous tokens in the same segment. We set m = 1 during training as many segments from the same document tend to have high lexical overlap and denote this model by TRIMELM ext .
In practice, when considering tokens from both the current segment and other segments in the  batch, we observe that the model tends to leverage local memory more and ignore other segments.
To encourage the use of information from other segments, we exclude the local memory from M train with a probability of p during training (we find that p = 90% works the best, see §E). This significantly improves performance when the model is evaluated with a large set of external memory.

Datasets and Tasks
We evaluate our approach on two popular language modeling benchmarks: WIKITEXT-103 (Merity et al., 2017) and ENWIK8 (Mahoney, 2009), and a machine translation benchmark: IWSLT'14 DE→EN (see Appendix B for data statistics and detailed task setups; see Appendix C for detailed model configurations). WIKITEXT-103 is a word-level language modeling dataset consisting of 103M training tokens. We evaluate on two model configurations-one uses a 247M Transformer model and a segment length L = 3, 072 and another one uses a 150M Transformer model with segment length L = 150.
ENWIK8 is a character-level language modeling dataset that contains a total of 100M characters. We use a 12-layer Transformer model with a hidden dimension 512 and segment length L = 512.
IWSLT'14 DE→EN is a machine translation task, which consists of 170K translation pairs. We use a Transformer encoder-decoder model. See Appendix B for how we adjust our approach to the machine translation task.

Training and Inference Details
We implement our approach using the Fairseq library (Ott et al., 2019). For TRIMELM long and TRIMELM ext , we tune the number of segments used in M long on the development set during evaluation. For our TRIMELM ext model which requires building a large datastore at testing time, we use the FAISS library (Johnson et al., 2019) for approximate nearest neighbor search. See our hyperparameters in Appendix C.
We first train our model with the standard LM objective (Eq. 1) for the first 5% updates. Without this warmup stage, we observe the training process to be unstable probably due to a large variance in the estimated distributions. We find that when a large set of external memory M ext is considered during inference, the performance can be improved linearly interpolating the output distribution and a distribution over the memory, similarly to kNN-LM (Khandelwal et al., 2020). Thus, we apply an additional linear interpolation to our output probability distribution when considering external memory M ext (see Appendix F for details).

Results: Language Modeling
We present our language modeling results in Table 2 (WIKITEXT-103, 247M model, L =3,072),  These results suggest that even though the attention mechanism can "see" local context, using local memory during both training and testing can still improve model performance. TRIMELM has no computational overhead compared to vanilla LM (indicated by the "speed" column), making it a simple and better replacement for vanilla language models. Similar trends can be observed in Table 3 and Table 4 (25.87 vs. 25.60 and 1.16 vs. 1.12). The improvement is much smaller though, due to a much smaller segment length L . More analyses are given in §6.
Transformer model and find it to underperform our model, demonstrating the importance of joint training using our approach. Compared to previous methods which explicitly leverage hidden representations from previous segments (Dai et al., 2019;Rae et al., 2020;Martins et al., 2022;Ji et al., 2022), our approach achieves better or at least competitive performance. Different from these approaches which need to store all the hidden representations of every layer and modify the model architecture, we only incorporate the outputs from the last layer-requiring less computations and GPU memory. We also believe that our approach is orthogonal and can be applied on top of these models and we leave it to future work.

TRIMELM ext vs.
kNN-LM Finally, our model TRIMELM ext outperforms the kNN-LM model (Khandelwal et al., 2020), which uses external memory only at testing time-improving the perplexity from 16.23 to 15.47 on WIKITEXT-103 (Table 2). We also evaluate a model which does not use long-term memory (denoted by TRIMELM ext w/o M long ) for a fair comparison with kNN-LM with continuous cache and the difference is very small (15.55 vs 15.47). Our results suggest that by using contrastive loss and BM25 batching ( §4.3), the model learns to better retrieve and leverage information from a large external memory.

Results: Machine Translation
To showcase the generality of our training approach TRIME to other generation tasks, we evaluate our approach on the IWSLT'14 German-English translation task. Since it is a sentence-level task, we do

Model BLEU (↑)
Transformer enc-dec 32.58 kNN-MT 33.15 TRIMEMT ext 33.40 Table 5: Adapting TRIME to kNN-MT (Khandelwal et al., 2021) on an IWSLT'14 machine translation task. not use any local or long-term memory (M local , M long ), as there are few repetitive tokens. We denote our model as TRIMEMT ext . As shown in Table 5, our approach improves the vanilla transformer by 0.82 BLEU score and outperforms kNN-MT (Khandelwal et al., 2021). This demonstrates that our approach is able to improve the performance on other language generation tasks with different memory access.

Analysis
We conduct ablation studies and analyses to further understand individual components of our approach. Due to the limited computation budget, some experiments on WIKITEXT-103 are conducted with a small 7M Transformer model (8 layers, hidden dimension 128) in this section and the trends are generally similar for smaller models (see Appendix C and Appendix D for details).
Batching and memory construction We first study how different data batching and memory construction strategies affects the performance when different testing memories are used. We compare our three models (TRIMELM, TRIMELM long , TRIMELM ext ) in Table 6. This ablation study clearly shows that packing consecutive segments and segments with high BM25 scores in the same training batch can improve the performance when the long-range and external memory are used. This demonstrates the importance of closing the gap between training and inference.
Effectiveness of using local memory We study the effectiveness of our model TRIMELM that uses only local memory with different segment lengths L. As shown in table 7, our model significantly outperforms the baselines in all the settings. This suggests that our model can leverage local memory very effectively to improve performance.
Leveraging long-range contexts We study if our model is able to handle large long-term memory. As Figure 3 shows, our model is able to effectively handle long-range context (more than 10k   Table 7: Performance on the WIKITEXT-103 development set (7M models). We vary the segment L here to study the effectiveness of using local memory.
tokens), which goes beyond typical attention context. Compared to continuous cache (Grave et al., 2017b,a), the improvement of our approach becomes larger when more long-term memory is incorporated. This suggests that our model is able to leverage long-range context much more effectively.

Retrieval performance on external memory
When external memory is used in our experiments, we perform nearest-neighbor search over the entire memory set M ext to retrieve the top K keys (we use K = 1024). Table 9 compares the retrieval accuracy of our approach and kNN-LM (Khandelwal et al., 2020) for different K. Our approach outperforms kNN-LM in terms of retrieval results; this explains how our final perplexity surpasses kNN-LM when incorporating external memory.
Perplexity breakdown for different frequencies Finally, we aim to understand which type of memories improves perplexity of tokens in different frequency groups. We group tokens into 5 buckets according to their frequency on the development set. Table 8 shows the results for different models. Interestingly, TRIMELM and TRIMELM long improve the perplexity of rare words (i.e., frequency ≤ 1k) while achieving similar or slightly worse results for frequent words. TRIMELM ext improves perplexity in all the buckets.

Related Work
Memory-augmented language models We have discussed continuous cache (Grave et al., 2017b,a), kNN-LM (Khandelwal et al., 2020) and models that leverage representations from long-range context, e.g., (Dai et al., 2019;Rae et al., 2020;Wu et al., 2022) in the previous sections.  also aim to combine several types of memories by learning an adaptive gating function; however, their external memory uses a pre-trained vanilla language model. Borgeaud et al. (2021) demonstrate a remarkable performance by augmenting LMs with an external datastore of trillion of tokens and their datastore is built based on chunks of text using off-the-shelf BERT embeddings (Devlin et al., 2019).
prehensive survey of efficient Transformers. Our approach is orthogonal, as we only change the training objective and data batching to enable models to use large contexts during inference.
Retrieval-augmented models for downstream tasks While our paper focuses on improving language models with memory augmentation, other works improve models on downstream tasks with a retrieval component, such as question answering (de Masson D'Autume et al., 2019;Borgeaud et al., 2021;Guu et al., 2020), dialogue , and other knowledge-intensive NLP tasks Petroni et al., 2021).

Conclusion
In this work, we propose TRIME, a training approach for language modeling. Our training objective is inspired by contrastive learning and combines in-batch representations and static token representations. We also present three model instantiations TRIMELM, TRIMELM long , TRIMELM ext . Through carefully-designed data batching and memory construction during training, we show that our models can leverage long-range contexts and external memory effectively at testing time. Our approach adds little computational overhead and does not modify model architecture, making it compatible with other neural models and techniques.
In particular, we believe that our TRIMELM model can be used as an alternative for state-of-the-art auto-regressive language models. For future work, we are interested in training TRIME with large language models and other text generation tasks.

A Inference Method
Formally speaking, our testing objective is basically the same as the training objective (Eq. 5): except that we take M eval as a combination of M local , M long and M ext . Because M eval may be different from the training memories, we tune a temperature term τ to adjust the weight of the memory component when calibrating the distribution, based on the development set.

B Dataset Statistics and Tasks
We evaluate our approach on three benchmarks: WIKITEXT-103, ENWIK8, and IWSLT'14. Table 10 shows the statistics.
WIKITEXT-103 is a word-level language modeling dataset consisting of 103M training tokens. Following standard practice, we use adaptive softmax and adaptive token embeddings  in our model and report perplexity. In order to better compare with previous work, we evaluate on two model configurations-one uses a 247M Transformer model and a segment length L = 3, 072 following Baevski and Auli (2019); Khandelwal et al. (2020) and another one uses a 150M Transformer model with segment length L = 150 following Dai et al. (2019). More details are provided in Appendix C.
ENWIK8 is a character-level language modeling dataset that contains a total of 100M characters. Following previous work, we report bit-percharacter (bpc) on this dataset. We use a 12-layer Transformer model with a hidden dimension 512 and segment length L = 512.
We also evaluate the IWSLT'14 DE→EN machine translation task, which consists of 170K translation pairs. Following Khandelwal et al. (2021), we build an external memory by taking all the translation contexts and the corresponding target token ((x, y <t ), y t ) on the training set. We use the output representation as f ((x, y <t )) and the input representation of last FFN layer as g((x, y <t )) to compute the loss. Similarly, we use BM25 to batch training data -we encourage two target sentences with a high BM25 score to be in the same training batch. We use the default model configuration in the Fairseq library (Ott et al., 2019), and sacrebleu (Post, 2018) to compute BLEU scores (Papineni et al., 2002).  Table 10: Statistics of the 3 datasets used in our paper. WIKITEXT-103 is a word-level LM task and EN-WIK8 is a character-level language modeling task, and IWSLT'14 is a German-English machine translation task. len: denotes the average document (sentence) length in each dataset. IWSLT'14 is a sentence-level task, so incorporating long-range context will not help.   with a probability of p. Here we study how p will affect the final performance of our model. The results of using different p are shown in Table 13. We find that when p = 0, the model performs poorly with external memory as the model learns to only leverage local memory and ignores external memory during training. By increasing p, this issue is mitigated. We set p = 0.9 in our main experiments.

F Linear Interpolation When Using M ext
We find that when a large set of external memory M ext is considered during inference, the performance can be improved by calibrating a separated distribution over the memory and interpolating the output distribution and the memory distribution, similarly to kNN-LM (Khandelwal et al., 2020). We think this is because the distribution of the similarity values has been significantly shifted during inference, while the relative ranking preserves. As a result, having values from two different distribution in one softmax normalization is sub-optimal compared to computing two separated probabilities and interpolating them. We think this is because the distribution of the similarity values has been significantly shifted during inference, while the relative ranking preserves. Thus, we apply an additional linear interpolation to our output probability distribution. Specifically, we first use Eq. 6 to compute the distribution P (w | c). Then, we compute a probability distribution over the tokens in memory P (w | c) as follow, P (w | c) ∝ (c j ,x j )∈M eval :x j =w exp( sim(g θ (c), g θ (c j )) τ ).
We linearly interpolate these two probability distributions with a coefficient λ and get the final output P final (w | c): We tune the temperature terms and λ on the development set.