A Self-supervised Joint Training Framework for Document Reranking

Pretrained language models such as BERT have been successfully applied to a wide range of natural language processing tasks and also achieved impressive performance in document reranking tasks. Recent works indicate that further pretraining the language models on the task-speciﬁc datasets before ﬁne-tuning helps improve reranking performance. However, the pre-training tasks like masked language model and next sentence prediction were based on the context of documents instead of encour­ aging the model to understand the content of queries in document reranking task. In this paper, we propose a new self-supervised joint training framework (SJTF) with a self-supervised method called Masked Query Pre­ diction (MQP) to establish semantic relations between given queries and positive documents. The framework randomly masks a token of query and encode the masked query paired with positive documents, and use a linear layer as a decoder to predict the masked token. In addition, the MQP is used to jointly opti­ mize the models with supervised ranking ob­ jective during ﬁne-tuning stage without an ex­ tra further pre-training stage. Extensive exper­ iments on the MS MARCO passage ranking and TREC Robust datasets show that models trained with our framework obtain signiﬁcant improvements compared to original models.


Introduction
The document ranking task is to generate a ranked list of candidate documents based on their rele vance scores to a given query posed in natural language, which has been a longstanding prob lem that has been widely studied over natural language processing (NLP) and question answer ing. Pre-trained language models (PLMs) such as BERT (Devlin et al., 2019), RoBERTa (Liu et al., 2019) and ELECTRA (Clark et al., 2020), * Corresponding author: Hai Liu. have achieved impressive results on various NLP tasks and have outperformed conventional docu ment ranking methods (Hui et al., 2018;Mitra and Craswell, 2019) for powerful contextual represen tation capability. In recent years, several studies (Karpukhin et al., 2020;Qu et al., 2021) have used pre-trained language models as dual-encoder to separately encode queries and documents for dense document retrieval. One of the most common ap proaches (Nogueira and Cho, 2019) uses a PLMs as an interaction-based reranker for passage ranking, which fine-tunes BERT simply with an extra linear layer on the top of BERT and using a special vector [CLS] to produce relevance score for each query document pair. Inspired by the fact that contextu alized embeddings produced by PLMs are essen tial for the success of pre-trained models, CEDR (MacAvaney et al., 2019) employed the classifica tion vector into existing neural ranking models and PARADE (Li et al., 2020) used a transformer mod ule for passage-level representation aggregation to obtain performance improvements.
In contrast to the approaches of using contex tual representation for reranking, prior works Gururangan et al., 2020) suggest that further pre-train the PLMs on within-task training unsupervised data is able to learn domain-specific and task-specific language patterns effectively. To better understand the complex sentence relations, UED (Yan et al., 2021) transformed original next sentence prediction (NSP) task in BERT to a new sentence relation prediction (NSR) task. Gu et al. (2020) proposed a novel selective masking strategy to focus on masking the important tokens and then train a model to reconstruct input for further pre training the PLMs to learn task-specific patterns. However, these approaches typically perform the pre-training task on task-specific corpus to under stand the context of passages, while fail to consider the passages as the context of the given query to capture semantic consistency. In addition, further pre-training on task-specific domain datasets en tails additional time cost and computational cost. Therefore, this paper proposes a self-supervised joint training framework for encouraging the model to understand the content of queries based on the context of passages, where the auxiliary selfsupervised method is combined with the ranking task to fine-tune pre-trained model. Specifically, the self-supervised joint training framework (SJTF) extends the typical reranking pipeline with the auxiliary self-supervised method MQP and a de coder for predicting the masked token after the pre trained model. On the one hand, the self-supervised approach enables the model to establish semantic relations between queries and positive passages to better identify relevant passages from a large number of candidate passages. On the other hand, the proposed training framework reduces the train ing time by simultaneously performing the selfsupervised method and the ranking task in the finetuning stage, while the time of the ranking model is not increased in the inference stage. We evaluate the proposed training framework on two widely used document ranking datasets MS MARCO and Robust04. The experimental results indicate that the models trained with the proposed SJTF frame work obtain a performance improvement against original models.
In summary, the contributions of our paper are as follows: • A self-supervised joint training framework (SJTF) is proposed to improve the representa tion learning without additional pre-training.
• A strategy to integrate the SJTF to existing passage reranking methods is proposed with out architecture modification and inference time increasing.
• Experiments on standard datasets show that reranking models with SJTF integration achieve significant performance improvement.

Related Work
In recent years, pretrained language models such as BERT (Devlin et al., 2019), RoBERTa (Liu et al., 2019) and ELECTRA (Clark et al., 2020) had substantially outperformed the traditional neu ral ranking models like DRMM (Guo et al., 2016), Co-PACRR (Hui et al., 2018) and Conv-KNRM (Dai et al., 2018). Nogueira and Cho (2019) first employed the pretrained language model BERT to passage reranking tasks using the classifica tion vector [CLS]. CEDR (MacAvaney et al., 2019) incorporated the contextualized embeddings of BERT into existing IR models for document ranking. PARADE (Li et al., 2020) utilized Trans former (Vaswani et al., 2017) blocks to aggregate relevance passage-level representations to predict a document ranking. These approaches provided different perspectives on score prediction using the relevance representation vectors produced by pre trained language models.
To better improve the representation learning of pre-trained language models in the target domain,  further pre-trained BERT with masked language model (MLM)and next sentence prediction (NSP). Gururangan et al. (2020) ob served that the less relevant the pre-trained corpus was to the target corpus, the more the pre-trained language model would benefit from further pre training. In this paper, the proposed joint training framework could be used in conjunction with those reranking models in the fine-tuning phase, which reduced the time spent in further pre-training.
Representation learning has been shown to be critical on natural language tasks and has a signifi cant influence on downstream tasks (Devlin et al., 2019;Peters et al., 2018;Yan et al., 2021). Devlin et al. (2019) adopted the self-supervised MLM task to encourage the model to learn contextual repre sentations by predicting the masked token from the context. From the perspective of optimizing the masking strategy, RoBERTa (Liu et al., 2019) mod ified the static masking strategy by dynamically masking the input examples during the training stage, while (Gu et al., 2020) proposed a selec tive masking strategy that masked important words rather than any word in the sentence.
For understanding the content of documents, Cross-Thought  proposed a method for recovering masked words from doc uments that contain the most important informa tion in a nearby sequence. However, understanding the content of both the query and the document is crucial in the question answering tasks (Zhang et al., 2019;Mudrakarta et al., 2018;Nogueira et al., 2019b), while these methods focused on represen tation learning for understanding the content of the document. The aim of our self-supervised task MQP is to predict the masked query word based on the semantic consistency of the query and the positive passages, which forces the model to con sider the positive passages as the context of the given query. The MQP differs from the standard self-supervised approaches of using unsupervised data (He et al., 2021;Hendrycks et al., 2019;You et al., 2021) in that the supervised signal is used to select relevant passage as the context of the query.
Multi-task learning (MTL) is an effective train ing setting for allowing model to obtain shared knowledge from several related supervised tasks in document ranking. Fun et al. (2021) enhanced the common representation learning using a retrieval optimized multi-task framework (ROM) for jointly training the retriever, reader and self-supervised tasks with a single encoder. UED (Yan et al., 2021) jointly trained both the ranking and query genera tion tasks to exploit the task relationships for en hancing the neural re-ranker. Liu et al. (2019) and Maillard et al. (2021) leveraged supervised data from related tasks to enhance the robustness of the model and generic knowledge representation learning. Although the datasets of related tasks are available, there may be differences in language pat terns and data distribution between datasets. Our proposed joint training framework uses relevance labels to construct data on the target domain of ranking tasks without the requirement of external data, which is achievable on any of the question answering datasets.

The Approach
This section describes the proposed self-supervised joint training framework (SJTF), which employs a self-supervised method (MQP) to jointly finetune reranking models with the ranking task. In general, the whole model consists of three parts: a pretrained encoder for producing interactionbased representation of the given query and pas sages, a scorer for calculating a precise relevance score for each query-passage pair and an extra de coder to predict the masked query token based on the interaction-based representation. The overall framework of our approach is shown in Figure 1.
In the task of passage reranking, a natural lan guage question and a list of candidate passages re trieved by traditional methods or dense retriever are provided. The question is denoted as a klength sequence of tokens Q=<q 1 , q 2 , · · · , q k >, while each candidate passage can be denoted by an m-length sequence of tokens P =< p 1 , p 2 , · · · , p m >. The passage reranking task re quires the model to learn informative representa tion and produce a precise relevance score for each query-passage pairs to return the best permutation of candidate passages.

Radom Masking
In contrast to the MLM task that random masks the tokens of the passage, we assume that understand ing the content of the given query is necessary and the semantic information of the positive passage can be used to infer the masked token in the query, which finally allows the model to learn the seman tic relation between the query and the positive passage. The tokens to be masked in the query are selected randomly following a uniform distribution and replaced with special token [MASK], and for simplicity, each query token is considered here to have the same importance. Then the masked query and positive passage are concatenated 3 and preprocessed into a sequence I mask =< [CLS], q 1 , · · · , [M ASK], · · · , q k , [SEP ], p 1 , · · · p m , [SEP ] >, where special token [CLS] in dicates the start of a sentence, token [M ASK] indicates the masked token and token [SEP ] is a separator symbol.
The reason for masking the query is that the orig inal query can be reconstructed by understanding the content of the processed query and the posi tive passage. However, if the goal is to predict the masked tokens in the passage, the semantic infor mation of the query is redundant as it can be in ferred from the context of the passages alone. With out requiring extra data augmentation, the frame work SJTF only utilizes the relevance label and a simple masking strategy to construct the masked input sequence I mask ,which can be easily imple mented in question-answering tasks.

Masked Query Prediction
The pre-trained language model BERT (Devlin et al., 2019) uses the MLM task to learn contex tual language representations of individual texts in a large corpus. However, establishing contextual semantic relations between queries and candidate passages is critical in the question-answering do main. To achieve this, we propose a self-supervised auxiliary approach called Masked Query Prediction (MQP) so that the model uses positive passages as the context of the query and predicts masked token in conjunction with visible query tokens, which allows the model to extract semantic relation be tween the query and positive passage, thus pick relevant passages out of a large number of candi date passages.
After passing the masked input sequence I mask constructed in section 3.1 through the BERT-like encoder which is shared with a passage reranking task, the representation vector T mask ∈ R d of the masked query token in the last layer is obtained as: = Encoder(q masked , p pos ) (1) T masked Finally, the masked token representation vector is fed into a decoder implemented by a T masked neural linear layer to predict the original token.The parameters of MQP module are optimized by the cross-entropy loss function L M QP which is defined as : where M is the number of masked token and P (t m ) denotes the probability that the token t m is pre dicted over the whole vocabulary. For each query and positive passage pair, a single query token is replaced with the special token [M ASK] for the self-supervised method. Note that the MQP de coder is only used for predicting the masked token during the fine-tuning stage, while in the inference phrase only the encoder and scorer are used for reranking passages. The decoder for the MQP task is implemented with one linear layer, meaning that this requires only a small number of additional neu ral network parameters to be trained, which makes the MQP task easy to be extended to the exist ing passage reranking methods without significant modifications to the model architecture.

Passage Ranking
Following the settings of Nogueira and Cho (2019), a pair of query q and candidate pas sage p i is packed as an input sequence I =< [CLS], q 1 , · · · , q k , [SEP ], p 1 , · · · p m , [SEP ] >, and the BERT-like pretrained language model is employed as a passage encoder E that produces a relevance representation vector for each QA pair. During fine-tuning the model, the [CLS] vector T cls ∈ R d from the last layer of encoder is re garded as the final interaction-based representation: cls For passage reranking task, the representation vector T cls of query-passage pairs is calculate by a scorer S ranker which generates a relevance score to quantify their relevance. The relevance score of i-th pairs of query q and candidate passage P i is denoted as: where the scorer S ranker can be implemented by a linear layer at top of the BERT or by an elaborate scoring module such as KNRM (Dai et al., 2018) , SAN (Kingma and Ba, 2015) or PARADE (Li et al., 2020), and θ contains the set of parameters of the scorer module.
Compared with the point-wise ranking loss used in (Nogueira and Cho, 2019), the pair-wise margin ranking loss discriminates the positive and nega tive examples by relative distance, allowing the model to learn the margins between the positive and negative examples to give an appropriate rel evance score. Therefore, the reranking module is optimized by the ranking loss L rank as Equation (5): L rank (5) = max(0, y · (Score k ) + γ), cls − Score k+1 cls where y = −1 if the Score k is higher than cls Score k+1 , and vice-versa for y = 1. γ is a hy cls perparameter that controls the margin of positive and negative examples.

Joint Training
Different from the self-supervised tasks that are used for further pre-training on the downstream dataset (Fun et al., 2021;Liu et al., 2019), our pro posed self-supervised method MQP is combined with the ranking task in the fine-tuning phase, since it utilizes the relevance label information to con struct masked input sequences. As shown in Figure  2, the joint training strategy simplifies the training procedure without further pre-training phases and training resources required. In the fine-tuning stage, the loss is defined as a linear combination of pas sage reranking loss and masked query prediction loss as: The hyperparameter α is assigned by different values to tradeoff between passage reranking and masked query prediction. As the MQP is a sec ondary task used to help the model understand the query content, where the loss weight α is usually set to a lower value, resulting in the parameters of the reranking model still being optimized primarily by the passage reranking task.

Dataset
The proposed method is extended to existing rank ing models and evaluated on two widely used datasets: MS MARCO Passage Ranking (Nguyen et al., 2016) and TREC Robust 2004 (Voorhees). The statistics of these two datasets are shown in Table 1.
MS MARCO Passage Ranking dataset is a large-scale dataset consisting of real anonymous questions from the Bing search engine and 8.8 mil lion candidate passages for passage reranking task. The training set contains about 500 thousand pos itive query-passage pairs and each query has one relevant passage on average. The development set and evaluation set contains 6980 queries and 6837 queries respectively, where the relevance labels are provided for the development set only.
Robust04 dataset is a newswire collection used by TREC 2004 Robust track, which comprises 250 queries and 0.5 million documents (TREC Disks 4 and 5). Following the setting of CEDR (MacA vaney et al., 2019), we use the same five folds cross-validation with three folds for training, one fold for validation and one fold for testing.

Evaluation Metrics
Following the previous works, three widely used evaluation metrics are adopted to measure the per formance of the proposed approach, including MRR@10 for the MS MARCO Passage Rank ing dataset, P@20 and NDCG@20 evaluated by trec_eval 1 for the Robust04 dataset. The result re ported for model performance are averaged over all test folds on the Robust04 dataset.
MRR@10(Mean Reciprocal Rank) This met ric considers the reciprocal rank of the first relevant passage in ranked list to a given query as the preci sion. For MS MARCO passage ranking task only provides binary label and does not specify relative ranking order between passages, thus the MRR metric is used for evaluation.
P@20 The top-20 precision is defined as the proportion of relevant documents which are ranked in the top 20 candidate documents.
NDCG@20 NDCG is used to measure the dis crepancy between the ranked list and the correct ranking list, which evaluates the ranking perfor mance of models.

Baselines
The proposed training framework is integrated with the following methods and compared the reranking results with the original methods: BM25+BERT (Nogueira and Cho, 2019) uti lizes the traditional unsupervised ranking method BM25 as a first-stage retriever to generate a ranked list of candidate passages, and the relevance scores between query and candidate passages are pro duced by the BERT-base model with a linear layer.
BM25+BERT+FP (Gururangan et al., 2020) further pretrains the language model BERT with MLM objective on the target datasets before finetuning to learn the specific domain language pat terns for document reranking task.
PARADE (Li et al., 2020) aggregates passagelevel relevance representations to predict a docu ment relevance score, where the long document is split into several passages and each passage is encoded with a given query.
CEDR-KNRM (MacAvaney et al., 2019) incor porates the classification vector of a fine-tuned BERT into existing neural models KNRM (Dai et al., 2018) and leverage contextual information to improve ad-hoc document ranking.

Implementation Detail
For both datasets, following the setting of the first retrieval stage (Nogueira et al., 2019a) , we employ BM25 (Robertson et al., 2009) as the retriever at the first stage to obtain a list of top-k candidate documents/passages for next reranking stage. The reranking models are optimized by Adam (Kingma and Ba, 2015) with a learning rate of 3e-6 and a batch size of 8. The maximum sequence length is limited to 128 tokens for MS MARCO passage ranking dataset, while the documents are truncated to 800 tokens for Robust04 dataset. The encoder of CEDR and PARADE are the base-size pre-trained language model which consists of 12 transformer blocks, the hidden size as 768, and 12 self-attention heads. For Robust04 dataset, the dropout function is used with the rate of 0.1 for improving model robustness. We set the value of margin hyperpa rameter γ of pairwise ranking loss to 1. A uniform distribution is used to random mask a token in a given query, where all query tokens are assumed to be of equal importance. Since the MQP loss serves as an auxiliary loss to enhance the interactive-based representation learning for encoder, the weight hy perparameter α is assigned to 0.2 for MS MARCO and 0.05 for Robust04 respectively. The exper iments are conducted on a single GeForce RTX 3090 GPU.

Results
The experimental results of model+SJTF on the MS MARCO passage ranking and Robust04 datasets are presented in Table 2, where model+SJTF refers to the joint training of different reranking mod els applying to SJTF during the fine-tuning stage. From Table 2, the reranking models adopting SJTF have obtained significant improvements us ing paired t-test (p < 0.05) compared to original models.
Compared with BM25+BERT-base, the further pretraining method BM25+BERT+FP improves MRR@10 by 0.5% on the MS MARCO and NDCG@20 by 0.8% on the Robust04 datasets for better capturing the domain-specific language pat terns. By Jointly training the BERT-base model with SJTF framework during the fine-tuning stage, BERT+SJTF achieves a 1.3% improvement in terms of MRR@10 on the MS MARCO and 1.1% improvement of NDCG@20 on the Robust04, The experimental results suggest that learning the se mantic dependencies of queries and relevant doc uments is more effective than further learning the contextual semantic information of documents. BERT+PACRR+SJTF outperforms the original ap proach by 0.5% on the MS MARCO dataset and 1.3% on the Robust04 dataset,which indicates that the performance of traditional reranking methods would benefit from an improved contextual repre sentation. The results show that the models finetuned with the joint training framework yields a better interaction-based representation based on the semantic relationships between the query and candidate passages, which leads to a significant per formance improvement without the model structure change.
The improvement of PARADE by splitting a long document into overlapping passages via slid ing windows is less, increased by 0.55% on Ro bust04 datasets, suggesting that the semantic blocks in long documents are not all tightly semantically associated with queries and using these to predict masked query tokens only slightly enhances rep resentation learning. In comparison with the re sults of PARADE, CEDR has been significantly im proved by applying the SJTF framework, which in dicates that the document-level semantic relations captured by the SJTF framework are beneficial for representation learning. The overall experimental results verify the effectiveness and generality of the training framework of encouraging the models to adopt the positive passage as the semantically consistent context for a given query.

Ablation Study
Inspired by the observations from previous works (Gu et al., 2020;Liu et al., 2019;Gururangan et al., 2020) that different masking strategies enable a model to learn various perspective knowledge dur ing pre-training stage, we design three different self-supervised methods in terms of focusing on understanding the content of query or passages and using the query or passage as the context for mask ing prediction task.
Similar to the traditional MLM approach, we randomly mask a token of a positive document and use a corresponding query and the remaining docu ments as contexts to predict the masked token, and this strategy is denoted as MPP. The aim of MPP is to encourage the model to learn task-specific language patterns and understand that the corre sponding query and the positive documents are semantically consistent. However, it is possible to reconstruct the masked token based on the content of document itself without demanding the semantic information of the query. In addition, since MPP is a simple task for pre-trained models, it cannot effectively improve the ranking performance of models.
Considering that the documents retrieved by BM25 are relevant or partially relevant to the query at the word level, the second strategy is to treat each retrieved document (regardless of positive and negative documents) as the context of the masked query for the MQP task, which does not utilize the supervised signal and is denoted as MQP ALL . In contrast to the first strategy, MQP ALL allows the models to focus on understanding the content of query and capture the semantic relationships be tween queries and documents. Table 3 illustrates the slight improvement achieved by MQP ALL com pared to the original model and the first strategy.
The self-supervised method MQP proposed in this paper is based on the assumption that although the documents retrieved by BM25 have word-level relevance to the query, the documents containing the query terms are not always relevant. Therefore, we use supervised signals to consider only posi tive documents as the context of masked queries, since the semantic information of negative docu ments is equal to noise for the queries. As shown in table 3, after filtering out negative documents by supervised signals, BERT-base and CEDR-KNRM achieve a significant improvement on NDCG@20 and P@20 metrics. Experimental results illustrate that the reasonable use of label informations to construct a self-supervised approach enables the models to obtain more accurate knowledges of the downstream task, which helps the models to iden tify correct documents for each question from a large collection.

Hyper-parameter Study
As a hyper-parameter described in Equation (6), the setting of α value controls the influence of the self-supervised approach on the representation learning of models. Figure 3 shows the result of different hyper-parameters setting on the Robust04 dataset. In general, a larger value α allows an ex cessive effect of the self-supervised method, which reduces the influence of the reranking task in op timizing models while achieving sub-optimal per formance. Therefore, with α value greater than 0.05, the ranking performance of all three methods show a decreasing trend. For the CEDR method, a stable performance improvement is obtained us ing the proposed training framework, indicating that considering the positive documents as the con textual features of the query can significantly en hance the representation learning of the CEDR model. By gradually raising the α value from 0.01 to 0.05, the performance of both BERT-base and BERT+PACRR also gradually achieves the highest NDCG@20, which suggests that encouraging these models to focus on query understanding can lead to better reranking performance. From Figure 3, it can be seen that the α value of 0.05 is more robust for most of the models.

Conclusion
This paper proposes a self-supervised joint train ing framework SJTF to improve the representation learning for document ranking. By integrating with this training framework, models are able to better understand the content of queries and capture the semantic relation between queries and positive doc uments, which guides the models to identify the relevant documents among a large number of can didate documents. Experimental results show that our proposed framework enhances the represen tation learning of reranking models and achieves better performance compared with baselines in the document reranking task.