Chapter 4: Explainable AI (XAI)
This chapter covers Explainable AI (XAI). You will learn Understanding the importance of interpretability, Local linear approximation, and Learning from real-world applications (Toyota.
Learning Objectives
By reading this chapter, you will master:
- ✅ Understanding the importance of interpretability and the black-box problem
- ✅ Quantitative interpretation of predictions using SHAP (Shapley values)
- ✅ Local linear approximation and explanation generation using LIME
- ✅ Neural network interpretation through Attention visualization
- ✅ Learning from real-world applications (Toyota, IBM, Citrine)
- ✅ Career paths and salary information for materials data scientists
4.1 The Importance of Interpretability
Understanding machine learning model predictions and extracting physical meaning is essential in materials science.
The Black-Box Problem
# Requirements:
# - Python 3.9+
# - matplotlib>=3.7.0
# - numpy>=1.24.0, <2.0.0
# - pandas>=2.0.0, <2.2.0
# - seaborn>=0.12.0
"""
Example: The Black-Box Problem
Purpose: Demonstrate data visualization techniques
Target: Advanced
Execution time: 2-5 seconds
Dependencies: None
"""
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import Ridge
# Sample data
np.random.seed(42)
X = np.random.randn(200, 10)
y = 2*X[:, 0] + 3*X[:, 1] - 1.5*X[:, 2] + np.random.normal(0, 0.5, 200)
# Interpretable model vs black-box model
ridge = Ridge(alpha=1.0)
rf = RandomForestRegressor(n_estimators=100, random_state=42)
ridge.fit(X, y)
rf.fit(X, y)
# Ridge coefficients (interpretable)
ridge_coefs = ridge.coef_
# Visualization: Model interpretability differences
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
# Ridge: Clear from linear coefficients
axes[0].bar(range(len(ridge_coefs)), ridge_coefs,
color='steelblue', alpha=0.7)
axes[0].set_xlabel('Feature Index', fontsize=11)
axes[0].set_ylabel('Coefficient', fontsize=11)
axes[0].set_title('Ridge Regression (Interpretable)', fontsize=12, fontweight='bold')
axes[0].axhline(y=0, color='red', linestyle='--', linewidth=1)
axes[0].grid(alpha=0.3)
# Random Forest: Complex non-linear relationships (black-box)
axes[1].text(0.5, 0.5, '❓\nBlack Box\n\n100 decision trees\nComplex non-linear relationships\nDifficult to interpret',
ha='center', va='center', fontsize=16,
bbox=dict(boxstyle='round', facecolor='gray', alpha=0.3),
transform=axes[1].transAxes)
axes[1].set_title('Random Forest (Black-Box)',
fontsize=12, fontweight='bold')
axes[1].axis('off')
plt.tight_layout()
plt.show()
print("Interpretability challenges:")
print("- Linear models: Influence is clear from coefficients, but accuracy is low")
print("- Non-linear models: High accuracy, but why predictions are made is unclear")
print("→ XAI (Explainable AI) aims for both")
The Need for Physical Interpretation in Materials Science
# Use cases for interpretability in materials science
use_cases = pd.DataFrame({
'Use Case': [
'New material discovery',
'Synthesis condition optimization',
'Process anomaly detection',
'Property prediction',
'Material design guidelines'
],
'Importance of Interpretability': [10, 9, 8, 7, 10],
'Reason': [
'Understanding physical mechanisms leads to new discoveries',
'Identify which parameters are critical',
'Root cause of anomalies must be identified',
'Verify prediction basis',
'Extract design principles'
]
})
# Visualization
fig, ax = plt.subplots(figsize=(12, 6))
colors = plt.cm.YlOrRd(np.linspace(0.3, 0.9, len(use_cases)))
bars = ax.barh(use_cases['Use Case'],
use_cases['Importance of Interpretability'],
color=colors, alpha=0.7)
ax.set_xlabel('Importance of Interpretability (1-10)', fontsize=12)
ax.set_xlim(0, 10)
ax.set_title('Importance of Interpretability in Materials Science',
fontsize=13, fontweight='bold')
ax.grid(axis='x', alpha=0.3)
# Add reason annotations
for idx, row in use_cases.iterrows():
ax.text(row['Importance of Interpretability'] + 0.3, idx,
row['Reason'], va='center', fontsize=9, style='italic')
plt.tight_layout()
plt.show()
print("Why XAI is needed in materials science:")
print("1. Verification of physical law consistency")
print("2. Reflection in experimental planning")
print("3. Integration of expert knowledge")
print("4. Accountability in papers and patents")
Trustworthiness and Debugging
# Example of discovering model prediction errors through interpretation
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error
# Data generation (intentionally including noise)
X_data = np.random.randn(300, 5)
# Correct relationship: y = 2*X0 + 3*X1
y_true = 2*X_data[:, 0] + 3*X_data[:, 1] + np.random.normal(0, 0.3, 300)
# Add noise to some samples (simulate measurement error)
noise_idx = np.random.choice(300, 30, replace=False)
y_data = y_true.copy()
y_data[noise_idx] += np.random.normal(0, 5, 30)
# Training
X_train, X_test, y_train, y_test = train_test_split(
X_data, y_data, test_size=0.2, random_state=42
)
model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
# Prediction
y_pred = model.predict(X_test)
mae = mean_absolute_error(y_test, y_pred)
# Identify samples with large errors
errors = np.abs(y_test - y_pred)
high_error_idx = np.where(errors > np.percentile(errors, 90))[0]
print(f"Model MAE: {mae:.4f}")
print(f"Number of high-error samples: {len(high_error_idx)}")
print("\n→ Use XAI to analyze causes of high-error samples")
print(" - Discover data quality issues")
print(" - Identify model weaknesses")
print(" - Verify physical validity")
4.2 SHAP (SHapley Additive exPlanations)
An interpretation method based on Shapley values from cooperative game theory.
Shapley Values Theory
# Requirements:
# - Python 3.9+
# - shap>=0.42.0
"""
Example: Shapley Values Theory
Purpose: Demonstrate machine learning model training and evaluation
Target: Beginner
Execution time: 1-5 minutes
Dependencies: None
"""
import shap
# Visualize SHAP basic concept
shap.initjs()
# Model training
model_shap = RandomForestRegressor(n_estimators=100, random_state=42)
model_shap.fit(X_train, y_train)
# SHAP Explainer
explainer = shap.TreeExplainer(model_shap)
shap_values = explainer.shap_values(X_test)
print("Meaning of SHAP values:")
print("- How much each feature contributed to the prediction")
print("- Shapley value: Fair distribution from cooperative game theory")
print("- Expressed as deviation from base value")
print(f"\nSHAP values shape: {shap_values.shape}")
print(f" Number of samples: {shap_values.shape[0]}")
print(f" Number of features: {shap_values.shape[1]}")
# Explain single sample
sample_idx = 0
base_value = explainer.expected_value
prediction = model_shap.predict(X_test[sample_idx:sample_idx+1])[0]
print(f"\nPrediction for sample {sample_idx}:")
print(f"Base value: {base_value:.4f}")
print(f"Sum of SHAP values: {shap_values[sample_idx].sum():.4f}")
print(f"Prediction: {prediction:.4f}")
print(f"Verification: {base_value + shap_values[sample_idx].sum():.4f} ≈ {prediction:.4f}")
SHAP Value Computation (Tree SHAP, Kernel SHAP)
# Tree SHAP (fast, tree-based models only)
explainer_tree = shap.TreeExplainer(model_shap)
shap_values_tree = explainer_tree.shap_values(X_test)
# Kernel SHAP (model-agnostic, slow)
# Demo with small sample
X_test_small = X_test[:10]
explainer_kernel = shap.KernelExplainer(
model_shap.predict,
shap.sample(X_train, 50)
)
shap_values_kernel = explainer_kernel.shap_values(X_test_small)
print("Comparison of SHAP computation methods:")
print("\nTree SHAP:")
print(f" Target models: Tree-based (RF, XGBoost, LightGBM)")
print(f" Computation speed: Fast")
print(f" Accuracy: Exact solution")
print("\nKernel SHAP:")
print(f" Target models: Any (including neural networks)")
print(f" Computation speed: Slow")
print(f" Accuracy: Approximate solution (sampling-based)")
# Simple computation time comparison
import time
start = time.time()
_ = explainer_tree.shap_values(X_test)
tree_time = time.time() - start
print(f"\nTree SHAP computation time: {tree_time:.3f} seconds ({len(X_test)} samples)")
Global vs Local Interpretation
# Global interpretation: Average importance across all samples
mean_abs_shap = np.abs(shap_values).mean(axis=0)
fig, axes = plt.subplots(1, 2, figsize=(14, 6))
# Global interpretation
axes[0].bar(range(len(mean_abs_shap)), mean_abs_shap,
color='steelblue', alpha=0.7)
axes[0].set_xlabel('Feature Index', fontsize=11)
axes[0].set_ylabel('Mean |SHAP value|', fontsize=11)
axes[0].set_title('Global Interpretation (Overall Importance)',
fontsize=12, fontweight='bold')
axes[0].grid(alpha=0.3)
# Local interpretation: Specific sample
sample_idx = 0
axes[1].bar(range(len(shap_values[sample_idx])),
shap_values[sample_idx],
color='coral', alpha=0.7)
axes[1].axhline(y=0, color='black', linestyle='-', linewidth=1)
axes[1].set_xlabel('Feature Index', fontsize=11)
axes[1].set_ylabel('SHAP value', fontsize=11)
axes[1].set_title(f'Local Interpretation (Sample {sample_idx} explanation)',
fontsize=12, fontweight='bold')
axes[1].grid(alpha=0.3)
plt.tight_layout()
plt.show()
print("Global vs Local interpretation:")
print("\nGlobal:")
print(" - Average feature importance across all samples")
print(" - Understand overall model behavior")
print(" - General guidelines for new material design")
print("\nLocal:")
print(" - Basis for individual predictions")
print(" - Identify causes of anomalous samples")
print(" - Optimization direction for specific materials")
Summary Plot, Dependence Plot
# Summary plot (overview)
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values, X_test, plot_type="dot", show=False)
plt.title('SHAP Summary Plot', fontsize=13, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()
print("How to read Summary Plot:")
print("- Vertical axis: Features (ordered by importance)")
print("- Horizontal axis: SHAP value (influence on prediction)")
print("- Color: Feature value (red=high, blue=low)")
print("- Distribution: Diversity of influence for each feature")
# Dependence plot (detailed for individual features)
feature_idx = 0
plt.figure(figsize=(10, 6))
shap.dependence_plot(
feature_idx,
shap_values,
X_test,
show=False
)
plt.title(f'SHAP Dependence Plot (Feature {feature_idx})',
fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()
print("\nHow to read Dependence Plot:")
print("- Horizontal axis: Feature value")
print("- Vertical axis: SHAP value (influence on prediction)")
print("- Color: Another feature that interacts")
print("- Trend: Visualization of non-linear relationships")
4.3 LIME (Local Interpretable Model-agnostic Explanations)
An explanation generation method based on local linear approximation.
Local Linear Approximation
# Requirements:
# - Python 3.9+
# - lime>=0.2.0
"""
Example: Local Linear Approximation
Purpose: Demonstrate data visualization techniques
Target: Beginner
Execution time: 1-5 minutes
Dependencies: None
"""
from lime import lime_tabular
# LIME Explainer
lime_explainer = lime_tabular.LimeTabularExplainer(
X_train,
mode='regression',
feature_names=[f'Feature_{i}' for i in range(X_train.shape[1])],
verbose=False
)
# Explain single sample
sample_idx = 0
explanation = lime_explainer.explain_instance(
X_test[sample_idx],
model_shap.predict,
num_features=5
)
# Visualization
fig = explanation.as_pyplot_figure()
plt.title(f'LIME Explanation (Sample {sample_idx})',
fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()
print("How LIME works:")
print("1. Random sampling around target sample")
print("2. Predictions by black-box model")
print("3. Distance-based weighting")
print("4. Train local linear model")
print("5. Explain using linear coefficients")
# Display explanation numerically
print("\nExplanations (by importance):")
for feature, weight in explanation.as_list():
print(f" {feature}: {weight:.4f}")
Tabular LIME
# Run LIME on multiple samples
n_samples_lime = 5
lime_results = []
for i in range(n_samples_lime):
exp = lime_explainer.explain_instance(
X_test[i],
model_shap.predict,
num_features=X_train.shape[1]
)
# Convert explanation to dictionary
exp_dict = dict(exp.as_list())
lime_results.append(exp_dict)
# Convert to DataFrame
lime_df = pd.DataFrame(lime_results)
print(f"\nLIME explanations for {n_samples_lime} samples:")
print(lime_df.head())
# Evaluate consistency (are same features always important?)
feature_importance_consistency = lime_df.abs().mean()
print("\nAverage feature importance (LIME):")
print(feature_importance_consistency.sort_values(ascending=False))
Generating Prediction Explanations
# SHAP vs LIME comparison
def compare_shap_lime(sample_idx):
"""
Compare SHAP vs LIME explanations for the same sample
"""
# SHAP
shap_exp = shap_values[sample_idx]
# LIME
lime_exp = lime_explainer.explain_instance(
X_test[sample_idx],
model_shap.predict,
num_features=X_train.shape[1]
)
lime_dict = dict(lime_exp.as_list())
# Align LIME explanations to same order as SHAP
lime_exp_ordered = []
for i in range(len(shap_exp)):
feature_name = f'Feature_{i}'
# Search for corresponding feature in LIME explanation
for key, value in lime_dict.items():
if feature_name in key:
lime_exp_ordered.append(value)
break
else:
lime_exp_ordered.append(0)
return shap_exp, np.array(lime_exp_ordered)
# Compare
sample_idx = 0
shap_exp, lime_exp = compare_shap_lime(sample_idx)
# Visualization
fig, ax = plt.subplots(figsize=(12, 6))
x_pos = np.arange(len(shap_exp))
width = 0.35
ax.bar(x_pos - width/2, shap_exp, width,
label='SHAP', color='steelblue', alpha=0.7)
ax.bar(x_pos + width/2, lime_exp, width,
label='LIME', color='coral', alpha=0.7)
ax.set_xlabel('Feature Index', fontsize=12)
ax.set_ylabel('Importance', fontsize=12)
ax.set_title(f'SHAP vs LIME (Sample {sample_idx})',
fontsize=13, fontweight='bold')
ax.set_xticks(x_pos)
ax.legend()
ax.grid(alpha=0.3)
ax.axhline(y=0, color='black', linestyle='-', linewidth=1)
plt.tight_layout()
plt.show()
# Correlation analysis
correlation = np.corrcoef(shap_exp, lime_exp)[0, 1]
print(f"\nSHAP-LIME correlation: {correlation:.4f}")
print("High correlation → Consistent explanations from both methods")
4.4 Attention Visualization (for NN/GNN)
Visualizing the Attention mechanism in neural networks.
Visualizing Attention Weights
# Simple demo of Attention mechanism
from sklearn.neural_network import MLPRegressor
# Train neural network
nn_model = MLPRegressor(
hidden_layer_sizes=(50, 50),
max_iter=1000,
random_state=42
)
nn_model.fit(X_train, y_train)
# Get activation of intermediate layer (simple version)
def get_activation(model, X, layer_idx=0):
"""
Get activation of specified layer
"""
# Weights and biases
W = model.coefs_[layer_idx]
b = model.intercepts_[layer_idx]
# Activation (ReLU)
activation = np.maximum(0, X @ W + b)
return activation
# Activation of first layer
activation_layer1 = get_activation(nn_model, X_test, layer_idx=0)
# Attention-like weights (treat activation magnitude as weights)
attention_weights = np.abs(activation_layer1).mean(axis=1)
# Visualization
plt.figure(figsize=(12, 6))
plt.scatter(range(len(attention_weights)), attention_weights,
c=y_test, cmap='viridis', s=100, alpha=0.6)
plt.colorbar(label='Target Value')
plt.xlabel('Sample Index', fontsize=12)
plt.ylabel('Attention Weight (Activation Strength)', fontsize=12)
plt.title('Attention-like Weights (First Layer Activation)',
fontsize=13, fontweight='bold')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()
print("Significance of Attention visualization:")
print("- Which inputs is the model attending to?")
print("- Identification of important samples and features")
print("- Understanding internal behavior of neural networks")
Grad-CAM for Materials
# Gradient-based importance (simplified version)
def gradient_based_importance(model, X_sample):
"""
Gradient-based feature importance
"""
# Approximate with numerical differentiation
epsilon = 1e-5
base_pred = model.predict(X_sample.reshape(1, -1))[0]
importances = []
for i in range(len(X_sample)):
X_perturbed = X_sample.copy()
X_perturbed[i] += epsilon
perturbed_pred = model.predict(X_perturbed.reshape(1, -1))[0]
# Gradient approximation
gradient = (perturbed_pred - base_pred) / epsilon
importances.append(gradient)
return np.array(importances)
# Execute on sample
sample_idx = 0
grad_importances = gradient_based_importance(nn_model, X_test[sample_idx])
# Compare SHAP, LIME, and Gradient
fig, axes = plt.subplots(1, 3, figsize=(18, 5))
# SHAP
axes[0].bar(range(len(shap_exp)), shap_exp,
color='steelblue', alpha=0.7)
axes[0].axhline(y=0, color='black', linestyle='-', linewidth=1)
axes[0].set_xlabel('Feature', fontsize=11)
axes[0].set_ylabel('Importance', fontsize=11)
axes[0].set_title('SHAP', fontsize=12, fontweight='bold')
axes[0].grid(alpha=0.3)
# LIME
axes[1].bar(range(len(lime_exp)), lime_exp,
color='coral', alpha=0.7)
axes[1].axhline(y=0, color='black', linestyle='-', linewidth=1)
axes[1].set_xlabel('Feature', fontsize=11)
axes[1].set_ylabel('Importance', fontsize=11)
axes[1].set_title('LIME', fontsize=12, fontweight='bold')
axes[1].grid(alpha=0.3)
# Gradient
axes[2].bar(range(len(grad_importances)), grad_importances,
color='green', alpha=0.7)
axes[2].axhline(y=0, color='black', linestyle='-', linewidth=1)
axes[2].set_xlabel('Feature', fontsize=11)
axes[2].set_ylabel('Gradient', fontsize=11)
axes[2].set_title('Gradient-based', fontsize=12, fontweight='bold')
axes[2].grid(alpha=0.3)
plt.tight_layout()
plt.show()
print("Characteristics of three methods:")
print("SHAP: Game-theoretic fairness, works with all models")
print("LIME: Local linear approximation, intuitive")
print("Gradient: Gradient information, specialized for neural networks")
Which Atoms/Bonds are Important?
# Application example in materials science: Element importance
composition_features = ['Li', 'Co', 'Ni', 'Mn', 'O']
# Simulation data
X_composition = pd.DataFrame({
'Li': np.random.uniform(0.9, 1.1, 100),
'Co': np.random.uniform(0, 0.6, 100),
'Ni': np.random.uniform(0, 0.8, 100),
'Mn': np.random.uniform(0, 0.4, 100),
'O': np.random.uniform(1.9, 2.1, 100)
})
# Capacity (Ni is important)
y_capacity = (
150 * X_composition['Ni'] +
120 * X_composition['Co'] +
80 * X_composition['Mn'] +
np.random.normal(0, 5, 100)
)
# Model training
model_comp = RandomForestRegressor(n_estimators=100, random_state=42)
model_comp.fit(X_composition, y_capacity)
# SHAP analysis
explainer_comp = shap.TreeExplainer(model_comp)
shap_values_comp = explainer_comp.shap_values(X_composition)
# Element-wise importance
mean_abs_shap_comp = np.abs(shap_values_comp).mean(axis=0)
# Visualization
plt.figure(figsize=(10, 6))
plt.bar(composition_features, mean_abs_shap_comp,
color=['#FFD700', '#4169E1', '#32CD32', '#FF69B4', '#FF6347'],
alpha=0.7, edgecolor='black', linewidth=1.5)
plt.xlabel('Element', fontsize=12)
plt.ylabel('Mean |SHAP value|', fontsize=12)
plt.title('Element Contribution to Battery Capacity (SHAP Analysis)',
fontsize=13, fontweight='bold')
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()
print("Element-wise importance:")
for elem, importance in zip(composition_features, mean_abs_shap_comp):
print(f" {elem}: {importance:.2f}")
print("\nImplications for material design:")
print("→ Increasing Ni content is expected to improve capacity")
4.5 Real-World Applications and Career Paths
We introduce industry applications of XAI and career information for materials data scientists.
Toyota: XAI Application in Materials Development
# Toyota case study (simulation)
print("=== Toyota Automotive Materials Development Case ===")
print("\nChallenge:")
print(" - Clarify battery material degradation mechanisms")
print(" - Select optimal materials from thousands of candidates")
print("\nXAI Application:")
print(" - Identify degradation contributing factors with SHAP analysis")
print(" - Visualize interactions between temperature, voltage, and cycle count")
print(" - Verify consistency with physical models")
print("\nResults:")
print(" - Development time reduced by 40%")
print(" - Battery lifetime improved by 20%")
print(" - Researchers gained physical insights")
# Simulation: Battery degradation prediction
battery_aging = pd.DataFrame({
'Temperature': np.random.uniform(20, 60, 200),
'Voltage': np.random.uniform(3.0, 4.5, 200),
'Cycle Count': np.random.uniform(0, 1000, 200),
'Charge Rate': np.random.uniform(0.5, 2.0, 200)
})
# Degradation rate (temperature and cycles are main factors)
degradation = (
0.5 * battery_aging['Temperature'] +
0.3 * battery_aging['Cycle Count'] / 100 +
0.2 * battery_aging['Voltage'] * battery_aging['Charge Rate'] +
np.random.normal(0, 2, 200)
)
# Model
model_aging = RandomForestRegressor(n_estimators=100, random_state=42)
model_aging.fit(battery_aging, degradation)
# SHAP analysis
explainer_aging = shap.TreeExplainer(model_aging)
shap_values_aging = explainer_aging.shap_values(battery_aging)
# Visualization
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values_aging, battery_aging, show=False)
plt.title('SHAP Analysis of Battery Degradation Factors (Toyota-style Case)',
fontsize=13, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()
IBM Research: Interpretability in AI Materials Design
# Requirements:
# - Python 3.9+
# - networkx>=3.1.0
"""
Example: IBM Research: Interpretability in AI Materials Design
Purpose: Demonstrate data visualization techniques
Target: Advanced
Execution time: 2-5 seconds
Dependencies: None
"""
print("\n=== IBM Research Materials Design Case ===")
print("\nProject: RoboRXN (Automated Chemistry Experiments)")
print("\nCharacteristics:")
print(" - XAI integrated into reaction condition optimization")
print(" - SHAP + Attention for reaction mechanism prediction")
print(" - Generate chemist-understandable recommendations")
print("\nTechnology Stack:")
print(" - Graph Neural Network (GNN)")
print(" - Attention mechanism")
print(" - SHAP for molecular graphs")
print("\nResults:")
print(" - Reaction yield prediction accuracy 95%")
print(" - Gained chemist trust")
print(" - Discovered novel reaction pathways")
# Molecular graph importance visualization (conceptual)
fig, ax = plt.subplots(figsize=(10, 8))
# Dummy molecular graph
import networkx as nx
G = nx.Graph()
G.add_edges_from([
(0, 1), (1, 2), (2, 3), (3, 4), (4, 0),
(1, 5), (3, 6)
])
pos = nx.spring_layout(G, seed=42)
# Node importance (Attention weights-like)
node_importance = np.random.rand(len(G.nodes))
node_importance = node_importance / node_importance.sum()
nx.draw(
G, pos,
node_color=node_importance,
node_size=1000 * node_importance / node_importance.max(),
cmap='YlOrRd',
with_labels=True,
font_size=12,
font_weight='bold',
edge_color='gray',
width=2,
ax=ax
)
sm = plt.cm.ScalarMappable(
cmap='YlOrRd',
norm=plt.Normalize(vmin=0, vmax=node_importance.max())
)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, label='Attention Weight')
ax.set_title('Molecular Graph Attention Visualization (IBM-style)',
fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()
Startup: Citrine Informatics (Explainable AI)
print("\n=== Citrine Informatics Case ===")
print("\nBusiness Model:")
print(" - Provides materials development platform")
print(" - Core technology: Explainable AI")
print(" - SaaS deployment to major manufacturers")
print("\nTechnical Features:")
print(" - Bayesian optimization + XAI")
print(" - Uncertainty quantification")
print(" - Integration of physical constraints")
print("\nCustomer Cases:")
print(" - Panasonic: 50% faster battery material development")
print(" - 3M: 30% improvement in adhesive performance")
print(" - Michelin: Tire rubber optimization")
print("\nDifferentiation Factors:")
print(" - Gain expert trust through explainability")
print(" - Integration with physical models")
print(" - High accuracy with small datasets")
# Citrine approach (simulation)
# Predictions with uncertainty + SHAP
from sklearn.ensemble import GradientBoostingRegressor
# Model (quantile regression-style)
model_citrine_lower = GradientBoostingRegressor(
loss='quantile', alpha=0.1, n_estimators=100, random_state=42
)
model_citrine_median = GradientBoostingRegressor(
n_estimators=100, random_state=42
)
model_citrine_upper = GradientBoostingRegressor(
loss='quantile', alpha=0.9, n_estimators=100, random_state=42
)
X_citrine = X_composition
y_citrine = y_capacity
model_citrine_lower.fit(X_citrine, y_citrine)
model_citrine_median.fit(X_citrine, y_citrine)
model_citrine_upper.fit(X_citrine, y_citrine)
# Prediction
X_new = X_citrine.iloc[:20]
y_pred_lower = model_citrine_lower.predict(X_new)
y_pred_median = model_citrine_median.predict(X_new)
y_pred_upper = model_citrine_upper.predict(X_new)
# Visualization
fig, ax = plt.subplots(figsize=(12, 6))
x_axis = range(len(X_new))
ax.fill_between(x_axis, y_pred_lower, y_pred_upper,
alpha=0.3, color='steelblue',
label='80% Prediction Interval')
ax.plot(x_axis, y_pred_median, 'o-',
color='steelblue', linewidth=2, label='Prediction Median')
ax.set_xlabel('Material Sample', fontsize=12)
ax.set_ylabel('Capacity (mAh/g)', fontsize=12)
ax.set_title('Citrine-style Prediction with Uncertainty',
fontsize=13, fontweight='bold')
ax.legend()
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()
print("\nBenefits of uncertainty:")
print(" - Risk assessment")
print(" - Prioritization of additional experiments")
print(" - Increased confidence in decision-making")
Career Path: Materials Data Scientists, XAI Researchers
# Career path information
career_paths = pd.DataFrame({
'Career Path': [
'Materials Data Scientist',
'XAI Researcher (Academia)',
'ML Engineer (Materials-specialized)',
'R&D Manager (AI adoption)',
'Technical Consultant'
],
'Required Skills': [
'Materials Science + ML + Python',
'Statistics + ML Theory + Paper Writing',
'ML Implementation + MLOps',
'Materials Science + Project Management',
'Materials Science + ML + Business'
],
'Example Employers': [
'Toyota, Panasonic, Mitsubishi Chemical',
'Universities, AIST, RIKEN',
'Citrine, Materials Zone',
'Large Manufacturing R&D Divisions',
'Accenture, Deloitte'
]
})
print("\n=== Career Paths ===")
print(career_paths.to_string(index=False))
Salary: ¥7-15 million (Japan), $90-180K (US)
# Salary data
salary_data = pd.DataFrame({
'Position': [
'Junior (0-3 years)',
'Mid-level (3-7 years)',
'Senior (7-15 years)',
'Lead Scientist',
'Manager'
],
'Japan_Low': [500, 700, 1000, 1200, 1500],
'Japan_High': [700, 1000, 1500, 2000, 2500],
'US_Low': [70, 90, 130, 150, 180],
'US_High': [90, 130, 180, 220, 300]
})
# Visualization
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
# Japan
axes[0].barh(salary_data['Position'],
salary_data['Japan_High'] - salary_data['Japan_Low'],
left=salary_data['Japan_Low'],
color='steelblue', alpha=0.7)
for idx, row in salary_data.iterrows():
axes[0].text(row['Japan_Low'] - 50, idx,
f"{row['Japan_Low']}", va='center', ha='right', fontsize=9)
axes[0].text(row['Japan_High'] + 50, idx,
f"{row['Japan_High']}", va='center', ha='left', fontsize=9)
axes[0].set_xlabel('Annual Salary (ten thousand yen)', fontsize=12)
axes[0].set_title('Japan Salary Range', fontsize=13, fontweight='bold')
axes[0].grid(axis='x', alpha=0.3)
# US
axes[1].barh(salary_data['Position'],
salary_data['US_High'] - salary_data['US_Low'],
left=salary_data['US_Low'],
color='coral', alpha=0.7)
for idx, row in salary_data.iterrows():
axes[1].text(row['US_Low'] - 5, idx,
f"${row['US_Low']}K", va='center', ha='right', fontsize=9)
axes[1].text(row['US_High'] + 5, idx,
f"${row['US_High']}K", va='center', ha='left', fontsize=9)
axes[1].set_xlabel('Annual Salary (thousand dollars)', fontsize=12)
axes[1].set_title('US Salary Range', fontsize=13, fontweight='bold')
axes[1].grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.show()
print("\nFactors affecting salary:")
print(" - Degree (Master's vs PhD)")
print(" - Industry (Manufacturing vs IT)")
print(" - Location (Tokyo vs Regional, Silicon Valley vs Other)")
print(" - Skillset (Materials Science + ML + Domain Knowledge)")
print(" - Track Record (Papers, Patents, Project Success)")
print("\nSkill Development Strategy:")
print(" 1. Strengthen materials science fundamentals (degree recommended)")
print(" 2. Practical ML/DL skills (Kaggle, GitHub)")
print(" 3. Master XAI methods (SHAP, LIME)")
print(" 4. Publish papers and contribute to OSS")
print(" 5. Network (conferences, meetups)")
Exercises
Problem 1 (Difficulty: Easy)
Using SHAP and LIME, generate explanations for the same sample and calculate the correlation of feature importance. Discuss what high and low correlations mean.
Solution Example
# Requirements:
# - Python 3.9+
# - lime>=0.2.0
# - numpy>=1.24.0, <2.0.0
# - shap>=0.42.0
"""
Example: Using SHAP and LIME, generate explanations for the same samp
Purpose: Demonstrate machine learning model training and evaluation
Target: Advanced
Execution time: 30-60 seconds
Dependencies: None
"""
import shap
from lime import lime_tabular
from sklearn.ensemble import RandomForestRegressor
import numpy as np
# Model training
model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
# SHAP
explainer_shap = shap.TreeExplainer(model)
shap_values = explainer_shap.shap_values(X_test)
# LIME
explainer_lime = lime_tabular.LimeTabularExplainer(
X_train, mode='regression'
)
sample_idx = 0
# LIME explanation
lime_exp = explainer_lime.explain_instance(
X_test[sample_idx], model.predict, num_features=X_train.shape[1]
)
lime_dict = dict(lime_exp.as_list())
# Calculate correlation
shap_importances = shap_values[sample_idx]
lime_importances = [lime_dict.get(f'Feature_{i}', 0)
for i in range(len(shap_importances))]
correlation = np.corrcoef(shap_importances, lime_importances)[0, 1]
print(f"SHAP-LIME correlation: {correlation:.4f}")
if correlation > 0.7:
print("High correlation: Consistent explanations from both methods → High reliability")
else:
print("Low correlation: Explanation discrepancy → Careful interpretation needed")
Problem 2 (Difficulty: Medium)
Using SHAP Dependence Plot, visualize interactions between two features. Analyze whether non-linear relationships or interactions are observed.
Solution Example
# Requirements:
# - Python 3.9+
# - matplotlib>=3.7.0
# - shap>=0.42.0
"""
Example: Using SHAP Dependence Plot, visualize interactions between t
Purpose: Demonstrate data visualization techniques
Target: Beginner to Intermediate
Execution time: 2-5 seconds
Dependencies: None
"""
import shap
import matplotlib.pyplot as plt
# SHAP calculation
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)
# Dependence Plot (feature 0 and feature 1 interaction)
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
shap.dependence_plot(0, shap_values, X_test, interaction_index=1,
ax=axes[0], show=False)
axes[0].set_title('Feature 0 (interaction with Feature 1)')
shap.dependence_plot(1, shap_values, X_test, interaction_index=0,
ax=axes[1], show=False)
axes[1].set_title('Feature 1 (interaction with Feature 0)')
plt.tight_layout()
plt.show()
print("Analysis Points:")
print("- Color variation: Strength of interaction")
print("- Non-linear patterns: Complex relationships")
print("- Trend: Direction of influence (positive/negative)")
Problem 3 (Difficulty: Hard)
Mimic the Toyota battery degradation prediction case and perform SHAP analysis on three factors (temperature, voltage, cycle count), quantitatively evaluating which factor contributes most to degradation. Also discuss physical validity.
Solution Example
# Requirements:
# - Python 3.9+
# - numpy>=1.24.0, <2.0.0
# - pandas>=2.0.0, <2.2.0
# - shap>=0.42.0
"""
Example: Mimic the Toyota battery degradation prediction case and per
Purpose: Demonstrate data visualization techniques
Target: Advanced
Execution time: 30-60 seconds
Dependencies: None
"""
import pandas as pd
import numpy as np
from sklearn.ensemble import GradientBoostingRegressor
import shap
# Data generation
battery_data = pd.DataFrame({
'Temperature': np.random.uniform(20, 60, 300),
'Voltage': np.random.uniform(3.0, 4.5, 300),
'Cycle Count': np.random.uniform(0, 1000, 300)
})
# Degradation rate (physically reasonable model)
# Degradation increases with high temperature, high voltage, and many cycles
degradation = (
0.8 * (battery_data['Temperature'] - 20) + # High temp accelerates degradation
2.0 * (battery_data['Voltage'] - 3.0)**2 + # High voltage accelerates degradation
0.05 * battery_data['Cycle Count'] + # Cycle degradation
0.01 * battery_data['Temperature'] * battery_data['Cycle Count'] / 100 + # Interaction
np.random.normal(0, 3, 300)
)
# Model training
model = GradientBoostingRegressor(n_estimators=100, random_state=42)
model.fit(battery_data, degradation)
# SHAP analysis
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(battery_data)
# Importance aggregation
mean_abs_shap = np.abs(shap_values).mean(axis=0)
feature_names = battery_data.columns
print("Degradation factor importance (SHAP):")
for name, importance in zip(feature_names, mean_abs_shap):
print(f" {name}: {importance:.2f}")
# Summary Plot
shap.summary_plot(shap_values, battery_data, show=False)
plt.title('SHAP Analysis of Battery Degradation Factors', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()
print("\nPhysical Validity:")
print("- Temperature: Arrhenius equation predicts increased reaction rate at high temp → Valid")
print("- Voltage: High voltage promotes side reactions → Valid")
print("- Cycle Count: Repeated charge-discharge causes degradation → Valid")
4.6 XAI Environment and Practical Pitfalls
SHAP Library Version Management
# Requirements:
# - Python 3.9+
# - lime>=0.2.0
# - numpy>=1.24.0, <2.0.0
# - pandas>=2.0.0, <2.2.0
# - scikit-learn>=1.3.0, <1.5.0
# - shap>=0.42.0
"""
Example: SHAP Library Version Management
Purpose: Demonstrate data visualization techniques
Target: Intermediate
Execution time: 2-5 seconds
Dependencies: None
"""
# Libraries needed for XAI
import sys
import shap
import lime
import sklearn
import pandas as pd
import numpy as np
xai_env_info = {
'Python': sys.version,
'NumPy': np.__version__,
'Pandas': pd.__version__,
'scikit-learn': sklearn.__version__,
'SHAP': shap.__version__,
'LIME': lime.__version__,
'Date': '2025-10-19'
}
print("=== XAI Environment ===")
for key, value in xai_env_info.items():
print(f"{key}: {value}")
# Recommended versions
print("\n【Recommended Environment】")
recommended_xai = """
numpy==1.24.3
pandas==2.0.3
scikit-learn==1.3.0
shap==0.43.0
lime==0.2.0.1
matplotlib==3.7.2
"""
print(recommended_xai)
print("\n【Installation Command】")
print("```bash")
print("pip install shap==0.43.0 lime==0.2.0.1")
print("```")
print("\n【Important Notes】")
print("⚠️ SHAP frequently has API changes → Version pinning recommended")
print("⚠️ TreeExplainer verified to work with scikit-learn 1.3 or later")
print("⚠️ Kernel SHAP difficult to compute with large datasets (>10000 samples)")
Practical Pitfalls in XAI
print("=== Practical Pitfalls in XAI\n")
print("【Pitfall 1: Misinterpretation of SHAP Values】")
print("❌ Misconception: Large SHAP value = Large feature value")
print("→ SHAP value represents 'contribution to prediction', not feature magnitude")
print("\n✅ Correct Understanding:")
print("```python")
print("# SHAP value = How much this feature contributed to prediction")
print("# Even small feature values can have large SHAP values")
print("feature_value = 0.1 # Small value")
print("shap_value = 2.5 # Large influence")
print("# → This feature is important despite small magnitude")
print("```")
print("\n【Pitfall 2: Ignoring LIME's Locality】")
print("⚠️ Generalizing LIME explanation from one sample to all")
print("→ LIME's local linear approximation varies across samples")
print("\n✅ Solution: Check consistency across multiple samples")
print("```python")
print("# Run LIME on 10 samples, check if important features agree")
print("for i in range(10):")
print(" explanation = lime_explainer.explain_instance(X[i], model.predict)")
print(" # Check if top 3 features match")
print("```")
print("\n【Pitfall 3: Confusing Correlation and Causation】")
print("⚠️ \"High SHAP value → Changing this feature will change prediction\"")
print("→ This is correlation, not causation")
print("\n✅ Causation requires separate methods")
print("```python")
print("# XAI is correlation analysis")
print("# For causation, use:")
print("# - A/B testing")
print("# - Causal graphs (DAG)")
print("# - Propensity score matching")
print("```")
print("\n【Pitfall 4: Overconfidence in Attention Visualization】")
print("⚠️ High Attention = Model correctly focuses on this part")
print("→ Not necessarily the correct reason")
print("\n✅ Cross-validate with multiple methods")
print("```python")
print("# Check agreement between SHAP + LIME + Attention")
print("# Verify physical plausibility with domain experts")
print("```")
print("\n【Pitfall 5: Ignoring Computational Cost at Scale】")
print("⚠️ Running Kernel SHAP on 10000 samples")
print("→ Computation time: Hours to days")
print("\n✅ Choose method and sample size appropriately")
print("```python")
print("if len(X) < 1000:")
print(" explainer = shap.KernelExplainer() # Any model")
print("else:")
print(" # Use subsampling or TreeExplainer")
print(" X_sample = shap.sample(X, 1000)")
print(" explainer = shap.TreeExplainer() # Fast")
print("```")
Summary
In this chapter, we learned the theory and practice of Explainable AI (XAI).
Key Points:
- Black-Box Problem: High-accuracy models are hard to interpret → XAI provides solutions
- SHAP: Fair feature contribution evaluation using Shapley values
- LIME: Generate explanations through local linear approximation
- Attention Visualization: Understand internal behavior of neural networks
- Real-World Applications: Success cases from Toyota, IBM, Citrine
- Career Path: Growing demand for materials data scientists, salaries ¥7-25 million
- Environment Management: Version pinning for SHAP, LIME and computational cost management
- Practical Pitfalls: SHAP value misinterpretation, LIME locality, correlation vs causation, Attention overconfidence, computational cost
Series Summary:
- Chapter 1: Data collection strategy and cleaning → Prepare high-quality data
- Chapter 2: Feature engineering → Reduce 200 dimensions to 20
- Chapter 3: Model selection and optimization → Automatic optimization with Optuna
- Chapter 4: Explainable AI → Physical interpretation of predictions
Next Steps:
- Apply full workflow to real datasets
- Submit papers and contribute to OSS
- Attend conferences and build networks
- Develop career as materials data scientist
Chapter 4 Checklist
SHAP (SHapley Additive exPlanations)
- [ ] Understanding SHAP Values
- [ ] Understand theoretical background of Shapley values (cooperative game theory)
- [ ] Verify base value + sum of SHAP values = prediction
-
[ ] SHAP values are independent of feature magnitude (represent contribution)
-
[ ] Choosing SHAP Computation Methods
- [ ] Tree-based models → TreeExplainer (fast, exact solution)
- [ ] Any model → KernelExplainer (slow, approximate solution)
-
[ ] Deep learning → DeepExplainer or GradientExplainer
-
[ ] Global Interpretation
- [ ] Evaluate overall feature importance using mean(|SHAP values|)
- [ ] Visualize distribution and influence direction with Summary Plot
-
[ ] Rank top important features with Bar Plot
-
[ ] Local Interpretation
- [ ] Explain prediction basis using SHAP values for individual samples
- [ ] Visualize contribution from base value using Force Plot
-
[ ] Display cumulative contribution using Waterfall Plot
-
[ ] Dependence Plot
- [ ] Visualize relationship between feature value and SHAP value
- [ ] Discover non-linear relationships
- [ ] Identify interaction terms (interaction_index)
LIME (Local Interpretable Model-agnostic Explanations)
- [ ] LIME Basic Understanding
- [ ] Understand how local linear approximation works
- [ ] Random sampling around sample
-
[ ] Distance-based weighting
-
[ ] Tabular LIME
- [ ] Explain tabular data using LimeTabularExplainer
- [ ] Specify number of important features with num_features argument
-
[ ] Explain individual predictions with explain_instance()
-
[ ] Recognize LIME Limitations
- [ ] Explanations are local, not global
- [ ] Different explanations for different samples (no consistency)
-
[ ] Computational cost lower than SHAP
-
[ ] SHAP vs LIME Comparison
- [ ] Check agreement of important features between both methods
- [ ] Correlation > 0.7 indicates high reliability
- [ ] Interpret carefully when results disagree
Attention Visualization (for NN/GNN)
- [ ] Get Attention Weights
- [ ] Extract intermediate layer activations from neural networks
- [ ] Visualize Attention mechanism weights
-
[ ] Analyze which inputs the model attends to
-
[ ] Grad-CAM-style Methods
- [ ] Compute gradient-based importance
- [ ] Approximate feature importance using numerical differentiation
-
[ ] Specialized for neural networks
-
[ ] Application to Molecular Graphs
- [ ] Identify important atoms/bonds in GNN
- [ ] Estimate reaction mechanisms with Attention
- [ ] Verify chemical plausibility with experts
Learning Real-World Application Cases
- [ ] Toyota: Materials Development
- [ ] Identify degradation factors with SHAP analysis
- [ ] Visualize interactions between temperature, voltage, cycle count
-
[ ] 40% reduction in development time, 20% improvement in battery life
-
[ ] IBM Research: Automated Chemistry Experiments
- [ ] Predict reaction mechanisms with GNN + Attention
- [ ] 95% reaction yield prediction accuracy
-
[ ] Discover novel reaction pathways
-
[ ] Citrine Informatics: SaaS Business
- [ ] Core technology: Explainable AI
- [ ] Uncertainty quantification + SHAP
- [ ] Deployed at Panasonic, 3M, Michelin
Building Career Path
- [ ] Required Skillset
- [ ] Materials science expertise (degree recommended)
- [ ] Machine learning and deep learning implementation skills
- [ ] Master XAI methods (SHAP, LIME)
-
[ ] Python, scikit-learn, PyTorch/TensorFlow
-
[ ] Career Options
- [ ] Materials Data Scientist (Manufacturing R&D)
- [ ] XAI Researcher (Academia)
- [ ] ML Engineer (Materials-specialized startup)
- [ ] R&D Manager (AI adoption)
-
[ ] Technical Consultant
-
[ ] Salary Goals
- [ ] Japan: Junior ¥5-7M, Mid ¥10-15M, Senior ¥15-25M
- [ ] US: $70-90K (Junior), $130-180K (Mid), $180-300K (Senior)
-
[ ] Improve salary through: Degree, papers, project success
-
[ ] Skill Development Strategy
- [ ] Strengthen materials science fundamentals (degree or self-study)
- [ ] Practical ML/DL (Kaggle, GitHub)
- [ ] Master XAI methods (this series)
- [ ] Publish papers and contribute to OSS
- [ ] Network at conferences and meetups
Avoiding Practical Pitfalls (XAI)
- [ ] Correct SHAP Value Interpretation
- [ ] SHAP value ≠ Feature magnitude
- [ ] SHAP value = Contribution to prediction
-
[ ] Interpret as deviation from base value
-
[ ] Recognize LIME Locality
- [ ] Don't generalize explanation from one sample to all
- [ ] Check consistency across multiple samples
-
[ ] Cross-validate with SHAP
-
[ ] Distinguish Correlation and Causation
- [ ] XAI is correlation analysis (not causal inference)
- [ ] "High SHAP value → Changing feature changes prediction" is a misconception
-
[ ] Use A/B tests and causal graphs for causation
-
[ ] Limitations of Attention Visualization
- [ ] High Attention ≠ Necessarily correct reasoning
- [ ] Cross-validate with multiple methods (SHAP + LIME + Attention)
-
[ ] Verify physical plausibility with experts
-
[ ] Manage Computational Cost
- [ ] Kernel SHAP: Recommended for < 1000 samples
- [ ] Tree SHAP: Fast even with tens of thousands of samples
- [ ] For large data: Use subsampling or tree-based methods
XAI Quality Evaluation
- [ ] Explanation Consistency
- [ ] SHAP vs LIME correlation > 0.7
- [ ] Important features agree across multiple samples
-
[ ] Physical interpretation matches ML interpretation
-
[ ] Physical Validity
- [ ] Verified by domain experts
- [ ] Consistent with known physical laws
-
[ ] Aligned with experimental results
-
[ ] Practical Utility
- [ ] Can extract material design guidelines
- [ ] Can inform experimental planning
- [ ] Meets accountability requirements for papers and patents
Ensuring Reproducibility
- [ ] Version Management
- [ ] Pin versions of SHAP and LIME
- [ ] Watch for API changes (especially SHAP)
-
[ ] Document in requirements.txt
-
[ ] Unified Computational Environment
- [ ] Set random seeds (SHAP, LIME)
- [ ] Fix parallel computation settings (n_jobs)
-
[ ] Docker environment recommended
-
[ ] Save Explanations
- [ ] Save SHAP values as NumPy arrays
- [ ] Save visualizations as PNG/PDF
- [ ] Convert explanation text to Markdown/LaTeX
References
-
Lundberg, S. M. & Lee, S. I. (2017). A unified approach to interpreting model predictions. Advances in Neural Information Processing Systems, 30, 4765-4774.
-
Ribeiro, M. T., Singh, S., & Guestrin, C. (2016). "Why should I trust you?": Explaining the predictions of any classifier. Proceedings of the 22nd ACM SIGKDD, 1135-1144. DOI: 10.1145/2939672.2939778
-
Molnar, C. (2022). Interpretable Machine Learning: A Guide for Making Black Box Models Explainable (2nd ed.). https://christophm.github.io/interpretable-ml-book/
-
Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). Attention is all you need. Advances in Neural Information Processing Systems, 30, 5998-6008.
-
Citrine Informatics. (2023). Materials Informatics Platform. https://citrine.io/
← Back to Chapter 3 | Back to Series Index
Congratulations on Completing the Series!
You have now acquired practical skills in data-driven materials science. We wish you continued success.
Feedback and Questions: - Email: yusuke.hashimoto.b8@tohoku.ac.jp - GitHub: AI_Homepage Repository
Related Series: - Introduction to Bayesian Optimization - Introduction to Active Learning - Introduction to Graph Neural Networks