DEMix Layers: Disentangling Domains for Modular Language Modeling

We introduce a new domain expert mixture (DEMix) layer that enables conditioning a language model (LM) on the domain of the input text. A DEMix layer is a collection of expert feedforward networks, each specialized to a domain, that makes the LM modular: experts can be mixed, added or removed after initial training. Extensive experiments with autoregressive transformer LMs (up to 1.3B parameters) show that DEMix layers reduce test-time perplexity, increase training efficiency, and enable rapid adaptation with little overhead. We show that mixing experts during inference, using a parameter-free weighted ensemble, allows the model to better generalize to heterogeneous or unseen domains. We also show that experts can be added to iteratively incorporate new domains without forgetting older ones, and that experts can be removed to restrict access to unwanted domains, without additional training. Overall, these results demonstrate benefits of explicitly conditioning on textual domains during language modeling.


Introduction
Conventional language model (LM) training algorithms assume data homogeneity: all parameters are updated to minimize the loss on all of the data. We refer to this approach as dense training. Yet human language is as varied as human experience, a fact researchers often refer to obliquely when they use the term domain to describe distinct underlying subpopulations of the corpus. Dense training leaves variation in the data to be implicitly discovered (Aharoni and Goldberg, 2020), assuming that models will be able to fit all domains equally well.
While dense training is convenient, and densely trained LMs achieve impressive results (Brown et al., 2020), the approach has drawbacks with respect to generalization, efficiency, and flexibility. Even if training data is sourced from many domains, dense training can in practice emphasize subsets of the data in proportion to their ease of access (Oren et al., 2019;Fan et al., 2020), limiting generalization to less prevalent domains. Updating all parameters of the network gets substantially more expensive as model size grows (Strubell et al., 2019), making fine-tuning or domain-adaptive pretraining (DAPT;  harder to perform with smaller computational budgets. It is also difficult to adapt to new domains without forgetting the original data (McCloskey and Cohen, 1989;Aghajanyan et al., 2021) or restrict access to certain domains the LM has been exposed to during training (e.g., those that contain hate speech; Bender et al. 2021), leading to risks of unwanted behavior (Gehman et al., 2020).
To address these limitations of dense training, we argue that LMs should be designed with modularity. We propose a modular LM that has components specialized to distinct domains in the training data, and can be customized at inference-time by mixing, adding, or removing these separated components as needed. This design principle emphasizes the ability to rapidly adapt the LM after training, a need that has been broadly advocated for language systems (Dinan et al., 2021;Lazaridou et al., 2021).
We introduce modularity into an LM with a new domain expert (DEMIX) layer that explicitly conditions the LM on the domain of the input text (when it is known), or estimates the input domain during inference (when it is not known). A DEMIX layer is a drop-in substitute for a feedforward layer in a transformer LM (e.g., GPT-3), creating a specialized version of the layer (or expert) per domain (see Figure 1; §3). 1 We find that replacing every feed-Training Inference Self Figure 1: Illustration of a DEMIX layer in a single transformer block. During training, expert feedforward networks are conditionally activated based on the domain (here, document provenance) of the input sequence (i.e., scientific papers or court opinions). At inference time, the language model has new modular functions: domain experts can be mixed to handle heterogeneous domains, added to adapt to novel domains, or removed to "forget" unwanted domains. Image attribution: news icon from emojipedia.org; all other icons from istockphoto.com. forward layer in the transformer with a DEMIX layer offers new affordances for modularity, addressing the challenges above, while improving performance in both training domains and novel test-time domains.
Although the concept of a domain lacks a rigorous definition in NLP, we use coarse provenance categories (e.g., whether a document is a medical research paper or a Reddit post) as a conditioning variable when training an LM with DEMIX layers ( §2). Training on data from eight different domains, we find that DEMIX layers consistently improve in-domain performance ( §4). However, because these categories may not be an optimal segmentation of the training data, or may lack coverage of test-time domains, naively selecting a single domain expert at test time can hurt generalization. Instead, we introduce a parameter-free probabilistic approach to dynamically estimate a weighted mixture of domains during inference ( §5). Mixing experts improves DEMIX performance not only on novel test-time domains, but also on test data from the training domains, which may themselves be heterogeneous. Our results suggest that introducing modularity into an LM need not come at a cost to generalization performance.
Because DEMIX forces experts to specialize to domains, the overall model can be (partially) disentangled after training. Beyond mixing, we can add ( §6) or remove ( §7) domain experts, resulting in predictable changes in model behavior at inference time: adding experts allows for model adaptation without updating all parameters (hence avoiding forgetting), and removing experts allows for simulating the removal of training domains without additional training. Overall, DEMIX layers demonstrate benefits of explicitly conditioning on textual domains during language modeling, and our results suggest that these benefits persist at scale. Our code is publicly available. 2

Multi-Domain Corpus
We center this study around a large, multi-domain corpus we constructed with explicit provenance metadata (Table 1). While other multi-domain corpora (Koh et al., 2021;Gao et al., 2020) cover many more domains and tasks, the corpus we introduce contains substantial metadata-tagged text for language modeling, as well as datasets with friendly licensing to support reproducibility.

Document Provenance as a Domain Label
While a growing body of work has attempted to address the structure and composition of language domains (Eisenstein et al., 2014;Plank, 2016;Aharoni and Goldberg, 2020;, fundamentally what a domain is remains a matter of debate. In this work, we focus on the provenance of a document, operationalized coarsely by the dataset we used to access it, which approximates a social process that produced it. Defining domains this way is easy and intuitive, conveys a great deal about the variation in a document's language, and aligns with common practice in NLP research. However, other accounts of variation in language (e.g., Lucy and Bamman, 2021), and richer notions of relationships among domains (e.g., hierarchies; Gururangan et al., 2020), may be studied in future work.

Corpus Description
The multi-domain corpus we use in this study consists of two parts. The first is a collection of training domains: text from eight domains of largely English text, listed at the top of Table 1, each of which vary in complexity and coverage and has been the subject of study in NLP. 3 3 The metadata for each document includes at least its provenance, and in some cases more information (e.g., URLs, The second part is a collection of novel domains: text from eight domains also of largely English text, listed at the bottom of Table 1, which may or may not align with the training domains. The novel domains allow us to measure how models generalize to a more challenging data distribution shift, where domain boundaries may be less clear. See Appendix §A.1 for more details on how these data were collected. To support future work with the data, we also release a standard API to download and preprocess it into a format compatible with Fairseq (Ott et al., 2019). 4 We replace user identifiable information (e.g., email addresses, user handles, social security numbers, credit card numbers, phone numbers) with dummy tokens. 5

Background: Mixture-of-Experts Transformers
The transformer architecture is comprised of interleaved multi-head self-attention, layer-norms, and publication venue, or legal jurisdiction). Future work might explore more fine-grained notions of domain. 4 https://github.com/kernelmachine/ demix-data 5 While it is difficult to anonymize data perfectly, especially at scale, we use a suite of regexes to identify commonly occurring identifiable information on the Internet. See Appendix §A.2 for more details. feedforward networks (Vaswani et al., 2017). Each of these layers produces a vector representation for each of the input tokens. Our focus is on the feedforward component: where h t, is the vector for the tth token produced by layer .  propose a formulation of one or more feedforward layers as an ensemble of n experts FFN 1 , . . . , FFN n , assigned weights respectively by functions g 1 , . . . , g n : The g function routes tokens to different experts, usually each a separate instance of the original feedforward network. If g routes to a single expert, then the computational cost (in floating-point operations; FLOPs) will be same as the original feedforward network, even though it has slightly more than n times as many parameters.

DEMIX Routing
Previous approaches learn the weighting functions g at a token-level, and either assign at most one (Fedus et al., 2021) or two (Lepikhin et al., 2020) experts per token. This necessitates load balancing and other techniques to encourage the model to use all experts instead of relying on just a few (Fedus et al., 2021;Lewis et al., 2021).
We instead use domain metadata provided with training documents to route data to experts at the document (i.e., sequence) level. During training, every token in the same sequence is assigned to the same expert based on the domain label.
Let D denote the set of domain labels (i.e., the eight labels in Table 1). If we index the experts by D and d ∈ D is the domain label for the current training instance, then While we assume that each training document is associated with a single domain label, we relax this requirement at inference time ( §5), which improves model performance in mixed and unknown domain scenarios.

DEMIX Architecture
Our design results in one expert in a DEMIX layer per domain (i.e., eight experts for eight training domains in our multi-domain corpus).
We replace every feedforward layer in the transformer with a DEMIX layer, in contrast to previous work (Fedus et al., 2021;Lepikhin et al., 2020) that interleaves shared and expert layers. Preliminary experiments showed that interleaving led to worse in-domain performance with DEMIX layers. We hypothesize that shared layers may serve as a bottleneck to find shared features between domains, and may impact performance adversely when training domains are highly different from one another. 6 Future work might perform careful comparisons of different architectural choices.
In this study, each expert FFN j is a two-layer MLP with the same dimensions as the original FFN layer of the transformer. As with other conditional computation models (Fedus et al., 2021;Lepikhin et al., 2020), this means that the effective number of parameters in the overall DEMIX LM increases (Table 2). While this incurs memory costs, the computational budget we consider in this study centers around runtime costs. DEMIX layers decrease the runtime costs of training the LM.

DEMIX Training
DEMIX layers increase the total parameters of the LM while also reducing GPU latency costs during training, effectively reducing runtime costs of training the LM.
DENSE training (also referred to as dataparallel) is usually implemented by copying model parameters to every GPU, feeding a different minibatch of shuffled data to each GPU, computing a stochastic gradient for each mini-batch, and updating all parameters synchronously with the average stochastic gradient from across all GPUs.
To train an LM with DEMIX layers, we instead partition the GPUs among the domains, so that each GPU is assigned a single domain (along with its corresponding expert). During training, we fill a mini-batch with k sequences, where each sequence represents data from a particular domain, and we send each mini-batch to its dedicated domain expert. We use larger batch sizes by performing data-parallel training between expert parameters on GPUs assigned to the same domain; we assign n/8 GPUs to each domain (Table 2). To reduce overfitting, we ensure that each of these n/8 GPUs is assigned to different shards of their domain's training data.
We compare the training efficiency of DENSE and DEMIX models up to 1.3B parameters per GPU in Table 2. Compared to DENSE LMs, DEMIX layers achieve the same or slightly higher throughput (measured in TFLOPs/GPU) for the same total FLOPs per update, despite adding significantly more parameters.
DEMIX achieves higher throughput because we only synchronize expert parameters allocated to the same domain. 7 As we increase model size, this results in a reduction of latency costs between GPUs, and hence, faster training; instead of synchronizing parameters over n GPUs, we perform eight synchronizations over n/8 GPUs. 8 In this work, we assume that there is sufficient data for each training domain that each expert can be exposed to the same amount of data, and load balancing between experts is not necessary. Future work may consider how varying the amount of data per domain influences absolute and relative performance across domains, especially in the long tail of rare domains.
While the total number of parameters of DEMIX LMs are substantially larger than their DENSE counterparts, since the practical training costs are essentially the same, we compare baselines in all subsequent experiments based on parameters per GPU, as we do in Table 2.

In-Domain Performance
The first set of experiments in this study considers the impact of replacing the conventional feedforward layers in a transformer LM with DEMIX layers. We run all experiments in this section with the training domains (Table 1).

Experimental Setup
Architecture and Input The model architecture is a randomly-initialized LM with the GPT-3 (Brown et al., 2020) architecture implemented in Fairseq (Ott et al., 2019). We experiment with multiple architectures (i.e., those of GPT-3 small, medium, large, and XL), at a maximum size of  Hyperparameters We set the total number of training steps based on this allocated runtime, set 8% of these steps to be warm-up, and use the Adam optimizer (Kingma and Ba, 2017) with a polynomial learning rate decay. Learning rates are tuned for each model separately over {0.0001, 0.0003, 0.0005}, taking the fastest learning rate that avoids divergence. Each worker processes two sequences of length 1,024, and gradients are accumulated over 8 updates. We clip gradients if their L 2 norm exceeds 0.1. See Appendix §A.4 for more details. These settings are inspired by Lewis et al. (2021).
Computational Budget We follow previous work in using runtime as the primary computational budget, which provides a better comparison of the practical costs of training conditional compute and dense models (Lewis et al., 2021). We assume a fixed budget of about 48 hours on NVIDIA V100 32GB GPUs. We display the number of GPUs used for each model size in Table 2; we chose these GPU budgets because larger models require more compute to train properly (Lewis et al., 2021;Kaplan et al., 2020), and found these GPU budgets to result in stable training for each model size given mostly fixed hyperparameters.   (Zellers et al., 2019;Keskar et al., 2019). This baseline provides domain information to the language model in the form of input supervision. We ignore the domain token when computing perplexity during evaluation.

DEMIX (naive)
We replace every feedforward layer in the transformer with a DEMIX layer, as detailed in §3. Under this setting, the domain of the test data is known and revealed to the model (e.g., the CS expert is used for CS test data), which we refer to as naive. We also ensure that the model is exposed to an equal amount of data from each domain.   Next, we observe that the benefits of additional domain information (i.e, domain tokens or DEMIX layers) are clearest for the smallest model; for larger models, the benefits are smaller but consistent. This result suggests that domain-specific information enables the model to better specialize to different domains in its training data. However, as the model size grows, the DENSE baseline becomes increasingly better at fitting the training domains, catching up to models with additional domain information, in the average case.

Domain Hetereogeneity
A more complete view of the experiments with the largest model is shown in Table 4. We see that even at scale, most training domains benefit from DEMIX layers in a naive setting (where the domain label is revealed at test time), but some do not; WEBTEXT, REALNEWS, and REDDIT fare worse than the DENSE baseline. We believe that this variation can be explained by heterogeneity within domains and varying degrees of similarity between them. DENSE training may be advantageous for domains that have a higher degree of overlap with other domains in the corpus (and therefore, benefit from parameter sharing).
To provide further evidence for this explanation, we measure the hetereogeneity of domains in the multi-domain corpus, according to a DEMIX LM. We plot a matrix of the perplexity changes across all domain experts in Figure 2, comparing all experts against the expert explicitly trained for each domain. As the perplexity change tends lower, the corresponding expert has higher affinity to the target domain.
First, we observe that domain experts have the highest affinity to their assigned domain, indicating that they do specialize. We also observe that some experts, e.g., WEBTEXT, REALNEWS, and REDDIT, have relatively high affinities to many domains, suggesting that these domains are hetereogeneous. Separately we observe that an expert's affinity to a domain correlates positively with bigram overlap between the expert domain and target domain (r=0.40, t=3.45, p=0.001). This further suggests that similar domains have more closely aligned domain experts.
These findings suggest that a discrete notion of domain, while usually helpful on average (in our artificially constructed population of eight training domains), is too rigid. In the next section, we introduce new ways of softening Equation 3 into a mixture over domain experts, to improve performance on heterogeneous domains.

Mixing Experts at Inference Time
The previous section establishes that incorporating DEMIX layers improves LM performance on test data from known training domains. At inference time, the domain label was revealed to the model and used to select an expert within each DEMIX layer. In practice, however, text may not come with a domain label, may straddle multiple domains, or may not belong to any of the domains constructed at training time; the provenance of the data may even be unknown.
In these cases, rather than a hard choice among experts (Equation 3), we propose to treat g 1 , . . . , g n as mixture coefficients, transforming the domain membership of an input text into a matter of probabilistic belief. Unlike previously proposed mixtureof-experts formulations Lepikhin et al., 2020), this approach introduces no new parameters and the weights are computed only at test time. 10 To analyze inference-time behavior in mixed or unknown domain scenarios, we turn to the corpus of novel domains in the multi-domain corpus ( Table  1). As mentioned in §2, these domains have fuzzier boundaries, compared to the training domains.

Dynamically Estimating Domain Membership
Consider the probabilistic view of language modeling, where we estimate p(X t | x <t ). We introduce a domain variable, D t , alongside each word. We assume that this hidden variable depends on the history, x <t , so that: This model is reminiscent of class-based n-gram LMs (Brown et al., 1992) and their derivatives (Saul and Pereira, 1997). 10 We choose to explore inference-time mechanisms instead of training mechanisms to mix experts because 1) we want to avoid substantially increasing training costs, i.e., GPU communication between domain experts and 2) we want to maintain the modularity of experts. Exploring mechanisms for training expert mixtures while satisfying these desiderata is a rich area for future work.
x <t " The COVID-19 pandemic is caused by severe acute respiratory syndrome coronavirus-2 (SARS-CoV-2) and has spread worldwide…" FFN 3 FFN 4 FFN 1 Figure 3: Illustration of inference with domain expert mixing. For a given input text x <t from CORD-19, we estimate a posterior domain probabilities p(D t | x <t ), informed by a prior that is either iteratively updated during inference, or is precomputed and cached on heldout data. In this example, the model assigns highest domain probabilities to the medical and news domains. We use these probabilities in a weighted mixture of expert outputs to compute the hidden representation h t .
We have already designed the DEMIX LM to condition on a domain label, giving a form for p(X t | x <t , D t = j). The modification is to treat g 1 , . . . , g n as a posterior probability over domains, calculated at each timestep, given the history so far.
To do this, we apply Bayes' rule: The conditional probabilities of word sequences given a domain label, as noted above, are already defined by the DEMIX LM. For the prior over domain labels, we consider three alternatives: Uniform Fix the prior to be uniform across the known domains.
Updating Set the prior at timestep t to be an exponentially-weighted moving average of the posteriors from previous timesteps: During evaluation, this moving average is calculated over the posterior at the end of each sequence block. The decay factor avoids putting too much weight on calculations made early in the dataset, when posterior calculations are noisier (Appendix §A.6). We performed a small grid search over {0.1, 0.3, 0.5, 1.0} to set the value λ, and found that 0.3 worked well for most settings.
Cached If, prior to testing, some data from the test distribution is available, we calculate the posterior over domain labels from that data, and fix the prior to that estimate. Under this setting, we use 100 sequences (i.e., 102,400 tokens) from the development set to estimate the prior, which we found to result in stable posterior probabilities (see Appendix §A.6 for more details).  We display an illustration of the mixture technique in Figure 3.

Visualizing Domain Membership
In Figure 4, we plot the posteriors, calculated using the updating method above after 100 sequences of development data, each from training and novel domains. This evaluation is carried out using the DEMIX LM with 1.3B parameters per GPU from §4, with no modifications.
For known domains (top heatmap of Figure 4), the correct label has the highest posterior, but these datasets do not appear to be as distinct or mutually exclusive as we assume. For example, Reddit data is estimated to be around 80% REDDIT, 11% WEBTEXT, and 8% REALNEWS. More variation in the estimates is expected and observed for the new domains (bottom heatmap of Figure 4). While ACL PAPERS is mostly associated with the CS domain, and BREAKING NEWS mostly with the WEBTEXT and REALNEWS domains, CORD-19 is spread across MED, REALNEWS, and 1B; YELP REVIEWS across REVIEWS, WEBTEXT, and REDDIT. The alignment of multiple domains like GITHUB and CONTRACTS primarily to WEBTEXT suggests the benefit of including a relatively heterogeneous domain in training.

Experimental Setup
We experiment with the corpus of novel domains (Table 1) to test out-of-distribution performance. We evaluate the three mixture treatments of DEMIX layers (i.e., uniform, updating, and cached priors) against five baselines. Note that no new models are trained for this experiment beyond those used in §4.

DENSE and DENSE (Balanced)
These are the basic baselines trained as in §4; there is no explicit reasoning about domain.

+DOMAIN-TOKEN
Here test data is evaluated using each domain label token, and we choose the lowest among these perplexity values per test set.
DEMIX (naive) Similar to +DOMAIN-TOKEN, we evaluate the data separately with each of the eight experts, and report the lowest among these perplexity values per test set. DEMIX (average) At every timestep, we take a simple average of the eight experts' predictions.

Results
Novel Domain Performance Results averaged across the eight novel domains are summarized in Table 5. Ensembling DEMIX experts outperforms DENSE baselines and using experts individually (i.e., the "naive" baseline), and caching a prior prior to evaluation results in the best average performance. While +DOMAIN-TOKEN is competitive with naively using DEMIX layers in-domain (Table 3), it consistently underperforms DEMIX with a weighted mixture on the novel domains. We observe that ensembling DEMIX experts with a cached prior allows smaller models to match or outperform much larger DENSE models. We also find that weighted ensembling outperforms simple averaging, confirming the importance of sparsity in the expert mixture.
Examining per-domain performance (Appendix §A.5), we find that DEMIX LMs with a cached prior either outperform DENSE baselines or closely match them. The largest improvement against DENSE baselines comes from the TWEETS domain, which are on average 67% better across all model sizes. This domain is heterogeneous according to the DEMIX model (Figure 4), confirming the importance of mixing experts for heterogeneous domains. These results demonstrate that conditioning the LM on domains during training need not come at a large cost to generalization to new domains, and in many cases can provide large boosts in performance over DENSE baselines.
In-Domain Performance We can also apply the expert mixture variant of inference (using a cached prior) to the training domains. We find that doing so is beneficial; see the last line of Table 3.
We see improvements in performance across all domains for every scale, though the largest im-

Adapt new expert, freezing all other parameters
x <t  provements seem to come from hetereogeneous domains (across all model sizes, REDDIT improves on average 10.7%, WEBTEXT 2.4%, REALNEWS 1.9%), again confirming that our intuition that domain metadata may not perfectly align with the most effective domain boundaries.

Adaptive Pretraining with New Experts
Domain-adaptive, continued pretraining 11 of a language model (DAPT) is a way to use unannotated, in-domain text to improve task performance . However, for a large model, DAPT with DENSE training (which we refer to as DENSE-DAPT) is expensive and may not be feasible on some computational budgets. Further-more, DENSE-DAPT may result in forgetting what was learned during earlier training phases, limiting reusability. The modular approach of DEMIX LMs allows the model to avoid forgetting training domains and adapt cheaply: we can train a new expert and add it to the DEMIX layers of the network without updating the other experts or the shared parameters. Because the original model is not changed, forgetting is impossible. We refer to this method of adaptation as DEMIX-DAPT. 12 We display an illustration of DEMIX-DAPT in Figure 5. We instantiate a new expert in each DEMIX feedforward layer, initialize it with the parameters of the pretrained expert nearest to the new domain. We use the posterior calculations from §5 on a held-out sample to choose the most probable expert. We then train the added expert on target data, updating only the new expert parameters. For inference, we use the weighted mixture of domain experts with a cached prior ( §5).

Experimental Setup
We compare DEMIX-DAPT to DENSE-DAPT on all novel domains. We report final test-set perplexity after adapting to each domain for 1 hour with 8 NVIDIA V100 32GB GPUs, tracking validation perplexity every 10 minutes for early stopping. We adapt to each novel domain with the same hyperparameters as the original phase of training ( §4), except for a 10x smaller learning rate.

Results
Adding one expert We display examples of DEMIX-DAPT and DENSE-DAPT on a single additional domain in Figure 6. We observe that while DENSE-DAPT reduces perplexity on the novel domain, its performance on the training domains progressively worsens, displaying the forgetting effect (we show similar results in larger models in Appendix §A.7). In contrast, DEMIX-DAPT reduces perplexity on the novel domain without forgetting.
We generally observe that DEMIX-DAPT outperforms DENSE-DAPT for some domains (e.g., CORD-19 and ACL PAPERS), while it closely approaches DENSE-DAPT for others (e.g., GUTEN-BERG; Appendix §A.5). Overall, the parameters for the additional expert comprise about 10% of the total parameters in the DEMIX model, and DENSE-DAPT involves updating all the parameters of the 12 Our proposed technique is reminiscent of Progressive Neural Networks (Rusu et al., 2016).   Table 6: Average perplexity in training and novel domains before and after adding 8 experts adapted to the novel domains (via DEMIX-DAPT). Adding experts reduces perplexity on all domains, even those previously seen.
model towards in the target domain, so we would expect that DENSE-DAPT outperforms DEMIX-DAPT in some cases. The strong performance of DEMIX-DAPT on domains like CORD-19 and ACL PAPERS suggests that DEMIX-DAPT is especially helpful when the target domain strongly aligns with one of the experts (Figure 4).
Adding eight experts With expert mixing ( §5), newly added experts can be combined with existing ones in the model at test time. To more thoroughly understand the effect of adding more experts to the system, we add all experts adapted to novel domains to the DEMIX model from §4. We display the performance of a DEMIX LM with 16 experts (8 experts trained on training domains, 8 additional experts adapted to novel domains) in Table 6. We generally observe that DEMIX-DAPT reduces perplexity on all domains for all model sizes, again without forgetting. Adding the eight additional experts in fact reduces perplexity on previously seen domains. For example, across all model sizes, on average, we see an 2.4% reduction on MED, 1.8% reduction on RE-ALNEWS, and 2% reduction on REDDIT (Appendix §A.5). These improvements are small, which is expected given that we only performed DEMIX-DAPT for at most one hour with eight GPUs. Even so, these results suggest that DEMIX layers can enable the LM to incorporate knowledge from novel domains to improve its performance on previously seen domains.

Language Models with Removable Parts
Current LM pretraining datasets are rife with undesirable content, from hatespeech to extremism (Gehman et al., 2020;Bender et al., 2021). Another consequence of DENSE training is that it is difficult to restrict the model's access to these problematic domains after training, as might be desirable for many user-facing tasks (Xu et al., 2020;Dinan et al., 2021). DEMIX layers offer new capabilities for lightweight control over the domains in the training data that LMs use to make predictions at inference time. In particular, since DEMIX layer experts specialize to their domain (Figure 2), experts that are assigned to domains that are unwanted at test-time can be simply disabled and unused.
A key question is whether disabling an expert can simulate a model that has not been exposed to that domain, which we study in this section. However, since the self-attention and input embedding parameters in the DEMIX LM are shared across domains, removing an expert offers no guarantee of having fully forgotten content from the removed domain. Establishing such bounds is an important avenue for future work.

Experimental Setup
To evaluate whether we can simulate models that have not been exposed to a particular domain, we compare three settings: +EXPERT A DEMIX LM with all experts active.
-EXPERT A DEMIX LM with a domain expert deactivated.  -DOMAIN A DEMIX LM retrained from scratch without a particular domain. We replace the removed domain with GUTENBERG. 13 We evaluate expert removal (+EXPERT and -EXPERT) with the DEMIX LM with 125M parameters per GPU from §4, with no modifications. For all baselines,we evaluate use expert mixing with a cached prior ( §5).

Results
Removing a domain expert harms model performance on the associated domain, in most cases approaching the performance of a model that has not been exposed to data from that domain (Table 7). In some cases (e.g., WEBTEXT and REALNEWS), -EXPERT even underperforms -DOMAIN. This leads us to conjecture that most domain-specific learning happens within the DEMIX layer, despite the fact that other parts of the model are affected by all training domains.

Related Work
Incorporating Metadata Document metadata has been commonly used to improve the quality of topic models (Mimno and McCallum, 2012;Ramage et al., 2009;Zhu et al., 2012), and previous works have used metadata for adapting RNN-based language models (Jaech and Ostendorf, 2018) or learning better document representations (Card et al., 2018). Zellers et al. (2019) and Keskar et al. (2019) prepend document metadata in the input text (similar to our +DOMAIN-TOKEN setting) while training transformer LMs to provide better inference-time control of text generation.
Inference-time Control DEMIX layers provide a simple mechanism for inference-time control of language model behavior. Previously proposed methods for inference-time control are either expensive to use (Dathathri et al., 2020), or rely on densely trained models (e.g., Keskar et al., 2019). Liu et al. (2021) use multiple experts for inferencetime text generation control. This method may be applied to DEMIX layers to steer text generation with experts trained on different domains.
Multilinguality Related to variation across domains is crosslingual variation. Past work has suggested that multilingual models benefit from language-specific parameters (Fan et al., 2020;Pfeiffer et al., 2020;Chau et al., 2020). Here, we investigate the effect of incorporating domainspecific parameters into the LM. Though the boundaries between languages are (often) more clear than those among domains, DEMIX layers draw inspiration from multilingual research, and future work might explore a compositional approach with both language experts and domain experts.
Continual Learning DEMIX-DAPT is a type of continual learning, in which the model learns incrementally on new data (Chen et al., 2018). Previously proposed techniques to support continual learning include regularization (Kirkpatrick et al., 2017), meta-learning (Munkhdalai and Yu, 2017), episodic memory modules (Lopez-Paz and Ranzato, 2017;de Masson d'Autume et al., 2019), and data replay (Sun et al., 2019), all of which may be combined with DEMIX layers. Model expansion techniques to incorporate new reinforcement learning or visual tasks (Rusu et al., 2016;Draelos et al., 2017) is especially related to DEMIX-DAPT. Our results suggest that continual learning in LMs is naturally enabled with modular domain experts; this may be further explored using temporally-relevant domains (Lazaridou et al., 2021).
LM Adapters Also related to DEMIX-DAPT is the line of work into adapter modules for pretrained LMs (Houlsby et al., 2019;Pfeiffer et al., 2020). Similar to the setting in which we add experts for new domains, adapter modules involve freezing the pretrained language model and updating a small number of additional parameters that are appended to certain parts of the network. This study confirms previous findings that only a subset of LM parameters need to be fine-tuned to a target dataset (Zaken et al., 2021). Expert addition may be performed with adapter modules to further improve efficiency.
Multi-Domain Models Multi-domain models have been studied extensively in the context of machine translation, first with statistical systems (Banerjee et al., 2010;Sennrich et al., 2013), and more recently with neural networks (Pham et al., 2021). Other works have explored multi-domain settings with smaller models and explicit domain labels, using supervision (e.g., Wright and Augenstein, 2020;Guo et al., 2018;Zeng et al., 2018) or dense training (e.g., Maronikolakis and Schütze, 2021). Previous studies have shown the importance considering domains when adapting LMs (Ramponi and Plank, 2020;. Our study establishes the importance of considering domains when training LMs from scratch.

Conclusion
We introduce DEMIX layers for language models, which provide modularity at inference time, addressing limitations of dense training by providing a rapidly adaptable system. DEMIX layers experts can be mixed to handle heterogeneous or unseen domains, added to iteratively incorporate new domains, and removed to restrict unwanted domains.
There are many exciting directions for future work, in addition to those described throughout the paper. They include combining domain and token-level routing, to realize the benefits of modularity while scaling models efficiently. The design of DEMIX layers assumes access to coarse provenance labels (or other metadata) to identify domains in pretraining data; an alternative option is to use unsupervised learning to discover domains in the corpus, which, in concert with domain metadata, may lead to better DEMIX expert assignments. Furthermore, in this work, we study DEMIX layers with a dataset that has a few large domains. In practice, textual domains usually contain many diverse subdomains of varying prevalence. Training DEMIX layers on dataset with a long tail of domains may require automatic measures to cluster smaller domains, or hierarchical experts that are specialized to progressively narrower data distributions.

A.1 Collecting Domains
For most domains, we use the associated sources, listed in Table 1, without modification. For TWEETS, we use the Twitter Academic API. For GUTENBERG, we use the scraping tool provided in https://github.com/ aparrish/gutenberg-dammit. For BREAKING NEWS, we identify a list of factually reliable English news sources, using the list curated by Baly et al. (2018).
Specifically, we filter on "high" factuality in the data provided in this repository: https://github. com/ramybaly/News-Media-Reliability. We then use Newspaper3K (https://newspaper. readthedocs.io/en/latest/) to scrape the latest 1000 articles from each site. After dropping duplicates, we arrive at about 20K articles from 400 news sources.

A.2 Dataset Anonymization
To anonymize certain datasets, we apply a suite of regexes that aim to identify common patterns of user-identifiable data and substitute them with dummy tokens. We display anonymization regexes and associated dummy tokens in Table 8.

A.3 Calculating TFLOPs/GPU
We use the formula presented in Narayanan et al. (2021)

A.5 Per-Domain Results
We display per-domain test results in the spreadsheets at the following link:

A.6 Domain Posterior Calculations
We track calculated domain posteriors over blocks of development data in Figure 7 (training domains) and Figure 8 (novel domains). The calculate domain posteriors are noisier for earlier blocks, stabilizing usually after around 50 blocks. For all experiments, we conservatively use 100 blocks of data to compute the domain posterior, though one may be able to accurately calcuate the domain posterior for some domains with less data.

A.7 Perplexity changes after DENSE-DAPT
In Table 13, we display the average perplexity change after performing DENSE-DAPT on a new domain. We observe that across all model sizes, DENSE-DAPT improves performance in the novel domain, at the cost of a large performance hit in the training domains.