R2D2: Recursive Transformer based on Differentiable Tree for Interpretable Hierarchical Language Modeling

Human language understanding operates at multiple levels of granularity (e.g., words, phrases, and sentences) with increasing levels of abstraction that can be hierarchically combined. However, existing deep models with stacked layers do not explicitly model any sort of hierarchical process. This paper proposes a recursive Transformer model based on differentiable CKY style binary trees to emulate the composition process. We extend the bidirectional language model pre-training objective to this architecture, attempting to predict each word given its left and right abstraction nodes. To scale up our approach, we also introduce an efficient pruned tree induction algorithm to enable encoding in just a linear number of composition steps. Experimental results on language modeling and unsupervised parsing show the effectiveness of our approach.


Introduction
The idea of devising a structural model of language capable of learning both representations and meaningful syntactic structure without any humanannotated trees has been a long-standing but challenging goal. Across a diverse range of linguistic theories, human language is assumed to possess a recursive hierarchical structure (Chomsky, 1956(Chomsky, , 2014de Marneffe et al., 2006) such that lowerlevel meaning is combined to infer higher-level semantics. Humans possess notions of characters, words, phrases, and sentences, which children naturally learn to segment and combine.
Pretrained language models such as BERT (Devlin et al., 2019) have achieved substantial gains * Equal contribution. 1 The code is available at: https://github.com/ alipay/StructuredLM_RTDT across a range of tasks. However, they simply apply layer-stacking with a fixed depth to increase the modeling power (Bengio, 2009;Salakhutdinov, 2014). Moreover, as the core Transformer component (Vaswani et al., 2017) does not capture positional information, one also needs to incorporate additional positional embeddings. Thus, pretrained language models do not explicitly reflect the hierarchical structure of linguistic understanding.
Inspired by Le and Zuidema (2015), Maillard et al. (2017) proposed a fully differentiable CKY parser to model the hierarchical process explicitly. To make their parser differentiable, they primarily introduce an energy function to combine all possible derivations when constructing each cell representation. However, their model is based on Tree-LSTMs (Tai et al., 2015;Zhu et al., 2015) and requires O(n 3 ) time complexity. Hence, it is hard to scale up to large training data.
In this paper, we revisit these ideas, and propose a model applying recursive Transformers along differentiable trees (R2D2). To obtain differentiability, we adopt Gumbel-Softmax estimation (Jang et al., 2017) as an elegant solution. Our encoder parser operates in a bottom-up fashion akin to CKY parsing, yet runs in linear time with regard to the number of composition steps, thanks to a novel pruned tree induction algorithm. As a training objective, the model seeks to recover each word in a sentence given its left and right syntax nodes. Thus, our model does not require any positional embedding and does not need to mask any words during training. Figure 1 presents an example binary tree induced by our method: Without any syntactic supervision, it acquires a model of hierarchical construction from the word-piece level to words, phrases, and finally the sentence level.
what ' s more , such short -term cat #ac #ly #sms are sur #vi #vable and are no cause for panic selling . We make the following contributions: • Our novel CKY-based recursive Transformer on differentiable trees model is able to learn both representations and tree structure (Section 2.1). • We propose an efficient optimization algorithm to scale up our approach to a linear number of composition steps (Section 2.2). • We design an effective pre-training objective, which predicts each word given its left and right syntactic nodes (Section 2.3). For simplicity and efficiency reasons, in this paper we conduct experiments only on the tasks of language modeling and unsupervised tree induction. The experimental results on language modeling show that our model significantly outperforms baseline models with same parameter size even in fewer training epochs. At unsupervised parsing, our model as well obtains competitive results.
Differentiable Tree. We follow Maillard et al. (2017) in defining a differentiable binary parser using a CKY-style (Cocke, 1969;Kasami, 1966;Younger, 1967) encoder. Informally, given a sentence S = {s 1 , s 2 , ..., s n } with n words or wordpieces, Figure 2 shows the chart data structure T , where each cell T i,j is a tuple e i,j , p i,j , p i,j , e i,j is a vector representation, p i,j is the probability of a single composition step, and p i,j is the probability of the subtree at span [i, j] over sub-string s i:j . At the lowest level, we have terminal nodes T i,i with e i,i initialized as embeddings of inputs s i , while p i,i and p i,i are set to one. When j > i, the representation e i,j is a weighted sum of intermediate combinations c k i,j , defined as: Here, k is a split point from i to j − 1, f (·) is a composition function that we shall further define later on, p k i,j and p k i,j denote the single step combination probability and the subtree probability, respectively, at split point k, p i,j and p i,j are the concatenation of all p k i,j or p k i,j values, and GUMBEL is the Straight-Through Gumbel-Softmax operation of Jang et al. (2017) with temperature set to one. The [, ] notation denotes stacking of tensors. Recursive Transformer. Figure 3 provides a schematic overview of the composition function f (·), comprising N Transformer layers. Taking c k i,j and p k i,j as an example, the input is a concatenation of two special tokens [SUM] and [CLS], the left cell e i,k , and the right cell e k+1,j . We also add role embeddings ([LEFT] and [RIGHT]) to the left and right inputs, respectively. Thus, the input consists of four vectors in R d . We denote as h [SUM] , h [CLS] , h i,k , h k+1,j ∈ R d the hidden state of the output of N Transformer layers. This is followed by a linear layer over h [SUM] to obtain where W p ∈ R 1×d , b p ∈ R, and σ refers to sigmoid activation. Then, c k i,j is computed as where W w ∈ R 2×d with w k i,j ∈ R 2 capturing the respective weights of the left and right hidden states h i,k and h k+1,j , and the final c k i,j is a weighted sum of h i,k and h k+1,j .
Tree Recovery. As the Straight-Through Gumbel-Softmax picks the optimal splitting point k at each cell in practice, it is straightforward to recover the complete derivation tree, Tree(T 1,n ), from the root node T 1,n in a top-down manner recursively. for i ∈ 1 to n − 1 do 6:

Complexity Optimization
for j ∈ i to n − 1 do 7: i ← i ≥ u + 1 ? i + 1 : i 8: j ← j ≥ u ? j + 1 : j 9: T i,j ← T i ,j Skip dark gray cells in Fig. 4 10: return T 11: function TREEINDUCTION(T , m) 12: T ← T 13: for t ∈ 1 to T .len − 1 do 14: if t ≥ m then 15: T ← PRUNING (T ,m) 16: l ← min(t + 1, m) Clamp the span length 17: for i ∈ 1 to T .len − l + 1 do 18: if T i,j is empty then 20: Compute cell T i,j with Equation 1 21: return T As the core computation comes from the composition function f (·), our pruned tree induction algorithm aims to reduce the number of composition calls from O(n 3 ) in the original CKY algorithm to linear.
Our intuition is based on the conjecture that locally optimal compositions are likely to be retained and participate in higher-level feature combination. Specifically, taking T 2 in Figure 4 (c) as an example, we only pick locally optimal nodes from the second row of T 2 . If T 2 4,5 is locally optimal and non-splittable, then all the cells highlighted in dark gray in (d) may be pruned, as they break span [4,5]. For any later encoding, including higher-level ones, we can merge the nodes and treat T 2 4,5 as a new non-splittable terminal node (see (e) to (g)).
Algorithm 2 Find the best merge point Create an array 4: for i ∈ 1 to n − 1 do 5: Collect cells on the 2nd row 6: τ ← ∅ 7: for i ∈ 1 to n − m + 1 do Iterate to m-th row 8: If index out of boundary then set to 0 16: l.append(x.p · p l · p r ) 17: return argmax i l[i] Figure 4 walks through the steps of processing a sentence of length 6, where s i:j denotes a substring from s i to s j . Algorithm 1 constructs our chart table T sequentially row-by-row. Let t be the time step and m be the pruning threshold. First, we invoke TREEINDUCTION (T , m), and compute a row of cells at each time step when t < m as in regular CKY parsing, leading to result (b) in Figure 4. When t ≥ m, we call PRUNING (T , m) in Line 15. As mentioned, the PRUNING function aims to find the locally optimal combination node in T , prunes some cells, and returns a new table omitting the pruned cells. Algorithm 2 shows how we FIND the locally optimal combination node. Again, the candidate set for the locally optimal node is the second row of T , and we also take advantage of the subtrees derived from all nodes in the m-th row to limit the candidate set. Lines 6 to 9 in Algorithm 2 generate the candidate set. Each candidate must be in the second row of T and also must be used in a subtree of any node in the m-th row. Given the candidate set, we find the least ambiguous one as the optimal selection (Lines 11 to (c) For each cell in the m-th row, recover its subtree and collect candidate nodes, each of which must appear in the subtree and also must be in the 2nd row, e.g., the tree of T 2 3,5 is within the dark line, and the candidate node is T 2 4,5 . (d) Find locally optimal node, which is T 2 4,5 here, and treat span s 4:5 as non-splittable. Thus, the dark gray cells become prunable. 17), i.e., the node with maximum own probability while adjacent bi-gram node probabilities (Lines 13 and 14 ) are as low as possible. After selecting the best merge point u, are pruned (highlighted in dark gray in (d)), and we generate a new table T t+1 by removing pruned nodes (Lines 4 to 9 in Algorithm 1). Then we obtain (e), and compute the empty cells on the m-th row of T 3 to obtain (f). We continue with the loop in Line 13, trigger PRUNING again, and obtain a new table T t+1 , and then fill empty cells on the m-th row T t+1 . Continuing with the process until all cells are computed, as shown in (g), we finally obtain a discrete chart table as given in (h).
In terms of the time complexity, when t ≥ m, there are at most m cells to update, so the complexity of each step is less than O(m 2 ). When t ≤ m, the complexity is O(t 3 ) ≤ O(m 2 t). Thus, the overall times to call the composition function is O(m 2 n), which is linear considering m is a constant.

Pretraining
Different from the masked language model training of BERT, we directly minimize the sum of all negative log probabilities of all words or word-pieces As shown in Figure 5, after invoking our recursive encoder on a sentence S, we directly use e 1,i−1 and e i+1,n as the left and right contexts, respectively, for each word s i . To distinguish from the encoding task, the input consists of a concatenation of a special token [MASK], e 1,i−1 , and e i+1,n . We apply the same composition function f (·) as in Figure 3, and feed h [MASK] through an output softmax to predict the distribution of s i over the complete vocabulary. Finally, we compute the cross-entropy over the prediction and ground truth distributions.
In cases where e 1,i−1 or e i+1,n is missing due to the pruning algorithm in Section 2.2, we simply use the left or right longest adjacent non-empty cell. For example, T x,i−1 means the longest nonempty cell assuming we cannot find any non-empty T x ,i−1 for all x < x. Analogously, T i+1,y is defined as the longest non-empty right cell. Note that although the final table is sparse, the sentence representation e 1,n is always established.

Experiments
As our approach (R2D2) is able to learn both representations and intermediate structure, we evaluate its representation learning ability on bidirectional language modeling and evaluate the intermediate structures on unsupervised parsing.

Setup
Baselines and Evaluation. As the objective of our model is to predict each word with its left and right context, we use the pseudo-perplexity (PPPL) metric of Salazar et al. (2020) to evaluate bidirectional language modeling.
logP (s i | s 1:i−1 , s i+1:n , θ) PPPL is a bidirectional version of perplexity, establishing a macroscopic assessment of the model's ability to deal with diverse linguistic phenomena. We compared our approach with SOTA autoencoding and autoregressive language models capable of capturing bidirectional contexts, including BERT, XLNet (Yang et al., 2019), and AL-BERT (Lan et al., 2020). For a fair apples to apples comparison, all models use the same vocabulary and are trained from scratch on a language modeling corpus. The models are all based on the open source Transformers library 2 . To compute PPPL for models based on sequential Transformers, for each word s i , we only mask s i while others remain visible to predict s i . When we evaluate our R2D2 model, for each word s i , we treat the left s 1:i−1 and right s i+1:n as two complete sentences separately, then encode them separately, and pick the root nodes as the final representations of left and right contexts. In the end, we predict word s i by running our Transformers as in Figure 5.
Data. The English language WikiText-2 corpus (Merity et al., 2017) serves as training data. The dataset is split at the sentence level, and sentences longer than 128 after tokenization are discarded (about 0.03% of the original data). The total number of sentences is 68,634, and the average sentence length is 33.4.
Hyperparameters. The tree encoder of our model uses 3-layer Transformers with 768dimensional embeddings, 3,072-dimensional hidden layer representations, and 12 attention heads.
Other models based on the Transformer share the same setting but vary on the number of layers. Training is conducted using Adam optimization with weight decay with a learning rate of 5 × 10 −5 . The batch size is set to 8 for m=8 and 32 for m=4, though we also limit the maximum total length for each batch, such that excess sentences are moved to the next batch. The limit is set to 128 for m=8 and 512 for m=4. It takes about 43 hours for 10 epochs of training with m = 8 and about 9 hours with m=4, on 8 v100 GPUs.  of training epochs. These results suggest that our model architecture utilizes the training data more efficiently. Comparing the different pruning thresholds m=4 and m=8 (last two rows), the two models actually converge to a similar place after 60 epochs, confirming the effectiveness of the pruned tree induction algorithm. We also replace Transformers with Tree-LSTMs as in Jang et al. (2017), denoted as T-LSTM, finding that the perplexity is significantly higher compared to other models. The best score is from the BERT model with 12 layers at epoch 60. Although our model has a linear time complexity, it is still a sequential encoding model, and hence its training time is not comparable to that of fully parallelizable models. Thus, we do not have results of 12-layer Transformers in Table 1. The experimental results comparing models with the same parameter size suggest that our model may perform even better with further deep layers. Table 2 shows the training time of our R2D2 with and without pruning. The last row is proportionally estimated by running the small setting (12×12×1). It is clear that it is not feasible to run our R2D2 without pruning.

Unsupervised Constituency Parsing
We next assess to what extent the trees that naturally arise in our model bear similarities with human-specified parse trees.

Setup
Baselines and Evaluation. For comparison, we further include four recent strong models for un- , we train all systems on a training set consisting of raw text, and evaluate and report the results on an annotated test set. As an evaluation metric, we adopt sentence-level unlabeled F 1 computed using the script from Kim et al. (2019a). We compare against the non-binarized gold trees per convention. The best checkpoint for each system is picked based on scores on the validation set.
As our model is a pretrained model based on word-pieces, for a fair comparison, we test all models with two types of input: word level (W) and word-piece level (WP) 3 . To support word-piece level evaluation, we convert gold trees to wordpiece level trees by simply breaking each terminal node into a non-terminal node with its word-pieces as terminals, e.g., (NN discrepancy) into (NN (WP disc) (WP ##re) (WP ##pan) (WP ##cy). We set the pruning threshold m to 8 for our tree encoder.
To support a word-level evaluation, since our model uses word-pieces, we force it to not prune or select spans that conflict with word spans during prediction, and then merge word-pieces into words in the final output. However, note that this constraint is only used for word-level prediction.
For training, we use the same hyperparameters as in Section 3.1.1. Our model pretrained on WikiText-2 is finetuned on the training set with the same unsupervised loss objective. For Chinese, we use a subset of Chinese Wikipedia for pretraining, specifically the first 100,000 sentences shorter than 150 characters.
Data. We test our approach on the Penn Treebank (PTB) (Marcus et al., 1993) with the standard splits (2-21 for training, 22 for validation, 23 for test) and the same preprocessing as in recent work (Kim et al., 2019a), where we discard punctuation and lower-case all tokens. To explore the universality of the model across languages, we also run experiments on Chinese Penn Treebank (CTB) 8 (Xue et al., 2005), on which we also remove punctuation. Note that in all settings, the training is conducted entirely on raw unannotated text.   Kim et al. (2019a). F 1 (M) describes the max. score of 4 runs with different random seeds. The F 1 column shows results of our runs with a random seed. The bottom three systems take word-pieces as input, and are also measured against word-piece level golden trees.

Results and Discussion
this is a remarkable result. Note that models such as C-PCFG are specially designed for unsupervised parsing, e.g., adopting 30 nonterminals, 60 preterminals, and a training objective that is well-aligned with unsupervised parsing. In contrast, the objective of our model is that of bi-directional language modeling, and the derived binary trees are merely a by-product of our model that happen to emerge naturally from the model's preference for structures that are conducive to better language modeling. Another factor is the mismatch between our training and evaluation, where we train our model at the word-piece level, but evaluate against word-level gold trees. For comparison, we thus also considered DIORA (WP), C-PCFG (WP), and our system all trained on word-piece inputs, and evaluated against word-piece level gold trees. The last three lines show the results, with our system achieving the best F 1 . As breaking words into word-pieces introduces word boundaries as new spans, while word boundaries are easier to recognize, the overall F 1 score may increase, especially on Chinese.
Analysis. In order to better understand why our model works better when evaluating on word-piece level golden trees, we compute the recall of constituents following Kim et al. (2019b) and Drozdov et al. (2020). Besides standard constituents, we also compare the recall of word-piece chunks and  proper noun chunks. Proper noun chunks are extracted by finding adjacent unary nodes with same parent and tag NNP. Table 4 reports the recall scores for constituents and words on the WSJ and CTB test sets. Our model and DIORA perform better for small semantic units, while C-PCFG better matches larger semantic units such as VP and SBAR. The recall of word chunks (WD) of our system is almost perfect and significantly better than for other algorithms. Please note that all word-piece level models are trained fairly without using any boundary information. Although it is trivial to recognize English word boundaries among word-pieces using rules, this is non-trivial for Chinese. Additionally, the recall of proper noun segments is as well significantly better for our model compared to other algorithms.

Dependency Tree Compatibility
We compared examples of trees inferred by our model with the corresponding ground truth constituency trees (see Appendix), encountering reasonable structures that are different from the constituent structure posited by the manually defined gold trees. Experimental results of previous work (Drozdov et al., 2020;Kim et al., 2019a) also show significant variance with different random seeds. Thus, we hypothesize that an isomorphy-focused F 1 evaluation with respect to gold constituency trees is insufficient to evaluate how reasonable the induced structures are. In contrast, dependency grammar encodes semantic and syntactic relations directly, and has the best interlingual phrasal cohesion properties (Fox, 2002). Therefore, we introduce dependency compatibility as an additional metric and re-evaluate all system outputs.

Setup
Baselines and Data. As our approach is a wordpiece level pretrained model, to enable a fair comparison, we train all models on word-pieces and   learn models with the same settings as in the original papers. Evaluation at the word-piece level reveals the model's ability to learn structure from a smaller granularity. In this section, we keep the word-level gold trees unchanged and invoke Stanford CoreNLP (Manning et al., 2014) to convert the WSJ and CTB into dependency trees.
Evaluation. Our metric is based on the notion of quantifying the compatibility of a tree by counting how many spans comply with dependency relations in the gold dependency tree. Specifically, as illustrated in Figure 6, a span is deemed compatible with the ground truth if and only if this span forms an independent subtree. Formally, given a gold dependency tree D, we denote as S(D) the raw token sequence for D. Considering predicting a binary tree for word-level input, predicted spans in the binary tree are denoted as Z. For any span z ∈ Z, the subgraph of D including nodes in z and directional edges between them is referred to as G z . O(G z ) is defined as the set of nodes with parent nodes not in G z and I(G z ) denotes the set of nodes whose child nodes are not in G z . Thus, |O(G z )| and |I(G z )| are the outdegree and in-degree of the subgraph G z . Let I(z) denote whether z is valid, defined as 1, |O(Gz)| = 1 and I(Gz) ⊆ O(Gz) 0, otherwise.
For binary tree spans for word-piece level input, if z breaks word-piece spans, then I(z) = 0. Otherwise, word-pieces are merged to words and the word-level logic is followed. Specifically, to make the results at the word and word-piece levels comparable, I(z) is forced to be zero if z only covers a single word. The final compatibility for Z is Table 5 lists system results on the WSJ and CTB test sets. % all refers to the accuracy on all test sentences, while % n≤x is the accuracy on sentences with up to x words. It is clear that the smaller granularity at the word-piece level makes this task harder. Our model performs better than other systems at the word-piece level on both English and Chinese and even outperforms the baselines in many cases at the word level. It is worth noting that the result is evaluated on the same binary predicted trees as we use for unsupervised constituency parsing, yet our model outperforms baselines that perform better in Table 3. One possible interpretation is that our approach learns to prefer structures different from human-defined phrase structure grammar but self-consistent and compatible with a tree structure.

Results and Discussion
To further understand the strengths and weaknesses of each baseline, we analyzed the compatibility of different sentence length ranges. Interestingly, we find that our approach performs better on long sentences compared with C-PCFG at the word-piece level. This shows that a bidirectional language modeling objective can learn to induce accurate structures even on very long sentences, on which custom-tailored methods may not work as well.
Pre-trained models. Pre-trained models have achieved significant success across numerous tasks. ELMo (Peters et al., 2018), pretrained on bidirectional language modeling based on bi-LSTMs, was the first model to show significant improvements across many downstream tasks. GPT (Radford et al., 2018) replaces bi-LSTMs with a Transformer (Vaswani et al., 2017). As the global attention mechanism may reveal contextual information, it uses a left-to-right Transformer to predict the next word given the previous context. BERT (Devlin et al., 2019) proposes masked language modeling (MLM) to enable bidirectional modeling while avoiding contextual information leakage by directly masking part of input tokens. As masking input tokens results in missing semantics, XLNET (Yang et al., 2019) proposes permuted language modeling (PLM), where all bi-directional tokens are visible when predicting masked tokens. However, all aforementioned Transformer based models do not naturally capture positional information on their own and do not have explicit interpretable structural information, which is an essential feature of natural language. To alleviate the above shortcomings, we extend pre-training and the Transformer model to structural language models.
Representation with structures. In the line of work on learning a sentence representation with structures, Socher et al. (2011) proposed the first neural network model applying recursive autoencoders to learn sentence representations, but their approach constructs trees in a greedy way, and it is still unclear how autoencoders can perform against large pre-trained models (e.g., BERT).  jointly train their shift-reduce parser and sentence embedding components. As their parser is not differentiable, they have to resort to reinforcement training, but the learned structures collapse to trivial left/right branching trees. The work of URNNG (Kim et al., 2019b) applies variational inference over latent trees to perform unsupervised optimization of the RNNG (Dyer et al., 2016), an RNN model that estimates a joint distribution over sentences and trees based on shiftreduce operations. Maillard et al. (2017) propose an alternative approach, based on CKY parsing. The algorithm is made differentiable by using a soft-gating approach, which approximates discrete candidate selection by a probabilistic mixture of the constituents available in a given cell of the chart.
This makes it possible to train with backpropagation. However, their model runs in O(n 3 ) and they use Tree-LSTMs.

Conclusion and Outlook
In this paper, we have proposed an efficient CKYbased recursive Transformer to directly model hierarchical structure in linguistic utterances. We have ascertained the effectiveness of our approach on language modeling and unsupervised parsing.
With the help of our efficient linear pruned tree induction algorithm, our model quickly learns interpretable tree structures without any syntactic supervision, which yet prove highly compatible with human-annotated trees. As future work, we are investigating pre-training our model on billion word corpora as done for BERT, and fine-tuning our model on downstream tasks.