An Invariant Learning Characterization of Controlled Text Generation

Controlled generation refers to the problem of creating text that contains stylistic or semantic attributes of interest. Many approaches reduce this problem to training a predictor of the desired attribute. For example, researchers hoping to deploy a large language model to produce non-toxic content may use a toxicity classifier to filter generated text. In practice, the generated text to classify, which is determined by user prompts, may come from a wide range of distributions.In this paper, we show that the performance of controlled generation may be poor if the distributions of text in response to user prompts differ from the distribution the predictor was trained on. To address this problem, we cast controlled generation under distribution shift as an invariant learning problem: the most effective predictor should be invariant across multiple text environments. We then discuss a natural solution that arises from this characterization and propose heuristics for selecting natural environments.We study this characterization and the proposed method empirically using both synthetic and real data. Experiments demonstrate both the challenge of distribution shift in controlled generation and the potential of invariance methods in this setting.


Introduction
The development of large language models (LLMs) has been paradigm-shifting. Simply by conditioning on some well-thought-out prompts, LLMs can be adapted to new tasks or distributions [28,20,31,24,3,5]. This increase in adaptability has led to a greater need for control -in order to deploy these models safely, we need to be able to control their generation. Increases in adaptability also presents new challenges to control, as we now need control methods that work for different tasks and distributions.
A major challenge of controlled text generation is attribute misalignment, in which the controlled model outputs text that is incompatible with the desired attribute. Many methods have been proposed for controlled generation, ranging from re-training [8,15,22], finetuning [37,29], weighted decoding [6,34], to filtering at inference time [30]. Unfortunately, given certain prompts, controlled models can still produce text that is not aligned with the desired attributes [9]. Thus, it remains unclear when we can expect these control methods to work.
The purpose of this paper is to take a step toward a principled understanding of the attribute misalignment problem in controlled generation. We start from a simple probabilistic formulation of controlled generation, where rejection sampling is used to obtain the controlled output. We further posit that the problem of attribute misalignment could be caused by distribution shift. We highlight that each prompt effectively induces a new distribution over text and there may be an exponential number of possible distributions. Building on the proposed characterization, we show that solving attribute alignment hinges on solving an invariant learning problem between the text representation and the desired control variable. Finally, we employ a commonly used method for invariant learning [17], demonstrating the challenge of successfully learning an invariant representation in text.
While this paper is only a first attempt to connect controlled generation with invariant learning, establishing this connection has two important benefits. First, we can apply principled methods from the invariant learning literature to controlled generation. Furthermore, controlled generation can provide new datasets and application areas for these invariant algorithms.
Our contributions are (1) Identifying and characterizing distribution shift problems in controlled text generation; (2) Providing a solution using methods from invariant learning; (3) Providing a proof of concept for controlling LLMs by using invariant text classifiers.

Controlled Generation
The goal of controlled generation is to produce text that is consistent with certain attributes. Formally, define a target distribution of text p(x) and a binary attribute y that relates to the text by p(y|x). Throughout this paper, we focus on the case where y corresponds to toxicity: y = 1 denotes that text contains toxic content, while y = 0 denotes non-toxic text. 2 Assume that text samples and attribute labels have been collected from a training distribution (x i , y i ) ∼ q(x, y). The goal of controlled distribution is to parameterize a distribution p θ such that, There are many ways to approximate Eq. 1. Most prior work has focused on the case where the training and target distributions are the same (i.e. p(x, y) = q(x, y)). In this case, one line of work has focused on strategies for modeling p(x | y = 0) directly [13,36,15,37,9,11].
We focus on another approach which makes use of Bayes' rule, Prior work that uses this paradigm either modifies the model activation [6,18] or develops weighted decoding methods [6,16,18,34]. This perspective is useful if the target distribution p(x) is large or difficult to modify, because controlled generation reduces to modeling the binary distribution p(y | x).
Specifically, if f θ (x) is a binary classifier that models p(y | x), it can be used to filter toxic samples from the target distribution p(x): where δ is some predetermined threshold. Eq. 3 can be approximated by rejection sampling [33,30].
We focus our analysis on the Bayesian perspective in Eq. 2. Reasoning about distribution shift under this setting has two advantages. First, as discussed above, the controlled generation problem reduces to building a classifier. Thus, practitioners can reason about how errors from the predictor might propagate to the controlled distribution. A second benefit of this formulation is that if Eq. 3 is approximated by rejection sampling, fluency is preserved because the model's likelihood is not being modified. We can reason about control without worrying about its trade-off with fluency.

An Invariant Learning Characterization of Controlled Generation
In this section, we will examine how distribution shifts could cause controlled language model output texts that are not aligned with desired attributes. We will discuss a possible solution to the problem, the conditions under which the solution is valid, and the challenges that remain to be addressed.
The Problem. Using the toxicity example, a simple measure of attribute alignment is the likelihood that toxic text will appear from the controlled distribution [9], We can decompose this probability as follows: the first term is an indicator of toxicity, the second term is the probability the text is toxic, and the third term is the controlled distribution.
Recall that in the setup in § 2, we use data from q(x, y) to fit the predictor f θ . Now, we replace the identity function in Eq. 4 with a loss function l(x, y). Let R p (θ) = E p(x,y) [l(f θ (x), y)] denote the risk of the predictor under p(x, y). If p(x, y) is different from q(x, y), the predictor that minimizes the risk under p(x, y) could be different from the optimal predictor for q(x, y), An implication of Eq. 5 is that the spurious correlations learned during training might introduce unintended biases in the controlled generation process. For example, toxicity classification methods learn spurious correlations between minority groups and toxicity, which consequently leads to a reduction in the LM's ability to generate non-toxic text about minorities [32].
A Solution. Finding a predictor that minimizes R p (θ) is challenging because we only observe p(x). Invariant learning [21,1,25,17,19] is a class of methods that address this problem. It posits that our observed training data often contains samples from multiple distributions, sometimes also called "environments." If we can learn a predictor that is equally optimal across environments, i.e., the performance of the predictor is invariant to which environment it is in, the predictor may also generalize to the target environment p(x).
In more details, let E = {e 1 , ..., e m } denote a set of training distributions. R e (θ) denotes the empirical risk of function f θ for probability distribution q e (x, y). The goal is to find a predictor f θ that is invariant and optimal across environments. The corresponding optimization objective is Conditions for Generalization. When the target and the training environments overlap, we might expect the predictor to generalize if there is an invariant relationship between the target y and the text x. A1 and A2 formally define this intuition.

A1. Causal Sufficiency
There exist a function f θ , such that A1 effectively assumes that term 2 in Eq. 4 does not change across environments.
A2. Overlap Let supp(p) be the support of p and supp(q e ) be the support of q e . We assume A2 assumes that a text in the target distribution should have a non-zero probability of appearing in the training corpus. For example, suppose the training distribution contains only English text, but the target distribution contains Chinese characters, then the toxicity predictor may not generalize.
Challenges. We have cast the controlled generation problem as an invariant prediction problem, but there are still many conceptual and technical challenges to overcome.
Invariant predictors are usually developed when variables are well-defined or when some features or relationships are known to be spurious. However, in controlled generation, the attributes we wish to align our model to are often subjective and poorly defined. This has two implications. First, it is difficult to determine when A1 and A2 might hold. For example, Sap et al. [26] found that perceived language toxicity may be influenced by context, identity, and belief, indicating that an invariant predictor based on text alone may not exist. Second, it is unclear what counts as a valid environment. Valid environments are defined with respect to the causal generation process of the target attributes [21,1,27,4,35]. An under-defined attribute leads to an under-defined causal graph, making it difficult to reason about valid environments.
On a technical level, solving the optimization problem in Eq. 6 is challenging. Various algorithms have been proposed to approximate Eq. 6. However, the performance of these methods varies across tasks and across different deployment distributions [10].   Table 1: The invariant predictors are more similar to Perspective API and are more effective at filtering out toxic text. We report the average toxicity score after filtering at various thresholds and cross-entropy loss for ERM and V-REx based on different environment splits.
For the experiment in § 4, we will use the Variance-REx algorithm [17]. The optimization objective is Eq. 9 is a widely applied approximation to the constraint optimization in Eq. 6.

Empirical Studies
We study how distribution shift affects attribute alignment empirically using a toxicity dataset.
Experiment Setup. To approximate the training distribution, we use the CivilComments dataset [2]. The dataset contains the archives of the CivilComments platform, where comments posted by users are annotated for toxicity. In addition to toxicity, this dataset also contains metadata for each comment, such as identity attributes mentioned, comment created date, and the number of identity attribute annotators. We use the metadata to create three specifications for binary environments q e .
To approximate the target distribution p(x), we select 40 prompts of varying toxicity level from the RealToxicityPrompts dataset [9]. The dataset contains 100K natural sentence-level prompts from a web corpus paired with toxicity scores computed by Perspective API, a widely used commercial toxicity model. 3 Specifically, we sample 10 prompts from each quartile of toxicity score. Given each prompt, we generate K = 100 continuations using GPT-2 [23]. Following Gehman et al. [9], we use nucleus sampling [12] with p = 0.9 to generate up to 20 tokens.
The predictors are trained by finetuning BERT [7] on a subset of the CivilComments dataset. The ERM predictor uses cross-entropy loss. The invariant predictors optimize the V-REx objective in Eq. 9. For the β parameter, we consider four values, (0, 10, 50, 100). Table 1 reports the results when β = 10. More experiment details and additional results are in Appendix A.
Evaluation. We use the 4K generated continuations to evaluate the predictors on their ability to detect out-of-distribution toxicity. For automatic evaluation, we use Perspective API as a proxy for ground truth. 4 The two performance metrics we consider are cross-entropy loss and the average toxicity score after filtering out generations with a higher toxicity score than a threshold δ according to the predictor. The second metric connects training a classifier back to generation via Eq. 4. We additionally evaluate Perspective API itself as a predictor to estimate a lower bound on the ideal performance.
Analysis. Table 1 illustrates the promise and challenges of applying invariant learning methods to controlled generation. By simply splitting the dataset using metadata that is otherwise discarded in the prediction task, we can train invariant predictors that are better at filtering out toxic content and more similar to Perspective API, the proxy ground truth. The invariant predictors, however, differ from one another. This suggests that "environment" plays an important role in the generalizability of invariant methods. Defining what constitutes a valid environment when the target attribute is underdefined is an important area of future research.

A Experiment Details & Additional Results
CivilComments. To fit the predictors, we use the CivilComments dataset [2]. This dataset contains the archives of the CivilComments platform, where comments posted by users are annotated for toxicity. We selected a subset of examples for the experiments by randomly sampling 38K examples that had labeled identity attributes. We created train, validation, and test sets according to an 80-10-10 split. The hyperparameters were cross-validated on the validation set.
In addition to toxicity, this dataset also contains metadata for each comment, such as identity attributes mentioned (e.g., binary variables for male, female, LGBT, black, white, Asian, etc.), comment created date, and the number of identity attribute annotators. We use the metadata to create three specifications for binary environments. To obtain a scalar representation of the identity attributes mentioned for each comment, we sum the identity variables to create the "identity attribute sum" feature. Our other two specifications use the "created date" and "identity annotator count" features, respectively. For each feature corresponding to a specification, we use the feature's median value in the training set to partition the data into two environments.
Hyperparameters. To train the ERM predictor, we initialize from pre-trained BERT-base and finetune on a subset of CivilComments using 4 GPUs with a batch size of 16 and learning rate of 5e-5 (AdamW optimizer) for 2 epochs. We use a linear warm-up for the first 10% of steps and then linearly decay the learning rate to zero in the remaining steps. We use the same settings as the baseline to train the invariant predictors. We did not do hyperparameter search for the invariant regularizer β, as we only had two environments. In

B Detoxified Text Examples
In Table 3, we randomly select prompts and GPT-2 continuations out of the 100 with the largest positive difference between the V-REx prediction and baseline prediction. We additionally conditioned on the toxicity score of the prompt being in [0.25, 0.5), as we found that without conditioning, many of GPT-2's continuations were extremely toxic. We are obliged to fight those who are burning our dwellings, to hit with full force those who, are destroying our homes," he wrote, according to his notes. "I will defend and defend my people.
0.08 0.13 0.13 0.30 0.13 We are obliged to fight those who are burning our dwellings, to hit with full force those who, are burning our forests."\n\nHe said that this was one of the biggest political protests in modern times and he expressed great respect for the protesters and  Table 3: Five examples of GPT-2 completions and corresponding predicted toxicity.