Multi-split Reversible Transformers Can Enhance Neural Machine Translation

Large-scale transformers have been shown the state-of-the-art on neural machine translation. However, training these increasingly wider and deeper models could be tremendously memory intensive. We reduce the memory burden by employing the idea of reversible networks that a layer’s input can be reconstructed from its output. We design three types of multi-split based reversible transformers. We also devise a corresponding backpropagation algorithm, which does not need to store activations for most layers. Furthermore, we present two fine-tuning techniques: splits shuffle and self ensemble, to boost translation accuracy. Specifically, our best models surpass the vanilla transformer by at least 1.4 BLEU points in three datasets. Our large-scale reversible models achieve 30.0 BLEU in WMT’14 En-De and 43.5 BLEU in WMT’14 En-Fr, beating several very strong baselines with less than half of the training memory.


Introduction
Transformers (Vaswani et al., 2017) and their variants (So et al., 2019;Dehghani et al., 2019;Fonollosa et al., 2019;Zhu et al., 2020) significantly enhance the performance of neural machine translation (NMT). But this often requires a large size of the hidden layer (e.g., Raffel et al. (2019) used a dimension of 65K) or a deeper network by stacking more building blocks (e.g.,  used a 60-layer encoder). Training large networks could be extremely memory intensive and might even require model parallelization across multiple GPUs (Brown et al., 2020). As a result, reducing memory consumption is crucial to train wider and deeper networks efficiently.
Backpropagation (BP) is commonly used for training modern neural networks. BP needs to store layer activations to calculate the parameter gradients, which severely increases the memory burden.
The idea of reversible networks can be a solution. During training, a reversible network layer's input can be reconstructed from its output. BP is run together with the reconstruction process, removing the need to store all layer activations except for the last layer. We extend the hidden dimension splitting approach by Gomez et al. (2017) and design three types of reversible transformers, namely, simple reversible transformers (SIM-REV), single dependent reversible transformers (SD-REV) and fully dependent reversible transformers (FD-REV). We also devise a corresponding BP algorithm for our reversible models, which is significantly memory efficient compared with the conventional BP.
Our reversible models rely on partitioning each layer's input and output into multiple equal splits. This multi-split feature inspires us to develop two fine-tuning techniques to further enhance translation accuracy. First, we randomly shuffle the output splits to encourage information sharing. Second, we train distinct translation models based on different output splits in the final decoder layer and run model ensemble during inference. These two techniques are applied after model convergence. Only a few epochs of fine-tuning are sufficient for boosting the translation performance. Also, both techniques do not break the reversibility of our proposed models.
We demonstrate that our reversible models can achieve similar or better performance than vanilla transformers do with less memory consumption. Specifically, by employing reversible training and the fine-tuning techniques, our best models can surpass vanilla transformers by 1.5 BLEU (IWSLT'14 De-En), 2.0 BLEU (WMT'19 En-Lt) and 1.4 BLEU (WMT'14 En-De). Our large-scale models also beat several very strong NMT models with less than half the training memory on WMT'14 En-De (30.0 BLEU) and WMT'14 En-Fr (43.5 BLEU).

Methodology
We introduce reversible transformers in this section. The definition and benefits of layer reversibility are given in Section 2.1. Section 2.2 shows three types of reversible architectures based on partitioning the layer input along the hidden/embedding dimension. Section 2.3 details the backpropagation algorithm we use. Finally, in Section 2.4, two techniques that can fit into the reversible training framework are introduced for further boosting model performance.

Reversible Architectures
A neural network layer is said to be reversible if its input can be reconstructed from its output. Usually, a network is trained in a forward-backward fashion. The activations in each layer are calculated in the forward process and stored for gradient computation in the backward process. The requirement for storing activations is memory intensive and often becomes a bottleneck for network training. However, if a network has reversible building blocks, we do not need to store the activations for most layers since they can be computed during the backward process. A reversible layer can be designed in two ways. The first way is that this layer has an analytical inverse (Gomez et al., 2017;Jacobsen et al., 2018;Chang et al., 2018;MacKay et al., 2018). The second way is to compute the layer input via numerical methods, e.g., the fixed-point iteration (Behrmann et al., 2019). We focus on the first way following the dimension-splitting approach proposed by RevNets (Gomez et al., 2017). We give a brief re-view of RevNets. X is the network input, which is split into two halves X 1 and X 2 . F and G are modules inside a layer (e.g., 3 × 3 convolutions). The forward process is as follows:

Reversible Transformers
Transformers (Vaswani et al., 2017) achieved the state of the art performance in several tasks (Edunov et al., 2018;Brown et al., 2020). Despite its success, training transformers is memory intensive. We propose three reversible transformers inspired by RevNets (Gomez et al., 2017) to reduce the training memory consumption. X is the layer input. O is the layer output. X and O are partitioned into n equal splits along the hidden/embedding dimension. X = {X 1 , X 2 , · · · , X n }, O = {O 1 , O 2 , · · · , O n }. F 1 , F 2 , · · · , F n are modules inside a layer (e.g. self-attention, fully connected layer).
Simple Reversible Transformer (SIM-REV) X is split into two halves X 1 , X 2 . For each module F i inside a layer, the forward process resembles RevNets by changing F and G in Equation (1) into F i . Part (b) of Figure 1 demonstrates the case of a layer with three modules: This is the simplest way to introduce reversibility into a transformer layer. But the computation complexity is doubled compared with vanilla transformers since each module function F i is used twice.
Single Dependent Reversible Transformer (SD-REV) We propose another reversible architecture to reduce the computational complexity. The i-th output split O i depends only on X i and O i−1 . The forward process is as follows: The reconstruction of X given O is also straightforward: . . .
Part (c) of Figure 1 shows a 3-split example of SD-REV. With only half of SIM-REV's computational complexity, experiments show that SD-REV can achieve similar or even better performance as SIM-REV does.
Fully Dependent Reversible Transformer (FD-REV) The SD-REV only encodes information in neighbour splits. The lack of interaction between distant splits may make the model less expressive. We force each output split O i to depend on all previous output splits O <i and all subsequent input splits X ≥i , while preserving the reversibility of the network layer. Despite the increased computational complexity, we hope the model to have a better generalization ability. A detailed description of the FD-REV's forward process is as follows: . . .
The reversibility of FD-REV is ensured by: . . .
Part (d) of Figure 1 illustrates FD-REV in a 3-split case. Experiments in Section 3 shows that allowing interaction between distant splits is beneficial for translation performance.
Instantiation The building blocks of transformers are attention based modules and position-wise feed-forward layers: For SIM-REV, each of the above modules are transformed into reversible modules, where S is an input/output split: Decoder: F 1 (S) = α(S + Self-Attn(S)), F 2 (S) = α(S + Cross-Attn(S)), For an n-split SD-REV or FD-REV (actually it has an n-split encoder and an (n + 1)-split decoder), we use multiple Self-Attn modules and a single Cross-Attn/FFN module within each layer: Applying layer normalization (LN) to the layer output O is crucial to better convergence in training transformers. However, it requires extra storage to calculate the reverse of LN. We use the ReZero (Bachlechner et al., 2020) technique as a substitution of LN. Each layer has a distinct re-scaling weight α which is initialized to zero. α is trained together with other network parameters using Algorithm 1 in Section 2.3.

Backpropagation with Reconstructing Activations
In the backward pass, we are given the activations O = {O 1 , · · · , O n } and their total derivatives dO = {dO 1 , · · · , dO n }. We wish to compute the inputs X = {X 1 , · · · , X n }, their total derivatives dX = {dX 1 , · · · , dX n } and the derivatives of model parameters in F 1 , · · · , F n . For SIM-REV, the backpropagation (BP) algorithm has no difference with that in Gomez et al. (2017). Such that our main focus is to derive the resulting BP algorithm for SD-REV and FD-REV.
The forward pass of SD-REV and FD-REV can be combined into a more general form: . Algorithm 1 defines the BP rule of this general form reversible network. Gradients for model parameters are computed in line 9 of Algorithm 1 as a side effect. A repeated apply of Algorithm 1 allows us to perform BP through a sequence of reversible layers, only requiring the activations and their derivatives of the top layer. In this way, the storage cost for activations can be small and independent of network depth.

Splits Shuffle and Self Ensemble
In this section, we propose two multi-split based fine-tuning methods that can enhance model performance, namely, splits shuffle and self ensemble.
Algorithm 1 BP Algorithm for Multi-Split Reversible Networks Input: Layer output: O = {O 1 , · · · , O n }; Derivatives of O: dO = {dO 1 , · · · , dO n }; Modules: G 1 , · · · , G n ; Output: Layer input: First, we train a reversible transformer till convergence. Then, several epochs of fine-tuning with one of these techniques can improve model accuracy.
Splits Shuffle A reversible transformer consists of several reversible layers. The inputs of a certain layer are the outputs of its preceding layer, which we denote as O = {O 1 , · · · , O n }. Note that if the order of O i is randomly shuffled, the whole network is still reversible as long as we keep a record of the shuffling order. This property inspires us to do the following fine-tuning technique: • For each layer in the reversible network, sample b ∼ Bernoulli(p).
Figure 2 shows the splits shuffle process. At inference time, we set p to 0. The idea behind splits shuffle is to apply dropout in the structure level. Splits shuffle provides a way to combine exponentially many network architectures efficiently. In order to let each structure perform well, each split is forced to become more expressive.
Self Ensemble Model ensemble is a commonly used method for boosting translation performance (Zhou et al., 2017;Wang et al., 2020b). Model ensemble usually requires multiple distinct models to output their probability distributions over the vocabulary. The ensemble process is both computational and memory intensive. Our multi-split model offers a new chance that we can view each split of the final output as an independent model. Our self ensemble technique works as follows: • O i is a split in the final layer output, y is the translation target, FC stands for a fully connected layer:  Architectures We experiment with all architectures proposed in Section 2.2. Multi-split based models have larger hidden dimensions than vanilla transformers do. We use a smaller embedding size than the hidden size by factorizing the word embedding matrix. N is the vocabulary size, d is the hidden size. The original word embedding matrix E ∈ R N ×d is factorized into a multiplication of two matrices of size N × l and l × d, where l d. We denote l as the embedding size. The embedding size for each language pair is 128 (IWSLT'14 De-En), 256 (WMT'19 En-Lt, WMT'14 En-De base models), 512 (WMT'14 En-De and WMT'14 En-Fr large models). For a specific language pair, we manage to ensure almost identical parameter sizes across different model architectures. One can refer to Appendix A for some details.
Training All models are trained on 8 RTX 2080Ti GPU cards with a mini-batch of 3584 tokens unless otherwise stated. We use the same learning rate scheduling strategy as (Vaswani et al., 2017) does with a warmup step of 4000. The learning rates are set to 5 × 10 −4 (IWSLT'14 De-En), 7 × 10 −4 (WMT'19 En-Lt, WMT'14 En-De base). The dropout probability and label smoothing factor are all set to 0.1. For training large models in Section 3.4, we increase the dropout probability to 0.3 and the learning rate to 1 × 10 −3 . We also accumulate gradients for 16 batches.

Machine Translation Results
To make comparisons between various architectures, we carry experiments on all corpora except WMT'14 En-Fr. Results are summarized in Table 1. In general, the best reversible architecture can outperform the transformer baseline by 1.1 (IWSTL'14 De-En), 1.5 (WMT'19 En-Lt) and 0.8 (WMT'14 En-De) BLEU points.
All models we propose deliver similar or superior performance to the vanilla transformer. SD-REV-2 (2 means the split number is 2) is almost as good as SIM-REV with only half the computational complexity. For SD-REV, translation performance increases as the split number becomes larger. A higher split number means more interaction between separate splits, which may benefit the translation quality. The good performance of FD-REV further indicates that interactions between splits should be encouraged. FD-REV translates best among different architectures. Increasing the split number is not necessary for FD-REV, since it   The remaining experiments are organized as follows: (1) In Section 3.3, we apply splits shuffle and self ensemble to the best models for each language pair. (2) In Section 3.4, we experiment with large model size for two large corpora, namely, WMT'14 En-De and WMT'14 En-Fr. We also try splits shuffle and self ensemble to validate their effectiveness.

Splits Shuffle and Self Ensemble
In this section, we focus on fine-tuning techniques to boost translation performance. After model convergence when training with Algorithm 1, we apply splits shuffle or self ensemble for fifteen epochs (IWSLT'14 De-En), five epochs (WMT'19 En-Lt) and one epoch . For the shuffle probability p, we use 0.3. Results are also summarized in Table 1.
Both fine-tuning techniques yield a performance gain over the original model. Splits shuffle is slightly better than self ensemble. Also, splits shuffle does not increase inference-time computational cost while self ensemble does. Several interesting phenomena are worth mentioning. First, the final validation perplexity decreases for splits shuffle and increases for self ensemble. Since both techniques are helpful, it is reasonable to think that splits shuffle indeed enhances model performance while self  ensemble benefits more from the ensemble process. Second, a combination of splits shuffle and self ensemble fails to converge. The combination task may be too challenging for the model to learn even the model is already in a sub-optimal state. Third, we can use splits shuffle and self ensemble from the beginning. However, such complicated training objectives also bring no performance gain. Details can be found in Appendix B.

Performance and Memory Consumptions of Large Models
In this section, we investigate large-scale reversible transformers. Experiments focus on two aspects. First, whether reversible models' performance is comparable or even better than the non-reversible models? Second, how much GPU memory can be saved when using Algorithm 1 for backpropagation (BP). We choose FD-REV-2 which performs best for WMT'14 En-De in Section 3.2. The hidden dimension is doubled, resulting in a similar parameter size with other large-scale models. As shown in Table 2, FD-REV-2 achieves comparable results in both datasets. The fine-tuning techniques in Section 2.4 offer a chance to enhance model performance further. We follow the settings for WMT'14 En-De in Section 3.3 and find out that large-scale models benefit more from splits shuffle and self ensemble. We can surpass various strong baselines by using splits shuffle for only one epoch of fine-tuning. Specifically, we achieve 30.0 BLEU points for WMT'14 En-De and 43.5 BLEU points for WMT'14 En-Fr.
We also compare the training memory consumption between three different settings: (1) Figure 4 illustrates the memory consumptions of three training settings. Transformer-Big and FD-REV-2 are similar to each other in GPU memory consumption. Reversible BP with Algorithm 1 removes the need to store activations for most layers, requiring about half of the GPU memory as conventional BP does.

Analysis
Splits Shuffle Probability We study the impact of splits shuffle probability p. For all language pairs, we use the best models as mentioned in Section 3.3 and Section 3.4. The results are summarized in Figure 5. We find out that a medium p value (0.3 ≤ p ≤ 0.5) yields the largest BLEU increase. A small p value is insufficient to enhance model generalization ability since the model has already been optimized for several epochs. Meanwhile, a very large p value makes the model highly unstable and hard to converge.
ReZero We study the impact of the ReZero technique. ReZero works as a substitution for layer normalization (LN). We can not apply LN to each layer's output O due to the requirement of reversibility. Instead, we can apply LN to each output split O i . We compare using ReZero with using LN on O i . Table 3 shows the results. Translation performance drops severely and the model converges more slowly. Therefore we choose to use ReZero throughout our experiments.
Reversible Training vs. Normal Training Reversible training saves GPU memory. However, reconstructing activations over many layers can introduce numerical errors. Inaccurate gradients may hurt model performance, so that it is important to compare the model performance between using reversible training and conventional backpropagation (BP). As shown in Table 3, reversible training does not hurt model performance. Also, since we update the model parameters for the same number of times, the convergence speed is almost identical between reversible training and conventional BP.

Memory Consumption of Deep Models
Reversible transformers are more memory efficient when the model gets deeper. We validate this argument with a simple experiment. First, we get a mini-batch of 2390 tokens. Then, one step of parameter update is done by conventional BP or reversible training. We gradually increase the model depth and keep a record of the corresponding memory consumption in a single RTX TITAN GPU card. Results are shown in Figure 6. Deeper models mean more activations to store when using conventional BP, while reversible training only needs to store the extra model parameters. The memory consumption gap can be up to 12.6 GB when we use a 30-layers encoder and a 30-layers decoder.
Computational Overhead Roughly speaking, our proposed network is composed mostly of fully connected (FC) layers. For an FC layer with N connections, the forward and backward passes require approximately N and 2N add-multiply operations, respectively. As the reconstruction during backpropagation (BP) adds another N add-multiply operations, training with Algorithm 1 will be 33% slower. We compare the training speed of 4 structures, namely, FD-REV-2-base, FD-REV-2-big,  From Table 4, we can see that FD-REV-2 trains almost as fast as vanilla Transformers when employing conventional BP. The apply of Algorithm 1 adds about 33% to 38% training time, which in turn saves about half the memory consumption. operates on multi-dimensional tensors and applies multiple attentions, each along a single axis of the input tensor. Several works (Wu et al., 2019;Lioutas and Guo, 2020;Beltagy et al., 2020;Zaheer et al., 2020; incorporate convolution networks into transformers. Except for inventing new attention modules, weight sharing (Dehghani et al., 2019;Bai et al., 2019;Lan et al., 2020) is another practical approach to decreasing the memory burden. Reversible models are orthogonal to these approaches. A combination of reversible models and variants of transformers can further reduce memory consumption.

Conclusion
We have presented three types of multi-split based reversible transformers which outperform vanilla transformers. During backpropagation, activations for most layers need not be stored in memory because they can be reconstructed. Furthermore, we have proposed two fine-tuning techniques, namely, splits shuffle and self ensemble. Both techniques are easy to implement, and only a few fine-tuning epochs are sufficient for boosting translation performance. Our approach has beaten several strong baselines in two large datasets with fewer model parameters and much less training memory. Specifically, we have achieved 30.0 BLEU points in WMT'14 En-De and 43.5 BLEU points in WMT'14 En-Fr. Also, one can transform other network structures into their reversible versions by applying our methods. We would explore more computer vision or natural language processing tasks to widen reversible networks' applicability.
Table 5 details the model hyper-parameters. As we use a factorized word embedding matrix, the embedding size l is smaller than the hidden dimension d. The hidden size d increases with the number of splits n to ensure a similar parameter size for a certain dataset. Another thing worth mentioning is that SD-REV and FD-REV have identical parameter sizes as long as they have the same number of splits n, embedding size l and hidden size d. Thus, we do not differentiate between SD-REV and FD-REV in Table 5.

B More Results on Splits Shuffle and Self Ensemble
We provide more experimental results on splits shuffle and self ensemble as shown in Figure 7 and Figure 8. Using any of the techniques from the beginning tend to hurt the model performance.