Increasing Robustness to Spurious Correlations using Forgettable Examples

Neural NLP models tend to rely on spurious correlations between labels and input features to perform their tasks. Minority examples, i.e., examples that contradict the spurious correlations present in the majority of data points, have been shown to increase the out-of-distribution generalization of pre-trained language models. In this paper, we first propose using example forgetting to find minority examples without prior knowledge of the spurious correlations present in the dataset. Forgettable examples are instances either learned and then forgotten during training or never learned. We show empirically how these examples are related to minorities in our training sets. Then, we introduce a new approach to robustify models by fine-tuning our models twice, first on the full training data and second on the minorities only. We obtain substantial improvements in out-of-distribution generalization when applying our approach to the MNLI, QQP and FEVER datasets.


Introduction
Despite the impressive performance of current NLP models, these models often exploit spurious correlations: they tend to capture prediction correlations that hold for most examples but do not hold in general. For instance, in natural language inference (NLI) datasets, word-overlap between hypothesis and premise is highly correlated with the entailment label (McCoy et al., 2019;Zhang et al., 2019). Therefore, these models are brittle when tested on examples that cannot be solved by recurring to these correlations, limiting their application in realworld scenarios. Out-of-distribution or challenging sets are benchmarks carefully designed to break systems that rely on such correlations.
The paradigm of fine-tuning pre-trained language models (PLM) has pushed the state-of-theart in a large variety of tasks involving natural lan-guage understanding (NLU) (Devlin et al., 2019;). This is achieved by selfsupervised learning from an enormous amount of text. PLMs also show increased robustness on challenging datasets (Hendrycks et al., 2019). This increase is attributed to an empirical finding that PLMs perform better on minority examples present in the training data (Tu et al., 2020). These minority examples violate the spurious correlations and therefore likely support the examples in challenging datasets. Tu et al. (2020) find minority examples by manually dividing the training data into two groups, according to the known spurious correlations (e.g., word-overlap in NLI). They present an analysis of the robustness of PLMs and its connection to minority examples. In this work, we first introduce a systematic way to find minority examples that does not need prior knowledge of spurious correlations, a big limitation of the earlier work. We then present a simple approach that increases the robustness of PLMs further by tuning models more on these examples.
To identify the set of minority examples, we adopt example forgetting (Toneva et al., 2019). This statistic has been shown to relate to the hardness of examples, so we assume it is useful to find minorities in the training data. Based on the definition presented in Toneva et al. (2019), we consider an example forgettable if during training it is either properly classified at some point and misclassified later, or if it is never properly classified. This method is model-and task-agnostic. We show in our datasets that minority examples w.r.t to spurious correlations, such as word-overlap in NLI, are well represented in forgettable examples.
After finding minorities through forgettable examples, we propose a simple method to increase the robustness of PLMs further. We perform an additional fine-tuning on the minorities ex-clusively, after fine-tuning on the whole training data. We find this strategy effective, as it increases robust accuracy, i.e., performance on outof-distribution data, while minimally impacting performance on in-distribution examples. We evaluate our proposed methods in three tasks, including NLI (MNLI, Williams et al., 2017), paraphrase identification (QQP, Iyer et al., 2017) and fact verification (FEVER, Thorne et al., 2018). For each task, recent work has introduced out-of-distribution test sets targetting specific spurious correlations.
Our contributions are the following: •  (Zhang et al., 2019) and FEVER-Symmetric (Schuster et al., 2019). Our method performs effectively when applied to both base and large versions of PLMs (e.g., BERT BASE and BERT LARGE ).
• We observe that finding minorities using a network shallower than the PLM is more effective to robustify it via fine-tuning.
• We show that training models only on forgettable examples leads to poor performance in our datasets, which contrasts with the vision results from Toneva et al. (2019). Our code is available at github.com/sordonia/ hans-forgetting

Datasets
We consider three sentence pair classification tasks, namely natural language inference, paraphrase identification, and fact verification. In the following, we describe the datasets we choose for each task following an introduction of the task.

Natural Language Inference
The first task we consider is MNLI (Williams et al., 2017), a common natural language inference dataset containing more than 400,000 premise and hypothesis pairs annotated with textual entailment information (neutral, entailment or contradiction). Models trained on this dataset have been shown to capture spurious correlations, such as word-overlap between hypothesis and premise as a strong signal for the entailment label (Naik et al., 2018;McCoy et al., 2019). A series of diagnostic out-of-distribution test sets have been devised to test robustness against such heuristics, e.g., HANS. HANS (McCoy et al., 2019, Heuristic Analysis for NLI Systems) is composed of both entailment and contradiction examples that have high word-overlap between hypothesis and premise (e.g. "The president advised the doctor" −→ "The doctor advised the president"). A model relying exclusively on the word-overlap feature would not have a higher than chance classification accuracy on HANS. As a matter of fact, BERT (Devlin et al., 2019) performance on this dataset is only slightly better than chance (McCoy et al., 2019). We consider HANS (size: 30k examples) and the MNLI matched dev (Williams et al., 2017) (size: 9815 examples) as our out-and in-distribution test sets for MNLI.

Paraphrase Identification
QQP (Iyer et al., 2017) is a widely used dataset for paraphrase identification containing over 400,000 pairs of questions annotated as either paraphrase or non-paraphrase. As a consequence of the dataset design, pairs with high lexical overlap have a high probability of being paraphrases. Similarly to MNLI, models trained on QQP are thus prone to learning lexical overlap as a highly informative feature and do not capture the common sense underlying paraphrasing. PAWS dataset is designed to test that.
PAWS (Zhang et al., 2019, Paraphrase Adversaries from Word Scrambling) is a question paraphrase dataset, well-balanced with respect to the lexical overlap heuristic. The accuracy of BERT is around 91.3% on QQP and only 32.2% on PAWS (Table 5). This makes it an interesting test-bed for our method. We use PAWS-QQP as our out-ofdistribution set, which contains 677 questions pairs. Training examples from PAWS were never used to update our models. Following Zhang et al. (2019) and Utama et al. (2020), our QQP training and testing splits are based on Wang et al. (2017).

Fact Verification
The task of fact verification aims to verify a claim given an evidence. The labels are support, refutes, and not enough information. This task is defined as part of the Fact Extraction and Verifi-cation (FEVER) challenge (Thorne et al., 2018). Schuster et al. (2019) show that models ignoring evidence can still achieve high accuracy on FEVER. They introduce an evaluation test set that challenges that bias. Following Utama et al. (2020), we use the FEVER-Symmetric datasets (Symm-v1 and Symm-v2 with 717 and 712 examples, respectively) for out-of-distribution evaluation 1 .

Finding Minorities with Forgettables
We first define example forgetting and how to compute it. We then show that it can be used to find minority examples in the training data.

Forgettable examples
An example is forgotten if it goes from being correctly to incorrectly classified during training (each such occurrence is called a forgetting event). This happens due to the stochastic nature of gradient descent, in which gradient updates performed on certain examples can hurt performance on others. If an example is forgotten at least once or is never learned during training it is dubbed forgettable.
Finding forgettable examples entails training the model on D and tracking the accuracy of each example at each presentation during training. The algorithm for computing forgettability is cheap (Toneva et al., 2019) and only requires storing the accuracy of each particular example at each epoch.
In Toneva et al. (2019), they extracted forgettable examples from a shallower network compared to their target model. This makes finding forgettables more efficient and also results in a more diverse set of examples, as the number of forgettable examples is usually higher for weaker models. Another factor is that the shallow models exhibit less memorization due to their fewer number of hyperparameters (Sagawa et al., 2020b) and therefore their forgettables are potentially more representative of the minorities.
We compute forgettable examples using two models with significantly lower capacity compared to PLMs. The first one is a "siamese" BoW classifier in which hypothesis and premise are independently encoded as a mean of word embeddings. This common model in NLP tasks has surprisingly good performance while relying only on the bag of lexical features. We also consider a siamese BiLSTM model. More details can be found in 1 https://github.com/TalSchuster/FeverSymmetric  Appendix A. Finally, for comparison, we also experiment with the model used for HANS in SOTA baselines (Clark et al., 2019;Utama et al., 2020) (see also §4.2), as well as BERT BASE for in NLI.
We train the shallow models for five epochs and track forgetting statistics after each epoch. Table 1 shows the number of forgettable examples for BoW, BiLSTM and BERT BASE on the MNLI, QQP and FEVER training sets. The performance of the models on the dev set of MNLI is also included.  Table 2: Average Jaccard index as a measure of wordoverlap between two sentences grouped by P (positive) and ¬P (non-positive).

Forgettable and minority examples
We focus on two important spurious correlations: word-overlap and contradiction-word. These correlations or biases are addressed in related work for MNLI and QQP (Tu et al., 2020;Zhou and Bansal, 2020). For convenience, we use "positive" for either entailment, supports or paraphrase, and "negative" for contradiction, refutes or non-paraphrase High word-overlap between two sentences is spuriously correlated to the positive label in all three datasets. In Table 2, we show that on average, positive examples have higher word-overlap compared to non-positive ones. In other words, minorities w.r.t. word-overlap correspond to non-positive examples with high word-overlap and positive examples with low word-overlap. For MNLI and QQP, the distribution in F BOW and F BILSTM exhibit an interesting behavior: on average, non-positive examples have higher word-overlap. F BERT has the same average for both labels. For FEVER, the difference is also clear as the gap in word-overlap between positive and non-positive examples is lower for forgettables. The table allows to conclude that forgettable examples contain more minority examples than a random subset of the same size.
In Table 3, we perform a similar analysis for the presence of contradiction words in the second sentence, which is shown to correlate with negative class in MNLI (Naik et al., 2018;Zhou and Bansal, 2020) and FEVER (Schuster et al., 2019). We choose these contradiction words: {"not", "no", "doesn't", "don't", "never", "any"}, and analyze all three datasets. We observe here as well that forgettables contain more minority examples, as their percentage of examples with a contradiction word is lower for negative examples, which is the opposite than in the overall dataset (with the exception of F BERT and FEVER).

PLMs
We are interested in the robustness of large PLMs.
In this work, we focus on two such models, BERT and XLNet, and experiment with both their base and large versions. BERT BASE being the model of choice in previous work (Clark et al., 2019;Zhang et al., 2019;Utama et al., 2020), it will serve as our default architecture. We adopt the Transformers library (Wolf et al., 2019). Our robust models are obtained by fine-tuning PLMs on the full training set for 3 epochs (using the default hyperparameters for each task) and then on the forgettable examples only, for 3 more epochs with a smaller learning rate. See B in Appendix for more details.

Baselines
Recently, multiple methods have been proposed to learn more robust models through mitigating biases (Clark et   To highlight the generality of our approach, we also add this biased model to the set of our shallow models for HANS and fine-tune on its forgettables ( F HANS with the size of around 200k). For FEVER-symmetric, Utama et al. (2020) consider an LSTM model that takes only the "claim" as input and ignores the "evidence". These baselines re-weight or confidence-regularize training examples using the biased models' performance.

MNLI and HANS
In Table 4, we present the results of our models and four recent baselines. The first line reports the performance of BERT on MNLI and HANS. The following lines report the results obtained by finetuning BERT on the set of forgettable examples obtained using different shallow models. We also report the average performance between MNLI and HANS. The results confirm that tuning the model towards minority examples improves robustness with a slight drop in MNLI accuracy. Our best model is obtained by fine-tuning on F BOW , achieving a HANS mean accuracy of 70.5% (with a max of 71.3% over five seeds, which constitutes a +8.4% absolute improvement w.r.t to the initial BERT). To assess whether F BOW is indeed responsible for the improvement, we also fine-tune BERT on the same number of randomly chosen examples (BERT + Rand 63,390 ), which leads to a negligible improvement.
Fine-tuning on F BILSTM is comparable to finetuning on F BOW , which demonstrates that both BoW and BiLSTM models learn similar spurious correlations. We also added results of fine-tuning BERT on its own forgettables for this task. Note  that while it provides less improvement in robustness than on F BILSTM or F BOW 2 , it does generate a significant 6.0% increase in performance. Finally, we also report fine-tuning results on F HANS , the biased model designed for HANS, and observe that it performs well with a smaller loss on MNLI and a smaller gain on HANS compared to F BOW and F BILSTM .
Compared to other baselines, our approach achieves a comparable or better average accuracy of MNLI and HANS, despite its simplicity. In Fig. 3, we breakdown the results of our best performing model for the three different heuristics HANS was built upon. Our method does not suffer as much as other baselines in the entailment class, and still provides a significant improvement for non-entailment. (More analysis is presented in Appendix.)

QQP and PAWS
Here we report the results of our method applied to QQP and PAWS as out-of-distribution dataset. Results can be found in Table 5. We observe that our method improves out-of-distribution accuracy substantially. It is worth noting that the groundtruth labels in QQP contain noisy annotations (Iyer et al., 2017); a portion of performance loss on QQP could be attributed to that.
Our method outperforms Reg-conf hans , while being simpler in terms of both the biased model and the training regime. We notice that Reg-conf hans also loses in-distribution performance 3 . 2 To eliminate the forgettables' size factor and focus on the type of model instead, we run an experiment where we sample from F BOW the same numbers as F BERT . The result of our fine-tuning on that smaller F BOW was still significantly better than F BERT . 3 The authors report accuracy on each label individually and not the overall accuracy. We compute that based on their  Figure 3: Performance of our BERT fine-tuned on the BiLSTM forgettables F BILSTM , and baselines on the "entailment" and "non-entailment" categories for each heuristic HANS was designed to capture.

FEVER
In Table 6, we report the results of our method applied to the FEVER development and symmetric evaluation sets (see §2.3). Our approach again works well for both F BOW and F BILSTM , but here we also gain on the original dev set when compared to the initial BERT BASE results. The gains of our method are larger than those of the Reg-conf claim baseline, which uses a biased model tailored to FEVER-symmetric.    obviously related, as the examples that are never learned rank the highest w.r.t to the loss and are considered as forgettables. However , Fig 4 shows, for MNLI and HANS, that using forgettables produces better performance both in-and out-of-distribution. One additional issue with using the final loss to pick examples is the need to determine either a threshold value α on the loss (keep examples with a loss larger than α) or a number N of examples to retain. The optimal α or N might yield better performance but finding them implies using the out-of-distribution set.

Robustness of larger models
We examine the performance of our method when applied to other PLMs and to larger networks by training BERT large and XLNET. Fig 5 shows the MNLI and HANS performance of those networks. Firstly, XL-Net is noticeably more robust than BERT, compatible with its superior in-distribution performance (Yang et al., 2019). Secondly, we observe that the large versions generalize on HANS significantly better than their base counterparts (e.g.,    Table 9: Average accuracy over seeds on FEVER and Fever-symm-v1 for BERT and XLNET base and large models, before and after fine-tuning on F BOW gest an intrinsic difficulty in F BERT that makes it hard for BERT BASE to generalize from it. However, as we showed previously, when starting from an already trained model, forgettables increase the out-of-distribution performance.
Calibration of models We look into the confidence of entailment when BERT BASE and BERT BASE + F BOW trained on MNLI are applied to HANS. In Fig 6, we show that BERT BASE can discriminate HANS entailments from non-entailments but with a very large classification threshold. Finetuning on forgettables recalibrates the classification threshold on HANS and makes 0.5 as the optimum value.
Other diagnostic evaluations Fine-tuning on the forgettable examples of simple biased models improves robustness in the three challenging benchmarks HANS, FEVER-Symmetric and PAWS. We additionally evaluate the trained models listed in Table 4 on Stress tests (Naik et al., 2018), adversarial NLI (Nie et al., 2019) and MNLI-matchedhard (Gururangan et al., 2018). For these test sets, we do not observe improvements when evaluating the robust model using F BOW . We posit that specific biased models might be needed in some of these cases. As a validation, for MNLI-matchedhard, we design a BiLSTM model that only takes the hypothesis as input, and apply our method using the forgettables of that model to fine-tune BERT BASE . We observe an increase in performance from 76.5% to 78.0% (averaged across five seeds). These results suggest that the forgettable examples of simple biased models like BoW or Bi-LSTM capture the more informative heuristics like word-overlap well. However, for less informative Figure 6: HANS accuracy vs classification threshold used to predict entailment/non-entailment. The base BERT model is overconfident in the entailment class while after fine-tuning on forgettables, we can improve model calibration.
heuristics like hypothesis-only features, a heuristicdesigned biased model is a better choice since its forgettables likely violate the specific heuristic.

Related Work
A growing body of literature recently focused on out-of-distribution generalization, showing that it is far from being attained, even in seemingly simple cases (Geirhos et al., 2019;Jia and Liang, 2017;Dasgupta et al., 2018). In particular, and in contrast with what Mitchell et al. (2018) recommend, NLP models do not seem to "embody the symmetries that allow the same meaning be expressed within multiple grammatical structures". Supervised models seem to exhibit poor systematic generalization capabilities (Loula et al., 2018;Baan et al., 2019;Hupkes et al., 2018) thus seemingly lacking compositional behavior (Montague, 1970). While this might seem at odds with the common belief that high-level semantic representations of the input data are formed (Bengio et al., 2009b), the reliance on highly predictive but brittle features is not confined to NLU tasks. It is also a perceived shortcoming of image classification models (Geirhos et al., 2019;. To test systematically if machine learning models generalize beyond their training distribution, several challenging datasets have been introduced in NLP and other ML applications (Kalpathy-Cramer et al., 2015;Peng et al., 2019;Clark et al., 2019). Those test sets are made automatically from designed grammars (McCoy et al., 2019) and/or by human annotators (Zhang et al., 2019;Schuster et al., 2019).
Dataset re-sampling and weighting These techniques have been studied in order to solve class imbalance problem (Chawla et al., 2002) or co-variate shift (Sugiyama et al., 2007) Katharopoulos and Fleuret (2018); Kim and Choi (2018); Jiang et al. (2018) have shown the concept can be quite successful in a variety of areas. Our robustifying method is related to this concept. However, our models are first trained using i.i.d samples from the whole dataset and then fine-tuned on more difficult cases, i.e., the minorities.
Spurious correlations in NLU datasets like MNLI or FEVER are the subjects of many works. They include (i) the presence of specific words in the hypothesis or claim, for example, negation words like "not" are correlated with the contradiction label in entailment tasks (Naik et al., 2018;Gururangan et al., 2018), or bigrams like "did not" with the refute label, in fact, verification (Schuster et al., 2019) (2020) show that when datasets are long-tailed, rare and atypical instances make up a significant fraction of the data distribution and memorizing them leads to better indomain generalization. They find those rare and atypical examples using influence estimation. We instead study forgettable examples and their impact on out-of-distribution generalization. An interesting experiment would be to mine minority examples by influence estimation and compare with forgettable examples.

Conclusion
We introduced a novel approach, based on example forgetting, to extract minority examples and build more robust models systematically. Via example forgetting, we built a set of minority examples on which a pre-trained model is fine-tuned. We evaluated our method on large-scale models such as BERT and XLNet and showed a consistent improvement in robustness on three challenging test sets. We also showed that the larger versions obtain higher out-of-distribution performance than the base ones but still benefit from our method.
A Details of biased models (BoW and BiLSTM) Both models are Siamese networks, with similar input representations and classification layers. For the input layer, we lower case and tokenize the inputs into words and initialize their representations with Glove, a 300-dimensional pretrained embedding (Pennington et al., 2014). For the classification task, from the premise and hypothesis vectors p and h, we build the concatenated vector s = [p, h, |p − h|, p h] and pass it to a 2layer feedforward network. To compute p or h, the BoW model max-pools the bag of word embeddings, while the BiLSTM model max-pools the top-layer hidden states of a 2-layer bidirectional LSTM. The hidden size of the LSTMs is set to 200. Overall, BoW and BiLSTM contain 560K and 2M parameters, respectively.

B Hyperparameters and training time
We use a learning rate of 5e-5 for MNLI and QQP when training the PLMs on the full training and the learning rate of 1e-5 when fine-tuning on forgettables. For FEVER, we use 2e-5 and 5e-6 for the full training and the fine-tuning on forgettables, respectively. With a 4x Tesla P100 GPU machine and batchsize 256 per GPU, one epoch of training on the full train set takes around 4-6 minutes for BOW and BiLSTM models in all of the three training tasks.
For BERT BASE , with batch-size 32 per GPU, one epoch of training on the full train set takes around 30 / 20 / 30 minutes (per task). The maximum input length after tokenization is set to 128 in all the experiments.  In Table 10, we show the performance of our method on the MNLI dev set as a function of wordoverlap, the main heuristic HANS was designed against. We split the evaluation set into High (> mean) and Low (< mean) word-overlap examples, where word-overlap is measured using the Jaccard Index between hypothesis and premise. We see in particular that entailment pairs with high wordoverlap suffer from the fine-tuning on forgettables, while non-entailment improves (we observe a similar trend for QQP; see App. C). This supports the observations in 3.2 that the initial model relied on the spurious correlation of word-overlap and entailment to classify pairs and that by fine-tuning on forgettable examples, the performance on minorities increased.

C Forgettables and word-overlap in QQP
In Table 11, we show the performance of our method on the QQ evaluation set as a function of word-overlap, the main heuristic PAWS was designed against. We see in particular that paraphrase pairs with high word-overlap suffered from the fine-tuning, while non-paraphrase improved. This supports the intuition that the initial model relied on word-overlap to classify pairs as paraphrase, while forgettables help mitigate that phenomenon to some extent.