🌐 EN | 🇯🇵 JP | Last sync: 2025-11-16

Chapter

📖 Reading time: 20-25 minutes 📊 Difficulty: Beginner 💻 Code examples: 0 📝 Exercises: 0

Chapter 4: Explainable AI (XAI)

This chapter covers Explainable AI (XAI). You will learn Understanding the importance of interpretability, Local linear approximation, and Learning from real-world applications (Toyota.


Learning Objectives

By reading this chapter, you will master:


4.1 The Importance of Interpretability

Understanding machine learning model predictions and extracting physical meaning is essential in materials science.

The Black-Box Problem

# Requirements:
# - Python 3.9+
# - matplotlib>=3.7.0
# - numpy>=1.24.0, <2.0.0
# - pandas>=2.0.0, <2.2.0
# - seaborn>=0.12.0

"""
Example: The Black-Box Problem

Purpose: Demonstrate data visualization techniques
Target: Advanced
Execution time: 2-5 seconds
Dependencies: None
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.ensemble import RandomForestRegressor
from sklearn.linear_model import Ridge

# Sample data
np.random.seed(42)
X = np.random.randn(200, 10)
y = 2*X[:, 0] + 3*X[:, 1] - 1.5*X[:, 2] + np.random.normal(0, 0.5, 200)

# Interpretable model vs black-box model
ridge = Ridge(alpha=1.0)
rf = RandomForestRegressor(n_estimators=100, random_state=42)

ridge.fit(X, y)
rf.fit(X, y)

# Ridge coefficients (interpretable)
ridge_coefs = ridge.coef_

# Visualization: Model interpretability differences
fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Ridge: Clear from linear coefficients
axes[0].bar(range(len(ridge_coefs)), ridge_coefs,
            color='steelblue', alpha=0.7)
axes[0].set_xlabel('Feature Index', fontsize=11)
axes[0].set_ylabel('Coefficient', fontsize=11)
axes[0].set_title('Ridge Regression (Interpretable)', fontsize=12, fontweight='bold')
axes[0].axhline(y=0, color='red', linestyle='--', linewidth=1)
axes[0].grid(alpha=0.3)

# Random Forest: Complex non-linear relationships (black-box)
axes[1].text(0.5, 0.5, '❓\nBlack Box\n\n100 decision trees\nComplex non-linear relationships\nDifficult to interpret',
             ha='center', va='center', fontsize=16,
             bbox=dict(boxstyle='round', facecolor='gray', alpha=0.3),
             transform=axes[1].transAxes)
axes[1].set_title('Random Forest (Black-Box)',
                  fontsize=12, fontweight='bold')
axes[1].axis('off')

plt.tight_layout()
plt.show()

print("Interpretability challenges:")
print("- Linear models: Influence is clear from coefficients, but accuracy is low")
print("- Non-linear models: High accuracy, but why predictions are made is unclear")
print("→ XAI (Explainable AI) aims for both")

The Need for Physical Interpretation in Materials Science

# Use cases for interpretability in materials science
use_cases = pd.DataFrame({
    'Use Case': [
        'New material discovery',
        'Synthesis condition optimization',
        'Process anomaly detection',
        'Property prediction',
        'Material design guidelines'
    ],
    'Importance of Interpretability': [10, 9, 8, 7, 10],
    'Reason': [
        'Understanding physical mechanisms leads to new discoveries',
        'Identify which parameters are critical',
        'Root cause of anomalies must be identified',
        'Verify prediction basis',
        'Extract design principles'
    ]
})

# Visualization
fig, ax = plt.subplots(figsize=(12, 6))

colors = plt.cm.YlOrRd(np.linspace(0.3, 0.9, len(use_cases)))

bars = ax.barh(use_cases['Use Case'],
               use_cases['Importance of Interpretability'],
               color=colors, alpha=0.7)

ax.set_xlabel('Importance of Interpretability (1-10)', fontsize=12)
ax.set_xlim(0, 10)
ax.set_title('Importance of Interpretability in Materials Science',
             fontsize=13, fontweight='bold')
ax.grid(axis='x', alpha=0.3)

# Add reason annotations
for idx, row in use_cases.iterrows():
    ax.text(row['Importance of Interpretability'] + 0.3, idx,
            row['Reason'], va='center', fontsize=9, style='italic')

plt.tight_layout()
plt.show()

print("Why XAI is needed in materials science:")
print("1. Verification of physical law consistency")
print("2. Reflection in experimental planning")
print("3. Integration of expert knowledge")
print("4. Accountability in papers and patents")

Trustworthiness and Debugging

# Example of discovering model prediction errors through interpretation
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_absolute_error

# Data generation (intentionally including noise)
X_data = np.random.randn(300, 5)
# Correct relationship: y = 2*X0 + 3*X1
y_true = 2*X_data[:, 0] + 3*X_data[:, 1] + np.random.normal(0, 0.3, 300)

# Add noise to some samples (simulate measurement error)
noise_idx = np.random.choice(300, 30, replace=False)
y_data = y_true.copy()
y_data[noise_idx] += np.random.normal(0, 5, 30)

# Training
X_train, X_test, y_train, y_test = train_test_split(
    X_data, y_data, test_size=0.2, random_state=42
)

model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# Prediction
y_pred = model.predict(X_test)
mae = mean_absolute_error(y_test, y_pred)

# Identify samples with large errors
errors = np.abs(y_test - y_pred)
high_error_idx = np.where(errors > np.percentile(errors, 90))[0]

print(f"Model MAE: {mae:.4f}")
print(f"Number of high-error samples: {len(high_error_idx)}")
print("\n→ Use XAI to analyze causes of high-error samples")
print("  - Discover data quality issues")
print("  - Identify model weaknesses")
print("  - Verify physical validity")

4.2 SHAP (SHapley Additive exPlanations)

An interpretation method based on Shapley values from cooperative game theory.

Shapley Values Theory

# Requirements:
# - Python 3.9+
# - shap>=0.42.0

"""
Example: Shapley Values Theory

Purpose: Demonstrate machine learning model training and evaluation
Target: Beginner
Execution time: 1-5 minutes
Dependencies: None
"""

import shap

# Visualize SHAP basic concept
shap.initjs()

# Model training
model_shap = RandomForestRegressor(n_estimators=100, random_state=42)
model_shap.fit(X_train, y_train)

# SHAP Explainer
explainer = shap.TreeExplainer(model_shap)
shap_values = explainer.shap_values(X_test)

print("Meaning of SHAP values:")
print("- How much each feature contributed to the prediction")
print("- Shapley value: Fair distribution from cooperative game theory")
print("- Expressed as deviation from base value")
print(f"\nSHAP values shape: {shap_values.shape}")
print(f"  Number of samples: {shap_values.shape[0]}")
print(f"  Number of features: {shap_values.shape[1]}")

# Explain single sample
sample_idx = 0
base_value = explainer.expected_value
prediction = model_shap.predict(X_test[sample_idx:sample_idx+1])[0]

print(f"\nPrediction for sample {sample_idx}:")
print(f"Base value: {base_value:.4f}")
print(f"Sum of SHAP values: {shap_values[sample_idx].sum():.4f}")
print(f"Prediction: {prediction:.4f}")
print(f"Verification: {base_value + shap_values[sample_idx].sum():.4f} ≈ {prediction:.4f}")

SHAP Value Computation (Tree SHAP, Kernel SHAP)

# Tree SHAP (fast, tree-based models only)
explainer_tree = shap.TreeExplainer(model_shap)
shap_values_tree = explainer_tree.shap_values(X_test)

# Kernel SHAP (model-agnostic, slow)
# Demo with small sample
X_test_small = X_test[:10]
explainer_kernel = shap.KernelExplainer(
    model_shap.predict,
    shap.sample(X_train, 50)
)
shap_values_kernel = explainer_kernel.shap_values(X_test_small)

print("Comparison of SHAP computation methods:")
print("\nTree SHAP:")
print(f"  Target models: Tree-based (RF, XGBoost, LightGBM)")
print(f"  Computation speed: Fast")
print(f"  Accuracy: Exact solution")

print("\nKernel SHAP:")
print(f"  Target models: Any (including neural networks)")
print(f"  Computation speed: Slow")
print(f"  Accuracy: Approximate solution (sampling-based)")

# Simple computation time comparison
import time

start = time.time()
_ = explainer_tree.shap_values(X_test)
tree_time = time.time() - start

print(f"\nTree SHAP computation time: {tree_time:.3f} seconds ({len(X_test)} samples)")

Global vs Local Interpretation

# Global interpretation: Average importance across all samples
mean_abs_shap = np.abs(shap_values).mean(axis=0)

fig, axes = plt.subplots(1, 2, figsize=(14, 6))

# Global interpretation
axes[0].bar(range(len(mean_abs_shap)), mean_abs_shap,
            color='steelblue', alpha=0.7)
axes[0].set_xlabel('Feature Index', fontsize=11)
axes[0].set_ylabel('Mean |SHAP value|', fontsize=11)
axes[0].set_title('Global Interpretation (Overall Importance)',
                  fontsize=12, fontweight='bold')
axes[0].grid(alpha=0.3)

# Local interpretation: Specific sample
sample_idx = 0
axes[1].bar(range(len(shap_values[sample_idx])),
            shap_values[sample_idx],
            color='coral', alpha=0.7)
axes[1].axhline(y=0, color='black', linestyle='-', linewidth=1)
axes[1].set_xlabel('Feature Index', fontsize=11)
axes[1].set_ylabel('SHAP value', fontsize=11)
axes[1].set_title(f'Local Interpretation (Sample {sample_idx} explanation)',
                  fontsize=12, fontweight='bold')
axes[1].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("Global vs Local interpretation:")
print("\nGlobal:")
print("  - Average feature importance across all samples")
print("  - Understand overall model behavior")
print("  - General guidelines for new material design")

print("\nLocal:")
print("  - Basis for individual predictions")
print("  - Identify causes of anomalous samples")
print("  - Optimization direction for specific materials")

Summary Plot, Dependence Plot

# Summary plot (overview)
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values, X_test, plot_type="dot", show=False)
plt.title('SHAP Summary Plot', fontsize=13, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()

print("How to read Summary Plot:")
print("- Vertical axis: Features (ordered by importance)")
print("- Horizontal axis: SHAP value (influence on prediction)")
print("- Color: Feature value (red=high, blue=low)")
print("- Distribution: Diversity of influence for each feature")

# Dependence plot (detailed for individual features)
feature_idx = 0

plt.figure(figsize=(10, 6))
shap.dependence_plot(
    feature_idx,
    shap_values,
    X_test,
    show=False
)
plt.title(f'SHAP Dependence Plot (Feature {feature_idx})',
          fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nHow to read Dependence Plot:")
print("- Horizontal axis: Feature value")
print("- Vertical axis: SHAP value (influence on prediction)")
print("- Color: Another feature that interacts")
print("- Trend: Visualization of non-linear relationships")

4.3 LIME (Local Interpretable Model-agnostic Explanations)

An explanation generation method based on local linear approximation.

Local Linear Approximation

# Requirements:
# - Python 3.9+
# - lime>=0.2.0

"""
Example: Local Linear Approximation

Purpose: Demonstrate data visualization techniques
Target: Beginner
Execution time: 1-5 minutes
Dependencies: None
"""

from lime import lime_tabular

# LIME Explainer
lime_explainer = lime_tabular.LimeTabularExplainer(
    X_train,
    mode='regression',
    feature_names=[f'Feature_{i}' for i in range(X_train.shape[1])],
    verbose=False
)

# Explain single sample
sample_idx = 0
explanation = lime_explainer.explain_instance(
    X_test[sample_idx],
    model_shap.predict,
    num_features=5
)

# Visualization
fig = explanation.as_pyplot_figure()
plt.title(f'LIME Explanation (Sample {sample_idx})',
          fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("How LIME works:")
print("1. Random sampling around target sample")
print("2. Predictions by black-box model")
print("3. Distance-based weighting")
print("4. Train local linear model")
print("5. Explain using linear coefficients")

# Display explanation numerically
print("\nExplanations (by importance):")
for feature, weight in explanation.as_list():
    print(f"  {feature}: {weight:.4f}")

Tabular LIME

# Run LIME on multiple samples
n_samples_lime = 5
lime_results = []

for i in range(n_samples_lime):
    exp = lime_explainer.explain_instance(
        X_test[i],
        model_shap.predict,
        num_features=X_train.shape[1]
    )

    # Convert explanation to dictionary
    exp_dict = dict(exp.as_list())
    lime_results.append(exp_dict)

# Convert to DataFrame
lime_df = pd.DataFrame(lime_results)

print(f"\nLIME explanations for {n_samples_lime} samples:")
print(lime_df.head())

# Evaluate consistency (are same features always important?)
feature_importance_consistency = lime_df.abs().mean()
print("\nAverage feature importance (LIME):")
print(feature_importance_consistency.sort_values(ascending=False))

Generating Prediction Explanations

# SHAP vs LIME comparison
def compare_shap_lime(sample_idx):
    """
    Compare SHAP vs LIME explanations for the same sample
    """
    # SHAP
    shap_exp = shap_values[sample_idx]

    # LIME
    lime_exp = lime_explainer.explain_instance(
        X_test[sample_idx],
        model_shap.predict,
        num_features=X_train.shape[1]
    )
    lime_dict = dict(lime_exp.as_list())

    # Align LIME explanations to same order as SHAP
    lime_exp_ordered = []
    for i in range(len(shap_exp)):
        feature_name = f'Feature_{i}'
        # Search for corresponding feature in LIME explanation
        for key, value in lime_dict.items():
            if feature_name in key:
                lime_exp_ordered.append(value)
                break
        else:
            lime_exp_ordered.append(0)

    return shap_exp, np.array(lime_exp_ordered)

# Compare
sample_idx = 0
shap_exp, lime_exp = compare_shap_lime(sample_idx)

# Visualization
fig, ax = plt.subplots(figsize=(12, 6))

x_pos = np.arange(len(shap_exp))
width = 0.35

ax.bar(x_pos - width/2, shap_exp, width,
       label='SHAP', color='steelblue', alpha=0.7)
ax.bar(x_pos + width/2, lime_exp, width,
       label='LIME', color='coral', alpha=0.7)

ax.set_xlabel('Feature Index', fontsize=12)
ax.set_ylabel('Importance', fontsize=12)
ax.set_title(f'SHAP vs LIME (Sample {sample_idx})',
             fontsize=13, fontweight='bold')
ax.set_xticks(x_pos)
ax.legend()
ax.grid(alpha=0.3)
ax.axhline(y=0, color='black', linestyle='-', linewidth=1)

plt.tight_layout()
plt.show()

# Correlation analysis
correlation = np.corrcoef(shap_exp, lime_exp)[0, 1]
print(f"\nSHAP-LIME correlation: {correlation:.4f}")
print("High correlation → Consistent explanations from both methods")

4.4 Attention Visualization (for NN/GNN)

Visualizing the Attention mechanism in neural networks.

Visualizing Attention Weights

# Simple demo of Attention mechanism
from sklearn.neural_network import MLPRegressor

# Train neural network
nn_model = MLPRegressor(
    hidden_layer_sizes=(50, 50),
    max_iter=1000,
    random_state=42
)
nn_model.fit(X_train, y_train)

# Get activation of intermediate layer (simple version)
def get_activation(model, X, layer_idx=0):
    """
    Get activation of specified layer
    """
    # Weights and biases
    W = model.coefs_[layer_idx]
    b = model.intercepts_[layer_idx]

    # Activation (ReLU)
    activation = np.maximum(0, X @ W + b)

    return activation

# Activation of first layer
activation_layer1 = get_activation(nn_model, X_test, layer_idx=0)

# Attention-like weights (treat activation magnitude as weights)
attention_weights = np.abs(activation_layer1).mean(axis=1)

# Visualization
plt.figure(figsize=(12, 6))
plt.scatter(range(len(attention_weights)), attention_weights,
            c=y_test, cmap='viridis', s=100, alpha=0.6)
plt.colorbar(label='Target Value')
plt.xlabel('Sample Index', fontsize=12)
plt.ylabel('Attention Weight (Activation Strength)', fontsize=12)
plt.title('Attention-like Weights (First Layer Activation)',
          fontsize=13, fontweight='bold')
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()

print("Significance of Attention visualization:")
print("- Which inputs is the model attending to?")
print("- Identification of important samples and features")
print("- Understanding internal behavior of neural networks")

Grad-CAM for Materials

# Gradient-based importance (simplified version)
def gradient_based_importance(model, X_sample):
    """
    Gradient-based feature importance
    """
    # Approximate with numerical differentiation
    epsilon = 1e-5
    base_pred = model.predict(X_sample.reshape(1, -1))[0]

    importances = []
    for i in range(len(X_sample)):
        X_perturbed = X_sample.copy()
        X_perturbed[i] += epsilon

        perturbed_pred = model.predict(X_perturbed.reshape(1, -1))[0]

        # Gradient approximation
        gradient = (perturbed_pred - base_pred) / epsilon
        importances.append(gradient)

    return np.array(importances)

# Execute on sample
sample_idx = 0
grad_importances = gradient_based_importance(nn_model, X_test[sample_idx])

# Compare SHAP, LIME, and Gradient
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# SHAP
axes[0].bar(range(len(shap_exp)), shap_exp,
            color='steelblue', alpha=0.7)
axes[0].axhline(y=0, color='black', linestyle='-', linewidth=1)
axes[0].set_xlabel('Feature', fontsize=11)
axes[0].set_ylabel('Importance', fontsize=11)
axes[0].set_title('SHAP', fontsize=12, fontweight='bold')
axes[0].grid(alpha=0.3)

# LIME
axes[1].bar(range(len(lime_exp)), lime_exp,
            color='coral', alpha=0.7)
axes[1].axhline(y=0, color='black', linestyle='-', linewidth=1)
axes[1].set_xlabel('Feature', fontsize=11)
axes[1].set_ylabel('Importance', fontsize=11)
axes[1].set_title('LIME', fontsize=12, fontweight='bold')
axes[1].grid(alpha=0.3)

# Gradient
axes[2].bar(range(len(grad_importances)), grad_importances,
            color='green', alpha=0.7)
axes[2].axhline(y=0, color='black', linestyle='-', linewidth=1)
axes[2].set_xlabel('Feature', fontsize=11)
axes[2].set_ylabel('Gradient', fontsize=11)
axes[2].set_title('Gradient-based', fontsize=12, fontweight='bold')
axes[2].grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("Characteristics of three methods:")
print("SHAP: Game-theoretic fairness, works with all models")
print("LIME: Local linear approximation, intuitive")
print("Gradient: Gradient information, specialized for neural networks")

Which Atoms/Bonds are Important?

# Application example in materials science: Element importance
composition_features = ['Li', 'Co', 'Ni', 'Mn', 'O']

# Simulation data
X_composition = pd.DataFrame({
    'Li': np.random.uniform(0.9, 1.1, 100),
    'Co': np.random.uniform(0, 0.6, 100),
    'Ni': np.random.uniform(0, 0.8, 100),
    'Mn': np.random.uniform(0, 0.4, 100),
    'O': np.random.uniform(1.9, 2.1, 100)
})

# Capacity (Ni is important)
y_capacity = (
    150 * X_composition['Ni'] +
    120 * X_composition['Co'] +
    80 * X_composition['Mn'] +
    np.random.normal(0, 5, 100)
)

# Model training
model_comp = RandomForestRegressor(n_estimators=100, random_state=42)
model_comp.fit(X_composition, y_capacity)

# SHAP analysis
explainer_comp = shap.TreeExplainer(model_comp)
shap_values_comp = explainer_comp.shap_values(X_composition)

# Element-wise importance
mean_abs_shap_comp = np.abs(shap_values_comp).mean(axis=0)

# Visualization
plt.figure(figsize=(10, 6))
plt.bar(composition_features, mean_abs_shap_comp,
        color=['#FFD700', '#4169E1', '#32CD32', '#FF69B4', '#FF6347'],
        alpha=0.7, edgecolor='black', linewidth=1.5)
plt.xlabel('Element', fontsize=12)
plt.ylabel('Mean |SHAP value|', fontsize=12)
plt.title('Element Contribution to Battery Capacity (SHAP Analysis)',
          fontsize=13, fontweight='bold')
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.show()

print("Element-wise importance:")
for elem, importance in zip(composition_features, mean_abs_shap_comp):
    print(f"  {elem}: {importance:.2f}")

print("\nImplications for material design:")
print("→ Increasing Ni content is expected to improve capacity")

4.5 Real-World Applications and Career Paths

We introduce industry applications of XAI and career information for materials data scientists.

Toyota: XAI Application in Materials Development

# Toyota case study (simulation)
print("=== Toyota Automotive Materials Development Case ===")
print("\nChallenge:")
print("  - Clarify battery material degradation mechanisms")
print("  - Select optimal materials from thousands of candidates")

print("\nXAI Application:")
print("  - Identify degradation contributing factors with SHAP analysis")
print("  - Visualize interactions between temperature, voltage, and cycle count")
print("  - Verify consistency with physical models")

print("\nResults:")
print("  - Development time reduced by 40%")
print("  - Battery lifetime improved by 20%")
print("  - Researchers gained physical insights")

# Simulation: Battery degradation prediction
battery_aging = pd.DataFrame({
    'Temperature': np.random.uniform(20, 60, 200),
    'Voltage': np.random.uniform(3.0, 4.5, 200),
    'Cycle Count': np.random.uniform(0, 1000, 200),
    'Charge Rate': np.random.uniform(0.5, 2.0, 200)
})

# Degradation rate (temperature and cycles are main factors)
degradation = (
    0.5 * battery_aging['Temperature'] +
    0.3 * battery_aging['Cycle Count'] / 100 +
    0.2 * battery_aging['Voltage'] * battery_aging['Charge Rate'] +
    np.random.normal(0, 2, 200)
)

# Model
model_aging = RandomForestRegressor(n_estimators=100, random_state=42)
model_aging.fit(battery_aging, degradation)

# SHAP analysis
explainer_aging = shap.TreeExplainer(model_aging)
shap_values_aging = explainer_aging.shap_values(battery_aging)

# Visualization
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values_aging, battery_aging, show=False)
plt.title('SHAP Analysis of Battery Degradation Factors (Toyota-style Case)',
          fontsize=13, fontweight='bold', pad=20)
plt.tight_layout()
plt.show()

IBM Research: Interpretability in AI Materials Design

# Requirements:
# - Python 3.9+
# - networkx>=3.1.0

"""
Example: IBM Research: Interpretability in AI Materials Design

Purpose: Demonstrate data visualization techniques
Target: Advanced
Execution time: 2-5 seconds
Dependencies: None
"""

print("\n=== IBM Research Materials Design Case ===")
print("\nProject: RoboRXN (Automated Chemistry Experiments)")
print("\nCharacteristics:")
print("  - XAI integrated into reaction condition optimization")
print("  - SHAP + Attention for reaction mechanism prediction")
print("  - Generate chemist-understandable recommendations")

print("\nTechnology Stack:")
print("  - Graph Neural Network (GNN)")
print("  - Attention mechanism")
print("  - SHAP for molecular graphs")

print("\nResults:")
print("  - Reaction yield prediction accuracy 95%")
print("  - Gained chemist trust")
print("  - Discovered novel reaction pathways")

# Molecular graph importance visualization (conceptual)
fig, ax = plt.subplots(figsize=(10, 8))

# Dummy molecular graph
import networkx as nx

G = nx.Graph()
G.add_edges_from([
    (0, 1), (1, 2), (2, 3), (3, 4), (4, 0),
    (1, 5), (3, 6)
])

pos = nx.spring_layout(G, seed=42)

# Node importance (Attention weights-like)
node_importance = np.random.rand(len(G.nodes))
node_importance = node_importance / node_importance.sum()

nx.draw(
    G, pos,
    node_color=node_importance,
    node_size=1000 * node_importance / node_importance.max(),
    cmap='YlOrRd',
    with_labels=True,
    font_size=12,
    font_weight='bold',
    edge_color='gray',
    width=2,
    ax=ax
)

sm = plt.cm.ScalarMappable(
    cmap='YlOrRd',
    norm=plt.Normalize(vmin=0, vmax=node_importance.max())
)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, label='Attention Weight')

ax.set_title('Molecular Graph Attention Visualization (IBM-style)',
             fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

Startup: Citrine Informatics (Explainable AI)

print("\n=== Citrine Informatics Case ===")
print("\nBusiness Model:")
print("  - Provides materials development platform")
print("  - Core technology: Explainable AI")
print("  - SaaS deployment to major manufacturers")

print("\nTechnical Features:")
print("  - Bayesian optimization + XAI")
print("  - Uncertainty quantification")
print("  - Integration of physical constraints")

print("\nCustomer Cases:")
print("  - Panasonic: 50% faster battery material development")
print("  - 3M: 30% improvement in adhesive performance")
print("  - Michelin: Tire rubber optimization")

print("\nDifferentiation Factors:")
print("  - Gain expert trust through explainability")
print("  - Integration with physical models")
print("  - High accuracy with small datasets")

# Citrine approach (simulation)
# Predictions with uncertainty + SHAP

from sklearn.ensemble import GradientBoostingRegressor

# Model (quantile regression-style)
model_citrine_lower = GradientBoostingRegressor(
    loss='quantile', alpha=0.1, n_estimators=100, random_state=42
)
model_citrine_median = GradientBoostingRegressor(
    n_estimators=100, random_state=42
)
model_citrine_upper = GradientBoostingRegressor(
    loss='quantile', alpha=0.9, n_estimators=100, random_state=42
)

X_citrine = X_composition
y_citrine = y_capacity

model_citrine_lower.fit(X_citrine, y_citrine)
model_citrine_median.fit(X_citrine, y_citrine)
model_citrine_upper.fit(X_citrine, y_citrine)

# Prediction
X_new = X_citrine.iloc[:20]
y_pred_lower = model_citrine_lower.predict(X_new)
y_pred_median = model_citrine_median.predict(X_new)
y_pred_upper = model_citrine_upper.predict(X_new)

# Visualization
fig, ax = plt.subplots(figsize=(12, 6))

x_axis = range(len(X_new))

ax.fill_between(x_axis, y_pred_lower, y_pred_upper,
                alpha=0.3, color='steelblue',
                label='80% Prediction Interval')
ax.plot(x_axis, y_pred_median, 'o-',
        color='steelblue', linewidth=2, label='Prediction Median')

ax.set_xlabel('Material Sample', fontsize=12)
ax.set_ylabel('Capacity (mAh/g)', fontsize=12)
ax.set_title('Citrine-style Prediction with Uncertainty',
             fontsize=13, fontweight='bold')
ax.legend()
ax.grid(alpha=0.3)

plt.tight_layout()
plt.show()

print("\nBenefits of uncertainty:")
print("  - Risk assessment")
print("  - Prioritization of additional experiments")
print("  - Increased confidence in decision-making")

Career Path: Materials Data Scientists, XAI Researchers

# Career path information
career_paths = pd.DataFrame({
    'Career Path': [
        'Materials Data Scientist',
        'XAI Researcher (Academia)',
        'ML Engineer (Materials-specialized)',
        'R&D Manager (AI adoption)',
        'Technical Consultant'
    ],
    'Required Skills': [
        'Materials Science + ML + Python',
        'Statistics + ML Theory + Paper Writing',
        'ML Implementation + MLOps',
        'Materials Science + Project Management',
        'Materials Science + ML + Business'
    ],
    'Example Employers': [
        'Toyota, Panasonic, Mitsubishi Chemical',
        'Universities, AIST, RIKEN',
        'Citrine, Materials Zone',
        'Large Manufacturing R&D Divisions',
        'Accenture, Deloitte'
    ]
})

print("\n=== Career Paths ===")
print(career_paths.to_string(index=False))

Salary: ¥7-15 million (Japan), $90-180K (US)

# Salary data
salary_data = pd.DataFrame({
    'Position': [
        'Junior (0-3 years)',
        'Mid-level (3-7 years)',
        'Senior (7-15 years)',
        'Lead Scientist',
        'Manager'
    ],
    'Japan_Low': [500, 700, 1000, 1200, 1500],
    'Japan_High': [700, 1000, 1500, 2000, 2500],
    'US_Low': [70, 90, 130, 150, 180],
    'US_High': [90, 130, 180, 220, 300]
})

# Visualization
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# Japan
axes[0].barh(salary_data['Position'],
             salary_data['Japan_High'] - salary_data['Japan_Low'],
             left=salary_data['Japan_Low'],
             color='steelblue', alpha=0.7)

for idx, row in salary_data.iterrows():
    axes[0].text(row['Japan_Low'] - 50, idx,
                 f"{row['Japan_Low']}", va='center', ha='right', fontsize=9)
    axes[0].text(row['Japan_High'] + 50, idx,
                 f"{row['Japan_High']}", va='center', ha='left', fontsize=9)

axes[0].set_xlabel('Annual Salary (ten thousand yen)', fontsize=12)
axes[0].set_title('Japan Salary Range', fontsize=13, fontweight='bold')
axes[0].grid(axis='x', alpha=0.3)

# US
axes[1].barh(salary_data['Position'],
             salary_data['US_High'] - salary_data['US_Low'],
             left=salary_data['US_Low'],
             color='coral', alpha=0.7)

for idx, row in salary_data.iterrows():
    axes[1].text(row['US_Low'] - 5, idx,
                 f"${row['US_Low']}K", va='center', ha='right', fontsize=9)
    axes[1].text(row['US_High'] + 5, idx,
                 f"${row['US_High']}K", va='center', ha='left', fontsize=9)

axes[1].set_xlabel('Annual Salary (thousand dollars)', fontsize=12)
axes[1].set_title('US Salary Range', fontsize=13, fontweight='bold')
axes[1].grid(axis='x', alpha=0.3)

plt.tight_layout()
plt.show()

print("\nFactors affecting salary:")
print("  - Degree (Master's vs PhD)")
print("  - Industry (Manufacturing vs IT)")
print("  - Location (Tokyo vs Regional, Silicon Valley vs Other)")
print("  - Skillset (Materials Science + ML + Domain Knowledge)")
print("  - Track Record (Papers, Patents, Project Success)")

print("\nSkill Development Strategy:")
print("  1. Strengthen materials science fundamentals (degree recommended)")
print("  2. Practical ML/DL skills (Kaggle, GitHub)")
print("  3. Master XAI methods (SHAP, LIME)")
print("  4. Publish papers and contribute to OSS")
print("  5. Network (conferences, meetups)")

Exercises

Problem 1 (Difficulty: Easy)

Using SHAP and LIME, generate explanations for the same sample and calculate the correlation of feature importance. Discuss what high and low correlations mean.

Solution Example
# Requirements:
# - Python 3.9+
# - lime>=0.2.0
# - numpy>=1.24.0, <2.0.0
# - shap>=0.42.0

"""
Example: Using SHAP and LIME, generate explanations for the same samp

Purpose: Demonstrate machine learning model training and evaluation
Target: Advanced
Execution time: 30-60 seconds
Dependencies: None
"""

import shap
from lime import lime_tabular
from sklearn.ensemble import RandomForestRegressor
import numpy as np

# Model training
model = RandomForestRegressor(n_estimators=100, random_state=42)
model.fit(X_train, y_train)

# SHAP
explainer_shap = shap.TreeExplainer(model)
shap_values = explainer_shap.shap_values(X_test)

# LIME
explainer_lime = lime_tabular.LimeTabularExplainer(
    X_train, mode='regression'
)

sample_idx = 0

# LIME explanation
lime_exp = explainer_lime.explain_instance(
    X_test[sample_idx], model.predict, num_features=X_train.shape[1]
)
lime_dict = dict(lime_exp.as_list())

# Calculate correlation
shap_importances = shap_values[sample_idx]
lime_importances = [lime_dict.get(f'Feature_{i}', 0)
                    for i in range(len(shap_importances))]

correlation = np.corrcoef(shap_importances, lime_importances)[0, 1]
print(f"SHAP-LIME correlation: {correlation:.4f}")

if correlation > 0.7:
    print("High correlation: Consistent explanations from both methods → High reliability")
else:
    print("Low correlation: Explanation discrepancy → Careful interpretation needed")

Problem 2 (Difficulty: Medium)

Using SHAP Dependence Plot, visualize interactions between two features. Analyze whether non-linear relationships or interactions are observed.

Solution Example
# Requirements:
# - Python 3.9+
# - matplotlib>=3.7.0
# - shap>=0.42.0

"""
Example: Using SHAP Dependence Plot, visualize interactions between t

Purpose: Demonstrate data visualization techniques
Target: Beginner to Intermediate
Execution time: 2-5 seconds
Dependencies: None
"""

import shap
import matplotlib.pyplot as plt

# SHAP calculation
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X_test)

# Dependence Plot (feature 0 and feature 1 interaction)
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

shap.dependence_plot(0, shap_values, X_test, interaction_index=1,
                     ax=axes[0], show=False)
axes[0].set_title('Feature 0 (interaction with Feature 1)')

shap.dependence_plot(1, shap_values, X_test, interaction_index=0,
                     ax=axes[1], show=False)
axes[1].set_title('Feature 1 (interaction with Feature 0)')

plt.tight_layout()
plt.show()

print("Analysis Points:")
print("- Color variation: Strength of interaction")
print("- Non-linear patterns: Complex relationships")
print("- Trend: Direction of influence (positive/negative)")

Problem 3 (Difficulty: Hard)

Mimic the Toyota battery degradation prediction case and perform SHAP analysis on three factors (temperature, voltage, cycle count), quantitatively evaluating which factor contributes most to degradation. Also discuss physical validity.

Solution Example
# Requirements:
# - Python 3.9+
# - numpy>=1.24.0, <2.0.0
# - pandas>=2.0.0, <2.2.0
# - shap>=0.42.0

"""
Example: Mimic the Toyota battery degradation prediction case and per

Purpose: Demonstrate data visualization techniques
Target: Advanced
Execution time: 30-60 seconds
Dependencies: None
"""

import pandas as pd
import numpy as np
from sklearn.ensemble import GradientBoostingRegressor
import shap

# Data generation
battery_data = pd.DataFrame({
    'Temperature': np.random.uniform(20, 60, 300),
    'Voltage': np.random.uniform(3.0, 4.5, 300),
    'Cycle Count': np.random.uniform(0, 1000, 300)
})

# Degradation rate (physically reasonable model)
# Degradation increases with high temperature, high voltage, and many cycles
degradation = (
    0.8 * (battery_data['Temperature'] - 20) +  # High temp accelerates degradation
    2.0 * (battery_data['Voltage'] - 3.0)**2 +  # High voltage accelerates degradation
    0.05 * battery_data['Cycle Count'] +  # Cycle degradation
    0.01 * battery_data['Temperature'] * battery_data['Cycle Count'] / 100 +  # Interaction
    np.random.normal(0, 3, 300)
)

# Model training
model = GradientBoostingRegressor(n_estimators=100, random_state=42)
model.fit(battery_data, degradation)

# SHAP analysis
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(battery_data)

# Importance aggregation
mean_abs_shap = np.abs(shap_values).mean(axis=0)
feature_names = battery_data.columns

print("Degradation factor importance (SHAP):")
for name, importance in zip(feature_names, mean_abs_shap):
    print(f"  {name}: {importance:.2f}")

# Summary Plot
shap.summary_plot(shap_values, battery_data, show=False)
plt.title('SHAP Analysis of Battery Degradation Factors', fontsize=13, fontweight='bold')
plt.tight_layout()
plt.show()

print("\nPhysical Validity:")
print("- Temperature: Arrhenius equation predicts increased reaction rate at high temp → Valid")
print("- Voltage: High voltage promotes side reactions → Valid")
print("- Cycle Count: Repeated charge-discharge causes degradation → Valid")

4.6 XAI Environment and Practical Pitfalls

SHAP Library Version Management

# Requirements:
# - Python 3.9+
# - lime>=0.2.0
# - numpy>=1.24.0, <2.0.0
# - pandas>=2.0.0, <2.2.0
# - scikit-learn>=1.3.0, <1.5.0
# - shap>=0.42.0

"""
Example: SHAP Library Version Management

Purpose: Demonstrate data visualization techniques
Target: Intermediate
Execution time: 2-5 seconds
Dependencies: None
"""

# Libraries needed for XAI
import sys
import shap
import lime
import sklearn
import pandas as pd
import numpy as np

xai_env_info = {
    'Python': sys.version,
    'NumPy': np.__version__,
    'Pandas': pd.__version__,
    'scikit-learn': sklearn.__version__,
    'SHAP': shap.__version__,
    'LIME': lime.__version__,
    'Date': '2025-10-19'
}

print("=== XAI Environment ===")
for key, value in xai_env_info.items():
    print(f"{key}: {value}")

# Recommended versions
print("\n【Recommended Environment】")
recommended_xai = """
numpy==1.24.3
pandas==2.0.3
scikit-learn==1.3.0
shap==0.43.0
lime==0.2.0.1
matplotlib==3.7.2
"""
print(recommended_xai)

print("\n【Installation Command】")
print("```bash")
print("pip install shap==0.43.0 lime==0.2.0.1")
print("```")

print("\n【Important Notes】")
print("⚠️ SHAP frequently has API changes → Version pinning recommended")
print("⚠️ TreeExplainer verified to work with scikit-learn 1.3 or later")
print("⚠️ Kernel SHAP difficult to compute with large datasets (>10000 samples)")

Practical Pitfalls in XAI

print("=== Practical Pitfalls in XAI\n")

print("【Pitfall 1: Misinterpretation of SHAP Values】")
print("❌ Misconception: Large SHAP value = Large feature value")
print("→ SHAP value represents 'contribution to prediction', not feature magnitude")

print("\n✅ Correct Understanding:")
print("```python")
print("# SHAP value = How much this feature contributed to prediction")
print("# Even small feature values can have large SHAP values")
print("feature_value = 0.1  # Small value")
print("shap_value = 2.5     # Large influence")
print("# → This feature is important despite small magnitude")
print("```")

print("\n【Pitfall 2: Ignoring LIME's Locality】")
print("⚠️ Generalizing LIME explanation from one sample to all")
print("→ LIME's local linear approximation varies across samples")

print("\n✅ Solution: Check consistency across multiple samples")
print("```python")
print("# Run LIME on 10 samples, check if important features agree")
print("for i in range(10):")
print("    explanation = lime_explainer.explain_instance(X[i], model.predict)")
print("    # Check if top 3 features match")
print("```")

print("\n【Pitfall 3: Confusing Correlation and Causation】")
print("⚠️ \"High SHAP value → Changing this feature will change prediction\"")
print("→ This is correlation, not causation")

print("\n✅ Causation requires separate methods")
print("```python")
print("# XAI is correlation analysis")
print("# For causation, use:")
print("# - A/B testing")
print("# - Causal graphs (DAG)")
print("# - Propensity score matching")
print("```")

print("\n【Pitfall 4: Overconfidence in Attention Visualization】")
print("⚠️ High Attention = Model correctly focuses on this part")
print("→ Not necessarily the correct reason")

print("\n✅ Cross-validate with multiple methods")
print("```python")
print("# Check agreement between SHAP + LIME + Attention")
print("# Verify physical plausibility with domain experts")
print("```")

print("\n【Pitfall 5: Ignoring Computational Cost at Scale】")
print("⚠️ Running Kernel SHAP on 10000 samples")
print("→ Computation time: Hours to days")

print("\n✅ Choose method and sample size appropriately")
print("```python")
print("if len(X) < 1000:")
print("    explainer = shap.KernelExplainer()  # Any model")
print("else:")
print("    # Use subsampling or TreeExplainer")
print("    X_sample = shap.sample(X, 1000)")
print("    explainer = shap.TreeExplainer()  # Fast")
print("```")

Summary

In this chapter, we learned the theory and practice of Explainable AI (XAI).

Key Points:

  1. Black-Box Problem: High-accuracy models are hard to interpret → XAI provides solutions
  2. SHAP: Fair feature contribution evaluation using Shapley values
  3. LIME: Generate explanations through local linear approximation
  4. Attention Visualization: Understand internal behavior of neural networks
  5. Real-World Applications: Success cases from Toyota, IBM, Citrine
  6. Career Path: Growing demand for materials data scientists, salaries ¥7-25 million
  7. Environment Management: Version pinning for SHAP, LIME and computational cost management
  8. Practical Pitfalls: SHAP value misinterpretation, LIME locality, correlation vs causation, Attention overconfidence, computational cost

Series Summary:

Next Steps:

  1. Apply full workflow to real datasets
  2. Submit papers and contribute to OSS
  3. Attend conferences and build networks
  4. Develop career as materials data scientist

Chapter 4 Checklist

SHAP (SHapley Additive exPlanations)

LIME (Local Interpretable Model-agnostic Explanations)

Attention Visualization (for NN/GNN)

Learning Real-World Application Cases

Building Career Path

Avoiding Practical Pitfalls (XAI)

XAI Quality Evaluation

Ensuring Reproducibility


References

  1. Lundberg, S. M. & Lee, S. I. (2017). A unified approach to interpreting model predictions. Advances in Neural Information Processing Systems, 30, 4765-4774.

  2. Ribeiro, M. T., Singh, S., & Guestrin, C. (2016). "Why should I trust you?": Explaining the predictions of any classifier. Proceedings of the 22nd ACM SIGKDD, 1135-1144. DOI: 10.1145/2939672.2939778

  3. Molnar, C. (2022). Interpretable Machine Learning: A Guide for Making Black Box Models Explainable (2nd ed.). https://christophm.github.io/interpretable-ml-book/

  4. Vaswani, A., Shazeer, N., Parmar, N., et al. (2017). Attention is all you need. Advances in Neural Information Processing Systems, 30, 5998-6008.

  5. Citrine Informatics. (2023). Materials Informatics Platform. https://citrine.io/


← Back to Chapter 3 | Back to Series Index


Congratulations on Completing the Series!

You have now acquired practical skills in data-driven materials science. We wish you continued success.

Feedback and Questions: - Email: yusuke.hashimoto.b8@tohoku.ac.jp - GitHub: AI_Homepage Repository

Related Series: - Introduction to Bayesian Optimization - Introduction to Active Learning - Introduction to Graph Neural Networks

Disclaimer