Memory-Efficient Differentiable Transformer Architecture Search

Differentiable architecture search (DARTS) is successfully applied in many vision tasks. However, directly using DARTS for Transformers is memory-intensive, which renders the search process infeasible. To this end, we propose a multi-split reversible network and combine it with DARTS. Specifically, we devise a backpropagation-with-reconstruction algorithm so that we only need to store the last layer's outputs. By relieving the memory burden for DARTS, it allows us to search with larger hidden size and more candidate operations. We evaluate the searched architecture on three sequence-to-sequence datasets, i.e., WMT'14 English-German, WMT'14 English-French, and WMT'14 English-Czech. Experimental results show that our network consistently outperforms standard Transformers across the tasks. Moreover, our method compares favorably with big-size Evolved Transformers, reducing search computation by an order of magnitude.


Introduction
Current neural architecture search (NAS) studies have produced models that surpass the performance of those designed by humans (Real et al., 2019;Lu et al., 2020). For sequence tasks, efforts are made in reinforcement learning-based (Pham et al., 2018) and evolution-based (So et al., 2019; methods, which suffer from the huge computational cost. Instead, gradient-based methods Jiang et al., 2019;Yang et al., 2020) are less demanding in computing resources and easy to implement, attracting many attentions recently.
The idea of gradient-based NAS is to train a super network covering all candidate operations. Different sub-graphs of the super network form the search space. To find a well-performing subgraph,   Experiments are run on a single step of forward-backward pass on a batch of 3584 tokens with a NVIDIA P100 GPU. Limited by GPU memory, DARTS in Transformers has to search in small sizes while evaluating in large sizes, which will cause performance gaps  weights. Operations corresponding to the largest search parameters are kept for each intermediate node after searching. A limitation of DARTS is its memory inefficiency because it needs to store the intermediate outputs from all its candidate operations. This is much more pronounced when we apply Transformers (Vaswani et al., 2017) as the backbone of DARTS (the operation set is detailed in Section 2.5). As shown in Figure 1, memory consumption grows extremely fast as we increase the hidden size d, quickly running out of memory as d > 400. As a result, we can only use a limited operation set or a small hidden size, which may lead to worse model performance.
To address the unfavorable memory consumption issue in DARTS, we propose a variant of reversible networks. Each input of a reversible network layer can be reconstructed from its outputs. Thus, it is unnecessary to store intermediate outputs except for the last layer because we can reconstruct them during backpropagation (BP). Inspired by the idea of RevNets (Gomez et al., 2017), we devise a multi-split reversible network. Each split contains a mixed operation search node to enable DARTS. Also, only a small modification of BP is needed to enable gradient calculation with input reconstruction. We show the memory consumption of our method in Figure 1, which on average halves the amount of memory required in the vanilla DARTS. We can search larger, deeper networks with a richer candidate operation set under the same memory constraint.
Our method is generic to handle various network structures. In this work, we focus on the sequenceto-sequence task. We first perform the architecture search using the WMT'14 English-German translation task. The resulting architecture is then re-trained on three datasets: WMT'14 English-German, WMT'14 English-French, and WMT'14 English-Czech. We achieve consistent improvement over standard Transformers in all tasks. At a medium model size, we can have the same translation quality as the original "big" Transformer with 69% fewer parameters. At a big model size, we exceed the performance of the Evolved Transformer (So et al., 2019), with the computational cost lowered by an order of magnitude. We will make our code and models publicly available.

Methodology
We give a detailed description of our method. In Section 2.1, we introduce DARTS and its memory inefficiency when applying in Transformers. In Section 2.2, we propose a multi-split reversible network, which works as the backbone of our memoryefficient architecture search approach. Section 2.3 shows a backpropagation-with-reconstruction algorithm. In Section 2.4, we manage to combine DARTS with our reversible networks. Finally, in Section 2.5, we summarize the proposed algorithms with more details.

Differentiable Architecture Search in Transformers
Following , we explain the idea of differentiable architecture search (DARTS) within a one-layer block. Let O be the candidate operation set (e.g., Self Attention, FFN, Zero). Each operation o ∈ O represents some function that can be applied to the layer inputs or hidden states (denoted X). The key of DARTS is to use a mixed operation search node f (X) to relax the categorical choice of a specific operation to a softmax over all candidate operations: where the α are trainable parameters of size |O| that determines the mixing weights. During searching, a one-layer block contains several search nodes. The task is to find a suitable set of α for each search node. At the end of the search, the resulting operation in each node is determined by: We optimize the α together with network weights θ by gradient descent. A good architecture means performing well on the searching validation set, such that we optimize α with validation loss L val and θ with training loss L train : In practice, we update α by ∇ α L val and θ by ∇ θ L train in each step. It is easy to directly apply DARTS in Transformers by replacing some or all operations in a Transformer block with mixed operation search nodes. For example, we can change the transformer decoder block from Self Attn → Cross Attn → FFN to Search Node 1 → Cross Attn → Search Node 2. Note that a search node outputs a weighted sum of different operations. To enable gradient calculation in the backward pass, we need to store every operation's output, which results in a steep rise in memory consumption during searching. Figure 1 shows the memory consumption of using 2 search nodes in both Transformer encoder and decoder. DARTS run out of memory easily, even at a small hidden size.

Multi-split Reversible Networks
To relieve the memory burden of DARTS in Transformers, we use reversible networks. A reversible network layer's input can be reconstructed from its output. Suppose a network is comprised of several reversible layers. We do not need to store intermediate outputs except the last layer, because we can reconstruct them from top to bottom during backpropagation (BP). Denote by X and f (X) the Each X k and Y k are in R l×d . The k-th pooling takes the concatenation of X i>k and Y i<k as the input, and outputs a tensor in R l×d . The operation search gives a weighted average of the outputs of each candidate operation.
layer input and the layer output, respectively. X is first split along the embedding/channel dimension into n equal parts {X 1 , · · · , X n }. A RevNets  alike operation is applied to each X k , which yields Y k . f (X) is a concatenation of {Y 1 , · · · , Y n } along the split dimension: (3) G k is a mixed operation node during the architecture search process. After searching, G k is a deterministic operation given by arg max o∈O α o . Detailed discussions can be found in Section 2.4. The reversibility of Eq. (3) needs rigorous validation, such that the input X can be easily reconstructed from f (X): Part (a) of Figure 2 illustrates a 3-split reversible network, which we frequently employ throughout our experiments for simplicity.
We show the BP-with-reconstruction through a single layer in Algorithm 1.
[·] represents Concat(·) for simplicity reasons. In Line 9 of Algorithm 1, dθ k is calculated as a side effect. Line 10 shows the reconstruction process, where each split X k is recovered in the order of n to 1. In Algorithm 1, grad k works as a gradient accumulator, which keeps track of all derivatives associated with X k . A repetitive application of Algorithm 1 enables us to backpropagate through a sequence of reversible layers. Only the top layer's outputs require storage, which makes it much more memory-efficient.
Roughly speaking, for a network with N connections, the forward and backward passes require approximately N and 2N add-multiply operations, respectively. Since we need to reconstruct X from f (X), the re-calculation requires another N addmultiply operations, making it 33% slower. Fortunately, we can only need Algorithm 1 for architecture search and will re-train the resulting network with ordinary BP. The search process turns out to converge fast. The computational overhead does not become a severe problem.

DARTS with Multi-split Reversible Networks
Performing DARTS based on n-split reversible networks only requires specifying each G k in Eq.
(3). Suppose that each X k ∈ R l×dn (l is the sequence length and d is the hidden size, d n = d n ), and that each Y k has the same size as X k . The input of G k contains n−1 tensors in R l×dn . To enable elementwise addition with X k , the output of G k must also be in R l×dn .
G k is factorized into two parts. The first part is a pooling operation, which takes an l × d n × (n − 1) tensor as input, and outputs an l × d n × 1 tensor. The second part is a mixed operation search node. G k is calculated as follows: where α k is randomly initialized. Figure 3 shows the design of G k . By substituting each G k in Eq.
(3) with Eq. (5), we are able to use Algorithm 1 to perform memory-efficient DARTS. We call this method DARTSformer, which is illustrated by Part (b) of Figure 2 in a 3-split case. The overall search space size is critical to the performance of DARTSformer. In our experiments, we focus on sequence-to-sequence tasks where the encoder and the decoder are searched simultaneously. Suppose that we have an m-split encoder and an n-split decoder. We search s consecutive layers. For example, s = 2 means that we search within a 2-layer encoder block. Each layer in the block is an m-split reversible layer. The encoder contains several identical 2-layer blocks, the same to the decoder. The search space is of size |O| s(m+n) . If |O| is large, it can easily introduce a large search space even with small m, n and s.

Instantiation
We describe the instantiation of DARTSformer in this section.
• Cross Attention: Only available to decoder.
• Zero: Return a zero tensor of the input size. Residual connections (He et al., 2016) and layer normalization (Ba et al., 2016) are crucial for convergence in training Transformers (Vaswani et al., 2017). To make our network fully reversible, these two tricks can not be used directly. Instead, we put  the residual connections and layer normalization within each operationõ(X) = LayerNorm(X + o(X)), except for Zero and Identity.

Encoder and Decoder
We use an n-split encoder and an (n+1)-split decoder for DARTSformer. Each G k in the encoder takes the format of Eq. (5). Instead for the decoder, G k<n+1 still follows Eq. (5), but the operation for the last split G n+1 is fixed as Cross Attention. Our experiments show that this constraint on the decoder yields architectures with better performances.
Search and Re-train We summarize the entire framework of DARTSformer in Algorithm 2. Note that the search process is the most memory intensive part, such that we use BP-with-reconstruction as shown in Line 2-5 of Algorithm 2.

Datasets
We use three standard datasets to perform our experiments as So et al. (2019): (1) WMT'18 English-German (En-De) without ParaCrawl, which consists of 4.5 million training sentence pairs. (2) WMT'14 French-English (En-Fr), which consists of 36 million training sentence pairs. (3) WMT'18 English-Czech (En-Cs), again without ParaCrawl, which consists of 15.8 million training sentence pairs. Tokenization is done by Moses 2 . We employ BPE (Sennrich et al., 2016) to generate a shared vocabulary for each language pair. The BPE merge operation numbers are 32K (WMT'18 En-De), 40K (WMT'14 En-Fr), 32K (WMT'18 En-Cs). We discard sentences longer than 250 tokens. For the retraining validation set, we randomly choose 3300 sentence pairs from the training set. The evaluation metric is BLEU (Papineni et al., 2002). We use beam search for test sets with a beam size of 5, and we tune the length penalty parameter from 0.5 to 1.0. Suppose the input length is m, and the maximum output length is 1.2m + 10.

Search Configuration
The architecture searches are all run on WMT'14 En-De. DARTS is a bilevel optimization process, which updates network weights θ on one dataset and search parameters α on another dataset. We split the 4.5 million sentence pairs into 2.5/2.0 million for θ and α. Both L train and L val are cross entropy loss with a label smoothing factor of 0.1. The split number n is 2 for the encoder and 3 for the decoder. We set s to 1 or 2, which means the super network contains several identical 1-layer or 2-layer blocks. The candidate operations are detailed in Section 2.5, where |O| = 13/14 for encoder and decoder, respectively. Along the analysis in Section 2.4, the largest size of the search space is around 1 billion. We use a factorized word embedding matrix to save memory. |V | is the vocabulary size, and d is the hidden size. The original word embedding matrix E ∈ R |V |×d is factorized into a multiplication of two matrices of size |V | × e and e×d, where e d. We let e denote the embedding size. We set e = 256, d = 960. During searching, we set the dropout probability to 0.1. Two Adam optimizers (Kingma and Ba, 2015) are used for updating θ and α, with β 1 = 0.9 and β 2 = 0.98. For θ, we use the same learning rate scheduling strategy as done in Vaswani et al. (2017) with a warmup step of 10000. The maximum learning rate is set to 5 × 10 −4 . For α, we fix the learning rate to 3 × 10 −4 with a weight decay of 1 × 10 −3 , which is the same as  does.
DARTSformer requires us to specify a pooling operation as stated in Eq. (5). We experiment with both max pooling and average pooling. All searches run on the same 8 NVIDIA V100 hardware. We use a batch size of 5000 tokens per GPU and save a checkpoint every 10,000 updates (5000 for θ and 5000 for α). Our search process finalizes after 60,000 updates.

Training Details
All the networks derived from the saved checkpoints are re-trained on WMT'14 En-De to select the best performing one. We then train the selected network on all datasets in Section 3.

Comparison Between Search Setups
We search through a different number of consecutive layers with different pooling operations. For retraining, we use the same learning rate scheduling strategy as in searching. We also keep the dropout rate unchanged. Results are summarized in Table  1. DARTSformers yields better results than standard Transformers in all experimental setups. The maximum performance gain is 0.7 BLEU with max pooling when searching through 2 consecutive lay-

Model Price
Steps Hardware ET $150k 4.2 × 10 8 200 TPUs DARTSformer $1.25k 4.8 × 10 5 8 V100 ers. Also, DARTSformer achieves slightly better results than the Evolved Transformer in three out of four runs. We compare the search cost between the Evolved Transformers and DARTSformer from various aspects. DARTSformer takes about 40 hours to run on an AWS p3dn.24xlarge node 3 . The price for a single run of search is about $1.25k. As reported by Strubell et al. (2019), the search process of Evolved Transformer takes up to $150k, which is extremely expensive. As for hardware, the evolutionary search employs 200 TPU V.2 chips to run, while our method only uses 8 NVIDIA V100 cards. The reason for the evolutionary search algorithm's huge cost is that it requires training multiple candidate networks from scratch. We compare the number of parameter update steps in Table 3. The evolutionary search needs approximately 874 times more update steps than our method.
A simple sampling-based NAS method (Guo et al., 2020) can also reduce memory consumption.
For each batch of training data, we set G k in Eq. (5) as a uniformly sampled operation from the candidate set O. The search parameters α are discarded, and the resulting network is produced from an evolutionary search by evaluating on the re-training validation set. This method performs poorly in machine translation, as shown in Table 1. We find that sampling-based methods favor large-kernel convolutions and that the resulting architectures tend to generate repetitive sentences.
We also experiment with increased split numbers. As shown in Table 2, an increased split number hurts the translation performance. The best results are all achieved by the smallest split. Also, the search process is harder to converge as the search space becomes too large. The re-training and inference speed will slow down when increasing the split number because more recurrence are introduced in the calculation as shown in Eq. (3).
In the following sections, we try the best search result (DARTSformer + search 2 layers + 2 split + max pooling) in various sequence-to-sequence tasks to see its generalization ability. We show this searched architecture in Figure 4.

Performance of DARTSformer on Other Datasets
First, we train DARTSformer with a base model size on three translation tasks in Section 3.1. We would like to see whether DARTSformer only performs well on the task used for architecture search or generalizes to related tasks. Second, we scale up the model size and the batch size to see whether the performance gain of DARTSformer still exists. We compare DARTSformer with standard Transformers and Evolved Transformers with similar model sizes. Following Vaswani et al. (2017), the parameter size is around 62.5M/214.7M for the base model and big model, respectively. To match the settings of So et al. (2019) when training big models, we increase the dropout rate to 0.3 and the learning rate to 1 × 10 −3 . We also accumulate gradients for two batches. Results are shown in Table 4. At the base model size, DARTSformer steadily outperforms standard Transformers. We achieved the same translation quality (28.4 BLEU, reported by Vaswani et al. (2017)) as the original big Transformer in WMT'14 En-De, with about 69% fewer parameters. Also, the maximum BLEU gain is 0.9 in WMT'14 En-Cs, which is not the dataset we conduct our architecture

Performance of DARTSformer vs. Parameter Size
In Section 4.2, DARTSformer consistently improves the performance with a model size comparable to the base and big Transformers. We are wondering whether the performance increase exists with smaller model sizes. We experiment with a spectrum of model sizes for standard Transformers and DARTSformer on WMT'14 En-De. Specifically, we use four embedding sizes for standard Transformers, [small:128, medium:256, base:512, big:1024], where its hidden size is identical to the embedding size. We also adjust the model size of DARTSformer accordingly. For base and big models, we use the results from Section 4.2. For small and medium models, we set the learning rate to 5 × 10 −4 , the dropout probability to 0.1, and update the model parameters for 200,000 steps on the same 8 NVIDIA V100 hardware. Figure 5 shows the results for both architectures. DARTSformer performs better than standard Transformers at all sizes. The BLEU increase is [1.3/0.9/0.7/0.7] for [small/medium/base/big] models. An interesting fact is that the performance gap between two models tends to be smaller as we increase the model size, which is also observed in So . Based on this observation, DARTSformer is more pronounced for environments with resource limitations, such as mobile phones. A possible reason for the decreased performance gap at larger model sizes is that the effect of overfitting becomes more important. We expect that some data augmentation skills (Sennrich et al., 2015;Edunov et al., 2018;Qu et al., 2020) might be of help.

The Impact of Search Hidden Size
The main motivation for our presented method is that we want to search in a large hidden size to reduce the performance gap between searching and re-training.  Table 5, which clearly shows that the translation quality is improving as the search hidden size gets larger. Also, note that when searching with tiny, small and medium settings, the final BLEU scores fall behind that of standard transformers. We argue that if one wants to evaluate the searched model in large model sizes, it is important to search with large hidden sizes. Further more, we directly apply DARTS with standard transformer as the backbone model. We set e = 320, d = 320. A larger search hidden size often causes memory failure due to the storage

Related Work
Architecture Search The field of neural architecture search (NAS) has seen advances in recent years. In the early stage, researchers focus on the reinforcement learning-based approaches (Baker et al., 2016;Zoph and Le, 2016;Cai et al., 2018a;Zhong et al., 2018) and evolution-based approaches (Liu et al., 2017;Real et al., 2017;Miikkulainen et al., 2019;So et al., 2019;. These methods can produce architectures that outperform human-designed ones Real et al., 2019). However, the computational cost is almost unbearable since it needs to fully train and evaluate every candidate network found in the search process. Weight sharing (Brock et al., 2017;Pham et al., 2018) is a practical solution where a super network is trained, and its sub-graphs form the search space.  proposed DARTS to use search parameters together with a super network, which allows searching with gradient descent. Gradient-based methods (Cai et al., 2018b;Xie et al., 2018;Xu et al., 2019;Yao et al., 2020) attracts researchers' attention since it is computationally efficient and easy to implement. We base our method on DARTS and take one step further to reduce the memory consumption of training the super network. Another recent trend is the one-stage NAS (Cai et al., 2019;Mei et al., 2019;Hu et al., 2020;Yang et al., 2020). Many NAS algorithms are in two stages. In the first stage, one searches for a good candidate network. In the second stage, the resulting network is re-initialized and re-trained. One-stage NAS tries to search and optimize the network weights simultaneously. After searching, one can have a ready-to-run network. We use a simple one-stage NAS algorithm (Guo et al., 2020) as a baseline in Section 4.1.

Conclusion
We have proposed a memory-efficient differentiable architecture search (DARTS) method on sequence-to-sequence tasks. In particular, we have first devised a multi-split reversible network whose intermediate layer outputs can be reconstructed from top to bottom by the last layer's output. We have then combined this reversible network with DARTS and developed a backpropagation-withreconstruction algorithm to significantly relieve the memory burden during the gradient-based architecture search process. We have validated the best searched architecture on three translation tasks. Our method consistently outperforms standard Transformers. We can achieve the same BLEU score as the original big Transformer does with 69% fewer parameters. At a large model size, we surpass Evolved Transformers with a search cost lower by an order of magnitude. Our method is generic to handle other architectures, and we plan to explore more tasks in the future.