Meta-Learning Adversarial Domain Adaptation Network for Few-Shot Text Classification

Meta-learning has emerged as a trending technique to tackle few-shot text classification and achieved state-of-the-art performance. However, existing solutions heavily rely on the exploitation of lexical features and their distributional signatures on training data, while neglecting to strengthen the model's ability to adapt to new tasks. In this paper, we propose a novel meta-learning framework integrated with an adversarial domain adaptation network, aiming to improve the adaptive ability of the model and generate high-quality text embedding for new classes. Extensive experiments are conducted on four benchmark datasets and our method demonstrates clear superiority over the state-of-the-art models in all the datasets. In particular, the accuracy of 1-shot and 5-shot classification on the dataset of 20 Newsgroups is boosted from 52.1% to 59.6%, and from 68.3% to 77.8%, respectively.


Introduction
Few-shot text classification (Yu et al., 2018;Geng et al., 2019) is a task in which a model will be adapted to predict new classes not seen in training. For each of these new classes, we only have a few labeled examples. To be specific, we are given lots of training data with a set of classes Y train . After training, our goal is to get accurate classification results on the testing data with a set of new classes Y test , which is disjoint to Y train . Only a small labeled support set will be available in the testing stage. If the support set contains K labeled examples for each of the N unique classes, we refer to the task as a N-way K-shot classification.
Existing approaches for few-shot text classification mainly fall into two categories: (1) transferlearning based methods (Howard and Ruder, 2018;Pan et al., 2019;Gupta et al., 2020), which aim to transfer knowledge learned from a task to a new task or leverage general-domain pretraining and fine-tuning techniques for few-shot classification.
(2) meta-learning based methods (Jamal et al., 2018;Yu et al., 2018;Geng et al., 2019Geng et al., , 2020Bao et al., 2020), which aim to learn generic information (meta-knowledge) by recreating training episodes, so that it can classify new classes through only a few labeled examples. Among these methods, Bao et al. (2020) leveraged distributional signatures (e.g. word frequency and information entropy) to train a model within a meta-learning framework, and achieved state-of-the-art performance. However, the method pays more attention to statistical information and ignores other implicit information such as correlation between words. Furthermore, existing meta-learning methods heavily rely on the exploitation of lexical features and their distributional signatures on training data, while neglecting to strengthen the model's ability to adapt to new tasks.
In this paper, we propose an adversarial domain adaptation network to enhance meta-learning framework, with the objective of improving the model's adaptive ability for new tasks in new domains. We first utilize two neural networks competing against each other, separately playing the roles of a domain discriminator and a meta-knowledge generator. The adversarial network is able to strengthen the adaptability of the meta-learning architecture. Moreover, we aggregate transferable features generated by the meta-knowledge generator with sentence-specific features to produce highquality sentence embeddings. Finally, we utilize a ridge regression classifier to obtain final classification results. To the best of our knowledge, we are the first to combine adversarial domain adaptation with meta-learning for few-shot text classification. We evaluate our model on four popular datasets for few-shot text classification. Experimental results demonstrate that our method outperforms state-of-the-art models in all datasets, for both in 1-shot and 5-shot classification tasks. Especially on the 20 Newsgroups dataset, our model outperforms DS-FSL (Bao et al., 2020) by 7.5% in 1-shot classification and 9.5% in 5-shot classification. In addition, we conduct visualization analysis to verify the adaptability of our model and capability to recognize important lexical features for unseen classes.

Related Work
The mainstream approaches for few-shot text classification are based on meta-learning or transfer learning. In this section, we first briefly introduce the preliminary background of these two technologies, and then review how they are applied to support few-shot text classification.
Meta-learning Meta-learning, also known as "learning to learn", refers to improving the learning ability of a model through multiple training episodes so that it can learn new tasks or adapt to new environments quickly with a few training examples. Existing approaches mainly fall into two categories: (1) Optimization-based methods , including developing a meta-learner as optimizer to output search steps for each learner directly (Andrychowicz et al., 2016;Ravi and Larochelle, 2017;Mishra et al., 2018;Gordon et al., 2019) and learning an optimized initialization of model parameters, which can be later adapted to new tasks by a few steps of gradient descent (Finn et al., 2017;Yoon et al., 2018;Grant et al., 2018;Bao et al., 2020). (2) Metric-based methods, including Matching Network (Vinyals et al., 2016), PROTO (Snell et al., 2017), Relation Network (Sung et al., 2018), TapNet (Yoon et al., 2019) and Induction Network (Geng et al., 2019), which aim to learn an appropriate distance metric to compare validation points with training points and make prediction through matching training points.
Transfer learning Few-shot text classification relates closely to transfer learning (Zhuang et al., 2021) that aims to leverage knowledge from a related domain (a.k.a. source domain) to improve the learning performance and reduce the reliance on the number of labeled examples required in a target domain. Compared to meta-learning designed to aggregate the knowledge learned from many tasks, transfer learning typically involves a few tasks. In addition, we aim to directly reuse or fine-tune some existing representation in transfer learning, while a meta-learner is typically optimized at adapting to new tasks. Domain adaptation (Ganin et al., 2016;Tzeng et al., 2017;Khaddaj and Hajj, 2020) is a type of transfer learning, which aims to bridge the gap between the source and target domains by learning domain-invariant feature representations. Pre-trained model (Devlin et al., 2019;Yang et al., 2019;Brown et al., 2020) can also be viewed as a type of transfer learning. The parameters pretrained in the source domain are fine-tuned in the target domain, with faster training convergence.
Few-shot text classification To tackle few-shot text classification, a straightforward idea is to apply BERT (Devlin et al., 2019) or XLNet (Yang et al., 2019), which have achieved strong performance in text classification by fine-tuning with a small number of training examples. Their performances can be less dependent on the number of training samples for the new classes. Some other approaches are based on transfer learning. Pan et al. (2019) proposed a modified hierarchical pooling strategy over pre-trained word embeddings to transfer knowledge obtained from some source domains to the target domain. Gupta et al. (2020) developed a binary classifier on the source domain to classify new classes by prefixing class identifiers to input texts.
Meta-learning (Jamal et al., 2018;Yu et al., 2018;Geng et al., 2019Geng et al., , 2020Bao et al., 2020) can also be utilized to solve few-shot text classification, and has achieved state-of-the-art performance. Yu et al. (2018) proposed an adaptive metric learning approach that automatically determines the best weighted combination from meta-training tasks for few-shot tasks. Geng et al. (2019Geng et al. ( , 2020 leveraged the dynamic routing algorithm in meta-learning for few-shot text classification. (Bao et al., 2020) lever-aged distributional signatures (e.g. word frequency and information entropy) to train a model within a meta-learning framework.

Method
In this section, we first present the preliminary background on episode-based meta-learning framework (Vinyals et al., 2016). After that, we explicitly describe the proposed MLADA (Meta-Learning Adversarial Domain Adaptation) Network.

Episode-based meta-learning
The goal of meta-training is to train a classifier that can learn meta-knowledge from training data. In this way, the classifier can quickly learn from a few annotations when classifying unseen classes. The "episode" training strategy that Vinyals et al. (2016) proposed has proved to be effective. The episode-based meta-learning consists of two main stages: Meta-training Firstly, N classes are sampled from training data Y train . For each of these N classes, two subsets of examples are sampled separately as the support set S and the query set Q. Next, input the support set S and the query set Q to the model and update the parameters by minimizing the loss in the query set Q. The procedure above is called a training episode, which will be repeated multiple times.
Meta-testing After meta-training is finished, the performance of the model will be evaluated by the same episode-based mechanism. In a testing episode, N new classes will be sampled from Y test , which is disjoint to Y train . Then the support set and the query set will be sampled from the N classes. The model parameters can be fine-tuned through the small support set. The performance of the model will be evaluated through the average classification accuracy on the query set across all testing episodes.
We found that only a small subset of training data are accessible per training episode in the standard episode-based meta-training (Vinyals et al., 2016). To solve this problem, we build domain adversarial tasks to utilize more training data per training episode. Details of our model are described in the next section.

Meta-Learning Adversarial Domain
Adaptation Network (MLADA) Overview Our goal is to improve the performance of few-shot classification by combining adversarial domain adaptation and episode-based meta-learning. Figure 1 gives an overview of our model. In the rest of this section, we will introduce the main components of the model.

Word Representation Layer
The goal of this layer is to represent each word with a ddimensional vector. Following Bao et al. (2020), we construct the d-dimensional vector with the word embeddings, which is pre-trained with fast-Text (Joulin et al., 2016).

Domain Discriminator
We refer to the support set and the query set as the target domain and the rest of the training data as the source domain. We sample a subset of examples from the source domain as the source set. The goal of this module is to distinguish whether the sample is from the source domain or the target domain. The discriminator is a three layer feed-forward neural network. We apply the sof tmax function in the output layer to evaluate the probability distribution P r(y|λ). y = 0 or 1 represents that the sample is from the query set or the source set.
Meta-knowledge Generator This module is mainly composed of a bi-directional LSTM (BiL-STM) and a fully connected layer. We utilize a BiLSTM to encode contextual embeddings for each time-step. The input of the module is a sequence of word vectors P : [p 1 , ..., p m ], where m represents the number of words in a sentence. The output is a matrix h p d×m , which is composed of contextual embeddings. The goal of the meta-knowledge generator is not only to make the final classification results better, but also to confuse the domain discriminator as much as possible, so that the discriminator can not distinguish between samples from query set or source set. The theory on domain adaptation suggests that, for effective domain transfer to be achieved, predictions must be made based on features that cannot discriminate between the source domain and target domain, which is the motivation for us to build the meta-knowledge generator.
Interaction Layer We consider that the vector generated by the meta-knowledge generator is the transferable features, and word embeddings is the specific features of sentences. The role of the interaction layer is to fuse transferable features and sentence-specific features to produce the output as sentence embeddings, which will be used as the input of the classifier to obtain the final classification results. Suppose that the length of the sentence p is m, the word vectors is w p i (i ∈ [1, m]), the dimension of the word vector is d and the metaknowledge of the sentence is k p , then the final sentence vector is s p : where W p = [w p 1 , w p 2 , ..., w p m ].
Classifier The classifier is trained by the support set from scratch for each episode. We choose the ridge regression as the classifier. The reason why we adopt the ridge regression to fit the support set are as follows: 1) If we choose neural networks as the classifier, it will be trained inadequately because the number of samples in the support set is too small. 2) The ridge regression admits a closed-form solution and it reduces over-fitting on the small support set through proper regularization.Specifically, we minimize regularized squared loss: where m represents the number of samples in the support set, f θ (x (i) ) represents the prediction of the ridge regressor, y (i) represents the label of the sample, n j=1 θ 2 j denotes the squared Frobenius norm and λ > 0 controls the extent of the regularization.

Loss Function
In each training episode, we first fix the parameters of the generator and the discriminator to update the classifier's parameters by the support set. The classifier's loss function is shown in Eq.7.
Next, we fix the parameters of the generator and the classifier to update the discriminator's parameters by the query set and the source set. We use the cross-entropy loss as the discriminator's loss function, which is shown in Eq.8. where µ denotes the parameters of the discriminator, m represents the number of samples of the query set or the source set.y d = 0 or 1 denotes whether the sample is from the source set or the query set. k represents the meta-knowledge vector.
Finally, we fix the parameters of the discriminator and the classifier to update the generator's parameters by the query set and the source set. The loss function of the generator is composed of two components. The first one is a cross-entropy loss for the final classification results, and the second one is the opposite of the discriminator's loss, which is to confuse the discriminator.
where β represents the generator's parameters. f denotes the ridge regressor. W represents the matrix of word vectors in a sentence. y denotes the real labels of samples.L D is shown in Eq.8.
Training Procedure It is remarkable that the meta-knowledge generator is optimized over all training episodes, while the classifier is trained from scratch for each episode. In each training episode, we first utilize the support set to update the parameters in the classifier. Next, we use the query set and source set to update the parameters of the meta-knowledge generator and the domain discriminator. The details of training procedure of our model are shown in Algorithm 1.

Experiments
In this section, we perform comprehensive experiments to compare our proposed model with five competitive baselines, and evaluate the performance on four text classification datasets.

Datasets
We use four benchmark datasets for text classification, whose statistics are summarized in Table 1. HuffPost headlines contains 41 classes of news headlines from the year 2012 to 2018 obtained from HuffPost (Misra, 2018). Its text is less abundant (i.e., with smaller text length) than the other datasets and considered to be more challenging for text classification. Amazon product data contains product reviews from 24 product categories, including 142.8 million reviews spanning 1996-2014 (He and McAuley, 2016). Our task is to identify the product categories of the reviews. Since the original dataset is proverbially large, we sample a subset of 1, 000 reviews from each category. Reuters-21578 is collected from Reuters newswire in 1987. We use the standard ApteMode version of the dataset. Following Bao et al. (2020), we consider 31 classes and remove multi-labeled articles. Each class contains at least 20 articles. 20 Newsgroups is a collection of approximately 20,000 newsgroup documents (Lang, 1995), partitioned (nearly) evenly across 20 different newsgroups.

Experiment Setup
Baselines We compare our MLADA with multiple competitive baselines, which are briefly summarized in the following:  Table 2: Mean accuracy (%) of 5-way 1-shot and 5-way 5-shot classification over four datasets.
• MAML (Finn et al., 2017) is trained by maximizing the sensitivity of the loss functions of new tasks, so that it can rapidly adapt to new tasks after the parameters have been up-dated through few gradient steps.
• Prototypical Networks (Snell et al., 2017), abbreviated as PROTO, is a metric-based method for few-shot classification by using sample averages as class prototypes.
• Induction Networks (Geng et al., 2019) learns a class-wise representation by leveraging the dynamic routing algorithm in metalearning.
• HATT (Gao et al., 2019) extends PROTO by adding a hybrid attention mechanism to the prototypical network.
• DS-FSL (Bao et al., 2020) is trained within a meta-learning framework to map the distribution signatures into attention scores so as to extract more transferable features. Following Bao et al. (2020), we use pre-trained fastText (Joulin et al., 2016) for word embedding. In the meta-knowledge generator, we use a BiLSTM with 128 hidden units.

Implementation Details
In the domain discriminator, the numbers of hidden units for the two feed-forward layers are set to 256 and 128, respectively. All parameters are optimized using Adam with a learning rate of 0.001 (Kingma and Ba, 2015). During meta-training, we perform 100 training episodes (T = 100) per epoch. Meanwhile, we apply early stopping when the accuracy on the validation set fails to improve for 20 epochs. We evaluate the model performance based on 1, 000 testing episodes and report the average accuracy over 5 different random seeds. All the experiments are conducted on a NVIDIA v100 GPU.

Experimental Results
The experimental results are reported in Table 2. Our model achieves the best performance across all datasets, with an average accuracy of 63.9% in 1-shot classification and 81.4% in 5-shot classification, outperforming the state-of-the-art model DS-FSL (Bao et al., 2020) by a notable 4% improvement. For DS-FSL, it extracts transferable features via certain distribution signatures (e.g., word frequency or information entropy), but ignores other information of sentences, including implicit interaction between words. In contrast, we does not limit the transferable knowledge to statistical information. Our strategy is to combine the proposed domain adversarial network with meta-learning, generating more comprehensive transferable features. Furthermore, our model improves dramatically 7.5% and 9.5% on 20 Newsgroups in 1-shot and 5-shot classification. The average length of texts in the 20 Newsgroups is longer than the other datasets.
The empirical results clearly demonstrate that our model is more suitable for longer texts, which contain more abundant text information.

Ablation Study
We conduct an ablation study to examine the effectiveness of the proposed domain adversarial network as well as the interaction layer and the source set. The results of Amazon dataset are reported in Table 3. Firstly, we use a bi-directional LSTM instead of the proposed domain adversarial network (including the meta-knowledge generator and the domain discriminator) for sentence encoding. The performances in the tasks of 1-shot classification and 5-shot classification decrease by 6.5% and 5.3%, respectively. This verifies the effectiveness of the proposed domain adversarial network.
Secondly, we study how the interaction layer contributes to the performance of our model. We concatenate the vector generated by the metaknowledge generator directly with the average sentence embedding instead of the interaction layer. From the result in Table 3, we can see that our proposed interaction layer to combine the transferable features with the sentence-specific information are indeed more effective.
Finally, we remove the source set and utilize the discriminator to distinguish the true classes of samples. We observe that the source set is also important to performance. Due to the removal of the source set, the model has only access to the sup-

Visualization
We utilize visualization experiments to demonstrate that our model can generate high-quality sentence embeddings and identify important lexical features for unseen classes. We first use t-SNE (Van der Maaten and Hinton, 2008) visualization of sentence embeddings generated by different methods on the query set, as shown in Figure2. Compared to 2(a) average word embeddings and 2(b) DS-FSL, our method produces better separation both in 1-shot and 5shot classification, demonstrating the effectiveness of MLADA in leveraging the supervised learning experience to generate high-quality sentence embeddings for few-shot text classification.
Moreover, we visualize the weight vectors generated by the meta-knowledge generator and compare it with DS-FSL, as shown in Figure 3. Our model reduces the weight of "committee" while increasing the weight of "Olympic", which demonstrates that our model can recognize important lexical features in the new task, rather than simply transferring features obtained from experience.

Conclusion
In this paper, we propose a novel meta-learning approach called Meta-Learning Adversarial Domain Adaptation Network(MLADA), which can recognize important lexical features and generate highquality sentence embeddings in new classes(not seen in training data). Specifically, we design an adversarial domain adaptation network in metatraining episodes, which aims to extract domaininvariant features and improve the adaptability of the meta-learner in new classes. We demonstrate that our method outperforms the existing state-ofthe-art approaches on four standard text classification datasets. Future work includes applying MLADA to other fields including computer vision and speech recognition, and exploring the combination between adversarial domain adaptation network and other FSL algorithms.