Multi-Vector Attention Models for Deep Re-ranking

Large-scale document retrieval systems often utilize two styles of neural network models which live at two different ends of the joint computation vs. accuracy spectrum. The first style is dual encoder (or two-tower) models, where the query and document representations are computed completely independently and combined with a simple dot product operation. The second style is cross-attention models, where the query and document features are concatenated in the input layer and all computation is based on the joint query-document representation. Dual encoder models are typically used for retrieval and deep re-ranking, while cross-attention models are typically used for shallow re-ranking. In this paper, we present a lightweight architecture that explores this joint cost vs. accuracy trade-off based on multi-vector attention (MVA). We thoroughly evaluate our method on the MS-MARCO passage retrieval dataset and show how to efficiently trade off retrieval accuracy with joint computation and offline document storage cost. We show that a highly compressed document representation and inexpensive joint computation can be achieved through a combination of learned pooling tokens and aggressive downprojection. Our code and model checkpoints are open-source and available on GitHub.


Introduction
Classical information retrieval systems used weighted sparse keyword matching to retrieve relevant documents for incoming search queries. A common approach has been to re-rank these documents with a neural network model which takes the concatenation of the query text and document text as input, and emits a relevance score. We refer to this as a cross-attention network, depicted in Figure 1a. In modern NLP, these systems are typically * Work done while at Google. pre-trained using a technique such as BERT (Devlin et al., 2019) and then fine-tuned on humanlabeled relevance data (Han et al., 2020;Ding et al., 2020;Nogueira et al., 2019).
However, even if the re-ranking model is a stateof-the-art neural system, it is still limited by the documents that were produced by the retrieval system. Because of this, end-to-end neural approaches to retrieval have become popular to improve the relevance of the documents produced in the retrieval stage. These models typically take the form of a dual encoder network (also called a "two-tower network," "Siamese network," and "DSSM" (Huang et al., 2013), as depicted in Figure 1b), which emits a "query vector" and "document vector" conditioned on the query text and document text respectively. The relevance score is typically defined as the dot product (or cosine distance) between these vectors.
With a dual encoder model, each document vector in the corpus can be precomputed offline, and the query vector only needs to be computed once per incoming query. Since scoring each querydocument pair is a simple dot product, a commodity CPU can score thousands of document in a few milliseconds. Moreover, retrieving the highest scoring documents for a query can be done in sub-linear time using approximate nearest neighbor algorithms such as hierarchical k-means clustering (Johnson et al., 2017;Guo et al., 2020).
In practice, modern IR systems are not restricted to this simple retrieve-then-rerank framework, and often use multi-stage re-ranking. For example, the system might first retrieve the top 1000 documents using a combination of a dual encoder and keyword retrieval, then re-rank with a very cheap cross-attention model, then finally re-score the top 50 with a more expensive cross-attention model.
There are two key costs to consider for these type of deep re-ranking models. The first is the number of bytes necessary to store the precomputed  document representation, since this must be stored in fast-access memory (typically RAM) for every document in the corpus. The second is the cost of the "joint computation," which is the part of the model that combines the query and document representation in order to generate a relevance score.
In this work, we explore the Multi-Vector Attention (MVA) architecture as an extension of dual encoder networks. The MVA network produces a query matrix (rather than vector) which attends to a document matrix to produce a query-dependent document representation. The scalar relevance score is then computed in a manner similar to a standard dual encoder model. A mathematical description is given in Section 2.1. Related Work. Several retrieval approaches have been proposed that apply lightweight querydocument scoring on last-layer Transformer features. These consist of multi-vector dual encoders (Luan et al., 2020;Khattab and Zaharia, 2020;Li et al., 2020) that emit multiple query and document vectors which interact via dot products, and multi-layer attention architectures (Gao et al., 2020;Chen et al., 2020;. The work presented here can be thought of as an extension to ColBERT (Khattab and Zaharia, 2020), where we explore various aspects of the output layer in order to compress the document representation even further. We found the max() operation to be unstable when used in conjunction with the more aggressive pooling and downsampling, so we instead used a differentiable attention operation. 2 Multi-Vector Attention Network

Model Architecture
Our architecture employs the standard form of "Transformer-style" dot product attention. Given input matrices Q, K, V ∈ R n×h : Single-headed attention is applied to query vectors X ∈ R q×h and document vectors Y ∈ R d×h using the learned projection matrices These parameters are used to generate the intermediate query key, query value, document key, and document value matrices: These matrices are passed into the operation Attention(Q K , D K , D V ) to perform query tokendependent attention over document tokens for each individual query token X i . The final relevance score between Q and D is given by: which is the average dot product between query value vectors and their corresponding attention averaged document vector.

Pooling Architectures
First-K Tokens. In this pooling method, we truncate sequences to the first K tokens.

Multiple [CLS]
Embeddings. We prepend query and document sequences with [CLS] embeddings Q CLS ∈ R h and D CLS ∈ R h that are retained following the encoder; this can be viewed as a generalization of BERT's single-[CLS] embedding. In prior work, ColBERT (Khattab and Zaharia, 2020) applies this to the query and uses the untruncated sequence during scoring, referring to this as "query augmentation." Temporal Pooling. We also explore a projectionbased pooling approach that reduces sequences by a specified pooling factor ρ. We reshape the input sequence X ∈ R n×h intoX ∈ R n ρ ×ρh , which concatenates every ρ consecutive elements into a single composite vector. Applying the Attention architecture projection layers to these composite vectors completes the pooling operation.

Losses
We pre-train using standard BERT objectives, which include the Masked Language Modeling (MLM) task as well as the Next-Sentence Prediction (NSP) task. Our models are fine-tuned using a softmax loss with the <query, positive document, negative document> training triples provided by MSMARCO.

Experimental Setup
Datasets. We pre-train all of our models on the Colossal Clean Crawled Corpus (C4). For retrieval evaluation, we use the MS-MARCO (Bajaj et al., 2016) passage re-ranking dataset. We truncate query sequences to length 32 and documents to length 112. We observe that the average lengths of queries and documents in MS-MARCO are around 6 and 70 respectively.
Training setup. We use a 12-layer Transformer model with 12 attention heads and hidden size 768, equivalent in architecture to BERT Base . We pretrain dual encoder and cross-attention models on C4 for 100,000 iterations on a v3-128 Cloud TPU, with batch size 8,192 and Adam with learning rate 3e-4. Our Dual Encoder is pre-trained directly on the MLM and NSP tasks rather than initialized as two identical BERT models (which is the conventional practice), and is therefore a stronger baseline than is typically found in the literature. We initialize MVA models from the Dual Encoder checkpoint and pre-train on the same MLM and NSP tasks for 20,000 iterations before fine-tuning to downstream tasks. We fine-tune on MS-MARCO using a batch size of 256 triples and Adam with learning rate 3e-5. We reserve a 10% split of the MS-MARCO training set as a validation set for

Ablation Experiments
Our results center around two important considerations that impact model deployment: (1) the amount of joint computation needed to score every query-document pair and (2) the cost to store the document representations offline. We first conduct ablation experiments to evaluate how independent architecture variations affect downstream accuracy.

Multiple [CLS]
pooling for short document representations, and no pooling for query sequences. In Table 1, we compare three approaches for reducing the length of document sequences: truncation, multiple [CLS] embeddings, and temporal pooling. We find that [CLS] embeddings are clearly superior for producing short document representations, but that all approaches perform similarly for longer representations.
In Table 3, we examine the effect of the number of [CLS] tokens on the query side. In the first row, "All Tokens", the query representation corresponds to one [CLS] token plus each actual (non-padding) token in the query. Our results suggest that short queries do not benefit much from query pooling.

Projections lower costs with little quality drop.
Down-projection is supported by changing the size of the attention head, which leads to a linear reduction in both joint computation and in offline document storage. In Table 2, we show via the accuracy trade-offs that hidden size projections outperform a comparable reduction in document length.
Multi-head Attention is unnecessary. In Ta    head used by MVA can be replaced by multiple smaller attention heads with the same combined dimensionality. We do so by examining a range of document sequence lengths where the length is varied using [CLS] embeddings, but where (length × projection_size) is held constant. We find that multiple projection heads worsen the results across all sequence lengths, and especially so for shorter document sequences.
Joint pre-training improves downstream performance. We find it helpful to first pre-train the MVA parameters for a small number of additional steps on the MLM and NSP tasks before fine-tuning on the downstream dataset. Pre-training a MVA model from scratch, however, is not necessary since the improvement in downstream accuracy is marginal.

Results
Optimal operating points. We present several optimal operating points that maximize MRR with respect to different amounts of allowed joint computation and storage cost. The configuration that achieves the highest MRR for reasonable cost uses   Table 6: Multi-pass re-ranking using a cross-attention model. D = timesteps in document representation, P = projection size.
the full query-document sequence and a projection size of 128. We also present two cheaper architecture variants that use aggressive projections and [CLS] pooling to lower computation, for a large improvement over dual encoders for small additional cost. Our approach attains comparable results to ColBERT (Khattab and Zaharia, 2020) without having to extensively pad the query with [CLS] tokens.
Improved Re-ranking with MVA. We simulate a three-stage re-ranking pipeline where the MS-MARCO top-1000 candidates (originally retrieved by BM25) are first re-ranked using MVA (as well as the Dual Encoder) and then the top-K candidates further re-ranked with the cross-attention model. We show in Table 6 that a comparatively cheap MVA model substantially outperforms a Dual Encoder. Moreover, the amortized cost of crossattention re-ranking enables a better operating point on the joint computation-accuracy curve.

Conclusion
We presented a Multi-Vector Attention (MVA) architecture for deep re-ranking that extends previous work on multi-vector dual encoders and attention architectures.