PruMUX: Augmenting Data Multiplexing with Model Compression

As language models increase in size by the day, methods for efficient inference are critical to leveraging their capabilities for various applications. Prior work has investigated techniques like model pruning, knowledge distillation, and data multiplexing to increase model throughput without sacrificing accuracy. In this paper, we combine two such methods -- structured pruning and data multiplexing -- to compound the speedup gains obtained by either method. Our approach, PruMUX, obtains up to 7.5-29.5X throughput improvement over BERT-base model with accuracy threshold from 80% to 74%. We further study various combinations of parameters (such as sparsity and multiplexing factor) in the two techniques to provide a comprehensive analysis of the tradeoff between accuracy and throughput in the resulting models. We then propose Auto-PruMUX, a meta-level model that can predict the high-performance parameters for pruning and multiplexing given a desired accuracy loss budget, providing a practical method to leverage the combination effectively.


Introduction
Large language models (LLMs) have achieved state-of-the-art performance across various NLP tasks and resulted in impressive user-facing demonstrations such as ChatGPT. 2 However, their large size necessitates the use of enormous amounts of compute and memory at inference time, which limits their widespread use.
Two types of techniques have been explored to reduce the cost of model inference. The first is model compression including network pruning (Le-Cun et al., 1989;Han et al., 2015b;Frankle and Carbin, 2019), quantization (Han et al., 2016), knowledge distillation (Hinton et al., 2015), combinations of multiple methods (Xia et al., 2022). The  (Devlin et al., 2018) on the MNLI task (Williams et al., 2017). The sparsity for a CoFi's data point is labeled as s. The width of multiplexing for a DataMUX's data point is labeled as N . The parameter pair for a PruMUX's data point is labeled as (N , s).
second is recently proposed data multiplexing (Murahari et al., 2023), which multiplexes multiple inputs into a single input for model inference.
While both types of methods leverage the overparameterization effect (Allen- Zhu et al., 2019;Radhakrishnan et al., 2020) in modern deep neural networks to improve the throughput-to-compute cost ratio, the manner in which they do so is different. Model compression aims at reducing the number of parameters in the model, hence reducing the overall compute cost (denominator) to improve the ratio. Data multiplexing, on the other hand, compresses multiple inputs into one to improve throughput (numerator) while keeping the model size fixed. This observation naturally leads us to hypothesize that the two types of methods could be complementary and can be combined for maximal gain in the throughput-to-compute cost ratio.
There are two challenges to this hypothesis. The first is that both model compression and data multiplexing aim at trading a small accuracy loss for large throughput improvement. Intuitively, the combination may incur an accuracy loss larger than either method and it is not clear how they interact with each other when combining them together. A research question is how to combine the two methods such that the combination achieves better throughput than each type of method individually, given any accuracy loss budget or accuracy threshold.
The second challenge is to efficiently find the best parameters pair (N, s) where N is the width of the data multiplexing and s is the sparsity of the model compression method. Training and testing with each parameter combination is costly and time-consuming. A research question is how to automatically predict and find top parameters based on the model's performance on one set of parameters.
To address the first research question, we present PruMUX, a combination of model compression and data multiplexing. Our method is simple and consists of three phases -multiplexed model pretraining, task-specific fine-tuning and task-specific model compression. In our implementation, we make use of CoFi (Xia et al., 2022), a state-ofthe-art model compression method that includes intermediate knowledge distillation steps that help minimize accuracy hits, and DataMUX (Murahari et al., 2023), which performs vector-based input multiplexing over instances.
Our results over four datasets (MNLI, QNLI, QQP and SST-2) demonstrate that PruMUX achieves significantly higher throughput over CoFi and DataMUX individually for a large range of accuracy thresholds. As an example, Figure 1 shows the throughput improvements over the BERT-base model on task MNLI, providing a more optimal Pareto frontier in the tradeoff between accuracy and throughput.
To address the second research question, we propose Auto-PruMUX, a meta-model to automatically predict and find the high-performance parameter combinations for a desired accuracy loss budget on a task based on the model's performance on one set of parameters without running additional experiments. We use interpolation and estimation models over a set of data points to predict the accuracy and throughput of a PruMUX model based on sparsity and multiplexing factor. We show promise in modeling the tradeoffs accurately and Auto-PruMUX can find high-performance combinations of known parameters as well as unknown parameters, pro-viding a practical method for choosing a highperformance PruMUX model for a downstream task.
Our key insight for why PruMUX can achieve better throughput than model compression and data multiplexing individually is that they improve the throughput of a model in two different dimensions: reducing the latency of an inference and compressing multiple inferences. In addition, both methods lead to non-linear drops in model accuracy at some points. PruMUX can achieve high throughput while avoiding each method's limitations.

CoFi Pruning
CoFi is a state-of-the-art model compression method (Xia et al., 2022) that uses distillation and structured pruning to jointly prune a Transformer network (Devlin et al., 2018). Its key idea is to distill the knowledge from the base model into the pruned model during training. A layer-wise distillation approach is used to guide the pruning from the teacher model, i.e., dense model, to the student model, i.e., pruned model, with a loss defined as: where H m(i) s and H i t are hidden representations of the m(i)th feed-forward layer of the student model and ith feed-forward layer of the teacher model. i is the teacher model's closest layer to the layer m(i) of the student model. W layer is a linear transformation matrix, initialized as an identity matrix.
CoFi prunes both coarse-grained and finegrained units of the distilled network. The coarsegrained units include multi-head attention layers, fully-connected layers, and attention heads. The fine-grained units include hidden dimensions and intermediate dimensions of the Transformer model. Different masks are used for different pruning units and are learned via ℓ 0 regularization during training. The units with mask variables smaller than a threshold are pruned away before inference.

DataMUX
Data multiplexing (DataMUX) is a recently proposed method (Murahari et al., 2022(Murahari et al., , 2023 to compress multiple inputs into a single "mixed" representation of the same size as a single input to a network, in order to improve inference throughput. DataMUX introduces multiplexing layers, which multiplex different sequences into a single sequence of representations, i.e., multiplexed representations, and demultiplexing layers, which demultiplex/decompress the multiplexed representations. The multiplexed layer first compresses multiple input sequences into a single sequence of representations. These representations are then processed by a Transformer model and the resulting representations are then disentangled into independent representations by the demultiplexer layer. These representations are then used to make predictions. DataMUX, therefore, leads to a many-fold increase in inference throughput as just a single pass through the large Transformer model.
The multiplexing layer is defined as , is the Hadamard product with a fixed Gaussian random vector and N is the number of input sequences that get multiplexed. The multiplexed representations, x 1:N , are then processed by the Transformer model to generate hidden multiplexed representations, h 1:N . The demultiplexer layer, in order to disentangle the hidden multiplexed representation, h 1:N , into independent representations, learns N parameterized demultiplexing functions, ψ i . The independent representations, h i , are then used to make predictions.

Observations
Both model compression and data multiplexing aim at trading small accuracy losses for large inference throughput improvements. When CoFi prunes a Transformer at relatively low sparsities, its accuracy loss is minimal and throughput improvement is significant, but at 95% sparsity, its accuracy loss becomes relatively significant (Xia et al., 2022). DataMUX also shares this nonlinear property, as shown in Figure 1. In other words, the trade-off of each method is good only up to a certain point.
The two methods improve the throughput of a model in two dimensions. CoFi reduces the latency of an inference, whereas DataMUX compresses multiple inferences into one. A natural question is whether combining the two methods can achieve higher throughput with a smaller accuracy loss than each method individually.

PruMUX
Figure 2: Illustration of PruMUX showing a multiplexer, sparse Transformer, and a demultiplexer, with multiplexing width of 10, where 10 input sequences are mixed into 1 input sequence. The multiplexed Transformer model is pruned to reduce inference time. The training for PruMUX consists of three steps including retrieval warm-up, multiplexed model training, and Transformer pruning.
Our key motivational question is the following: given an accuracy loss budget, can the combination of model compression and data multiplexing achieve better throughput than each method individually? In this section, we first present PruMUX, a method to combine the two methods, and then show that PruMUX achieves substantially better throughput than each method alone for various accuracy thresholds in our experimental results.

Method
PruMUX is a method to convert any Transformer into a high throughput model, capable of compressing multiple inference inputs into a single input and executing it at a low latency.
For multiplexing, PruMUX uses the recently proposed DataMUX (Murahari et al., 2023), which appends a multiplexer and demultiplexer as described in Sec 2.2. With width N , the inference throughput of the Transformer can be improved by a factor of up to N , as each multiplexed input takes the same amount of computing resources as performing inference over a single input.
For model compression, PruMUX can use any method such as network pruning, distillation, or a combination of the two (such as CoFi). The goal is to substantially reduce the latency of processing an inference. For our experiments, PruMUX uses CoFi as the model compression method.
Training a model with PruMUX consists of three phases as shown in Figure 2: Phase 1: Priming the multiplexed model with the token retrieval objective We first prime the multiplexed transformer model with a token retrieval task. Murahari et al. (2022) introduced this "retrieval warm-up" self-supervised objective (shown below) and found it to be critical to improve the performance of multiplexed models. L is the length of each input sentence. I is the index of the randomly selected sentence from the input batch.
Phase 2: Pre-training and fine-tuning multiplexed models The multiplexed models from the previous stage are then pre-trained on large-scale text corpora with the masked language modeling (MLM) objective. The pre-trained multiplexed models are then fine-tuned on downstream tasks to yield task-specific multiplexed models.

Implementation Details
We use the pre-trained multiplexed BERT-base models (Murahari et al., 2023) with the standard BERT pre-training recipe with the masked language modeling objective for N = 2, 5, 10 on Wikipedia (Foundation) and BooksCorpus (Zhu et al., 2015) datasets. We prime the multiplexed model before pre-training with the token-retrieval task in Section 2.2 on the Wikipedia and BooksCorpus datasets. We then train the pretrained multiplexed models on the four largest GLUE Tasks ) -MNLI (Williams et al., 2018), QNLI , QQP (qqp), and SST-2 (Socher et al., 2013). We then use the CoFi structured pruning objective to get pruned multiplexed model on each task dataset. The hyperparameters we use for the training process are shown in Appendix A.1. We perform a single run to train the model for each setting, i.e., task, multiplexer width N , model sparsity s, following the training process.

Experiments
Setup We would like to answer the question that given an accuracy threshold, whether PruMUX method can achieve a higher throughput than either CoFi or DataMUX alone. We compare PruMUXed BERT-base model to three baselines: • BERT-base: BERT-base model trained without data multiplexing and model compression.
We have applied PruMUX to the BERT-base model with all combinations of (N, s) for all 4 tasks. We follow the procedure in Xia et al. (2022) to calculate throughput improvements for PruMUXed Transformers and all three baselines, i.e. BERT-base, DataMUX, and CoFi. The evaluation batch size is 128*N , where N is the multiplexer width.
Results Figure 3 shows the throughput improvements and accuracies of PruMUXed, DataMUXed, and CoFi-Pruned Transformers over the Transformer base model on the MNLI, QNLI, QQP, and SST-2 tasks with all available parameters.
The main takeaway is that PruMUX achieves higher throughput than either CoFi or DataMUX individually in all cases starting at various accuracy thresholds: • For MNLI, with the accuracy thresholds from 80% to 74%, PruMUX achieves 7.5-29.5X throughput improvement over the BERT-base model, whereas CoFi improves by 4.0-10.6X and DataMUX by 2.0-4.9X.
The results also confirm the intuition that Pru-MUX with (N, s) incurs an accuracy loss, loosely speaking, close to the sum of the accuracy loss of DataMUX with N and that of CoFi with s. In general, PruMUX can achieve substantial throughput improvement when there is a decent accuracy loss budget.

Discussion
The results above find top PruMUX performance with all parameter pairs (N, s), where N = 2, 5, 10 and s = 0.60, 0.70, 0.80, 0.90, and 0.95, for each accuracy loss budget. Searching for top Pru-MUX parameters at a finer parameter granularity will require training and testing on all additional parameter pairs. Exhaustive tests are impractical. First, for each N , pre-training a DataMUX model with multiplexing width N is time-consuming. Second, given each pre-trained model with multiplexer width N , different sparsities s provide different throughput and accuracy trade-offs. In order to find the sparsity s with the highest throughput given an accuracy budget, one has to train the model for all possible sparsities. The total training time for the sparsities from 0.60 to 0.95 at the granularity of 0.05 for each N takes over six thousand GPU hours on commodity GPUs, for a small original BERT-base model. A key question is whether one can automatically find a high-throughput (N, s) with a small number of PruMUX experiments.

Auto-PruMUX
To address the question above, we propose Auto-PruMUX, a method to search for top (N, s) parameters, to help practitioners balance the performance vs throughput trade-off.
Our research question is: Suppose we have some experimental data of PruMUX and the experimental data of DataMUX and CoFi, how can we find and predict the top parameters (N, s) given an accuracy loss budget?
Our approach is to develop performance models for the accuracy and throughput of PruMUX. We first train PruMUX models for a set of (N, s) combinations and measure both the accuracy and the throughput improvement. We then use this data to fit a throughput model and an accuracy model to predict throughput and accuracy respectively given (N, s) parameters.
We first discuss how we fit the accuracy and throughput models with a set of sparse data points. Given that we are working with a limited set of data points, we opt to use a simple class of interpolation models for modeling PruMUX accuracy and use an estimation model for modeling throughput. We then outline how we leverage these models to predict top (N, s) parameters, given an accuracy loss budget. We then demonstrate the effectiveness of the Auto-PruMUX in predicting the top parameters across a wide range of accuracy loss budgets.

Task Accuracy Model
We use linear interpolation for our task accuarcy model.
Each term is a linear combination of data multiplexer width and model sparsity.
The model is fitted on the gathered data of model task accuracy at different multiplexer width and sparsity.
where N and s are the range of N and s values used to fit the model.

Throughput Model
We collect the throughput values for all N and s on one task (task 0 ) and use the throughput values as the throughput estimations for all tasks.

Predicting (N, s)
We use our models, f A (N, s) and f T (N, s), to model the accuracy and the throughput of Pru-MUX with N > 1 and s > 0%. Acc(1, s) and T hrou(1, s) are the measured accuracy and throughput of CoFi-pruned models. Acc(N, 0) and T hrou(N, 0) are the measured accuracy and throughput of DataMUX models. Acc(1, 0) and T hrou(1, 0) are the performance of BERT-base model. We search for (N, s) parameters that maximize ζ f defined below.
ζ f (N, s) = T hrou(N, s) · g(Acc(N, s)) (1) Intuitively, ζ f tries to tradeoff task performance and throughput, given an accuracy loss budget ξ with the goal of maximizing the throughput. g(x) provides a mechanism for a strict accuracy threshold -i.e. a model that does not meet the minimum required accuracy will have ζ f = 0.

Experimental Results
Experimental setting In this section, we show Auto-PruMUX's prediction results by fitting the performance models using a set of parameter space and predicting top parameters on a larger set of parameter space. We define the set of (N, s) parameter space (test set) as follows. We fit the accuracy model with the model accuracies on (N, s) ∀N ∈ 2,5,10, ∀s ∈ 0.60, 0.70, 0.80, 0.90, 0.95 (training set). We fit the throughput model with the throughput of one task on all parameter pairs.
Our goal is to evaluate the task accuracy model, the throughput model, and parameter prediction performance.
Performance Model Accuracy To evaluate the accuracy of the task performance models on the training set, we perform leave-one-out crossvalidation for each task. We show the fraction M A of accuracy predictions with error falling within ∆ξ = 1.5% from real accuracy in Table 1. To evaluate the accuracy of the throughput model on the training set, we fit the model using PruMUX's performance of the QQP task. We show the fraction M T of throughput predictions with error within 20% of real throughput improvement in Table 1. Across different tasks, our accuracy and throughput models are accurate across a broad set of parameter combinations.

Top Parameter Prediction
We show Auto-PruMUX's prediction results by fitting the accuracy model on the training set and fitting the throughput model using the throughput of the QQP task, and predicting top parameters on the test set. We show Auto-PruMUX's top parameter predictions for accuracy loss budget 3% in Table 2. Auto-PruMUX predicts the actual best parameter pairs within its top 3 predictions. In Table 3, we use  Table 2: Auto-PruMUX top 3 (N, s) predictions for different tasks with accuracy loss budgets of 3% along with their predicted throughput improvements. The actual best parameters (N, s) and their throughput improvements are shown in the last column.
Auto-PruMUX to predict parameters for accuracy loss budgets in 0%, 0.5%, ..., 10% and show the percentage of accuracy loss budgets which Auto-PruMUX predicts the actual best parameter in its top 3 predictions. Auto-PruMUX is able to predict top parameters in most cases.  Table 3: Percentage of accuracy loss budgets in 0%, 0.5%, ..., 9.5%, 10% which Auto-PruMUX predicts the actual best (N, s) parameter in its top 3 predictions.

Related Work Model Compression
Model compression reduces the number of model parameters with minimal loss in task performance. A well-studied method is network pruning, which removes unimportant connections from a network with minimal or no accuracy loss (LeCun et al., 1989;Hanson and Pratt, 1989;Hassibi et al., 1993). Unstructured pruning (Han et al., 2015b,a;Zhu and Gupta, 2017;Frankle and Carbin, 2019;Chen et al., 2020a;Huang et al., 2021;Sanh et al., 2020) does not impose any constraints on the locations of non-zero weights. The resulting network can achieve high sparsity but may not run efficiently on common hardware such as GPUs.
Pruning with distillation objective have been explored (Sanh et al., 2020;Lagunas et al., 2021). (Xia et al., 2022) proposes structured pruning with distillation objective to reduce the Transformer parameters by up to 95% and achieve over 10x speedups with small accuracy drops.

Multi-input Multi-output Models
Multi-input Multi-output models concurrently process multiple inputs within one neural network to reduce network over-parameterization. (Havasi et al., 2021) and (Ramé et al., 2021) train independent sub-networks and ensemble them into a multi-input multi-output model to obtain better accuracy and uncertainty estimation with inference cost similar to a single network. (Murahari et al., 2022) proposes data multiplexing technique to multiplex multiple input sequences into one input sequence to Transformer model, which leads to up to 18x inference speedup. (Murahari et al., 2023) develops pre-trained multiplexed language models to improve model throughput.

Performance Modeling
Various methods have been proposed to estimate the performance of machine learning models. (Justus et al., 2018) proposes a method to predict CNN execution time for training. They decompose CNN training into several components, estimate the time for each component, and predict the model execution time as the combination of different components. (Qi et al., 2017;Cai et al., 2017) predict the performance of deep neural networks based on the neural network models' architecture. (Stamoulis et al., 2018) proposes predictive models for the power and memory of neural networks executing on GPUs. Machine-learning-based cost models (Chen et al., 2018;Bouzidi et al., 2020) have been explored to predict program running time.
Interpolation (Davis, 1975) is widely used in engineering and science (Oliver and Webster, 1990;Keys, 1981;Lehmann et al., 1999), where function values at discrete data points are collected in experiments and the function values at the intervals between discrete data points are estimated using interpolation methods.

Conclusion
We propose PruMUX, a method to combine model compression and data multiplexing to build high throughput transformers. Our implementation of PruMUX makes use of CoFi and DataMUX and we show that it achieves substantial throughput improvement over either CoFi or DataMUX for a large range of accuracy thresholds.
We conclude that the reason that PruMUX performs well in certain range of accuracy loss budgets is that CoFi and DataMUX improve the throughput of a model in two different dimensions: reducing the latency of an inference and compressing multiple inferences. When the accuracy loss budget is large, both methods lead to non-linear drops in model accuracy, PruMUX can achieve much better performance than either approach because it uses more conservative parameters for CoFi and Data-MUX before each reaches its bad trade-off point.
We also present Auto-PruMUX, a meta-model to automatically predict high-performance parameter combinations for a desired accuracy on a task. We show it is promising in predicting parameters without individual data points and additional training.

Limitations
Our experiments are limited to 3 DataMUXed pretrained models (N = 2, 5, and 10) due to compute constraints. More pre-trained models with different N 's would provide PruMUX with more options to improve throughput and would allow us to conduct a more detailed evaluation of Auto-PruMUX.
PruMUX uses CoFi as its model compression method. Experiments with other methods could improve our understanding of the interactions between model compression and data multiplexing. Table 6 shows the sizes and metrics of the datasets in our experiments.

A.3 Potential Risks
Multiplexing with model compression may lead to information leakage between different instances, which can potentially raise privacy concerns if used in a public API serving these models.