Value-aware Approximate Attention

Following the success of dot-product attention in Transformers, numerous approximations have been recently proposed to address its quadratic complexity with respect to the input length. However, all approximations thus far have ignored the contribution of the *value vectors* to the quality of approximation. In this work, we argue that research efforts should be directed towards approximating the true output of the attention sub-layer, which includes the value vectors. We propose a value-aware objective, and show theoretically and empirically that an optimal approximation of a value-aware objective substantially outperforms an optimal approximation that ignores values, in the context of language modeling. Moreover, we show that the choice of kernel function for computing attention similarity can substantially affect the quality of sparse approximations, where kernel functions that are less skewed are more affected by the value vectors.


Introduction
The Transformer architecture (Vaswani et al., 2017) has been widely successful in a wide range of natural language processing tasks, including machine translation (Edunov et al., 2018), language modeling (Roy et al., 2020), question-answering (Karpukhin et al., 2020), and many more. Transformers pre-trained on large amounts of text with a language modeling (LM) objective have become the standard in NLP, exhibiting surprising amounts of linguistic and world knowledge (Peters et al., 2018;Devlin et al., 2019;Petroni et al., 2019;Hewitt and Manning, 2019;Roberts et al., 2020).
The contextualizing component of the Transformer is the attention layer where all positions in an input sequence of length L aggregate information from the entire sequence in parallel. At its core, given L query, key and value vec-tors, the dot-product attention function outputs 1 softmax(QK )V where the softmax function is applied row-wise on the matrix QK ∈ R L×L , consisting of similarity scores of the query-key pairs. Unfortunately, computing Ω(L · L) similarity scores is prohibitive for long sequences.
To alleviate this, past work proposed to compute an approximation of softmax(QK ). One major line of research focused on sparse attention variants, where only a few similarity scores are computed per position, and the rest are ignored. Methods differ by which query-key pairs are selected Ye et al., 2019;Roy et al., 2020;Kitaev et al., 2020;Beltagy et al., 2020;Gupta and Berant, 2020;Vyas et al., 2020). A second line of research explored dense variants (Katharopoulos et al., 2020;Wang et al., 2020;Bello, 2021;Tay et al., 2020a) (cf. (Tay et al., 2020b) for a survey). E.g., instead of computing the attention scores exactly for only a few querykey pairs, (Choromanski et al., 2020) compute an approximation of scores for all pairs.
In this work, we point to a lacuna in current research on efficient Transformers. While recent work focused on approximating the attention scores softmax(QK ), the true target of approximation should be the output of the attention sub-layer, namely H = softmax(QK )V , which also includes the value vectors, V . We show that ignoring value vectors leads to unwarranted consequences both theoretically and empirically.
To demonstrate the importance of value-aware approximation, we analyze optimal sparse attention, that is, the case where, in hindsight, the model computes dot product similarity only with the most similar key vectors, while still ignoring the value vectors. We show that in the popular masked language modeling (MLM) setup, optimal sparse attention dramatically under-performs compared to 1 Usually, the term is softmax(QK / √ d)V but √ d can be dropped via scaling of queries. an optimal approximation of the true output of the attention sub-layer, H, leading to an error increase of 8-20 points. Next, by theoretically focusing on the case where queries compute similarity to the single most similar key vector, we show that approximating softmax(QK ) is equivalent to approximating H when the value vectors V satisfy strong orthogonality and norm constraints. Conversely, when they do not, ignoring V can lead to an unbounded approximation error.
Second, we discuss the kernel-based view of attention, where efficiency is gained by replacing the exponential kernel (corresponding to softmax) with other kernel functions (Katharopoulos et al., 2020). We theoretically show that while in the exponential kernel case (corresponding to softmax), the effect of the norm of the value vectors is potentially small, switching to other kernels can dramatically increase the importance of the value vectors. We empirically test this by comparing optimal sparse attention given different kernel functions, and see that indeed approximation quality decreases when replacing the exponential kernel, To conclude, we theoretically and empirically show that approximating the attention score matrix alone is insufficient, and propose that the research community should instead approximate the true output of the sub-attention layer, which importantly includes value vectors. Our code and trained models are available at https://github.com/ ag1988/value_aware_attn.

Background
We review the kernel-based view of attention (Tsai et al., 2019), which will be instructive in §3.
Given L d queries, the attention function (Eq. 1) requires computing L · L similarity scores for the query-key pairs, which is prohibitive for long sequences. Sparse attention variants relax this requirement and compute only a few similarity scores, ignoring the rest: for some S ⊆ {1, . . . , L}, |S| L. Methods differ in how S is determined given the queries and keys, and include use of locality bias (Beltagy et al., 2020), global memory (Gupta and Berant, 2020), and LSH hashing (Kitaev et al., 2020), among others. Conversely, instead of exactly computing the attention scores only on a few query-key pairs, dense variants compute an approximation of the true kernel values for all pairs. Such methods output i β i · v i for some approximation β of the true attention distribution α (Choromanski et al., 2020;Peng et al., 2021).

Optimal Sparse Attention
Prior methods for approximating attention have ignored the contribution of the values vectors V . As the true output of the attention sub-layer also depends on V , a natural question is whether it is possible to design better approximation methods by incorporating V , and if so, how much improvement is even possible?
To answer this, we focus on sparse attention, and analyze the difference between an oracle sparse approximation that considers the value vectors, and an oracle approximation that does not. That is, we look at the difference between the two approximations from the perspective of expressivity, ignoring any memory and computational constraints. We denote an optimal value-aware approximation that uses r key vectors per query by optimal-vaware-r, and an optimal approximation that ignores value vectors by optimal-v-oblivious-r. We define optimal-v-oblivious-r as the output of Eq. 2 in which S is selected to be the r indices with the highest attention scores α i 's. This is a natural baseline since this is what current sparse methods are trying to emulate. We now explicitly derive and analyze the value-aware objective.
≤ r} denote the set of points in the polytope of v i 's that can be expressed as a convex combination of at most r value vectors v i . The goal of value-aware approximation is to solve for the point in the constrained region C r closest to the true output o, i.e. argminõ ∈Cr ||o −õ|| 2 . As mentioned, this solution is termed optimal-v-aware-r.
We consider two extreme cases of r: r = 1 and r ≥ d + 1. For r ≥ d + 1, the Carathéodory Theorem (Bárány and Karasev, 2012) then o ∈ C r and the optimal approximation error is 0. In most popular architectures, such as BERT (Devlin et al., 2019), d = 64 L. This means that from the point of expressivity, optimal-v-aware-65 can obtain a perfect approximation. Conversely, we will show in §4 that the performance of optimalv-oblivious-65 is substantially lower.
At the other extreme, when r = 1 (a single value vector), the above objective is equivalent to argmin i∈(1,...,L) ||o − v i || 2 and can be simplified as (3) This equation induces a ranking over value vectors that depends on the value vectors themselves, in contrast to a value-oblivious ranking induced solely by attention weights α.
If v 1 , . . . , v L are orthogonal, the above equation further simplifies to argmin i ||v i || 2 (0.5 − α i ) − j =i α j · 0 = argmin i ||v i || 2 (0.5 − α i ). In this case, if some α i ≥ 0.5 or if v 1 , . . . , v L have equal norms, this would further simplify to argmax i α i , and would therefore be independent of the valuevectors v i 's, implying that a value-oblivious approximation would work well.
But such assumptions on v 1 , . . . , v L do not hold in general and thus an approximation that only depends on α i 's can be sub-optimal. E.g., let v 1 , v 2 , v 3 be orthogonal vectors (1, 0, 0), (0, 2, 0), (0, 0, 3) respectively and let α 1 , α 2 , α 3 be 0.25, 0.35, 0.4. Then v 3 with the highest attention weight α 3 has a squared distance of 3.79 from the true output i α i v i whereas v 1 with the least atten-tion weight α 1 has only 2.49. In this case, optimalv-aware-1 induces exactly the opposite ranking of value vectors compared to optimal-v-oblivious-1. Moreover, if we increase the value 3 in v 3 to infinity, the approximation error will also infinitely grow. This example and, in general, Eq. 3 also show that the optimal ranking can be significantly different from the one induced by α i ||v i || proposed recently by (Kobayashi et al., 2020) for obtaining better interpretability of attention models.

Effect of kernel function Recently, Linear
Transformer (Katharopoulos et al., 2020) proposed to replace the existing exponential kernel with more efficient kernels. We now show that replacing the exponential kernel with a polynomial kernel can lead to a drop in quality for current sparse approximation methods.
Intuitively, because the kernel function affects the skewness of α, it also affects the difference between the ranking induced by the optimalvalue-aware approximation and the optimal-valueoblivious one. For simplicity, consider the case of orthogonal value vectors in which Eq. 3 simplifies to argmin i ||v i || 2 (0.5 − α i ). From Eq. 1, we have α i = κ(q, k i )/ j κ(q, k j ) which is q, k i C / j q, k j C for the degree-C polynomial kernel. For C = 0, we have α i = 1/L, which gives argmin i ||v i || 2 . In this case, the value vectors become crucial when α is uniform. On the other hand, assuming distinct inner products, for C 0 we will obtain max i α i ≥ 0.5, thereby reducing us to argmax i α i , where value vectors do not affect the approximation. The complexity of the Transformer grows exponentially with the degree C and thus in practice a low C must be used (e.g., degree-2 polynomial). In such case, α is likely to be less skewed compared to the exponential kernel and more likely to induce a sub-optimal ranking.
In the next section, we empirically verify the above observations and show a significant performance gap between value-oblivious approximations and value-aware ones.

Experiments
We empirically verify our observations in the context of training causal and masked language models, which are known to strongly correlate with performance on downstream applications Devlin et al., 2019).

Masked LM task
We form examples by sampling sequences and replacing sub-words with <mask> following the procedure in (Devlin et al., 2019). The model is trained to maximize the log probability of the masked out tokens and we evaluate the error of the model as the percentage of masked tokens predicted incorrectly. As approximate attention becomes increasingly relevant for long sequences, we train ROBERTA-4096 on sequences of length 4096 (Fig. 1). Training was warm-started using ROBERTA-base (Liu et al., 2019). Full details on the experimental setup are in §A.1. After training the model for ∼ 2.5M steps, the error of the model (that is, proportion of incorrect predictions) on the evaluation set was 24.2 (compared to 26.6 for an analogous training on 512-long sequences), ensuring that tokens in ROBERTA-4096 indeed attend over longer distances and result in higher quality representations. We then replace the attention function of the trained model with various approximation schemes and evaluate the resulting model on the evaluation set. We first compare optimal-v-oblivious-r to optimal-v-aware-r. We know that the approximation error of value-aware approximation is 0 for r > 64. For r = 1, we exhaustively go through all possible values and choose the one that minimizes the value-aware objective. As seen in Fig. 2 and Table 1, there is substantial gap between the two approximations. For instance, optimal-v-oblivious-65 gives an MLM error of 43.5 whereas the error of optimal-v-aware-65 is 24.2, since it can perfectly approximate full attention. Moreover, we compare optimal-v-oblivious-r to existing approximations: (a) sliding-window-r, where a position attends to r/2 positions to its left and right), (b) LSH attention (Kitaev et al., 2020) and (c) Performer attention (Choromanski et al., 2020). Fig. 2 shows that sliding-window-r trails behind optimal-v-oblivious-r. LSH attention, which tries to emulate optimal-v-oblivious-r, either requires a large number of hash rounds or a large chunk size. Similarly, the Performer attention provides an unbiased approximation of the exponential kernel but suffers from high variance in practice.  Table 1: MLM error of ROBERTA-4096 on the evaluation set using approximate attention described in §4. OVO r: optimalv-oblivious-r, OVA r: optimal-v-aware-r. In LSH, each query attends to a total of r keys per hash round. Causal LM task To investigate the effect of the kernel function on the quality of value-oblivious methods, we train a 6-layer Transformer LM over 512 tokens on WikiText-103 (Merity et al., 2017) (details in §A.2). We train 3 models with identical hyperparameters using the exponential, degree-2 polynomial, and elu kernels respectively and evaluate the trained models with value-aware and valueoblivious approximations. Again, optimal-v-awarer substantially outperforms optimal-v-oblivious-r (Table 2), pointing to the potential of working on approximating the value-aware objective.  Table 2: Evaluation perplexity of models using approximate attention. OVO-r: optimal-v-oblivious-r, OVA-r: optimal-vaware-r.
More importantly, comparing the approximation quality across different kernel functions (Fig. 3), we see that the gap between the three kernels is small when using full attention (512 keys) vectors. However, convergence is much slower for the elu kernel, and especially the degree-2 polynomial, demonstrating that the approximation based on the top-r key vectors is sub-optimal when switching to a less skewed kernel, which is more affected by the value vectors.

Conclusions
In this work, we provide theoretical and empirical evidence against current practice of focusing on approximating the attention matrix in Transformers, while ignoring the value vectors. We propose a value-aware objective and argue that the efforts to develop more efficient Transformers should consider this objective function as a research target.

A Supplemental Material
A.1 Masked LM task The instances for the MLM task ( §4) were formed separately using the corpora listed in Table 3. For each dataset, after appending </s> token at the end of each document, the documents were arranged in a random order and concatenated into a single long text which was then tokenized into a list of sub-words. Depending upon the final input sequence length L of the experiment (512/4096) this list was chunked into full length L − 2 sequences which were then masked randomly following (Devlin et al., 2019) and enclosed within <s> and </s> tokens. To handle sequences longer than 512 tokens, the positional embeddings were used following (Gupta and Berant, 2020 (Zhu et al., 2015) 1.06B 1.02B 1.02M ArXiv (Cohan et al., 2018) 1.78B 1.53B 1.02M PubMed (Cohan et al., 2018) 0.47B 510M 1.02M PG19 (Rae et al., 2020) 3.06B 510M 1.02M  Details of LSH attention Given L queries and L keys in R d , in each hash round, we sample a new matrix R ∈ R C 2 ×d of standard gaussians and hash the queries and keys as H R (x) = argmax([−Rx; Rx]) ∈ {1, . . . , C}. We rearrange the queries (and similarly keys) according to their hash value, breaking ties using the original position, and then chunk them into L/B chunks of B vectors each. Denoting these chunks as Q 1 , . . . , Q L/B and K 1 , . . . , K L/B , for each query in Q i we compute its similarity scores with respect to all keys in K i−1 , K i . I.e. in each hash round a query attends to r = 2B keys. For each query, these similarity scores are accumulated over different hash rounds, and at the end normalized by their sum to get normalized attention scores over the keys. As recommended in the original paper (Kitaev et al., 2020), we use C = 2L/B = 4L/r which in practice can be sub-optimal as rearrangement destroys the original locality structure.
Details of Performer attention Given L queries and L keys in R d we divide each vector by d 1 4 to account for the temperature term in dot-product attention. For a given number F of features, we sample a random orthogonal matrix R ∈ R F ×d as described in (Saxe et al., 2013) and provided as a tensor initialization option in PyTorch. We then map each vector to the feature space as Φ(x) = 1 √ F exp Rx − ||x|| 2 2 ∈ R F where (−) and exp operations are applied element-wise. Similarity score of a query-key pair (q, k) is computed as Φ(q), Φ(k) and and is normalized by the sum of the similarity scores of q with all the keys. Computing this directly leads to numerical instability so we instead compute Φ(q) = 1 √ F exp Rq − ||q|| 2 2 − max(Rq) for queries and Φ(k) = 1 √ F exp Rk − ||k|| 2 2 − max(RK) where K is the matrix of all keys and max is over all elements of input.

A.2 Causal LM task
For this task, we used the language modeling framework provided by Faiseq 2 .