Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup

Contrastive learning has been applied successfully to learn vector representations of text. Previous research demonstrated that learning high-quality representations benefits from batch-wise contrastive loss with a large number of negatives. In practice, the technique of in-batch negative is used, where for each example in a batch, other batch examples’ positives will be taken as its negatives, avoiding encoding extra negatives. This, however, still conditions each example’s loss on all batch examples and requires fitting the entire large batch into GPU memory. This paper introduces a gradient caching technique that decouples backpropagation between contrastive loss and the encoder, removing encoder backward pass data dependency along the batch dimension. As a result, gradients can be computed for one subset of the batch at a time, leading to almost constant memory usage.


Introduction
Contrastive learning learns to encode data into an embedding space such that related data points have closer representations and unrelated ones have further apart ones. Recent works in NLP adopt deep neural nets as encoders and use unsupervised contrastive learning on sentence representation (Giorgi et al., 2020), text retrieval , and language model pre-training tasks . Supervised contrastive learning (Khosla et al., 2020) has also been shown effective in training dense retrievers (Karpukhin et al., 2020;Qu et al., 2020). These works typically use batch-wise contrastive loss, sharing target texts as in-batch negatives. With such a technique, previous works have empirically shown that larger batches help learn better representations. However, computing loss and updating model parameters with respect 1 Our code is at github.com/luyug/GradCache. to a big batch require encoding all batch data and storing all activation, so batch size is limited by total available GPU memory. This limits application and research of contrastive learning methods under memory limited setup, e.g. academia. For example,  pre-train a BERT  passage encoder with a batch size of 4096 while a high-end commercial GPU RTX 2080ti can only fit a batch of 8. The gradient accumulation technique, splitting a large batch into chunks and summing gradients across several backwards, cannot emulate a large batch as each smaller chunk has fewer in-batch negatives.
In this paper, we present a simple technique that thresholds peak memory usage for contrastive learning to almost constant regardless of the batch size. For deep contrastive learning, the memory bottlenecks are at the deep neural network based encoder. We observe that we can separate the backpropagation process of contrastive loss into two parts, from loss to representation, and from representation to model parameter, with the latter being independent across batch examples given the former, detailed in subsection 3.2. We then show in subsection 3.3 that by separately pre-computing the representations' gradient and store them in a cache, we can break the update of the encoder into multiple sub-updates that can fit into the GPU memory. This pre-computation of gradients allows our method to produce the exact same gradient update as training with large batch. Experiments show that with about 20% increase in runtime, our technique enables a single consumer-grade GPU to reproduce the state-of-the-art large batch trained models that used to require multiple professional GPUs.

Related Work
Contrastive Learning First introduced for probablistic language modeling (Mnih and Teh, 2012), Noise Contrastive Estimation (NCE) was later used by Word2Vec (Mikolov et al., 2013) to learn word embedding. Recent works use contrastive learning to unsupervisedly pre-train Chang et al., 2020) as well as supervisedly train dense retriever (Karpukhin et al., 2020), where contrastive loss is used to estimate retrieval probability over the entire corpus. Inspired by SimCLR , constrastive learning is used to learn better sentence representation (Giorgi et al., 2020) and pre-trained language model .
Deep Network Memory Reduction Many existing techniques deal with large and deep models. The gradient checkpoint method attempts to emulate training deep networks by training shallower layers and connecting them with gradient checkpoints and re-computation (Chen et al., 2016). Some methods also use reversible activation functions, allowing internal activation in the network to be recovered throughout back propagation (Gomez et al., 2017;MacKay et al., 2018). However, their effectiveness as part of contrastive encoders has not been confirmed. Recent work also attempts to remove the redundancy in optimizer tracked parameters on each GPU (Rajbhandari et al., 2020). Compared with the aforementioned methods, our method is designed for scaling over the batch size dimension for contrastive learning.

Methodologies
In this section, we formally introduce the notations for contrastive loss and analyze the difficulties of using it on limited hardware. We then show how we can use a Gradient Cache technique to factor the loss so that large batch gradient update can be broken into several sub-updates.

Preliminaries
Under a general formulation, given two classes of data S, T , we want to learn encoders f and g for each such that, given s ∈ S, t ∈ T , encoded representations f (s) and g(t) are close if related and far apart if not related by some distance measurement. For large S and T and deep neural network based f and g, direct training is not tractable, so a common approach is to use a contrastive loss: sample anchors S ⊂ S and targets T ⊂ T as a training batch, where each element s i ∈ S has a related element t r i ∈ T as well as zero or more specially sampled hard negatives. The rest of the random samples in T will be used as in-batch negatives.
Define loss based on dot product as follows: (1) where each summation term depends on the entire set T and requires fitting all of them into memory.
We set temperature τ = 1 in the following discussion for simplicity as in general it only adds a constant multiplier to the gradient.

Analysis of Computation
In this section, we give a mathematical analysis of contrastive loss computation and its gradient. We show that the back propagation process can be divided into two parts, from loss to representation, and from representation to encoder model. The separation then enables us to devise a technique that removes data dependency in encoder parameter update. Suppose the function f is parameterized with Θ and g is parameterized with Λ.
As an extra notation, denote normalized similarity, We note that the summation term for a particular s i or t i is a function of the batch, as, where which prohibits the use of gradient accumulation. We make two observations here: • The partial derivative ∂f (s i ) ∂Θ depends only on s i and Θ while ∂Λ depends only on t j and Λ; and • Computing partial derivatives ∂L ∂f (s i ) and ∂L ∂g(t j ) requires only encoded representations, but not Θ or Λ.
These observations mean back propagation of f (s i ) for data s i can be run independently with its own computation graph and activation if the numerical value of the partial derivative ∂L ∂s i is known. Meanwhile the derivation of ∂L ∂s i requires only numerical values of two sets of representa- A similar argument holds true for g, where we can use representation vectors to compute ∂L ∂t j and back propagate for each g(t j ) independently. In the next section, we will describe how to scale up batch size by precomputing these representation vectors.

Gradient Cache Technique
Given a large batch that does not fit into the available GPU memory for training, we first divide it into a set of sub-batches each of which can fit into memory for gradient computation, denoted as S = {Ŝ 1 ,Ŝ 2 , ..}, T = {T 1 ,T 2 , ..}. The full-batch gradient update is computed by the following steps.
Step1: Graph-less Forward Before gradient computation, we first run an extra encoder forward pass for each batch instance to get its representation. Importantly, this forward pass runs without constructing the computation graph. We collect and store all representations computed.
Step2: Representation Gradient Computation and Caching We then compute the contrastive loss for the batch based on the representation from Step1 and have a corresponding computation graph constructed. Despite the mathematical derivation, automatic differentiation system is used in actual implementation, which automatically supports variations of contrastive loss. A backward pass is then run to populate gradients for each representation. Note that the encoder is not included in this gradient computation. Let u i = ∂L ∂f (s i ) and v i = ∂L ∂g(t i ) , we take these gradient tensors and store them as a Representation Gradient Cache, [u 1 , u 2 , .., v 1 , v 2 , ..].
Step3: Sub-batch Gradient Accumulation We run encoder forward one sub-batch at a time to compute representations and build the corresponding computation graph. We take the sub-batch's representation gradients from the cache and run back propagation through the encoder. Gradients are accumulated for encoder parameters across all sub-batches. Effectively for f we have, where the outer summation enumerates each subbatch and the entire internal summation corresponds to one step of accumulation. Similarly, for g, gradients accumulate based on, Here we can see the equivalence with direct large batch update by combining the two summations.
Step4: Optimization When all sub-batches are processed, we can step the optimizer to update model parameters as if the full batch is processed in a single forward-backward pass. Compared to directly updating with the full batch, which requires memory linear to the number of examples, our method fixes the number of examples in each encoder gradient computation to be the size of sub-batch and therefore requires constant memory for encoder forward-backward pass. The extra data pieces introduced by our method that remain persistent across steps are the representations and their corresponding gradients with the former turned into the latter after representation gradient computation. Consequently, in a general case with data from S and T each represented with d dimension vectors, we only need to store (|S|d + |T |d) floating points in the cache on top of the computation graph. To remind our readers, this is several orders smaller than million-size model parameters.

Multi-GPU Training
When training on multiple GPUs, we need to compute the gradients with all examples across all GPUs. This requires a single additional cross GPU communication after Step1 when all representations are computed. We use an all-gather operation to make all representations available on all GPUs. Denote F n , G n representations on n-th GPU and a total of N device. Step2 runs with gathered representations F all = F 1 ∪ .. ∪ F N and G all = G 1 ∪ .. ∪ G N . While F all and G all are used to compute loss, the n-th GPU only computes gradient of its local representations F n , G n and stores them into cache. No communication happens in Step3, when each GPU independently computes gradient for local representations. Step4 will then perform gradient reduction across GPUs as with standard parallel training.

Experiments
To examine the reliability and computation cost of our method, we implement our method into dense passage retriever (DPR; Karpukhin et al. (2020)) 2 . We use gradient cache to compute DPR's supervised contrastive loss on a single GPU. Following DPR paper, we measure top hit accuracy on the Natural Question Dataset (Kwiatkowski et al., 2019) for different methods. We then examine the training speed of various batch sizes.

Retrieval Accuracy
Compared Systems 1) DPR: the reference number taken from the original paper trained on 8 GPUs, 2) Sequential: update with max batch size that fits into 1 GPU, 3) Accumulation: similar to Sequential but accumulate gradients and update until number of examples matches DPR setup, 4) Cache: training with DPR setup using our gradient cache on 1 GPU. We attempted to run with gradient checkpointing but found it cannot scale to standard DPR batch size on our hardware.
Implementations All runs start with the same random seed and follow DPR training hyperparameters except batch size. Cache uses a batch size of 128 same as DPR and runs with a sub-batch size of 16 for questions and 8 for passages. We also run Cache with a batch size of 512 (BSZ=512) to 2 Our implementation is at: https://github.com/ luyug/GC-DPR examine the behavior of even larger batches. Sequential uses a batch size of 8, the largest that fits into memory. Accumulation will accumulate 16 of size-8 batches. Each question is paired with a positive and a BM25 negative passage. All experiments use a single RTX 2080ti.
Results Accuracy results are shown in Table 1.
We observe that Cache performs better than DPR reference due to randomness in training. Further increasing batch size to 512 can bring in some advantage at top 20/100. Accumulation and Sequential results confirm the importance of a bigger batch and more negatives. For Accumulation which tries to match the batch size but has fewer negatives, we see a drop in performance which is larger towards the top. In the sequential case, a smaller batch incurs higher variance, and the performance further drops. In summary, our Cache method improves over standard methods and matches the performance of large batch training.

Training Speed
In Figure 1,

Extend to Deep Distance Function
Previous discussion assumes a simple parameterless dot product similarity. In general it can also be deep distance function Φ richly parameterized by Ω, formally, This can still scale by introducing an extra Distance Gradient Cache. In the first forward we collect all representations as well as all distances. We compute loss with d ij s and back propagate to get w ij = ∂L ∂d ij , and store them in Distance Gradient Cache, [w 00 , w 01 , .., w 10 , ..]. We can then update Ω in a sub-batch manner, (11) Additionally, we simultaneously compute with the constructed computation graph ∂d ij ∂f (s i ) and ∂d ij ∂g(t j ) and accumulate across batches, and, v j = ∂L ∂g(t j ) = i w ij ∂d ij ∂g(t j ) with which we can build up the Representation Gradient Cache. When all representations' gradients are computed and stored, encoder gradient can be computed with Step3 described in subsection 3.3. In philosophy this method links up two caches. Note this covers early interaction f (s) = s, g(t) = t as a special case.

Conclusion
In this paper, we introduce a gradient cache technique that breaks GPU memory limitations for large batch contrastive learning. We propose to construct a representation gradient cache that removes in-batch data dependency in encoder optimization. Our method produces the exact same gradient update as training with a large batch. We show the 3 We used the gradient checkpoint implemented in Huggingface transformers package method is efficient and capable of preserving accuracy on resource-limited hardware. We believe a critical contribution of our work is providing a large population in the NLP community with access to batch-wise contrastive learning. While many previous works come from people with industry-grade hardware, researchers with limited hardware can now use our technique to reproduce state-of-the-art models and further advance the research without being constrained by available GPU memory.