Finetuning Pretrained Transformers into RNNs

Transformers have outperformed recurrent neural networks (RNNs) in natural language generation. But this comes with a signifi- cant computational cost, as the attention mechanism’s complexity scales quadratically with sequence length. Efficient transformer variants have received increasing interest in recent works. Among them, a linear-complexity recurrent variant has proven well suited for autoregressive generation. It approximates the softmax attention with randomized or heuristic feature maps, but can be difficult to train and may yield suboptimal accuracy. This work aims to convert a pretrained transformer into its efficient recurrent counterpart, improving efficiency while maintaining accuracy. Specifically, we propose a swap-then-finetune procedure: in an off-the-shelf pretrained transformer, we replace the softmax attention with its linear-complexity recurrent alternative and then finetune. With a learned feature map, our approach provides an improved tradeoff between efficiency and accuracy over the standard transformer and other recurrent variants. We also show that the finetuning process has lower training cost relative to training these recurrent variants from scratch. As many models for natural language tasks are increasingly dependent on large-scale pretrained transformers, this work presents a viable approach to improving inference efficiency without repeating the expensive pretraining process.


Introduction
Transformer models (Vaswani et al., 2017) have advanced the state of the art beyond recurrent neural network models (e.g., LSTMs, Hochreiter and Schmidhuber, 1997;GRUs, Cho et al., 2014) across a wide range of natural language process- * Work was done during an internship at Microsoft. ing tasks. In particular, the transformer architecture has been widely used in autoregressive generation such as language modeling (Baevski and Auli, 2019) and machine translation (Vaswani et al., 2017). The transformer makes crucial use of interactions between feature vectors over the input sequence through the attention mechanism (Bahdanau et al., 2015). However, this comes with significant computation and memory footprints during text generation. Since the output words are incrementally predicted conditioned on the prefix, generation steps cannot be parallelized over time steps and require quadratic time complexity in sequence length. The memory consumption in every generation step also grows linearly as the sequence becomes longer. This bottleneck for long sequence generation limits the usage of large-scale pretrained generation models, such as GPT-3 (Brown et al., 2020), Image Transformer (Parmar et al., 2018), and DALL-E (Ramesh et al., 2021).
Recent works aim at reducing the overhead of autoregressive transformers (Child et al., 2019;Kitaev et al., 2020;Beltagy et al., 2020, inter alia). Among them are recurrent alternatives that approximate the standard softmax attention (Katharopoulos et al., 2020;Choromanski et al., 2021;Schlag et al., 2021). Similar to recurrent neural networks (RNNs), those models represent the context by a recurrent state with a fixed size, thereby achieving linear time and constant memory complexity in generation sequence length. When the recurrent state size is smaller than the sequence length, these variants provide substantial speed and memory advantages over the transformer. A small state size, however, tends to deteriorate the generation quality , leading to a tradeoff between efficiency and accuracy.
This work improves the balance between efficiency and accuracy by a conversion approach: instead of training a recurrent alternative from scratch, we develop a method to convert a pretrained transformer into an efficient RNN that speeds up generation and reduces memory footprints. Our conversion proceeds with a swap-thenfinetune process. Specifically, we change the exponential similarity function in the attention mechanism to a single-layer MLP feature map. We then finetune the MLP parameters and the other network parameters. Our experiments in language modeling and machine translation show that the conversion can compress the context into a much smaller recurrent state than the sequence length (e.g., 1/16 of the sequence length in WikiText-103 language modeling) while retaining high accuracy. In addition, this conversion requires much less GPU time than training randomly-initialized models from scratch.
State-of-the-art models in many natural language tasks are increasingly dependent on numerous large-scale pretrained transformer models (e.g., GPT-2, Radford et al., 2019;BERT, Devlin et al., 2019;RoBERTa, Liu et al., 2019;T5, Raffel et al., 2020;BART, Lewis et al., 2020;DeBERTa, He et al., 2021). Converting a large off-the-shelf transformer to a lightweight inference model without repeating the whole training procedure is particularly useful in many downstream applications. Our work focuses on text generation and presents a viable approach towards efficient inference with high accuracy.

Convert a Transformer into an RNN
The transformer architecture consists of multihead attention, feedforward layers, and layer normalization modules (Vaswani et al., 2017). When a transformer is trained for a sequence generation task with teacher forcing (Williams and Zipser, 1989), the attention can be parallelized over positions because the target sequence is fully available. During generation, on the other hand, the output is incrementally constructed. As a result, the attention becomes an inference bottleneck for long sequences. Here we present a method to eliminate this bottleneck by converting a pretrained transformer into an efficient RNN of linear time and constant space complexity.

Multihead Attention
The attention module takes as input sequences of source and target vectors. The source vectors are used to produce key and value features, while the target vectors are mapped to query vectors. More formally, denote by {x tgt i } N i=1 and {x src j } M j=1 the target and source vectors, where x tgt i , x src j ∈ R h and h is the model dimensionality. We assume r attention heads of d dimensions (h = dr). For each head, the input vectors are first mapped to d dimensional query, key, and value features by learned affine transformations with W * ∈ R d×h and b * ∈ R d : The similarities of each query vector q i with all M key vectors are computed and normalized to produce attention coefficients, which are then used to output a weighted average of the value vectors (Vaswani et al., 2017): where Multihead attention runs this procedure for each of the r heads in parallel and concatenates r output vectors to get the final h dimensional vector. 1 Generation Speed Overhead Fig. 1

Converting Transformers to RNNs
To address this generation bottleneck of quadratic time and linear space, we propose Transformerto-RNN (T2R), a method to convert a pretrained transformer to an RNN inference model of linear time and constant memory complexity in sequence length ( Fig. 1). T2R follows a swap-then-finetune procedure that modifies the attention computation of a pretrained transformer, and finetunes the model with the task objective. We first replace the dot-then-exponential similarity function in a pretrained transformer (Eq. 2b) by where Here W φ ∈ R k×d and b φ ∈ R k are learned parameters of a single-layer MLP. They map a d dimensional vector to a k dimensional kernel feature space. The relu activation (Fukushima, 1980) ensures that the features are all nonnegative. 3 Different MLP parameters are used for different attention heads, and thus we add a total of rk(d + 1) learnable parameters per layer (less than 0.2% parameter increase in our language model, §3). We then finetune this modified network, including the MLP parameters, with the original task objective.
During inference generation, we reformulate the attention computation (Eq. 2a) as by the associativity of matrix multiplication. This formulation lends itself to recurrent computation. For the query vector at each position i, define states S i and z i : In causal attention where each query only attends to its prefix to predict the next word (M = i), S i and z i define recurrent states (Katharopoulos et al., 2020): Here S i , z i ∈ R k×d , R k . When M does not vary across all queries, such as self-attention and encoder-to-decoder (cross) attention in a sequenceto-sequence model, S and z define fixed states. Given the two recurrent states at position i, we can compute the output vector: This avoids quadratic computation with respect to the input sequence length. We also speed up inference by merging the MLP feature map with the affine feature maps that produce query and key vectors. where After the model is trained, Eqs. 8c-8d are computed once before generation; the intermediate features of q i and k j are never computed during inference.

Generation Speed Overhead
The time complexity of each step in a T2R model is shown in Fig. 1. Similar to the transformer, it proceeds over two stages.
• Feature Mapping: Generation Memory Overhead T2R only needs to store the RNN state, and thus its space complexity is O(hk), constant in sequence length. This implies reduction in memory footprints when k << M , compared to the transformer's O(M h).

Autoregressive Linear Transformers
In principle, any kernel function can be used as the similarity function in Eq. 2a (Tsai et al., 2019). Previous work proposed several untrainable feature map functions φ and developed autoregressive transformer variants with linear time and constant space complexity in sequence length. Katharopoulos et al. (2020) proposed φ (x) = elu (x) + 1 and applied it to image generation. In language modeling and machine translation tasks, RFA  and Performer (Choromanski et al., 2021) used random features that approximate the softmax attention via Monte Carlo sampling (Rahimi and Recht, 2007;Yu et al., 2016).
While those models follow similar computation steps to T2R, there are several differences in generation efficiency. Since the elu feature map (Katharopoulos et al., 2020) preserves input dimensions, the feature size is always the same as the head dimensions (k = d). This means that the speedup and memory savings from using a small feature size are restricted by design. In our experiments ( §3.3), our T2R models gain further efficiency by using a feature size that is even smaller than the head dimensions (k = 32 and d = 128 for language modeling). RFA and Performer scale query and key vectors by their norms before the random approximation to bound the approximation error. This means that the feature mapping stage needs additional steps of producing intermediate q and k and scaling them. T2R models suppress these intermediate steps and speed up generation further ( §3.3).

Experiments
We present extensive experiments on standard benchmarks for language modeling and machine translation. Our results show that T2R achieves efficient autoregressive generation while retaining high accuracy.

Baselines and Comparison
We compare performance with previous transformer models for autoregressive generation with linear time and constant space complexity in input sequence length (Katharopoulos et al., 2020;Choromanski et al., 2021). 4 As discussed in §2.3, those prior methods correspond to two different untrainable feature maps φ. We experiment with two types of feature maps for comparisons: ELU (φ (x) = elu (x) + 1, Katharopoulos et al., 2020); RFA (random feature approximation with softmax temperature reparameterization, . Each feature map is evaluated in two settings: random initialization and pretrain. Random initialization is our reimplementation of the experiments in Katharopoulos et al. (2020) and . The pretrain setting follows the same protocol as T2R except that we use different feature maps φ than our proposed one-layer MLP with relu activation. Positive orthogonal random features (Performer, Choromanski et al., 2021) provide similar random approximation to RFA and were evaluated in the biology domain, but we found that this method caused training divergence in the language modeling task. 5

Setup and Implementations
We apply our method to causal attention in language models and both cross and causal attention in machine translation. For language modeling, we use a 32-dimensional feature map function. We do not modify the encoder in machine translation as its generation speed overhead is much less significant than the decoder (Kasai et al., 2021). Our exploration showed that reducing the feature size of causal attention tends to have less impact on the final translation accuracy as opposed to cross attention; we use feature sizes of 32 and 4 for cross and causal attention. This observation is consistent with previous work that demonstrated that causal attention can be more drastically simplified than cross attention in transformer machine translation models (You et al., 2020;Tay et al., 2020a).

Language Modeling
We use the WikiText-103 benchmark, which consists of 103M tokens sampled from English Wikipedia (Merity et al., 2017). We choose similar hyperparameters to prior work (Baevski and Auli, 2019; Fan et al., 2020): 32 layers, 8 heads, 128 head dimensions, 1024 model dimensions, 4096 fully connected dimensions and dropout (Srivastava et al., 2014) and layer dropout rates of 0.2. The word embedding and softmax matrices are tied (Press and Wolf, 2017;. We partition the training data into non-overlapping blocks of 512 contiguous tokens ignoring document boundaries and train the model to predict each token from left to right (Baevski and Auli, 2019). Validation and test perplexity are measured by predicting the last 256 words out of the input of 512 consecutive words to avoid evaluating tokens in the beginning with limited context (early token curse, Press et al., 2021). We generally follow the optimization method from Baevski and Auli (2019), but some hyperparameters, such as the learning rate for the T2R finetuning, are adjusted for better con-vergence than randomly initialized training. See Appendix A.1 for more details and a complete list of hyperparameters.

Machine Translation
We experiment with 3 translation benchmarks: (2018), we use an increased batch size of approximately 460K tokens by accumulating gradients without updating parameters. Each randomly initialized model is trained for 30K (60K for the large EN-FR dataset) steps using Adam with a learning rate of 5⋅10 −4 and β = (0.9, 0.98) (Kingma and Ba, 2015). We observed that convergence of the T2R conversion can be achieved with 20K (40K for EN-FR) steps and a reduced learning rate of 2⋅10 −4 . We average the checkpoints from the last five epochs to obtain the final model (Vaswani et al., 2017). In inference, we apply beam search decoding with beam size 5 and length penalty 0.6. Consistent with previous practice, we use tokenized BLEU (Papineni et al., 2002) for evaluation. Further details are described in Appendix A.1.

Results
Language Modeling Seen in Table 1 are language modeling results in perplexity. We observe that T2R with the learnable MLP feature map outperforms the other two linear transformer models by more than 2.0 perplexity points in the pretrain setting. Unlike the other linear transformer models, T2R greatly benefits from pretraining (T2R + Pretrain: 19.6 vs. T2R + Random Init.: 20.8 test perplexity points). This difference suggests that using a trainable feature map is crucial in our swapthen-finetune approach. We attribute this advantage to the fact that the MLP feature map is able to learn attention patterns that are similar to those of the pretrained transformer, as evidenced in §4.2. Notice also that the T2R conversion is ∼5x faster (measured in GPU hours) than training a model from scratch. These results illustrate that a lightweight model can be obtained without repeating the expensive training of large-scale pretrained language models such as GPT-2 and GPT-3 (Radford et al., 2019;Brown et al., 2020). There remains a gap of 1.1 perplexity points between the T2R and pretrained transformer models (19.6 vs. 18.5). Nonetheless, the gap can be closed when every fourth layer from the top is kept as the original transformer layer and the model is finetuned in the same way (T2R 75%). This suggests that keeping a small fraction of the quadratic attention layers can provide an effective middle ground between efficiency and accuracy. Table 2 are machine translation results in BLEU from various configurations. Departing from the language modeling experiments, the T2R model underperforms the other two linear transformer models when initialized randomly. However, consistent with the language modeling, the T2R model substantially benefits from pretraining (e.g., 28.7 vs. 27.5 BLEU points in EN-DE). As a result, the T2R model achieves similar BLEU scores to the original transformer across all language pairs. ELU trained from the pretrained transformer yields comparable performance to T2R, but the feature size is much larger (64 vs. 32 and 64 vs. 4 in cross and causal attention), thus leading to increased overhead, as shown later. Note that the T2R finetuning time is only moderately smaller than that of randomly initialized training here, but further speedup in conversion can be potentially achieved with more extensive hyperparameter tuning. 6   8 16 32 64 128 256 512 1024 Katharopoulos et al. (2020) and . Pretrain indicates initialization with a trained transformer-large model. *: diverged even when running with multiple random seeds and smaller learning rates.

Speedup and Memory Savings in Generation
We run a conditional generation experiment to compare sequence-to-sequence decoding speed (Fig. 2).
Here we assume the input and output sequences are of the same length. The compared models are of the same size as those in Table 2. All models are tested using greedy decoding with the same batch size of 16 on a TPU v2 accelerator. 7 We see that indeed the linear transformer models can generate an almost constant number of tokens per second regardless of the sequence length and outpaces the transformer model dramatically as the sequence becomes longer. The T2R model achieves 15%+ speedup over ELU and RFA due to its smaller feature sizes and faster feature mapping respectively; this confirms our analysis on T2R's speed advantage over them ( §2.3). Fig. 3 plots memory consumption from the attention computation during decoding for machine translation. Since the T2R models compress keys and values into a k×d matrix S and a k dimensional vector z ( §2.2), the required memory at each decoding step is constant over varying sequence lengths. It is also roughly proportional to the feature size k. The MLP feature map in the T2R model allows for small feature dimensions than the ELU feature of the head dimensions, resulting in a 70% memory reduction. The attention computation in the standard transformer, on the other hand, consumes memory linearly in sequence length at each decoding step because all previous key and value vectors have to be stored. We also found a similar speedup and memory savings in unconditional generation with the T2R language 7 https://opensource.google/projects/ jax. model (∼4x speedup in generating 512 consecutive tokens with batch size 16 and beam size 1 over the standard transformer).

Analysis and Ablations
We presented T2R, a method to convert a pretrained transformer into an efficient RNN. In this section, we analyze our conversion approach by examining the impact of the feature size and induced attention weight distributions. Our analysis shows that T2R implicitly learns attention distributions similar to the original transformer.

Feature Size and Pretraining
We saw that T2R benefits substantially from transformer pretraining. Fig. 4 Figure 5: Average Euclidean distance of T2R models from the transformer attention weights with varying feature sizes. The distances are computed on the Wikitext-103 validation data for predicting a word given the preceding 512 words. All models are initialized with a pretrained transformer model. the relation between the validation perplexity from WikiText-103 and the feature sizes. We see that as the feature size (RNN state size) becomes smaller, pretraining becomes particularly important to achieve low perplexity. Transformer pretraining achieves a Pareto improvement over random initialization in the tradeoff between efficiency (small feature size) and accuracy (low perplexity).

Attention Distribution
Our T2R conversion runs a swap-then-finetune procedure; we modify the similarity function and finetune the model and the MLP parameters with the task objective. This means that the T2R model is not explicitly trained to mimic the original attention distributions, and there is no guarantee that the MLP feature map approximates the exponential similarly function, unlike previous approximation approaches Choromanski et al., 2021). Here, we analyze the properties of the attention weight distributions that are induced by finetuning. We use the validation data from WikiText-103 and run language models to predict the next word given the input of 512 contiguous words. We compute the attention weight distribution over the 512 words for each attention head in the model layers.
Distance from Softmax Fig. 5 compares the attention distributions from T2R in various configurations. T2R MLP frozen indicates a model that is finetuned with the MLP parameters frozen. Euclidean distances in attention distributions between the original transformer and each model are averaged across validation samples, model layers, and attention heads. 8 Comparing T2R before finetuning and the full T2R model, we see that the finetuning process induces much more similar attention distributions, and the distance diminishes as the feature size increases (and the perplexity approaches the original transformer, Fig. 4). We also observed that when the MLP parameters are not trained (T2R MLP frozen), the distance from the original attention distributions increases. These results suggest that finetuning of the whole network in T2R implicitly develops similar attention distributions to the original transformer even though the training supervision comes solely from language modeling.

Further Related Work
In addition to the work we already discussed, we highlight related methods from prior work that make transformer models efficient here.

Knowledge Distillation
Knowledge distillation (Hinton et al., 2015) is closely related to our T2R conversion method and uses a similar pipeline: a teacher model with large capacity is first trained and it is typically used to generate silver training data for a new lightweight inference model. It has been successfully applied to machine translation (e.g., Kim and Rush, 2016;Gu et al., 2018) and contextual word representations (e.g., Sanh et al., 2019;Wang et al., 2020c,b) to make inference efficient. In particular, several prior works distill a transformer translation model to a recurrent neural network for efficient inference (Senellart et al., 2018;. We share the same motivation toward fast generation with light memory, but our approach differs in two ways: the original tranining data are used for finetuning an RNN model, and its model parameters are initialized with the "teacher" transformer. These differences have several crucial implications in practice. Firstly, our method does not use the computationally expensive teacher model to generate new training data. While this procedure is one-time computational cost, it becomes expensive as the teacher model size and training data increase. In addition to this computational cost, it is challenging to apply knowledge distillation to an autoregressive language model, such as GPT-3 (Brown et al., 2020) since we need to sample diverse teacher outputs without explicit conditioning unlike machine translation. Lastly as shown in our experiments ( §3.3), since the pretrained parameters can be directly used, conversion requires fewer GPU hours than training a brand new lightweight model from scratch.

Efficient Transformers
Prior work suggested many other strategies to improve efficiency in transformers, such as weight sharing and factorization (Dehghani et al., 2019;Lan et al., 2020), weight and layer pruning (Michel et al., 2019;Fan et al., 2020), quantization (Zafrir et al., 2019;Shen et al., 2020), training on short sequences (Press et al., 2021), gradually reducing the input sequence length from layer to layer (Dai et al., 2020), clustering query vectors (Vyas et al., 2020), and modifying the combination of attention and feedforward sublayers (Press et al., 2020;Mandava et al., 2020). Some of these methods present orthogonal design choices and can be integrated into our T2R model to gain further efficiency. For a more comprehensive survey of efficient transformers, see Tay et al. (2020c). Below we describe several prior works along two major strategies that reduce time and memory overhead in the attention mechanism: compressing the attention context and sparsifying the attention patterns.
Attention Context Compression This strand of methods compresses the context that is attended to, thereby reducing the time and memory overhead in the attention. RNN models that we converted pretrained transformers into (Katharopoulos et al., 2020;Choromanski et al., 2021) compress the context into a recurrent state. Other approaches include low rank approximation of the attention computation (Wang et al., 2020a;Tay et al., 2020a) and adding a memory module that can access multiple tokens at once (Liu et al., 2018;Dai et al., 2019;Lee et al., 2019;Ainslie et al., 2020;Rae et al., 2020;Beltagy et al., 2020;Zaheer et al., 2020).
Sparse Attention Patterns Another approach to reducing the time and memory overhead from the attention computation is to limit the tokens that are attended to by sparsifying the attention patterns. These patterns can be set in advance or learned during training (Tay et al., 2020c). For example, prior works introduced fixed patterns of blockwise attention (Qiu et al.) and strided attention (Child et al., 2019;Beltagy et al., 2020;Zaheer et al., 2020). Other previous works, on the other hand, presented methods to learn attention patterns from data (Sukhbaatar et al., 2019;Roy et al., 2020;Tay et al., 2020b).
It should be noted that significant modifications are necessary to apply many of these methods to autoregressive generation tasks such as language modeling and machine translation, and their empirical evaluation in these generation settings has yet to be conducted . This work presents extensive empirical evaluation in autoregressive generation settings.

Conclusion and Future Work
We present T2R, a method that converts a pretrained transformer to a recurrent neural network that reduces the time and memory cost of autoregressive generation. Our experiments in language modeling and machine translation demonstrated that our model produces an improved tradeoff between efficiency and accuracy over randomly initialized training and previous models with lightweight attention. Our work provides further support for the claim that large-scale pretrained models can be compressed into efficient inference models that facilitate downstream applications.

A.1.1 Language Modeling
We generally follow the optimization method from Baevski and Auli (2019). For optimizing a model from random initialization, the learning rate is linearly warmed up from 10 −7 to 1 for the initial 16K steps and then annealed using a cosine learning rate schedule with cycles (Loshchilov and Hutter, 2017). Each period lasts for twice the number of updates than the previous cycle, and we lower the maximum and minimum learning rates by 25% compared to the previous cycle. The initial minimum and maximum learning rates are 10 −5 and 1 respectively (Baevski and Auli, 2019). We train the model with a batch size of about 74K tokens with a total of 286K steps (Baevski and Auli, 2019). When we convert a pretrained transformer to an RNN model by finetuning, we found that we could speed up training by reducing the warm-up steps, total update steps, maximum and minimum rates, and batch size to 8K steps, 142K steps, 5 ⋅ 10 −6 , 0.5, and 25K tokens without loss in validation perplexity.
Randomly Initialized Training We generally follow the hyperparameters chosen in Baevski and Auli (2019); Fan et al. (2020). Specifically, we list the hyperparameters in Table 3 for easy replication. All other hyperparamter options are left as default values in fairseq. Table 4 are the hyperparameters for finetuning a pretrained transformer to RNN models. The learning rates, the max number of updates, and the learning period length are all reduced.

A.1.2 Machine Translation
We experiment with 3 translation benchmarks:    . These datasets are all encoded into subwords by BPE (Sennrich et al., 2016). We run joint BPE on all language pairs except EN-ZH. We use the hyperparameters of the large sized transformer (Vaswani et al., 2017): 6 layers, 16 attention heads, 1024 model dimensions, and 4096 hidden dimensions for both the encoder and decoder. Similar to the language models, all subword embedding and softmax matrices are tied. We apply dropout with 0.3, weight decay with 0.01 and label smoothing with ε = 0.1. Following , we use an increased batch size of approximately 460K tokens by accumulating gradients without updating parameters.
Randomly Initialized Training We generally follow the hyperparameters chosen in Vaswani et al. (2017); . Specifically, we list the hyperparameters in Table 5 for easy replication. All other hyperparamter options are left as default values in fairseq. The parameters from the last five epochs were averaged to obtain the final model. Table 6 are the hyperparameters for finetuning a pretrained transformer to RNN models. The learning rate and the max number of updates are reduced.