Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup

Luyu Gao, Yunyi Zhang, Jiawei Han, Jamie Callan


Abstract
Contrastive learning has been applied successfully to learn vector representations of text. Previous research demonstrated that learning high-quality representations benefits from batch-wise contrastive loss with a large number of negatives. In practice, the technique of in-batch negative is used, where for each example in a batch, other batch examples’ positives will be taken as its negatives, avoiding encoding extra negatives. This, however, still conditions each example’s loss on all batch examples and requires fitting the entire large batch into GPU memory. This paper introduces a gradient caching technique that decouples backpropagation between contrastive loss and the encoder, removing encoder backward pass data dependency along the batch dimension. As a result, gradients can be computed for one subset of the batch at a time, leading to almost constant memory usage.
Anthology ID:
2021.repl4nlp-1.31
Volume:
Proceedings of the 6th Workshop on Representation Learning for NLP (RepL4NLP-2021)
Month:
August
Year:
2021
Address:
Online
Editors:
Anna Rogers, Iacer Calixto, Ivan Vulić, Naomi Saphra, Nora Kassner, Oana-Maria Camburu, Trapit Bansal, Vered Shwartz
Venue:
RepL4NLP
SIG:
Publisher:
Association for Computational Linguistics
Note:
Pages:
316–321
Language:
URL:
https://aclanthology.org/2021.repl4nlp-1.31
DOI:
10.18653/v1/2021.repl4nlp-1.31
Bibkey:
Cite (ACL):
Luyu Gao, Yunyi Zhang, Jiawei Han, and Jamie Callan. 2021. Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup. In Proceedings of the 6th Workshop on Representation Learning for NLP (RepL4NLP-2021), pages 316–321, Online. Association for Computational Linguistics.
Cite (Informal):
Scaling Deep Contrastive Learning Batch Size under Memory Limited Setup (Gao et al., RepL4NLP 2021)
Copy Citation:
PDF:
https://aclanthology.org/2021.repl4nlp-1.31.pdf
Video:
 https://aclanthology.org/2021.repl4nlp-1.31.mp4
Code
 luyug/GradCache +  additional community code