Empowering Language Understanding with Counterfactual Reasoning

Present language understanding methods have demonstrated extraordinary ability of recognizing patterns in texts via machine learning. However, existing methods indiscriminately use the recognized patterns in the testing phase that is inherently different from us humans who have counterfactual thinking, e.g., to scrutinize for the hard testing samples. Inspired by this, we propose a Counterfactual Reasoning Model, which mimics the counterfactual thinking by learning from few counterfactual samples. In particular, we devise a generation module to generate representative counterfactual samples for each factual sample, and a retrospective module to retrospect the model prediction by comparing the counterfactual and factual samples. Extensive experiments on sentiment analysis (SA) and natural language inference (NLI) validate the effectiveness of our method.


Introduction
Language understanding (Ke et al., 2020) is a central theme of artificial intelligence (Chomsky, 2002), which empowers a wide spectral of applications such as sentiment evaluation (Feldman, 2013), commonsense inference (Bowman et al., 2015). The models are trained on labeled data to recognize the textual patterns closely correlated to different labels. Owing to the extraordinary representational capacity of deep neural networks, the models can well recognize the pattern and make prediction accordingly (Devlin et al., 2019). However, the cognitive ability of these data-driven models is still far from human beings due to lacking counterfactual thinking (Pearl, 2019).
Counterfactual thinking is a high-level cognitive ability beyond pattern recognition (Pearl, 2019). In addition to observing the patterns within factual * * Corresponding author. samples, counterfactual thinking calls for comparing the fact with imaginations, so as to make better decision. For instance, given a factual sample "What do lawyers do when they die? Lie still.", the intuitive evaluation of its sentiment based on the textual patterns will recognize "Lie still" as an objective description of body posture which is neutral. By scrutinizing that the "still" could be intentionally postposed, we can imagine a counterfactual sample "What do lawyers do when they die? Still lie." and uncover the negative sarcastic pun, whose sentiment is more accurate.
Recent work (Kaushik et al., 2019;Zeng et al., 2020) shows that incorporating counterfactual samples into model training improves the generalization ability. However, these methods follow the standard machine learning paradigm that uses the same procedure (e.g., a forward propagation) to make prediction in the testing phase. That is, making decision for testing samples according to their relative positions to the model decision boundary. The indiscriminate procedure focuses on the textual patterns occurred in the testing sample and treats all testing samples equally, which easily fails on hard samples (cf. Figure 1). On the contrary, humans can discriminate hard samples and ponder the decision with a rational system (Daniel, 2017), which imagines counterfactual and adjusts the decision.
The key to bridge this gap lies in imitating the counterfactual thinking ability of humans, i.e., learning a decision making procedure to serve for the testing phase. That is a procedure of: 1) constructing counterfactual samples for a target factual sample; 2) calling the trained language understanding model to make prediction for the counterfactual samples; and 3) comparing the counterfactual and factual samples to retrospect the model prediction. However, the procedure is non-trivial to achieve for two reasons: 1) the space of counterfactual sample is huge since any variant from the target factual sample can be a counterfactual sample. It is thus challenging to search for suitable counterfactual samples that can facilitate the decision making. 2) The mechanism of how we retrospect the decision is still unclear, making it hard to be imitated.
Towards the target, we propose a Counterfactual Reasoning Model (CRM), which is a two-phase procedure consisting a generation module and a retrospection module. In particular, given a factual sample in the testing phase, the generation module constructs representative counterfactual samples by imagining what would the content be if the label of the sample is y. To imitate the unknown retrospection mechanism of humans, we build the retrospection module as a carefully designed deep neural network that separately compares the latent representation and the prediction of the factual and counterfactual samples. The proposed CRM forms a general paradigm that can be applied to most existing language understanding models without constraint on the format of the language understanding task. We select two language understanding tasks: SA and NLI, and test CRM on three representative models for each task. Extensive experiments on benchmark datasets validate the effectiveness of CRM, which achieves performance gains ranging from 5.1% to 15.6%.
The main contributions are as follow: • We propose the Counterfactual Reasoning Model to enlighten the language understanding model with counterfactual thinking.
• We devise a generation module and a retrospection module that are task and model agnostic.
• We conduct extensive experiments, which validate the rationality and effectiveness of the proposed method.

Pilot Study
Decisions are usually accompanied by confidence, a feeling of being wrong or right (Boldt et al., 2019). From the perspective of model confidence, we investigate the performance of language understanding models across different testing samples. We estimate the model confidence on a sample as the widely used Maximum Class Probability (MCP) (Corbière et al., 2019), which is the probability over the predicted class. A lower value of MCP means less confidence and "hard" sample. According to the value of MCP, we rank the testing (a) Sentiment analysis (b) Natural language inference Figure 1: Prediction performance of the language understanding models over testing samples at different confidence levels.
samples in ascending order and split them into ten groups, i.e., confidence level from 1 to 10. Figure 1 shows the performance of representative models over samples at different model confidence levels on the SA and NLI tasks (see Section 4.1 for model and dataset descriptions). From the figures, we can observe a clear increasing trend of classification accuracy as the confidence level increases from 1 to 10 in all cases. In other words, these models fail to predict accurately for the hard samples. It is thus essential to enhance the standard inference with a more precise decision making procedure.

Methodology
In this section, we first formulate the task of learning a decision making procedure for the testing phase (Section 3.1), followed by introducing the proposed CRM (Section 3.2) and the paradigm of building language understanding solutions with CRM (Section 3.3).

Problem Formulation
As discussed in the previous work Li et al., , 2019, language understanding tasks can be abstracted as a classification problem where the input is a text and the target is to make decision across a set of candidates of interests. We follow the problem setting with consideration of counterfactual samples (Kaushik et al., 2019;Liang et al., 2020), where the training data are twofold: 1) factual samples T = {(x, y)} where y ∈ [1, C] denotes the class or the target decision of the text; x ∈ R D is the latent representation of the text, which encodes the textual contents 1 . 2) counterfactual samples is a counterfactual sample in class c corresponds to the factual sample (x, y) 2 . We assume that a classification model (e.g., BERT (Devlin et al., 2019)) has been trained over the labeled data. Formally, whereθ is the learned parameters of the model f (·) ; l(·) is a classification loss such as crossentropy (Kullback, 1997), and α is a hyperparameter to adjust the regularization. The target is to build a decision making procedure to perform counterfactual reasoning when serving for the testing phase. Given a testing sample x, the core is a policy of generating counterfactual samples and retrospecting the decision, which is formulated as: y ∈ R C denotes the final prediction for the testing sample x, which is a distribution over the classes; x * is one of the generated counterfactual samples for x. The generation module g(·) parameterized by ω is expected to construct a set of representative counterfactual samples for the target factual sample, which provide signals for the retrospection module h(·) parameterized by η to retrospect the prediction f x|θ given by the trained classification model. In particular, h(·) and g(·) will be learned from the factual and counterfactual training samples, respectively. Figure 2 illustrates the process of CRM where the arrows in grey color represent the standard inference of trained classification model, and arrows in red color represent the retrospection with consideration of counterfactual samples.

Retrospection Module
We devise the retrospection module with one key consideration-distilling signals for making final decision by comparing both the latent representation and the prediction of the counterfactual samples with the factual sample. To achieve the target, encoder for briefness since focusing on the decision making.
2 Given the labeled factual sample, counterfactual samples can be constructed either manually (Kaushik et al., 2019) or automatically (Chen et al., 2020) by conducting minimum changes on x to swap its label from y to c we devise three key building blocks for retrospection, which successively perform representation comparison, prediction comparison, and fusion . In particular, the module first compares the representation of each counterfactual sample with the factual sample; then compares their predictions accordingly; and fuses the comparison across the counterfactual samples.
Representation comparison. Given a pair of counterfactual sample x * and factual sample x, we believe the signals meaningful for making final decision lie in the difference of the samples and how the difference affects the classification. To distill such signals, we devise the representation comparison block as y ∆ = f (x − x * |θ), where y ∆ ∈ R C denotes the prediction of the representation difference x − x * given by the trained classification model. Note that we leverage the trained model to enlighten how the content difference affects the classification since the model is trained to capture the connection between the textual patterns and the classes. It should be noted that we use a duplicate of the trained classification model for the representation comparison. That is to say, the training of the retrospection module will not affect the classification model.

Prediction comparison.
To retrospect the prediction f (x|θ), we devise a prediction comparison block to compare the predictions of each counterfactual and factual sample pair and distill patterns from f (x|θ), f (x * |θ), and y ∆ . Inspired by the success of convolutional neural network (CNN) in capture local-region patterns, the block is devised as a CNN, which is formulated as: where y * denotes the retrospected prediction when comparing to x * . In particular, a stack layer first stacks the three predictions as a matrix, which serves as an "image" to facilitate "observing" pat- . Y is then fed into an 1D convolution layer to capture the intra-class patterns across the predictions, which is formulated as: where F ∈ R 3×K denotes the filters in the convolution layer, and σ(·) is an activation function such as GELU (Hendrycks and Gimpel, 2016). Y :i and F j represent the i-th row of Y and the j-th column of F , respectively. The filter F j can learn rules for conducting retrospection. For instance, a filter [1, −1, 0] means deducting the prediction of the counterfactual sample from that of the factual sample. The output H ∈ R C×K is then flattened as a vector and fed into a fully-connected (FC) layer to capture the inter-class patterns. Formally, where W and b are model parameters.
Fusion. The target is to fuse the retrospected predictions {y * } into a final decision y. Inspired by the success of pooling function in reading out patterns, we devise the block as y = pooling({y * }).
As the fusion is performed after the pairwise comparison, we term it as late fusion.
Training. We update the parameters of the retrospection module by minimizing the classification loss over the factual training samples, which is: where λ denotes the hyper-parameter to adjust the weight of the regularization term. It should be noted that no existing research has uncovered the specific mechanism of retrospection in our brain, i.e., the order of comparison and fusion is unclear. As such, we further devise two fusion strategies: middle fusion and early fusion, which performs fusion within the CNN, i.e., during comparison, and before the CNN, respectively.
• Middle fusion performs aggregation between the convolution layer and the FC layer. This fusion first calculates the latent comparison signals H for each pair of counterfactual and factual samples according to Equation 3. The aggregated signals pooling({H}) are then fed into the FC layer (Equation 4) to obtain the final decision y.
• Early fusion aggregates the counterfactual samples before performing comparison, which is formulated asx * = pooling({x * }). In this way, the retrospection module is formulated as: . For all the three fusion methods, we can use either regular pooling function without parameter or parameterized pooling function (Ying et al., 2018) to enhance the expressiveness of the retrospection module. In our experiments, using a simple mean pooling achieves a performance that is comparable to the parameterized one in most cases (cf. Table 3).

Generation Module
The target is to construct counterfactual samples that are informative for retrospecting the decision on the target factual sample x. As the task involves making decision among C candidate classes, we believe that the key to generate representative counterfactual samples lies in imagining "what would the content be if the sample belongs to class c", i.e., generating C counterfactual samples {x * c }. With the C classes as the targets, the searching space of samples can also be largely narrowed down. Toward this end, we devise the generation module with two main considerations: 1) decomposing the factual sample x to distill contents irrelevant to the label of the sample u = d(x|ω); 2) injecting class c into u to form the counterfactual sample x * c .
Decomposition. To distill u, we need to recognize the connection between the content of the factual sample and each class. We thus account for class representations in the decomposition function.
To align the sample space of the generation module with the retrospection module h(·) and the classification model f (·), we extract the parameters from the prediction layer of the trained classification model as the class representations. In particular, we extract the mapping matrix W ∈ R C×D where the c-th row corresponds to class c. Note that we assume that the prediction layer has the same dimensionality as the latent representation, which is a common setting in most cutting edge language understanding models. The decomposition function is devised as a CNN to capture both the intra-dimension and inter-dimension connections between the factual sample and the classes.
• Stack layer. The stack layer stacks the factual sample, class representations, and the element-wise product between sample and each class, which is formulated as: x W T ∈ R D×C shed lights on how closely each dimension of x connect to each class, where large absolute value indicates closer connections.
• Convolution layer. This layer uses 1D horizontal filters to learn patterns of deducting class relevant contents from the factual sample, which is formulated as h = pooling(σ(X * F g )). F g ∈ R (2C+1)×L denotes the filters where L is the total number of filters. The output h ∈ R D is a hidden representation.
• FC layers. We use two FC layers to capture the inter-dimension connections. Formally, u = M is a hyper-parameter to adjust the complexity of the decomposition function. Note that we can stack more layers to enhance the expressiveness of the function, whereas using two layers according to the universal approximation theorem (Hornik, 1991). We learn the parameters of the decomposition function from the counterfactual training samples by optimizing the following objective: where u * c = d(x * c |ω) and u = d(x|ω) are the decomposition results of the counterfactual sample x * c and the corresponding factual sample x; u c = 1 2 (x + x * c ) denotes the target value of the decomposition. The two terms r(·) and l(·) are Euclidean distance (Dattorro, 2010) and classification loss. By minimizing the two terms, we encourage the decomposition result: 1) to be close to the target valueũ c ; and 2) if being deducted from the original sample (e.g., , x − u), the classification cannot be influenced. γ is a hyper-parameter to balance the two terms.
The rationality of settingũ c = 1 2 (x + x * c ) as the target class irrelevant content of x and x * c comes from the parallelogram law (Nash, 2003). Note that this pair of samples belong to two different classes where a decision boundary (a hyperplane) lies between the two classes y and c. Considering that the sample x corresponds to a vector in the hidden space, we can decompose the vector into two components that are orthogonal and parallel to the decision boundary, i.e., x * c = o * c + p * c and x = o + p. Since the two samples belong to different classes, their orthogonal components are in opposite directions and their addition will only retain the parallel components, which are irrelevant to judging the class between y and c 3 .
Injection. Accordingly, given a testing sample x, we can inject the orthogonal components towards class c via x * c = 2 * d(x|ω c ) − x, which is the imagined content of the sample if it belongs to class c. In this way, for each testing sample, we conduct the injection over all the classes and construct C counterfactual samples {x * c }, which are then used in the retrospection module 4 .

Learning Paradigm with CRM
The existing work (Kaushik et al., 2019;Zeng et al., 2020) for language understanding typically follows the standard learning paradigm, i.e., training a classification model over labeled data. Applying the proposed CRM indeed forms a new learning paradigm for constructing language understanding solutions. Algorithm 1 illustrates the procedure of the new paradigm. Classification model inference 6: for c = 1 → C do 7: x * c = 2 * g(x|ωc) − x; Generation 8: end for 9: Calculate h(x, {x * c }|η,θ); Retrospection

Experiments
We conduct experiments on two representative language understanding tasks, SA and NLI, to answer the following research questions: • RQ1: To what extent counterfacutal reasoning improves language understanding? • RQ2: How does the design of the retrospection module affect the proposed CRM?
• RQ3: How effective are the counterfactual samples generated by the proposed generation module?

Experiment Settings
Datasets. We adopt the same datasets in (Kaushik et al., 2019) for both tasks. The SA data are reviews from IMDb, which are labeled as either positive or negative. For each factual review, the dataset contains a manually constructed counterfactual sample where the crowd workers are asked to manipulate the text to reverse the label with the constraint of no gratuitous change. NLI is a three-way classification task with two sentences as inputs and the target of detecting their relation within entailment, contradiction, and neutral. For each factual sample, four counterfactual samples are given, which are constructed by editing either the first or the second sentence with target relations different to the label of the factual sample. Classification models. Owing to the extraordinary representational capacity of language model, fine-tuning pre-trained language model has become the emergent technique for solving language understanding tasks (Devlin et al., 2019). We select the widely used RoBERTa-base 5 and RoBERTalarge 6 for the consideration of the robustness of the RoBERTa (Liu et al., 2019) and our limited computation resources. For SA, we also test the classical Multi-Layer Perceptron (MLP) (Teney et al., 2020) with tf-idf text features (Schütze et al., 2008) as inputs. For NLI, we further test RoBERTa-largenli 7 , which has been fine-tuned on the large-scale MultiNLI dataset (Williams et al., 2018).
Baselines. As the proposed CRM leverages counterfactual samples, we compare CRM with three representative methods using counterfactual samples in language understanding tasks: 1) +CF (Kaushik et al., 2019), which uses counterfactual samples as data augmentation for model training; 2) +GS (Teney et al., 2020), which compares the factual and counterfactual samples in model training through regularizing their gradients; and 3) +CL (Liang et al., 2020), which compares the factual and counterfactual samples through a contrastive loss. Moreover, we report the performance of the testing model under Normal Training, i.e., training over factual samples only.
Implementation. We implement the proposed CRM with PyTorch 1.7.0 based on Hugging Face Transformer 8 , which is released at: https://github. com/fulifeng/Counterfactual Reasoning Model. In all cases, we follow the setting of +CF for training the classification model, which is a standard fine-tuning in (Liu et al., 2019). We then use adam (Kingma and Ba, 2014) with learning rate of 0.001 to optimize the retrospection module and the generation module. For the retrospection module, we set the number of filters in the convolution layer K as 10, the weight for regularization λ as 0. As to the generation module, we set the number of convolution filters as 10, the size of the hidden layer M as 256, and the weight for balancing Euclidean distance and classification loss γ as 15. We report the average classification accuracy over 5 different runs. For each repeat, we train the model with 20 epochs and select the model with the best performance on the validation set.

Performance Comparison (RQ1)
We first use the handcrafted counterfactual samples to demonstrate the effectiveness of counterfactual reasoning in the inference stage of language understanding model, which can be seen as using a golden standard generation module to provide counterfactual samples for the retrospection module. Note that we do not use the label of counterfactual samples in the testing set. Table 1 shows the performance of the compared methods on the two tasks. From the table, we observe that: • +CRM largely outperforms all the baseline methods in all cases. As compared to +CF, the same classification model without CRM in the testing phase, +CRM achieves relative performance improvement up to 15.6%. The performance gain is attributed to the retrospection module, which justifies the rationality and effectiveness of incorporating counterfactual thinking into the inference stage of language understanding model. In other words, by comparing the factual sample with its counterfactual samples, the retrospection module indeed makes more accurate decisions. huge gap of representational capacity between MLP and RoBERTa-large.
• The performance of baseline methods are comparable to each other in most cases, i.e., incorporating counterfactual samples into model training does not necessarily improve the testing performance on factual samples. This result is consistent with (Kaushik et al., 2019), which is reasonable since these methods are devised for enhancing the generalization ability, especially for the out-of-distribution testing samples, which can sacrifice the performance on normal testing samples. Besides, the result indicates that training with counterfactual samples is insufficient for achieving counterfactual thinking, which reflects the rationality of enhancing the inference paradigm with a decision making procedure. Performance on hard samples. Furthermore, we investigate whether the proposed CRM facilitate dealing with hard samples. Recall that we split the testing samples into 10 groups according to the confidence of the classification model, i.e., +CF (cf. Section 2). We perform group-wise comparison between +CF and +CRM. Figure 3 shows the performance of all the classification models with +CF and +CRM. From the figures, 1) we observe that the performance of +CRM is stable across different confidence levels, whereas the performance of the classification model shows a clear decreasing trend as the confidence level decreases from 10 to 1.  CRM V.S. implicit modeling. According to the uniform approximation theorem (Hornik, 1991), the CRM can also be approximated by a deep neural network. We thus investigate whether counterfactual thinking can be learned in an implicit manner. In particular, we evaluate a model that takes both the factual sample and counterfactual samples as inputs to make prediction for the fac-tual one. Table 2 shows the performance, where we have the following observations: 1) The implicit modeling performs much worse than the proposed CRM in most cases, which justifies the effectiveness of the retrospection module and the rationality of modeling comparison explicitly. 2) On the NLI task, RoBERTa-base+CRM outperforms RoBERTa-large (implicit), which means that the superior performance of CRM is not because of the additional model parameters introduced by the retrospection module, but the explicit comparison between factual and counterfactual samples.

In-depth Analysis
Effects of retrospection module design (RQ2). Note that the order of comparison and fusion in the retrospection mechanism of us humans is still unclear. We investigate how the fusion strategies influence the effectiveness of the proposed CRM. Table 3 shows the performance of CRM based on the early fusion (EF), late fusion (LF), and middle fusion (MF) on the NLI task. We omit the comparison on the SA task since the dataset only has one counterfactual sample for the target factual sample. For both EF and LF, we use the mean pooling as the pooling function. As to MF, we use a pooling function that is equipped with self-attention (Vaswani et al., 2017). The reasons of this setting are twofold: 1) using mean pooling will make LF and MF equivalent since the FC layer in the retrospection module is a linear mapping. Note that LF performs pooling after the FC layer, while the pooling function of MF is just before the FC layer. 2) The comparison between the LF and MF can thus shed light on whether parameterized pooling function can benefit the retrospection.
From the table, we can observe that, in most cases, CRM based on different fusion strategies achieve performance comparable to each other. It indicates that the retrospection is insensitive to the order of fusion and the comparison between counterfactual and factual samples. Considering that MF with mean pooling is equivalent to LF, we can see that the benefit of parameterized pooling function is limited. In particular, MF only performs better than LF on one of the three testing models.
Effects of generation module (RQ3). We then investigate whether the proposed generation module constructs useful counterfactual samples for retrospection. We train and test the retrospection module (using EF) with the generated samples on RoBERTa-large on the SA task. We omit the experiments of other settings for saving computation resources. In this way, the model achieves an accuracy of 94.5 which is better than +CF (93.4) but worse than +CRM with manually constructed counterfactual samples (98.2) (cf. Table 1). The result indicates that the generated samples indeed facilitate the retrospection while the generation quality can be further improved. Moreover, on the testing samples at confidence level of 1, using the generated samples achieves an accuracy of 81.3 which is much better than +CF (70.8) (cf. Figure 3). The generated samples indeed benefit the decision making over hard testing samples.

Related Work
Counterfactual sample. Constructing counterfactual samples has become an emergent data augmentation technique in natural language processing, which has been used in a wide spectral of language understanding tasks, including SA (Kaushik et al., 2019;, NLI (Kaushik et al., 2019), named entity recognition (Zeng et al., 2020) question answering (Chen et al., 2020), dialogue system , vision-language navigation (Fu et al., 2020). Beyond data augmentation under the standard supervised learning paradigm, a line of research explores to incorporate counterfactual samples into other learning paradigms such as adversarial training Fu et al., 2020;Teney et al., 2020) and contrastive learning (Liang et al., 2020). This work lies in an orthogonal direction that incorporates counterfactual samples into the decision making procedure of model inference.
Counterfactual inference. A line of research attempts to enable deep neural networks with counterfactual thinking by incorporating counterfactual inference (Yue et al., 2021;Wang et al., 2021;Niu et al., 2021;Tang et al., 2020;Feng et al., 2021). These methods perform counterfactual inference over the model predictions according to a pre-defined causal graph. Due to the requirement of causal graph, such methods are hard to be generalized to different tasks. Our method does not suffer from such limitation since working on the counterfactual samples which can be generated without a comprehensive causal graph.
Hard sample. A wide spectral of machine learning techniques are related to dealing with the hard samples in language understanding. For instance,  Table 3: Performance of the proposed CRM based on early fusion (EF), late fusion (LF), or middle fusion (MF) on the NLI task. RI represents the relative performance improvement over the +CF method. adversarial training (Khashabi et al., 2020) enhances the model robustness against perturbations and attacks, which are hard samples for normally trained models. Debiased training (Tu et al., 2020;Utama et al., 2020) eliminates the spurious correlation or bias in training data to enhance the generalization ability and deal with out-of-distribution samples. In addition to the training phase, a few inference techniques might improve the model performance on hard samples, including posterior regularization (Srivastava et al., 2018) and causal inference (Yu et al., 2020;Niu et al., 2021). However, both techniques require domain knowledge such as prior or causal graph tailored for specific applications. On the contrary, this work provides a general paradigm that can be used for most language understanding tasks.

Conclusion
In this work, we pointed out the issue of standard inference of existing language understanding models. We proposed a Counterfactual Reasoning Model which empowers the trained model with a highlevel cognitive ability, counterfactual thinking. By applying the proposed CRM, we formed a new paradigm for building language understanding solutions. We conducted extensive experiments, which validate the effectiveness of our proposal, especially in dealing with hard samples.
This work opens up a new research direction about the decision making procedure in testing phase. In the future, we will explore sequential decision procedure to resolve the constraint on the number of constructed counterfactual samples. In addition, we will investigate generation module for language understanding with unsupervised generative techniques (Sauer and Geiger, 2021).