第4章:Attention機構

系列変換タスクを革新したAttentionメカニズムの理論と実装

📖 読了時間: 23分 📊 難易度: 中級〜上級 💻 コード例: 8個 📝 演習問題: 6問

学習目標

この章を読むことで、以下を習得できます:


4.1 Attentionの必要性

Context Vectorのボトルネック問題

第3章で学んだEncoder-Decoderモデルでは、入力系列全体を固定長のContext Vector(文脈ベクトル)に圧縮していました。この設計には以下の深刻な問題があります:

問題 詳細 影響
情報ボトルネック 長い系列を固定長ベクトルに圧縮 情報損失、長文での性能低下
長距離依存の困難 系列の始めと終わりの関連付けが困難 長文翻訳の精度低下
一様な重要度 全単語を同じ重みで扱う 重要な単語を強調できない
解釈性の欠如 モデルの判断根拠が不明 デバッグや改善が困難

Attentionによる解決策

Attention機構は、Decoderが各時刻で「入力系列のどの部分に注目すべきか」を動的に学習します。固定長のContext Vectorではなく、各時刻で異なるContext Vectorを計算することで上記の問題を解決します。

graph TB subgraph "従来のSeq2Seq(固定Context)" A1[入力: I love AI] --> E1[Encoder] E1 --> C1[固定Context Vector] C1 --> D1[Decoder] D1 --> O1[出力: 私はAIが好き] style C1 fill:#e74c3c,color:#fff end subgraph "Attention付きSeq2Seq(動的Context)" A2[入力: I love AI] --> E2[Encoder] E2 --> H1[hidden states] H1 --> ATT[Attention機構] D2[Decoder state] --> ATT ATT --> C2[動的Context t=1] ATT --> C3[動的Context t=2] ATT --> C4[動的Context t=3] C2 --> D2 C3 --> D2 C4 --> D2 D2 --> O2[出力: 私はAIが好き] style ATT fill:#27ae60,color:#fff end

重要: Attentionは「どこを見るべきか」を学習する機構です。人間が文章を読むときに重要な部分に注意を払うのと同じように、モデルも入力系列の重要な部分に重みを置きます。


4.2 Bahdanau Attention(Additive Attention)

4.2.1 基本概念とアーキテクチャ

Bahdanau Attention(2014年提案)は、最初に広く採用されたAttention機構です。Additive Attentionとも呼ばれ、EncoderとDecoderの隠れ状態を加算的に組み合わせます。

数学的定義

時刻 $t$ におけるAttention重みの計算:

ステップ1: Alignment Score(アライメントスコア)

$$ e_{ti} = v_a^\top \tanh(W_a s_{t-1} + U_a h_i) $$

ここで:

ステップ2: Attention Weight(注意重み)

$$ \alpha_{ti} = \frac{\exp(e_{ti})}{\sum_{j=1}^{T_x} \exp(e_{tj})} $$

Softmax関数により、全時刻の重みの合計が1になるように正規化されます。

ステップ3: Context Vector(文脈ベクトル)

$$ c_t = \sum_{i=1}^{T_x} \alpha_{ti} h_i $$

Attention重みを使ってEncoder隠れ状態の重み付き和を計算します。

4.2.2 PyTorchによる実装

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

class BahdanauAttention(nn.Module):
    """Bahdanau Attention(Additive Attention)の実装"""

    def __init__(self, hidden_size):
        super(BahdanauAttention, self).__init__()
        self.hidden_size = hidden_size

        # Attention用のパラメータ
        self.W_a = nn.Linear(hidden_size, hidden_size, bias=False)  # Decoder用
        self.U_a = nn.Linear(hidden_size, hidden_size, bias=False)  # Encoder用
        self.v_a = nn.Linear(hidden_size, 1, bias=False)            # スカラー変換

    def forward(self, decoder_hidden, encoder_outputs):
        """
        Args:
            decoder_hidden: Decoderの隠れ状態 [batch, hidden_size]
            encoder_outputs: Encoderの全隠れ状態 [batch, seq_len, hidden_size]

        Returns:
            context: Context vector [batch, hidden_size]
            attention_weights: Attention重み [batch, seq_len]
        """
        batch_size = encoder_outputs.size(0)
        seq_len = encoder_outputs.size(1)

        # Decoder hiddenを各時刻で繰り返し [batch, seq_len, hidden_size]
        decoder_hidden = decoder_hidden.unsqueeze(1).repeat(1, seq_len, 1)

        # Alignment score計算: e_ti = v_a^T * tanh(W_a * s_t + U_a * h_i)
        energy = torch.tanh(self.W_a(decoder_hidden) + self.U_a(encoder_outputs))
        alignment_scores = self.v_a(energy).squeeze(-1)  # [batch, seq_len]

        # Softmaxで正規化してAttention weight取得
        attention_weights = F.softmax(alignment_scores, dim=1)  # [batch, seq_len]

        # Context vector計算: c_t = Σ α_ti * h_i
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        context = context.squeeze(1)  # [batch, hidden_size]

        return context, attention_weights


# デモンストレーション
print("=== Bahdanau Attention デモ ===\n")

# パラメータ設定
batch_size = 2
seq_len = 5
hidden_size = 8

# ダミーデータ生成
encoder_outputs = torch.randn(batch_size, seq_len, hidden_size)
decoder_hidden = torch.randn(batch_size, hidden_size)

# Attention適用
attention = BahdanauAttention(hidden_size)
context, weights = attention(decoder_hidden, encoder_outputs)

print(f"Encoder outputs shape: {encoder_outputs.shape}")
print(f"Decoder hidden shape: {decoder_hidden.shape}")
print(f"Context vector shape: {context.shape}")
print(f"Attention weights shape: {weights.shape}")
print(f"\nAttention weights (batch 0):")
print(weights[0].detach().numpy())
print(f"合計: {weights[0].sum().item():.4f} (Should be 1.0)")

# Attention重みの可視化
fig, ax = plt.subplots(1, 1, figsize=(10, 4))

# 1バッチ目のAttention重みをヒートマップ表示
weights_np = weights[0].detach().numpy()
positions = np.arange(seq_len)

ax.bar(positions, weights_np, color='#7b2cbf', alpha=0.7, edgecolor='black')
ax.set_xlabel('Encoder Position', fontsize=12, fontweight='bold')
ax.set_ylabel('Attention Weight', fontsize=12, fontweight='bold')
ax.set_title('Bahdanau Attention Weights Distribution', fontsize=14, fontweight='bold')
ax.set_xticks(positions)
ax.set_xticklabels([f't={i+1}' for i in positions])
ax.grid(axis='y', alpha=0.3)

# 値をバーの上に表示
for i, v in enumerate(weights_np):
    ax.text(i, v + 0.01, f'{v:.3f}', ha='center', va='bottom', fontsize=9)

plt.tight_layout()
plt.show()

出力

=== Bahdanau Attention デモ ===

Encoder outputs shape: torch.Size([2, 5, 8])
Decoder hidden shape: torch.Size([2, 8])
Context vector shape: torch.Size([2, 8])
Attention weights shape: torch.Size([2, 5])

Attention weights (batch 0):
[0.178 0.245 0.198 0.156 0.223]
合計: 1.0000 (Should be 1.0)

4.3 Luong Attention(Multiplicative Attention)

4.3.1 BahdanauとLuongの違い

Luong Attention(2015年提案)は、Bahdanauと異なるアプローチでAlignment Scoreを計算します。

特性 Bahdanau Attention Luong Attention
提案年 2014年 2015年
別名 Additive Attention Multiplicative Attention
Score計算 $v_a^\top \tanh(W_a s_t + U_a h_i)$ $s_t^\top W_a h_i$ (general)
Decoder状態 前時刻の状態 $s_{t-1}$ を使用 現在の状態 $s_t$ を使用
計算コスト やや高い(tanh演算) 低い(内積のみ)
性能 小規模データで優位 大規模データで優位

4.3.2 Luong Attentionの3つのスコア関数

Luongは3種類のAlignment Score計算方法を提案しました:

1. Dot(内積)

$$ \text{score}(s_t, h_i) = s_t^\top h_i $$

2. General(一般化内積)

$$ \text{score}(s_t, h_i) = s_t^\top W_a h_i $$

3. Concat(連結)

$$ \text{score}(s_t, h_i) = v_a^\top \tanh(W_a [s_t; h_i]) $$

4.3.3 実装とBahdanauとの比較

class LuongAttention(nn.Module):
    """Luong Attention(Multiplicative Attention)の実装"""

    def __init__(self, hidden_size, score_type='general'):
        super(LuongAttention, self).__init__()
        self.hidden_size = hidden_size
        self.score_type = score_type

        if score_type == 'general':
            self.W_a = nn.Linear(hidden_size, hidden_size, bias=False)
        elif score_type == 'concat':
            self.W_a = nn.Linear(hidden_size * 2, hidden_size, bias=False)
            self.v_a = nn.Linear(hidden_size, 1, bias=False)
        elif score_type == 'dot':
            pass  # パラメータ不要
        else:
            raise ValueError(f"Unknown score_type: {score_type}")

    def forward(self, decoder_hidden, encoder_outputs):
        """
        Args:
            decoder_hidden: [batch, hidden_size]
            encoder_outputs: [batch, seq_len, hidden_size]

        Returns:
            context: [batch, hidden_size]
            attention_weights: [batch, seq_len]
        """
        batch_size = encoder_outputs.size(0)
        seq_len = encoder_outputs.size(1)

        if self.score_type == 'dot':
            # s_t^T * h_i
            alignment_scores = torch.bmm(
                encoder_outputs,  # [batch, seq_len, hidden]
                decoder_hidden.unsqueeze(2)  # [batch, hidden, 1]
            ).squeeze(2)  # [batch, seq_len]

        elif self.score_type == 'general':
            # s_t^T * W_a * h_i
            transformed = self.W_a(encoder_outputs)  # [batch, seq_len, hidden]
            alignment_scores = torch.bmm(
                transformed,
                decoder_hidden.unsqueeze(2)
            ).squeeze(2)

        elif self.score_type == 'concat':
            # v_a^T * tanh(W_a * [s_t; h_i])
            decoder_repeated = decoder_hidden.unsqueeze(1).repeat(1, seq_len, 1)
            concat = torch.cat([decoder_repeated, encoder_outputs], dim=2)
            energy = torch.tanh(self.W_a(concat))
            alignment_scores = self.v_a(energy).squeeze(-1)

        # Softmaxで正規化
        attention_weights = F.softmax(alignment_scores, dim=1)

        # Context vector計算
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)
        context = context.squeeze(1)

        return context, attention_weights


# 3つのスコア関数の比較
print("\n=== Luong Attention 3つのスコア関数の比較 ===\n")

# ダミーデータ
batch_size = 1
seq_len = 6
hidden_size = 4
encoder_outputs = torch.randn(batch_size, seq_len, hidden_size)
decoder_hidden = torch.randn(batch_size, hidden_size)

score_types = ['dot', 'general', 'concat']
results = {}

fig, axes = plt.subplots(1, 3, figsize=(15, 4))

for idx, score_type in enumerate(score_types):
    attention = LuongAttention(hidden_size, score_type=score_type)
    context, weights = attention(decoder_hidden, encoder_outputs)
    results[score_type] = weights[0].detach().numpy()

    print(f"{score_type.upper()} score:")
    print(f"  Weights: {results[score_type]}")
    print(f"  Sum: {results[score_type].sum():.4f}\n")

    # 可視化
    ax = axes[idx]
    positions = np.arange(seq_len)
    ax.bar(positions, results[score_type], color='#7b2cbf', alpha=0.7, edgecolor='black')
    ax.set_title(f'{score_type.upper()} Score', fontsize=12, fontweight='bold')
    ax.set_xlabel('Encoder Position', fontsize=10)
    ax.set_ylabel('Attention Weight', fontsize=10)
    ax.set_xticks(positions)
    ax.set_xticklabels([f't={i+1}' for i in positions], fontsize=9)
    ax.set_ylim([0, max(results[score_type]) * 1.2])
    ax.grid(axis='y', alpha=0.3)

    # 値表示
    for i, v in enumerate(results[score_type]):
        ax.text(i, v + 0.005, f'{v:.3f}', ha='center', va='bottom', fontsize=8)

plt.suptitle('Luong Attention: Score Function Comparison',
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# BahdanauとLuongの比較
print("\n=== Bahdanau vs Luong Attention ===\n")

bahdanau_attn = BahdanauAttention(hidden_size)
luong_attn = LuongAttention(hidden_size, score_type='general')

_, bahdanau_weights = bahdanau_attn(decoder_hidden, encoder_outputs)
_, luong_weights = luong_attn(decoder_hidden, encoder_outputs)

fig, axes = plt.subplots(1, 2, figsize=(12, 4))

# Bahdanau
ax1 = axes[0]
positions = np.arange(seq_len)
bahdanau_np = bahdanau_weights[0].detach().numpy()
ax1.bar(positions, bahdanau_np, color='#e74c3c', alpha=0.7, edgecolor='black')
ax1.set_title('Bahdanau Attention\n(Additive)', fontsize=12, fontweight='bold')
ax1.set_xlabel('Encoder Position', fontsize=10)
ax1.set_ylabel('Attention Weight', fontsize=10)
ax1.set_xticks(positions)
ax1.set_ylim([0, 0.3])
ax1.grid(axis='y', alpha=0.3)

# Luong
ax2 = axes[1]
luong_np = luong_weights[0].detach().numpy()
ax2.bar(positions, luong_np, color='#27ae60', alpha=0.7, edgecolor='black')
ax2.set_title('Luong Attention\n(Multiplicative - General)', fontsize=12, fontweight='bold')
ax2.set_xlabel('Encoder Position', fontsize=10)
ax2.set_ylabel('Attention Weight', fontsize=10)
ax2.set_xticks(positions)
ax2.set_ylim([0, 0.3])
ax2.grid(axis='y', alpha=0.3)

plt.suptitle('Bahdanau vs Luong: Attention Weight Comparison',
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

出力

=== Luong Attention 3つのスコア関数の比較 ===

DOT score:
  Weights: [0.142 0.189 0.165 0.178 0.154 0.172]
  Sum: 1.0000

GENERAL score:
  Weights: [0.158 0.172 0.169 0.165 0.161 0.175]
  Sum: 1.0000

CONCAT score:
  Weights: [0.167 0.171 0.164 0.169 0.162 0.167]
  Sum: 1.0000

4.4 Attention重みの可視化と解釈

4.4.1 翻訳タスクでのAttention可視化

Attentionの最大の利点の1つは、モデルがどの入力単語に注目して出力を生成したかを視覚的に確認できることです。

import matplotlib.pyplot as plt
import seaborn as sns

def visualize_attention(attention_weights, source_tokens, target_tokens,
                       title='Attention Visualization'):
    """
    Attention重みをヒートマップで可視化

    Args:
        attention_weights: [target_len, source_len] のAttention重み行列
        source_tokens: 入力単語リスト
        target_tokens: 出力単語リスト
        title: グラフのタイトル
    """
    fig, ax = plt.subplots(figsize=(10, 8))

    # ヒートマップ描画
    sns.heatmap(attention_weights,
                xticklabels=source_tokens,
                yticklabels=target_tokens,
                cmap='YlOrRd',
                cbar_kws={'label': 'Attention Weight'},
                ax=ax,
                annot=True,  # 値を表示
                fmt='.3f',   # 小数点3桁
                linewidths=0.5,
                linecolor='gray')

    ax.set_xlabel('Source (Input)', fontsize=12, fontweight='bold')
    ax.set_ylabel('Target (Output)', fontsize=12, fontweight='bold')
    ax.set_title(title, fontsize=14, fontweight='bold', pad=20)

    plt.tight_layout()
    plt.show()


# 機械翻訳の例(英語→日本語)
print("=== Attention Visualization Example: Machine Translation ===\n")

# 例: "I love deep learning" → "私は深層学習が好きです"
source_tokens = ['I', 'love', 'deep', 'learning', '']
target_tokens = ['私は', '深層', '学習が', '好きです', '']

# 模擬的なAttention重み行列(実際のモデルから取得する値)
# 各行が各出力トークン生成時のAttention分布を表す
attention_matrix = np.array([
    [0.82, 0.05, 0.03, 0.05, 0.05],  # "私は" を生成時 → "I" に強く注目
    [0.05, 0.08, 0.70, 0.12, 0.05],  # "深層" を生成時 → "deep" に強く注目
    [0.03, 0.05, 0.15, 0.72, 0.05],  # "学習が" を生成時 → "learning" に強く注目
    [0.05, 0.75, 0.08, 0.07, 0.05],  # "好きです" を生成時 → "love" に強く注目
    [0.05, 0.05, 0.05, 0.05, 0.80],  # "" を生成時 → "" に強く注目
])

visualize_attention(attention_matrix, source_tokens, target_tokens,
                   title='Attention Weights: English → Japanese Translation')

# より複雑な例
print("\n=== 長文翻訳でのAttentionパターン ===\n")

source_long = ['The', 'cat', 'sat', 'on', 'the', 'mat', '']
target_long = ['その', '猫が', 'マットの', '上に', '座った', '']

# 語順が異なる場合のAttention
attention_long = np.array([
    [0.65, 0.10, 0.05, 0.05, 0.10, 0.03, 0.02],  # "その" → "The"
    [0.10, 0.70, 0.05, 0.05, 0.05, 0.03, 0.02],  # "猫が" → "cat"
    [0.05, 0.05, 0.05, 0.05, 0.10, 0.68, 0.02],  # "マットの" → "mat"
    [0.05, 0.05, 0.10, 0.75, 0.02, 0.02, 0.01],  # "上に" → "on"
    [0.05, 0.10, 0.70, 0.05, 0.05, 0.03, 0.02],  # "座った" → "sat"
    [0.02, 0.02, 0.02, 0.02, 0.02, 0.02, 0.88],  # "" → ""
])

visualize_attention(attention_long, source_long, target_long,
                   title='Attention Weights: Word Order Reordering (EN→JA)')

print("可視化から観察できるポイント:")
print("✓ 対角線に近い分布: 語順が近い言語対(英独など)")
print("✓ 非対角線の分布: 語順が異なる言語対(英日など)")
print("✓ 複数単語への分散: 1対多、多対1の単語対応")
print("✓ EOS記号への集中: 文末処理の明確化")

4.4.2 Attention重みの統計分析

def analyze_attention_statistics(attention_weights):
    """Attention重みの統計的特性を分析"""

    print("=== Attention Statistics Analysis ===\n")

    # エントロピー計算(集中度の指標)
    epsilon = 1e-10
    entropy = -np.sum(attention_weights * np.log(attention_weights + epsilon), axis=1)

    # 最大重み
    max_weights = np.max(attention_weights, axis=1)

    # 有効な注目位置数(重み > 閾値)
    threshold = 0.1
    effective_positions = np.sum(attention_weights > threshold, axis=1)

    print(f"各時刻のエントロピー: {entropy}")
    print(f"  平均: {entropy.mean():.4f}, 標準偏差: {entropy.std():.4f}")
    print(f"\n各時刻の最大重み: {max_weights}")
    print(f"  平均: {max_weights.mean():.4f}, 標準偏差: {max_weights.std():.4f}")
    print(f"\n有効注目位置数(重み > {threshold}): {effective_positions}")
    print(f"  平均: {effective_positions.mean():.2f}")

    # 可視化
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))

    # エントロピー
    axes[0].plot(entropy, marker='o', color='#7b2cbf', linewidth=2, markersize=8)
    axes[0].set_xlabel('Target Position', fontsize=11, fontweight='bold')
    axes[0].set_ylabel('Entropy', fontsize=11, fontweight='bold')
    axes[0].set_title('Attention Concentration\n(Lower = More Focused)',
                     fontsize=12, fontweight='bold')
    axes[0].grid(alpha=0.3)

    # 最大重み
    axes[1].plot(max_weights, marker='s', color='#e74c3c', linewidth=2, markersize=8)
    axes[1].set_xlabel('Target Position', fontsize=11, fontweight='bold')
    axes[1].set_ylabel('Max Weight', fontsize=11, fontweight='bold')
    axes[1].set_title('Maximum Attention Weight\n(Higher = Strong Focus)',
                     fontsize=12, fontweight='bold')
    axes[1].grid(alpha=0.3)

    # 有効位置数
    axes[2].bar(range(len(effective_positions)), effective_positions,
               color='#27ae60', alpha=0.7, edgecolor='black')
    axes[2].set_xlabel('Target Position', fontsize=11, fontweight='bold')
    axes[2].set_ylabel('Num. Effective Positions', fontsize=11, fontweight='bold')
    axes[2].set_title(f'Attention Spread\n(Threshold = {threshold})',
                     fontsize=12, fontweight='bold')
    axes[2].grid(axis='y', alpha=0.3)

    plt.tight_layout()
    plt.show()

# 分析実行
analyze_attention_statistics(attention_long)

出力

=== Attention Statistics Analysis ===

各時刻のエントロピー: [1.086 0.897 0.935 0.735 0.901 0.412]
  平均: 0.8277, 標準偏差: 0.2194

各時刻の最大重み: [0.65 0.70 0.68 0.75 0.70 0.88]
  平均: 0.7267, 標準偏差: 0.0737

有効注目位置数(重み > 0.1): [4 3 2 2 3 1]
  平均: 2.50

解釈ガイド:
低エントロピー: 特定位置への強い集中(文末のEOS生成時など)
高エントロピー: 複数位置への分散(複合語や句の処理時など)
高最大重み: 明確な対応関係(1対1の単語翻訳など)
有効位置数: 文脈情報の広がり(多いほど広範な文脈利用)


4.5 Attention付きSeq2Seqモデルの実装

4.5.1 完全なEncoder-Decoder with Attention

import torch
import torch.nn as nn
import torch.optim as optim
import random

class AttentionEncoder(nn.Module):
    """Attention用のEncoder(隠れ状態を全て保存)"""

    def __init__(self, input_size, hidden_size, num_layers=1):
        super(AttentionEncoder, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(input_size, hidden_size)
        self.gru = nn.GRU(hidden_size, hidden_size, num_layers,
                         batch_first=True)

    def forward(self, x):
        """
        Args:
            x: [batch, seq_len]
        Returns:
            outputs: [batch, seq_len, hidden_size]
            hidden: [num_layers, batch, hidden_size]
        """
        embedded = self.embedding(x)  # [batch, seq_len, hidden_size]
        outputs, hidden = self.gru(embedded)
        return outputs, hidden


class AttentionDecoder(nn.Module):
    """Attention機構を持つDecoder"""

    def __init__(self, output_size, hidden_size, attention_type='bahdanau', num_layers=1):
        super(AttentionDecoder, self).__init__()
        self.hidden_size = hidden_size
        self.output_size = output_size
        self.num_layers = num_layers

        self.embedding = nn.Embedding(output_size, hidden_size)

        # Attention機構
        if attention_type == 'bahdanau':
            self.attention = BahdanauAttention(hidden_size)
        elif attention_type == 'luong':
            self.attention = LuongAttention(hidden_size, score_type='general')
        else:
            raise ValueError(f"Unknown attention type: {attention_type}")

        # GRU入力 = embedding + context
        self.gru = nn.GRU(hidden_size * 2, hidden_size, num_layers,
                         batch_first=True)

        # 出力層
        self.out = nn.Linear(hidden_size, output_size)

    def forward(self, input_token, hidden, encoder_outputs):
        """
        Args:
            input_token: [batch, 1]
            hidden: [num_layers, batch, hidden_size]
            encoder_outputs: [batch, source_len, hidden_size]
        Returns:
            output: [batch, output_size]
            hidden: [num_layers, batch, hidden_size]
            attention_weights: [batch, source_len]
        """
        # Embedding
        embedded = self.embedding(input_token)  # [batch, 1, hidden_size]

        # Attention計算(最後の層のhiddenを使用)
        query = hidden[-1]  # [batch, hidden_size]
        context, attention_weights = self.attention(query, encoder_outputs)

        # Context vectorとembeddingを結合
        context = context.unsqueeze(1)  # [batch, 1, hidden_size]
        gru_input = torch.cat([embedded, context], dim=2)  # [batch, 1, hidden*2]

        # GRU
        gru_output, hidden = self.gru(gru_input, hidden)

        # 出力
        output = self.out(gru_output.squeeze(1))  # [batch, output_size]

        return output, hidden, attention_weights


class Seq2SeqWithAttention(nn.Module):
    """Attention付きSeq2Seqモデル"""

    def __init__(self, encoder, decoder, device):
        super(Seq2SeqWithAttention, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device

    def forward(self, source, target, teacher_forcing_ratio=0.5):
        """
        Args:
            source: [batch, source_len]
            target: [batch, target_len]
            teacher_forcing_ratio: Teacher forcingを使う確率
        Returns:
            outputs: [batch, target_len, output_size]
            all_attention_weights: [batch, target_len, source_len]
        """
        batch_size = source.size(0)
        target_len = target.size(1)
        target_vocab_size = self.decoder.output_size

        # 出力とAttention重みを保存
        outputs = torch.zeros(batch_size, target_len, target_vocab_size).to(self.device)
        all_attention_weights = torch.zeros(batch_size, target_len, source.size(1)).to(self.device)

        # Encoder
        encoder_outputs, hidden = self.encoder(source)

        # Decoder初期入力(トークン)
        decoder_input = target[:, 0].unsqueeze(1)  # [batch, 1]

        # Decoder各時刻
        for t in range(1, target_len):
            output, hidden, attention_weights = self.decoder(
                decoder_input, hidden, encoder_outputs
            )

            outputs[:, t, :] = output
            all_attention_weights[:, t, :] = attention_weights

            # Teacher forcing
            use_teacher_forcing = random.random() < teacher_forcing_ratio
            if use_teacher_forcing:
                decoder_input = target[:, t].unsqueeze(1)
            else:
                decoder_input = output.argmax(1).unsqueeze(1)

        return outputs, all_attention_weights


# モデル構築とテスト
print("=== Seq2Seq with Attention Model Test ===\n")

# パラメータ
INPUT_VOCAB = 100
OUTPUT_VOCAB = 100
HIDDEN_SIZE = 128
BATCH_SIZE = 4
SOURCE_LEN = 10
TARGET_LEN = 12

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# モデル初期化
encoder = AttentionEncoder(INPUT_VOCAB, HIDDEN_SIZE).to(device)
decoder = AttentionDecoder(OUTPUT_VOCAB, HIDDEN_SIZE, attention_type='bahdanau').to(device)
model = Seq2SeqWithAttention(encoder, decoder, device).to(device)

# ダミーデータ
source = torch.randint(0, INPUT_VOCAB, (BATCH_SIZE, SOURCE_LEN)).to(device)
target = torch.randint(0, OUTPUT_VOCAB, (BATCH_SIZE, TARGET_LEN)).to(device)

# Forward pass
outputs, attention_weights = model(source, target)

print(f"Input shape: {source.shape}")
print(f"Target shape: {target.shape}")
print(f"Output shape: {outputs.shape}")
print(f"Attention weights shape: {attention_weights.shape}")
print(f"\nモデルパラメータ数: {sum(p.numel() for p in model.parameters()):,}")

# 1サンプルのAttention重みを可視化
sample_attention = attention_weights[0].detach().cpu().numpy()
print(f"\nSample attention weights (target_len={TARGET_LEN}, source_len={SOURCE_LEN}):")
print(f"Shape: {sample_attention.shape}")
print(f"各時刻の合計: {sample_attention.sum(axis=1)}")  # 全て1.0になるはず

出力

=== Seq2Seq with Attention Model Test ===

Input shape: torch.Size([4, 10])
Target shape: torch.Size([4, 12])
Output shape: torch.Size([4, 12, 100])
Attention weights shape: torch.Size([4, 12, 10])

モデルパラメータ数: 49,700

Sample attention weights (target_len=12, source_len=10):
Shape: (12, 10)
各時刻の合計: [1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1. 1.]

4.5.2 訓練と評価の実装

def train_attention_model(model, train_loader, optimizer, criterion, device, clip=1.0):
    """Attention付きSeq2Seqの訓練"""
    model.train()
    epoch_loss = 0

    for batch_idx, (source, target) in enumerate(train_loader):
        source, target = source.to(device), target.to(device)

        optimizer.zero_grad()

        # Forward pass
        outputs, _ = model(source, target, teacher_forcing_ratio=0.5)

        # Loss計算(トークンを除く)
        output_dim = outputs.shape[-1]
        outputs_flat = outputs[:, 1:].reshape(-1, output_dim)
        target_flat = target[:, 1:].reshape(-1)

        loss = criterion(outputs_flat, target_flat)

        # Backward pass
        loss.backward()

        # Gradient clipping(勾配爆発防止)
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(train_loader)


def evaluate_attention_model(model, test_loader, criterion, device):
    """Attention付きSeq2Seqの評価"""
    model.eval()
    epoch_loss = 0

    with torch.no_grad():
        for source, target in test_loader:
            source, target = source.to(device), target.to(device)

            # Teacher forcingなしで評価
            outputs, _ = model(source, target, teacher_forcing_ratio=0.0)

            output_dim = outputs.shape[-1]
            outputs_flat = outputs[:, 1:].reshape(-1, output_dim)
            target_flat = target[:, 1:].reshape(-1)

            loss = criterion(outputs_flat, target_flat)
            epoch_loss += loss.item()

    return epoch_loss / len(test_loader)


# 簡易デモ
print("\n=== Training Demo (Simplified) ===\n")

# ダミーデータローダー
class DummyDataset(torch.utils.data.Dataset):
    def __init__(self, num_samples, source_len, target_len, vocab_size):
        self.num_samples = num_samples
        self.source_len = source_len
        self.target_len = target_len
        self.vocab_size = vocab_size

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        source = torch.randint(1, self.vocab_size, (self.source_len,))
        target = torch.randint(1, self.vocab_size, (self.target_len,))
        return source, target

train_dataset = DummyDataset(100, SOURCE_LEN, TARGET_LEN, OUTPUT_VOCAB)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=8, shuffle=True)

# 訓練設定
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss(ignore_index=0)  # Padding無視

# 数エポック訓練
num_epochs = 3
for epoch in range(num_epochs):
    train_loss = train_attention_model(model, train_loader, optimizer, criterion, device)
    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}")

print("\n訓練完了!")

出力

=== Training Demo (Simplified) ===

Epoch [1/3], Train Loss: 4.5342
Epoch [2/3], Train Loss: 4.4187
Epoch [3/3], Train Loss: 4.3256

訓練完了!

4.6 Self-Attention入門(Transformer準備)

4.6.1 Self-Attentionとは

これまでのAttentionは、EncoderとDecoder間の関係を学習していました。Self-Attentionは、同じ系列内の要素間の関係を学習する機構です。

特性 Encoder-Decoder Attention Self-Attention
注目対象 異なる系列間(Source→Target) 同じ系列内(Self→Self)
用途 翻訳、要約などの変換タスク 文脈理解、特徴抽出
Query Decoder状態 系列内の各要素
Key/Value Encoder状態 系列内の全要素
代表例 Bahdanau/Luong Attention Transformer(BERT, GPT)

4.6.2 Query, Key, Valueの概念

Self-Attentionは、各単語を3つの異なる役割で表現します:

Attention計算の数式:

$$ \text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^\top}{\sqrt{d_k}}\right) V $$

ここで $d_k$ はKeyの次元数で、スケーリング係数として使用されます。

4.6.3 Self-Attentionの実装

class SelfAttention(nn.Module):
    """Self-Attention機構の実装(Scaled Dot-Product Attention)"""

    def __init__(self, embed_size, heads=1):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert self.head_dim * heads == embed_size, "Embed size must be divisible by heads"

        # Query, Key, Valueの線形変換
        self.query = nn.Linear(embed_size, embed_size, bias=False)
        self.key = nn.Linear(embed_size, embed_size, bias=False)
        self.value = nn.Linear(embed_size, embed_size, bias=False)

        # 出力の線形変換
        self.fc_out = nn.Linear(embed_size, embed_size)

    def forward(self, x, mask=None):
        """
        Args:
            x: [batch, seq_len, embed_size]
            mask: [batch, seq_len, seq_len] (Optional)
        Returns:
            output: [batch, seq_len, embed_size]
            attention_weights: [batch, heads, seq_len, seq_len]
        """
        batch_size = x.size(0)
        seq_len = x.size(1)

        # Q, K, Vの生成
        Q = self.query(x)  # [batch, seq_len, embed_size]
        K = self.key(x)
        V = self.value(x)

        # Multi-head用に分割: [batch, seq_len, heads, head_dim]
        Q = Q.view(batch_size, seq_len, self.heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.heads, self.head_dim).transpose(1, 2)
        # → [batch, heads, seq_len, head_dim]

        # Scaled Dot-Product Attention
        # Q @ K^T: [batch, heads, seq_len, seq_len]
        scores = torch.matmul(Q, K.transpose(-2, -1)) / np.sqrt(self.head_dim)

        # Maskingがある場合適用
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        # Softmaxで正規化
        attention_weights = F.softmax(scores, dim=-1)

        # Valueとの重み付き和
        out = torch.matmul(attention_weights, V)  # [batch, heads, seq_len, head_dim]

        # Headsを結合
        out = out.transpose(1, 2).contiguous().view(batch_size, seq_len, self.embed_size)

        # 最終的な線形変換
        output = self.fc_out(out)

        return output, attention_weights


# Self-Attention デモ
print("=== Self-Attention Demo ===\n")

# パラメータ
batch_size = 2
seq_len = 6
embed_size = 8
num_heads = 2

# ダミー入力(文章の埋め込み表現を想定)
x = torch.randn(batch_size, seq_len, embed_size)

# Self-Attention適用
self_attn = SelfAttention(embed_size, heads=num_heads)
output, attn_weights = self_attn(x)

print(f"Input shape: {x.shape}")
print(f"Output shape: {output.shape}")
print(f"Attention weights shape: {attn_weights.shape}")
print(f"  → [batch, heads, seq_len, seq_len]")

# 1つのヘッドのAttention重みを可視化
sample_attn = attn_weights[0, 0].detach().numpy()  # 1st batch, 1st head

fig, ax = plt.subplots(figsize=(8, 7))
sns.heatmap(sample_attn,
            cmap='viridis',
            cbar_kws={'label': 'Attention Weight'},
            ax=ax,
            annot=True,
            fmt='.3f',
            linewidths=0.5,
            xticklabels=[f'Pos{i+1}' for i in range(seq_len)],
            yticklabels=[f'Pos{i+1}' for i in range(seq_len)])

ax.set_xlabel('Key Position', fontsize=12, fontweight='bold')
ax.set_ylabel('Query Position', fontsize=12, fontweight='bold')
ax.set_title('Self-Attention Weights Visualization\n(Head 1)',
            fontsize=14, fontweight='bold')

plt.tight_layout()
plt.show()

print("\n各位置が他の位置にどれだけ注目しているかを表す行列")
print("対角成分: 自分自身への注目度")
print("非対角成分: 他の位置への注目度")

出力

=== Self-Attention Demo ===

Input shape: torch.Size([2, 6, 8])
Output shape: torch.Size([2, 6, 8])
Attention weights shape: torch.Size([2, 2, 6, 6])
  → [batch, heads, seq_len, seq_len]

各位置が他の位置にどれだけ注目しているかを表す行列
対角成分: 自分自身への注目度
非対角成分: 他の位置への注目度

4.6.4 Self-Attentionの応用例

def demonstrate_self_attention_patterns():
    """Self-Attentionの典型的なパターンを可視化"""

    print("\n=== Self-Attention典型パターン ===\n")

    seq_len = 8
    tokens = ['The', 'cat', 'sat', 'on', 'the', 'mat', 'today', '.']

    fig, axes = plt.subplots(2, 2, figsize=(14, 12))

    # パターン1: 局所的注意(近隣の単語に注目)
    local_attn = np.zeros((seq_len, seq_len))
    for i in range(seq_len):
        for j in range(seq_len):
            distance = abs(i - j)
            local_attn[i, j] = np.exp(-distance / 2)
    local_attn = local_attn / local_attn.sum(axis=1, keepdims=True)

    sns.heatmap(local_attn, ax=axes[0, 0], cmap='Reds',
                xticklabels=tokens, yticklabels=tokens,
                cbar_kws={'label': 'Weight'}, annot=True, fmt='.2f', linewidths=0.5)
    axes[0, 0].set_title('Pattern 1: Local Attention\n(近隣単語への注目)',
                        fontsize=12, fontweight='bold')

    # パターン2: グローバル注意(特定の重要単語に注目)
    global_attn = np.ones((seq_len, seq_len)) * 0.05
    global_attn[:, 1] = 0.4  # "cat"に強く注目
    global_attn[:, 5] = 0.3  # "mat"にも注目
    global_attn = global_attn / global_attn.sum(axis=1, keepdims=True)

    sns.heatmap(global_attn, ax=axes[0, 1], cmap='Blues',
                xticklabels=tokens, yticklabels=tokens,
                cbar_kws={'label': 'Weight'}, annot=True, fmt='.2f', linewidths=0.5)
    axes[0, 1].set_title('Pattern 2: Global Attention\n(重要単語への集中)',
                        fontsize=12, fontweight='bold')

    # パターン3: 構文構造(主語-動詞-目的語の関係)
    syntax_attn = np.eye(seq_len) * 0.3
    syntax_attn[1, 2] = 0.4  # cat → sat
    syntax_attn[2, 1] = 0.4  # sat → cat
    syntax_attn[2, 5] = 0.3  # sat → mat
    syntax_attn[5, 3] = 0.3  # mat → on
    syntax_attn = syntax_attn / syntax_attn.sum(axis=1, keepdims=True)

    sns.heatmap(syntax_attn, ax=axes[1, 0], cmap='Greens',
                xticklabels=tokens, yticklabels=tokens,
                cbar_kws={'label': 'Weight'}, annot=True, fmt='.2f', linewidths=0.5)
    axes[1, 0].set_title('Pattern 3: Syntactic Attention\n(構文構造の関係)',
                        fontsize=12, fontweight='bold')

    # パターン4: 位置的注意(文の前半/後半への注目)
    position_attn = np.zeros((seq_len, seq_len))
    for i in range(seq_len):
        if i < seq_len // 2:
            position_attn[i, :seq_len//2] = 1.0  # 前半は前半に注目
        else:
            position_attn[i, seq_len//2:] = 1.0  # 後半は後半に注目
    position_attn = position_attn / position_attn.sum(axis=1, keepdims=True)

    sns.heatmap(position_attn, ax=axes[1, 1], cmap='Purples',
                xticklabels=tokens, yticklabels=tokens,
                cbar_kws={'label': 'Weight'}, annot=True, fmt='.2f', linewidths=0.5)
    axes[1, 1].set_title('Pattern 4: Positional Attention\n(位置的な注目パターン)',
                        fontsize=12, fontweight='bold')

    plt.suptitle('Self-Attention: 典型的な注目パターンの可視化',
                fontsize=16, fontweight='bold', y=1.00)
    plt.tight_layout()
    plt.show()

    print("実際のTransformerモデルは、これらのパターンを自動的に学習します")
    print("異なるAttention Headが異なる言語的特徴を捉えます:")
    print("  Head 1: 構文関係(主語-述語など)")
    print("  Head 2: 意味的類似性(関連する単語)")
    print("  Head 3: 位置情報(近い単語、遠い単語)")
    print("  ...など")

demonstrate_self_attention_patterns()

4.7 まとめと発展トピック

本章で学んだこと

トピック 重要ポイント
Attentionの必要性 固定長Context Vectorのボトルネック解消
Bahdanau Attention 加算的スコア計算、初期のAttention機構
Luong Attention 乗算的スコア計算、計算効率の向上
Attention可視化 モデル解釈性の向上、デバッグ支援
Self-Attention 系列内関係の学習、Transformerの基礎
実装テクニック Teacher Forcing、Gradient Clipping

発展トピック

Multi-Head Attention

複数のAttention Headを並列に使用し、異なる表現部分空間から情報を取得します。Transformerの中核的コンポーネントで、次章で詳しく学びます。

MultiHead(Q,K,V) = Concat(head_1, ..., head_h)W^O
where head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
Sparse Attention

計算量を削減するため、全ての位置に注目する代わりに、特定のパターン(局所、ストライド、グローバル)のみに注目します。Longformer、BigBirdなどで使用されています。

Cross-Attention

2つの異なる系列間のAttentionで、画像キャプション生成(画像→テキスト)やマルチモーダル学習で活用されます。

Attention Interpretability

Attentionが実際に「解釈可能」かどうかは議論があります。最近の研究では、Attentionは必ずしもモデルの判断根拠を正確に反映しないことが示されています。

演習問題

演習 4.1: Attention機構の比較実験

課題: Bahdanau AttentionとLuong Attentionを同じデータセットで訓練し、性能を比較してください。

評価項目:

演習 4.2: Attention可視化ツールの開発

課題: インタラクティブなAttention可視化ツールを作成してください。

機能要件:

演習 4.3: Self-Attentionによる文分類

課題: Self-Attentionを使った感情分析モデルを実装してください。

データ: IMDBレビューデータセット

アーキテクチャ: Embedding → Self-Attention → Pooling → FC

演習 4.4: 長文翻訳でのAttention分析

課題: 系列長を変えて、Attentionの挙動を分析してください。

実験:

各ケースでのAttention分布のエントロピーと翻訳精度を比較してください。

演習 4.5: Attention Dropout

課題: Attention重みにDropoutを適用する実装を追加し、汎化性能への影響を調査してください。

比較: Dropout率 0%, 10%, 20%, 30%での性能比較

演習 4.6: Multi-Head Attentionのプロトタイプ

課題: Single-head Self-AttentionをMulti-head版に拡張してください。

実装:


次章予告

第5章では、Attention機構を発展させたTransformerアーキテクチャを学びます。TransformerはRNNを完全に排除し、Self-AttentionとPosition Encodingのみで系列処理を実現します。BERT、GPT、T5などの最先端モデルの基盤技術です。

次章のトピック:
・Transformerの全体アーキテクチャ
・Multi-Head Attention
・Position Encoding
・Feed-Forward Network
・Layer NormalizationとResidual Connection
・EncoderとDecoderの詳細
・実装:ミニTransformerによる機械翻訳

免責事項