Thermodynamics-inspired explanations of artificial intelligence
Interpretation unfaithfulness (\(\mathcalU\)) for surrogate model construction
Our starting point is some given dataset \(\mathcalX\) and corresponding predictions g coming from a black-box model. For a particular element \(x\in {\mathcalX}\), we seek explanations that are as human-interpretable as possible while also being as faithful as possible to g in the vicinity of x. We aim to address this problem of explaining g by developing a linear approximation instead, which is more interpretable due to its linear construction. Specifically, we formulate F as a linear combination of an ordered set of representative features, s = s1, s2, …, sn. Typically, these features are domain-dependent, e.g., one-hot encoded superpixels for an image, keywords for text, and standardized values for tabular data. We demonstrate this in Equation (1) below, where F represents the linear approximation, f0 is a constant, and fk comes from an ordered set of feature coefficients, f = f1, f2, …, fn.
$$F=f_\! 0+\Sigma _k=1^n \, \, f_ks_k$$
(1)
Let’s consider a specific problem where x0 is a high-dimensional instance, and g(x0) is a black-box model prediction, for which an explanation is needed. We first generate a neighborhood x1, x2, …, xN of N samples by randomly perturbing the high-dimensional input space22. A detailed discussion of neighborhood generation is provided in “Methods.” Afterward, the black-box predictions g(x1), g(x2), …, g(xN) associated with each sample in the neighborhood are obtained. Subsequently, a local surrogate model is constructed by employing linear regression using the loss function defined in Equation (2).
$$\mathcalL=\min _f_k\mathop\sum _i=1^N\Pi _i(\bfx_\bf0,\bfx_\bfi)\left[g(\bfx_\bfi)-\left(\sum _k=1^nf_ks_ik\right)\right]^2$$
(2)
Here \(\Pi _i(\bfx_\bf0,\bfx_\bfi)=e^{-d{(\bfx_\bf0,{{\bfx}}_{\bfi})}^2/\sigma ^2}\) is a Gaussian similarity measure, where d is the distance between the explanation instance x0 and a neighborhood sample xi. In previous surrogate model construction approaches19, Euclidean distance in the continuous input feature space has been the typical choice for d. However, if the input space has several correlated or redundant features, a similarity measure based on Euclidean distance can be misleading52,53. TERP addresses this problem by computing a one-dimensional (1-d) projection of the neighborhood using linear discriminant analysis54 (LDA), which removes redundancy and produces more accurate similarity. Such a projection encourages the formation of two clusters in a 1-d space, corresponding to in-class and not in-class data points, respectively, by minimizing within-class variance and maximizing between-class distances. Since the projected space is one-dimensional, there is no need to tune the hyperparameter, σ in \(\Pi _i({{{{\bfx}}}}_{{{{\bf0}}}},{{{{\bfx}}}}_{{{{\bfi}}}})=e^{-d{({{{{\bfx}}}}_{{{{\bf0}}}},{{{{\bfx}}}}_{{{{\bfi}}}})}^2/\sigma ^2}\) as might be necessary in established methods, and we can set σ = 1. We demonstrate the advantages of LDA-based similarity for practical problems by performing experiments in a subsequent subsection.
Next, we introduce a meaningful unfaithfulness measure (\(\mathcalU\)) of the generated interpretation, computed from the correlation coefficient C between linear, surrogate model predictions (F) obtained using Equation (1) and black-box predictions (g). For any interpretation, C(F, g) ∈ [ − 1, + 1], and thus interpretation unfaithfulness is bounded, i.e., \(\mathcalU\in [0,1]\)
$$\mathcalU=1-| C(F,g)|$$
(3)
Using these definitions, we implement a forward feature selection scheme55,56 by first constructing n linear models, each with j = 1 non-zero coefficients. We use Equation (3) to identify the feature responsible for the lowest \(\mathcalU^j=1\). Here, the superscript j = 1 highlights that \(\mathcalU\) was calculated for a model with j = 1 non-zero coefficients. We will follow this notation for other relevant quantities throughout this manuscript.
Afterward, the selected feature is propagated to identify the best set of two features resulting in the lowest \(\mathcalU^\, j=2\), and the scheme is continued until \(\mathcalU^\, j=n\) is computed. Since a model with j + 1 non-zero coefficients will be less or at best equally unfaithful as a model with j non-zero coefficients as defined in Equation (1), it can be observed that \(\mathcalU\) monotonically decreases with j. The overall scheme generates n distinct interpretations as j goes from 1 to n.
Interpretation entropy (\(\mathcalS\)) for model selection
After identifying n interpretations, our goal is to determine the optimal interpretation from this family of models. At this point, we introduce the definition of interpretation entropy \(\mathcalS\) for quantifying the degree of human interpretability of any linear model. Given a linear model with an ordered set of feature coefficients f1, f2, …, fn among which j are non-zero, we can define p1, p2, …, pn, where \(p_k:=\frac \). Interpretation entropy is then defined as:
$$\mathcalS^\, \, j=-\mathop\sum _k=1^np_k\log p_k| \\log p_k=0\,\forall \,p_k=0\$$
(4)
Here the superscript j indicates that \(\mathcalS\) is calculated for a model with j non-zero coefficients. It is easy to see that pk satisfies the properties of a probability distribution. Specifically, pk ≥ 0 and \(\sum _k=1^np_k=1\).
Similar to the concept of self-information/surprisal in information theory, the negative logarithm of pk from a fitted linear model can be defined as the self-interpretability penalty of that feature. Interpretation entropy is then computed as the expectation value of self-interpretability penalty of all the features, as shown in Equation (5). Using Jensen’s inequality, it can be shown that \(\mathcalS\) has an upper limit of \(\log n\) and we can normalize the definition so that \(\mathcalS\) is bounded between [0, 1].
$$\mathcalS^\, \, j=\frac-1\log n\mathop\sum _k=1^np_k\log p_k=\frac1\log n\mathbbE[-\log p]$$
(5)
This functional form of interpretation entropy (\(\mathcalS\)), i.e., interpretability penalty, encourages low values for a sharply peaked distribution of fitted weights, indicating high human interpretability and vice versa. Furthermore, if the features are independent, \(\mathcalS\) has two interesting properties expressed in the theorems below. The corresponding proofs are provided in Supplementary Notes 1 and 2 of the Supplementary Information (SI).
Theorem 1
\(\mathcalS^\, j\) is a monotonically increasing function of the number of features (j).
Theorem 2
\(\mathcalS\) monotonically increases as \(\mathcalU\) decreases (Supplementary Fig. S1).
Free energy (ζ) for optimal explanation
For an interpretation with j non-zero coefficients, we now define free energy ζ j as a trade-off between \(\mathcalU^\, j\), and \(\mathcalS^\, \, j\) tunable by a parameter θ ≥ 0, as shown in Fig. 2 and Equation (6).
$$\zeta ^\, j(\, \, f,\theta )=\mathcalU^\, \, j+\theta \mathcalS^\, \, j$$
(6)
By writing an expression shown in Equation (7) for the stationary value, Δζ j = ζ j+1 − ζ j = 0, we can define characteristic temperatures θ j at each j ∈ [1, n − 1]. Essentially, \(\theta ^\, \, j=-\frac\Delta \mathcalU^\, \, j\Delta \mathcalS^\, \, j\) is a measure of change in unfaithfulness per unit change in interpretation entropy for a model with j non-zero coefficients. This closely resembles the definition of thermodynamic temperature which is defined as the derivative of internal energy with respect to entropy. Afterward, we identify the interpretation with (j + 1) non-zero coefficients that minimizes \((\theta ^\, \, j+1-\theta ^\, \, j)=-(\frac\Delta \mathcalU^\, \, j+1\Delta \mathcalS^\, \, j+1-\frac{\Delta \mathcalU^\, j}{\Delta \mathcalS^\, \, j})\) as the optimal interpretation since it is guaranteed that ζ j+1 will preserve the lowest minimum among the set ζ 1, ζ 2, …, ζ j, …, ζ n within the widest range of temperatures. Finally, we calculate optimal temperature, \(\theta ^o=\frac\theta ^\, \, j+1+\theta ^\, \, j2\) (any value within θ j < θ < θ j+1 is equally valid since the optimal interpretation itself does not change) and generate the explanation as weights of this model. All ζ j vs. j plots shown in this manuscript are created using this definition of optimal temperature.
$$\beginarrayrcl\zeta ^j+1-\zeta ^\, j&=&(\mathcalU^j+1-\mathcalU^\, j)+\theta (\mathcalS^j+1-\mathcalS^\, j)\\ \Delta \zeta ^\, j&=&\Delta \mathcalU^\, j+\theta \Delta \mathcalS^\, j\\ \theta ^\, \, j&=&-\frac{\Delta {{\mathcalU}}^\, j}{\Delta \mathcalS^\, j}[\,\mboxBy setting\,\Delta \zeta ^\, j=0]\endarray$$
(7)
Thus,
$$\zeta ^\, j={{{{\mathcalU}}}}^\, j+\left(-\frac{\Delta {{{{\mathcalU}}}}^\, j}{\Delta {{{{\mathcalS}}}}^\, j} _\Delta \zeta ^\, j=0\right){{{{\mathcalS}}}}^\, j$$
(8)
This is again reminiscent of classical thermodynamics, where a system’s equilibrium configuration will, in general, vary with temperature, but the coarse-grained metastable state description remains robust over a well-defined range of temperatures (Supplementary Note 3). In our framework, when θ = 0, ζ j is minimized at j = n interpretation or the model that maximizes unfaithfulness and completely ignores entropy. As θ is increased from zero, interpretation entropy contributes more to ζ j. Here, (θ j+1 − θ j ) is a measure of the stability of the j non-zero coefficient interpretation. The complete TERP protocol is summarized as an algorithm in Fig. 3.
Application to AI-augmented MD: VAMPnets
Variational approach for Markov processes (VAMPnets) is a popular technique for analyzing molecular dynamics (MD) trajectories36. VAMPnets can be used to featurize, transform inputs to a lower-dimensional representation, and construct a Markov state model57 in an automated manner by maximizing the so-called VAMP score. Additional details involving the implementation of VAMPnets are provided in “Methods.”
In this work, we trained a VAMPnets model on a standard toy system: alanine dipeptide in vacuum. An 8-dimensional input space with sines and cosines of all the dihedral angles ϕ, ψ, θ, ω was constructed and passed to VAMPnets. VAMPnets was able to identify three metastable states I, II, and III as shown in Fig. 4b, c.
To explain VAMPnets model predictions using TERP, we picked 713 different configurations, some of which are near different transition states. To quantify data points as being a transition state, we use the criterion that the prediction probability for both classes should be higher than a threshold of 0.4. From a physics perspective, the behavior of such molecular systems near the transition states is a very pertinent question. Additionally, class prediction probability is the most sensitive at the transition state, and if our method generates a meaningful local neighborhood, it should include a broad distribution of probabilities resulting in highly accurate approximations to the black-box behavior. Thus, a correct analysis of the transition state ensemble will validate our similarity metric and overall neighborhood generation scheme.
We generated 5000 neighborhood samples for each configuration and performed TERP by following the algorithm in Fig. 3. In Fig. 4b, c, we highlight the first, and second most dominant features using colored stars (⋆) identified by TERP for all the 713 configurations. The generated explanations are robust and TERP identified various regions where different dihedral angles are relevant to predictions. The results are in agreement with existing literature, e.g., the relevance of θ dihedral angle at the transition state between I and III as reported by Chandler et al.58. Also, the results intuitively make sense, e.g., we see the VAMPnets state definitions change rapidly near ϕ ≈ 0, and TERP learned that ϕ is the most dominant feature in that region. This shows that VAMPnets worked here for the correct reasons and can be trusted. In Fig. 4d–g, we show TERP results for a specific configuration (ϕ = 0.084, ψ = 0.007, θ = 0.237, ω = 2.990 radians) for which j = 2 non-zero model resulted in optimal interpretation with pϕ = 0.82, and pθ = 0.18. Figure 4f clearly shows that (θ j+1 − θ j) is minimized at j = 2 and the average of θ j+1, and θj is taken as the optimal temperature θ o for calculating ζ j using Equation (8). Additional implementation details are provided in “Methods.”
In this section, we demonstrated the applicability of TERP for probing black-box models designed to analyze time-series data coming from MD simulations. In addition to assigning confidence to these models, TERP can be used to extract valuable insights (relevant degrees of freedom) learned by the model. In the future, we expect an increased adoption of TERP-like methods in the domain of AI-enhanced MD simulations for investigating conformational dynamics, nucleation, target-drug interactions, and other relevant molecular phenomena39,40,41,42,43,44,45,46,47,48,49,50,51.
Dimensionality reduction (LDA) significantly improves neighborhood similarity
As discussed in the first subsection, neighborhood similarity evaluated using Euclidean distance can be incorrect and may lead to poor explanations. Here, we perform experiments to demonstrate the advantages of LDA-based similarity measure. Figure 4h shows that the LDA projection successfully generated two clusters of data points belonging to the in-explanation (predicted class of the instance requiring explanation) and not in-explanation classes (all other classes except predicted class) respectively. These well-separated clusters help in computing meaningful and improved distance measure d. In Fig. 4i, j, we illustrate the robustness of an LDA implementation against noisy and correlated features and compare results with Euclidean similarity implementation. We generate pure white noise by drawing samples from a normal distribution \(\mathcalN(0,1)\) and generate correlated data by taking \(a_ix_i+b{\mathcalN}(0,1)\) (e.g., ai = 1.0, b = 0.2), where xi are standardized features from the actual data. As shown in Fig. 4i, j, we construct synthetic neighborhoods by combining actual data from the four dihedral angles and adding one pure noise, four pure noise, and four correlated features, respectively. Since the synthetic features do not contain any information, their addition should not change similarity. Thus, we can compare the robustness of a measure by computing the average change in similarity per datapoint squared, which we call similarity error, ΔΠ ∈ [0, 1], as shown in Equation (9).
$$\Delta \Pi=\frac1N\mathop\sum _i=1^N\left(\Pi _i^o-\Pi _i^s\right)^2$$
(9)
Here, the superscripts o and s represent similarities corresponding to the original and synthetic data points, respectively. We can see that LDA-based similarity performs significantly better in 100 independent trials compared to Euclidean similarity. On the other hand, the addition of one pure noise introduces a significant similarity error for the Euclidean measure. Thus we conclude that adopting LDA over Euclidean similarity measure produced a significantly improved explanation.
Application to image classification: vision transformers (ViTs)
Transformers are a type of machine learning model characterized by the presence of self-attention layers and are commonly used in natural language processing (NLP) tasks59. The more recently proposed Vision transformers (ViTs)37 aim to directly apply the transformer architecture to image data, eliminating the need for convolutional layers, and have become a popular choice in computer vision. Per construction, ViTs are black-box models, and because of their practical usage, it is desirable to employ an explanation scheme to validate their predictions before deploying them.
ViTs operate by segmenting input images into smaller patches, treating each patch as a token similar to words in NLP. These patches are then embedded (patch-embeddings) and passed to the transformer layers conducting self-attention and feedforward operations. Such a design allows ViTs to capture long-range spatial dependencies within images and learn meaningful representations. Interestingly, ViTs are known to perform poorly with limited training data, but with sufficiently large datasets, ViTs have been shown to outperform convolutional layer-based models. Thus a typical ViT implementation includes two stages: first a large dataset is used to learn meaningful representation and pre-train a transferable model, followed by fine-tuning for specific tasks.
In this work, we employ a ViT pre-trained on the ImageNet-21k dataset from the authors37,60,61 and then fine-tune the model for predicting human facial attributes by training on the publicly available Large-scale CelebFaces Attributes (CelebA)62 dataset. CelebA is a large collection of 202,599 human facial images and each image is labeled with 40 different attributes (e.g., ‘Smiling’, ‘Eyeglasses’, ‘Male’, etc.). During training, input images are converted into 16 × 16 pixel patches resulting in a total of 196 patches for each CelebA image (224 × 224 pixel) depicted in Fig. 5b. Other details of the architecture and training procedure are provided in “Methods.”
To explain the ViT prediction ‘Eyeglasses’ (prediction probability of 0.998) for the image shown in Fig. 5a using TERP, we first construct human-understandable representative features by dividing the image into 196 superpixels (collection of pixels) corresponding to the 196 ViT patches as shown in Fig. 5b. Afterward, a neighborhood of perturbed images was generated by averaging the RGB color of randomly chosen superpixels following the neighborhood generation scheme outlined in “Methods.” Figure 5c–f shows \({{{{\mathcalU}}}}^\, j\), \({{{{\mathcalS}}}}^\, j\), θ j, and ζ j as functions of j after implementing the TERP protocol (Fig. 3). Thus, TERP explanation enables us to conclude that the ViT prediction of ‘Eyeglasses’ was made for the correct reasons. The optimal TERP explanation shown in Fig. 5g appears at j = 3, due to the maximal decrease in θ j as j is increased from 2 to 3. Using Equations (7) and (8), ζ j is calculated, and a minimum occurs at j = 3.
Data and model parameter randomization experiments show TERP explanations are sensitive
To establish that TERP indeed takes both the input data and the black-box model into account when generating explanations, we subject our protocol to the sanity tests developed by Adebayo et al.63. We achieve this by taking the fine-tuned ViT model and randomizing the model parameters in a top-to-bottom cascading fashion following their work and obtaining corrupted models. Specifically, we randomize all parameters of ViT blocks 11 − 9 and blocks 11 − 3, respectively, to obtain two corrupt models. TERP explanations for ‘Eyeglasses’ for these two models are shown in Fig. 5h–i. Plots showing \({{{{\mathcalU}}}}^\, j\), \({{{{\mathcalS}}}}^\, j\), ζ j, and θ j as functions of j for these models are provided in the SI (Supplementary Fig. S2). Here, the idea is that, due to randomization, the explanation will not match the ground truth. However, a good AI explanation scheme should be sensitive to this randomization test and produce different explanations from the fully trained model. Similarly, we implemented the data randomization test (Fig. 5j) proposed in the same work, where the labels of the training data are randomized prior to training, and a new ViT is obtained (training details provided in the SI) using the corrupted data. Again, the results of an AI explanation method should be sensitive to this randomization. From the corresponding TERP explanations shown in Fig. 5h–j, we conclude TERP passes both randomization tests.
Baseline benchmark against saliency map shows TERP explanations are reliable
To understand the validity, robustness, and human interpretability of the explanations, we benchmarked TERP against saliency map, LIME, and SHAP, respectively. In this section, we first show that TERP explanations are significantly better, and reasonable compared to a baseline method, i.e., a simple gradient-based saliency map (additional details in “Methods”) for ‘Eyeglasses’ prediction using the previously trained ViT. Comparison with more advanced methods (LIME, and SHAP) to demonstrate how our work contributes to the existing field is discussed in the next subsection.
From Fig. 5k, we see the limitations of the saliency explanation, e.g., a lot of pixels irrelevant to ‘Eyeglasses’ are detected to have high absolute values of the probability gradient across the RGB channels. This is not surprising since saliency maps are known to detect color changes, object edges, and other high-level features instead of learning a relationship between model inputs and class prediction63. We also generated TERP and saliency map explanations for the label ‘Male’ as shown in Fig. 5l, m (further details in the SI). Again, the saliency map explanation includes pixels that should be irrelevant for this predicted class. Contrarily, TERP explanations involve pixels that should be relevant to the respective classes demonstrating the validity of the results.
Comparison with advanced methods demonstrates TERP explanations are unique
In this subsection, we compare TERP with state-of-the-art methods for generating unique and highly human-interpretable explanations. To ensure a fair comparison, we focus on other widely used model-agnostic, post-hoc explanation schemes (LIME19, and SHAP20) that work only on the input and output layers of a black-box model.
LIME generates local, linear approximation (f) to black-box predictions (g) by minimizing: \(\xi (\,\mboxx)=\mboxargmin\,f\mathcalL(g,f,\pi _{\rmx})+\Omega (f)\), where \({{\mathcalL}}\) is a fidelity function (typically root-mean-squared error), πx is neighborhood similarity, and Ω is the complexity measure of the surrogate linear model. In practice, LIME is implemented by first performing weighted linear regression and then either (1) selecting the top j features with extreme coefficients, or (2) by directly implementing Lasso regression with L1 regularization64 for constructing sparse models, where the degree of sparsity can be tuned by a hyperparameter α. Both j and α typically depend on the instance under investigation and will need to be set to a reasonable value by the user. Thus, an accurate human interpretability-based mechanism for generating unique explanations is missing in LIME, and when analyzing a large number of black-box predictions, significant testing/human intervention becomes necessary.
While both TERP and LIME use similar fidelity functions, the main difference is that TERP does not use model complexity or simplicity as a proxy for human interpretability. As discussed in the “Introduction”, such metrics can be misleading, and TERP directly computes the degree of human interpretability by introducing the concept of interpretation entropy. Afterward, a unique explanation is generated by identifying the set of features causing the highest decrease in unfaithfulness per unit increase in entropy.
We applied LIME to explain the ViT prediction for ‘Eyeglasses’, and in Fig. 6a, the top 10 features contributing to the prediction are shown. We also implemented the second approach in LIME, i.e., Lasso regression for sparse models for 10 different values of α. As α is increased, the number of selected features in the explanation decreases, as shown in Fig. 6b. While the relevant superpixels identified by LIME are reasonable and overlap with the superpixels identified by TERP (Fig. 5g), LIME involves hyperparameter selection/human intervention which can be unfeasible for high-throughput experiments, e.g., when analyzing MD data.
After LIME, we implemented another widely used state-of-the-art method, SHAP, for explaining ‘Eyeglasses’, and ‘Male’ predictions as shown in Fig. 6c, d. A feature associated with an extreme SHAP value indicates a high contribution to black-box prediction. Specifically, the SHAP value associated with a feature j can be obtained using: \(\phi _j=\sum _S\frac -1)!N![v(S\cup \\, \, j\)-v(S)]\). Here, the prefactor represents the weight of the marginal contribution (enclosed in []) of feature j to S where S, ∣S∣, and N represent a specific set of features (coalition), number of features in that specific coalition, and total number of features, respectively. The marginal contribution is evaluated by subtracting the effects of the feature j in predictions when j is present and absent in the coalitions respectively. After obtaining SHAP values for all the features, a sparse explanation is typically obtained by taking the top j (j is user-defined) features with the most extreme SHAP values. Thus, similar to LIME, SHAP explanations are also not unique. By comparing SHAP results with TERP (Fig. 5g, l), we again see that the relevant features overlap, which validates TERP explanation.
In this section, we compared TERP with two widely used state-of-the-art model-agnostic, post-hoc approaches and demonstrated the validity of TERP explanations. Furthermore, by employing the theory developed in this work, TERP successfully generated highly human-interpretable, unique explanations, unlike the established methods. Implementation details of LIME and SHAP are provided in “Methods.”
Application to text classification: attention-based bidirectional long short-term memory (Att-BLSTM)
Classification tasks in natural language processing (NLP) involve identifying semantic relations between units appearing at distant locations in a block of text. This challenging problem is known as relation classification, and models based on long short-term memory (LSTM)65, gated recurrent unit (GRU)66, and transformers59 have been very successful in addressing such problems. In this work, we look at the widely used attention-based bidirectional long short-term memory38 (Att-BLSTM) classifier and apply TERP to explain its predictions.
First, we trained an Att-BLSTM model on Antonio Gulli’s (AG’s) news corpus67, which is a large collection of more than 1 million news articles curated from more than 2000 news sources. The labels associated with each news article in the dataset indicate the section of the news source (e.g., World, Sports, Business, or Science and technology) that the news was published in. Afterward, we employed the trained model and obtained prediction for a story titled “AI predicts protein structures,” published in ‘Nature’s biggest news stories of 2022’68.
To implement TERP for probing a black-box prediction involving text input (sequence of sentences), first, the text is passed through a tokenizer (nltk69) which generates a dictionary of words/phrases contained in that text. These words are the representative features to be used in TERP. Afterward, a neighborhood of the perturbed text is generated by randomly choosing and removing all instances of different words from the text. TERP processes the neighborhood as numerical values for linear model construction by creating a one-hot-encoded matrix where the columns represent the presence or absence of the different words in the perturbed text.
As a specific instance, the Att-BLSTM classifier predicted that the story titled “AI predicts protein structures” is about Science and Technology, and we implemented TERP to generate the optimal explanation behind this prediction as shown in Fig. 7. Here, the maximum decrease in θ j occurs when going from j = 1 to j = 2 and thus, ζ j has a minimum at j = 2. The most influential keywords were identified to be ‘species’, and ‘science’ with pk = 0.47, and 0.53 respectively. This gives confidence that the Att-BLSTM model was able to classify the news story for the correct reasons.
link