How Many Layers and Why? An Analysis of the Model Depth in Transformers

In this study, we investigate the role of the multiple layers in deep transformer models. We design a variant of Albert that dynamically adapts the number of layers for each token of the input. The key specificity of Albert is that weights are tied across layers. Therefore, the stack of encoder layers iteratively repeats the application of the same transformation function on the input. We interpret the repetition of this application as an iterative process where the token contextualized representations are progressively refined. We analyze this process at the token level during pre-training, fine-tuning, and inference. We show that tokens do not require the same amount of iterations and that difficult or crucial tokens for the task are subject to more iterations.


Introduction
Transformers are admittedly over-parametrized Hou et al., 2020;Voita et al., 2019). Yet the role of this over-parametrization is not well understood. In particular, transformers consist of a fixed number of stacked layers, which are suspected to be highly redundant  and to cause over-fitting (Fan et al., 2020;. In this paper we provide a study on the role of the multiple layers traditionally used. The mechanism of transformer layers is often compared to intuitive NLP pipelines (Tenney et al., 2019). Starting with the lower layers encoding surface information, middle layers encoding syntax and higher layers encoding semantics (Jawahar et al., 2019;Peters et al., 2018). Transformers progressively refine the features, which become more fine-grained at each iteration . However, ALBERT (Lan et al., 2020) highlights that it is possible to tie weights across layers and repeat the application of the same function. Consequently, we hypothesize that it is the number of layer applications that gradually abstracts the surface information into semantic knowledge.
To better study the transformation of token representations across layers, we propose a variant of ALBERT. Our model implements the key specificity of weights tying across layers but also dynamically adapts the number of layers applied to each token. Since all layers share the same weight, we refer to the application of the layer to the hidden states as an iteration.
After reviewing the related work (Section 2), we detail the model and the training methodology in Section 3. In particular, we encourage our model to be parsimonious and limit the total number of iterations performed on each token. In Section 4, we analyze iterations of the model during pre-training, fine-tuning and inference.

Related Work
Adapting the transformer depth is an active subject of research. In particular, deep transformer models are suspected to struggle to adapt to different levels of difficulty. While large models correctly predict difficult examples, they over-calculate simpler inputs . This issue can be addressed using early-stopping: some samples might be sufficiently simple to classify using intermediate features. Some models couple a classifier to each layer . After each layer, given the classifier output, the model either immediately returns the output or passes the sample to the next layer. Exiting too late may even have negative impacts due to the network "over-thinking" the input (Kaya et al., 2019).
Ongoing research also refines the application of layers at the token level. Wang and Kuo (2020) build sentence embeddings by combining token representations from distinct layers. Elbayad et al. (2020) and Dehghani et al. (2019) successfully use dynamic layers depth at the token level for full transformers (encoder-decoder). However, to the best of our knowledge, our attempt is the first to apply such mechanism to encoder only transformers and to provide an analysis of the process.

Method
In this Section, we detail the model architecture, illustrated in Figure 1, and pre-training procedure.

Model architecture
We use a multi-layer transformer encoder (Devlin et al., 2019) which transforms a context vector of tokens (u 1 · · · u T ) through a stack of L transformer encoder layers (Eq. 1, 2). We use weight tying across layers and apply the same transformation function at each iteration (Lan et al., 2020).
For the first layer, W e is the token embedding matrix, and W p the position embedding matrix. We augment the model with a halting mechanism, which allows dynamically adjusting the number of layers for each token (Eq. 3 to 8). We directly adapted this mechanism from Graves (2016). The main distinction with the original version is the use of a transformer model instead of a recurrent state transition model. The mechanism works as follow: at each iteration n, we add the following operations after Eq. 2. We assign a probability to stop p n t for each token at index t (Eq. 3). Given this probability, we compute an update weight λ n t (Eq. 4), which we use to compute the final state as the linear convex combination between the previous and current hidden state (Eq. 5).
With σ the sigmoid function. We define the remainder R t and the number of iterations for the token at index t, N t with: As soon as the sum of the probability becomes greater than 1, the update weights λ n t are set to 0 and the token is not updated anymore (Eq. 4). A small factor ensures that the network can stop after the first iteration (Eq. 6). Figure 1: As in ALBERT model, tokens are transformed through the iterative application of a transformer encoder layer. Our model key specificity is the application of the halting mechanism, which dynamically adjusts the number of iterations for each token.

Pre-training objective
During the pre-training phase, we train the model with the sentence order prediction (sop) -the task introduced in Lan et al. (2020) that classifies whether segments from the input sequence follow the original order or were swapped -and the masked language model task (mlm) (Devlin et al., 2019). We also encourage the network to minimize the number of iterations by directly adding the ponder cost into ALBERT pre-training objective. Given a length T input sequence u, Graves (2016) defines the ponder cost P(u) as: We define the final pre-training loss as the following sum:L where τ is a time penalty parameter that weights the relative cost of computation versus error.

Datum and infrastructure
We follow the protocol from ALBERT and pre-train the model with BOOKCORPUS (Zhu et al., 2015) and English Wikipedia. We reduce the maximum input length to 128 and the number of training steps to 112,500 1 . We use a lowercase vocabulary of size 30,000 tokenized using SentencePiece. We train all our models on a single TPU v2-8 from Google Colab Pro 2 and accumulate gradients to preserve a 4,096 batch size. We optimize the parameters using LAMB with a learning rate at 1.76e-3.

Experiments
We now analyze our iterative model properties during pre-training (Section 4.1) and fine-tuning (Section 4.2). We start by describing the setup for each of the subtasks.
mlm task We generate masked inputs following ALBERT n-gram masking. We mask 20% of all WordPiece tokens but do not always replace masked words with the [MASK] token to avoid discrepancy between pre-training and fine-tuning. We effectively replace 80% of the masked position with sop task We format our inputs as " In 50% of the case the two segments x 1 and x 2 are effectively consecutive in the text. In the other 50%, the segments are swapped.
Ponder cost We fix the time penalty factor τ empirically such that the ponder penalty represents around 10% of the total loss. To estimate the ponder cost, we discard the remainder, as R N for sufficient values of N . Given Eq. 7, the ponder cost then corresponds to the total number of iterations in the sentence, which is given by l × T , with T the number of tokens in the sequence and l the average iterations per token. We observe that ALBERT base loss converges to around 3.5. We calibrate τ such that τ P ≈ 0.35 ≈ τ × l × T . We train distinct models, listed in Table 1, that we calibrate such that their average number of iterations per token l is respectively 3, 6, and 12. We refer to these models as respectively tiny, small and base.
1 As emphasized in https://github.com/ google-research/bert, longer sequences are computationally expensive. To lighten the pre-training process, they advise using 128 sentence length and increase the length to 512 only for the last 10% of the training to train the positional embeddings. In this work, we only perform the first 90% steps as we are not looking for brute force performances.  We observe that the [CLS] token receives far more iterations than other tokens. This observation is in line with Clark et al. (2019) who analyze BERT attention and report systematic and broad attention to special tokens. We interpret that the [CLS] token is used as input for the sop task and aggregates a representation for the entire input. On the contrary, [SEP] token benefits from usually few iterations. Again, this backs the observation emerging from the analysis of attention that interprets [SEP] as a no-op operation for attention heads (Clark et al., 2019).
We also observe an interesting behavior from the [MASK] which also benefits from more iterations than average tokens.
As The model seems to have an intuitive mechanism and distributes iterations for tokens that are either crucial for the pre-training task or present a certain level of difficulty. This also appears in line with early-exit mechanisms cited in Section 2, that adapt the number of layers, for the whole example, to better scale to each sample level of difficulty.
Natural Fixed point We now analyze how the token's hidden states evolve during our model iterative transformations. At each iteration n, the self-attentive mechanism (Vaswani et al., 2017) computes the updated state n + 1 as a weighted sum of the current states. This introduces a cyclic dependency as every token depends on each other during the iterative process. As convergence within a loopy structure is not guaranteed, we encourage the model to converge towards a fixed point (Bai et al., 2019). We obtain this property "for free" thanks to our architecture specificity. Indeed at each iteration, the hidden state is computed as a convex combination of the previous n and current n + 1 hidden state. The combination is controlled by λ n t (Eq. 5). If λ n t is closed to 0, then h n t ≈ h n+1 t and by definition (Eq. 4, 6) λ n t will eventually be set to 0 at a certain iteration. Figure 2 represents the evolution of the mean cosine similarity between two hidden states from two consecutive iterations h n t and h n+1 t . The network indeed reaches a fixed point for every token. The

Application to downstream tasks
During the pre-training phase, the model focuses on tokens either crucial for the pre-training task or presents a certain level of difficulty. Now we study our model behavior during the fine-tuning on downstream syntactic or semantic tasks.
Control test To verify that our setup has reasonable performance, we evaluate it on the GLUE benchmark (Wang et al., 2019). Results from Table 2 are scored by the evaluation server 5 . As in Devlin et al. (2019), we discard results for the WNLI task 6 . For each task, we fine-tune the model on the train set and select the hyperparameters on the dev set using a grid search. We tune the learning rate between 5e-5, 3e-5, and 2e-5; batch size between 16 and 32 and epochs between 2, 3, or 4. To better compare our setup, we pre-train BERT and ALBERT model using our configuration, infrastructure and datum.  We present results on the test set in Table 2. As expected, the average score decreases with the number of iterations. Indeed, we limit the number of computation operations performed by our model. Moreover, we build our model on top of ALBERT, which share parameters across layers, thus reducing the number of parameters compared with the original BERT architecture. However, despite these additional constraints, results stay in a reasonable range. In particular, ALBERT-base with adaptative depth is very close to the version with a fixed depth. 4 Conneau and Kiela (2018) introduce probing tasks, which assess whether a model encodes elementary linguistic properties. We consider semantic and syntactic tasks that do not introduce random replacements. In particular, a task that predicts the sequence of top constituents immediately below the sentence node (TopConst), a task that predicts the tense of the main-clause verb (Tense), and two tasks that predict the subject (resp. direct object) number in the main clause (SubjNum, resp. ObjNum).  Table 3: Distribution of the iterations across token dependency types. We fine-tune our base model on each probing task. We then perform inference on the Penn Tree Bank dataset and report the number of iterations given token dependency types. The number in parentheses denotes the number of dependency tags. We only display the top 10 most frequent tags. We indicate in bold tags for which the number of iterations is above avg + std. We include a baseline accuracy which we obtain with the ALBERT-base version without an adaptative depth mechanism and therefore 12 iterations performed for each token.

Probing tasks
In our setup, we fine-tune the model on the task train set and select the hyperparameters on the dev set using a grid search. We use a 5e-5 learning rate and fine tune the epochs between 1 to 5; we use a 32 batch size. Finally, we compare in Table 3 the number of iterations performed for each token on the Penn Tree Bank (Marcus et al., 1993) converted to Stanford dependencies 7,8 .
We provide an accuracy baseline, obtained with the same setup but using ALBERT without the dynamic halting mechanism. As in the previous experiment, we observe that for these tasks, out model 7 Since we use sentence piece vocabulary, we assign to each piece the dependency tag from the whole token. 8 We present the Tables for other model configurations in Appendix B achieve competitive performances despite using less computational operations.
Although all tasks achieve significant and comparable accuracies, they all require a distinct global mean of iterations. The Tense task, which can be solved from the verb only, is completed in only 5.4 iterations, while the TopConst task, which requires to infer some sentence structure, is performed in 7.2 iterations. This suggests the model can adapt itself to the complexity of the task and globally spare unnecessary iterations.
Looking at the token level, as during the pretraining (Section 4.1), the iterations are unevenly distributed across tokens. The model seems to iterate more on tokens that are crucial for the task. For SubjNum, the subj tokens achieve the maximum number of iterations, while for the ObjNum task, the obj and root token iterates more. Similarly, all tasks present a high number of iteration on the main verb (root) that is crucial for each prediction.

Conclusion
We investigated the role of the layers in deep transformers. We designed an original model that progressively transforms each token through a dynamic number of iterations. We analyzed the distribution of these iterations during pre-training and confirmed the results obtained by analyzing the distribution of attention across BERT layers, particularly the specific behavior played by special tokens. Moreover, we observed that key tokens for the prediction task benefit from more iterations. We confirmed this observation during fine-tuning, where the tokens with a large number of iterations are also suspected to be key for achieving the task.
Our experiments provide a new interpretation path for the role of layers in deep transformer models. Rather than extracting some specific features at each stage, layers could be interpreted as the iteration from an iterative and convergence process. We hope that this can help to better understand the convergence mechanisms for transformers models, reduce the computational footprint or provide new regularization methods.

A Natural fixed point
We present here the evolution of the mean cosine similarity between two hidden states from two consecutive iterations for our small (Figure 3) and tiny (Figure 4) models. As presented in Section 3.2, we fix the maximum number of iterations at respectively 6 and 12 for the tiny and small models.

B Probing tasks
We give here the probing tasks results from Section 4.2 with our small (Table 4) and tiny (Table 5) models.  Table 4: Distribution of the iterations across token dependency types. We fine-tune our small model on each probing task. We then perform inference on the Penn Tree Bank dataset and report the number of iterations given token dependency types. The number in parentheses denotes the number of dependency tags. We only display the top 10 most frequent tags. We indicate in bold tags for which the number of iterations is above avg + std. We include a baseline accuracy which we obtain with the ALBERT-base version without an adaptative depth mechanism and therefore 12 iterations performed for each token.  Table 5: Distribution of the iterations across token dependency types. We fine-tune our tiny model on each probing task. We then perform inference on the Penn Tree Bank dataset and report the number of iterations given token dependency types. The number in parentheses denotes the number of dependency tags. We only display the top 10 most frequent tags. We indicate in bold tags for which the number of iterations is above avg + std. We include a baseline accuracy which we obtain with the ALBERT-base version without an adaptative depth mechanism and therefore 12 iterations performed for each token.