DoT: An efficient Double Transformer for NLP tasks with tables

Transformer-based approaches have been successfully used to obtain state-of-the-art accuracy on natural language processing (NLP) tasks with semi-structured tables. These model architectures are typically deep, resulting in slow training and inference, especially for long inputs. To improve efficiency while maintaining a high accuracy, we propose a new architecture, DoT, a double transformer model, that decomposes the problem into two sub-tasks: A shallow pruning transformer that selects the top-K tokens, followed by a deep task-specific transformer that takes as input those K tokens. Additionally, we modify the task-specific attention to incorporate the pruning scores. The two transformers are jointly trained by optimizing the task-specific loss. We run experiments on three benchmarks, including entailment and question-answering. We show that for a small drop of accuracy, DoT improves training and inference time by at least 50%. We also show that the pruning transformer effectively selects relevant tokens enabling the end-to-end model to maintain similar accuracy as slower baseline models. Finally, we analyse the pruning and give some insight into its impact on the task model.


Introduction
Recently, transfer learning with large-scale pretrained language models has been successfully used to solve many NLP tasks (Devlin et al., 2019;Radford et al., 2019;Liu et al., 2019). In particular, transformer models have been used to solve tasks that include semi-structured table knowledge, such as table question answering (Herzig et al., 2020) and entailment (Wenhu et al., 2019;Eisenschlos et al., 2020) -a binary classification task to support or refute a sentence based on the table's content.
While transformer models lead to significant improvements in accuracy, they suffer from high * Work done at Google Research. computation and memory cost, especially for large inputs. The total computational complexity per layer for self-attention is O(n 2 d) (Vaswani et al., 2017), where n is the input sequence length, and d is the embedding dimension. Using longer sequence lengths translates into increased training and inference time.
Improving the computational efficiency of transformer models has recently become an active research topic. To the best of our knowledge, the only technique that was applied to NLP tasks with semistructured tables is heuristic pruning. Eisenschlos et al. (2020) show on the TABFACT data set (Wenhu et al., 2019) that using heuristic pruning accelerates the training time while achieving a similar accuracy. This raises the question of whether a better pruning strategy can be learned.
We propose to use DoT , a double transformer model ( Figure 1): A first transformer -which we call pruning transformer -selects k tokens given a query and a table and a task-specific transformer solves the task based on these tokens. Decomposing the problem into two simpler tasks imposes additional structure that makes training more efficient: The first model is shallow, allowing the use of long input sequences at moderate cost, and the second model is deeper and uses the shortened input that solves the task. The combined model achieves a better efficiency-accuracy trade-off.
The pruning transformer is based on the TAPAS QA model (Herzig et al., 2020). TAPAS answers questions by selecting tokens from a given table. This problem is quite similar to the pruning task. The second transformer is a task-specific model adapted for each task to solve: We use another TAPAS QA model for QA and a classification model (Eisenschlos et al., 2020) for entailment. In Section 2, we explain how we jointly learn both models by incorporating the pruning scores into the attention mechanism.
DoT achieves a better trade-off between effi- ... E [CLS] E 1 E N E [SEP] ... ciency and accuracy on three datasets. We show that the pruning transformer selects relevant tokens, resulting in higher accuracy for longer input sequences. We study the meaning of relevant tokens and show that the selection is deeply linked to solving the main task by studying the answer token scores. We open source the code in http://github.com/google-research/tapas.

The DoT Model
As show in Figure 1, the double transformer DoT is composed of two transformers: the pruning transformer selects the most relevant k tokens followed by a task-specific model that operates on the selected tokens to solve the task. The two transformers are learned jointly. DoT loss is detailed in Appendix A.2. We explore learning the pruning model using an additional loss in Appendix C.2. Let q be the query (or statement) and T the table. The transformer takes as input the embedding E = [E [CLS] ; E q ; E [SEP ] ; E T ], composed of the query and table embeddings. The pruning transformer computes the probability P (t|q, T ) of the token t being relevant to solve the example. We derive the pruning score s t = log(P (t|q, T )) and keep the top-k tokens. The pruning scores are then passed to the task transformer as shown in Figure 2.
To enable the joint learning, we change the attention scores of the task model. For a normal transformer (Vaswani et al., 2017), given the input embedding E t at position t, for each layer and attention head, the self-attention output is given by a linear combination of the value vector projections using the attention matrix.
Each row of the attention matrix is obtained by a softmax on the attention scores z <t,t > given by where W Q and W K represent the query and key projections for that layer and head. In our task model we add a negative bias term and replace this equation with z <t,t >|s t = z <t,t > + s t Thus, the attention scores provide a notion of token relevance -detailed in Appendix A.1 -and enable end-to-end learning of both models, letting DoT define the top-K tokens. Unlike previous soft-masking methods (Bastings et al., 2019;De Cao et al., 2020), ours coincides exactly with removing the input token t when P (t|q, T ) → 0. We prove this formally in Appendix A.3.
We explore two different pruning strategies: token selection defined as discussed above and column selection where we average all token scores in each column.

Experimental Setup
We compare our approach against models using heuristic pruning. Cell concatenation (CC) The TAPAS model uses a default heuristic to limit the input tokens. The objective of the algorithm is to fit an equal number of tokens for each cell. This is done by first selecting the first token from each cell, then the second and so on until the desired limit is reached. Heuristic exact match (HEM ) (Eisenschlos et al., 2020). This method scores the columns based on their similarity to the question, where similarity is defined by token overlap.
We introduce a notation to clarify the setup: In all our experiments we report results for DoT using token selection for WIKISQL and TABFACT and a column selection for WIKITQ.

Results
The baseline TAPAS model outperforms the previous state-of-the-art on all datasets (Table 1) Efficiency accuracy trade-off Table 1 reports the accuracy test results along with the average number of processed examples per second N P E/s computed at training time. Using HEM as pre-processing step improves DoT models com-pared to CC for both WIKISQL and TABFACT. DoT (m) and DoT (s) reach better efficiency accuracy trade-off for WIKISQL: with a small drop of accuracy by 0.4% (respectively 0.7%), they are 3.5 (respectively 4.6) times faster than the best baseline. For TABFACT dataset, DoT is compared to a faster baseline than the one used for WIKISQL as it takes only 512 input tokens instead of 1024. DoT (s) still achieves a good trade-off: with a decrease of 0.4% of accuracy it is 1.5 times faster. Unlike the previous datasets, WIKITQ is a harder task to solve and requires passing more data. By restricting DoT (m) to select only 256 tokens we decrease the accuracy by a bigger drop 3.9% to be 3.5 times faster compared to HEM 1024 − −− → TAPAS(l).
Small task models The previous results, raise the question of whether a smaller task model can reach a similar accuracy. To answer this question, we compare Table 2. DoT outperforms the smaller models showing the importance of using both transformers.

Analysis
Accuracy for long input sequences To study the long inputs, we bucketize the datasets per example input length. We compare DoT (m Table 3. For the bucket > 1024 the DoT model outperforms the 256 and 512 length baselines for all tasks. This indicates that the pruning model extracts two times more relevant tokens than the heuristic CC.
For the bucket [512, 1024], we expect all models to reach a higher accuracy, as we expect lower loss of context than for the bucket > 1024 when applying CC. The results shows that DoT gives a similar accuracy to 512 − − → TAPAS for WIKISQL and TABFACT-in the margin error -and a slightly lower accuracy for WIKITQ: The pruning transformer selects only 256 top-K tokens compared to 512 − − → TAPAS that selects twice more. Thus, the task-specific transformer has access to less tokens, therefore to possibly less context that can lead to an accuracy drop. This drop is small compared to Pruning relevant tokens We inspect the pruning transformer on the WIKISQL and WIKITQ datasets, where the set of answer tokens is given.
− − → l) 72.9 ± 1.8 74.7 ± 0.7 39.1 ± 0.6 We compute the difference between the answer token scores and the average scores of the top-K tokens, and report the distribution in Figure 3. The pruning transformer tends to attribute high scores to the answer tokens, suggesting that it learns to answer the downstream question -a positive difference -especially for WIKISQL. The difference is lower for WIKITQ as it is a harder task: The set of answer tokens is larger, especially for aggregation, making their scores closer to the average.

Pruning transformer depth
We study the pruning transformer complexity impact on the efficiency accuracy trade-off. Figure 4 compares the results of medium, small and mini models -complexity in Appendix B.3. For all datasets the mini model drops drastically the accuracy. The pruning transformer must be deep enough to learn the top-K Figure 3: Distribution of the answer token scores minus the average scores of the top-K tokens. The difference is larger when the pruning transformer attributes a higher score to the answer tokens. tokens and attribute token scores that can be used by the task-specific transformer. For both WIK-ISQL and TABFACT the small model reaches a better accuracy efficiency trade-off: Using a small instead of medium -4 hidden layers instead of 8drops the accuracy by less than 0.4% -in the margin error -while accelerating the model times 1.3. In other words there is no gain of using a more complex model to select the top-K tokens especially when we restrict K to 256.
Restricting K can lead to a drop in the accuracy. Even by increasing the pruning complexity, DoT cannot recover the full drop. This is the case of WIKITQ. This dataset is more complex, it requires more reasoning including operation to run over multiple cells in one column. Thus selecting the top 256 tokens is a harder task compared to previous detests. We reduce the task complexity by using column selection instead of token selection. For this dataset using medium pruning transformer, DoT (m) reaches a better accuracy efficiency tradeoff: 2 points higher in accuracy compared to using a small transformer.
Effects of HEM and CC on DoT Table 1 and Figure 4 compare the effect of using HEM and CC on DoT models. As both heuristics are applied in the pre-processing step, using HEM or CC along with a similar DoT model, doesn't change the average number of processed examples per second N P E/s computed over the training step. For both WIKISQL and TABFACT we use a token based selection to select the top-K tokens. Combining the token based strategy with HEM , outperforms on accuracy the token pruning DoT combined with CC. For WIKITQ, the top-K pruning is a column based selection. Unlike the token selection the column pruning combined with HEM gives a lower accuracy.

Related work
Efficient Transformers Improving the computational efficiency of transformer models, especially for serving, is an active research topic. Proposed approaches fall into four categories. The first is to use knowledge distillation, either during the pretraining phase (Sanh et al., 2019), or for building task-specific models (Sun et al., 2019), or for both (Jiao et al., 2020). The second category is to use quantization-aware training during the finetuning phase of BERT models, such as (Zafrir et al., 2019). The third category is to modify the transformer architecture to improve the dependence on the sequence length (Choromanski et al., 2020;Wang et al., 2020). The fourth category is to use pruning strategies such as McCarley (2019), who studied structured pruning to reduce the number of parameters in each transformer layer, and Fan et al.
(2020) who used structured dropout to reduce transformer depth at inference time. Our method most closely resembles the last category, but we focus our efforts on shrinking the sequence length of the input instead of model weights. Eisenschlos et al. (2020) explore heuristic methods based on lexical overlap and apply it to tasks involving tabular data, as we do, but our algorithm is learned end-to-end and more general in nature. , based on reparametrization (Diederik and Max, 2014) to approximate the discrete choice of a rationale from an input text, before using it as input for a classifier. Partially masked tokens are then replaced at the input embedding layer by some linear interpolation. We rely on a soft attention mask instead as a way to partially reduce the information coming from some tokens during training. To the best of our knowledge these methods have not been investigated in the context of semi-structured data such as tables or evaluated with a focus on efficiency.

Conclusion
We introduced double transformer (DoT ) where an additional small model prunes the input of a larger second model. This accelerates the training and inference time at a low drop in accuracy. As future work we will explore hierarchical pruning and adapt DoT to other semi-structured NLP tasks. We study, the updates of the pruning scores according to the attention scores needs. We note the set of relevant tokens R. The output probability given by the pruning transformer is in (0, 1) making s t in (−∞, 0). Lets suppose that the token t is not needed to answer the question, then the attention scores are decreased z <i,t,t >|st → −∞ for all the tokens t ∈ R for all the layers i. The model updates both parts of z <i,t,t > making s t converging to −∞, then lim st→−∞ z <i,t,t >|st = −∞. Thus, the meaning of relevant token is defined by the attention scores updates: The pruning scores decreases for non relevant tokens and increase for relevant ones.

A.2 DoT loss
The DoT loss is similar to TAPAS model loss -noted as J SA = J aggr + βJ scalar in (Herzig et al., 2020) -computed over the task-specific transformer where the attention scores are modified. More precisely, we modify only the scalar loss of the task specific model J scalar . We incorporate the pruning scores S = {s t ∀t ∈ T top k =256 }, and we note J scalar|S . The DoT loss is then compute only over the top-K tokens: J DoT = J aggr + βJ scalar|S . For TABFACT dataset, Eisenschlos et al. (2020) modified the TAPAS loss -used for QA tasks -to adapt it to the entailment task: Aggregation is not used, instead, one hidden layer is added as output of the [CLS] token to compute the probability of Entailment. We use a similar loss for TABFACT where the attention scores are modified.
A.3 Feed-forward pass: Safe use of shorter inputs for the task-specific transformer The top-K selection enables the use of shorter inputs for the task-specific. We prove that using input length equal to K is equivalent to using input length higher than K, without any loss of context. Note that the pruning scores are the same for both inputs, where the top-K are scored non-zero and we impose the other tokens to be scored zero.
Theorem A.1. Given a transformer and a set of tokens as input I. Let t be one of the input tokens t ∈ I. If the transformer verifies the following conditions, that holds for all layers i.
Then applying this transformer on I is equivalent to applying it on I − {t} Proof. We look at the different use cases. ∀i layers, any token t ∈ I −{t} attending to any token t ∈ I − {t}: the soft-max scores a <i,t ,t > have the same formula using I or I − {t} as input.
Lets fix t = t. The token t attending to any token t ∈ I: The first condition 1 gives ∀t that attends to t, z <i,t,t > = −∞. That follows exp(z <i,t,t > ) = 0 then a <i,t,t > = 0.
Similarly, if t = t. Any token t ∈ I attending to t: The second condition 2 gives ∀t that attends to t, z <i,t,t > = −∞. That follows exp(z <i,t ,t> ) = 0 then a <i,t ,t> = 0.
Remark. Given a transformer and a set of tokens as input I. Let t be one of the input tokens t ∈ I with t is not selected by the pruning transformer scored zero -not the first-k tokens. Using DoT , s t = −∞. That follows z <i,t,t > = −∞.
The case t = t, for any token t ∈ I attending to t we have: ∀i = 0, the input E t = t ∈I a <i,t,t > . As z <i,t,t > = −∞, E t = 0, E t zero out all the variables making exp(z <i,t ,t> ) a constant and a <i,t ,t> independent of t . This is equivalent to ∀i = 0, t doesn't attend to any t ∈ I.
Only for the first layer i = 0, we add an approximation to drop the attention (t attending to t ∈ I). We consider the impact of t on the full attention is small as we stuck multiple layers. We experimented with a task-specific model with a big input length > k and compare it to a task-specific model with input length = k. The two models gives similar accuracy. In our experiment we report only the results for the model with input length = k.
This makes the attention scores similar to the ones computed over t / ∈ I.

B Experiments
In all the experiment we report the median accuracy and the error margin computed over 3 runs. We estimate the error margin as half the inter quartile range, that is half the difference between the 25 th and 75 th percentiles.

B.1 Models hyper-parameters
We do not perform hyper-parameters search for DoT models we use the same as TAPAS baselines.  Table 4: Hyper-parameters used per dataset. Reports the learning rate (lr), the warmup ratio (ρ), the hidden dropout, the attention dropout and the number of training steps (num steps) used for each dataset. These hyper-parameters are the same for all the baselines and DoT models.
For WIKISQL and WIKITQ we use the same hyper-parameters as the one used by (Herzig et al., 2020) and for TABFACT the one used by (Eisenschlos et al., 2020). Baselines and DoT are initialized from models pre-trained with a MASK-LM task and on SQA(Iyyer et al., 2017) following Herzig et al. (2020). We report the models hyper parameters used for TAPAS baselines and DoT in Table 4. The hyper-parameters are fixed independently of the pre-processing step or the input size: For all the pre-processing input lengths -{256, 512, 1024}-, for both CC and HEM we use the same hyperparameters. Additionally, we use an Adam optimizer with weight decay for all the baselines and DoT models -the same configuration as BERT.

B.2 state-of-the-art
We report state-of-the-art for the three datasets in Table 5.

B.3 Models complexity
In all our experiments we use different transformer sizes called large, medium, small and mini. These models correspond to the BERT open sourced model sizes described in Turc et al. (2019). We report all models complexity in Table 6. The sequence length changes the total number of used parameters. The formula to count the number of parameters is given by Table 7. The number of used parameters equals to V × H + (2 + 3L)I × H + I + (256 * 4 + 17 + 9L)H + (1 + 2L × H) × Hi. The number of parameters of each model is reported in Table 8 The number of parameters is not proportional to the computational time as multiple operations involves multiplying tensors of shapes [I, H] × [H, H]. Table 7: Parameters counts. Let H be the hidden embedding size, L the number of layers, Hi the intermediate size, V = 30522 the vocabulary size and I the input size. We report the used parameters based on tensors shape for TAPAS models.