FedID: Federated Interactive Distillation for Large-Scale Pretraining Language Models

,


Introduction
The remarkable success of natural language processing (NLP) is highly dependent on large-scale pre-trained language models (PLMs; Devlin et al. 2019;Liu et al. 2019;Yang et al. 2019;Clark et al. 2020).To fully realize the potential of PLMs, they are typically trained using large amounts of combined data that is collected from multiple distributed user devices (a.k.a., clients) and transmitted to a single data center (a.k.a., server).With the growing concerns about privacy protection, data regulations such as the Personal Data Protection Act (PDPA; Chik 2013) and the General Data Protection Regulation (GDPR; Voigt and von dem Bussche 2017) have imposed strict requirements on preserving user data privacy, making it impractical to aggregate such data to a centralized location for training.Federated learning (FL;Mcmahan et al. 2017) has emerged as a privacy-preserving decentralized training paradigm, in which a federation of clients is orchestrated by a central server to collaboratively train a shared global model via aggregating the local models trained on their respective data.As a result, the private data of massive clients is effectively exploited in the form of model parameter exchange to train a unified model for better performance than individually working.
Previous work on federated NLP mainly targets solving either word-level language modeling applications such as mobile keyboard suggestion (Ji et al., 2019) and recommendation (Lin et al., 2020), or biomedical named entity recognition (Liu and Miller, 2020;Ge et al., 2020;Sui et al., 2020).More recently, Lin et al. (2022) provide a researchoriented benchmarking framework for advancing FL in NLP.However, these federated NLP frameworks are limited to identical architectures across the server and clients, making it impossible for clients to design their models independently according to their inconsistent system resources and nonindependent and identically distributed (non-IID) data.Also, the frequent model parameter exchange entails expensive communication costs.These obstacles significantly hinder the applicability and scalability of FL for large-scale PLMs.
Instead, federated distillation (FD) eliminates the need to share model parameters by transferring knowledge from the clients to the server using an unlabeled public proxy dataset (Jeong et al., 2018;Li and Wang, 2019;Chang et al., 2019;Gong et al., 2022;Itahara et al., 2021;Hu et al., 2021), thereby allowing collaboration between heterogeneous models with less communication costs.However, FD suffers from confirmation bias (Arazo et al., 2020;Pham et al., 2021) induced by incor- porating incorrect or even biased predictions on unlabeled data for knowledge transfer, where the central model will tend to degenerate.
To tackle this challenge, we propose Federated Interactive Distillation (FedID), where a small handful of labeled data is retained in the server, aiming to provide feedback to the local models to debias the predictions.
In addition, previous studies on FD tend to design different partitioning strategies on different datasets, and only small-scale models are taken into consideration, which makes it difficult to evaluate and compare various FD approaches scaled to the NLP domain in a systematic and fair manner.For this reason, based on the General Language Understanding Evaluation (GLUE) benchmark (Wang et al., 2019), we create a unified benchmarking framework across multiple tasks with diverse data distributions to simulate a variety of federated scenarios for evaluating the effectiveness of these methods on the decentralized training of large-scale PLMs, advancing the research of FD in NLP.Empirical experiments show that our proposed FedID achieves the best results in homogeneous and heterogeneous federated scenarios.
The contributions of this paper are summarized as follows: • To the best of our knowledge, we are the first to investigate the application of FD to decentralized learning of large-scale PLMs in ho-mogeneous and heterogeneous settings.
• We present a novel Federated Interactive Distillation framework to mitigate the problem of misleading privileged knowledge caused by confirmation bias in conventional FD.
• We provide a unified benchmarking framework across multiple NLP tasks with diverse data distributions to contribute to the research of FD in NLP community.
2 Related Work

Federated Learning
FL has gained significant interest and attention in the NLP field due to its potential for collaborative training on distributed data sources while preserving data privacy (Liu et al., 2021).Recent efforts have made preliminary explorations for the application of parameter averaging-based FL (e.g., Fe-dAvg (Mcmahan et al., 2017)) in the context of NLP (Tian et al., 2022;Dong et al., 2022;Zhang et al., 2022;Lin et al., 2022).Despite some success, several system-oriented challenges have to be faced to make FL widely available in NLP, including extensive communication overhead, inability to handle heterogeneity, and vulnerability to whitebox inference attacks.Several variants of FL have emerged to attempt to alleviate these issues.FedDF (Lin et al.) builds prototypical models with the same structure as the client models on the server side to enable model heterogeneity, and allows server-side ensemble distillation on unlabeled data from other domains to enhance model aggregation.FedED (Sui et al., 2020) reduces uplink communication costs by uploading the predictions of the local models instead of the parameters to train the central model, but still requires broadcasting the parameters of the central model over the downlink.Accordingly, these solutions still rely on exchanging model parameters and therefore are unable to completely address these limitations.

Federated Distillation
FD is a new algorithmic paradigm for FL with fundamentally different communication properties by exchanging the knowledge obtained during the local training in the form of model outputs rather than model parameters.This shared knowledge can be an aggregated statistic of model outputs on local private data (Jeong et al., 2018) or an ensemble of local model outputs computed on a publicly available proxy dataset (Li and Wang, 2019;Chang et al., 2019;Gong et al., 2022;Itahara et al., 2021;Hu et al., 2021).Existing efforts on FD fall into two main categories: • The server does not hold any model and is only used as an aggregator FedMD (Li and Wang, 2019) adopts a labeled public dataset for transfer learning among clients to seek fast improvement across all participants.Cronus (Chang et al., 2019) combines the local private dataset and the pseudo-labeled public dataset jointly for local training, where the pseudo-labels are ensembled with more robust aggregation rules.
• The server holds a central model that acts as the target for collaborative training FedKD (Gong et al., 2022) adopts a privacy-preserving ensemble strategy on cross-domain unlabeled data for one-way and one-shot distillation of the central model.In addition to server-side distillation, DS-FL (Itahara et al., 2021) also performs client-side distillation using the ensemble predictions on the unlabeled public dataset.Instead of transferring an ensemble of predictions, MHAT (Hu et al., 2021) achieves information aggregation by directly using predictions from multiple clients to train the central model simultaneously.However, these methods are generally subject to confirmation bias caused by transferring knowledge over unlabeled data, which greatly limits their performance.

Problem Definition
Consider a federated training environment with K clients, where the k-th client holds a labeled private dataset i=1 drawn from the same or distinct distribution, along with a homogeneous or heterogeneous local model f k parameterized by θ k .The goal is to train a central model f parameterized by θ on the server, but without direct access to these private data.

Federated Learning for NLP
In a general FL framework, the training process is divided into T communication rounds through a server-client paradigm, where all clients share the same model architecture coordinated by a central server.Specifically, at the beginning of federated training, the server initializes the global model parameters θ 0 .At each communication round t, the training is proceeded as follows: • Broadcast A subset of the client population C t ⊆ {1, 2, ..., K} is sampled to participate in training, where |C t | = ε • K, and ε is the sampling rate.Then the server distributes the current global model parameters θ t−1 to the participating clients.
• Local training Each participating client k ∈ C t uses the received parameters to initialize its local model, and updates it several epochs with its own private data D k , where η is the learning rate of the central model, and L CE denotes the loss function, which is usually a categorical cross-entropy for classification tasks.
• Upload The updated local model parameters θ k t are sent back to the server.• Aggregation The server collects and aggregates the parameters from clients to obtain the global model parameters for the next round,

Federated Distillation for NLP
In a general FD framework, an unlabeled public dataset i=1 is hosted by the server and transmitted to all clients for knowledge transfer before the federated training starts.At each communication round t, the training process is summarized as the following steps: • Local training Each participating client trains its local model θ k t−1 on its own private data D k for several epochs, ) where η k is the learning rate of the k-th local model.
• Local prediction Each participating client computes its local predictions on the entire public proxy dataset D 0 using its updated local model • Upload Participating clients upload their local predictions to the server.
• Aggregation The predictions from clients are collected and aggregated by the server as ensemble predictions, • Server distillation The ensemble predictions are treated as teacher knowledge to train the central model for several epochs, • Broadcast The server broadcasts the ensemble predictions to participating clients.
• Local distillation Each participating client distills its local model using the received ensemble predictions on the entire public proxy dataset,

Federated Interactive Distillation
In existing FD approaches, the central model is only allowed to passively mimic the local models by one-way knowledge transfer, leading to confirmation bias that heavily fades the superiority of FD.Instead of directly transmitting the entire public dataset and its predictions between the server and clients, the proposed FedID slices the unlabeled public dataset into multiple smaller batches for training, and handles only a small batch of data and predictions in each communication, which allows for an interaction between the central model and local models during the knowledge transfer process, while significantly reducing the load of a single communication.After each server distillation, the central model is allowed to feedback its performance on a small amount of labeled data held by the server back to each client to adapt its local model accordingly for rectifying its confirmation bias.The overall framework of FedID is presented in Figure 1.

Server Interactive Distillation
The server samples a batch of unlabeled public data x 0 from D 0 and distributes them to each participating client for local prediction, The predictions from clients are uploaded to the server and aggregated with the same strategy as in Eq. ( 6), together with the batch input x 0 , which are adopted to train the central model for knowledge transfer, The updated central model θ t is then evaluated on a batch of data (x val , y val ) sampled from the held by the server, In addition to the ensemble predictions y t , the above-computed validation loss is also broadcast together to each participating client as feedback.

Client Interactive Distillation
For each participating client k ∈ C t , the gradients on the ensemble predictions y t are computed to learn knowledge from other clients for alleviating data heterogeneity, Also, the feedback gradients from the server to the client are computed from the validation loss, and are added to further rectify its local model, In this way, FedID establishes interactive distillation between the server and clients, where the client-to-server interaction aims to transfer the knowledge learned by local models during local training on their respective private data to the central model, while the server-to-client interaction attempts to rectify confirmation bias by allowing the local models to learn from the central model's feedback.The detailed procedures are summarized in Algorithms 1.  et al., 2006;Giampiccolo et al., 2007;Bentivogli et al., 2009), MRPC (Dolan and Brockett, 2005), CoLA (Warstadt et al., 2019), SST-2 (Socher et al., 2013), QNLI (Rajpurkar et al., 2016), QQP1 , and MNLI (Williams et al., 2018).See Appendix A for more details about GLUE.

Method
For each task, the original development set is employed to evaluate the performance of the central and local models, while the original training set is divided into private and public datasets at a ratio of 1:1, which are used for client training and knowledge transfer between the server and clients, respectively.Particularly, for the resulting public dataset, we further sample 10% of it as the labeled dataset reserved for the server, and the rest as the unlabeled public dataset after rounding off labels.
Furthermore, to create disjoint client training data from the private dataset, the training instances of each client are drawn independently with class labels following a categorical distribution over N classes parameterized by a vector q (q i ≥ 0, i ∈ [1, N ], and q 1 = 1).Meanwhile, to simulate varying data distributions for clients, we further draw q ∼ Dir(αp) from a Dirichlet distribu-tion (Hsu et al., 2019), where p is a prior class distribution over N classes, and α is a concentration parameter that controls the degree of data heterogeneity among clients.Typically, when α → ∞, clients tend to be assigned to the identical data distribution, and conversely, when α → 0, clients are more likely to hold examples from only one random class.In our experiments, we set α to 100 and 1 to generate IID and non-IID data, respectively.2

Homogeneous setting
For a homogeneous federated scenario, the model architectures of clients are limited to be the same as that of the server.To be compatible with FL methods for comparison, we adopt BERT-base (Devlin et al., 2019) as the central model since FL cannot usually be applied to larger PLMs due to communication bottlenecks.

Heterogeneous setting
For a heterogeneous federated scenario, the central model is initialized with BERT-base, while each local model is selected from BERT-base, BERT-large, RoBERTa-base (Liu et al., 2019)

Implementation Details
We adopt the AdamW optimizer (Loshchilov and Hutter, 2019) with an initial learning rate of 2e-5 to update the model parameters.For single-sentence or sentence-pair input to the model, the maximum sequence length is set to 128, and the batch size is set to 32.For hyperparameters in federated training, the number of epochs for local training, local distillation, and server distillation is set to 3, 3, and 3, respectively, the number of clients K is set to 10, the fraction of client sampling ε is set to 1, and the number of communication rounds T is set to 10.

Baselines
We compare FedID with FL algorithms including FedAvg (Mcmahan et al., 2017), FedDF (Lin et al.), and FedED (Sui et al., 2020), as well as FD algorithms including FedKD (Gong et al., 2022), MAHT (Hu et al., 2021), and DS-FL (Itahara et al., 2021).We also provide the models with centralized training (denoted as Centralized) that have access to all private data held by the clients as an upper bound on model performance.4

Homogeneous setting
Table 1 shows the performances across models in homogeneous setting.Without considering data privacy, centralized models always exhibit the best performance, while decentralized models sacrifice performance in ex-change for better privacy protection.However, this performance gap is gradually alleviated as the training data increases.In addition, the performances of FL and FD models are significantly degenerated when encountering non-IID data.Also, when sufficient public data is made available, the performance of FD models can be comparable to that of FL models, accompanied by lower communication costs.

Heterogeneous setting
Table 2 shows the performances across models in heterogeneous setting.The proposed FedID outperforms other baselines, demonstrating the superiority of tackling the confirmation bias.In particular, FedID exhibits strong robustness when only a small amount of training data is available, as there is not enough private data to adequately train the local models and thus the confirmation bias becomes more pronounced.

Cross-domain setting
We also use the original training sets of IMDB (Maas et al., 2011) andPAWS (Zhang et al., 2019) as unlabeled public data for SST-2 and QQP, respectively, to construct crossdomain knowledge transfer environments, where the confirmation bias is more likely to occur.The experimental results on the dev sets of SST-2 and QQP are shown in Table 3, where the greater performance gap between FedID and other baselines further confirms our claim.

Ablation Study
We remove the feedback gradient and the knowledge transfer gradient from Eq. ( 15 to conduct ablation experiments on the small dataset RTE, the medium dataset SST-2, and the large dataset QQP.As shown in Table 4, without the feedback gradient or the knowledge transfer gradient, the performances of models get worse, where the feedback gradient contributes more.

Communication Cost
Communication costs between the server and client models across baselines are presented in Table 5.
The communication costs of FedAvg, FedDF, and FedED are much higher than those of FedKD, DS-FL, MHAT, and FedID as they entail extensive communication to share the model parameters.
FedKD exhibits the lowest communication costs as the clients' predictions are aggregated without the need to send back to the clients and only one round of communication is executed, while DS-FL and MHAT need to broadcast the ensemble predictions from the server back to each client.Similarly, FedID is required to transmit ensemble predictions and the validation loss in batches to clients as feedback, but the communication costs for the validation loss are negligible compared to that of ensemble predictions.As a result, the communication costs of FedID remain in line with DS-FL and MHAT, but the communication between the server and clients is more frequent.

Effect of Unlabeled Public Dataset Size
In our experimental setup, we partition the original training dataset into a private dataset and a public dataset.To further investigate the effect of different proportions of the public dataset on performance, we keep the size of the private dataset constant while conducting experiments using 10%, 20%, 40%, 80%, and 100% of the public dataset, respectively.The results in Figure 2 show that the performance of the central model improves to some extent as the size of the public dataset increases, where FedID still exhibits superior performance and robustness.

Effect of Labeled Dataset Size
To investigate the effect of the size of the labeled dataset retained by the server on performance, we experiment with 10%, 20%, 40%, 80%, and 100% of the labeled dataset, respectively.For FedID, the labeled data is used to rectify the confirmation bias in the client models' predictions, while for other FD methods, the labeled data is added to the training of the central model.The results in Figure 3 show that FedID is least sensitive to the size of the labeled dataset since this data is not used to directly participate in the training of the central model.Moreover, although other FD methods use the labeled data directly for additional training of the central model, there is still no significant performance improvement observed because the proportion of the labeled data is far lower than that of the unlabeled data, and thus its supervision on the central model is limited.Our solution makes better use of the small amount of labeled data by leveraging it to rectify confirmation bias in the predictions from unlabeled data.

Effect of Number of Clients
The number of clients usually imposes a significant impact on performance, as the entire training dataset is partitioned and distributed to multiple clients.To investigate this, we increase the number of clients from 5 to 10 and 20 while keeping

Conclusions
This study explores the application of FD to decentralized training of large-scale PLMs in homogeneous and heterogeneous settings, and further presents an interactive FD scheme to mitigate the confirmation bias caused by transferring knowledge on an unlabeled public dataset.Moreover, a benchmarking framework across multiple tasks with diverse data distributions is developed to contribute to the research of FD in NLP community.Future work will be executed to aggregate differ- entially private local predictions for a stronger privacy guarantee, enhancing the resilience of FedID against malicious server or clients.

Limitations
There are two main limitations to our work compared to previous efforts: 1) We assume that a small amount of labeled data is retained in the server.However, this situation may be common in real life.For instance, an institution possesses only a small amount of training data, which is not enough to train a well-performing model, thus it may want to resort to collaborative training with other institutions with the help of FD on a large amount of unlabeled public data.However, directly transferring knowledge on the unlabeled data may not yield a satisfactory performance, while the small amount of training data retained by the institution can be used as labeled data by the proposed FedID to maximize the performance.In addition, our approach is more suitable for the case where one client in the federation acts as the server; 2) Compared with other FD approaches, our solution slices the unlabeled public dataset into multiple smaller batches for training, thus entailing more frequent communication between the server and clients.However, the increase in communication frequency may be tolerable considering the similar communication costs and the fact that transmitting smaller packets avoids potential network congestion when the public dataset is too large.• WNLI The Winograd Natural Language Inference (Levesque et al., 2012) is a sentencepair binary classification task that requires the model to determine whether two sentences in a given sentence-pair are entailment relations, with the evaluation metric of accuracy.
• RTE The Recognizing Textual Entailment (Dagan et al., 2005;Bar-Haim et al., 2006;Giampiccolo et al., 2007;Bentivogli et al., 2009) is a sentence-pair binary classification task, which requires the model to determine whether two sentences in a given sentence pair are entailment relations, with the evaluation metric of accuracy.
• MRPC The Microsoft Research Paraphrase Corpus (Dolan and Brockett, 2005) is a sentence-pair binary classification task that requires the model to determine whether two sentences in a given sentence pair are semantically equivalent, with evaluation metrics of accuracy and F 1 -score.
• STS-B The Semantic Textual Similarity Benchmark (Cer et al., 2017) is a sentencepair regression task that requires the model to evaluate how similar two sentences in a given sentence-pair are by a floating score range from 0 to 5, with evaluation metrics of Pearson and Spearman correlations.
• CoLA The Corpus of Linguistic Acceptability (Warstadt et al., 2019)  model to determine whether a given English sentence is grammatically correct, with the evaluation metric of the Matthews correlation.
• SST-2 The Stanford Sentiment Treebank (Socher et al., 2013) is a single-sentence binary classification task that requires the model to determine whether a given movie review is positive or negative in sentiment, with the evaluation metric of accuracy.
• QNLI The Question Natural Language Inference (Rajpurkar et al., 2016) is a sentence-pair binary classification task.Given a question and a context, the model is required to determine whether the context contains the answer to the question, with the evaluation metric of accuracy.
• QQP The Quora Question Pairs is a sentencepair binary classification task.Given a pair of questions, the model is required to determine whether the two sentences are semantically equivalent, with evaluation metrics of accuracy and F 1 -score.
• MNLI The Multi-genre Natural Language Inference (Williams et al., 2018) is a sentencepair three-way classification task.Given a premise and a hypothesis, the model is required to determine whether the hypothesis is an entailment, contradiction, or neutral with respect to the premise.The task is divided into matched and mismatched versions, with evaluation metrics of matched accuracy and mismatched accuracy, respectively.
The statistics of these tasks are presented in Table 6.

where
|D k | and |D| = k∈Ct |D k | are the number of local data held by the k-th client and all participating clients, respectively.

Figure 2 :
Figure 2: Performance of the central model on SST-2 with different sizes of public dataset.

Figure 3 :
Figure 3: Performance of the central model on SST-2 with different sizes of labeled dataset.

Figure 4 :
Figure 4: Performance of the central model on SST-2 with different numbers of clients.
Algorithm 1: Federated Interactive Distillation (FedID) Input: labeled private datasets {D k } K k=1 ; unlabeled public dataset D 0 ; a handful of labeled dataset D val held by the server; local models {θ k } K k=1 ; central model θ; communication rounds T Output: decentrally trained θ 1 Each client initializes the local model θ k 15Server aggregates local predictions to create the ensemble predictions y t via Eq.(10)16Server updates the central model parameters θ t via Eq.(11)17Server samples a mini-batch of labeled data (x val , y val ) ∼ D val 18 Server computes the validation loss L CE (y val , f (x val ; θ t )) via Eq.(12) 19 Server broadcasts the validation loss L CE (y val , f (x val ; θ t )) and ensemble predictions y t to all participants C t 20 for each client k ∈ C t in parallel do 21 Update the local model parameters θ k t via Eq.(15)

Table 1 :
Experiment results of the homogeneous setting on the GLUE dev sets.

Table 2 :
, or RoBERTa-large 3 .Experiment results of the heterogeneous setting on the GLUE dev sets.

Table 3 :
), respectively, Results of models in the cross-domain setting.

Table 5 :
Formulations of communication costs.
used in this work are publicly available and widely used.Yuan Zhang, Jason Baldridge, and Luheng He. 2019.Paws: Paraphrase adversaries from word scrambling.In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies (NAACL-HLT 2019), pages 1298-1308.

Table 6 :
is a single-sentence binary classification task that requires the Statistics of the GLUE benchmark.