SynJax: Structured Probability Distributions for JAX

,


Introduction
In many domains, data can be seen as having some structure explaining how its parts fit into a larger whole. This structure is often latent, and it varies depending on the task. For examples of discrete structures in natural language consider Figure 1. The words together form a sequence. Each word in a sequence is assigned a part-of-speech tag. These tags are dependent on each other, forming a linearchain marked in red. The words in the sentence can be grouped together into small disjoint contiguous groups by sentence segmentation, shown with bubbles. A deeper analysis of language would show that the groupings can be done recursively and thereby produce a syntactic tree structure. Structures can also relate two languages. For instance, in the same figure, a Japanese translation can be mapped to an English source by an alignment. These structures are not specific to language. Similar structures appear in biology as well. Nucleotides in RNA sequences are matched with monotone alignment (Needleman and Wunsch, 1970;Wang and Xu, 2011), genomic data is segmented into contiguous groups (Day et al., 2007) and tree-based models of RNA capture the hierarchical nature of the protein folding process (Sakakibara et al., 1994;Hockenmaier et al., 2007;Huang et al., 2019).
Most contemporary deep learning models attempt to predict output variables directly from the input without any explicit modeling of the intermediate structure. Modeling structure explicitly could improve these models in multiple ways. First, it could allow for better generalization trough the right inductive biases (Sartran et al., 2022). This would improve not only sample efficiency but also downstream performance (Bastings et al., 2017;Nȃdejde et al., 2017;Bisk and Tran, 2018). Explicit modeling of structure can also enable incorporation of problem specific algorithms (e.g. finding shortest paths; Pogančić et al., 2020;Niepert et al., 2021) or constraints (e.g. enforcing alignment Mena et al., 2018 or enforcing compositional calculation Havrylov et al., 2019). Discrete structure also allows for better interpretability of the model's decisions (Bastings et al., 2019). Finally, sometimes structure is the end goal of learning itself -for example we may know that there is a hidden structure of a particular form explaining the data, but its specifics are not known and need to be discovered (Kim et al., 2019;Paulus et al., 2020).
Auto-regressive models are the main approach used for modeling sequences. Non-sequential structures are sometimes linearized and approximated with a sequential structure (Choe and Charniak, 2016). These models are powerful as they do not make any independence assumptions and can be trained on large amounts of data. While sampling from auto-regressive models is typically tractable, other common inference problems like finding the optimal structure or marginalizing over hidden variables are not tractable. Approximately solving these tasks with auto-regressive models requires using biased or high-variance approximations that are often computationally expensive, making them difficult to deploy in large-scale models.
Alternative to auto-regressive models are models over factor graphs that factorize in the same way as the target structure. These models can efficiently compute all inference problems of interest exactly by using specialized algorithms. Despite the fact that each structure needs a different algorithm, we do not need a specialized algorithm for each inference task (argmax, sampling, marginals, entropy etc.). As we will show later, SynJax uses automatic differentiation to derive many quantities from just a single function per structure type.
Large-scale deep learning has been enabled by easy to use libraries that run on hardware accelerators. Research into structured distributions for deep learning has been held back by the lack of ergonomic libraries that would provide acceleratorfriendly implementations of structure componentsespecially since these components depend on algorithms that often do not map directly onto available deep learning primitives, unlike Transformer models. This is the problem that SynJax addresses by providing easy to use structure primitives that compose within JAX machine learning framework.
If we just want to change the type of trees slightly by requiring that the trees follow the projectivity constraint as users we only need to change one flag and SynJax will in the background use completely different algorithms that are appropriate for that structure: it will use Kuhlmann's algorithm (2011) for argmax and variations of Eisner's (1996) algorithm for other quantities. The user does not need to implement those algorithms or even be aware of their specifics, that is all covered by SynJax, and can focus on the modeling side of their problem.

Structured Distributions
Distributions over most structures can be expressed with factor graphs -bipartite graphs that have random variables and factors between them. We associate to each factor a non-negative scalar, called potential, for each possible assignment of the random variables that are in its neighbourhood. The potential of the structure is a product of its factors: where t is a structure, e is a factor, and ϕ(·) is the potential function. The probability of a structure can be found by normalizing its potential: where T is the set of all possible structures and Z is a normalization often called partition function. It can be thought of as a softmax equivalent over an extremely large set of structured outputs that share sub-structures (Sutton and McCallum, 2007).

Computing Probability of a Structure and Partition Function
Equation 2 shows the definition of the probability of a structure in a factor graph. Computing the numerator is often trivial. However, computing the denominator, the partition function, is the complicated and computationally demanding part because the set of valid structures T is usually exponentially large and require specialized algorithms for each type of structure. As we will see later, the algorithm for implementing the partition function accounts for the majority of the code needed to add support for a structured distribution, as most of the other properties can be derived from it. Here we document the algorithms for each structure.

Sequence Tagging
Sequence tagging can be modelled with Linear-Chain CRF (Lafferty et al., 2001). The partition function for linear-chain models is computed with the forward algorithm (Rabiner, 1990). The computational complexity is O(m 2 n) for m tags and sequence of length n. Särkkä and García-Fernández (2021) have proposed a parallel version of this algorithm that has parallel computational complexity O(m 3 log n) which is efficient for m≪n. Rush (2020) reports a speedup using this parallel method for Torch-Struct, however in our case the original forward algorithm gave better performance both in terms of speed and memory. The implementation of Linear-Chain CRF supports having a different transition matrix for each time step which gives greater flexibility in implementing models where a neural network predicts CRF parameters for each position such as in models like LSTM-CNN-CRF (Ma and Hovy, 2016) and Neural Hidden Markov Model (Tran et al., 2016).

Segmentation with Semi-Markov CRF
Joint segmentation and tagging can be done with a generalization of linear-chain called Semi-Markov CRF (Sarawagi and Cohen, 2004;Abdel-Hamid et al., 2013;Lu et al., 2016). It has a similar parametrization with transition matrices except that here transitions can jump over multiple tokens. The partition function is computed with an adjusted version of the forward algorithm that runs in O(sm 2 n) where s is the maximal size of a segment.

Alignment Distributions
Alignment distributions are used in time series analysis (Cuturi and Blondel, 2017), RNA sequence alignment (Wang and Xu, 2011), semantic parsing (Lyu and Titov, 2018) and many other areas.

Monotone Alignment
Monotone alignment between two sequences of lengths n and m allows for a tractable partition function that can be computed in O(nm) time using the Needleman-Wunsch (1970) algorithm.

CTC
Connectionist Temporal Classification (CTC, Graves et al., 2006;Hannun, 2017) is a monotone alignment model widely used for speech recognition and non-auto-regressive machine translation models. It is distinct from the standard monotone alignment because it requires special treatment of the blank symbol that provides jumps in the alignment table. It is implemented with an adjusted version of Needleman-Wunsch algorithm.

Non-Monotone 1-on-1 Alignment
This is a bijective alignment that directly maps elements between two sets given their matching score. Computing partition function for this distribution is intractable (Valiant, 1979), but we can compute some other useful quantities (see Section 5).

Tree-CRF
Today's most popular constituency parser by Kitaev et al. (2019) uses a global model with factors defined over labelled spans. Stern et al. (2017) have shown that inference in this model can be done efficiently with a custom version of the CKY algorithm in O(mn 2 + n 3 ) where m is number of non-terminals and n is the sentence length.

PCFG
Probabilistic Context-Free Grammars (PCFG) are a generative model over constituency trees where each grammar rule is associated with a locally normalized probability. These rules serve as a template which, when it gets expanded, generates jointly a constituency tree together with words as leaves.
SynJax computes the partition function using a vectorized form of the CKY algorithm that runs in cubic time. Computing a probability of a tree is in principle simple: just enumerate the rules of the tree, look up their probability in the grammar and multiply the found probabilities. However, extracting rules from the set of labelled spans requires many sparse operations that are non-trivial to vectorize. We use an alternative approach where we use sticky span log-potentials to serve as a mask for each constituent: constituents that are part of the tree have sticky log-potentials 0 while those that are not are −∞. With sticky log-potentials set in this way computing log-partition provides a log-probability of a tree of interest.

TD-PCFG
Tensor-Decomposition PCFG (TD-PCFG, Cohen et al., 2013;Yang et al., 2022) uses a lower rank tensor approximation of PCFG that makes inference with much larger number of non-terminals feasible.

Spanning Trees
Spanning trees appear in the literature in many different forms and definitions. We take a spanning tree to be any subgraph that connects all nodes and does not have cycles. We divide spanning tree CRF distributions by the following three properties: directed or undirected Undirected spanning trees are defined over symmetric weighted adjacency matrices i.e. over undirected graphs. Directed spanning trees are defined over directed graphs with special root node. projective or non-projective Projectivity is a constraint that appears often in NLP. It constrains the spanning tree over words not to have crossing edges. Non-projective spanning tree is just a regular spanning tree -i.e. it may not satisfy the projectivity constraint. single root edge or multi root edges NLP applications usually require that there can be only one edge coming out of the root (Zmigrod et al., 2020). Single root edge spanning trees satisfy that constraint.
Each of these choices has direct consequences on which algorithm should be used for probabilistic inference. SynJax abstracts away this from the user and offers a unified interface where the user only needs to provide the weighted adjacency matrix and set the three mentioned boolean values. Given the three booleans SynJax can pick the correct and most optimal algorithm. In total, these parameters define distributions over 8 different types of spanning tree structures all unified in the same interface.
We are not aware of any other library providing this set of unified features for spanning trees.
We reduce undirected case to the rooted directed case due to bijection. For projective rooted directed spanning trees we use Eisner's algorithm for computation of the partition function (Eisner, 1996). The partition function of Non-Projective spanning trees is computed using Matrix-Tree Theorem (Tutte, 1984;Koo et al., 2007;Smith and Smith, 2007).

Computing Marginals
In many cases we would like to know the probability of a particular part of structure appearing, regardless of the structure that contains it. In other words, we want to marginalize (i.e. sum) the probability of all the structures that contain that part: where 1[·] is the indicator function, T is the set of all structures and T e is the set of structures that contain factor/part e.
Computing these factors was usually done using specialized algorithms such as Inside-Outside or Forward-Backward. However, those solutions do not work on distributions that cannot use belief propagation like Non-Projective Spanning Trees. A more general solution is to use an identity that relates gradients of factor's potentials with respect to the log-partition function: This means that we can use any differentiable implementation of log-partition function as a forward pass and apply backpropagation to compute the marginal probability (Darwiche, 2003). Eisner (2016) has made an explicit connection that "Inside-outside and forward backward algorithms are just backprop". This approach also works for Non-Projective Spanning Trees that do not fit belief propagation framework (Zmigrod et al., 2021).
For template models like PCFG, we use again the sticky log-potentials because usually we are not interested in marginal probability of the rules but in the marginal probability of the instantiated constituents. The derivative of log-partition with respect to the constituent's sticky log-potential will give us marginal probability of that constituent.
For finding the log-potential of the most probable structure we can just run the same belief propagation algorithm, but with the max-plus semiring (Goodman, 1999). To get the most probable structure, and not just its potential, we can compute the gradient of potentials of parts with respect to the structure's log-potential (Rush, 2020).
The only exceptions to this process are nonmonotone alignments and spanning trees. For the highest scoring non-monotone alignment, we use the Jonker-Volgenant algorithm as implemented in SciPy (Crouse, 2016;Virtanen et al., 2020). Maximal projective spanning tree can be found by combining Eisner's algorithm with max-plus semiring, but we have found Kuhlmann's tabulated arc-hybrid algorithm to be much faster (Kuhlmann et al., 2011) (see Figure 4 in the appendix). This algorithm cannot be used for any inference task other than argmax because it allows for spurious derivations. To enforce single-root constraint with Kuhlmann's algorithm we use the Reweighting trick from Stanojević and Cohen (2021). For nonprojective maximum spanning trees SynJax uses a combination of Reweighting+Tarjan algorithm as proposed in Stanojević and Cohen (2021).

Sampling a Structure
Strictly speaking, there is no proper sampling semiring because semirings cannot have nondeterministic output. However, we can abuse the semiring framework and make some aspect of them non-deterministic. First one is by Aziz (2015) that tracks two numbers: inside probability (same as with log-semiring) and sampled structure score (similar to max-plus semiring except that it is sampled using inside scores). Second approach is by Rush (2020) which in the forward pass behaves the same as log-semiring, while in the backward pass instead of computing the gradient it does sampling. This is in line of how forwardfiltering backward-sampling algorithm works (Murphy, 2012, §17.4.5). We use Rush's version of sampling semiring as it was faster in our experiments.
Non-Projective Spanning Trees do not support the semiring framework so we use custom algorithms for them described in Stanojević (2022). It contains Colbourn's algorithm that has a fixed runtime of O(n 3 ) but is prone to numerical issues because it requires matrix-inversion (Colbourn et al., 1996), and Wilson's algorithm that is more numer-ically stable but has a runtime that depends on concrete values of log-potentials (Wilson, 1996). SynJax also supports vectorized sampling without replacement (SWOR) from Stanojević (2022).

Entropy and KL Divergence
To compute the cross-entropy and KL divergence, we will assume that the two distributions factorize in exactly the same way. Like some other properties, cross-entropy can also be computed with the appropriate semirings (Hwa, 2000;Eisner, 2002;Cortes et al., 2008;Chang et al., 2023), but those approaches would not work on Non-Projective Spanning Tree distributions. There is a surprisingly simple solution that works across all distributions that factorize in the same way and has appeared in a couple of works in the past (Li and Eisner, 2009;Martins et al., 2010;Zmigrod et al., 2021). Here we give a full derivation for cross-entropy: This reduces the computation of cross-entropy to finding marginal probabilities of one distribution, and finding log-partition of the other -both of which can be computed efficiently for all distributions in SynJax. Given the method for computing cross-entropy, finding entropy is trivial: KL divergence is easy to compute too:

Library Design
Each distribution has different complex shape constraints which makes it complicated to document and implement all the checks that verify that the user has provided the right arguments. The jaxtyping library 1 was extremely valuable in making SynJax code concise, documented and automatically checked.
Structured algorithms require complex broadcasting, reshaping operations and application of semirings. To make this code simple, we took the einsum implementation from the core JAX code and modified it to support arbitrary semirings. This made the code shorter and easier to read.
Most inference algorithms apply a large number of elementwise and reshaping operations that are in general fast but create a large number of intermediate tensors that occupy memory. To speed this up we use checkpointing (Griewank, 1992) to avoid memorization of tensors that can be recomputed quickly. That has improved memory usage and speed, especially on TPU.
All functions that could be vectorized are written in pure JAX. Those that cannot, like Wilson sampling (1996) and Tarjan's algorithm (1977), are implemented with Numba (Lam et al., 2015).
All SynJax distributions inherit from Equinox modules (Kidger and Garcia, 2021) which makes them simultaneously PyTrees and dataclasses. Thereby all SynJax distributions can be transformed with jax.vmap and are compatible with any JAX neural framework, e.g. Haiku and Flax.
9 Comparison to alternative libraries JAX has a couple of libraries for probabilistic modeling. Distrax (Babuschkin et al., 2020) and Tensorflow-Probability JAX substrate (Dillon et al., 2017) provide continuous distributions. NumPyro (Phan et al., 2019) and Oryx provide probabilistic programming. DynaMax (Chang et al., 2022) brings state space models to JAX and includes an implementation of HMMs.
PGMax (Zhou et al., 2023) is a beliefpropagation library for JAX that supports structured inference, but expects from the user to implement the dynamic programming algorithms for each structure. SynJax on the other hand implements all the necessary algorithms for the supported structured distributions and user only need to provide distribution's parameters. Additionally, SynJax provides unbiased samples, while PGMax provides biased samples via perturb-andmap framework (Papandreou and Yuille, 2011).
Optax (Babuschkin et al., 2020) provides CTC loss implementation for JAX but without support for computation of optimal alignment, marginals over alignment links, sampling alignments etc.
All the mentioned JAX libraries focus on continuous or categorical distributions and, with the  exception of HMMs and CTC loss, do not contain distributions provided by SynJax. SynJax fills this gap in the JAX ecosystem and enables easier construction of structured probability models. The most comparable library in terms of features is Torch-Struct (Rush, 2020) that targets Py-Torch as its underlying framework. Torch-Struct, just like SynJax, uses automatic differentiation for efficient inference. We will point out here only the main differences that would be of relevance to users. SynJax supports larger number of distributions and algorithms and gives a unified interface to all of them. It also provides reproducable sampling trough controlled randomness seeds. SynJax has a more general approach to computation of entropy that does not depend on semirings and therefore applies to all distributions. SynJax is fully implemented in Python and compiled with jax.jit and numba.jit while Torch-Struct does not use any compiler optimizations except a custom CUDA kernel for semiring matrix multiplication. If we compare lines of code and speed (Table 1) we can see that SynJax is much more concise and faster than Torch-Struct (see Appendix A for details).
Non-Projective Spanning Trees are not covered by any of the alternatives above. The main libraries for these types of trees are by Zmigrod et al. and Stanojević and Cohen. SynJax builds on Stanojević and Cohen code and annotates it with Numba instructions which makes it many times faster than any other alternative (see Figure 3 in the appendix).

Conclusion
One of the main challenges in creating deep neural models over structured distributions is the difficulty of their implementation on modern hardware accelerators. SynJax addresses this problem and makes large scale training of structured models feasible and easy in JAX. We hope that this will encourage research into finding alternatives to auto-regressive modeling of structured data.

Limitations
SynJax is quite fast, but there are still some areas where the improvements could be made. One of the main speed and memory bottlenecks is usage of big temporary tensors in the dynamic programming algorithms needed for computation of log-partition function. This could be optimized with custom kernels written in JAX-Triton 2 . There are some speed gains that would conceptually be simple but they depend on having a specialized hardware. For instance, matrix multiplication with semirings currently does not use hardware acceleration for matrix multiplication, such as TensorCore on GPU, but instead does calculation with regular CUDA cores. We have tried to address this with log-einsum-exp trick (Peharz et al., 2020) but the resulting computation was less numerically precise than using a regular log-semiring with broadcasting. Maximum spanning tree algorithm would be much faster if it could be vectorized -currently it's executing as an optimized Numba CPU code. Guangyao Zhou, and Fabio Viola. 2020  Pro+. The results are shown in Table 1 in the main text. Table 2 shows sizes of the distributions being tested.
A.2 Comparison with Zmigrod et al.
Non-Projective spanning trees present a particular challenge because they cannot be vectorized easily due to dynamic structures that are involved in the algorithm. The main algorithms and libraries for parsing this type of trees are from Zmigrod et al. (2020) 5 and Stanojević and Cohen (2021) 6 . The first one is expressed as a recursive algorithm, while the second one operates over arrays of fixed size in iterative way. This makes Stanojević and Cohen algorithm much more amendable to Numba optimization. We took that code and just annotated it with Numba primitives. This made the algorithm significantly faster, especially for big graphs, as can be seen from Figure 3.

A.3 Comparison of Maximum Projective Spanning Tree Algorithms
Eisner's algorithm is virtually the only projective parsing algorithm actively used, if we do not count 5 https://github.com/rycolab/spanningtrees 6 https://github.com/stanojevic/ Fast-MST-Algorithm the transition based parsers. We have found that replacing Eisner's algorithm with Kuhlmann et al. (2011) tabulation of arc-hybrid algorithm can provide large speed gains both on CPU and GPU. See Figure 4. In this implementation graph size does not make a big difference because it is implemented in a vectorized way so most operations are parallelized.