🌐 EN | πŸ‡―πŸ‡΅ JP | Last sync: 2025-11-16

Chapter 6: PyTorch Geometric Workflow

In this chapter, we will learn practical workflows using PyTorch Geometric and the Materials Project API. From creating custom datasets, distributed training, GPU optimization, to deployment in production environments, you will comprehensively master the techniques required for actual projects.

6.1 Data Acquisition with Materials Project API

Materials Project is one of the largest open databases in materials science, providing over 148,000 crystal structures and property data. We will learn how to efficiently acquire and process this data using the pymatgen library and mp-api.

6.1.1 Materials Project API Authentication

To use the Materials Project API, you need a free API key. Please create an account on the Materials Project website and obtain your API key.

# Requirements:
# - Python 3.9+
# - pandas>=2.0.0, <2.2.0

"""
Example: To use the Materials Project API, you need a free API key. P

Purpose: Demonstrate data manipulation and preprocessing
Target: Beginner to Intermediate
Execution time: 5-10 seconds
Dependencies: None
"""

# Code Example 1: Materials Project API Authentication and Basic Data Acquisition
# Executable in Google Colab

# Install required libraries
!pip install mp-api pymatgen -q

from mp_api.client import MPRester
from pymatgen.core import Structure
import pandas as pd

# Set API key (replace with your own API key)
API_KEY = "your_api_key_here"

# Initialize MPRester
with MPRester(API_KEY) as mpr:
    # Search for perovskite structures (ABX3)
    # Materials with negative formation energy (stable) and band gap 1-3 eV
    docs = mpr.materials.summary.search(
        formula="*3",  # ABX3 format
        num_elements=(3, 3),  # Ternary systems
        energy_above_hull=(0, 0.01),  # Nearly stable phase
        band_gap=(1.0, 3.0),  # Semiconductor range
        fields=["material_id", "formula_pretty", "band_gap",
                "energy_above_hull", "formation_energy_per_atom"]
    )

# Convert results to DataFrame
data = []
for doc in docs:
    data.append({
        "material_id": doc.material_id,
        "formula": doc.formula_pretty,
        "band_gap": doc.band_gap,
        "e_hull": doc.energy_above_hull,
        "formation_energy": doc.formation_energy_per_atom
    })

df = pd.DataFrame(data)
print(f"Search results: {len(df)} materials")
print(df.head())

# Statistical information
print("\n=== Statistics ===")
print(f"Band gap range: {df['band_gap'].min():.3f} - {df['band_gap'].max():.3f} eV")
print(f"Formation energy range: {df['formation_energy'].min():.3f} - {df['formation_energy'].max():.3f} eV/atom")
Example output:
Search results: 247 materials
Band gap range: 1.012 - 2.987 eV
Formation energy range: -2.345 - -0.128 eV/atom

6.1.2 Acquiring Crystal Structure Data and Saving in CIF Format

Crystal structures obtained from Materials Project are handled as pymatgen Structure objects. These can be saved in CIF (Crystallographic Information File) format for use in visualization and as input to machine learning models.

# Code Example 2: Acquiring Crystal Structures and Saving in CIF Format
# Executable in Google Colab

from mp_api.client import MPRester
from pymatgen.io.cif import CifWriter
import os

API_KEY = "your_api_key_here"

# Create save directory
os.makedirs("structures", exist_ok=True)

with MPRester(API_KEY) as mpr:
    # Example: Get crystal structure for mp-1234 (sample Material ID)
    # Replace with actual Material ID
    structure = mpr.get_structure_by_material_id("mp-1234")

    # Display structure information
    print("=== Crystal Structure Information ===")
    print(f"Chemical formula: {structure.composition.reduced_formula}")
    print(f"Space group: {structure.get_space_group_info()}")
    print(f"Lattice constants: {structure.lattice.abc}")
    print(f"Lattice angles: {structure.lattice.angles}")
    print(f"Number of atoms: {len(structure)}")
    print(f"Volume: {structure.volume:.3f} Γ…Β³")

    # Atomic site information
    print("\n=== Atomic Sites ===")
    for i, site in enumerate(structure):
        print(f"Site {i+1}: {site.species_string} at {site.frac_coords}")

    # Save in CIF format
    cif_writer = CifWriter(structure)
    cif_writer.write_file(f"structures/mp-1234.cif")
    print("\nSaved CIF file: structures/mp-1234.cif")

# Batch acquisition of multiple materials
material_ids = ["mp-1234", "mp-5678", "mp-9012"]  # Replace with actual IDs

with MPRester(API_KEY) as mpr:
    for mat_id in material_ids:
        try:
            structure = mpr.get_structure_by_material_id(mat_id)
            cif_writer = CifWriter(structure)
            cif_writer.write_file(f"structures/{mat_id}.cif")
            print(f"βœ“ {mat_id}: {structure.composition.reduced_formula}")
        except Exception as e:
            print(f"βœ— {mat_id}: Error - {e}")
About API Limits: The Materials Project API has daily request limits. For large-scale data acquisition, please respect rate limits using time.sleep() and consider batch processing.

6.2 PyTorch Geometric Custom Dataset

We will convert data obtained from Materials Project into PyTorch Geometric Data objects and create training datasets. By inheriting the InMemoryDataset class, we can achieve efficient data loading.

6.2.1 Converting Materials Project to PyG Data

# Requirements:
# - Python 3.9+
# - numpy>=1.24.0, <2.0.0
# - torch>=2.0.0, <2.3.0

# Code Example 3: Converting Materials Project Structures to PyTorch Geometric Data
# Executable in Google Colab (GPU recommended)

import torch
from torch_geometric.data import Data, InMemoryDataset
from pymatgen.core import Structure
from mp_api.client import MPRester
import numpy as np
from typing import List, Tuple

class StructureToGraph:
    """
    Convert pymatgen Structure objects to graph representation
    """
    def __init__(self, cutoff: float = 5.0):
        """
        Args:
            cutoff: Cutoff radius for atomic distances (Γ…)
        """
        self.cutoff = cutoff

    def convert(self, structure: Structure) -> Data:
        """
        Structure β†’ PyG Data conversion

        Args:
            structure: pymatgen Structure object

        Returns:
            PyG Data object
        """
        # Node features: One-hot representation of atomic numbers (max atomic number 92: U)
        atom_numbers = [site.specie.Z for site in structure]
        x = torch.zeros((len(atom_numbers), 92))
        for i, z in enumerate(atom_numbers):
            x[i, z-1] = 1.0  # Index starts from 0

        # Edge construction: Atom pairs within cutoff radius
        edge_index = []
        edge_attr = []

        for i, site_i in enumerate(structure):
            # Neighbor search considering periodic boundary conditions
            neighbors = structure.get_neighbors(site_i, self.cutoff)

            for neighbor in neighbors:
                j = neighbor.index
                distance = neighbor.nn_distance

                # Add edge (bidirectional for undirected graph)
                edge_index.append([i, j])

                # Edge features: Gaussian expansion of distance
                edge_feature = self._gaussian_expansion(distance)
                edge_attr.append(edge_feature)

        edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
        edge_attr = torch.tensor(edge_attr, dtype=torch.float)

        # Create graph data
        data = Data(
            x=x,
            edge_index=edge_index,
            edge_attr=edge_attr
        )

        return data

    def _gaussian_expansion(self, distance: float, num_centers: int = 41) -> np.ndarray:
        """
        Expand distance using Gaussian basis functions

        Args:
            distance: Atomic distance (Γ…)
            num_centers: Number of Gaussian basis functions

        Returns:
            Expansion coefficient vector
        """
        centers = np.linspace(0, self.cutoff, num_centers)
        width = 0.5  # Gaussian width

        gamma = -0.5 / (width ** 2)
        return np.exp(gamma * (distance - centers) ** 2)

# Usage example
API_KEY = "your_api_key_here"
converter = StructureToGraph(cutoff=5.0)

with MPRester(API_KEY) as mpr:
    structure = mpr.get_structure_by_material_id("mp-1234")
    data = converter.convert(structure)

    print("=== Graph Representation ===")
    print(f"Number of nodes: {data.x.size(0)}")
    print(f"Node feature dimension: {data.x.size(1)}")
    print(f"Number of edges: {data.edge_index.size(1)}")
    print(f"Edge feature dimension: {data.edge_attr.size(1)}")

6.2.2 Custom InMemoryDataset Implementation

By inheriting InMemoryDataset, we can automate preprocessing and caching of Materials Project data. This significantly reduces data acquisition time on subsequent runs.

# Requirements:
# - Python 3.9+
# - torch>=2.0.0, <2.3.0

# Code Example 4: Custom InMemoryDataset for Materials Project
# Executable in Google Colab (GPU recommended)

import os
import torch
from torch_geometric.data import InMemoryDataset, Data
from mp_api.client import MPRester
import pickle

class MaterialsProjectDataset(InMemoryDataset):
    """
    Create dataset for material property prediction from Materials Project
    """
    def __init__(self, root, api_key, material_ids=None,
                 property_name="band_gap", cutoff=5.0,
                 transform=None, pre_transform=None, pre_filter=None):
        """
        Args:
            root: Dataset save directory
            api_key: Materials Project API key
            material_ids: List of Material IDs to retrieve (if None, search)
            property_name: Target property for prediction ('band_gap', 'formation_energy_per_atom', etc.)
            cutoff: Graph construction cutoff radius (Γ…)
        """
        self.api_key = api_key
        self.material_ids = material_ids
        self.property_name = property_name
        self.converter = StructureToGraph(cutoff=cutoff)

        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['materials.pkl']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        """
        Download data from Materials Project API
        """
        with MPRester(self.api_key) as mpr:
            if self.material_ids is None:
                # Search if Material IDs are not specified
                docs = mpr.materials.summary.search(
                    energy_above_hull=(0, 0.05),  # Include metastable phases
                    num_elements=(1, 5),  # 1-5 element systems
                    fields=["material_id", self.property_name]
                )
                self.material_ids = [doc.material_id for doc in docs
                                     if getattr(doc, self.property_name) is not None]
                print(f"Search results: {len(self.material_ids)} materials")

            # Retrieve structure and property data
            materials_data = []
            for i, mat_id in enumerate(self.material_ids):
                try:
                    structure = mpr.get_structure_by_material_id(mat_id)
                    doc = mpr.materials.summary.search(
                        material_ids=[mat_id],
                        fields=[self.property_name]
                    )[0]

                    property_value = getattr(doc, self.property_name)

                    materials_data.append({
                        'material_id': mat_id,
                        'structure': structure,
                        'property': property_value
                    })

                    if (i + 1) % 100 == 0:
                        print(f"Download progress: {i+1}/{len(self.material_ids)}")

                except Exception as e:
                    print(f"Error ({mat_id}): {e}")

            # Save
            os.makedirs(self.raw_dir, exist_ok=True)
            with open(self.raw_paths[0], 'wb') as f:
                pickle.dump(materials_data, f)

            print(f"βœ“ Download complete: {len(materials_data)} entries")

    def process(self):
        """
        Convert raw data to PyG Data format
        """
        # Load raw data
        with open(self.raw_paths[0], 'rb') as f:
            materials_data = pickle.load(f)

        # Convert to PyG Data format
        data_list = []
        for item in materials_data:
            # Graph conversion
            data = self.converter.convert(item['structure'])

            # Add label (property value)
            data.y = torch.tensor([item['property']], dtype=torch.float)
            data.material_id = item['material_id']

            data_list.append(data)

        # Filtering (optional)
        if self.pre_filter is not None:
            data_list = [d for d in data_list if self.pre_filter(d)]

        # Preprocessing (optional)
        if self.pre_transform is not None:
            data_list = [self.pre_transform(d) for d in data_list]

        # Save
        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])
        print(f"βœ“ Processing complete: {len(data_list)} entries")

# Usage example
API_KEY = "your_api_key_here"

# Create dataset (automatic download & processing on first run)
dataset = MaterialsProjectDataset(
    root='./data/mp_band_gap',
    api_key=API_KEY,
    property_name='band_gap',
    cutoff=5.0
)

print(f"Dataset size: {len(dataset)}")
print(f"Sample: {dataset[0]}")
Caching Feature: InMemoryDataset automatically saves processed data. On subsequent runs, it simply loads the saved data for fast startup.

6.3 Distributed Training and GPU Optimization

To efficiently train large datasets and complex GNN models, we utilize PyTorch's distributed training features and GPU optimization techniques.

6.3.1 Multi-GPU Training with DataParallel

# Requirements:
# - Python 3.9+
# - torch>=2.0.0, <2.3.0

# Code Example 5: Multi-GPU Parallel Training with DataParallel
# Executable in Google Colab Pro/Pro+ (multi-GPU environment)

import torch
import torch.nn as nn
from torch_geometric.nn import CGConv, global_mean_pool
from torch_geometric.loader import DataLoader
import time

class CGCNNModel(nn.Module):
    """
    CGCNN (Crystal Graph Convolutional Neural Network)
    """
    def __init__(self, atom_fea_len=92, nbr_fea_len=41,
                 hidden_dim=128, n_conv=3):
        super(CGCNNModel, self).__init__()

        # Atom embedding
        self.atom_embedding = nn.Linear(atom_fea_len, hidden_dim)

        # CGConv layers
        self.conv_layers = nn.ModuleList([
            CGConv(hidden_dim, nbr_fea_len) for _ in range(n_conv)
        ])

        self.bn_layers = nn.ModuleList([
            nn.BatchNorm1d(hidden_dim) for _ in range(n_conv)
        ])

        # Prediction head
        self.fc1 = nn.Linear(hidden_dim, 64)
        self.fc2 = nn.Linear(64, 1)

        self.activation = nn.Softplus()

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        # Atom embedding
        x = self.atom_embedding(x)

        # CGConv layers (with residual connections)
        for conv, bn in zip(self.conv_layers, self.bn_layers):
            x_new = conv(x, edge_index, edge_attr)
            x_new = bn(x_new)
            x_new = self.activation(x_new)
            x = x + x_new  # Residual connection

        # Graph-level pooling
        x = global_mean_pool(x, batch)

        # Prediction
        x = self.activation(self.fc1(x))
        x = self.fc2(x)

        return x.squeeze()

# Multi-GPU training
def train_multigpu(dataset, epochs=100, batch_size=64, lr=0.001):
    """
    Multi-GPU parallel training with DataParallel
    """
    # Data loader
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    # Model construction
    model = CGCNNModel()

    # Check GPU devices
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    gpu_count = torch.cuda.device_count()
    print(f"Available GPUs: {gpu_count}")

    if gpu_count > 1:
        # Multi-GPU parallelization
        model = nn.DataParallel(model)
        print(f"DataParallel mode: Using {gpu_count} GPUs")

    model = model.to(device)

    # Optimization setup
    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=10, verbose=True
    )

    # Training loop
    for epoch in range(epochs):
        model.train()
        total_loss = 0
        start_time = time.time()

        for batch in train_loader:
            batch = batch.to(device)

            optimizer.zero_grad()
            output = model(batch)
            loss = criterion(output, batch.y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * batch.num_graphs

        avg_loss = total_loss / len(dataset)
        epoch_time = time.time() - start_time

        # Learning rate adjustment
        scheduler.step(avg_loss)

        if (epoch + 1) % 10 == 0:
            print(f"Epoch {epoch+1}/{epochs}, Loss: {avg_loss:.4f}, Time: {epoch_time:.2f}s")

    return model

# Execution example
# dataset = MaterialsProjectDataset(...) dataset created earlier
# model = train_multigpu(dataset, epochs=200, batch_size=64)
How DataParallel Works: Batches are split across GPUs, and forward and backward propagation are executed in parallel on each GPU. Gradients are gathered on GPU 0, and parameter updates are performed.

6.3.2 Mixed Precision Training

Using PyTorch's torch.cuda.amp, we train with a mix of FP16 (half-precision floating point) and FP32 (single-precision). This reduces memory usage and can speed up training by up to 2x.

# Requirements:
# - Python 3.9+
# - torch>=2.0.0, <2.3.0

# Code Example 6: Mixed Precision Training
# Executable in Google Colab (GPU environment)

import torch
from torch.cuda.amp import autocast, GradScaler
from torch_geometric.loader import DataLoader
import time

def train_mixed_precision(model, dataset, epochs=100, batch_size=64, lr=0.001):
    """
    Fast training with Mixed Precision Training
    """
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    # Data loader
    train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=4)

    # Optimization setup
    criterion = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    # Gradient Scaler (gradient scaling)
    scaler = GradScaler()

    print("=== Starting Mixed Precision Training ===")

    for epoch in range(epochs):
        model.train()
        total_loss = 0
        start_time = time.time()

        for batch in train_loader:
            batch = batch.to(device)

            optimizer.zero_grad()

            # Mixed Precision: Forward pass in FP16
            with autocast():
                output = model(batch)
                loss = criterion(output, batch.y)

            # Backward pass with scaling
            scaler.scale(loss).backward()

            # Gradient clipping (optional)
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

            # Parameter update
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item() * batch.num_graphs

        avg_loss = total_loss / len(dataset)
        epoch_time = time.time() - start_time

        if (epoch + 1) % 10 == 0:
            # Display memory usage
            if torch.cuda.is_available():
                memory_allocated = torch.cuda.memory_allocated() / 1024**3  # GB
                memory_reserved = torch.cuda.memory_reserved() / 1024**3
                print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Time={epoch_time:.2f}s, "
                      f"Memory={memory_allocated:.2f}GB/{memory_reserved:.2f}GB")
            else:
                print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Time={epoch_time:.2f}s")

    return model

# Usage example
model = CGCNNModel()
# model = train_mixed_precision(model, dataset, epochs=200)
Benefits of Mixed Precision: On V100 GPU, approximately 1.5-2x training speed improvement and about 40% reduction in memory usage. Impact on accuracy is negligible (MAE difference < 0.001).

6.4 Model Saving and Loading

We will learn best practices for saving trained models and loading them later for inference.

6.4.1 Checkpoint Saving

# Requirements:
# - Python 3.9+
# - torch>=2.0.0, <2.3.0

# Code Example 7: Model Checkpoint Saving and Loading
# Executable in Google Colab

import torch
import os
from datetime import datetime

class ModelCheckpoint:
    """
    Model checkpoint management
    """
    def __init__(self, save_dir='checkpoints', monitor='val_loss', mode='min'):
        """
        Args:
            save_dir: Save directory
            monitor: Metric to monitor ('val_loss', 'val_mae', etc.)
            mode: 'min' (minimize) or 'max' (maximize)
        """
        self.save_dir = save_dir
        self.monitor = monitor
        self.mode = mode
        self.best_score = float('inf') if mode == 'min' else float('-inf')

        os.makedirs(save_dir, exist_ok=True)

    def save(self, model, optimizer, epoch, metrics, filename=None):
        """
        Save checkpoint

        Args:
            model: PyTorch model
            optimizer: Optimizer
            epoch: Current epoch
            metrics: Metrics dictionary (e.g., {'val_loss': 0.025, 'val_mae': 0.18})
            filename: Save filename (auto-generated if None)
        """
        if filename is None:
            timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
            filename = f"checkpoint_epoch{epoch}_{timestamp}.pt"

        filepath = os.path.join(self.save_dir, filename)

        # Checkpoint data
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'metrics': metrics
        }

        torch.save(checkpoint, filepath)
        print(f"βœ“ Checkpoint saved: {filepath}")

        # Save separately if best model
        current_score = metrics.get(self.monitor)
        if current_score is not None:
            is_best = (self.mode == 'min' and current_score < self.best_score) or \
                      (self.mode == 'max' and current_score > self.best_score)

            if is_best:
                self.best_score = current_score
                best_path = os.path.join(self.save_dir, 'best_model.pt')
                torch.save(checkpoint, best_path)
                print(f"βœ“ Best model updated: {self.monitor}={current_score:.4f}")

    @staticmethod
    def load(filepath, model, optimizer=None):
        """
        Load checkpoint

        Args:
            filepath: Checkpoint file path
            model: Model to load into
            optimizer: Optimizer to load into (optional)

        Returns:
            epoch, metrics
        """
        checkpoint = torch.load(filepath)

        model.load_state_dict(checkpoint['model_state_dict'])

        if optimizer is not None and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

        epoch = checkpoint.get('epoch', 0)
        metrics = checkpoint.get('metrics', {})

        print(f"βœ“ Checkpoint loaded: {filepath}")
        print(f"  Epoch: {epoch}, Metrics: {metrics}")

        return epoch, metrics

# Usage example: Checkpoint saving in training loop
checkpoint_manager = ModelCheckpoint(save_dir='./checkpoints', monitor='val_mae', mode='min')

for epoch in range(100):
    # Training process
    train_loss = 0.0  # Calculated in actual training

    # Validation process
    val_loss = 0.0  # Calculated in actual validation
    val_mae = 0.0

    # Save checkpoint every 10 epochs
    if (epoch + 1) % 10 == 0:
        metrics = {
            'train_loss': train_loss,
            'val_loss': val_loss,
            'val_mae': val_mae
        }
        checkpoint_manager.save(model, optimizer, epoch + 1, metrics)

# Load best model
model_new = CGCNNModel()
checkpoint_manager.load('./checkpoints/best_model.pt', model_new)

6.4.2 ONNX Format Export (Inference Optimization)

By exporting to ONNX (Open Neural Network Exchange) format, you can maximize inference speed and use the model in different frameworks (TensorFlow, C++, etc.).

# Requirements:
# - Python 3.9+
# - numpy>=1.24.0, <2.0.0
# - torch>=2.0.0, <2.3.0

# Code Example 8: ONNX Export and Inference
# Executable in Google Colab

import torch
import torch.onnx
from torch_geometric.data import Batch

def export_to_onnx(model, sample_data, onnx_path='model.onnx'):
    """
    Export PyTorch Geometric model to ONNX format

    Args:
        model: PyTorch model
        sample_data: Sample input data (Data type)
        onnx_path: Save path
    """
    model.eval()

    # Convert sample data to batch format
    batch = Batch.from_data_list([sample_data])

    # Create dummy input (required for ONNX export)
    dummy_input = (
        batch.x,
        batch.edge_index,
        batch.edge_attr,
        batch.batch
    )

    # ONNX export
    torch.onnx.export(
        model,
        dummy_input,
        onnx_path,
        export_params=True,
        opset_version=14,
        do_constant_folding=True,
        input_names=['x', 'edge_index', 'edge_attr', 'batch'],
        output_names=['output'],
        dynamic_axes={
            'x': {0: 'num_nodes'},
            'edge_index': {1: 'num_edges'},
            'edge_attr': {0: 'num_edges'},
            'batch': {0: 'num_nodes'}
        }
    )

    print(f"βœ“ ONNX export completed: {onnx_path}")

    # Validate ONNX model
    import onnx
    onnx_model = onnx.load(onnx_path)
    onnx.checker.check_model(onnx_model)
    print("βœ“ ONNX model validation passed")

# ONNX Runtime inference (fast inference)
def inference_onnx(onnx_path, data):
    """
    Fast inference using ONNX Runtime

    Args:
        onnx_path: ONNX model path
        data: Input Data

    Returns:
        Prediction value
    """
    import onnxruntime as ort
    import numpy as np

    # Create ONNX Runtime session
    ort_session = ort.InferenceSession(onnx_path)

    # Batchify
    batch = Batch.from_data_list([data])

    # Convert to NumPy arrays
    ort_inputs = {
        'x': batch.x.numpy(),
        'edge_index': batch.edge_index.numpy(),
        'edge_attr': batch.edge_attr.numpy(),
        'batch': batch.batch.numpy()
    }

    # Inference
    ort_outputs = ort_session.run(None, ort_inputs)
    prediction = ort_outputs[0]

    return prediction[0]

# Usage example
# model = CGCNNModel()  # Trained model
# sample_data = dataset[0]  # Sample data

# export_to_onnx(model, sample_data, 'cgcnn_model.onnx')
# prediction = inference_onnx('cgcnn_model.onnx', sample_data)
# print(f"ONNX prediction: {prediction:.4f}")
Benefits of ONNX Runtime: Compared to PyTorch native inference, 1.5-3x inference speed improvement can be expected. The effect is particularly noticeable in CPU environments.

6.5 Production Environment Deployment

We will publish the trained model as a REST API and make it available from web applications and other systems. We will show an implementation example using FastAPI.

graph LR A[Web Client] -->|POST /predict| B[FastAPI Server] B --> C[Load Model] C --> D[PyG Data Conversion] D --> E[Execute Inference] E --> F[Result JSON] F -->|Response| A style B fill:#667eea,color:#fff style E fill:#764ba2,color:#fff
# Requirements:
# - Python 3.9+
# - fastapi>=0.100.0
# - requests>=2.31.0
# - torch>=2.0.0, <2.3.0

"""
Example: We will publish the trained model as a REST API and make it 

Purpose: Demonstrate core concepts and implementation patterns
Target: Advanced
Execution time: 10-20 seconds
Dependencies: None
"""

# Code Example 9: FastAPI REST API Deployment
# Run in local environment or server

# requirements.txt:
# fastapi
# uvicorn
# torch
# torch-geometric
# pymatgen

from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from pymatgen.core import Structure
import json

app = FastAPI(title="Materials Property Prediction API")

# Load model as global variable
MODEL = None
DEVICE = None

class CrystalInput(BaseModel):
    """
    Input data schema
    """
    structure: dict  # pymatgen Structure dictionary representation
    # or
    cif_string: str = None  # CIF string

class PredictionResponse(BaseModel):
    """
    Prediction result schema
    """
    prediction: float
    uncertainty: float = None
    material_id: str = None

@app.on_event("startup")
async def load_model():
    """
    Load model on server startup
    """
    global MODEL, DEVICE

    DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Load model
    MODEL = CGCNNModel()
    checkpoint = torch.load('checkpoints/best_model.pt', map_location=DEVICE)
    MODEL.load_state_dict(checkpoint['model_state_dict'])
    MODEL.to(DEVICE)
    MODEL.eval()

    print(f"βœ“ Model loaded on {DEVICE}")

@app.post("/predict", response_model=PredictionResponse)
async def predict_property(input_data: CrystalInput):
    """
    Material property prediction endpoint

    Args:
        input_data: Crystal structure data

    Returns:
        Prediction result
    """
    try:
        # Parse structure data
        if input_data.cif_string:
            structure = Structure.from_str(input_data.cif_string, fmt='cif')
        else:
            structure = Structure.from_dict(input_data.structure)

        # Graph conversion
        converter = StructureToGraph(cutoff=5.0)
        data = converter.convert(structure)
        data = data.to(DEVICE)

        # Inference
        with torch.no_grad():
            prediction = MODEL(data).item()

        return PredictionResponse(
            prediction=prediction,
            material_id=input_data.structure.get('material_id', 'unknown')
        )

    except Exception as e:
        raise HTTPException(status_code=400, detail=str(e))

@app.get("/health")
async def health_check():
    """
    Health check endpoint
    """
    return {
        "status": "healthy",
        "model_loaded": MODEL is not None,
        "device": str(DEVICE)
    }

@app.get("/")
async def root():
    """
    API root
    """
    return {
        "message": "Materials Property Prediction API",
        "endpoints": {
            "POST /predict": "Predict material property from structure",
            "GET /health": "Health check",
            "GET /docs": "API documentation (Swagger UI)"
        }
    }

# Server startup command:
# uvicorn api:app --host 0.0.0.0 --port 8000 --reload

# Client usage example (Python):
"""
import requests
import json

# CIF string (example)
cif_string = '''
data_mp-1234
_cell_length_a    3.905
_cell_length_b    3.905
_cell_length_c    3.905
_cell_angle_alpha 90.0
_cell_angle_beta  90.0
_cell_angle_gamma 90.0
_symmetry_space_group_name_H-M 'P 1'
loop_
_atom_site_label
_atom_site_type_symbol
_atom_site_fract_x
_atom_site_fract_y
_atom_site_fract_z
Ti1 Ti 0.0 0.0 0.0
O1  O  0.5 0.5 0.0
O2  O  0.5 0.0 0.5
O3  O  0.0 0.5 0.5
'''

# API call
response = requests.post(
    'http://localhost:8000/predict',
    json={'cif_string': cif_string}
)

result = response.json()
print(f"Predicted band gap: {result['prediction']:.3f} eV")
"""
Benefits of FastAPI: Automatic API documentation generation (Swagger UI), fast asynchronous processing, type checking, and concise code writing are possible. You can check interactive API documentation at http://localhost:8000/docs.

6.6 Chapter Summary

In this chapter, we learned practical workflows using PyTorch Geometric and the Materials Project API. From creating custom datasets to production deployment, we comprehensively mastered the techniques required for actual projects.

Key Points

Practical Workflow

graph TD A[Materials Project API] --> B[Data Acquisition & CIF Save] B --> C[PyG Dataset Creation] C --> D[Distributed Training / GPU Optimization] D --> E[Checkpoint Save] E --> F[ONNX Conversion] F --> G[FastAPI Deploy] G --> H[REST API Publication] style A fill:#667eea,color:#fff style D fill:#764ba2,color:#fff style H fill:#28a745,color:#fff

Next Steps

Using the knowledge learned in this series, try challenging practical projects such as:

  1. Custom Property Prediction Model: Build a model to predict other properties from Materials Project (formation energy, elastic modulus, etc.)
  2. Ensemble Model: High-accuracy prediction through ensemble of composition-based and GNN models
  3. Active Learning Pipeline: Efficient data collection strategy using uncertainty estimation
  4. Web Application: Interactive material exploration tool using Streamlit, etc.

Exercises

Exercise 6.1 (Easy): Basic Data Acquisition with Materials Project API

Using the Materials Project API, create code to retrieve 100 stable oxide materials (containing O) with formation energy of -2.0 eV/atom or less and display the following statistical information.

Solution Example:
# Requirements:
# - Python 3.9+
# - pandas>=2.0.0, <2.2.0

"""
Example: Using the Materials Project API, create code to retrieve 100

Purpose: Demonstrate core concepts and implementation patterns
Target: Beginner to Intermediate
Execution time: 10-30 seconds
Dependencies: None
"""

from mp_api.client import MPRester
import pandas as pd
from collections import Counter

API_KEY = "your_api_key_here"

with MPRester(API_KEY) as mpr:
    docs = mpr.materials.summary.search(
        elements=["O"],  # Containing oxygen
        formation_energy_per_atom=(None, -2.0),  # -2.0 eV/atom or less
        num_elements=(2, 5),  # 2-5 element systems
        fields=["material_id", "formula_pretty", "formation_energy_per_atom",
                "elements", "symmetry"]
    )

# Calculate statistics
formation_energies = [doc.formation_energy_per_atom for doc in docs]
all_elements = []
space_groups = []

for doc in docs:
    all_elements.extend([str(el) for el in doc.elements])
    space_groups.append(doc.symmetry.symbol)

print(f"=== Statistics ({len(docs)} entries) ===")
print(f"Average formation energy: {sum(formation_energies)/len(formation_energies):.3f} eV/atom")
print(f"\nElement frequency (top 10):")
for elem, count in Counter(all_elements).most_common(10):
    print(f"  {elem}: {count} times")
print(f"\nSpace group distribution (top 5):")
for sg, count in Counter(space_groups).most_common(5):
    print(f"  {sg}: {count} entries")

Exercise 6.2 (Easy): Data Saving and Loading in CIF Format

Create code to retrieve the crystal structure of any material (specified by material_id) from Materials Project, save it in CIF format, reload it, and display atomic site information.

Solution Example:
from mp_api.client import MPRester
from pymatgen.io.cif import CifWriter, CifParser

API_KEY = "your_api_key_here"
material_id = "mp-1234"  # Replace with actual ID

# 1. Data acquisition and CIF save
with MPRester(API_KEY) as mpr:
    structure = mpr.get_structure_by_material_id(material_id)

    cif_writer = CifWriter(structure)
    cif_writer.write_file(f"{material_id}.cif")
    print(f"βœ“ CIF saved: {material_id}.cif")

# 2. CIF load
parser = CifParser(f"{material_id}.cif")
structure_loaded = parser.get_structures()[0]

# 3. Display atomic site information
print(f"\n=== Atomic Site Information ===")
print(f"Chemical formula: {structure_loaded.composition.reduced_formula}")
for i, site in enumerate(structure_loaded):
    print(f"Site {i+1}: {site.species_string} at fractional coords {site.frac_coords}")

Exercise 6.3 (Medium): Extending Custom InMemoryDataset

Extend the MaterialsProjectDataset from Code Example 4 and add the following features:

  1. Retrieve multiple properties simultaneously (band_gap, formation_energy_per_atom)
  2. Explicit implementation of __len__() and __getitem__() methods
  3. Add statistics() method that returns dataset statistics
Solution Example:
class MultiPropertyDataset(InMemoryDataset):
    def __init__(self, root, api_key, property_names=['band_gap', 'formation_energy_per_atom'],
                 cutoff=5.0, transform=None, pre_transform=None, pre_filter=None):
        self.api_key = api_key
        self.property_names = property_names
        self.converter = StructureToGraph(cutoff=cutoff)

        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['materials.pkl']

    @property
    def processed_file_names(self):
        return ['data.pt']

    def download(self):
        with MPRester(self.api_key) as mpr:
            docs = mpr.materials.summary.search(
                energy_above_hull=(0, 0.05),
                fields=["material_id"] + self.property_names
            )

            materials_data = []
            for doc in docs:
                try:
                    structure = mpr.get_structure_by_material_id(doc.material_id)
                    properties = {prop: getattr(doc, prop) for prop in self.property_names}

                    materials_data.append({
                        'material_id': doc.material_id,
                        'structure': structure,
                        'properties': properties
                    })
                except:
                    pass

            with open(self.raw_paths[0], 'wb') as f:
                pickle.dump(materials_data, f)

    def process(self):
        with open(self.raw_paths[0], 'rb') as f:
            materials_data = pickle.load(f)

        data_list = []
        for item in materials_data:
            data = self.converter.convert(item['structure'])

            # Tensorize multiple properties
            y = torch.tensor([item['properties'][prop] for prop in self.property_names],
                             dtype=torch.float)
            data.y = y
            data.material_id = item['material_id']

            data_list.append(data)

        data, slices = self.collate(data_list)
        torch.save((data, slices), self.processed_paths[0])

    def __len__(self):
        return len(self.slices['x']) - 1

    def __getitem__(self, idx):
        data = self.get(idx)
        return data

    def statistics(self):
        """Return dataset statistics"""
        stats = {
            'num_samples': len(self),
            'properties': {}
        }

        for i, prop in enumerate(self.property_names):
            values = [self[j].y[i].item() for j in range(len(self))]
            stats['properties'][prop] = {
                'mean': np.mean(values),
                'std': np.std(values),
                'min': np.min(values),
                'max': np.max(values)
            }

        return stats

Exercise 6.4 (Medium): Verification of Mixed Precision Training Effects

Create code to compare normal FP32 training and Mixed Precision Training (FP16) and verify the following:

Solution Example:
# Requirements:
# - Python 3.9+
# - torch>=2.0.0, <2.3.0

"""
Example: Create code to compare normal FP32 training and Mixed Precis

Purpose: Demonstrate optimization techniques
Target: Advanced
Execution time: 1-5 minutes
Dependencies: None
"""

import torch
from torch.cuda.amp import autocast, GradScaler
import time

def compare_training_precision(model, dataset, epochs=50):
    device = torch.device('cuda')
    train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

    results = {}

    # 1. FP32 training
    print("=== FP32 Training ===")
    model_fp32 = model.to(device)
    optimizer = torch.optim.Adam(model_fp32.parameters(), lr=0.001)
    criterion = torch.nn.MSELoss()

    torch.cuda.reset_peak_memory_stats()
    start_time = time.time()

    for epoch in range(epochs):
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            output = model_fp32(batch)
            loss = criterion(output, batch.y)
            loss.backward()
            optimizer.step()

    fp32_time = time.time() - start_time
    fp32_memory = torch.cuda.max_memory_allocated() / 1024**3  # GB

    results['fp32'] = {'time': fp32_time, 'memory': fp32_memory}

    # 2. Mixed Precision training
    print("\n=== Mixed Precision Training ===")
    model_fp16 = model.to(device)
    optimizer = torch.optim.Adam(model_fp16.parameters(), lr=0.001)
    scaler = GradScaler()

    torch.cuda.reset_peak_memory_stats()
    start_time = time.time()

    for epoch in range(epochs):
        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()

            with autocast():
                output = model_fp16(batch)
                loss = criterion(output, batch.y)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

    fp16_time = time.time() - start_time
    fp16_memory = torch.cuda.max_memory_allocated() / 1024**3

    results['fp16'] = {'time': fp16_time, 'memory': fp16_memory}

    # Display results
    print("\n=== Comparison Results ===")
    print(f"Training time: FP32={fp32_time:.2f}s, FP16={fp16_time:.2f}s (Speedup: {fp32_time/fp16_time:.2f}x)")
    print(f"GPU memory: FP32={fp32_memory:.2f}GB, FP16={fp16_memory:.2f}GB (Reduction: {(1-fp16_memory/fp32_memory)*100:.1f}%)")

    return results

Exercise 6.5 (Medium): Resume Training from Checkpoint

Create code to interrupt training midway, then resume training from a saved checkpoint. Correctly restore epoch number, loss, and optimizer state.

Solution Example:
# Requirements:
# - Python 3.9+
# - torch>=2.0.0, <2.3.0

"""
Example: Create code to interrupt training midway, then resume traini

Purpose: Demonstrate optimization techniques
Target: Advanced
Execution time: 1-5 minutes
Dependencies: None
"""

import torch

def train_with_resume(model, dataset, total_epochs=100, checkpoint_path=None):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    criterion = torch.nn.MSELoss()
    train_loader = DataLoader(dataset, batch_size=64, shuffle=True)

    start_epoch = 0

    # Resume from checkpoint
    if checkpoint_path and os.path.exists(checkpoint_path):
        checkpoint = torch.load(checkpoint_path)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        start_epoch = checkpoint['epoch']
        print(f"βœ“ Resuming from checkpoint: Epoch {start_epoch}")

    # Training loop
    for epoch in range(start_epoch, total_epochs):
        model.train()
        total_loss = 0

        for batch in train_loader:
            batch = batch.to(device)
            optimizer.zero_grad()
            output = model(batch)
            loss = criterion(output, batch.y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        avg_loss = total_loss / len(train_loader)
        print(f"Epoch {epoch+1}/{total_epochs}, Loss: {avg_loss:.4f}")

        # Save checkpoint every 10 epochs
        if (epoch + 1) % 10 == 0:
            checkpoint = {
                'epoch': epoch + 1,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss': avg_loss
            }
            torch.save(checkpoint, f'checkpoint_epoch{epoch+1}.pt')
            print(f"βœ“ Checkpoint saved")

    return model

# Usage example
# model = train_with_resume(CGCNNModel(), dataset, total_epochs=100, checkpoint_path='checkpoint_epoch50.pt')

Exercise 6.6 (Hard): Batch Prediction and ONNX Inference Speed Benchmark

Create benchmark code to compare the speed of PyTorch native inference and ONNX Runtime inference. Measure under the following conditions:

Solution Example:
# Requirements:
# - Python 3.9+
# - numpy>=1.24.0, <2.0.0
# - torch>=2.0.0, <2.3.0

"""
Example: Create benchmark code to compare the speed of PyTorch native

Purpose: Demonstrate core concepts and implementation patterns
Target: Advanced
Execution time: 10-30 seconds
Dependencies: None
"""

import torch
import onnxruntime as ort
import time
import numpy as np
from torch_geometric.data import DataLoader, Batch

def benchmark_inference(model, dataset, onnx_path, batch_sizes=[1, 32, 64, 128], n_iterations=100):
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    # ONNX Runtime session
    ort_session = ort.InferenceSession(onnx_path)

    results = []

    for batch_size in batch_sizes:
        print(f"\n=== Batch Size: {batch_size} ===")

        loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)

        # PyTorch inference
        torch_times = []
        for _ in range(n_iterations):
            batch = next(iter(loader))
            batch = batch.to(device)

            start = time.time()
            with torch.no_grad():
                _ = model(batch)
            torch.cuda.synchronize() if torch.cuda.is_available() else None
            torch_times.append((time.time() - start) * 1000)  # ms

        torch_avg = np.mean(torch_times)
        torch_throughput = batch_size * 1000 / torch_avg

        # ONNX Runtime inference
        onnx_times = []
        for _ in range(n_iterations):
            batch = next(iter(loader))

            ort_inputs = {
                'x': batch.x.numpy(),
                'edge_index': batch.edge_index.numpy(),
                'edge_attr': batch.edge_attr.numpy(),
                'batch': batch.batch.numpy()
            }

            start = time.time()
            _ = ort_session.run(None, ort_inputs)
            onnx_times.append((time.time() - start) * 1000)

        onnx_avg = np.mean(onnx_times)
        onnx_throughput = batch_size * 1000 / onnx_avg

        # Save results
        results.append({
            'batch_size': batch_size,
            'pytorch_ms': torch_avg,
            'onnx_ms': onnx_avg,
            'speedup': torch_avg / onnx_avg,
            'pytorch_throughput': torch_throughput,
            'onnx_throughput': onnx_throughput
        })

        print(f"PyTorch: {torch_avg:.2f} ms/batch ({torch_throughput:.1f} samples/sec)")
        print(f"ONNX: {onnx_avg:.2f} ms/batch ({onnx_throughput:.1f} samples/sec)")
        print(f"Speedup: {torch_avg/onnx_avg:.2f}x")

    return results

# Usage example
# results = benchmark_inference(model, dataset, 'cgcnn_model.onnx')

Exercise 6.7 (Hard): FastAPI Asynchronous Batch Inference

Using FastAPI's background task feature, implement an asynchronous API that batch processes multiple prediction requests. Meet the following requirements:

Solution Example:
# Requirements:
# - Python 3.9+
# - fastapi>=0.100.0
# - torch>=2.0.0, <2.3.0

"""
Example: Using FastAPI's background task feature, implement an asynch

Purpose: Demonstrate core concepts and implementation patterns
Target: Advanced
Execution time: 10-30 seconds
Dependencies: None
"""

from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
import asyncio
import uuid
from collections import defaultdict
import torch

app = FastAPI()

# Global state
pending_requests = []
results_store = {}
MODEL = None

class PredictionRequest(BaseModel):
    structure: dict

class JobResponse(BaseModel):
    job_id: str
    status: str

class ResultResponse(BaseModel):
    job_id: str
    prediction: float = None
    status: str

async def batch_processor():
    """Execute batch processing in background"""
    while True:
        await asyncio.sleep(1.0)  # Batch process every 1 second

        if len(pending_requests) == 0:
            continue

        # Get buffered requests
        batch_requests = pending_requests.copy()
        pending_requests.clear()

        # Batch inference
        job_ids = [req['job_id'] for req in batch_requests]
        structures = [req['structure'] for req in batch_requests]

        # Graph conversion (parallel processing)
        data_list = []
        for structure in structures:
            converter = StructureToGraph(cutoff=5.0)
            data = converter.convert(Structure.from_dict(structure))
            data_list.append(data)

        # Batchify
        from torch_geometric.data import Batch
        batch = Batch.from_data_list(data_list)
        batch = batch.to('cuda' if torch.cuda.is_available() else 'cpu')

        # Inference
        with torch.no_grad():
            predictions = MODEL(batch).cpu().numpy()

        # Save results
        for job_id, pred in zip(job_ids, predictions):
            results_store[job_id] = {
                'status': 'completed',
                'prediction': float(pred)
            }

@app.on_event("startup")
async def startup_event():
    global MODEL
    MODEL = CGCNNModel()
    checkpoint = torch.load('checkpoints/best_model.pt')
    MODEL.load_state_dict(checkpoint['model_state_dict'])
    MODEL.eval()

    # Start background task
    asyncio.create_task(batch_processor())

@app.post("/predict/async", response_model=JobResponse)
async def predict_async(request: PredictionRequest):
    """Asynchronous prediction request"""
    job_id = str(uuid.uuid4())

    # Add request to buffer
    pending_requests.append({
        'job_id': job_id,
        'structure': request.structure
    })

    # Save initial state to result store
    results_store[job_id] = {'status': 'pending'}

    return JobResponse(job_id=job_id, status='pending')

@app.get("/result/{job_id}", response_model=ResultResponse)
async def get_result(job_id: str):
    """Get result"""
    if job_id not in results_store:
        return ResultResponse(job_id=job_id, status='not_found')

    result = results_store[job_id]

    return ResultResponse(
        job_id=job_id,
        prediction=result.get('prediction'),
        status=result['status']
    )

Exercise 6.8 (Hard): Prediction API with Uncertainty Estimation

Implement an API that estimates prediction uncertainty using Monte Carlo Dropout (MC dropout). Include the following:

Solution Example:
# Requirements:
# - Python 3.9+
# - fastapi>=0.100.0
# - numpy>=1.24.0, <2.0.0
# - torch>=2.0.0, <2.3.0

import torch
import torch.nn as nn
from torch_geometric.nn import CGConv, global_mean_pool
import numpy as np

class CGCNNWithDropout(nn.Module):
    """CGCNN with dropout layers"""
    def __init__(self, atom_fea_len=92, nbr_fea_len=41,
                 hidden_dim=128, n_conv=3, dropout=0.1):
        super().__init__()

        self.atom_embedding = nn.Linear(atom_fea_len, hidden_dim)

        self.conv_layers = nn.ModuleList([
            CGConv(hidden_dim, nbr_fea_len) for _ in range(n_conv)
        ])

        self.bn_layers = nn.ModuleList([
            nn.BatchNorm1d(hidden_dim) for _ in range(n_conv)
        ])

        self.dropout = nn.Dropout(p=dropout)

        self.fc1 = nn.Linear(hidden_dim, 64)
        self.fc2 = nn.Linear(64, 1)
        self.activation = nn.Softplus()

    def forward(self, data):
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        x = self.atom_embedding(x)

        for conv, bn in zip(self.conv_layers, self.bn_layers):
            x_new = conv(x, edge_index, edge_attr)
            x_new = bn(x_new)
            x_new = self.activation(x_new)
            x_new = self.dropout(x_new)  # Dropout
            x = x + x_new

        x = global_mean_pool(x, batch)

        x = self.dropout(self.activation(self.fc1(x)))  # Dropout
        x = self.fc2(x)

        return x.squeeze()

    def predict_with_uncertainty(self, data, n_samples=30):
        """
        Uncertainty estimation with MC Dropout

        Returns:
            mean, std (mean and standard deviation of predictions)
        """
        self.train()  # Enable dropout

        predictions = []
        with torch.no_grad():
            for _ in range(n_samples):
                pred = self.forward(data).item()
                predictions.append(pred)

        mean = np.mean(predictions)
        std = np.std(predictions)

        return mean, std

# FastAPI integration
from fastapi import FastAPI
from pydantic import BaseModel

app = FastAPI()
MODEL = None

class UncertaintyResponse(BaseModel):
    prediction: float
    uncertainty: float
    confidence_interval_95: tuple

@app.post("/predict/uncertainty", response_model=UncertaintyResponse)
async def predict_with_uncertainty(request: CrystalInput):
    """Prediction with uncertainty estimation"""
    # Parse structure data
    structure = Structure.from_dict(request.structure)

    # Graph conversion
    converter = StructureToGraph(cutoff=5.0)
    data = converter.convert(structure)
    data = data.to('cuda' if torch.cuda.is_available() else 'cpu')

    # MC Dropout inference
    mean, std = MODEL.predict_with_uncertainty(data, n_samples=30)

    # 95% confidence interval
    ci_lower = mean - 1.96 * std
    ci_upper = mean + 1.96 * std

    return UncertaintyResponse(
        prediction=mean,
        uncertainty=std,
        confidence_interval_95=(ci_lower, ci_upper)
    )

References

  1. Jain, A., Ong, S. P., Hautier, G., et al. (2013). Commentary: The Materials Project: A materials genome approach to accelerating materials innovation. APL Materials, 1(1), 011002. DOI: 10.1063/1.4812323, pp. 1-11. (Foundational paper for Materials Project API)
  2. Ong, S. P., Richards, W. D., Jain, A., et al. (2013). Python Materials Genomics (pymatgen): A robust, open-source python library for materials analysis. Computational Materials Science, 68, 314-319. DOI: 10.1016/j.commatsci.2012.10.028, pp. 314-319. (Official paper for pymatgen library)
  3. Fey, M., & Lenssen, J. E. (2019). Fast Graph Representation Learning with PyTorch Geometric. In ICLR Workshop on Representation Learning on Graphs and Manifolds. arXiv:1903.02428, pp. 1-5. (Official paper for PyTorch Geometric)
  4. Micikevicius, P., Narang, S., Alben, J., et al. (2018). Mixed Precision Training. In International Conference on Learning Representations (ICLR). arXiv:1710.03740, pp. 1-12. (Proposal paper for Mixed Precision Training)
  5. Bingham, E., Chen, J. P., Jankowiak, M., et al. (2019). Pyro: Deep Universal Probabilistic Programming. Journal of Machine Learning Research, 20(28), 1-6. (Theoretical background for uncertainty estimation)
  6. RamΓ­rez, S. (2021). FastAPI: Modern Python Web Development. O'Reilly Media, pp. 1-350. (Comprehensive guide for FastAPI, especially Chapters 5-7 are useful for production deployment)
  7. ONNX Runtime Development Team. (2020). ONNX Runtime Performance Tuning. Microsoft Technical Report. https://onnxruntime.ai/docs/performance/ (Official documentation for ONNX Runtime optimization)
← Back to Series Top

Disclaimer