Efficient Shapley Values Estimation by Amortization for Text Classification

Despite the popularity of Shapley Values in explaining neural text classification models, computing them is prohibitive for large pretrained models due to a large number of model evaluations. In practice, Shapley Values are often estimated with a small number of stochastic model evaluations. However, we show that the estimated Shapley Values are sensitive to random seed choices – the top-ranked features often have little overlap across different seeds, especially on examples with longer input texts. This can only be mitigated by aggregating thousands of model evaluations, which on the other hand, induces substantial computational overheads. To mitigate the trade-off between stability and efficiency, we develop an amortized model that directly predicts each input feature’s Shapley Value without additional model evaluations. It is trained on a set of examples whose Shapley Values are estimated from a large number of model evaluations to ensure stability. Experimental results on two text classification datasets demonstrate that our amortized model estimates Shapley Values accurately with up to 60 times speedup compared to traditional methods. Further, our model does not suffer from stability issues as inference is deterministic. We release our code at https://github.com/yangalan123/Amortized-Interpretability.


Introduction
Many powerful natural language processing (NLP) models used in commercial systems only allow users to access model outputs. When these systems are applied in high-stakes domains, such as healthcare, finance, and law, it is essential to interpret how these models come to their decisions. To this end, post-hoc black-box explanation methods have been proposed to identify the input features that are most critical to model predictions (Ribeiro et al., 2016;Lundberg and Lee, 2017). A famous class of post- * * Work done during full-time work at AWS AI Seed=1 Seed=2 Figure 1: Heatmaps of explanation scores of an example from Yelp-Polarity based on two runs of KernelSHAP (KS) using different random seeds. KS is run on a fine-tuned BERT model using 200 samples per instance (approx. 3.47s per instance on average using a single A100 GPU, more than 150 times slower than one forward inference of the BERT model). The darker each token is, the higher its explanation score. Clearly, interpretation results are significantly different when using different seeds. hoc black-box local explanation methods takes advantage of the Shapley Values (Shapley, 1953) to identify important input features, such as Shapley Value Sampling (SVS) (Strumbelj and Kononenko, 2010) and KernelSHAP (KS) (Lundberg and Lee, 2017). These methods typically start by sampling permutations of the input features ("perturbation samples") and aggregating model output changes over the perturbation samples. Then, they assign an explanation score for each input feature to indicate its contribution to the prediction.
Despite the widespread usage of Shapley Values methods, we observe that when they are applied to text data, the estimated explanation score for each token varies significantly with the random seeds used for sampling. Figure 1 shows an example of interpreting a BERT-based sentiment classifier (Devlin et al., 2019) on Yelp-Polarity dataset, a restaurant review dataset (Zhang et al., 2015)   varies significantly when using different random seeds. They become stable only when the number of perturbation samples increases to more than 2,000. As KS requires model prediction for each perturbation sample, the inference cost can be substantial. For example, it takes about 183 seconds to interpret each instance in Yelp-Polarity using the KS Captum implementation (Kokhlikyan et al., 2020) on an A100 GPU. In addition, this issue becomes more severe when the input text gets longer, as more perturbation samples are needed for reliable estimation of Shapley Values. This sensitivity to the sampling process leads to an unreliable interpretation of the model predictions and hinders developers from understanding model behavior.
To achieve a better trade-off between efficiency and stability, we propose a simple yet effective amortization method to estimate the explanation scores. Motivated by the observation that different instances might share a similar set of important words (e.g., in sentiment classification, emotional words are strong label indicators (Taboada et al., 2011)), an amortized model can leverage similar interpretation patterns across instances when predicting the explanation scores. Specifically, we amortize the cost of computing explanation scores by precomputing them on a set of training examples and train an amortized model to predict the explanation scores given the input. At inference time, our amortized model directly outputs explanation scores for new instances. Although we need to collect a training set for every model we wish to interpret, our experiments show that with as few as 5000 training instances, the amortized model achieves high estimation accuracy. We show our proposed amortized model in Figure 2.
The experimental results demonstrate the effi-ciency and effectiveness of our approach. First, our model reduces the computation time from about 3.47s per instance to less than 50ms, 1 which is 60 times faster than the baseline methods. Second, our model is robust to randomness in training (e.g., random initialization, random seeds used for generating reference explanation scores in the training dataset), and produces stable estimations over different random seeds. Third, we show that the amortized model can be used along with SVS to perform local adaption, i.e., adapting to specific instances at inference time, thus further improving performance if more computation is available (6.3). Finally, we evaluate our model from the functionality perspective (Doshi-Velez and Kim, 2017;Ye and Durrett, 2022) by examining the quality of the explanation in downstream tasks. We perform case studies on feature selection and domain calibration using the estimated explanation scores, and show that our method outperforms the computationally expensive KS method.

Related Works
Post-Hoc Local Explanation Methods Post-hoc local explanations are proposed to understand the prediction process of neural models (Simonyan et al., 2014;Ribeiro et al., 2016;Lundberg and Lee, 2017;Shrikumar et al., 2017). They work by assigning an explanation score to each feature (e.g., a token) in an instance ("local") to indicate its contribution to the model prediction. In this paper, we focus on studying KernelSHAP (KS) (Lundberg and Lee, 2017), an additive feature attribution method that estimates the Shapley Value (Shapley, 1953) for each feature.
There are other interpretability methods in NLP. For example, gradient-based methods (Simonyan et al., 2014;Li et al., 2016), which use the gradient w.r.t. each input dimension as a measure for its saliency. Reference-based methods (Shrikumar et al., 2017;Sundararajan et al., 2017) consider the model output difference between the original input and reference input (e.g., zero embedding vectors).

Shapley Values Estimation
Shapley Values are concepts from game theory to attribute total contribution to individual features. However, in practice estimating Shapley values requires prohibitively high cost for computation, especially when explaining the prediction on long documents in NLP. KS works as an efficient way to approximate Shapley Values. Previous work on estimating Shapley Values mainly focuses on accelerating the sampling process (Jethani et al., 2021;Parvez and Chang, 2021;Mitchell et al., 2022) or removing redundant features (Aas et al., 2021;. In this work, we propose a new method to combat this challenge by training an amortized model.

Robustness of Local Explanation Methods
Despite being widely adopted, there has been a long discussion on the actual quality of explanation methods. Recently, people have found that explanation methods can assign substantially different attributions to similar inputs (Alvarez-Melis and Jaakkola, 2018; Ghorbani et al., 2019;Kindermans et al., 2019;Yeh et al., 2019;Slack et al., 2021;Yin et al., 2022), i.e., they are not robust enough, which adds to the concerns about how faithful these explanations are (Doshi-Velez and Kim, 2017;Adebayo et al., 2018;Jacovi and Goldberg, 2020). In addition to previous work focusing on robustness against input perturbations, we demonstrate that even just changing the random seeds can cause the estimated Shapley Values to be weakly-correlated with each other, unless a large number of perturbation samples are used (which incurs high computational cost).
Amortized Explanation Methods Our method is similar to recent works on amortized explanation models including CXPlain (Schwab and Karlen, 2019) and FastSHAP (Jethani et al., 2021)), where they also aim to improve the computational efficiency of explanation methods. The key differences are: 1) We do not make causal assumptions between input features and model outputs; and 2) we focus on text domains, where each feature is a discrete token (typical optimization methods for continuous variables do not directly apply).

Background
In this section, we briefly review the basics of Shapley Values, focusing on its application to the text classification task. Local explanation of black-box text classification models. In text classification tasks, inputs are usually sequences of discrete tokens X = [w 1 , w 2 , . . . , w L ]. Here L is the length of X and may vary across examples; w j is the j-th token of X. The classification model M CLF takes the input X and predict the label aŝ y = arg max y∈Y M CLF (X) [y]. Local explanation methods treat each data instance independently and compute an explanation score ϕ(j, y), representing the contribution of w j to the label y. Usually, we care about the explanation scores when y =ŷ. Shapley Values (SV) are concepts from game theory originally developed to assign credits in cooperative games (Shapley, 1953;Strumbelj and Kononenko, 2010;Lundberg and Lee, 2017;. Let s ∈ {0, 1} L be a masking of the input and define X s def = {w i } i:s i =1 as the perturbed input that consists of unmasked tokens x i (where the corresponding mask s i has a value of 1). In this paper, we follow the common practice (Ye et al., 2021;Ye and Durrett, 2022;Yin et al., 2022) to replace masked tokens with [PAD] in the input before sending it to the classifier. Let |s| represent the number of non-zero terms in s. Shapley Values ϕ SV (i, y) (Shapley, 1953) are computed by: Intuitively, ϕ SV (i, y) computes the marginal contributions of each token to the model prediction.
Computing SV is known to be NP-hard (Deng and Papadimitriou, 1994). In practice, we estimate Shapley Values approximately for efficiency. Shapley Values Sampling (SVS) (Castro et al., 2009;Strumbelj and Kononenko, 2010) is a widely-used Monte-Carlo estimator of SV: Here σ j ∈ Π(L) is the sampled ordering and [σ j ] is the non-ordered set of indices for σ j . [σ j ] i−1 represents the set of indices ranked lower than i in . m is the number of perturbation samples used for computing SVS.
KernelSHAP Although SVS has successfully reduced the exponential time complexity to polynomial, it still requires sampling permutations and needs to do sequential updates following sampled orderings and computing the explanation scores, which is an apparent efficiency bottleneck. Lundberg and Lee (2017) introduce a more efficient estimator, KernelSHAP (KS), which allows better parallelism and computing explanation scores for all tokens at once using linear regression. That is achieved by showing that computing SV is equivalent to solving the following optimization problem: where ⃗ s(k) is the one-hot vector corresponding to the mask 2 s(k) sampled from the Shapley Kernel . m is again the number of perturbation samples. We will use "SVS-m" and "KS-m" in the rest of the paper to indicate the sample size for SVS and KS. In practice, the specific perturbation samples depend on the random seed of the sampler, and we will show that the explanation scores are highly sensitive to the random seed under a small sample size. Note that the larger the number of perturbation samples, the more model evaluations are required for a single instance, which can be computationally expensive for large Transformer models. Therefore, the main performance bottleneck is the number of model evaluations.

Stability of Local Explanation
One of the most common applications of SV is feature selection, which selects the most important features by following the order of the explanation scores. People commonly use KS with an affordable number of perturbation samples in practice (the typical numbers of perturbation samples used in the literature are around 25,200,2000). However, as we see in Figure 1, the ranking of the scores can be quite sensitive to random seeds when using stochastic estimation of SV. In this section, we investigate this stability issue. We demonstrate stochastic approximation of SV is unstable in text classification tasks under common settings, especially with long texts. In particular, when ranking input tokens based on explanation scores, Spearman's correlation between rankings across different runs is low. Measuring ranking stability. Given explanation scores produced by different random seeds using an SV estimator, we want to measure the difference between these scores. Specifically, we are interested in the difference in the rankings of the scores as this is what we use for feature selection. To measure the ranking stability of multiple runs using different random seeds, we compute Spearman's correlation between any two of them and use the average Spearman's correlation as the measure of the ranking stability. In addition, we follow Ghorbani et al. (2019) to report Top-K intersections between two rankings, since in many applications only the top features are of explanatory interest. We measure the size of the intersection of Top-K features from two different runs. Setup. We conduct our experiments on the validation set of the Yelp-Polarity dataset (Zhang et al., 2015) and MNLI dataset (Williams et al., 2018). Yelp-Polarity is a binary sentiment classification task and MNLI is a three-way textual entailment classification task. We conduct experiments on 500 random samples with balanced labels (we refer to these datasets as "Stability Evaluation Sets" subsequently). Results are averaged over 5 different random seeds. 3 We use the publicly available finetuned BERT-base-uncased checkpoints 4 (Morris et al., 2020) as the target models to interpret and use the implementation of Captum (Kokhlikyan et al., 2020) to compute the explanation scores for both KS and SVS. For each explanation method, we test with the recommended numbers of pertur- 3 We take more than 2,000 hours on a single A100 GPU for all experiments in this section.   bation samples 5 used to compute the explanation scores for every instance. For Top-K intersections, we report results with K = 5 and K = 10.
Trade-off between stability and computation cost. The ranking stability results are listed in Table 1 and Table 2 for Yelp-Polarity and MNLI datasets. We observe that using 25 to 200 perturbation samples, the stability of the explanation scores is low (Spearman's correlation is only 0.16). Sampling more perturbed inputs makes the scores more stable. However, the computational cost explodes at the same time, going from one second to two minutes per instance. To reduce the sensitivity to an acceptable level (i.e., making the Spearman's correlation between two different runs above 0.40, which indicates moderate correlation (Akoglu, 2018)), we usually need thousands of model evaluations and spend roughly 33.40 seconds per instance. Low MSE does not imply stability. Mean Squared Error (MSE) is commonly used to evaluate the distance between two lists of explanation scores. In Table 1, we observe that MSE only weakly correlates with ranking stability (e.g., For Yelp-Polarity, R = −0.41 and p < 0.05, so the correlation is not significant). Even when the difference of MSE for different settings is as low as 0.01, the correlation between rankings produced by explanations can still be low. Therefore, from users' perspectives, low MSEs do not mean the explanations are reliable as they can suggest distinct rankings.
Longer input suffers more from instability. We 5 For SVS, the recommended number of perturbation samples is 25 in Captum. For KS, to our best knowledge, the typical numbers of perturbation samples used in previous works are 25, 200, 2000. We also include KS-8000 to see how stable KS can be given much longer running time. also plot the Spearman's correlation decomposed at different input lengths in Figure 3. Here, we observe a clear trend that the ranking stability degrades significantly even at an input length of 20 tokens. The general trend is that the longer the input length is, the worse the ranking stability. The same trend holds across datasets. As many NLP tasks involve sentences longer than 20 tokens (e.g., SST-2 (Socher et al., 2013), MNLI (Williams et al., 2018)), obtaining stable explanations to analyze NLP models can be quite challenging. Discussion: why Shapley Values estimation is unstable in text domain? One of the most prominent characteristics of the text domain is that individual tokens/n-grams can have a large impact on the label. Thus they need to be all included in the perturbation samples for an accurate estimate. When the input length grows, the number of n-grams will grow fast. As shown in Section 3, the probability of certain n-grams getting sampled is drastically reduced as each n-gram will be sampled with equivalent probability. Therefore, the observed model output will have a large variance as certain n-grams may not get sampled. A concurrent work (Kwon and Zou, 2022) presented a related theoretical analysis on why the uniform sampling setting in SV computation can lead to suboptimal attribution.

Amortized Inference for Shapley Values
Motivated by the above observation, we propose to train an amortized model to predict the explanation scores given an input without any model evaluation on perturbation samples. The inference cost is thus amortized by training on a set of pre-computed reliable explanation scores. We build an amortized explanation model for text classification in two stages. In the first stage, we construct a training set for the amortized model. We compute reliable explanation scores as the reference scores for training using the existing SV estimator. As shown in Section 4, SVS-25 is the most stable SV estimator and we use it to obtain reference scores. In the second stage, we train a BERT-based amortized model that takes the text as input and outputs the explanation scores using MSE loss.
Specifically, given input tokens X, we use a pretrained language model M LM to encode words into d-dim embeddings ⃗ e = M LM (X) = [⃗ e 1 , . . . , ⃗ e L(X) ] ∈ R L(X)×d . Then, we use a linear layer to transform each ⃗ e i to the predicted explanation score ϕ AM (i,ŷ i ) = W ⃗ e i + b. To train the model, we use MSE loss to fit ϕ AM (i,ŷ) to the precomputed reference scores ϕ(i,ŷ) over the training set X Train . This is an amortized model in the sense that there are no individual sampling and model queries for each test example X as in SVS and KS. When a new sample comes in, the amortized model makes a single inference on the input tokens to predict their explanation scores.

Algorithm 1 Local Adaption
Require: m: the desired number of local adaption perturbation samples, MAM: the trained amortized explanation model, X: the target data instance that has length L,ŷ: the predicted label, MCLF: the target model ϕ ← MAM(X) for j = 1 to m do sample ordering σ from permutation Π(L)

Better Fit via Local Adaption
By amortization, our model can learn to capture the shared feature attribution patterns across data to achieve a good efficiency-stability trade-off. We further show that the explanations generated by our amortized model can be used to initialize the explanation scores of SVS. This way, the evaluation of SVS can be significantly sped up compared with using random initialization. On the other hand, applying SVS upon amortized method improves the latter's performance as some important tokens might not be captured by the amortized method but can be identified by SVS through additional sampling (e.g., low-frequency tokens). The detailed algorithm is shown in Algorithm 1. Note that here we can recover the original SVS computation (Strumbelj and Kononenko, 2010) by replacing ϕ ← M AM (X) to be ϕ ← 0. M AM is the amortized model trained using MSE as explained earlier.

Experiments
In this section, we present experiments to demonstrate the properties of the proposed approach in terms of accuracy against reference scores (6.1) and sensitivity to training-time randomness (6.2). We also show that we achieve a better fit via a local adaption method that combines our approach with SVS (6.3). Then, we evaluate the quality of the explanations generated by our amortized model on two downstream applications (6.5). Setup. We conduct experiments on the validation set of Yelp-Polarity and MNLI datasets. To generate reference explanation scores, we leverage the Thermostat (Feldhus et al., 2021) dataset, which contains 9,815 pre-computed explanation scores of SVS-25 on MNLI. We also compute explanation scores of SVS-25 for 25,000 instances on Yelp-Polarity. We use BERT-base-uncased (Devlin et al., 2019) for M LM . For dataset preprocessing and other experiment details, we refer readers to Appendix C.
To our best knowledge, FastSHAP (Jethani et al., 2021) is the most relevant work to us that also takes an amortization approach to estimate SV on tabular or image data. We adapt it to explain the text classifier and use it as a baseline to compare with our approach. We find it non-trivial to adapt Fast-SHAP to the text domain. As pre-trained language models occupy a large amount of GPU memory, we can only use a small batch size with limited perturbation samples (i.e., 32 perturbation samples per instance). This is equivalent to approximate KS-32 and the corresponding reference explanation scores computed by FastSHAP are unstable. More details can be found in Appendix A.

Shapley Values Approximation
To examine how well our model fits the precomputed SV (SVS-25), we compute both Spearman's correlation and MSE over the test set. As it is intractable to compute exact Shapley Values for ground truth, we use SVS-25 as a proxy. We also include different settings for KS results over the same test set. KS is also an approximation to permutation-based SV computation (Lundberg and Lee, 2017). Table 3 shows the correlation and MSE of aforementioned methods against SVS-25. First, we find that despite the simplicity of our amortized model, the proposed amortized models achieve a high correlation with the reference scores We also find that the amortized model achieves the best MSE score among all approximation methods. Note that the two metrics, Spearman's correlation and MSE, do not convey the same information. MSE measures how well the reference explanation scores are fitted while Spearman's correlation reflects how well the ranking information is learned. We advocate for reporting both metrics. Cost of training the amortized models To produce the training set, we need to pre-compute the explanation scores on a set of data. Although this is a one time cost (for each model), one might wonder how time consuming this step is as we need to run the standard sample-based estimation. As the learning curve shows in Figure 4, we observe that the model achieves good performance with about 25% (≈ 5, 000 on Yelp-Polarity) instances. Additionally, in Section 6.4, we show this one-time training will result in a model transferable to other domains, so we may not need to train a new amortized model for each new domain.

Sensitivity Analysis
Given a trained amortized model, there is no randomness when generating explanation scores. However, there is still some randomness in the  Table 4: Training time sensitivity study. To evaluate how much the amortized model will be influenced by randomness during training, we sample training data 5 times with different random seeds and then compute the averaged Spearman's correlation among all pairs of runs. The standard deviation is less than 1e-2. Our amortized model is stable against training time randomness with only 10% of data.
training process, including the training data, the random initialization of the output layer and randomness during update such as dropout. Therefore, similar to Section 4, we study the sensitivity of the amortized model. Table 4 shows the results with different training data and random seeds. We observe that: 1) when using the same data (100%), random initialization does not affect the outputs of amortized models -the correlation between different runs is high (i.e., 0.77 on MNLI and 0.76 on Yelp-Polarity). 2) With more training samples, the model is more stable.

Local Adaption
The experiment results for Local Adaption (Section 5.1) are shown in Table 5. Here we can see that: 1) by doing local adaption, we can further improve the approximation results using our amortized model, 2) by using our amortized model as initialization, we can improve the sample efficiency of SVS significantly (by comparing the performance of SVS-X and Adapt-X). These findings hold across datasets.

Domain Transferability
To see how well our model performs on out-ofdomain data, we train a classification model and its amortized explanation model on Yelp-Polarity and then explain its performance on SST-2 (Socher et al., 2013) validation set. Both tasks are twoway sentiment classification and have significant domain differences. Our amortized model achieves a Spearman's correlation of approximately 0.50 with ground truth SV (SVS-25) while only requiring 0.017s per instance. In comparison, KS-100 achieves a lower Spearman's correlation of 0.46 with the ground truth and takes 1.6s per instance; KS-200 performs slightly better in Spearman's correlation but requires significantly more time. Thus, our amortized model is more than 90 times faster and more correlated with ground truth Shapley Values. This shows that, once trained, our amortized model can provide efficient and stable estimations of SV even for out-of-domain data.
In practice, we do not recommend directly explaining model predictions on out-of-domain data without verification, because it may be misaligned with user expectations for explanations, and the outof-domain explanations may not be reliable (Hase et al., 2021;Denain and Steinhardt, 2022). More exploration on this direction is required but is orthogonal to this work.

Evaluating the Quality of Explanation
Feature Selection. The first case study is feature selection, which is a straightforward application of local explanation scores. The goal is to find decision-critical features via removing input features gradually according to the rank given by the explanation methods. Following previous work (Zaidan et al., 2007;Jain and Wallace, 2019;DeYoung et al., 2020), we measure faithfulness by changes in the model output after masking tokens identified as important by the explanation method. The more faithful the explanation method is to the target model, the more performance drop will be incurred by masking important tokens.
We gradually mask Top-α tokens (α = 1%, 5%, 10%, 20%) and compute the accuracy over corrupted results using the stability evaluation sets for MNLI and Yelp-Polarity datasets as mentioned in Section 4. As the results show in Figure

Conclusion
In this paper, we empirically demonstrated that it is challenging to obtain stable explanation scores on long text inputs. Inspired by the fact that different instances can share similarly important features, we proposed to efficiently estimate the explanation scores through an amortized model trained to fit pre-computed reference explanation scores.
In the future, we plan to explore model architecture and training loss for developing effective amortized models. In particular, we may incorporate sorting-based loss to learn the ranking order of features. Additionally, we could investigate the transferability of the amortized model across different domains, as well as exploring other SHAP-based methods instead of the time-consuming SVS-25 in the data collection process to improve efficiency further.

Limitations
In this paper, we mainly focus on developing an amortized model to efficiently achieve a reliable estimation of SV. Though not experimented with in the paper, our method can be widely applied to other black-box post-hoc explanation methods including LIME (Ribeiro et al., 2016). Also, due to the limited budget, we only run experiments on BERT-based models. However, as we do not make any assumption for the model as other blackbox explanation methods, our amortized model can be easily applied to other large language models. We only need to collect the model output and our model can be trained offline with just thousands of examples as we show in our method and experiments.

Comparison and Training with Exact Shapley
Values Computing exact SV is computationally prohibitive for large language models (LLMs) on lengthy text inputs, as it necessitates the evaluation of LLMs on an exponential (in sequence length) number of perturbation samples per instance. As a result, we resort to using SVS-25, which serves as a reliable approximation, for training our amortized models.

A Adaption for FastSHAP Baseline
As we mentioned in Section 6, we build our amortized models upon a pre-trained encoder BERT (Devlin et al., 2019). However, using the pre-trained encoder significantly increases the memory footprint when running FastSHAP. In particular, we have to host two language models on GPUs, one for the amortized model and the other one for the target model. Therefore, we can only adopt the batch size equal to 1 and 32 perturbation samples per instance. Following the proof in FastSHAP, this is equivalent to teaching the amortized model to approximate KS-32, which is an unreliable interpretation method (See Section 6.2).
In experiments, we find that the optimization of FastSHAP is unstable. After an extensive hyperparameter search, we set the learning rate to 1e-6 and increased the number of epochs to 30. However, this requires us to train the model on a single A100 GPU for 3 days to wait for FastSHAP to converge.

B Scientific Artifacts License
For the datasets used in this paper, MNLI (Williams et al., 2018) is released under ONAC's license. Yelp-Polarity (Zhang et al., 2015) and SST-2 (Socher et al., 2013) datasets does not provide detailed licenses.
For model checkpoints used in this paper, they all come from textattack project (Morris et al., 2020) and they are open-sourced under MIT license.

C Training Details
In this section, we introduce our dataset preprocessing, hyperparameter settings and how we train the models.
For both MNLI and Yelp-Polarity datasets, we split them into 8:1:1 for training, validation, and test sets.
The hyperparameters of amortized models are tuned on the validation set. We use Adam (Kingma and Ba, 2015) optimizer with a learning rate of 5e-5, train the model for at most 10 epochs and do early stopping to select best model checkpoints.