Once is Enough: A Light-Weight Cross-Attention for Fast Sentence Pair Modeling

Transformer-based models have achieved great success on sentence pair modeling tasks, such as answer selection and natural language inference (NLI). These models generally perform cross-attention over input pairs, leading to prohibitive computational costs. Recent studies propose dual-encoder and late interaction architectures for faster computation. However, the balance between the expressive of cross-attention and computation speedup still needs better coordinated. To this end, this paper introduces a novel paradigm MixEncoder for efficient sentence pair modeling. MixEncoder involves a light-weight cross-attention mechanism. It conducts query encoding only once while modeling the query-candidate interaction in parallel. Extensive experiments conducted on four tasks demonstrate that our MixEncoder can speed up sentence pairing by over 113x while achieving comparable performance as the more expensive cross-attention models.


Introduction
Transformer-based models (Vaswani et al., 2017;Devlin et al., 2019) have shown promising performance on sentence pair modeling tasks, such as natural language inference, question answering, information retrieval, etc (Nogueira and Cho, 2019;. Most pair modeling tasks can be depicted as a procedure of scoring the candidates given a query. A fundamental component of these models is the pre-trained cross-encoder, which models the interaction between the query and the candidates. As shown in Figure 1(a), the cross-encoder takes a pair of query and candidate as input, and calculates the interaction between them at each layer by the input-wide self-attention mechanism. This interaction will be calculated N times if there are N candidates. Despite the effective text representation power, it leads to exhaustive computation cost especially when the number of candidates is very large. This computation cost therefore restricts the use of these crossencoder models in many real-world applications .
Extensive studies, including dual-encoder (Huang et al., 2013;Reimers and Gurevych, 2019) and late interaction models (MacAvaney et al., 2020;Gao et al., 2020;Khattab and Zaharia, 2020), have been proposed to accelerate the transformer inference on sentence pair modeling tasks. As shown in Figure 1(b), the query and candidates are processed separately in dualencoders, thus the candidates can be pre-computed and cashed for online inference, resulting in fast inference speed. However, this speedup is built upon sacrificing the expressiveness of cross-attention (Luan et al., 2021;Hu et al., 2021). Alternatively, late-interaction models adjust dual-encoders by appending an interaction component, such as a stack of Transformer layers (Cao et al., 2020;Nie et al., 2020), for modelling the interaction between the query and the cashed candidates, as illustrated in Figure 1(c). Although this interaction components better preserve the effectiveness of cross-attention than dual-encoders, they still suffer from the heavy costs of the interaction component. Clearly, the computation cost of late-interaction models will still be dramatically increased as the number of candidates grows Zhang et al., 2021).
To tackle the above issues, we propose a new paradigm named MixEncoder to speed up the inference while maintaining the expressiveness of cross-attention. In particular, MixEncoder involves a light-weight cross-attention mechanism which mostly disentangles the query encoding from querycandidate interaction. Specifically, MixEncoder encodes the query along with pre-computed candidates during runtime, and conducts the light-weight cross-attention at each interaction layer (named as interaction layer), as illustrated in Figure 1 Figure 1: Architecture illustration of three popular sentence pair approaches and proposed MixEncoder, where N denotes the number of candidates and s denotes the relevance score of candidate-query pairs. The cache is used to store the pre-computed embeddings.
This design of light-weight cross-attention allows the interaction layer to process all the candidates in parallel. Thus, MixEncoder is able to encode the query only once, regardless of the number of candidates.
MixEncoder accelerates the online inference from two aspects. Firstly, MixEncoder processes each candidate into k dense context embeddings offline and cache them, where k is a hyper-parameter. This setup speeds up the online inference using precomputed representations. Secondly, our interaction layer performs attention only from candidates to the query. This disentangles the query encoding from query-candidate interaction, thus avoiding repeatedly query encoding and supporting processing multiple candidates in parallel.
We evaluate the capability of MixEncoder for sentence pair modeling on four benchmark datasets, related to tasks of natural language inference, dialogue and information retrieval. The results demonstrate that MixEncoder better balances the effectiveness and efficiency. For example, MixEncoder achieves substantial speedup more than 113x over the cross-encoder and provides competitive performance.
In summary, our main contributions can be summarized as follows: • A novel framework MixEncoder is proposed for fast and accurate sentence pair modeling. MixEncoder involves a light-weight crossattention mechanism which allows us to encode the query once and process all the candidates in parallel.
• Extensive experiments on four public datasets demonstrate that the proposed MixEncoder provides better trade-offs between effectiveness and efficiency than state-of-the-art models.

Background and Related Work
Neural ranking models. These models focus on measuring the relevance of sentence pairs. A common practice is to map each sentence to a dense vector separately, and then measure their relevance with a similarity function (Huang et al., 2013;Karpukhin et al., 2020;. These models are known as dual-encoder models. Dualencoder models can pre-compute the candidate representations offline, since the candidate encoding is conducted independent of the query. Recently, pretrained Transformer-based models (cross-encoder) have achieved great success on many sentence pair tasks (Li et al., 2022;Guo et al., 2022). These models take the concatenation of one sentence pair as input and perform cross-attention at each layer. This brings deep interactions between the input query and the candidate. Despite the promising performance, cross-encoder models will face significant latency in online inference since all the candidates are encoded online. Late-interaction models. Various late interaction models have been proposed to combine the advantages of the dual-encoder and the cross-encoder. Specifically, these models disentangle the sentence pair modeling into separate encoding followed by a late interaction. They can pre-compute candidate representations offline, and model the relationship of query-candidate pairs by cross-attention online. For instance, the late-interaction models, including Deformer and PreTTR (MacAvaney et al., 2020), is based on a decomposed transformer, where low-level layers encode the query and candidate separately and the higher-level layers process them jointly. As shown in Figure 1(c), given N candidates, the late Transformer layers have to encode the query N times. It results in extensive computation costs. Other models propose to adopt a light weight interaction mechanism, such as polyattention (Humeau et al., 2020) and MaxSim (Khattab and Zaharia, 2020), instead of Transformer layers to speed up the online inference.
Our MixEncoder can behave as a late interaction model by replacing the upper Transformer layers of a dual-encoder with our interaction layer. The novelty of the MixEncoder lies in the light-weight cross-attention mechanism and pre-computed context embeddings.

Method
In this section, we first introduce the details of proposed MixEncoder, which mainly includes two stages, i.e., candidate pre-computation stage and query encoding stage. Figure 2 provides the architecture of MixEncoder. We then describe how to apply MixEncoder for different tasks such as classification task and ranking task.

Problem Statement
Given a sentence pair, models are required to generate either a prediction or a ranking score. The former is known as a linear-probe classification task (Conneau and Kiela, 2018) and the latter is a multi-candidate ranking task (Nguyen et al., 2016). For classification task, the training set consists of paired samples, where y i is the label of the sentence pair, N is the size of the dataset, and q i , p i denotes the query and the candidate, respectively. For ranking task, the samples in the training set can be denoted as where p i is the positive candidate for the q i while C i is a set of negative candidates.

Candidate Pre-computation
We describe how MixEncoder pre-compute each existing candidate into several context embeddings offline. Let the token embeddings of one candidate be T i = [t 1 , · · · , t d ]. We experiment with two strategies to obtain k context embeddings from these token embeddings: (1) prepending k special tokens {S i } k i=1 to T i before feeding T i into an Transformer encoder (Vaswani et al., 2017;Devlin et al., 2019), and using the output at these special tokens as context embeddings (S-strategy); (2) maintaining k context codes (Humeau et al., 2020) to extract global features from the last layer output of the encoder by attention mechanism (Cstrategy). The default configuration is S-strategy as it provides slightly better performance.
Suppose there are N candidates; we use E 0 ∈ R N ×k×d to denote the pre-computed context embeddings of these candidates, where d indicates the embedding size.

Query Encoding
During the online inference stage, for a query with N candidates, models have to measure the relevance of N query-candidate pairs. A typical crossencoder repeatedly concatenates the query with each candidate and encodes it N times. It leads to prohibitive computation costs. One of the most effective ways to reduce the computation is to reduce the encoding times of the query.
In this section, we first depict the overview of the query encoder. Then, we introduce the core component of our MixEncoder: interaction layer. It performs a light-weight candidate-to-query crossattention to estimate relevance scores in a single pass of the query encoding, no matter how many candidates the query has.

Overview of Encoder
Take an encoder that consists of five Transformer layers L 1 , L 2 , . . . , L 5 as an example. When encoding the incoming query online, we replace the second and fifth Transformer layers L 2 , L 5 with two interaction layers, denoted as I 1 2 , I 2 5 . Now the encoder can be depicted as {L 1 , I 1 2 , L 3 , L 4 , I 2 5 }, shown in Figure 2(b) . These layers are applied to the incoming query sequentially to produce contextualized representations of the query and the candidates.
Formally, each Transformer layer L i (·) takes the query token representations q i−1 ∈ R m×d from the previous layer and produces a new representation matrix q i = L i (q i−1 ), where m denotes the query length and q i ∈ R m×d .
Each interaction layer I j i (·) takes the query token representations q i−1 from the previous layer as input, along with the context embeddings E j−1 and a set of state vectors H j−1 ∈ R N ×d from the previous interaction layer (or cache): The output E, H of the last interaction layer are fed into a classifier to generate predictions for each query-candidate pair.

Interaction Layer
This section describes the details of how the interaction layer generates candidate and query representations.
Candidate Representation. Given q i−1 and E j−1 , layer I j i performs a self-attention over q i−1 , and a candidate-to-query cross-attention over q i−1 , E j−1 simultaneously, as shown in Figure 2(b). Formally, the query self-attention is conducted as where we write LN(·) for a linear transformation, FFN(·) for a feed-forward network and Att(Q, K, V ) for a self-attention operation (Vaswani et al., 2017). The cross-attention is formulated as By simply concatenating K i−1 , V i−1 generated from the query with K j−1 , V j−1 generated from the candidates, the cross-attention operation dominated by Q j−1 aggregates the semantics for each query-candidate pair and produces new context embeddings E j ∈ R N ×k×d .
As shown in Eq. (3) and (5), the interaction layer separates the query encoding and the crossattention, thus the candidates embeddings are transparent to query. This design allows encoding the query only once regardless of the number of its candidates.
Query Representation. As shown in Eq.(5), the context embedding matrix E contains the semantics from both the query and the candidates. It can be used to estimate the relevance score of candidate-query pairs as where s ∈ R N . Since E may not be sufficient to represent semantics for each candidate-query pair, we choose to maintain a separate embedding h to represent the query. Concretely, we conduct an attention operation at each interaction layer and obtain a unique query state for each candidate.
We first employ a pooling operation followed by a linear transformation on E j−1 and obtain Q * ∈ R N ×d . Then, the query semantics w.r.t. the candidates are extracted as where K i−1 , V i−1 are generated by Eq. (2).
Next, the gate proposed by (Cho et al., 2014) is utilized to fuse H * with the query states H j−1 : where H j ∈ R N ×d . Each row of H j stands for the representation of the incoming query with respect to one candidate.

Classifier
Let H and E denote the query states and the candidate context embeddings generated by the last interaction layer, respectively. For the i-th candidate, the representation of the query is the i-th row of H, denoted as h i . The representation of this candidate is the mean of the i-th row of context embeddings E, denoted as e i . Classification Task: For a classification task such as NLI, we concatenate the embeddings h i and e i with the element-wise difference |h i − e i | (Reimers and Gurevych, 2019) and feed them into a feed-forward network: The network is trained to minimize a cross entropy loss.
Ranking Task: For ranking tasks such as passage retrieval, we estimate the relevance score of candidate-query pairs as: where · denotes dot product. The network is optimized by minimizing a cross-entropy loss in which the logits are s i , · · · , s N . Table 1 presents the time complexity of the Dual-BERT, Cross-BERT, and our proposed MixEncoder. We can observe that the dual-encoder and MixEncoder support offline pre-computation to reduce the online time complexity. During the online inference, the query encoding cost term (hq 2 + h 2 q) of both Dual-BERT and MixEncoder does not increase with the number of candidates, since they conduct query encoding only once. Moreover, the MixEncoder's query-candidate term N c (c + q + h)hc can be reduced by setting c as a small value, which can further speed up the inference.

Datasets
To fully evaluate the proposed MixEncoder, we conduct an empirical evaluation on four pairedinput datasets, including natural language inference(NLI), information retrieval, and utterance selection for dialogue. MNLI (Multi-Genre Natural Language Inference) (Williams et al., 2018) is a crowd-sourced classification dataset. It contains sentence pairs annotated with textual entailment information.
MS MARCO Passage Reranking (Nguyen et al., 2016) is a large collection of passages collected from Bing search logs. Given a query, the goal is to rank provided 1000 passages. We use a subset of training data 1 . Following previous work, we evaluate the models on the 6980 development queries (Khattab and Zaharia, 2020;Gao et al., 2020).
DSTC7 (Yoshino et al., 2019) is a chat log corpus contained in the DSTC7 challenge (Track 1). It consists of multi-turn conversations where one partner seeks technical support from the other.
Ubuntu V2 (Lowe et al., 2015) is a popular corpus similar to DSTC7. It is proposed earlier and contains more data than DSTC7.
These four datasets have the same form that every sample in the dataset contains one text and several candidates. The statistics of these datasets are detailed in Table 2. We use accuracy to evaluate the classification performance on MNLI. For other datasets, MRR and recall are used as evaluation metrics.

Baselines
MixEncoder is compared to following baselines: Cross-BERT refers to the original BERT (Devlin et al., 2019). We take the output at CLS token as the representation of the pair. This embedding is fed into a feedforward network to generate logits for either classification tasks or matching tasks. Dual-BERT (Sentence-BERT) is proposed by Reimers et al. (Reimers and Gurevych, 2019). This model uses siamese architecture and encodes text pairs separately. Deformer (Cao et al., 2020) is a decomposed Transformer, which utilizes lower layers to encode query and candidates separately and then uses upper layers to encode text pairs together. We followed the settings reported in the original Table 1: Time Complexity of the attention module in MixEncoder, Dual-BERT and Cross-BERT. We use q, d to denote query and candidate length, respectively. h indicates the hidden layer dimension, N c indicates the number of candidates for each query and k indicates the number of context embeddings for each candidate.

Model
Total paper and split BERT-base into nine lower layers and three upper layers. Poly-Encoder (Humeau et al., 2020) encodes the query and its candidates separately and performs a light-weight late interaction. Before the interaction layer, the query is compressed into several context vectors. We set the number of these context vectors as 64 and 360 respectively. ColBERT (Khattab and Zaharia, 2020) is a late interaction model for information retrieval. It adopts the MaxSim operation to obtain relevance scores after encoding the sentence pairs separately. Note that the design of ColBERT prohibits the utilization on classification tasks.

Training Details
While training models on MNLI, we follow the conventional practice that uses the labels provided in the dataset. While training models on other three datasets, we use in-batch negatives (Karpukhin et al., 2020;) that considers the positive candidate of other queries in a training batch as negative candidates. For Cross-BERT and Deformer, that require exhaustive computation, we set batch size as 16 due to the limitation of computation resources. For other models, we set batch size as 64. All the models use one BERT (based, uncased) with 12 layers and fine-tune it for up to 50 epochs with a learning rate of 1e-5 and a linear scheduling. All experiments are conducted on a server with 4 Nvidia Tesla A100 GPUs which has 40 GB graphic memory.  Table 4 shows the experimental results of Dual-BERT, Cross-BERT, existing late interaction models and three variants of MixEncoder on four datasets. We measure the inference time of all the baseline models and present the results in Table 3.

Performance Comparison
Variants of MixEncoder. To study the effect of the number of interaction layers and that of the number of context embeddings per candidate, we present three variants in the tables, denoted as MixEncoder-a, -b and -c, respectively. Specifically, MixEncoder-a and -b set k as 1. The former has interaction layer I 1 12 and the latter has layers {I 1 10 , I 2 11 , I 3 12 }. MixEncoder-c has the same layers as MixEncoder-b but with k = 2. Inference Speed. We conduct speed experiments to measure the online inference speed for all the baselines. Concretely, we samples 100 samples from MS MARCO. Each of samples has roughly 1000 candidates. We measure the time for computations on the GPU and exclude time for text reprocessing and moving data to the GPU. Dual-BERT and Cross-BERT. The performance of the dual-BERT and cross-BERT are reported in  the first two rows of Table 4. We can observe that the MixEncoder consistently outperforms the Dual-BERT. The variants with more interaction layers or more context embeddings generally yield more improvement. For example, on DSTC7, MixEncodera and MixEncoder-b achieves an improvement by 1.0% (absolute) and 2.4% over the Dual-BERT, respectively. Moreover, MixEncoder-a provides comparable performance to the Cross-BERT on both Ubuntu and DSTC7. MixEncoder-b can even outperform the Cross-BERT on DSTC7 (+0.9), since MixEncoder can benefit from a large batch size (Humeau et al., 2020). On MNLI, MixEncoder-a retains 92.6% of effectiveness of the Cross-BERT and MixEncoder-c can retain 93.7% of that. However, the effectiveness of the MixEncoder on MS MARCO is slight. We can find that the difference of the inference time for processing samples with 1k candidates between the Dual-BERT and MixEncoder is minimal. Cross-BERT is 2 orders of magnitude slower than these models. Late Interaction Models. From table 3, 4, we can have following observations. First, among all the late interaction models, Deformer that adopts a stack of Transformer layers as the late interaction component consistently shows the best performance on all the datasets. This demonstrates the effectiveness of cross-attention in transformer layers. In exchange, Deformer shows limited speedup (1.9x) compared to Cross-BERT. Compared to the ColBERT and Poly-Encoder, our MixEncoder outperforms them on the datasets except MS MARCO. Although ColBERT consumes more computation than MixEncoder, it shows worse performance than MixEncoder on DSTC7 and Ubuntu. This demonstrates the effectiveness of the light-weight crossattention, which can achieve a trade-off between the efficiency and effectiveness. However, on MS MARCO, our MixEncoder and poly-encoder lag behind the ColBERT with a large margin. We conjecture that our MixEncoder falls short of handling token-level matching. We will elaborate it in section 5.5 .

Effectiveness of Interaction Layer
Representations. We conduct ablation studies to quantify the impact of two key components (E and H) utilized in MixEncoder. The results are shown in Table 5. Every component results in a gain in performance compared to Dual-BERT. It demonstrates that our simplified cross-attention can produce effective representations for both the candidate and query. An interesting observation is that removing E can lead to a slight improvement on DSTC7. Moreover, we also implement MixEncoder based on Eq. 6 that a linear transformation is applied to E to estimate relevance scores, which leads to a drop in performance. Varying the Interaction Layers. To verify the impact of the interaction layer, we perform ablation studies by varying the number and the position of layers. First, we use two interaction layers {I 1 i , I 2 12 }, and choose i from the set {1, 2, 4, 6, 8, 10, 11}. The results are shown in Figure 3(b). We can find that MixEncoder on Ubuntu is insensitive to i while MixEncoder on DSTC7 can be enhanced with i = 11. Moreover, Figure  3(c) shows the results when MixEncoder has interaction layers {I 1 i , I 2 i+1 , · · · , I 13−i 12 }. Increasing in-  teraction layers cannot always improve the ranking quality. On Ubuntu, replacing all the transformer layers provides close performance to that with the last layer replaced. On DSTC7, the performance of MixEncoder achieves a peak with last three layer replaced by our interaction layers.

Candidate Pre-computation
We study the effect of the number of candidate embeddings, denoted as k, and the pre-computation strategies introduced in section 3.2. Specifically, We choose the value of k from the set {1, 2, 3, 10} with one interaction layer I 1 12 . From Figure 3(c), we can observe that as k gets larger, the performance of the MixEncoder increases first, and then descends. Moreover, two pre-computation strategies have different impacts on the model performance. S-strategy generally outperforms Cstrategy with the same k.

In-batch Negative Training
We change the batch size and show the results in Figure 3(d). It can be observed that increasing batch size contributes to good performance. Moreover, we have the observation that models may fail to diverge with a small batch size. Due to the limitation of computation resources, we set the batch size as 64 for our training.

Error Analysis
In this section, we take a sample from MS MARCO to analyze our errors. We observe that MixEncoder falls short of detecting token overlapping. Given the query "foods and supplements to lower blood sugar", MixEncoder fails to pay attention to the keyword "supplements" which appears in both the query and the positive candidate. We conjecture that this drawback is due to the pre-computation that represents each candidate into k context embeddings. It lose the token-level feature of the candidates. On the contrary, ColBERT caches all the token embeddings of the candidates and estimate relevance scores based on token-level similarity.

Conclusion
In this paper, we propose MixEncoder which provides a good trade-off between the performance and efficiency. MixEncoder involves a light-weight cross-attention mechanism which allows us to encode the query once and process all the candidates in parallel. We evaluate MixEncoder with four datasets. The results demonstrate that MixEncoder can speed up sentence pairing by over 113x while achieving comparable performance as the more expensive cross-attention models.

Limitations
Although MixEncoder is demonstrated effective, we recognize that MixEncoder does not performs well on MS MARCO. It indicates that our Mix-Encoder falls short of detecting token overlapping, since it may lose token level semantics of candidates during pre-computation. Moreover, the Mix-Encoder is not evaluated on a large scale evaluation dataset, such as an end-to-end retrieval task, which requires model to retrieve top-k candidates from millions of candidates Khattab and Zaharia, 2020).