Efficient Long-Range Transformers: You Need to Attend More, but Not Necessarily at Every Layer

Pretrained transformer models have demonstrated remarkable performance across various natural language processing tasks. These models leverage the attention mechanism to capture long- and short-range dependencies in the sequence. However, the (full) attention mechanism incurs high computational cost - quadratic in the sequence length, which is not affordable in tasks with long sequences, e.g., inputs with 8k tokens. Although sparse attention can be used to improve computational efficiency, as suggested in existing work, it has limited modeling capacity and often fails to capture complicated dependencies in long sequences. To tackle this challenge, we propose MASFormer, an easy-to-implement transformer variant with Mixed Attention Spans. Specifically, MASFormer is equipped with full attention to capture long-range dependencies, but only at a small number of layers. For the remaining layers, MASformer only employs sparse attention to capture short-range dependencies. Our experiments on natural language modeling and generation tasks show that a decoder-only MASFormer model of 1.3B parameters can achieve competitive performance to vanilla transformers with full attention while significantly reducing computational cost (up to 75%). Additionally, we investigate the effectiveness of continual training with long sequence data and how sequence length impacts downstream generation performance, which may be of independent interest.

These models leverage the attention mechanism (Vaswani et al., 2017) to compute the dependency score for each pair of tokens in an input sequence.
Some practical tasks require these transformer models to handle long-sequence inputs like 8k tokens.For example, chatbot systems gather longterm contexts of user interactions to generate informative texts (Roller et al., 2021).Summarization for news, government reports, and academic papers request models to take inputs of long sequences to generate comprehensive summaries (Shaham et al., 2022), otherwise models often miss important information.Note that typical transformer models apply full attention to capture token dependencies pair-wise.It leads to a quadratic time and space complexity w.r.t.input length.However, such a complexity is prohibitive for long sequences.In particular, it incurs massive memory consumption during the back propagation.For example, a transformer model with 250M parameters consumes over 80G GPU memory when sequence length is 8k (Zuo et al., 2022).
To address this scalability issue, various approaches have been proposed to reduce the complexity.One approach is sparse attention, which restricts each token to attend a subset of tokens based on predefined sparsity patterns (Beltagy et al., 2020;Zaheer et al., 2020;Ainslie et al., 2020).For instance, block sparse attention (Kitaev et al., 2020;Ma et al., 2023) divides the input sequence into several blocks, and only intra-block attention is performed.Besides, sliding-window attention (Beltagy et al., 2020;Zaheer et al., 2020;Ainslie et al., 2020) allows each token to attend to its neighboring tokens within a sliding window.These methods, though reducing the complexity of full attention, cannot sufficiently capture long-range dependencies.Other variants, such as kernel approximation (Peng et al., 2021) and low-rank approximation (Wang et al., 2020;Chen et al., 2021) methods, share the similar spirit and drawbacks.To com-pensate for the lack of long-range dependencies, LongT5 (Guo et al., 2021) introduces global tokens that are obtained by average pooling on every block of tokens (Ainslie et al., 2020).However, the block pooling operations can weaken the signal of crucial tokens and prevent the long-range dependencies from being detected.
In addition to these methods, state space models (SSMs) prespecify global dependency patterns to capture the long-range dependencies only (Gu et al., 2020(Gu et al., , 2021;;Li et al., 2022;Zuo et al., 2022;Ma et al., 2023;Smith et al., 2023).These models can be regarded as linear recurrent neural networks with specifically designed fixed weights.As tailored for global dependencies, SSMs fail to effectively capture local dependencies.In order to combine both local and global dependencies, SPADE (Zuo et al., 2022) and MEGA (Ma et al., 2023) augment SSM layers into transformer layers equipped with local attention.However, state space methods require sophisticated implementation, and often encounter computational instability during the back propagation, especially when scaling up to large model size (Gupta et al., 2022).SPADE and MEGA hence inherit these drawbacks.
Note that the aforementioned methods apply same attention mechanism for every layer.We challenge this conventional wisdom and propose a transformer variant -MASFormer (Mixed Attention Span transFormer).MASFormer utilizes full attention only at a subset of layers whereas employs sparse attention at the remaining layers.Our design is motivated by the phenomenon -that most contexts in NLP data display a great deal of locality of reference (Zaheer et al., 2020;Beltagy et al., 2020).That is, most of information about a token can be derived from its neighboring tokens.In contrast, long-range dependencies among tokens are sparse and infrequent.Consider an academic paper as an example.Within a paragraph, there exist numerous short-term dependencies.Neighboring tokens are closely connected to convey meaningful semantics.Across paragraphs, there can be a small number of long-range dependencies.For example, tokens associated to the primary theme of the paper exhibit rare and weak dependencies across a long span.Since long-range dependencies occur much less frequently, a few layers of full attention are adequate to capture them.In stark contrast, short-term dependencies are more frequent, necessitating local attention in the majority of layers to fully extract these signals.
To demonstrate the effectiveness of MASFormer, We conduct experiments on natural language modeling (ArXiv and PubMed Cohan et al. (2018)) and natural language generation (ArXiv, Cohan et al. (2018) and SCROLLS, Shaham et al. (2022)) tasks.Specifically, we compare the performance of MASFormer to other attention methods using a pretrained GPT-2 model (Radford et al., 2019) of 1.3 billion parameters.Our empirical results demonstrate that MASFormer consistently outperforms baseline methods across different attention cost (i.e. the total number of computed attention scores).In particular, MASFormer can achieve comparable performance to full attention while significantly reducing the computational cost.For example, with 27% of its attention cost, MASFormer achieves a close R2 score as full attention on QMSUM dataset.
We also make additional discoveries with MAS-Former, which are of independent interest.Firstly, we investigate the effectiveness of continual training for long sequence modeling.Many publicly available models are pre-trained with sequences shorter than 2048, and often fail to perform well on longer sequences (e.g.8k/16k tokens).To bridge the gap, we explore the option of continual training to adapt these models to long sequences, thereby avoiding pre-training from the scratch.We discuss its effectiveness with MASFormer in Section 4.3.Secondly, we showcase that increasing sequence length can yield more performance gains on downstream tasks than NLM tasks evaluated by perplexity.We are aware of the recent findings by Sun et al. (2021) that increasing context length exhibits limited impact on NLM perplexity.Nevertheless, when applying MASFormer to downstream tasks like long-context summarization, we find that model performance benefits significantly from extending context length.Such a difference arises from the fact that predicting the next tokens in NLM primarily relies on locality of reference.Capturing infrequent long-range tokens can improve perplexity but not significantly.Therefore, we emphasize the necessity to evaluate model performance on downstream tasks that require longrange dependencies.Furthermore, our empirical evidence suggests that increasing the length can improve the performance only if models possess sufficient capability to handle additional long-range information.Local attention, as a counterexample, often fails to capture long-range signals and hence benefits much less from long sequences.

Pretrained Language Models
Pre-trained transformer models (Devlin et al., 2019;Liu et al., 2019;Brown et al., 2020;Dosovitskiy et al., 2020;He et al., 2021b,a) have manifested superior performance in various NLP tasks.These models are often pre-trained on enormous amounts of unlabeled data in a unsupervised/self-supervised manner such that they can learn rich semantic knowledge.By further fine-tuning these pre-trained models, we can effectively transfer such knowledge to benefit downstream tasks (Zhang et al., 2023).
Existing research on long-range transformers commonly requires pre-training the proposed models from scratch to accommodate new architectures and long inputs (Guo et al., 2021;Zuo et al., 2022).However, the significant training overheads raise a barrier for the widespread utilization of these methods across different language models.Motivated by this, we explore the possibility of leveraging existing pre-trained models and adapting them to long sequences though continual training.

Attention Mechanism
Suppose the input to the layer is X ∈ R n×d , where n is the input sequence length and d is embedding dimension, then self-attention mechanism outputs where Such full attention can simultaneously evaluate the alignment between any pair of tokens in the sequence.Specifically, denote the attention score A = softmax(QK ⊤ / √ d), then A ij captures the alignment between tokens i and j.A typical transformer model applies the full attention at every layer.Denote the number of layers as L. Then its attention cost is Ln 2 .
Sparse attention variants are introduced to mitigate the computational cost of full attention.Figures 1a and 1b illustrates the attention patterns of block sparse attention and sliding-window attention.For instance, block sparse attention divides tokens into blocks of size b and performs intrablock attention only, resulting in an attention cost of bn.Sliding-window attention allows each token to attend its left/right neighboring tokens within a local window of size w.In most of cases, block sparse attention exhibits similar performance as sliding-window attention (Zuo et al., 2022).

Our Approach
We present our method -MASFormer, a longrange transformer variant that mixes different attention spans across layers.

MASFormer: Mixed Attention Span
MASFormer leverages full attention exclusively at a subset of transformer layers, whereas it employs block sparse attention at the remaining layers.The structure of MASFormer is illustrated in Figure 1c.We choose full attention to encode long-range information due to the following reasons: (i) full attention exhibits superior capability to capture longrange dependencies compared to sparse attention; (ii) full attention does not require sophisticated implementation and hence is computationally stable compared to SSMs (Zuo et al., 2022;Gupta et al., 2022); (iii) full attention is compatible with existing pre-trained transformer models, enabling us to conduct continual training which we elaborate in Section 3.2.To mitigate the computational cost, we restrict the number of layers using full attention.
MASFormer is motivated by empirical investigations on performance comparison between models that apply the same attention span at every layer.Figure 2 presents the performance of block sparse attention and full attention on language modeling and summarization tasks.We find that, given longsequence inputs, sparse attention is often insufficient to capture long-range dependencies beyond its attention span.As a result, it shows unsatisfactory performance.To remedy it, one can either increase attention span or switch to full attention to improve model capability of capturing sophisticated dependencies.Though improving model performance, it incurs high computational cost.
Confronting such a trade-off between computational cost and model performance, we challenge the common practice -that applies the same attention span at every layer.MASFormer provides an alternative solution.Instead of increasing attention span evenly, MASFormer allocates a large portion of attention computations to a subset of l layers by equipping them with full attention.Specifically, equipping bottom layers with full attention can yield the best performance as suggested by our empirical analysis in Section 4.31 .At the remain-  ing layers, MASFormer utilizes block attention of small size m, resulting in a controlled attention cost of (L − l)mn + ln 2 .As mentioned in Section 1, such a design is inspired by the phenomenon that most of contexts in NLP data exhibit a great deal of locality of reference.Long-range dependencies, in contrast, are less frequent.Therefore, it is not necessary to enhance attention span at every layer.Instead, a few layers of full attention are sufficient to capture infrequent long-range signals.
The majority of layers can maintain small attention spans to adequately extract local dependencies and control the attention cost.
Our empirical results demonstrate that, with the same attention cost, MASFormer significantly outperforms sparse attention.Remarkably, MAS-Former can achieve comparable performance to full attention while substantially reducing computational cost.Therefore, by mixing different attention spans, MASFormer strikes a better balance between computational cost and model performance.
Moreover, MASFormer offers additional implementation advantages.As using the same attention function, MASFormer is easy to implement and compatible with existing pre-trained models.We can build MASFormer upon pre-trained transformers by changing their attention patterns, which does not involve modification on model architectures and pre-trained weights.Meanwhile, acceleration packages, such as FlashAttention (Dao et al., 2022) and xFormers (Lefaudeux et al., 2022), are applicable to further accelerate the computation of block attention and full attention in MASFormer.

Continual Training with Long Sequences
As mentioned, MASFormer can be implemented upon majority of pre-trained transformers by modifying their attention patterns.However, most of publicly available models are pre-trained with sequences shorter than 2048, and often exhibit subpar performance on longer sequences such as 8k/16k.To bridge this gap, we propose the continual training to adapt the revised model on long sequences and new attention pattern.As such, we can preserve existing pre-trained knowledge and circumvent the intensive overheads of pre-training from scratch.In particular, we first modify the attention pattern of the target model as proposed by MASFormer.If the pre-trained model uses absolute position embeddings, we duplicate them to accommodate long sequences.Subsequently, we provide the revised model with long sequences (e.g., 8k) from pretraining corpus like PILE (Gao et al., 2020).Then we conduct continual pre-training using casual language modeling (CLM) objective.We discuss the effectiveness of continual training in Section 4.3.

Experiments
We evaluate the effectiveness and efficiency of MASFormer on natural language modeling (ArXiv and PubMed, Cohan et al. (2018)), natural language generation (ArXiv Cohan et al. (2018), QMSUM and GovReport Shaham et al. (2022)).We choose the GPT-3 XL model architecture (Brown et al., 2020) as our base model, which consists of 1.3 billion parameters and 24 layers and is pre-trained on PILE (Gao et al., 2020) for 300 billion tokens.GPT is a general purpose model that can be applied to many tasks instead of tailoring them for specific tasks.As such, it makes easy to control experiments and showcase the difference among various methods.Implementation Details.Our base model uses absolute positional embeddings with maximum length 1024.To accommodate longer inputs, we duplicate its positional embeddings to have the maximum length as 8192 such that the model can handle sequences containing up to 8192 tokens.Then, we implement different attention methods by modifying the attention pattern of the base model.We implement all the models with PyTorch (Paszke et al., 2019).All the experiments are conducted on NVIDIA A100 GPUs.Continual Training Details.After changing the attention pattern, we conduct the continual training for MASFormer and baseline methods on PILE corpus (Gao et al., 2020) to adapt the revised models to new attention patterns and long-sequence inputs.We leverage the casual language modeling (CLM) objective to train the model for 50,000 steps with a warmup of 2000 steps.We set the input length as 8192 and use a batch size of 128 such that the models are optimized with 1M tokens per step.We use the constant learning 0.0001 for all methods.Baseline.We compare MASFormer with the following methods: • All full attention is to apply full attention at every layer.It has been adopted by most of existing transformer models as default.Although incurring the maximum attention cost, it achieves the best performance for most of our tasks.Hence, it acts as an upper bound for other methods.
• All block sparse attention is to apply block attention at every layer, which is an effective method to reduce computational cost when modeling long sequences.Block attention sets the attention span of each layer identical such that it evenly distributes the budget of attention computation across layers.
• All sliding-window attention is to apply slidingwindow attention at every layer, which is another variant of sparse attention.It shares the similar spirits and often performs similarly as block attention.
In the following experiments, we compare MAS-Former and the baseline methods across different attention cost C.That is, for all block sparse attention, we set the block size as b = C/(Ln).For all sliding-window attention, we choose the window size as w = C/(2Ln).For MASFormer, we apply a small block size m = 1024 for its block attention and set l as (C − Lmn)/(n 2 − mn).Then we observe how their performance evolves when enhancing the attention cost C or input length n.Experiment Overview.We briefly summarize the experimental contents as follows: • Section 4.1 presents the perplexity evaluation of all the models on ArXiv and PubMed after continual training.
• Section 4.2 compares the summarization performance of the models on ArXiv, QMSUM, and Gov-Report after fine-tuning.Besides, we also discuss the difference between perplexity and downstream evaluation in reflecting model capacity to capture long-range dependencies.
• Section 4.3 provides three crucial analyses: (i) we evaluate the benefits of increasing input length and discuss the requirements to attain these gains; (ii) we analyze the effectiveness of continual training for long-sequence modeling; (iii) we conduct an ablation study to demonstrate that equipping bottom layers with full attention yields the most significant performance gains than other options.We further provide the explanations.

Datasets and Evaluation Details
Datasets.We evaluate the perplexity of the updated GPT-2 for each attention method after continual training.The evaluation is conducted on test sets of ArXiv and PubMed (Cohan et al., 2018).Table 5 presents the statistics of these two datasets.Pubmed consists of scientific documents, with a document's content used as input and its corresponding abstract as the target summary.ArXiv is similar to PubMed, with documents from arXiv.Evaluation Details.We conduct the perplexity evaluation under two settings.(i) We calculate the perplexity (ppl.) with all documents from test sets.

Results
Table 1 compares the overall perplexity on test sets of ArXiv and PubMed.The results suggest that, with l = 4 layers of full attention, MASFormer achieves comparable performance to all full attention, while reducing 72% of its attention cost.With the similar attention cost C, MASFormer outperforms all block attention that evenly distributes the budget of attention computation.For example, MASFormer with l = 2 achieves 8.75 ppl. on PubMed, which is 1.37 lower than that of block attention of b = 2048.

C ArXiv PubMed
Full Figure 3 illustrates the perplexity variation of each method given examples of different length.We can tell that MASFormer and full attention show better performance on longer documents, suggesting increasing context length can improve their prediction performance.Full attention, though incurring the highest attention cost, always achieves the best performance due to its outstanding capability to handle sophisticated dependencies.Notably, with 27% of its attention cost, MASFormer exhibits a curve of ppl.v.s.length that closely resembles to that of full attention.This demonstrates the effectiveness and efficiency of MASFormer to capture long-range dependencies.In contrast, block sparse attention benefits much less from long contexts and underperforms both of them because of its incapability to encode long-range signals.For example, when b = 1024, block attention achieves similar perplexity on PubMed examples of different length.

Datasets and Training Details
Datasets.We evaluate the downstream performance of models on several abstractive summarization tasks to compare their capability of handling long sequences in practice.Specifically, we finetune models on ArXiv (Cohan et al., 2018), QM-SUM and GovReport (from SCROLLS benchmark, Shaham et al. (2022)).Their statistics are summarized in Table 5.We mainly use ROUGE-2 (R2) score (Lin, 2004) as the evaluation metric, which is more important and sensitive than R1 and RL.Training Details.After continual training, we finetune each model and report R2 scores on validation sets.Specifically, we fine-tune models for 3000 steps on QMSUM, 8000 steps on GovReport, and 12000 steps on ArXiv.We set the batch size as 64 for ArXiv and 32 for QMSUM and GovReport.We pick the learning rates from {1 × 10 −5 , 5 × 10 −5 , 1 × 10 −4 , 5 × 10 −4 }, and choose the optimal ones to report the performance of each method.Moreover, the input length is fixed as 8192.We apply the greedy decoding for generation.Please see Appendix B for more details.

Results
In Table 2  different attention cost.The results demonstrate that, with the similar attention cost, MASFormer significantly outperforms sparse attention variants.Furthermore, when enhancing attention cost, MAS-Former achieves greater performance gains than sparse attention methods.This is evident from the steeper slope of its R2 curve versus attention cost, in contrast to the baseline method.For example, when increasing C form 553M to 671M, the R2 score of MASFormer on QMSUM exhibits a substantial improvement, reaching 8.70 from 7.46.Remarkably, this score surpasses even that of full attention.Therefore, MASFormer addresses the trade-off between computational cost and performance gains in a more efficient and effective way.
Notice that, in order to achieve comparable summarization performance to full attention, MAS-Former needs at leaset l = 8 layers of full attention, and providing more can lead to more gains.This observation is different from the findings in NLM (Figure 3) that increasing l beyond 4 provides limited improvement in perplexity.Their different capacity requirements arise from the fact that predicting next tokens in NLM primarily relies on lo-cal dependencies.Capturing infrequent long-range tokens does not significantly improve perplexity.Thus, this discrepancy emphasizes the necessity to evaluate long-range models on downstream tasks.

Benefits of Increasing Sequence Length
In this section, we investigate the benefits of increasing input length for downstream performance.Specifically, we select the input length from {2048, 4096, 6144, 8192} and present the fine-tuning performance of full attention in Figure 5.The results consistently demonstrate that as the input length increases, the model's performance improves.That is, downstream performance benefits significantly from long-sequence inputs.In contrast, increasing example length beyond 6k results in marginal improvements in perplexity (See Figure 3), highlighting again the importance of downstream evaluation.
In addition, when comparing the behaviors of block attention in Figure 2c and 2d, we find that sparse attention often insufficiently capitalize on the benefits offered by longer inputs.For instance, given block size as 4096, its performance on ArXiv remains nearly unchanged when increasing input length from 4096 (R2 = 15.52 in Figure 5a) to 8192 (R2 = 14.49 in Figure 2c).This finding suggests that enhancing input length can only improve model performance if the model possesses the sufficient capability to handle long-range dependencies.

Where to use full attention
To answer where to apply full attention, we compare fine-tuning performance of MASFormers that apply full attention at (i) bottom layers; (ii) middle layers; (iii) top layers; (iv) every L/l layers.The results in Table 4 demonstrate that equipping bottom layers with full attention yields the best performance.This is because that long-range dependencies can be continually captured and reinforced by bottom layers before propagated to upper layers.As such, these long-range signals can be effectively incorporated into the upper layers with local attention, facilitating their encoding of local information.In contrast, when equipping local attention at bottom layers, long-range tokens are first aggregated with neighboring tokens by local attention, thereby weakening their long-range signals.Moreover, if alternating full and local attention every L/l layers, the long-range signals cannot be continually reinforced nor efficiently captured.

Discussion
GPT-Neo (Black et al., 2021) introduces an attention pattern that alternates full and window attention.However, this models is not tailored for long sequences.It sets the local window size as 256 and has the maximum input length as 2048, unable to handle long sequences.Instead, this attention pattern is applied heuristically in an attempt to re- duce computational cost.However, as discussed in Section 4.3.3,this approach is neither effective nor efficient as MASFormer when handling long sequences.As shown in Table 4, applying full attention at every 2 or 3 layers underperforms applying it at bottom 12 or 8 layers.Therefore, alternating between full and block attention results in additional computational cost and performance degradation.
In contrast, MASFormer presents an effective solution for efficient long-sequence modeling.It provides guidance on adapting existing pre-trained transformers to long inputs.Meanwhile, it provides insights for designing large long-range models, especially for deeper models.By equipping only a subset of bottom layer with full attention, we can substantially mitigate computational cost.Additionaly, the computation of MASFormer can be further optimized by leveraging system-level acceleration techniques (e.g., FlashAttention and xFormer) that support both block and full attention.

Conclusion
We propose an efficient long-range transformer -MASFormer that utilizes full attention at a few of bottom layers and employs sparse attention at the remaining layers.Our empirical results on natural language modeling and generation tasks demonstrate that MASFormer can address the trade-off between computational cost and performance gains in a more efficient and effective way.

Figure 1 :
Figure 1: Illustration of attention patterns of (a) block sparse attention with block size b = 3; (b) sliding-window attention with window size w = 1 (on each side); (c) MASFormer that integrates full and sparse attention.
Figure 2: (a,b): We evaluate the perplexity of a pre-trained GPT-2 model with block attention of differnet block size after continual training.(c,d): We fine-tune a GPT-2 model with block attention and compare the summarization performance on ArXiv and GovReport under different block size.Here the input length n is 8192.

Figure 3 :
Figure 3: Perplexity evaluation on ArXiv and PubMed with examples of different length.Here x-axis is the maximum document length of each subset, i.e., k × 1024 (k = 1, 2, 3, . . .). the varying behaviors of models on documents of different length, we divide all documents into several subsets according to their length.Each subset consists of examples, whose length is within ((k − 1) × 1024, k × 1024] (k = 1, 2, 3, . . .).Then, we evaluate the perplexity on each subset.Figure 3 presents the perplexity of models on different subsets of examples.
Table 1 presents the overall perplexity of different models on two datasets.(ii) To showcase

Table 1 :
Perplexity evaluation on ArXiv and PubMed.

Table 2 :
2and Figure4, we present the fine-tuning results on QMSUM, ArXiv and GovReport across Summarization performance of models with different attention methods.The best results are shown in bold.

Table 3 :
We report R1/R2/RL for the above results.