Improving word mover’s distance by leveraging self-attention matrix

,


Introduction
The task of measuring the semantic textual similarity (STS) of two sentences is important for natural language processing with a variety of applications, including machine translation, summarization, text generation, and question answering (Cer et al., 2017).There are several methods for measuring STS, among which many methods using Optimal Transport (OT) distance have been proposed and have shown good performance (Kusner et al., 2015;Huang et al., 2016;Chen et al., 2019;Yokoi et al., 2020).
OT theory gives a method to measure the difference between two distributions by setting the transportation cost of a unit mass and considering the allocation of the transported mass to minimize the total cost.This allocation is called optimal transport, and the total cost is called the OT distance.A basic form of OT distance is the Wasserstein distance (also known as the Kantorovich-Rubinstein distance), which measures the similarity between sets using the distance between the elements in the sets (Kantorovich, 1960).Word Mover's Distance (WMD) computes Wasserstein distance by considering a sentence as a set of word embeddings (Kusner et al., 2015).
However, WMD does not take into account the word order of sentences, making it difficult to identify paraphrases.Let us see the following illustrative example of paraphrase adapted from (Kusner et al., 2015) and (Zhang et al., 2019): Here (b) is a paraphrase of (a), while (c) has a very different meaning from (a).However, all these sentences have high overlap of words, thus WMD cannot distinguish these pairs.The darker the color, the higher the value of the SAM element (reflecting a real SAM used in Table 1).The difference of SAM values are small for the word alignment in SMD or WSMD; e.g., (obama, speaks) matches (press, greets).The difference is large in WMD; e.g., (obama, speaks) matches (president, greets).See Section 2.
To account for sentence structure, we focused on BERT (Devlin et al., 2019), an attention-based model that has recently achieved remarkable performance in natural language processing tasks.BERT is a masked language model that performs taskspecific fine-tuning after pre-training on a large dataset.By inputting a sentence into the pretrained BERT, we can extract the Self-Attention Matrix (SAM), which represents the dependencies between words in the sentence.The SAM is known to encode information on sentence structure, such as syntactic information (Clark et al., 2019;Htut et al., 2019;Luo, 2021).Therefore, sentence structure can be considered in STS by measuring the distance between SAMs.
We propose a novel method called sentence Structure Mover's Distance (SMD) in Section 5 that measures the optimal transport distance between sentence structures represented in SAMs.SMD aims to find the optimal word alignment by considering the transport cost between elements of SAMs, while WMD considers the transport cost between word embeddings.To compute SMD, we employ the Gromov-Wasserstein (GW) distance (Memoli, 2011), the optimal transport distance for measuring structural similarity between sets.
SMD is not effective by itself, because it measures only the difference between sentence struc-tures, thus SMD is used in combination with WMD.We improve WMD-like methods (WMD and its variants based on Wasserstein distance between sets of word embeddings) by combining SMD with them in Section 6.This proposed method is called Word and sentence Structure Mover's Distance (WSMD).WSMD measures simultaneously the difference between word embeddings and the difference between sentence structures.WSMD employs the Fused Gromov-Wasserstein (FGW) distance (Vayer et al., 2019(Vayer et al., , 2020)), known as the optimal transport distance that combines both Wasserstein and GW distances.
In the following, we denote two sentences as where w i , w j ∈ R d are word embeddings, and n, m are sentence lengths.SAM is denoted A = (A ii ) ∈ R n×n and A = (A jj ) ∈ R m×m , respectively, for s and s .The element A ii is the attention weight from w i to w i , and the element A jj is the attention weight from w j to w j .Each row of SAM is normalized as n i =1 A ii = m j =1 A jj = 1.Two probability distributions on s and s are specified by where u i , u j are weights on w i , w j , respectively.In this paper, unless otherwise stated, the weights on words in a sentence are equal, i.e., the uniform distribution specified as (1)

Illustrative example
Here we explain how our proposed method works for the paraphrase example of Section 1.The original WMD and its improvement with our proposal are computed for this example, and shown in Table 1.We computed these values using the word embeddings taken from the 0th layer of BERT and the SAMs taken from the second head of the 7th layer of BERT.The similarity should be high for the case (1), i.e., sentence pair (a) vs. (b), while the similarity should be low for the case (2), i.e., sentence pair (a) vs. (c).(2) obama speaks to the media in illinois.12.55 13.31 8.03 low the press greets the president in chicago.
Looking at the values of the original WMD for the two cases, we find that they are about the same.Thus WMD failed to give a reasonable sentence similarity.This failure is explained in Fig. 1 by illustrating word embeddings for the case (2).For each word, WMD tries to find the closest word, matching obama to president, media to press, and so on; the word alignment is indicated as the short dotted arrows.WMD is computed by averaging the length of these short arrows, indicating a high sentence similarity contrary to the fact that the similarity is low for the case (2).Note that WMD is actually computed with real-valued transport weight P ij in Section 4; nevertheless, in this example, OT gave P ij = 0 for word pairs without arrows.This matching of WMD is not appropriate given the word order and sentence structure.Our proposed method correctly matches obama to press and media to president.This word alignment is shown as the long arrows with solid lines in Fig. 1.WMD λ in Table 1 is computed by averaging the length of these arrows; i.e., the same formula as WMD but using the OT taking account of the sentence structure.We find that WMD λ increases, indicating a low sentence similarity correctly for the case (2).Fig. 2 explains how our proposed method takes account of the sentence structure.Given a word alignment between two sentences, we can think of a matching of elements between the two SAMs by applying the word alignment to both the rows and the columns.For example, the word alignment of WMD in Fig. 1 induces the matching from (obama, speaks) to (president, greets), (media, speaks) to (press,greets), and so on.SMD minimizes the average difference of SAM values between the matched elements.This recovers the correct word matching with respect to the word order in this example.WSMD minimizes the weighted sum of the objectives for WMD and SMD; we used the mixing ratio λ = 0.5 here.As shown in Table 1, WSMD correctly indicates that the sentences in case (2) are less similar than those in case (1).

Related Work
In this paper, we focus on sentence similarity measures based on OT of word embeddings.WMD (Kusner et al., 2015) uses the uniform distribution for the word weight and the L 2 distance of word embeddings for the transportation cost.By modifying the word weight and the transportation cost, several WMD-like methods are obtained for computing sentence similarity based on the Wasserstein distance between sets of word embeddings.For example, Word Rotator's Distance (WRD) (Yokoi et al., 2020) uses the vector norm of word embedding for the weight and the cosine similarity for the cost.
There are many attempts to improve WMD by incorporating word order and sentence structure information into the weight, cost, and penalty terms as explained below, but none of them considers the optimal transport of sentence structures nor utilizes SAMs from BERT.
In Order-Preserving Wasserstein Distance (OPWD) (Su and Hua, 2019), the temporal difference between w i and w j is measured by the distance between their relative temporal positions (i/n − j/m) 2 , which is then incorporated in the transportation cost and an additional regularization term.Ordered WMD (OWMD) (Liu et al., 2018) computes OPWD between the normalized representations of reordered sentences via hierarchical semantic trees.WMDo (Chow et al., 2019) identifies consecutive words common to two sentences as a chunk, and introduces a penalty term according to the number of chunks; this adds a notion of fluency in machine translation.In Syntax-aware WMD (SynWMD) (Wei et al., 2022), the word weight is computed from the word co-occurrence extracted from the syntactic parse tree of sentences, and a word embedding incorporates those from its subtree of the parse tree.The weight and the cost in SynWMD are called Syntax-aware Word Flow (SWF) and Syntax-aware Word Distance (SWD), respectively.In MoverScore (Zhao et al., 2019), the inverse document frequency (idf) is used for the word weight, and the word embeddings from BERT are used for computing the cost, expecting that BERT embeddings encode information from the whole sentence.BERTscore (Zhang et al., 2020) also uses the idf and BERT embedding, but it employs a greedy matching instead of OT for word alignment.

Optimal Transport of Words
We review the computation of WMD.

Wasserstein distance
Sentences s and s cannot be naively compared because they are generally different in length, and the corresponding words are unknown.We then consider the word transport from s to s to obtain the word alignment.Interpret the sentence s as the amount of mass u i at position w i .First, we denote P ij ∈ [0, 1] the amount of mass transported from position w i to position w j , and consider the transport matrix P = (P ij ) ∈ R n×m ≥0 with nonnegative elements.Next, we define the distance function c(w i , w j ) ∈ R ≥0 as the cost of transporting a unit mass from w i to w j , and specify the distance ma- . Given P and C, the transport from w i to w j costs C ij P ij , and the total cost of transport is (2) Finding the optimal transport matrix P = ( Pij ) that minimizes the total cost, we compute Wasserstein distance between s and s as the minimum value of the total cost n i=1 m j=1 C ij Pij .
In the original WMD, the weight on words is the uniform distribution (1), and the distance function is Euclid distance between word embeddings w i and w j .The distance matrix is defined as Therefore, WMD between s and s is Here, Π(u, u ) is the set of all possible values of the transport matrix P: and u, u are omitted on the left side of (4).
4.3 WMD treats a sentence as a set of words WMD cannot account for the word order of the sentences because it treats a sentence as a set of word embeddings to find the optimal transport distance.Therefore, it is hard to distinguish sentence pairs with significant word overlaps in the case of static word embeddings.For some models such as BERT and ELMo, the problem remains even for dynamic word embeddings that depend on the context because the similarity is still high between word embeddings of the same word (Ethayarajh, 2019).In fact, WMD does not distinguish case (1) and case (2) of Table 1 in Section 2 using the 0th layer of BERT as a statistic word embedding.Moreover, the difference is slight even if the 12th layer of BERT is used as a dynamic word embedding: WMD = 9.17 for case (1) and WMD = 9.26 for case (2).Note, however, for other models such as GPT-2, the similarity between word embeddings of the same word behaves as if they are two random words (Ethayarajh, 2019).

Optimal Transport of Sentence Structures
In this section, we propose to apply the GW distance to the SAM of BERT to measure the optimal transport distance between sentence structures.Later in Section 6, we will attempt to improve WMD using this method.

Gromov-Wasserstein distance
As seen in Section 2, transport of sentence structure from s to s , i.e., transport from A to A , is induced by the transport of words.The transport of SAMs can be used for WMD later with these constraints.
For the i-th word of s, p i = (P i1 , . . ., P im ) denotes the array of transport amount P ij to the j-th word of s .Then, as illustrated in Fig. 3, the transport amount from A ii to s is defined by the outer product of p i and p i ; the transport amount from position A ii to position A jj is P ij P i j .This defines transport between SAMs in a consistent manner with the mass at A ii is m j,j =1 P ij P i j = u i u i , and the mass at In this paper, we specify the cost of transporting a unit mass as

and the total cost of transport from
(5) Finding the optimal transport matrix P that minimizes the total cost, we compute Gromov-Wasserstein (GW) distance (Memoli, 2011) between A and A as the minimum value of (5).

Sentence structure mover's distance
Applying the GW distance to SAMs, we propose the sentence Structure Mover's Distance (SMD), which is the optimal transport distance considering the dependency of words in a sentence.
Note that SMD is not a metric in a strict sense.Since the general definition of GW distance considers the form |A ii − A jj | p and A ii is assumed to be a symmetric distance matrix, using SAM is not GW distance in a strict sense.The reason we consider the case p = 2 is for faster computation.For general p ≥ 1, the computational complexity of GW distance is O(n 2 m 2 ), while for p = 2, there is a known algorithm (Peyre et al., 2016) that improves it to O(n 2 m + nm 2 ).

Optimal Transport of Words and Sentence Structures
Here, we propose an optimal transport distance that simultaneously considers the word embeddings and sentence structure.

Word and sentence structure mover's distance
As seen in Section 4.3, WMD computes Wasserstein distance using Euclid distance of word embeddings, but it cannot handle sentences with different meanings depending on word order.On the other hand, as seen in Section 5.2, SMD computes the GW distance using the SAM of BERT, which encodes the sentence structure, but it cannot handle individual word information like word embeddings.Therefore, by combining WMD-like methods (WMD and its variants obtained by modifying C, u and u ) with SMD, we propose an optimal transport distance Word and Sentence Structure Mover's Distance (WSMD) that utilizes word features and considers word dependency within a sentence.By specifying the mixing ratio parameter λ ∈ [0, 1], we obtain where k = C M /A MSE is computed from By noting n i =1 m j =1 P i j = 1, WSMD = WMD for λ = 0.For λ = 1, WSMD = kSMD.For an intermediate value λ ∈ (0, 1), WSMD considers both WMD and SMD.We normalized λ by introducing the factor k in ( 7).C M and A MSE are interpreted as ( 2) and ( 5), respectively, by specifying P ij = u i u j with the uniform weight (1).
The optimal transport distance that simultaneously considers the Wasserstein and GW distances, as in WSMD, is known as the Fused Gromov-Wasserstein (FGW) distance (Vayer et al., 2019(Vayer et al., , 2020)).However, like SMD, WSMD uses an asymmetric SAM, so WSMD is not a metric in a strict sense.

Decomposition of WSMD into WMD and SMD components
Let P be the optimal transport matrix, i.e., P that attains the minimum of ( 7).Substitute this P into (2) and ( 5), and denote them as WMD λ and SMD λ , respectively.Then we can write In the case (2) of Table 1, we computed WSMD = 8.03, k = 688, λ = 0.5.This is decomposed into the components WMD λ = 13.30and kSMD λ = 2.76.On the other hand, we can also compute WMD = 8.03, kSMD = 2.76 in the same setting.You see that WSMD is not a simple interpolation of WMD and kSMD.This is because the optimal transport matrices in the the computation of WMD, SMD, and WSMD are different, and in WSMD the optimization considers the two components at the same time.

Experiments
We compare the performance of our proposed method and existing baseline methods on the task of measuring semantic textual similarity (STS) between two sentences.

PAWS dataset
Paraphrase Adversaries from Word Scrambling (PAWS) (Zhang et al., 2019) has a binary label for a sentence pair indicating paraphrase (i.e., the sentence pair has the same meaning) or nonparaphrase.The binary classification of these labels can be considered an STS task.If the optimal transport distance for paraphrase pairs is smaller than non-paraphrase pairs, we consider that we have successfully measured the sentence similarity.AUC is used as the evaluation index for this binary classification.
There are two types of PAWS: PAWS Wiki and PAWS QQP constructed from sentences in Wikipedia and Quora, respectively.Table 2 shows the number of sentence pairs and the percentage of paraphrases used in our experiments.For PAWS Wiki , the first 1536 pairs were selected from the dev set.In PAWS QQP , since the test set is not provided, the first 1536 pairs were selected from the train set to make a new dev set, and the original dev set was considered as a test set.

STS Benchmark dataset
STS-B dataset (Cer et al., 2017) has a human annotated gold score for a sentence pair, which is the average value of the similarity of the sentence pair evaluated by multiple annotators on a six point scale.Table 2 shows the number of sentence pairs.Spearman's rank correlation coefficient between the optimal transport distances and the gold scores is used as the evaluation index.

Word embeddings and SAM
We used bert-base-uncased model of BERT from huggingface transformers library (Wolf et al., 2020).We input the sentences into BERT and removed the stopwords, [CLS] and [SEP] from the output tokens.For one of the methods (SynWMD), however, we did not remove the stopwords, since removing them would have deteriorated the performance.
The word embeddings used in the experiments are the static embeddings taken from the 0th layer (BERT0) of BERT and the dynamic embeddings taken from the 12th layer (BERT12), the final layer of BERT.We performed whitening for all the word embeddings, because the representation of BERT is known to be anisotropic (Ethayarajh, 2019).Note that the BERT0 gives only approximation of statistic embeddings, because the segmentation embeddings and the position embeddings are added.BERT computes many SAMs internally.Since bert-base-uncased model has a 12-layer, 12-head attention mechanism, there are 12 × 12 = 144 SAMs that can be extracted.The SAMs and other hyperparameters used were determined by tuning with the dev set.

Baseline methods
The following simple baseline methods for representing a sentence were selected.The similarity for a sentence pair is computed by the cosine similarity of the two vectors.Bag-of-Words (BoW) is a high-dimensional vector whose elements are the frequency of occurrence in a sentence for all words.
Sentence embedding (Sent.Emb.) is simply the average vector of word embeddings in a sentence.

Benchmarking OT methods
The following OT methods computed from word embeddings were selected: WMD (Kusner et al., 2015) and WRD (Yokoi et al., 2020) with the uniform weight as well as the idf weight.OPWD (Su and Hua, 2019) with cost defined by L 2 distance and cosine similarity.WMDo (Chow et al., 2019).SynWMD (Wei et al., 2022) with cost defined by cosine similarity and swd.
We used the software code from Python Optimal Transport (POT) (Flamary et al., 2021) for implementing simple WMD-like methods.We also used OPWD part of the publicly available code for OWMD (Liu et al., 2018), but the other part was not found there.
There are other interesting OT methods such as Structed Optimal Transport (Alvarez-Melis et al., 2018), WE_WPI (Echizen'ya et al., 2019), Recursive Optimal Transport (Wang et al., 2020), but these methods were not included in the experiment because the implementation code could not be found.

Results
Table 3 shows the results for the embeddingindependent methods, and Table 4 shows the results for the methods using word embeddings extracted from BERT0 and BERT12.Performances of the methods are measured by AUC on PAWS dataset and Spearman's rank correlation coefficient on STS-B dataset.Some WMD-like methods (WMD, WRD, SynWMD, WMDo) are combined with SMD (as indicated as WSMD) and compared with the original method to see if an improvement by introducing WSMD distance.For example, WSMD obtained from the original WMD corresponds to WMD+WSMD+uniform in the table.WSMD is not attempted for OPWD, because it is not obtained from WMD by simply modifying the weight and cost.

PAWS dataset
The proposed method WSMD, which combines SMD with WMD-like methods, boosts the performance of existing methods, but with a few exceptions, there was only a slight performance degradation.For example, AUC score on PAWS Wiki with BERT0 is 54.75 for WMD+uniform and 75.23 for WMD+WSMD+uniform, showing the improvement of 20.48 points.The performance improvement is especially large for the static embedding from BERT0.On the other hand, the performance improvement is smaller for the dynamic embedding from BERT12.This may be due to the fact that the sentence structure information is already included in BERT12 embeddings and so WMD-like methods give already high performance.For PAWS Wiki , the best performance was obtained when WSMD was applied to WMDo for both BERT0 and BERT12.For PAWS QQP , the best performance was obtained for BERT0 when WSMD was applied to WRD+idf, and for BERT12 when WSMD was applied to WMDo.It is interesting to note that SMD, which does not use individual word information, performs better than several existing methods using BERT0.This indicates that SMD performs to some extent on the basis of sentence structure alone, but that the combination of information from words and sentence structure yields better results.

STS-B dataset
For most of the WMDs and all of the WRDs in both BERT0 and BERT12, the proposed method WSMD improved performance.However, Sent.Emb. is a strong baseline: WMD-like methods without WSMD beated Sent.Emb.only when WRD+idf and SynWMD for BERT0, and WRD+idf, Syn-WMD and WMDo for BERT12.On the other hand, WRD combined with SMD outperforms Sent.Emb. by a consistent score improvement.In addition, SynWMD did not show any score improvement even when SMD was combined, but this may be because SynWMD is already a method using syntax information and SAM did not play a significant role.

Conclusion
Since WMD treats a sentence as a set of word embeddings and computes sentence similarity, it cannot take into account the word order in the sentence.Therefore, we focused on the fact that the SAM of the input sentence obtained from the pre-trained BERT represents the relationship between words in the sentence and has information on the sentence structure.We proposed an optimal transport distance WSMD that improves existing WMD-like methods by using FGW distance that measures simultaneously the difference between word embeddings and the difference between sentence structures.We conducted experiments on paraphrase identification on PAWS dataset, which contains many overlapping words between two sentences, and confirmed that the proposed method improves the performance.We also observed that the proposed method can improve the performance of existing OT methods on STS-B.Future work includes faster SAM selection and the simultaneous use of multiple SAMs.
deep bidirectional transformers for language understanding.In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pages 4171-4186, Minneapolis, Minnesota.Association for Computational Linguistics.
Word embedding-based automatic MT evaluation metric using word position information.In Proceedings of the 2019 Conference of the North American Chapter of the Association for Computational Linguistics: Human Language Technologies, Volume 1 (Long and Short Papers), pages 1874-1883, Minneapolis, Minnesota.Association for Computational Linguistics.

Figure 1 :
Figure 1: An illustration of OT for word embeddings from sentence 1 to sentence 2. Words are aligned by word similarity in WMD; e.g., obama matches president.Words are aligned by sentence structure in SMD or by word similarity and sentence structure simultaneously in WSMD; e.g., obama matches press.See Section 2.
(a) Obama speaks to the media in Illinois (b) The President greets the press in Chicago.(c) The press greets the President in Chicago.

Figure 2 :
Figure 2: An illustration of OT for SAMs from sentence 1 to sentence 2. To avoid crowding the diagram, arrows are shown in only two SAM elements.The darker the color, the higher the value of the SAM element (reflecting a real SAM used in Table 1).The difference of SAM values are small for the word alignment in SMD or WSMD; e.g., (obama, speaks) matches (press, greets).The difference is large in WMD; e.g., (obama, speaks) matches (president, greets).See Section 2.

Figure 3 :
Figure3: Transportation of sentence structure.The array of transportation from the i-th word is p i and that from the i -th word is p i .The amount of transportation from the (i, i )-th element of SAM is defined by the outer product of p i and p i ; the sentence structure transportation is induced by the word transportation.

Table 2 :
The number of sentence pairs used in our experiments.Train sets are not used.The percentage of paraphrase for PAWS is shown in parentheses.

Table 3 :
Experimental results for methods without word embedding.Scores are AUC on PAWS and Spearman on STS-B (the higher, the better).

Table 4 :
Experimental results for methods with word embedding.Scores are AUC on PAWS and Spearman on STS-B (the higher, the better).The distributed representations taken from the 0th layer and 12th layer of BERT are used.WSMD indicates the score when each method is combined with SMD; the value in parentheses shows the difference from the score without SMD.
Kawin Ethayarajh.2019.How contextual are contextualized word representations?Comparing the geometry of BERT, ELMo, and GPT-2 embeddings.In Proceedings of the 2019 Conference on Empirical Methods in Natural Language Processing and the 9th International Joint Conference on Natural Language Processing (EMNLP-IJCNLP), pages 55-65, Hong Kong, China.Association for Computational Linguistics.