Domain Aligned Prefix Averaging for Domain Generalization in Abstractive Summarization

Domain generalization is hitherto an underexplored area applied in abstractive summarization. Moreover, most existing works on domain generalization have sophisticated training algorithms. In this paper, we propose a lightweight, weight averaging based, Domain Aligned Prefix Averaging approach to domain generalization for abstractive summarization. Given a number of source domains, our method first trains a prefix for each one of them. These source prefixes generate summaries for a small number of target domain documents. The similarity of the generated summaries to their corresponding documents is used for calculating weights required to average source prefixes. In DAPA, prefix tuning allows for lightweight finetuning, and weight averaging allows for the computationally efficient addition of new source domains. When evaluated on four diverse summarization domains, DAPA shows comparable or better performance against the baselines, demonstrating the effectiveness of its prefix averaging scheme.


Introduction
Abstractive document summarization aims at filtering the most crucial information in a document to present a concise view of it (Nallapati et al., 2016;See et al., 2017).This document may take the form of a news article (Hermann et al., 2015), a scientific paper (Yasunaga et al., 2019), a dialogue (Gliwa et al., 2019), or a social media post (Kim et al., 2018).The advent of pretrained models (Raffel et al., 2020;Lewis et al., 2020) has significantly improved abstractive summarization on several of the aforementioned domains.However, these approaches require extensive manual labelling of data which limits their use to domains without any labelled data.Given that real-world applications of summarization often face the problem of adapting to new domains, it becomes crucial to develop summarization systems that do well in data-scarce settings by leveraging information from the source domains.
Domain generalization accounts for learning a robust model for unseen domains from a set of source domains.This problem is closely related to transfer learning, multitask learning and domain adaptation all of which involve learning a model from a set of source tasks/domains to perform well on a set of target tasks/domains.However, in the case of domain generalization, labelled data for the target domain is unavailable.Previous works on domain generalization mainly focus on learning domain invariant features (Gulrajani and Lopez-Paz, 2020;Wang, 2020;Li et al., 2018).Such methods work well on classification tasks where learning domain invariant features is sufficient to predict target classes.However, they may be insufficient for language generation tasks which have writing style, grammar as their ingredients (Vu et al., 2022).Moreover, such methods involve sophisticated algorithms for training and cannot be used for lightweight domain generalization.
Prefix tuning (Li and Liang, 2021) is a lightweight approach to adapting pretrained language models to downstream tasks.It augments the transformer self attention via prefix tokens learned through backpropagation on the task data while keeping the pretrained model's parameters frozen.Prompt tuning based approaches have been shown to do well on lifelong learning (Qin and Joty, 2021) and zero-shot domain adaptation (Zhao et al., 2022), which inspires us to adapt it to domain generalization for abstractive summarization.Concurrently, weight averaging has performed well on domain generalization tasks in Computer Vision (Cha et al., 2021;Ramé et al., 2022;Arpit et al., 2021).To improve functional diversity, these methods average model parameters from different runs and/or checkpoints.Matena and Raffel (2021) applied weight aver-aging to NLP tasks.They merged models trained on different tasks/domains through fisher weight averaging.Their promising results motivate us to apply model merging through weight averaging for domain generalization.Keeping in mind the goal of generating a lightweight and parameter-efficient approach, and the benefits brought by weight averaging, we propose a lightweight Domain Aligned Prefix Averaging, DAPA, approach to domain generalization for abstractive summarization.Our algorithm consists of three stages.First, prefixes are trained for each source domain.In the second stage, these source prefixes generate summaries for a small number of unlabelled target domain documents.In the third stage, the target domain prefix is obtained through a weighted average of these source prefixes.A higher document-summary similarity score, calculated from the summaries generated in the second stage, would assign a greater weightage to the corresponding source prefix.Through our prefix averaging scheme, we can identify source prefixes essential to ensure good performance on the target domain.Our extensive experimentation on four domains demonstrates the benefits bought by DAPA.DAPA comes with the following advantages: i) It is a lightweight approach to domain generalization since the backbone pretrained model is frozen and only source prefixes are trained.ii) Through our novel prefix averaging scheme, DAPA is able to generalize well onto target domains.Moreover, freezing the backbone model's parameters further preserves generalization.iii) Our approach supports the efficient addition of new source domains since it only involves recomputing the prefix averaging weights.
To this end, we summarize our contributions as follows: • To the best of our knowledge, we are the first to explore prefix averaging for domain generalization on a language generation task.
• We propose a lightweight Domain Aligned Prefix Averaging, DAPA, approach to domain generalization for abstractive summarization.DAPA first trains prefixes for each source domain, following which it utilizes the summary generation capabilities of these source prefixes to generalize to the target domain.
• Through our experimentation setup we demonstrate the effectiveness of DAPA on domain generalization for abstractive summarization.
The rest of the paper is structured as follows: We explore related works in Section 2. Section 3 describes our proposed approach DAPA.Section 4 and 5 provide results for our domain generation experiments and a set of analysis we conduct on our approach.Section 6 provides the conclusion and thoughts for future work.Section 7 discusses limitations and Section 8 sheds light om ethical risks of our work.

Abstractive Summarization
Abstractive document summarization aims to distill the most critical information in a document to present a concise view of it.Nallapati et al. (2016) used an RNN based sequence to sequence model for abstractive summarization.See et al. (2017) used pointer generator networks to copy words from the input document.Duan et al. (2019) augmented the transformer architecture with a contrastive attention mechanism to ignore the irrelevant parts of the document.Zhang et al. (2020) pretrained a transformer model for summarization.Liu and Liu (2021) used contrastive loss for better re-ranking of summaries generated by pretrained models.Paulus et al. (2018) proposed the use of policy gradient reinforcement learning to alleviate exposure bias.Gehrmann et al. (2018) developed a bottom-up copy attention mechanism to overdetermine phrases in the document that should be included in the summary.Although great progress has been made in advancing state-of-the-art, few works have explored domain generalization for abstractive summarization.In this work, we develop a lightweight prefix averaging based method for domain generalization in abstractive summarization.

Prompt Learning in Language Generation
Prompts are task specific instructions prepended to the pretrained model's input.These task specific instructions are trained on the downstream task data while keeping the pretrained model's parameters frozen.Li and Liang (2021) proposed deep continuous prefixes that are prepended to the self attention layers of transformers.They demonstrated the effectiveness of using prefixes for language generation tasks such as summarization.Qin and Joty (2021) prepended prompts into model embeddings for lifelong learning on language generation tasks.Similar to our approach, they train a separate prompt for each domain, however, their work does not focus on domain generalization.Tan et al. (2021) developed a multistage prompting network for machine translation where the encoder is prompted twice to refine the input representation, and the decoder is prompted once to generate the translation.Schick and Schütze (2020) used manually crafted templates for fixed-prompt tuning of pretrained models for few-shot summarization.Zhao et al. (2021);Dou et al. (2021) used learnable prompts as guiding instructions for summarization.Zhao et al. (2022), similar to our approach, used prefixes to adapt to target domain.However, their approach involves pretraining prefix weights and cannot easily incorporate new source domains.On the other hand, our work takes a weighted average of source prefixes and can easily add new source domains to the mix.

Weight Averaging and Domain Generalization
Domain generalization accounts for learning a robust model for unseen domains from a set of source domains.It has mostly been studied in Computer Vision.The main approaches are based on invariant feature learning (Li et al., 2018;Wang, 2020), data augmentation (Wang et al., 2020a), and meta learning (Wang et al., 2020b).For language generation, Vu et al. (2022) proposed a leave-one-domain-out strategy to fuse adaptors for machine translation.
Recently, weight averaging (Garipov et al., 2018) has been successfully applied to domain generalization.Cha et al. (2021)  Unlike these approaches, we develop a novel mechanism to generate weights for averaging source prefixes, and evaluate our approach on abstractive summarization.Matena and Raffel (2021) also developed a scheme to average model weights.They utilized Fisher information in model parameters for weight averaging.However, they only evaluated on NLU tasks.

Method
We first describe the domain generalization problem in Section 3.1.Then we move on to describe prefix tuning in Section 3.2.Finally, in Section 3.3, we describe our proposed approach, DAPA.

Problem Definition
Let D S = {D S 1 , D S 2 , ..., D S n } be the set of source domains.We denote the target domain with D T .Domain generalization aims to seek a network which generalizes well on D T when trained on D S .We require our model to generate fluent summaries for target domain documents when trained on source domain documents.

Prefix Tuning
We utilize prefix tuning to train a separate prefix for each source domain.We begin by restating the transformer attention: Here, the query matrix Q, the key matrix K, and the value matrix V are obtained through independent linear transformations on the output of the previous layer/encoder.d is the model dimension.
Note that we omit the multihead notation for clarity.
Prefix tuning modifies the transformer attention by adding tunable prefixes to K and V .Consequently K is modified as Here h K and h V represent the key prefix and the value prefix respectively.
Following Li and Liang (2021), we model these prefixes using a two layer MLP as follows: where is a trainable embedding matrix with C as the prefix length.Index j corresponds to source domain D S j .We detail the initialization of E j in Section 4.3.Each source prefix is trained in an end-to-end fashion on its corresponding source domain data.

Domain Aligned Prefix Averaging
Having formulated our problem and described prefix tuning, we now describe our approach, DAPA.

Computing Weights to Average Source Prefixes
DAPA utilizes the summary generation capabilities of source prefixes to generate weights for averaging source prefixes.Let D T m,sample = {x 1 , x 2 , ..., x m } be a set m unlabelled documents from the target domain.In our experiments, we observe that a value of m as small as 50 suffices.Let P S = {P S 1 , P S 2 , ..., P S n } represent the set of source prefixes.Note that P S j = {h j K , h j V }.For target domain document x i , DAPA first generates n summaries pertaining to each source prefix as follows: where M represents the frozen pretrained language model.
Next, it uses an encoder f to generate sentence representations for the summary y j i as r j i = f (y j i ) and the document x i as t i = f (x i ).We use SentenceBERT (Reimers and Gurevych, 2019) as our encoder.Following this, we compute average document-summary cosine similarity scores for each source prefix as follows: The final weights for averaging source prefixes are generated by taking a softmax over the average document-summary similarity scores as: (5)

Prefix Averaging
Given a target domain document x, we wish to generate a summary y with a target prefix obtained by averaging the source prefixes using our weight averaging scheme described in Section 3.3.1.Through W = {w 1 , w2 , ..., w n }, we take a weighted average of the source prefixes as follows: where P T is the target prefix through which the target summary y = M (x; P T ) is generated.Note that test time averaging requires recomputation of h j K and h j V by replacing E j with E T in equation 2. We detail the computation of E T in Section 4.3.

Dataset and Metrics
We use four summarization datasets, each belonging to a different domain.For the news domain we use the CNN/Daily Mail dataset (Hermann et al., 2015); Samsum (Gliwa et al., 2019) for the chat domain; Reddit posts for the social-media domain (Kim et al., 2018); ScisummNet (Yasunaga et al., 2019) for training on the scientific domain and Cl-SciSumm (Jaidka et al., 2018) for testing on the scientific domain.Dataset statistics are presented in Table 1.For evaluation, we report ROUGE-1, ROUGE-2 and ROUGE-L metrics 2 (Lin, 2004).

Baselines
We use the method of empirical risk minimization (ERM) as our primary baseline.It trains the model by minimizing the sum of errors across source domains and examples.For computer vision tasks, Gulrajani and Lopez-Paz ( 2021) have shown that a well tuned ERM baseline performs competitively with several sophisticated methods for domain generalization.Thus, we use it as our primary baseline.We define two variants of ERM: i) ERMfinetune, finetunes the pretrained language model on a combination of all source domains.ii) ERMprefix, prefix-tunes the backbone language model on a combination of all source domains.To validate the efficacy of our weight averaging scheme, we create two variants of DAPA, namely DAPAaverage and DAPA-max.DAPA-average follows w j = 1 n and DAPA-max takes a max pooling operation over the source prefixes instead of averaging them.We also present results for an instantaneous version of DAPA, DAPA-inst.Here, instead of using D T m,sample , we use the current testing document to compute weights W .We also consider four additional baselines as an upper bound to our method, results for which are presented in Appendix A.

Training Details
For our backbone pretrained model, we use T5small (containing roughly 60M parameters) (Raffel et al., 2020).Prefix tuning adds rough 922K test time parameters to T5-small.Prefix length C is fixed to 50 unless otherwise specified.m is also set to 50.We verify our choices for C and m in Section 5. Both, finetuning and prefix tuning experiments are optimized with Adafactor (Shazeer and Stern, 2018).Finetuning uses a maximum learning rate of 5e−4, a square root decay schedule, and a linear warmup of 5000 steps.Prefix tuning uses a constant learning rate of 5e − 3.All other Adafactor specific hyperparameters are left to their default values in HuggingFace-transformers3 (Wolf et al., 2020).We utilize OpenPrompt4 (Ding et al., 2022) and HuggingFace-transformers to implement prefix tuning, and use the sentence-transformers5 implementation for SentenceBERT.
For finetuning, we employ a batch size of 5 with gradient accumulation up to 5 iterations.For prefix tuning, we use a batch size of 5 but without any gradient accumulation.All our experiments are run on a single Nvidia-RTX 2080 Ti machine.One finetuning weight update (via gradient accumulation) takes rough 224 milliseconds and one prefix tuning iteration takes roughly 139 milliseconds.All our models are trained for 10 epochs with early stopping performed through validation ROUGE scores.For ERM-finetune and ERM-prefix, the training process is stopped if the in-domain validation scores for any of the three source domains starts to fall.Each training experiment is carried out only once.
For prefix tuning, we initialize E T with T5 embeddings of the C most frequent sentencepiece6 (Kudo and Richardson, 2018) tokens of D T m,sample .For DAPA-inst, the same process is applied, however instead of D T m,sample , the current test document is used.For source domains, C most frequent tokens are extracted from the train set.
Summary generation uses a beam length of 10 and a repetition penalty of 2.5.All source documents are truncated to 512 sentencepiece tokens and all summaries are truncated to 200 sentencepiece tokens.For the scientific domain, we only include the document's abstract, introduction and conclusion in our input to present the document's most crucial aspects within T5's maximum allowed sequence length.All our results are presented on the test sets of the four domains.

Main Results
Table 2 presents results for our domain generalization experiments.DAPA outperforms all compared methods on the chat domain and outperforms all compared methods on two out of the three ROUGE scores on the news domain.Owing to our prefix averaging method, DAPA demonstrates better generalization capabilities when compared to ERMfinetune and ERM-prefix.DAPA-max and DAPAaverage do not utilize the summary generation capabilities of source prefixes and, thus, fail to account       for the aspects most crucial to summary generation for the two domains.Also, DAPA-inst performs significantly worse than DAPA, thereby emphasizing the importance of using a greater number of target domain documents to better approximate the weights for averaging source prefixes.
On the contrary, ERM-finetune outperforms DAPA on the scientific domain.Yasunaga et al. (2019) demonstrate the superior performance of extractive approaches over abstractive summarization approaches for the scientific domain.Thus, the ability to copy phrases from the input document becomes imperative to a good performance.
Possibly, since all model parameters are tuned for ERM-finetune, its ability to copy phrases from the input document exceeds that of DAPA which only tunes the source prefixes.Despite this, DAPA outperforms the other compared methods for reasons similar to the ones stated previously.Also, all compared methods outperform DAPA on the social-media domain.We find that DAPA allocates a significant amount of weight to the scientific domain prefix.However, we observe that the scientific domain adversely affects performance on the social-media domain (Refer to Section 5.1 for more details).This irregularity may have resulted from the encoder f .Further investigation to this is left for future work.Only 27.94% of the maximal weights selected by DAPA-max belong to the scientific domain prefix.Also, DAPA-average assigns equal weights to all the three source domains which is less than the weight assigned by DAPA to the scientific domain.Thus both DAPAaverage and DAPA-max outperform DAPA.The noise added by DAPA-inst to the weight calculation process results in a smaller weight assigned to the scientific domain prefix (0.35 vs 1.00) as a result of which DAPA-inst outperforms DAPA.Owing to the larger dataset size for the chat and news domains, ERM-prefix and ERM-finetune are less impacted by adverse effects of the scientific domain and thus outperform DAPA.ERM-finetune underperforms ERM-prefix probably because of its larger capacity to retain scientific domain knowledge.

Analysis
To study the impact of various factors in DAPA, and better understand its efficacy, we conduct a series of analysis.

Effect of Source Domains on the Target Domain
In Table 3, and Table 4, we analyze the effect of various source domains on the target domain.
Throughout these experiments, m = 50 and C = 50.Note that the experiments in this section only require recomputation of the prefix averaging weights and thus support the claim that DAPA allows for computationally efficient addition of new source domains.For the scientific domain, we can see that the performance is best when only using the news domain, in fact, it outperforms ERM-finetune from Table 2. Adding the chat and social-media domains only hampers performance.The performance is worst when using only the social-media domain.The improvement over this result indicates that DAPA is able to assign appropriate weights to the three source domains allowing for a greater contribution from the news and chat domains.In realworld applications where the number of domains is significantly greater than our setting, and there is no labelled data to measure the performance over the target domain (to decide the optimal set source prefixes to be averaged), DAPA offers an effective scheme to choose the most appropriate source prefixes.In the case of the chat domain, ablating the scientific domain results in degraded performance.On the contrary, excluding the other two prefixes does not affect DAPA's performance.Similarly for the news domain, removing the scientific domain's prefix results in a significant drop in all three ROUGE scores and removing the chat domain results in a slight drop in ROUGE-1 score.Here, we see that both the chat, and the scientific domain contribute to the performance since using only one of them underperforms the result in Table 2.
In the case of the social-media domain, excluding the scientific domain significantly improves DAPA's performance.Also, both the news and the chat domain contribute to the performance since using only one of them underperforms their weighted average.

Effect of Prefix Length C
An analysis on the effect of C is presented in Figure 2. C most frequent words in the target domain are extracted from m = 50 documents.In general, we observe the target domain performance increases up to C = 50 following which it either drops or remains more or less the same.Thus, we select C = 50 for DAPA.scientif ic domain ROUGE scores drop significantly after 50 prefix tokens.We leave further exploration into this for future work.

Effect of m on W
An analysis on the effect of the number of sentences used for computing weights to average source prefixes is presented in Figure 3. Beyond m = 20, the performance on the target domain remains more or less constant.Thus, we stick to our initial choice of using 50 sentences to compute weights W .
An analysis on the effect of the number of sentences used for obtaining C most frequent tokens is presented in Figure 4. Again, the performance does not vary significantly beyond m = 20.Thus, we hold to our initial choice of m = 50 for our main experiments.Note that we do not include results on the scientific domain for this subsection since its test set has only 10 instances, and we use all of them for our experiments.

Does Averaging over Source E j s Help?
In our main method, we initialize E T with C most frequent sentencepiece tokens from the m unlabelled documents.Here, we explore an alternative way of initializing E T wherein we use w j s to take a weighted average of source E j s : Results for this initialization scheme (DAPAembed) are presented in Table 5. DAPA outperforms DAPA-embed across domains demonstrating the benefits of supplying P T with some prior target domain knowledge by initializing E T with C most frequent sentencepiece tokens from the m unlabelled documents.

Does Softmax before Summation Help?
Here, we propose an alternative way of computing w j .Instead of summing over the documentsummary cosine similarities (Equation 4) and then applying the softmax operation (Equation 5), we first apply the softmax operation to documentsummary similarity scores following which we average over them.That is, we replace Equation 4Approach   Table 6: ROUGE scores for an alternative way of computing w j s as discussed in Section 5.5, i.e.DAPA-alt. with: and Equation 5with By doing so, we are flattening the weights w j .This is evident from Figure 5c where the target domain assigns near equal weights to all three source domains.This is different from DAPA, where the weights w j are sharp as depicted in Figure 5a.In Table 6, DAPA outperforms this alternative strategy (DAPA-alt) on three out of the four domains.A flattened w j distribution approaches DAPA-average, and thus, does not benefit from DAPA's weight averaging scheme.DAPA-alt outperforms DAPA on the social-media domain since it assigns near equal weights to each source domain (Refer to Section 4.4 for a detailed analysis).

Is DAPA correlated to Document Similarity?
We analyze how DAPA's weight assigning process aligns with source-target domain similarity at the document level.For this we train BERT-base (Devlin et al., 2019) on 300 source domain documents (100 from each source domain) for source domain identification.We evaluate this model on the target domain's D T 50,sample .We plot the average probabilities assigned to each source domain in Figure 5b.Also, in Figure 5a, we plot weights W computed by DAPA.Note that BERT-base achieves perfect test accuracy when evaluated on an in-domain validation set for each training setting.Entries of the form [x, x] are always zero owing to our domain generalization setting.
From the two plots, it is clear that DAPA's weight assignment process does not always correlate with source-target domain similarity at the document level.For the news domain, as per Table 3 and  Table 4, the scientific domain contributes most to the model's performance, on the other hand, BERT assigns near equal probability to the social-media and chat domain, and assigns near zero probability to the scientific domain.Similarly, for the chat domain, the scientific domain is vital to good performance, however, BERT assigns a low probability to it.Whereas, for the social-media and scientific domain, BERT does a better job and assigns a higher probability to the news domain.These results navigate us to the conclusion that DAPA does not use document level similarities and indeed relies on the summary generation capabilities of source prefixes.

Conclusion
In this paper, we present DAPA, a lightweight, domain aligned prefix averaging approach to domain generalization in abstractive summarization.DAPA utilizes source prefixes to generate summaries for a small number of target domain documents.The similarity of these summaries to their corresponding documents are used for calculating weights required to average source prefixes.DAPA can easily account for the addition of new source domains since only the prefix averaging weights need to be recomputed.On four diverse summarization domains, DAPA either performs comparably or outperforms the baselines.We also perform an in-depth analysis of various components of DAPA to further strengthen our design choices.In future, we would like to develop an improved similarity function f and analyze the loss landscapes of these models to corroborate our prefix averaging strategy.

Limitations
Our work focuses on domain generalization for abstractive summarization through prefix averaging.However, we do not experiment with larger backbone models due to computational constraints.Based on previous works we expect our approach's performance to improve with model size.Also, a larger sequence length for prefix tuning increases the computational costs at inference.
Another limitation of our work is that we do not test it on natural language understanding tasks.This can be part of a future work.

Ethical Statement
We consider our approach to have low ethical risks since we do not utilize any data biases.Our approach could be extended to any natural language generation task and does not constraint the input/output structure.We therefore conclude that our method would not bring any harmful ethical impact.

Figure 1 :
Figure 1: Overview of the Domain Aligned Prefix Averaging model.Source prefixes are used for generating summaries for m target domain documents.The similarity of these summaries to their corresponding documents are used for computing weights required to average source prefixes.

Figure 2 :
Figure 2: Variation of ROUGE scores with prefix length C. Empirically, C = 50 is the most optimal prefix length.Throughout this experiment, m = 50.

Figure 3 :
Figure 3: Variation of ROUGE scores with the number of sentences used for computing weights W .Throughout this experiment, C = 50 and E T is initialized with 50 most frequent tokens from D T m,sample .

Figure 4 :
Figure 4: Variation of ROUGE scores with the number of sentences used for deriving the C most frequent words.Throughout this experiment, C = 50.
Probabilities assigned to the target domain by DAPA-alt.

Figure 5 :
Figure 5: Source domain preferences for the target domain obtained through DAPA, a BERT model trained for source domain identification and DAPA-alt.

Table 1 :
Dataset statistics for each domain.

Table 3 :
ROUGE scores while using only two source domains to compute W .The second column mentions the domain whose prefix has been left out for the target domain's prefix computation.

Table 4 :
ROUGE scores while using a single source domain's prefix to generate target domain summaries.The second column mentions the source domain adopted to generate summaries on the target domain.