Cluster-Former: Clustering-based Sparse Transformer for Question Answering

Transformer has become ubiquitous in the deep learning ﬁeld. One of the key ingredients that destined its success is the self-attention mechanism, which allows fully-connected contextual encoding over input tokens. However, despite its effectiveness in modeling short sequences, self-attention suffers when handling inputs with extreme long-range dependencies, as its complexity grows quadratically w.r.t. the sequence length. Therefore, long sequences are often encoded by Transformer in chunks using a sliding window. In this paper, we propose Cluster-Former , a novel clustering-based sparse Transformer to perform attention across chunked sequences. The proposed framework is pivoted on two unique types of Transformer layer: Sliding-Window Layer and Cluster-Former Layer, which encode local sequence information and global context jointly and iteratively. This new design allows information integration beyond local windows, which is especially beneﬁcial for question answering (QA) tasks that rely on long-range dependencies. Experiments show that Cluster-Former achieves state-of-the-art performance on several major QA benchmarks.


Introduction
Long-range contextual understanding has proven critical in many natural language processing (NLP) tasks.For example, the relevant context for correctly answering an open-domain question can arch over thousands of words (Chen et al., 2017).Encoding long sequences via deep neural networks, however, has remained an expensive and challenging task due to high demand on training time and GPU memory.Traditional sequence modeling methods (Hochreiter and Schmidhuber, 1997) encode long sequences in a chronological order, which suffers high latency.In the place of sequential encoding, recent models such as Trans-former (Vaswani et al., 2017) use simultaneous selfattention over the entire input instead, which has been successfully adopted in many NLP tasks such as textual entailment (Devlin et al., 2019), dependency parsing (Zhou and Zhao, 2019), and summarization (Lewis et al., 2019).A caveat with Transformer though is that building full connections over long sequences translates to quadratic growth on memory demand and computational complexity w.r.t.sequence length.
One way to efficiently encode long sequences is to first chunk a sequence into much shorter ones with a sliding window, then build connections between the shorter sequences (Figure 1(a)).For example, Child et al. (2019), Beltagy et al. (2020) and Zaheer et al. (2020) apply sparse attention to chunked sequences in hand-designed patterns in order to gather information from the chunks (Figure 1(b)).Choi et al. (2017) and Wang et al. (2019) first use a simpler model to filter chunked sequences, then process selected sequences with fully-connected self-attention.Rae et al. (2019) makes use of the shared memory of chunked sequences to build connections between them.However, these methods cannot encode long-range dependencies with as much flexibility or accuracy as fully-connected self-attention, due to their dependency on handdesigned patterns.
Recently, several studies (Kitaev et al., 2020;Tay et al., 2020a) propose to further improve the sparse attention mechanism by hashing or sorting the hidden states into different buckets (Figure 1(c)).These works mainly explore tasks with relatively short sequences, such as sentence-level machine translation, where the number of hashing vectors is relatively small (less than 16 in Kitaev et al. (2020)), allowing randomly initialized hashing vectors to hash hidden states into correct buckets.However, how to use hashing-based attention in the context of long sequences (e.g.,, up to thousands of words) is still an unexplored territory.
Our proposed framework for efficient long sequence encoding, Cluster-Former, marries both sliding-window and hashing-based methods to achieve effective local and long-range dependency encoding.Cluster-Former consists of two types of encoding layer.The first one (noted as Sliding-Window Layer) focuses on extracting local information within a sliding window.It applies Transformer to the hidden states of each chunked sequence independently, as shown in Figure 1(a).The other one (noted as Cluster-Former Layer) learns to encode global information beyond the initial chunked sequences.Specifically, we first apply clustering to the input hidden states so that similar hidden states are assigned to the same cluster, as shown in Figure 1(d).The clustered and sorted input is then divided uniformly into chunks, each encoded by a Transformer layer.Note that to make model training more efficient, the cluster centroids are not computed online but updated periodically (every epoch or a few epochs).We accumulate the hidden states from the layer prior to the Cluster-Former layer in a memory bank, and apply the K-Means algorithm to form cluster centroids during each update cycle.Compared to previously discussed sparse attention based on pre-selected positions (Figure 1
Among all these works, our method is closer to Set Transformer (Lee et al., 2019), Routing Transformer (Roy et al., 2020), and Fast Trans- former (Vyas et al., 2020), which all use cluster centroids to learn patterns.However, we target at solving a different task, question answering.And it also leads to a significant different framework to encode a short question with a long context, other than a single long sequence, such as language modeling task.Moreover, our cluster centroids are updated in a very different way by periodical centroids update with K-Means on memory bank, other than memory-based centroids (Lee et al., 2019), exponentially moving centroids (Roy et al., 2020), or online clustering (Vyas et al., 2020).
Long Sequence in Question Answering For tasks such as open-domain question answering (Chen et al., 2017), a large volume of documents or paragraphs is usually retrieved to infer the answer, yielding extremely long context content.Despite the fact that state-of-the-art NLP models are capable of extracting answers amid complex context, they still struggle with extremely long input sequences.Recent advances that advocate the use of large-scale pre-trained models (Lewis et al., 2019;Liu et al., 2019;Lan et al., 2020) for question answering make this problem more prominent, due to tremendous memory consumption.To process long sequences, the most widely-used method is to first use a lightweight model to filter out redundant text, then use sliding-window-based approaches to encode the remaining sequences with a more sophisticated model.Chen et al. (2017) integrated bi-gram features into Information Retrieval (IR) methods to retrieve related documents more accurately.Wang et al. (2018) trained a paragraph selector using as the reward whether the entire system can obtain the correct answer or not.Asai et al. (2020) trained a recurrent retriever to select paragraphs for multi-hop question answering.Izacard and Grave (2021) proposed to fuse local encoded information into a decoder for answer generation.Besides the above methods, directly applying Efficient Transformers to process long sequences in question answering is another option.In this paper, we focus on this direction by directly training our Cluster-Former on the long context without using lightweight model for context filtering.

Proposed Approach
The proposed framework to handle long sequences is pivoted on two types of Transformer layer: (i) Sliding-Window Layer; and (ii) Cluster-Former Layer.The former focuses on encoding local sequence information, while the latter is on encoding global context and always built on top of the former layer.An overview of the two layers is illustrated in Figure 2.

Sliding-Window Layer
Despite that our focus is on capturing long-range dependencies for global context, local information also plays a critical role for knowledge propagation.Therefore, in the lower section of our network, we adopt the traditional sliding-window encoding mechanism.A sliding window segments a long sequence X into short, overlapping ones with window size l and stride m, as illustrated in Figure 2(a).Note that in this paper, we focus on question answering tasks, for which we concatenate the question Q with each sequence chunked from the document: where Q ∈ R q×d denotes question embeddings given a QA task, q is the number of tokens in the question, and X ∈ R x×d is the embeddings for all context, x is the number of tokens in context.k is the ID of the chunked sequence, l is the window size, and m is the stride of the sliding window.
[idx 1 : idx 2 ] indicates selecting rows between the index of idx 1 and idx 2 of the matrix.[•; •] means concatenating the matrices along the row.We first use Transformer to encode each sequence in sliding window as follows: where ×d is the output of Transformer on the k-th sequence in the n-th layer, while it is not the final output of the n-th layer.As we expect the neighbouring sequences to share useful information in hidden states as well, we always set m < l to allow overlapping between sequences.We use the mean values of the Transformer hidden states at the overlapped tokens between windows as final outputs.To merge the representations from the (k − 1)-th sequence: and merge representations from (k + 1)-th sequence: where + = is to add matrices in-place and / = is to divide a matrix by a scalar value in-place.The merged hidden states H n+1 k ∈ R (q+l)×d are the final outputs of the n-th layer.If the next layer is Cluster-Former, the output hidden states in this layer H n+1 k will be saved into memory bank for computing the cluster centroids.for i = 2, 3,. . ., ClusterNum do 24: but not in Outputs 25: end for 26: return Outputs 27: end function

Cluster-Former Layer
We introduce a Cluster-Former layer to add global representational power to Transformer beyond sliding windows.An in-depth visualization of the layer is illustrated in Figure 2(b).
The input of the Cluster-Former layer comes from the hidden states of the prior layer (in our case a Sliding-Window layer).After merging the overlaps between sequence chunks, the input of this layer is defined as: where Hn ∈ R (q x/m +x)×d is the hidden states to cluster, x is the number of tokens in the context.As the hidden states with larger cosine similarity are more likely to have higher attention weights, we build sparse self-attention only on the hidden states in the same cluster.In this work, we use K-Means as the chosen clustering method for simplicity.More advanced clustering algorithms have the potential of yielding better performance.Since running K-Means on the fly in each training iteration is computationally expensive, we decide to recompute the cluster centroids with low frequency (every epoch or a few epochs).
In addition, to avoid dramatic changes in the cluster centroids due to limited hidden state inputs, we maintain a memory bank for the most recent hidden states.The entire procedure is depicted in Algorithm 1. Once we compute the cluster centroids, we can directly use them for hidden state clustering as follows: where C n ∈ R p×d are the cluster centroids for layer n, and p is the pre-defined number of clusters.
The function argmax(•) performs on the last dimension and assigns all the input hidden states into different clusters based on the max value of cosine similarity between the hidden states and cluster centroids.v n ∈ R (q x/m +x) is the assigned cluster IDs of all the input hidden states.
Since the number of hidden states in different clusters can vary substantially, padding them to the maximum length for Transformer training will significantly increase the computational time.To make the extraction of global context more efficient, we greedily pick the cluster centroids based on the nearest neighbour (measured by cosine similarity) as shown in the function GETCENTROIDS in Algorithm 1.Thus, the hidden states with similar cluster IDs are also close to each other.Then, we can directly sort the cluster IDs of hidden states and uniformly chunk the hidden states (same window size and stride m): where the function argsort(•) is to obtain the indexes of input values sorted in order (same values sorted by the corresponding position of hidden states).
a n k ∈ R m is the chunked indexes of the hidden states.E n k ∈ R m×d is the k-th clustered hidden states, and we will run Transformer on top of it to build the connection beyond the words in the initial sliding window as follows: After updating the hidden states, we map them back to the order before clustering: where Hn+1 is the final output hidden state of this layer and has the same word order as the input Hn .
In experiments, we stack these two types of layer interchangeably to capture both global and local context efficiently.

Datasets
We evaluate our proposed approach on multiple question answering benchmarks.The statistics of all the datasets are summarized in Table 1.
• Quasar-T1 (Dhingra et al., 2017): The goal of this task is to answer open-domain questions from Trivia Challenge.All the passages harvested through information retrieval can be used to answer questions.The task requires the model to generate answers in phrases.The evaluation metric on this dataset is based on Exact Match and F1 score of the bag-of-words matching.Our evaluation tool2 comes from the SQuAD dataset.
• SearchQA3 (Dunn et al., 2017): The setting of this dataset is the same as Quasar-T, except that the questions are sourced from Jeopardy!instead.

Long Answer Short Answer F1 Precision Recall F1 Precision Recall
BigBird-ETC-large (Zaheer et al., 2020)   incorrect answer predictions, and false negatives are incorrect "no answer" predictions.As the test set is hidden, we split 5% of the training set for validation, and use the original validation set for testing.We use the official tool from the dataset to evaluate our models.We also submit our best model to the leaderboard.

Implementation Details
All the models are trained on 8 Nvidia V100 GPUs.For clustering, we adopt "Yinyang kmeans" (Ding et al., 2015) 5 which takes less than 5 seconds for clustering in all our experiment settings.We set the memory size for clustering M = 100, 000 in Algorithm 1.Based on our experiments, it makes little difference for memory banks with 50k and 100k, update cycles with 1 iteration or half iteration.We use cluster centroids that perform the best on the validation set for test set experiments.As 5 https://github.com/src-d/kmcudathe cluster-centroid is offline computed, the inference time is the same as the sliding-window-based method.We initialize our models with RoBERTalarge (Liu et al., 2019).As the number of position embeddings of RoBERTa is limited to 512, we cannot assign different position embeddings to all tokens.Instead, we assign the same position embeddings to each chunked sequence.
The majority of our model is made up of Sliding-Window Layers, as the local information is essential for QA tasks.We adopt the proposed Cluster-Former Layer in layers 15 and 20 to further capture long-range information.We set the sliding window size l to 256, stride m to 224, and change the number of clusters in {64, 256, 512} to analyze its impact on the final performance.We prepend a special token to the beginning of all the given/retrieved paragraphs and directly concatenate all the paragraphs as the final context sequence.Due to memory constraints, we set the max length to be 5000 during training and 10000 during inference.During dataset finetuning, we use Adam (Kingma and Ba, 2015) to optimize the model.We set warm-up updates to 2,220, maximal updates to 22,200, learning rate to 5 × 10 −5 , and batch size to 160.We tune the dropout rate from {0.1, 0.15, 0.2} for all the methods including baselines and report the best results.The model converges in one day for all the QA datasets.
For Quasar-T and SearchQA, we predict the start and end positions of the answer.For Natural Question, we first identify whether the question has short/long answers or not based on the mean values of the first hidden state of all the chunked sequences, where K is the number of chunks and N is the number of layers.If answerable, we rank all the candidates for long answer selection, and predict the start and end positions of short answers.Our model submitted to Natural Question Leaderboard ensembled 3 models with 512 clusters, and only these models are firstly trained on SQuAD2.0 and then finetuned on Natural Question dataset.

Baselines
We compare our models with several strong baselines, including: R3 (Wang et al., 2018) proposes to use reinforcement learning to jointly train passage ranker and reader.DS-QA (Lin et al., 2018) proposes to first use paragraph selection to filter the noisy data and then trained model on denoised data.Multipassage BERT (Wang et al., 2019) proposes to filter the passages and then merge multiple useful passages into one sequence, which can be encoded by BERT.DrQA (Chen et al., 2017) makes use of attention mechanism across the question and the document for answer phrase extraction.DecAtt and DocReader (Kwiatkowski et al., 2019)   to select long answers and then a reading comprehension model to extract short answers from the long answers.BERT joint (Alberti et al., 2019) jointly trains short and long answer extraction in a single model rather than using a pipeline approach.BERT wwm +SQuAD2 (Pan et al., 2019) makes use of multi-task learning to further boost performance.RikiNet-RoBERTa (Liu et al., 2020) proposes a dynamic paragraph dual-attention reader and a multi-level cascaded answer predictor.BigBird-ETC (Zaheer et al., 2020) makes use of a sparse attention mechanism to encode long sequences.
We also re-implement several strong baselines which have not been applied to process long context in question answering tasks: • Sliding Window: The original method is fully made up of Sliding-Window Layers and can only attend to local information.To make a fair comparison among different methods on long-range information collection, we replace several layers of this sliding window baseline with Sparse Attention, Locality-Sensitive Hashing, and Cluster-Former.
• Sparse Attention (Child et al., 2019): This method replaces several layers in the previous baseline by training a Transformer layer across sequences on pre-selected positions.We run this sparse Transformer on all the hidden states in the same position across sequences, so that the output of sparse Transformer can merge the information from different sequences.

Experimental Results
State-of-the-Art Results on QA Table 2 and 3 show that our proposed method outperforms several strong baselines, thanks to its ability to encode both local and global information.Cluster-Former with 512 clusters achieves new state-of-the-art results on Quasar-T, SearchQA and Natural Question (long answer).

Effect of Cluster-Former
We also test the ability of Cluster-Former on modeling long-range dependencies.Note that Sparse Attention (Child et al., 2019) and Locality-Sensitive Hashing (Kitaev et al., 2020) have never been tested on question answering tasks with long context.For fair comparison, we set the layers 15 and 20 as either Sparse Attention, Locality-Sensitive Hashing or our Cluster-Former, and the left layers are Sliding Window layers.
As shown, Sparse Attention performs worse than our Cluster-Former.The loss may come from the noise introduced by pre-selected positions, the corresponding words of which may not be related.We set the number of hashing vectors in Locality-Sensitive Hashing (LSH) to 64, the same as the number of clusters in Cluster-Former.LSH outperforms the baseline slightly on QA and consistently underperforms our Cluster-Former (#C=64).Overall, our Cluster-Former performs the best.

Effect of Number of Cluster Centroids
We also test the effect of different numbers of cluster centroids (C) on model performance.We observe that the model with 512 clusters works significantly better than the model with 64 clusters on most of the tasks.However, for Natural Questions Long Answer setting, the improvement is marginal.As we mainly rely on the hidden state of special tokens "<s>" for long answer selection, and the same tokens can be assigned into same chunk more easily even with a smaller number of clusters.

Selection of Cluster-Former Layers
We also have an analysis on which layers are better used for Cluster-Former layer.As shown in Table 4, we conduct a hyper-parameter search.And find that it can get better performance with at least one Cluster-Former layers in the middle layer (8-16).The worst results come from only one Cluster-Former layer in the layer of 22 or 23.
Language Modeling Although we focus on QA tasks, to demonstrate the versatility of Cluster-Former, we conduct additional experiments on language modeling using the Wikitext-103 (Merity et al., 2017) and Enwik8 (Mahoney, 2011) benchmarks.All the models are trained from scratch.We set the number of layers to 16, with 8 heads per layer.Our Cluster-Former Layer is used in layers 11 and 15 as in QA models.We segment long input into short sequences of 3072 tokens, set sliding window size l to 256, and stride m to 128.SGD is used for optimizing the models.We set clip threshold of gradients to 0.1, warm-up updates to 16,000, maximal updates to 286,000, dropout rate to 0.3, learning rate to 0.1, and batch size to 16.The model will converge in 3 days for all the LM datasets.As shown in Table 5, Cluster-Former outperforms strong state-of-the-art baselines.

Qualitative Analysis
We perform qualitative analysis on how the hidden states are clustered, by visualizing the corresponding words and positions of the hidden states in Table 6.From the first row, we can see that the special tokens "<s>" tend to belong to the same cluster.Note that "<s>" is the start token of each long answer candidate, and its hidden state is used for final long answer selection.Therefore, Transformer on this cluster can compare across the candidates to make the final prediction.
We further observe that the same types of token are more likely to appear in the same cluster.For example, words from the second row to the forth row cover the topics of time, stopwords, and organization & geopolitical entities.
Finally, we randomly sample a cluster and list the positions of clustered hidden states in the last row of the table.We find that states in long distance, such as the 50-th and 6060-th states (over 6000 tokens apart), can be in one cluster, which demonstrates the ability of Cluster-Former in detecting long-range dependencies.Further, we observe that states tend to cluster in phrases.For example, we see consecutive positions such as "49, 50, 51, 52, 53, 54, 55", which likely results from the sliding-window encoding.

Conclusion
In this paper, we present Cluster-Former, a new method to encode global information for long sequences.We achieve new state of the art on three question answering datasets: Quasar-T, SearchQA, and Natural Questions.Further, we observe that a larger number of clusters in Cluster-Former can lead to better performance on question answering tasks.Cluster-Former is a generic approach, and we believe that it can benefit other NLP tasks that rely on long-range dependencies as well.

Figure 1 :
Figure 1: Illustration of different methods for processing long sequences.Each square represents a hidden state.The black-dotted boxes are Transformer layers.(a) is the sliding-window-based method to chunk a long sequence into short ones with window size 3 and stride 2. (b) builds cross-sequence attention based on sliding window over pre-selected positions (red-dotted boxes).(c) hashes the hidden states into different buckets by randomlyinitialized vectors.(d) is our proposed approach to cluster the hidden states.Our final model is a combination of (a) and (d) that processes both local and global context.
(b)) or randomly-initialized hashing vectors (Figure 1(c)), experimental results show that our method can encode dependency across chunked sequences more effectively.Our contributions can be summarized as follows.(i) We propose Cluster-Former, a novel approach to capturing long-range dependencies more effectively than locality-sensitive hashing method.(ii) We propose a new Transformer-based framework to process long sequences by combining Sliding-Window and Cluster-Former layers to extract both local and global contextual information.(iii) Our model achieves the best performance on question answering datasets of Natural Questions (long answer), SearchQA, and Quasar-T.

Figure 2 :
Figure 2: An overview of the proposed Transformer layer.(a) Sliding-Window layer over a sequence.(b) Cluster-Former layer over clustered hidden states from the output of (a).Cluster centroids are periodically updated based on the memory bank of the hidden states in the corresponding layer.

Table 2 :
Results on Quasar-T, SearchQA test sets and NQ dev set.#C: number of clusters.

Table 3 :
Results on Natural Questions (NQ) leaderboard (test set).We show two published results here from over 40 submissions.Our model achieves No.1 for long answer and No.4 for short answer.
is based on a pipeline approach that first uses a simpler model

Table 6 :
An example from Natural Question dataset.The rows in the middle section show the corresponding words of the clustered hidden states, and the bottom row shows the positions of the clustered hidden states."<s>" refers to start token of long answer candidate.