Adversarial Regularization as Stackelberg Game: An Unrolled Optimization Approach

Adversarial regularization has been shown to improve the generalization performance of deep learning models in various natural language processing tasks. Existing works usually formulate the method as a zero-sum game, which is solved by alternating gradient descent/ascent algorithms. Such a formulation treats the adversarial and the defending players equally, which is undesirable because only the defending player contributes to the generalization performance. To address this issue, we propose Stackelberg Adversarial Regularization (SALT), which formulates adversarial regularization as a Stackelberg game. This formulation induces a competition between a leader and a follower, where the follower generates perturbations, and the leader trains the model subject to the perturbations. Different from conventional approaches, in SALT, the leader is in an advantageous position. When the leader moves, it recognizes the strategy of the follower and takes the anticipated follower’s outcomes into consideration. Such a leader’s advantage enables us to improve the model fitting to the unperturbed data. The leader’s strategic information is captured by the Stackelberg gradient, which is obtained using an unrolling algorithm. Our experimental results on a set of machine translation and natural language understanding tasks show that SALT outperforms existing adversarial regularization baselines across all tasks. Our code is publicly available.


Introduction
Adversarial regularization (Miyato et al., 2017) has been shown to improve the generalization performance of deep learning models in various natural language processing (NLP) tasks, such as language modeling (Wang et al., 2019b), machine translation (Sato et al., 2019), natural language understanding , and reading comprehen- * Corresponding author.
sion (Jia and Liang, 2017). However, even though significant progress has been made, the power of adversarial regularization is not fully harnessed.
Conventional adversarial regularization is formulated as a zero-sum game (a min-max optimization problem), where two players seek to minimize/maximize their utility functions. In this formulation, an adversarial player composes perturbations, and a defending player solves for the model parameters subject to the perturbed inputs. Existing algorithms find the equilibrium of this zerosum game using alternating gradient descent/ascent (Madry et al., 2018). For example, in a classification problem, the adversarial player first generates the input perturbations by running projected gradient ascent to maximize a loss function, and then the defending player updates the model using gradient descent, trying to decrease the classification error. Notice that in this case, neither of the players know the strategy of its competitor, i.e., the model does not know how the perturbations are generated, and vice versa. In other words, the two players are of the same priority, and either one of them can be advantageous in the game. It is possible that the adversarial player generates over-strong perturbations that hinder generalization of the model.
To resolve this issue, we grant the defending player (i.e., the model) a higher priority than the adversarial player by letting the defender recognize its competitor's strategy, such that it is advantageous in the game. Consequently, we propose Stackelberg Adversarial Regularization (SALT), where we formulate adversarial regularization as a Stackelberg game (Von Stackelberg, 2010). The concept arises from economics, where two firms are competing in a market, and one of the them is in the leading position by acknowledging the opponent's strategy. In Stackelberg adversarial regularization, a leader solves for the model parameters, and a follower generates input perturbations. The leader procures its advantage by considering what the best response of the follower is, i.e., how will the follower respond after observing the leader's decision. Then, the leader minimizes its loss, anticipating the predicted response of the follower.
The SALT framework identifies the interaction between the leader and the follower by treating the follower's strategy (i.e., the input perturbations) as an operator of the leader's decision (i.e., the model parameters). Then we can solve for the model parameters using gradient descent. One caveat is that computing the gradient term, which we call the Stackelberg gradient, requires differentiating the interaction operator. To rigorously define this operator, recall that the follower can be approximately solved using gradient ascent. We can treat the perturbations in each iteration as an operator of the model parameters, and the interaction operator is then the composition of such update-induced operators. Correspondingly, the Stackelberg gradient is obtained by differentiating through these updates. This procedure is referred to as unrolling (Pearlmutter and Siskind, 2008), and the only computational overhead caused by it is computing Hessian vector products. As a result, when applying the finite difference method, computing the Stackelberg gradient requires two backpropagation and an extra O(d) complexity operation, where d is the embedding dimension. Therefore, the unrolling algorithm computes the Stackelberg gradient without causing much computational overhead.
We conduct experiments on neural machine translation (NMT) and natural language understanding (NLU) tasks. For the NMT tasks, we experiment on four low-resource and one richresource datasets. SALT improves upon existing adversarial regularization algorithms by notable margins, especially on low-resource datasets, where it achieves up to 2 BLEU score improvements. To test performance on NLU tasks, we evaluate SALT on the GLUE (Wang et al., 2019a) benchmark. SALT outperforms state-of-the-art models, such as BERT (Devlin et al., 2019), FreeAT (Shafahi et al., 2019), FreeLB (Zhu et al., 2019), andSMART (Jiang et al., 2020). We build SALT on the BERT-base architecture, and we achieve an average score of 84.5 on the GLUE development set, which is at least 0.7 higher than existing methods. Moreover, even though we adapt SALT to BERT-base, the performance is noticeably higher than the vanilla BERT-large model (84.5 vs. 84.0).
The unrolling procedure was first proposed for auto-differentiation (Pearlmutter and Siskind, 2008), and later applied in various context, such as hyper-parameter optimization (Maclaurin et al., 2015;Finn et al., 2017), meta-learning (Andrychowicz et al., 2016, and Generative Adversarial Networks (Metz et al., 2017). To the best of our knowledge, we are the first to apply the unrolling technique to adversarial regularization to improve generalization performance.
We summarize our contributions as the following: (1) We propose SALT, which employs a Stackelberg game formulation of adversarial regularization. (2) We use an unrolling algorithm to find the equilibrium of the Stackelberg game.
(3) Extensive experiments on NMT and NLU tasks verify the efficacy of our method.
Notation. We use df (x)/dx to denote the gradient of f with respect to x. We use ∂f (x, y)/∂x to denote the partial derivative of f with respect to

Background and Related Works
Neural machine translation has achieved superior empirical performance (Bahdanau et al., 2015;Gehring et al., 2017;Vaswani et al., 2017). We focus on the Transformer architecture (Vaswani et al., 2017), which integrates the attention mechanism in an encoder-decoder structure. The encoder in a Transformer model first maps a source sentence into an embedding space, then the embeddings are fed into several encoding layers to generate hidden representations, where each of the encoding layers contains a self-attention mechanism and a feed-forward neural network (FFN). After which the Transformer decoder layers, each contains a self-attention, a encoder-decoder attention, and a FFN, decode the hidden representations.
Adversarial training was originally proposed for training adversarial robust classifiers in image classification (Szegedy et al., 2014;Goodfellow et al., 2015;Madry et al., 2018). The idea is to synthesize strong adversarial samples, and the classifier is trained to be robust to them. Theoretical understanding (Li et al., 2019) about adversarial training and various algorithms to generate the adversarial samples, such as learning-to-learn (Jiang et al., 2021), are proposed. Besides computer vision, adversarial training can also benefit reinforce-ment learning (Shen et al., 2020). Different from the above fields, in NLP, the goal of adversarial training is to build models that generalize well on the unperturbed test data. Note that robustness and generalization are different concepts. Recent works (Raghunathan et al., 2020;Min et al., 2020) showed that adversarial training can hurt generalization performance, i.e., accuracy on clean data. As such, adversarial training needs to be treated with great caution. Therefore, in NLP, this technique requires refined tuning of, for example, the training algorithm and the perturbation strength.

Method
Natural language inputs are discrete symbols (e.g., words), instead of continuous ones. Therefore, a common approach to generate perturbations is to learn continuous embeddings of the inputs and operate on the embedding space (Miyato et al., 2017;Clark et al., 2018;Sato et al., 2018Sato et al., , 2019Stutz et al., 2019). Let f (x, θ) be our model, where x is the input embedding, and θ is the model parameter. Further let y be the ground-truth output corresponding to x. For example, in NMT, f is a sequence-to-sequence model, x is the embedding of the source sentence, and y is the target sentence.
In classification tasks, f is a classifier, x is the input sentence/document embedding, and y is the label. In both of these cases, the model is trained by minimizing the empirical risk over the training data, i.e., Here {(x i , y i )} n i=1 is our dataset, and is a taskspecific loss function, e.g., cross-entropy loss.

Adversarial Regularization
Adversarial Regularization (Miyato et al., 2017) is a regularization technique that encourages smooth-ness of the model outputs around each input data point. Concretely, we define an adversarial regularizer for non-regression tasks as Here KL(·||·) is the Kullback-Leibler (KL) divergence, δ is the perturbation corresponding to x, and f (·, θ) is the prediction probability simplex given model parameters θ. In regression tasks, the model output f (·, θ) is a scalar, and the adversarial regularizer is defined as Then the training objective is where α is a tuning parameter, is a pre-defined perturbation strength, and · is either the 2 norm or the ∞ norm. The min and max problems are solved using alternating gradient descent/ascent. We first generate the perturbations δ by solving the maximization problem using several steps of projected gradient ascent, and then we update the model parameters θ with gradient descent, subject to the perturbed inputs. More details are deferred to Appendix A.
One major drawback of the zero-sum game formulation (Eq. 1) is that it fails to consider the interaction between the perturbations δ and the model parameters θ. This is problematic because a small change in δ may lead to a significant change in θ, which renders the optimization ill-conditioned. Thus, the model is susceptible to underfitting and generalize poorly on unperturbed test data.

Adversarial Regularization as
Stackelberg Game We formulate adversarial regularization as a Stackelberg game (Von Stackelberg, 2010): Here "•" denotes operator composition, i.e., f • g(·) = f (g(·)). Following conventions, in this Stackelberg game, we call the optimization problem in Eq. 2 the leader. Further, the follower in Eq. 2 is described using a equality constraint. Note that U K is the follower's K-step composite strategy, which is the composition of K one-step strategies {U k } K k=1 . In practice, K is usually small. This is because in NLP, we target for generalization, instead of robustness, and choosing a small K prevents over-strong adversaries.
In Eq. 2, U k s are the follower's one-step strategies, and we call them update operators, e.g., U 1 updates δ 0 to δ 1 using pre-selected algorithms. For example, projected gradient ascent can be applied as the update procedure, that is, where δ 0 ∼ N (0, σ 2 I) is a initial random perturbation drawn from a normal distribution with variance σ 2 I, η is a pre-defined step size, and Π denotes projection to the 2 -ball or the ∞ -ball.
To model how the follower will react to a leader's decision θ, we consider the function δ K (θ). Then, adversarial training can be viewed solely in terms of the leader decision θ.
We highlight that in our formulation, the leader knows the strategy, instead of only the outcome, of the follower. This information is captured by the Stackelberg gradient dF(θ)/dθ, defined as the following: The underlying idea behind Eq. 4 1 is that given a leader's decision θ, we take the follower's strategy into account (i.e., the "leader-follower interaction" term) and find a direction along which the 1 The second term in "leader" is written as ∂ v (x, δ K , θ)/∂θ, instead of ∂ v (x, δ K (θ), θ)/∂θ. This is because the partial derivative of θ is only taken w.r.t. the third argument in v (x, δ K , θ). We drop the θ in δ K (θ) to avoid causing any confusion.
Input: D: dataset; T : total number of training epochs; σ 2 : variance of initial perturbations; K: number of unrolling steps; Optimizer: optimizer to update θ.
leader's loss decreases the most. Then we update θ in that direction. Note that the gradient used in standard adversarial training (Eq. 1) only contains the "leader" term, such that the "leader-follower interaction" is not taken into account.

SALT: Stackelberg Adversarial
Regularization We propose to use an unrolling method (Pearlmutter and Siskind, 2008) to compute the Stackelberg gradient (Eq. 4). The general idea is that since the interaction operator is defined as the composition of the {U k } operators, all of which are known, we can directly compute the derivative of δ K (θ) with respect to θ. Concretely, we first run a forward iteration to update δ, and then we differentiate through this update to acquire the Stackelberg gradient.
Note that the updates of δ can take any form, such as projected gradient ascent in Eq. 3, or more complicated alternatives like Adam (Kingma and Ba, 2015). For notation simplicity, we denote ∆(x, δ k−1 (θ), θ) = δ k (θ) − δ k−1 (θ). Accordingly, Eq. 3 can be rewritten as The most expensive part in computing the Stackelberg gradient (Eq. 4) is to calculate dδ K (θ)/dθ, which involves differentiating through the composition form of the follower's strategy: We can compute Eq. 6 efficiently using deep learning libraries, such as PyTorch (Paszke et al., 2019). Notice that ∆(x, δ k−1 (θ), θ) already contains the first order derivative with respect to the perturbations. Therefore, the term ∂∆(x, δ k−1 (θ), θ)/∂δ k−1 (θ) contains the Hessian of δ k−1 (θ). As a result, in Eq. 4, the most expensive operation is the Hessian vector product (Hvp). Using the finite difference method, computing Hvp only requires two backpropagation and an extra O(d) complexity operation. This indicates that in comparison with conventional adversarial training, SALT does not introduce significant computational overhead. The training algorithm is summarized in Algorithm 1.

Experiments
In all the experiments, we use PyTorch 2 (Paszke et al., 2019) as the backend. All the experiments are conducted on NVIDIA V100 32GB GPUs. We use the Higher package 3 (Grefenstette et al., 2019) to implement the proposed algorithm.

Baselines
We adopt several baselines in the experiments.
BERT (Devlin et al., 2019) is a pre-trained language model that exhibits outstanding performance after fine-tuned on downstream NLU tasks.
Adversarial training (Adv, Sato et al. 2019) in NMT can improve models' generalization by training the model to defend against adversarial attacks.
FreeAT (Shafahi et al., 2019) enables "free" adversarial training by recycling the gradient information generated when updating the model parameters. This method was proposed for computer vision tasks, but was later modified for NLU. We further adjust the algorithm for NMT tasks.  All the baseline results are from our re-implementation.
We report the mean of three runs.
FreeLB (Zhu et al., 2019) is a "free" large batch adversarial training method. We modify FreeLB to an adversarial regularization method that better fits our need. This algorithm was originally proposed for NLU. We modify the algorithm so that it is also suitable for NMT tasks.
SMART (Jiang et al., 2020) is a state-of-theart fine-tuning method that utilizes smoothnessinducing regularization and Bregman proximal point optimization.
We highlight that we focus on model generalization on clean data, instead of adversarial robustness (a model's ability to defend adversarial attacks). As we will see in the experiments, adversarial training methods (e.g., Adv, FreeAT) suffer from label leakage, and do not generalize as well as adversarial regularization methods.

Neural Machine Translation
Datasets. We adopt three low-resource datasets and a rich-resource dataset. Dataset statistics are summarized in Table 1

BLEU
Transformer (Vaswani et al., 2017) 28.4 FreeAT (Shafahi et al., 2019) 29.0 FreeLB (Zhu et al., 2019) 29.0 SMART  29.1 SALT 29.6  Sato et al. (2019). This is because Adv generates perturbations using the correct examples, thus, the label information are "leaked" (Kurakin et al., 2017). Additionally, we can see that SALT is particularly effective in this low-resource setting, where it outperforms all the baselines by large margins. In comparison with the vanilla Transformer model, SALT achieves up to 2 BLEU score improvements on all the three datasets. Table 5 summarizes experiment results on the WMT'16 En-De dataset. We report the sacre-BLEU (Post, 2018) score, which is a detokenzied version of the BLEU score that better reflects translation quality. We can see that SALT outperforms all the baseline methods by notable margins, and it improves upon the vanilla Transformer model by 1.2 BLEU score.  Results. Table 3 summarizes experiment results on the GLUE development set. We can see that SALT outperforms BERT BASE in all the tasks. Further, our method is particularly effective for small datasets, such as RTE, MRPC, and CoLA, where we achieve 9.4, 4.3, and 6.3 absolute improvements, respectively. Comparing with other adversarial training baselines, i.e., FreeAT, FreeLB, and SMART, our method achieves notable improvements in all the tasks. We highlight that SALT achieves a 84.5 average score, which is significantly higher than that of the vanilla BERT BASE (+3.0) fine-tuning approach. Also, our average score is higher than the scores of baseline adversarial training methods (+1.9, +1.2, +0.7 for FreeAT, FreeLB, SMART, respectively). Moreover, the 84.5 average score is even higher 6 https://github.com/microsoft/MT-DNN than fine-tuning BERT LARGE (+0.5), which contains three times more parameters than the backbone of SALT. Table 4 summarizes results on the GLUE test set. We can see that SALT consistently outperforms BERT BASE and FreeLB across all the tasks.

Parameter Study
Robustness to the number of unrolling steps. From Figure 1a, we can see that SALT is robust to the number of unrolling steps. As such, setting the unrolling steps K = 1 or 2 suffices to build models that generalize well.
Robustness to the perturbation strength. Unrolling is robust to the perturbation strength within a wide range, as indicated in Figure 1b. Meanwhile, the performance of SMART consistently drops when we increase from 0.01 to 0.5. This indicates that the unrolling algorithm can withstand stronger perturbations than conventional approaches. 2 constraints vs. ∞ constraints. Figure 1c illustrates model performance with respect to different perturbation strength in the ∞ case. Notice that in comparison with the 2 case (Figure 1b), SALT achieves the same level of performance, but the behavior of SMART is unstable. Additionally, SALT is stable within a wider range of perturbation strength in the 2 than in the ∞ case, which is the reason that we adopt 2 constraints in the experiments.
We highlight that SALT does not introduce additional tuning parameter comparing with conventional adversarial regularization approaches.

Analysis
Unrolling reduces bias. In Figure 3, we visualize the training and the validation error on the STS-B and the SST datasets from the GLUE benchmark. As mentioned, conventional adversarial regulariza-   tion suffers from over-strong perturbations, such that the model cannot fit the unperturbed data well. This is supported by the fact that the training loss of SALT is smaller than that of SMART, which means SALT fits the data better. SALT also yields a smaller loss than SMART on the validation data, indicating that the Stackelberg game-formulated model exhibits better generalization performance.
Adversarial robustness. Even though the primary focus of SALT is model generalization, we still test its robustness on the Adversarial-NLI (ANLI, Nie et al. 2020) dataset. The dataset contains 163k data, which are collected via a humanand-model-in-the-loop approach. From   datasets while keeping the representations fixed. Such a method directly measures the quality of representations generated by different models. As illustrated in Fig. 4, SALT outperforms the baseline methods by large margins.
Classification Model Calibration. Adversarial regularization also helps model calibration (Stutz et al., 2020). A well-calibrated model produces reliable confidence estimation (i.e., confidence actual accuracy), where the confidence is defined as the maximum output probability calculated by the model. We evaluate the calibration performance of BERT BASE , SMART, and SALT by the Expected Calibration Error (ECE, Niculescu-Mizil and Caruana 2005). We plot the reliability diagram (confidence vs. accuracy) on the SST task in Fig. 2 (see Appendix C for details). As we can see, BERT BASE and SMART are more likely to make overconfident predictions. SALT reduces ECE, and its corresponding reliability diagram aligns better with the perfect calibration curve.
Comparison with Unrolled-GAN. The unrolling technique has been applied to train GANs (Unrolled-GAN, Metz et al. 2017). However, subsequent works find that this approach not necessarily improves training (Grnarova et al., 2018;Tran et al., 2019;Doan et al., 2019). This is because Unrolled-GAN unrolls its discriminator, which has a significant amount of parameters. Consequently, the unrolling algorithm operates on a very large space, rendering the stochastic gradients that are used for updating the discriminator considerably noisy. In SALT, the unrolling space is the sample embedding space, the dimension of which is much smaller than the unrolling space of GANs. Therefore, unrolling is more effective for NLP tasks.

Conclusion
We propose SALT, an adversarial regularization method that employs a Stackelberg game formulation. Such a formulation induces a competition between a leader (the model) and a follower (the adversary). In SALT, the leader is in an advantageous position by recognizing the follower's strategy, and this strategic information is captured by the Stackelberg gradient. We compute the Stackelberg gradient, and hence find the equilibrium of the Stackelberg game, using an unrolled optimization approach. Empirical results NMT and NLU tasks suggest the superiority of SALT to existing adversarial regularization methods.

Broader Impact
This paper proposes Stackelberg Adversarial Regularization (SALT), an adversarial regularized training framework for NLP tasks. Different from Generative Adversarial Networks (GAN), where the target is to attack existing neural network models, or to improve models' robustness to adversarial attacks, we seek to improve the generalization performance of deep learning models. We demonstrate that the SALT framework can be used for neural machine translation and natural language understanding tasks. In all the experiments, we use publicly available data, and we build our algorithms using public code bases. We do not find any ethical concerns.  64k 1 × 10 −3 1 × 10 −4 1 × 10 −4 0.3 1 9 1.5 Fr-En (IWSLT'16) 64k 1 × 10 −3 1 × 10 −5 1 × 10 −5 0.3 1 10 2.0 En-De (WMT'16) 450k 1 × 10 −3 1 × 10 −4 1 × 10 −4 0.3 1 4 0.6 Table 8: Hyper-parameters for machine translation. Here, σ is the standard deviation of the initial perturbations, is the perturbation strength, K is the number of unrolling steps, Beam is the size of beam search, and Len-Pen is the length penalty parameter during beam search.

References
we set the tokens-per-GPU to be 8,000, and we accumulate gradients for 2 steps. For rich-resource translation, we set the batch size to be equivalent to 450k tokens. In all the experiments, we constrain each perturbation according to its sentence-level 2 norm, i.e., δ 2 ≤ . Other hyper-parameters are specified in Table 8.

B.2 Natural Language Understanding
Details of the GLUE benchmark, including tasks, statistics, and evaluation metrics, are summarized in Table 7. We use Adam as both the leader's and the follower's optimizer, and we set β = (0.9, 0.98). The learning rate of the leader lr leader is chosen from {5 × 10 −5 , 1 × 10 −4 , 5 × 10 −4 }, and the follower's learning rate is chosen from {1×10 −5 , lr leader }. We choose the batch size from {4, 8, 16, 32}, and we train for a maximum 6 epochs with early-stopping based on the results on the development set. We apply a gradient norm clipping of 1.0. We set the dropout rate in task specific layers to 0.1. We choose standard deviation of initial perturbations σ from {1 × 10 −5 , 1 × 10 −4 }, and 2 constraints with perturbation strength = 1.0 are applied. We set the unrolling steps K = 2. We report the best performance on each dataset individually.

C Model Calibration
Many applications require trustworthy predictions that need to be not only accurate but also well calibrated (Kong et al., 2020). A well-calibrated model is expected to output prediction confidence comparable to its classification accuracy. For example, given 100 data points with their prediction confidence 0.6, we expect 60 of them to be correctly classified. More precisely, for a data point X, we denote by Y (X) the ground truth label, Y (X) the label predicted by the model, and P (X) the output probability associated with the predicted label. The calibration error of the predictive model for a given confidence p ∈ (0, 1) is defined as: E p = P Y (X) = Y (X)| P (X) = p − p . (8) Since Eq. 8 involves population quantities, we usually adopt empirical approximations (Guo et al., 2017) to estimate the calibration error. Specifically, we partition all data points into 10 bins of equal size according to their prediction confidence. Let B m denote the bin with prediction confidence bounded between m and u m . Then, for any p ∈ [ m , u m ), we define the empirical calibration error as: where y i , y i and p i are the true label, predicted label and confidence for sample i.
Reliability Diagram is a bar plot that compares E p against each bin, i.e., p. A perfectly calibrated would have E p = ( m + u m )/2 for each bin.
Expected Calibration Error (ECE) is the weighted average of the calibration errors of all bins (Naeini et al., 2015) defined as: where n is the sample size. We remark that the goal of calibration is to minimize the calibration error without significantly sacrificing prediction accuracy. Otherwise, a random guess classifier can achieve zero calibration error.