Muppet: Massive Multi-task Representations with Pre-Finetuning

We propose pre-finetuning, an additional large-scale learning stage between language model pre-training and fine-tuning. Pre-finetuning is massively multi-task learning (around 50 datasets, over 4.8 million total labeled examples), and is designed to encourage learning of representations that generalize better to many different tasks. We show that pre-finetuning consistently improves performance for pretrained discriminators (e.g. RoBERTa) and generation models (e.g. BART) on a wide range of tasks (sentence prediction, commonsense reasoning, MRC, etc.), while also significantly improving sample efficiency during fine-tuning. We also show that large-scale multi-tasking is crucial; pre-finetuning can hurt performance when few tasks are used up until a critical point (usually above 15) after which performance improves linearly in the number of tasks.


Introduction
The recent success of language model pre-training (Devlin et al., 2018;Liu et al., 2019b;Lewis et al., 2019;Raffel et al., 2019;Radford et al., 2019) is remarkable, at least in part, due to the exclusive use of self supervision, without any manually labeled data. For many tasks, however, we already have training examples for related problems, which we should be able to leverage. Recent work has shown gains from fine-tuning schemes that are multi-task (Raffel et al., 2019;Khashabi et al., 2020) and multi-stage (Liu et al., 2019a), but it can be difficult to know which intermediate tasks will best transfer (Raffel et al., 2019). In this paper, we show that multi-task supervised tuning, if done at a sufficiently large scale with many different tasks, can be an effective second stage of task-agnostic pre-training, removing the need to pre-select the best intermediate tasks.
More specifically, in addition to the standard pre-training/fine-tuning methodology of learning language tasks, we introduce a new intermediate stage, pre-finetuning. Pre-finetuning involves a massive multi-task learning step (4.8 million total training examples) performed on around 50 classification, summarization, question answering, and common sense reasoning tasks. We believe we are the first to investigate multi-task learning at this scale in terms of both number and types of tasks. We show, in particular, that standard multi-tasking schemes can be unstable and often fail to learn high quality representations. However, we introduce a new training scheme which uses loss scaling and task-heterogeneous batches so that gradient steps are more evenly balanced across multiple different competing tasks, greatly improving training stability and overall performance. We call our prefinetuned models MUPPET; Massive Multi-task RePresentation with PrE-fineTuning.
Through extensive experiments, we show that incorporating pre-finetuning to RoBERTa (Liu et al., 2019b) and BART (Lewis et al., 2019) models yields consistent improvements, including new state-of-the-art performance for RTE  and HellaSWAG , without having to specify specific intermediate transfer tasks. These gains are particularly strong in the low resource regime, where there is relatively little labeled data for fine-tuning. We also study why pre-finetuning outperforms previous multitasking schemes. We first compare different optimization techniques to stabilize training, and find it important to use task-heterogeneous batches with task-rebalancing loss scaling. We also show that scale is crucial for effective multi-task learning. We empirically see a critical point in terms of the number of tasks (usually over 15); having fewer tasks degrades representations, while having more seems to improve performance linearly as far as we were able to scale.
To summarize, our contributions include: • We show that we can further improve pre-trained representations with an additional stage we call pre-finetuning, which utilizes massively multi-task learning. We show standard pre-trained representations, when further refined with pre-finetuning consistently improve performance on downstream tasks.
• We introduce a new multi-task training scheme for effective learning at scale, which uses loss scaling and task-heterogeneous batches.
• We explore the effects of scale on multi-task learning and show the existence of critical points in multi-task training, beyond which increasing the number of tasks improves generalizable representations.
• We conduct a study surrounding the data efficiency of standard pre-trained representations and their respective pre-finetuned counterparts. We show that the pre-finetuned models consistently require less data for fine-tuning.

Related Work
Multi-task learning has been an increasingly active topic in recent literature. Recent advances such as MT-DNN show that by leveraging multitask learning, we can further improve performance on several language benchmarks on top of traditional pre-training (Liu et al., 2019a). However, T5 (Raffel et al., 2019) shows that incorporating multi-task learning ontop of larger models does not improve upon the standardized pre-training / finetuning. Thus the effect of multi-task learning across different pre-training methods is not fully understood.
Recently Khashabi et al. (2020) showed how doing MTL training on a range of QA tasks can improve the performance of T5 by taking advantage of cross dataset transfer. Unlike our approach, they convert all the data to a seq2seq format, operate on a smaller MTL scale, have a different batching strategy, and focus solely on improving QA tasks. Our work shows how even seemingly very different datasets, for example, summarization and extractive QA, can help each other by improving the model's representations.
Our work aims to explore multi-task learning at a much larger scale; by incorporating a larger number of tasks, we show that we can consistently improve several language benchmarks from several domains. Contrary to T5, we show that incorporating a secondary stage of multi-task learning does lead to better representations. In §5 we demonstrate the effectiveness of multi-task learning to be coming from the large scale of our MTL setup.

Pre-Finetuning Through Massive Multitask Learning
Previous work has reported mixed results from experiments on multi-task learning (Liu et al., 2019a;Raffel et al., 2019). In general, it can be challenging to balance the losses from different tasks; upsampling can lead to overfitting low resource tasks, and downsampling can lead to improper learning of specific tasks. This difficulty is particularly pronounced when operating at the scale of experiments we show in Section 5.1, where there are more diverse tasks than previously considered. This section presents our pre-finetuning approach that leads to more stable and accurate multi-task training by introducing new optimization, loss scaling, and task sampling schemes to balance each minibatch's updates better.

Tasks and Losses
Diverse Tasks To learn general language representations, we include a variety of tasks across many domains. We select language tasks across four different domains: classification, commonsense reasoning, machine reading comprehension, and summarization. In Table 1, we show the break down of each of the task types along with the number of samples used from each during prefinetuning. In total our multi-task set up learns over 4.8 supervised samples across 4 families of tasks.
A full list of all of the datasets we leverage for pre-finetuning is described in appendix §A.1.  Standard Losses To train on several datasets, our model contains task-specific heads, each optimizing for a task-specific loss. The loss functions are summarized in table 2. Each loss is scaled with loss scaling described in §3.3. After loss scaling, the gradients from each task are averaged before doing the model update step.

Optimization
We show two strategies to learn multi-task representations at scale: Accumulating Gradients Across Tasks (Heterogeneous Batches) and Leveraging Better Finetuning.
Accumulating Gradients Across Tasks Our model is trying to optimize not a single objective but several potentially competing objectives to create a unified representation across several tasks during model training. During gradient descent, moving along the gradient of a single task may not be the optimal direction for the model to move to learn a single unified representation across tasks. To overcome this, we ensure each batch our model optimizes consists of several tasks. Each worker samples a random batch from our set of tasks and computes a gradient, accumulated for the final update. Empirically we use 64 GPUs for pre-finetuning, resulting in each batch consisting of gradients across 64 sampled tasks. In §5.2 we show how such a strategy allows for our model to arrive at a better representation for end task finetuning.
Better Finetuning Instead of starting from scratch, we initialize our model with representations learned from self-supervised pre-training in pre-finetuning. This can inherit the knowledge captured in the pre-trained representations and speed up training. Mosbach et al. (2020) show that standard fine-tuning of pre-trained models can be unstable, which may be aggravated in our case as we are training on a diverse set of tasks simultaneously. Therefore, we employ the R3F/R4F methods (Aghajanyan et al., 2020) to combat this issue.
In particular, R3F/R4F consists of an additional loss term, ensuring that small perturbations to the input space result in similar representations, which can be used to learn more robust representations during pre-finetuning.
In early experimentation, we found that R3F was pivotal in getting MUPPET to work for BART. All other fine-tuning and pre-finetuning was done using standard SGD.

Loss Scaling
Loss scaling methods introduce a multiplicative reweighting of individual losses per data-point. Various loss scaling techniques have been proposed, from dynamic scaling by inverse training loss to simple scaling by the number of data-points in respective datasets (Chen et al., 2018).
As pre-finetuning optimizes several different types of tasks and datasets, each having its own output spaces, loss scaling becomes essential to ensure stable training. We attempted various forms of loss-scaling throughout initial experimentation, but the most effective was the novel method we describe below.
Let us denote L i (x i , y i ; θ) as the loss for datapoint i for a model parameterized by θ. Remember that the loss depends on the type of task (commonsense loss is different from binary classification). Furthermore let n : N → N be a function which for each data-point returns the number of predictions L operates over. For example, for binary classification, n would return two, while for generation, n would return the size of the vocabulary (since we average across loss per token generated). We scale data-point loss so that, if the class distribution were uniformly distributed along with our models predictions, all of our losses would have equivalent values.
We found that this static scaling worked incredibly well, outperforming other loss scaling methods in early experimentation.

Sampling
Another approach to balancing various tasks in a multi-task set up is to up-sample smaller datasets and down-sample larger ones to achieve more uniformity between dataset sizes.
Existing results for dataset sampling methods in multi-task learning are conflicting, but recent work has shown that it does not work well for multitask learning of pre-trained representations. For example, T5 showed that all various forms of sampling did not improve overusing the natural size of datasets (Raffel et al., 2019).
We also found that sampling datasets were consistently detrimental for multi-task learning over pre-trained representations during initial experimentation. Specifically, we saw unmanageable over-fitting and stability issues. Therefore we opt for maintaining the natural distribution of the datasets throughout all of our experiments.

Experimental Setup
We selected RoBERTa (Liu et al., 2019b) and BART (Lewis et al., 2019) as our initial pre-trained models to further pre-finetune. For each task type we use a different prediction scheme. Every Sentence Prediction dataset gets a separate classification head, for Commonsense and MRC we utilize a separate unified head for each task. For Summarization, we do not add any parameters and use the BART decoder and output layer as is. Experimentally we saw using a different head per individual Commonsense and MRC datasets lead to severe overfitting.
For both models, we do the pre-finetuning procedure for both the Base and Large models. We trained each model configuration with 64 GPUs until convergence. Dependent on configuration, this ranged from a day to 4 days. We include the hyper-parameters used per pre-finetuning run in the Appendix in Section §A.2.

Empirical Results
We first show that pre-finetuning improves the representations of pre-training models. To do so, we fine-tune our pre-finetuned models on a large set of tasks.
For each of the individual downstream tasks, we use a fixed hyper-parameter search to optimize over simple hyperparameters such as learning rate, Adam (Kingma and Ba, 2014) and dropout (Srivastava et al., 2014). We present our results in two tables. Table 3 shows our results on the GLUE benchmark (Wang et al., 2018) as well as two MRC tasks; SQuAD  and ReCoRD . Table 4 reports results on other Sentence Prediction tasks as well as Commonsense tasks. We also include results from MT-DNN (Liu et al., 2019a), ELECTRA , 1 and RoBERTa (Liu et al., 2019b) models. For Summarization tasks we show that our pre-finetuned BART model outperforms all other summarization baselines. Both of these tables report over data-sets available during the prefinetuning stage.
Given that our pre-finetuned models now have an understanding of the task at hand through the use of classification heads, we have a choice during finetuning on whether or not to use these heads. In general we found re-using heads to be beneficial for MRC, Commonsense and Sentence Prediction tasks with small dataset size.
Across the board, pre-trained representations that were further refined with pre-finetuning outperformed standard pre-trained representations. We see more modest gains on larger datasets, most likely because we do not need to refine representations beforehand if the fine-tuning dataset is large. On smaller datasets, we see substantial gains. For example, the pre-finetuned RoBERTa-BASE model on RTE improves by close to 9 points, rivaling the RoBERTa-Large accuracy, while the pre-finetuned RoBERTa-Large model gets new state-of-the-art on RTE rivaling models an order of magnitude larger than it.
We do not improve just over sentence prediction tasks but on every set of tasks that we measured. For example, we reach a new state of the art on the HellaSwag dataset previously achieved by utilizing a new fine-tuning approach. Our methods do not increase parameter count or any complexity measures but are quite successful at refining features and preparing them for downstream fine-tuning.

Finetuning Outside of Pre-Finetuning Domain
We also report the performance on tasks not included in the pre-finetuning data. To do so, we finetune our models on a set of tasks including  (3) Chunking, Constituency Parsing and Part-Of-Speech tagging for structured prediction from the Penn Treebank dataset (Marcus et al., 1993). We present these results in Table 5 and Table 6.   We see that the MUPPET variants of our models out-perform the baselines consistently across task type and dataset. As a special case we do an in depth analysis of the MUPPET variant of RoBERTa on the notoriously tough ANLI dataset and see the same pattern. Pre-finetuned models consistently outperform their base counterparts.

Importance of Scale
The first axis we would like to explore is the scale on which multi-task learning is done. Previous work, such as T5 and MT-DNN, focused on the MTL scale of around a dozen datasets. To the best of our knowledge, our paper has the largest MTL set up to date. Accordingly, we are interested in empirically exploring the effects of scaling up the number of datasets to the representations learned during MTL.
We pre-finetune a collection of RoBERTa-Base models with varying numbers of datasets. We train seven models, six uniformly chosen between 10 and 40, ensuring that at each point, the selected datasets are a superset of the datasets from prior points. The last model is fully trained on all datasets. Concretely given two models trained with a different number of datasets a, b : a > b, model a will contain all datasets used to train model b and more.
For each version of the model, we fine-tune five datasets and plot the results in Figure 1. Specifically we finetune STS-B (Cer et al., 2017) Table 6: We show the performance of the RoBERTa model and the pre-finetuned RoBERTa-MUPPET model on the ANLI benchmark. Bolded numbers signify MUPPET vs base model, underline signifies best number. 'S' refers to SNLI, 'M' to MNLI dev (-m=matched, -mm=mismatched), and 'F' to FEVER; 'A1-A3' refer to the rounds respectively and 'ANLI' refers to A1+A2+A3. , RACE , SQuAD , and MNLI . We include these five datasets in the first MTL run (10 datasets) to remove any bias from adding them in a later stage. We see a couple of interesting patterns. First, for individual tasks such as RTE , increasing the pre-finetuning scale monotonically improves performance. This is aligned with other papers that have seen benefits from first training on MNLI  and then fine-tuning on RTE (Liu et al., 2019b). For other datasets, we see that doing MTL in the < 15 datasets regime is detrimental for end-task finetuning. This is also aligned with other empirical observations, i.e., T5 reported that doing MTL did not improve over only fine-tuning. Nevertheless, it seems that as we increase the number of tasks past some critical point, our pre-trained representations become more generalizable. Furthermore, al-though dependent on the dataset, this critical point is roughly between 10 and 25 tasks.
This suggests that previously observed MTL limitations were not fundamental and can instead be attributed to the lack of sufficient scale.

Importance of Heterogenous Batches
Another critical factor to getting MTL to learn generalizable representations is the method through which MTL is implemented, specifically the selection of batches. To better quantify this trend, we experimented with three balancing schemes: dataset homogenous, batch homogenous and batch heterogenous.
We refer to dataset homogenous as selecting batches from datasets sequentially. So we first train on dataset A, then train on dataset B, etc. On the other hand, batch homogenous refers to selecting batches containing only data from the same task; therefore, all gradients are from the same dataset. This is implemented by selecting all datasets, batching on a dataset level, and selecting those same batches randomly during training. Finally, batch heterogeneous refers to a single update containing a batch from multiple different datasets spanning different tasks. We implemented this by first creating homogenous sub-batches, calculating loss per sub-batch per GPU, and then aggregating across GPUs manifesting in a gradient update that contains various datasets and, therefore, tasks.
To dissect the importance of heterogeneous batches, we train a RoBERTa-Base model on 35 randomly selected tasks using the three data selection methodologies outlined above. We then finetune these three models on the same five data-sets mentioned in the previous section.
We present our results in Figure 2. We see the importance of properly defining a batching strategy for effective multi-task learning. Our findings are also consistent with (Aghajanyan et al., 2020) which saw that sequential training of data-sets degrades generalizable representations.

Low Resource Experiments
We noticed in Section §4 that data-sets with smaller data-set sizes tended to improve more from MTL training. To strengthen this hypothesis, we look at two factors: the scale of pre-finetuning and the scale of fine-tuning (size of fine-tuning data-set).
We select three data-sets that were not used in pre-finetuning in Section §5.1. We also select nine partitions per fine-tuning data-set, which is sampled uniformly between 10% of the data-set and 100% of the data-set. Selecting the low-resource splits was done through random sampling.
We then fine-tune every low-resource split with every pre-finetuning checkpoint from Section §5.1. We plot the heatmaps generated from these runs in Figure 3. Multiple patterns emerge. First, we see a clear visualization of the critical point mentioned when doing pre-finetuning. As we increase the scale of MTL, better representations are available for downstream finetuning. Furthermore, we see that prefinetuned models at a larger scale are much more data-efficient than standard pre-trained models.
Specifically looking at the 34/40 pre-finetuning scale on Figure 3 we see that we reach higher evaluation accuracies much sooner than the base RoBERTa model (row 0).

Conclusion
In this work, we propose pre-finetuning, a stage after pre-training to further refine representations before end-task finetuning. We show that we can effectively learn more robust representations through multi-task learning (MTL) at scale. Our MTL models outperform their vanilla pre-trained counterparts across several tasks. Our analysis shows that properly scaling MTL with heterogeneous batches and loss scaling is critical to leveraging better representations. We also show a critical point regarding the number of tasks when doing multi-task learning, where fewer tasks degrade representations compared to the pre-trained model, but more tasks than this point improve representations.
We discussed a practical setting in which doing this massive multi-task learning is stable and effective through simple loss scaling and heterogeneous batches. With our method, we improve upon prior state of the art methods for RTE    Figure 3: We fine-tune every low-resource split with every pre-finetuning checkpoint from Section §5.1 for two datasets not available in any of the pre-finetuning MTL datasets; QNLI  and CoLA . The pre-finetuning scale is reported in terms of the number of datasets.
2019), as well as improve upon vanilla pre-trained representations for MNLI , SQuAD , BoolQ , and Common Sense QA . We also our MTL model performance with low resource experiments. We show that on heldout datasets, leveraging representations from our pre-finetuned models with 34-40 tasks, we reach higher evaluation accuracies with much less data than the RoBERTa model.