SynJax: Structured Probability Distributions for JAX
Miloš
Stanojević
author
Laurent
Sartran
author
2023-12
text
Proceedings of the 2023 Conference on Empirical Methods in Natural Language Processing: System Demonstrations
Yansong
Feng
editor
Els
Lefever
editor
Association for Computational Linguistics
Singapore
conference publication
The development of deep learning software libraries enabled significant progress in the field by allowing users to focus on modeling, while letting the library to take care of the tedious and time-consuming task of optimizing execution for modern hardware accelerators. However, this has benefited only particular types of deep learning models, such as Transformers, whose primitives map easily to the vectorized computation. The models that explicitly account for structured objects, such as trees and segmentations, did not benefit equally because they require custom algorithms that are difficult to implement in a vectorized form. SynJax directly addresses this problem by providing an efficient vectorized implementation of inference algorithms for structured distributions covering alignment, tagging, segmentation, constituency trees and spanning trees. This is done by exploiting the connection between algorithms for automatic differentiation and probabilistic inference. With SynJax we can build large-scale differentiable models that explicitly model structure in the data. The code is available at https://github.com/google-deepmind/synjax
stanojevic-sartran-2023-synjax
10.18653/v1/2023.emnlp-demo.32
https://aclanthology.org/2023.emnlp-demo.32
2023-12
353
364