HRKD: Hierarchical Relational Knowledge Distillation for Cross-domain Language Model Compression

On many natural language processing tasks, large pre-trained language models (PLMs) have shown overwhelming performances compared with traditional neural network methods. Nevertheless, their huge model size and low inference speed have hindered the deployment on resource-limited devices in practice. In this paper, we target to compress PLMs with knowledge distillation, and propose a hierarchical relational knowledge distillation (HRKD) method to capture both hierarchical and domain relational information. Specifically, to enhance the model capability and transferability, we leverage the idea of meta-learning and set up domain-relational graphs to capture the relational information across different domains. And to dynamically select the most representative prototypes for each domain, we propose a hierarchical compare-aggregate mechanism to capture hierarchical relationships. Extensive experiments on public multi-domain datasets demonstrate the superior performance of our HRKD method as well as its strong few-shot learning ability. For reproducibility, we release the code at https://github.com/cheneydon/hrkd.

To solve the above problem, many compression techniques for PLMs have been proposed, such as * Corresponding author. quantization (Shen et al., 2020), weight pruning (Michel et al., 2019), and knowledge distillation (KD) (Sun et al., 2019;Jiao et al., 2020). Due to the plug-and-play feasibility of KD, it is the most commonly used method in practice, and we focus on it in this work. The purpose of KD is to transfer knowledge from a larger teacher model to a smaller student model (Hinton et al., 2015). Traditional KD methods only leverage single-domain knowledge, i.e., transferring the knowledge of the teacher model to the student model domain by domain. However, as stated in the purpose of transfer learning, the model performance on target domains can be improved by transferring the knowledge from different but related source domains (Lu et al., 2015), thus the cross-domain knowledge also plays an important role. In addition, several recent works have also proved the advantage of crossdomain knowledge, and many multi-domain KD methods have been proposed. For example, Peng et al. (2020); Yang et al. (2020) demonstrate the effectiveness of distilling knowledge from multiple teachers in different domains; Liu et al. (2019a,b) show that jointly distilling the student models of different domains can enhance the performance.
Nevertheless, these methods fail to capture the relational information across different domains and might have poor generalization ability. To enhance the transferability of the multi-domain KD framework, some researchers have recently adopted the idea of meta-learning. Some studies have pointed out that meta-learning can improve the transferability of models between different domains (Finn et al., 2017;Javed and White, 2019). For example, Meta-KD (Pan et al., 2020) introduces an instancespecific domain-expertise weighting technique to distill the knowledge from a meta-teacher trained across multiple domains to the student model. However, the Meta-KD framework trains student models in different domains separately, which is inconvenient in real-world applications and might not have enough capability to capture multi-domain correlations.
In this paper, we aim to simultaneously capture the relational information across different domains to make our framework more convenient and effective. Specifically, we set up several domainrelational graphs to adequately learn the relations of different domains and generate a set of domainrelational ratios to re-weight each domain during the KD process. Moreover, since different domains might have different preferences of layer prototypes, motivated by the Riesz representation theorem (Hartig, 1983), we first construct a set of reference prototypes for each domain, which is calculated by a self-attention mechanism to integrate the information of different domains. Then we introduce a hierarchical compare-aggregate mechanism to compare each layer prototype with the corresponding reference prototype and make an aggregation based on their similarities. The aggregated prototypes are finally sent to the corresponding domain-relational graphs. Our framework is referred to as hierarchical relational knowledge distillation (HRKD).
We evaluate the HRKD framework on two multidomain NLP datasets, including the MNLI dataset (Williams et al., 2018) and the Amazon Reviews dataset (Blitzer et al., 2007). Experiments show that our HRKD method can achieve better performance compared with several multi-domain KD methods. We also evaluate our approach under the few-shot learning setting, and it can still achieve better results than the competing baselines.

Method
In this section, we detailedly describe the proposed HRKD framework. Our HRKD aims to simultaneously capture the relational information across different domains with both hierarchical and domain meta-knowledges. To achieve this goal, we introduce a hierarchical compare-aggregate mechanism to dynamically identify more representative prototypes for each domain, and construct a set of domain-relational graphs to generate re-weighting KD ratios. The overview of HRKD is shown in Figure 1. We first introduce the basic multi-domain KD method in Section 2.1, which is a naive framework lacking the ability of capturing cross-domain relations. Then we describe the domain-relational graph and compare-aggregate mechanism in Section 2.2 and 2.3, respectively, which are the primary modules of our HRKD method to discover the relational information.

Multi-domain Knowledge Distillation
Similar to (Jiao et al., 2020), we jointly distill the embeddings, attention matrices, transformer layer outputs, and predicted logits between the teacher and student models. Inspired by (Liu et al., 2019c), we use a multi-task training strategy to perform multi-domain KD. Specifically, we share the weights of the embedding and transformer layers for all domains while assigning different prediction layers to different domains. Innovatively, we optimize models in different domains simultaneously rather than sequentially.
In detail, the embedding loss L d embd and prediction loss L d pred of d-th domain are formulated as: where MSE and CE represent the mean square loss and cross-entropy loss, respectively. E S and E T d represent the embeddings of student model and teacher model of d-th domain, respectively. z S d and z T d represent the predicted logits of student model and teacher model of d-th domain, respectively. W embd is a learnable transformation matrix to align the student embedding dimension that mismatches with the teacher embedding dimension, and t is the temperature factor.
The attention loss L m,d attn and transformer layer output loss L m,d hidn at m-th student layer and d-th domain are formulated as: where h is the number of attention heads, A S i,m and A T i,n,d are the i-th head of attention matrices at mth student layer and its matching n-th teacher layer of d-th domain, respectively. H S m and H T n,d are the transformer layer outputs at m-th student layer and n-th teacher layer of d-th domain, respectively. W hidn m is a transformation matrix to align the m-th layer of student output dimension that mismatches with the n-th layer of teacher output dimension. We use uniform strategy to match the layers between the student and teacher models.  Figure 1: An overview of the proposed HRKD method. We use knowledge distillation (KD) to transfer the knowledge from the teacher model to the student model. During KD, we set up several domain-relational graphs to generate domain-relational ratios for re-weighting each domain. We then introduce a hierarchical compare-aggregate mechanism. The prototypes of different layers are dynamically aggregated based on the similarity ratios compared with the corresponding reference prototypes, which are then fed into the domain-relational graphs.
Finally, the overall KD loss is formulated as: where D is the total domain number, M is the number of transformer layers in the student model, γ is used to control the weight of the prediction loss L pred .

Prototype-based Domain-relational Graph
Although the basic multi-domain KD method described in Section 2.1 can distill the student models across different domains, the relational information between different domains is neglected, which is important for enhancing the model transferability as pointed out by previous studies (Finn et al., 2017;Javed and White, 2019). To solve the problem, we attempt to leverage meta-learning to enhance the performance and transferability of our student model. Inspired by the metric-based meth-ods of meta-learning (Snell et al., 2017;Sung et al., 2018), we use prototype representations rather than raw samples to reflect the characteristics of each domain data. This helps to alleviate the negative impact of abnormal samples when there are few training samples (e.g., overfitting) and make the metalearner easier to learn transferable cross-domain knowledge. Moreover, since we conduct KD over all of the student layers, we calculate different prototypes for different student layers to explicitly distinguish their characteristics. Specifically, the prototype h m,d of m-th layer of the student model at d-th domain is calculated by: where D d refers to the training set of d-th domain, L refers to the sentence length (i.e., number of tokens), E S i,l represents l-th token of i-th sampled student embedding in D d , and H S m,i,l represents the l-th token output by the i-th sampled student transformer layer of the m-th student layer in D d . In practice, we calculate different prototypes for different batches of training samples.
Afterward, these domain prototypes are leveraged to probe the relations across different domains. Although many multi-domain text mining methods have been proposed recently Pan et al., 2020), they capture the relations separately for each given domain, which might be inconvenient and time-consuming in practice. Meanwhile, the learning process is not effective enough since the other domains cannot learn from each other when optimizing a specific domain. To solve this problem, we aim to simultaneously discover the cross-domain relations to make our framework more convenient and effective. To achieve the goal, we propose to use the graph attention network (GAT) (Veličković et al., 2018) to process the prototypes of all domains at the same time. To utilize GAT, each node in the graph represents a domain prototype, and each edge weight represents the similarity of the connected two prototypes. In this way, the relations across different domains can be captured simultaneously. In detail, we set up a twolayer domain-relational graph for each layer of the student model (except for the prediction layer). The input h m of the m-th graph is a set of node features containing all of the domain prototypes at m-th student layer, i.e., h m = {h m,1 , ..., h m,D } ∈ R D×F , where D is the total domain number, F is the channel number of each prototype.
In the first-layer domain-relational graph of the m-th student layer, a shared weight matrix W m ∈ R F ′ ×F is first applied to each node followed by a self-attention mechanism, where F ′ is the intermediate channel number. Then a multi-head concatenation mechanism with K heads is employed to stabilize the training process. Specifically, each input prototype h m,d is first transformed by the weight matrix W m , then the attention coefficient α i,j,m between two nodes i, j is calculated by applying a weight vector a m ∈ R 2F ′ ×1 to the concatenation of their transformed features followed by the LeakyReLU nonlinearity and softmax function, which can be formulated as: where ⊕ represents the concatenation operation and N i is all the first-order neighbors of node i (includ-ing node i). Then the final output h ′ m,i ∈ R KF ′ of node i can be obtained by the weighted sum of the transformed features of node i and its neighbors based on their attention coefficients followed by the ELU nonlinearity and a multi-head concatenation mechanism: where k represents the head index.
In the second-layer domain-relational graph of the m-th student layer, targeting at obtaining domain-relational ratios, we reformulate the parameters W m , a m used in the first-layer graph as W ′ m ∈ R 1×KF ′ , a ′ m ∈ R 2×1 respectively and do not apply the multi-head mechanism. We use the softmax operation to normalize the output and finally derive the domain-relational ratios r m ∈ R D , formulated as below: where α ′ i,j,m is calculated by:

Hierarchical Compare-aggregate Mechanism
As different domains might have different preferences towards different layer prototypes, we propose a hierarchical compare-aggregate mechanism to dynamically select the most representative prototype for each domain. Our compare-aggregate mechanism is motivated by the Riesz representation theorem (Hartig, 1983), which indicates that an element can be evaluated by comparing it with a specific reference element and the quality of the element is the same as that of the selected reference element. Based on this, we establish a set of reference prototypes for each domain and hierarchically aggregate the current and previous layer prototypes based on their similarities with the corresponding reference prototypes.
Reference prototype. For each student layer, a simple way is to use the original domain prototypes of current layer as the reference prototypes for the current and previous layer prototypes. However, the information of other domains is not integrated, which plays an important role to enhance the model transferability across different domains. To handle this, we introduce a self-attention mechanism over all of the domain prototypes in the same layer to inject the information of different domains. Specifically, the reference prototype RP m ∈ R D×F of m-th student layer is calculated by: where α D m ∈ R D×D refers to the attention matrix of m-th layer, h m ∈ R D×F refers to the prototypes of all domains at m-th layer, W D m ∈ R F ×F refers to a learnable parameter matrix at m-th layer, and the softmax operation is performed over the last vector dimension.
Compare-aggregate mechanism. After obtaining the reference prototypes, we propose a compare-aggregate mechanism to hierarchically aggregate the layer prototypes by comparing them with the corresponding reference prototypes, which makes the model be aware to more representative layer prototypes for each domain. In detail, the aggregated prototype AP m,d ∈ R F of m-th layer and d-th domain is formulated as: where α H m,d ∈ R m+1 represents the similarity ratios of m-th layer and d-th domain, h ≤m,d ∈ R (m+1)×F represents the prototypes of m-th layer and its previous layers at d-th domain, W H m,d ∈ R F ×F is a learnable parameter matrix of m-th layer and d-th domain, and RP m,d ∈ R F is the reference prototype of m-th layer and d-th domain. Then the aggregated prototype AP is sent to the domain-relational graphs to obtain the domainrelational ratios r ∈ R (M +1)×D , as formulated by Equation (7)-(11).
Finally, the overall loss of our HRKD can be represented as: where r m,d is the domain-relational ratio at m-th student layer and d-th domain.

Experiment
In this section, we conduct extensive experiments on two multi-domain datasets, namely MNLI and Amazon Reviews, to demonstrate the effectiveness of our HRKD method.

Datasets and Model Settings
We evaluate our method on two multi-domain datasets, including the multi-genre natural language inference (MNLI) dataset (Williams et al., 2018) and the Amazon Reviews dataset (Blitzer et al., 2007). In detail, MNLI is a natural language inference dataset with five domains for the task of entailment relation prediction between two sentences. In our setting, we randomly sample 10% of the original training data as our development set and use the original development set as our test set. Amazon Reviews is a sentiment analysis dataset with four domains for predicting whether the reviews are positive or negative. Following Pan et al.
(2020), we randomly split the original data into train, development, and test sets. The statistics of these two datasets are listed in Table 1. We use BERT B (the number of layers N =12, the hidden size d ′ =768, the FFN intermediate hidden size d ′ i =3072, the number of attention heads h=12, the number of parameters #params=109M) as the architecture of our teacher model, and BERT S (M =4, d ′ =312, d ′ i =1200, h=12, #params=14.5M) as our student model.
Our teacher model HRKD-teacher is trained in a multi-domain manner as described in Section 2.1, and our student model BERT S is initialized with the general distillation weights of TinyBERT 1 .

Baselines
We mainly compare our KD method with several KD baseline methods distilled from four teacher models, including BERT B -single, BERT B -mix, BERT B -mtl, and Meta-teacher in Meta-KD (Pan et al., 2020). Specifically, BERT B -single trains the teacher model of each domain separately with the single-domain dataset; BERT B -mix trains a single teacher model with the combined dataset of all domains; BERT B -mtl adopts the multi-task training method proposed by Liu et al. (2019c) to train the teacher model; Meta-teacher trains the teacher model with several meta-learning strategies including prototype-based instance weighting and domain corruption.

Implementation Details
For the teacher model, we train the HRKD-teacher for three epochs with a learning rate of 5e-5. For the student model, we train it for ten epochs with a learning rate of 1e-3 and 5e-4 on MNLI and Amazon Reviews, respectively. γ is set to 1, and t is 1. For few-shot learning, the learning rate for the student model is 5e-5, while other hyper-parameters are kept the same. The few-shot training data is selected from the front of our original training set with different sample ratios, while the dev and test data are the same as our original dev and test sets without sampling to make a fair comparison. In all the experiments, the sequence length is set to 128, and the batch size is 32. The hyper-parameters are tuned on the development set, and the results are averaged over five runs. Our experiments are conducted on 4 GeForce RTX 3090 GPUs.

General Experimental Results
The experimental results of our method are shown in Table 2  Meta-KD HRKD , which demonstrate the superior performance of our method. Specifically, with the HRKD method, the average score of the student model is both 0.5% higher than that of the model with the base TinyBERT-KD method and its counterpart Meta-KD method (see Table 2). It can also be observed that the improvement of our HRKD method on the Telephone domain is the most significant, which is probably caused by the amount of training data. From Table 1, we can see that the Telephone domain has much more training data than other domains, indicating that the Telephone domain can derive more relationship information from other domains and lead to higher improvement. Meanwhile, as shown in the results on the Amazon Reviews dataset in Table 3, the performance of the HRKD-teacher model is slightly better than that of other teacher models, but the student model distilled by the HRKD method largely outperforms the models distilled by the TinyBERT-KD and Meta-KD methods with average gains of 2.6% and 1.1% respectively, which prove the excellent performance of our method again. Note that our HRKD method significantly outperforms the base TinyBERT-KD method on both MNLI and Amazon Reviews datasets (t-test with p < 0.1). And since the performances of the Meta-teacher and our HRKD-teacher are similar on both datasets, the impact of the teacher is negligible, making the comparison between our HRKD and its counterpart Meta-KD relatively fair.

Few-shot Learning Results
As a large amount of training data is hard to collect in reality, the few-shot learning ability of our method is worth being evaluated, where both the teacher and student models are trained with few training data in each domain. We randomly sample a part of the training data in the MNLI dataset to make an evaluation, where the chosen sample rates are 2%, 5%, 10%, and 20%. We mainly compare the performance improvements between two methods: distilling from BERT B -single to BERT S with TinyBERT-KD (BERT B -single  Figure 2, we can observe that the improvement gets more prominent when the training data gets fewer, and the average improvement rate is the largest of 10.1% when there is only 2% MNLI training data. In addition, we can see that the improvement rates of our method are higher than those of Meta-KD under most of the sample rates, especially when there are only 2% training data. These results demonstrate the strong learning ability of our HRKD method under the few-shot setting.

Ablation Studies
In this section, we progressively remove each module of our KD method to evaluate the effect of each module.
The results are shown in Table 4. We first remove the self-attention mechanism across different domain prototypes (-Self-attention), and the average score on Amazon Reviews drops by 0.2%, which proves its effectiveness. Next, we replace the hierarchical compare-aggregate mechanism with a simple average operation (-Comp-Agg), and the average score drops by 0.4%, which demonstrates the effectiveness of the compare-aggregate mechanism. Then we remove the hierarchical graph structure (-Hierarchical Rel.), where the input of each domain-relational graph comes from a single   22, 0.20, 0.19, 0.18, 0.21] student layer. As can be seen, the average score drops by 0.4%, which proves the importance of the hierarchical relationship. Finally, we remove the domain-relational graph in each layer (-Domain Rel.), and the performance significantly drops by 1.6%, which strongly demonstrates the advantage of the domain relationship.

Case Studies
We further provide some case studies to intuitively explain the effectiveness of the domain-relational ratios and hierarchical similarity ratios calculated by our HRKD method (see Table 5 and 6).
In Table 5 and 6, we use the label to denote the categories of sampled domain examples, and we assume that if the learned domain-relational ratios and hierarchical similarity ratios are similar for domain examples with same category while different for those with different categories, then the model has relatively correctly captured the cross-domain and hierarchical relational information. We select two typical types of cases from Amazon Reviews across four domains, in which we adjust the number of domains in each category under two settings: (i) three same categories (i.e., POS) with another one category (i.e., NEG) as in Table 5, and (ii) two same categories (i.e., POS) with another two same categories (i.e., NEG) as in Table 6.
We find the results are intuitive, as we observe that the review texts with the same labels have similar domain-relational ratios and hierarchical similarity ratios, while different layers indeed have different domain weighting preferences and different preferences of layer prototypes for graph input. For example, in Table 5 and 6, positive samples tend to have higher domain-relational ratios in the middle layers (i.e., 2-4), while negative samples have higher ratios in the marginal layers (i.e., 1, 5). Meanwhile, in the second and third layers of Table 5 as well as the first layer of Table 6, lower positive layer prototypes tend to have higher similarity ratios, and the higher positive layer prototypes in the third layer of Table 6 also tend to have higher similarity ratios; while those of the negative layer prototypes are just the opposite. The results show that HRKD method has distinctively and correctly captured the hierarchical and domain meta-knowledges, leading to better performance.

Related Work
Pre-trained Language Model (PLM) Compression. Due to the large size and slow inference speed, PLMs are hard to be deployed on edge devices for practical usage. To solve this problem, many PLM compression methods have been proposed, including quantization (Shen et al., 2020), weight pruning (Michel et al., 2019), and knowledge distillation (KD) (Sun et al., 2019;Jiao et al., 2020). Among them, KD (Hinton et al., 2015) has been widely adopted due to its plug-and-play feasibility, aiming to distill the knowledge from a larger teacher model to a smaller student model without decreasing too much performance. For example, BERT-PKD (Sun et al., 2019) distills both intermediate and output layers on fine-tuning. TinyBERT (Jiao et al., 2020) additionally distills the embedding layer and attention matrices during pre-training and fine-tuning. Meta-KD (Pan et al., 2020) proposes to distill knowledge from a crossdomain meta-teacher through an instance-specific domain-expertise weighting technique.
In this paper, we propose a novel cross-domain KD framework that captures the relational information across different domains with both domain and hierarchical meta-knowledges, which has a better capability for capturing multi-domain correlations.
Transfer Learning and Meta-learning. Transfer learning focuses on transferring the knowledge from source domains to boost the model performance on the target domain. Among the methods in transfer learning, the shared-private architecture (Liu et al., 2017(Liu et al., , 2019c is most commonly applied in NLP tasks, which consists of a shared network to store domain-invariant knowledge and a private network to capture domain-specific information. There are also many works applying adversarial training strategies (Shen et al., 2018;Li et al., 2019;, which introduce domain adversarial classifiers to learn the domain-invariant features. Besides, the research of multi-domain learning has gained more and more attention recently, which is a particular case of transfer learning targeting transferring knowledge across different domains to comprehensively enhance the model performance (Cai and Wan, 2019;. Unlike transfer learning, the goal of meta-learning is to train a meta-learner that can easily adapt to a new task with a few training data and iterations (Finn et al., 2017). Traditional meta-learning typically contains three categories of methods: metric-based (Snell et al., 2017;Sung et al., 2018), model-based (Santoro et al., 2016;Munkhdalai and Yu, 2017), and optimization-based (Ravi and Larochelle, 2017;Finn et al., 2017). In addition, the meta-learning technique can benefit the multi-domain learning task by learning the relationship information among different domains (Franceschi et al., 2017).
In this paper, we leverage meta-learning to solve the multi-domain learning task, where we consider cross-domain KD to simultaneously capture the correlation between different domains, aiming to train a better student meta-learner.

Conclusion
In this paper, we present a hierarchical relational knowledge distillation (HRKD) framework to simultaneously capture the cross-domain relational information. We build several domain-relational graphs to capture domain meta-knowledge and introduce a hierarchical compare-aggregate mechanism to capture hierarchical meta-knowledge. The learnt domain-relational ratios are leveraged to measure domain importance during the KD process. Extensive experiments on public datasets demonstrate the superior performance and solid few-shot learning ability of our HRKD method.