ARCH: Efficient Adversarial Regularized Training with Caching

Adversarial regularization can improve model generalization in many natural language processing tasks. However, conventional approaches are computationally expensive since they need to generate a perturbation for each sample in each epoch. We propose a new adversarial regularization method ARCH (adversarial regularization with caching), where perturbations are generated and cached once every several epochs. As caching all the perturbations imposes memory usage concerns, we adopt a K-nearest neighbors-based strategy to tackle this issue. The strategy only requires caching a small amount of perturbations, without introducing additional training time. We evaluate our proposed method on a set of neural machine translation and natural language understanding tasks. We observe that ARCH significantly eases the computational burden (saves up to 70% of computational time in comparison with conventional approaches). More surprisingly, by reducing the variance of stochastic gradients, ARCH produces a notably better (in most of the tasks) or comparable model generalization. Our code is available at https://github.com/SimiaoZuo/Caching-Adv.


Introduction
Adversarial regularization (Miyato et al., 2017) can improve model generalization in many natural language processing tasks, such as neural machine translation (Cheng et al., 2019), natural language understanding , language modeling (Wang et al., 2019b), and reading comprehension (Jia and Liang, 2017). Even though the method has demonstrated its power in many scenarios, its computational efficiency remains unsatisfactory.
Conventional adversarial regularization (Miyato et al., 2017) methods involve a min-max optimization problem. Specifically, a perturbation is * Corresponding author. generated for each sample by solving a maximization problem, and the model parameters are subsequently updated through a minimization problem, subject to the generated perturbations. A popular algorithm (Madry et al., 2018) for such optimization is to alternate between several projected gradient descent steps (PGD, for the maximization) and a gradient descent step (for the minimization).
There are two drawbacks with the alternating gradient descent/ascent method. First, the procedure requires significant computational efforts. Suppose we run PGD for S steps, then we introduce extra S forward passes and extra S backward passes in each iteration. As such, training with adversarial regularization is significantly slower than standard training. Second, optimizing the min-max problem is hard. This is because the perturbations are model and data dependent, and thus, variance of them is large. That is, the model needs to adapt to drastically different "noisy data" (i.e., clean data with perturbations), such that the stochastic gradients vary significantly during training. Such large variance imposes optimization challenges.
We propose ARCH (Adversarial Regularization with CacHing) that alleviates the aforementioned issues by reusing perturbations. Recall that in conventional adversarial regularization methods, a different perturbation is generated for each sample in each epoch. In contrast to this, we propose to generate perturbations less frequently. For example, for a given sample, we can generate a new perturbation every 20 epochs, and the sample's perturbation remains unchanged in other epochs. We call this method "caching". The method has two advantages. First, it alleviates the computational burden. By reusing the perturbations, we avoid the extra forward and backward passes caused by PGD for most of the iterations. Second, caching stabilizes the stochastic gradients. Notice that in our method, the model is optimized with respect to the same noisy data for multiple times, instead of only one. In this way, variance of the stochastic gradients is reduced.
One caveat of the caching method is its memory overhead. This is because a sample's perturbation is significantly larger than itself (the perturbation has an extra embedding dimension). We propose a K-nearest neighbors-based approach to tackle this problem. Specifically, instead of caching perturbations for all the samples, we only cache a small proportion of them. Each uncached perturbation can then be constructed using the cached ones in its neighborhood. Such a construction procedure can be executed in parallel with model training. Therefore, training time will not be prolonged because of this memory saving strategy.
We use a moving average approach to boost model generalization. Specifically, when generating a new perturbation, we integrate information from both the current model and the current perturbation. This is different from conventional approaches, where the new perturbation only depends on the current model. The moving average approach has a smoothing effect that boosts model generalization, as demonstrated both theoretically and empirically by previous works (Izmailov et al., 2018;Athiwaratkun et al., 2019;. Arguably, the perturbations introduced by our method may not constitute strong adversarial attacks, because of the "staleness" caused by infrequent updates. However, we highlight that the focus of this work is model generalization over clean data, instead of adversarial robustness (ability to defend attacks). As we will demonstrate in the experiments, the "weak" perturbations show notable improvement of model generalization. And somewhat surprisingly, ARCH also exhibits on par or even better robustness comparing with conventional approaches.
We conduct extensive experiments on neural machine translation (NMT) and natural language understanding (NLU) tasks. In comparison with conventional adversarial regularization approaches, ARCH can save up to 70% computational time. Moreover, in NMT tasks, our method improves about 0.5 BLUE over baseline methods on seven datasets. ARCH also achieves 0.7 average score improvement on the GLUE (Wang et al., 2019a) development set over existing methods.
We summarize our contributions as follows: (1) We propose a caching method that needs drastically less computational efforts. The method can also improve model generalization by reducing variance of stochastic gradient. (2) We propose a memory saving strategy to efficiently implement the caching method.
(3) Extensive experiments on neural machine translation and natural language understanding demonstrate the efficiency and effectiveness of the proposed method.

Background
Neural machine translation has achieved superior empirical performance (Bahdanau et al., 2015;Gehring et al., 2017;Vaswani et al., 2017). Recently, the Transformer (Vaswani et al., 2017) architecture dominates the field. This sequence-tosequence model employs an encoder-decoder structure, and also integrates the attention mechanism. During the encoding phase, a Transformer model first computes an embedding for each sentence, after which the embeddings are fed into several layers of encoding blocks. Each of these blocks contain a self-attention mechanism and a feed-forward neural network (FFN). Subsequently, after encoding, the hidden representations are fed into the decoding blocks, each constituted of a self-attention, a encoder-decoder attention, and a FFN.
Fine-tuning pre-trained language models (Peters et al., 2018;Devlin et al., 2019;Radford et al., 2019; is a stateof-the-art method for natural language understanding tasks such as the GLUE (Wang et al., 2019a) benchmark. Adversarial regularization is also incorporated into the fine-tuning approach. For example, Liu et al. (2020a) combines adversarial pretraining and fine-tuning, Zhu et al. (2020); adopt trust region-based methods, and Aghajanyan et al. (2020) aims for a more efficient computation.
Adversarial training was originally proposed for computer vision tasks (Szegedy et al., 2014;Goodfellow et al., 2015;Madry et al., 2018), where the goal is to train robust classifiers. Such methods synthesize adversarial samples, such that the classifier is trained to be robust against them. This strategy is also effective for tasks beyond computer vision, such as in reinforcement learning (Shen et al., 2020). Various algorithms are proposed to craft the adversarial samples, e.g., learning-to-learn  and Stackelberg adversarial training (Zuo et al., 2021). Moreover, adversarial training is also well-studied theoretically (Li et al., 2019). In natural language processing, the goal is no longer adversarial robustness, but instead we use adver-sarial regularization to boost model generalization. Note that adversarial training and adversarial regularization are different concepts. The former focuses on defending against adversarial attacks, and the latter focuses on encouraging smooth model predictions (Miyato et al., 2017). These two goals are usually treated as mutually exclusive (Raghunathan et al., 2020;Min et al., 2020).

Method
Generating perturbations for natural language inputs faces the difficulty of discreteness, i.e., words are defined in a discrete space. A common approach to tackle this is to work on the continuous embedding space (Miyato et al., 2017;Sato et al., 2019). Denote f (x, θ) a neural network parameterized by θ, where x is the input embedding. Further denote y the ground-truth corresponding to x. For example, in classification tasks, x is the sentence embedding, and y is its label. In sequenceto-sequence learning, x is the source sentence embedding, and y is the target sentence. In both of these cases, the model is trained by minimizing the empirical risk over the training data, i.e., Here {(x i , y i )} n i=1 is the dataset, and is a taskspecific loss, e.g., cross-entropy loss for classification and mean-squared error for regression.
We consider the worst-case perturbation to encourage the model to make smooth predictions.
Specifically, at epoch t, we solve Here λ is the weight of the regularizer, is a predefined perturbation strength, and · is either the 2 norm or the ∞ norm. Notice that the perturbation δ t i of sample x i is different in each epoch. The min-max optimization problem in Eq. 1 is notoriously difficult to solve. Previous works (Miyato et al., 2017;Sato et al., 2019;Zhu et al., 2020) employ variations of alternating gradient descent/ascent. That is, we first solve the maximization problem using several iterations of projected gradient ascent, and then we run a gradient descent step on the loss function of the minimization problem, subject to the generated perturbations. The above procedures are run iteratively.
On major drawback of the alternating gradient descent/ascent approach is that the stochastic gradients are unstable. Specifically, norms of the gradients vary significantly during training (Fig. 3). This is because perturbations are generated based on the current model parameters, i.e., by maximiz- where θ t changes in each epoch. Therefore, the perturbations exhibit large variance. This causes instability of the stochastic gradients, because the model needs to adapt to drastically different adversarial directions (i.e., δ t i ).

Adversarial Regularization with Caching
To alleviate the gradient instability problem, we propose to reuse the perturbations. Specifically, instead of optimizing with respect to different perturbations {δ t i } n i=1 in each epoch, we optimize with respect to the same ones for several epochs.
Concretely, the training objective is now Here, % is the mod operator, and T c is a pre-defined gap between re-computing the perturbations. Notice that we use an exponential moving average (EMA) approach with parameter α when updating the perturbations. The EMA strategy integrates past information into the current epoch, and induces a smoothing effect that boosts model generalization. This strategy has demonstrated its effectiveness in many previous works (Izmailov et al., 2018;Athiwaratkun et al., 2019;.
In comparison with Eq. 1, the formulation in Eq. 2 indicates that the perturbations are generated T /T c times instead of T times when we train for T epochs. As such, the model is optimized with respect to {δ i } n i=1 for T c times, instead of only one time. In this way, the model can better adapt to the perturbed data, and thus, variance of the gradient norms is reduced. Intuitively, this is because optimization is more stable when the model is trained on the same data for multiple epochs, in comparison with trained on different noisy data in each epoch. The algorithm to implement the caching strategy is summarized in Algorithm 1.
In conventional adversarial regularization (e.g., SMART), we find the perturbations by optimization algorithms such as projected gradient decent at every iteration. Recently, R3F (Aghajanyan et al., 2020) propose to use random perturbations instead, i.e., they directly draw δ from a normal distribution, and generalization of R3F can match SMART in some cases. However, because the random noise (as opposed to optimized perturbations) is not datadependent, generalization of R3F is subpar in some scenarios, e.g., machine translation (see our experiments). Our approach enjoys the advantages of both of these two methods. Specifically, ARCH is efficient since it remove the maximization problem most of the time. Moreover, perturbations generated by our method are informative, unlike R3F. Empirically, our proposed method is just as efficient as R3F, and somewhat surprisingly, we find that generalization of ARCH can not only match, but even surpass conventional approaches in most of the tasks (see our experiments).

Memory Saving with KNN
One caveat of Algorithm 1 is the increased memory usage. For example, there are about 4.5 million sentence pairs in the WMT'16 En-De dataset, so that simply caching the adversarial samples takes about 100GB of memory. We propose a memory saving strategy based on K-nearest neighbors (KNN) to address this issue.
The idea is to only cache perturbations of some samples, and perturbations of the other samples are constructed using the cached ones on the fly.
Input: T : number of training epochs; T c : number of epochs between caching; α: moving average parameter.
Compute δ t i using K i and Eq. 3; end end Specifically, whenever t%T c = 0, i.e., we need to re-compute and re-cache the perturbations, we only cache δ t i such that i ∈ X . Here, X ⊂ {1 · · · n} is a pre-defined cache set and |X | n. This strategy significantly reduces memory overhead. Consequently, in each epoch t where t%T c = 0, perturbations δ t i such that i ∈ X are directly retrieved from the cache. And perturbations δ t i such that i ∈ {1 · · · n} \ X are defined as the following: Here, i be the length of sentence x i , δ t i, ∈ R d is the perturbation for the -th word in sentence x i , and K i is the nearest neighbor set for x i (which we present later). We remark that constructing the perturbations does not impose extra training time, because we can perform such computation in parallel with training.
We remark that each word has an identical perturbation in Eq. 3, i.e., δ t i ∈ R | i |×d has identical rows. We choose this design because a perturbation in the neighbor of δ t i may have a different dimension, i.e., δ t j ∈ R | j |×d is in the neighbor of δ t i and it is possible that | i | = | j |. To resolve this issue, we compute the word-level mean of all the perturbations in the neighbor of δ t i and assign it to each row of δ t i . The remaining is to find K nearest neighbors in X for each sentence x i such that i ∈ {1 · · · n} \ X . Suppose we have a word embedding matrix W ∈ R d×|V| , where |V| is the vocabulary size and d is the embedding dimension. Note that W can be obtained from pre-trained models such BERT (Devlin et al., 2019). For each sentence x i , we compute its sentence representation Here, x i, ∈ R |V| is the one-hot vector of the -th word in sentence x i . Then, we can find K nearest neighbors K i for sample x i using the KNN algorithm, where the distance between two samples is defined as their cosine similarity. Notice that finding {K i } n i=1 is a pre-processing step, i.e., we can find the neighbors before training the model.
The memory saving algorithm is summarized in Algorithm 2, and an extended version that combines caching and memory saving is presented in Algorithm 3 in the appendix.

Computational Efficiency
Computational costs of various methods are summarized in Table 1. In conventional adversarial

Forward Backward
Standard regularization algorithms, such as FreeLB (Zhu et al., 2020) and SMART , suppose we solve the inner maximization problem for S steps, then we impose extra S forward passes and S backward passes in each iteration. In contrast, R3F (Aghajanyan et al., 2020) removes the maximization problem, and directly samples perturbations from a normal distribution. Thus, R3F only introduce one extra forward pass to compute the regularization term. Using Algorithm 1, our method shares similar efficiency as R3F. Specifically, suppose we cache the perturbations every T c epochs, then the average number of forward passes and backward passes per iteration is 2+(S −1)/T c and 1 + S/T c , respectively. In practice, S/T c is usually small, such that the computational cost between ARCH and R3F is close. Wall time comparison is illustrated in Fig. 1. Notice that in the left subfigure, both our method and R3F save about 70% computation time in comparison with FreeLB and SMART. In the right subfigure, the time saving is about 50%. The absolute time saving is more significant on large models and large datasets. For example, when training a Transformer-big model on the WMT'16 En-De dataset, our method costs about 176 GPU hours, while SMART uses 576 GPU hours.

En-Vi Vi-En En-De De-En En-Fr Fr-En
Transformer ( Table 2: BLEU score on three low-resource datasets. All the baseline results are from our re-implementation. We report the mean over three runs using different random seeds. ARCH saves about 70% computational time comparing with SMART.   Table 4: Dataset source and statistics. Here "k" stands for thousand, and "m" stands for million.

Experiments
In all the experiments, we use PyTorch 1 (Paszke et al., 2019) as the backend. All the experiments are conducted on NVIDIA V100 GPUs.

Baselines
We adopt several baselines in the experiments.
BERT (Devlin et al., 2019) exhibits outstanding performance when fine-tuned on natural language understanding tasks.
FreeAT (Shafahi et al., 2019) enables "free" adversarial training by recycling the gradient information generated when updating the model. 1 https://pytorch.org/ FreeLB (Zhu et al., 2020) treats the intermediate perturbations during the projected gradient ascent steps as virtual batches. As such, the method achieves "free" large batch adversarial training.
SMART  achieves state-of-theart performance in natural language understanding. The method utilizes smoothness-inducing regularization and Bregman proximal point optimization.
R3F (Aghajanyan et al., 2020) replaces the maximization problem in conventional adversarial regularization with random noise.

Machine Translation
Datasets. We use three low-resource datasets 2 : English-German from IWSLT'14, English-Vietnamese from IWSLT'15, and English-French from IWSLT'16. We also use a rich-resource dataset: English-German from WMT'16. Dataset statistics are summarized in Table 4.
Implementation. In NMT tasks, we have the source-side and the target-side inputs. We add perturbations to both of their embeddings (Sato et al., 2019). This has demonstrated to be more effective than adding perturbations to a single side. We use Fairseq 3 (Ott et al., 2019) to implement our algorithms. For En-Vi and En-Fr experiments, we use the Transformer-base architecture (Vaswani et al., 2017). For En-De (IWSLT'14) experiments, we modify 4 the Transformer-base architecture by decreasing the hidden dimension size from 2048 to 1024, and decreasing the number of heads from 8 to 4 (while dimension of each head doubles). For En-De (WMT'16) experiments, we use the Transformer-big (Vaswani et al., 2017) architecture. The training details are presented in Appendix B.1.  Results. Experimental results on the low-resource datasets are summarized in Table 2. We can see that ARCH outperforms all the baselines in all the experiments. We remark that our method saves about 70% computational time in comparison with SMART and FreeLB, and has the save level of efficiency comparing with R3F (Fig. 1). Even though R3F is efficient by eliminating the maximization problem, we can see that is does not generalize as well as SMART, i.e., R3F has worse BLEU score than SMART in 5/6 of the experiments. Experimental results on the WMT'16 En-De dataset are summarized in Table 3. We report both the BLEU score and the sacreBLEU (Post, 2018) score. The former is standard for machine translation tasks, and the latter is a detokenzied version of BLEU. The absolute computational time saving is more significant for larger datasets (e.g., WMT) and larger models (e.g., Transformer-big). In the experiments, ARCH uses about 176 GPU hours to train, while it costs SMART about 576 hours. Performance of ARCH is better or on par with all the baselines. Notice that like in Table 2, performance of R3F is worse than SMART.

Natural Language Understanding
Datasets. We conduct experiments on the General Language Understanding Evaluation (GLUE) benchmark (Wang et al., 2019a), which is a collection of nine natural language inference tasks. Implementation. We implement our algorithm using the MT-DNN 5 (Liu et al., 2019a(Liu et al., , 2020b and the Transformers (Wolf et al., 2020) code-base. The training details are presented in Appendix B.2.
Results. Table 5 summarizes experimental results on the GLUE development set. We can see that ARCH is on par or outperforms all the baselines in all the tasks. Notice that generalization of R3F is comparable with SMART. Our proposed method shares the advantages of both efficiency (i.e., R3F) and informative perturbations (i.e., SMART), and thus, ARCH behaves better than both of these methods. We highlight that our method is 50%-70% faster than SMART and FreeLB.

Parameter Study
Moving average helps. As indicated in Fig. 2a, without the exponential moving average, model performance drops about 0.3 BLEU. Also, the model is robust to the moving average parameter, as increasing it from 0.01 to 0.1 does not change model performance.
Number of epochs between caching is important. If we cache the perturbations too frequently (i.e., 5 in Fig. 2b), the model cannot adapt to the perturbations well; and if we cache the perturbations too infrequently (i.e., inf in Fig. 2b), staleness of the perturbations hinders model generalization.
Robustness to the number of neighbors. In Fig. 2c, notice that ARCH is robust to the number of neighbors. We also examine a variant of the KNN memory-saving strategy (R-1-NN): namely  in Algorithm 1, the nearest neighbors set K i for sample x i is randomly constructed instead of based on word embeddings. We can see that model performance drops, and the method also exhibits drastically larger variance.
Robustness to the number of cached samples. From Fig. 2d, notice that the model generalizes well even caching only 1% of the perturbations (i.e., only 1400 samples for the IWSLT'14 De-En dataset). Moreover, the KNN memory-saving strategy does not hinder model performance, i.e., the BLEU score is consistent when caching all the samples and caching only 10% of the samples.
We highlight that in practice ARCH does not need much tuning, because the method is robust to the introduced hyper-parameters.

Analysis
Caching reduces gradient norm variance. As demonstrated in Fig. 3, variance of the gradient norms reduces significantly comparing with SMART and R3F. This meets our expectation that by reusing perturbations, the model can adapt to the noisy data (i.e., clean data with perturbations) better. Notice that R3F has even larger gradient norm variance than SMART, which is because R3F uses random noise instead of data-dependent ones.
Adversarial robustness. We remark that the focus of ARCH is model generalization. Never- theless, we investigate model robustness on the Adversarial-NLI (ANLI, Nie et al. 2020) dataset. The dataset contains 163k data, which are collected via a human-and-model-in-the-loop approach. Surprisingly, from Table 6, we can see that R3F and ARCH achieve on par robustness with SMART. This indicates that reusing perturbations, or even constructing random perturbations can increase robustness (than BERT) to the same level as computing optimized perturbations (i.e., SMART).  Probing experiments. We first fine-tune a BERT BASE model on the SST-2 dataset using different methods, and then we freeze the representations and only tune a prediction head on other datasets. The probing method directly measures the quality of representations generated by different models. As illustrated in Fig. 4, ARCH consistently outperforms the baseline methods. Figure 4: Probing experiments. Each violin plot is based on 10 runs with different random seeds.

Conclusion
We propose a new caching method to speedup the training of neural models with adversarial regularization. By reusing the generated perturbations, our proposed method significantly amortizes the computational cost of the backward passes at each iteration. Our thorough experiments show that the proposed method not only improves the computational efficiency, but also reduces the variance of the stochastic gradients, which leads to better model generalization.

Broader Impact
This paper proposes a caching method to speedup adversarial regularized training for NLP tasks. Our proposed method provides a fundamental way to address the efficiency issue that commonly exists in conventional adversarial regularization methods. We use publicly available data, to conduct neural machine translation and natural language understanding experiments. Our framework is built using public code bases. We do not find any ethical concerns.

References
Compute δ t i using K i , C, and Eq. 3; end end One-step gradient descent on Eq. 2 to update model parameters; end end Output: Trained model.

B.1 Machine Translation Experiments
For the low-resource experiments, we use a batch size of 64k tokens. For example, when running the experiments on 4 GPUs, we set the tokens-per-GPU to be 8k, and we accumulate gradients for 2 steps. We use Adam (Kingma and Ba, 2015) as the optimizer, and we set β = (0.9, 0.98). The learning rate is set to be 1 × 10 −3 in all the experiments. We choose the model with the best validation performance to test on the test set. Other training details are the same as Ott et al. (2019) 6 .
For the rich resource experiments, we use a batch size of 450k tokens. That is, we set tokens-per-GPU to be 7k with 8 GPUs, and we further accumulate gradients for 8 steps. We set the learning rate to be 1 × 10 −3 }. For other training setups, please refer to Ott et al. (2018) To implement our proposed method, we sample the initial perturbation from a uniform distribution. We use sentence-level 2 constraints on the perturbations, and we set the perturbation strength = 0.1. We run a modified version of projected gradient ascent for 3 steps to compute the perturbations, and the learning rate is set to be 0.1. Concretely, in each iteration to compute the perturbations, we apply the following update rule where η is the learning rate and Π denotes the projection into the 2 ball. We set the number of epochs between caching to be 15, and the exponential moving average parameter α = 0.01. We cache 10% of perturbations, and we use the nearest neighbor (i.e., 1-NN) to construct uncached perturbations.
Inference settings are presented in Table 7.
Beam Len-Pen

B.2 Natural Language Understanding Experiments
Statistics and descriptions of the GLUE benchmark is summarized in Table 8. We fine-tune a pre-trained BERT BASE model. For each task, we choose the batch size from {8, 16, 32, 64, 128}, and the learning rate from {5 × 10 −5 , 8 × 10 −5 , 1 × 10 −4 , 2 × 10 −4 }. We use a linear learning rate warm-up schedule for 10% of the training iterations. We set the dropout rate of the task specific layer (i.e., the classification head) to be 0.1, and the dropout rate of BERT is chosen from {0.0, 0.1}. We train the model for 10 epochs. We report the best performance on each dataset individually.
To implement the adversarial regularization method, we sample the initial perturbation from a normal distribution with mean 0 and standard deviation 10 −5 . We use word-level ∞ constraints, and the perturbation strength is set to be 1.0. We run standard projected gradient ascent to compute the perturbations, where the number of steps is chosen from {1, 2}, and the learning rate is chosen from {10 −4 , 10 −5 }. Because of the limited number of training samples, we only cache the perturbations once for fine-tuning tasks. We refer to the MT-DNN code-base 8 for other details.