Computationally Efficient Wasserstein Loss for Structured Labels

The problem of estimating the probability distribution of labels has been widely studied as a label distribution learning (LDL) problem, whose applications include age estimation, emotion analysis, and semantic segmentation. We propose a tree-Wasserstein distance regularized LDL algorithm, focusing on hierarchical text classification tasks. We propose predicting the entire label hierarchy using neural networks, where the similarity between predicted and true labels is measured using the tree-Wasserstein distance. Through experiments using synthetic and real-world datasets, we demonstrate that the proposed method successfully considers the structure of labels during training, and it compares favorably with the Sinkhorn algorithm in terms of computation time and memory usage.


Introduction
Label distribution learning (LDL), which is a generalized framework for performing single/multilabel classification and estimating the probability distribution over labels, is an important machinelearning problem (Geng, 2016). Its applications include age estimation (Geng et al., 2013), emotion estimation (Zhou et al., 2016), head-pose estimation (Geng and Xia, 2014), and semantic segmentation (Gao et al., 2017). In particular, multilabel classification is an important problem in many NLP areas, and has several applications including multi-label text classification (Banerjee et al., 2019;Chalkidis et al., 2019).
Typically, Kullback-Leibler (KL) divergence is used to measure the similarity between two distributions. However, the KL divergence can tend to infinity if the supports of the two distributions do not overlap, resulting in model failure.
To solve this support problem, Wasserstein distance is used instead of KL divergence (Arjovsky et al., 2017). Wasserstein distance is defined as the cost of optimally transporting one probability distribution to match another (Villani, 2009;Peyré and Cuturi, 2018). Because it can compare two probability measures while considering the ground metric, it is more powerful than measurements that do not consider geometrical information.
An LDL framework with Wasserstein distance has been recently proposed (Frogner et al., 2015;Zhao and Zhou, 2018). This framework employs the Sinkhorn algorithm (Cuturi, 2013) to calculate the Wasserstein distance, which requires quadratic computational-time. Thus, when we consider extremely large label-sets, for example, 10 5 , the computation cost can be significant. However, the Wasserstein distance on a tree (hereinafter called tree-Wasserstein distance) can be written in a closed-form and calculated in linear computation time (Evans and Matsen, 2012;Le et al., 2019).
In this paper, we propose a tree-regularized LDL algorithm with a tree-Wasserstein distance. The key advantage of the tree-Wasserstein distance is that it considers the hierarchical label information explicitly, whereas the Sinkhorn-based algorithm needs a cost matrix using tree-structured data. Moreover, the tree-Wasserstein distance has an analytic form that can be computed in linear time using significantly less memory. We experimentally demonstrate that the proposed algorithm compares favorably with the Sinkhorn-based LDL algorithm (Frogner et al., 2015;Zhao and Zhou, 2018) with considerably lower memory consumption and computational costs. We demonstrate that the calculation is more efficient than that of the existing Wasserstein loss.
Contribution: Our contributions are summarized as follows. (1) We propose training a model by minimizing the tree-Wasserstein distance for hierarchical labels, and (2)  Illustration of a tree-structured label with the root "animal".

Problem Setting
We observe n input and output samples {(x 1 , y 1 ), · · · , (x n , y n )} from (X , Y), where X ⊂ R d . We consider the problem of learning a map from a feature space X into P, which is a set of distributions over a finite set Y.
For example, multi-class classification is included in this problem, y, which represents the -th class, and it is expressed as the following onehot vector: where L denotes the total number of classes, and y 1 L = 1. Additionally, 1 L ∈ R L denotes a vector whose elements are all 1. When multi-label classification is considered, P denotes binary vectors that indicate existing labels. For example, if the sample x belongs to classes and , y is given as where y 1 L = 2. Accordingly, we can transform y into a probability vector as p y = y/y 1 L . Notably, we assume that Y is discrete and has a tree structure similar to hierarchical labels.
We aim to estimate the conditional probability vector p y for x by considering the structure information of Y from {(x 1 , p y 1 ), · · · , (x n , p y n )}.

Proposed Method
In this study, we assume Y has a tree-structure. Accordingly, we propose LDL with tree-Wasserstein distance.

Wasserstein distance on tree metrics
Let T be a tree with non-negative weighted edges and N T be the set of nodes of T . A shortest path metric d T : N T × N T → R associated with T is called the tree metric. Let v and v be the nodes in T . Accordingly, d T (v, v ) is equal to the sum of the edge weights along the shortest path between v and v . Next, we know that M T = (N T , d T ) is a metric space and can be naturally derived from T .
It is assumed that T is rooted at r. For each node v, the set of nodes in the sub-tree of T rooted at v is defined as denotes the set of nodes in a unique path from a node v to the root r in T . For each edge e, v e denotes a deeper level node. Figure 1 illustrates a tree-structured label.
Given two probability measures µ, ν supported on M T , the 1-Wasserstein distance between µ and ν is expressed as follows (Evans and Matsen, 2012; Le et al., 2019): where w e denotes the weight of edge e. The key advantage of the tree-Wasserstein distance is that it can be computed with the linear time complexity, whereas the time complexity for the Sinkhorn algorithm is quadratic (Cuturi, 2013).

LDL with tree-Wasserstein distance
We define the tree-Wasserstein regularizer as follows.
Definition 1 (tree-Wasserstein regularizer). Let h θ : X → P be a model with learnable parame- Given input x ∈ X and the ground-truth distribution of y p y ∈ P, then the tree-Wasserstein regularization term T W(x, p y ) is defined as follows: where h θ denotes the prediction model.
Using the tree-Wasserstein regularizer, we pro- pose the following LDL: where is the multi-class Kullback-Leibler loss function, and λ ≥ 0 is its regularization parameter.
where L denotes the number of labels. Unlike the Sinkhorn-Knopp algorithm, we need not compute and hold a distance matrix. For treestructured labels, including hierarchical labels, the tree structure can be used directly as a tree metric. If we have prior knowledge about labels (e.g., similarity), we can set edge-weights using the prior knowledge.
4 Related Work 4.1 Label distribution learning LDL (Geng, 2016) is the task of estimating the distribution of labels from each input. While age estimation (Geng et al., 2013), head-pose estimation (Geng and Xia, 2014), and semantic segmentation (Gao et al., 2017) are well known LDL tasks, in this study, we consider the task of estimating a distribution on a hierarchical structure. The key difference between LDL and a generative model is that the "true" distribution on labels is given in LDL.

Wasserstein distance
Given two probability vectors a, b ∈ R n ≥0 and a distance matrix D ∈ R n×n ≥0 , the 1-Wasserstein distance W 1 (a, b) between a and b is defined as: where Π denotes the set of transport plans such that Π = {P ∈ R n×n ≥0 | P 1 n = a, P 1 n = b}. Because Wasserstein distance can incorporates the ground metric in the comparison of the probability distributions, it has been widely used in applications, including domain adaptation , generative models (Arjovsky et al., 2017), and natural language processing (Kusner et al., 2015). A loss function that uses the Wasserstein distance can improve predictions based on a structure of labels (Frogner et al., 2015;Zhao and Zhou, 2018). Additionally, an entropic optimal transport loss can provide a robustness against noise labels by finding the coupling of the data samples and propagating their labels according to the coupling weight (Damodaran et al., 2020). Frogner et al. (2015) proposed learning using a Wasserstein loss to consider the geometric information in predicting a probability distribution. Because computing a sub-gradient of the exact Wasserstein loss is expensive, they estimated the sub-gradient by introducing an entropicregularization term and using the Sinkhorn-Knopp algorithm. Although they also suggested extending the Wasserstein loss to unnormalized measures, we do not consider this case. Zhao and Zhou (2018) showed that Wasserstein loss influenced LDL in terms of simultaneously learning label correlations and distribution. We proposed learning using an exact Wasserstein distance with efficient computations when the ground metric is represented by a tree. Le et al. (2019) suggested the tree-sliced Wasserstein distance, where the Wasserstein distance is approximated on a continuous space by averaging the Wasserstein distances on tree metrics constructed by dividing that space. An unbalanced variant of the tree-Wasserstein distance has been recently proposed (Sato et al., 2020).

Experiments
We applied our proposed method to LDL on trees based on a synthetic dataset and to multi-label text classification of a hierarchical structure based on a real dataset. We implemented all the methods using Pytorch (Paszke et al., 2019). Our models were optimized using a gradient method with the Adam (Kingma and Ba, 2015) optimizer.
Baselines: We compared our proposed method to the Wasserstein-loss-based LDL framework (Frogner et al., 2015;Zhao and Zhou, 2018) and a multi-class KL loss mentioned in (3). Notably, in the original paper (Zhao and Zhou, 2018), they did not include KL loss and used only Wasserstein loss, but (Frogner et al., 2015) used a linear combination of KL divergence and Wasserstein distance as the loss. To ensure fair comparison, we also report the combination of Wasserstein loss and multi-class KL loss as a strong baseline. Therefore, we set the combination parameter λ = {0, 1 2 , 1} defined in Eq 2 and the weight of all edges to 1. The Wasserstein loss was computed using the Sinkhorn-Knopp algorithm in the log domain (Schmitzer, 2019;Peyré and Cuturi, 2018) on GPUs. For the proposed method, we computed the tree-Wasserstein loss on the CPU and then passed it to the GPU to compute the gradient. Then, we set the number of iterations of the Sinkhorn-Knopp algorithm to 10 and the regularization parameter to 50, respectively.

Synthetic data
We generated a synthetic dataset that comprises pairs of a real vector and a target probability distribution on the nodes of a randomly generated tree. This dataset was created as follows: First, we defined a parametric distribution on a graph. Given a graph, G = (V, E), the shortest path metric, d G , and the probability distribution, F vuσ , over V parameterized by v, u ∈ V, σ > 0 is defined as: Algorithm 1 shows the algorithm used to generate the dataset used in the experiments. In this experiment, we prepared datasets with the distribution on a random tree with 1000 nodes using NetworkX (Hagberg et al., 2008). The size of each of the training and testing datasets is 1000. We set the number of epochs to 500 and the batch size to 10, and we fixed the learning rate at .001. We reported the average scores of the experiments using 10 different random seeds.
Predictive model: We adopted the following model for class : Evaluation Metric: To evaluate predictions from various perspectives, we used the metric listed in Table 3. Notably we adopted the exact Wasserstein distance, called Wasserstein, between the prediction and ground-truth label distributions to assess the extent to which the ground metric was considered in the prediction. In these experiments, we used the Python Optimal Transport (POT) library (Flamary and Courty, 2017) to calculate the exact Wasserstein distance, and the weights of all the edges were set to 1. The other evaluation metrics are the same as those used in (Geng, 2016). The scores of the experiment with synthetic data are presented in Table 1. The proposed linear combinations of KL and T W outperformed the others in terms of Wasserstein and Chebyshev metric, but they performed poorly in terms of the other metrics.

BlurbGenreCollectionEN
In this study, we used the BlurbGenreCollectio-nEN 1 (Cortes and Vapnik, 1995;Lewis et al., 2004) dataset for performing experiments with real data. It comprises advertising descriptions of books from the Penguin Random House webpage. Each instance has one or multiple labels that are hierarchically structured. Because the hierarchical structure of these data is a forest and not a tree, we added a root node to the hierarchical tree. Of the total 91, 892 data samples 64%, 16% and 20% were used in the train, validation, and test sets, respectively. We set the number of epochs to 100 and the batch size to 100, and we fixed the learning rate to .001. We reported the average scores and standard deviations of the experiments using 10 different random seeds. Table 3: Evaluation metrics for LDL. h θ (x) is the predicted distribution of x, and p y is the ground truth distribution of a label y.
Predictive model: We adopted a long-short-termmemory (LSTM) (Hochreiter and Schmidhuber, 1997) model with a hidden state size of 200. Because LSTM can efficiently learn long-term dependencies of time-series data, it has often been used in the natural-language processing domain (Yin et al., 2017;Kuncoro et al., 2018). Additionally, we used fastText  for word embeddings. A fully connected layer exists before the output layer, and the output function is a softmax function.
Evaluation metric: We evaluated prediction accuracy using three metrics, namely pseudo-recall, top-k cost, and receiver operating characteristic area under the curve (ROC-AUC). Pseudo-recall is defined as |P∪L| |L| , where L denotes the set of ground-truth labels, and P is a set that comprises L = |L| labels in descending order of the probability score.
Top-k cost is defined as: where p k denotes the label with the k-th highest probability score. This metric measures how close the predicted top-k labels are to the ground-truth labels. We calculate ROC-AUC using the output distribution of each model as a score vector, which is assigned 1 on the ground truth labels or 0 on the other labels. Table 2 presents the comparison results. Both regularization terms (W 1 and T W) did not have a significant impact on the results.

Computational-efficiency comparison
In the computational efficiency experiment, distributions with 10 2 , 10 3 , 10 4 , and 10 5 supports were prepared. Subsequently, the computation time and memory required to calculate the loss of pairs of Algorithm 1: Generating a synthetic dataset 1 Generate a random tree : G = (V, E), where V = {s 1 , ..., s l } 2 W 1 ← (n × m)-dim random matrix 3 W 2 ← (m × (l + 1))-dim random matrix 4 for i = 1 to N do 5 x i ← n-dimensional random vector random probability distributions on the supports were measured. To avoid calculating a shortestpath distance matrix, we used the matrix (11 − I), where I denotes an identity matrix, as the distance matrix while computing the Wasserstein loss. Additionally, we used a random tree, with edge weights of 1, as a tree metric while computing the tree-Wasserstein loss. We report the average scores of three measurements. Table 4 presents the time and memory required to calculate the losses for various numbers of nodes. T W outperforms the other Wasserstein losses in terms of computation time and is significantly superior in terms of memory consumption. Although W 1 that uses a GPU is faster than the others with 10 3 supports, it cannot calculate the loss with 10 5 supports because the required memory cannot be allocated.

Conclusions
This study proposed the use of a tree-Wasserstein reguralizer for learning. The experimental results indicate that our proposed method can successfully predict the distributions of structured labels and that it outperforms existing Wasserstein loss calculation methods in terms of both computational speed and memory consumption.