Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models

Recent works have shown that powerful pre-trained language models (PLM) can be fooled by small perturbations or intentional attacks. To solve this issue, various data augmentation techniques are proposed to improve the robustness of PLMs. However, it is still challenging to augment semantically relevant examples with sufficient diversity. In this work, we present Virtual Data Augmentation (VDA), a general framework for robustly fine-tuning PLMs. Based on the original token embeddings, we construct a multinomial mixture for augmenting virtual data embeddings, where a masked language model guarantees the semantic relevance and the Gaussian noise provides the augmentation diversity. Furthermore, a regularized training strategy is proposed to balance the two aspects. Extensive experiments on six datasets show that our approach is able to improve the robustness of PLMs and alleviate the performance degradation under adversarial attacks. Our codes and data are publicly available at bluehttps://github.com/RUCAIBox/VDA.


Introduction
Recently, pre-trained language models (PLMs) such as BERT (Devlin et al., 2019) and RoBERTa  have achieved remarkable success in various natural language processing (NLP) tasks (Rajpurkar et al., 2016;Zhou et al., 2020b). As a general and effective approach, fine-tuning PLMs on specific datasets has become the mainstream paradigm for developing NLP applications. Despite the success, researchers have found that PLMs can be easily fooled by adversarial attacks (Jin et al., 2020;Li et al., 2020b). Although encapsulated into a black box, these attack strategies can detect the vulnerabilities of a PLM via intentional queries (He et al., † † Corresponding author 2021; Li et al., 2020a), and then add small perturbations (e.g., synonyms substitution) into the input texts for misleading PLMs to incorrect predictions.
As found in previous works (Schmidt et al., 2018;Yin et al., 2019;Jiang et al., 2020), a possible reason of the vulnerability is that these PLMs do not generalize well on semantic neighborhood around each example in the representation space. To solve this issue, adversarial data augmentation (ADA) methods (Jia and Liang, 2017;Wang and Bansal, 2018;Michel et al., 2019) have been proposed by revising original data to augment attackrelated data for training. However, due to the discrete nature of language, it is challenging to generate semantically relevant and sufficiently diverse augmentations. Although attempts by leveraging expert knowledge (Ren et al., 2019;Li et al., 2019b) and victim models (Jin et al., 2020;Li et al., 2020b) have achieved better performance, their generalizability and flexibility is highly limited.
Recently, virtual adversarial training (Miyato et al., 2017;Madry et al., 2018) is applied to various NLP models for improving the performance and robustness Jiang et al., 2020), which usually generates gradient-based perturbation on the embedding space as virtual adversarial samples. However, it is hard to explicitly constrain the gradient-based perturbation within the same semantic space as the original sample. In addition, unlike attacks in computer vision (Zheng et al., 2016;Miyato et al., 2019), textual adversarial attacks are discrete (e.g., word replacement) and are hard to be captured by gradient-based perturbations.
To solve these challenges, we propose Virtual Data Augmentation (VDA), a robust and general framework for fine-tuning pre-trained models. Our idea is to generate data augmentations at the embedding layer of PLMs. To guarantee semantic relevance, we consider a multinomial mixture of the original token embeddings as the augmented embedding for each position of the input. In the mixture, each token embedding is weighted according to its likelihood estimated by a masked language model conditioned on the input. To provide sufficient diversity, we further incorporate Gaussian noise in the above multinomial mixture, which enhances the randomness of the augmentations. As shown in Figure 1, for a target token "good", we first predict the substitution probabilities of candidate tokens via a masked language model, then inject the Gaussian noise to produce multiple multinomial mixtures. After that, we aggregate the candidate embeddings with the multinomial mixtures to generate new embeddings (virtual data embeddings) to replace the original embedding of "good".
There are two major advantages to our VDA approach. First, with the original token embeddings as the representation basis, the augmented embeddings stay close to the existing embeddings, which avoids the unexpected drift of semantic space. Second, with the injected Gaussian noise, we are able to generate diverse variations for augmentations. In order to enhance the relevance with the given injected Gaussian noise, we further design a regularized training strategy that guides the learning of the augmented virtual data towards the original predictions of PLMs. In this way, our approach has considered both semantic relevance and sufficient diversity. Besides, since VDA only revises the input embeddings, it is agnostic to downstream tasks, model architectures and learning strategies.
To evaluate the effectiveness of our proposed VDA framework, we construct extensive experiments on six datasets. Results show that VDA can boost the robustness of all the baseline models without performance degradation. We also find that our approach can be further improved by combining it with traditional adversarial data augmentation.
Our contributions are summarized as follows: • We propose a new data augmentation framework for resisting discrete adversarial attacks on PLMs, which is general to improve the robustness of various PLMs on downstream tasks.
• Our approach utilizes a masked language model with Gaussian noise to augment virtual examples for improving the robustness, and also adopts regularized training to further guarantee the semantic relevance and diversity.
• Extensive experiments on six datasets have demonstrated that the proposed approach is able to effectively improve the robustness of PLMs, which can be further improved by combining with existing adversarial data augmentation strategies.

Related Work
We review the related work in the following three aspects.

Adversarial Attack in NLP
Inspired by the success in compute vision (Goodfellow et al., 2015;Kurakin et al., 2017), adversarial attack in NLP tasks has become an emerging research topic in recent years (Gao et al., 2018;Yang et al., 2020;Chen et al., 2020). Early works usually adopt heuristic rules to revise the input text for producing adversarial samples, including character modification (Ebrahimi et al., 2018), synonyms replacement (Alzantot et al., 2018), word insertion or deletion . However, with the revolution of large-scale PLMs, these attack strategies can be defended (Jones et al., 2020;Gui et al., 2021;Zhou et al., 2020a) to some extent. To attack PLMs, TextFooler (Jin et al., 2020) designs an attack algorithm to revise the input data and queries the PLM several times to find important words for replacement, which greatly reduces the accuracy of BERT. Following it, recent works (Li et al., 2020b;He et al., 2021) continuously improve the quality of the adversarial samples and the attack success ratio. In our approach, we consider improving the robustness of PLMs against these adversarial attack methods via a new fine-tuning framework VDA.

Data Augmentation
Data augmentation has been extensively studied in NLP tasks for improving the robustness (Wang and Yang, 2015;Fadaee et al., 2017;Wei and Zou, 2019). Similar to adversarial attack, early works mostly try heuristic rules to revise the in-put data for augmentation, such as synonym replacement (Wang and Bansal, 2018), grammar induction (Min et al., 2020), word insert and delete (Wei and Zou, 2019). With the development of text generation techniques, back translation (Xie et al., 2020;Ribeiro et al., 2018) and variant autoencoder Li et al., 2019c) are used to augment new data. Besides, a surge of works (Hou et al., 2018;Li et al., 2019a; focus on augmentation for specific tasks with special rules or models. Although they perform well, these methods have lost the generality. In this paper, we propose a new data augmentation framework VDA that utilizes a masked language model with Gaussian noise to augment virtual examples for improving the robustness. our VDA is agnostic to downstream tasks, model architectures and learning strategies.

Virtual Adversarial Training
To improve the robustness of neural networks against adversarial examples, virtual adversarial training (VAT) (Miyato et al., 2015;Kurakin et al., 2017;Qin et al., 2019) has been widely used in compute vision. It formulates a class of adversarial training algorithms into solving a minimax problem, which can be achieved reliably through multiple projected gradient ascent steps (Qin et al., 2019). Recently, VAT has shown its effectiveness in NLP tasks, where the gradient-based noise is able to improve the performance and smoothness of the pre-trained models Jiang et al., 2020). However, due to the discrete nature of language, it has been shown that VAT methods are not very effective in defending against adversarial attacks (Si et al., 2020;Li and Qiu, 2021).

Preliminary
This work seeks to improve the fine-tuning performance of pre-trained language models (PLM), in that the fine-tuned model will become more robust to data permutations or attacks. Specially, we take the text classification task as an example task to illustrate our approach, where a set of n labeled texts { x i , y i } are available. Each labeled text consists of a text x i and a label y i from the label set Y. We refer to the adversarial example generated from a text x i as adversarial text, denoted byx i . The purpose of adversarial examples is to enhance the model robustness in resisting intentional data perturbations or attacks.
Let f denote a PLM parameterized by θ. Following (Jia and Liang, 2017;Michel et al., 2019), we incorporate adversarial examples to improve the fine-tuning of PLMs. To conduct the adversarial learning, we formulate the learning objective as follows: where m is the number of adversarial texts that we use, λ is a trade-off parameter, L c and L reg try to minimize the classification loss and reduce the prediction difference between original and adversarial texts, respectively.
For the PLM f , we assume that it is already pretrained on general-purpose large-scale text data, we would like to fine-tune its parameter θ based on some downstream task. The PLMs are usually developed based on multi-layered Transformer architecture such as BERT (Devlin et al., 2019) and RoBERTa , where a sequence of tokens will be encoded into a sequence of contextual representations. Here, we take the representation of the first token (i.e., [CLS]) as the input of the classifier, and optimize the classification performance with the cross-entropy loss.

Our Approach
In this section, we describe our proposed framework Virtual Data Augmentation (VDA) for robustly fine-tuning PLMs. Our framework consists of two important ingredients, namely embedding augmentation and regularized training.

Embedding Augmentation
To improve the model robustness, a good adversarial example should adhere to the original semantic space, as well as incorporate sufficient variations in meanings. However, existing studies cannot make a good trade-off between the two aspects.
Considering this difficulty, we generate adversarial texts at the embedding layer of PLMs. For adversarial training, continuous embeddings are easier to optimize and can encode more semantic variations than discrete tokens. The key idea of embedding augmentation is inspired by the word replacement strategy in previous data augmentation methods (Kobayashi, 2018;Wei and Zou, 2019). Instead of selecting some tokens for replacement, we use an augmented embedding to replace the original contextual embedding of a specific token

Pre-trained Language Model
Gaussian noise Original Embeddings in the input sentence. To adhere to the original semantic space, the augmented embedding is derived by a probabilistic mixture of the embeddings of the vocabulary terms, where each term is weighted according to its substitution probability (i.e., replacing the original token with the candidate term) calculated by a masked language model (MLM).
To simplify our presentation, we only discuss the augmentation for a specific tokenw from an input sentence S. The same procedure will be applied to each position of the original sentence S. Specially, we utilize the MLM to evaluate the substitution probabilities of all the terms in the vocabulary. For each chosen token, we predict its probability to be replaced by other words in the whole vocabulary via MLM, denoted as p(ŵ i |S). Finally, we obtain the substitution probabilities of all the terms as where V is the vocabulary size. Different from previous masked prediction (Devlin et al., 2019), we do not mask the chosen token but also keep it as the input to compute the substitution probabilities. In this way, we aim to generate very relevant embeddings for augmentation. Such a strategy is also very efficient in practice, since it no longer performs the costly mask-and-completion operations for each token.
To augment diverse virtual data embeddings, we further draw a random noise from the Gaussian distribution as where the randomness can be controlled by the standard variance σ. By mixing the random noise with the substitution probabilities, we can produce multiple different probability distributions for each instance as Then, for each target tokenw, we obtain its corresponding substituted embedding by aggregating the token embedding matrix according to the noised substitution probability aŝ , and M E ∈ R V ×d is the token embedding matrix from the MLM. Note that by using the output of MLM, our approach can augment more "real" embeddings from the semantic space spanned by original token embeddings. Besides, mixing Gaussian noise brings additional semantic diversity for augmentation.

Regularized Training
The above augmentation strategy is able to enhance the semantic variations by continuous embeddings. However, augmented data is likely to incorporate unexpected semantic drift in representations. To further improve the model robustness, instead of directly using the augmented embeddings as positive examples, we propose a regularized training strategy to prevent large changes between the predictions given real and augmented embeddings. Formally, given the original data point (E i , y i ) and the augmented virtual dataÊ i , where E i andÊ i denote the original embeddings and augmented embeddings of the instance respectively, we set the regularization loss in Equation 1: for minibatch B ∈ { xi, yi } do 5: Tokenize input sentences in B into {w1, ..., wm}.

6:
Generate the substitution probability of all tokens. 7: for j = 1 . . . k do 8: Sample from N (0, σ 2 ). 9: Produce p (ŵi|S) using Eq. 4. 10: Augment virtual data using Eq. 5. 11: Optimize θ using Eq. 1. 12: end for 13: end for 14: end for 15: return θ regularizer enforces the model f to produce similar scores for the original data and augmented data, which lies in the semantic neighborhood of original embeddings. Furthermore, we instantiate the classification loss in Equation 1 as follows: where CE(·, ·) is the cross-entropy loss function, which can be changed according to specific tasks.

Overview and Discussion
In this part, we present the overview and discussions of our VDA approach.
Overview The overall framework of VDA consists of two important parts, namely embedding augmentation and regularized training. We present the overall training algorithm in Algorithm 1. For embedding augmentation, we utilize the output of a MLM as the multinomial mixtures to augment new embeddings for each token in the input sentence. It is called virtual data augmentation, since the augmented embeddings do not correspond to actual text or tokens, but a probabilistic mixture of all the token embeddings. Then, for regularized training, we leverage the original predictions to guide the learning of the augmented embeddings, which reduces the influence from noisy or incorrect perturbations in the augmentations.

Discussion
In the background of machine learning (Schmidt et al., 2018;Yin et al., 2019), robustness corresponds to the ability to resist data drift, perturbation and attack. To improve the robustness, a key point is that the model is able to generalize  to the semantic neighborhood of training data instances (Schmidt et al., 2018). However, discrete augmentation methods (Wei and Zou, 2019;Wang and Bansal, 2018) (e.g., insert, delete or replace tokens) do not have good generalization ability for model optimization. While, virtual adversarial training methods Jiang et al., 2020) cannot well constrain the augmentations in the original semantic space. As a comparison, our approach utilizes original token embeddings to augment new embeddings, so that the augmentations will stay close to the existing embeddings in the same semantic space. For relevance, we adopt a MLM to generate the multinomial mixture according to the likelihood of each candidate given the input. For diversity, we inject Gaussian noise to enhance the randomness. To further balance the two aspects, we design a regularized strategy to guide the augmentation learning towards the original predictions. By only revising the embeddings, our approach is model-agnostic and domain-agnostic, which is general to apply to various PLMs on different downstream tasks.

Experiment -Main Results
We demonstrate the effectiveness of VDA for finetuning PLMs in the text classification task.

Dataset
We conduct experiments on the sentence classification task and the sentence-pair classification task. The dataset statistics are summarized in Table 1.

Sentence Classification
We use four sentence classification datasets for evaluation.
• IMDB 2 : a binary document-level sentiment classification dataset on movie reviews.  • AG's News (Zhang et al., 2015): a news-type classification dataset, containing 4 types of news: World, Sports, Business, and Science.
• MR (Pang and Lee, 2005): a binary sentiment classification dataset based on movie reviews.

Sentence-Pair Classification
We also use two sentence-pair classification datasets for evaluation.
• QNLI (Demszky et al., 2018): a questionanswering dataset consisting of question-paragraph pairs. The task is to determine whether the context sentence contains the answer to the question.
• MRPC (Dolan and Brockett, 2005): a corpus of sentence pairs with human annotations about the semantic equivalence.

Baselines
To evaluate the generalization of our framework, we implement VDA on the following models.
• FreeLB  is an adversarial training approach for fine-tuning PLMs, which adds gradient-based perturbations to token embeddings. We implement it on BERT-Base.
• SMART (Jiang et al., 2020) is a robust and efficient computation framework for fine-tuning PLMs. Limited by the GPU resource, we can only implement the smooth-inducing adversarial regularization on BERT-Base but remove the Bregman Proximal Point Optimization.
• SMix (Si et al., 2020) uses mixup on [CLS] tokens of the PLM to cover larger attack space. We implement it on BERT-Base. For a fair comparison, we remove the adversarial data augmentation strategy here, and leave it on Section 6.2.
• RoBERTa-Large ) is a robustly optimized BERT model with more training data and time. It owns 24 layers, 1024 hidden units and 16 heads, totally 355M parameters.

Evaluation Metrics
We set up various metrics for measuring accuracy and robustness. Original accuracy, is the accuracy  Table 3: Main results on the sentence-pair classification task. " V DA " denotes that the model is trained with our proposed VDA framework. The best results in each group are highlighted in bold.
of models on the original test set. While attack accuracy is the counter-part of after attack accuracy, which is the core metric measuring the robustness. Larger attack accuracy reflects better robustness.
In this paper, we adopt BERT-Attack (Li et al., 2020b) as the attack method, since it can generate fluent and semantically preserved samples. For AG, MR, QNLI and MRPC datasets, we follow previous works (Jin et al., 2020;Li et al., 2020b) to randomly sample 1000 instances for robustness evaluation. For Yelp and IMDB, we randomly sample 300 instances since the long sentences in the two datasets are more time-consuming. Note that for sentence-pair classification datasets (i.e., QNLI and MRPC), we attack the second sentence in evaluation. Besides, we also apply the query number and perturbed percentage per sample for evaluation. Under the black-box setting, queries of the target model are the only way of attack methods to access information. The larger query number indicates that the vulnerability of the target model is harder to be detected, which reflects better robustness. The perturbed percentage is the ratio of perturbed words number to the text length, a larger percentage also reveals more difficulty to successfully attack the model.

Implementation Details
We implement all baseline models based on HuggingFace Transformers 3 , and their hyperparameters are set following the suggestions from the original papers. For our proposed VDA, we reuse the same hyper-parameter setting as the original baseline model. All models are trained on a 3 https://huggingface.co/transformers/ GeForce RTX 3090. For hyper-parameters in VDA, the sampling number m is set as 1, the learning rate is 1e −5 . We use 5% steps to warm up PLMs during training. The variance of Gaussian noise is mostly set as 1e −2 and tuned in {1e −3 , 4e −3 , 1e −2 , 4e −2 }, the weight λ is mostly set as 1.0 and tuned in {0.04, 0.1, 0.4, 1.0, 4.0}. Table 2 reports the evaluation results of our proposed VDA framework and the baseline models on sentence classification datasets. And the results on sentence-pair classification datasets are shown in Table 3. Based on these results, we can find:

Main Results
First, FreeLB and SMART mostly outperform BERT-base model on the original accuracy metric, but perform not well on robustness-related metrics, especially on Yelp and IMDB datasets. These methods adopt gradient-based perturbations and smoothness-inducing regularization, respectively, which are able to improve the classification accuracy but may be not effective in defending against adversarial attacks. A potential reason may be that textual adversarial attacks are discrete, which can not be captured by virtual adversarial training.
Second, SMix improves the robustness of BERTbase in all datasets, but performs not well in original accuracy. It mixes hidden representations of the BERT-base model, which increases the coverage of the attack space for PLMs but may augment noised examples into training data. Besides, RoBERTa-large outperforms all other baselines in performance and robustness metrics. The reason is that RoBERTa-large is pre-trained on more training  data with more training time, which can directly improve the generalization and robustness to adversarial attack samples. Finally, we compare our proposed framework with these baseline models. After being combined with VDA, it is clear to see a significant improvement in robustness metrics on most of datasets. Our VDA utilizes a masked language model to generate substitution probabilities, and then add a Gaussian noise. In this way, we can augment diverse and semantic-consistent examples, which are able to improve the robustness of PLMs. Furthermore, we can also see that the most of baseline models combined with VDA achieve a marginal improvement in original accuracy. It indicates that our approach can better balance the performance and robustness of PLMs. Among them, we can see that our VDA can bring more improvement in MRPC and QNLI. The reason may be that the two tasks are more difficult and require more data for training. The virtual augmented data via our approach is semantic-consistent and diverse, hence it can be more helpful for these tasks.

Experiment -Analysis and Extension
In this section, we continue to study and analyze the effectiveness of our proposed VDA.

Ablation and Variation Study
We devise four variations for exploring the effectiveness of key components in our proposed VDA. BERT+V DA − is the variation by removing the Gaussian noise in Eq. 4. BERT+CEV DA replaces the symmetric KL-  divergence by cross-entropy loss. BERT+Argmax and BERT+Sample adopt argmax and sample operators to select the substituted token according to the substitution probability, respectively. We conduct the experiments on AG and QNLI datasets. As shown in Table 4, most of the variations perform better than BERT in robustness metrics, since they all augment virtual data for improving the robustness. Among them, BERT+V DA outperforms most of the variations in both accuracy and robustness metrics. It indicates that the Gaussian noise, symmetric KL-divergence loss and weighted aggregated embeddings are all useful to improve the robustness and stabilize the accuracy. However, we can see BERT+Argmax and BERT+Sample achieve better results than BERT+V DA in part of metrics, but cause a dramatic drop in other metrics. It indicates that the two variations can not balance the trade-off between accuracy and robustness well.

Virtual Data Augmentation with Adversarial Data Augmentation
Our proposed VDA is general to various methods, including conventional adversarial data augmentation (ADA). In this part, we collect the adversarial examples curated from the MR and MRPC training sets, and add them to the original training set, respectively. Then we test the accuracy and the robustness of BERT-base model and our VDA after training with the adversarial data. As seen in Table 5, although augmented adversarial data improves the robustness of BERT, the performance on original accuracy also drops. The reason may be that there are noised instances in the adversarial data. As a comparison, our proposed VDA can augment diverse and semantic-consistent virtual data, which better balances accuracy and robustness. Be- sides, after combining with ADA, our VDA can be further improved on accuracy and robustness metrics. It indicates that our approach is also general to ADA methods.

Hyper-parameter Analysis
Our framework includes a few parameters to tune. Here, we report the tuning results of two parameters on MR and MRPC datasets, i.e., the variance of the Gaussian noise η and the number of argumented virtual data. We show the change curves of original accuracy and attack accuracy in Figure 3. We can see that our model achieves the best performance when the variance is nearby 0.05. It indicates that too small or too large noise may influence the quality of the augmented virtual data. Besides, our model also achieves the best performance when the sampling number is nearby 3. It shows that augmenting 3 examples per sample is enough to improve the robustness.

Performance Change during Regularizing Fine-tuning
In this part, we investigate how the accuracy and robustness change during regularizing fine-tuning with our VDA. We conduct experiments on AG and MRPC datasets and report the original accuracy and attack accuracy metrics. As shown in Figure 4, the original and attack accuracy of the model can be improved with the increasing of train-ing epochs. When reaching the optimal point, the accuracy and robustness start to shock, and even decrease to some extent. The reason may be that the model has overfitted. An interesting finding is that the optimal points of the original accuracy and attack accuracy are usually not the same one. A possible reason is that accuracy and robustness are not always consistent objectives for deep models. Besides, we can see that after combined with our VDA, BERT is able to achieve a better optimal point with higher original and attack accuracy. It indicates that VDA is an effective regularization approach for BERT.

Conclusion
In this work, we proposed the framework virtual data augmentation (VDA), for robustly fine-tuning pre-trained language models. It is a general framework agnostic to downstream tasks, model architectures and learning strategies. In VDA, we augmented new embeddings by making weighted aggregation on token embedding matrix according to a multinomial mixture distribution. To construct the mixture distribution, we utilized a masked language model to generate the substitution probability for guaranteeing semantic consistency, and a Gaussian noise to provide diversity. And we also adopted a regularized training strategy to further enhance the robustness. Extensive experiments on six datasets have demonstrated that the proposed approach can effectively improve the robustness of various PLMs.