Sharpness-Aware Minimization with Dynamic Reweighting

Deep neural networks are often overparameterized and may not easily achieve model generalization. Adversarial training has shown effectiveness in improving generalization by regularizing the change of loss on top of adversarially chosen perturbations. The recently proposed sharpness-aware minimization (SAM) algorithm conducts adversarial weight perturbation, encouraging the model to converge to a flat minima. SAM finds a common adversarial weight perturbation per-batch. Although per-instance adversarial weight perturbations are stronger adversaries and can potentially lead to better generalization performance, their computational cost is very high and thus it is impossible to use per-instance perturbations efficiently in SAM. In this paper, we tackle this efficiency bottleneck and propose sharpness-aware minimization with dynamic reweighting (delta-SAM). Our theoretical analysis motivates that it is possible to approach the stronger, per-instance adversarial weight perturbations using reweighted per-batch weight perturbations. delta-SAM dynamically reweights perturbation within each batch according to the theoretically principled weighting factors, serving as a good approximation to per-instance perturbation. Experiments on various natural language understanding tasks demonstrate the effectiveness of delta-SAM.


Introduction
Although deep neural networks (DNNs) have demonstrated promising results in various fields such as natural language understanding (Devlin et al., 2019) and computer vision (Krizhevsky et al., 2012), they are often overparameterized and can easily overfit the training data (Zhang et al., 2021).Adversarial training has been proven effective in improving both model generalization (Zhu et al., 2019;Zhang et al., 2020a) and adversarial robustness (Madry et al., 2018;Zhang et al., 2019).A general approach for adversarial training has been first to augment the inputs with small perturbations that lead to the maximum possible change of loss, and then optimize the model parameters to the direction where the changed amount is minimized.
Besides perturbing inputs, a recent work of sharpness-aware minimization (SAM; Foret et al. 2020) has further proposed to adversarially perturb model weights.Such a method works by first adversarially calculating a weight perturbation that maximizes the empirical risk and then minimizing the empirical risk on the perturbed model.This method demonstrates improved model generalizations across different datasets and models.In principle, each instance in a batch has its own worst-case weight perturbation and the weight perturbations of different instances need to be calculated separately and cannot be done in a single forward/backward pass.This leads to a significant increase in computational and memory cost.To allow a feasible algorithm, SAM approximates per-instance perturbations by a single per-batch perturbation, where the weight perturbation is calculated on the averaged loss of the batch and shared by all instances in the batch.However, as the per-batch perturbation represents the average of perturbations yielded by different instances, it is a weaker adversary compared to per-instance perturbations, and may hinder the effectiveness of SAM.
In this paper, we study how to efficiently approximate per-instance weight perturbation for sharpness-aware minimization, while maintaining a similar computational cost to per-batch perturbation.We first theoretically analyze the gradient posed by the optimization of per-instance perturbation, and find that it can be effectively approximated with a weighted-batch perturbation under some assumptions, where the instances with a larger rate of gradient change are up-weighted.Based on this motivation, we propose sharpnessaware minimization with dynamic reweighting (δ-SAM).Specifically, we first estimate the Hessian and gradient norm of each instance by perturbing the loss with a random Gaussian noise on model weights.Next, δ-SAM dynamically reweights the loss within each batch of training instances, and then calculate a shared weight perturbation that maximizes the reweighted batch loss.Finally, we update the perturbed model on the original (unweighted) batch.Compared to SAM, δ-SAM only requires extra computation cost in estimation of the rate of gradient change, which can be efficiently performed using three additional forward passes.
We evaluate δ-SAM on finetuning pretrained language models (PLMs).Experiments on standard GLUE benchmark (Wang et al., 2018), selfsupervised Semantic Textual Similarity (STS), and abstractive summarization tasks show that besides significantly outperforming base models, δ-SAM also consistently outperforms SAM with only 18% extra computational cost in average.
The main contributions of this paper are threefold.First, we analyze the training objective of per-instance weight perturbation and find that under some assumptions, it can be approximated by a weighted-batch perturbation, where instances are efficiently reweighted according to their caused rates of gradient changes.Second, we propose to use random perturbations as estimations to efficiently realize the weighting scheme.Third, we evaluate δ-SAM on a diverse set of datasets and find consistent improvements across the board.

Related Work
Model Generalization.Deep neural networks are often overparameterized and may suffer from poor generalization (Zhang et al., 2021).A lot of efforts have been devoted to improve the generalization of neural models, leading to methods including data augmentation (Sennrich et al., 2016;Wei and Zou, 2019;Sun et al., 2020;Kumar et al., 2020;Thakur et al., 2021), regularization (Loshchilov and Hutter, 2018;Xuhong et al., 2018;Liang et al., 2021), and improved optimization processes (Izmailov et al., 2018;Mobahi et al., 2020;Heo et al., 2021).These methods consider different aspects of generalization and may be combined to achieve better performance.Among them, adversarial training (Goodfellow et al., 2014) has demonstrated its effectiveness in improving model generalization without the need of any extra data or external knowledge, and has been widely attempted to enhance NLP models (Zhu et al., 2019;Jiang et al., 2020a;Pereira et al., 2021;Li and Qiu, 2021).Adversarial training works by adversarially perturbing the input embedding and either minimize the adversarial risk or regularizes the change of risk to be small.Specifically, FreeLB (Zhu et al., 2019) uses projected gradient descent (PGD; Madry et al. 2018) to generate adversarial perturbations on input embedding, and recycles the computed gradients when updating model parameters in adjacent steps (Shafahi et al., 2019) to reduce the computational costs.TAT (Pereira et al., 2021) improves FreeLB by prioritizing the most frequently mispredicted classes in perturbation calculation.TAVAT (Li and Qiu, 2021) uses a token-level accumulated perturbation vocabulary to guide the initialization in PGD.However, these works only consider the robustness on input feature representation, while we consider the robustness of all model weights.
Sharpness-Aware Minimization.Foret et al. (2020) leverage the correlation between flat minima and better model generalization and propose SAM for training deep neural models that are robust to adversarial weight perturbations.It has demonstrated effectiveness in tasks on both vision (Chen et al., 2022;Zheng et al., 2021) and language (Bahri et al., 2022) modalities.Several variants of SAM have been proposed to improve its efficiency or effectiveness.For efficiency, Brock et al. (2021) propose to speed up SAM by perturbing fewer instances in the batch; Du et al. (2022) introduce stochastic weight perturbation and sharpness-sensitive data selection to reduce the computational overhead.For effectiveness; Kwon et al. (2021) propose to adaptively set the SAM's radius such that it is invariant to parameters' scales; Zhuang et al. (2022) introduce a gradient ascent optimization step in the perturbed model's orthogonal direction to achieve better flat- ness.In contrast to the aforementioned studies, this paper, for the first time, tackles how to narrow the gap of per-instance and per-batch weight perturbation.Accordingly, we propose an efficient approximation to per-instance weight perturbation, which shows improved results on several NLP tasks while does not bring much computational overheads.
3 Sharpness-Aware Minimization (SAM) In this section, we briefly review the principle of SAM and discuss its limitations.
Literature has observed a direct correlation between flat minima and better model generalization, both empirically and theoretically (Keskar et al., 2016;Dziugaite and Roy, 2017;Li et al., 2018;Jiang et al., 2020b).To find a flat loss landscape, SAM (Foret et al., 2020) adversarially perturbs the model weights and optimizes the following minmax objective on a batch of size N : where given the model weights w, the inner maximization seeks for a perturbation ϵ with L 2 -norm ≤ ρ that maximizes the empirical risk, and the outer minimization minimizes the empirical risk of the perturbed model.This training objective aims at finding model parameters whose neighborhood has a uniformly low training loss.As finding the exact solution to ϵ is NP-hard, SAM estimates the solution ϵ * of the inner maximization with a singlestep gradient descent on the empirical risk of the batch: The outer minimization can be performed with a standalone optimizer (e.g., Adam; Kingma and Ba 2015).SAM roughly doubles the computational cost of training the model, requiring two forward and two backward passes for each batch.The SAM algorithm is outlined in Alg. 1.Besides perturbing by batches, weight perturbation can also be performed on individual instances: where ϵ i is calculated by single-step gradient descent on individual instances.This approach is similar to many adversarial training methods in NLP, such as VAT (Miyato et al., 2018) and FreeLB (Zhu et al., 2019), except that the perturbation is computed on model weights instead of input embedding only.We refer to the objectives of Eq. 1 and Eq. 2 as per-batch weight perturbation and per-instance weight perturbation, respectively.It is noted in the same paper by Foret et al. (2020) that per-instance weight perturbation produces a smaller test error and is a better predictor of model generalization.Despite its effectiveness, per-instance weight perturbation increases the computational and memory cost significantly, requiring 2N forward and 2N backward passes for a batch of size N .Because per-instance weight perturbation modifies all model weights independently, the perturbation for each individual instance needs to be computed on a distinct model copy.Therefore, per-instance weight perturbation can be computationally unaffordable for large-scale training.

SAM with Dynamic Reweighting
In this paper, we seek to improve SAM with a better adversary on weight perturbations.As the perbatch weight perturbation adopted by SAM weakens the adversarial training, we propose a simple yet effective modification of SAM, δ-SAM (SAM with dynamic reweighting), that can approximate per-instance weight perturbation without requiring much additional computational cost.Our reweighting approach is motivated by a theoretical analysis on approximating the per-instance weight perturbation to justify its superior efficiency.Based on this motivation, we then illustrate how δ-SAM is realized in implementation.

Theoretical Motivations
In this subsection, we motivate our dynamic reweighting approach by formally analyzing the training objective posed by per-instance perturbation, and show that it can be approximated with a weighted-batch perturbation, which motivates our δ-SAM algorithm.Preliminary.We motivate our approach from the perspective of sharpness in SAM, which quantifies the flatness of loss landscape as the increase of loss in the neighborhood region of model weights.The sharpness of per-batch and per-instance weight perturbations are defined as: Due to non-shared ϵ i , R inst ≥ R batch , suggesting stronger regularization effects of R inst .However, R inst is expensive to compute, since ϵ i in the N inner maximization problems must be calculated by gradient descent on N individual instances, for which O(N ) backward passes through the network are needed.In the analysis below, we show how to approximate the stronger R inst with a weighted per-batch weight perturbation.We start by considering the second-order expansion of a general empirical risk l i for instance i: To allow a tractable theoretical analysis, we assume that the Hessian is a low-rank, positive definite matrix H i (w) = a i ∇l i (w)∇l i (w) ⊺ , (a i > 0).Then, we obtain the perturbations in R batch and R inst under second-order approximation in closedform with one-step gradient descent, to align with the practice of SAM: where ∇l(w) =1 N N i=1 ∇l i (w) is the average gradient of the batch.
Training Objective.After R batch or R inst is obtained, SAM will compute ∂ ∂w R batch or ∂ ∂w R inst and update model weights to minimize the loss.To aim for a more effective perturbation, we seek to align with the gradient R inst , which determines how model weights will be updated under the strong perinstance adversary.We hope to update the model weights in a "similar" manner as the per-instance adversary, while not explicitly computing the expensive term R inst .Here "similar" means the cosine similarity between the gradient of R inst and our new objective is positive. 1 We thereby first derive the gradient of perinstance weight perturbation.Specifically, after calculating ϵ i in the inner maximization step (ϵ i is not differentiated in outer minimization), the gradient of per-instance perturbation is: In this paper, we aim at finding a per-batch perturbation ϵ ′ that produces the gradient whose direction is aligned with the per-instance gradient, so model weights will be updated similarly using gradient based optimizers.Specifically, for a shared perturbation ϵ ′ and R defined as R := Compared to the ordinary R batch , here we propose to use a different ϵ ′ .Our goal is that optimizing R also leads to smaller R inst ; that is, we seek to find an An easy choice would be: Because each H i is positive definite under our assumptions, 1 N N i=1 H i (w) is also positive definite, then we have: Thus, under this specific inst is the gradient that we aim to approximate and not computable here, the key observation of Eq. 4 is that per-instance weight perturbation can be optimized by using a perturbation shared by all instances in the batch.Therefore, we attempt to perturb the model with only a (rough) estimation of ϵ ′ .Now the next challenge is how to efficiently derive such estimation.
An important observation is that under our assumptions on H i (w) and the use of second order approximations, ∂ ∂w R inst can be calculated by using one time backpropagation on a reweighted batch.To see this fact, we define weights g i = a i ∥∇l i (w)∥ 2 and the reweighted batch is: Then, treating g i as constants, ∂ ∂w l reweighted ∝ ∂ ∂w R inst , as defined in Eq. 3. Compared to perbatch SAM, this reweighting only requires small extra computation cost on calculating the instance weights g i .We introduce how to estimate these weights efficiently in the following section.

Implementation
In δ-SAM, the key problem in implementation is how to efficiently estimate the instance weight g i .We solve this problem by sampling random perturbations.Specifically, for random perturbation r, where r i follows Gaussian distribution N (0, σI), under the same assumptions as in Section 4.1, we have: Therefore, by sampling random perturbations and take the expectation, we can get unbiased estimations of a i and the gradient norm ∥∇l i (w)∥ 2 .Each estimation takes three forward passes for calculating l i (w), l i (w + r), and l i (w − r).As we do not need to save the intermediate states for backpropagation (no_grad in PyTorch), these forward passes are faster than normal ones.In δ-SAM, for the efficiency of the algorithm, we only sample one (shared) r ∼ N (0, σI) for each batch in δ-SAM, and then calculate the instance weight g i by: where η is a hyperparameter for avoiding division by zero.After deriving the instance weights, the weighted-batch weight perturbation can be computed by: We hereby summarize our algorithm, as outlined in Alg. 1. Modifications made for δ-SAM are highlighted in blue.Given a batch B, we first dynamically reweigh the instances by Eq. 5, then estimate the perturbation ϵ * that maximizes the reweighted loss by a single-step gradient descent as shown in Eq. 6 and Eq. 7, and finally minimize the empirical risk of the perturbed model on the original (unweighted) batch.

Baseline Methods
We compare SAM and δ-SAM to the following baseline methods, which were all proposed for improving the generalization of PLMs: • R-Drop (Liang et al., 2021) enforces the prediction of the same instances augmented by different dropout masks to be similar with a consistency term (KL divergence for classification and mean squared error for regression), which leads to improved performance on various language and vision tasks.
• R3F (Aghajanyan et al., 2020) also uses a consistency term to make the prediction of the same instance to be similar.Besides augmenting the instances by different dropout masks, it further adds random uniform or normal noise to input embedding in PLMs.Therefore, R3F can be regarded as an extension to R-Drop.
• FreeLB (Zhu et al., 2019) adversarially perturbs the token embedding using a multi-step projected gradient descent (PGD; Madry et al. 2018) to maximize the empirical risk and regularizes the adversarial risk to be small.
• SMART (Jiang et al., 2020a) is a framework that combines multiple techniques for improving model generalization, including adversarial training, and improved optimizer and regularization techniques.In terms of adversarial training, it perturbs the input embedding with PGD to maximize the empirical risk.It then uses a consistency term to regularize the change of risk to be small, for which the consistency term is defined as the KL divergence for classification and mean squared loss for regression.

GLUE Tasks
Task Setup.We first evaluate δ-SAM on the GLUE benchmark (Wang et al., 2018).In this experiment, we use both BERT BASE and RoBERTa LARGE as the encoders.To ensure a fair comparison, for task-specific hyperparameters including batch sizes, optimizers, learning rates, training steps, weight decay, dropout rates, and learning rate scheduling, we strictly replicate the values from R-Drop (Liang et al., 2021).For SAM and δ-SAM, we search ρ in {0.01, 0.02, 0.05} and η in {1e-4, 2e-4, 5e-4, 1e-3}.For σ in random perturbations, we find that rescaling the random Gaussian perturbation to an L 2 -norm of ρ achieves promising results, so we simply set σ = 1 and rescale the random perturbation afterwards.Following the evaluation settings of R3F and R-drop, we report the best result on the development set out of 5 runs of training with different random seeds.
Results and Discussion.Results are shown in Table 1.We observe that in average, SAM improves BERT BASE and RoBERTa LARGE by 1.0% and 0.7%, respectively, showing that SAM improves the generalization of PLMs, being consistent with the findings in the recent work (Bahri et al., 2022).However, its performance is still worse than other compared methods.On the other hand, δ-SAM improves BERT BASE and RoBERTa LARGE by 1.8% and 1.2%, respectively, and also achieves better or comparable results compared to other methods, demonstrating its effectiveness.In terms of individual tasks, the performance gain of δ-SAM to Table 2: Results on self-supervised STS tasks.For all datasets, we report the average Spearman's ρ of 5 runs of training using 5 fixed random seeds.
SAM is larger on smaller datasets (e.g., MRPC, RTE, CoLA, SST2), while it becomes less prominent on larger datasets.We hypothesize that due to increased training steps and number of instances in large datasets, the gap between per-batch and perinstance perturbation becomes smaller.It is also possible that smaller datasets need better generalization so δ-SAM helps more.Besides, we observe that the improved performance and generalization by δ-SAM is obtained at a merely little average extra computational cost of 18% to SAM (see Table 6 in Appendix for the running time of models).
Taking BERT BASE and the SST2 dataset as an example, the average running time is 118/132 min for SAM/δ-SAM, respectively, meaning that δ-SAM is only 12% slower than SAM to approximate perinstance perturbation.

Self-supervised STS
Model.To conduct self-supervised STS evaluation, we apply δ-SAM to the training process of Mirror-BERT BASE and Mirror-RoBERTa BASE (Liu et al., 2021), which are SOTA self-supervised sentence embedding frameworks.Similar to SimCSE (Gao et al., 2021), Mirror-BERT embeds a sentence x with the same encoder but different dropout masks to get two sentence embedding h 1 and h 2 , and optimizes h 1 and h 2 to be similar using contrastive loss.This training objective resembles R-Drop and SMART.Empirically, we find that applying adversarial training (FreeLB, SAM, and δ-SAM) to only one embedding (e.g.h 2 only) achieves much better results than to the contrastive loss, and we use that strategy in experiments.

Summarization
Task Setup.We experiment with the abstractive summarization task on the CNN/DailyMail dataset (Hermann et al., 2015).Using the large version of BART (Lewis et al., 2020) as the encoderdecoder model, we compare δ-SAM to two regularization methods, including R3F and R-Drop, that have experimented on this task.Besides, we also compare to PEGASUS (Zhang et al., 2020b), which introduces a self-supervised training objective specifically designed for summarization.Following the experimental settings of BART, we report metrics including the unigram ROUGE-1 and bigram ROUGE-2 for evaluating the informativeness, and the longest common subsequence ROUGE-L for evaluating the fluency.
Results and Discussion.Results are shown in Table 3.We observe that δ-SAM achieves the best results in ROUGE-1 and ROUGE-L, outperforming the original BART by 0.54% and 0.91%, respectively.As for ROUGE-2, it also outperforms BART by 0.26% and achieves performance comparable to R3F and R-Drop.This experiment shows the effectiveness of δ-SAM on optimizing an encoderdecoder model for abstractive summarization.

Analysis
The previous experiments have demonstrated that δ-SAM achieves promising improvements on various tasks.In this section, we analyze whether δ-SAM can derive smaller adversarial risk and how well it approximates per-instance weight perturbation.We assess the adversarial risks and accuracies of four optimization approaches including vanilla training, SAM, δ-SAM, and per-instance weight perturbation.We measure the adversarial risk with Eq. 2 in §3, which is copied as follows: For all compared methods, we set ρ = 0.05 in L adv .
Due to the high computational cost of per-instance weight perturbation (about 7x of δ-SAM with a batch size of 16), we only conduct experiments on two small datasets: MRPC and RTE.
From the results shown in Table 4, we observe that δ-SAM achieves smaller adversarial risk than SAM, showing that δ-SAM is indeed a better approximation to per-instance weight perturbation than SAM, being consistent with our theoretical motivation ( §4.1).When it comes to accuracy, we observe that: (1) Per-instance weight perturbation generally achieves the highest accuracy except for the maximum accuracy on RTE, being consistent with the observation in Foret et al. (2020); (2) Although δ-SAM consistently outperforms SAM, its performance is often slightly lower than the much more costly per-instance weight perturbation, indicating room for further improvements.

Conclusion
This paper presents a new sharpness-aware minimization method with dynamic reweighting (δ-SAM).The proposed method represents the first successful attempt in realizing a per-instance weighting scheme.We achieve this by prioritizing instances with larger gradient change rate in adversarial weight perturbation, in comparison to previous approaches that adopt per-batch weight perturbation.We show that perturbation calculated on reweighted batch can serve as a better approximation to per-instance weight perturbation while requiring only similar computational cost to perbatch perturbation.We conduct extensive experiments on the GLUE, STS, and abstractive summarization benchmarks.Across all 30 experimental setups that compares to SAM, δ-SAM achieves an consistent improvement over SAM in 27 of them.When compared to a set of other competitive regularization methods, δ-SAM achieves the best performance in 23 out of 33 of the setups.Further, we quantitatively analyze δ-SAM's impact on sharpness, finding that it indeed leads to flatter loss landscape.Future work includes inventing new techniques to further reduce the computational cost of δ-SAM and demonstrating its effectiveness on more tasks such as sequence tagging and question answering.

Limitations
Like SAM and other training methods based on weight perturbation, the improved performance by δ-SAM is at the cost of introducing additional computational overhead to vanilla training.Specifically, although δ-SAM more precisely approximates perinstance weight perturbation with merely 18% extra computational cost to per-batch SAM, both SAM and δ-SAM are still slower than vanilla training by roughly doubling the computational costs in practice.This may limit the application of such optimization algorithms on massive-scale training.

Figure 1 :
Figure 1: δ-SAM performs close to the computationintensive per-instance weight perturbation while adding only marginal computation overhead to the standard perbatch SAM.Results shown are from the MRPC dataset.See §5.5 for more detailed results.

Table 1 :
Liang et al. (2021)lopment set of the GLUE benchmark.*denotesresults derived from the model intermediately trained on the MNLI dataset (not comparable to other results), while others are derived by finetuning the original BERT/RoBERTa.The results of BERT BASE are from the reimplementation byLiang et al. (2021).

Table 3 :
Results on CNN/Daily Mail summarization.

Table 4 :
Adversarial risk and evaluation results on MRPC and RTE datasets.We report the average adversarial risk and the median/max accuracy of 5 runs.