BertGCN: Transductive Text Classification by Combining GNN and BERT

In this work, we propose BertGCN, a model that combines large scale pretraining and transductive learning for text classification. BertGCN constructs a heterogeneous graph over the dataset and represents documents as nodes using BERT representations. By jointly training the BERT and GCN modules within BertGCN, the proposed model is able to leverage the advantages of both worlds: large-scale pretraining which takes the advantage of the massive amount of raw data and transductive learning which jointly learns representations for both training data and unlabeled test data by propagating label influence through graph convolution. Experiments show that BertGCN achieves SOTA performances on a wide range of text classification datasets. Code is available at https://github.com/ZeroRin/BertGCN.


Introduction
Text classification is a core task in natural language processing (NLP) and has been used in many realworld applications such as spam detection (Wang, 2010) and opinion mining (Bakshi et al., 2016). Transductive learning (Vapnik, 1998) is a particular method for text classification which makes use of both labeled and unlabeled examples in the training process. Graph neural networks (GNNs) serve as an effective approach for transductive learning (Yao et al., 2019;Liu et al., 2020). In these works, a graph is constructed to model the relationship between documents. Nodes in the graph represent text units such as words and documents, while edges are constructed based on the semantic similarity between nodes. GNNs are then applied to the graph to perform node classification. The merits of GNNs and transductive learning are as follows: (1) the decision for an instance (both training and test) does not depend merely on itself, but also its neighbors. 1 Code available at https://github.com/ ZeroRin/BertGCN. This makes the model more immune to data outliers; (2) at the training time, since the model propagates influence from supervised labels across both training and test instances through graph edges, unlabeled data also contributes to the process of representation learning, and consequently higher performances.
Large-scale pretraining has recently demonstrated their effectiveness on a variety of NLP tasks (Devlin et al., 2018;Liu et al., 2019). Trained on large-scale unlabeled corpora in an unsupervised manner, large-scale pretrained models are able to learn implicit but rich text semantics in language at scale. Intuitively, large-scale pretrained models have potentials to benefit transductive learning. However, existing models for transductive text classification (Yao et al., 2019;Liu et al., 2020) did not take large-scale pretraining into consideration, and its effectiveness still remains unclear.
In this work, we propose BertGCN, a model that combines the advantages of both large-scale pretraining and transductive learning for text classification. BertGCN constructs a heterogeneous graph for the corpus with node being word or document, and node embeddings are initialized with pretrained BERT representations, and uses graph convolutional networks (GCN) for classification. By jointly training the BERT and GCN modules, the proposed model is able to leverage the advantages of both worlds: large-scale pretraining which takes the advantage of the massive amount of raw data and transductive learning which jointly learns representations for both training data and unlabeled test data by propagating label influence through graph edges. The proposed BertGCN model successfully combines the powers of large-scale pretraining and graph networks, and achieves new state-of-the-art performances on a wide range of text classification datasets.

arXiv:2105.05727v3 [cs.CL] 16 May 2021
Graph neural networks (GNNs) are connectionist models that capture dependencies and relations between graph nodes via message passing through edges that connect nodes (Scarselli et al., 2008;Hamilton et al., 2017;Xu et al., 2018). GNNs are practically categorized into : graph convolutional networks (Kipf and Welling, 2016a;Wu et al., 2019), graph attention networks (Veličković et al., 2017;Zhang et al., 2018a), graph auto-encoder (Cao et al., 2016;Kipf and Welling, 2016b), graph generative networks (De Cao and Kipf, 2018;Li et al., 2018b) and graph spatialtemporal networks (Li et al., 2017;. GNNs serve as powerful tools to utilize the relationship between different objects, and have been applied to various domains such as traffic prediction Zhang et al., 2018a) and recommendation Monti et al., 2017). In the context of NLP, GNNs have achieved remarkable successes across a wide range of end tasks such as relation extraction (Zhang et al., 2018b), semantic role labeling , data-to-text generation (Marcheggiani and Perez-Beltrachini, 2018), machine translation (Bastings et al., 2017) and question answering (Song et al., 2018;. The prevalence of neural networks has motivated a diverse array of works on developing neural models for text classification. Different neural model architectures (Kim, 2014;Zhou et al., 2015;Radford et al., 2018;Chai et al., 2020) have demonstrated their effectiveness against traditional statistical feature based methods (Wallach, 2006). Other works leverage label embeddings and jointly train them along with input texts Pappas and Henderson, 2019). More recently, the success achieved by large-scale pretraining models has spurred great interests in adapting the largescale pretraining framework (Devlin et al., 2018) into text classification (Reimers and Gurevych, 2019), leading to remarkable progressive on fewshot (Mukherjee and Awadallah, 2020) and zeroshot (Ye et al., 2020) learning.
Our work is inspired by the work of using graph neural networks for text classification (Yao et al., 2019;Huang et al., 2019;. But different from these works, we focus on combining large-scale pretrained models and GNNs, and show that GNNs can significantly benefit from large-scale pretraining. Existing works that combine BERT and GNNs uses graph to model relationships between tokens within a single document sample (Lu et al., 2020;He et al., 2020b), which fall into the category of inductive learning. Different from these works, we use graph to model relationships between different samples from the whole corpus to utilize the similarity between labeled and unlabeled documents, and uses GNNs to learn their relationships.

BertGCN
In the proposed BertGCN model, we initialize representations for document nodes in a text graph using a BERT-style model (e.g., BERT, RoBERTa). These representations are used as inputs to GCN. Document representations will then be iteratively updated based on the graph structures using GCN, the outputs of which are treated as final representations for document nodes, and are sent to the softmax classifier for predictions. In this way, we are able to leverage the complementary strengths of pretrained models and graph models.
Specifically, we construct a heterogeneous graph containing both word nodes and document nodes following TextGCN (Yao et al., 2019). We define word-document edges and word-word edges based on the term frequency-inverse document frequency (TF-IDF) and positive point-wise mutual information (PPMI), respectively. The weight of an edge between two nodes i and j is defined as: In TextGCN, an identity matrix X = I n doc +n word is used as initial node features, where n doc is the number of document nodes, n word is the number of word nodes (including both training and test). In BertGCN, we use a BERT-style model to obtain the document embeddings, and treat them as input representations for document nodes. Document node embeddings are denoted by X doc ∈ R n doc ×d , where d is the embedding dimensionality. Overall, the initial node feature matrix is given by: We feed X into a GCN model (Kipf and Welling, 2016a) which iteratively propagates messages across training and test examples. Specifically, the output feature matrix of the i-th GCN layer L (i) is computed as where ρ is an activation function,Ã is the normalized adjacency matrix and W (i) ∈ R d i−1 ×d i is a weight matrix of the layer. L (0) = X is the input feature matrix of the model. Outputs of GCN are treated as final representations for documents, which is then fed to the softmax layer for classification: where g represents the GCN model. We use the cross entropy loss over labeled document nodes to jointly optimize parameters for BERT and GCN.

Interpolating BERT and GCN Predictions
Practically, we find that optimizing BertGCN with a auxiliary classifier that directly operates on BERT embeddings leads to faster convergence and better performances. Specifically, we construct an auxiliary classifier by directly feeding document embeddings (denoted by X) to a dense layer with softmax activation: The final training objective is the linear interpolation of the prediction from BertGCN and the prediction from BERT, which is given by: where λ controls the tradeoff between the two objectives. λ = 1 means we use the full BertGCN model, and λ = 0 means we only use the BERT module. When λ ∈ (0, 1), we are able to balance the predictions from both models, and the BertGCN model can be better optimized.
The explanation for better performances achieved by the interpolation is as follows: The Z BERT directly operates on the input of GCN, making sure that inputs to GCN are regulated and optimized towards the objective. This helps the multi-layer GCN model to overcome intrinsic drawbacks such as gradient vanishing or over-smoothing (Li et al., 2018a), and thus leads to better performances.

Optimization using Memory Bank
The original GCN model uses the full-batch gradient descent method for training, which is intractable for the proposed BertGCN model, since the full-batch method can not be applied to BERT due to the memory limitation. Inspired by techniques in contrastive learning which decouples the dictionary size from the mini-batch size (Wu et al., 2018;He et al., 2020a), we introduce a memory bank that stores all document embeddings to decouple the training batch size from the total number of nodes in the graph.
Specifically, during training, we maintain a memory bank M that tracks input features for all document nodes. At the beginning of each epoch, we first compute all document embeddings using the current BERT module and store them in M . During each iteration, we sample a mini batch from both labeled and unlabeled document nodes with the index set B = {b 0 , b 1 ...b n }, where n is the mini-batch size. We then compute their document embeddings M B also using the current BERT module and update the corresponding memories in M . 2 Next, we use the updated M as input to derive the GCN output and compute the loss for the current mini batch. For back-propagation, M is considered as constant except the records in B.
With the memory bank, we are able to efficiently train the BertGCN model including the BERT module. However, during training, the embeddings in the memory bank are computed using the BERT module at different steps in an epoch and are thus inconsistent. To overcome this issue, we set a small learning rate for the BERT module to improve consistency of the stored embeddings. With low learning rate the training takes more time. In order to speed up training, we fine-tune a BERT model on the target dataset before training begins, and use it to initialize the BERT parameters in BertGCN.

Experiment Setups
We run experiments on five widely-used text classification benchmarks: 20 Newsgroups (20NG) 3 , R8  Table 1: Results for different models on transductive text classification datasets. We run all models 10 times and report the mean test accuracy. and R52 4 , Ohsumed 5 and Movie Review (MR) 6 .
We follow protocols in TextGCN to preprocess data. For BERT and RoBERTa, we use the output feature of the [CLS] token as the document embedding, followed by a feedforward layer to derive the final prediction. We use BERT base and a two-layer GCN to implement BertGCN. We initialize the learning rate to 1e-3 for the GCN module and 1e-5 for the fine-tuned BERT module. We also implement our model with RoBERTa and GAT (Veličković et al., 2017). GAT variants are trained over the same graph as GCN variants, but learn edge weights through attention mechanism instead of using predefined weight matrix.  tics, which means that long texts may produce more document connections transited via an intermediate word node, and this potentially benefits message passing through the graph, leading to better performances when combined with GCN. This may also explain why GCN models perform better than BERT models on 20NG. For datasets with shorter documents such as R52 and MR, the power of graph structure is limited, and thus the performance boost is smaller relative to 20NG. BertGAT and RoBERTaGAT can also benefit from the graph structure, but their performance are not as good as GCN variants due to the lack of edge weight information.

Main Results
4.3 The Effect of λ λ controls the trade-off between training BertGCN and BERT. The optimal value of λ can be different for different tasks. Fig.1 shows the accuracy of RoBERTaGCN with different λ. On 20NG, the accuracy is consistently higher with larger λ value. This can be explained by the high performance of graph-based methods on 20NG. The model reaches its best when λ = 0.7, performing slightly better than only using the GCN prediction (λ = 1).

The Effect of Strategies in Joint Training
To overcome inconsistency of embeddings in the memory bank, we set a smaller learning rate for the BERT module and use a finetuned BERT model for initialization. We evaluate the effect of the two strategies. Table 2 shows the results of RoBERTaGCN on 20NG with and without these strategies. With the same learning rate for RoBERTa and GCN, the model cannot be trained due to inconsistency in the memory bank, regardless of whether the fine-tuned RoBERTa is used. Models can be successfully trained when we set a smaller learning rate for the RoBERTa module, and additional using finetuned RoBERTa leads to the best performance.

Conclusion and Future Work
In this work, we propose BertGCN, which takes the best advantages from both large-scale pretraining models and transductive learning for text classification. We efficiently train BertGCN by using a memory bank that stores all document embeddings and updates part of them with respect to the sampled mini batch. The framework of BertGCN can be built on top of any document encoder and any graph model. Experiments demonstrate the power of the proposed BertGCN model. However, in this work, we only use document statistics to build the graph, which might be sub-optimal compared to models that are able to automatically construct edges between nodes. We leave this in future work.