Gradient-based Intra-attention Pruning on Pre-trained Language Models

Pre-trained language models achieve superior performance but are computationally expensive. Techniques such as pruning and knowledge distillation have been developed to reduce their sizes and latencies. In this work, we propose a structured pruning method GRAIN (gradient-based intra-attention pruning), which performs task-specific pruning with knowledge distillation and yields highly effective models. Different from common approaches that prune each attention head as a whole, GRAIN inspects and prunes intra-attention structures, which greatly expands the structure search space and enables more flexible models. We also propose a gradient separation strategy that reduces the interference of distillation on pruning for a better combination of the two approaches. Experiments on GLUE, SQuAD, and CoNLL 2003 show that GRAIN notably outperforms other methods, especially in the high sparsity regime, and achieves 6 7x speedups while maintaining 93% 99% performance. Under extreme compression where only 3% transformer weights remain, the pruned model is still competitive compared to larger models.


Introduction
Transformer-based (Vaswani et al., 2017) pretrained language models (PLMs) have achieved great success and become the backbone of various natural language processing tasks. However, these models are computationally expensive and slow in inference due to their large size, which limits their applications in real-world scenarios. To overcome this shortcoming, many works have devolved into the compression and acceleration of PLMs.
One of the mainstream approaches is pruning, which compresses the model by identifying and removing redundant neurons (substructures). Based on types of pruning units, pruning methods can be divided into unstructured pruning, which selects and remove each weight individually (Han et al., 2015b;Zhu and Gupta, 2018;Sanh et al., 2020;Gordon et al., 2020), and structured pruning, where entire rows or more coarse units are removed (Xia et al., 2022;Lagunas et al., 2021). Unstructured pruning yields models with higher sparsities, but it is hard to improve the inference speed without specialized devices. In contrast, structured pruning is more effective at acceleration on common devices.
Another mainstream approach is knowledge distillation (KD). In the context of KD, the student model learns from the teacher model by mimicking the teacher's outputs and intermediate representations (Jiao et al., 2020;Sun et al., 2020). Unlike pruning, the structure of the student model is specified in advance. Choosing a suitable student structure is crucial for effective distillation.
When applying structured pruning on PLMs, the most common pruning units are the hidden dimensions in feed-forward networks and the attention heads in the multi-head attention layers. Recent works have introduced some new kinds of pruning units, such as whole multi-head attention and feed-forward layers (Xia et al., 2022), and blocks of weights (Lagunas et al., 2021).
However, we argue that the basic pruning units (attention-head and feed-forward dimensions) only span a small model structure space and limit the structures that the pruning algorithm can explore. Take BERT base for example. There are 12 × 3072 feed-forward dimensions and 12 × 12 attention heads can be pruned in total. Any pruned model has to choose its units from the two sets. The choices of attention heads are rather limited.
In this work, we propose GRAIN (Gradientbased Intra-attention pruning), a structural prun-ing method that inspects the fine intra-attention structures, and prunes the dimensions of each head individually. Parameters can be more freely allocated among different heads. It greatly expands the searching space of model structure. With a larger searching space, models are more likely to find better structures.
Directly applying intra-attention pruning yields models with many fragmented pruning units. This may hinder the running efficiency on devices like GPUs and bring additional overhead. We devise a structure regularization strategy, which provides a mechanism for achieving variable speedups. It encourages the pruning process to prioritize the pruning of fragmented units to produce more regular structures. The speedups can be significantly improved with a little or no performance drop.
We adapt gradient-based pruning (Michel et al., 2019) to integrate the above approaches. Gradientbased pruning is a light-weighted method that measures the importance of the pruning units with gradient-based scores and then prunes the least important ones. Following Zhu and Gupta (2018), we prune the model gradually during fine-tuning.
Pruning and KD are complementary and suitable for working together in the sense that pruning determines the model structure, while KD provides efficient training objectives. Previous works have also shown that pruning with a distillation objective improves performance (Sanh et al., 2020;Xia et al., 2022). However, the knowledge distillation objective may inject noise and disturb gradient-based pruning. Therefore, we propose a gradient separation strategy to mitigate the negative effects of the KD on pruning by optimizing parameters and the model structure with different sets of gradients.
In the experiments, we compare GRAIN with strong pruning and distillation baselines on GLUE, SQuAD, and CoNLL 2003 tasks. GRAIN achieves promising performance at different compression ratios and notably outperforms the baselines on all tasks at the same or similar model size. Furthermore, even under extreme compression where only 3% weights in transformers remain, GRAIN still produces competitive results.

Transformers
A Transformer block (Vaswani et al., 2017) is mainly composed of a multi-head attention (MHA) layer and a feed-forward network (FFN) layer.
Let X ∈ R n×d be the input to the transformer, where n is the sequence length, and d is the hidden size. An attention head is parameterized by the matrices where d h is head size, and i is the head index. An MHA layer contains N h attention heads Usually, we have d h = d/N h . In the actual implementation, the parameters of N h heads are gathered together and stored in four matrices Following the MHA layer is the feed-forward network layer. It consists of two linear layers and a GeLU activation (Hendrycks and Gimpel, 2016) A transformer block contains other components, such as LayerNorm and residual connection, but they only take up a few parameters.

Gradient-based Pruning
Generally, the pruning methods prune the neurons based on their importance scores. Different methods have defined different scores. In magnitude pruning, the score of each weight is its absolute value (Han et al., 2015a;Zhu and Gupta, 2018). In L0-regularization (Louizos et al., 2017) and movement pruning, the scores are additional trainable parameters and get optimized during training.
The gradient-based pruning proposed in Michel et al. (2019) defines the importance score of a set of neurons Θ as the variation of the loss if the neurons Θ are removed: where X is the data distribution. The expression in the absolute sign is the first-order Taylor approximation of the loss L around Θ = 0. For example, by setting Θ to be the weights of the attention head h i or the weights in i-th row of W 2 , IS(Θ) gives the importance score of the head h i or the i-th hidden dimension in the FFN layer. A lower importance score means the loss is less sensitive to those neurons. Therefore, the neurons are pruned in the order of increasing scores. Empirically, instead of pruning the model to the target size at once, pruning gradually results in better performance (Yang et al., 2022). After pruning, further fine-tuning is employed to recover performance (Sanh et al., 2020;Zhu and Gupta, 2018).

Methodology
GRAIN performs task-specific pruning together with knowledge distillation. We adapt gradientbased pruning to integrate intra-attention structure pruning and structure regularization. Gradient separation is applied to dissolve the conflict between gradient-based pruning and KD. We also explore the settings of embedding compression.
Following previous work, the word embedding matrix is excluded unless otherwise specified in counting the model size. We use the term model density to refer to the size of the pruned model relative to the size of the unpruned model. Sparsity is equal to one minus model density.

Intra-attention Pruning
Generally, there are two most common pruning units that have been widely used in many works: FFN hidden dimensions and attention heads. These pruning units have been treated as atomic in structural PLM pruning.
In intra-attention pruning, we introduce two new kinds of pruning units. A key observation is that attention heads allow for finer structural pruning. First, different heads do not have to have the same head size, so we can prune the output dimensions of the matrices have to share the same output dimension, and W V i , W O i have to share the same output dimension, but the output dimensions of these two groups can be different.
Based on the above discussion, we introduce two new pruning units: query dimensions for each head, namely the output dimensions of W q and W k ; and value dimensions for each head, namely the output We further replace the attention heads as pruning units with query dimensions and value dimensions, since the latter are more fundamental units than attention heads. We still use FFN hidden dimensions as the pruning units in our gradient-based pruning. Pruning each unit in this new set results in 2d parameters reduction. The new set of pruning units greatly enriches the model space, thus allowing for more efficient model structures.

Structure Regularization
With intra-attention pruning, there are fewer constraints on the pruned model structures. However, intra-attention pruning tends to generate heterogeneous structures with fragmented modules. For example, an MHA layer containing three heads of sizes 4, 20, 40 respectively is more fragmented than an MHA layer containing only one head of size 64. There are two negative consequences of fragmentation. First, fragmentation may slow down the running efficiency on devices like GPUs. Second, with fragmentation, pruning methods are hard to find whole layers that can be pruned since there are some pruning units left in each layer, which also slows down the inference.
To remedy this, we introduce structure regularization as a kind of correction, which encourages the pruning process to generate more structured models with less fragmented units.
We define D(M, W ) as the density of the pruning unit W in module M, i.e., the ratio of the remaining pruning units in M. The regularized importance score is: where α is the regularization strength. The lower the density of the pruning unit W in M, the lower the importance score. Thus the modules with low density tend to be pruned with priority. Fewer lowdensity modules (fragmented modules) will be left in the pruned model.
We apply structure regularization on the intraattention structure: W are the rows/columns in and M is the attention head. With structure regularization, the pruned model will contain less fragmented attention heads.

Gradient Separation
Knowledge distillation provides effective objectives for transferring knowledge from a large model  to a small model. In knowledge distillation, the teacher model is a fixed fine-tuned model and the student model is to be optimized and should learn from the teacher. The most simple KD involves a cross-entropy loss between the student and the teacher's prediction probabilities where T and S denote teacher and student respectively, and p τ = softmax(z/τ ) is the scaled probability with temperature τ and logits z. The performance of KD can be improved by employing a hidden layer matching objective that distills the knowledge from the intermediate hidden states where I is the set of pairs of layer index to match, H i (i > 0) is the hidden representation from the i-th transformer block (H 0 represents the output from the embedding layer), and W i is a trainable linear mapping. We apply both objectives (5) and (6). The total loss is L = L ce + L hidden .
For simplicity, we do not use any additional hyperparameter that balances L ce and L hidden in (7). However, when applying KD with gradientbased pruning, the hidden layer matching loss L hidden should be treated carefully. In gradientbased pruning, the units are pruned based on how significantly they affect the model predictions.
Thus, the importance score should be calculated solely from the cross-entropy loss, and we should avoid the gradient from other losses like L hidden affects the estimation of the importance of the units. Therefore, we apply gradient separation (GS): the gradient from the task objective L ce is used for both optimization of model parameters and the computation of the importance scores that guide the pruning process, while the gradient from the auxiliary loss L hidden is only used in model parameters optimization. The gradient flows of different losses are illustrated in Figure 1.

Embedding Factorization
The aforementioned pruning process reduces the number of parameters of the transformer blocks. Another large fraction of the parameters are stored in the word embedding matrix E ∈ R q×d , where q is the vocabulary size and d is the hidden size. Word embedding matrix E accounts for about 22% parameters in BERT base , and thus become dominant when the transformers are heavily pruned.
A common approach to reduce embedding size is the matrix factorization (Lan et al., 2020). We apply singular value decomposition (SVD) on the word embedding matrix E to reduce its size: where U ∈ R q×d ,V ∈ R d×d and Σ is a diagonal matrix composed of singular values. We approximate E with E r by selecting the top r singular values and corresponding r rows from U and V respectively where W r ∈ R q×r and U r ∈ R r×d . The original embedding E is now replaced by W r and V r . The embedding size is reduced from qd to (q + d)r.
Embedding factorization has little effect on inference speed but significantly reduces model size. Some works (Xia et al., 2022;Lagunas et al., 2021) exclude word embedding matrix in calculating the number of parameters. To compare with those works, we also study and test the proposed method without embedding factorization. We name this setting as GRAIN w/o EF.

GRAIN Procedure
Similar to Zhu and Gupta (2018) and Sanh et al. (2020), we take an iterative approach to prune the model. The whole process can be divided into three stages, as depicted in Figure 2.
We denote the number of total training steps by N . The first stage is the warm-up stage. We train the student model for N p s steps with the KD objective, where 0 < p s < 1 is a hyperparameter.
In the second stage, we gradually prune the model together with the distillation for N (p e − p s ) steps. The model density s decreases from the initial density (100%) to the target density s f under the control of a cubic scheduler (Zhu and Gupta, 2018) 10) where t ∈ [0, 1] is the training percentage. At each step i = N t, the pruning is guided by the exponential smoothed importance score where IS r i (W ) is the regularized importance score of the pruning unit W calculated at step i, and β is the smoothing factor. The smoothed score avoids the large variance and leads to more stability.
In the last stage, the model reaches the target density and its structure is fixed. We continually train the model with distillation to recover performance.
The three stages take place consecutively, and the whole process is efficiently done in a single run of fine-tuning.

Experiment Setup
Datasets We evaluate our approach on the machine reading comprehension task SQuAD 1.1 (Rajpurkar et al., 2016), the named entity recognition task CoNLL 2003 (Tjong Kim Sang and De Meulder, 2003), and four text classification tasks (SST-2, QNLI, MNLI, and QQP) from GLUE benchmark (Wang et al., 2018). We select the above four tasks because they have large training data, so the results are more stable.
Training Settings We use BERT base (uncased) as the backbone model for both teachers and students. We prune the model with different target densities, ranging from 3% to 20%. We report the model size with and without the word embedding matrix for easy comparison with other works.
Hyperparameters The batch size is 32 for all tasks. We train the model for 20 epochs on all tasks except CoNLL 2003, on which we train for 40 epochs. The distillation temperature τ is set to 8; the start and the end of pruning p s , p e are 0.2 and 0.4, respectively. The smoothing factor β in the exponential smoothing is 0.998. 3 The reduced embedding size r is 192, which leads to about 75% reduction in word embedding matrix. The regularization strength α ranges from 0 to 0.3.
All the experiments are conducted with a single NVIDIA V100 GPU. We run each experiment three times and report the average score.
Baselines We compare our method with the following baselines: (1) CoFi (Xia et al., 2022), a state-of-the-art task-specific structured pruning method. It prunes attention heads, FFN hidden dimensions, and whole MHA/FFN layers. Distillation has also been integrated into CoFi; (2) TinyBERT (Jiao et al., 2020), a strong small PLM baseline. TinyBERT is distilled from BERT base with general distillation for pre-training and taskspecific distillation for downstream tasks.

Main Results
In counting without embeddings (Encoder) and with embeddings (Total) in the table.
For CoFi and TinyBERT, we show both the results extracted from previous works and the results obtained by our reimplementations. We use the same teacher model for CoFi, TinyBERT, and GRAIN in our reimplementation. TinyBERT is initialized from the public general distillation weights and further distilled on the tasks with task-specific data. We do not use data augmentation when training TinyBERT 4 for a fair comparison.
At 20% model size, compared with the teacher, the performance of GRAIN drops less than 1% on all the tasks and outperforms CoFi on all tasks except SST-2. At 5% model size, we see that GRAIN achieves notable improvements over the CoFi and TinyBERT on all tasks.
Under the extreme compression with 3% model size, as shown in the last two lines of Table 1, GRAIN (2.6M) still outperforms or is comparable with TinyBERT (4.7M) and CoFi (4.8M) on most tasks, even GRAIN has fewer parameters.
We also show the results of GRAIN without embedding factorization, as some previous works only conduct transformer pruning and leave embeddings un-pruned. Without embedding factorization, the pruned model has more parameters. However, Table 1 shows that the pruned model does not always benefit from having large embeddings. On QNLI and SQuAD, embedding factorization leads  to improved performance, while on SST-2, large embedding matrices are better. We also note that the gaps get closer as the model size increases.

Ablation Study
Gradient Separation We first remove gradient separation (GS) to see the effect of including the gradients from L hidden into the calculation of importance scores. In Table 2, we see that MNLI is not affected by gradient separation, while the performance drops on the rest tasks. The machine reading comprehension task SQuAD is most notably affected by gradient separation.
Hiiden Layer Matching Knowledge distillation only optimizes the cross-entropy objective by removing the hidden layer matching loss. The results shown in Table 2 indicate that removing the hidden layer matching leads to significant performance drops, showing the necessity to use both distillation objectives for obtaining efficient pruning.
Gradient-based pruning Is gradient-based pruning effective? Can we choose a random small model structure? To answer these questions, we use random scores instead of gradient-based scores at each pruning step. So the pruned model still has the same size, but its units are randomly chosen from the unpruned model. The results are shown in the last line in Table 2. The drops on all tasks indicate that gradient-based pruning is effective and indispensable to GRAIN, and proves the pruned model structures are significantly better than randomly selected structures.

Conclusion
In this paper, we propose GRAIN, a gradient-based structural pruning method that inspects the intraattention structures and allows a large searching space of model structures. We provide a structure regularization strategy that allows for variable speedups by encouraging regular model structures.
In order to integrate KD into the pruning, we use gradient separation to mitigate the negative effects of the KD on pruning. Compared with other pruning methods, such as L0-regularization, GRAIN is easy to implement and does not need to introduce additional trainable parameters. GRAIN is also training-efficient as it does not require extensive data augmentation or any pre-training stages. Experiments show that GRAIN achieves impressive high performance at different compression ratios on various natural language understanding tasks.