Neural Retrieval for Question Answering with Cross-Attention Supervised Data Augmentation

Early fusion models with cross-attention have shown better-than-human performance on some question answer benchmarks, while it is a poor fit for retrieval since it prevents pre-computation of the answer representations. We present a supervised data mining method using an accurate early fusion model to improve the training of an efficient late fusion retrieval model. We first train an accurate classification model with cross-attention between questions and answers. The cross-attention model is then used to annotate additional passages in order to generate weighted training examples for a neural retrieval model. The resulting retrieval model with additional data significantly outperforms retrieval models directly trained with gold annotations on Precision at N (P@N) and Mean Reciprocal Rank (MRR).


Introduction
Open domain question answering (QA) involves finding answers to questions from an open corpus (Surdeanu et al., 2008;Yang et al., 2015;Chen et al., 2017;Ahmad et al., 2019). The task has led to a growing interest in scalable end-to-end retrieval systems for question answering.
When QA is formulated as a reading comprehension task, cross-attention models like BERT (Devlin et al., 2019) have achieved better-than-human performance on benchmarks such as the Stanford Question Answering Dataset (SQuAD) (Rajpurkar et al., 2016). Cross-attention models are especially well suited for problems involving comparisons between paired textual inputs, as they provide early fusion of fine-grained information within the pair. This encourages careful comparison and integration of details across and within the two texts.
However, early fusion across questions and answers is a poor fit for retrieval, since it prevents precomputation of the answer representations. Rather, neural retrieval models independently compute embeddings for questions and answers, typically using dual encoders for fast scalable search (Henderson et al., 2017;Gillick et al., 2018;Yang et al., 2019b;Karpukhin et al., 2020). Using dual encoders results in late fusion within a shared embedding space.
For machine reading, early fusion using crossattention introduces an inductive bias to compare fine grained text spans within questions and answers. This inductive bias is missing from the single dot-product scoring operation of dual encoder retrieval models. Thus, late fusion is expected to require more training data to learn the necessary representations for fine grained comparisons.
To support learning improved representations for retrieval, we explore a supervised data augmentation approach leveraging a complex classification model with cross-attention between questionanswer pairs. Given gold question passage pairs, we first train a cross-attention classification model as the supervisor. Then any collection of questions can be used to mine potential question passage pairs under the supervision of the cross-attention model. The retrieval model training benefits from additional training pairs annotated with the graded predictions from the cross-attention model augmenting the existing gold data. Experiments on MultiReQA-SQuAD and MultiReQA-NQ establish significant improvements on Precision at N (P@N) and Mean Reciprocal Rank (MRR).
The supervised mining approach is closely connected to the recently studied hard negative mining for neural retrieval models (Xiong et al., 2020;Lu et al., 2020). The key differences is that the proposed approach finds the positive training examples, while the negative mining approaches find the negative examples for training. The two approaches are complementary and can be combined.

Neural Passage Retrieval for Open Domain Question Answering
Open domain question answering systems usually follow a two-step approach: first retrieve question relevant passages, and then scan the returned text to identify the answer span using a reading comprehension model (Jurafsky and Martin, 2018;Kratzwald and Feuerriegel, 2018;Yang et al., 2019a). Prior work has focused on the answer span annotation task and has even achieved super human performance on some datasets. However, the evaluations implicitly assume the trivial availability of passages for each question that are likely to contain the correct answer. While the retrieval task can be approached using traditional keyword based retrieval methods such as BM25, there is a growing interest in developing more sophisticated neural retrieval methods Guu et al., 2020;Karpukhin et al., 2020).
3 Retrieval Question-Answering (ReQA) Ahmad et al. (2019) introduced Retrieval Question-Answering (ReQA), a task that has been rapidly adopted by the community (Guo et al., 2020;Zhao and Lee, 2020;Roy et al., 2020). Given a question, the task is to retrieve the answer sentence from a corpus of candidates. ReQA provides direct evaluation of retrieval, independent of span annotation. Compare to Open Domain QA, ReQA focuses on evaluating the retrieval component and, by construction, avoids the need for span annotation. We explore the proposed approach on MultiReQA-NQ and MultiReQA-SQuAD (Guo et al., 2020). 1 MultiReQA (Guo et al., 2020)

Methodology
In this section we describe the proposed approach using a neural retrieval model augmented with su-  Figure 1: Use of a cross-attention model for the supervised mining of additional QA pairs. Our accurate cross-attention model supervises the mining process by identifying new previously unannotated positive pairs. Mined QA pairs augment the original training data for the dual encoder based neural passage retrieval model. pervised data mining. Figure 2 illustrates our approach using a cross-attention classifier to supervise the data augmentation process for training a retrieval model. After training the cross-attention model, we retrieve additional potential answers to questions using an off-the-shelf retrieval system 2 . The predicted scores from our classifier with cross-attention are then used to weight and filter the retrieved candidates with positive examples serving as additional training data for the dual encoder based retrieval model.

BERT Classification Model
Cross-attention models like BERT are often used for re-ranking after retrieval and can significantly improve performance as they allow for fine-grained interactions between paired inputs (Nogueira et al., 2019;Han et al., 2020). Here we formalize a binary classification task for predicting question answer relatedness. We use the question-answer pairs from the training set as our positive examples. Negatives are sampled for each question using the following strategies with a 1:1:1 ratio: (1) A sentence from the top 10 nearest neighbors returned by a term based BM25 (Robertson and Zaragoza, 2009) over a sentence pool containing all supporting documents in a corpus. (2) A sentence from the top 10 nearest neighbors using the Universal Sentence Encoder -QA (USE-QA) (Yang et al., 2019b). (3) A sentence randomly sampled from its supporting documents, excluding the question's gold answer. The sampled non-answer sentences are paired with their questions as negative examples. A BERT model is fine-tuned following the default setup from the Devlin et al. (2019).

Dual-Encoder Retrieval Model
We follow Guo et al. (2020) and employ a BERT based dual-encoder model for retrieval. The model architecture is illustrated in figure ??. The dualencoder model critically differs from the crossattention model in that there is no early interactions (cross-attention) between the question and answer. The resulting independent encodings are only combined in the final dot-product scoring a pair. The same BERT encoder is used for questions and answers with the output of the CLS token taken as the output encoding. For answers, the answer and context are concatenated and segmented using the segment IDs from the original BERT model. A learned input type embedding is added to each input token representation to distinguish questions and answers within the encoding model. The BERT dual-encoder model can be fine-tuned using the in batch sampled softmax loss (Gillick et al., 2018): Where x is the question, y is the correct answer, Y is all answers in the same batch that are used as sampled negatives, and φ(x, y) is the dot product of question and answer representations. Note that the dot product is scaled by X100 during training, which is a critical component when applying l 2 normalization to the embeddings.

Mining Augmented Training Pairs
We create an augmented training set for the retrieval model using our cross-attention based QA model. For each question in the training set, we employ USE-QA to mine the top 10 nearest neighbors from the entire training set, and then remove those retrieved pairs which are true positives. Next the cross-attention based QA model is used to score the retrieved pairs. The dual-encoder based neural retrieval model is then trained on the combination the additional scored positive pairs and the original QA pairs from the training set. The original pairs are assigned a score 1.

Weighted In-batch Softmax for Dual-Encoder Retrieval Model
The neural retrieval model is trained using the batch negative sampling loss (Gillick et al., 2018) in equation 2. We modify the standard formulation to include a weight, w(x, y), for each pair.
We set w(x, y) to 1 if (x, y) is a ground truth positive pair and p(x, y) 2 , otherwise, whereby p(x, y) is the probability from the cross-attention model.

Evaluation
In this section we evaluate the proposed approach using the MultiReQA evaluation splits for NQ and SQuAD. Models are assessed using Precision at N (P@N) and Mean Reciprocal Rank (MRR). Following the ReQA setup (Ahmad et al., 2019), we report P@N for N= [1,5,10]. P@N evaluates whether the true answer sentence appears in the top-N ranked candidates. MRR is calculated as 1 rank i , where N is the total number of questions, and rank i is the rank of the first correct answer for the ith question.

Configurations
Our cross-attention QA models are fine-tuned from the public English BERT for 10 epochs, using a batch size of 256 and a weighted Adam optimizer with learning rate 3e-5. We experiment with both BERT Base and BERT Large . All hyper-parameters are set using a dev set split out from the training data (10%). When mining for silver data, we only keep candidate pairs with positive cross-attention QA model scores (≥ 0.5).  The BERT Base model is used to initialize the dual encoder retrieval model. During training we use a batch size of 64, and a weighted Adam optimizer with learning rate 1e-4. The maximum input length is set to 96 for questions and 384 for answers. Models are trained for 200 epochs. The embeddings are l 2 normalized. Hyper-parameters are manually tuned on a held out development set.

Performance for the Classification Task
The classification data created using the method from section 4.1 contains a total of 531k and 469k training examples for NQ and SQuAD, respectively. Test sets extracted from the SQuAD and NQ test splits contain 15k and 41k examples. 3 Table 2 provides the performance of the crossattention models, compared to a majority baseline which always predict false and a BERT dual encoder retrieval model without any mined examples that uses cosine similarity for prediction. Crossattention based models outperform the baselines by a wide margin, 4 with BERT Large achieving the highest performance on all metrics. This is consistent with our hypothesis that early fusion models outperform late fusion based retrieval models. Both models achieve better performance on SQuAD than NQ. The SQuAD task has higher token overlap, as described in section 3, making the task somewhat easier. We use the BERT Large model to supervise the data augmentation in the next section.

Mined Examples
We mined the SQuAD and NQ training data to construct additional QA pairs. After collecting and scoring addition pairs using the method described in section 4.3, we obtained 53% (56,148) and 12% (10,198) more examples for NQ and SQuAD, respectively. Much less data is mined for SQuAD then NQ. We believe it is because of the way SQuAD was created, whereby workers write the questions based on the content of a particular article. The resulting questions are much more specific and biased toward a particular question types, e.g. what questions Ahmad et al. (2019). Additionally, the candidate pool for SQuAD is only half that of NQ, resulting in questions having fewer opportunities to be matched to good additional answers. Table 3 gives P@N and MRR@100 for retrieval models on MultiReQA-SQuAD and MultiReQA-NQ. The first two rows show the result from two simple baselines: BM25 (Robertson and Zaragoza, 2009), USE-QA, and USE-QA finetune reported by Guo et al. (2020). BM25 remains a strong baseline, especially with 62.8% P@1 and 70.5% MRR for SQuAD. BM25's performance on NQ is much lower, as there is much less token overlap between NQ questions and answers. USE-QA matches the performance of BM25 on NQ but performs worse on SQuAD. 5 BERT dual encoder performs well compared to other baselines, especially on NQ with a +6.6 point improvement compared to the USE-QA finetune model. 6 Its P@1 on SQuAD performs better than USE-QA and BM25, but -3.1 points MRR worse than USQ-QA finetune . On average, BERT dual encoder is the best among those baselines.

Results on the Retrieval QA
Performance improves by a large margin using augmented training data from our cross-attention QA model, obtaining a +8.6 and +7.0 improvement on NQ P@1 and MRR. Compare to NQ, the improvement on SQuAD is rather marginal. The augmented BERT dual encoder retrieval model only achieves slightly improved performance on SQuAD, with +1 points for both P@1 and MRR. As discussed in section 5.3, we mine much less data from SQuAD compare to NQ, with only 10% more data than the original training set. As demonstrated by the strong BM25 performance and shown in (Guo et al., 2020), the SQuAD QA pairs have high token overlap between question and answers,   Effectiveness of Weighted Softmax. We further experimented the Retrieval QA tasks using the model with the non-modified softmax using the augmented data. All other configurations are keep the same. The MRR of the model using nonmodified softmax is 60.1 on MultiReQA-NQ and 71.9 on MultiReQA-SQuAD, which are much worse than the model using weighted softmax. This result indicates the weighted softmax is important for the proposed approach.

Conclusion
In this paper, we propose a novel approach for making use of an early fusion classification model to improve late fusion retrieval models. The early fusion model is used for data mining to augment the training set for the late fusion model. The proposed approach mines 53% (56,148) and 12% (10,198) more examples for MultiRQA-NQ and MultiRQA-SQuAD, respectively. Compared to the models directly trained with gold annotations, the resulting retrieval models improve +8.6% and +1.0% P@1 on NQ and SQuAD respectively. The current pipeline assumes there exists annotated indomain question answer pairs to train the crossattention model. With a strong general purpose cross-attention model, our method could be modified to train in-domain retrieval models without gold data. We leave this to the future work.