Disentangling Representations of Text by Masking Transformers

Representations from large pretrained models such as BERT encode a range of features into monolithic vectors, affording strong predictive accuracy across a multitude of downstream tasks. In this paper we explore whether it is possible to learn disentangled representations by identifying existing subnetworks within pretrained models that encode distinct, complementary aspect representations. Concretely, we learn binary masks over transformer weights or hidden units to uncover subsets of features that correlate with a specific factor of variation; this eliminates the need to train a disentangled model from scratch for a particular task. We evaluate this method with respect to its ability to disentangle representations of sentiment from genre in movie reviews,"toxicity"from dialect in Tweets, and syntax from semantics. By combining masking with magnitude pruning we find that we can identify sparse subnetworks within BERT that strongly encode particular aspects (e.g., toxicity) while only weakly encoding others (e.g., race). Moreover, despite only learning masks, we find that disentanglement-via-masking performs as well as -- and often better than -- previously proposed methods based on variational autoencoders and adversarial training.


Introduction and Motivation
Large pretrained models such as ELMo (Peters et al., 2018), BERT (Devlin et al., 2019), and XL-Net (Yang et al., 2019) have come to dominate modern NLP. Such models rely on self-supervision over large datasets to learn general-purpose representations of text that achieve strong predictive performance across a spectrum of downstream tasks (Liu et al., 2019). A downside of such learned representations is that it is not obvious what information they encode, which hinders model robustness and interpretability. The opacity of embeddings produced by models such as BERT has motivated NLP research on designing probing tasks as a means of uncovering the properties of input texts that are encoded in learned representations (Rogers et al., 2020;Linzen et al., 2019;Tenney et al., 2019).
In this paper we investigate whether we can uncover disentangled representations from pretrained models. That is, rather than mapping inputs onto a single vector that captures arbitrary combinations of features, our aim is to extract a representation that factorizes into distinct, complementary properties of inputs. Explicitly factorizing representations aids interpretability, in the sense that it becomes more straightforward to determine which factors of variation inform predictions in downstream tasks.
A general motivation for learning disentangled representations is to try and minimize -or at least expose -model reliance on spurious correlations, i.e., relationships between (potentially sensitive) attributes and labels that exist in the training data but which are not causally linked (Kaushik et al., 2020). This is particularly important for large pretrained models like BERT, as we do not know what the representations produced by such models encode. Here, learning disentangled representations may facilitate increased robustness under distributional shifts by capturing a notion of invariance: If syntactic changes do not affect the representation of semantic features (and vice versa) then we can hope to learn models that are less sensitive to any incidental correlations between these factors.
As one example that we explore in this paper, consider the task of identifying Tweets that contain hate speech (Founta et al., 2018). Recent work shows that models trained over Tweets annotated on a toxicity scale exhibit a racial bias: They have a tendency to over-predict that Tweets written by users who self-identify as Black are "toxic", owing to the use of African American Vernacular English (AAVE; Sap et al. 2019). In principal, disentangled representations would allow us to isolate relevant signal from irrelevant or spurious factors (such as, in this case, the particular English dialect used), which might in turn reveal and allow us to mitigate unwanted system biases, and increase robustness.
To date, most research on disentangled representations has focused on applications in computer vision (Locatello et al., 2019b;Kulkarni et al., 2015;Chen et al., 2016;Higgins et al., 2017), where there exist comparatively clear independent factors of variation such as size, position, and orientation, which have physical grounding and can be formalized in terms of actions of symmetry subgroups (Higgins et al., 2018). A challenge in learning disentangled representations of text is that it is less clear which factors of variation should admit invariance. Still, we may hope to disentangle particular properties for certain applications -e.g., protected demographic information (Elazar and Goldberg, 2018) -and there are general properties of language that we might hope to disentangle, e.g., syntax and semantics .

Methods
We are interested in learning a disentangled representation that maps inputs x (text) onto vectors z (a) and z (b) that encode two distinct factors of variation. To do so, we will learn two sets of masks M (a) and M (b) that can be applied to either the weights or the intermediate representations in a pretrained model (in our case, BERT). We estimate only the mask parameters and do not finetune the weights of the pretrained model.
To learn M (a) and M (b) , we assume access to triplets (x 0 , x 1 , x 2 ) in which x 0 and x 1 are similar with respect to aspect a but dissimilar with respect to aspect b, whereas x 0 and x 2 are similar with respect to aspect b but dissimilar with respect to aspect a. In some of our experiments (e.g., when disentangling sentiment from genre in movie reviews) we further assume that we have access to class labels y (a) ∈ {0, 1} and y (b) ∈ {0, 1} for as-pects of interest. In such cases, we build triplets using these labels, defining (x 0 , x 1 , x 2 ) such that y  Figure 1 illustrates the two forms of masking that we consider in our approach (we depict only a single linear layer of the model). Here h = (h (a) , h (b) ) are input activations, W are the weights in the pretrained model, 1 and h = (h (a) , h (b) ) are output activations. We augment each layer of the original network with two (binary) masks M = (M (a) , M (b) ), applied in one of two ways:

Masking Weights and Hidden Activations
1. Masking Weights Here masks M (a) and M (b) have the same shape as weights W , and outputs are computed using the masked weights tensor (1)

Masking Hidden Activations
In both methods, we follow (Zhao et al., 2020) and only mask the last several layers of BERT, leaving bottom layers unchanged. 2

Triplet Loss
To learn masks, we assume that we have access to supervision in the form of triplets, as introduced above. Passing (x 0 , x 1 , x 2 ) through our model yields two representations for each instance: Here α is a hyperparameter specifying a margin for the loss, which we set to α = 2 in all experiments.

Supervised Loss
In some settings we may have access to more direct forms of supervision. For example, when learning representations for the genre and sentiment in a movie review, we have explicit class labels y (a) and y (b) for each aspect. To exploit such supervision when available, we add classification layers C (a) and C (b) and define classification losses

Disentanglement Loss
To ensure that the two aspect representations are distinct, we encourage the masks to overlap as little as possible. To achieve this we add a term in the loss for each layer l ∈ L

Binarization and Gradient Estimation
The final loss of our model is L = λ trp · L trp + λ ovl · L ovl (+λ cls · L cls ). (8) We parenthetically denote the classification loss, which we only include when labels are available. We minimize this loss to estimate M (and classifier parameters), keeping the pretrained BERT weights fixed. Because the loss is not differentiable with respect to a binary mask, we learn continuous masks M that are binarized during the forward pass by applying a threshold τ , a global hyperparameter, We then use a straight-through estimator (Hinton et al., 2012;Bengio et al., 2013) to approximate the derivative, which is to say that we evaluate the derivative of the loss with respect to the continuous mask M at the binarized values M = M * ,  Table 1: Percentage of each class in the original dataset, and in the two subsets we sampled, one correlated training set and one uncorrelated test set. We train models on the former and test on the latter. This is meant to assess the robustness of models to shifts in spurious correlations that might exist in training data.

Experiments
We conduct a series of experiments to evaluate the degree to which the proposed masking strategy achieves disentanglement, as compared to existing methods for disentanglement in NLP. As a first illustrative example, we consider a corpus of movie reviews, in which sentiment is correlated with film genre (3.1). We treat this as a proxy for a spurious correlation, and evaluate the robustness of the models to shifts in conditional probabilities of one attribute (sentiment) given another (genre). We then consider a more consequential example: Hate speech classification on Twitter (3.2). Prior work (Sap et al., 2019) has shown that models exploit a spurious correlation between "toxicity" and African American Vernacular English (AAVE); we aim to explicitly disentangle these factors in service of fairness. We evaluate whether the model is able to achieve equalized odds, a commonly used fairness metric. Finally, following prior work, we investigate disentangling semantics from syntax (insofar as this is possible) in Section 3.3.

Disentangling Sentiment From Genre
Experimental Setup In this experiment we assume a setting in which each data point x has both a 'main' label y and a secondary (possibly sensi- tive) attribute z. We are interested in evaluating the degree to which explicitly disentangling representations corresponding to these may afford robustness to shifts in the conditional distribution of y given z. As a convenient, illustrative dataset with which to investigate this, we use a set of movie reviews from IMDB (Maas et al., 2011) in which each review has both a binary sentiment label and a genre label. We pick the two genres of movies that exhibit a strong correlation with review sentiment: Drama (reviews tend to be positive) and Horror (negative), excluding reviews corresponding to other genres and the (small) set of instances that belong to both genres. To investigate robustness to shifts in correlations between z and y we sampled two subsets from the training set such that in the first sentiment and genre are highly correlated, while in the second they are uncorrelated. We report the correlations between these variables in the two subsets in Table  1. We train models on the correlated subset, and then evaluate them on the uncorrelated set.
We compare the proposed masking approaches to several baselines. Untuned is a dense classification layer on top of BERT representations (without finetuning). In the finetuned variant we omit masks and instead minimize the loss with respect to BERT weights. In the adversarial model we adopt 'adversarial debiasing': In addition to minimizing loss on the main task, we train an adversarial classifier to predict the non-target attribute, and the encoder is trained to mitigate the adversaries' ability to do so. We implement this via gradient-reversal (Ganin and Lempitsky, 2015). We also compare to two variational autoencoder baselines: DRLST (John et al., 2019) is a VAE model with multi-task loss and adversarial loss; and DRLST-BERT is the same model, except we use BERT as the encoder  Table 2: Performance on the main task of sentiment analysis and gender information leakage. The DRLST baselines perform poorly on the main task; the proposed masking approaches have achieve comparable results for sentiment, and expose less genre information.
in place of a GRU (Cho et al., 2014).

Leakage of the Non-target Attribute
We evaluate the degree to which representations "leak" nontarget information. Following (Elazar and Goldberg, 2018), we first train the model to predict the main task label on the correlated dataset. Then we fix the encoder and train a single layer MLP on the uncorrelated dataset to probe the learned representations for the non-target attribute. Because this probe is trained and tested on only uncorrelated data, it cannot simply learn the main task and exploit the correlation. We report results for our proposed masking models and baselines in Table 2. We also report the results with genre classification as the main task and sentiment as the protected attributes in the Appendix (Section A.1). The DRLST baselines generally underperform, which translates to low leakage numbers but also poor performance on the main task. Compared to the baselines, our masking variants perform comparably with respect to predicting the main task label, but do so with less leakage.
Worst Group Performance In addition to nontarget attribute leakage, we measure how models perform on the main task for each subgroup: (Positive, Drama), (Positive, Horror), (Negative, Drama), and (Negative, Horror). Because the distribution of the four groups is unequal in the train set, we expect that models will perform better on attribute combinations that are over-represented in this set, and worse on those that are underrepresented, suggesting that the model is implicitly exploiting the correlation between these attributes. We report both the average and worst performance on the four subgroups; the latter is a proxy to measure robustness when subgroup compositions shift In the upper row we expect the points of the same color to be clustered together, but not the points with the same marker shapes, and for the lower row we expect the points of the same marker shapes to be clustered together, but not those of the same colors.
between the train to the test set. Figure 2 plots the results. We observe that the masking variants realize similar average performance as the baselines, but consistently outperform these in terms of worst performance. This indicates that the proposed variants rely less on the correlation between the two attributes when predicting the main label.
Qualitative Evaluation In Figure 3 we plot t-SNE visualizations (Maaten and Hinton, 2008) of the representations induced by different models. If the representations are disentangled as desired, instances with different sentiment will be well separated, while those belonging to different genres within each sentiment will not be separated.
Similarly, for genre representations, instances of the same genre should co-locate, but clusters should not reflect sentiment. No method perfectly realizes these criteria, but the proposed masking approaches achieve better results than do the two baselines. For instance, in the embeddings from the adversarial (Sentiment) and finetuned (Sentiment), instances that have negative sentiment but different genres (• and ×) are separated, indicating that these sentiment representations still carry genre information.

Disentangling Toxicity from Dialect
Experimental Setup In this experiment we evaluate models on a more consequential task: Detecting hate speech in Tweets (Founta et al., 2018).
Prior work (Sap et al., 2019) has shown that existing hate speech datasets exhibit a correlation between African American Vernacular English (AAVE) and toxicity ratings, and that models trained on such datasets propagate these biases. This results in Tweets by Black individuals being more likely to be predicted as "toxic". Factorizing representations of Tweets into dialectic and toxicity subvectors could ameliorate this problem.
We use (Founta et al., 2018) as a dataset for this task. This comprises 100k Tweets, each with a label indicating whether the Tweet is considered toxic, and self-reported information about the author. We focus on the self-reported race information. Specifically, we subset the data to include only users who self-reported as being either white or Black. The idea is that Tweets from Black individuals will sometimes use AAVE, which in turn could be spuriously associated with 'toxicity'. Similar to the above experiment, we sampled two subsets of the data such that in the first the (annotated) toxicity and self-reported race are highly correlated, while in the second they are uncorrelated (see Table 1). We train models on the correlated subset, and evaluate them on the uncorrelated set. This setup is intended to measure the extent to which models are prone to exploiting (spurious) correlations, and whether and which disentanglement methods render models robust to these.

Leakage of Race Information
We evaluate the degree to which representations of Tweets "leak"  Table 3: Performance on the main task of toxicity prediction, and leakage of race information. Compared to the baselines, the proposed approaches achieve performance competitive with or better than baselines, while minimizing leakage of the protected attribute.
information about the (self-reported) race of their authors using the same method as above, and report results in Table 3. We observe that the proposed masking variants perform comparably to baselines with respect to predicting the toxicity label, but leak considerably less information pertaining to the sensitive attribute (race).
Fairness Implications In addition to the degree to which representations encode race information, we are interested in how the model performs on instances comprising (self-identified) Black and white individuals, respectively. More specifically, we can measure the True Positive Rate (TPR) and the True Negative Rate (TNR) on these subgroups, which in turn inform equalized odds, a standard metric used in the fairness literature. We report the TPR and TNR of each model achieved over white and Black individuals, respectively, as well as the difference across the two groups in Figure 4. We observe that the proposed model variants achieve a smaller TPR and TNR gap across the two races (see rightmost subplots), indicating that performance is more equitable across the groups, compared to baselines.

Disentangling Semantics from Syntax
Experimental Setup As a final experiment, we follow prior work in attempting to disentangle semantic from syntactic information encoded in learned (BERT) representations of text. Because we have proposed exploiting triplet-loss, we first construct triplets (x 0 , x 1 , x 2 ) such that x 0 and x 1 are similar semantically but differ in syntax, while x 0 and x 2 are syntactically similar but encode different semantic information. We follow prior work Ravfogel et al., 2020) in deriving these triplets. Specifically, we obtain x 0 , x 1 from the ParaNMT-50M (Wieting and Gimpel, 2018) dataset. Here x 1 is obtained by applying backtranslation to x 0 , i.e., by translating x 0 from English to Czech and then back into English. To derive x 2 we keep all function words (from a list introduced in Ravfogel et al. 2020) in x 0 , and replace content words by masking each in turn, running the resultant input forward through BERT, and randomly selecting one of the top predictions (that differs from the original word) as a replacement.
We compare our disentanglement-via-masking strategies against models that represent state-of-theart approaches to disentangling syntax and semantics. In particular, we compare against VGVAE , though we implement this on top of BERT-base to allow fair comparison. Following prior work that has used triplet loss for disentanglement, we also compare against a model in which we finetune BERT using the same triplet loss that we use to train our model, but in which we update all model parameters (as opposed to only estimating mask parameters). To evaluate learned representations with respect to the semantic and syntactic information that they encode, we evaluate them on four tasks. Two of these depend predominantly on semantic information, while the other two depend more heavily on syntax. 3 For the semantics tasks we use: (i) A word content (WC) (Conneau et al., 2018) task in which we probe sentence representations to assess whether the corresponding sentence contains a particular word; and (ii) A semantic textual similarity (STS) benchmark (Nakov et al., 2013), which includes human provided similarity scores between pairs of sentences. We evaluate the former in terms of accuracy; for the latter (a ranking task) we use Spearman correlation. To evaluate whether representations encode syntax, we use: (i) A task in which the aim is to predict the length of the longest path in a sentence's parse tree from its embedding (Depth) (Conneau et al., 2018); and (ii) A task in which we probe sentence representations for the type of their top constituents immediately below the S node (TopConst). 4 Figure 5 shows the signed differences between the performance achieved on semantics-and syntax-oriented tasks by BERT embeddings (we mean-pool over token embeddings) and the 'syntax' representations from the disentangled models considered (see the Appendix for the analogous  Figure 4: True Positive and True Negative Rates achieved on white and Black individuals, respectively, and the (signed) difference between these (rightmost subplots). The proposed masked variants (cross-hatched) are more equitable in performance, while other methods tend to over-predict Tweets written by Black individuals as "toxic". plot for the 'semantics' representations in figure  A.2). Ideally, syntax embeddings would do well on the syntax-oriented tasks (Depth and TopCon) and poorly on the semantic tasks (WC and STS). With respect to syntax-oriented tasks, the proposed masking methods outperform BERT base representations, as well as the alternative disentangled models considered. These methods also considerably reduce performance on semantics-oriented tasks, as we would hope.
We emphasize that this is achieved only via masking, and without modifying the underlying model weights.

Identifying Sparse Disentangled Sub-networks for Semantic and Syntax
We next assess if we are able to identify sparse disentangled subnetworks by combining the proposed masking approaches with magnitude pruning (Han et al., 2015a). Specifically, we use the loss function defined in Equation 8 to finetune BERT for k iterations, and prune weights associated with the m smallest magnitudes after training. We then initialize masks to the sparse sub-networks identified in this way, and continue refining these masks via the training procedure proposed above. We compare the resultant sparse network to networks similarly pruned (but not masked). Specifically, for the latter we consider: Standard magnitude tuning applied to BERT, without additional tuning (Pruned + Untuned), and a method in which after magnitude pruning we resume finetuning of the subnetwork until convergence, using the aforementioned loss function (Pruned + Finetuned).
We compare the performance achieved on the semantic and syntax tasks by the subnetworks identified using the above strategies at varying levels of sparsity, namely after pruning: {0, 20%, 40%, 60%, 80%, 85%, 90%, 95%} of weights. 5 We report full results in Appendix Figure A.3, but here observe that combining the proposed masking strategy with magnitude pruning consistently yields representations of semantics that perform comparatively strongly on the semantics-oriented tasks (STS, WC), even at very high levels of sparsity; these semantics representations also perform comparatively poorly on the syntax-oriented tasks (Depth, TopCon), as one would hope. Similarly, syntax representations perform poorly on semantics-oriented tasks, and outperform alternatives on the syntaxoriented tasks. In sum, this experiment suggests that we are indeed able to identify sparse disentangled subnetworks via masking.

Related Work
Disentangled and structured representations of images. The term disentangled representations has been used to refer to a range of methods with differing aims. Much of the initial focus in this space was on learning representations of images, in which certain dimensions correspond to interpretable factors of variation (Kulkarni et al., 2015;Higgins et al., 2017;Chen et al., 2016). In the context of variational autoencoders (Kingma and Welling, 2014; ?) this motivated work that evaluates to what extent such representations can recover a set of ground-truth factors of variation when learned without supervision (Eastwood and Williams, 2018;Kim and Mnih, 2018;Chen et al., 2018). Other work has investigated representations with the explicit motivation of fairness (Locatello et al., 2019a;Creager et al., 2019), which disentanglement may help to facilitate.
Disentangling representations in NLP. Compared to vision, there has been relatively little work on learning disentangled representations of text.
Much of the prior work on disentanglement for NLP that does exist has focused on using such representations to facilitate controlled generation, e.g., manipulating sentiment (Larsson et al., 2017). A related notion is that of style transfer, for example, separating style from content in language models Shen et al. (2017);Mir et al. (2019). There has also been prior work on learning representations of particular aspects to facilitate domain adaptation (Zhang et al., 2017), and aspect-specific information retrieval (Jain et al., 2018). Esmaeili et al. (2019) focus on disentangling user and item representations for product reviews. Moradshahi et al. (2019) combine BERT with Tensor-Product Representations to improve its transferability across different tasks. Recent work has proposed learning distinct vectors coding for semantic and syntactic properties of text Ravfogel et al., 2020); these serve as baseline models in our experiments.
Finally, while not explicitly framed in terms of disentanglement, efforts to 'de-bias' representations of text are related to our aims. Some of this work has used adversarial training to attempt to remove sensitive information (Elazar and Goldberg, 2018;Barrett et al., 2019).
Network pruning. A final thread of relevant work concerns selective pruning of neural networks. This has often been done in the interest of model compression (Han et al., 2015a,b). Recent intriguing work has considered pruning from a different perspective: Identifying small subnetworks -winning 'lottery tickets' (Frankle and Carbin, 2019)that, trained in isolation with the right initialization, can match the performance of the original networks from which they were extracted. Very recent work has demonstrated that winning tickets exist within BERT (Chen et al., 2020).

Discussion
We have presented a novel perspective on learning disentangled representations for natural language processing in which we attempt to uncover existing subnetworks within pretrained transformers (e.g., BERT) that yield disentangled representations of text. We operationalized this intuition via a masking approach, in which we estimate only binary masks over weights or hidden states within BERT, leaving all other parameters unchanged. We demonstrated that -somewhat surprisingly -we are able to achieve a level of disentanglement that often exceeds existing approaches (e.g., a varational auto-encoder on top of BERT), which have the benefit of finetuning all model parameters.
Our experiments demonstrate the potential benefits of this approach. In Section 3.1 we showed that disentanglement via masking can yield representations that are comparatively robust to shifts in correlations between (potentially sensitive) attributes and target labels. Aside from increasing robustness, finding sparse subnetworks that induce disentangled representations constitutes a new direction to pursue in service of providing at least one type of model interpretability for NLP. Finally, we note that sparse masking (which does not mutate the underlying transformer parameters) may offer efficiency advantages over alternative approaches.

Acknowledgements
This work was supported by that National Science Foundation (NSF), grant 1901117.

A.1 Additional IMDB Results
In

A.2 Distribution of Learned Masks Across BERT Layers
Here we inspect the subnetworks (i.e., the weights or hidden activations that are not masked) uncovered by our model, which may provide insights regarding where pretrained (masked) language models encode different sorts of linguistic information. Figure A.1 shows the distributions of the two types of masks (weights and hidden activations, respectively) over the layers within BERT for for the semantics/syntax tasks. We observe that the learned 'semantic' mask zeros out fewer elements at higher layers in the network, while the 'syntax' mask prefers to keep non-zero entries in lower layers. This suggests that semantic information may be captured mostly in higher layers of BERT, while syntactic information may be encoded in lower layers, consistent with observations in prior work (Tenney et al., 2019).

A.3 Semantic Representation Performance (vs. BERT)
We show the signed differences between the performance achieved on semantics-and syntax-oriented tasks by BERT embeddings (we mean-pool over token embeddings) and the 'semantic' representations from the disentangled models in figure A.2.   Figure A.2: Differences between the performances achieved via BERT embeddings and the disentangled model variants considered on semantics-oriented (WC, STS) and syntax-oriented (Depth, TopCon) tasks compared with BERT embeddings. We plot this difference with respect to the semantic embeddings induced by the models.

Representation: Semantic
Representation: Syntax Figure A.3: Model performance as a function of the degree of pruning. The x-axis corresponds to the subnetwork sparsities (percent of weights dropped), while the y axes are performance measures -accuracy for all tasks except for STS, where we report Pearson's correlation. We compare the performance of models trained on the semantic (top) and syntax representations (bottom) learned by the disentangling strategies considered, after pruning to varying levels of sparsity.

A.4 Model performance with iterative magnitude pruning
We report full results of combining our method with magnitude pruning to uncover sparse subnetworks in Figure A.3. We compare our method to several alternative pruning strategies: Standard magnitude tuning applied to BERT, without additional tuning (Pruned + Untuned), and a method in which after magnitude pruning we resume finetuning of the subnetwork for a fixed number of steps, using the aforementioned loss function (Pruned + Finetuned).

A.5 Additional Experiments:Perturbation Study of Hyper-parameters
We report model performance when masking different number of layers of BERT (Table A.2) and when choosing different values for α( Table A.3).
A.6 Model performance with varying degree of correlation in the training set We report the comparison of our model (Masking Weights) with two baselines (Finetuned and Adversarially trained BERT) with varying degree of correlation in the training set. The task is sentiment classification and we control the correlation between sentiment and genre into 3 different settings: strong, moderate and weak (if any) correlation. We report the results in Table A.4. Our model significantly outperforms the baselines when the correlation is strong, and the advantage begins to diminish as the correlation becomes weaker, as we would expect.