Adaptive Attention for Sparse-based Long-sequence Transformer

,


Introduction
Transformer-based models (Vaswani et al., 2017) have achieved state-of-the-art performance on a wide variety of natural language processing tasks (Devlin et al., 2019;Liu et al., 2019;Yang et al., 2019).It is also gradually applied to other research fields such as speech and computer vision (Dong et al., 2018;Li et al., 2019;Zhang et al., 2020;Dosovitskiy et al., 2021;Zhu et al., 2021;Touvron et al., 2021).Although self-attention module, the core component in Transformer, can capture global contexts from the whole sequence, the time and memory complexity are both quadratic to the sequence length.Especially when facing longer sequences, Transformer becomes more difficult to process them efficiently and effectively.
Recently, a wide spectrum of efficient Transformers (Child et al., 2019;Ho et al., 2019;Rae et al., 2020;Zhao et al., 2019;Kitaev et al., 2020;Tay et al., 2020;Beltagy et al., 2020;Choromanski et al., 2020;Wang et al., 2020;Zaheer et al., 2020;Roy et al., 2021;Xiong et al., 2021;Tay et al., 2021a;Ma et al., 2021;Chen, 2021;Zhu and Soricut, 2021;Liu et al., 2022) have been proposed to tackle the problem of efficiency, which can be roughly divided into the following directions: sparse attention, low-rank and kernel methods.Because sparse-based attention is intuitive and interpretable in addition to efficiency, we focus on this method in this paper.It usually utilizes some strategies or patterns to limit the number of tokens involved in the attention calculation.Different from traditional sparse Transformer (Martins and Astudillo, 2016;Correia et al., 2019;Peters et al., 2019) with different softmax and pattern-related quadratic computation, recent works mainly adopt sliding windows to achieve linear complexity.For example, Longformer (Beltagy et al., 2020) employs an attention pattern that combines local windowed attention with task-motivated global attention while also scaling linearly with the sequence length.BigBird (Zaheer et al., 2020) incorporates random attention (queries attend to random keys) besides global tokens and local sliding windows.However, these hand-crafted attention patterns mentioned above are usually selected empirically or randomly.It is not an ideal solution for modeling long sequences.How to adaptively select useful tokens for sparse attention according to the context is still an important problem to be considered.
To address these issues, we propose A 2 -Former with adaptive attention to model longer sequences in this paper.It can select useful tokens automatically in sparse attention by learnable position vectors, which consist of meta position and offset position vectors.Because each element in the learnable offset position vector is not an integer, we utilize linear interpolation to gather discrete vectors from original the input embedding matrix.Position visualization further shows that traditional attention patterns are not enough to cover the valuable positions automatically selected by models.Experiments on Long Range Arena, a systematic and unified benchmark with different tasks, show that our model has achieved further improvement in performance compared with other sparse-based Transformers.
Overall, the main contributions are as follows: • We propose a novel efficient Transformer, A 2 -Former, which replaces hand-crafted attention patterns with learnable adaptive attention in sparse attention.Besides, position visualization (Figure 3) further shows that traditional attention patterns are not enough to cover the useful positions automatically selected by models.
• We adopt an interpolation technique to help the model gather discrete positions with a continuous weight matrix.By combining the meta position and generated offset position, the position of tokens can be selected dynamically according to the context.
• Experiments on different long sequence tasks validate the effectiveness of our model.Especially, compared with the previous best sparse attention model, BigBird (Zaheer et al., 2020), our model achieves better results.
Sparse attention methods usually limit the field of view to fixed or random patterns.These patterns can also be used in combination.For example, Sparse Transformer (Child et al., 2019) combines stride and fixed factorized attention by assigning half of its heads to the pattern for reducing the complexity of a traditional Transformer.Longformer (Beltagy et al., 2020) integrates a windowed localcontext self-attention and task-oriented global attention that encodes inductive bias about the corresponding task.BigBird (Zaheer et al., 2020) incorporates random attention besides global attention and local window attention.Random attention means that each query attends to a small number of random keys.However, it is still difficult for these hand-crafted, random or combined attention patterns to select valuable pairs in the sparse attention calculation.Different from them, our proposed sparse attention mechanism can automatically and efficiently learn the position that should be selected and calculated.Especially, our model is also different from traditional sparse Transformer (Martins and Astudillo, 2016;Correia et al., 2019;Peters et al., 2019).They only focus on sparse softmax and its threshold and still require quadratic computation to determine the sparsity pattern.
Low-rank and kernel methods are the other solutions to improve the efficiency of Transformer.Low-rank methods usually assume a low-rank structure in the self-attention matrix.For example, Linformer (Wang et al., 2020) decomposes the original scaled dot-product attention into multiple smaller attentions through linear projections, such that the combination of these operations forms a low-rank factorization of the original attention.And kernel methods rewrite the self-attention mechanism through kernelization.For example, Performer (Choromanski et al., 2020) scales linearly rather than quadratically in the number of tokens in the sequence, which is characterized by subquadratic space complexity and does not incorporate any sparsity pattern priors.Different from these mathematical and theoretical methods, our proposed method is still based on sparse attention but focuses more on how to find and learn attention patterns effectively and efficiently.3 Methodology

Preliminary
, key and value element in Transformer, respectively.L is the sequence length and H is the dimension of hidden states.Thus selfattention in vanilla Transformer can be calculated by where W is the learnable weight for x v .The attention weights α qk ∝ exp{ }, where W ′ and W ′′ are learnable weight matrices for x q and x k .The attention weights are normalized as k∈Ψ k α qk = 1, ensuring that they represent the relative importance of each key vector in the set Ψ k for the query vector x q .
For sparse attention, we can also express previous models in a unified form.We will only consider the query and key in Transformer in the following discussion.Thus sparse attention can be represented as where k indexes the sampled keys, and K is the total sampled key number.Because only a small set of keys are utilized in sparse attention, K ≪ L. p q represents the position of K sampled keys for the query x q .Different models utilize different patterns to select each sampling position p qk ∈ p q , such as sliding window or random generation.Because our proposed adaptive attention is also based on sparse attention, which can be further refined into the following forms: where β qk represents the offset position of the selected key k for the query x q , pq represents the meta position predefined for each query x q according to their absolute index.That is to say, the final position of keys p qk is obtained from the meta position pq and the offset position β qk .Because pq + β qk is a float , we adopt linear interpolation to compute x pq+βqk .The detailed calculation process will be described in the next subsection.

Adaptive Attention
As shown in Figure 1, we propose the adaptive attention to learn sampling position dynamically in sparse attention.The pipeline of our proposed adaptive attention is illustrated in Algorithm 1.For convenience, we describe them in the form of iteration rather than batch.We take L = 6, H = 3, K = 3 as an example to illustrate the whole process from input to output.
First, we will assign the meta position pq = {p q } K in Eq. 3 according to the absolute index of the query token.As shown in Figure 1, the meta position is from 0 to 5 for the sentence with 6 tokens.The position of sampling keys will be generated according to the meta position of the query.We will take the orange token (in Figure 1) as a query to obtain the corresponding representation after adaptive attention.
Then, we use a learnable weights Ŵ ∈ R K×H to obtain the offset position β q ∈ R K for K sampling keys in Eq. 4. As shown in Eq. 5, we can obtain the final position p q ∈ R K from original position pq ∈ R K by combining the meta position and the offset position.
Because the final position p q is not an integer vector, it can not be used directly to select sampling keys.Inspired by previous works (Dai et al., 2017;Zhu et al., 2021) in computer vision, we transform bilinear interpolation of two-dimensional images into linear interpolation of one-dimensional text.That is to say, we utilize linear interpolation to gather vectors of corresponding positions.After we rescale each element p ∈ p q in p qk = pq + β qk to [0, L], we round it down and up to i = ⌊p qk ⌋, j = ⌈p qk ⌉ respectively (j − i = 1).Then we can gather x i , x j according to the integer i, j from the input x.According to the variation of linear interpolation formula, the final position of sampling keys x pq+βqk can be calculated by Next, we use the learnable matrix α qk to obtain the weights of different sampling keys for the query x q .Then we can obtain the final weighted representation AAttn(x q , x) in Eq. 3. We can further optimize the complexity for some classification tasks based on sequence level without pre-training.Since sequence level representation is more useful than token level in these tasks, we can convert x ∈ R L×H to x ′ ∈ R L ′ ×H by linear projection, where L ′ can be set to half of L or even smaller.The detailed performance will be further analyzed in the next section.

Datasets
Long-Range Arena (LRA) (Tay et al., 2021b) is a systematic and unified benchmark for the purpose of evaluating sequence models under the longcontext scenario, which includes six tasks to assess different capabilities of efficient Transformers like their ability to model relations and hierarchical/spatial structures, generalization capability, etc.These tasks include different domains, such as math, language, image, spatial and so on.Following the original datasets, we use accuracy as the metric for these tasks.

Implementation Details
Because different tasks have different lengths and characteristics, we use the same hyper-parameters as those described in (Tay et al., 2021b) for a fair comparison.Specifically, the max length is set to 2,000, 4,000, 4,000 for ListOps, Text and Matching task, respectively.The hidden states in attention is set to 512, 256, 128 for ListOps, Text and Matching task, respectively.In our experiments, Adamax (Kingma and Ba, 2014) is used as our optimizer with 0.05 learning rate.The sampling size K for  each token is ten in all the tasks.To prevent overfitting, we use dropout and set it to 0.1.We integrate our attention into the igloo framework (Sourkov, 2018) and run them in Keras with Tensorflow backend on NVIDIA V100 GPU.

Results
We compare our model with the following state-ofthe-art methods as baselines, including sparse attention methods and low-rank and kernel methods.Sparse attention methods include Sparse Transformer (Child et al., 2019), Longformer (Beltagy et al., 2020), Big Bird (Zaheer et al., 2020) and so on.The results on five tasks are summarized in Table 1.It shows that our proposed A 2 -Former achieves 62.21 average accuracy, which outperforms the best sparse model based on sliding window, Big Bird (Zaheer et al., 2020), by 7.2%.Thus, the adaptive attention approach proposed in this paper is shown to be superior to traditional handcrafted, random, or combined patterns in sparsebased Transformer.

Analysis
As shown in Figure 2, we further analyze the impact of different configurations and parameters on five different tasks.As mentioned above, our proposed A 2 -Former has achieved a huge improve-ment compared to the previous best sparse attention model, BigBird (Zaheer et al., 2020), which proves that even models that combine multiple manual attention patterns is still inferior to the models that learn attention patterns automatically.We attempt to adjust the maximum input length L to half, change hidden states H to small, and reduce the sampling number K. We can observe that the performance of A 2 -Former decreased compared with the original model.specifically, shorter length means less time and content.It is important to find a balance between efficiency and effectiveness according to different tasks.Although the impact of length on some classification tasks based on sequence level is not significant.For adaptive sparse attention, K limits the number of tokens involved in the calculation in each row of the attention matrix, which is also a factor that needs to be balanced.

Visualization
As shown in Figure 3, we randomly selected two examples for visualization.To study the distribution of positions, we only show the position of the selected tokens in sparse attention matrix.The max length of long sequences is 2000.It is obvious that previous hand-crafted attention patterns, such as sliding window attention, are not enough to cover the positions automatically selected by models.From a general trend, these selected positions are indeed distributed on the diagonal, but to cover these positions, a window size of about half the maximum length is required, which is unacceptable in terms of efficiency.

Conclusion
In this paper, we propose a novel sparse-based Transformer, A 2 -Former, which replaces handcrafted attention patterns with learnable adaptive attention in sparse attention.We creatively adopt an interpolation technique to help the model gather discrete positions with continuous position vectors.By combining the meta position and generated offset position, the position of tokens can be selected dynamically according to the context.And position visualization further shows that traditional attention patterns are not enough to cover the useful positions automatically selected by models.Experiments on LRA show that our model has been significantly improved compared with the previous sparse Transformers based on sliding windows.

Figure 1 :
Figure 1: The structure of adaptive attention.

89
Algorithm 1: Adaptive Attention input :an input matrix x ∈ R L×H ; output :AAttn(xq, x) after adaptive attention; 1 begin 2 Generate the meta position pq; 3 for each q ∈ [1, L) do 4 Calculate the offset position β qk via Eq.4; 5 Calculate the final position p qk via Eq.5; 6 Rescale each element p qk to [0, L]; 7 Round p qk down and up to i = ⌊p qk ⌋, j = ⌈p qk ⌉, respectively; Gather xi, xj according to the integer vector i, j from the input x; Calculate the representation x pq +β qk according to xi, xj via Eq.6; 10 Calculate the attention weights α qk of different sampling keys;

Table 1 :
Experimental results on five different tasks, i.e., ListOps, Text, Retrieval, Image and Pathfinder.The last four lines are the main sparse attention methods for comparison.(The best model is in boldface and the second best is underlined.)