Counterfactuals to Control Latent Disentangled Text Representations for Style Transfer

Disentanglement of latent representations into content and style spaces has been a commonly employed method for unsupervised text style transfer. These techniques aim to learn the disentangled representations and tweak them to modify the style of a sentence. In this paper, we propose a counterfactual-based method to modify the latent representation, by posing a ‘what-if’ scenario. This simple and disciplined approach also enables a fine-grained control on the transfer strength. We conduct experiments with the proposed methodology on multiple attribute transfer tasks like Sentiment, Formality and Excitement to support our hypothesis.


Introduction
Counterfactual Reasoning (Bottou et al., 2013) is leveraged in structured data analysis and econometrics towards generation of alternatives and estimation of alternate scenarios. Counterfactuals describe a causal situation of the form 'If X would have (not) occurred, Y would have (not) occurred' (Molnar, 2019). In interpretable machine learning, counterfactuals have been used to explain predictions of individual instances across various types of datasets and tasks (Neal et al., 2018;Martens and Provost, 2014;Wachter et al., 2017). Laugel et al.(2018) and Neal et al.(2018) use counterfactuals towards generating training data. Counterfactual reasoning also provides us with a unique ability to generate explanations and make causal analysis on the latent space. However, this technique has never been explored in natural language generation tasks. Here, we plug-in the concept of counterfactuals to the text-style transfer task, to enable the manipulation of latent spaces towards controlled transfer of style. * Work done while authors were at Adobe Research.
Existing works in text style transfer focus on transferring a specific target attribute. Unsupervised methods based on adversarial attacks (Fu et al., 2018;she), back translation (Prabhumoye et al., 2018), learning disentangled representations (John et al., 2019) have been popular in this domain. Other techniques include deletion of stylespecific words and conditionally generate sentences in the target style (Li et al., 2018;Sudhakar et al., 2019). However, all of them fail to provide a control over the target style strength i.e. a clever manipulation of the latent space is non-trivial.
Recent works on controlled text generation include (Wang et al., 2019), which brings in a transformer-based model that modifies the gradient functions leading to controlled generation in the output space. Jin et al. (2019) is an unsupervised approach integrated during end-to-end model training. The drawback in all these efforts is the lack of a prefixed logic towards controlling the latent space. Our proposed method of counterfactuals fills in this gap and provides a logical method to control the latent spaces for enabling a smooth style transfer.
Our approach is based on the premise of disentangled representation spaces inspired from John et al. (2019). Separating out the style and content representations introduce an opportunity to finetune, resulting in the ability to control the output sentences specific to style. We introduce a counterfactual reasoning module for controlling latent disentangled spaces for style transfer. Figure 1 shows an illustrative example for the variants generated through our approach. To the best of our knowledge, this is the first work leveraging such a concept towards controlled text generation. Through extensive quantitative and qualitative experiments, across attributes and datasets, we conclude that the proposed approach is effective in providing control over the style strength and also shows that the best transfer performance is on par with the existing baseline style-transfer techniques.

Approach
Figure 2 illustrates our proposed approach, that incorporates counterfactual reasoning to latent disentangled representations for manipulating style in text. It consists of (1) A Variational Autoencoder (VAE) model to learn the disentangled style and content representations for different stylistic attributes, (2) A Counterfactual Reasoning Module to control the latent representations for generating style variants.

Learning Disentangled Representations
We adopt the model described in (John et al., 2019) for learning the disentangled content and style representations. Here, a VAE with an encoder-decoder is used to encode a sentence x into a latent distribution H = q E (h|x), guided by the loss function: where, θ E and θ D are the encoder and decoder parameters respectively. The first term encourages reconstruction, while the second term regularizes the latent space to a prior distribution p(h) (N (0, 1)). We experiment with some variations of this architecture, which are detailed in section 3.
Additionally, Multi-Task (J mul(s) , J mul(c) ) and Adversarial losses (J adv(s) , J adv(c) ) are imposed on the latent space h to disentangle the embeddings into representing content c and style s, i.e., h = [s; c], where [; ] denotes concatenation. These four losses ensure that the style and content information are present in, and only in their respective style(s) and content(c) embeddings.
Once we have the disentangled representations, our basic idea is to feed the generative model with the same content and a different style embedding to produce sentences of altering style. In (John et al., 2019), the average style embeddings of the target style is fed to the decoder. Intuitively, changing these style embeddings will produce different variants of target style sentences, but a disciplined approach to generate smooth style variants of the sentence is missing. We propose the counterfactual reasoning for this purpose.

Counterfactual Reasoning Module
Counterfactuals (CF) are used for gradually changing the style representation along the target-style axis. A counterfactual explanation of an outcome Y takes the form 'if X had not occurred, Y would not have occurred'. We leverage this notion here. A Multi-layer Perceptron (MLP) classifier is trained on the disentangled style latent representations learnt by the VAE, such that every instance of style embedding s, predicts a target style (T ) of a sentence. Now, the aim is to find s such that it is close to s in the latent space but leads to a different prediction T , i.e. the target class. The CF generation loss is given by, where t is the desired target style class for s , p t is the probability with which we want to predict this target class (perfect transfer would mean p t = 1), f t is the model prediction on class t and L 1 is the distance between s and s. The first term in the loss guides towards finding an s that changes the model prediction to the target class and use of the L 1 distance ensures that minimum number of features are changed in order to change the prediction. λ is the weighting term. The resulting set of CFs are obtained by optimizing (Wachter et al., 2017) the following equation:arg min s max λ L(s |s), subject to |f t (s − p t )| ≤ (tolerance parameter).
The CF generator is generalizable across different stylistic attributes. To generate multiple variants for a target style, CFs are generated varying the probability of target specific generation (or confidence), p t . This results in different sentence variants with a similar target style but varied degrees for transfer strength. Finally, the disentangled representations enable finer control over the style dimensions with no risk of content loss during the counterfactual reasoning stage (as the content representations are retained).

Proposed models
The VAE model adapted from (John et al., 2019), with RNN encoder-decoder blocks is R-VAE. We experiment with a variation by replacing RNNs with the transformer blocks (T-VAE). T-VAE-CF uses counterfactuals for generating variants, while models with -AVG use average style embedding of the target style to enable transfer. For T-VAE, we experimented with different loss combinations.-1,-2,-3,-4 refers to the inclusion of J mul(s) , J mul(s) + J adv(s) , J mul(s) + J adv(s) + J mul(c , J mul(s) + J adv(s) + J mul(c) + J adv(c) , respectively along with J V AE in the overall loss function.

Baselines
We compare our best transfer models (with p t ≈ 1) against standard unsupervised style-transfer approaches. CrossAligned (CA) (Fu et al., 2018) aligns the hidden representations of original and style transferred sentences. T-D and T-DRG (Sudhakar et al., 2019) models delete attribute related words and conditionally generate words with the target style through transformer architecture.

Implementation
The counterfactual module has a linear classifier with a sigmoid activation, taking input dim. of 16 (s) and a output dim. 2 (style label). It is trained with Adam optimizer and 0.001 learning rate is used to minimize CCE loss. The transfer strength in CF-module, p t , is varied from 0 to 1. Experiments with the following values (0.2, 0.3, 0.5, 0.5, 0.8, 0.9, 0.95, 1.0) are reported. *

Datasets
We experiment with varied style attributes using 5 datasets. YELP is used for sentiment. Human gold standard references of these datasets from (

Results and Analysis
Transfer Control. Figure 3 shows the performance of CF variants across metrics for different styles. The CF generated variants from T-VAE-CF (solid lines) are compared against the reference values which take avg. embeddings (T-VAE-AVG) for target style (dotted lines). To recollect, the higher the CF transfer confidence (strength), the closer is the generated variant to the target attribute. Thus, the ideal performance is to have the highest accuracies for the highest CF confidence values (see figure 3(a)). Note that CF strength = 1 alludes to perfect transfer. This is difficult to achieve as CF in the representation space may not be generated cluding stopwords (Fu et al., 2018) for such a strict target. Hence, the variants generated with near perfect transfer target (CF strength = 0.8,0.9,0.95) show the best performance across metrics. The low transfer accuracies for models with low CF confidence establishes the ability of the model to stay near the source when the target strength is low. All models implemented with transfer control report improved performance w.r.t BLEU scores establishing the utility of the alternatives generator. Table 2, 3 compares baselines with the proposed models. Note that the evaluation metrics for text style-transfer cannot be compared in isolation. There is always a trade-off between content preservation and transfer accuracy. Amongst the baselines, we observe that T-D and T-DRG report high content preservation with some loss in accuracy, but these models only cater to generating a single output sentence and there is no provision to generate the variants. Note that in most style dimensions, T-VAE based models show highest performance in transfer accuracy with good content preservation (CP), but, lower BLEU-S score. The lower BLEU-S scores indicates the ability of our model to generate variants that are not mere repetition of the input samples. R-VAE models show impressive perplexity values. For the political dataset, R-VAE baseline shows very high transfer accuracy but takes a tremendous hit in content preservation (BLEU), which is improved with the use of counterfactuals. Examples in Table  1 illustrate the gradual changes introduced by T-VAE-CF across different styles.
Human Evaluation: We conducted a crowdsourcing based experiment (through Amazon Mechanical Turk) to understand both -(A) How baselines compare to the generated text and (B) The interpretation of control as seen by human annotators. For the first experiment, the annotators were presented with sentences generated by our model, baselines and ground truth to evaluate and rank. Specifically, they were asked to score each of the output sentences on a Likert scale of range 1-5 across three aspects -transfer strength, content preservation and fluency. The key takeaways highlight that the sentences generated by our model are at par in terms of grammar and fluency and are better in terms of transfer control. As against text generated by baselines, the text generated by our proposed models is preferred by humans 70% of times (inter-annotator agreement 0.42).
For the second experiment to evaluate the control, we presented the sentence variants generated through different CFs (by varying p t ) and asked the annotators to rank them from best to worst based on their transfer strength. On an average, 60% individuals could grade the gradual control as intended by the model. If we bucket the sentences into low (with p t < 0.4) and high groups (with p t > 0.7), the annotators' preference for bucketing the output into the right confidence goes up to 73% on average (68% for low, and 81% for high), hence, confirming our hypothesis towards using CF for controlled generation.

Conclusion
We introduce the use of counterfactual reasoning towards controlling the latent disentangled representations for text style transfer. Experiments not only establish the superiority of the proposed models across standard metrics for a multitude of styles but also illustrate the utility of the gradual control variable in this model. We further validate the use for CF via a human evaluation establishing improved text attribute transfer. A VAE Models -Further Details RNN-based (R-VAE). We adopt the model described in John et al. (2019) to disentangle the content and style representations with a recurrent neural network (RNN)-based VAE. The RNN encoder with Bi-GRUs (Cho et al., 2014) learns the hidden representation q E (h|x) by reading the input x = (x 1 , x 2 , ..., x n ) sequentially. The RNN decoder, then decodes sequentially over time, predicting the probabilities of each token conditioned on the previous tokens and the latent representation. The reconstruction loss, which is the key loss for the generation objective, is the negative-loglikelihood loss as follows: The hidden space, h, is separated into 2 spaces while disentangling the style (s) and content (c) representations. Disentanglement is achieved using well-defined auxilliary losses.

Transformer-based (T-VAE).
Transformers (Vaswani et al., 2017) have gained popularity for text generation due to their robust architectures. We introduce a transformer-based VAE inspired from Wang et al.(2019). The transformer encoder has a multi-headed self-attention block followed by a feed forward network (FFN). The decoder is similar to the encoder with an additional encoder-decoder attention block. Given an input sentence x = (x 1 , x 2 , ..., x n ), the transformer encoder, E trans learns a hidden word representation (z 1 , z 2 , ..., z n ). They are pooled to get a sentence representation z, which is further encoded into a probabilistic latent space q E (h|x). A sample from this latent representation is given as an input to the encoder-decoder attention block in the decoder. The decoder reconstructs the input sentence x with condition on h. We adopt the label smoothing regularization (Li et al., 2020) while training, for performance improvement. The reconstruction loss (J REC ) is : where, v is the vocabulary size, is the label smoothing parameter, p i andp i are the predicted and the ground truth probabilities over the vocabulary at every time step for word-wise decoding. KL Annealing. We also use an Adam optimiser and KL cost annealing technique (Bowman et al., 2016) to train our model. KL cost annealing refers to slow increase in the weight of the KL term (λ kl ) in the loss function from 0 to 1. This aids the training process as the model is warm-started to minimize the reconstruction loss in the initial iterations, followed by a gradual inclusion of KL loss term in the subsequent iterations.

A.1 Loss Functions
Auxiliary loss functions are used to achieve the text rewriting objectives. Note that the reconstruction loss is the primary loss generation but this does not take into consideration the style or the controlled generation.
We use Multi-task and Adversarial losses on the latent space h to disentangle the embeddings into representing content c and style s (i.e., h = [s; c], where [; ] denotes concatenation) separately. Style-oriented losses. Multitask Loss ensures that the style space s is discriminative for the style. We train a style classifier on s jointly with the autoencoder loss.
J mul(s) (θ E ; θ mul(s) ) = − l∈labels t s (l) log(y s (l))  where θ mul(s) are the parameters for style multitask classifier, y s is the style probability distribution predicted by the classifier and t s is the ground truth style distribution. Adversarial loss for style is introduced to ensure that the content space c is not-discriminative of the style. An adversarial classifier is trained, that deliberately discriminates the true style label using the content vector c, with the following loss.
where θ dis(s) are the parameters for style adversary, y s is the style probability distribution predicted by the classifier on the content space.The encoder is then trained to learn a content vector space c, from which its adversary cannot predict style information. The objective is to maximize the cross entropy H(p) = − i∈labels p i log(p i ) with: J adv(s) (θ E ) = H(y s |c; θ dis(s) ) Content-oriented losses. Multi-task loss aims to ensure that all content information is in the content space c. We define the content information using a bag-of-words (BoW) concept. Here, partof-speech tags , i.e. nouns are used. (Liu et al., 2020; DBL) argue nouns in the text are considered as attribute-independent content. This definition allows a generic content loss for all style dimensions as against the previous work where content is defined as bag-of-words in a sentence, excluding stopwords and specific style (sentiment) related lexicon. The content multitask loss is analogical to style multitask loss as follows:

B Dataset details
The brief descriptions for datsets are as follows: YELP: Reviews from Yelp. Each review is labeled with a sentiment class -positive or negative. The task is to change the label while rewriting. GYAFC: Corpus created from a subset of Yahoo Answers. Each sample is tagged either formal or informal. The task is to switch the label. GYAFC-Excitment: The task here is to convert the sentences from 'exciting' to 'non-exciting'. We create a subset of the GYAFC data where annotators (using Amazon Mechanical Turk), were asked to tag the sentence to be either showing excitement or not. Excitement follows the definition as given by (Aaker, 1997). We follow annotation scheme provided by Rao (2017). POLITICAL: Comments from Facebook posts from United States Senate and House members. Each comment is labelled is with either Republican or Democrat tag. Task is to interchange between the two. GENDER: Reviews from Yelp for food businesses.
Each review is labeled with either male or female based on the author of the review. Task is to switch between the two. Table 4 refers to the number of sentences in traindev-test split available for each dataset. The URL link to the data files are also provided for each of them.

C Implementation details
The dimensions of c and s are set to 128 and 16 respectively. The posterior probability distributions (µ, σ) learnt for the respective content and style also have the same dimensions. The learnt hidden state representation is converted to 128 (c) and 16 (s) with a linear layer.
For R-VAE, hidden state dimension is set to 256. For the T-VAE, the embedding size, latent layer and the self-attention layers all are set to 256. The inner dimension of FFN in the transformer is set to 1024. Each of the encoder and decoder is stacked with two layers of transformer blocks. We used the Adam optimizer for the VAE and the RMSProp optimizer for the discriminators, following stability tricks in adversarial training (Arjovsky and Bottou, 2017). Each optimizer has an initial learning rate of 10 −3 . Models are trained for 50 epochs. Figure  4 illustrates the architecture of T-VAE.
Word embeddings initiated with word2vec (Mikolov et al., 2013) are trained on respective training sets. Both, the autoencoder and the discriminators are trained once per mini batch with λ mul(s) , λ mul(c) , λ adv(s) , and λ adv(c) = 1. The label smoothing parameter in the transformer loss is set to 0.1. The KL-Divergence penalty is weighted by λ kl (s) and λ kl (c) on style and content, respectively. During training, we also used the sigmoid KL annealing schedule The hyper-parameter weights in the loss function λ mul(s) , λ mul(c) , λ adv(s) , and λ adv(c) are chosen to be 1, as the values were Observed to be converging over iterations.
We implement our model based on Pytorch 0.4. We trained our models on a machine with 4 NVIDIA Tesla V100-SXM2-16GB GPUs. On a single GPU, our transformer model with all the losses (T-VAE-4) took approximately 0.4 s to train for one step with a batch of size 128. It takes around 10 hours to train our model on 1 GPU. Table 5 depicts the runtime details for all the model variations.
For our counterfactual generator model, we use the counterfactual model from Alibi library in Python § . On an average it takes 3 seconds to generate a counterfactual for a given input representation and transfer strength (p t ). § Alibi Counterfactual Module   Further details of our model summary and generated sentences are present here : https://bit.ly/34DYHP5