Aligned Weight Regularizers for Pruning Pretrained Neural Networks

Pruning aims to reduce the number of parameters while maintaining performance close to the original network. This work proposes a novel self-distillation based pruning strategy, whereby the representational similarity between the pruned and unpruned versions of the same network is maximized. Unlike previous approaches that treat distillation and pruning separately, we use distillation to inform the pruning criteria, without requiring a separate student network as in knowledge distillation. We show that the proposed cross-correlation objective for self-distilled pruning implicitly encourages sparse solutions, naturally complementing magnitude-based pruning criteria. Experiments on the GLUE and XGLUE benchmarks show that self-distilled pruning increases mono- and cross-lingual language model performance. Self-distilled pruned models also outperform smaller Transformers with an equal number of parameters and are competitive against (6 times) larger distilled networks. We also observe that self-distillation (1) maximizes class separability, (2) increases the signal-to-noise ratio, and (3) converges faster after pruning steps, providing further insights into why self-distilled pruning improves generalization.


Introduction
Deep neural networks (DNNs) have grown increasingly large in the recent years. This has led to models requiring more storage requirements, more resources for training and inference (e.g., GPUs and TPUs), longer compute times and larger carbon footprints. This is largely due to the rise of masked self-supervised learning (SSL) which trains DNNs (e.g., Transformers in NLP) on a large collection of samples that do not have task labels but instead use a subset of the inputs as labels. Given the aforementioned challenges, it has become more difficult for machine learning practitioners to use these SSL pretrained models for fine-tuning on downstream tasks. While training tricks such as effective batch sizes, gradient accumulation and dynamic learning rate schedules (Howard and Ruder, 2018) have improved the efficiency of fine-tuning DNNs under resource constraints, it can still come at a cost, e.g. gradient accumulation leads to less updates.
Pruning (LeCun et al., 1990;Reed, 1993) is a type of model compression method (Buciluǎ et al., 2006) that aims to address these shortcomings by zeroing out a subset of weights in the DNN, while maintaining performance close to the original model. Retraining is often carried out directly after each pruning step to recover from pruning induced performance drops. This process is referred to as iterative pruning. Although, iterative pruning has been extensively studied in the SSL setting (Hassibi and Stork, 1993;Han et al., 2016;Ding et al., 2018) and the transfer learning setting (Molchanov et al., 2016;Gordon et al., 2020;Sanh et al., 2020), little is known about pruning DNNs in the zeroshot setting 1 where a model is required to make predictions on a set of samples from classes that are unobserved during training. One salient example is pretrained cross-lingual language models (XLMs) (Conneau and Lample, 2019;Conneau et al., 2020a) whereby the model is trained with a masked/translation language model (MLM/TLM) objective to predict tokens for a large set of different languages whereby the objective forces the XLM model to learn similar representations for different languages. After cross-lingual pretraining, the model is further fine-tuned to a downstream task in one language (e.g., English) and then evaluated on different languages in the zero-shot setting (e.g., Spanish, French, Chinese, etc.). In this context, applying current pruning methods can damage the XLM cross-lingual alignment that has been learned during pretraining. Ideally, we would aim to prune XLMs in such a way that avoids this alignment distortion which effects the zero-shot performance of pruned XLMs. Additionally, overfitting to the language used for fine-tuning becomes more of an issue due to the progressive reduction in parameters throughout iterative pruning as the remaining weights are relatively large, moving away from an "aligned" XLM state. This is an important problem to address as the application of large pretrained models in the zero shot-setting for natural language and other modalities (e.g images and audio) is of practical importance e.g., using XLMs in production for multiple languages by only requiring annotations in a single language for fine-tuning, making predictions on unseen classes at test time from pretrained visual representations (Bucher et al., 2017) using only semantic descriptions (i.e., label similarity to known classes) or zero-shot predictions in pretrained multimodal models such as CLIP (Radford et al., 2021).
Hence, this work addresses the alignment distortion pruning problem by introducing AlignReg, a class of weight regularizers for magnitude-based pruning that force pruned models to have parameters that point in a similar direction or have a similar distribution to the parameters of the original pretrained network. To our knowledge, this is the first study on how iteratively pruned models perform in the zero-shot setting and how the solution differs from solutions found in the non-zero shot setting. We believe our findings have a strong practical implication as well-established pruning criteria may not be suitable given the observed discrepancy between zero-shot performance and the typically reported non-zero shot performance. Moreover, our proposed weight regularizer improves overall pruning generalization in zero-shot cross-lingual transfer. Below, we summarize our contributions.
• The first analysis of pruning cross-lingual models, how this effects zero-shot crosslingual transfer and performance differences to pruning in the SSL setup.
• A weight regularizer that mitigates alignment distortion by minimizing the layer-wise Frobenius norm or unit similarity between the pruned model and unpruned model, avoiding overfitting to single language task fine-tuning.
• A post-analysis of weight distributions after pruning and how they differ across module types in Transformers.

Related Work
Below we describe regularization-based pruning, other non-magnitude based pruning and how masked language modeling (MLM) implicitly learns to align cross-lingual representations.
Regularization-based pruning. Pruning can be achieved by using a weight regularizer that encourages network sparsity. Three well-established regularizers are L 0 (Louizos et al., 2018), L 1 regularization (Liu et al., 2017;Ye et al., 2018) and the commonly used L 2 regularization for weight sparsity (Han et al., 2015(Han et al., , 2016. Wang et al. have proposed an L 2 regularizer that increases in influence throughout retraining and shows the increasing regularization improves pruning performance. For structured pruning where whole blocks of weights are removed, Group-wise Brain Damage (Lebedev and Lempitsky, 2016) and SSL (Wen et al., 2016) propose to use Group LASSO (Yuan and Lin, 2006) to learn structured solutions.
Importance-based pruning. Magnitude-based pruning (MBP) relies on the assumption that weight or gradient magnitudes have correlation with its importance to the overall output of the network. Mozer and Smolensky instead use a learnable gating mechanism that approximates layer importance, finding that weight magnitudes reflect importance statistics. To measure weight importance as the difference in loss between pruned and unpruned network, LeCun et al. approximate this difference with a Taylor series up to the second order. This involves the product of the gradient and weight magnitude in the 1st term and an approximation of the Hessian and the square of the weight magnitude for the second term. However, computing the Hessian and even its approximations (LeCun et al., 1990;Hassibi and Stork, 1993;Dong et al., 2017;Wang et al., 2019;Singh and Alistarh, 2020) can significantly slow down retraining. In our work, we avoid the requirement of computing the Hessian or approximations thereof, as it is not scalable for models such as XLM-R (Conneau et al., 2020a). Park et al. have extended MBP to block approximations to avoid pruning lowest weight magnitudes that may be connected to weights in adjacent layers that have high weight magnitude. Lee et al. have provided a method to automatically choose the sparsity of layers by using the rescaled version of weight magnitude to incorporates the modellevel distortion incurred by pruning.

Implicit Alignment in Pretrained MLMs
In context of multi-task learning, Chen et al. (2020) minimize the mean squared error between pretrained weights and weights being learned for a set of different source tasks to avoid catastrophic forgetting in the continual learning setting. Conneau et al. (2020b) have found that multilingual MLM (i.e training with an MLM objective with concatenated text for multiple languages) naturally leads to models with strong cross-lingual transfer capabilities. Additionally, they find that this is also found for monolingual models that do not share vocabulary across monolingual corpora and the only requirement is that weight sharing is used in the top layers of the multi-lingual encoder. In the context of our work, we want to bias our fine-tuned and iteratively pruned model to have similar geometric properties and symmetries to these pretrained MLMs to preserve zero-shot cross-lingual transfer.

Methodology
In this section, we describe how our proposed AlignReg weight regularizers can improve pruning performance in both supervised learning and zero-shot pruning settings. We focus on two regularizers, namely, a neuron correlation-based regularizer (cosine-MBP) and Frobenius layer-norm regularizer (frobenius-MBP). Let where each X i of D training samples consists of a sequence of vectors X i := (x 1 , . . . , x n ) and x i ∈ R d (e.g., d = 512). For structured prediction (e.g., NER and POS), y i ∈ R n×c and for single and pairwise sentence classification, y i ∈ R c where c is the number of classes. Let θ = (θ 1 , . . . , θ L ) be the parameters of a pretrained network f with L layers, where θ l refers to the parameters, including weight matrix W l and bias b l , at layer l. Let fθ be a network with parametersθ consisting of weightsW l ∈ R N l−1 ×N l and biasb l ∈ R N l where N l is the number of units in the l-th layer. Here, W l := W l M l where M is the pruned mask. For MBP (Karnin, 1990) we remove weights of W l , ∀l ∈ L with the smallest absolute weight magnitude until a specified percentage p of weights are removed. Note that this is a layer-wise process and requires the pruned weights to be masked with M l which has 0 entries corresponding to weights to be pruned and 1 entries for unpruned weights W l . Global MBP can also be used whereby the weights {W l } L l=1 are first vectorized and concatenated prior to choosing p lowest weight magnitudes. Unlike layer-wise MBP, the percentage of weights removed in each layer can vary for global-MBP. Typically, weight regularization is used with MBP to encourage weight sparsity. Thus the objective for iterative pruning can be expressed as, where λ controls the influence of the weight magnitude regularization. We now describe our proposed AlignReg.

AlignReg -Pruning-Aware Regularization
AlignReg can be used to align weights unit-wise or layer-wise between unpruned and pruned networks. We initially discuss the cosine-MBP regularizer.
cosine-MBP aims to preserve the inherent crosslingual alignment, during iterative pruning, by minimizing the angle between parameter vectors of the same unit in the pruned and unpruned network. The intuition is that cross-lingual alignment relies more on parameter vector direction than vector magnitudes. Moreover, as the network is being pruned, the weights will consequently change weight magnitude to account for the information loss. To apply AlignReg to linear layers within Transformers, we compute the pairwise cosine similarity between pairs of pruned weightsW l ⊂f and unpruned weights W ⊂ f for all l-th layers. For W l ∈ R N l−1 ×N l of the l-th layer, the average weight correlation is where W li is i-th column of the matrix corresponding to the i-th unit of the l-th layer. Intuitively, ρ(W l ,W l ) is the average cosine similarity between weight vectors of the same unit at the l-th layer of the pruned and unpruned network. Adding AlignReg to the objective results in Equation (3), where λ ∈ [0, ∞) controls the importance of AlignReg relative to the main cross-entropy loss ℓ ce (·, ·). The gradient of the loss w.r.t to θ is then Algorithm 1: AlignReg Pruning 1: Input: Weight tensors W 1 , . . . , W L of a finetuned network, p percentage of weights to remove per layer 2: Output: Pruned weight tensorsW 1 , . . .W L 3: for l = 1, . . . , L do 4: Compute ρ(W l , W l ) with Eq.2.

5:
SetW s i as s l -th smallest element ofW 6: SetW l ← M l ⊙ W l 8: end for 9: Compute L θ according to Eq.3 expressed as equation (4), where W l,(,j) andW l,(,j) are j-th column in W l andW l , respectively. Thus, this regularization favors solutions with high cosine similarity between units of pruned and unpruned networks. We also consider a layer-wise ρ(W,W) that relaxes the unit-level alignment to whole layers. This is partially motivated due to the fact neural networks can exhibit similar output activation behavior even when neuron weights have been permuted within the layer (Brea et al., 2019). To perform this we simply apply Equation (2) with vectorized weights ρ(vec(W l ), vec(W l )) and the subsequent partial derivatives in Equations (4) and (5) are applied for updatingW l . In our experiments we did not see a significant difference using vectorized weights and thus use unit-wise cosine similarity.
Algorithm 1 shows how AlignReg is applied for a single mini-batch update during an iterative pruning epoch.

Relaxing Unit-Wise AlignReg To A Layer-Wise
Frobenius Distortion Formulation Thus far we have described the application of cosine similarity as a measure of similarity between unpruned and pruned weights of the same units. However, this may be a strict constraint, particularly at high compression rates where the remaining weights for a unit are forced to have higher norms to allow zeroed weights. Hence, an alternative measure is the layer-wise Frobenius norm (Frobenius-MBP) regularizer based on the difference between weights ||W −W|| F . MBP itself can be viewed in terms of minimizing the Frobenius distortion (Han et al., 2016;Dong et al., 2017) as min M:||M|| 0 =p ||W − M ⊙ W|| F where ⊙ is the Hadamard product, || · || 0 denotes the entrywise 0-norm, and p is a constraint of the number of weights to remove as a percentage of the total number of weights for that layer. In the zero-shot setting, we need to account for out-of-distribution Frobenius distortions, such as alignment distortion in XLM due to pruning and overfitting to a single language. Taking the view of Frobenius distortion minimization when using our weight regularizer, we reformulate it to include Frobenius-MBP as, where W T are the weights from the pretrained model prior to fine-tuning that is cross-lingually aligned from the masked language modeling (MLM) pretraining objective. In our experiments, λ = 5 × 10 −4 .

frobenius-MBP Implicitly Aligns Eigenvectors
To explicitly show that the Frobenius distortion minimization aligns fine-pruned and pretrained parameter vectors we expect their eigenvectors to also be close. We can use the Eckart-Young-Mirsky Theorem (Golub et al., 1987) to express Frobenius distortion minimization as Equation 7, where the unitary invariance under the 2-norm that U,V vanishes and singular value matrix is left to approximate W T , hence the inclusion of Σ. We express X = U k Σ 12 k , Y = Σ 12 k V ⊤ k and XY = A k . Hence, we can further describe the minimization as ||Σ − U ⊤ W T k V|| 2 F and since X, Y are unitary, ||Σ − Σ k || 2 F .

Connections to Knowledge Distillation
Knowledge distillation (KD) works by using outputs of the last layer (Hinton et al., 2015) or intermediate layers (Romero et al., 2015)   operate directly on minimizing a divergence or distance between weight tensors as opposed to their corresponding output activations. Hence, AlignReg does not necessarily need training data as it operates directly on aligning weight tensors. Since the networks that are used for alignment are architecturally identical, we can show that maximizing weight similarity is equivalent to minimizing distance between their corresponding output activations (Romero et al., 2015) when the norm of input Z is smaller than the output range of σ.
For our experiments, we use XLM-RoBERTa Base which contain Gaussian Linear Error Unit (GeLU) activation functions, which can be formulated as σ(Z li ) := Z li /2(1.0 + erf(Z li / √ 2.0)) where erf is an error function, σ(·) is a monotonic activation function and Z li is the input vector. The GELU activation has the properties that for Z li > 0 it is equivalent to the ReLU activation and Z li ≤ 0 it tends to -1. For Z li > 0, ||Z li || 2 ≤ 1 and a monotonic piecewise linear function σ(·), the inequality in Equation 8 holds.
Layer normalization leads to features having zero mean and unit variance and hence ||Z li || 2 ≤ 1. Hence, minimizing the Frobenius distortion of pruned and unpruned weights is equivalent to minimizing the mean squared error (MSE) between output activations, as is the knowledge distillation method used for FitNets (Romero et al., 2015). In contrast, KD using FitNets encourages the student network to have activation outputs that are the same as the teacher with permutation invariance on the units incoming weights, not restricting the weights to be similar. Unlike KD, this minimization can be performance without any data.
Iterative Pruning Details. Texts are tokenized using the SentencePiece BPE tokenizer (Sennrich et al., 2016) with a vocabulary of 250K tokens. For structured prediction tasks (POS and NER), a single layer feed-forward (SLFF) token-level classifier is used on top of XLM-R Base and for sentence-level task a SLFF sentence-level classifier is used. The batch size is 32, the learning rate is 5 · 10 −6 and the maximum sequence length is set to 256 for all tasks, except for POS in which we use a learning rate of 2 · 10 −5 with the adam optimizer (Kingma and Ba, 2015) with weight decay (AdamW) and a max sequence length of 128. We carry out a pruning step after each 15 training epochs, uniformly pruning 10% of the parameters at each pruning step. We omit the pruning of embedding layers, layer normalization parameters and the classification layer as they account for a relatively small number of the total parameter count (< 1%) and play an important role in XLM generalization. Although prior work has suggested non-uniform pruning schedules (e.g., cubic schedule (Zhu and Gupta, 2017)), we did not see major differences to uniform pruning in preliminary experiments. Each task is trained with English data only and evaluated on all available languages for that task. Hence, we expect the percentage of achievable compression to be lower as performance in the zero-shot cross-lingual setting to be more difficult than the monolingual setting (e.g., GLUE tasks).

Empirical Results
We now discuss results on the XGLUE tasks.
News Classification (NC) Figure 1 shows the results on news classification where a category for news article is predicted and evaluated in 5 languages and trained and iteratively pruned on English text. Firstly, we observe the trend in iterative pruning performance degradation is somewhat volatile. From preliminary experiments we found news classification to require only 3 epochs to converge for standard fine-tuning on XLM-RoBERTa Base . We find that this task is relatively "similar" to the pretraining task and therefore able to easier recover from pruning steps. Overall, both Cosine-MBP and Frobenius-MBP consistently lead to the best zero-shot test performance across both pruning steps and languages.
Question Answer Matching (QAM) Figure 3 shows the test accuracy on English and the zeroshot test accuracy on French and German for Question-Answer Matching (QAM). This involves predicting whether a question is answered correctly or not given a question-answer pair. We find that Frobenius-MBP and Cosine-MBP maintain higher accuracy across multiple pruning steps, outperforming baselines. More generally, we see there is close to 2% drop in average test accuracy drop in French and German when compared to testing on samples from the same language used in training.

Named Entity Recognition (NER) The Named
Entity Recognition (NER) cross-lingual dataset is made up of CoNLL-2002NER and CoNLL-2003NER (Sang and De Meulder, 2003, covering English, Dutch, German and Spanish with 4 named entities. From Figure 2 we find that cross-lingual transfer of pruned models is most difficult in German and Dutch, which both come from the same language family, sharing commonalities such as word order and having similar vocabularies. The primary reason for the difficulty in maintaining per- formance in high compression rates for this NER dataset is that there is only 15k training samples, being significantly lower than the remaining XGLUE tasks (the majority contains 100k training samples). Thus, not only is there less training data to recover directly after each pruning step, but the pruning step interval itself is shorter. In contrast, English test performance is close to the original performance up until 25% of remaining weights, unlike the remaining languages. We find that gradient-MBP eventually overtakes MBP approaches past 20% remaining weights. However accuracy has reduced too much at this compression level. We find that Cosine-MBP and Frobenius-MBP weight regu-larizers achieve the best performing pruned model performance above 20% remaining weights, with Lookahead pruning and L 0 regularized MBP being competitive in zero-shot performance.
Part of Speech Tagging (POS) The Part of Speech (PoS) tagging dataset consists of a subset of the Universal Dependencies treebank (Nivre et al., 2020) and covers 18 languages. In Figure 4, we see both Cosine-MBP and Frobenius-MBP tend to outperform baselines, although L 0 -based pruning (Louizos et al., 2018) has similar performance to Cosine-MBP for zero-shot accuracy. There is also a clear discrepancy between SSL accuracy (English) versus zero-shot accuracy (Average), the latter following closer to linear decay after 40-50% of weights remaining. Generally, both Cosine-MBP and Frobenius-MBP outperform baselines with the exception of Thai and Urdu at higher compression rates (< 40%), both being some of the most underresourced languages of all 18 languages.
Web Page Ranking aims to predict whether a web page is relevant (1-5 ratings, "bad" to "perfect") to an input query and it is evaluated for 7 languages using the Normalized Discounted Cumulative Gain (nDCG). From Figure 5, we see that between the 15% -45% region the average zeroshot performance degrades faster than the English language used for training. In contrast, semantically and syntactically different languages from English, such as Chinese, already suffer from loss of alignment due to pruning as the performance gap between proposed methods (and baselines) and random pruning is shortened. Cross-Lingual Natural Language Inference (XNLI) Figure 6 shows the zero-shot crosslingual transfer for various unstructured pruning methods. We find that both the accuracy on the English test (i.e SSL generalization) and the average zero-shot test accuracy are consistently improved using Cosine-MBP and Frobenius-MBP, outperforming L 0 pruning, Lookahead pruning and LAMP. We find that morphologically rich languages such as Arabic, Swahili and Turkish degrade in performance linearly once performance begins to drop after 60% of the remaining weights are pruned. This trend is roughly followed for all MBP-based pruning methods. Additionally, test accuracy on English can be maintained within 10% accuracy drop of the original test accuracy up to 20% of remaining weights for MBP, while Swahili can only be within a 10% accuracy drop up to 55% of the remaining weights. Hence, iterative pruning in the zero-shot setting leads to faster performance degradation for languages that are typologically or etymologically further from the language used for fine-tuning. When comparing, English and the average zeroshot test accuracy we see that the slope is steeper after the inflection point 2 for all pruning methods, not to mention the greater than 10% accuracy drop across pruning steps. Table 1 we show the overall and average task understanding scores on the XGLUE benchmark for our proposed AlignReg weight regularizer and the pruning baselines. We find that the use of AlignReg Cosine-MBP and Frobenius-MBP better preserves cross-lingual alignment during model pruning, thereby outperform other MBP baselines, including LAMP and Lookahead pruning, based on improved zero-shot cross-lingual performance. Discussion From our experiments, we found that layer-wise pruning tends to outperform global pruning. This can be explained by the clear discrepancy between weight norms of different layer types within each self-attention block. Global pruning chooses the majority of weights to prune from the layer type that has the smallest norm, leading to an information bottleneck, or layer collapse (Lee et al., 2018)  is due to layer normalization being applied after query, key and value (QKV) parameters, rescaling features such that weight magnitudes remain low. Hence, this motivates why we have focused on the application of AlignReg to layer-wise MBP. This is reflected in Figure 7 which shows the weight norm by layer type for each layer for MBP. We see that QKV weight values are distinctly higher than the remaining fully-connected layers (attention output layer, intermediate position-wise feedforward layer and the blocks output layer), with the exception that the output attention layer norm becomes higher between layer 3-8. For the majority of tasks, the rate of performance drop for zero-shot test performance occurs close to 30% of remaining weights. This is consistent for all pruning methods and therefore the focus of our analysis has been around this operating region.

XGLUE Average Result Finally, in
We also note that the effect of MBP (including our AlignReg regularization-based MBP) on zeroshot performance for different languages heavily depends on the semantic distance of evaluated lan-guage to the single language used for training. For example, in Figure 6 Arabic, Bulgarian, Swahili and Hindi have the largest drops in test accuracy around 20-60% remaining weights. Similarly Arabic, Thai and Hindi suffer most around 20% -60% for PoS tagging in Figure 4. However, we also acknowledge this is partly reliant on the proportion of training data per language used during pretraining the underlying language model, in our case XLM-R Base .
Lastly, to show the representational degradation of pruned networks, in Figure 8 we visualize the class separability via a t-SNE plot of two principal components of the last hidden representation corresponding to the [CLS] token of an iteratively pruned XLM-R Base for PAWSX. Even from only two principal components of a single token input, we clearly see a change in class separability from 31% to 28% remaining weights, reflecting the lack of linear separation.

Conclusion
In this paper, we analysed iterative pruning in the zero-shot setting where a pretrained masked language model uses self-supervised learning on text from various languages but can only use a single language for downstream task fine-tuning. We find that some languages degrade in iterative pruning performance faster than others for some tasks (NER and XNLI) and propose a weight regularizer that biases the iteratively pruned model towards learning weight distributions close to the crosslingually aligned pretrained state. This improves over well-established weight regularization methods for magnitude-based pruning in both the standard supervised learning setting and the zero-shot setting.