第2章: Transformerアーキテクチャ

Self-AttentionからMixture of Expertsまで

読了時間: 30-35分 難易度: 中級 最終更新: 2026年1月
免責事項: 本章は2026年初頭時点の最新のTransformerアーキテクチャを解説しています。この分野は急速に進化しており、実装の詳細については公式ドキュメントを参照してください。

1. Self-Attention機構

Self-Attention(自己注意機構)は、Transformerアーキテクチャの中核をなすメカニズムです。入力シーケンス内の各トークンが、他のすべてのトークンとの関連性を動的に計算することで、長距離の依存関係を効果的に捉えることができます。

1.1 Attentionの直感的理解

Attentionの本質

「The cat sat on the mat because it was tired」という文において、「it」が何を指すかを理解するには、文全体の文脈を考慮する必要があります。Self-Attentionは、「it」と「cat」の間の強い関連性を学習し、適切な参照解決を可能にします。

flowchart LR subgraph "入力" T1[The] T2[cat] T3[sat] T4[...] T5[it] end T5 -->|強い注意| T2 T5 -->|弱い注意| T1 T5 -->|弱い注意| T3 style T2 fill:#e3f2fd style T5 fill:#fff3e0

1.2 数学的定式化

Self-Attentionは、Query(Q)、Key(K)、Value(V)の3つの線形変換を通じて計算されます:

$\text{Attention}(Q, K, V) = \text{softmax}\left(\frac{QK^T}{\sqrt{d_k}}\right)V$

ここで、$d_k$はKey次元であり、$\sqrt{d_k}$によるスケーリングは勾配の安定化のために重要です。

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class SelfAttention(nn.Module):
    """
    Self-Attentionの基本実装

    パラメータ:
        embed_dim: 埋め込み次元
        num_heads: アテンションヘッド数(オプション、Multi-Head用)
    """

    def __init__(self, embed_dim: int, num_heads: int = 1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        assert embed_dim % num_heads == 0, "embed_dimはnum_headsで割り切れる必要があります"

        # Query, Key, Value の線形変換
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # 出力投影
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        self.scale = math.sqrt(self.head_dim)

    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor = None
    ) -> torch.Tensor:
        """
        Args:
            x: 入力テンソル [batch_size, seq_len, embed_dim]
            mask: オプションの注意マスク [batch_size, seq_len, seq_len]

        Returns:
            出力テンソル [batch_size, seq_len, embed_dim]
        """
        batch_size, seq_len, _ = x.shape

        # Q, K, V を計算
        Q = self.q_proj(x)  # [batch, seq, embed]
        K = self.k_proj(x)
        V = self.v_proj(x)

        # 注意スコアを計算
        # [batch, seq, embed] @ [batch, embed, seq] -> [batch, seq, seq]
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        # マスクを適用(因果マスクなど)
        if mask is not None:
            attention_scores = attention_scores.masked_fill(mask == 0, float('-inf'))

        # Softmaxで確率に変換
        attention_weights = F.softmax(attention_scores, dim=-1)

        # Valueの重み付け和
        output = torch.matmul(attention_weights, V)

        return self.out_proj(output)

# 使用例
embed_dim = 512
seq_len = 100
batch_size = 32

attention = SelfAttention(embed_dim)
x = torch.randn(batch_size, seq_len, embed_dim)
output = attention(x)
print(f"出力形状: {output.shape}")  # [32, 100, 512]

2. Multi-Head Attention

Multi-Head Attention(MHA)は、複数の「ヘッド」を並列に使用することで、異なる表現部分空間で異なるタイプの依存関係を学習します。

flowchart TB Input[入力] --> Split[分割] Split --> H1[Head 1] Split --> H2[Head 2] Split --> H3[Head 3] Split --> H4[Head ...] H1 --> Concat[連結] H2 --> Concat H3 --> Concat H4 --> Concat Concat --> Output[出力投影]
class MultiHeadAttention(nn.Module):
    """
    Multi-Head Attention の実装

    各ヘッドは異なる表現部分空間で注意を学習し、
    最終的に連結されて出力される。
    """

    def __init__(
        self,
        embed_dim: int,
        num_heads: int,
        dropout: float = 0.0,
        bias: bool = True
    ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = math.sqrt(self.head_dim)

        # 統合されたQKV投影(効率化のため)
        self.qkv_proj = nn.Linear(embed_dim, 3 * embed_dim, bias=bias)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
        self.dropout = nn.Dropout(dropout)

    def forward(
        self,
        x: torch.Tensor,
        attention_mask: torch.Tensor = None,
        is_causal: bool = False
    ) -> tuple:
        batch_size, seq_len, _ = x.shape

        # QKVを一度に計算
        qkv = self.qkv_proj(x)
        qkv = qkv.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)  # [3, batch, heads, seq, head_dim]
        Q, K, V = qkv[0], qkv[1], qkv[2]

        # 注意スコア計算
        attention_scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale

        # 因果マスク(デコーダ用)
        if is_causal:
            causal_mask = torch.triu(
                torch.ones(seq_len, seq_len, device=x.device),
                diagonal=1
            ).bool()
            attention_scores = attention_scores.masked_fill(causal_mask, float('-inf'))

        # 追加のマスク
        if attention_mask is not None:
            attention_scores = attention_scores + attention_mask

        # Softmax と Dropout
        attention_weights = F.softmax(attention_scores, dim=-1)
        attention_weights = self.dropout(attention_weights)

        # 出力計算
        output = torch.matmul(attention_weights, V)
        output = output.transpose(1, 2).reshape(batch_size, seq_len, self.embed_dim)
        output = self.out_proj(output)

        return output, attention_weights

3. 位置エンコーディング

Transformerは本質的に順序を認識しないため、位置情報を明示的に注入する必要があります。位置エンコーディング手法は大きく進化しており、現代のLLMでは回転位置エンコーディング(RoPE)が標準となっています。

3.1 位置エンコーディングの比較

手法 タイプ 長所 使用例
正弦波エンコーディング 絶対位置 シンプル、学習不要 オリジナルTransformer
学習位置エンコーディング 絶対位置 タスク適応的 BERT, GPT-2
RoPE 相対位置 外挿性、効率的 Llama, Qwen, Mistral
ALiBi 相対位置 学習不要、高速 BLOOM, MPT

3.2 RoPE(回転位置エンコーディング)

RoPEは、位置情報をQueryとKeyベクトルに回転行列として適用します。これにより、注意計算が自然に相対位置をエンコードするようになります。

class RotaryPositionEmbedding(nn.Module):
    """
    回転位置エンコーディング(RoPE)の実装

    RoPEは位置情報をベクトルの回転として適用し、
    相対位置情報を注意計算に自然に組み込む。

    利点:
    - 学習パラメータなし
    - 相対位置のエンコード
    - 長いシーケンスへの優れた外挿性
    """

    def __init__(self, dim: int, max_seq_len: int = 8192, base: int = 10000):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        # 周波数を事前計算
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
        self.register_buffer("inv_freq", inv_freq)

        # 位置インデックスを事前計算
        self._set_cos_sin_cache(max_seq_len)

    def _set_cos_sin_cache(self, seq_len: int):
        """cos/sinキャッシュを事前計算"""
        positions = torch.arange(seq_len).float()
        freqs = torch.outer(positions, self.inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)

        self.register_buffer("cos_cached", emb.cos())
        self.register_buffer("sin_cached", emb.sin())

    def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
        """ベクトルの半分を回転"""
        x1 = x[..., : x.shape[-1] // 2]
        x2 = x[..., x.shape[-1] // 2 :]
        return torch.cat((-x2, x1), dim=-1)

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        positions: torch.Tensor = None
    ) -> tuple:
        """
        QueryとKeyにRoPEを適用

        Args:
            q: Query テンソル [batch, heads, seq, head_dim]
            k: Key テンソル [batch, heads, seq, head_dim]
            positions: オプションの位置インデックス

        Returns:
            回転後のq, kのタプル
        """
        seq_len = q.shape[2]

        if positions is None:
            cos = self.cos_cached[:seq_len]
            sin = self.sin_cached[:seq_len]
        else:
            cos = self.cos_cached[positions]
            sin = self.sin_cached[positions]

        # ブロードキャスト用に次元を調整
        cos = cos.unsqueeze(0).unsqueeze(0)  # [1, 1, seq, dim]
        sin = sin.unsqueeze(0).unsqueeze(0)

        # 回転を適用
        q_rotated = (q * cos) + (self.rotate_half(q) * sin)
        k_rotated = (k * cos) + (self.rotate_half(k) * sin)

        return q_rotated, k_rotated

# 使用例
rope = RotaryPositionEmbedding(dim=64, max_seq_len=4096)
q = torch.randn(2, 8, 100, 64)  # [batch, heads, seq, head_dim]
k = torch.randn(2, 8, 100, 64)
q_rot, k_rot = rope(q, k)
print(f"回転後の形状: {q_rot.shape}")

4. アーキテクチャの変種

4.1 Encoder-Decoder vs Decoder-Only

flowchart TB subgraph "Encoder-Decoder (T5, BART)" E1[入力] --> E2[エンコーダ] E2 --> E3[隠れ状態] E3 --> D1[デコーダ] D1 --> O1[出力] end subgraph "Decoder-Only (GPT, Llama)" I2[入力] --> D2[デコーダ層] D2 --> D3[デコーダ層] D3 --> O2[出力] end
アーキテクチャ 代表モデル 最適な用途
Encoder-Only BERT, RoBERTa 分類、NER、埋め込み
Decoder-Only GPT-4, Llama, Claude テキスト生成、会話
Encoder-Decoder T5, BART 翻訳、要約
業界トレンド: 2024-2026年において、Decoder-Onlyアーキテクチャが大規模LLMの標準となっています。これは、より単純なアーキテクチャでスケーリングが容易であり、生成タスクへの統一的なアプローチが可能なためです。

5. Mixture of Experts(MoE)

MoE(Mixture of Experts)は、Transformerのフィードフォワード層を複数の「エキスパート」ネットワークに置き換え、各入力に対して少数のエキスパートのみを活性化することで、パラメータ効率とスケーラビリティを向上させます。

5.1 MoEの仕組み

flowchart TB Input[入力トークン] --> Router[ルーター] Router -->|選択| E1[Expert 1] Router -->|選択| E2[Expert 2] Router -.->|非活性| E3[Expert 3] Router -.->|非活性| E4[Expert 4] E1 --> Combine[重み付け結合] E2 --> Combine Combine --> Output[出力] style E3 fill:#f0f0f0 style E4 fill:#f0f0f0

MoEの主な利点

5.2 代表的なMoEモデル

モデル 総パラメータ 活性パラメータ エキスパート数
Mixtral 8x22B 176B 39B 8
DeepSeek-V3 671B 37B 256
Llama 4 Scout 109B 17B 16
Qwen3-235B-A22B 235B 22B 128
class MoELayer(nn.Module):
    """
    Mixture of Expertsレイヤーの実装

    Transformerのフィードフォワード層を複数のエキスパートに置き換え、
    ルーターが各トークンに対して最適なエキスパートを選択する。
    """

    def __init__(
        self,
        embed_dim: int,
        num_experts: int = 8,
        top_k: int = 2,
        expert_dim: int = None
    ):
        super().__init__()
        self.num_experts = num_experts
        self.top_k = top_k
        expert_dim = expert_dim or embed_dim * 4

        # ルーターネットワーク
        self.router = nn.Linear(embed_dim, num_experts, bias=False)

        # エキスパートネットワーク(各々がFFN)
        self.experts = nn.ModuleList([
            nn.Sequential(
                nn.Linear(embed_dim, expert_dim),
                nn.GELU(),
                nn.Linear(expert_dim, embed_dim)
            )
            for _ in range(num_experts)
        ])

    def forward(self, x: torch.Tensor) -> tuple:
        """
        Args:
            x: 入力テンソル [batch_size, seq_len, embed_dim]

        Returns:
            出力テンソルとルーター損失(負荷分散用)のタプル
        """
        batch_size, seq_len, embed_dim = x.shape
        x_flat = x.view(-1, embed_dim)  # [batch*seq, embed]

        # ルーティングスコアを計算
        router_logits = self.router(x_flat)  # [batch*seq, num_experts]
        router_probs = F.softmax(router_logits, dim=-1)

        # Top-kエキスパートを選択
        top_k_probs, top_k_indices = torch.topk(router_probs, self.top_k, dim=-1)

        # 選択されたエキスパートの確率を正規化
        top_k_probs = top_k_probs / top_k_probs.sum(dim=-1, keepdim=True)

        # 出力を初期化
        output = torch.zeros_like(x_flat)

        # 各エキスパートを処理
        for expert_idx in range(self.num_experts):
            # このエキスパートに割り当てられたトークンを見つける
            expert_mask = (top_k_indices == expert_idx).any(dim=-1)

            if expert_mask.any():
                expert_input = x_flat[expert_mask]
                expert_output = self.experts[expert_idx](expert_input)

                # このエキスパートの重みを取得
                expert_weights = torch.where(
                    top_k_indices[expert_mask] == expert_idx,
                    top_k_probs[expert_mask],
                    torch.zeros_like(top_k_probs[expert_mask])
                ).sum(dim=-1, keepdim=True)

                output[expert_mask] += expert_weights * expert_output

        # 負荷分散損失(エキスパート利用の均等化)
        load_balance_loss = self._compute_load_balance_loss(router_probs)

        return output.view(batch_size, seq_len, embed_dim), load_balance_loss

    def _compute_load_balance_loss(self, router_probs: torch.Tensor) -> torch.Tensor:
        """エキスパート間の負荷分散を促す補助損失"""
        # 各エキスパートの平均利用率
        expert_usage = router_probs.mean(dim=0)
        # 均等分布からの乖離
        target = 1.0 / self.num_experts
        return ((expert_usage - target) ** 2).sum()

# 使用例
moe = MoELayer(embed_dim=512, num_experts=8, top_k=2)
x = torch.randn(4, 100, 512)
output, aux_loss = moe(x)
print(f"出力形状: {output.shape}, 補助損失: {aux_loss.item():.4f}")

6. 効率的なAttention機構

6.1 FlashAttention

FlashAttentionは、GPUのメモリ階層を最適化することで、標準的なAttentionと比較して2-4倍の高速化と大幅なメモリ削減を実現します。

FlashAttentionの原理

# FlashAttentionの使用(PyTorch 2.0+)
import torch
import torch.nn.functional as F

def flash_attention_example(q, k, v, is_causal=True):
    """
    PyTorch 2.0+のscaled_dot_product_attentionを使用した
    FlashAttention互換の実装

    FlashAttention v2は以下を自動的に最適化:
    - メモリ効率(O(N)のメモリ使用)
    - 計算効率(タイリングによる高速化)
    - 因果マスクの最適化
    """
    # PyTorch 2.0+では自動的に最適な実装を選択
    output = F.scaled_dot_product_attention(
        q, k, v,
        attn_mask=None,
        dropout_p=0.0,
        is_causal=is_causal,
        scale=None  # 自動計算
    )
    return output

# 実装比較
def benchmark_attention():
    """標準Attention vs FlashAttentionのベンチマーク"""
    import time

    batch_size = 4
    num_heads = 32
    seq_len = 4096
    head_dim = 128

    q = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')
    k = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')
    v = torch.randn(batch_size, num_heads, seq_len, head_dim, device='cuda')

    # ウォームアップ
    _ = flash_attention_example(q, k, v)
    torch.cuda.synchronize()

    # FlashAttention
    start = time.time()
    for _ in range(10):
        _ = flash_attention_example(q, k, v)
    torch.cuda.synchronize()
    flash_time = (time.time() - start) / 10

    print(f"FlashAttention: {flash_time*1000:.2f}ms")
    print(f"シーケンス長: {seq_len}, メモリ使用量: O(N)")

# GPUが利用可能な場合実行
if torch.cuda.is_available():
    benchmark_attention()

6.2 Grouped-Query Attention(GQA)

GQAは、複数のQueryヘッドで1つのKey-Valueヘッドを共有することで、KVキャッシュのメモリ使用量を削減します。

タイプ Q:KV比率 KVキャッシュ削減 使用モデル
MHA 1:1 0% GPT-3
GQA 8:1 87.5% Llama 2/3
MQA N:1 最大 PaLM

まとめ

第2章の重要ポイント

前へ: LLMの基礎 次へ: LLMの訓練と整合性
English