Learning to Sample Replacements for ELECTRA Pre-Training

ELECTRA pretrains a discriminator to detect replaced tokens, where the replacements are sampled from a generator trained with masked language modeling. Despite the compelling performance, ELECTRA suffers from the following two issues. First, there is no direct feedback loop from discriminator to generator, which renders replacement sampling inefficient. Second, the generator's prediction tends to be over-confident along with training, making replacements biased to correct tokens. In this paper, we propose two methods to improve replacement sampling for ELECTRA pre-training. Specifically, we augment sampling with a hardness prediction mechanism, so that the generator can encourage the discriminator to learn what it has not acquired. We also prove that efficient sampling reduces the training variance of the discriminator. Moreover, we propose to use a focal loss for the generator in order to relieve oversampling of correct tokens as replacements. Experimental results show that our method improves ELECTRA pre-training on various downstream tasks.


Introduction
One of the most successful language model pretraining tasks is masked language modeling (MLM; Devlin et al. 2019). First, we randomly mask some input tokens in a sentence. Then the encoder learns to recover the masked tokens given the corrupted input. ELECTRA (Clark et al., 2020a) argues that MLM only produces supervision signals at a small proportion of positions (usually 15%), and uses the replaced token detection task as an alternative. Specifically, ELECTRA contains a generator and a discriminator. The generator is a masked language * Contribution during internship at Microsoft Research. model, which substitutes masks with the tokens sampled from its MLM predictions. The discriminator learns to distinguish which tokens have been replaced or kept the same. Experimental results on downstream tasks show that ELECTRA can largely improve sample efficiency.
Despite achieving compelling performance, it is usually difficult to balance the training pace between the generator and the discriminator. Along with pre-training, the generator is expected to sample more hard replacements for the detection task in a curriculum manner, while the discriminator learns to identify the corrupted positions. Although the two components are designed to compete with each other, there is no explicit feedback loop from the discriminator to the generator, rendering the learning games independent. The absence of feedback results in sub-efficient learning, because many replaced tokens have been successfully trained while the generator does not know how to effectively sample replacements. In addition, a well trained generator tends to achieve reasonably good MLM accuracy, where many sampled replacements are correct tokens. In order to relieve the issue of oversampling correct tokens, ELECTRA explored tweaking the mask probability larger, raising the sampling temperature, and using a manual rule to avoid sampling original tokens.
In this paper, we propose two methods, namely hardness prediction and sampling smoothing, to tackle the above issues. First, the motivation of hardness prediction is to sample the replacements that the discriminator struggles to predict correctly. We elaborate on the benefit of a good replacement mechanism from the perspective of variance reduction. Theoretical derivations indicate that the replacement sampling should be proportional to both the MLM probability (i.e., language frequency) and the corresponding discriminator loss (i.e., discrimination hardness). Based on the above conclu-sion, we introduce a sampling head in the generator, which learns to sample by estimating the expected discriminator loss for each candidate replacement. So the discriminator can give feedback to the generator, which helps the model to learn what it has not acquired. Second, we propose a sampling smoothing method for the issue of oversampling original tokens. We adopt a focal loss (Lin et al., 2017) for the generator's MLM task, rather than using cross-entropy loss. The method adaptively downweights the well-predicted replacements for MLM, which avoids sampling too many correct tokens as replacements.
We conduct pre-training experiments on the Wik-iBooks corpus for both small-size and base-size models. The proposed techniques are plugged into ELECTRA for training from scratch. Experimental results on various tasks show that our methods outperform ELECTRA despite the simplicity. Specifically, under the small-size setting, our model performance is 0.9 higher than ELECTRA on MNLI (Williams et al., 2018) and 4.2 higher on SQuAD 2.0 (Rajpurkar et al., 2016a), respectively. Under the base-size setting, our model performance is 0.26 higher than ELECTRA on MNLI and 0.52 higher on SQuAD 2.0, respectively.

Related Work
State-of-the-art NLP models are mostly pretrained on a large unlabeled corpus with the self-supervised objectives (Peters et al., 2018;Lan et al., 2020;Raffel et al., 2020). The most representative pretext task is masked language modeling (MLM), which is introduced to pretrain a bidirectional BERT (Devlin et al., 2019) encoder. RoBERTa  apply several strategies to enhance the BERT performance, including training with more data and dynamic masking. UniLM (Dong et al., 2019;Bao et al., 2020) extend the mask prediction to generation tasks by adding the auto-regressive objectives. XLNet  propose the permuted language modeling to learn the dependencies among the masked tokens. Besides, ELEC-TRA (Clark et al., 2020a) propose a novel training objective called replaced token detection which is defined over all input tokens. Moreover, ELEC-TRIC (Clark et al., 2020b) extends the idea of ELEC-TRA by energy-based cloze models.
Some prior efforts demonstrate that sampling more hard examples is conducive to more effective training. Lin et al. (2017) propose the focal loss in order to focus on more hard examples. Generative adversarial networks (Goodfellow et al., 2014) is trained to maximize the probability of the discriminator making a mistake, which is closely related to ELECTRA's training framework. In this work, we aim at guiding the generator of ELECTRA to sample the replacements that are hard for the discriminator to predict correctly, therefore the pre-training process of the discriminator can be more efficient.
3 Background: ELECTRA An overview of ELECTRA is shown in Figure 1. The model consists of a generator G and a discriminator D. The generator is trained by masked language modeling (MLM). Formally, given an input sequence x = x 1 · · · x n , we first randomly mask k = 0.15n tokens at the positions m = m 1 · · · m k with [MASK]. The perturbed sentence c is denoted as: where the replace operation conducts masking at the positions m. The generator encodes c and performs MLM prediction. At each masked position i, we sample replacements from MLM output distribution p G : where masks are replaced with the sampled tokens. Next, the discriminator encodes the corrupted sentence x R . A binary classification task learns to distinguish which tokens have been replaced or kept the same, which predicts the probability D(x R t , x R ) to indicate how likely x R t comes from the true data distribution.
The overall pre-training objective is defined as: where X represents text corpus, and λ = 50 suggested by Clark et al. (2020a) is a hyperparameter used to balance the training pace of generator and discriminator. Once pre-training is finished, only the discriminator is fine-tuned on downstream tasks. : An overview of our model. The generator has two prediction heads. The MLM head learns to perform MLM through the focal loss instead of the cross entropy loss. The sampling head is trained to estimate the discriminator loss over the vocabulary. Our model samples the replacements from a new distribution, which is proportional to both the MLM probability and the corresponding discriminator loss. The discriminator is trained to distinguish input tokens and the loss feedback is transferred to the generator for the sampling head to learn.

Hardness Prediction
The key idea of hardness prediction is to let the generator receive the discriminator's feedback and sample more hard replacements. Figure 2 shows the overview of our method. Besides the original MLM head in the generator, there is an additional sampling head used to sample replaced tokens. Given a 1 replaced token x in the input sequence c, let L D (x , c) denote the discriminator loss for the replacement. Rather than directly sampling replacements from the MLM prediction p G , we propose to sample from p S : where the corrupted sentence x R is obtained by substituting the masked positions m with sampled replacements x . The first term p G (x |c) implies sampling from the data distribution. The second term L D (x , c) encourages the model to sample more replacements that the discriminator has not successfully learned.
Notice that Equation (1) uses the actual discriminator loss L D (x , c), which can not be obtained without feeding x R into the discriminator. As an alternative, we use the estimated loss valueL D (x , c) to sample replaced tokens, which approximates the actual loss for the candidate replacement. During pre-training, we use the actual loss as supervision, and simultaneously train the sampling head. We describe the detailed implementations of loss estimation in Section 4.1.2.
By considering detection hardness in replacement sampling and giving feedback from the discriminator to the generator, the components are no longer independently learned. ELECTRA (Clark et al. 2020a; Appendix F) also attempts to achieve the same goal by adversarially training the generator. However, it underperforms the maximumlikelihood training, because of the poor sample efficiency of reinforcement learning on discrete text data. More importantly, their generator is trained to fool the discriminator, rather than guiding the discriminator by data distribution, which breaks the ELECTRA training objective. In contrast, we still retain the MLM head, and decouple it from re-placement sampling. So we can take the advantage of the original training objective.

Perspective of Variance Reduction
We show that the proposed hardness prediction method is well supported from the perspective of variance reduction.
Proposition 1. Sampling replacements from p S (x |c) can minimize the estimation variance of the discriminator loss.
Proof. At each masked position, the expectation of the discriminator loss we aim to estimate can be summarized as Under p G , the estimation variance of the discriminator loss is: Similar to importance sampling, we can select an alternative distribution p S different from p G , then the expectation Z is rewritten as: By making a multiplicative adjustment to L D , the estimation variance of Z under the new sampling distribution p S is converted to: Based on the above derivation, it is obvious that we obtain a zero-variance estimator when we choose p S (x * |c) = p G (x * |c)L D (x * , c)/Z as Equation (1). This theoretically optimal form provides us insights into designing the above sampling scheme.

Two Implementations of Hardness Prediction
We design two variants of the sampling head. The first one is to explicitly estimate the discriminator loss (HP Loss ). The second method is to approximate the expected sampling distribution (HP Dist ).
HP Loss guides the generator to learn the probability predicted by the discriminator that the sampled token x is an original token. In this case, the output layer of the sampling head is actually a sigmoid function same as the discriminator: where h S (c) denotes the contextual representations projected by the sampling head, and w denotes the projection parameters. Then the loss of the sampling head at the masked position is: When sampling replacements over the vocabulary, the estimated discriminator probabilitŷ D(x , c) can be easily rewritten to the estimated discriminator lossL D (x , c): Multiplying the MLM probability factor p G , we obtain the sampling distribution: HP Dist aims to directly approximate the expected sampling distribution as in Equation (1), instead of the discriminator loss. In this case, the sampling head produces an output probability of the token x with a softmax layer: where e represents the token embeddings. For the sampled token x , we define the loss of the sampling head as: Then we show that minimizing the above loss L S (x , c) pushes sampling distribution of Equation (2) to our goal. Specifically, the loss expectation over the whole vocabulary is: According to the Lagrange Multiplier method, the optimal solutionp S of the loss function L S (x , c) is consistent with Equation (1):

Sampling Smoothing
Along with the learning process, the masked language modeling tends to achieve relatively high accuracy. As a consequence, the generator oversamples the correct tokens as replacements, which renders the discriminator learning inefficient. In order to address the issue, we apply an alternative loss function called focal loss (Lin et al., 2017) for MLM of the generator. Compared with the vanilla cross-entropy loss, focal loss adds a modulating factor for the weighting purpose: where γ ≥ 0 is a tunable hyperparameter. Besides using a constant γ, we try the piecewise function γ = 1(p G > 0.2) * 3 + 1(p G ≤ 0.2) * 5 in our experiments as suggested by Mukhoti et al. (2020). In other words, the focal loss is used to adaptively down-weight the well-classified easy examples and thus focusing on more difficult ones. When applying the focal loss to the MLM head for the generator, we notice that if a token is easy for the generator to be predicted correctly, i.e., p G (x|c) → 1, the modulating factor is greatly decreased. In contrast, if a token is hard to predict, the focal loss approximates to the original cross entropy loss. Therefore, we propose to employ the focal loss in order to smooth the sampling distribution, which in turn relieves oversampling correct tokens as replacements.

Pre-Training Objective
Adopting the above two strategies, we jointly train the generator and the discriminator together as the original ELECTRA model. The word embeddings of them are still tied during the pre-training stage. Formally, we minimize the combined loss over a large corpus X : where λ 1 , λ 2 are two hyperparameters to adjust three parts of the loss. We only search λ 1 value and keep λ 2 = 50 for the fair comparison with ELECTRA. After pre-training, we throw out the generator and only fine-tune the discriminator on the downstream tasks.

Setup
We implement ELECTRA+HP Loss /HP Dist +Focal on both the small-size setting and the base-size setting. The two prediction heads share both the generator and the token embeddings, which avoids the unnecessary increase in model complexity. We follow most settings as suggested in ELECTRA (Clark et al., 2020a). In order to enhance the ELECTRA baseline for a solid comparison, we add the relative position (Raffel et al., 2020). Experimental results show that our methods can improve performance even on the enhanced ELECTRA baseline. We pretrain our models on the same text corpus as ELECTRA, which is a combination of English Wikipedia and BooksCorpus (Zhu et al., 2015). We also adopt the N-gram masking strategy which is beneficial for MLM tasks. The models are trained for 1M steps for small-size models and 765k steps for base-size models, so that the computation consumption can be similar to baseline models (Clark et al., 2020a). The base-size models are pretrained with 16 V100 GPUs less than five days. The smallsize models are pretrained with 8 V100 GPUs less than three days. We use the Adam (Kingma and Ba, 2015) optimizer (β 1 = 0.9, β 2 = 0.999) with learning rate of 1e-4. The value of λ 2 in the training objective is kept fixed at 50 for a fair comparison with ELECTRA. For HP Loss , we search λ 1 in {5, 10, 20}, the best one is 5. For HP Dist , we keep λ 1 = 1. We search the focal loss weight γ in {1, 4} on both the base-size and small-size model, the best configuration is γ = 1. The detailed pre-training configurations are provided in the supplemental materials.

Results on GLUE Benchmark
The General Language Understanding Evaluation For small-size settings, we use the hyperparameter configuration as suggested in (Clark et al., 2020a). For base-size settings, we consider a limited hyperparameter searching for each task, with learning rates ∈ {5e-5, 1e-4, 1.5e-4} and training epochs ∈ {3, 4, 5}. The remaining hyperparameters are the same as ELECTRA. We report the median performance on the dev set over five different random seeds for each task. All the results come from the single-task fine-tuning. For more detailed fine-tuning configurations, please refer to the supplementary materials.
Results are shown in Table 1. With the same configuration and pre-training data, for both the smallsize and the base-size, our methods outperform the strong reimplemented ELECTRA baseline by 0.6 and 0.4 on average respectively. For the most widely reported task MNLI, our models achieve 87.0/86.9 points on the matched/mismatched set, which obtains 0.3/0.4 absolute improvements. The performance gains on the small-size models are more obvious than the base-size models, we speculate that is due to the learning of the small-size generator is more insufficient and suffers from the above issues more significantly. The results demonstrate that our proposed methods can improve the pre-training of ELECTRA. In other words, sampling more hard replacements is more efficient than the original masked language modeling.

Results on SQuAD 2.0
The Stanford Question Answering Dataset (SQuAD; Rajpurkar et al. 2016a) is a reading comprehension dataset, each example consists of a context and a question-answer pair. Given a context and a question, the task is to answer the question by extracting the relevant span from the context. We only use the version 2.0 for evaluation, where some questions are not answerable. We report the results of both the Exact-Match (EM) and F1 score. When fine-tuning on SQuAD, we add the question-answering module from XLNet on the top of the discriminator as Clark et al.  (2020a). All the hyperparameter configurations are the same as ELECTRA. We report the median performance on the dev set over five different random seeds. Refer to the appendix for more details about fine-tuning. Results on SQuAD 2.0 are shown in Table 2. Consistently, our models perform better than ELEC-TRA baseline under both the small-size setting and the base-size setting. Under the base setting, our models improve the performance over the reimplemented ELECTRA baseline by 0.6 points (EM) and 0.5 points (F1). Especially under the small setting, our models outperform the baseline by a remarkable margin. ELECTRA+HP Dist +Focal obtains 4.3 and 4.2 points absolute improvements on EM and F1 metric.

Ablation Studies
We conduct ablation studies on small-size ELEC-TRA+HP Loss +Focal models. We investigate the effect of the loss weight λ 1 of the sampling head and the focal loss factor γ in order to better understand their relative importance. Results are presented in Table 3.
We first disable the focal loss and only understand the effect of λ 1 . As shown in Table 3, no matter what the value of λ 1 is, our models exceed the baseline by a substantial margin, which demonstrates that the hardness prediction can indeed improve the pre-training and our methods  Table 3: Ablation studies on small-size models. We analyze the effect of the hardness prediction loss weight λ 1 and the focal loss factor γ. Reported results are medians over five random seeds.
are not sensitive to the loss weight hyperparameter. Next, we fix λ 1 at 5 and understand the effect of the focal loss factor γ. We observe that the application of the focal loss with piece- and γ = 1 can improve the performance on two datasets, which proves the effectiveness of the sampling smoothing.

Analysis
To better understand the main advantages of our models over ELECTRA, we conduct several analysis experiments.

Impacts on Sampling Distributions
We first provide a comparison between the sampling distributions of ELECTRA and our models illustrate the effect of our proposed methods. We conduct evaluations on a subset of the pre-training corpus. Figure 3 demonstrates the distribution of the maximum probability of the two sampling distributions at the masked positions. We observe that the ratio of the maximum value under ELEC-TRA sampling distribution between [0.9, 1] is much higher than that of our models. In other words, the original distribution suffers from over-sampling the high-probability tokens and the discriminator is forced to learn from these easy examples repeatedly. In contrast, the distribution of the maximum value of our models in each interval is relatively more uniform than ELECTRA, which indicates that our methods can significantly reduce the probability of sampling the well-classified tokens and smooth the

Token Type Original Replaced All
Corr. Coeff. 0.78 0.61 0.64 Table 4: Correlation coefficient between the actual discriminator loss and the estimated value for ELEC-TRA+HP Loss +Focal. "Original": sampling correct tokens as replacements. "Replaced": the positions that are substituted to incorrect tokens.
whole sampling distribution.

Estimation Quality
In order to measure the estimation quality of discriminator loss, we evaluate our models on a heldout set of pre-training corpus and compute the correlation coefficient between the actual discriminator loss L D (x, c) and the estimated valueL D (x, c).
The results of ELECTRA+HP Loss +Focal are shown in Table 4. We report the estimation quality of the original tokens and the replaced tokens separately. The correlation coefficient value is 0.64 over two types of tokens, which proves that L D (x, c) and L D (x, c) correlate well. Furthermore, we observe that the estimation quality over the original tokens is relatively higher than the replacements. We spec-  ulate that the sampling probability of the original tokens is generally higher than the replacements, so the sampling head tends to receive more feedback from these original tokens.

Prediction Accuracy of the Discriminator
In order to verify the claim that the sampling distribution of our models indeed considers the detection difficulty, we evaluate the prediction accuracy of the discriminator under the two sampling schemes of ELECTRA and ELECTRA+HP Loss +Focal. Results are listed in Table 5. No matter evaluating at all positions or only at the masked positions, the detection accuracy under our sampling distribution is relatively lower than under masked language modeling in original ELECTRA. Because the unmasked tokens constitute the majority of input examples, the difference of the all-token accuracy between two models is not so distinct compared to the masked tokens. This phenomenon is consistent with our original intention. It proves that our models can sample more replacements that the discriminator struggles to make correct predictions. In contrast, the replacements sampled from ELECTRA are easier to distinguish.

Conclusion
We propose to improve the replacement sampling for ELECTRA pre-training. We introduce two methods, namely hardness prediction and sampling smoothing. Rather than sampling from masked language modeling, we design a new sampling scheme, which considers both the MLM probability and the prediction difficulty of the discriminator. So the generator can receive feedback from the discriminator. Moreover, we adopt the focal loss to MLM, which adaptively downweights the wellclassified examples and smooth the entire distribution. The sampling smoothing technique relieves oversampling original tokens as replacements. Results show that our models outperform ELECTRA baseline. In the future, we would like to apply our strategies to other pre-training frameworks and cross-lingual models. Moreover, we are exploring how to integrate the findings and insights of the proposed method into the masked language modeling task, which seems also quite promising.